torch.tril

torch.tril

  • torch.tril(input, diagonal=0, *, out=None)

  • 행렬의 아래쪽 삼각형 부분 (2 차원 텐서) 또는 행렬의 배치 input 을 반환합니다 .[행렬의 오른쪽 부분을(0으로 만듬)]

  • attention 구조의 mask를 만들 때 많이 사용되는 함수입니다.

  • 무슨 말인지 이해가 잘 안가실겁니다. 예제 출력 코드를 보면 바로 이해가 갈겁니다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
a = torch.ones((5, 5))
torch.tril(a)

tensor([[1., 0., 0., 0., 0.],
[1., 1., 0., 0., 0.],
[1., 1., 1., 0., 0.],
[1., 1., 1., 1., 0.],
[1., 1., 1., 1., 1.]])


a = torch.ones((5, 5))
torch.tril(a, diagonal=1)

tensor([[1., 1., 0., 0., 0.],
[1., 1., 1., 0., 0.],
[1., 1., 1., 1., 0.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]])

Author

InhwanCho

Posted on

2023-01-07

Updated on

2023-01-07

Licensed under

Comments