大多数深度学习教程在讲数据加载时,直接给你一行代码:

dataset = torchvision.datasets.CIFAR10(root='./data', download=True)

然后继续往下讲模型。你照着写,能跑,但你不知道这行代码背后发生了什么,也不知道换成自己的数据集该怎么办。

等到真正开始做项目,数据是自己采集的图片,或者是一个 CSV 文件,或者是某种奇怪格式的标注文件——你才发现完全不知道从哪下手。

这篇文章就是为了解决这个问题。我们不用任何内置数据集,从零开始写一个图片分类的数据加载管道,同时把每一个设计决策背后的原因讲清楚。


数据加载需要解决什么问题

训练一个模型,数据需要经历这些步骤:

硬盘上的原始文件 --> 读取文件(图片解码 / CSV 解析 / ...)--> 预处理(resize、normalize、数据增强 ...)--> 组成 batch(把多个样本打包)--> 搬运到 GPU --> 喂给模型

这条流水线如果设计得不好,GPU 会大量时间在等数据——GPU 算完一个 batch,下一个 batch 还没准备好,白白浪费算力。

PyTorch 用两个类来分工解决这件事:

  • Dataset:负责"知道数据在哪,怎么读取单个样本";
  • DataLoader:负责"怎么把单个样本组织成 batch,怎么并行预处理,怎么打乱顺序"。

Dataset 只管单样本,DataLoader 只管批量调度。这种设计让你可以专注在数据处理逻辑上,不用操心并行和批处理的细节。


第一步:准备一个真实的图片数据集目录

先搭建测试用的目录结构。真实项目里,图片分类数据集最常见的组织方式是"每个类别一个文件夹":

data/
├── train/
│   ├── cat/
│   │   ├── cat_001.jpg
│   │   ├── cat_002.jpg
│   │   └── ...
│   ├── dog/
│   │   ├── dog_001.jpg
│   │   └── ...
│   └── bird/
│       └── ...
└── val/
    ├── cat/
    ├── dog/
    └── bird/

一个包含 train/val/ 目录的数据集。这是真实项目里最普遍的数据组织形式。


第二步:从零写一个 Dataset

PyTorch 的 Dataset 是一个抽象基类,定义在 torch.utils.data 里。你的自定义 Dataset 需要继承它,并实现两个方法:

  • __len__:返回数据集的总样本数
  • __getitem__:给定一个索引 idx,返回第 idx 个样本

就这两个,没有别的强制要求。

import torch
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
from pathlib import Path
import os

class ImageFolderDataset(Dataset):
    """
    自定义图片文件夹数据集。
    目录结构:root_dir/class_name/image.jpg
    """

    def __init__(self, root_dir: str, transform=None):
        """
        初始化:扫描目录,建立索引,不读取任何图片文件。

        Args:
            root_dir:  数据根目录(如 './data/train')
            transform: 图片预处理/增强的 transforms,默认为 None
        """
        self.root_dir  = Path(root_dir)
        self.transform = transform

        # ---- 建立类别到整数标签的映射 ----
        # sorted() 保证每次运行顺序一致,避免随机性导致标签混乱
        self.classes = sorted([
            d.name for d in self.root_dir.iterdir() if d.is_dir()
        ])
        self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}
        # {'bird': 0, 'cat': 1, 'dog': 2}

        # ---- 扫描所有图片,建立 (图片路径, 标签) 的列表 ----
        # 这一步只记录路径,不读取文件内容
        # 好处:初始化速度快,内存占用极小,数百万张图片也没问题
        self.samples = []
        VALID_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.webp'}

        for cls_name in self.classes:
            cls_dir = self.root_dir / cls_name
            label   = self.class_to_idx[cls_name]
            for img_path in cls_dir.iterdir():
                if img_path.suffix.lower() in VALID_EXTENSIONS:
                    self.samples.append((img_path, label))

        print(f"数据集初始化完成:{len(self.samples)} 张图片,{len(self.classes)} 个类别")
        print(f"类别映射:{self.class_to_idx}")

    def __len__(self):
        """
        返回数据集的总样本数。
        DataLoader 通过这个方法知道数据集有多大,从而决定要生成多少个 batch。
        """
        return len(self.samples)

    def __getitem__(self, idx):
        """
        给定索引 idx,返回第 idx 个样本。
        DataLoader 在需要某个样本时就会调用这个方法。

        返回值:(image_tensor, label)
            image_tensor: 经过 transform 处理后的图片张量
            label:        整数标签
        """
        img_path, label = self.samples[idx]

        # ---- 读取图片 ----
        # 这里才真正发生 I/O,打开并解码图片文件
        # 用 PIL 打开是最通用的方式,支持各种格式
        img = Image.open(img_path).convert('RGB')
        # convert('RGB') 是个好习惯:强制转为 3 通道
        # 有些图片是 RGBA(4 通道)或者灰度图,不转换的话后续处理可能出错

        # ---- 应用 transform ----
        if self.transform is not None:
            img = self.transform(img)
        # transform 通常会把 PIL Image 转成 Tensor,返回 (C, H, W) 的浮点张量

        return img, label

注意 __init__ 里的设计决策:只记录路径,不读取图片

这是非常重要的原则。如果你在初始化时就把所有图片读进内存,10 万张 224×224 的图片大概需要 50GB 内存,根本装不下。正确的做法是只存路径(一个字符串列表),真正需要某张图片时才在 __getitem__ 里读取。

这种"懒加载"的设计使得 Dataset 可以处理任意规模的数据集,哪怕数据集大小远超内存。


第三步:理解 getitem 的调用时机

__getitem__ 不是你手动调用的,是 DataLoader 在需要某个样本时自动调用的。

# 你可以直接用索引访问 Dataset(但通常不这么做)
dataset = ImageFolderDataset('./data/train')

# 这会触发 __getitem__(0)
sample = dataset[0]
print(type(sample))          # <class 'tuple'>
print(len(sample))           # 2  ← (image, label)
print(type(sample[0]))       # <class 'PIL.Image.Image'>(还没加 transform)
print(sample[1])             # 整数标签,如 1

这里有一个值得注意的地方:__getitem__ 在训练时会被调用数百万次(每个 epoch 里每个样本都被访问一次)。它必须足够高效,因为它是整个数据加载管道里最频繁的操作。

常见的性能建议:

  • 图片读取用 PIL 或者 OpenCV,不要用纯 Python 循环逐像素处理
  • 如果数据集很小(比如几千张图),可以在 __init__ 里把所有图片预先读进内存,__getitem__ 直接从内存返回
  • 如果数据集极大,考虑把数据预处理成 LMDB 或者 HDF5 格式,随机访问更快

第四步:transforms——在 Dataset 里做预处理

transform 是一个可调用对象(callable),接收一个样本,返回处理后的样本。torchvision.transforms 提供了大量现成的变换。

from torchvision import transforms

# 训练集的 transform:包含数据增强
train_transform = transforms.Compose([
    # ---- 几何变换(数据增强) ----
    transforms.Resize(256),                # 先缩放到 256(短边)
    transforms.RandomCrop(224),            # 随机裁剪到 224×224
    transforms.RandomHorizontalFlip(p=0.5), # 50% 概率水平翻转

    # ---- 颜色增强(可选,对颜色敏感的任务慎用) ----
    transforms.ColorJitter(
        brightness=0.2,   # 随机调整亮度 ±20%
        contrast=0.2,     # 随机调整对比度 ±20%
        saturation=0.2,   # 随机调整饱和度 ±20%
    ),

    # ---- 转为 Tensor + 归一化 ----
    transforms.ToTensor(),
    # ToTensor 做了两件事:
    # 1. PIL Image (H, W, C) uint8 → Tensor (C, H, W) float32
    # 2. 像素值 [0, 255] → [0.0, 1.0]

    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],   # ImageNet 的均值
        std=[0.229, 0.224, 0.225]     # ImageNet 的标准差
    ),
    # Normalize 做的是:output = (input - mean) / std
    # 让每个通道的数值分布大致在 [-2, 2] 区间,有助于训练稳定
])

# 验证集的 transform:不做随机增强,只做确定性预处理
val_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),    # 中心裁剪,而不是随机裁剪
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    ),
])

训练集和验证集用不同的 transform,这是一个必须遵守的原则:

训练集:需要数据增强(随机翻转、随机裁剪、颜色抖动),人为增加样本多样性,提升泛化能力。

验证集:绝对不能用随机增强。验证是为了评估模型的真实性能,如果每次验证时图片内容因随机性而不同,指标就失去了可比性,你没法判断模型是在进步还是运气好。验证集只用确定性变换(Resize + CenterCrop + ToTensor + Normalize)。

ToTensor 和 Normalize 的细节

from torchvision import transforms
from PIL import Image
import torch

# 创建一个测试图片
img = Image.fromarray(
    (torch.randint(0, 256, (64, 64, 3))).numpy().astype('uint8')
)
print(f"PIL Image 大小:{img.size},模式:{img.mode}")
# PIL Image 大小:(64, 64),模式:RGB

# ToTensor
to_tensor = transforms.ToTensor()
tensor = to_tensor(img)
print(f"Tensor 形状:{tensor.shape}")     # torch.Size([3, 64, 64])
print(f"值域:[{tensor.min():.3f}, {tensor.max():.3f}]")  # [0.000, 1.000]
# 注意:PIL 是 (H, W, C),ToTensor 会自动变成 (C, H, W)
# 注意:uint8 [0,255] 被自动转为 float32 [0.0, 1.0]

# Normalize
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
normalized = normalize(tensor)
print(f"归一化后值域:[{normalized.min():.3f}, {normalized.max():.3f}]")
# [-1.000, 1.000]
# (0.0 - 0.5) / 0.5 = -1.0,(1.0 - 0.5) / 0.5 = 1.0

一个常见的错误:先 Normalize 再 ToTensor,或者 Normalize 的顺序弄错Normalize 要求输入已经是 float Tensor(值域 [0, 1]),必须放在 ToTensor 之后。


第五步:把 Dataset 和 DataLoader 接起来

# 实例化 Dataset
train_dataset = ImageFolderDataset('./data/train', transform=train_transform)
val_dataset   = ImageFolderDataset('./data/val',   transform=val_transform)

print(f"训练集大小:{len(train_dataset)}")   # 300(3类 × 100张)
print(f"验证集大小:{len(val_dataset)}")     # 60(3类 × 20张)

# 验证单个样本的形状
img, label = train_dataset[0]
print(f"单个样本 - 图片形状:{img.shape},标签:{label}")
# 单个样本 - 图片形状:torch.Size([3, 224, 224]),标签:0

现在把 DataLoader 接上:

from torch.utils.data import DataLoader

train_loader = DataLoader(
    dataset     = train_dataset,
    batch_size  = 32,       # 每个 batch 包含 32 个样本
    shuffle     = True,     # 每个 epoch 开始前打乱顺序(训练集必须 True)
    num_workers = 4,        # 用 4 个子进程并行加载数据
    pin_memory  = True,     # 把数据锁定在内存,加速 CPU→GPU 的数据传输
    drop_last   = True,     # 丢弃最后一个不满 batch_size 的 batch
)

val_loader = DataLoader(
    dataset     = val_dataset,
    batch_size  = 64,       # 验证时不需要梯度,显存充裕,batch 可以更大
    shuffle     = False,    # 验证集不需要打乱
    num_workers = 4,
    pin_memory  = True,
    drop_last   = False,    # 验证集要评估所有样本,不能丢
)

# 查看一个 batch 的形状
images, labels = next(iter(train_loader))
print(f"一个 batch - 图片形状:{images.shape},标签形状:{labels.shape}")
# 一个 batch - 图片形状:torch.Size([32, 3, 224, 224]),标签形状:torch.Size([32])

DataLoader 从 Dataset 里取出单个样本,然后沿着 batch 维度把它们堆叠起来:

Dataset[0]  → (img_tensor [3,224,224], label 1)
Dataset[1]  → (img_tensor [3,224,224], label 0)
...
Dataset[31] → (img_tensor [3,224,224], label 2)
                              ↓ DataLoader 拼接
(images [32,3,224,224], labels [32])

这个"拼接"操作在 DataLoader 里叫做 collate,默认用 torch.stack 把所有样本的 Tensor 堆叠成第 0 维(batch 维)。


第六步:DataLoader 的每个参数都在做什么

batch_size

最直接的参数。每次迭代返回 batch_size 个样本。

batch size 的选择是一个权衡:

  • 太小(如 1 或 4):每个 batch 的梯度估计噪声太大,训练不稳定;GPU 利用率低,大量时间浪费在 kernel 启动开销上
  • 太大:需要更多显存;研究表明过大的 batch 有时会损害泛化性能(陷入锐利极值,参考第 01 篇的宽极值讨论)
  • 常见的选择是 32、64、128,要是 2 的幂次(配合硬件内存对齐效率更高)

shuffle

shuffle=True:每个 epoch 开始前,重新随机排列所有样本的顺序
这保证了每个 batch 里的样本是随机组合的,模型不会记住"第 N 个 batch 总是这些样本"

为什么训练集要 shuffle?
假设你的数据按类别排列:前 1000 张全是猫,接着 1000 张全是狗
不 shuffle 的话,前几个 batch 全是猫,梯度会极度偏向猫类
参数更新方向完全失衡,训练会很混乱

验证集不需要 shuffle:
验证只是在算指标,顺序不影响结果,保持固定顺序还方便调试
 

num_workers

这是理解 DataLoader 工作原理的关键参数。

num_workers=0:主进程单线程加载(默认值)
每次需要一个样本时,主进程暂停训练,去读图片,处理完再继续
GPU 等待期间完全空闲——这是最常见的训练瓶颈

num_workers=4:启动 4 个独立的子进程来加载数据
主进程在 GPU 计算当前 batch 的同时,子进程已经在预读取下一个 batch
GPU 几乎不需要等待,计算效率大幅提升
 

Windows 用户注意:在 Windows 上使用 num_workers > 0 时,必须把 DataLoader 的创建放在 if __name__ == '__main__': 保护块里,否则子进程会递归启动导致报错:

# Windows 上的正确写法
if __name__ == '__main__':
    loader = DataLoader(dataset, batch_size=32, num_workers=4)
    for images, labels in loader:
        ...

pin_memory

pin_memory=True:把加载好的 Tensor 放在"锁页内存"(page-locked memory)
普通内存页面可能被操作系统换出到硬盘(swap),锁页内存保证始终在物理内存里
从锁页内存传输数据到 GPU,可以启用 DMA(直接内存访问),
不需要先把数据拷贝到一个临时缓冲区,速度更快

什么时候开:有 GPU 训练时,几乎总是应该开
什么时候不开:CPU-only 训练,或内存非常紧张时

drop_last

假设训练集有 305 个样本,batch_size=32
305 / 32 = 9 个完整 batch(288 个样本)+ 1 个不完整 batch(17 个样本)

drop_last=True:丢弃最后那个 17 个样本的 batch
好处:所有 batch 大小一致,不会因为最后一个 batch 太小导致统计不稳定(BN 层对 batch size 敏感,最后一个小 batch 可能让 BN 的统计量失真)

drop_last=False(默认):保留最后一个小 batch
验证集通常用 False,确保每个样本都被评估到
 

collate_fn:自定义 batch 打包逻辑

默认的 collate 逻辑是用 torch.stack 把样本堆叠,要求所有样本的形状完全相同。但如果你的数据形状不一样(比如不同长度的文本序列,或者不同尺寸的图片),就需要自定义 collate_fn

def custom_collate_fn(batch):
    """
    batch 是一个列表,每个元素是 __getitem__ 返回的单个样本。
    batch = [(img1, label1), (img2, label2), ..., (img32, label32)]

    默认的 collate 会直接 torch.stack,要求形状相同。
    自定义 collate 可以做 padding、过滤空样本等操作。
    """
    # 过滤掉 __getitem__ 返回 None 的样本(有时候文件损坏时这样处理)
    batch = [b for b in batch if b is not None]

    if len(batch) == 0:
        return None

    images = torch.stack([b[0] for b in batch])    # (B, C, H, W)
    labels = torch.tensor([b[1] for b in batch])   # (B,)
    return images, labels

loader = DataLoader(train_dataset, batch_size=32, collate_fn=custom_collate_fn)

第七步:处理一个真实的边界情况——文件损坏

真实数据集里总有几张坏掉的图片。如果 __getitem__ 直接抛出异常,整个训练就崩了。

防御性写法:

class RobustImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir  = Path(root_dir)
        self.transform = transform
        self.classes   = sorted([d.name for d in self.root_dir.iterdir() if d.is_dir()])
        self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}

        self.samples = []
        VALID_EXT = {'.jpg', '.jpeg', '.png', '.bmp', '.webp'}
        for cls_name in self.classes:
            cls_dir = self.root_dir / cls_name
            label   = self.class_to_idx[cls_name]
            for p in cls_dir.iterdir():
                if p.suffix.lower() in VALID_EXT:
                    self.samples.append((p, label))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]

        try:
            img = Image.open(img_path).convert('RGB')
            if self.transform:
                img = self.transform(img)
            return img, label

        except Exception as e:
            # 记录错误,但不崩溃
            print(f"警告:无法读取图片 {img_path},错误:{e}")
            # 返回 None,配合自定义 collate_fn 过滤掉
            return None

配合自定义 collate_fn 过滤 None

def collate_skip_none(batch):
    """过滤掉损坏样本(__getitem__ 返回 None 的)"""
    batch = [b for b in batch if b is not None]
    if not batch:
        return None
    images = torch.stack([b[0] for b in batch])
    labels = torch.tensor([b[1] for b in batch])
    return images, labels

robust_loader = DataLoader(
    RobustImageDataset('./data/train', transform=train_transform),
    batch_size=32,
    collate_fn=collate_skip_none,
    num_workers=4,
)

# 训练循环里处理 None batch
for batch in robust_loader:
    if batch is None:
        continue   # 跳过全坏的 batch
    images, labels = batch
    # 正常训练...

第八步:完整的训练循环

把上面所有东西组合成一个完整可运行的训练例子:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
from pathlib import Path

# -------- Dataset --------
class ImageFolderDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir  = Path(root_dir)
        self.transform = transform
        self.classes   = sorted([d.name for d in self.root_dir.iterdir() if d.is_dir()])
        self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}

        self.samples = []
        VALID_EXT = {'.jpg', '.jpeg', '.png', '.bmp', '.webp'}
        for cls_name in self.classes:
            label = self.class_to_idx[cls_name]
            for p in (self.root_dir / cls_name).iterdir():
                if p.suffix.lower() in VALID_EXT:
                    self.samples.append((p, label))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        img = Image.open(img_path).convert('RGB')
        if self.transform:
            img = self.transform(img)
        return img, label

# -------- Transforms --------
train_transform = transforms.Compose([
    transforms.Resize(80),
    transforms.RandomCrop(64),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
val_transform = transforms.Compose([
    transforms.Resize(80),
    transforms.CenterCrop(64),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

# -------- DataLoader --------
train_dataset = ImageFolderDataset('./data/train', transform=train_transform)
val_dataset   = ImageFolderDataset('./data/val',   transform=val_transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True,
                          num_workers=2, pin_memory=True, drop_last=True)
val_loader   = DataLoader(val_dataset,   batch_size=64, shuffle=False,
                          num_workers=2, pin_memory=True)

# -------- 模型、优化器、损失函数 --------
device    = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_classes = len(train_dataset.classes)

# 用一个简单的 CNN(不用预训练,方便观察从随机初始化开始的训练过程)
model = nn.Sequential(
    nn.Conv2d(3, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),   # 64→32
    nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),  # 32→16
    nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2), # 16→8
    nn.Flatten(),
    nn.Linear(128 * 8 * 8, 256), nn.ReLU(), nn.Dropout(0.5),
    nn.Linear(256, num_classes),
).to(device)

optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

# -------- 训练与验证函数 --------
def train_one_epoch(model, loader, optimizer, criterion, device):
    model.train()
    total_loss, correct, total = 0.0, 0, 0

    for images, labels in loader:
        # 把数据搬到 GPU(如果有)
        images = images.to(device, non_blocking=True)
        # non_blocking=True:配合 pin_memory,让数据传输和 GPU 计算异步进行
        labels = labels.to(device, non_blocking=True)

        optimizer.zero_grad()

        outputs = model(images)              # 前向传播
        loss    = criterion(outputs, labels) # 计算损失
        loss.backward()                      # 反向传播
        optimizer.step()                     # 参数更新

        # 统计指标
        total_loss += loss.item() * images.size(0)  # loss.item() 是这个 batch 的均值
        _, predicted = outputs.max(dim=1)           # 取概率最大的类别
        correct += predicted.eq(labels).sum().item()
        total   += labels.size(0)

    avg_loss = total_loss / total
    accuracy = correct / total
    return avg_loss, accuracy

@torch.no_grad()
def evaluate(model, loader, criterion, device):
    model.eval()
    total_loss, correct, total = 0.0, 0, 0

    for images, labels in loader:
        images = images.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        outputs = model(images)
        loss    = criterion(outputs, labels)

        total_loss += loss.item() * images.size(0)
        _, predicted = outputs.max(dim=1)
        correct += predicted.eq(labels).sum().item()
        total   += labels.size(0)

    return total_loss / total, correct / total

# -------- 主训练循环 --------
num_epochs = 10
print(f"开始训练,设备:{device},类别:{train_dataset.classes}\n")

for epoch in range(1, num_epochs + 1):
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion, device)
    val_loss,   val_acc   = evaluate(model, val_loader, criterion, device)

    print(f"Epoch [{epoch:02d}/{num_epochs}]  "
          f"Train Loss: {train_loss:.4f}  Train Acc: {train_acc:.3f}  |  "
          f"Val Loss: {val_loss:.4f}  Val Acc: {val_acc:.3f}")

注意两个细节:

non_blocking=True:配合 pin_memory=True 使用,让 CPU→GPU 的数据传输和 GPU 的计算可以异步进行,进一步减少等待时间。

loss.item() * images.size(0)loss.item() 返回的是这个 batch 的平均 loss,乘以 batch size 才是总 loss,最后除以总样本数才是真正的平均 loss。如果直接累加 loss.item() 再除以 batch 数,当最后一个 batch 大小不同时会有偏差。


第九步:一个更常见的数据组织方式——用 CSV 管理数据集

"每个类别一个文件夹"的方式很直观,但不够灵活。真实项目里更常见的是用一个 CSV 文件来管理数据集,记录每个样本的路径和标签:

# dataset.csv
image_path,label,split
images/0001.jpg,cat,train
images/0002.jpg,dog,train
images/0003.jpg,cat,val
...

对应的 Dataset 写法:

import pandas as pd

class CSVDataset(Dataset):
    """
    从 CSV 文件加载数据集。
    CSV 格式:image_path, label, split
    """
    def __init__(self, csv_path: str, split: str, transform=None):
        df = pd.read_csv(csv_path)

        # 只保留当前 split(train/val/test)的数据
        df = df[df['split'] == split].reset_index(drop=True)

        self.image_paths = df['image_path'].tolist()
        self.transform   = transform

        # 建立标签到整数的映射
        unique_labels    = sorted(df['label'].unique().tolist())
        self.class_to_idx = {cls: idx for idx, cls in enumerate(unique_labels)}
        self.labels      = [self.class_to_idx[l] for l in df['label'].tolist()]

        print(f"[{split}] 加载 {len(self.image_paths)} 个样本,"
              f"{len(unique_labels)} 个类别")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img = Image.open(self.image_paths[idx]).convert('RGB')
        label = self.labels[idx]
        if self.transform:
            img = self.transform(img)
        return img, label

CSV 方式的好处:

  • 方便做数据集划分(只需修改 CSV 里的 split 列)
  • 可以存储额外的元信息(拍摄日期、来源、权重等)
  • 便于数据版本管理,只需要版本化 CSV 文件

第十步:Dataset 的一些实用技巧

用 Subset 切出子集

from torch.utils.data import Subset
import random

full_dataset = ImageFolderDataset('./data/train', transform=train_transform)

# 随机取 50 个样本做快速调试
indices = random.sample(range(len(full_dataset)), 50)
small_dataset = Subset(full_dataset, indices)

print(len(small_dataset))   # 50

# 调试时先用小数据集跑通,再换回完整数据集
debug_loader = DataLoader(small_dataset, batch_size=8, shuffle=True)

这个技巧非常实用。每次写完新代码,先用一个几十个样本的小数据集跑一遍,确认 loss 在下降、没有 shape 错误,再换成完整数据集跑正式训练,可以省去很多等待时间。

用 random_split 划分训练/验证集

from torch.utils.data import random_split

full_dataset = ImageFolderDataset('./data/train', transform=train_transform)
total = len(full_dataset)

# 80% 训练,20% 验证
train_size = int(0.8 * total)
val_size   = total - train_size

train_sub, val_sub = random_split(
    full_dataset,
    [train_size, val_size],
    generator=torch.Generator().manual_seed(42)  # 固定随机种子,保证可复现
)

print(f"训练子集:{len(train_sub)},验证子集:{len(val_sub)}")

注意:random_split 只是在索引层面做划分,两个子集共享同一个 Dataset 对象,因此用的是同一个 transform。这意味着验证子集会用训练的数据增强 transform,这通常不是你想要的。

正确做法是为训练和验证分别创建 Dataset 实例(指向同一份数据,但用不同 transform):

# 或者用 indices 手动实现,给不同 split 用不同 transform
all_indices = list(range(len(ImageFolderDataset('./data/train'))))
random.shuffle(all_indices)
split_point = int(0.8 * len(all_indices))

train_indices = all_indices[:split_point]
val_indices   = all_indices[split_point:]

# 两个 Dataset 实例,transform 不同
train_full = ImageFolderDataset('./data/train', transform=train_transform)
val_full   = ImageFolderDataset('./data/train', transform=val_transform)

train_sub = Subset(train_full, train_indices)
val_sub   = Subset(val_full,   val_indices)

用 WeightedRandomSampler 处理类别不均衡

如果你的数据集类别不均衡(比如猫 1000 张,狗 100 张),直接训练模型会偏向多数类。WeightedRandomSampler 可以让少数类被更频繁地采样:

from torch.utils.data import WeightedRandomSampler

# 统计每个类别的样本数
class_counts = [0] * len(train_dataset.classes)
for _, label in train_dataset.samples:
    class_counts[label] += 1
print(f"各类别样本数:{class_counts}")
# [100, 100, 100](我们的假数据集是均衡的,真实场景可能差很多)

# 计算每个样本的采样权重:类别越少,权重越大
# 权重 = 1 / 该类别的样本数
class_weights = [1.0 / c for c in class_counts]
sample_weights = [class_weights[label] for _, label in train_dataset.samples]

# 创建 Sampler
sampler = WeightedRandomSampler(
    weights     = sample_weights,
    num_samples = len(train_dataset),   # 每个 epoch 采样多少个
    replacement = True,                  # 有放回采样(少数类可以被重复采样)
)

# 注意:使用自定义 sampler 时,shuffle 参数必须为 False(它们互斥)
balanced_loader = DataLoader(
    train_dataset,
    batch_size  = 32,
    sampler     = sampler,    # 用 sampler 替代 shuffle
    num_workers = 2,
)
Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐