Worked Example: Grouped Query Attention in Action
This section walks through a complete, tiny example of Grouped Query Attention (GQA) step by step. You can verify every number by hand.
Setup
Parameters:
- Sequence length: n = 3 tokens
- Query heads: n_heads = 4
- KV heads: n_kv_heads = 2
- Dimension per head: d_head = 2
- Group size: 4 / 2 = 2 (queries 1–2 share KV head 1; queries 3–4 share KV head 2)
Input embeddings (3 × embedding_dim = 3 × 4):
Token 1: [1, 0, 1, 0]
Token 2: [0, 1, 0, 1]
Token 3: [1, 1, 1, 1]
Projection matrices (we’ll define small ones):
Query projections (4 heads, each maps 4-dim embedding → 2-dim):
W_Q1 = [[1, 0], W_Q2 = [[0, 1], W_Q3 = [[1, 0], W_Q4 = [[0, 1],
[0, 1], [1, 0], [1, 1], [1, 1],
[0, 0], [0, 0], [0, 0], [0, 0],
[0, 0]] [0, 0]] [0, 0]] [0, 0]]
KV projections (2 heads, each maps 4-dim → 2-dim):
W_K1 = [[1, 0], W_K2 = [[1, 1], W_V1 = [[1, 0], W_V2 = [[0, 1],
[0, 1], [1, 1], [0, 1], [1, 0],
[0, 0], [0, 0], [0, 0], [0, 0],
[0, 0]] [0, 0]] [0, 0]] [0, 0]]
Step 1: Project embeddings to Q, K, V
Query projections:
Q1(token 1) = [1, 0, 1, 0] @ W_Q1 = [1, 0]
Q1(token 2) = [0, 1, 0, 1] @ W_Q1 = [0, 1]
Q1(token 3) = [1, 1, 1, 1] @ W_Q1 = [1, 1]
Q2(token 1) = [1, 0, 1, 0] @ W_Q2 = [0, 1]
Q2(token 2) = [0, 1, 0, 1] @ W_Q2 = [1, 0]
Q2(token 3) = [1, 1, 1, 1] @ W_Q2 = [1, 1]
Q3(token 1) = [1, 0, 1, 0] @ W_Q3 = [1, 0]
Q3(token 2) = [0, 1, 0, 1] @ W_Q3 = [1, 1]
Q3(token 3) = [1, 1, 1, 1] @ W_Q3 = [2, 2]
Q4(token 1) = [1, 0, 1, 0] @ W_Q4 = [0, 1]
Q4(token 2) = [0, 1, 0, 1] @ W_Q4 = [2, 0]
Q4(token 3) = [1, 1, 1, 1] @ W_Q4 = [2, 2]
KV projections:
K1(token 1) = [1, 0, 1, 0] @ W_K1 = [1, 0]
K1(token 2) = [0, 1, 0, 1] @ W_K1 = [0, 1]
K1(token 3) = [1, 1, 1, 1] @ W_K1 = [1, 1]
K2(token 1) = [1, 0, 1, 0] @ W_K2 = [1, 1]
K2(token 2) = [0, 1, 0, 1] @ W_K2 = [2, 1]
K2(token 3) = [1, 1, 1, 1] @ W_K2 = [2, 2]
V1(token 1) = [1, 0, 1, 0] @ W_V1 = [1, 0]
V1(token 2) = [0, 1, 0, 1] @ W_V1 = [0, 1]
V1(token 3) = [1, 1, 1, 1] @ W_V1 = [1, 1]
V2(token 1) = [1, 0, 1, 0] @ W_V2 = [0, 1]
V2(token 2) = [0, 1, 0, 1] @ W_V2 = [1, 0]
V2(token 3) = [1, 1, 1, 1] @ W_V2 = [1, 1]
Step 2: Compute attention for each query head at token 3
We’ll compute attention for token 3 only (you can attend to all previous tokens including itself).
Query Head 1 (attends to KV head 1)
Q1(token 3) = [1, 1]
K1 sequence = [[1, 0], [0, 1], [1, 1]]
V1 sequence = [[1, 0], [0, 1], [1, 1]]
Compute attention scores:
Scores = Q1(3) @ K1^T / √2
= [1, 1] @ [[1, 0, 1],
[0, 1, 1]]^T / √2
= [1, 1] @ [[1, 0],
[0, 1],
[1, 1]]^T / √2
= [1×1 + 1×0, 1×0 + 1×1, 1×1 + 1×1] / √2
= [1, 1, 2] / √2
≈ [0.707, 0.707, 1.414]
Apply softmax:
exp([0.707, 0.707, 1.414]) ≈ [2.028, 2.028, 4.113]
sum = 8.169
softmax ≈ [2.028/8.169, 2.028/8.169, 4.113/8.169]
≈ [0.248, 0.248, 0.504]
Compute output:
Output1(token 3) = [0.248, 0.248, 0.504] @ V1 sequence
= [0.248, 0.248, 0.504] @ [[1, 0], [0, 1], [1, 1]]
= [0.248×1 + 0.248×0 + 0.504×1, 0.248×0 + 0.248×1 + 0.504×1]
= [0.248 + 0.504, 0.248 + 0.504]
= [0.752, 0.752]
Query Head 2 (attends to KV head 1 — same as head 1)
Q2(token 3) = [1, 1] (same as Q1 in this example)
Since Q2 and Q1 are identical and they both attend to the same K1, V1, the output will be the same:
Output2(token 3) = [0.752, 0.752]
(In a real model, Q1 and Q2 would be different, so outputs would differ.)
Query Head 3 (attends to KV head 2)
Q3(token 3) = [2, 2]
K2 sequence = [[1, 1], [2, 1], [2, 2]]
V2 sequence = [[0, 1], [1, 0], [1, 1]]
Compute attention scores:
Scores = [2, 2] @ [[1, 2, 2],
[1, 1, 2]]^T / √2
= [2×1 + 2×1, 2×2 + 2×1, 2×2 + 2×2] / √2
= [4, 6, 8] / √2
≈ [2.828, 4.243, 5.657]
Apply softmax:
exp([2.828, 4.243, 5.657]) ≈ [16.93, 69.80, 286.25]
sum = 372.98
softmax ≈ [16.93/372.98, 69.80/372.98, 286.25/372.98]
≈ [0.0454, 0.1872, 0.7675]
Compute output:
Output3(token 3) = [0.0454, 0.1872, 0.7675] @ V2 sequence
= [0.0454, 0.1872, 0.7675] @ [[0, 1], [1, 0], [1, 1]]
= [0.0454×0 + 0.1872×1 + 0.7675×1,
0.0454×1 + 0.1872×0 + 0.7675×1]
= [0.1872 + 0.7675, 0.0454 + 0.7675]
= [0.9547, 0.8129]
Query Head 4 (attends to KV head 2 — same as head 3)
Q4(token 3) = [2, 2] (same as Q3 in this example)
Output will be the same:
Output4(token 3) = [0.9547, 0.8129]
Step 3: Concatenate and project
Concatenate all 4 head outputs:
Concatenated = [Output1, Output2, Output3, Output4]
= [0.752, 0.752, 0.752, 0.752, 0.9547, 0.8129, 0.9547, 0.8129]
(8-dimensional, 4 heads × 2 dims each)
This is the raw output of the multi-head GQA layer. In a real model, this would be projected back to the original embedding dimension via a final output projection matrix W_O.
Key Insight: Memory Saving
In this example:
- Standard MHA: 4 query heads × 4 KV heads = 16 (K, V) pairs to store
- GQA: 4 query heads × 2 KV heads = 8 (K, V) pairs to store
- Reduction: 2×
Scale this to Mistral 7B (32 Q heads, 8 KV heads = 4× reduction), and you save massive memory during inference.
Verification by Hand
You can verify these calculations with a simple Python snippet:
import numpy as np
# Token 3 attention with head 3
Q3_t3 = np.array([2, 2])
K2_seq = np.array([[1, 1], [2, 1], [2, 2]])
V2_seq = np.array([[0, 1], [1, 0], [1, 1]])
scores = Q3_t3 @ K2_seq.T / np.sqrt(2)
print(f"Scores: {scores}") # [2.828, 4.243, 5.657]
weights = np.exp(scores) / np.sum(np.exp(scores))
print(f"Weights: {weights}") # [0.0454, 0.1872, 0.7675]
output = weights @ V2_seq
print(f"Output: {output}") # [0.9547, 0.8129]
This worked example shows exactly how GQA reduces parameters and memory while preserving the core attention mechanism. The trade-off in expressiveness is small (only 2× in this toy example, 4× in Mistral), and empirically, models trained with GQA maintain or exceed the quality of standard MHA.