Worked Example: Complete Mamba Forward Pass
Let’s trace a minimal Mamba model processing a 3-token sequence by hand.
Scenario
We’re processing the sequence: “The cat sat”
u₁ = "The" (embedding)
u₂ = "cat" (embedding)
u₃ = "sat" (embedding)
For simplicity, represent each as a scalar (dimension 1):
u₁ = 0.5 (representing "The")
u₂ = 1.0 (representing "cat")
u₃ = 0.2 (representing "sat")
Model Setup
State dimension: N = 2
Input dimension: d = 1
Fixed state matrix: A = [[-1.0, 0], [0, -0.5]] (diagonal, stable)
Learned weights (fixed for this example):
W_Δ = [0.5, -0.3] (projects input to Δ)
W_B = [[0.8], [0.6]]
W_C = [[0.7], [0.4]]
Initial state: x₀ = [0, 0]
Processing Token 1: u₁ = 0.5 (“The”)
Step 1a: Compute Selective Parameters
Δ₁ = softplus(W_Δ · u₁)
= softplus([0.5, -0.3] · 0.5)
= softplus(0.25 - 0.15)
= softplus(0.1)
= log(1 + exp(0.1))
≈ log(1 + 1.105)
≈ log(2.105)
≈ 0.743
B₁ = W_B · u₁ = [[0.8], [0.6]] × 0.5 = [[0.4], [0.3]]
C₁ = W_C · u₁ = [[0.7], [0.4]] × 0.5 = [[0.35], [0.2]]
Step 1b: Discretise State Transition
Ā₁ = e^(Δ₁ × A)
= e^(0.743 × [[-1.0, 0], [0, -0.5]])
= e^([[-0.743, 0], [0, -0.372]])
= [[e^(-0.743), 0], [0, e^(-0.372)]]
= [[0.475, 0], [0, 0.690]]
B̄₁ = Δ₁ × B₁
= 0.743 × [[0.4], [0.3]]
= [[0.297], [0.223]]
Step 1c: Update State
x₁ = Ā₁ x₀ + B̄₁ u₁
= [[0.475, 0], [0, 0.690]] × [[0], [0]] + [[0.297], [0.223]] × 0.5
= [[0], [0]] + [[0.149], [0.112]]
= [[0.149], [0.112]]
Verification:
x₁[0] = 0.475 × 0 + 0.297 × 0.5 = 0.149 ✓
x₁[1] = 0.690 × 0 + 0.223 × 0.5 = 0.112 ✓
Step 1d: Compute Output
y₁ = C₁ · x₁
= [[0.35], [0.2]] · [[0.149], [0.112]]
= 0.35 × 0.149 + 0.2 × 0.112
= 0.052 + 0.022
= 0.074
Output for token 1: 0.074
(This is passed through a final linear layer to predict the next token,
but we'll just track the SSM output for now)
State after processing “The”:
x₁ = [0.149, 0.112]
Processing Token 2: u₂ = 1.0 (“cat”)
Step 2a: Compute Selective Parameters
Δ₂ = softplus(W_Δ · u₂)
= softplus([0.5, -0.3] · 1.0)
= softplus(0.5 - 0.3)
= softplus(0.2)
≈ log(1 + exp(0.2))
≈ log(1.221)
≈ 0.803
B₂ = W_B · u₂ = [[0.8], [0.6]] × 1.0 = [[0.8], [0.6]]
C₂ = W_C · u₂ = [[0.7], [0.4]] × 1.0 = [[0.7], [0.4]]
Step 2b: Discretise
Ā₂ = e^(0.803 × [[-1.0, 0], [0, -0.5]])
= e^([[-0.803, 0], [0, -0.402]])
= [[e^(-0.803), 0], [0, e^(-0.402)]]
= [[0.448, 0], [0, 0.668]]
B̄₂ = 0.803 × [[0.8], [0.6]]
= [[0.642], [0.482]]
Step 2c: Update State (Using Previous State x₁)
x₂ = Ā₂ x₁ + B̄₂ u₂
= [[0.448, 0], [0, 0.668]] × [[0.149], [0.112]] + [[0.642], [0.482]] × 1.0
Compute Ā₂ x₁:
Ā₂ x₁[0] = 0.448 × 0.149 + 0 × 0.112 = 0.067
Ā₂ x₁[1] = 0 × 0.149 + 0.668 × 0.112 = 0.075
Ā₂ x₁ = [[0.067], [0.075]]
x₂ = [[0.067], [0.075]] + [[0.642], [0.482]]
= [[0.709], [0.557]]
Interpretation: The previous state [0.149, 0.112] decayed to [0.067, 0.075], then the new strong input “cat” (u₂ = 1.0) added [0.642, 0.482], resulting in [0.709, 0.557].
Step 2d: Compute Output
y₂ = C₂ · x₂
= [[0.7], [0.4]] · [[0.709], [0.557]]
= 0.7 × 0.709 + 0.4 × 0.557
= 0.496 + 0.223
= 0.719
Output for token 2: 0.719
State after “The cat”:
x₂ = [0.709, 0.557]
Processing Token 3: u₃ = 0.2 (“sat”)
Step 3a: Compute Selective Parameters
Δ₃ = softplus(W_Δ · u₃)
= softplus([0.5, -0.3] · 0.2)
= softplus(0.1 - 0.06)
= softplus(0.04)
≈ log(1 + exp(0.04))
≈ log(1.041)
≈ 0.693
B₃ = W_B · u₃ = [[0.8], [0.6]] × 0.2 = [[0.16], [0.12]]
C₃ = W_C · u₃ = [[0.7], [0.4]] × 0.2 = [[0.14], [0.08]]
Step 3b: Discretise
Ā₃ = e^(0.693 × [[-1.0, 0], [0, -0.5]])
= e^([[-0.693, 0], [0, -0.347]])
= [[e^(-0.693), 0], [0, e^(-0.347)]]
= [[0.500, 0], [0, 0.706]]
B̄₃ = 0.693 × [[0.16], [0.12]]
= [[0.111], [0.083]]
Step 3c: Update State (Using Previous State x₂)
x₃ = Ā₃ x₂ + B̄₃ u₃
= [[0.500, 0], [0, 0.706]] × [[0.709], [0.557]] + [[0.111], [0.083]] × 0.2
Compute Ā₃ x₂:
Ā₃ x₂[0] = 0.500 × 0.709 + 0 × 0.557 = 0.355
Ā₃ x₂[1] = 0 × 0.709 + 0.706 × 0.557 = 0.393
Ā₃ x₂ = [[0.355], [0.393]]
Compute B̄₃ u₃:
B̄₃ u₃ = [[0.111], [0.083]] × 0.2 = [[0.022], [0.017]]
x₃ = [[0.355], [0.393]] + [[0.022], [0.017]]
= [[0.377], [0.410]]
Interpretation: The strong previous state [0.709, 0.557] decayed to [0.355, 0.393], then the weak input “sat” (u₃ = 0.2) added only [0.022, 0.017], resulting in [0.377, 0.410].
Step 3d: Compute Output
y₃ = C₃ · x₃
= [[0.14], [0.08]] · [[0.377], [0.410]]
= 0.14 × 0.377 + 0.08 × 0.410
= 0.053 + 0.033
= 0.086
Output for token 3: 0.086
Final state:
x₃ = [0.377, 0.410]
Summary Table
| Token | Input | Δ | Ā eigenvalues | x state | Output y | Interpretation |
|---|---|---|---|---|---|---|
| ”The” | 0.5 | 0.743 | (0.475, 0.690) | (0.149, 0.112) | 0.074 | Initial encoding of “The" |
| "cat” | 1.0 | 0.803 | (0.448, 0.668) | (0.709, 0.557) | 0.719 | Strong input “cat” dominates state |
| ”sat” | 0.2 | 0.693 | (0.500, 0.706) | (0.377, 0.410) | 0.086 | Weak input, state decays toward previous |
Key Observations
1. Selectivity in Action
- Token “cat” (strong, u=1.0): Large Δ=0.803 → slower decay (Ā has larger eigenvalues like 0.448, 0.668)
- Token “sat” (weak, u=0.2): Smaller Δ=0.693 → faster decay (Ā has smaller eigenvalues)
The model learned to spend more “memory time” on important words and less on weak words.
2. State Accumulation
- After “The”: state = [0.149, 0.112] (small, just the first word)
- After “The cat”: state = [0.709, 0.557] (much larger, dominated by “cat”)
- After “The cat sat”: state = [0.377, 0.410] (smaller again, “sat” weak, previous state decaying)
The state is a running summary, with emphasis on recent strong inputs.
3. Output Variance
- y₁ = 0.074 (weak, early in sequence)
- y₂ = 0.719 (strong, “cat” is important)
- y₃ = 0.086 (weak, “sat” is weak)
The outputs reflect input importance, with “cat” producing the strongest signal.
What This Represents
In a full Mamba model:
- The outputs [0.074, 0.719, 0.086] would be non-linearly transformed and mixed with gating (SiLU activation)
- Then passed through a projection layer back to vocabulary dimension
- Finally, softmax → predicted next token logits
But the core mechanism — selective memory, input-dependent decay, and efficient linear-time processing — is captured in this simple example.