本文核心讲解如何基于PyTorch的DatasetDataLoader构建高效、可无缝集成到深度学习模型的图像数据加载流水线,同时覆盖自定义数据集处理与内置数据集快速使用两大场景。

一、教程核心目标与示例数据集

1. 核心目标

构建完整的图像数据处理流水线,实现:

  • 数据集的结构化划分(训练集+验证集)
  • 图像数据的加载与自动化预处理/增强
  • 按批次高效读取数据并输入模型

2. 示例数据集

使用5类花卉数据集,包含:郁金香(Tulips)、雏菊(Daisy)、蒲公英(Dandelion)、玫瑰(Roses)、向日葵(Sunflowers),原始数据按"类别文件夹+图像文件"的结构组织。

二、开发环境配置

1. 必需依赖包

通过pip一键安装:

pip install torch torchvision matplotlib opencv-contrib-python imutils

2. 备选方案

若本地环境配置困难,可通过PyImageSearch University获取预配置的Google Colab笔记本,支持Windows/macOS/Linux全平台直接运行,无需本地安装依赖。

三、项目结构与文件说明

下载并解压教程资源后,项目目录结构如下:

├── build_dataset.py          # 数据集划分脚本
├── builtin_dataset.py        # PyTorch内置数据集加载脚本
├── flower_photos/            # 原始花卉数据集
│   ├── daisy/
│   ├── dandelion/
│   ├── roses/
│   ├── sunflowers/
│   └── tulips/
├── load_and_visualize.py     # 数据加载与批次可视化脚本
└── pyimagesearch/
    ├── config.py             # 全局配置文件
    └── __init__.py           # 包初始化文件

四、全局配置文件(config.py)

集中管理所有可配置参数,避免硬编码:

# specify path to the flowers and mnist dataset
FLOWERS_DATASET_PATH = "flower_photos"
MNIST_DATASET_PATH = "mnist"
# specify the paths to our training and validation set 
TRAIN = "train"
VAL = "val"
# set the input height and width
INPUT_HEIGHT = 128
INPUT_WIDTH = 128
# set the batch size and validation data split
BATCH_SIZE = 8
VAL_SPLIT = 0.1

参数名

取值/含义

FLOWERS_DATASET_PATH

原始花卉数据集根目录

MNIST_DATASET_PATH

MNIST内置数据集保存目录

TRAIN/VAL

训练集/验证集输出目录名

INPUT_HEIGHT/WIDTH

模型输入图像尺寸(128×128)

BATCH_SIZE

数据批次大小(8)

VAL_SPLIT

验证集占总数据集的比例(0.1,即10%)

五、数据集划分(build_dataset.py)

将原始数据集按比例划分为训练集和验证集,保持类别分布均匀。

# USAGE  
# python build_dataset.py
# import necessary packages
from pyimagesearch import config  
from imutils import paths  
import numpy as np  
import shutil  
import os

1. 核心函数:copy_images

  • 功能:接收图像路径列表和目标目录,自动按类别创建子文件夹并复制图像
  • 关键逻辑:从图像路径中提取类别名(路径倒数第二层),在目标目录下创建对应类别文件夹,再将图像复制到对应位置
def copy_images(imagePaths, folder):
	# check if the destination folder exists and if not create it
	if not os.path.exists(folder):
		os.makedirs(folder)
	# loop over the image paths
	for path in imagePaths:
		# grab image name and its label from the path and create
		# a placeholder corresponding to the separate label folder
		imageName = path.split(os.path.sep)[-1]
		label = path.split(os.path.sep)[-2]
		labelFolder = os.path.join(folder, label)
		# check to see if the label folder exists and if not create it
		if not os.path.exists(labelFolder):
			os.makedirs(labelFolder)
		# construct the destination image path and copy the current
		# image to it
		destination = os.path.join(labelFolder, imageName)
		shutil.copy(path, destination)

2. 划分流程

  1. 加载所有图像路径并通过np.random.shuffle()随机打乱,保证训练集和验证集的类别分布一致
  2. VAL_SPLIT计算验证集和训练集的图像数量
  3. 分别将训练集、验证集图像复制到train/val/目录
# load all the image paths and randomly shuffle them
print("[INFO] loading image paths...")
imagePaths = list(paths.list_images(config.FLOWERS_DATASET_PATH))
np.random.shuffle(imagePaths)
# generate training and validation paths
valPathsLen = int(len(imagePaths) * config.VAL_SPLIT)
trainPathsLen = len(imagePaths) - valPathsLen
trainPaths = imagePaths[:trainPathsLen]
valPaths = imagePaths[trainPathsLen:]
# copy the training and validation images to their respective
# directories
print("[INFO] copying training and validation images...")
copy_images(trainPaths, config.TRAIN)
copy_images(valPaths, config.VAL)

3. 最终目录结构

划分后生成独立的训练集和验证集目录,均保持"根目录→类别文件夹→图像文件"的结构:

├── train/
│   ├── daisy/
│   ├── dandelion/
│   ├── roses/
│   ├── sunflowers/
│   └── tulips/
└── val/
    ├── daisy/
    ├── dandelion/
    ├── roses/
    ├── sunflowers/
    └── tulips/

六、PyTorch Dataset与DataLoader核心实现(load_and_visualize.py)

这是教程的核心部分,讲解如何加载数据、应用增强并构建可迭代的DataLoader。

1. 关键导入

  • ImageFolder:PyTorch内置的图像数据集类,用于加载按类别分文件夹的数据集
  • DataLoader:将Dataset包装为可迭代对象,实现按批次加载
  • transforms:提供图像预处理和数据增强的内置函数
# USAGE
# python load_and_visualize.
# import necessary packages
from pyimagesearch import config
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torchvision import transforms
import matplotlib.pyplot as plt
import torch

2. 批次可视化函数:visualize_batch

  • 输入:数据批次、类别列表、数据集类型(train/val)
  • 核心处理:
    1. 将PyTorch张量(通道优先格式:C×H×W)转换为numpy数组(通道最后格式:H×W×C)
    2. 将归一化到[0,1]的像素值还原为[0,255]的整数格式
    3. 绘制批次中所有图像,并标注对应的类别名称
def visualize_batch(batch, classes, dataset_type):
	# initialize a figure
	fig = plt.figure("{} batch".format(dataset_type),
		figsize=(config.BATCH_SIZE, config.BATCH_SIZE))
	# loop over the batch size
	for i in range(0, config.BATCH_SIZE):
		# create a subplot
		ax = plt.subplot(2, 4, i + 1)
		# grab the image, convert it from channels first ordering to
		# channels last ordering, and scale the raw pixel intensities
		# to the range [0, 255]
		image = batch[0][i].cpu().numpy()
		image = image.transpose((1, 2, 0))
		image = (image * 255.0).astype("uint8")
		# grab the label id and get the label from the classes list
		idx = batch[1][i]
		label = classes[idx]
		# show the image along with the label
		plt.imshow(image)
		plt.title(label)
		plt.axis("off")
	# show the plot
	plt.tight_layout()
	plt.show()

3. 数据预处理与增强

针对训练集和验证集设计不同的变换流水线(验证集不使用数据增强,仅做必要的格式转换):

# 训练集变换:包含数据增强
trainTransforms = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.RandomHorizontalFlip(p=0.25),  # 25%概率水平翻转
    transforms.RandomVerticalFlip(p=0.25),    # 25%概率垂直翻转
    transforms.RandomRotation(degrees=15),    # 随机旋转±15度
    transforms.ToTensor()                     # 转换为张量并归一化到[0,1]
])

# 验证集变换:仅调整尺寸和格式转换
valTransforms = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor()
])
  • 关键说明:ToTensor()不仅完成数据类型转换,还会自动将PIL图像或numpy数组的像素值从[0,255]归一化到[0,1]区间。

4. ImageFolder数据集创建

  • 输入要求:数据集必须遵循"根目录/类别名/图像文件"的结构
  • 自动功能:识别所有唯一类别并映射为整数标签(0~4对应5类花卉)
  • 代码实现:
trainDataset = ImageFolder(root=config.TRAIN, transform=trainTransforms)
valDataset = ImageFolder(root=config.VAL, transform=valTransforms)
  • 核心方法:
    • __len__():返回数据集的总样本数
    • __getitem__(index):通过索引获取单个样本,返回格式为(图像张量, 整数标签)

5. DataLoader配置

将Dataset包装为可迭代对象,支持按批次加载和并行处理:

# 训练集DataLoader:开启shuffle打乱样本,优化梯度下降收敛
trainDataLoader = DataLoader(trainDataset, batch_size=8, shuffle=True)
# 验证集DataLoader:无需打乱
valDataLoader = DataLoader(valDataset, batch_size=8)
  • 核心作用:将数据集划分为固定大小的批次,支持模型批量处理;训练集开启shuffle=True可避免模型学习到数据的顺序特征。

6. 批次获取与可视化

通过iter()将DataLoader转换为迭代器,再通过next()获取单个批次,调用visualize_batch函数展示批次中的图像和标签。

七、PyTorch内置数据集使用(builtin_dataset.py)

PyTorch的torchvision.datasets模块提供了大量常用计算机视觉数据集的一键下载和加载功能,包括MNIST、CIFAR-10、CIFAR-100、CelebA等。

1. MNIST数据集加载示例

# 加载训练集(自动下载)
trainDataset = MNIST(root=config.MNIST_DATASET_PATH, train=True, download=True, transform=transforms.ToTensor())
# 加载测试集
valDataset = MNIST(root=config.MNIST_DATASET_PATH, train=False, download=True, transform=transforms.ToTensor())
  • 关键参数:
    • train=True:加载训练集;train=False:加载测试集
    • download=True:自动将数据集下载到指定的root目录

2. 适配调整

  • 可视化函数修改:MNIST为单通道灰度图,绘制时需指定cmap="gray"
  • DataLoader配置与自定义数据集完全一致

八、教程总结

本文完整实现了PyTorch图像数据加载的全流程:

  1. 完成了自定义数据集的结构化划分,保证训练集和验证集的类别分布均匀
  2. 掌握了transforms模块的使用,实现了数据预处理和训练集数据增强
  3. 理解了ImageFolderDataLoader的工作原理,构建了高效的批次数据加载流水线
  4. 学会了快速加载PyTorch内置数据集,简化常用数据集的使用流程

最终构建的数据加载流水线可直接无缝集成到任意PyTorch深度学习模型中,用于模型训练和验证。

参考文章:Image Data Loaders in PyTorch - PyImageSearch

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐