long-term context in Transformers: Nonlinear Function
Created: March 20, 2020
Modified: March 21, 2020

long-term context in Transformers

This page is from my personal notes, and has not been specifically reviewed for public consumption. It might be incomplete, wrong, outdated, or stupid. Caveat lector.
  • Notes on https://www.pragmatic.ml/a-survey-of-methods-for-incorporating-long-term-context/
  • 'Standard' transformers have O(n**2) complexity in sequence length. Why?
    • I guess at each position we produce a query and we have to match that query against the keys at each of the n sequence points. So that's quadratic even before we then do it at every level.
  • How could we imagine making this faster? This is just me thinking things through a prior before reading the article:
    • LSH: hash all keys with a locality-sensitive hash. Then for each query we get the top k keys in constant time. We can't do fully soft attention, but we can sum over those top k and that's probably most of what we need.
      • tree structure: for all values at a level, build a space-partitioning tree?
    • locality: each point only connects to a constant number of nearby points. by keeping this constant we maintain linear complexity in seq length, but lose the ability to handle long-term dependence.
    • decaying sparsity: each point connects only to a constant number of others, but instead of being local, they are chosen randomly with some distribution that decays with distance. this allows long-term deps but is weird---in a book we'll depend on random words from previous chapters but not their neighbors.
    • hierarchical sparsity: in a book, maybe we depend on the past few words, but also on 'abstract' representations of previous pages, and of previous chapters (which themselves depend on the page representations, etc.). Those abstract reps come from higher up the transformer stack? For each 'section' (could just be constant-length subsequence) we define a single set of attention heads whose inputs are all of their component represetnations (and previous sections). Then we condition the representations in the next section on these heads also. We'd have to do this with a logarithmic amount of comditioning: given a previous context of length N, we depend on something that represents the first N/2, another that represents N/2-3N/4, another that represents 3N/4-7N/8, etc.
      • Q: what does it mean to 'condition' on these heads? Do we just add them into our representation? That seems low-powered.
      • Q: Are these 'long-term context' heads different from higher-up-stack heads? Or are they just a subset?
    • Of these, I'm most optimistic about some combination of hashing and hierarchical sparsity?
  • Okay what do people actually do?
  • Factorized attention: https://arxiv.org/abs/1904.10509 : half of heads depend on local context, but the other half depend on heads spread evenly throughout the sequence. Then any info can be routed through the transformer in only two levels: first to get to their local 'broadcast point', then to wherever else attends to that point. Total complexity O(n sqrt(n)) (why?)
    • this seems not-crazy, but I don't get the sqrt complexity: first why does this analyze as sqrt (implies that each of the long-term heads attends to only sqrt(n) points which is not 'evenly spaced'), second, sqrt isn't good enough. we want nlog n.
  • Adaptive span: https://www.pragmatic.ml/a-survey-of-methods-for-incorporating-long-term-context/ Each head 'learns' a context length, with a penalty for long context lengths. So most heads will prefer short context lengths, but some heads (esp at higher levels) do learn to look back a long way. Unfortunately this nonuniformity doesn't parallelize well.
  • Transformer-XL: splits the series into segments, passes activations of segment t as inputs to heads of segment t+1. For tractability, doesn't backprop across segments. Introduces relative position encoding.
  • Compressive Transformer: https://arxiv.org/abs/1911.05507 Like Transformer-XL, conditions on previous activations. Also introduces a compressed memory.