Multi-Head Self-Attention
The Transformer blocks remove casual masks while take the attention masks. The attention mask indicating the padding tokens has the shape (B,T) for batch size B and sequence length T. Such masks are applied before softmax and often mask over key instead of query in attention computation. Moreover, an interesting implementation for multi-head attention is to initialize one matrix for all head and then reshape-reorder for attention score computation.
class MultiHeadSelfAttention(nn.Module):
"""Multi-head self-attention (bidirectional) with padding mask.
Inputs:
x: [B, T, H]
attention_mask: [B, T] with 1 for real tokens and 0 for padding.
Output:
y: [B, T, H]
Notes:
- No causal masking (BERT is bidirectional).
- Implement `scores.masked_fill(mask == 0, -1e4)` **before** softmax.
- Apply dropout to attention probabilities and to the final output projection.
"""
def __init__(self, hidden_size: int, num_heads: int, dropout_prob: float = 0.1):
super().__init__()
if hidden_size % num_heads != 0:
raise ValueError("hidden_size must be divisible by num_heads")
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
self.qkv = nn.Linear(hidden_size, 3 * hidden_size)
self.out = nn.Linear(hidden_size, hidden_size)
self.attn_drop = nn.Dropout(dropout_prob)
self.proj_drop = nn.Dropout(dropout_prob)
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
Q, K, V = torch.split(self.qkv(x), self.hidden_size, dim=-1)
B,T,D = x.shape
Q= Q.reshape(B,T,self.num_heads,self.head_dim).permute(0,2,1,3)
K= K.reshape(B,T,self.num_heads,self.head_dim).permute(0,2,1,3)
V= V.reshape(B,T,self.num_heads,self.head_dim).permute(0,2,1,3)
scores = Q @ K.permute(0,1,3,2) / (self.head_dim)**0.5
scores = scores.masked_fill(attention_mask.reshape(B,1,1,T) == 0, -1e4)
probs = torch.softmax(scores, dim=-1)
probs = self.attn_drop(probs)
O = probs @ V
O_ = self.out(O.permute(0,2,1,3).reshape(B,T,D))
return self.proj_drop(O_)