flowchart TD
X[Input] --> Split[Split into heads]
Split --> H1[Head 1]
Split --> H2[Head 2]
Split --> H3[Head 3]
Split --> H4[Head 4]
H1 --> Concat[Concatenate]
H2 --> Concat
H3 --> Concat
H4 --> Concat
Concat --> Proj[Output Projection]
Proj --> Out[Output]
28 Multi-Head Attention
One attention head captures one pattern. Multiple heads capture more.
28.1 Why Multiple Heads?
Single attention might learn: “Look at nearby words”
Multiple heads can learn different patterns: - Head 1: Syntactic relationships (subject-verb) - Head 2: Coreference (pronouns to nouns) - Head 3: Positional (nearby words) - Head 4: Semantic similarity
28.2 The Architecture
28.3 Implementation
class MultiHeadAttention(Module):
"""Multi-head attention mechanism."""
def __init__(self, d_model, num_heads):
super().__init__()
assert d_model % num_heads == 0
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads # Dimension per head
# Combined projection for Q, K, V (more efficient)
self.W_qkv = Linear(d_model, 3 * d_model, bias=False)
self.W_out = Linear(d_model, d_model, bias=False)
def forward(self, x, mask=None):
batch_size, seq_len, _ = x.shape
# Project to Q, K, V
qkv = self.W_qkv(x) # (batch, seq, 3*d_model)
# Split into Q, K, V
qkv = qkv.reshape(batch_size, seq_len, 3, self.num_heads, self.d_k)
qkv = qkv.transpose(2, 0, 3, 1, 4) # (3, batch, heads, seq, d_k)
Q, K, V = qkv[0], qkv[1], qkv[2]
# Attention per head
scores = (Q @ K.transpose(-2, -1)) / np.sqrt(self.d_k)
if mask is not None:
scores = scores + mask * (-1e9)
weights = softmax(scores, axis=-1)
attn_out = weights @ V # (batch, heads, seq, d_k)
# Concatenate heads
attn_out = attn_out.transpose(0, 2, 1, 3) # (batch, seq, heads, d_k)
attn_out = attn_out.reshape(batch_size, seq_len, self.d_model)
# Output projection
output = self.W_out(attn_out)
return outputCode Reference: See src/tensorweaver/layers/multihead_attention.py for the implementation.
28.4 Dimension Math
For GPT-2 small (d_model=768, num_heads=12):
d_model = 768
num_heads = 12
d_k = 768 / 12 = 64
Input: (batch, seq, 768)
Q,K,V: (batch, heads, seq, 64) each
Scores: (batch, heads, seq, seq)
Output: (batch, seq, 768)
Each head operates on 64 dimensions independently.
28.5 Efficient Implementation
Combining Q, K, V projection is faster:
# Slow: three separate projections
Q = self.W_q(x) # Forward pass 1
K = self.W_k(x) # Forward pass 2
V = self.W_v(x) # Forward pass 3
# Fast: one combined projection
qkv = self.W_qkv(x) # Single forward pass
Q, K, V = qkv.split(3, dim=-1)28.6 Causal Multi-Head Attention
For GPT, we need causal masking. Here’s a clean implementation that maintains gradient flow:
class CausalMultiHeadAttention(Module):
"""Multi-head attention with causal mask."""
def __init__(self, d_model, num_heads, max_seq_len=1024):
super().__init__()
assert d_model % num_heads == 0
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.W_qkv = Linear(d_model, 3 * d_model, bias=False)
self.W_out = Linear(d_model, d_model, bias=False)
# Causal mask (registered as buffer, not parameter)
mask = np.triu(np.ones((max_seq_len, max_seq_len)), k=1)
self.mask = mask
def forward(self, x):
batch_size, seq_len, _ = x.shape
# QKV projection
qkv = self.W_qkv(x) # (batch, seq, 3*d_model)
# Reshape for multi-head attention
# Use Tensor methods to maintain gradient tracking
qkv = qkv.reshape(batch_size, seq_len, 3, self.num_heads, self.d_k)
qkv = qkv.permute(2, 0, 3, 1, 4) # (3, batch, heads, seq, d_k)
Q, K, V = qkv[0], qkv[1], qkv[2]
# Scaled dot-product attention
scores = (Q @ K.transpose(-2, -1)) / np.sqrt(self.d_k)
# Apply causal mask
scores = scores + Tensor(self.mask[:seq_len, :seq_len]) * (-1e9)
# Softmax and weighted sum
weights = softmax(scores, axis=-1)
attn_out = weights @ V # (batch, heads, seq, d_k)
# Concatenate heads
attn_out = attn_out.permute(0, 2, 1, 3) # (batch, seq, heads, d_k)
attn_out = attn_out.reshape(batch_size, seq_len, self.d_model)
# Output projection
output = self.W_out(attn_out)
return outputKey Point: Always use Tensor methods (reshape, permute, transpose) instead of accessing .data directly. This preserves the computational graph for backpropagation.
28.7 Attention Pattern Visualization
Different heads learn different patterns:
Head 1 (positional): Head 2 (syntactic):
The cat sat The cat sat
The [▓░░] The [░▓░]
cat [▓▓░] cat [▓░░]
sat [░▓▓] sat [░▓░]
Darker = higher attention weight.
28.8 Parameter Count
For d_model=768, num_heads=12:
W_qkv: 768 × (3 × 768) = 1,769,472
W_out: 768 × 768 = 589,824
Total: 2,359,296 parameters per attention layer
28.9 Summary
- Multi-head: Run multiple attention patterns in parallel
- Split d_model: Each head gets d_model/num_heads dimensions
- Concatenate: Combine all head outputs
- Output projection: Final linear layer
- Different heads learn different relationship types
Next: The complete Transformer block.