probabilistic transformers: Nonlinear Function
Created: October 30, 2020
Modified: October 30, 2020

probabilistic transformers

This page is from my personal notes, and has not been specifically reviewed for public consumption. It might be incomplete, wrong, outdated, or stupid. Caveat lector.
  • A short note on interpreting a transformer layer as performing maximum-likelihood inference in a Gaussian mixture model: https://arxiv.org/abs/2010.15583
  • The correspondence is as follows:
    • The mixture 'components' are the units at each layer.
    • Each unit has an associated distribution over queries and values, which is assumed to be factored Normal: p(q,vu)=p(qu)p(vu)=N(q;ξu,α)N(v;μu,β)p(q, v | u) = p(q | u)p(v | u) = N(q; \xi_u, \alpha)N(v; \mu_u, \beta) where ξu\xi_u and μu\mu_u are the key and expected value vectors for unit uu.
    • Recall that in a transformer, each unit produces a key, a value, and a query for its counterpart at the next layer.
    • Here, each unit has a known key ξu\xi_u, input value μu\mu_u, and sampled query quq_u. The output value vuv_u is unknown.
    • This model assumes that the query is randomly drawn from a Normal around the key at each layer---is that reasonable??
    • Given a query qq where we don't know what unit it came from, we want the most probable value for that query. That means we:
      • compute 'weights' wuw_u for each unit according to how likely they are to have generated qq.
      • take the weighted expectation of v=uwuμuv = \sum_u w_u \mu_u
    • This is pretty much exactly the transformer update equation (with some differences in how the weights are computed---the gaussian form uses squared distance between query and key, while the actual transformer uses dot products, these are of course related by the norms of the query and key, which can be encoded as 'priors' to make the two forms equivalent).
  • I'm not super impressed by this. The model and the query feel pretty artificial.