The Math: Formal Algorithm and Complexity Analysis
This section formalises Ring Attention with precise notation and complexity analysis.
Prerequisites
Before reading this section, review:
- Scaled Dot-Product Attention
- Attention Complexity and KV Cache
- Online Softmax and Numerically Stable Attention
- Distributed Computing Basics
Part 1: Ring Attention Algorithm
Notation
- n: Total sequence length
- P: Number of GPUs
- i: GPU index (0 ≤ i < P)
- d: Embedding dimension
- d_head: Dimension per attention head
- n_heads: Number of attention heads
- Q_i ∈ ℝ^((n/P) × d_head): Query chunk on GPU i
- K_i ∈ ℝ^((n/P) × d_head): Key chunk on GPU i
- V_i ∈ ℝ^((n/P) × d_head): Value chunk on GPU i
Algorithm: Standard (Bidirectional) Attention
RING_ATTENTION(Q, K, V, n_heads, P):
Input: Query, Key, Value matrices (distributed across P GPUs)
Output: Attention output (distributed, each GPU stores its Q chunk's output)
for round = 0 to P-1 do
// Round-dependent KV index (circular)
kv_idx = (GPU_id - round) % P
// Blockwise attention: Q_i @ KV_kv_idx
// Using online softmax for numerical stability
LOCAL_SOFTMAX = online_softmax(Q_i @ K[kv_idx]^T / √d_head)
ATTENTION_OUT = LOCAL_SOFTMAX @ V[kv_idx]
// Accumulate output (sum across all P rounds)
OUTPUT[GPU_id] += ATTENTION_OUT
// Communicate: send own KV, receive neighbor's KV
SEND(K[GPU_id], V[GPU_id]) → GPU_(GPU_id + 1) % P
RECV(K_new, V_new) ← GPU_(GPU_id - 1) % P
UPDATE K[GPU_id], V[GPU_id] with received values
BARRIER() // Synchronise all GPUs
return OUTPUT
Algorithm: Causal (Autoregressive) Attention
In autoregressive generation (language modelling), query position t cannot attend to key position s if s > t.
RING_ATTENTION_CAUSAL(Q, K, V, n_heads, P):
for round = 0 to P-1 do
kv_idx = (GPU_id - round) % P
// Key addition: apply causal mask
// Mask out attention to future tokens
causal_mask = (position_q[:, None] >= position_kv[None, :])
// Compute scores
scores = Q_i @ K[kv_idx]^T / √d_head
// Apply causal mask: set future tokens to -∞
scores[~causal_mask] = -∞
// Blockwise attention with online softmax
LOCAL_SOFTMAX = online_softmax(scores)
ATTENTION_OUT = LOCAL_SOFTMAX @ V[kv_idx]
OUTPUT[GPU_id] += ATTENTION_OUT
// Communication...
SEND(...) → GPU_(GPU_id + 1) % P
RECV(...) ← GPU_(GPU_id - 1) % P
BARRIER()
return OUTPUT
Part 2: Online Softmax (Numerical Stability)
Standard softmax applied to each block independently:
softmax_block_b = exp(scores_b) / sum(exp(scores_b))
Problem: If you compute softmax for block 1, then block 2, the results don’t combine correctly. You’d need to recompute the full softmax across all blocks.
Solution: Online (Incremental) Softmax
Maintain running statistics:
- m (maximum): The largest score seen so far
- l (normalisation): The sum of exponentials (normalized)
- o (output): The accumulated attention output
For each block b:
// Step 1: Compute block-local statistics
m_b = max(Q @ K_b^T / √d_head)
exp_scores_b = exp(Q @ K_b^T / √d_head - m_b) // Stabilised
l_b = sum(exp_scores_b, axis=1) // Per-query sum
// Step 2: Update running max and normalisation
m_old = m // Previous global max
m_new = max(m_old, m_b) // New global max
l_new = l_old × exp(m_old - m_new) + l_b × exp(m_b - m_new)
// Adjust both old and new contributions by their relative position to the new max
// Step 3: Update output
o_new = (o_old × exp(m_old - m_new) + exp_scores_b @ V_b × exp(m_b - m_new)) / l_new
After P blocks, o_new contains the exact same result as computing:
softmax(concat([scores_1, ..., scores_P])) @ concat([V_1, ..., V_P])
But computed incrementally, numerically stable, and parallelisable across blocks.
Part 3: Complexity Analysis
Memory Complexity Per GPU
Standard single-GPU full attention:
Memory = KV_cache + Q cache + intermediate activations
= 2 × n × d + n × d + O(n × d)
= O(n × d)
Ring Attention with P GPUs (per GPU):
Memory per GPU = KV_cache for chunk + Q chunk + activations
= 2 × (n/P) × d + (n/P) × d + O((n/P) × d)
= O((n/P) × d)
Memory reduction factor: P
For 1M tokens, P=8 GPUs:
Single GPU: 1M × 4096 bytes per position ≈ 4 GB (key) + 4 GB (value) = 8 GB per GPU
Ring Attention: (1M/8) × 4096 = 125K × 4096 ≈ 0.5 GB + 0.5 GB = 1 GB per GPU
Saving: 8× per GPU
Compute Complexity
Standard single-GPU attention:
Attention compute per layer:
Q @ K^T: O(n²d) operations
softmax @ V: O(n²d) operations
Total: O(n²d)
For n=1M, d=4096: 1M² × 4096 = 4.1 × 10^15 operations
Ring Attention with P GPUs:
Each GPU computes (n/P) × n FLOPs
Per GPU: O((n/P) × n × d) = O(n²d / P)
But only for its own queries. Across all P GPUs:
Total compute: P × O(n²d / P) = O(n²d)
No reduction in total computation — it’s just distributed across P devices.
Communication Complexity
Ring Attention:
In each of P rounds, each GPU sends (n/P) × d floats to next GPU
and receives (n/P) × d floats from previous GPU.
Per round: O((n/P) × d) data per GPU
P rounds: O(n × d) total data per GPU
But in practice, network bandwidth limits this:
Time = (n × d × bytes_per_float) / (network_bandwidth)
For n=1M, d=4096, bytes=2 (float16), bandwidth=200GB/s:
Time = (1M × 4096 × 2) / 200 GB
= 8 GB / 200 GB/s
≈ 0.04 seconds
Compute time (rough estimate):
(n/P)² × d FLOPs per GPU / (GPU_throughput)
= (125K)² × 4096 / (312 TFLOPS) [for H100]
≈ 1-2 seconds
Communication (0.04s) << Compute (1-2s), so communication is hidden.
Wall-Clock Time
If compute and communication overlap perfectly:
Wall time ≈ max(compute_time, communication_time)
≈ compute_time
≈ O(n²d / (P × GPU_throughput))
Compared to single GPU:
Single GPU wall time ≈ O(n²d / GPU_throughput)
Ring Attention wall time ≈ O(n²d / (P × GPU_throughput))
Speedup: ~P (linear in number of GPUs, if well-balanced)
Part 4: Worked Example - Complexity Calculation
Problem: Compute Ring Attention complexity for:
- n = 8 tokens (total sequence)
- P = 2 GPUs
- d_head = 4 (dimensions per head)
- 1 attention head (simplified)
Step 1: Memory Per GPU
GPU 0 holds: Q[0:4], K[0:4], V[0:4]
GPU 1 holds: Q[4:8], K[4:8], V[4:8]
Memory per GPU = 2 × (8/2) × 4 + (8/2) × 4
= 2 × 4 × 4 + 4 × 4
= 32 + 16
= 48 values per GPU
Single GPU (comparison): 2 × 8 × 4 + 8 × 4 = 96 values
Reduction: 96 / 48 = 2× (equals P)
Step 2: Attention FLOPs Per GPU Per Round
Each GPU computes:
Q_chunk @ K_chunk^T = (4 × 4) @ (4 × 4)^T
= 4 × 4 × 4 = 64 FLOPs (matrix multiply)
softmax @ V = (4 × 4) @ (4 × 4) = 64 FLOPs
Total per round = 128 FLOPs
Step 3: Total FLOPs Across Rounds
P rounds: 2 rounds
FLOPs per GPU = 128 × 2 = 256 FLOPs
Total FLOPs (all GPUs) = 256 × 2 = 512 FLOPs
Single GPU (comparison):
Q @ K^T = 8 × 8 × 4 = 256 FLOPs
softmax @ V = 8 × 8 × 4 = 256 FLOPs
Total = 512 FLOPs
No difference in total FLOPs — computation is just distributed.
Step 4: Communication
Round 1: GPU 0 sends K[0:4], V[0:4] (8 values) to GPU 1
GPU 0 receives K[4:8], V[4:8] (8 values) from GPU 1
Round 2: GPU 0 sends K[4:8], V[4:8] (received in round 1)
GPU 0 receives K[0:4], V[0:4]
Per GPU per round: 8 values sent + 8 values received = 16 values
2 rounds: 32 values total per GPU
For P GPUs: P × (n × d) = 2 × (8 × 4) = 64 values total
In this tiny example, communication and compute are comparable. In real examples (millions of tokens), compute dominates, so communication is hidden.
Part 5: Receptive Field and Long-Range Dependencies
Ring Attention maintains full receptive field without depth trade-off:
In Ring Attention, each query can attend to any key in the sequence,
regardless of depth. There's no receptive field growth across layers
like in Mistral's SWA.
Effective context = n (the full sequence)
No degradation with sequence length
This is the key advantage over Mistral 7B’s SWA approach.
Summary: Key Numbers
| Metric | Single GPU | Ring Attention (P GPUs) |
|---|---|---|
| Memory per GPU | O(n × d) | O((n/P) × d) |
| Total compute FLOPs | O(n²d) | O(n²d) distributed |
| Communication per GPU | None | O(n × d) |
| Wall-clock time (ideal) | O(n²d / throughput) | O(n²d / (P × throughput)) |
| Speedup | 1× | ~P× (linear) |
| Effective context | n | n (full) |
| Receptive field | n | n (no degradation) |