Section 04

The Math: Formal Algorithm and Complexity Analysis

Ring Attention with Blockwise Transformers for Near-Infinite Context 2023

The Math: Formal Algorithm and Complexity Analysis

This section formalises Ring Attention with precise notation and complexity analysis.

Prerequisites

Before reading this section, review:

Part 1: Ring Attention Algorithm

Notation

  • n: Total sequence length
  • P: Number of GPUs
  • i: GPU index (0 ≤ i < P)
  • d: Embedding dimension
  • d_head: Dimension per attention head
  • n_heads: Number of attention heads
  • Q_i ∈ ℝ^((n/P) × d_head): Query chunk on GPU i
  • K_i ∈ ℝ^((n/P) × d_head): Key chunk on GPU i
  • V_i ∈ ℝ^((n/P) × d_head): Value chunk on GPU i

Algorithm: Standard (Bidirectional) Attention

RING_ATTENTION(Q, K, V, n_heads, P):

  Input: Query, Key, Value matrices (distributed across P GPUs)
  Output: Attention output (distributed, each GPU stores its Q chunk's output)
  
  for round = 0 to P-1 do
    // Round-dependent KV index (circular)
    kv_idx = (GPU_id - round) % P
    
    // Blockwise attention: Q_i @ KV_kv_idx
    // Using online softmax for numerical stability
    
    LOCAL_SOFTMAX = online_softmax(Q_i @ K[kv_idx]^T / √d_head)
    ATTENTION_OUT = LOCAL_SOFTMAX @ V[kv_idx]
    
    // Accumulate output (sum across all P rounds)
    OUTPUT[GPU_id] += ATTENTION_OUT
    
    // Communicate: send own KV, receive neighbor's KV
    SEND(K[GPU_id], V[GPU_id]) → GPU_(GPU_id + 1) % P
    RECV(K_new, V_new) ← GPU_(GPU_id - 1) % P
    UPDATE K[GPU_id], V[GPU_id] with received values
    
    BARRIER()  // Synchronise all GPUs
  
  return OUTPUT

Algorithm: Causal (Autoregressive) Attention

In autoregressive generation (language modelling), query position t cannot attend to key position s if s > t.

RING_ATTENTION_CAUSAL(Q, K, V, n_heads, P):
  
  for round = 0 to P-1 do
    kv_idx = (GPU_id - round) % P
    
    // Key addition: apply causal mask
    // Mask out attention to future tokens
    causal_mask = (position_q[:, None] >= position_kv[None, :])
    
    // Compute scores
    scores = Q_i @ K[kv_idx]^T / √d_head
    
    // Apply causal mask: set future tokens to -∞
    scores[~causal_mask] = -∞
    
    // Blockwise attention with online softmax
    LOCAL_SOFTMAX = online_softmax(scores)
    ATTENTION_OUT = LOCAL_SOFTMAX @ V[kv_idx]
    OUTPUT[GPU_id] += ATTENTION_OUT
    
    // Communication...
    SEND(...) → GPU_(GPU_id + 1) % P
    RECV(...) ← GPU_(GPU_id - 1) % P
    BARRIER()
  
  return OUTPUT

Part 2: Online Softmax (Numerical Stability)

Standard softmax applied to each block independently:

softmax_block_b = exp(scores_b) / sum(exp(scores_b))

Problem: If you compute softmax for block 1, then block 2, the results don’t combine correctly. You’d need to recompute the full softmax across all blocks.

Solution: Online (Incremental) Softmax

Maintain running statistics:

  • m (maximum): The largest score seen so far
  • l (normalisation): The sum of exponentials (normalized)
  • o (output): The accumulated attention output

For each block b:

// Step 1: Compute block-local statistics
m_b = max(Q @ K_b^T / √d_head)
exp_scores_b = exp(Q @ K_b^T / √d_head - m_b)  // Stabilised
l_b = sum(exp_scores_b, axis=1)  // Per-query sum
// Step 2: Update running max and normalisation
m_old = m  // Previous global max
m_new = max(m_old, m_b)  // New global max
l_new = l_old × exp(m_old - m_new) + l_b × exp(m_b - m_new)
// Adjust both old and new contributions by their relative position to the new max
// Step 3: Update output
o_new = (o_old × exp(m_old - m_new) + exp_scores_b @ V_b × exp(m_b - m_new)) / l_new

After P blocks, o_new contains the exact same result as computing:

softmax(concat([scores_1, ..., scores_P])) @ concat([V_1, ..., V_P])

But computed incrementally, numerically stable, and parallelisable across blocks.

Part 3: Complexity Analysis

Memory Complexity Per GPU

Standard single-GPU full attention:

Memory = KV_cache + Q cache + intermediate activations
       = 2 × n × d + n × d + O(n × d)
       = O(n × d)

Ring Attention with P GPUs (per GPU):

Memory per GPU = KV_cache for chunk + Q chunk + activations
               = 2 × (n/P) × d + (n/P) × d + O((n/P) × d)
               = O((n/P) × d)

Memory reduction factor: P

For 1M tokens, P=8 GPUs:

Single GPU: 1M × 4096 bytes per position ≈ 4 GB (key) + 4 GB (value) = 8 GB per GPU
Ring Attention: (1M/8) × 4096 = 125K × 4096 ≈ 0.5 GB + 0.5 GB = 1 GB per GPU

Saving: 8× per GPU

Compute Complexity

Standard single-GPU attention:

Attention compute per layer:
  Q @ K^T: O(n²d) operations
  softmax @ V: O(n²d) operations
  Total: O(n²d)

For n=1M, d=4096: 1M² × 4096 = 4.1 × 10^15 operations

Ring Attention with P GPUs:

Each GPU computes (n/P) × n FLOPs
Per GPU: O((n/P) × n × d) = O(n²d / P)

But only for its own queries. Across all P GPUs:
Total compute: P × O(n²d / P) = O(n²d)

No reduction in total computation — it’s just distributed across P devices.

Communication Complexity

Ring Attention:

In each of P rounds, each GPU sends (n/P) × d floats to next GPU
  and receives (n/P) × d floats from previous GPU.

Per round: O((n/P) × d) data per GPU
P rounds: O(n × d) total data per GPU

But in practice, network bandwidth limits this:
  Time = (n × d × bytes_per_float) / (network_bandwidth)
  
For n=1M, d=4096, bytes=2 (float16), bandwidth=200GB/s:
  Time = (1M × 4096 × 2) / 200 GB
       = 8 GB / 200 GB/s
       ≈ 0.04 seconds

Compute time (rough estimate):

(n/P)² × d FLOPs per GPU / (GPU_throughput)
= (125K)² × 4096 / (312 TFLOPS)  [for H100]
≈ 1-2 seconds

Communication (0.04s) << Compute (1-2s), so communication is hidden.

Wall-Clock Time

If compute and communication overlap perfectly:

Wall time ≈ max(compute_time, communication_time)
          ≈ compute_time
          ≈ O(n²d / (P × GPU_throughput))

Compared to single GPU:

Single GPU wall time ≈ O(n²d / GPU_throughput)
Ring Attention wall time ≈ O(n²d / (P × GPU_throughput))

Speedup: ~P (linear in number of GPUs, if well-balanced)

Part 4: Worked Example - Complexity Calculation

Problem: Compute Ring Attention complexity for:

  • n = 8 tokens (total sequence)
  • P = 2 GPUs
  • d_head = 4 (dimensions per head)
  • 1 attention head (simplified)

Step 1: Memory Per GPU

GPU 0 holds: Q[0:4], K[0:4], V[0:4]
GPU 1 holds: Q[4:8], K[4:8], V[4:8]

Memory per GPU = 2 × (8/2) × 4 + (8/2) × 4
               = 2 × 4 × 4 + 4 × 4
               = 32 + 16
               = 48 values per GPU

Single GPU (comparison): 2 × 8 × 4 + 8 × 4 = 96 values
Reduction: 96 / 48 = 2× (equals P)

Step 2: Attention FLOPs Per GPU Per Round

Each GPU computes:

Q_chunk @ K_chunk^T = (4 × 4) @ (4 × 4)^T
                    = 4 × 4 × 4 = 64 FLOPs (matrix multiply)
softmax @ V = (4 × 4) @ (4 × 4) = 64 FLOPs
Total per round = 128 FLOPs

Step 3: Total FLOPs Across Rounds

P rounds: 2 rounds
FLOPs per GPU = 128 × 2 = 256 FLOPs
Total FLOPs (all GPUs) = 256 × 2 = 512 FLOPs

Single GPU (comparison):
Q @ K^T = 8 × 8 × 4 = 256 FLOPs
softmax @ V = 8 × 8 × 4 = 256 FLOPs
Total = 512 FLOPs

No difference in total FLOPs — computation is just distributed.

Step 4: Communication

Round 1: GPU 0 sends K[0:4], V[0:4] (8 values) to GPU 1
         GPU 0 receives K[4:8], V[4:8] (8 values) from GPU 1

Round 2: GPU 0 sends K[4:8], V[4:8] (received in round 1)
         GPU 0 receives K[0:4], V[0:4]

Per GPU per round: 8 values sent + 8 values received = 16 values
2 rounds: 32 values total per GPU

For P GPUs: P × (n × d) = 2 × (8 × 4) = 64 values total

In this tiny example, communication and compute are comparable. In real examples (millions of tokens), compute dominates, so communication is hidden.

Part 5: Receptive Field and Long-Range Dependencies

Ring Attention maintains full receptive field without depth trade-off:

In Ring Attention, each query can attend to any key in the sequence,
regardless of depth. There's no receptive field growth across layers
like in Mistral's SWA.

Effective context = n (the full sequence)
No degradation with sequence length

This is the key advantage over Mistral 7B’s SWA approach.

Summary: Key Numbers

MetricSingle GPURing Attention (P GPUs)
Memory per GPUO(n × d)O((n/P) × d)
Total compute FLOPsO(n²d)O(n²d) distributed
Communication per GPUNoneO(n × d)
Wall-clock time (ideal)O(n²d / throughput)O(n²d / (P × throughput))
Speedup~P× (linear)
Effective contextnn (full)
Receptive fieldnn (no degradation)