linear attention: Nonlinear Function
Created: December 07, 2023
Modified: December 07, 2023

linear attention

This page is from my personal notes, and has not been specifically reviewed for public consumption. It might be incomplete, wrong, outdated, or stupid. Caveat lector.

tags: [ ] created: 2023-12-07

modified: 2023-12-07

References:

The usual transformer attention mechanism is written as

V=softmax(QKTd)VV' = \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right)V

or equivalently

Vi=jexp(QiTKj)Vjjexp(QiTKj)=jAi,jVjV'_i = \frac{\sum_j \exp\left( Q_i^T K_j \right)V_j}{\sum_j \exp\left( Q_i^T K_j \right)} = \sum_j A_{i, j} V_j

where

Ai,j=exp(QiTKj)kexp(QiTKk)A_{i,j} = \frac{\exp\left( Q_i^T K_j \right)}{\sum_k \exp\left( Q_i^T K_k \right)}

is the matrix of normalized attention scores.

Mechanically, nothing is stopping us from replacing these scores with any positive similarity function Sim(Qi,Kj)\text{Sim}(Q_i, K_j).

Linear attention is the case where we choose Sim\text{Sim} to be linear in (features ϕ\phi of) QQ and KK:

Ai,j=ϕ(Qi)Tϕ(Kj)kϕ(Qi)Tϕ(Kk)A_{i,j} = \frac{\phi(Q_i)^T \phi(K_j) }{\sum_k \phi(Q_i)^T \phi(K_k)}

The advantage of doing this is that it allows attention output values to be computed recurrently, accumulating the key-value outer product matrix jϕ(Kj)VjT\sum_{j} \phi(K_j)V_j^T across steps. The disadvantage is that we give up some expressivity relative to traditional attention.

How should we understand the loss of expressivity?

can we do multi-query attention with linear attention?? at the final position for each head we end up with KV^T of shape [key_size, value_size]. then with multiple queries we consider a query matrix of size [num_quer]