pytorch初学笔记(五):torchvision中dataset的最详细使用(以CIFAR10和MNIST为例)
目录
一、torchvision介绍
1. 作用与结构
torchvision — Torchvision main documentation
torchvision是pytorch下的一个包,主要由计算机视觉中的流行数据集、模型体系结构和常见图像转换等模块组成。
常用的包:
- Transforming and augmenting images:进行图片变换等。
- Models and pre-trained weights:提供一些预训练好的神经网络或权重参数等。
- Dataset :提供常用的数据集。
2. torchvision中常用数据集
Datasets — Torchvision main documentation
Datasets模块提供了需要常用的数据集以及其具体的使用方法,比如下图所示的图像分类中常用的CIFAR10数据集,图像检测中常用的COCO数据集等。
下面具体说明如何对CIFAR10进行下载和使用。
二、CIFAR10的介绍
1. 数据集简介
CIFAR-10 and CIFAR-100 datasets (toronto.edu)
- CIFAR-10是一个更接近普适物体的彩色图像的小型数据集。
- 一共包含10 个类别的RGB 彩色图片:飞机( airplane )、汽车( automobile )、鸟类( bird )、猫( cat )、鹿( deer )、狗( dog )、蛙类( frog )、马( horse )、船( ship )和卡车( truck )。
- 每个图片的尺寸为32 × 32 ,每个类别有6000个图像,数据集中一共有50000 张训练图片和10000 张测试图片。
2. 使用该数据集的所需参数
CIFAR10 — Torchvision main documentation
需要设定的5个参数:
1. root(字符串型):把数据集下载到的位置路径。
2. train(布尔型):是否把该数据集作为训练数据集使用。
- True: 作为训练数据集创建
- False:不作为训练数据集,作为测试数据集创建
3. transform:图像需要进行的变换操作,一般使用compose把所需的transforms结合起来。
4. target_transform:对于标签需要做的变换
5. download(布尔型):是否下载数据集。
- True:把数据集下载到root指定的对应位置;如果数据集以及进行过下载,则不会再一次下载
- False:不下载数据集
3. 数据集下载
3.1 pycharm在线下载(下载速度较快时)
1. 导入torchvision包,然后依次创建训练数据集和测试数据集。
注意:训练数据集的train参数要设置为True,测试数据集的train设置为False
import torchvision
#创建训练数据集
train_set = torchvision.datasets.CIFAR10(root="./dataset3",train=True,download=True)
#创建测试数据集
test_set = torchvision.datasets.CIFAR10(root="./dataset3",train=False,download=True)
2. 点击运行,等待一段时间后显示下载成功
3. 观察项目包目录,可以发现自动创建了名为dataset3的文件夹,下载的解压文件和解压好的数据集都在其中。
3.2 第三方下载
如果在pycharm中下载速度很慢的话,可以找到pycharm所用的下载链接,然后自己使用迅雷等下载软件进行快速下载。
如何找到下载链接?
- 把鼠标移动到想要下载的数据集名称上,然后Ctrl+C,进入该数据集的帮助文档。
2. 可以看到对应的下载文件名和下载链接。
3. 使用迅雷或者浏览器下载,然后把下载过后的压缩文件按照root中定义的路径创建文件夹,然后把文件放入文件夹中,注意,自己创建的文件夹一定要和root中定义的文件夹姓名相同才行,否则后期扫描不到该数据集。
4. 运行上面在线下载中定义的语句,可以发现程序不会再次下载数据集文件,而是会帮你解压好数据集。
3.3 数据库的下载总结
无论是否需要在线下载数据集,都推荐把download参数值设为True。
因为程序可以帮你自动完成下载解压工作,就算自己下载过文件,也可以提供解压功能,因此更加方便。
三、 CIFAR10的具体使用
1. 数据集对象的显示(PIL型)
import torchvision
#创建训练数据集
train_set = torchvision.datasets.CIFAR10(root="./dataset3",train=True,download=True)
#创建测试数据集
test_set = torchvision.datasets.CIFAR10(root="./dataset3",train=False,download=True)
#1. 查看数据集的图片
#输出所有类别
print(test_set.classes)
#输出数据集第一张图片的类型
print(test_set[0])
#输出图片的PIL型格式和标签
img,label = test_set[0]
print(label,test_set.classes[label])
img.show()
1. 数据集所有类别的查看
图片有十个类,对应的类别名称存储在dataset.classes列表中。
2. 数据集中单个具体对象的查看
想要输出数据集中具体的某一张图片,使用下标调用方式dataset[x]即可显示第x+1张图片;输出的对象类型为一个元组,里面第一项是PIL类型的图片,第二项是图片的标签。
3. 数据集中图片对象和标签的定义
可以使用 img,label = dataset[x] 的方式接收对象中的图片和label,然后可以用print进行对label的输出,也可以用 dataset. classes[label]的格式进行对该类别名称的显示。
4. 数据集中图片的可视化
使用img.show()方法进行图片的可视化显示。
输出结果如下:
打开的对应图片如下图所示,由于数据集中的图片较小,所以不清晰,但是可以看出来是一只小猫的图片。
2. 把数据集中的图片对象转换为tensor型
2.1 转换所需transform的定义
因为需要完成数据集中所有图片类型从PIL到tensor的转换,我们需要用到transforms工具,也需要设定数据集中的transform参数。
我们在数据集定义的语句之前定义我们需要的transform,由于一般需要对图像做的变换不止一个,所以我们使用compose来对多个transforms进行组合。在这里我们只需要一个ToTensor即可。
下面代码给出使用compose定义transform和不使用compose的两个版本,都可以完成成功运行。
- 使用compose:
import torchvision
#定义transforms
dataset_transform = torchvision.transforms.Compose([
#定义totensor
torchvision.transforms.ToTensor()
])
#创建训练数据集
train_set = torchvision.datasets.CIFAR10(root="./dataset3",train=True,transform=dataset_transform,download=True)
#创建测试数据集
test_set = torchvision.datasets.CIFAR10(root="./dataset3",train=False,transform=dataset_transform,download=True)
- 不使用compose:
import torchvision
#定义transforms
from torch.utils.tensorboard import SummaryWriter
trans_totensor_tool = torchvision.transforms.ToTensor()
#创建训练数据集
train_set = torchvision.datasets.CIFAR10(root="./dataset3",train=True,transform=trans_totensor_tool,download=True)
#创建测试数据集
test_set = torchvision.datasets.CIFAR10(root="./dataset3",train=False,transform=trans_totensor_tool,download=True)
、
2.2 使用tensorboard进行图片显示
完成了transform和数据集的定义后,即可使用add_image()方法完成图片显示。在这里我们使用for循环进行10张图片的显示。
import torchvision
#定义transforms
from torch.utils.tensorboard import SummaryWriter
trans_totensor_tool = torchvision.transforms.ToTensor()
#创建训练数据集
train_set = torchvision.datasets.CIFAR10(root="./dataset3",train=True,transform=trans_totensor_tool,download=True)
#创建测试数据集
test_set = torchvision.datasets.CIFAR10(root="./dataset3",train=False,transform=trans_totensor_tool,download=True)
#使用tensorboard进行显示
writer = SummaryWriter("logs")
#for循环完成10张图片的显示
for i in range(10):
img,label=test_set[i]
writer.add_image("dataset",img,i)
writer.close()
结果如下所示。可以看到一共step=9,成功显示了数据集中第1-10张图片。
四、练习:MNIST数据集的下载和使用
1. 可能的报错和修改
使用上面做过的练习对MNIST数据集进行相同的操作,注意在下载数据集后可能会爆“UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors.” 的错误,按照博文的方法修改即可。
2. 代码实现
对于PIL对象:
- 完成数据集所有类别的输出(classes)
- 输出数据集中的第一个对象
- 完成前10张图片对应类别的输出
- 完成第10张图片的显示(show方法)
对于tensor对象:
- 把数据集中所有图片类型从PIL型转换为tensor型,重定义图片大小为10*10(使用Compose,ToTensor和Resize)
- 输出前10张图片
2.1 PIL对象实现
import torchvision
from torch.utils.tensorboard import SummaryWriter
train_set = torchvision.datasets.MNIST(root="./MNIST_test",train=True,download=True)
test_set = torchvision.datasets.MNIST(root="./MNIST_test",train=False,download=True)
#pil型对象显示
print(test_set.classes)
print(test_set[0])
for i in range(10):
img,label=test_set[i]
print(test_set.classes[label])
img.show()
2.2 tensor对象实现
import torchvision
from torch.utils.tensorboard import SummaryWriter
trans_tool = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Resize((10,10))
])
train_set = torchvision.datasets.MNIST(root="./MNIST_test",train=True,transform=trans_tool,download=True)
test_set = torchvision.datasets.MNIST(root="./MNIST_test",train=False,transform=trans_tool,download=True)
#tensor型对象显示
writer = SummaryWriter("logs")
for i in range(10):
img,label=test_set[i]
writer.add_image("MNIST",img,i)
print(img.shape)
writer.close()
3. 运行结果
数据集下载并创建成功:
显示第10张图片:
print的显示结果:
在未改变大小之前的维度是(1,28,28),resize后可见tensor的维度变成了(1,10,10 )
,
tensoeboard显示结果:
更多推荐
所有评论(0)