「解析」如何优雅的学习 torch.einsum()
「解析」
Einsum 是简约版的 ‘求和公式’ ,故在看 einsum 公式的时候可以 反推 原计算过程
eg: C = t o r c h . e i n s u m ( " i j k , k l − > i j l " , A , B ) C i j l = ∑ k A i j k B k l = A i j k B k l C = torch.einsum("ijk,kl->ijl",A,B) \qquad C_{ijl}=\sum_k \red{A_{ijk}} \blue{B_{kl}}=\red{A_{ijk}} \blue{B_{kl}} C=torch.einsum("ijk,kl−>ijl",A,B)Cijl=∑kAijkBkl=AijkBkl
自由标: i j l ijl \qquad ijl 哑标: k k \qquad k 且自由标顺序out顺序提取 i j l ijl ijl
因此,einsum 可以计算不同维度的和,如 vector & matrix,matrix & tensor 等,只要有哑标即可,
输出维度按 out 顺序 for 循环
Einsum 求和过程理论上等价于如下四步:
- 维度对齐:将所有标记按字母序排序,按照标记顺序将输入张量逐一转置、补齐维度,使得处理后的所有张量其维度标记保持一致
- 广播点乘:以out标顺序为索引进行广播点乘「out标 为空时,对输入进行全量求和,返回标量」
⚠️注意: 若没有 「-> 和 out标」,则按照字母顺序自动调整不同维度
⚠️注意: 若有 -> 没有 out标,则对输入进行全量求和,返回标量- 维度规约:将哑标对应的维度分量求和消除「按照equation 哑标顺序」
- 转置输出:若存在out标,则按 out标 进行输出「若out标 为空时,对输入进行全量求和,返回标量」
einsum方法解析
Einsum 是爱因斯坦在研究广义相对论时,需要处理大量求和运算,为了简化这种繁复的运算,提出了求和约定,推动了张量分析的发展,具有重要意义!einsum 在Pytorch、TensorFlow、numpy中一个十分优雅的方法。Einsum 可以计算向量、矩阵、张量运算,包括计算 transposes、sum、column/row sum、Matrix-Vector Multiplication、Matrix-Matrix Multiplication
。如果利用得当,sinsum绝对是你科研路上的一把利器,可完全代替其他的矩阵计算方法。
1、einsum 公式约定
爱因斯坦求和是一种对求和公式简洁高效的记法
其原则是当变量下标重复出现时,即可省略繁琐的求和符号。
比如 矩阵点积 公式:
M i j = ∑ k A i k B k j = A i k B k j M = e i n s u m ( ′ i k , k j − > i j ′ , A , B ) M_{ij=}\sum_{k} A_{ik}B_{kj}=A_{ik}B_{kj} \qquad \color{red} \mathbf{M = einsum('ik,kj ->ij', A, B)} Mij=k∑AikBkj=AikBkjM=einsum(′ik,kj−>ij′,A,B)
哑标: 必须是重复一次的,且在每一项中的重复次数不能多于1次;含义就是虚设的指标,只是临时性的,经过求和之后就消失了;
自由标: 在表达式的每一项中,出现一次 且 仅出现一次,用同一字母,表示方程或变量的数目,并不作求和运算;
Einsum 标记的约定
- 维度分量下标:张量的维度分量下标使用英文字母表示,不区分大小写,如’ijk’表示张量维度分量为i,j,k
- 下标对应输入操作数:维度下标以
,
分段,按顺序1-1对应输入操作数 - 广播维度:省略号
...
表示维度的广播分量,例如,‘i…j’ 表示首末分量除外的维度需进行广播对齐 - 自由标和哑标:输入标记中仅出现一次的下标为自由标,重复出现的下标为哑标,哑标对应的维度分量将被规约消去
- 输出:输出张量的维度分量既可由输入标记自动推导,也可以用输出标记定制化
- 自动推导输出
广播维度分量位于维度向量高维位置,自由标维度分量按字母顺序排序,位于维度向量低纬位置,哑标维度分量不输出 - 定制化输出
若输出包含广播维度,则输出标记需包含...
哑标出现在输出标记中则自动提升为自由标
输出标记中未出现的自由标被降为哑标
- 自动推导输出
# 自由标
for i in range(3):
for j in range(4):
for l in range():
# 哑标 求和过程
total = 0
for k in range(5):
total += A[i,j,k] * B[k,l]
M[i,j] = total
2、torch.einsum() 方法原理
Sums the product of the elements of the input operands along dimensions specified using a notation based on the Einstein summation convention.
einsum方法正是利用了爱因斯坦求和简介高效的表示方法,从而可以驾驭任何复杂的矩阵计算操作。基本的框架如下:
C = einsum('ij,jk->ik', A, B)
上述操作表示矩阵A与矩阵B的点积。
输入的参数分为两部分
- equation (str): 求和标记 计算操作的字符串
- operands (Tensor, [Tensor, …]): 输入张量 操作对象(数量及维度需与前面对应)
3、向量操作
Let A and B be two 1D arrays of compatible shapes (meaning the lengths of the axes we pair together either equal, or one of them has length 1):
参数 | 数学含义 | 描述 |
---|---|---|
(‘i’, A) | A A A | 返回A的视图 |
(‘i->’, A) | s u m ( A ) sum(A) sum(A) | A的元素总和 |
(‘i,i->i’, A, B) | A ∗ B A * B A∗B | A与B 逐元素依次相乘 |
(‘i,i’, A, B) | i n n e r ( A , B ) inner(A, B) inner(A,B) | A与B的 点积(内积) |
(‘i,j->ij’, A, B) | o u t e r ( A , B ) outer(A, B) outer(A,B) | A与B的 外积(叉积) |
4、矩阵操作
Now let A and B be two 2D arrays with compatible shapes:
参数 | 数学含义 | 描述 |
---|---|---|
(‘ij’, A) | A A A | 返回A的视图 |
(‘ji’, A) | A T A^T AT | A的转置 |
(‘ii->i’, A) | d i a g ( A ) diag(A) diag(A) | A的主对角线 |
(‘ii’, A) | t r a c e ( A ) trace(A) trace(A) | A的迹 |
(‘ij->’, A) | s u m ( A ) sum(A) sum(A) | A的值累加和 |
(‘ij->i’, A) | s u m ( A , a x i s = 1 ) sum(A, axis=1) sum(A,axis=1) | 对A的行(水平轴)求和 |
(‘ij->j’, A) | s u m ( A , a x i s = 0 ) sum(A, axis=0) sum(A,axis=0) | 对A的列(竖直轴)求和 |
(‘ij,ij->ij’, A, B) | A ∗ B A * B A∗B | A与B逐元素依次相乘 |
(‘ij,ji->ij’, A, B) | A ∗ B T A * B^T A∗BT | A与B的转置逐元素依次相乘 |
(‘ij,jk’, A, B) | d o t ( A , B ) dot(A, B) dot(A,B) | A与B 的点积 |
(‘ij,kj->ik’, A, B) | i n n e r ( A , B ) inner(A, B) inner(A,B) | A与B 的内积 |
(‘ij,kj->ijk’, A, B) | A [ : , N o n e ] ∗ B A[:, None] * B A[:,None]∗B | A的每一行乘以B |
(‘ij,kl->ijkl’, A, B) | A [ : , : , N o n e , N o n e ] ∗ B A[:, :, None, None] * B A[:,:,None,None]∗B | A的每个值乘以B |
When working with larger numbers of dimensions, keep in mind that einsum allows the ellipses syntax ‘…’. This provides a convenient way to label the axes we’re not particularly interested in, e.g. np.einsum(’…ij,ji->…’, a, b) would multiply just the last two axes of a with the 2D array b. There are more examples in the documentation.
5、实例
einsum方法在numpy和pytorch中均有内置,这里以pytorch为例,首先定义一些需要用到的变量:
import torch
from torch import einsum
a = torch.rand((3,4))
b = torch.rand((4,5))
c = torch.rand((6,7,8))
d = torch.rand((3,4))
x, y = torch.randn(5), torch.randn(5)
# 计算矩阵/张量 sum
einsum('i,j->', a) #等价于 einsum('i,j', a)
einsum('i,j,k', b)
# 计算矩阵的迹「ps:行=列」
einsum('ii->', a)
# 获取矩阵对角线元素组成的向量「ps:行=列」
einsum('ii->i', a)
# 向量相乘得到矩阵 Vector-Vector Multiplication
einsum('i,j->ij', x, y)
# 矩阵与向量相乘 的到矩阵 Matrix-Vector Multiplication
einsum('ij,kj->ik',b, x)
# 矩阵点积 Matrix-Matrix Multiplication
einsum('ij,jk->ik', a, b)
# 矩阵对应元素相乘
einsum('ij,ij->ij', a, d)
# 矩阵的 transposes
einsum('ijk->ikj', c)
einsum('...jk->...kj', c) # 两种形式等价
# 双线性运算
A = torch.randn(3,5,4)
l = torch.randn(2,5)
r = torch.randn(2,4)
torch.einsum('bn,anm,bm->ba', l, A, r)
5.1 MATRIX TRANSPOSE
B j i = A i j B_{ji}=A_{ij} Bji=Aij
import torch
a = torch.arange(6).reshape(2, 3)
torch.einsum('ij->ji', [a])
tensor([[ 0., 3.],
[ 1., 4.],
[ 2., 5.]])
5.2 SUM
b = ∑ i ∑ j A i j = A i j b=\sum_i\sum_j A_{ij}=A_{ij} b=i∑j∑Aij=Aij
a = torch.arange(6).reshape(2, 3)
torch.einsum('ij->', [a])
tensor(15.)
5.3 COLUMN SUM
b j = ∑ i A i j = A i j b_j=\sum_iA_{ij}=A_{ij} bj=i∑Aij=Aij
a = torch.arange(6).reshape(2, 3)
torch.einsum('ij->j', [a])
tensor([ 3., 5., 7.])
5.4 ROW SUM
b i = ∑ j A i j = A i j b_i=\sum_j A_{ij}=A_ij bi=j∑Aij=Aij
a = torch.arange(6).reshape(2, 3)
torch.einsum('ij->i', [a])
tensor([ 3., 12.])
5.5 MATRIX-VECTOR MULTIPLICATION
c i = ∑ k A i k b k = A i k b k c_i=\sum_k \red{A_{ik}} \blue{b_k}=\red{A_{ik}} \blue{b_k} ci=k∑Aikbk=Aikbk
a = torch.arange(6).reshape(2, 3)
b = torch.arange(3)
torch.einsum('ik,k->i', [a, b])
tensor([ 5., 14.])
5.6 MATRIX-MATRIX MULTIPLICATION
C i j = ∑ k A i k B k j = A i k B k j C_{ij}=\sum_k \red{A_{ik}} \blue{B_{kj}}=\red{A_{ik}} \blue{B_{kj}} Cij=k∑AikBkj=AikBkj
a = torch.arange(6).reshape(2, 3)
b = torch.arange(15).reshape(3, 5)
torch.einsum('ik,kj->ij', [a, b])
tensor([[ 25., 28., 31., 34., 37.],
[ 70., 82., 94., 106., 118.]])
5.7 DOT PRODUCT
Vector:
c
=
∑
i
a
i
b
i
=
a
i
b
i
c=\sum_i \red{a_i} \blue{b_i}=\red{a_i} \blue{b_i}
c=i∑aibi=aibi
a = torch.arange(3)
b = torch.arange(3,6) # -- a vector of length 3 containing [3, 4, 5]
torch.einsum('i,i->', [a, b])
tensor(14.)
Matrix:
c
=
∑
i
∑
j
A
i
j
B
i
j
=
A
i
j
B
i
j
c=\sum_i\sum_j \red{A_{ij}} \blue{B_{ij}}=\red{A_{ij}} \blue{B_{ij}}
c=i∑j∑AijBij=AijBij
a = torch.arange(6).reshape(2, 3)
b = torch.arange(6,12).reshape(2, 3)
torch.einsum('ij,ij->', [a, b])
tensor(145.)
5.8 HADAMARD PRODUCT
C i j = A i j B i j C_{ij}=\red{A_{ij}} \blue{B_{ij}} Cij=AijBij
a = torch.arange(6).reshape(2, 3)
b = torch.arange(6,12).reshape(2, 3)
torch.einsum('ij,ij->ij', [a, b])
tensor([[ 0., 7., 16.],
[ 27., 40., 55.]])
5.9 OUTER PRODUCT
C i j = a i b j C_{ij}=\red{a_i} \blue{b_j} Cij=aibj
a = torch.arange(3)
b = torch.arange(3,7) # -- a vector of length 4 containing [3, 4, 5, 6]
torch.einsum('i,j->ij', [a, b])
tensor([[ 0., 0., 0., 0.],
[ 3., 4., 5., 6.],
[ 6., 8., 10., 12.]])
5.10 BATCH MATRIX MULTIPLICATION
C i j l = ∑ k A i j k B i k l = A i j k B i k l C_{ijl}=\sum_k \red{A_{ijk}} \blue{B_{ikl}}=\red{A_{ijk}} \blue{B_{ikl}} Cijl=k∑AijkBikl=AijkBikl
a = torch.randn(3,2,5)
b = torch.randn(3,5,3)
torch.einsum('ijk,ikl->ijl', [a, b])
tensor([[[ 1.0886, 0.0214, 1.0690],
[ 2.0626, 3.2655, -0.1465]],
[[-6.9294, 0.7499, 1.2976],
[ 4.2226, -4.5774, -4.8947]],
[[-2.4289, -0.7804, 5.1385],
[ 0.8003, 2.9425, 1.7338]]])
5.11 TENSOR CONTRACTION
Batch matrix multiplication is a special case of a tensor contraction. Let’s say we have two tensors, an order-n tensor A ∈ R I 1 × ⋯ × I n \red{A}\in ℝ^{I_1×⋯×I_n} A∈RI1×⋯×In and an order-m tensor B ∈ R J 1 × ⋯ × I m \blue{B}∈ℝ^{J_1×⋯×I_m} B∈RJ1×⋯×Im. As an example, take n = 4 , m = 5 n=4, m=5 n=4,m=5 and assume that I 2 = J 3 a n d I 3 = J 5 I_2=J_3 and I_3=J_5 I2=J3andI3=J5. We can multiply the two tensors in these two dimensions (2 and 3 for A \red A A and 3 and 5 for B \blue B B) resulting in a new tensor C ∈ R I 1 × I 4 × J 1 × J 2 × J 4 C∈ℝ^{I_1×I_4×J_1×J_2×J_4} C∈RI1×I4×J1×J2×J4 as follows
C p s t u v = ∑ q ∑ r A p q r s B t u q v r = A p q r s B t u q v r C_{pstuv}=\sum_q\sum_r \red{A_{pqrs}} \blue{B_{tuqvr}}=\red{A_{pqrs}} \blue{B_{tuqvr}} Cpstuv=q∑r∑ApqrsBtuqvr=ApqrsBtuqvr
a = torch.randn(2,3,5,7)
b = torch.randn(11,13,3,17,5)
torch.einsum('pqrs,tuqvr->pstuv', [a, b]).shape
torch.Size([2, 7, 11, 13, 17])
5.12 BILINEAR TRANSFORMATION
As mentioned earlier, einsum can operate on more than two tensors. One example where this is used is bilinear transformation.
D
i
j
=
∑
k
∑
l
A
i
k
B
j
k
l
C
i
l
=
A
i
k
B
j
k
l
C
i
l
D_{ij}=\sum_k \sum_l \red{A_{ik}} \purple{B_{jkl}} \blue{C_{il}}=\red{A_{ik}} \purple{B_{jkl}} \blue{C_{il}}
Dij=k∑l∑AikBjklCil=AikBjklCil
a = torch.randn(2,3)
b = torch.randn(5,3,7)
c = torch.randn(2,7)
torch.einsum('ik,jkl,il->ij', [a, b, c])
np_out = np.empty((2, 5), dtype=np.float32)
for i in range(0, 2):
for j in range(0, 5):
# 求和索引内循环
# 这里是 k 和 l
sum_result = 0
for k in range(0, 3):
for l in range(0, 7):
sum_result += a[i, k] * b[j, k, l] * c[i, l]
np_out[i, j] = sum_result
tensor([[ 3.8471, 4.7059, -3.0674, -3.2075, -5.2435],
[-3.5961, -5.2622, -4.1195, 5.5899, 0.4632]])
参考文献
- https://rockt.github.io/2018/04/30/einsum
- https://ajcr.net/Basic-guide-to-einsum/
- https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/einsum_cn.html
- https://blog.csdn.net/qq_32768743/article/details/109131936
- https://dengbocong.blog.csdn.net/article/details/109566151
- https://zhuanlan.zhihu.com/p/46006162
更多推荐
所有评论(0)