Transformer Architecture

Multi-Head Attention

Multi-Head Attention runs scaled dot-product attention in parallel across multiple lower-dimensional subspaces, concatenates the results, and projects back to model dimension, enabling the model to capture diverse relational patterns simultaneously.


type: component title: "Multi-Head Attention" tags: [attention, component, transformer] part_of: ["Transformer"] created: 2025-01-01

Multi-Head Attention

Purpose

Multi-Head Attention runs Self-Attention (or Masked Self-Attention in decoder contexts) multiple times in parallel over lower-dimensional subspaces of the input, allowing the model to jointly attend to information from different representation subspaces at different positions. The results are concatenated and projected back to the model dimension.

Inputs and Outputs

  • Input: A sequence of vectors of shape (seq_len, d_model).
  • Output: A sequence of vectors of the same shape (seq_len, d_model).

Internally, the input is projected into h sets of query, key, and value matrices, each of dimension d_k = d_v = d_model / h.

Implementation

# Simplified multi-head attention
def multi_head_attention(X, W_Q, W_K, W_V, W_O, h):
    # X: (seq_len, d_model)
    # Split into h heads by projecting to (seq_len, h, d_k) then processing per head
    d_k = d_model // h

    # Project inputs to Q, K, V for all heads simultaneously
    Q = X @ W_Q  # (seq_len, d_model)
    K = X @ W_K
    V = X @ W_V

    # Reshape into (h, seq_len, d_k)
    Q = reshape_to_heads(Q, h, d_k)
    K = reshape_to_heads(K, h, d_k)
    V = reshape_to_heads(V, h, d_k)

    # Scaled dot-product attention per head
    scores = softmax(Q @ K.transpose(-1, -2) / sqrt(d_k))  # (h, seq_len, seq_len)
    head_outputs = scores @ V  # (h, seq_len, d_k)

    # Concatenate heads and project
    concat = reshape_to_model_dim(head_outputs)  # (seq_len, d_model)
    output = concat @ W_O  # (seq_len, d_model)
    return output

Splitting into heads is implemented as a reshape operation on the Q, K, V vectors. For GPT-2 small, which has 12 attention heads and d_model=768, each head operates on 64-dimensional Q, K, V vectors.

Merging heads: After each head independently computes its attention output over its d_k-dimensional subspace, the outputs are concatenated along the feature dimension to produce a vector of dimension d_model. This concatenated vector is then multiplied by a learned output projection matrix W_O of shape (d_model, d_model) to produce the final multi-head attention output.

Design Choices and Hyperparameters

  • Number of heads (h): GPT-2 small uses 12 heads. More heads allow the model to attend to more distinct subspaces simultaneously.
  • Head dimension (d_k = d_model / h): Kept proportional to model size so total parameter count scales predictably.
  • Output projection (W_O): A learned linear transformation applied after concatenating all head outputs; this allows the model to learn how to best combine information from the different heads.
  • Masking: In decoder blocks, an attention mask is added to the scores before softmax to enforce causal (left-to-right) attention. See Masked Self-Attention.
  • KV Caching: At inference time in autoregressive models, the key and value projections for previously seen tokens can be cached per layer and per head to avoid redundant computation.

Related Concepts

Notes

  • The "splitting into heads" operation is conceptually a reshaping: the full Q, K, V projections are computed once and then divided across heads by reshaping, rather than using h separate weight matrices. In practice, this is equivalent but more computationally efficient.
  • GPT-2 small has 12 heads × 64 dimensions = 768 d_model, confirming the relationship d_k = d_model / h.
  • A common misunderstanding is that each head has entirely separate weight matrices. In most implementations, a single large matrix multiplication computes all heads' Q (or K or V) projections at once, and the result is reshaped.