tp_attn
메서드는 입력 텐서를 재구성하여 multi-head attention 메커니즘에 맞게 데이터를 준비하는 중요한 단계입니다. 이 메서드를 통해 입력된 텐서는 여러 개의 attention head로 분할되고, 각 head는 특정 부분의 정보를 독립적으로 처리할 수 있습니다. 코드를 자세히 살펴보겠습니다.
x_shape
계산x_shape = x.size()[:-1] + (self.num_heads, self.attn_head_size)
x.size()[:-1]
: x
텐서의 마지막 차원(hidden dimension)을 제외한 나머지 차원들(일반적으로 batch 크기와 sequence 길이)을 가져옵니다. 예를 들어, 만약 x
의 크기가 (batch_size, seq_len, hidden_size)
라면, x.size()[:-1]
는 (batch_size, seq_len)
이 됩니다.self.num_heads
: Attention head의 수입니다. Multi-head attention에서는 여러 개의 head가 병렬적으로 동작합니다.self.attn_head_size
: 각 head가 담당할 hidden size입니다. 이것은 hidden_size
를 num_heads
로 나눈 값입니다.따라서, **x_shape
**는 (batch_size, seq_len, num_heads, attn_head_size)
가 됩니다. 이 과정에서 텐서는 원래의 hidden dimension을 각 head가 독립적으로 처리할 수 있도록 나눠서 재구성됩니다.
view
)x = x.view(*x_shape)
x.view(*x_shape)
: 텐서를 x_shape
에 맞게 재배열합니다. view
함수는 텐서의 데이터를 새로운 차원으로 해석하게 만듭니다. 이 경우, hidden_size
로 표현되던 마지막 차원을 num_heads
와 attn_head_size
로 나누어 새롭게 표현합니다.예를 들어, 원래 x
의 크기가 (batch_size, seq_len, hidden_size)
이고, hidden_size가 512이며, num_heads가 8이라고 가정합시다. 그러면 attn_head_size
는 512 / 8 = 64
가 됩니다. view
이후 텐서의 크기는 (batch_size, seq_len, 8, 64)
가 됩니다. 이는 각 sequence에서 512 차원의 데이터를 8개의 head로 나누어 64 차원씩 처리하게 만듭니다.
permute
)return x.permute(0, 2, 1, 3)
x.permute(0, 2, 1, 3)
: 텐서의 차원을 재배열합니다. permute
함수는 지정된 순서에 따라 텐서의 차원을 바꿉니다. 여기서 (0, 2, 1, 3)은:
따라서, 최종적으로 x
의 크기는 (batch_size, num_heads, seq_len, attn_head_size)
가 됩니다. 이 배치는 각 head가 독립적으로 attention을 계산할 수 있도록 데이터가 준비된 상태입니다.
tp_attn
함수는 입력 텐서를 multi-head attention에 적합한 형태로 변환합니다.x
는 (batch_size, num_heads, seq_len, attn_head_size)
의 형태를 가지며, 이는 이후 attention 계산에서 사용됩니다.이렇게 변환된 텐서는 각 head가 각각의 Query, Key, Value를 계산하고, 병렬적으로 attention을 수행하여 다양한 관점에서 입력 데이터의 관계를 학습할 수 있게 됩니다.