nn.Linear() 函数可二维或三维输入
nn.Linear() 常用于处理二维信息,全连接层的输入与输出一般都设置为二维张量,形状通常为[batch_size, size],不同于卷积层要求输入输出是四维张量。从输入输出的张量的shape角度来理解,相当于一个输入为[batch_size, in_features]的张量变换成了[batch_size, out_features]的输出张量。仔细观察Shape可以发现,nn.Linear
文章共460字 · 阅读需要大约2分钟
一键AI生成摘要,助你高效阅读
问答
·
nn.Linear
PyTorch的nn.Linear()
用于设置网络中的全连接层。
常用于处理二维信息,全连接层的输入与输出一般都设置为二维张量,形状通常为[batch_size, size]
,不同于卷积层要求输入输出是四维张量。从输入输出的张量的shape角度来理解,相当于一个输入为[batch_size, in_features]
的张量变换成了[batch_size, out_features]
的输出张量。
仔细观察Shape可以发现,nn.Linear(in_features,out_features)可以用于处理多维信息。输入三维张量[batch_size, n, in_features],经过线性层,可以得到输出张量[batch_size, n, out_features].
如下图:
参考博客
nn.Linear()函数详解及代码使用_墨晓白的博客-CSDN博客
Pytorch linear 多维 输入的参数_linear参数_又是花落时的博客-CSDN博客
更多推荐
已为社区贡献9条内容
所有评论(0)