torch.tril
torch.tril
torch.tril(input, diagonal=0, *, out=None)
행렬의 아래쪽 삼각형 부분 (2 차원 텐서) 또는 행렬의 배치 input 을 반환합니다 .[행렬의 오른쪽 부분을(0으로 만듬)]
attention 구조의 mask를 만들 때 많이 사용되는 함수입니다.
무슨 말인지 이해가 잘 안가실겁니다. 예제 출력 코드를 보면 바로 이해가 갈겁니다.
1 | a = torch.ones((5, 5)) |