Modified: August 23, 2022
automatic differentiation
This page is from my personal notes, and has not been specifically reviewed for public consumption. It might be incomplete, wrong, outdated, or stupid. Caveat lector.This is my stab at explaining automatic differentiation, specifically backprop and applications to neural nets.
A few dimensions to think about:
- implementations: fixed graph versus trace-based
- forward vs reverse and mixed-mode autodiff
Say we have a function that returns a scalar loss as a function of parameters ; we wish to compute the gradient of the loss with respect to the parameters (the function might also take other inputs and return other outputs, for example, in supervised learning the loss may measure the correspondence of some prediction function for some input with a true label , but for current purposes we will view these as constants and focus simply on the parameters to be updated). We will assume that this function is represented as a computational graph - a graph in which the nodes are computational primitives (things like addition, multiplication, exponentiation, etc.) with edges indicating how each computation depends on the outputs of other computations.
For intuition, let's start with the simplest possible case where the graph is a chain of (unary) operations, each with scalar input and output. Take
for , , , . How can we compute the derivative of with respect to ? We simply work backward, applying the chain rule from calculus one step at a time:
finding that the result is a product of local terms corresponding to the derivatives of each component function at their respective inputs. So we can implement this as:
- Run the computation forward and save the intermediate values at each node (these are ).
- Initialize the gradient as , representing the derivative of the output with respect to itself.
- Walk backwards through the chain, updating at each node , where was the input to that node on the forward pass.
This seems simple enough. But computation graphs are rarely pure chains, since binary operations like addition, subtraction, multiplication, etc. take multiple inputs.
Important points:
- intermediate values are reused. does not blow up, unlike naive symbolic differentiation.
- more efficient than finite-difference approximations by a factor of .
Advanced topics: multivariate values (Jacobian-vector products), cyclic graphs, programs with control flow, randomness, limitations of autodiff