Transformer Architecture

Sparse Attention

Sparse attention restricts each query to a structured subset of key positions, reducing attention complexity from quadratic to sub-quadratic and enabling Transformers to process much longer sequences.


type: concept title: "Sparse Attention" tags: [efficiency, attention, long-context, sparse] related: ["Self-Attention", "Multi-Head Attention", "Context Window", "Adaptive Attention Span"] created: 2023-01-27 source: "https://lilianweng.github.io/posts/2023-01-27-the-transformer-family-v2/"

Sparse Attention

Summary

Sparse attention restricts which positions each query can attend to, reducing the $O(L^2)$ cost of full Self-Attention to sub-quadratic complexity, enabling Transformers to handle much longer sequences.

How It Works

Given a connectivity pattern $S = {S_1, \ldots, S_n}$ where each $S_i$ records the key positions the $i$-th query attends to:

$$ \text{Attend}(X, S) = (a(x_i, S_i)){i \in {1,\ldots,L}} $$ $$ a(x_i, S_i) = \text{softmax}!\left(\frac{(x_i W^q)(x_j W^k)^\top{j \in S_i}}{\sqrt{d_k}}\right)(x_j W^v)_{j \in S_i} $$

Although $|S_i|$ varies, the output $a(x_i, S_i)$ is always of size $d_v$, so $\text{Attend}(X, S) \in \mathbb{R}^{L \times d_v}$.

Full causal attention has $S_i = {j : j \leq i}$. Sparse patterns reduce $|S_i|$ from $O(L)$ to $O(\sqrt{L})$ or $O(\log L)$ per query.

Role in the Transformer

Sparse attention is a drop-in replacement for the standard scaled dot-product attention in Multi-Head Attention. The rest of the Transformer block (layer norm, feed-forward, residuals) is unchanged.

Variants

Local / Fixed Window Attention (Image Transformer, Parmer et al., 2018)

Each query attends only to a local neighbourhood:

  • 1D local attention: Input is linearised (raster order); each query block attends to tokens in the same block plus a fixed number of prior tokens.
  • 2D local attention: Image is partitioned into rectangular query blocks; each query attends to an extended memory block including top, left, and right neighbours.

Strided Sparse Attention (Sparse Transformer, Child et al., 2019)

Factorized self-attention decomposes $S_i$ into $p$ non-overlapping subsets $A^{(m)}_i$, ensuring every pair $(i, j)$ with $j \leq i$ has a path of length at most $p + 1$. Two concrete patterns:

  • Strided attention (stride $\ell \sim \sqrt{n}$): Each query attends to the previous $\ell$ tokens (local) plus every $\ell$-th prior token (strided). Suited to data with periodic structure (e.g., images, music).
  • Fixed attention: A fixed set of positions attend to all positions in a local window; other positions attend to those special positions. Suited to text.

Adaptive Attention Span (Sukhbaatar et al., 2019)

Each attention head learns its own optimal span $z \in [0, s]$ via a soft mask function; see Adaptive Attention Span.

Key Papers

  • Parmer et al. (2018), "Image Transformer"
  • Child et al. (2019), "Generating Long Sequences with Sparse Transformers"
  • Sukhbaatar et al. (2019), "Adaptive Attention Span in Transformers"

Notes

The practical benefit of sparse attention depends heavily on hardware support for irregular memory access patterns. Strided and fixed patterns in Sparse Transformer are designed to map efficiently onto GPUs. Many follow-on works (Longformer, BigBird) build directly on these two patterns.