einsum
참고 : <Aladdin youtube>
Einsum
은 Einstein Summation Convention으로 연산을 하는 방법입니다.
연산을 통해 내적(Dot products), 외적(Outer porducts), 전치(Transpose), 행렬곱(Matmul) 등을 표현할 수 있으며,
형태(dim, shape)을 관리할 때 매우 유용하다.
einsum은 numpy, torch, tensorflow에서 사용가능하다.
ex) numpy.einsum(), torch.einsum(), tensorflow.einsum()
- 간단하게 아래처럼 사용할 수 있습니다.(차원 표현으로
ijk...
으로 많이 사용됩니다.)
- a,b 중 같은 차원이라면 동일한 알파벳으로 입력해주기.
einsum의 통상적인 사용방법은 다음과 같습니다. torch인 a.shape==(2,3,4),b.shape(3,4,1)가 있다면,
torch.einsum(‘ijk , jka -> jki’ , [a,b])
결과는 [3,4,2] 라는 식으로 나옵니다.
- 수학적으로 표현하자면 너무 복잡해지니 예시를 통해 간단한 사용 방법을 익혀봅시다.
예시
1 2 3 4 5 6 7 8 9 10 11 12
|
import torch a = torch.arange(6).reshape(2, 3) print(a) torch.einsum('ij->ji', [a])
tensor([[0, 1, 2], [3, 4, 5]]) tensor([[0, 3], [1, 4], [2, 5]])
|
1 2 3 4 5
| a = torch.arange(6).reshape(2, 3) torch.einsum('ij->', [a])
tensor(15)
|
1 2 3 4 5 6 7 8 9
| a = torch.arange(6).reshape(2, 3) print(a) torch.einsum('ij->j', [a])
tensor([[0, 1, 2], [3, 4, 5]])
tensor([3, 5, 7])
|
1 2 3 4 5 6 7 8 9 10
| a = torch.arange(6).reshape(2, 3) print(a) torch.einsum('ij->i', [a])
tensor([[0, 1, 2], [3, 4, 5]])
tensor([ 3, 12])
|
1 2 3 4 5 6 7 8 9 10 11 12
|
a = torch.arange(6).reshape(2, 3) b = torch.arange(3) torch.einsum('ik,k->i', [a, b])
tensor([ 5, 14])
np.matmul(a,b)
tensor([ 5, 14])
|
1 2 3 4 5 6 7 8
|
a = torch.arange(6).reshape(2, 3) b = torch.arange(15).reshape(3, 5) torch.einsum('ik,kj->ij', [a, b])
tensor([[ 25, 28, 31, 34, 37], [ 70, 82, 94, 106, 118]])
|
1 2 3 4 5 6 7
|
a = torch.arange(3) b = torch.arange(3,6) torch.einsum('i,i->', [a, b])
tensor(14)
|
1 2 3 4 5 6 7
|
a = torch.arange(6).reshape(2, 3) b = torch.arange(6,12).reshape(2, 3) torch.einsum('ij,ij->', [a, b])
tensor(145)
|
1 2 3 4 5
|
a = torch.arange(6).reshape(2, 3) b = torch.arange(6,12).reshape(2, 3) torch.einsum('ij,ij->ij', [a, b])
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14
|
a = torch.arange(3) b = torch.arange(3,7) c = torch.einsum('i,j->ij', [a, b]) print(a.shape,b.shape,c.shape) c
torch.Size([3]) torch.Size([4]) torch.Size([3, 4]) tensor([[ 0, 0, 0, 0], [ 3, 4, 5, 6], [ 6, 8, 10, 12]])
|