27  The Attention Mechanism

Attention is the core of Transformers. Let’s understand it deeply.

27.1 The Intuition

Consider: “The cat sat on the mat because it was tired.”

What does “it” refer to? The cat.

Attention lets each word “look at” all other words to gather context.

27.2 Query, Key, Value

Attention uses three projections:

  • Query (Q): What am I looking for?
  • Key (K): What do I contain?
  • Value (V): What do I return?

flowchart LR
    X[Input] --> Q[Query]
    X --> K[Key]
    X --> V[Value]
    Q --> Attn[Attention Scores]
    K --> Attn
    Attn --> Out[Weighted Sum]
    V --> Out

27.3 The Attention Formula

\[\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V\]

Step by step:

  1. Compute scores: \(QK^T\) (how similar is each query to each key)
  2. Scale: Divide by \(\sqrt{d_k}\) (prevents softmax saturation)
  3. Softmax: Convert to probabilities
  4. Weighted sum: Multiply by values

27.4 Implementation

import numpy as np
from tensorweaver import Tensor
from tensorweaver.nn.functional import softmax

def attention(Q, K, V, mask=None):
    """
    Scaled dot-product attention.

    Args:
        Q: Queries tensor (batch, seq, d_k)
        K: Keys tensor (batch, seq, d_k)
        V: Values tensor (batch, seq, d_v)
        mask: Optional attention mask tensor

    Returns:
        Attention output tensor (batch, seq, d_v)
    """
    d_k = Q.shape[-1]

    # Compute attention scores: Q @ K^T
    scores = Q @ K.transpose(-2, -1)

    # Scale by sqrt(d_k) to prevent softmax saturation
    scores = scores / np.sqrt(d_k)
    # scores: (batch, seq, seq)

    # Apply mask (for causal attention)
    if mask is not None:
        scores = scores + mask * (-1e9)

    # Softmax over keys (last dimension)
    weights = softmax(scores, axis=-1)
    # weights: (batch, seq, seq) - each row sums to 1

    # Weighted sum of values
    output = weights @ V
    # output: (batch, seq, d_v)

    return output

All inputs and outputs are Tensor objects, maintaining the computational graph for backpropagation.

27.5 Self-Attention

When Q, K, V all come from the same input:

class SelfAttention(Module):
    """Single-head self-attention."""

    def __init__(self, d_model):
        super().__init__()
        self.d_model = d_model

        self.W_q = Linear(d_model, d_model, bias=False)
        self.W_k = Linear(d_model, d_model, bias=False)
        self.W_v = Linear(d_model, d_model, bias=False)

    def forward(self, x, mask=None):
        """
        Args:
            x: Input (batch, seq, d_model)
            mask: Attention mask

        Returns:
            Output (batch, seq, d_model)
        """
        Q = self.W_q(x)
        K = self.W_k(x)
        V = self.W_v(x)

        return attention(Q, K, V, mask)

27.6 Causal Masking

For language models, we can’t see the future:

def create_causal_mask(seq_len):
    """
    Create causal attention mask.

    Position i can only attend to positions <= i.
    """
    mask = np.triu(np.ones((seq_len, seq_len)), k=1)
    return Tensor(mask)

Visualization:

       pos0  pos1  pos2  pos3
pos0  [  0     -∞    -∞    -∞  ]  ← can only see self
pos1  [  0      0    -∞    -∞  ]  ← can see pos0, pos1
pos2  [  0      0     0    -∞  ]  ← can see pos0-2
pos3  [  0      0     0     0  ]  ← can see all

27.7 Attention Weights Visualization

For “The cat sat”:

        The   cat   sat
The   [ 0.8   0.1   0.1 ]
cat   [ 0.3   0.5   0.2 ]
sat   [ 0.2   0.4   0.4 ]

Each row sums to 1 (softmax). Values show how much each position attends to others.

27.8 Why Scale by √d_k?

Without scaling:

d_k = 64
Q = np.random.randn(1, 10, d_k)
K = np.random.randn(1, 10, d_k)

scores = Q @ K.T  # Values can be large (~d_k)
# If d_k=64, scores could be ±8 on average

# Softmax saturates!
probs = softmax(scores)  # One value ~1, others ~0

With scaling:

scores = (Q @ K.T) / np.sqrt(d_k)  # Values ~1
probs = softmax(scores)  # Smoother distribution

27.9 Complete Causal Self-Attention

class CausalSelfAttention(Module):
    """Causal self-attention for language models."""

    def __init__(self, d_model, max_seq_len=1024):
        super().__init__()
        self.d_model = d_model

        self.W_q = Linear(d_model, d_model, bias=False)
        self.W_k = Linear(d_model, d_model, bias=False)
        self.W_v = Linear(d_model, d_model, bias=False)

        # Precompute causal mask
        mask = np.triu(np.ones((max_seq_len, max_seq_len)), k=1)
        self.register_buffer('mask', mask)

    def forward(self, x):
        batch, seq_len, _ = x.shape

        Q = self.W_q(x)
        K = self.W_k(x)
        V = self.W_v(x)

        # Attention with causal mask
        scores = (Q @ K.transpose(-2, -1)) / np.sqrt(self.d_model)
        scores = scores + self.mask[:seq_len, :seq_len] * (-1e9)
        weights = softmax(scores, axis=-1)
        output = weights @ V

        return output
Note

Code Reference: See src/tensorweaver/layers/causal_self_attention.py

27.10 Summary

  • Attention computes weighted sums based on similarity
  • Q, K, V: Query what to find, Key what’s available, Value what to return
  • Scaling: Divide by √d_k for stable softmax
  • Causal mask: Prevent looking at future tokens
  • Self-attention has O(n²) complexity with sequence length

Next: Multi-head attention for richer representations.