Automatic Differentiation Demystified

From dual numbers to backpropagation — the intuition, the trade-offs, and what breaks in practice for LLM training

MLSys
Author

Imad Dabbura

Published

February 3, 2024

Modified

May 10, 2025

The Derivative Engine Behind Every loss.backward()

Every time you train a neural network, something computes exact derivatives through millions of operations automatically. You call loss.backward() and gradients appear — but how? And why does training a 7B-parameter LLM consume 5x more GPU memory than running inference on it?

The answer to both questions is Automatic Differentiation (AD): a family of techniques for computing exact derivatives through arbitrary code, efficiently. Understanding it changes how you reason about memory budgets, gradient flow failures, and why certain training tricks (gradient checkpointing, mixed precision) exist at all.

There are two fundamentally different approaches — forward mode and reverse mode — and the choice between them explains why deep learning frameworks are built the way they are.

Why Not Just Use Calculus or Finite Differences?

Before getting to AD, it helps to understand what it replaced.

Numerical differentiation approximates the derivative using finite differences: \(f'(x) \approx \frac{f(x+h) - f(x)}{h}\) for some small \(h\). It’s dead simple but has two fatal flaws: it requires one extra forward pass per parameter (catastrophic for millions of parameters), and floating-point subtraction of nearly-equal numbers amplifies numerical error badly.

Symbolic differentiation (what a computer algebra system does) applies calculus rules to produce a closed-form derivative expression. It’s exact, but the resulting expressions grow exponentially with computation depth — a 100-layer network would produce a gradient expression no machine could reasonably evaluate.

AD is neither. It applies the chain rule mechanically at each elementary operation, accumulating intermediate values rather than symbolic expressions. The result is exact (to floating-point precision) and efficient — no expression explosion, no extra passes per parameter.

Three Ways to Differentiate Code
Method Accuracy Cost Practical for ML?
Numerical (finite diff) Approximate 1 extra pass per input ❌ Too slow
Symbolic Exact Expression explosion ❌ Intractable
AD — forward mode Exact 1 pass per input ⚠️ Only if few inputs
AD — reverse mode Exact 1 pass per output ✅ Standard choice

Forward Mode AD: Sensitivity Flowing Downstream

Forward mode AD propagates derivatives alongside values as computation flows from inputs to outputs. At each operation, it tracks not just the result but how sensitive that result is to a chosen input.

The elegant implementation uses dual numbers: instead of a scalar \(x\), carry a pair \((x,\ \dot{x})\) where \(\dot{x}\) represents the derivative of \(x\) with respect to some chosen input \(x_i\). Operations on dual numbers automatically propagate the derivative via the chain rule — you never write it explicitly:

\[f(a + b\varepsilon) \approx f(a) + f'(a)\cdot b\varepsilon \qquad (\varepsilon^2 = 0)\]

The \(\varepsilon\) coefficient carries the derivative forward through every arithmetic operation.

flowchart LR
    x1["x₁\n(x₁, ẋ₁=1)"] --> mul["×"]
    x2["x₂\n(x₂, ẋ₂=0)"] --> mul
    mul -->|"(x₁x₂, x₂·1)"| add["+"]
    x3["x₃\n(x₃, ẋ₃=0)"] --> add
    add -->|"(x₁x₂+x₃, x₂)"| L["L\n∂L/∂x₁ = x₂"]

Forward mode propagates (value, derivative) pairs from inputs to output. The derivative component tracks sensitivity w.r.t. one chosen input. Here, the seed is set for x₁, so x₂’s dot is 0.

The critical limitation: the initial seed vector — the \((0,\ldots,1,\ldots,0)\) that selects which input you’re differentiating with respect to — means one forward pass gives you the sensitivity with respect to one input. Getting gradients for all \(n\) inputs requires \(n\) passes.

For a 7B-parameter LLM, that’s 7 billion passes to compute a single gradient update. Forward mode is not the answer for ML.

When Forward Mode Wins

Forward mode is efficient when outputs greatly outnumber inputs — the opposite of ML. It shines in scientific computing: a simulation with 3 input parameters and 10,000 output metrics needs only 3 forward passes, not 10,000. In ML the ratio is reversed: millions of inputs (parameters), one output (scalar loss). Reverse mode exists to handle exactly this case.

Reverse Mode AD: Tracing Blame Upstream

Reverse mode flips the direction. Instead of asking “how does changing this input affect the output?”, it asks “how much did each node contribute to this output?”

The key insight: for a scalar output (a loss function), one backward pass distributes gradient credit back to every node in the graph simultaneously. One pass. All gradients.

flowchart TD
    subgraph fwd ["① Forward Pass — compute and store"]
        direction LR
        x["x"] --> mul["mul"] --> add["add"] --> L["L (scalar)"]
        w["w"] --> mul
        b["b"] --> add
    end
    subgraph bwd ["② Backward Pass — propagate gradients"]
        direction RL
        dL["∂L/∂L = 1"] --> dadd["∂L/∂add"] --> dmul["∂L/∂mul"]
        dmul --> dx["∂L/∂x"]
        dmul --> dw["∂L/∂w"]
        dadd --> db["∂L/∂b"]
    end
    fwd --> bwd

Reverse mode runs two phases: a forward pass that computes and stores all intermediate values, then a backward pass that propagates ∂L/∂· back to every node.

The Unavoidable Memory Cost

Here’s the catch. To compute gradients during the backward pass, each operation needs its inputs from the forward pass. For a mul node computing \(z = w \cdot x\), the backward step needs both \(w\) and \(x\) to distribute credit:

\[\frac{\partial L}{\partial w} = x \cdot \frac{\partial L}{\partial z}, \qquad \frac{\partial L}{\partial x} = w \cdot \frac{\partial L}{\partial z}\]

So the framework must keep every intermediate tensor alive until the backward pass consumes it. The consequence:

  • Inference: each layer’s activations can be discarded once the next layer is computed → memory is roughly \(O(1)\) in depth
  • Training: all activations must survive until their gradient is computed → memory is \(O(N)\) in depth

This is why training a transformer consumes so much more memory than running inference on it. At large batch sizes, forward activations alone can dwarf the parameter memory.

Why Your GPU OOMs During Training But Not Inference

During inference, each layer’s output overwrites the previous buffer — memory stays roughly constant regardless of model depth. During training, every layer’s output must survive until the backward pass reaches it. A 24-layer transformer holds 24 layers of activations simultaneously. Scale batch size by 4x and activation memory scales 4x too — parameters don’t budge, activations do. This is the first thing to check when you hit an OOM that doesn’t happen at inference time.

Gradient Checkpointing: Buying Memory Back with Compute

The standard solution to activation memory pressure is gradient checkpointing (also called activation recomputation): don’t store all activations during the forward pass. Store only at segment boundaries — checkpoints — and recompute intermediate activations on-the-fly during the backward pass when they’re needed.

flowchart LR
    subgraph s1 ["Segment 1"]
        L1["Layer 1"] --> L2["Layer 2"] --> L3["Layer 3"]
    end
    subgraph s2 ["Segment 2"]
        L4["Layer 4"] --> L5["Layer 5"] --> L6["Layer 6"]
    end
    s1 -->|"✓ checkpoint"| s2
    style L1 fill:#e8f5e9
    style L3 fill:#e8f5e9
    style L4 fill:#e8f5e9
    style L6 fill:#e8f5e9

Checkpointing stores activations only at segment boundaries (green). During backward, each segment re-runs its forward pass to recover the discarded intermediates.

Strategy Activation memory Compute overhead
No checkpointing \(O(N)\) layers None
\(\sqrt{N}\) checkpoints \(O(\sqrt{N})\) layers ~1 extra forward pass
Recompute everything \(O(1)\) Up to \(N\) extra forward passes

The sweet spot for most LLM training is \(\sqrt{N}\) checkpoints — roughly one extra forward pass in exchange for a meaningful memory reduction. This is what torch.utils.checkpoint.checkpoint_sequential implements.

The Trade-off, Stated Clearly

Forward Mode Reverse Mode
Passes needed 1 per input variable 1 per output variable
Best for Few inputs, many outputs Many inputs, few outputs (ML)
Memory overhead Low — no stored intermediates High — all intermediates stored
What frameworks use Occasionally for Jacobians Always for gradient-based training
The Jacobian Perspective

Forward mode naturally computes a Jacobian-vector product (JVP) — the full Jacobian multiplied by a chosen input direction. Reverse mode naturally computes a vector-Jacobian product (VJP) — a chosen output direction multiplied by the full Jacobian. For a scalar loss, the VJP with direction \([1]\) gives you the complete gradient vector in one pass. This is the mathematical reason reverse mode dominates ML training.

What Breaks in Practice

Gradient flow failures. In reverse mode, gradients are products of local Jacobians chained across all layers. If any factor is consistently small (saturating activations, poor initialization) or large (unbounded weights), the gradient signal degrades before reaching early layers. This is the vanishing/exploding gradient problem — it’s not specific to RNNs, it’s a structural property of deep reverse-mode computation.

Silent NaN propagation. A NaN anywhere in the forward pass propagates silently through the computation graph. During backward, every gradient flowing through the affected node becomes NaN, and the weight update corrupts the entire model. Use torch.autograd.set_detect_anomaly(True) to get a traceback pointing to the originating operation — invaluable for tracking these down.

In-place operations on tensors with gradients. In-place ops (e.g., x += 1) can modify a tensor that the backward pass expects to find unchanged. PyTorch raises a runtime error when it detects this, but the error message can be confusing. The fix is simple: avoid in-place ops on any tensor that requires gradients, or clone before modifying.

Key Takeaways

  1. AD is not numerical or symbolic differentiation. It applies the chain rule exactly at each elementary operation — no approximation, no expression explosion.

  2. Forward mode needs one pass per input; reverse mode needs one pass per output. For ML — scalar loss, millions of parameters — reverse mode wins unconditionally.

  3. The cost of reverse mode is memory. Every intermediate tensor from the forward pass must stay alive for the backward pass. This is the root cause of training using far more memory than inference.

  4. Gradient checkpointing trades compute for memory. Store only at segment boundaries, recompute the rest during backward. Expect roughly one extra forward pass overhead for a meaningful memory reduction.

  5. Most gradient problems are reverse-mode problems. Vanishing/exploding gradients, NaN propagation, and in-place op errors all stem from how reverse-mode AD chains local Jacobians through the computation graph. Understanding the mechanism is the fastest path to diagnosing them.

Back to top