Modified: February 13, 2023
transformer
This page is from my personal notes, and has not been specifically reviewed for public consumption. It might be incomplete, wrong, outdated, or stupid. Caveat lector.The core of the transformer architecture is multi-headed attention.
# Constants from 'Attention is All You Need'.
d_k = 64 # Dimension of queries, keys, and values.
d_ff = 2048 # Hidden dimension for feedforward layers.
num_heads = 8
d_model = d_k * num_heads
def multi_head_attention(qx, kvx):
# Project activations x
# (shape `[n_pos, d_model]`)
# down to per-head query/key/values
# (shape `[num_heads, n_pos, d_k]`).
# Encoder blocks have a single x
# (`qx = kvx`), while decoder blocks also
# include a layer that attends to
# activations `kvx` output from the encoder.
query = Dense(d_model)(qx)
key = Dense(d_model)(kvx)
value = Dense(d_model)(kvx)
query, key, value = [
transpose(reshape(v, [n_pos, num_heads, d_k]),
[0, 1])
for v in (query, key, value)]
# Normalized attention weights for each output (row).
# Shape is `[num_heads, n_pos, n_pos]`
scores = matmul(query, key, adjoint_b=True) / sqrt(d_k)
p_attn = softmax(scores, axis=-1)
# This is the heart of everything.
x = matmul(p_attn, value) # [num_heads, n_pos, d_k]
# Linear layer to 'un-project' the per-head results
# back into a single vector.
x = reshape(transpose(x, [0, 1]), [n_pos, d_model])
return Dense(d_model)(x)
The transformer block consists of a multi-headed attention layer followed by a fully-connected MLP, each wrapped with layer normalization and a residual connection:
def transformer_block(x):
# Attention mechanism.
x_attn = multi_head_attention(qx=x, kvx=x)
# Normalization and residual connection.
x = layer_norm(x + dropout(x_attn))
# Feedforward layers.
x_ff = Dense(d_ff)(x)
x_ff = dropout(relu(x_ff))
x_ff = Dense(d_model)(x_ff)
# Normalization and residual connection.
return layer_norm(x + dropout(x_ff))
Basic questions for intuition:
- why do multi-headed attention?
- what is the power of multiplicative interactions?
- we can view them as fast weights
Computational cost: the attention matrices scores
and p_attn
have shape [n, n]
, so the basic attention mechanism has cost , i.e., quadratic in the sequence length and linear in the model dimension. Meanwhile, the dense layers have cost : linear in the sequence length and quadratic in the model dimension.
Other observations:
- The same weights are reused at every position in the sequence; in this sense the dense layers function as
1 x 1
convolutions. This means we can apply a trained transformer model to sequences of different lengths, with no retraining required. - Transformers intrinsically model sets of embeddings. Because the model doesn't explicitly represent position within a sequence, sequence-modeling applications usually augment the input with a positional embedding.
- Attention is (up to the softmax nonlinearity) a third-order polynomial: Encoders and decoders: the use of these terms in the transformer literature doesn't make a lot of sense. The terminology is a relic of the original application to machine translation, where the 'encoder' maps from the input sentence to a latent representation, and the 'decoder' uses that representation to generated the translation. Now that purely autoregressive language models like GPT-3 are dominating everything, they are sometimes called 'decoder-only' models because they use the same kind of masking as the decoder blocks in the original setup. But this makes no sense conceptually (what could it even mean to 'decode' in the absence of an encoder?) and also turns out to not even be fully accurate at the technical level!
Classically, an 'encoder' block refers to just the basic transformer as shown above: it takes the input sequence, e.g., the French sentence to be translated, and outputs an activation vector at each position. A 'decoder' block takes the output sequence (which may be only partially generated) and also produces a vector of activations at each position. At the final decoder layer, these are interpreted as specifying logits over the next token, while the analogous activations at the final encoder layer are interpreted just as a latent representation of the input. The decoder block also adds two architectural ingredients:
- Masked attention: the attention pattern over the inputs is multiplied by a lower-triangular mask matrix to ensure that activations at position are influenced only by inputs at positions . This is sometimes called 'causal' masking; it can also be seen as simply enforcing an autoregressive model structure.
- Encoder-decoder connection: the decoder block incorporates a second multi-head attention mechanism which attends to the final activations of the encoder stack, representing the input sentence (the queries to this head are built from the layer input as usual, but the keys and values come from the encoder). Note that this attention is not masked, but operates over the entire input sequence.
def decoder_block(
x, encoder_activations, mask):
# Masked attention on layer inputs.
x_attn = mask * multi_head_attention(
qx=x, kvx=x)
x = layer_norm(x + dropout(x_attn))
# Also attend to encoder activations.
x_attn = multi_head_attention(
qx=x, kvx=encoder_activations)
x = layer_norm(x + dropout(x_attn))
# Feedforward layers identical to encoder.
...
return layer_norm(x + dropout(x_ff))
The transformer blocks in autoregressive language models necessarily use masked attention, so in this sense they are like the decoder blocks of the original setup. But of course don't use an encoder-decoder connection, since there's no separate input to attend to! Mathematically it is as if they are always 'decoding' the null input. So they are 'decoder' stacks only in a very degenerate sense!
Sometimes people also talk about 'encoder-only' transformer models, usually in reference to the BERT family. These really do look like the encoder stack in the original translation setup, where each block uses non-masked attention over the entire input sequence. Usually these are trained by instead masking some parts of the input, and then attempting to reconstruct those parts from the final activations.