The Math: Discretisation and Selective State Spaces
Prerequisites: Eigenvalues and Eigenvectors, Matrix Multiplication
Part 1: Continuous State Space Model
A linear time-invariant (LTI) system in continuous time:
dx/dt = A x(t) + B u(t) (state equation)
y(t) = C x(t) + D u(t) (output equation)
Where:
x(t) ∈ ℝ^N (hidden state at time t)
u(t) ∈ ℝ (input at time t)
y(t) ∈ ℝ (output at time t)
A ∈ ℝ^(N×N) (state transition matrix)
B ∈ ℝ^(N×1) (input projection)
C ∈ ℝ^(1×N) (output projection)
D ∈ ℝ (feedthrough, usually 0)
Example: Exponential Decay
Simplest case: dx/dt = -0.5 x(t) + u(t)
Solution: x(t) = e^(-0.5t) x(0) + ∫₀^t e^(-0.5(t-τ)) u(τ) dτ
Interpretation:
- Initial state x(0) decays as e^(-0.5t) (half-life ≈ 1.4 time units)
- Each input u(τ) contributes to the current state, weighted by e^(-0.5(t-τ))
- Older inputs decay faster (exponential decay)
Part 2: Discretisation (Zero-Order Hold)
In practice, we process discrete sequences (tokens), not continuous signals. We need to convert the continuous system to discrete time.
Assume the input u(t) is piecewise constant over intervals [t, t+Δt):
u(t) = u_k for t ∈ [k·Δt, (k+1)·Δt) (u_k is constant during interval k)
The solution from t_k to t_{k+1} = t_k + Δt is:
x(t+Δt) = e^(A·Δt) x(t) + ∫_0^Δt e^(A(Δt-τ)) B u(τ) dτ
Since u is constant:
x(t+Δt) = e^(A·Δt) x(t) + (∫_0^Δt e^(A(Δt-τ)) dτ) B u(t)
Compute the integral:
∫_0^Δt e^(A(Δt-τ)) dτ
Let s = Δt - τ, then:
∫_0^Δt e^(As) ds = ∫_0^A e^s ds|... (change of variables)
Using matrix exponential properties:
∫_0^Δt e^(A(Δt-τ)) dτ = A^(-1) (e^(A·Δt) - I)
So the discrete recurrence is:
x_k+1 = e^(A·Δt) x_k + A^(-1) (e^(A·Δt) - I) B u_k
Define Discretised Matrices
Let:
Ā = e^(A·Δt) (discretised state transition)
B̄ = A^(-1) (e^(A·Δt) - I) B or B̄ ≈ Δt·B (discretised input matrix, approximation)
The discrete system becomes:
x_k = Ā x_{k-1} + B̄ u_k
y_k = C x_k
This is now a discrete recurrence that we can compute step by step.
Part 3: Mamba’s Selective Discretisation
Mamba modifies the discretised system to be input-dependent:
Δ_k = softplus(W_Δ u_k) (step size depends on input, scalar)
B_k = Linear_B(u_k) (input matrix depends on input, N-dimensional)
C_k = Linear_C(u_k) (output matrix depends on input, N-dimensional)
Ā_k = e^(Δ_k A) (discretised state transition, input-dependent)
B̄_k = (Δ_k A)^(-1) (e^(Δ_k A) - I) Δ_k B_k ≈ Δ_k B_k
x_k = Ā_k x_{k-1} + B̄_k u_k
y_k = C_k x_k
Key differences from fixed SSM:
- Δ_k varies per token — allows adaptive memory decay
- B_k varies per token — allows input-dependent importance weighting
- C_k varies per token — allows input-dependent readout
- Ā_k varies per token — a consequence of Δ_k varying
Part 4: Simplification for Computation
Computing e^(Δ_k A) and its inverse is expensive. Mamba uses approximations:
For Small Δ_k (Taylor Expansion)
e^(Δ_k A) ≈ I + Δ_k A + (Δ_k A)²/2! + ...
For small Δ_k:
e^(Δ_k A) ≈ I + Δ_k A
Then:
A^(-1) (e^(Δ_k A) - I) ≈ A^(-1) (Δ_k A) = Δ_k I
So: B̄_k ≈ Δ_k B_k
For Diagonal A (Efficient Computation)
If A is diagonal (A = diag(a₁, a₂, …, a_N)):
e^(Δ_k A) = diag(e^(Δ_k a₁), e^(Δ_k a₂), ..., e^(Δ_k a_N))
Matrix multiply is cheap:
x_k = diag(...) x_{k-1} + Δ_k B_k u_k
This is just element-wise multiplication!
O(N) instead of O(N²).
Mamba uses diagonal SSMs to keep computation efficient.
Part 5: Worked Numerical Example
Let’s trace a complete example with real numbers.
Setup
State dimension: N = 2
Sequence length: T = 3
Input sequence: u = [1.0, 0.5, 2.0]
Fixed state matrix: A = [[-0.9, 0], [0, -0.8]] (stable, diagonal)
Fixed input matrix: B = [[1.0], [1.0]]
Fixed output matrix: C = [1.0, 1.0]
Initial state: x_0 = [0, 0]
Selective parameters (learned, for this example):
For u_k = 1.0: Δ_k = softplus(0.5) ≈ 0.974 (smallish, moderate memory)
For u_k = 0.5: Δ_k = softplus(-0.2) ≈ 0.626 (smaller, faster forgetting)
For u_k = 2.0: Δ_k = softplus(1.0) ≈ 1.313 (largest, slowest forgetting)
For simplicity, use approximations:
B̄_k ≈ Δ_k B = Δ_k [[1.0], [1.0]]
Ā_k ≈ e^(Δ_k A)
Step 1: t = 1, u₁ = 1.0
Δ₁ = softplus(0.5) ≈ 0.974
B̄₁ = 0.974 × [[1.0], [1.0]] = [[0.974], [0.974]]
Ā₁ = e^(0.974 × [[-0.9, 0], [0, -0.8]])
= [[e^(-0.874), 0], [0, e^(-0.779)]]
= [[0.418, 0], [0, 0.458]]
x₁ = Ā₁ x₀ + B̄₁ u₁
= [[0.418, 0], [0, 0.458]] × [[0], [0]] + [[0.974], [0.974]] × 1.0
= [[0], [0]] + [[0.974], [0.974]]
= [[0.974], [0.974]]
y₁ = C x₁ = [1.0, 1.0] × [[0.974], [0.974]] = 0.974 + 0.974 = 1.948
Step 2: t = 2, u₂ = 0.5
Δ₂ = softplus(-0.2) ≈ 0.626
B̄₂ = 0.626 × [[1.0], [1.0]] = [[0.626], [0.626]]
Ā₂ = e^(0.626 × [[-0.9, 0], [0, -0.8]])
= [[e^(-0.563), 0], [0, e^(-0.501)]]
= [[0.570, 0], [0, 0.606]]
x₂ = Ā₂ x₁ + B̄₂ u₂
= [[0.570, 0], [0, 0.606]] × [[0.974], [0.974]] + [[0.626], [0.626]] × 0.5
= [[0.570×0.974], [0.606×0.974]] + [[0.313], [0.313]]
= [[0.555], [0.590]] + [[0.313], [0.313]]
= [[0.868], [0.903]]
y₂ = C x₂ = [1.0, 1.0] × [[0.868], [0.903]] = 0.868 + 0.903 = 1.771
Step 3: t = 3, u₃ = 2.0
Δ₃ = softplus(1.0) ≈ 1.313
B̄₃ = 1.313 × [[1.0], [1.0]] = [[1.313], [1.313]]
Ā₃ = e^(1.313 × [[-0.9, 0], [0, -0.8]])
= [[e^(-1.182), 0], [0, e^(-1.050)]]
= [[0.307, 0], [0, 0.350]]
x₃ = Ā₃ x₂ + B̄₃ u₃
= [[0.307, 0], [0, 0.350]] × [[0.868], [0.903]] + [[1.313], [1.313]] × 2.0
= [[0.307×0.868], [0.350×0.903]] + [[2.626], [2.626]]
= [[0.266], [0.316]] + [[2.626], [2.626]]
= [[2.892], [2.942]]
y₃ = C x₃ = [1.0, 1.0] × [[2.892], [2.942]] = 2.892 + 2.942 = 5.834
Part 6: Interpretation
Memory Trace
t=1: x₁ = [0.974, 0.974] (input 1.0 with Δ=0.974 encoded)
t=2: x₂ = [0.868, 0.903] (previous state decayed by Ā₂, plus new input 0.5)
t=3: x₃ = [2.892, 2.942] (previous state decayed by Ā₃, plus large new input 2.0)
Key observation: When u₃ = 2.0 was large, Δ₃ was also large (1.313), which meant Ā₃ had slower decay rates (0.307 and 0.350 instead of smaller values). So:
- The important input (2.0) got a large Δ (slow decay)
- Less important inputs got smaller Δ (fast decay)
- The model learned to allocate memory dynamically!
(Note: In real training, the Linear layers for Δ, B, C are learned to produce these selective values.)
Part 7: Stability and Eigenvalues
The fixed matrix A should be stable — all eigenvalues should have negative real parts.
For the diagonal example:
A = [[-0.9, 0], [0, -0.8]]
Eigenvalues: λ₁ = -0.9, λ₂ = -0.8 (both negative ✓)
Both are negative, so:
- x(t) = e^(λt) x(0) → 0 as t → ∞ (exponential decay)
- The system is stable
If A had a positive eigenvalue, e^(λt) would grow unboundedly (unstable).
Mamba constrains A to be stable (via initialization or parameterization).
Part 8: Comparison with Fixed SSM
Fixed SSM (e.g., S4)
Ā = e^(Δ A) (Δ is a fixed hyperparameter, e.g., 0.1)
B̄ = Δ B (fixed)
For all tokens, same state transition Ā.
For all tokens, same input importance B̄.
Mamba (Selective SSM)
Δ_k = softplus(W_Δ u_k) (varies per token)
Ā_k = e^(Δ_k A) (varies per token)
B̄_k = Δ_k B_k (varies per token, since both Δ_k and B_k vary)
C_k = Linear_C(u_k) (varies per token)
Different tokens get different state transitions and input importance.
The fixed SSM is simpler but less expressive. Mamba’s selectivity is what makes it competitive.
Summary: Full Mamba Forward Pass
For each token t:
1. Compute selective parameters:
Δ_t = softplus(W_Δ u_t)
B_t = Linear_B(u_t)
C_t = Linear_C(u_t)
2. Discretise:
Ā_t = exp(Δ_t A)
B̄_t = Δ_t B_t (approximation)
3. Update state:
x_t = Ā_t x_{t-1} + B̄_t u_t
4. Compute output:
y_t = C_t x_t
Total per-token complexity: O(N²) for matrix multiply, or O(N) with diagonal A