Section 05

Worked Example: Complete Mamba Forward Pass

Mamba: Linear-Time Sequence Modeling with Selective State Spaces 2023

Worked Example: Complete Mamba Forward Pass

Let’s trace a minimal Mamba model processing a 3-token sequence by hand.


Scenario

We’re processing the sequence: “The cat sat”

u₁ = "The"  (embedding)
u₂ = "cat"  (embedding)
u₃ = "sat"  (embedding)

For simplicity, represent each as a scalar (dimension 1):
u₁ = 0.5   (representing "The")
u₂ = 1.0   (representing "cat")
u₃ = 0.2   (representing "sat")

Model Setup

State dimension: N = 2
Input dimension: d = 1
Fixed state matrix: A = [[-1.0, 0], [0, -0.5]]  (diagonal, stable)

Learned weights (fixed for this example):
  W_Δ = [0.5, -0.3]  (projects input to Δ)
  W_B = [[0.8], [0.6]]
  W_C = [[0.7], [0.4]]

Initial state: x₀ = [0, 0]

Processing Token 1: u₁ = 0.5 (“The”)

Step 1a: Compute Selective Parameters

Δ₁ = softplus(W_Δ · u₁)
   = softplus([0.5, -0.3] · 0.5)
   = softplus(0.25 - 0.15)
   = softplus(0.1)
   = log(1 + exp(0.1))
   ≈ log(1 + 1.105)
   ≈ log(2.105)
   ≈ 0.743

B₁ = W_B · u₁ = [[0.8], [0.6]] × 0.5 = [[0.4], [0.3]]

C₁ = W_C · u₁ = [[0.7], [0.4]] × 0.5 = [[0.35], [0.2]]

Step 1b: Discretise State Transition

Ā₁ = e^(Δ₁ × A)
   = e^(0.743 × [[-1.0, 0], [0, -0.5]])
   = e^([[-0.743, 0], [0, -0.372]])
   = [[e^(-0.743), 0], [0, e^(-0.372)]]
   = [[0.475, 0], [0, 0.690]]

B̄₁ = Δ₁ × B₁
   = 0.743 × [[0.4], [0.3]]
   = [[0.297], [0.223]]

Step 1c: Update State

x₁ = Ā₁ x₀ + B̄₁ u₁
   = [[0.475, 0], [0, 0.690]] × [[0], [0]] + [[0.297], [0.223]] × 0.5
   = [[0], [0]] + [[0.149], [0.112]]
   = [[0.149], [0.112]]

Verification:
  x₁[0] = 0.475 × 0 + 0.297 × 0.5 = 0.149 ✓
  x₁[1] = 0.690 × 0 + 0.223 × 0.5 = 0.112 ✓

Step 1d: Compute Output

y₁ = C₁ · x₁
   = [[0.35], [0.2]] · [[0.149], [0.112]]
   = 0.35 × 0.149 + 0.2 × 0.112
   = 0.052 + 0.022
   = 0.074

Output for token 1: 0.074
(This is passed through a final linear layer to predict the next token,
 but we'll just track the SSM output for now)

State after processing “The”:

x₁ = [0.149, 0.112]

Processing Token 2: u₂ = 1.0 (“cat”)

Step 2a: Compute Selective Parameters

Δ₂ = softplus(W_Δ · u₂)
   = softplus([0.5, -0.3] · 1.0)
   = softplus(0.5 - 0.3)
   = softplus(0.2)
   ≈ log(1 + exp(0.2))
   ≈ log(1.221)
   ≈ 0.803

B₂ = W_B · u₂ = [[0.8], [0.6]] × 1.0 = [[0.8], [0.6]]

C₂ = W_C · u₂ = [[0.7], [0.4]] × 1.0 = [[0.7], [0.4]]

Step 2b: Discretise

Ā₂ = e^(0.803 × [[-1.0, 0], [0, -0.5]])
   = e^([[-0.803, 0], [0, -0.402]])
   = [[e^(-0.803), 0], [0, e^(-0.402)]]
   = [[0.448, 0], [0, 0.668]]

B̄₂ = 0.803 × [[0.8], [0.6]]
   = [[0.642], [0.482]]

Step 2c: Update State (Using Previous State x₁)

x₂ = Ā₂ x₁ + B̄₂ u₂
   = [[0.448, 0], [0, 0.668]] × [[0.149], [0.112]] + [[0.642], [0.482]] × 1.0
   
Compute Ā₂ x₁:
  Ā₂ x₁[0] = 0.448 × 0.149 + 0 × 0.112 = 0.067
  Ā₂ x₁[1] = 0 × 0.149 + 0.668 × 0.112 = 0.075
  Ā₂ x₁ = [[0.067], [0.075]]

x₂ = [[0.067], [0.075]] + [[0.642], [0.482]]
   = [[0.709], [0.557]]

Interpretation: The previous state [0.149, 0.112] decayed to [0.067, 0.075], then the new strong input “cat” (u₂ = 1.0) added [0.642, 0.482], resulting in [0.709, 0.557].

Step 2d: Compute Output

y₂ = C₂ · x₂
   = [[0.7], [0.4]] · [[0.709], [0.557]]
   = 0.7 × 0.709 + 0.4 × 0.557
   = 0.496 + 0.223
   = 0.719

Output for token 2: 0.719

State after “The cat”:

x₂ = [0.709, 0.557]

Processing Token 3: u₃ = 0.2 (“sat”)

Step 3a: Compute Selective Parameters

Δ₃ = softplus(W_Δ · u₃)
   = softplus([0.5, -0.3] · 0.2)
   = softplus(0.1 - 0.06)
   = softplus(0.04)
   ≈ log(1 + exp(0.04))
   ≈ log(1.041)
   ≈ 0.693

B₃ = W_B · u₃ = [[0.8], [0.6]] × 0.2 = [[0.16], [0.12]]

C₃ = W_C · u₃ = [[0.7], [0.4]] × 0.2 = [[0.14], [0.08]]

Step 3b: Discretise

Ā₃ = e^(0.693 × [[-1.0, 0], [0, -0.5]])
   = e^([[-0.693, 0], [0, -0.347]])
   = [[e^(-0.693), 0], [0, e^(-0.347)]]
   = [[0.500, 0], [0, 0.706]]

B̄₃ = 0.693 × [[0.16], [0.12]]
   = [[0.111], [0.083]]

Step 3c: Update State (Using Previous State x₂)

x₃ = Ā₃ x₂ + B̄₃ u₃
   = [[0.500, 0], [0, 0.706]] × [[0.709], [0.557]] + [[0.111], [0.083]] × 0.2

Compute Ā₃ x₂:
  Ā₃ x₂[0] = 0.500 × 0.709 + 0 × 0.557 = 0.355
  Ā₃ x₂[1] = 0 × 0.709 + 0.706 × 0.557 = 0.393
  Ā₃ x₂ = [[0.355], [0.393]]

Compute B̄₃ u₃:
  B̄₃ u₃ = [[0.111], [0.083]] × 0.2 = [[0.022], [0.017]]

x₃ = [[0.355], [0.393]] + [[0.022], [0.017]]
   = [[0.377], [0.410]]

Interpretation: The strong previous state [0.709, 0.557] decayed to [0.355, 0.393], then the weak input “sat” (u₃ = 0.2) added only [0.022, 0.017], resulting in [0.377, 0.410].

Step 3d: Compute Output

y₃ = C₃ · x₃
   = [[0.14], [0.08]] · [[0.377], [0.410]]
   = 0.14 × 0.377 + 0.08 × 0.410
   = 0.053 + 0.033
   = 0.086

Output for token 3: 0.086

Final state:

x₃ = [0.377, 0.410]

Summary Table

TokenInputΔĀ eigenvaluesx stateOutput yInterpretation
”The”0.50.743(0.475, 0.690)(0.149, 0.112)0.074Initial encoding of “The"
"cat”1.00.803(0.448, 0.668)(0.709, 0.557)0.719Strong input “cat” dominates state
”sat”0.20.693(0.500, 0.706)(0.377, 0.410)0.086Weak input, state decays toward previous

Key Observations

1. Selectivity in Action

  • Token “cat” (strong, u=1.0): Large Δ=0.803 → slower decay (Ā has larger eigenvalues like 0.448, 0.668)
  • Token “sat” (weak, u=0.2): Smaller Δ=0.693 → faster decay (Ā has smaller eigenvalues)

The model learned to spend more “memory time” on important words and less on weak words.

2. State Accumulation

  • After “The”: state = [0.149, 0.112] (small, just the first word)
  • After “The cat”: state = [0.709, 0.557] (much larger, dominated by “cat”)
  • After “The cat sat”: state = [0.377, 0.410] (smaller again, “sat” weak, previous state decaying)

The state is a running summary, with emphasis on recent strong inputs.

3. Output Variance

  • y₁ = 0.074 (weak, early in sequence)
  • y₂ = 0.719 (strong, “cat” is important)
  • y₃ = 0.086 (weak, “sat” is weak)

The outputs reflect input importance, with “cat” producing the strongest signal.


What This Represents

In a full Mamba model:

  • The outputs [0.074, 0.719, 0.086] would be non-linearly transformed and mixed with gating (SiLU activation)
  • Then passed through a projection layer back to vocabulary dimension
  • Finally, softmax → predicted next token logits

But the core mechanism — selective memory, input-dependent decay, and efficient linear-time processing — is captured in this simple example.


Next: The Code: Implementing Mamba