194 Automatic Differentiation
Automatic differentiation, often abbreviated as autodiff or AD, is the algorithmic machinery that computes exact derivatives of functions expressed as computer programs. It sits at the foundation of nearly every modern machine learning system, because training a model by gradient descent requires the gradient of a scalar loss with respect to millions or billions of parameters. This chapter develops the subject rigorously. We distinguish autodiff from its two siblings, numerical and symbolic differentiation, formalize the two principal modes of accumulation, introduce dual numbers and the Wengert tape as concrete computational devices, and examine how production frameworks realize these ideas.
194.1 1. Three Ways to Differentiate
Given a function \(f: \mathbb{R}^n \to \mathbb{R}^m\) implemented as code, there are three broadly distinct strategies for obtaining its derivatives.
194.1.1 1.1 Numerical Differentiation
Numerical differentiation approximates derivatives using finite differences. The forward difference estimate of a partial derivative is
\[ \frac{\partial f}{\partial x_i} \approx \frac{f(x + h\,e_i) - f(x)}{h}, \]
where \(e_i\) is the \(i\)th unit vector and \(h\) is a small step. This method is trivial to implement and requires only the ability to evaluate \(f\). It suffers from two compounding errors. Truncation error arises because the difference quotient drops higher order Taylor terms, contributing an error of order \(\mathcal{O}(h)\) for the forward difference and \(\mathcal{O}(h^2)\) for the central difference. Roundoff error grows as \(h\) shrinks, because subtracting two nearly equal floating point numbers loses significant digits. The optimal step balances these effects near \(h \approx \sqrt{\epsilon_{\text{mach}}}\) for the forward scheme, yielding only about half the available digits of accuracy. Worse, computing a full gradient in \(\mathbb{R}^n\) requires at least \(n + 1\) evaluations of \(f\), which is prohibitive when \(n\) is large.
194.1.2 1.2 Symbolic Differentiation
Symbolic differentiation manipulates expressions algebraically, applying differentiation rules to produce a closed form derivative expression, as computer algebra systems such as Mathematica or SymPy do. It is exact in the sense of producing no approximation error. Its weakness is expression swell. Consider repeated products. By the product rule, differentiating \(\prod_{i=1}^{k} g_i(x)\) produces a sum of \(k\) terms, and nested compositions cause the symbolic form to grow exponentially in the depth of the expression. Symbolic differentiation also struggles to handle control flow, iteration, and general program constructs, because a program is not an algebraic expression. It produces a formula, not an efficient evaluation procedure, and the two can differ enormously in cost.
194.1.3 1.3 Automatic Differentiation
Automatic differentiation occupies a productive middle ground. Like symbolic differentiation it computes exact derivatives, accurate to machine precision with no truncation error. Like numerical differentiation it operates on the program itself rather than on a closed form expression, so it handles loops, branches, and function calls naturally. The key insight is that any program, no matter how elaborate, is ultimately a composition of elementary operations such as addition, multiplication, and transcendental functions, each of which has a known derivative. Autodiff applies the chain rule mechanically to this composition, propagating derivative information alongside the ordinary computation. It never builds a monolithic symbolic expression. Instead it evaluates derivatives numerically at a specific input while preserving exactness, reusing intermediate results to avoid the redundancy that plagues both alternatives.
194.2 2. The Chain Rule on a Computational Graph
Every numerical program can be decomposed into a sequence of primitive assignments. Suppose we evaluate \(y = f(x_1, x_2)\) through intermediate variables. A representative trace, sometimes called the Wengert list after R. E. Wengert, records each elementary step in order:
v1 = x1
v2 = x2
v3 = v1 * v2
v4 = sin(v1)
v5 = v3 + v4
y = v5
This trace is equivalent to a directed acyclic graph in which nodes are variables and edges connect each variable to the operands that produced it. Differentiation is the systematic application of the chain rule across this graph. For a composition \(y = g(h(x))\), the chain rule gives \(\frac{dy}{dx} = \frac{dy}{dh} \cdot \frac{dh}{dx}\). The two modes of automatic differentiation correspond to two orders in which one can multiply the chain of local Jacobians.
Let the full function be a composition \(f = f_L \circ f_{L-1} \circ \cdots \circ f_1\). Its Jacobian is the matrix product
\[ J_f = J_{f_L} \cdot J_{f_{L-1}} \cdots J_{f_1}. \]
Matrix multiplication is associative, so we may evaluate this product from either end. Multiplying right to left is forward mode. Multiplying left to right is reverse mode. The choice of association has no effect on the result but a decisive effect on computational cost.
194.3 3. Forward Mode
Forward mode accumulates derivatives in the same order as the original computation, propagating from inputs toward outputs. With each intermediate variable \(v_i\) we associate a tangent \(\dot{v}_i = \frac{\partial v_i}{\partial x}\) representing the sensitivity of \(v_i\) to a chosen input direction. As we evaluate each primitive we simultaneously evaluate its tangent using the local derivative rule. For the trace above, seeding \(\dot{x}_1 = 1\) and \(\dot{x}_2 = 0\) propagates the partial derivative with respect to \(x_1\):
\[ \dot{v}_3 = \dot{v}_1 v_2 + v_1 \dot{v}_2, \qquad \dot{v}_4 = \cos(v_1)\,\dot{v}_1, \qquad \dot{v}_5 = \dot{v}_3 + \dot{v}_4. \]
A single forward pass computes the directional derivative \(J_f \, v\) for the seed direction \(v\). To obtain the full Jacobian of \(f: \mathbb{R}^n \to \mathbb{R}^m\) one runs \(n\) passes, one per input coordinate. Forward mode is therefore efficient when \(n\) is small and \(m\) is large, that is, for tall Jacobians. Its cost is proportional to \(n\) times the cost of evaluating \(f\), with a small constant factor.
194.3.1 3.1 Dual Numbers
Dual numbers give forward mode an elegant algebraic realization. A dual number has the form \(a + b\,\epsilon\), where \(a\) and \(b\) are real and \(\epsilon\) is a nilpotent symbol satisfying \(\epsilon^2 = 0\) with \(\epsilon \neq 0\). This single rule encodes first order Taylor expansion exactly. Arithmetic follows naturally:
\[ (a + b\epsilon) + (c + d\epsilon) = (a + c) + (b + d)\epsilon, \]
\[ (a + b\epsilon)(c + d\epsilon) = ac + (ad + bc)\epsilon, \]
where the \(\epsilon^2\) term vanishes by nilpotency. The remarkable consequence is that evaluating any analytic function on a dual argument produces the function value in the real part and its derivative in the dual part. Expanding \(f(a + b\epsilon)\) as a Taylor series and discarding terms beyond first order gives
\[ f(a + b\epsilon) = f(a) + b\,f'(a)\,\epsilon. \]
Seeding the input as \(x + 1\cdot\epsilon\) thus yields \(f(x) + f'(x)\,\epsilon\) in one evaluation. The dual part carries the derivative for free. A minimal implementation overloads the elementary operations.
class Dual:
def __init__(self, real, dual):
self.real = real
self.dual = dual
def __mul__(self, other):
return Dual(self.real * other.real,
self.real * other.dual + self.dual * other.real)Each primitive carries its local derivative rule, and the chain rule emerges automatically from operator composition. Dual numbers extend to multiple inputs by attaching a vector of partials, or by repeating the evaluation once per input direction.
194.4 4. Reverse Mode
Reverse mode accumulates derivatives in the opposite order, propagating from outputs back toward inputs. With each intermediate variable \(v_i\) we associate an adjoint \(\bar{v}_i = \frac{\partial y}{\partial v_i}\), the sensitivity of the chosen output to that variable. Reverse mode proceeds in two phases. A forward phase evaluates the function and records the trace and the intermediate values. A backward phase then traverses the trace in reverse, propagating adjoints by summing contributions through every outgoing edge:
\[ \bar{v}_i = \sum_{j : v_i \to v_j} \bar{v}_j \, \frac{\partial v_j}{\partial v_i}. \]
For our running example, seeding \(\bar{y} = 1\) and walking backward gives \(\bar{v}_5 = 1\), then \(\bar{v}_3 = \bar{v}_5\) and \(\bar{v}_4 = \bar{v}_5\), then \(\bar{v}_1 = \bar{v}_3 v_2 + \bar{v}_4 \cos(v_1)\) and \(\bar{v}_2 = \bar{v}_3 v_1\). A single backward pass computes the entire gradient \(\nabla f\) of a scalar output with respect to all \(n\) inputs at once.
This is the central economic fact of deep learning. For a function \(f: \mathbb{R}^n \to \mathbb{R}\) with one scalar output, reverse mode obtains the complete gradient at a cost of roughly two to four times a single function evaluation, independent of \(n\). For a neural network with billions of parameters this independence from \(n\) is what makes training feasible. Reverse mode is efficient when \(m\) is small and \(n\) is large, that is, for wide Jacobians, which is exactly the regime of loss functions. The backpropagation algorithm familiar from neural networks is precisely reverse mode automatic differentiation applied to the layered composition of a network.
194.4.1 4.1 The Cost Asymmetry
The complementary strengths of the two modes follow directly from the shape of the Jacobian. To compute \(J_f\) for \(f: \mathbb{R}^n \to \mathbb{R}^m\), forward mode costs about \(n\) function evaluations and reverse mode costs about \(m\). When \(n \gg m\), as in machine learning where the loss is scalar, reverse mode wins decisively. When \(m \gg n\), forward mode wins. The price reverse mode pays for its favorable scaling in \(n\) is memory. Because adjoints flow backward, every intermediate value needed by a local derivative rule must be retained from the forward phase until the backward phase consumes it.
194.5 5. The Wengert Tape
Reverse mode requires a record of the computation so the backward pass can replay it. This record is the Wengert tape, also called the trace or the tape. The tape stores, in evaluation order, each primitive operation, references to its inputs, and whatever intermediate values the corresponding backward rule will need. Conceptually each tape entry is a closure that knows how to push adjoints from its output to its inputs.
# A tape entry records the local backward rule.
def multiply_backward(grad_out, a, b):
return grad_out * b, grad_out * a # gradients wrt a and bDuring the backward pass the framework iterates the tape in reverse, invoking each entry’s backward rule and accumulating the resulting adjoints into the inputs. Memory consumption is the salient cost. The tape grows with the number of operations, and for deep networks the stored activations dominate memory use during training.
194.5.1 5.1 Checkpointing
Gradient checkpointing trades computation for memory. Rather than storing every intermediate value, one stores only a sparse set of checkpoints and recomputes the missing intermediate values during the backward pass by re-running the forward computation from the nearest checkpoint. For a sequential computation of length \(L\), storing \(\mathcal{O}(\sqrt{L})\) checkpoints reduces peak memory from \(\mathcal{O}(L)\) to \(\mathcal{O}(\sqrt{L})\) at the cost of one extra forward evaluation. This technique is indispensable for training very deep or very long sequence models that would otherwise exhaust accelerator memory.
194.6 6. Higher Order and Mixed Modes
Derivatives of derivatives follow by composing the modes. The Hessian vector product \(H v\), central to second order optimization and to curvature estimation, is computed efficiently by applying forward mode over reverse mode. One first builds the reverse mode gradient, then propagates a forward tangent through that gradient computation, obtaining \(H v\) at a cost comparable to a few gradient evaluations and without ever forming the full \(n \times n\) Hessian. The general principle is that any composition of forward and reverse passes yields the corresponding higher order object, and the optimal choice of modes again depends on the shapes involved.
194.7 7. Implementation Strategies in Modern Frameworks
Frameworks realize automatic differentiation through two broad architectural choices, with a third hybrid approach emerging.
194.7.1 7.1 Operator Overloading and Dynamic Graphs
Operator overloading, the approach taken by PyTorch through its autograd engine, builds the tape dynamically as the program runs. Tensor operations are overloaded so that each one, in addition to computing its result, appends an entry to a tape recording how to propagate gradients backward. The graph is therefore defined by execution, which is why this style is called define by run. Its great advantage is flexibility. Ordinary control flow, dynamic shapes, and data dependent branching all work transparently, because the tape simply records whatever operations actually executed. The cost is interpreter overhead and fewer opportunities for whole program optimization, since the graph is not known until it has already run.
# PyTorch records operations as they execute.
x = torch.tensor([2.0], requires_grad=True)
y = (x * x + torch.sin(x)).sum()
y.backward() # reverse pass populates x.grad194.7.2 7.2 Source Transformation and Static Graphs
Source transformation analyzes the program and generates new code that computes derivatives. Early TensorFlow used a static define and run model in which the user first constructs a graph and then executes it, allowing the framework to optimize the entire graph before running. JAX takes a modern transformation based view. It traces a function to an intermediate representation and then applies functional transformations such as grad, jvp for forward mode Jacobian vector products, and vjp for reverse mode vector Jacobian products. Because the whole computation is captured before execution, the compiler can fuse operations, eliminate dead code, and target accelerators aggressively through a backend such as XLA.
# JAX exposes differentiation as a function transformation.
import jax
df = jax.grad(lambda x: x ** 3 + jax.numpy.sin(x))
df(2.0) # returns 3*x**2 + cos(x) at x = 2194.7.3 7.3 Tradeoffs
The dynamic approach favors expressiveness and ease of debugging, since the program is plain imperative code. The static or transformation approach favors performance and deployment, since the captured graph can be compiled and optimized as a whole. Contemporary systems blur the boundary. PyTorch added tracing and graph capture facilities to obtain compiled performance from define by run code, while JAX retains the feel of ordinary functions despite compiling them. The convergence reflects a shared goal of offering imperative ergonomics with compiled efficiency.
194.8 8. Subtleties and Correctness
Automatic differentiation is exact only where the underlying function is differentiable. Several practical concerns deserve mention. Nonsmooth primitives such as the absolute value or the rectified linear unit are not differentiable at isolated points, and frameworks adopt conventions, for instance assigning a subgradient value at the kink. Control flow is differentiated along the branch actually taken, so the computed derivative is valid only locally and may be discontinuous across branch boundaries. In place mutation of buffers can invalidate stored intermediate values that the backward pass relies on, which is why frameworks track versions and raise errors when a needed value has been overwritten. Finally, numerical stability of the derivative is a separate matter from numerical stability of the function. A primitive such as the logarithm of a sum of exponentials must be implemented with a numerically stable backward rule to avoid overflow and catastrophic cancellation in the gradient.
194.9 9. Summary
Automatic differentiation computes exact derivatives of programs by mechanically applying the chain rule to the elementary operations a program performs. It avoids the approximation error of finite differences and the expression swell of symbolic differentiation, while reusing intermediate values for efficiency. Forward mode propagates tangents from inputs to outputs and is realized cleanly with dual numbers, excelling when inputs are few. Reverse mode propagates adjoints from outputs to inputs using a recorded tape, excelling when inputs are many and outputs few, which is the regime of machine learning loss functions and the reason backpropagation is simply reverse mode applied to neural networks. Modern frameworks implement these ideas through operator overloading with dynamic tapes or through source transformation with static or traced graphs, and increasingly combine both to deliver imperative flexibility together with compiled performance.
194.10 References
- Baydin, A. G., Pearlmutter, B. A., Radul, A. A., and Siskind, J. M. “Automatic Differentiation in Machine Learning: a Survey.” Journal of Machine Learning Research, 2018. https://www.jmlr.org/papers/volume18/17-468/17-468.pdf
- Griewank, A., and Walther, A. “Evaluating Derivatives: Principles and Techniques of Algorithmic Differentiation.” SIAM, 2008. https://epubs.siam.org/doi/book/10.1137/1.9780898717761
- Wengert, R. E. “A Simple Automatic Derivative Evaluation Program.” Communications of the ACM, 1964. https://dl.acm.org/doi/10.1145/355586.364791
- Rumelhart, D. E., Hinton, G. E., and Williams, R. J. “Learning Representations by Back-Propagating Errors.” Nature, 1986. https://www.nature.com/articles/323533a0
- Paszke, A., et al. “PyTorch: An Imperative Style, High-Performance Deep Learning Library.” NeurIPS, 2019. https://papers.nips.cc/paper/9015-pytorch-an-imperative-style-high-performance-deep-learning-library.pdf
- Bradbury, J., et al. “JAX: Composable Transformations of Python and NumPy Programs.” 2018. https://github.com/jax-ml/jax
- Chen, T., Xu, B., Zhang, C., and Guestrin, C. “Training Deep Nets with Sublinear Memory Cost.” 2016. https://arxiv.org/abs/1604.06174
- Pearlmutter, B. A. “Fast Exact Multiplication by the Hessian.” Neural Computation, 1994. https://doi.org/10.1162/neco.1994.6.1.147