pytorch中的reshape()、view()、nn.flatten()和flatten()
在使用pytorch定义神经网络结构时,经常会看到类似如下的.view() / flatten()用法,这里对其用法做出讲解与演示。
torch.reshape用法
reshape()可以由torch.reshape(),也可由torch.Tensor.reshape()调用,
其作用是在不改变tensor元素数目的情况下改变tensor的shape。
torch.reshape() 需要两个参数,一个是待被改变的张量tensor,一个是想要改变的形状。
torch.reshape(input, shape) → Tensor
input(Tensor)-要重塑的张量
shape(python的元组:ints)-新形状`
案例1
输入:
import torch
a = torch.tensor([[0,1],[2,3]])
x = torch.reshape(a,(-1,))
print (x)
b = torch.arange(4.)
Y = torch.reshape(a,(2,2))
print(Y)
结果:
tensor([0, 1, 2, 3])
tensor([[0, 1],
[2, 3]])
torch.view用法
view()的原理很简单,其实就是把原先tensor中的数据进行排列,排成一行,然后根据所给的view()中的参数从一行中按顺序选择组成最终的tensor。
view()可以有多个参数,这取决于你想要得到的是几维的tensor,一般设置两个参数,也是神经网络中常用的(一般在全连接之前),代表二维。
view(h,w),h代表行(想要变为几行),当不知道要变为几行,但知道要变为几列时可取-1;w代表的是列(想要变为几列),当不知道要变为几列,但知道要变为几行时可取-1。
一、普通用法(手动调整)
view()相当于reshape、resize,重新调整Tensor的形状。
案例2.
输入
import torch
a1 = torch.arange(0,16)
print(a1)
输出
tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])
输入
a2 = a1.view(8, 2)
a3 = a1.view(2, 8)
a4 = a1.view(4, 4)
print(a2)
print(a3)
print(a4)
输出
tensor([[ 0, 1],
[ 2, 3],
[ 4, 5],
[ 6, 7],
[ 8, 9],
[10, 11],
[12, 13],
[14, 15]])
tensor([[ 0, 1, 2, 3, 4, 5, 6, 7],
[ 8, 9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]])
二、特殊用法:参数-1(自动调整size)
view中一个参数定为-1,代表自动调整这个维度上的元素个数,以保证元素的总数不变。
输入
import torch
a1 = torch.arange(0,16)
print(a1)
输出
tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])
输入
a2 = a1.view(-1, 16)
a3 = a1.view(-1, 8)
a4 = a1.view(-1, 4)
a5 = a1.view(-1, 2)
a6 = a1.view(4*4, -1)
a7 = a1.view(1*4, -1)
a8 = a1.view(2*4, -1)
print(a2)
print(a3)
print(a4)
print(a5)
print(a6)
print(a7)
print(a8)
输出
tensor([[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0, 1, 2, 3, 4, 5, 6, 7],
[ 8, 9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]])
tensor([[ 0, 1],
[ 2, 3],
[ 4, 5],
[ 6, 7],
[ 8, 9],
[10, 11],
[12, 13],
[14, 15]])
tensor([[ 0],
[ 1],
[ 2],
[ 3],
[ 4],
[ 5],
[ 6],
[ 7],
[ 8],
[ 9],
[10],
[11],
[12],
[13],
[14],
[15]])
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]])
tensor([[ 0, 1],
[ 2, 3],
[ 4, 5],
[ 6, 7],
[ 8, 9],
[10, 11],
[12, 13],
[14, 15]])
torch.nn.Flatten(start_dim=1,end_dim=-1)
start_dim与end_dim分别表示开始的维度和终止的维度,默认值为1和-1,其中1表示第一维度,-1表示最后的维度。结合起来看意思就是从第一维度到最后一个维度全部给展平为张量。(注意:数据的维度是从0开始的,也就是存在第0维度,第一维度并不是真正意义上的第一个)。
因为其被用在神经网络中,输入为一批数据,第 0 维为batch(输入数据的个数),通常要把一个数据拉成一维,而不是将一批数据拉为一维。所以torch.nn.Flatten()默认从第一维开始平坦化。
使用nn.Flatten(),使用默认参数
官方给出的示例:
input = torch.randn(32, 1, 5, 5)
# With default parameters
m = nn.Flatten()
output = m(input)
output.size()
#torch.Size([32, 25])
# With non-default parameters
m = nn.Flatten(0, 2)
output = m(input)
output.size()
#torch.Size([160, 5])
#开头的代码是注释
整段代码的意思是:给定一个维度为(32,1,5,5)的随机数据。
1.先使用一次nn.Flatten(),使用默认参数:
m = nn.Flatten()
也就是说从第一维度展平到最后一个维度,数据的维度是从0开始的,第一维度实际上是数据的第二位置代表的维度,也就是样例中的1。
因此进行展平后的结果也就是[32,155]→[32,25]
2.接着再使用一次指定参数的nn.Flatten(),即
m = nn.Flatten(0,2)
也就是说从第0维度展平到第2维度,0~2,对应的也就是前三个维度。
因此结果就是[3215,5]→[160,25]
torch.flatten
torch.flatten()函数经常用于写分类神经网络的时候,经过最后一个卷积层之后,一般会再接一个自适应的池化层,输出一个BCHW的向量。这时候就需要用到torch.flatten()函数将这个向量拉平成一个Bx的向量(其中,x = CHW),然后送入到FC层中。
语句结构
torch.flatten(input, start_dim=0, end_dim=-1)
input: 一个 tensor,即要被“摊平”的 tensor。
start_dim: “摊平”的起始维度。
end_dim: “摊平”的结束维度。
作用与 torch.nn.flatten 类似,都是用于展平 tensor 的,只是 torch.flatten 是 function 而不是类,其默认开始维度为第 0 维。例1:
import torch
data_pool = torch.randn(2,2,3,3) # 模拟经过最后一个池化层或自适应池化层之后的输出,Batchsize*c*h*w
print(data_pool)
y=torch.flatten(data_pool,1)
print(y)
输出结果:
结果是一个B*x的向量。
更多推荐
所有评论(0)