이 코드는 MultiHeadSelfAttention 클래스를 정의하고 있으며, 이는 Transformer 아키텍처에서 중요한 역할을 하는 Multi-Head Self-Attention 메커니즘을 구현한 것입니다. 이 메커니즘은 입력 시퀀스에서 각 단어가 다른 단어들과의 관계를 학습할 수 있게 도와줍니다. 코드를 단계별로 설명하겠습니다.

1. 클래스 초기화 (__init__ method)

def __init__(self, num_heads, hidden_size):
    super().__init__()
    self.num_heads = num_heads
    self.attn_head_size = int(hidden_size / num_heads)
    self.head_size = self.num_heads * self.attn_head_size

    self.Q = nn.Linear(hidden_size, self.head_size)
    self.K = nn.Linear(hidden_size, self.head_size)
    self.V = nn.Linear(hidden_size, self.head_size)

    self.dense = nn.Linear(self.head_size, hidden_size)

2. 텐서 변환 (tp_attn method)

def tp_attn(self, x):
    x_shape = x.size()[:-1] + (self.num_heads, self.attn_head_size)
    x = x.view(*x_shape)
    return x.permute(0, 2, 1, 3)

3. Forward Pass (forward method)

def forward(self, hidden_states):
    Q, K, V = self.Q(hidden_states), self.K(hidden_states), self.V(hidden_states)
    Q_layer, K_layer, V_layer = self.tp_attn(Q), self.tp_attn(K), self.tp_attn(V)

    attn = torch.matmul(Q_layer, K_layer.transpose(-1, -2)) / math.sqrt(self.attn_head_size)
    attn = nn.Softmax(dim=-1)(attn)
    output = torch.matmul(attn, V_layer)

    output = output.permute(0, 2, 1, 3).contiguous()
    output_shape = output.size()[:-2] + (self.head_size,)
    output = output.view(*output_shape)

    Z = self.dense(output)

    return Z

전체 요약