이 코드는 MultiHeadSelfAttention
클래스를 정의하고 있으며, 이는 Transformer 아키텍처에서 중요한 역할을 하는 Multi-Head Self-Attention 메커니즘을 구현한 것입니다. 이 메커니즘은 입력 시퀀스에서 각 단어가 다른 단어들과의 관계를 학습할 수 있게 도와줍니다. 코드를 단계별로 설명하겠습니다.
__init__
method)def __init__(self, num_heads, hidden_size):
super().__init__()
self.num_heads = num_heads
self.attn_head_size = int(hidden_size / num_heads)
self.head_size = self.num_heads * self.attn_head_size
self.Q = nn.Linear(hidden_size, self.head_size)
self.K = nn.Linear(hidden_size, self.head_size)
self.V = nn.Linear(hidden_size, self.head_size)
self.dense = nn.Linear(self.head_size, hidden_size)
num_heads
: Attention heads의 수를 나타냅니다. Multi-head attention에서 서로 다른 부분의 정보를 캡처하기 위해 여러 개의 attention heads가 사용됩니다.hidden_size
: 입력의 hidden dimension 크기입니다.attn_head_size
: 각 head가 처리할 hidden dimension의 크기입니다. hidden_size
를 num_heads
로 나눈 값입니다.head_size
: 모든 head의 dimension 합입니다. num_heads
와 attn_head_size
의 곱입니다.self.Q
, self.K
, self.V
: 각각 Query, Key, Value를 계산하기 위한 선형 변환입니다.self.dense
: 모든 head의 출력을 결합한 후 다시 원래의 hidden_size로 변환하는 선형 변환입니다.tp_attn
method)def tp_attn(self, x):
x_shape = x.size()[:-1] + (self.num_heads, self.attn_head_size)
x = x.view(*x_shape)
return x.permute(0, 2, 1, 3)
tp_attn
함수: 입력 텐서를 (배치 크기, sequence 길이, hidden_size)에서 (배치 크기, num_heads, sequence 길이, attn_head_size)로 변환합니다. 그런 다음, 이를 permute
하여 차원을 (배치 크기, num_heads, sequence 길이, attn_head_size) 순서로 재배치합니다. 이는 이후 attention 계산을 위해 각 head 별로 데이터를 분리하는 과정입니다.forward
method)def forward(self, hidden_states):
Q, K, V = self.Q(hidden_states), self.K(hidden_states), self.V(hidden_states)
Q_layer, K_layer, V_layer = self.tp_attn(Q), self.tp_attn(K), self.tp_attn(V)
attn = torch.matmul(Q_layer, K_layer.transpose(-1, -2)) / math.sqrt(self.attn_head_size)
attn = nn.Softmax(dim=-1)(attn)
output = torch.matmul(attn, V_layer)
output = output.permute(0, 2, 1, 3).contiguous()
output_shape = output.size()[:-2] + (self.head_size,)
output = output.view(*output_shape)
Z = self.dense(output)
return Z
hidden_states
로부터 Query, Key, Value 벡터를 각각 Q
, K
, V
로 변환합니다.tp_attn
메서드를 사용하여 각 벡터(Q, K, V)를 multi-head 구조에 맞게 재배열합니다.attn
: Q_layer
와 K_layer
의 내적을 계산한 후, attn_head_size
의 제곱근으로 나누어 스케일을 조정합니다. 그런 다음, Softmax를 적용하여 각 Query에 대해 Key가 얼마나 관련 있는지를 나타내는 attention 값을 구합니다.output
: Attention 값과 V_layer
를 곱하여 최종 attention 출력을 얻습니다.permute
및 view
: 각 head의 출력을 결합하여 원래의 hidden_size로 변환할 수 있도록 차원을 변환합니다.Z
: 결합된 출력을 dense
레이어를 통해 최종 출력을 생성합니다.