Modified: October 04, 2023
transformer parallelization math
This page is from my personal notes, and has not been specifically reviewed for public consumption. It might be incomplete, wrong, outdated, or stupid. Caveat lector.What does the computational profile of a transformer vs a similar RNN look like?
First, the transformer.
Let's take the LLama 6.7B model, with:
- d=4096
- 32 attention heads
- 32 layers
and assume this means we have 4096/32 = 128
dims per head? And maybe d_ff
of 16384? So then matrices of size
- W_k, W_v, W_q:
[D, 3D]
- W_o:
[D, D]
- W_ff:
[D, 4D]
, W_ff2:[4D, D]
which gives 3D^2 + D^2 + 4D^2 + 4D^2 = 12D^2 = 12 * 16777216 = 201M
params per layer. And multiplying by 32 gives 6.4B params total so we're roughly in the ballpark (after accounting for biases, embeddings, modified attention mechanisms, etc).
Suppose we also have embeddings for 50000 tokens (wild guess), then this is a [50000, 4096]
matrix with approximately 209M
params.
Now in a forward pass we have activations per position of 4096 * 32 = 131K
. So with a context window of size 2048, this is 268M floats to store the activations for a single sequence. So in this case the activations are about 1/30 of the weights, though this will depend in general on the context length and the depth: generally activations scale with D
while weights scale with D**2
.
(Is this right? We'll also have the intermediate activations within the attention and feedforward layers, and attention weights of size 2048 at each layer for each position. the latter are relatively minor but the former might actually multiply the activation size by a factor of 2 or 3.)
If LLama was trained with a batch size of 4M sequences, that's actually giant? I wonder if that must mean 4M tokens, ie 2048 tokens-per-sequence * 2048 sequences.
If thats the case, and we have a cluster of 64 GPUs, then each GPU processes 32 sequences, meaning it has about 7B weights and 7B activations in memory, which seems plausible.
Now let's imagine an RNN. The simplest would be to just train the same transformer model as a "Markovized" RNN where the state is the key-value cache. So we still have all the same parameters. But we now have a 'state', the KV cache holding all the previous activations, so this is of size 268M. It's still the same number of FLOPs to predict a new token. But a batch size of 4M tokens would now require 4M copies of the 268M activations, instead of just 2048 copies. Basically we now have context_length
* sequences_in_batch
memory requirement where sequences_in_batch
must itself be larger than the transformer case by a factor of context_length
if we are to maintain the same tokens-per-batch. So the memory requirement for an equivalent context goes up from linear in context length, to effectively quadratic in context length.
(attention was already quadratic in the context length, but the attention matrices are relatively small compared to activation caches).
Something particularly interesting about the transformer setup is that it is effectively recurrent at runtime (with the KV cache). So it's a parallel training procedure for a recurrent process. And whatever recurrent processes we come up with, we might hope to find a parallel training procedure for.
The thing that makes the parallelization possible is that the recurrent state (the KV cache) decouples into the activations of individual previous timesteps, that can be computed independently.
In general this is true for any part of the system that does total bottom-up processing independently at each step, or that incorporates only earlier-layer information from previous steps (as transformers do).
What about a hypothetical video / robot model that sees several new frames every second? And each frame is represented as 256 tokens or whatever. I don't think this directly changes anything except that it implies you need a longer token context.