28  Multi-Head Attention

One attention head captures one pattern. Multiple heads capture more.

28.1 Why Multiple Heads?

Single attention might learn: “Look at nearby words”

Multiple heads can learn different patterns: - Head 1: Syntactic relationships (subject-verb) - Head 2: Coreference (pronouns to nouns) - Head 3: Positional (nearby words) - Head 4: Semantic similarity

28.2 The Architecture

flowchart TD
    X[Input] --> Split[Split into heads]
    Split --> H1[Head 1]
    Split --> H2[Head 2]
    Split --> H3[Head 3]
    Split --> H4[Head 4]
    H1 --> Concat[Concatenate]
    H2 --> Concat
    H3 --> Concat
    H4 --> Concat
    Concat --> Proj[Output Projection]
    Proj --> Out[Output]

28.3 Implementation

class MultiHeadAttention(Module):
    """Multi-head attention mechanism."""

    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads  # Dimension per head

        # Combined projection for Q, K, V (more efficient)
        self.W_qkv = Linear(d_model, 3 * d_model, bias=False)
        self.W_out = Linear(d_model, d_model, bias=False)

    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.shape

        # Project to Q, K, V
        qkv = self.W_qkv(x)  # (batch, seq, 3*d_model)

        # Split into Q, K, V
        qkv = qkv.reshape(batch_size, seq_len, 3, self.num_heads, self.d_k)
        qkv = qkv.transpose(2, 0, 3, 1, 4)  # (3, batch, heads, seq, d_k)
        Q, K, V = qkv[0], qkv[1], qkv[2]

        # Attention per head
        scores = (Q @ K.transpose(-2, -1)) / np.sqrt(self.d_k)

        if mask is not None:
            scores = scores + mask * (-1e9)

        weights = softmax(scores, axis=-1)
        attn_out = weights @ V  # (batch, heads, seq, d_k)

        # Concatenate heads
        attn_out = attn_out.transpose(0, 2, 1, 3)  # (batch, seq, heads, d_k)
        attn_out = attn_out.reshape(batch_size, seq_len, self.d_model)

        # Output projection
        output = self.W_out(attn_out)

        return output
Note

Code Reference: See src/tensorweaver/layers/multihead_attention.py for the implementation.

28.4 Dimension Math

For GPT-2 small (d_model=768, num_heads=12):

d_model = 768
num_heads = 12
d_k = 768 / 12 = 64

Input:  (batch, seq, 768)
Q,K,V:  (batch, heads, seq, 64) each
Scores: (batch, heads, seq, seq)
Output: (batch, seq, 768)

Each head operates on 64 dimensions independently.

28.5 Efficient Implementation

Combining Q, K, V projection is faster:

# Slow: three separate projections
Q = self.W_q(x)  # Forward pass 1
K = self.W_k(x)  # Forward pass 2
V = self.W_v(x)  # Forward pass 3

# Fast: one combined projection
qkv = self.W_qkv(x)  # Single forward pass
Q, K, V = qkv.split(3, dim=-1)

28.6 Causal Multi-Head Attention

For GPT, we need causal masking. Here’s a clean implementation that maintains gradient flow:

class CausalMultiHeadAttention(Module):
    """Multi-head attention with causal mask."""

    def __init__(self, d_model, num_heads, max_seq_len=1024):
        super().__init__()
        assert d_model % num_heads == 0

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        self.W_qkv = Linear(d_model, 3 * d_model, bias=False)
        self.W_out = Linear(d_model, d_model, bias=False)

        # Causal mask (registered as buffer, not parameter)
        mask = np.triu(np.ones((max_seq_len, max_seq_len)), k=1)
        self.mask = mask

    def forward(self, x):
        batch_size, seq_len, _ = x.shape

        # QKV projection
        qkv = self.W_qkv(x)  # (batch, seq, 3*d_model)

        # Reshape for multi-head attention
        # Use Tensor methods to maintain gradient tracking
        qkv = qkv.reshape(batch_size, seq_len, 3, self.num_heads, self.d_k)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # (3, batch, heads, seq, d_k)

        Q, K, V = qkv[0], qkv[1], qkv[2]

        # Scaled dot-product attention
        scores = (Q @ K.transpose(-2, -1)) / np.sqrt(self.d_k)

        # Apply causal mask
        scores = scores + Tensor(self.mask[:seq_len, :seq_len]) * (-1e9)

        # Softmax and weighted sum
        weights = softmax(scores, axis=-1)
        attn_out = weights @ V  # (batch, heads, seq, d_k)

        # Concatenate heads
        attn_out = attn_out.permute(0, 2, 1, 3)  # (batch, seq, heads, d_k)
        attn_out = attn_out.reshape(batch_size, seq_len, self.d_model)

        # Output projection
        output = self.W_out(attn_out)

        return output
Tip

Key Point: Always use Tensor methods (reshape, permute, transpose) instead of accessing .data directly. This preserves the computational graph for backpropagation.

28.7 Attention Pattern Visualization

Different heads learn different patterns:

Head 1 (positional):      Head 2 (syntactic):
    The cat sat              The cat sat
The [▓░░]               The [░▓░]
cat [▓▓░]               cat [▓░░]
sat [░▓▓]               sat [░▓░]

Darker = higher attention weight.

28.8 Parameter Count

For d_model=768, num_heads=12:

W_qkv: 768 × (3 × 768) = 1,769,472
W_out: 768 × 768 = 589,824
Total: 2,359,296 parameters per attention layer

28.9 Summary

  • Multi-head: Run multiple attention patterns in parallel
  • Split d_model: Each head gets d_model/num_heads dimensions
  • Concatenate: Combine all head outputs
  • Output projection: Final linear layer
  • Different heads learn different relationship types

Next: The complete Transformer block.