kNN-Augmented Language Model
kNN-augmented language models combine a pretrained Transformer LM with nearest-neighbour retrieval over an external key-value datastore, interpolating or gating retrieved token probabilities with the model's own predictions to extend effective context far beyond the training window.
type: concept title: "kNN-Augmented Language Model" tags: [memory, retrieval, long-context, language-modeling] related: ["Transformer-XL", "Context Window", "Autoregressive Generation"] created: 2023-01-27 source: "https://lilianweng.github.io/posts/2023-01-27-the-transformer-family-v2/"
kNN-Augmented Language Model
Summary
kNN-augmented language models enhance a pretrained Transformer language model with a non-differentiable external key-value datastore, interpolating next-token predictions from the model with predictions derived from nearest-neighbour retrieval, allowing the model to leverage large corpora without retraining.
How It Works
kNN-LM (Khandelwal et al., 2020)
A separate $k$NN model is built over an external key-value store of $(\text{context embedding}, \text{next token})$ pairs. At inference, the next-token probability is a weighted interpolation:
$$ p(y | x) = \lambda, p_{\text{kNN}}(y | x) + (1 - \lambda), p_{\text{LM}}(y | x) $$ $$ p_{\text{kNN}}(y | x) \propto \sum_{(k_i, w_i) \in \mathcal{N}} \mathbf{1}[y = w_i] \exp(-d(k_i, f(x))) $$
where $\mathcal{N}$ is the set of $k$ nearest neighbours, $d(\cdot, \cdot)$ is L2 distance, and $f(x)$ is the LM's context embedding. The datastore can hold any large dataset; retrieval uses libraries such as FAISS. The interpolation scalar $\lambda$ should be larger for out-of-domain data.
SPALM (Yogatama et al., 2021)
Combines Transformer-XL-style hidden-state memory (short-term) with a $k$NN-LM datastore (long-term). Retrieved token embeddings ${y_i}_{i=1}^{k}$ are aggregated via an attention layer using the current hidden state $h_t^R$ as query, then gated with local information:
$$ m_t = \sum_{i=1}^{k} \frac{\exp(y_i^\top h_t^R)}{\sum_j \exp(y_j^\top h_t^R)} \cdot y_i, \quad g_t = \sigma(w_g^\top h_t^R) $$ $$ z_t = (1 - g_t) \odot m_t + g_t \odot h_t^R, \quad p(x_{t+1} | x_{\leq t}) = \text{softmax}(z_t; W) $$
Unlike $k$NN-LM, SPALM does not use retrieval distance in the aggregation. The word embedding matrix $W$ is shared between input and output and is updated during training while key representations remain frozen.
Memorizing Transformer (Wu et al., 2022)
Adds a $k$NN-augmented attention layer near the top of a decoder-only Transformer. The layer maintains a Transformer-XL-style FIFO cache of past key-value pairs. The same Q/K/V are used for both local attention and $k$NN lookup. Top-$k$ retrieved pairs are attended to via a standard attention stack, and the result is combined with local attention using a learnable per-head gating parameter. Keys and values in the cache are normalised to prevent distributional shift.
Key empirical findings:
- A Memorizing Transformer with 8k tokens in memory can match the perplexity of a 5× larger vanilla Transformer.
- Gains are consistent up to external memory sizes of 262k tokens.
- Fine-tuning a vanilla Transformer to use memory is competitive with training from scratch with memory.
Role in the Transformer
kNN augmentation is applied at inference (and optionally fine-tuning) without changing the core Transformer weights. It extends the effective Context Window to the size of the external datastore.
Variants
- kNN-LM — pure interpolation, no architectural change.
- SPALM — learned gating between short-term recurrent memory and long-term $k$NN store.
- Memorizing Transformer — $k$NN attention layer integrated into the Transformer stack with per-head gating.
Key Papers
- Khandelwal et al. (2020), "Generalization through Memorization: Nearest Neighbor Language Models"
- Yogatama et al. (2021), "Adaptive Semiparametric Language Models"
- Wu et al. (2022), "Memorizing Transformers"
Notes
The datastore in $k$NN-LM is non-differentiable; it functions as a retrieval index rather than a trained component. This makes it easy to update (by re-indexing) without retraining the language model, which is particularly valuable for domain adaptation and keeping knowledge current.