PyTorch碎片:F.pad的图文透彻理解
1. F.pad函数定义
F.pad是pytorch内置的tensor扩充函数,便于对数据集图像或中间层特征进行维度扩充,下面是pytorch官方给出的函数定义。
torch.nn.functional.pad(input, pad, mode='constant', value=0)
函数变量说明:
- input
需要扩充的tensor,可以是图像数据,抑或是特征矩阵数据 - pad
扩充维度,用于预先定义出某维度上的扩充参数 - mode
扩充方法,’constant‘, ‘reflect’ or ‘replicate’三种模式,分别表示常量,反射,复制 - value
扩充时指定补充值,但是value只在mode='constant’有效,即使用value填充在扩充出的新维度位置,而在’reflect’和’replicate’模式下,value不可赋值
2. F.pad透彻理解
为了方便从可视角度上分析F.pad的实际效果,首先给出空值矩阵,并且为了能够让宁能复现效果,实际代码全部给出,并最小化解释复杂度。
import torch
import torch.nn.functional as F
t4d = torch.empty(1, 3, 5, 3)
其中t4d中维度分别表示(batchsize, channel, height, width),为了透彻理解扩充参数pad的定义和实际效果,从随后的小节中我将逐层提高扩充的复杂度,建议跟我的思路来。
2.1 最后一维的扩充
为了便于理解,只观察输入矩阵t4d的最后两维,即一个5行3列的矩阵, 如图1
如果F.pad中第二个参数pad只定义两个参数,表示只对输入矩阵的最后一个维度进行扩充,不会对前两个维度造成任何影响,所以此处直接忽略前两个维度。
p1d = (1, 2)
t1 = F.pad(t4d, p1d, 'constant', 1)
先输出看一下t1的维度变化:
>>> print('原始矩阵大小为:', t4d.shape)
'''
原始矩阵大小为:torch.Size([1, 3, 5, 3])
'''
>>> print('t1矩阵大小为:', t1.shape)
'''
t1矩阵大小为:torch.Size([1, 3, 5, 6])
'''
接下来,从可视化的角度分析一下,原始矩阵全为0值,扩充维度全部用1值填充,这样易于理解。
从图2可以明显看出,左侧扩充了1列,右侧扩充了2列,即原始矩阵大小从5×3扩充到5×6,则p1d的参数设置意义为
p1d = (1, 2)
# p1d = (左边填充数, 右边填充数)
此外,在实际项目中,为了保持代码的可扩展性,按下面定义,也可以获取同样的效果。
p1d_ = (1, 2, 0, 0)
t1 = F.pad(t4d, p1d_, 'constant', 1)
2.2 两维扩充
# p1d = (1, 2) # 与p1d做对比
p2d = (1, 2, 3, 4)
t2 = F.pad(t4d, p2d, 'constant', 2)
同样的,先分析下原始矩阵的维度变化情况:
>>> print('原始矩阵大小为:', t4d.shape)
'''
原始矩阵大小为:torch.Size([1, 3, 5, 3])
'''
>>> print('t2矩阵大小为:', t2.shape)
'''
t2矩阵大小为:torch.Size([1, 3, 12, 6])
'''
这里给出的是两维的扩充代码,为了便于理解,看一下实际的扩充效果,如图3
看图实际一目了然,对左侧扩充了1列,右侧扩充了2列,上边扩充了3行,下边扩充了4行。也就是说,前两个参数对最后一个维度有效,后两个参数对倒数第二维有效。接下来就可以看一下p2d参数的实际意义:
p2d = (1, 2, 3, 4)
# p2d = (左边填充数, 右边填充数, 上边填充数, 下边填充数)
2.3 三维扩充
# p1d = (1, 2) # 与p1d做对比
# p2d = (1, 2, 3, 4) # 与p2d做对比
p3d = (1, 2, 3, 4, 5, 6)
t3 = F.pad(t4d, p3d, 'constant', 3)
仍然先分析下原始矩阵的维度变化情况:
>>> print('原始矩阵大小为:', t4d.shape)
'''
原始矩阵大小为:torch.Size([1, 3, 5, 3])
'''
>>> print('t3矩阵大小为:', t3.shape)
'''
t3矩阵大小为:torch.Size([1, 14, 12, 6])
'''
从可视化角度分析,如图4所示。
根据p3d = (1, 2, 3, 4, 5, 6)中,前4个参数完成了在高和宽维度上的扩张,后两个参数则完成了对通道维度上的扩充。接下来就可以看一下p3d参数的实际意义:
p3d = (1, 2, 3, 4, 5, 6)
# p3d = (左边填充数, 右边填充数, 上边填充数, 下边填充数, 前边填充数,后边填充数)
3. 小结
这回小节,简单用一个图完美表述F.pad()的应用。
微信公众号搜索 pytorch_star ,搜集的很多AI资料,会有用的!
交个朋友,来吧!
更多推荐
所有评论(0)