Einsum (Einstein Summation)

einsum

  • 참고 : <Aladdin youtube>

  • Einsum은 Einstein Summation Convention으로 연산을 하는 방법입니다.

  • 연산을 통해 내적(Dot products), 외적(Outer porducts), 전치(Transpose), 행렬곱(Matmul) 등을 표현할 수 있으며,

  • 형태(dim, shape)을 관리할 때 매우 유용하다.

einsum은 numpy, torch, tensorflow에서 사용가능하다.
ex) numpy.einsum(), torch.einsum(), tensorflow.einsum()

  • 간단하게 아래처럼 사용할 수 있습니다.(차원 표현으로 ijk... 으로 많이 사용됩니다.)
  • a,b 중 같은 차원이라면 동일한 알파벳으로 입력해주기.

einsum의 통상적인 사용방법은 다음과 같습니다. torch인 a.shape==(2,3,4),b.shape(3,4,1)가 있다면,

torch.einsum(‘ijk , jka -> jki’ , [a,b])

결과는 [3,4,2] 라는 식으로 나옵니다.

  • 수학적으로 표현하자면 너무 복잡해지니 예시를 통해 간단한 사용 방법을 익혀봅시다.

예시

1
2
3
4
5
6
7
8
9
10
11
12
# MATRIX TRANSPOSE

import torch
a = torch.arange(6).reshape(2, 3)
print(a)
torch.einsum('ij->ji', [a])

tensor([[0, 1, 2],
[3, 4, 5]])
tensor([[0, 3],
[1, 4],
[2, 5]])
1
2
3
4
5
# SUM
a = torch.arange(6).reshape(2, 3)
torch.einsum('ij->', [a])

tensor(15) # 6!
1
2
3
4
5
6
7
8
9
#  COLUMN SUM
a = torch.arange(6).reshape(2, 3)
print(a)
torch.einsum('ij->j', [a])

tensor([[0, 1, 2],
[3, 4, 5]])
# 0+3 , 1+4, 2+5
tensor([3, 5, 7])
1
2
3
4
5
6
7
8
9
10
# ROW SUM
a = torch.arange(6).reshape(2, 3)
print(a)
torch.einsum('ij->i', [a])


tensor([[0, 1, 2], #0+1+2->3
[3, 4, 5]]) #3+4+5->12

tensor([ 3, 12])
1
2
3
4
5
6
7
8
9
10
11
12
# MATRIX-VECTOR MULTIPLICATION

a = torch.arange(6).reshape(2, 3)
b = torch.arange(3)
torch.einsum('ik,k->i', [a, b])

tensor([ 5, 14])

# 행렬곱과 값이 동일
np.matmul(a,b)

tensor([ 5, 14])
1
2
3
4
5
6
7
8
# MATRIX-MATRIX MULTIPLICATION

a = torch.arange(6).reshape(2, 3)
b = torch.arange(15).reshape(3, 5)
torch.einsum('ik,kj->ij', [a, b])

tensor([[ 25, 28, 31, 34, 37],
[ 70, 82, 94, 106, 118]])
1
2
3
4
5
6
7
# DOT PRODUCT(vector)

a = torch.arange(3)
b = torch.arange(3,6) # [3, 4, 5]
torch.einsum('i,i->', [a, b])

tensor(14)
1
2
3
4
5
6
7
# DOT PRODUCT(matrix)

a = torch.arange(6).reshape(2, 3)
b = torch.arange(6,12).reshape(2, 3)
torch.einsum('ij,ij->', [a, b])

tensor(145)
1
2
3
4
5
# HADAMARD PRODUCT

a = torch.arange(6).reshape(2, 3)
b = torch.arange(6,12).reshape(2, 3)
torch.einsum('ij,ij->ij', [a, b])
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# OUTER PRODUCT

a = torch.arange(3)
b = torch.arange(3,7) #[3, 4, 5, 6]
c = torch.einsum('i,j->ij', [a, b])
print(a.shape,b.shape,c.shape)
c

torch.Size([3]) torch.Size([4]) torch.Size([3, 4])
tensor([[ 0, 0, 0, 0],
[ 3, 4, 5, 6],
[ 6, 8, 10, 12]])


Author

InhwanCho

Posted on

2023-01-05

Updated on

2023-01-05

Licensed under

Comments