The Idea: Making SSMs Input-Dependent
The Core Innovation
Instead of this (fixed SSM):
x_t = A x_{t-1} + B u_t (A, B fixed for all t)
y_t = C x_t
Mamba does this (selective SSM):
Δ_t = softplus(W_Δ u_t) (step size depends on input)
B_t = Linear(W_B u_t) (input matrix depends on input)
C_t = Linear(W_C u_t) (output matrix depends on input)
A_bar_t = exp(Δ_t A) (discretized, input-dependent)
B_bar_t = Δ_t B_t (scaled input matrix)
x_t = A_bar_t x_{t-1} + B_bar_t u_t
y_t = C_t x_t
Now the model decides, for each token:
- Δ_t: How much to “stretch” or “compress” time (large Δ = slower decay)
- B_t: How much this token should influence the state
- C_t: How to read out the state
Intuition: Selective Memory
Think of a student reading a textbook:
Fixed SSM student:
- Reads each word at a constant pace
- Forgets old information at a constant rate
- Can’t speed up for important parts or skim through filler
Mamba student:
- Sees the word “Definition:” → slows down (Δ increases, slower decay)
- Sees “for example” → speeds up (Δ decreases, faster decay)
- Important concepts → remember longer (high B)
- Filler words → weak signal (low B)
- Adjust reading emphasis based on content (C)
The student now matches or beats the attentive reader (Transformer) because it’s selective, not because it reads everything carefully.
Mathematical Details
Step 1: Compute Selective Parameters
For each input u_t, compute:
Δ_t = softplus(W_Δ u_t) ∈ ℝ (scalar, minimum ~0, no upper bound)
B_t = Linear_B(u_t) ∈ ℝ^N (N-dimensional, for state size N)
C_t = Linear_C(u_t) ∈ ℝ^N (N-dimensional, for output)
Where:
W_Δ ∈ ℝ^(1 × d) (projects input to step size)
Linear_B, Linear_C are learned layers
softplus(x) = log(1 + exp(x)) (smooth approximation of max(x, ε))
Why softplus? It’s smooth (differentiable everywhere) and maps to (≈0, ∞), suitable for time-step sizes.
Step 2: Discretization
The continuous SSM has a fixed A matrix (e.g., A = -I, which causes exponential decay). Discretize using the step size Δ_t:
Discretized: x_t = A_bar_t x_{t-1} + B_bar_t u_t
Where:
A_bar_t = exp(Δ_t A) ∈ ℝ^(N×N)
B_bar_t = (Δ_t A)^(-1) (exp(Δ_t A) - I) Δ_t B_t ≈ Δ_t B_t (approximation)
For small Δ_t:
A_bar_t ≈ I + Δ_t A (matrix exponential Taylor expansion)
B_bar_t ≈ Δ_t B_t
Key insight: A stays fixed (learned), but A_bar_t changes per token because Δ_t changes.
Step 3: State Update (Recurrent)
x_t = A_bar_t x_{t-1} + B_bar_t u_t (standard recurrence)
This is O(N²) per step if computed naively (matrix multiplication). But see Section 4 for optimization.
Step 4: Output
y_t = C_t x_t (element-wise product, O(N))
Why This Works: The Selectivity Effect
Example: Long-Range Dependency
Suppose we want the model to remember token 0 up to token T.
Fixed SSM:
x_t = A^t x_0 + (stuff from recent tokens)
If A has eigenvalues with magnitude < 1 (stable):
x_T ≈ A^T x_0 (decayed exponentially)
If A = -0.1, then (-0.1)^T ≈ 0 for T > 50
Mamba (selective):
At token 0 (important): Δ_0 is large → A_bar_0 ≈ exp(large A) → slower decay
At tokens 1-T (filler): Δ_t is small → A_bar_t ≈ exp(small A) → faster decay
x_T ≈ (small A_bar) × ... × (small A_bar) × (large A_bar) × x_0
≈ (product of many small scalars) × memory_of_token_0
But here's the trick: Δ values are learned!
If token 0 matters, the model learns to set Δ_0 very large.
If tokens 1-T don't matter, the model learns to set Δ_t small.
Result: Mamba remembers token 0 despite the distance.
Example: Input-Dependent Importance
Sequence: “Alice lives in Paris. It was a beautiful day. Where does Alice live?”
Token "Paris" (important):
u_t = "Paris"
B_t = Linear_B("Paris") = [high, high, high, ...] (large values)
→ x_t = A_bar_t x_{t-1} + [large, large, ...] × "Paris"
→ "Paris" strongly encoded in state
Token "beautiful" (filler):
u_t = "beautiful"
B_t = Linear_B("beautiful") = [low, low, low, ...] (small values)
→ x_t = A_bar_t x_{t-1} + [small, small, ...] × "beautiful"
→ "beautiful" weakly encoded in state
Model learns this selectivity from data!
Hardware Efficiency: The Parallel Scan Trick
Naively, computing x_t requires O(N²) work (matrix multiply). For a sequence of length n, that’s O(n N²) total — slow.
But there’s a trick: parallel scan algorithm (also called parallel prefix scan).
Training: Use Convolution
During training, Mamba reformulates the recurrence as a convolution:
x_t = A_bar x_{t-1} + B_bar u_t
Can be rewritten as:
y = conv(A_bar, B_bar * u) (using FFT, O(n log n) with FFT)
This allows parallelization across time steps during training, despite the recurrent structure.
Inference: Use Recurrence with Fused Kernels
During inference, we generate one token at a time, so the recurrence structure is unavoidable. But:
Standard way:
x_t = exp(Δ_t A) @ x_{t-1} + Δ_t B_t @ u_t
(requires two matrix multiplications per step)
Fused kernel (hardware-aware):
x_t = fused_kernel(A, Δ_t, B_t, u_t, x_{t-1})
(one optimized kernel call, minimizing memory movement)
Memory is the bottleneck on modern GPUs, not compute!
Fusing operations reduces memory traffic dramatically.
Result: Inference is 5× faster than Transformers for long sequences (2K+ tokens).
Architecture Recap
Input u_t
|
├─→ Linear (d_model → 2×d_model) [expand for B, C]
├─→ Linear (d_model → 1) → softplus → Δ_t
|
├─→ SiLU activation (gating)
|
└─→ B_t, C_t computed from expanded features
State x is maintained recurrently:
x_t = exp(Δ_t A) @ x_{t-1} + Δ_t B_t @ u_t
Output: y_t = (C_t @ x_t) * (gate(u_t))
y_t is projected back to d_model (via another Linear layer)
The gating (SiLU activation) is inspired by gating in RNNs and LSTMs — it allows the model to control information flow.
The Trade-off
What Mamba Gains:
- ✓ O(n) inference time (no O(n²) attention)
- ✓ O(1) memory per step (no KV cache)
- ✓ Selectivity (input-dependent parameters)
What Mamba Sacrifices:
- ✗ Less expressive attention (can’t attend backwards with full flexibility)
- ✗ Harder to implement (needs custom kernels)
- ✗ Recurrent structure (harder to parallelize than attention)
The Bet: Selectivity + efficiency > full attention flexibility, for most tasks.
Comparison with Transformers
| Aspect | Mamba | Transformer |
|---|---|---|
| Forward pass | Recurrent (sequential) | Parallel (batched) |
| Memory per step | O(N) state size | O(n·d) for KV cache |
| Inference latency | O(n) with O(1) per step | O(n) but parallelizable to O(1) per token with KV cache |
| Long-context speed | 5× faster at 2K tokens, 10× at 64K | Standard baseline |
| Context length limit | None (only limited by compute) | Practical limit ~128K (memory) |
| In-context recall | Struggles at very long range | Excellent (full attention) |
| Training speed | Similar to Transformers | Similar to Mamba |