automatic differentiation: Nonlinear Function
Created: August 23, 2022
Modified: August 23, 2022

automatic differentiation

This page is from my personal notes, and has not been specifically reviewed for public consumption. It might be incomplete, wrong, outdated, or stupid. Caveat lector.

This is my stab at explaining automatic differentiation, specifically backprop and applications to neural nets.

A few dimensions to think about:

  • implementations: fixed graph versus trace-based
  • forward vs reverse and mixed-mode autodiff

Say we have a function (θ)\ell(\theta) that returns a scalar loss as a function of parameters θ\theta; we wish to compute the gradient of the loss with respect to the parameters (the function might also take other inputs and return other outputs, for example, in supervised learning the loss may measure the correspondence of some prediction function fθ(x)f_\theta(x) for some input xx with a true label yy, but for current purposes we will view these as constants and focus simply on the parameters to be updated). We will assume that this function is represented as a computational graph - a graph in which the nodes are computational primitives (things like addition, multiplication, exponentiation, etc.) with edges indicating how each computation depends on the outputs of other computations.

For intuition, let's start with the simplest possible case where the graph is a chain of (unary) operations, each with scalar input and output. Take

(θ)=11eθ=f1(f2(f3(f4(θ))))\ell(\theta) = \frac{1}{1 - e^{-\theta}} = f_1(f_2(f_3(f_4(\theta))))

for f1(x)=1xf_1(x) = \frac{1}{x}, f2(x)=1xf_2(x) = 1 - x, f3(x)=exf_3(x) = e^x, f4(x)=xf_4(x) = -x. How can we compute the derivative of (θ)\ell(\theta) with respect to θ\theta? We simply work backward, applying the chain rule from calculus one step at a time:

(θ)=f1(g1(θ))g1(θ) where g1(θ)=f2(f3(f4(θ)))=f1(g1(θ))f2(g2(θ))g2(θ) where g2(θ)=f3(f4(θ))=f1(g1(θ))f2(g2(θ))f3(g3(θ))g3(θ) where g3(θ)=f4(θ)=f1(f2(f3(f4(θ))))f2(f3(f4(θ)))f3(f4(θ))f4(θ)\begin{align*} \ell'(\theta) &= f_1'( g_1(\theta) ) \cdot g_1'(\theta)\\ & \text{ where } g_1(\theta) = f_2(f_3(f_4(\theta)))\\ &=f_1'( g_1(\theta) ) \cdot f_2'(g_2(\theta)) \cdot g_2'(\theta)\\ & \text{ where } g_2(\theta) = f_3(f_4(\theta))\\ &=f_1'( g_1(\theta) ) \cdot f_2'(g_2(\theta)) \cdot f_3'(g_3(\theta))\cdot g_3'(\theta)\\ & \text{ where } g_3(\theta) = f_4(\theta)\\ &=f_1'( f_2(f_3(f_4(\theta))) ) \cdot f_2'(f_3(f_4(\theta))) \cdot f_3'(f_4(\theta))\cdot f_4'(\theta)\\ \end{align*}

finding that the result is a product of local terms corresponding to the derivatives of each component function at their respective inputs. So we can implement this as:

  1. Run the computation forward and save the intermediate values at each node (these are g1,g2,g3g_1, g_2, g_3).
  2. Initialize the gradient as =1\nabla = 1, representing the derivative of the output with respect to itself.
  3. Walk backwards through the chain, updating fi(gi)\nabla \leftarrow f'_i(g_i) \cdot \nabla at each node fif_i, where gig_i was the input to that node on the forward pass.

This seems simple enough. But computation graphs are rarely pure chains, since binary operations like addition, subtraction, multiplication, etc. take multiple inputs.

Important points:

  • intermediate values are reused. does not blow up, unlike naive symbolic differentiation.
  • more efficient than finite-difference approximations by a factor of dd.

Advanced topics: multivariate values (Jacobian-vector products), cyclic graphs, programs with control flow, randomness, limitations of autodiff