pytorch中collate_fn函数的使用&如何向collate_fn函数传参
1.为什么要使用collate_fn
这里先从dataset的运行机制讲起.
在dataloader按照batch进行取数据的时候, 是取出大小等同于batch size的index列表; 然后将列表列表中的index输入到dataset的getitem()函数中,取出该index对应的数据; 最后, 对每个index对应的数据进行堆叠, 就形成了一个batch的数据.
⚠️ 在最后一步堆叠的时候可能会出现问题: 如果一条数据中所含有的每个数据元的长度不同, 那么将无法进行堆叠. 如: multi-hot类型的数据, 序列数据.
在使用这些数据时, 通常需要先进行长度上的补齐, 再进行堆叠. 以现在的流程, 是没有办法加入该操作的.
⭐️ 此外, 某些优化方法是要对一个batch的数据进行操作.
collate-fn函数就是手动将抽取出的样本堆叠起来的函数
2.collate_fn的用法
loader = Dataloader(dataset, batch_size, shuffle, collate_fn, ...)
collate_fn函数是实例化dataloader的时候, 以函数形式传递给loader.
既然是collate_fn是以函数作为参数进行传递, 那么其一定有默认参数. 这个默认参数就是getitem函数返回的数据项的batch形成的列表.
先假设, datase类是如下形式:
class testData(Dataset):
def __init__(self):
super().__init__()
def __getitem__(self, index):
return x, y
可以看到, 假设的dataset返回两个数据项: x和y. 那么, 传入collate_fn的参数定义为data, 则其shape为(batch_size, 2,…).
知道了输入参数的形式, 就可以去定义collate_fn函数了:
def collate_fn(data):
for unit in data:
unit_x.append(unit[0])
unit_y.append(unit[1])
...
return {x: torch.tensor(unit_x), y: torch.tensor(unit_y)}
可以看到我对collate_fn函数的定义,最后返回的是一个字典. 这也是collate_fn函数最大的一个好处: 可以自定义取出一个batch数据的格式. 该函数的输出就是对dataloader进行遍历, 取出一个batch的数据.
3.如何给collate_fn函数传参
在collate_fn的使用过程中, 我发现只输入data有时候是非常不方便的, 需要额外的参数来传递其他变量.
这里有两个方法可以解决以上问题:
(1)使用lambda函数
info = args.info # info是已经定义过的
loader = Dataloader(collate_fn=lambda x: collate_fn(x, info))
这里巧用lambda函数, 相当于使用collate_fn函数再定义了一个匿名函数.
(2)创建可被调用的类
class collater():
def __init__(self, *params):
self. params = params
def __call__(self, data):
'''在这里重写collate_fn函数'''
collate_fn = collater(*params)
loader = Dataloader(collate_fn=collate_fn)
4.总结
collate_fn的用处:
- 自定义数据堆叠过程
- 自定义batch数据的输出形式
collate_fn的使用
- 定义一个以data为输入的函数
- 注意, 输入输出分别域getitem函数和loader调用时对应
更多推荐
所有评论(0)