memory efficient backprop: Nonlinear Function
Created: January 02, 2024
Modified: January 02, 2024

memory efficient backprop

This page is from my personal notes, and has not been specifically reviewed for public consumption. It might be incomplete, wrong, outdated, or stupid. Caveat lector.

Suppose we want to do automatic differentiation on a computational graph of sequential length TT. This could equally well be a multilayer network with TT layers, a recurrent model applied to a sequence of length TT, or the execution of a sequential algorithm like an inner-loop stochastic gradient descent (https://arxiv.org/abs/1502.03492). Naive backpropagation requires us to store the intermediate values (activations) at all TT layers, which might require a prohibitive quantity of memory. There are a few ways to get around this.

Invertibility: If the sequential steps are invertible then we can do a sequential backwards pass recomputing everything as we go using the inverse: https://arxiv.org/abs/1707.04585. This requires O(1)O(1) memory and O(T)O(T) computation.

This requires special network architectures which might be plausible for resnets (the case in the paper above). But it seems less promising for sequence models. In a sequence model we would also expect to need to hold on to the inputs at each step (since the successor state is a function of the previous state and the input), but these are generally of roughly the same size as the activations themselves, so at that point we're probably not saving very much?

Approximate invertibility: Generally we may not want our layers to be invertible. For example, the "memory" of an LSTM needs to be able to forget information. Steps that lose information can't be inverted directly.

LSTMs may already be hopeless due to the input issue above, but in other cases we may be able to salvage invertibility by only storing the lost information, which is much more efficient than storing the whole input.

For example, in gradient descent with momentum, we can invert a position update by simply subtracting the momentum term, and then use the gradient at that previous point to approximately invert the momentum update. Some small amount of information is lost due to finite precision, but this can be coded very efficiently: https://arxiv.org/abs/1502.03492 The resulting approach may still require linear memory in theory, but the practical benefits can be large.

Activation checkpointing: if we can't rely on invertibility, a generic method is to store the activations every T\sqrt{T} steps, and recompute only between successive checkpoints. This was (maybe?) first proposed by Martens and Sutskever 2012 (https://www.cs.toronto.edu/~jmartens/docs/HF_book_chapter.pdf, section 7.1).

This stores a total of T\sqrt{T} checkpoints plus memory to recompute the T\sqrt{T} activations between the current pair of checkpoints, so still only O(T)O(\sqrt{T}) memory overall. Each non-checkpointed activation gets recomputed so this uses a total of O(T)O(T) computation.

Recursive checkpointing: The preceding approach can be applied recursively. This was introduced by Chen et al. (https://arxiv.org/abs/1604.06174).

That is:

  1. We subdivide the sequence into kk chunks of length T/kT/k and on the forward pass we store the intermediate input to each chunk. Now we have a top-level compute graph of length kk and we are storing kk activations. We have done O(T)O(T) work to compute the forward pass.
  2. To do the backward step for each chunk, we split the sequence within the chunk into kk subchunks. We run the computation all the way forward from the beginning of the chunk (using T/kT/k compute) but again only store the kk intermediates, the inputs to each subchunk. Note we are still storing the kk toplevel activations and the kk sub-level activations, so only 2k2k activations total. We have done O(T/k)O(T/k) work to compute the chunk forward pass. Aggregating this over all chunks (which we must process sequentially to keep the memory savings), we will do O(T)O(T) work at this level.
  3. Now we split the subchunks into sub-subchunks, and so on. We will end up with logT\log T levels. At each level we store kk checkpoints, so we end up with klogTk\log T total memory requirement. At each level we end up essentially doing a full forward pass (aggregated over all branches of the recursion), so we end up doing TlogTT\log T compute.

Concretely for k=2k=2 and a sequence of length T=8T=8:

  1. The initial forward pass stores the activations a(0)a(0) and a(4)a(4).
  2. To differentiate the second chunk, starting from the output gradient g(8)=da(8)da(7)g(8) = \frac{da(8)}{da(7)} we do another forward pass starting at t=4t=4 to compute the intermediate a(6)a(6).
    1. Now we do a forward pass from t=6t=6 to compute the intermediate a(7)a(7). With this and a(6)a(6) we compute the gradients g(7)g(7) and g(6)g(6). We can now forget the activation a(6)a(6) and the gradient g(7)g(7).
    2. Now we do a forward pass from t=4t=4 to compute the intermediate a(5)a(5). With this and a(4)a(4) we compute the gradients g(5)g(5) and g(4)g(4). We can now forget the activations a(5)a(5) and a(4)a(4) (having finished the higher-level chunk) and the gradient g(5)g(5).
  3. Now that we have g(4)g(4) we apply a similar recursion to differentiate the first chunk. This produces g(0)g(0) as desired.

Of course there is flexibility to choose kk differently at different levels of the recursion (and even divide the sequence into uneven-length chunks, especially if the sequence itself contains heterogenous elements). This design space can be optimized to minimize computation for a given memory budget: https://dl.acm.org/doi/10.5555/3157382.3157559.

Does this (recursive) chunking approach work for sequence models? We don't have to invert anything, but we still need inputs in order to run the forward pass. The case I'm worried about is where we have multiple sequence layers, like in a Mamba-style model but with LSTMs, say. We can store the LSTM state at all layers for a given sequence position and this is (by definition of the Markov assumption) enough to run a forward pass to new sequence positions. We still have to store the 'input' but this is just a token so it's not a huge deal. I wonder if this approach can give us something competitive with SSMs but more flexible.