0. 问题背景

在使用 torchtext.datasets.IMDB 时,我们经常会写:

from torchtext.datasets import IMDB

train_dataset = IMDB(split="train")

刚开始很容易以为 train_dataset 是一个普通列表,可以直接:

len(train_dataset)
train_dataset[0]

但实际测试可能会发现:

len(train_dataset)      # 可以用
train_dataset[0]        # 报错

这说明:这个对象知道自己有多少条数据,但不支持按索引直接取值。

要理解这个现象,需要搞清楚几个核心问题:

  1. 数据到底是什么格式?
  2. 数据是存储在硬盘里,还是内存里?
  3. IterableDataset 为什么省内存?
  4. list(train_dataset) 到底做了什么?
  5. __len__()__getitem__() 这些函数为什么有特殊作用?

1. 数据格式到底是什么?

Python 里常见的数据格式可以分成几类。

1.1 基础数据类型

int:整数

label = 1

常用于标签、索引、计数。

例如 IMDB 情感分类数据里,标签可能是:

1
2

代表负面或正面评论。


float:小数

loss = 0.35
learning_rate = 0.001

常用于损失函数、学习率、实验指标。


str:字符串

text = "This movie is very good."

文本数据通常是字符串。

但是神经网络不能直接处理字符串,所以文本最终要转换成数字,比如:

[12, 45, 7, 91]

再变成:

torch.tensor([12, 45, 7, 91])

2. 常见容器数据结构

2.1 list:列表,用 []

data = [1, 2, 3, 4]

特点:

有顺序
可以修改
可以用下标访问

例如:

data[0]

输出:

1

在 IMDB 中,如果我们把数据集转成列表,可能长这样:

train_list = [
    (1, "This movie is good."),
    (2, "This movie is bad."),
    (1, "I like this film.")
]

整体是 list,里面每个元素是一个样本。


2.2 tuple:元组,用 ()

sample = (1, "This movie is good.")

特点:

有顺序
通常不修改
可以用下标访问

例如:

sample[0]
sample[1]

分别得到:

1
"This movie is good."

在 IMDB 里,一个样本通常是:

(label, text)

也就是:

(标签, 评论文本)

例如:

(1, "This movie is good.")

2.3 dict:字典,用 {}

sample = {
    "label": 1,
    "text": "This movie is good."
}

特点是用 key 取值:

sample["label"]
sample["text"]

字典适合表示有明确字段名的数据。

比如图像任务中,一个样本可能是:

sample = {
    "image": image_tensor,
    "label": 3,
    "filename": "cat_001.jpg"
}

2.4 set:集合,用 {}

vocab = {"good", "bad", "movie"}

特点:

元素不重复
通常不关心顺序

常用于去重,比如构建词表前统计所有出现过的词。


3. IMDB 数据集返回的是什么?

train_dataset = IMDB(split="train")

它通常不是普通的 list,而是一个 IterableDataset / DataPipe 风格的数据流对象

它的特点是:

for label, text in train_dataset:
    ...

通常可以用。

但是:

train_dataset[0]

不一定可以用。

也就是说,它更像一个“数据流”,而不是一个已经完整展开的列表。


4. IterableDataset / DataPipe 是什么?

4.1 普通列表像一个盒子

train_list = [
    (1, "review text 1"),
    (2, "review text 2"),
    (1, "review text 3"),
]

它已经完整放在内存里。

所以你可以:

train_list[0]
len(train_list)

它像一个有编号的样本库:

编号 0 -> 第 0 个样本
编号 1 -> 第 1 个样本
编号 2 -> 第 2 个样本

4.2 IterableDataset 像一条水流

train_dataset = IMDB(split="train")

它更像:

数据源 -> 样本1 -> 样本2 -> 样本3 -> ...

你可以顺着数据流一个一个取:

for sample in train_dataset:
    print(sample)
    break

但是通常不能直接说:

train_dataset[100]

因为它不一定支持随机访问。


5. 为什么 len(train_dataset) 可以,但 train_dataset[0] 不行?

这个现象说明:

它实现了 __len__()
但没有实现真正可用的 __getitem__()

也就是说,它知道自己有多少条数据,但不支持按索引直接跳到第几个样本。

类比一下:

一个排队窗口知道今天一共有 25000 个人
但你不能直接把第 10000 个人拎出来
只能按顺序一个一个叫号

所以:

len(train_dataset)

可以。

但:

train_dataset[0]

不行。


6. 如何检查一个数据对象是什么格式?

拿到任何数据集,可以先运行下面几步。

6.1 看对象类型

print(type(train_dataset))

可能输出类似:

<class 'torch.utils.data.datapipes.iter.sharding.ShardingFilterIterDataPipe'>

这说明它是 DataPipe / 可迭代数据流对象。


6.2 看能不能用 len

try:
    print(len(train_dataset))
except Exception as e:
    print("len 不可用:", type(e).__name__, e)

6.3 看能不能用下标

try:
    print(train_dataset[0])
except Exception as e:
    print("索引不可用:", type(e).__name__, e)

6.4 看能不能遍历

for sample in train_dataset:
    print(sample)
    break

如果能输出一个样本,说明它是可迭代对象。


6.5 看一个样本内部是什么结构

sample = next(iter(train_dataset))

print(sample)
print(type(sample))

如果输出:

(1, "This movie is good.")
<class 'tuple'>

说明每个样本是一个 tuple

进一步拆开:

label, text = sample

print(label, type(label))
print(text, type(text))

可能得到:

1 <class 'int'>
"This movie is good." <class 'str'>

7. list(train_dataset) 到底做了什么?

如果写:

train_list = list(train_dataset)

它会把数据流里的样本一个一个取出来,然后全部装进一个 Python 列表。

等价于:

train_list = []

for sample in train_dataset:
    train_list.append(sample)

最后得到:

[
    (1, "review text 1"),
    (2, "review text 2"),
    (1, "review text 3"),
    ...
]

这时候:

len(train_list)
train_list[0]

都可以用了。


8. 为什么 IterableDataset 更省内存?

因为它不会一开始就把所有数据加载到内存里。

8.1 全量加载

train_list = list(IMDB(split="train"))

这会把整个训练集全部读出来,放到内存中。

如果数据集很小,比如 IMDB,通常问题不大。

但如果数据集有 100GB,这种方式可能会爆内存。


8.2 流式读取

train_dataset = IMDB(split="train")

for label, text in train_dataset:
    ...

这种方式是需要一个样本,就读取一个样本。

数据流大概是:

硬盘上的数据
    ↓
读取一个样本
    ↓
处理一个样本
    ↓
继续读取下一个样本

内存里不需要同时保存全部数据。

所以:

list / map-style Dataset:
访问方便,但可能占内存大

IterableDataset / DataPipe:
更省内存,但不一定支持随机访问

9. 什么叫“数据加载到内存中”?

数据长期存储在硬盘或 SSD 中。

但是程序运行时,CPU/GPU 通常不能直接在硬盘上计算,而是要先把数据读到内存。

可以这样理解:

硬盘 SSD/HDD:长期仓库
内存 RAM:临时工作台
CPU/GPU:真正干活的人

当你写:

text = open("review.txt", "r", encoding="utf-8").read()

发生的是:

1. Python 找到硬盘上的 review.txt
2. 操作系统读取文件内容
3. 数据被复制到内存 RAM
4. Python 创建一个 str 字符串对象
5. 变量 text 指向这个对象

10. 一行代码过去之后,数据会消失吗?

不一定。

关键看有没有变量引用它。

10.1 有变量接住

x = [1, 2, 3]

这一行执行完后,列表还在内存里,因为变量 x 指向它。

可以理解为:

x ───> [1, 2, 3]

所以后面还能用:

print(x)

10.2 没有变量接住

[1, 2, 3]

这一行也会创建一个列表。

但是没有变量保存它。

所以执行完之后,这个对象没有引用,后面会被 Python 回收。


11. Python 程序结束后,内存会释放吗?

会。

当 Python 程序结束时,它占用的 RAM 内存通常会被操作系统释放。

但是要分清楚:

硬盘上的原始数据文件:不会被删除
内存中的临时对象:会被释放

例如:

train_list = list(IMDB(split="train"))

程序运行时,train_list 在内存中。

程序结束后,train_list 消失。

但是硬盘上的 IMDB 数据还在。


11.1 Jupyter Notebook 要注意

在 Jupyter 中,一个 cell 运行结束,不等于 Python 程序结束。

只要 kernel 没有重启,变量还在内存中。

例如运行:

train_list = list(IMDB(split="train"))

这个 cell 结束后,train_list 仍然可以用。

但是如果你点击:

Restart Kernel

变量就没了,内存被释放。


12. 如果有 100GB 数据,内存必须大于 100GB 吗?

不一定。

关键看你是不是一次性全读入内存。

12.1 如果一次性全加载

data = list(dataset)

如果数据真的有 100GB,那么内存通常要大于 100GB。

而且还要考虑:

Python 对象开销
中间计算结果
batch 数据
模型参数
系统占用

所以实际可能需要 128GB、192GB 甚至更多。


12.2 如果分批读取

深度学习一般不是一次性读取所有数据,而是用 DataLoader 分批读取:

for batch in dataloader:
    ...

数据流是:

硬盘中的 100GB 数据
    ↓
每次读取一个 batch
    ↓
送入模型
    ↓
计算完成后继续下一个 batch

内存中同时存在的只是当前 batch 和少量缓存。

所以即使数据集有 100GB,也可能用 32GB 或 64GB 内存训练。


13. DataLoader 是什么?

DataLoader 是 PyTorch 的批量加载器。

例如:

from torch.utils.data import DataLoader

dataloader = DataLoader(
    train_dataset,
    batch_size=4,
    shuffle=True,
    collate_fn=collate_batch
)

它的作用是把单个样本组成 batch。

单个样本:

(label, text)

一个 batch:

[
    (label1, text1),
    (label2, text2),
    (label3, text3),
    (label4, text4)
]

然后交给 collate_fn 处理成模型能吃的张量。


14. random_split(list(train_dataset), [20000, 5000]) 是什么意思?

代码:

from torch.utils.data import random_split

train_dataset, valid_dataset = random_split(
    list(train_dataset), [20000, 5000]
)

这句可以分成三步理解。


14.1 第一步:转成 list

full_train_list = list(train_dataset)

这一步把原来的数据流对象全部展开成列表:

[
    (1, "review text 1"),
    (2, "review text 2"),
    ...
]

14.2 第二步:随机切分

train_subset, valid_subset = random_split(
    full_train_list,
    [20000, 5000]
)

表示把 25000 条数据随机分成:

训练集:20000 条
验证集:5000 条

14.3 第三步:变量覆盖

train_dataset = train_subset
valid_dataset = valid_subset

注意:这里新的 train_dataset 已经不是原来的 IMDB 数据流对象了,而是一个 Subset 对象。


15. Subset 是什么?

Subset 表示原始数据集的一个子集。

它不是把数据完整复制一份,而是保存:

原始 dataset
选中的 indices

例如:

data = [
    (1, "text A"),  # index 0
    (2, "text B"),  # index 1
    (1, "text C"),  # index 2
    (2, "text D"),  # index 3
]

如果:

from torch.utils.data import Subset

subset = Subset(data, [2, 0])

那么:

subset[0]

实际返回:

data[2]

也就是:

(1, "text C")

而:

subset[1]

实际返回:

data[0]

也就是:

(1, "text A")

16. __len__()__getitem__() 到底是什么?

这是 Python 里非常重要的机制。

像:

__len__()
__getitem__()
__iter__()
__next__()

这种前后都有两个下划线的方法,叫做 特殊方法,也叫 dunder method

dunder 是 double underscore 的意思。

这些名字不是随便起的,而是 Python 语言规定好的。


17. 为什么函数名有下划线,但调用时不用写?

因为你写:

len(obj)

Python 背后会自动调用:

obj.__len__()

你写:

obj[0]

Python 背后会自动调用:

obj.__getitem__(0)

也就是说:

外部语法:len(obj)
背后机制:obj.__len__()

外部语法:obj[0]
背后机制:obj.__getitem__(0)

18. 自定义类实现 __len__

例如:

class MyDataset:
    def __len__(self):
        return 3

使用:

dataset = MyDataset()

print(len(dataset))

输出:

3

因为:

len(dataset)

会自动调用:

dataset.__len__()

19. 自定义类实现 __getitem__

例如:

class MyDataset:
    def __init__(self):
        self.data = ["A", "B", "C"]

    def __getitem__(self, idx):
        return self.data[idx]

使用:

dataset = MyDataset()

print(dataset[0])
print(dataset[1])

输出:

A
B

因为:

dataset[0]

背后调用:

dataset.__getitem__(0)

20. 同时实现 __len____getitem__

class MyDataset:
    def __init__(self):
        self.data = ["A", "B", "C"]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


dataset = MyDataset()

print(len(dataset))
print(dataset[0])

输出:

3
A

这个对象就同时支持:

len(dataset)
dataset[0]

21. 方法名可以自己随便取吗?

普通方法可以随便取名,但特殊语法只认 Python 规定的名字。

例如:

class MyDataset:
    def __init__(self):
        self.data = ["A", "B", "C"]

    def length(self):
        return len(self.data)

    def get_item(self, idx):
        return self.data[idx]

这样你可以手动调用:

dataset.length()
dataset.get_item(0)

但是不能自动支持:

len(dataset)
dataset[0]

因为 Python 只认:

__len__
__getitem__

不认:

length
get_item

22. 常见特殊方法总结

外部写法 背后调用
len(obj) obj.__len__()
obj[i] obj.__getitem__(i)
obj[i] = x obj.__setitem__(i, x)
for x in obj: obj.__iter__()
next(obj) obj.__next__()
obj() obj.__call__()
print(obj) obj.__str__()obj.__repr__()
a + b a.__add__(b)
a * b a.__mul__(b)

23. 这和 PyTorch Dataset 有什么关系?

PyTorch 的 Dataset 就是利用这套机制。

通常自定义数据集要写:

from torch.utils.data import Dataset

class IMDBMapDataset(Dataset):
    def __init__(self, data_list):
        self.data = data_list

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

这样 PyTorch 就知道:

这个数据集有多长
第 idx 个样本怎么取

于是 DataLoader 才能工作:

DataLoader 想组成一个 batch
    ↓
决定取哪些索引
    ↓
调用 dataset[5], dataset[12], dataset[90]
    ↓
拿到多个样本
    ↓
组成 batch

24. 为什么有时要自己包装 Dataset?

如果你直接有一个 list:

train_list = list(IMDB(split="train"))

其实可以直接交给 DataLoader

dataloader = DataLoader(
    train_list,
    batch_size=4,
    shuffle=True,
    collate_fn=collate_batch
)

但是如果你想在取样本时自动处理标签和文本,可以自己包装:

class IMDBMapDataset(Dataset):
    def __init__(self, data_list, label_pipeline=None, text_pipeline=None):
        self.data = data_list
        self.label_pipeline = label_pipeline
        self.text_pipeline = text_pipeline

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        label, text = self.data[idx]

        if self.label_pipeline is not None:
            label = self.label_pipeline(label)

        if self.text_pipeline is not None:
            text = self.text_pipeline(text)

        return label, text

这样每次:

dataset[0]

都会自动执行:

取出原始 label, text
    ↓
处理 label
    ↓
处理 text
    ↓
返回处理后的样本

25. IMDB 数据的完整流动过程

以 IMDB 文本分类为例,完整流程可以写成:

硬盘中的 IMDB 原始数据
        ↓
IMDB(split="train") 创建数据流对象
        ↓
for 循环或 list(...) 触发读取
        ↓
得到单个样本 (label, text)
        ↓
label_pipeline 处理标签
        ↓
text_pipeline 处理文本
        ↓
DataLoader 组成 batch
        ↓
collate_fn padding 并转成 Tensor
        ↓
送入神经网络模型

26. 最重要的几个判断标准

以后看到一个数据对象,先问四个问题。

26.1 它是什么类型?

print(type(data))

26.2 它能不能取长度?

len(data)

如果可以,说明实现了:

__len__()

26.3 它能不能按索引取值?

data[0]

如果可以,说明实现了:

__getitem__()

26.4 它能不能遍历?

for sample in data:
    print(sample)
    break

如果可以,说明它是可迭代对象,通常实现了:

__iter__()

27. 三种常见情况对比

类型 len() [0] for 循环 特点
list 可以 可以 可以 数据已经装进内存
Dataset 通常可以 通常可以 可以 map-style 数据集
IterableDataset / DataPipe 不一定,可以或不可以 通常不可以 可以 数据流,省内存
generator 通常不可以 不可以 可以 yield 逐个产生

你的情况是:

IMDB(split="train")
len() 可以
[0] 不可以
for 可以

所以它是:

有长度信息的可迭代数据流对象

28. 一句话总括

Python / PyTorch 中的数据处理,本质上是在处理三件事:

1. 数据是什么结构?
   list、tuple、dict、Tensor、Dataset、IterableDataset

2. 数据在哪里?
   硬盘、内存、GPU 显存

3. 数据怎么流动?
   一次性加载,还是按 batch / 按样本流式读取

__len__()__getitem__() 这些特殊方法,则是 Python 提供的标准接口。

只要你的类实现了这些接口,它就可以支持:

len(obj)
obj[0]
for x in obj:
    ...

这就是为什么 PyTorch 的 DatasetDataLoader 能够以统一方式处理各种不同数据来源。


29. 最核心的学习结论

list:
像一个已经装满数据的盒子,可以按编号取。

IterableDataset / DataPipe:
像一条数据水流,需要时一个一个吐出,更省内存。

Dataset:
PyTorch 的标准数据集接口,通常通过 __len__ 和 __getitem__ 定义。

DataLoader:
批量加载器,把单个样本组成 batch。

torch.Tensor:
模型真正能够计算的数据格式。

__len__:
让对象支持 len(obj)。

__getitem__:
让对象支持 obj[index]。

list(dataset):
把数据流全部展开,加载进内存,变成 Python 列表。
Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐