【Torch API】pytorch中的torch.where()函数详解
前言
本文主要讲述torch.where()
的两种用法,第一种是最常规的,也是官方文档所注明的;第二种就是配合bool
型张量的计算
1、torch.where()常规用法
我们先看官方文档的解释:
torch.where(condition, x, y)
根据条件,也就是condiction,返回从x或y中选择的元素的张量(这里会创建一个新的张量,新张量的元素就是从x或y中选的,形状要符合x和y的广播条件)。
Parameters解释如下:
1、condition (bool型张量) :当condition为真,返回x的值,否则返回y的值
2、x (张量或标量):当condition=True时选x的值
2、y (张量或标量):当condition=False时选y的值
1.1 形状相同
先演示形状相同的情况:
import torch
x = torch.tensor([[1, 2, 3], [3, 4, 5], [5, 6, 7]])
y = torch.tensor([[5, 6, 7], [7, 8, 9], [9, 10, 11]])
z = torch.where(x > 5, x, y)
print(f'x = {x}')
print(f'=========================')
print(f'y = {y}')
print(f'=========================')
print(f'x > 5 = {x > 5}')
print(f'=========================')
print(f'z = {z}')
>print result:
x = tensor([[1, 2, 3],
[3, 4, 5],
[5, 6, 7]])
=========================
y = tensor([[ 5, 6, 7],
[ 7, 8, 9],
[ 9, 10, 11]])
=========================
x > 5 = tensor([[False, False, False],
[False, False, False],
[False, True, True]])
=========================
z = tensor([[5, 6, 7],
[7, 8, 9],
[9, 6, 7]])
上面定义了x和y,两者的形状shape=(3, 3)相同,然后condition = x > 5也是就x中的每个元素值都要大于5,这里就能看到x中第0行和第1行都是False,只有第2行的1、2列是True,然后前面说了,为True时使用的是x中的值,为False时使用的是y中的值,那么新创建的z前两行和第2行0列使用的是y中的值,剩下两个使用x中的值,z的shape也是(3, 3)。
1.2 标量情况
x = 3
y = torch.tensor([[1, 5, 7]])
z = torch.where(y > 2, y, x)
print(f'y > 2 = {y > 2}')
print(f'=========================')
print(f'z = {z}')
print(f'y > 2 = {y > 2}')
print(f'=========================')
print(f'z = {z}')
>print result:
y > 2 = tensor([[False, True, True]])
=========================
z = tensor([[3, 5, 7]])
在这里,x
是一个标量,condition = y > 2
,你要是问我为什么不把condition
设为condition = x > 2
,很简单,x > 2
不是bool Tensor
。这里标量和张量是可以进行广播的!!
example:
a = torch.tensor([1, 5, 7])
b = 3
c = a + b
d = torch.tensor([3, 3, 3])
e = a + d
print(f'c = {c}')
print(f'e = {e}')
>print result:
c = tensor([ 4, 8, 10])
d = tensor([ 4, 8, 10])
其实就是把b = 3
拉成了[3, 3, 3]
,也是就d
那样。
1.3 形状不同
其实标量那里也算是形状不同了,这里我再啰嗦一下吧,看例子:
x = torch.tensor([[1, 3, 5]])
y = torch.tensor([[2], [4], [6]])
z = torch.where(x > 2, x, y)
print(f'x = {x}')
print(f'=========================')
print(f'y = {y}')
print(f'=========================')
print(f'x > 2 = {x > 2}')
print(f'=========================')
print(f'z = {z}')
>print result:
x = tensor([[1, 3, 5]])
=========================
y = tensor([[2],
[4],
[6]])
=========================
x > 2 = tensor([[False, True, True]])
=========================
z = tensor([[2, 3, 5],
[4, 3, 5],
[6, 3, 5]])
上面x.shape=(1, 3) y.shape=(3, 1),然后condition = x > 2的shape=(1, 3),是可广播的,所以运算也能成功,在计算torch.where(x > 2, x, y)时,分别对x、y、condition进行广播,x.shape=(3, 3),y.shape=(3, 3),condition.shape=(3, 3)
所以y
的值替换第0列,第1、2列为x
的值。
更多的广播形式请读者朋友自行尝试
2、torch.where()特殊用法
torch.where(a & b)
a
和b
都是bool Tensor
,返回的是一个元组,元组第一项是a、b
中都为True
的行的index
的Tensor
,第二项是a、b
都为True
列的index
的Tensor
请看例子:
a = torch.tensor([[0, 1, 1], [1, 0, 0], [0, 0, 1]], dtype=torch.bool)
b = torch.ones((3, 3), dtype=torch.bool)
c = torch.where(a & b)
print(f'a = {a}')
print(f'=========================')
print(f'b = {b}')
print(f'=========================')
print(f'c = {c}')
>print result:
a = tensor([[False, True, True],
[ True, False, False],
[False, False, True]])
=========================
b = tensor([[True, True, True],
[True, True, True],
[True, True, True]])
=========================
c = (tensor([0, 0, 1, 2]), tensor([1, 2, 0, 2]))
c
就是一个元组,第0项是a、b
都为True
的行标,第1项是a、b
都为True
的列标
总结
以上就是torch.where()的两种用法,看起来比较麻烦,多练练也就是那样,特别一点的就是一个广播机制一个特殊用法,欢迎评论指正!
更多推荐
所有评论(0)