expressive transformer: Nonlinear Function
Created: October 27, 2022
Modified: October 28, 2022

expressive transformer

This page is from my personal notes, and has not been specifically reviewed for public consumption. It might be incomplete, wrong, outdated, or stupid. Caveat lector.

This note is a scratchpad for investigating the expressivity of the transformer architecture.

In general, one set of intuitions that we might have about neural nets is that they are differentiable circuits. Then training is a process of gradient-based search over all computable functions within certain limits (circuit depth / width, etc.), and if all goes well, with sufficient regularization we eventually expect to find the 'right' such circuit (as in grokking). This view relies on some wishful (if not entirely implausible) thinking about the power of hill-climbing search in highly structured spaces. It also depends on having an architecture of the correct 'shape' to express the computations we were hoping to express. It seems reasonable to ask if transformers are a plausible candidate for a generic 'shape', and if not, what would be?

Some basic questions:

  • How easily can a transformer implement basic mathematical functions?
    • addition
    • multiplication
    • exp/log/etc
  • Given two circuits that a transformer can implement (eg, addition and multiplication), can we construct a 'switchable' or 'programmable' transformer that applies one or the other of these circuits according to some other input token? (this is really trying to get at the question of whether transformers can do fast weights)
  • Can we construct a universal turing machine (or equivalent, eg, a von Neuman machine, Scheme interpreter, etc.) within the transformer architecture?

The null hypothesis is that transformers can do all of these things, with varying but plausible degrees of effort (and appropriate input encodings). What would be interesting is if there is some simple function that:

  1. Transformers can't easily implement.
  2. Could be enabled by some reasonably simple and generic architecture change.

Literature review

Paper: Universal Transformers (Deghani et al., 2018): this just uses a standard transformer block, but it repeats the same block at every layer, and adds a dynamic halting mechanism to choose how deep to go independently at each position. So the 'universal' is that it supports adaptive depth instead of fixed depth. I guess that would help with generalization in something like addition where you need depth nn to add nn-bit inputs.

Switchable transformers

Let a transformer block at layer ii be defined by the quantities Wi\mathbf{W}_i, consisting of

per-head query WQ,i(h):[d,dk]per-head key WK,i(h):[d,dk]per-head value WV,i(h):[d,dk]per-head output WO,i(h):[dk,d]feedforward weights W1,i:[d,dff]feedforward bias b1,i:[dff]feedforward weights W2,i:[dff,d]feedforward bias b2,i:[d]\begin{align*} \text{per-head query } & W^{(h)}_{Q,i}: [d, d_k]\\ \text{per-head key } & W^{(h)}_{K,i}: [d, d_k]\\ \text{per-head value } & W^{(h)}_{V,i}: [d, d_k]\\ \text{per-head output } & W^{(h)}_{O,i}: [d_k, d]\\ \text{feedforward weights } & W_{1, i}: [d, d_\text{ff}]\\ \text{feedforward bias } & b_{1, i}: [d_\text{ff}] \\ \text{feedforward weights } & W_{2, i}: [d_\text{ff}, d] \\ \text{feedforward bias } & b_{2, i}: [d]\\ \end{align*}

Say we have two different transformer blocks WA\mathbf{W}^A and WB\mathbf{W}^B implementing two different computations on inputs of length nn. Can we mechanically construct a new transformer block, or sequence of blocks, such that feeding it a length-n+1n+1 sequence [A, …] or [B, …] returns results equivalent to the corresponding one of its components? That is, can we program a transformer at runtime?

Block solution: We can just stack the weights to get a block that does both computations at once, in separate subspaces. It also adds a final subspace that just copies the initial 'A' or 'B' switching token into every dimension. Then we have a second block that uses this to choose which result to report:

  1. The query WQ,2W_{Q, 2} for all heads hh just copies the switching token, represented as a vector [1, 0] for A or [0, 1] for B, plus something that picks out the current position.
  2. The key for heads that computed [A] contains [1, 0] along with the current position, and similarly the query for heads that computed [B] contains [0, 1]. (this would seem to require a bias term bK,i(h)b^{(h)}_{K, i} in the attention computations?)
  3. Thus the attention layer just copies the result from whichever of the two computations we actually requested.

This solution is unsatisfying because it still does both of the underlying computations! It's like making a 'programmable' computer by having it first compute the results of all possible programs, before you tell it which program you actually wanted. Conceptually it works, but it's not a practical strategy.

But maybe it works if we break things down further? Of course we can't run all possible full-length programs. But if there are only a small number of primitive operations that compose to make full-length programs, we can perhaps run all of these operations at each individual step and then choose which result to take (similar in principle to 'speculative execution' in modern CPUs). We can maybe view this in terms of the idea that attention is dual metareasoning: instead of allocating computations given a fixed input, we allocate the input among a set of fixed computations.

research idea: A more radical change might be to incorporate attention over weight matrices. Suppose we allow multiple weight matrices, each with a learned key vector, and that each position produces a weight-query vector, so that the actual weights used are a runtime-determined linear combination of some set of options. You could do this just for the feedforward layers, or even for all the weights in a transformer block. This becomes kind of a continuous, combinatorial (if you do separate attention at each layer) mixture of experts model? Of course the downside is that you'd still need to keep all the weights in memory. Maybe we're not multiplying by all these weight matrices, but we're still averaging their entries which still takes d2d^2 time. And then you get different weights for different batch elements and presumably at different positions within the same batch element (unless you share the weight attention over all positions, which is interesting and also closer to the mixture-of-experts idea I think, but violates some intuitive notion of 'locality' in that it makes less sense as you imagine contexts getting larger and larger - in the limit the context should always be 'everything', which is constant, so if you're choosing your weight matrices based on full context then they too should be constant. but you could build in some lesser notion of locality, like convolutional shared weights over windows of adjacent tokens…).

A related research idea: use retrieval (transformers with memory) except instead of retrieving data, we retrieve programs or capabilities. Looking at the input, the model produces a query vector, which gets matched to various weight sets, and we use the best match (or, for gradients, some mixture of the top few best matches). Maybe this is the key to kitchen sink deep learning.