Section 02

The Problem: Distributed Attention is Hard

Ring Attention with Blockwise Transformers for Near-Infinite Context 2023

The Problem: Distributed Attention is Hard

Standard Transformer attention assumes all tokens can be accessed from a single compute device. Distributing attention across GPUs introduces challenges that previous approaches didn’t solve cleanly.

Problem 1: Data Localisation

When you split a sequence across P GPUs:

GPU 0: queries [0:n/P], keys [0:n/P], values [0:n/P]
GPU 1: queries [n/P:2n/P], keys [n/P:2n/P], values [n/P:2n/P]
...
GPU P-1: queries [(P-1)n/P:n], keys [(P-1)n/P:n], values [(P-1)n/P:n]

But to compute attention correctly, every query must attend to every key-value pair.

  • Query on GPU 0 needs keys and values from all other GPUs
  • Query on GPU 1 needs keys and values from all other GPUs
  • And so on…

Naive solution: Broadcast all keys and values to all GPUs.

  • Memory cost: Each GPU stores the full KV cache (defeats the purpose)
  • Communication cost: O(n × P) transfers of KV data

This is worse than a single GPU!

Problem 2: Communication Overhead

If you broadcast KV data naively, the communication cost dominates:

All-to-all broadcast of KV to P GPUs:
  Time = (n × d × P) / (network_bandwidth)
  
For n=1M tokens, d=4096, P=8, network ~200 GB/s:
  Time = (1M × 4096 × 8 × 2 bytes) / (200 GB/s)
       = 64 GB / (200 GB/s)
       ≈ 0.32 seconds
       
Attention computation time: much less than 0.32s

Communication becomes the bottleneck, not computation. You’re waiting for data to arrive, not waiting for compute to finish.

Problem 3: Balancing Computation and Communication

Even if you optimise communication patterns, there’s a fundamental trade-off:

  • Less communication: Only send data you need → limited KV visibility → sparse/limited attention
  • More communication: Send all KV data → full attention → communication dominates wall time

Previous distributed attention schemes had to choose one.

Problem 4: Causal Masking (for Autoregressive Models)

In language generation, token t cannot attend to tokens beyond t (future tokens).

In a distributed setting with tokens split across GPUs:

GPU 0: tokens [0:250K]
GPU 1: tokens [250K:500K]
GPU 2: tokens [500K:750K]
GPU 3: tokens [750K:1M]

If GPU 0 is computing, tokens 0-250K can attend to tokens 0-250K,
but NOT to tokens 250K-1M (future).

But if tokens are circulating around the ring, how do you enforce this
when tokens from GPU 1 arrive at GPU 0?

Causal masking adds implementation complexity. You need to track which tokens have already been processed on earlier GPUs.

Problem 5: Load Imbalance

In a ring topology, if one GPU is slower, it becomes the bottleneck for the entire ring.

GPU 0: Fast (RTX 4090)
GPU 1: Fast (RTX 4090)
GPU 2: Slow (older GPU)
GPU 3: Fast (RTX 4090)

GPU 0 finishes its work and waits for GPU 2 to finish.
Effective speed = speed of slowest GPU.

This is the “straggler problem” in distributed systems. Unless all GPUs are identical, one slow device can drag down the entire computation.

Problem 6: Synchronisation Overhead

Ring Attention requires careful synchronisation between GPUs:

  1. GPU i computes attention with its local KV chunk
  2. GPU i sends its KV chunk to GPU i+1 (and receives from GPU i-1)
  3. Only when GPU i+1 receives the KV chunk can it proceed to the next round

If synchronisation is not managed carefully, GPUs will be idle waiting for each other.

Problem 7: Numerical Stability

Standard softmax in attention is:

weights = softmax(Q @ K^T / √d_head) @ V

Softmax is numerically unstable if scores vary wildly. In standard attention, you compute all scores at once, then softmax.

With blockwise attention (each block in isolation), you need to be careful:

Block 1: scores₁, softmax(scores₁)
Block 2: scores₂, softmax(scores₂)
...

But the correct softmax should be over all scores combined:
softmax([scores₁, scores₂, ..., scoresₚ])

If you apply softmax to each block independently, the weights don’t sum to 1 across blocks. This breaks attention semantics.

Solution: Use online softmax (also called incremental softmax or log-sum-exp trick). But this adds complexity.

Summary: Why Previous Approaches Failed

ApproachMemoryCommunicationFull Attention?Causal OK?
Single GPUO(n)None
Naive broadcastO(n) per GPUO(n × P)✓ (complex)
Sparse attentionO(n) per GPUVaries
Sliding windowO(W)O(n × W)
Sequence parallelism v1O(n/P)O(n × P)✗ (hard)
Ring AttentionO(n/P)O(n × d) pipelined

Ring Attention solves all of these: it achieves O(n/P) memory, O(n × d) communication (not O(n × P)), maintains full attention expressiveness, and makes causal masking tractable.

The key insight that makes it work: pipelined communication that hides latency through compute-communication overlap.