The Chain Rule

Intermediate
Before this tutorial: /math-tutorials/calculus/derivatives-introduction

̌̌̌

The Chain Rule

1. What is this and why do we care?

Backpropagation is the algorithm that trains every neural network in the world. It is how a network figures out which weights to adjust, and by how much, after making a mistake.

Backpropagation is, at its mathematical core, nothing more than the chain rule applied repeatedly.

The chain rule lets you find the derivative of a composed function — a function inside a function. Neural networks are composed functions. Input → Layer 1 → Layer 2 → Layer 3 → Output → Loss. Each arrow is a function. The chain rule is the tool for tracking how a change at the input ripples through all these layers to affect the final loss.

If you understand the chain rule, you understand backpropagation. Not just “roughly” — you understand it completely.


2. Prerequisites

You must understand derivatives first. Read Derivatives — Introduction before this tutorial.


3. The intuition — before any symbols

Imagine a relay race with three runners: Priya, Arjun, and Sunita.

  • Priya’s speed affects how fast the baton reaches Arjun.
  • Arjun’s speed affects how fast the baton reaches Sunita.
  • Sunita’s speed affects the final finishing time.

If you want to know: “how much does Priya’s speed affect the final finishing time?” — you cannot just look at Priya. You have to trace the effect through Arjun and Sunita.

“If Priya runs 10% faster, the baton reaches Arjun 8% sooner (because of friction, handoff time, etc.). If the baton reaches Arjun 8% sooner, Sunita gets it 6% sooner. If Sunita gets it 6% sooner, the team finishes 5% faster.”

So a 10% increase in Priya’s speed → 5% improvement in finish time. The total effect is the product of all the intermediate effects.

This is the chain rule. When a function is composed — when the output of one function feeds into another — the derivative of the whole thing is the product of the derivatives of each part.


4. A tiny worked example with real numbers

Say we have a composed function:

y = f(g(x))   where   g(x) = 2x   and   f(u) = u²

So: y = (2x)² = 4x²

The chain rule says: dy/dx = (dy/du) × (du/dx)

Where u = g(x) = 2x is the intermediate value.

Let us compute each piece:

u = g(x) = 2x        →   du/dx = 2       (derivative of 2x is 2)
y = f(u) = u²        →   dy/du = 2u      (derivative of u² is 2u)

Chain rule:

dy/dx = dy/du × du/dx = 2u × 2 = 4u = 4(2x) = 8x

Check: directly differentiating y = 4x²:

dy/dx = 4 × 2x = 8x  ✓

Both methods give 8x. At x = 3: dy/dx = 8 × 3 = 24.


5. The general rule

If y = f(g(x)), then:

dy/dx = (df/dg) × (dg/dx)

Or in the Leibniz notation that neural network papers use:

dy/dx = (dy/du) × (du/dx)     where u = g(x)

In words: The derivative of the outer function (evaluated at the inner function’s output) times the derivative of the inner function.

For three nested functions — y = f(g(h(x))):

dy/dx = (dy/dg) × (dg/dh) × (dh/dx)

For n nested functions, you just keep multiplying. This is the chain of derivatives — hence “chain rule.”


6. A neural network example — two layers

Consider a tiny neural network with one input x, one hidden neuron, and one output.

Layer 1:  h = w₁ × x          (hidden neuron output, simplified — no activation)
Layer 2:  ŷ = w₂ × h          (final output)
Loss:     L = (y - ŷ)²        (squared error, where y is the correct answer)

We want: dL/dw₁ — how does the loss change when we change w₁?

w₁ is two steps removed from L. We need the chain rule.

Step 1: Find intermediate derivatives

dL/dŷ  = -2(y - ŷ)           (derivative of squared error w.r.t. ŷ)
dŷ/dh  = w₂                   (ŷ = w₂ × h, so derivative w.r.t. h is w₂)
dh/dw₁ = x                    (h = w₁ × x, so derivative w.r.t. w₁ is x)

Step 2: Chain them together

dL/dw₁ = dL/dŷ × dŷ/dh × dh/dw₁
        = -2(y - ŷ) × w₂ × x

Now substitute numbers. Say: x = 2, y = 1 (correct answer), w₁ = 0.5, w₂ = 0.3.

h  = 0.5 × 2 = 1.0
ŷ  = 0.3 × 1.0 = 0.3
L  = (1 - 0.3)² = 0.49

dL/dw₁ = -2(1 - 0.3) × 0.3 × 2
        = -2 × 0.7 × 0.3 × 2
        = -0.84

The gradient is -0.84. Negative means: increase w₁ to decrease loss. So gradient descent will do:

w₁_new = w₁ - η × dL/dw₁ = 0.5 - 0.1 × (-0.84) = 0.5 + 0.084 = 0.584

W₁ increases from 0.5 to 0.584. The network is learning.

This is backpropagation — computed by hand, for two layers. The same logic scales to hundreds of layers.


7. Why “back” propagation?

Notice we computed the derivatives in reverse order — from the loss (at the output) back toward the input:

dL/dŷ   ← computed first (at the output layer)
dŷ/dh   ← computed second (one layer back)
dh/dw₁  ← computed third (at the input layer)

We “propagate” gradients backwards through the network, layer by layer. That is where the name comes from.

In a real network with millions of weights, this backward pass is done efficiently using the same chain rule, automated by the software. Libraries like PyTorch and TensorFlow do this automatically when you call .backward().


8. Where does this appear in AI?

Paper 03 — Backpropagation: The entire paper is an efficient algorithm for applying the chain rule to multi-layer networks. Rumelhart, Hinton and Williams showed that you can compute all gradients in a single backward pass — much faster than computing each weight’s gradient independently.

Paper 04 — LSTM: The vanishing gradient problem happens when the chain rule’s product of derivatives shrinks toward zero over many layers. LSTM is designed so that gradients can flow without vanishing — it engineers the chain rule products to stay close to 1.

Every deep learning library: PyTorch’s autograd, TensorFlow’s gradient tape, and JAX’s jit all automate the chain rule. When you write loss.backward() in PyTorch, you are triggering an automated chain rule computation across the entire computation graph.


9. Common mistakes

  • Forgetting to evaluate the outer derivative at the inner function’s output. The chain rule says (dy/du) × (du/dx) — and dy/du must be evaluated at u = g(x), not at x. Students sometimes plug in x where they should plug in u.

  • Getting the order wrong. The chain rule multiplies from outside in: outer derivative first, then inner. If you reverse the order you still get the same answer (multiplication is commutative) — but the intermediate expressions will be confusing.

  • Stopping the chain too early. In a 5-layer network, the gradient of the loss with respect to layer 1’s weights requires 5 chain rule applications. Students sometimes stop at 3 or 4. The chain must go all the way back to the weight you are computing the gradient for.


10. Try it yourself

Exercise 1: Let y = (3x + 1)². Find dy/dx using the chain rule. Let u = 3x + 1, so y = u².

Show answer

du/dx = 3 (derivative of 3x+1) dy/du = 2u (derivative of u²)

Chain rule: dy/dx = 2u × 3 = 6u = 6(3x+1)

At x = 2: dy/dx = 6(3×2+1) = 6×7 = 42

Check directly: y = (3×2+1)² = 7² = 49. At x = 2.01: y = (3×2.01+1)² = 7.03² = 49.4209. (49.4209 - 49)/0.01 ≈ 42.09 ≈ 42 ✓


Exercise 2: A one-layer network: ŷ = w × x. Loss L = (y - ŷ)² = (y - wx)².

Using the chain rule, find dL/dw. Let u = y - wx, so L = u².

Show answer

u = y - wx → du/dw = -x (y is constant, derivative of -wx w.r.t. w is -x) L = u² → dL/du = 2u = 2(y - wx)

Chain rule: dL/dw = dL/du × du/dw = 2(y - wx) × (-x) = -2x(y - wx)

This is the gradient of squared loss for a single-layer network. If x=3, y=6, w=1: dL/dw = -2×3×(6 - 1×3) = -6×3 = -18

Negative → increase w to decrease loss. Gradient descent: w_new = 1 - 0.1×(-18) = 1 + 1.8 = 2.8. After this update, ŷ = 2.8×3 = 8.4 — closer to 6 than before (ŷ = 3). Several more steps will converge to w = 2 (since y = 6 = 2×3).


Exercise 3: In the two-layer example from Section 6, compute dL/dw₂ (not dL/dw₁). Use: x=2, y=1, w₁=0.5, w₂=0.3. (h=1.0, ŷ=0.3)

Show answer

dL/dw₂ = dL/dŷ × dŷ/dw₂

dL/dŷ = -2(y - ŷ) = -2(1 - 0.3) = -1.4 dŷ/dw₂ = h = 1.0 (ŷ = w₂×h, derivative w.r.t. w₂ is h)

dL/dw₂ = -1.4 × 1.0 = -1.4

Gradient descent: w₂_new = 0.3 - 0.1 × (-1.4) = 0.3 + 0.14 = 0.44


Previous tutorial: Derivatives — Introduction ← Next tutorial: Partial Derivatives → Used in: Paper 03 — Backpropagation →