Python / PyTorch 数据格式、数据流与特殊方法:从 IMDB 数据集理解数据是如何被读取和使用的
0. 问题背景
在使用 torchtext.datasets.IMDB 时,我们经常会写:
from torchtext.datasets import IMDB
train_dataset = IMDB(split="train")
刚开始很容易以为 train_dataset 是一个普通列表,可以直接:
len(train_dataset)
train_dataset[0]
但实际测试可能会发现:
len(train_dataset) # 可以用
train_dataset[0] # 报错
这说明:这个对象知道自己有多少条数据,但不支持按索引直接取值。
要理解这个现象,需要搞清楚几个核心问题:
- 数据到底是什么格式?
- 数据是存储在硬盘里,还是内存里?
IterableDataset为什么省内存?list(train_dataset)到底做了什么?__len__()、__getitem__()这些函数为什么有特殊作用?
1. 数据格式到底是什么?
Python 里常见的数据格式可以分成几类。
1.1 基础数据类型
int:整数
label = 1
常用于标签、索引、计数。
例如 IMDB 情感分类数据里,标签可能是:
1
2
代表负面或正面评论。
float:小数
loss = 0.35
learning_rate = 0.001
常用于损失函数、学习率、实验指标。
str:字符串
text = "This movie is very good."
文本数据通常是字符串。
但是神经网络不能直接处理字符串,所以文本最终要转换成数字,比如:
[12, 45, 7, 91]
再变成:
torch.tensor([12, 45, 7, 91])
2. 常见容器数据结构
2.1 list:列表,用 []
data = [1, 2, 3, 4]
特点:
有顺序
可以修改
可以用下标访问
例如:
data[0]
输出:
1
在 IMDB 中,如果我们把数据集转成列表,可能长这样:
train_list = [
(1, "This movie is good."),
(2, "This movie is bad."),
(1, "I like this film.")
]
整体是 list,里面每个元素是一个样本。
2.2 tuple:元组,用 ()
sample = (1, "This movie is good.")
特点:
有顺序
通常不修改
可以用下标访问
例如:
sample[0]
sample[1]
分别得到:
1
"This movie is good."
在 IMDB 里,一个样本通常是:
(label, text)
也就是:
(标签, 评论文本)
例如:
(1, "This movie is good.")
2.3 dict:字典,用 {}
sample = {
"label": 1,
"text": "This movie is good."
}
特点是用 key 取值:
sample["label"]
sample["text"]
字典适合表示有明确字段名的数据。
比如图像任务中,一个样本可能是:
sample = {
"image": image_tensor,
"label": 3,
"filename": "cat_001.jpg"
}
2.4 set:集合,用 {}
vocab = {"good", "bad", "movie"}
特点:
元素不重复
通常不关心顺序
常用于去重,比如构建词表前统计所有出现过的词。
3. IMDB 数据集返回的是什么?
train_dataset = IMDB(split="train")
它通常不是普通的 list,而是一个 IterableDataset / DataPipe 风格的数据流对象。
它的特点是:
for label, text in train_dataset:
...
通常可以用。
但是:
train_dataset[0]
不一定可以用。
也就是说,它更像一个“数据流”,而不是一个已经完整展开的列表。
4. IterableDataset / DataPipe 是什么?
4.1 普通列表像一个盒子
train_list = [
(1, "review text 1"),
(2, "review text 2"),
(1, "review text 3"),
]
它已经完整放在内存里。
所以你可以:
train_list[0]
len(train_list)
它像一个有编号的样本库:
编号 0 -> 第 0 个样本
编号 1 -> 第 1 个样本
编号 2 -> 第 2 个样本
4.2 IterableDataset 像一条水流
train_dataset = IMDB(split="train")
它更像:
数据源 -> 样本1 -> 样本2 -> 样本3 -> ...
你可以顺着数据流一个一个取:
for sample in train_dataset:
print(sample)
break
但是通常不能直接说:
train_dataset[100]
因为它不一定支持随机访问。
5. 为什么 len(train_dataset) 可以,但 train_dataset[0] 不行?
这个现象说明:
它实现了 __len__()
但没有实现真正可用的 __getitem__()
也就是说,它知道自己有多少条数据,但不支持按索引直接跳到第几个样本。
类比一下:
一个排队窗口知道今天一共有 25000 个人
但你不能直接把第 10000 个人拎出来
只能按顺序一个一个叫号
所以:
len(train_dataset)
可以。
但:
train_dataset[0]
不行。
6. 如何检查一个数据对象是什么格式?
拿到任何数据集,可以先运行下面几步。
6.1 看对象类型
print(type(train_dataset))
可能输出类似:
<class 'torch.utils.data.datapipes.iter.sharding.ShardingFilterIterDataPipe'>
这说明它是 DataPipe / 可迭代数据流对象。
6.2 看能不能用 len
try:
print(len(train_dataset))
except Exception as e:
print("len 不可用:", type(e).__name__, e)
6.3 看能不能用下标
try:
print(train_dataset[0])
except Exception as e:
print("索引不可用:", type(e).__name__, e)
6.4 看能不能遍历
for sample in train_dataset:
print(sample)
break
如果能输出一个样本,说明它是可迭代对象。
6.5 看一个样本内部是什么结构
sample = next(iter(train_dataset))
print(sample)
print(type(sample))
如果输出:
(1, "This movie is good.")
<class 'tuple'>
说明每个样本是一个 tuple。
进一步拆开:
label, text = sample
print(label, type(label))
print(text, type(text))
可能得到:
1 <class 'int'>
"This movie is good." <class 'str'>
7. list(train_dataset) 到底做了什么?
如果写:
train_list = list(train_dataset)
它会把数据流里的样本一个一个取出来,然后全部装进一个 Python 列表。
等价于:
train_list = []
for sample in train_dataset:
train_list.append(sample)
最后得到:
[
(1, "review text 1"),
(2, "review text 2"),
(1, "review text 3"),
...
]
这时候:
len(train_list)
train_list[0]
都可以用了。
8. 为什么 IterableDataset 更省内存?
因为它不会一开始就把所有数据加载到内存里。
8.1 全量加载
train_list = list(IMDB(split="train"))
这会把整个训练集全部读出来,放到内存中。
如果数据集很小,比如 IMDB,通常问题不大。
但如果数据集有 100GB,这种方式可能会爆内存。
8.2 流式读取
train_dataset = IMDB(split="train")
for label, text in train_dataset:
...
这种方式是需要一个样本,就读取一个样本。
数据流大概是:
硬盘上的数据
↓
读取一个样本
↓
处理一个样本
↓
继续读取下一个样本
内存里不需要同时保存全部数据。
所以:
list / map-style Dataset:
访问方便,但可能占内存大
IterableDataset / DataPipe:
更省内存,但不一定支持随机访问
9. 什么叫“数据加载到内存中”?
数据长期存储在硬盘或 SSD 中。
但是程序运行时,CPU/GPU 通常不能直接在硬盘上计算,而是要先把数据读到内存。
可以这样理解:
硬盘 SSD/HDD:长期仓库
内存 RAM:临时工作台
CPU/GPU:真正干活的人
当你写:
text = open("review.txt", "r", encoding="utf-8").read()
发生的是:
1. Python 找到硬盘上的 review.txt
2. 操作系统读取文件内容
3. 数据被复制到内存 RAM
4. Python 创建一个 str 字符串对象
5. 变量 text 指向这个对象
10. 一行代码过去之后,数据会消失吗?
不一定。
关键看有没有变量引用它。
10.1 有变量接住
x = [1, 2, 3]
这一行执行完后,列表还在内存里,因为变量 x 指向它。
可以理解为:
x ───> [1, 2, 3]
所以后面还能用:
print(x)
10.2 没有变量接住
[1, 2, 3]
这一行也会创建一个列表。
但是没有变量保存它。
所以执行完之后,这个对象没有引用,后面会被 Python 回收。
11. Python 程序结束后,内存会释放吗?
会。
当 Python 程序结束时,它占用的 RAM 内存通常会被操作系统释放。
但是要分清楚:
硬盘上的原始数据文件:不会被删除
内存中的临时对象:会被释放
例如:
train_list = list(IMDB(split="train"))
程序运行时,train_list 在内存中。
程序结束后,train_list 消失。
但是硬盘上的 IMDB 数据还在。
11.1 Jupyter Notebook 要注意
在 Jupyter 中,一个 cell 运行结束,不等于 Python 程序结束。
只要 kernel 没有重启,变量还在内存中。
例如运行:
train_list = list(IMDB(split="train"))
这个 cell 结束后,train_list 仍然可以用。
但是如果你点击:
Restart Kernel
变量就没了,内存被释放。
12. 如果有 100GB 数据,内存必须大于 100GB 吗?
不一定。
关键看你是不是一次性全读入内存。
12.1 如果一次性全加载
data = list(dataset)
如果数据真的有 100GB,那么内存通常要大于 100GB。
而且还要考虑:
Python 对象开销
中间计算结果
batch 数据
模型参数
系统占用
所以实际可能需要 128GB、192GB 甚至更多。
12.2 如果分批读取
深度学习一般不是一次性读取所有数据,而是用 DataLoader 分批读取:
for batch in dataloader:
...
数据流是:
硬盘中的 100GB 数据
↓
每次读取一个 batch
↓
送入模型
↓
计算完成后继续下一个 batch
内存中同时存在的只是当前 batch 和少量缓存。
所以即使数据集有 100GB,也可能用 32GB 或 64GB 内存训练。
13. DataLoader 是什么?
DataLoader 是 PyTorch 的批量加载器。
例如:
from torch.utils.data import DataLoader
dataloader = DataLoader(
train_dataset,
batch_size=4,
shuffle=True,
collate_fn=collate_batch
)
它的作用是把单个样本组成 batch。
单个样本:
(label, text)
一个 batch:
[
(label1, text1),
(label2, text2),
(label3, text3),
(label4, text4)
]
然后交给 collate_fn 处理成模型能吃的张量。
14. random_split(list(train_dataset), [20000, 5000]) 是什么意思?
代码:
from torch.utils.data import random_split
train_dataset, valid_dataset = random_split(
list(train_dataset), [20000, 5000]
)
这句可以分成三步理解。
14.1 第一步:转成 list
full_train_list = list(train_dataset)
这一步把原来的数据流对象全部展开成列表:
[
(1, "review text 1"),
(2, "review text 2"),
...
]
14.2 第二步:随机切分
train_subset, valid_subset = random_split(
full_train_list,
[20000, 5000]
)
表示把 25000 条数据随机分成:
训练集:20000 条
验证集:5000 条
14.3 第三步:变量覆盖
train_dataset = train_subset
valid_dataset = valid_subset
注意:这里新的 train_dataset 已经不是原来的 IMDB 数据流对象了,而是一个 Subset 对象。
15. Subset 是什么?
Subset 表示原始数据集的一个子集。
它不是把数据完整复制一份,而是保存:
原始 dataset
选中的 indices
例如:
data = [
(1, "text A"), # index 0
(2, "text B"), # index 1
(1, "text C"), # index 2
(2, "text D"), # index 3
]
如果:
from torch.utils.data import Subset
subset = Subset(data, [2, 0])
那么:
subset[0]
实际返回:
data[2]
也就是:
(1, "text C")
而:
subset[1]
实际返回:
data[0]
也就是:
(1, "text A")
16. __len__() 和 __getitem__() 到底是什么?
这是 Python 里非常重要的机制。
像:
__len__()
__getitem__()
__iter__()
__next__()
这种前后都有两个下划线的方法,叫做 特殊方法,也叫 dunder method。
dunder 是 double underscore 的意思。
这些名字不是随便起的,而是 Python 语言规定好的。
17. 为什么函数名有下划线,但调用时不用写?
因为你写:
len(obj)
Python 背后会自动调用:
obj.__len__()
你写:
obj[0]
Python 背后会自动调用:
obj.__getitem__(0)
也就是说:
外部语法:len(obj)
背后机制:obj.__len__()
外部语法:obj[0]
背后机制:obj.__getitem__(0)
18. 自定义类实现 __len__
例如:
class MyDataset:
def __len__(self):
return 3
使用:
dataset = MyDataset()
print(len(dataset))
输出:
3
因为:
len(dataset)
会自动调用:
dataset.__len__()
19. 自定义类实现 __getitem__
例如:
class MyDataset:
def __init__(self):
self.data = ["A", "B", "C"]
def __getitem__(self, idx):
return self.data[idx]
使用:
dataset = MyDataset()
print(dataset[0])
print(dataset[1])
输出:
A
B
因为:
dataset[0]
背后调用:
dataset.__getitem__(0)
20. 同时实现 __len__ 和 __getitem__
class MyDataset:
def __init__(self):
self.data = ["A", "B", "C"]
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
dataset = MyDataset()
print(len(dataset))
print(dataset[0])
输出:
3
A
这个对象就同时支持:
len(dataset)
dataset[0]
21. 方法名可以自己随便取吗?
普通方法可以随便取名,但特殊语法只认 Python 规定的名字。
例如:
class MyDataset:
def __init__(self):
self.data = ["A", "B", "C"]
def length(self):
return len(self.data)
def get_item(self, idx):
return self.data[idx]
这样你可以手动调用:
dataset.length()
dataset.get_item(0)
但是不能自动支持:
len(dataset)
dataset[0]
因为 Python 只认:
__len__
__getitem__
不认:
length
get_item
22. 常见特殊方法总结
| 外部写法 | 背后调用 |
|---|---|
len(obj) |
obj.__len__() |
obj[i] |
obj.__getitem__(i) |
obj[i] = x |
obj.__setitem__(i, x) |
for x in obj: |
obj.__iter__() |
next(obj) |
obj.__next__() |
obj() |
obj.__call__() |
print(obj) |
obj.__str__() 或 obj.__repr__() |
a + b |
a.__add__(b) |
a * b |
a.__mul__(b) |
23. 这和 PyTorch Dataset 有什么关系?
PyTorch 的 Dataset 就是利用这套机制。
通常自定义数据集要写:
from torch.utils.data import Dataset
class IMDBMapDataset(Dataset):
def __init__(self, data_list):
self.data = data_list
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
这样 PyTorch 就知道:
这个数据集有多长
第 idx 个样本怎么取
于是 DataLoader 才能工作:
DataLoader 想组成一个 batch
↓
决定取哪些索引
↓
调用 dataset[5], dataset[12], dataset[90]
↓
拿到多个样本
↓
组成 batch
24. 为什么有时要自己包装 Dataset?
如果你直接有一个 list:
train_list = list(IMDB(split="train"))
其实可以直接交给 DataLoader:
dataloader = DataLoader(
train_list,
batch_size=4,
shuffle=True,
collate_fn=collate_batch
)
但是如果你想在取样本时自动处理标签和文本,可以自己包装:
class IMDBMapDataset(Dataset):
def __init__(self, data_list, label_pipeline=None, text_pipeline=None):
self.data = data_list
self.label_pipeline = label_pipeline
self.text_pipeline = text_pipeline
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
label, text = self.data[idx]
if self.label_pipeline is not None:
label = self.label_pipeline(label)
if self.text_pipeline is not None:
text = self.text_pipeline(text)
return label, text
这样每次:
dataset[0]
都会自动执行:
取出原始 label, text
↓
处理 label
↓
处理 text
↓
返回处理后的样本
25. IMDB 数据的完整流动过程
以 IMDB 文本分类为例,完整流程可以写成:
硬盘中的 IMDB 原始数据
↓
IMDB(split="train") 创建数据流对象
↓
for 循环或 list(...) 触发读取
↓
得到单个样本 (label, text)
↓
label_pipeline 处理标签
↓
text_pipeline 处理文本
↓
DataLoader 组成 batch
↓
collate_fn padding 并转成 Tensor
↓
送入神经网络模型
26. 最重要的几个判断标准
以后看到一个数据对象,先问四个问题。
26.1 它是什么类型?
print(type(data))
26.2 它能不能取长度?
len(data)
如果可以,说明实现了:
__len__()
26.3 它能不能按索引取值?
data[0]
如果可以,说明实现了:
__getitem__()
26.4 它能不能遍历?
for sample in data:
print(sample)
break
如果可以,说明它是可迭代对象,通常实现了:
__iter__()
27. 三种常见情况对比
| 类型 | len() |
[0] |
for 循环 |
特点 |
|---|---|---|---|---|
list |
可以 | 可以 | 可以 | 数据已经装进内存 |
Dataset |
通常可以 | 通常可以 | 可以 | map-style 数据集 |
IterableDataset / DataPipe |
不一定,可以或不可以 | 通常不可以 | 可以 | 数据流,省内存 |
generator |
通常不可以 | 不可以 | 可以 | 用 yield 逐个产生 |
你的情况是:
IMDB(split="train")
len() 可以
[0] 不可以
for 可以
所以它是:
有长度信息的可迭代数据流对象
28. 一句话总括
Python / PyTorch 中的数据处理,本质上是在处理三件事:
1. 数据是什么结构?
list、tuple、dict、Tensor、Dataset、IterableDataset
2. 数据在哪里?
硬盘、内存、GPU 显存
3. 数据怎么流动?
一次性加载,还是按 batch / 按样本流式读取
而 __len__()、__getitem__() 这些特殊方法,则是 Python 提供的标准接口。
只要你的类实现了这些接口,它就可以支持:
len(obj)
obj[0]
for x in obj:
...
这就是为什么 PyTorch 的 Dataset、DataLoader 能够以统一方式处理各种不同数据来源。
29. 最核心的学习结论
list:
像一个已经装满数据的盒子,可以按编号取。
IterableDataset / DataPipe:
像一条数据水流,需要时一个一个吐出,更省内存。
Dataset:
PyTorch 的标准数据集接口,通常通过 __len__ 和 __getitem__ 定义。
DataLoader:
批量加载器,把单个样本组成 batch。
torch.Tensor:
模型真正能够计算的数据格式。
__len__:
让对象支持 len(obj)。
__getitem__:
让对象支持 obj[index]。
list(dataset):
把数据流全部展开,加载进内存,变成 Python 列表。
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐
所有评论(0)