Multi-Head Attention
Multi-Head Attention runs scaled dot-product attention in parallel across multiple lower-dimensional subspaces, concatenates the results, and projects back to model dimension, enabling the model to capture diverse relational patterns simultaneously.
type: component title: "Multi-Head Attention" tags: [attention, component, transformer] part_of: ["Transformer"] created: 2025-01-01
Multi-Head Attention
Purpose
Multi-Head Attention runs Self-Attention (or Masked Self-Attention in decoder contexts) multiple times in parallel over lower-dimensional subspaces of the input, allowing the model to jointly attend to information from different representation subspaces at different positions. The results are concatenated and projected back to the model dimension.
Inputs and Outputs
- Input: A sequence of vectors of shape
(seq_len, d_model). - Output: A sequence of vectors of the same shape
(seq_len, d_model).
Internally, the input is projected into h sets of query, key, and value matrices, each of dimension d_k = d_v = d_model / h.
Implementation
# Simplified multi-head attention
def multi_head_attention(X, W_Q, W_K, W_V, W_O, h):
# X: (seq_len, d_model)
# Split into h heads by projecting to (seq_len, h, d_k) then processing per head
d_k = d_model // h
# Project inputs to Q, K, V for all heads simultaneously
Q = X @ W_Q # (seq_len, d_model)
K = X @ W_K
V = X @ W_V
# Reshape into (h, seq_len, d_k)
Q = reshape_to_heads(Q, h, d_k)
K = reshape_to_heads(K, h, d_k)
V = reshape_to_heads(V, h, d_k)
# Scaled dot-product attention per head
scores = softmax(Q @ K.transpose(-1, -2) / sqrt(d_k)) # (h, seq_len, seq_len)
head_outputs = scores @ V # (h, seq_len, d_k)
# Concatenate heads and project
concat = reshape_to_model_dim(head_outputs) # (seq_len, d_model)
output = concat @ W_O # (seq_len, d_model)
return output
Splitting into heads is implemented as a reshape operation on the Q, K, V vectors. For GPT-2 small, which has 12 attention heads and d_model=768, each head operates on 64-dimensional Q, K, V vectors.
Merging heads: After each head independently computes its attention output over its d_k-dimensional subspace, the outputs are concatenated along the feature dimension to produce a vector of dimension d_model. This concatenated vector is then multiplied by a learned output projection matrix W_O of shape (d_model, d_model) to produce the final multi-head attention output.
Design Choices and Hyperparameters
- Number of heads (
h): GPT-2 small uses 12 heads. More heads allow the model to attend to more distinct subspaces simultaneously. - Head dimension (
d_k = d_model / h): Kept proportional to model size so total parameter count scales predictably. - Output projection (
W_O): A learned linear transformation applied after concatenating all head outputs; this allows the model to learn how to best combine information from the different heads. - Masking: In decoder blocks, an attention mask is added to the scores before softmax to enforce causal (left-to-right) attention. See Masked Self-Attention.
- KV Caching: At inference time in autoregressive models, the key and value projections for previously seen tokens can be cached per layer and per head to avoid redundant computation.
Related Concepts
- Self-Attention: The underlying attention mechanism that each head computes.
- Masked Self-Attention: The causally masked variant used in decoder blocks.
- Key-Query-Value Projection: The three learned projections that produce the Q, K, V tensors for each head.
- Softmax Temperature: The
1/√d_kscaling applied before softmax within each head to prevent gradient saturation.
Notes
- The "splitting into heads" operation is conceptually a reshaping: the full Q, K, V projections are computed once and then divided across heads by reshaping, rather than using h separate weight matrices. In practice, this is equivalent but more computationally efficient.
- GPT-2 small has 12 heads × 64 dimensions = 768 d_model, confirming the relationship
d_k = d_model / h. - A common misunderstanding is that each head has entirely separate weight matrices. In most implementations, a single large matrix multiplication computes all heads' Q (or K or V) projections at once, and the result is reshaped.