第 04 篇:数据加载的完整链路——Dataset 和 DataLoader
大多数深度学习教程在讲数据加载时,直接给你一行代码:
dataset = torchvision.datasets.CIFAR10(root='./data', download=True)
然后继续往下讲模型。你照着写,能跑,但你不知道这行代码背后发生了什么,也不知道换成自己的数据集该怎么办。
等到真正开始做项目,数据是自己采集的图片,或者是一个 CSV 文件,或者是某种奇怪格式的标注文件——你才发现完全不知道从哪下手。
这篇文章就是为了解决这个问题。我们不用任何内置数据集,从零开始写一个图片分类的数据加载管道,同时把每一个设计决策背后的原因讲清楚。
数据加载需要解决什么问题
训练一个模型,数据需要经历这些步骤:
硬盘上的原始文件 --> 读取文件(图片解码 / CSV 解析 / ...)--> 预处理(resize、normalize、数据增强 ...)--> 组成 batch(把多个样本打包)--> 搬运到 GPU --> 喂给模型
这条流水线如果设计得不好,GPU 会大量时间在等数据——GPU 算完一个 batch,下一个 batch 还没准备好,白白浪费算力。
PyTorch 用两个类来分工解决这件事:
- Dataset:负责"知道数据在哪,怎么读取单个样本";
- DataLoader:负责"怎么把单个样本组织成 batch,怎么并行预处理,怎么打乱顺序"。
Dataset 只管单样本,DataLoader 只管批量调度。这种设计让你可以专注在数据处理逻辑上,不用操心并行和批处理的细节。
第一步:准备一个真实的图片数据集目录
先搭建测试用的目录结构。真实项目里,图片分类数据集最常见的组织方式是"每个类别一个文件夹":
data/
├── train/
│ ├── cat/
│ │ ├── cat_001.jpg
│ │ ├── cat_002.jpg
│ │ └── ...
│ ├── dog/
│ │ ├── dog_001.jpg
│ │ └── ...
│ └── bird/
│ └── ...
└── val/
├── cat/
├── dog/
└── bird/
一个包含 train/ 和 val/ 目录的数据集。这是真实项目里最普遍的数据组织形式。
第二步:从零写一个 Dataset
PyTorch 的 Dataset 是一个抽象基类,定义在 torch.utils.data 里。你的自定义 Dataset 需要继承它,并实现两个方法:
__len__:返回数据集的总样本数__getitem__:给定一个索引idx,返回第idx个样本
就这两个,没有别的强制要求。
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
from pathlib import Path
import os
class ImageFolderDataset(Dataset):
"""
自定义图片文件夹数据集。
目录结构:root_dir/class_name/image.jpg
"""
def __init__(self, root_dir: str, transform=None):
"""
初始化:扫描目录,建立索引,不读取任何图片文件。
Args:
root_dir: 数据根目录(如 './data/train')
transform: 图片预处理/增强的 transforms,默认为 None
"""
self.root_dir = Path(root_dir)
self.transform = transform
# ---- 建立类别到整数标签的映射 ----
# sorted() 保证每次运行顺序一致,避免随机性导致标签混乱
self.classes = sorted([
d.name for d in self.root_dir.iterdir() if d.is_dir()
])
self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}
# {'bird': 0, 'cat': 1, 'dog': 2}
# ---- 扫描所有图片,建立 (图片路径, 标签) 的列表 ----
# 这一步只记录路径,不读取文件内容
# 好处:初始化速度快,内存占用极小,数百万张图片也没问题
self.samples = []
VALID_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.webp'}
for cls_name in self.classes:
cls_dir = self.root_dir / cls_name
label = self.class_to_idx[cls_name]
for img_path in cls_dir.iterdir():
if img_path.suffix.lower() in VALID_EXTENSIONS:
self.samples.append((img_path, label))
print(f"数据集初始化完成:{len(self.samples)} 张图片,{len(self.classes)} 个类别")
print(f"类别映射:{self.class_to_idx}")
def __len__(self):
"""
返回数据集的总样本数。
DataLoader 通过这个方法知道数据集有多大,从而决定要生成多少个 batch。
"""
return len(self.samples)
def __getitem__(self, idx):
"""
给定索引 idx,返回第 idx 个样本。
DataLoader 在需要某个样本时就会调用这个方法。
返回值:(image_tensor, label)
image_tensor: 经过 transform 处理后的图片张量
label: 整数标签
"""
img_path, label = self.samples[idx]
# ---- 读取图片 ----
# 这里才真正发生 I/O,打开并解码图片文件
# 用 PIL 打开是最通用的方式,支持各种格式
img = Image.open(img_path).convert('RGB')
# convert('RGB') 是个好习惯:强制转为 3 通道
# 有些图片是 RGBA(4 通道)或者灰度图,不转换的话后续处理可能出错
# ---- 应用 transform ----
if self.transform is not None:
img = self.transform(img)
# transform 通常会把 PIL Image 转成 Tensor,返回 (C, H, W) 的浮点张量
return img, label
注意 __init__ 里的设计决策:只记录路径,不读取图片。
这是非常重要的原则。如果你在初始化时就把所有图片读进内存,10 万张 224×224 的图片大概需要 50GB 内存,根本装不下。正确的做法是只存路径(一个字符串列表),真正需要某张图片时才在 __getitem__ 里读取。
这种"懒加载"的设计使得 Dataset 可以处理任意规模的数据集,哪怕数据集大小远超内存。
第三步:理解 getitem 的调用时机
__getitem__ 不是你手动调用的,是 DataLoader 在需要某个样本时自动调用的。
# 你可以直接用索引访问 Dataset(但通常不这么做)
dataset = ImageFolderDataset('./data/train')
# 这会触发 __getitem__(0)
sample = dataset[0]
print(type(sample)) # <class 'tuple'>
print(len(sample)) # 2 ← (image, label)
print(type(sample[0])) # <class 'PIL.Image.Image'>(还没加 transform)
print(sample[1]) # 整数标签,如 1
这里有一个值得注意的地方:__getitem__ 在训练时会被调用数百万次(每个 epoch 里每个样本都被访问一次)。它必须足够高效,因为它是整个数据加载管道里最频繁的操作。
常见的性能建议:
- 图片读取用 PIL 或者 OpenCV,不要用纯 Python 循环逐像素处理
- 如果数据集很小(比如几千张图),可以在
__init__里把所有图片预先读进内存,__getitem__直接从内存返回 - 如果数据集极大,考虑把数据预处理成 LMDB 或者 HDF5 格式,随机访问更快
第四步:transforms——在 Dataset 里做预处理
transform 是一个可调用对象(callable),接收一个样本,返回处理后的样本。torchvision.transforms 提供了大量现成的变换。
from torchvision import transforms
# 训练集的 transform:包含数据增强
train_transform = transforms.Compose([
# ---- 几何变换(数据增强) ----
transforms.Resize(256), # 先缩放到 256(短边)
transforms.RandomCrop(224), # 随机裁剪到 224×224
transforms.RandomHorizontalFlip(p=0.5), # 50% 概率水平翻转
# ---- 颜色增强(可选,对颜色敏感的任务慎用) ----
transforms.ColorJitter(
brightness=0.2, # 随机调整亮度 ±20%
contrast=0.2, # 随机调整对比度 ±20%
saturation=0.2, # 随机调整饱和度 ±20%
),
# ---- 转为 Tensor + 归一化 ----
transforms.ToTensor(),
# ToTensor 做了两件事:
# 1. PIL Image (H, W, C) uint8 → Tensor (C, H, W) float32
# 2. 像素值 [0, 255] → [0.0, 1.0]
transforms.Normalize(
mean=[0.485, 0.456, 0.406], # ImageNet 的均值
std=[0.229, 0.224, 0.225] # ImageNet 的标准差
),
# Normalize 做的是:output = (input - mean) / std
# 让每个通道的数值分布大致在 [-2, 2] 区间,有助于训练稳定
])
# 验证集的 transform:不做随机增强,只做确定性预处理
val_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224), # 中心裁剪,而不是随机裁剪
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
),
])
训练集和验证集用不同的 transform,这是一个必须遵守的原则:
训练集:需要数据增强(随机翻转、随机裁剪、颜色抖动),人为增加样本多样性,提升泛化能力。
验证集:绝对不能用随机增强。验证是为了评估模型的真实性能,如果每次验证时图片内容因随机性而不同,指标就失去了可比性,你没法判断模型是在进步还是运气好。验证集只用确定性变换(Resize + CenterCrop + ToTensor + Normalize)。
ToTensor 和 Normalize 的细节
from torchvision import transforms
from PIL import Image
import torch
# 创建一个测试图片
img = Image.fromarray(
(torch.randint(0, 256, (64, 64, 3))).numpy().astype('uint8')
)
print(f"PIL Image 大小:{img.size},模式:{img.mode}")
# PIL Image 大小:(64, 64),模式:RGB
# ToTensor
to_tensor = transforms.ToTensor()
tensor = to_tensor(img)
print(f"Tensor 形状:{tensor.shape}") # torch.Size([3, 64, 64])
print(f"值域:[{tensor.min():.3f}, {tensor.max():.3f}]") # [0.000, 1.000]
# 注意:PIL 是 (H, W, C),ToTensor 会自动变成 (C, H, W)
# 注意:uint8 [0,255] 被自动转为 float32 [0.0, 1.0]
# Normalize
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
normalized = normalize(tensor)
print(f"归一化后值域:[{normalized.min():.3f}, {normalized.max():.3f}]")
# [-1.000, 1.000]
# (0.0 - 0.5) / 0.5 = -1.0,(1.0 - 0.5) / 0.5 = 1.0
一个常见的错误:先 Normalize 再 ToTensor,或者 Normalize 的顺序弄错。Normalize 要求输入已经是 float Tensor(值域 [0, 1]),必须放在 ToTensor 之后。
第五步:把 Dataset 和 DataLoader 接起来
# 实例化 Dataset
train_dataset = ImageFolderDataset('./data/train', transform=train_transform)
val_dataset = ImageFolderDataset('./data/val', transform=val_transform)
print(f"训练集大小:{len(train_dataset)}") # 300(3类 × 100张)
print(f"验证集大小:{len(val_dataset)}") # 60(3类 × 20张)
# 验证单个样本的形状
img, label = train_dataset[0]
print(f"单个样本 - 图片形状:{img.shape},标签:{label}")
# 单个样本 - 图片形状:torch.Size([3, 224, 224]),标签:0
现在把 DataLoader 接上:
from torch.utils.data import DataLoader
train_loader = DataLoader(
dataset = train_dataset,
batch_size = 32, # 每个 batch 包含 32 个样本
shuffle = True, # 每个 epoch 开始前打乱顺序(训练集必须 True)
num_workers = 4, # 用 4 个子进程并行加载数据
pin_memory = True, # 把数据锁定在内存,加速 CPU→GPU 的数据传输
drop_last = True, # 丢弃最后一个不满 batch_size 的 batch
)
val_loader = DataLoader(
dataset = val_dataset,
batch_size = 64, # 验证时不需要梯度,显存充裕,batch 可以更大
shuffle = False, # 验证集不需要打乱
num_workers = 4,
pin_memory = True,
drop_last = False, # 验证集要评估所有样本,不能丢
)
# 查看一个 batch 的形状
images, labels = next(iter(train_loader))
print(f"一个 batch - 图片形状:{images.shape},标签形状:{labels.shape}")
# 一个 batch - 图片形状:torch.Size([32, 3, 224, 224]),标签形状:torch.Size([32])
DataLoader 从 Dataset 里取出单个样本,然后沿着 batch 维度把它们堆叠起来:
Dataset[0] → (img_tensor [3,224,224], label 1)
Dataset[1] → (img_tensor [3,224,224], label 0)
...
Dataset[31] → (img_tensor [3,224,224], label 2)
↓ DataLoader 拼接
(images [32,3,224,224], labels [32])
这个"拼接"操作在 DataLoader 里叫做 collate,默认用 torch.stack 把所有样本的 Tensor 堆叠成第 0 维(batch 维)。
第六步:DataLoader 的每个参数都在做什么
batch_size
最直接的参数。每次迭代返回 batch_size 个样本。
batch size 的选择是一个权衡:
- 太小(如 1 或 4):每个 batch 的梯度估计噪声太大,训练不稳定;GPU 利用率低,大量时间浪费在 kernel 启动开销上
- 太大:需要更多显存;研究表明过大的 batch 有时会损害泛化性能(陷入锐利极值,参考第 01 篇的宽极值讨论)
- 常见的选择是 32、64、128,要是 2 的幂次(配合硬件内存对齐效率更高)
shuffle
shuffle=True:每个 epoch 开始前,重新随机排列所有样本的顺序
这保证了每个 batch 里的样本是随机组合的,模型不会记住"第 N 个 batch 总是这些样本"
为什么训练集要 shuffle?
假设你的数据按类别排列:前 1000 张全是猫,接着 1000 张全是狗
不 shuffle 的话,前几个 batch 全是猫,梯度会极度偏向猫类
参数更新方向完全失衡,训练会很混乱
验证集不需要 shuffle:
验证只是在算指标,顺序不影响结果,保持固定顺序还方便调试
num_workers
这是理解 DataLoader 工作原理的关键参数。
num_workers=0:主进程单线程加载(默认值)
每次需要一个样本时,主进程暂停训练,去读图片,处理完再继续
GPU 等待期间完全空闲——这是最常见的训练瓶颈
num_workers=4:启动 4 个独立的子进程来加载数据
主进程在 GPU 计算当前 batch 的同时,子进程已经在预读取下一个 batch
GPU 几乎不需要等待,计算效率大幅提升
Windows 用户注意:在 Windows 上使用 num_workers > 0 时,必须把 DataLoader 的创建放在 if __name__ == '__main__': 保护块里,否则子进程会递归启动导致报错:
# Windows 上的正确写法
if __name__ == '__main__':
loader = DataLoader(dataset, batch_size=32, num_workers=4)
for images, labels in loader:
...
pin_memory
pin_memory=True:把加载好的 Tensor 放在"锁页内存"(page-locked memory)
普通内存页面可能被操作系统换出到硬盘(swap),锁页内存保证始终在物理内存里
从锁页内存传输数据到 GPU,可以启用 DMA(直接内存访问),
不需要先把数据拷贝到一个临时缓冲区,速度更快
什么时候开:有 GPU 训练时,几乎总是应该开
什么时候不开:CPU-only 训练,或内存非常紧张时
drop_last
假设训练集有 305 个样本,batch_size=32
305 / 32 = 9 个完整 batch(288 个样本)+ 1 个不完整 batch(17 个样本)
drop_last=True:丢弃最后那个 17 个样本的 batch
好处:所有 batch 大小一致,不会因为最后一个 batch 太小导致统计不稳定(BN 层对 batch size 敏感,最后一个小 batch 可能让 BN 的统计量失真)
drop_last=False(默认):保留最后一个小 batch
验证集通常用 False,确保每个样本都被评估到
collate_fn:自定义 batch 打包逻辑
默认的 collate 逻辑是用 torch.stack 把样本堆叠,要求所有样本的形状完全相同。但如果你的数据形状不一样(比如不同长度的文本序列,或者不同尺寸的图片),就需要自定义 collate_fn:
def custom_collate_fn(batch):
"""
batch 是一个列表,每个元素是 __getitem__ 返回的单个样本。
batch = [(img1, label1), (img2, label2), ..., (img32, label32)]
默认的 collate 会直接 torch.stack,要求形状相同。
自定义 collate 可以做 padding、过滤空样本等操作。
"""
# 过滤掉 __getitem__ 返回 None 的样本(有时候文件损坏时这样处理)
batch = [b for b in batch if b is not None]
if len(batch) == 0:
return None
images = torch.stack([b[0] for b in batch]) # (B, C, H, W)
labels = torch.tensor([b[1] for b in batch]) # (B,)
return images, labels
loader = DataLoader(train_dataset, batch_size=32, collate_fn=custom_collate_fn)
第七步:处理一个真实的边界情况——文件损坏
真实数据集里总有几张坏掉的图片。如果 __getitem__ 直接抛出异常,整个训练就崩了。
防御性写法:
class RobustImageDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = Path(root_dir)
self.transform = transform
self.classes = sorted([d.name for d in self.root_dir.iterdir() if d.is_dir()])
self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}
self.samples = []
VALID_EXT = {'.jpg', '.jpeg', '.png', '.bmp', '.webp'}
for cls_name in self.classes:
cls_dir = self.root_dir / cls_name
label = self.class_to_idx[cls_name]
for p in cls_dir.iterdir():
if p.suffix.lower() in VALID_EXT:
self.samples.append((p, label))
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
img_path, label = self.samples[idx]
try:
img = Image.open(img_path).convert('RGB')
if self.transform:
img = self.transform(img)
return img, label
except Exception as e:
# 记录错误,但不崩溃
print(f"警告:无法读取图片 {img_path},错误:{e}")
# 返回 None,配合自定义 collate_fn 过滤掉
return None
配合自定义 collate_fn 过滤 None:
def collate_skip_none(batch):
"""过滤掉损坏样本(__getitem__ 返回 None 的)"""
batch = [b for b in batch if b is not None]
if not batch:
return None
images = torch.stack([b[0] for b in batch])
labels = torch.tensor([b[1] for b in batch])
return images, labels
robust_loader = DataLoader(
RobustImageDataset('./data/train', transform=train_transform),
batch_size=32,
collate_fn=collate_skip_none,
num_workers=4,
)
# 训练循环里处理 None batch
for batch in robust_loader:
if batch is None:
continue # 跳过全坏的 batch
images, labels = batch
# 正常训练...
第八步:完整的训练循环
把上面所有东西组合成一个完整可运行的训练例子:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
from pathlib import Path
# -------- Dataset --------
class ImageFolderDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = Path(root_dir)
self.transform = transform
self.classes = sorted([d.name for d in self.root_dir.iterdir() if d.is_dir()])
self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}
self.samples = []
VALID_EXT = {'.jpg', '.jpeg', '.png', '.bmp', '.webp'}
for cls_name in self.classes:
label = self.class_to_idx[cls_name]
for p in (self.root_dir / cls_name).iterdir():
if p.suffix.lower() in VALID_EXT:
self.samples.append((p, label))
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
img_path, label = self.samples[idx]
img = Image.open(img_path).convert('RGB')
if self.transform:
img = self.transform(img)
return img, label
# -------- Transforms --------
train_transform = transforms.Compose([
transforms.Resize(80),
transforms.RandomCrop(64),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
val_transform = transforms.Compose([
transforms.Resize(80),
transforms.CenterCrop(64),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
# -------- DataLoader --------
train_dataset = ImageFolderDataset('./data/train', transform=train_transform)
val_dataset = ImageFolderDataset('./data/val', transform=val_transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True,
num_workers=2, pin_memory=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False,
num_workers=2, pin_memory=True)
# -------- 模型、优化器、损失函数 --------
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_classes = len(train_dataset.classes)
# 用一个简单的 CNN(不用预训练,方便观察从随机初始化开始的训练过程)
model = nn.Sequential(
nn.Conv2d(3, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2), # 64→32
nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2), # 32→16
nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2), # 16→8
nn.Flatten(),
nn.Linear(128 * 8 * 8, 256), nn.ReLU(), nn.Dropout(0.5),
nn.Linear(256, num_classes),
).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
# -------- 训练与验证函数 --------
def train_one_epoch(model, loader, optimizer, criterion, device):
model.train()
total_loss, correct, total = 0.0, 0, 0
for images, labels in loader:
# 把数据搬到 GPU(如果有)
images = images.to(device, non_blocking=True)
# non_blocking=True:配合 pin_memory,让数据传输和 GPU 计算异步进行
labels = labels.to(device, non_blocking=True)
optimizer.zero_grad()
outputs = model(images) # 前向传播
loss = criterion(outputs, labels) # 计算损失
loss.backward() # 反向传播
optimizer.step() # 参数更新
# 统计指标
total_loss += loss.item() * images.size(0) # loss.item() 是这个 batch 的均值
_, predicted = outputs.max(dim=1) # 取概率最大的类别
correct += predicted.eq(labels).sum().item()
total += labels.size(0)
avg_loss = total_loss / total
accuracy = correct / total
return avg_loss, accuracy
@torch.no_grad()
def evaluate(model, loader, criterion, device):
model.eval()
total_loss, correct, total = 0.0, 0, 0
for images, labels in loader:
images = images.to(device, non_blocking=True)
labels = labels.to(device, non_blocking=True)
outputs = model(images)
loss = criterion(outputs, labels)
total_loss += loss.item() * images.size(0)
_, predicted = outputs.max(dim=1)
correct += predicted.eq(labels).sum().item()
total += labels.size(0)
return total_loss / total, correct / total
# -------- 主训练循环 --------
num_epochs = 10
print(f"开始训练,设备:{device},类别:{train_dataset.classes}\n")
for epoch in range(1, num_epochs + 1):
train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion, device)
val_loss, val_acc = evaluate(model, val_loader, criterion, device)
print(f"Epoch [{epoch:02d}/{num_epochs}] "
f"Train Loss: {train_loss:.4f} Train Acc: {train_acc:.3f} | "
f"Val Loss: {val_loss:.4f} Val Acc: {val_acc:.3f}")
注意两个细节:
non_blocking=True:配合 pin_memory=True 使用,让 CPU→GPU 的数据传输和 GPU 的计算可以异步进行,进一步减少等待时间。
loss.item() * images.size(0):loss.item() 返回的是这个 batch 的平均 loss,乘以 batch size 才是总 loss,最后除以总样本数才是真正的平均 loss。如果直接累加 loss.item() 再除以 batch 数,当最后一个 batch 大小不同时会有偏差。
第九步:一个更常见的数据组织方式——用 CSV 管理数据集
"每个类别一个文件夹"的方式很直观,但不够灵活。真实项目里更常见的是用一个 CSV 文件来管理数据集,记录每个样本的路径和标签:
# dataset.csv
image_path,label,split
images/0001.jpg,cat,train
images/0002.jpg,dog,train
images/0003.jpg,cat,val
...
对应的 Dataset 写法:
import pandas as pd
class CSVDataset(Dataset):
"""
从 CSV 文件加载数据集。
CSV 格式:image_path, label, split
"""
def __init__(self, csv_path: str, split: str, transform=None):
df = pd.read_csv(csv_path)
# 只保留当前 split(train/val/test)的数据
df = df[df['split'] == split].reset_index(drop=True)
self.image_paths = df['image_path'].tolist()
self.transform = transform
# 建立标签到整数的映射
unique_labels = sorted(df['label'].unique().tolist())
self.class_to_idx = {cls: idx for idx, cls in enumerate(unique_labels)}
self.labels = [self.class_to_idx[l] for l in df['label'].tolist()]
print(f"[{split}] 加载 {len(self.image_paths)} 个样本,"
f"{len(unique_labels)} 个类别")
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
img = Image.open(self.image_paths[idx]).convert('RGB')
label = self.labels[idx]
if self.transform:
img = self.transform(img)
return img, label
CSV 方式的好处:
- 方便做数据集划分(只需修改 CSV 里的 split 列)
- 可以存储额外的元信息(拍摄日期、来源、权重等)
- 便于数据版本管理,只需要版本化 CSV 文件
第十步:Dataset 的一些实用技巧
用 Subset 切出子集
from torch.utils.data import Subset
import random
full_dataset = ImageFolderDataset('./data/train', transform=train_transform)
# 随机取 50 个样本做快速调试
indices = random.sample(range(len(full_dataset)), 50)
small_dataset = Subset(full_dataset, indices)
print(len(small_dataset)) # 50
# 调试时先用小数据集跑通,再换回完整数据集
debug_loader = DataLoader(small_dataset, batch_size=8, shuffle=True)
这个技巧非常实用。每次写完新代码,先用一个几十个样本的小数据集跑一遍,确认 loss 在下降、没有 shape 错误,再换成完整数据集跑正式训练,可以省去很多等待时间。
用 random_split 划分训练/验证集
from torch.utils.data import random_split
full_dataset = ImageFolderDataset('./data/train', transform=train_transform)
total = len(full_dataset)
# 80% 训练,20% 验证
train_size = int(0.8 * total)
val_size = total - train_size
train_sub, val_sub = random_split(
full_dataset,
[train_size, val_size],
generator=torch.Generator().manual_seed(42) # 固定随机种子,保证可复现
)
print(f"训练子集:{len(train_sub)},验证子集:{len(val_sub)}")
注意:random_split 只是在索引层面做划分,两个子集共享同一个 Dataset 对象,因此用的是同一个 transform。这意味着验证子集会用训练的数据增强 transform,这通常不是你想要的。
正确做法是为训练和验证分别创建 Dataset 实例(指向同一份数据,但用不同 transform):
# 或者用 indices 手动实现,给不同 split 用不同 transform
all_indices = list(range(len(ImageFolderDataset('./data/train'))))
random.shuffle(all_indices)
split_point = int(0.8 * len(all_indices))
train_indices = all_indices[:split_point]
val_indices = all_indices[split_point:]
# 两个 Dataset 实例,transform 不同
train_full = ImageFolderDataset('./data/train', transform=train_transform)
val_full = ImageFolderDataset('./data/train', transform=val_transform)
train_sub = Subset(train_full, train_indices)
val_sub = Subset(val_full, val_indices)
用 WeightedRandomSampler 处理类别不均衡
如果你的数据集类别不均衡(比如猫 1000 张,狗 100 张),直接训练模型会偏向多数类。WeightedRandomSampler 可以让少数类被更频繁地采样:
from torch.utils.data import WeightedRandomSampler
# 统计每个类别的样本数
class_counts = [0] * len(train_dataset.classes)
for _, label in train_dataset.samples:
class_counts[label] += 1
print(f"各类别样本数:{class_counts}")
# [100, 100, 100](我们的假数据集是均衡的,真实场景可能差很多)
# 计算每个样本的采样权重:类别越少,权重越大
# 权重 = 1 / 该类别的样本数
class_weights = [1.0 / c for c in class_counts]
sample_weights = [class_weights[label] for _, label in train_dataset.samples]
# 创建 Sampler
sampler = WeightedRandomSampler(
weights = sample_weights,
num_samples = len(train_dataset), # 每个 epoch 采样多少个
replacement = True, # 有放回采样(少数类可以被重复采样)
)
# 注意:使用自定义 sampler 时,shuffle 参数必须为 False(它们互斥)
balanced_loader = DataLoader(
train_dataset,
batch_size = 32,
sampler = sampler, # 用 sampler 替代 shuffle
num_workers = 2,
)
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐

所有评论(0)