Created: February 12, 2023 Modified: February 12, 2023
relu selection
This page is from my personal notes, and has not been specifically reviewed for public consumption. It might be incomplete, wrong, outdated, or stupid. Caveat lector.
The selection operation y = where(c, a, b) returns
y={ab if c=1 if c=0.
How can a transformer layer implement this operation? One approach is to is to use the feedforward layer, parameterized as
y=W2 ReLU(W1x+b1).
We'll demonstrate a construction that works for a vector input x=[c,a,b], where c is an integer condition and a,b are vectors to be selected. Here we assume that these are all stacked in a vector at the current position. To implement the selection operation we choose
W1=⎣⎡CC−C−CI−I0000I−I⎦⎤;b1=⎣⎡−C−C00⎦⎤
for some large positive constant C, so that
W1x+b1=⎣⎡C(c−1)+aC(c−1)−a−Cc+b−Cc−b⎦⎤
evaluates approximately to [a,−a,−∞,−∞]T when c=1 and to [−∞,−∞,b,−b]T when c=0. Passing this through a ReLU removes the negative components, so that we end up with [a+,(−a)+,0,0]T or [0,0,b+,(−b)+]T respectively. Finally we choose
W2=[I,−I,I,−I],
so that
W2 ReLU(W1x+b1)={a+−(−a)++0−0=a0−0+b+−(−b)+=b if c=1 if c=0
exactly as desired.
Most transformer actually include a residual connection
y=x+W2 ReLU(W1x+b1),
which we've omitted for simplicity. In this setting we might want the output y=[c,a,0] if c=1 and y=[c,b,0] if c=0. This can be obtained by selecting between 0 and b−a in the second position using a straightforward modification of the above technique.
Extensions
We can allow the condition to be an arbitrarily positive or negative integer, casting it into 0/1 via a relu inequality layer.