Modified: January 02, 2024
memory efficient backprop
This page is from my personal notes, and has not been specifically reviewed for public consumption. It might be incomplete, wrong, outdated, or stupid. Caveat lector.Suppose we want to do automatic differentiation on a computational graph of sequential length . This could equally well be a multilayer network with layers, a recurrent model applied to a sequence of length , or the execution of a sequential algorithm like an inner-loop stochastic gradient descent (https://arxiv.org/abs/1502.03492). Naive backpropagation requires us to store the intermediate values (activations) at all layers, which might require a prohibitive quantity of memory. There are a few ways to get around this.
Invertibility: If the sequential steps are invertible then we can do a sequential backwards pass recomputing everything as we go using the inverse: https://arxiv.org/abs/1707.04585. This requires memory and computation.
This requires special network architectures which might be plausible for resnets (the case in the paper above). But it seems less promising for sequence models. In a sequence model we would also expect to need to hold on to the inputs at each step (since the successor state is a function of the previous state and the input), but these are generally of roughly the same size as the activations themselves, so at that point we're probably not saving very much?
Approximate invertibility: Generally we may not want our layers to be invertible. For example, the "memory" of an LSTM needs to be able to forget information. Steps that lose information can't be inverted directly.
LSTMs may already be hopeless due to the input issue above, but in other cases we may be able to salvage invertibility by only storing the lost information, which is much more efficient than storing the whole input.
For example, in gradient descent with momentum, we can invert a position update by simply subtracting the momentum term, and then use the gradient at that previous point to approximately invert the momentum update. Some small amount of information is lost due to finite precision, but this can be coded very efficiently: https://arxiv.org/abs/1502.03492 The resulting approach may still require linear memory in theory, but the practical benefits can be large.
Activation checkpointing: if we can't rely on invertibility, a generic method is to store the activations every steps, and recompute only between successive checkpoints. This was (maybe?) first proposed by Martens and Sutskever 2012 (https://www.cs.toronto.edu/~jmartens/docs/HF_book_chapter.pdf, section 7.1).
This stores a total of checkpoints plus memory to recompute the activations between the current pair of checkpoints, so still only memory overall. Each non-checkpointed activation gets recomputed so this uses a total of computation.
Recursive checkpointing: The preceding approach can be applied recursively. This was introduced by Chen et al. (https://arxiv.org/abs/1604.06174).
That is:
- We subdivide the sequence into chunks of length and on the forward pass we store the intermediate input to each chunk. Now we have a top-level compute graph of length and we are storing activations. We have done work to compute the forward pass.
- To do the backward step for each chunk, we split the sequence within the chunk into subchunks. We run the computation all the way forward from the beginning of the chunk (using compute) but again only store the intermediates, the inputs to each subchunk. Note we are still storing the toplevel activations and the sub-level activations, so only activations total. We have done work to compute the chunk forward pass. Aggregating this over all chunks (which we must process sequentially to keep the memory savings), we will do work at this level.
- Now we split the subchunks into sub-subchunks, and so on. We will end up with levels. At each level we store checkpoints, so we end up with total memory requirement. At each level we end up essentially doing a full forward pass (aggregated over all branches of the recursion), so we end up doing compute.
Concretely for and a sequence of length :
- The initial forward pass stores the activations and .
- To differentiate the second chunk, starting from the output gradient we do another forward pass starting at to compute the intermediate .
- Now we do a forward pass from to compute the intermediate . With this and we compute the gradients and . We can now forget the activation and the gradient .
- Now we do a forward pass from to compute the intermediate . With this and we compute the gradients and . We can now forget the activations and (having finished the higher-level chunk) and the gradient .
- Now that we have we apply a similar recursion to differentiate the first chunk. This produces as desired.
Of course there is flexibility to choose differently at different levels of the recursion (and even divide the sequence into uneven-length chunks, especially if the sequence itself contains heterogenous elements). This design space can be optimized to minimize computation for a given memory budget: https://dl.acm.org/doi/10.5555/3157382.3157559.
Does this (recursive) chunking approach work for sequence models? We don't have to invert anything, but we still need inputs in order to run the forward pass. The case I'm worried about is where we have multiple sequence layers, like in a Mamba-style model but with LSTMs, say. We can store the LSTM state at all layers for a given sequence position and this is (by definition of the Markov assumption) enough to run a forward pass to new sequence positions. We still have to store the 'input' but this is just a token so it's not a huge deal. I wonder if this approach can give us something competitive with SSMs but more flexible.