Section 05

Worked example: routing a batch of tokens through 4 experts

Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer 2017

5. Worked example — routing a batch of tokens through 4 experts

🔴 Advanced undergrad. We trace a batch of 4 tokens through a 4-expert MoE layer with k=2, then compute the auxiliary balancing loss.


Setup

4 experts, 2-dimensional token representations, k = 2.

The four tokens from the sentence “chai bahut garam hai” (“chai is very hot”):

x₁ = [1.0,  0.2]   ("chai"   — a concrete noun, drink)
x₂ = [0.3,  0.8]   ("bahut"  — an adverb, degree)
x₃ = [0.1,  0.5]   ("garam"  — an adjective, temperature)
x₄ = [0.6,  0.1]   ("hai"    — a verb, copula)

Gating weight matrix W_g (2 dimensions → 4 expert scores):

W_g = [[ 1.0,  0.5, -0.5,  0.2],   ← weights for dim 1
       [-0.2,  0.8,  1.0, -0.3]]   ← weights for dim 2
     shape: (2 × 4)

Expert networks (each is a simple linear map, for illustration):

Expert 1 (E₁): specialises in nouns/concrete objects
  W_E1 = [[1.2, 0.0], [0.0, 0.5]]   → E₁(x) = x · W_E1

Expert 2 (E₂): specialises in degree/quantifier words
  W_E2 = [[0.3, 0.0], [0.0, 1.4]]   → E₂(x) = x · W_E2

Expert 3 (E₃): specialises in descriptive/adjective words
  W_E3 = [[0.2, 0.8], [0.9, 0.1]]   → E₃(x) = x · W_E3

Expert 4 (E₄): specialises in function words/verbs
  W_E4 = [[0.7, 0.3], [0.1, 0.6]]   → E₄(x) = x · W_E4

Step 1: Compute raw gating logits for all tokens

h(xₜ) = xₜ · W_g for each token:

Token 1 "chai" x₁ = [1.0, 0.2]:
h(x₁) = [1.0×1.0 + 0.2×(−0.2),  1.0×0.5 + 0.2×0.8,  1.0×(−0.5) + 0.2×1.0,  1.0×0.2 + 0.2×(−0.3)]
       = [1.0−0.04,  0.5+0.16,  −0.5+0.20,  0.2−0.06]
       = [0.96,  0.66,  −0.30,  0.14]

Token 2 "bahut" x₂ = [0.3, 0.8]:
h(x₂) = [0.3×1.0 + 0.8×(−0.2),  0.3×0.5 + 0.8×0.8,  0.3×(−0.5) + 0.8×1.0,  0.3×0.2 + 0.8×(−0.3)]
       = [0.30−0.16,  0.15+0.64,  −0.15+0.80,  0.06−0.24]
       = [0.14,  0.79,  0.65,  −0.18]

Token 3 "garam" x₃ = [0.1, 0.5]:
h(x₃) = [0.1×1.0 + 0.5×(−0.2),  0.1×0.5 + 0.5×0.8,  0.1×(−0.5) + 0.5×1.0,  0.1×0.2 + 0.5×(−0.3)]
       = [0.10−0.10,  0.05+0.40,  −0.05+0.50,  0.02−0.15]
       = [0.00,  0.45,  0.45,  −0.13]

Token 4 "hai" x₄ = [0.6, 0.1]:
h(x₄) = [0.6×1.0 + 0.1×(−0.2),  0.6×0.5 + 0.1×0.8,  0.6×(−0.5) + 0.1×1.0,  0.6×0.2 + 0.1×(−0.3)]
       = [0.60−0.02,  0.30+0.08,  −0.30+0.10,  0.12−0.03]
       = [0.58,  0.38,  −0.20,  0.09]

Step 2: Apply Top-2 selection (k=2)

Keep the 2 highest scores per token, set others to −∞:

Token 1 "chai":  scores [0.96, 0.66, −0.30, 0.14]
  Top 2: Expert 1 (0.96) ✓, Expert 2 (0.66) ✓
  After TopK: [0.96, 0.66, −∞, −∞]   → Experts 3, 4 skipped

Token 2 "bahut": scores [0.14, 0.79, 0.65, −0.18]
  Top 2: Expert 2 (0.79) ✓, Expert 3 (0.65) ✓
  After TopK: [−∞, 0.79, 0.65, −∞]   → Experts 1, 4 skipped

Token 3 "garam": scores [0.00, 0.45, 0.45, −0.13]
  Top 2: Expert 2 (0.45) ✓, Expert 3 (0.45) ✓  (tie — pick first two)
  After TopK: [−∞, 0.45, 0.45, −∞]   → Experts 1, 4 skipped

Token 4 "hai":   scores [0.58, 0.38, −0.20, 0.09]
  Top 2: Expert 1 (0.58) ✓, Expert 2 (0.38) ✓
  After TopK: [0.58, 0.38, −∞, −∞]   → Experts 3, 4 skipped

Step 3: Compute gating weights (softmax over top-2)

Token 1 "chai":  softmax([0.96, 0.66])
  exp(0.96)=2.612, exp(0.66)=1.935, sum=4.547
  G(x₁) = [0.574, 0.426, 0.000, 0.000]

Token 2 "bahut": softmax([0.79, 0.65])
  exp(0.79)=2.203, exp(0.65)=1.916, sum=4.119
  G(x₂) = [0.000, 0.535, 0.465, 0.000]

Token 3 "garam": softmax([0.45, 0.45])
  exp(0.45)=1.568, exp(0.45)=1.568, sum=3.136
  G(x₃) = [0.000, 0.500, 0.500, 0.000]

Token 4 "hai":   softmax([0.58, 0.38])
  exp(0.58)=1.786, exp(0.38)=1.462, sum=3.248
  G(x₄) = [0.550, 0.450, 0.000, 0.000]

Step 4: Compute MoE outputs

Token 1 “chai” — Experts 1 and 2 activated:

E₁(x₁) = x₁ · W_E1 = [1.0, 0.2] · [[1.2, 0.0], [0.0, 0.5]]
        = [1.0×1.2 + 0.2×0.0,  1.0×0.0 + 0.2×0.5]
        = [1.200, 0.100]

E₂(x₁) = x₁ · W_E2 = [1.0, 0.2] · [[0.3, 0.0], [0.0, 1.4]]
        = [0.300, 0.280]

MoE(x₁) = 0.574×[1.200, 0.100] + 0.426×[0.300, 0.280]
         = [0.689, 0.057] + [0.128, 0.119]
         = [0.817, 0.176]

Expert 1 (noun specialist) dominates “chai” with 57.4% weight — sensible, “chai” is a concrete noun.

Token 2 “bahut” — Experts 2 and 3 activated:

E₂(x₂) = [0.3, 0.8] · [[0.3,0],[0,1.4]] = [0.090, 1.120]
E₃(x₂) = [0.3, 0.8] · [[0.2,0.8],[0.9,0.1]] = [0.3×0.2+0.8×0.9, 0.3×0.8+0.8×0.1]
        = [0.060+0.720, 0.240+0.080] = [0.780, 0.320]

MoE(x₂) = 0.535×[0.090, 1.120] + 0.465×[0.780, 0.320]
         = [0.048, 0.599] + [0.363, 0.149]
         = [0.411, 0.748]

Token 4 “hai” — Experts 1 and 2 activated (same experts as “chai” but different weights):

E₁(x₄) = [0.6, 0.1] · [[1.2,0],[0,0.5]] = [0.720, 0.050]
E₂(x₄) = [0.6, 0.1] · [[0.3,0],[0,1.4]] = [0.180, 0.140]

MoE(x₄) = 0.550×[0.720, 0.050] + 0.450×[0.180, 0.140]
         = [0.396, 0.028] + [0.081, 0.063]
         = [0.477, 0.091]

Step 5: Compute auxiliary balancing loss

Count how many times each expert was selected (hard routing) across all 4 tokens:

Expert 1: tokens 1, 4         → f₁ = 2/4 = 0.500
Expert 2: tokens 1, 2, 3, 4   → f₂ = 4/4 = 1.000
Expert 3: tokens 2, 3         → f₃ = 2/4 = 0.500
Expert 4: no tokens           → f₄ = 0/4 = 0.000

Expert 2 is getting all 4 tokens! Expert 4 gets none. This is a mild expert collapse scenario.

Soft gating probabilities p (averages of full softmax across all 4 tokens, approximate):

p₁ ≈ 0.310,  p₂ ≈ 0.470,  p₃ ≈ 0.170,  p₄ ≈ 0.050

Auxiliary loss (with n=4, α=0.01):

L_balance = 0.01 × 4 × (f₁p₁ + f₂p₂ + f₃p₃ + f₄p₄)
          = 0.04 × (0.500×0.310 + 1.000×0.470 + 0.500×0.170 + 0.000×0.050)
          = 0.04 × (0.155 + 0.470 + 0.085 + 0.000)
          = 0.04 × 0.710
          = 0.0284

This non-trivial loss (compared to perfect balance which would give ~0.04) generates a gradient that pushes W_g to route fewer tokens to Expert 2 and start sending some to Expert 4. Over many batches, the distribution balances out.


Summary of token routing

"chai"  (noun)     → Expert 1 (57%) + Expert 2 (43%)
"bahut" (adverb)   → Expert 2 (54%) + Expert 3 (47%)
"garam" (adjective)→ Expert 2 (50%) + Expert 3 (50%)
"hai"   (verb)     → Expert 1 (55%) + Expert 2 (45%)

Even in this tiny toy example with hand-crafted weights, the routing is semantically sensible: the noun-specialist Expert 1 handles the noun “chai” and the function word “hai,” while the descriptive-word Experts 2–3 handle “bahut” and “garam.” In a real trained model with 1,000 experts and millions of training steps, the specialists become far more nuanced.