The Problem: Attention is Memory-Hungry
Standard Transformer attention is computationally elegant but memory-inefficient. Two specific problems limit its use in 2023:
Problem 1: The KV Cache Blowup
Multi-Head Attention (MHA) requires:
n_headsindependent K and V caches- Each cache stores
d_headfloats per token per layer
For a single-layer decoder with 32 heads and 128-dimensional heads, generating 8192 tokens:
KV cache = 2 layers × 32 heads × 128 dims × 8192 tokens
= 67,108,864 floats
= 256 MB per sample
Scale to 32 layers (the depth of models like LLaMA 2 7B), and one inference sample requires:
Total KV cache = 32 layers × 256 MB = 8.2 GB per sample
Run a batch of 32 samples (a typical inference server batch), and you need 256 GB of GPU memory just for KV caches — before loading model weights.
Why this is a real problem:
- A single GPU (H100 with 80 GB) can hold the Mistral 7B weights (~28 GB in float16) and KV caches for only ~3 samples at full context
- Longer sequences → linearly larger KV cache → quadratically fewer samples per GPU
- Mobile/edge inference becomes impossible
Problem 2: Quadratic Attention Complexity
Standard attention computes Q @ K^T where both Q and K have length n (sequence length).
This is O(n²) — the number of pairs to attend to grows quadratically.
For a 32-token sequence: 32 × 32 = 1,024 pairs For a 8,192-token sequence: 8,192 × 8,192 = 67,108,864 pairs
At higher layers or in longer documents, this becomes prohibitive. Not just memory — it’s also slow. Each of the 32 layers performs O(n²) operations.
Problem 3: Naive Solutions Don’t Work
Why can’t we just use smaller models? LLaMA 2 7B exists, but it underperforms 13B on reasoning tasks. The gap is real:
- LLaMA 2 7B: 43.35% on GSM8k (grade-school math)
- LLaMA 2 13B: 56.44% on GSM8k (30% better)
Why can’t we use Multi-Query Attention (MQA)? MQA solves the KV cache problem by sharing a single KV head across all query heads:
Standard MHA: n_heads KV heads (e.g., 32)
MQA: 1 KV head
KV cache reduction: 32×
But MQA breaks quality. By forcing all query heads to attend to the same key-value pairs, you lose expressiveness. An MQA 7B model underperforms standard 7B.
Why can’t we just limit context to short sequences? Many tasks require reasoning over longer contexts:
- Summarising documents (4K–32K tokens)
- Code understanding (full files can be thousands of tokens)
- Few-shot learning (context examples + prompt > 1K tokens)
A 512-token limit feels like writing novels on paper napkins.
The insight Mistral had
The solution isn’t to eliminate KV heads or ignore attention — it’s to compromise intelligently:
-
Grouped Query Attention (GQA): Share KV heads, but not completely. Use
n_kv_heads < n_headswheren_heads / n_kv_heads = 4or8. This preserves most of the expressiveness of MHA while achieving 4–8× KV cache reduction. -
Sliding Window Attention (SWA): Most of the information a token needs comes from nearby tokens, not the entire history. Limit each token to attend only to the last
Wtokens (e.g., W = 4,096). This reduces attention from O(n²) to O(n × W).
The two ideas work together. Combined, they reduce memory by 4–8× and compute by 4–8×, while maintaining or even improving quality.
The rest of the paper shows how.