Section 04

The math: gating, top-k, and the balancing loss

Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer 2017

4. The math — gating, top-k, and the balancing loss

🔴 Advanced undergrad. Requires softmax and matrix multiplication. Read Softmax Function and Cross-Entropy Loss first.


The gating function

Let x be a token representation of dimension d_model. The gating network has a single learnable weight matrix W_g of shape (d_model × n), where n is the number of experts.

Raw logits:

h(x) = x · W_g          shape: (n,)   — one score per expert

Noisy version (during training only):

H(x)ᵢ = h(x)ᵢ + ε · Softplus( x · W_noise )ᵢ

Where ε ~ N(0, 1) (random Gaussian noise), Softplus(z) = log(1 + eᶻ), and W_noise is another learnable matrix. This noise encourages exploration — the model tries routing tokens to different experts and learns which work best.

TopK operation:

TopK(v, k)ᵢ = vᵢ   if vᵢ is among the k largest values in v
            = −∞   otherwise

This is just “keep the top k entries, set the rest to −∞.”

Final gating weights (sparse):

G(x) = Softmax( TopK( H(x), k ) )

Because softmax maps −∞ to 0, only the k selected experts have non-zero weights. G(x) is a sparse vector with exactly k non-zero entries summing to 1.

MoE layer output:

MoE(x) = Σᵢ G(x)ᵢ · Eᵢ(x)

Only the k experts with G(x)ᵢ > 0 are evaluated. The rest contribute zero.


Worked numerical example

Setup: 3 experts, 2-dimensional token vector, k = 2 (use top-2 experts).

Token representation: x = [0.8, 0.6]

Gating weight matrix:
W_g = [[ 1,   0,  -1],     ← weights for dimension 1
       [ 0,   1,   1]]     ← weights for dimension 2
     shape: (2 × 3) — maps 2D input to 3 expert scores

Step 1: Compute raw logits h(x) = x · W_g

h(x) = [0.8, 0.6] · [[1, 0, -1],
                       [0, 1,  1]]

h(x)₁ = 0.8×1 + 0.6×0 = 0.800   (score for Expert 1)
h(x)₂ = 0.8×0 + 0.6×1 = 0.600   (score for Expert 2)
h(x)₃ = 0.8×(−1) + 0.6×1 = −0.200  (score for Expert 3)

h(x) = [0.800,  0.600,  −0.200]

Step 2: Apply Top-2 (k=2) — keep only the 2 highest scores

Top 2 values: Expert 1 (0.800) and Expert 2 (0.600).

TopK(h(x), 2) = [0.800,  0.600,  −∞]

Expert 3 is masked to −∞.

Step 3: Apply softmax to get gating weights

exp(0.800) = 2.226
exp(0.600) = 1.822
exp(−∞)    = 0.000

Sum = 2.226 + 1.822 + 0.000 = 4.048

G(x) = [2.226/4.048,  1.822/4.048,  0/4.048]
      = [0.550,         0.450,         0.000]

Expert 1 gets weight 0.550, Expert 2 gets 0.450, Expert 3 is completely skipped.

Check: 0.550 + 0.450 + 0.000 = 1.000 ✓

Step 4: Evaluate the two active experts

Each expert is a feed-forward network (here, a single linear + ReLU for simplicity):

Expert 1: E₁(x) = ReLU( x · W₁ )
  W₁ = [[1, 0],   → E₁([0.8, 0.6]) = ReLU([0.8, 0.6]) = [0.8, 0.6]
         [0, 1]]

Expert 2: E₂(x) = ReLU( x · W₂ )
  W₂ = [[0, 1],   → E₂([0.8, 0.6]) = ReLU([0.6, 0.8]) = [0.6, 0.8]
         [1, 0]]

Expert 3: not evaluated (G(x)₃ = 0, compute saved)

Step 5: Compute MoE output

MoE(x) = G(x)₁ · E₁(x) + G(x)₂ · E₂(x) + G(x)₃ · E₃(x)

       = 0.550 × [0.8, 0.6]  +  0.450 × [0.6, 0.8]  +  0 × (skipped)

       = [0.440, 0.330]  +  [0.270, 0.360]

       = [0.710, 0.690]

The final output [0.710, 0.690] is a blend of Expert 1’s output (weighted 55%) and Expert 2’s (weighted 45%). Expert 3 contributed nothing and was never computed.

Compute savings: 1 of 3 experts (33%) was skipped entirely. At scale with 1,000 experts and k=2, you save 99.8% of expert compute per token.


The load-balancing auxiliary loss

Notation for a batch of T tokens, n experts:

For each expert i, define:

  • fᵢ = fraction of tokens routed to expert i by the hard top-k selection
fᵢ = (1/T) · Σₜ 𝟙[i ∈ TopK(H(xₜ), k)]

(𝟙 is 1 if expert i is selected for token t, 0 otherwise)

  • pᵢ = mean soft gating probability for expert i across the batch
pᵢ = (1/T) · Σₜ Softmax(H(xₜ))ᵢ

(using the full softmax before TopK masking)

The auxiliary loss:

L_balance = α · n · Σᵢ fᵢ · pᵢ

Why this works:

If load is perfectly balanced (fᵢ = 1/n for all i), then:

L_balance = α · n · Σᵢ (1/n) · pᵢ = α · Σᵢ pᵢ = α · 1 = α

(since soft probabilities sum to 1)

This is the minimum possible value. Any imbalance makes it larger.

If one expert gets all the tokens (f₁ = 1, fᵢ = 0 for i > 1):

L_balance ≈ α · n · 1 · p₁ ≈ α · n    (much larger when n is big)

The gradient of L_balance with respect to W_g pushes the gating network toward equal distribution across experts.

Numerical example:

3 experts, 4 tokens in a batch. Hard routing: tokens [1,2,3] → Expert 1, token [4] → Expert 2.

f₁ = 3/4 = 0.750,  f₂ = 1/4 = 0.250,  f₃ = 0/4 = 0.000

Soft probabilities (average over batch, hypothetical):
p₁ = 0.600,  p₂ = 0.300,  p₃ = 0.100

L_balance = α × 3 × (0.750×0.600 + 0.250×0.300 + 0.000×0.100)
          = α × 3 × (0.450 + 0.075 + 0.000)
          = α × 3 × 0.525
          = 1.575α

If α = 0.01 → L_balance = 0.0158. Small relative to the language modelling loss, but its gradient continuously nudges toward balance.

Total training loss:

L_total = L_language_model + L_balance

The model learns to predict text accurately (L_language_model via cross-entropy) while distributing tokens evenly across experts (L_balance).


Capacity factor and token dropping

One practical detail: each expert can only handle a fixed number of tokens per training batch, called its capacity.

Capacity = (tokens per batch / n experts) × capacity_factor

If a popular expert receives more tokens than its capacity, excess tokens are dropped — they skip the MoE layer and pass through unchanged (via a residual connection). The capacity factor (typically 1.0–2.0) controls how much buffer each expert has beyond its fair share.

This is a key implementation detail: without it, a single overloaded expert becomes a bottleneck on distributed hardware. With it, the system degrades gracefully and training continues.