【Pytorch】torch.argmax 函数详解
·
文章目录
一、一个参数时的 torch.argmax 函数
官网链接:TORCH.ARGMAX
1. 介绍
torch.argmax(input) → LongTensor
返回输入张量 input 所有元素中的最大值的下标(如果有多个最大值,则返回第一个最大值的索引)。
2. 实例
import torch
a = torch.randn(4, 4)
print(a)
print(torch.argmax(a))
输出结果:
tensor([[-0.7018, 1.1887, -0.2344, 0.3216],
[ 1.3548, -0.8575, -1.0585, -0.3462],
[ 0.5845, 0.2345, 1.6444, 1.1129],
[-1.1226, -0.5765, -0.4906, 0.0132]])
tensor(10)
在所有的元素中,第11个元素 1.6444 最大,其索引是 10 ,因此返回 tensor(10)。
二、多个参数时的 torch.argmax 函数
1. 介绍
torch.argmax(input, dim=None, keepdim=False)
- 返回一个张量 input 在某一维度 dim 上的最大值的索引(返回 input 的指定维度 dim 上的最大值的序号)。
- input (Tensor) - 输入张量。
- dim (int) - 要减少的维度(指定维度)。如果为None,则返回扁平输入的argmax。dim 的不同值表示不同维度。在二维中,dim=0 表示行,此时要压缩行,找列的最大值;dim=1 表示列,此时要压缩列,找行的最大值。广泛的来说,我们不管一个矩阵是几维的,比如一个矩阵维度如下:(d0,d1,…,dn−1) ,那么 dim=0 就表示对应到d0 也就是第一个维度,dim=1表示对应到也就是第二个维度,依次类推。指定哪个维度,哪个维度就要消失,就要被压缩。
- Keepdim (bool) - 输出张量是否保留dim。如果dim=None 则忽略。
- 返回值:指定维度 dim 消失之后的矩阵。dim (int) – the dimension to reduce。因为在该维度找了最大值,相当于该维度就被压缩了,只保留了其他维度。这样不好理解,接下来看看例子。
2. 实例
实例1:二维矩阵
import torch
a = torch.tensor(
[
[1, 5, 5, 2],
[9, -6, 2, 8],
[-3, 7, -9, 1]
])
print(a.shape)
b = torch.argmax(a, dim=0) # 压缩行,返回列最大值的序号
print(b)
print(b.shape)
输出结果:
torch.Size([3, 4])
tensor([1, 2, 0, 1])
torch.Size([4])
指定的维度是 0 ,也就是行,要压缩行,就要找列的最大值。
从 [3, 4] -> [4],可见第一个维度 3 消失了。
import torch
a = torch.randn(4, 4)
print(a)
print(torch.argmax(a, dim=1)) #压缩列,返回行最大值的序号
输出结果:
tensor([[-1.3736, 0.8958, -0.6470, 1.3395],
[-0.4279, 0.0682, 0.7635, 1.1857],
[ 1.7861, -0.6515, -0.5456, -0.3066],
[ 1.1898, -0.0208, -0.3662, 0.1799]])
tensor([3, 3, 0, 0])
指定的维度是 1 ,也就是列,要压缩列,就要找行的最大值。
实例2:三维矩阵
import torch
a = torch.tensor([
[
[1, 5, 5, 2],
[9, -6, 2, 8],
[-3, 7, -9, 1]
],
[
[-1, 7, -5, 2],
[9, 6, 2, 8],
[3, 7, 9, 1]
]])
print(a.shape)
b = torch.argmax(a, dim=0)
print(b)
print(b.shape)
输出结果:
torch.Size([2, 3, 4])
tensor([[0, 1, 0, 0],
[0, 1, 0, 0],
[1, 0, 1, 0]])
torch.Size([3, 4])
从 [2, 3, 4] -> [3, 4],可见第一个维度 2 消失了。
实例3:保留dim
import torch
a = torch.tensor([
[
[1, 5, 5, 2],
[9, -6, 2, 8],
[-3, 7, -9, 1]
],
[
[-1, 7, -5, 2],
[9, 6, 2, 8],
[3, 7, 9, 1]
]])
print(a.shape)
b = torch.argmax(a, dim=0, keepdim=True)
print(b)
print(b.shape)
输出结果:
torch.Size([2, 3, 4])
tensor([[[0, 1, 0, 0],
[0, 1, 0, 0],
[1, 0, 1, 0]]])
torch.Size([1, 3, 4])
与实例2的不同之处:加了 keepdim=True 参数,输出从 [3, 4] -> [1, 3, 4],保留了被压缩的第一维,只不过从 2 变成了压缩后的 1 。
参考链接
更多推荐
已为社区贡献18条内容
所有评论(0)