flowchart LR
x1["x₁\n(x₁, ẋ₁=1)"] --> mul["×"]
x2["x₂\n(x₂, ẋ₂=0)"] --> mul
mul -->|"(x₁x₂, x₂·1)"| add["+"]
x3["x₃\n(x₃, ẋ₃=0)"] --> add
add -->|"(x₁x₂+x₃, x₂)"| L["L\n∂L/∂x₁ = x₂"]
The Derivative Engine Behind Every loss.backward()
Every time you train a neural network, something computes exact derivatives through millions of operations automatically. You call loss.backward() and gradients appear — but how? And why does training a 7B-parameter LLM consume 5x more GPU memory than running inference on it?
The answer to both questions is Automatic Differentiation (AD): a family of techniques for computing exact derivatives through arbitrary code, efficiently. Understanding it changes how you reason about memory budgets, gradient flow failures, and why certain training tricks (gradient checkpointing, mixed precision) exist at all.
There are two fundamentally different approaches — forward mode and reverse mode — and the choice between them explains why deep learning frameworks are built the way they are.
Why Not Just Use Calculus or Finite Differences?
Before getting to AD, it helps to understand what it replaced.
Numerical differentiation approximates the derivative using finite differences: \(f'(x) \approx \frac{f(x+h) - f(x)}{h}\) for some small \(h\). It’s dead simple but has two fatal flaws: it requires one extra forward pass per parameter (catastrophic for millions of parameters), and floating-point subtraction of nearly-equal numbers amplifies numerical error badly.
Symbolic differentiation (what a computer algebra system does) applies calculus rules to produce a closed-form derivative expression. It’s exact, but the resulting expressions grow exponentially with computation depth — a 100-layer network would produce a gradient expression no machine could reasonably evaluate.
AD is neither. It applies the chain rule mechanically at each elementary operation, accumulating intermediate values rather than symbolic expressions. The result is exact (to floating-point precision) and efficient — no expression explosion, no extra passes per parameter.
| Method | Accuracy | Cost | Practical for ML? |
|---|---|---|---|
| Numerical (finite diff) | Approximate | 1 extra pass per input | ❌ Too slow |
| Symbolic | Exact | Expression explosion | ❌ Intractable |
| AD — forward mode | Exact | 1 pass per input | ⚠️ Only if few inputs |
| AD — reverse mode | Exact | 1 pass per output | ✅ Standard choice |
Forward Mode AD: Sensitivity Flowing Downstream
Forward mode AD propagates derivatives alongside values as computation flows from inputs to outputs. At each operation, it tracks not just the result but how sensitive that result is to a chosen input.
The elegant implementation uses dual numbers: instead of a scalar \(x\), carry a pair \((x,\ \dot{x})\) where \(\dot{x}\) represents the derivative of \(x\) with respect to some chosen input \(x_i\). Operations on dual numbers automatically propagate the derivative via the chain rule — you never write it explicitly:
\[f(a + b\varepsilon) \approx f(a) + f'(a)\cdot b\varepsilon \qquad (\varepsilon^2 = 0)\]
The \(\varepsilon\) coefficient carries the derivative forward through every arithmetic operation.
The critical limitation: the initial seed vector — the \((0,\ldots,1,\ldots,0)\) that selects which input you’re differentiating with respect to — means one forward pass gives you the sensitivity with respect to one input. Getting gradients for all \(n\) inputs requires \(n\) passes.
For a 7B-parameter LLM, that’s 7 billion passes to compute a single gradient update. Forward mode is not the answer for ML.
Forward mode is efficient when outputs greatly outnumber inputs — the opposite of ML. It shines in scientific computing: a simulation with 3 input parameters and 10,000 output metrics needs only 3 forward passes, not 10,000. In ML the ratio is reversed: millions of inputs (parameters), one output (scalar loss). Reverse mode exists to handle exactly this case.
Reverse Mode AD: Tracing Blame Upstream
Reverse mode flips the direction. Instead of asking “how does changing this input affect the output?”, it asks “how much did each node contribute to this output?”
The key insight: for a scalar output (a loss function), one backward pass distributes gradient credit back to every node in the graph simultaneously. One pass. All gradients.
flowchart TD
subgraph fwd ["① Forward Pass — compute and store"]
direction LR
x["x"] --> mul["mul"] --> add["add"] --> L["L (scalar)"]
w["w"] --> mul
b["b"] --> add
end
subgraph bwd ["② Backward Pass — propagate gradients"]
direction RL
dL["∂L/∂L = 1"] --> dadd["∂L/∂add"] --> dmul["∂L/∂mul"]
dmul --> dx["∂L/∂x"]
dmul --> dw["∂L/∂w"]
dadd --> db["∂L/∂b"]
end
fwd --> bwd
Gradient Checkpointing: Buying Memory Back with Compute
The standard solution to activation memory pressure is gradient checkpointing (also called activation recomputation): don’t store all activations during the forward pass. Store only at segment boundaries — checkpoints — and recompute intermediate activations on-the-fly during the backward pass when they’re needed.
flowchart LR
subgraph s1 ["Segment 1"]
L1["Layer 1"] --> L2["Layer 2"] --> L3["Layer 3"]
end
subgraph s2 ["Segment 2"]
L4["Layer 4"] --> L5["Layer 5"] --> L6["Layer 6"]
end
s1 -->|"✓ checkpoint"| s2
style L1 fill:#e8f5e9
style L3 fill:#e8f5e9
style L4 fill:#e8f5e9
style L6 fill:#e8f5e9
| Strategy | Activation memory | Compute overhead |
|---|---|---|
| No checkpointing | \(O(N)\) layers | None |
| \(\sqrt{N}\) checkpoints | \(O(\sqrt{N})\) layers | ~1 extra forward pass |
| Recompute everything | \(O(1)\) | Up to \(N\) extra forward passes |
The sweet spot for most LLM training is \(\sqrt{N}\) checkpoints — roughly one extra forward pass in exchange for a meaningful memory reduction. This is what torch.utils.checkpoint.checkpoint_sequential implements.
The Trade-off, Stated Clearly
| Forward Mode | Reverse Mode | |
|---|---|---|
| Passes needed | 1 per input variable | 1 per output variable |
| Best for | Few inputs, many outputs | Many inputs, few outputs (ML) |
| Memory overhead | Low — no stored intermediates | High — all intermediates stored |
| What frameworks use | Occasionally for Jacobians | Always for gradient-based training |
Forward mode naturally computes a Jacobian-vector product (JVP) — the full Jacobian multiplied by a chosen input direction. Reverse mode naturally computes a vector-Jacobian product (VJP) — a chosen output direction multiplied by the full Jacobian. For a scalar loss, the VJP with direction \([1]\) gives you the complete gradient vector in one pass. This is the mathematical reason reverse mode dominates ML training.
What Breaks in Practice
Gradient flow failures. In reverse mode, gradients are products of local Jacobians chained across all layers. If any factor is consistently small (saturating activations, poor initialization) or large (unbounded weights), the gradient signal degrades before reaching early layers. This is the vanishing/exploding gradient problem — it’s not specific to RNNs, it’s a structural property of deep reverse-mode computation.
Silent NaN propagation. A NaN anywhere in the forward pass propagates silently through the computation graph. During backward, every gradient flowing through the affected node becomes NaN, and the weight update corrupts the entire model. Use torch.autograd.set_detect_anomaly(True) to get a traceback pointing to the originating operation — invaluable for tracking these down.
In-place operations on tensors with gradients. In-place ops (e.g., x += 1) can modify a tensor that the backward pass expects to find unchanged. PyTorch raises a runtime error when it detects this, but the error message can be confusing. The fix is simple: avoid in-place ops on any tensor that requires gradients, or clone before modifying.
Key Takeaways
AD is not numerical or symbolic differentiation. It applies the chain rule exactly at each elementary operation — no approximation, no expression explosion.
Forward mode needs one pass per input; reverse mode needs one pass per output. For ML — scalar loss, millions of parameters — reverse mode wins unconditionally.
The cost of reverse mode is memory. Every intermediate tensor from the forward pass must stay alive for the backward pass. This is the root cause of training using far more memory than inference.
Gradient checkpointing trades compute for memory. Store only at segment boundaries, recompute the rest during backward. Expect roughly one extra forward pass overhead for a meaningful memory reduction.
Most gradient problems are reverse-mode problems. Vanishing/exploding gradients, NaN propagation, and in-place op errors all stem from how reverse-mode AD chains local Jacobians through the computation graph. Understanding the mechanism is the fastest path to diagnosing them.