The Code: Implementing Grouped Query Attention
This section shows a clean PyTorch implementation of Grouped Query Attention (GQA) and compares it to standard Multi-Head Attention (MHA). Every line is commented.
Code Block 1: GQA Attention Function
import torch
import torch.nn.functional as F
def gqa_attention(Q, K, V, n_heads, n_kv_heads):
"""
Grouped Query Attention (GQA)
Args:
Q: (batch, seq_len, d_embed)
K: (batch, seq_len, d_embed)
V: (batch, seq_len, d_embed)
n_heads: number of query heads (e.g., 32)
n_kv_heads: number of KV heads (e.g., 8)
Returns:
output: (batch, seq_len, d_embed)
"""
batch, seq_len, d_embed = Q.shape
d_head = d_embed // n_heads # dimension per head
# Reshape Q to (batch, seq_len, n_heads, d_head)
Q = Q.reshape(batch, seq_len, n_heads, d_head).transpose(1, 2)
# Now: (batch, n_heads, seq_len, d_head)
# Reshape K, V to (batch, seq_len, n_kv_heads, d_head)
K = K.reshape(batch, seq_len, n_kv_heads, d_head).transpose(1, 2)
V = V.reshape(batch, seq_len, n_kv_heads, d_head).transpose(1, 2)
# Now: (batch, n_kv_heads, seq_len, d_head)
# Each KV head serves (n_heads // n_kv_heads) query heads
group_size = n_heads // n_kv_heads
# Repeat KV heads to match number of Q heads
K = K.repeat_interleave(group_size, dim=1) # (batch, n_heads, seq_len, d_head)
V = V.repeat_interleave(group_size, dim=1)
# Compute scaled dot-product attention (standard formula)
scores = torch.matmul(Q, K.transpose(-2, -1)) / (d_head ** 0.5)
# scores: (batch, n_heads, seq_len, seq_len)
# Apply softmax to get attention weights
attn_weights = F.softmax(scores, dim=-1) # (batch, n_heads, seq_len, seq_len)
# Apply attention weights to values
output = torch.matmul(attn_weights, V) # (batch, n_heads, seq_len, d_head)
# Reshape back to (batch, seq_len, d_embed)
output = output.transpose(1, 2).contiguous()
output = output.reshape(batch, seq_len, d_embed)
return output
Code Block 2: Compare GQA vs MHA Memory Usage
import torch
# Configuration
batch_size = 1
seq_len = 8192
d_embed = 4096
n_heads = 32
n_kv_heads = 8
d_head = d_embed // n_heads
# Create input tensors
Q = torch.randn(batch_size, seq_len, d_embed)
K = torch.randn(batch_size, seq_len, d_embed)
V = torch.randn(batch_size, seq_len, d_embed)
# Standard MHA: store all n_heads KV heads
kv_cache_mha = 2 * n_heads * d_head * seq_len
print(f"MHA KV cache: {kv_cache_mha / (1024**2):.1f} MB")
# GQA: store only n_kv_heads KV heads
kv_cache_gqa = 2 * n_kv_heads * d_head * seq_len
print(f"GQA KV cache: {kv_cache_gqa / (1024**2):.1f} MB")
# Memory reduction
reduction_factor = kv_cache_mha / kv_cache_gqa
print(f"Memory reduction: {reduction_factor:.1f}x")
# Expected output:
# MHA KV cache: 1024.0 MB
# GQA KV cache: 256.0 MB
# Memory reduction: 4.0x
# Run GQA computation
output_gqa = gqa_attention(Q, K, V, n_heads=n_heads, n_kv_heads=n_kv_heads)
print(f"GQA output shape: {output_gqa.shape}") # (batch, seq_len, d_embed)
Code Block 3: Running on Google Colab
To run this code on Google Colab, paste into a cell and execute:
# (Copy code blocks 1 and 2 above into a Colab cell)
# Make sure torch is already installed (it is by default)
# This will print memory usage and output shape
No additional downloads or installs needed — PyTorch is pre-installed in Colab.
Explanation
Line-by-line breakdown of the GQA function:
-
Reshape to separate heads: Standard attention requires splitting embeddings into multiple heads for parallel computation.
-
Key insight (repeat_interleave): K and V have fewer heads than Q. We “repeat” each KV head across multiple Q heads. This is the core of GQA:
- K starts as (batch, n_kv_heads, seq, d_head) — 8 heads
- K becomes (batch, n_heads, seq, d_head) — 32 heads
- But the 8 heads are just repeated 4 times, so they’re shared
-
Scaled dot-product attention:
Q @ K^T / √d_headgives attention scores. This is standard. -
Softmax: Normalises scores to get probability-like weights.
-
Apply to values: Multiply weights by V to get the output.
-
Reshape back: Convert from multi-head format back to single embedding vector per token.
Memory Calculation Details
For Mistral 7B with GQA:
- n_heads = 32 query heads
- n_kv_heads = 8 key-value heads
- d_head = 128 dimensions per head
- seq_len = 8,192 tokens
Standard MHA KV cache:
KV_cache = 2 × 32 × 128 × 8,192 = 67,108,864 floats = 256 MB (float32)
GQA KV cache:
KV_cache = 2 × 8 × 128 × 8,192 = 16,777,216 floats = 64 MB (float32)
Memory saved: 192 MB per sample. With a batch of 32 samples, that’s 6 GB — enough to fit more samples on a 40GB GPU or enable inference on consumer hardware.
Key Takeaway
GQA achieves most of the quality of MHA (sometimes better, due to regularisation effect of shared parameters) while using a fraction of the memory. The implementation is trivial (just repeat KV heads), making it easy to adopt in production.
Mistral 7B’s success was largely due to this simple but effective architectural change.