Attention 코드로 구현하기

스케일드 닷-프로덕트 어텐션

  • 참조 : <위키독스>
  • 닷-프로덕트 어텐션(dot-product attention)에서 스케일링하는 것을 추가하면

    스케일드 닷-프로덕트 어텐션(Scaled dot-product Attention)이라고 합니다
  • scaled_dot_product_attention을 tensorflow로 구현, 살펴보겠습니다.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
def scaled_dot_product_attention(query, key, value, mask):
# query 크기 : (batch_size, num_heads, query의 문장 길이, d_model/num_heads)
# key 크기 : (batch_size, num_heads, key의 문장 길이, d_model/num_heads)
# value 크기 : (batch_size, num_heads, value의 문장 길이, d_model/num_heads)
# padding_mask : (batch_size, 1, 1, key의 문장 길이)

# Q와 K의 곱. 어텐션 스코어 행렬.
matmul_qk = tf.matmul(query, key, transpose_b=True)

# 스케일링
depth = tf.cast(tf.shape(key)[-1], tf.float32)
logits = matmul_qk / tf.math.sqrt(depth)

# 마스킹, 매우 작은 값이므로 소프트맥스 함수에 의해 0이 된다.
if mask is not None:
logits += (mask * -1e9)

# 소프트맥스 함수는 마지막 차원인 key의 문장 길이 방향으로 수행(axis=-1)
# attention weight : (batch_size, num_heads, query의 문장 길이, key의 문장 길이)
attention_weights = tf.nn.softmax(logits, axis=-1)

# output : (batch_size, num_heads, query의 문장 길이, d_model/num_heads)
output = tf.matmul(attention_weights, value)

return output, attention_weights

테스트

  • temp_q의 값 [0, 10, 0]은 Key에 해당하는 temp_k의 두번째 값 [0, 10, 0]과 일치하게 설정
1
2
3
4
5
6
7
8
9
10
11
12
13
14
import tensorflow as tf
import numpy as np
# 임의의 Query, Key, Value인 Q, K, V 행렬 생성
np.set_printoptions(suppress=True) #옵션 넣어줘야 보기 편함(소수점 반올림)
temp_k = tf.constant([[10,0,0],
[0,10,0],
[0,0,10],
[0,0,10]], dtype=tf.float32) # (4, 3)

temp_v = tf.constant([[1,0],
[10,0],
[100,5],
[1000,6]], dtype=tf.float32) # (4, 2)
temp_q = tf.constant([[0, 10, 0]], dtype=tf.float32) # (1, 3) #transpose_b
  • 어텐션 분포는 [0, 1, 0, 0]의 값을 가지며
  • Value의 두번째 값인 [10, 0]이 출력되는 것을 확인할 수 있습니다.
1
2
3
4
5
6
7
8
# 함수 실행
temp_out, temp_attn = scaled_dot_product_attention(temp_q, temp_k, temp_v, None)

print(temp_attn) # 어텐션 분포(어텐션 가중치의 나열)
# tf.Tensor([[0. 1. 0. 0.]], shape=(1, 4), dtype=float32)

print(temp_out) # 어텐션 값
# tf.Tensor([[10. 0.]], shape=(1, 2), dtype=float32)
Author

InhwanCho

Posted on

2023-01-08

Updated on

2023-01-08

Licensed under

Comments