pytorch之dataloader,enumerate

from torch.utils.data import TensorDataset
import torch
from torch.utils.data import DataLoader
a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3],
                  [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9]])
b = torch.tensor([44, 55, 66, 44, 55, 66, 44, 55, 66, 44, 55, 66])
train_ids = TensorDataset(a, b)#封装数据a与标签b

# 切片输出
print(train_ids[0:2])
print('='* 80)
#  循环取数据
for x_train, y_label in train_ids:
     print(x_train, y_label)
# DataLoader进行数据封装
print('=' * 80)
 
train_loader = DataLoader(dataset=train_ids, batch_size=4, shuffle=True)
for i, data in enumerate(train_loader):  
    # 注意enumerate返回值有两个,一个是序号,一个是数据(包含训练数据和标签)
    x_data, label = data
    print(' batch:{0}\n x_data:{1}\nlabel: {2}'.format(i, x_data, label))
    

for i, data in enumerate(train_loader,5):  
    # 注意enumerate返回值有两个,一个是序号,一个是数据(包含训练数据和标签)
    x_data, label = data
    print(' batch:{0}\n x_data:{1}\nlabel: {2}'.format(i, x_data, label))


Dataloader:传入数据(这个数据包括:训练数据和标签),
batchsize代表的是每次取出4个样本数据。本例题中一共12个样本,因此迭代3次即可全部取出,迭代结束。
enumerate:返回值有两个:一个是序号,一个是数据train_ids
输出结果如下图:
在这里插入图片描述
在这里插入图片描述
也可如下代码,进行迭代:

for i, data in enumerate(train_loader,5):  
    # 注意enumerate返回值有两个,一个是序号,一个是数据(包含训练数据和标签)
    x_data, label = data
    print(' batch:{0}\n x_data:{1}\nlabel: {2}'.format(i, x_data, label))

for i, data in enumerate(train_loader,1):此代码中5,是batch从5开始,batch仍然是3个。运行结果如下:
在这里插入图片描述

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐