KV Cache
The KV cache is the inference-time optimization that makes autoregressive
generation tractable. Without it, generating N tokens costs
O(N^3) work; with it, O(N^2). Every production
LLM inference engine uses one.
llama2.c/run.c is a minimal C
implementation you can read end to end.
Why generation is wasteful without it
At inference, generating one new token requires running the full transformer forward on the current prefix. For each attention layer, you compute:
- Q from all
Tpositions - K from all
Tpositions - V from all
Tpositions - The attention scores
Q @ K.Tof shape(T, T)
If T = 100 and you want to generate token 101, you compute
attention from scratch over positions 0..100. To generate token 102, you
compute attention from scratch over positions 0..101. The first 100
positions of K and V are identical to the previous step — you're
redoing the same work.
The cache
Insight: for any position i, the K and V vectors at that
position depend only on the residual stream at position i,
which depends only on tokens 0..i. So once you've computed
K[i] and V[i], they never change.
Cache them.
// from llama2.c/run.c
typedef struct {
// ...
float* key_cache; // (layer, seq_len, dim)
float* value_cache; // (layer, seq_len, dim)
} RunState;
Two big arrays sized (n_layers, seq_len, dim). Per layer, per
position, you store the K and V. At each new token:
- Compute Q, K, V for the new token only (one position, not all
T). - Write the new K and V into the cache at position
t. - Compute attention from the new Q against all cached K and V up to position t.
O(T2) to O(T) — you do one row of
attention (the new query against all cached keys), not a full
T × T attention matrix.
What Q looks like in inference
Only one query position per step:
// from llama2.c/run.c forward()
// qkv matmuls for this position
matmul(s->q, s->xb, w->wq + l*dim*dim, dim, dim);
matmul(s->k, s->xb, w->wk + l*dim*kv_dim, dim, kv_dim);
matmul(s->v, s->xb, w->wv + l*dim*kv_dim, dim, kv_dim);
// save key,value at this time step (pos) to our kv cache
int loff = l * p->seq_len * kv_dim;
float* key_cache_row = s->key_cache + loff + pos * kv_dim;
float* value_cache_row = s->value_cache + loff + pos * kv_dim;
memcpy(key_cache_row, s->k, kv_dim*sizeof(*key_cache_row));
memcpy(value_cache_row, s->v, kv_dim*sizeof(*value_cache_row));
s->q is a single vector (one position). The cache holds
the history.
Why no KV cache during training
Training uses teacher forcing: you have the entire
ground-truth target sequence, so you compute all positions in parallel in
a single forward pass. There's nothing to cache because you're not
generating sequentially. The (T, T) attention matrix is
computed once and dropped.
Flash attention matters
Because you do compute the full T × T matrix.
KV cache matters
Because you don't.
Memory cost of the cache
For a model with n_layer layers and kv_dim
channels in K and V (which equals dim for vanilla multi-head
and is smaller for
grouped-query attention):
cache size = 2 (K and V) * n_layer * seq_len * kv_dim * sizeof(dtype)
For Llama 2 7B with n_layer = 32, kv_dim = 4096
(no GQA), fp16, seq_len 4096:
This is why batch sizes for inference are so much smaller than for
training — the KV cache for each user's sequence dominates memory. Modern
tricks like GQA (grouped-query attention) and MQA (multi-query attention)
shrink kv_dim by 4–8× to fit more sequences in memory.
Paged KV cache
The naive cache layout ((n_layer, max_seq_len, dim) arrays)
wastes a lot of memory when most sequences are short.
PagedAttention (the optimization that powers vLLM) treats
the cache like virtual memory — physical KV blocks are allocated on demand
and indexed by a page table per sequence. Not in llama2.c or
llm.c but worth knowing about. It's the difference between
serving 1 user per A100 and serving 50.
Related
- attention
- what gets cached
- sampling
- what consumes the cache one token at a time
- repos/llama2-c
- the C implementation
- transformer-block
- where the cache lives, conceptually