einsum

einopseinsum 的使用——einsum

einsum是一个强大的执行爱因斯坦求和约定的函数,他几乎可以进行张量的所有运算,加上改变张量形状的einops库,它们可以使代码更加易读。它们的主要区别如下:

  1. 表达能力:
    • einops 提供了更广泛的张量操作,包括但不限于求和。它可以用于重排、合并、拆分等各种操作,使得代码更具灵活性
    • einsum 主要用于执行张量的求和操作,虽然可以通过合适的字符串表示式执行其他操作,但相对而言,其表达能力不如 einops 多样
  2. 直观性:
    • einops 的语法更加直观,它使用一些关键词来描述张量操作,使得代码更易读。
    • einsum 使用字符串表示式,这种表示方式相对于einops比较难懂
  3. 性能:
    • 在一些情况下,einops 可能在性能上稍逊于 einsumeinsum 的实现经过了高度优化,特别是在使用NumPy等库时,它可以通过C语言加速执行,因此在某些情况下可能更快
  4. 使用场景:
    • einops 更适合需要进行各种复杂张量操作的情况,特别是需要频繁改变张量形状的情况
    • einsum 更适合执行相对简单的张量求和操作,尤其是对于数学上常见的张量运算。

einsum的用法:

einsum实现矩阵乘法的例子:

1
2
3
4
a = torch.rand(2,3)
b = torch.rand(3,4)
c = torch.einsum("ik,kj->ij", [a, b])
# 等价操作 torch.mm(a, b)

einsum 的第一个参数 “ik,kj->ij”,该字符串(下文以 equation 表示)表示了输入和输出张量的维度。equation 中的箭头左边表示输入张量,以逗号分割每个输入张量,箭头右边则表示输出张量。表示维度的字符只能是26个英文字母 ‘a’ - ‘z’

而 einsum 的第二个参数表示实际的输入张量列表,其数量要与 equation 中的输入数量对应。同时对应每个张量的 子 equation 的字符个数要与张量的真实维度对应,比如 “ik,kj->ij” 表示输入和输出张量都是两维的。

基本规则:

  • 规则一,equation 箭头左边,在不同输入之间重复出现的索引表示,把输入张量沿着该维度做乘法操作,以上面矩阵乘法为例, “ik,kj->ij”,k 在输入中重复出现,所以就是把 a 和 b 沿着 k 这个维度作相乘操作
  • 规则二,只出现在 equation 箭头左边的索引,表示中间计算结果需要在这个维度上求和,也就是上面提到的求和索引
  • 规则三,equation 箭头右边的索引顺序可以是任意的,比如上面的 “ik,kj->ij” 如果写成 “ik,kj->ji”,那么就是返回输出结果的转置,用户只需要定义好索引的顺序,转置操作会在 einsum 内部完成

特殊规则

特殊规则有两条:

  • equation 可以不写包括箭头在内的右边部分(不推荐),那么在这种情况下,输出张量的维度会根据默认规则推导。就是把输入中只出现一次的索引取出来,然后按字母表顺序排列,比如上面的矩阵乘法 “ik,kj->ij” 也可以简化为 “ik,kj”,根据默认规则,输出就是 “ij” 与原来一样;
  • equation 中支持 “…” 省略号(不推荐),用于表示并不关心的索引,比如只对一个高维张量的最后两维做转置可以这么写:
1
2
3
a = torch.randn(2,3,5,7,9)
# i = 7, j = 9
b = torch.einsum('...ij->...ji', [a])

矩阵转置:

Bji=Aij\boldsymbol B_{ji} = \boldsymbol A_{ij}

1
2
a = torch.arange(6).reshape(2, 3)
torch.einsum('ij->ji', [a])

求和

b=ijAijb=\sum_i \sum_j \boldsymbol A_{ij}

1
2
a = torch.arrange(6).reshape(2, 3)
torch.einsum('ij->', [a])

点积

c=iaibic=\sum_i a_ib_i

1
2
3
a = torch.arrange(3)
b = torch.arrange(3, 6)
torch.einsum('i,i->', [a, b])

Hardmard积

Cij=AijBij\boldsymbol C_{ij}= \boldsymbol A_{ij} \boldsymbol B_{ij}

1
2
3
a = torch.arrange(6).reshape(2, 3)
b = torch.arrange(6, 12).reshape(2, 3)
torch.einsum('ij,ij->ij', [a, b])

外积

Cij=aibj\boldsymbol C_{ij} = \boldsymbol a_i \boldsymbol b_j

1
2
3
a = torch.arrange(3)
b = torch.arrange(3, 7)
torch.einsum('i,j->ij', [a, b])

batch矩阵相乘

Cijl=kAijkAikl\boldsymbol C_{ijl} = \sum_k \boldsymbol A_{ijk}A_{ikl}

1
2
3
a = torch.randn(3,2,5)
b = torch.randn(3,5,3)
torch.einsum('ijk,ikl->ijl', [a, b])

张量缩约(上面例子的一般版本)

Cpstuv=qrApqrsBtuqvr\boldsymbol C_{pstuv} = \sum_q \sum_r \boldsymbol A_{pqrs} \boldsymbol B_{tuqvr}

注意:在相乘维度上元素个数一定要相同

1
2
3
4
a = torch.randn(2,3,5,7)
b = torch.randn(11,13,3,17,5)
torch.einsum('pqrs,tuqvr->pstuv', [a, b]).shape
torch.Size([2, 7, 11, 13, 17])

双线性变换(einsum可用于超过两个张量的计算)

Dij=klAikBjklAil\boldsymbol D_{ij} = \sum_k \sum_l \boldsymbol A_{ik} \boldsymbol B_{jkl} \boldsymbol A_{il}

1
2
3
4
a = torch.randn(2,3)
b = torch.randn(5,3,7)
c = torch.randn(2,7)
torch.einsum('ik,jkl,il->ij', [a, b, c])

案例

img

用传统写法实现这些可要费不少力气,特别是考虑batch实现。einsum是我们的救星!

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# 参数
# -- [隐藏层维度]
bM, br, w = random_tensors([7], num=3, requires_grad=True)
# -- [隐藏层维度 x 隐藏层维度]
WY, Wh, Wr, Wt = random_tensors([7, 7], num=4, requires_grad=True)

# 注意力机制的单次应用
def attention(Y, ht, rt1):
# [i为batch大小]
tmp = torch.einsum("ik,kl->il", [ht, Wh]) + torch.einsum("ik,kl->il", [rt1, Wr])
Mt = F.tanh(torch.einsum("ijk,kl->ijl", [Y, WY]) + tmp.unsqueeze(1).expand_as(Y) + bM)
# -- [batch大小 x 序列长度]
at = F.softmax(torch.einsum("ijk,k->ij", [Mt, w]))
# -- [batch大小 x 隐藏层维度]
rt = torch.einsum("ijk,ij->ik", [Y, at]) + F.tanh(torch.einsum("ij,jk->ik", [rt1, Wt]) + br)
# -- [batch大小 x 隐藏层维度], [batch大小 x 序列维度]
return rt, at

# 取样仿造输入
# -- [batch大小 x 序列长度 x 隐藏层维度]
Y = random_tensors([3, 5, 7])
# -- [batch大小 x 隐藏层维度]
ht, rt1 = random_tensors([3, 7], num=2)
rt, at = attention(Y, ht, rt1)

总结

einsum是一个函数走天下,用例中可以看到,我们仍然需要在einsum之外应用非线性和构造额外维度。类似地,分割、连接、索引张量仍然需要应用其他库函数(einiops就是很好的选择)。

使用einsum的麻烦之处是你需要手动实例化参数,操心它们的初始化