4. The math — gating, top-k, and the balancing loss
🔴 Advanced undergrad. Requires softmax and matrix multiplication. Read Softmax Function and Cross-Entropy Loss first.
The gating function
Let x be a token representation of dimension d_model. The gating network has a single learnable weight matrix W_g of shape (d_model × n), where n is the number of experts.
Raw logits:
h(x) = x · W_g shape: (n,) — one score per expert
Noisy version (during training only):
H(x)ᵢ = h(x)ᵢ + ε · Softplus( x · W_noise )ᵢ
Where ε ~ N(0, 1) (random Gaussian noise), Softplus(z) = log(1 + eᶻ), and W_noise is another learnable matrix. This noise encourages exploration — the model tries routing tokens to different experts and learns which work best.
TopK operation:
TopK(v, k)ᵢ = vᵢ if vᵢ is among the k largest values in v
= −∞ otherwise
This is just “keep the top k entries, set the rest to −∞.”
Final gating weights (sparse):
G(x) = Softmax( TopK( H(x), k ) )
Because softmax maps −∞ to 0, only the k selected experts have non-zero weights. G(x) is a sparse vector with exactly k non-zero entries summing to 1.
MoE layer output:
MoE(x) = Σᵢ G(x)ᵢ · Eᵢ(x)
Only the k experts with G(x)ᵢ > 0 are evaluated. The rest contribute zero.
Worked numerical example
Setup: 3 experts, 2-dimensional token vector, k = 2 (use top-2 experts).
Token representation: x = [0.8, 0.6]
Gating weight matrix:
W_g = [[ 1, 0, -1], ← weights for dimension 1
[ 0, 1, 1]] ← weights for dimension 2
shape: (2 × 3) — maps 2D input to 3 expert scores
Step 1: Compute raw logits h(x) = x · W_g
h(x) = [0.8, 0.6] · [[1, 0, -1],
[0, 1, 1]]
h(x)₁ = 0.8×1 + 0.6×0 = 0.800 (score for Expert 1)
h(x)₂ = 0.8×0 + 0.6×1 = 0.600 (score for Expert 2)
h(x)₃ = 0.8×(−1) + 0.6×1 = −0.200 (score for Expert 3)
h(x) = [0.800, 0.600, −0.200]
Step 2: Apply Top-2 (k=2) — keep only the 2 highest scores
Top 2 values: Expert 1 (0.800) and Expert 2 (0.600).
TopK(h(x), 2) = [0.800, 0.600, −∞]
Expert 3 is masked to −∞.
Step 3: Apply softmax to get gating weights
exp(0.800) = 2.226
exp(0.600) = 1.822
exp(−∞) = 0.000
Sum = 2.226 + 1.822 + 0.000 = 4.048
G(x) = [2.226/4.048, 1.822/4.048, 0/4.048]
= [0.550, 0.450, 0.000]
Expert 1 gets weight 0.550, Expert 2 gets 0.450, Expert 3 is completely skipped.
Check: 0.550 + 0.450 + 0.000 = 1.000 ✓
Step 4: Evaluate the two active experts
Each expert is a feed-forward network (here, a single linear + ReLU for simplicity):
Expert 1: E₁(x) = ReLU( x · W₁ )
W₁ = [[1, 0], → E₁([0.8, 0.6]) = ReLU([0.8, 0.6]) = [0.8, 0.6]
[0, 1]]
Expert 2: E₂(x) = ReLU( x · W₂ )
W₂ = [[0, 1], → E₂([0.8, 0.6]) = ReLU([0.6, 0.8]) = [0.6, 0.8]
[1, 0]]
Expert 3: not evaluated (G(x)₃ = 0, compute saved)
Step 5: Compute MoE output
MoE(x) = G(x)₁ · E₁(x) + G(x)₂ · E₂(x) + G(x)₃ · E₃(x)
= 0.550 × [0.8, 0.6] + 0.450 × [0.6, 0.8] + 0 × (skipped)
= [0.440, 0.330] + [0.270, 0.360]
= [0.710, 0.690]
The final output [0.710, 0.690] is a blend of Expert 1’s output (weighted 55%) and Expert 2’s (weighted 45%). Expert 3 contributed nothing and was never computed.
Compute savings: 1 of 3 experts (33%) was skipped entirely. At scale with 1,000 experts and k=2, you save 99.8% of expert compute per token.
The load-balancing auxiliary loss
Notation for a batch of T tokens, n experts:
For each expert i, define:
- fᵢ = fraction of tokens routed to expert i by the hard top-k selection
fᵢ = (1/T) · Σₜ 𝟙[i ∈ TopK(H(xₜ), k)]
(𝟙 is 1 if expert i is selected for token t, 0 otherwise)
- pᵢ = mean soft gating probability for expert i across the batch
pᵢ = (1/T) · Σₜ Softmax(H(xₜ))ᵢ
(using the full softmax before TopK masking)
The auxiliary loss:
L_balance = α · n · Σᵢ fᵢ · pᵢ
Why this works:
If load is perfectly balanced (fᵢ = 1/n for all i), then:
L_balance = α · n · Σᵢ (1/n) · pᵢ = α · Σᵢ pᵢ = α · 1 = α
(since soft probabilities sum to 1)
This is the minimum possible value. Any imbalance makes it larger.
If one expert gets all the tokens (f₁ = 1, fᵢ = 0 for i > 1):
L_balance ≈ α · n · 1 · p₁ ≈ α · n (much larger when n is big)
The gradient of L_balance with respect to W_g pushes the gating network toward equal distribution across experts.
Numerical example:
3 experts, 4 tokens in a batch. Hard routing: tokens [1,2,3] → Expert 1, token [4] → Expert 2.
f₁ = 3/4 = 0.750, f₂ = 1/4 = 0.250, f₃ = 0/4 = 0.000
Soft probabilities (average over batch, hypothetical):
p₁ = 0.600, p₂ = 0.300, p₃ = 0.100
L_balance = α × 3 × (0.750×0.600 + 0.250×0.300 + 0.000×0.100)
= α × 3 × (0.450 + 0.075 + 0.000)
= α × 3 × 0.525
= 1.575α
If α = 0.01 → L_balance = 0.0158. Small relative to the language modelling loss, but its gradient continuously nudges toward balance.
Total training loss:
L_total = L_language_model + L_balance
The model learns to predict text accurately (L_language_model via cross-entropy) while distributing tokens evenly across experts (L_balance).
Capacity factor and token dropping
One practical detail: each expert can only handle a fixed number of tokens per training batch, called its capacity.
Capacity = (tokens per batch / n experts) × capacity_factor
If a popular expert receives more tokens than its capacity, excess tokens are dropped — they skip the MoE layer and pass through unchanged (via a residual connection). The capacity factor (typically 1.0–2.0) controls how much buffer each expert has beyond its fair share.
This is a key implementation detail: without it, a single overloaded expert becomes a bottleneck on distributed hardware. With it, the system degrades gracefully and training continues.