前言

本文主要讲述torch.where()的两种用法,第一种是最常规的,也是官方文档所注明的;第二种就是配合bool型张量的计算

1、torch.where()常规用法

我们先看官方文档的解释:

torch.where(condition, x, y)
根据条件,也就是condiction,返回从x或y中选择的元素的张量(这里会创建一个新的张量,新张量的元素就是从x或y中选的,形状要符合x和y的广播条件)。
Parameters解释如下:
1、condition (bool型张量) :当condition为真,返回x的值,否则返回y的值
2、x (张量或标量):当condition=True时选x的值
2、y (张量或标量):当condition=False时选y的值

1.1 形状相同

先演示形状相同的情况:

import torch

x = torch.tensor([[1, 2, 3], [3, 4, 5], [5, 6, 7]])
y = torch.tensor([[5, 6, 7], [7, 8, 9], [9, 10, 11]])
z = torch.where(x > 5, x, y)

print(f'x = {x}')
print(f'=========================')
print(f'y = {y}')
print(f'=========================')
print(f'x > 5 = {x > 5}')
print(f'=========================')
print(f'z = {z}')

>print result:
x = tensor([[1, 2, 3],
        [3, 4, 5],
        [5, 6, 7]])
=========================
y = tensor([[ 5,  6,  7],
        [ 7,  8,  9],
        [ 9, 10, 11]])
=========================
x > 5 = tensor([[False, False, False],
        [False, False, False],
        [False,  True,  True]])
=========================
z = tensor([[5, 6, 7],
        [7, 8, 9],
        [9, 6, 7]])

上面定义了x和y,两者的形状shape=(3, 3)相同,然后condition = x > 5也是就x中的每个元素值都要大于5,这里就能看到x中第0行和第1行都是False,只有第2行的1、2列是True,然后前面说了,为True时使用的是x中的值,为False时使用的是y中的值,那么新创建的z前两行和第2行0列使用的是y中的值,剩下两个使用x中的值,z的shape也是(3, 3)。
 

1.2 标量情况

x = 3
y = torch.tensor([[1, 5, 7]])
z = torch.where(y > 2, y, x)

print(f'y > 2 = {y > 2}')
print(f'=========================')
print(f'z = {z}')

print(f'y > 2 = {y > 2}')
print(f'=========================')
print(f'z = {z}')

>print result:
y > 2 = tensor([[False,  True,  True]])
=========================
z = tensor([[3, 5, 7]])

在这里,x是一个标量,condition = y > 2,你要是问我为什么不把condition设为condition = x > 2,很简单,x > 2不是bool Tensor这里标量和张量是可以进行广播的!!

example:

a = torch.tensor([1, 5, 7])
b = 3
c = a + b
d = torch.tensor([3, 3, 3])
e = a + d

print(f'c = {c}')
print(f'e = {e}')

>print result:
c = tensor([ 4,  8, 10])
d = tensor([ 4,  8, 10])

其实就是把b = 3拉成了[3, 3, 3],也是就d那样。

1.3 形状不同

其实标量那里也算是形状不同了,这里我再啰嗦一下吧,看例子:

x = torch.tensor([[1, 3, 5]])
y = torch.tensor([[2], [4], [6]])
z = torch.where(x > 2, x, y)

print(f'x = {x}')
print(f'=========================')
print(f'y = {y}')
print(f'=========================')
print(f'x > 2 = {x > 2}')
print(f'=========================')
print(f'z = {z}')

>print result:
x = tensor([[1, 3, 5]])
=========================
y = tensor([[2],
        [4],
        [6]])
=========================
x > 2 = tensor([[False,  True,  True]])
=========================
z = tensor([[2, 3, 5],
        [4, 3, 5],
        [6, 3, 5]])

上面x.shape=(1, 3) y.shape=(3, 1),然后condition = x > 2的shape=(1, 3),是可广播的,所以运算也能成功,在计算torch.where(x > 2, x, y)时,分别对x、y、condition进行广播,x.shape=(3, 3),y.shape=(3, 3),condition.shape=(3, 3)
在这里插入图片描述

 所以y的值替换第0列,第1、2列为x的值。
更多的广播形式请读者朋友自行尝试

2、torch.where()特殊用法

torch.where(a & b)
ab都是bool Tensor,返回的是一个元组,元组第一项是a、b中都为TrueindexTensor,第二项是a、b都为TrueindexTensor

 请看例子:

a = torch.tensor([[0, 1, 1], [1, 0, 0], [0, 0, 1]], dtype=torch.bool)
b = torch.ones((3, 3), dtype=torch.bool)
c = torch.where(a & b)

print(f'a = {a}')
print(f'=========================')
print(f'b = {b}')
print(f'=========================')
print(f'c = {c}')

>print result:
a = tensor([[False,  True,  True],
        [ True, False, False],
        [False, False,  True]])
=========================
b = tensor([[True, True, True],
        [True, True, True],
        [True, True, True]])
=========================
c = (tensor([0, 0, 1, 2]), tensor([1, 2, 0, 2]))

c就是一个元组,第0项是a、b都为True行标,第1项是a、b都为True列标

总结

以上就是torch.where()的两种用法,看起来比较麻烦,多练练也就是那样,特别一点的就是一个广播机制一个特殊用法,欢迎评论指正!

Logo

旨在为数千万中国开发者提供一个无缝且高效的云端环境,以支持学习、使用和贡献开源项目。

更多推荐