reversal curse: Nonlinear Function
Created: January 10, 2024
Modified: January 10, 2024

reversal curse

This page is from my personal notes, and has not been specifically reviewed for public consumption. It might be incomplete, wrong, outdated, or stupid. Caveat lector.

References:

The observation is that transformers trained on statements of the form "A is B" do not generalize to "B is A", even when the relation is in fact symmetric and its symmetry can be learned from the training set.

Note that a model that did learn the symmetry relation would be able to better predict its training set. So this is not a fundamental failing of the language modeling objective. It's a specific architectural issue with transformers.

Neel Nanda elaborates that the key-value lookup implemented by a transformer feedforward block is inherently asymmetric. If the model learns that an input query A should produce an output value B, then A and B are stored in the weights of the first and second layer of the feedforward block respectively. There's no 'inverse index' mechanism by which a query of B can retrieve the value A.

Could we define an architecture that wouldn't have this limitation? Ideally we'd like the model to be able to learn that some relations are symmetric, rather than having to build this in with a training procedure (ie augmenting training data to present both versions of the relation).

One question is if transformers could in principle learn about symmetric relations. You could imagine some property of the input that would do a key-value lookup leading to a realization of 'symmetric relation' in the first layer.

Brainstorming:

  • concept of association that works in both directions. we have an undirected graph of concepts and when we query a concept we get its neighbors, all relations it's associated with. then potentially we can store arbitrary data about the relation as part of that edge (ie is it symmetric, transitive, etc).
  • can this just work as a vector database? concept vectors are associated with nearby concepts in vector space, but the specific direction of the difference gives you the relation between them.
  • how would you implement this concretely? the 'weights' are just a big matrix of concept vectors. doing the matrix multiply is a similarity search. then the result is some concatenation of the top results (concepts) and the differences (relations)? So if input is x:[d]x: [d], weights are W:[n,d]W: [n, d], then σ(Wx):[n]\sigma(Wx): [n] is an indicator for similar points, and we could return
    • σ(Wx)W\sigma(Wx)W: an average of similar points? this is probably not useful because it ignores the variation of these points from the input, which was the whole interesting bit. (and under plausible distributional assumptions the average of similar points might just be the input and we've accomplished nothing)
    • W[argmaxσ(Wx)]W[\text{argmax} \sigma(Wx)]: the single most similar point. not good because it's presumably just the input.
  • I think we also want to be able to search by relation? Like we don't just want everything we know about a given entity, we might be querying a specific relation. So we have cc and rr which are themselves both projections of the input. And we maybe then query c+rc + r? But of course this is just a different projection of xx.
  • https://thegradient.pub/dont-forget-about-associative-memories/

I guess one thing we could do is key-value lookups but always maintain and also query the reverse index. That is, in addition to the normal query y=W2σ(W1x)y = W_2\sigma(W_1x) we also run the reverse query y=W1Tσ(W2Tx)y' = W_1^T\sigma(W_2^Tx) which would attempt, if passed a value, to recover the corresponding key. And then perhaps we average these αy+(1α)y\alpha y + (1 - \alpha) y' where α=σ(W3x)\alpha = \sigma(W_3x) is a data-depending gating mechanism. If it's always 1 we recover standard transformers (with some inefficiency). But for a symmetric relation the model can learn to make it 0 in the noncanonical direction and this is more efficient than explicitly storing both directions of the relation. Note that this won't be more computationally efficient than a model explicitly storing both directions (we're still doing double computation), but it does save memory and potentially improve learnability. OTOH it's a bit too specific of a hack.

One way to frame the general goal here is a 'relational' layer. If we're going to hard-code something into the network you could do worse than a prior that the world contains both 'things' (concepts, etc) and relations between things.

Another connection is to work on learning invariances in neural networks. If the data are invariant to some transformation, we can often build this into the network as parameter sharing. For example, convnets share parameters across translations of the model. Graph networks (and transformers) may share parameters across permutations. A symmetric relation would share parameters across the key and value layers. Stuff to read:

Also connections to editing / updating LLMs with new facts. We want the LLMs to learn the symmetric and transitive consequences of those facts. E.g., if "A is B's mother" and we add "B is C's mother" then it should follow that "C is B's child" and also that "A is C's grandmother" and "C is A's grandchild". This seems nontrivial and probably intractable in full generality (nobody knows all the consequences of most of our factual knowledge, e.g., mathematical axioms) but it seems interesting nontheless.