relu selection: Nonlinear Function
Created: February 12, 2023
Modified: February 12, 2023

relu selection

This page is from my personal notes, and has not been specifically reviewed for public consumption. It might be incomplete, wrong, outdated, or stupid. Caveat lector.

The selection operation y = where(c, a, b) returns

y={a if c=1b if c=0.y = \left\{\begin{array}{ll}a&\text{ if } c=1\\b&\text{ if } c=0\end{array}\right..

How can a transformer layer implement this operation? One approach is to is to use the feedforward layer, parameterized as

y=W2 ReLU(W1x+b1).y = \mathbf{W}_2 \text{ ReLU}\left(\mathbf{W}_1 x + \mathbf{b}_1\right).

We'll demonstrate a construction that works for a vector input x=[c,a,b]x = [c, a, b], where cc is an integer condition and a,ba, b are vectors to be selected. Here we assume that these are all stacked in a vector at the current position. To implement the selection operation we choose

W1=[CI0CI0C0IC0I];b1=[CC00]\mathbf{W}_1 = \left[\begin{array}{ccc}C & I & 0\\C & -I & 0\\-C & 0 & I\\-C & 0& -I\end{array}\right];\qquad \mathbf{b}_1 = \left[\begin{array}{c}-C\\-C\\0\\0\end{array}\right]

for some large positive constant CC, so that

W1x+b1=[C(c1)+aC(c1)aCc+bCcb]\mathbf{W}_1 x + \mathbf{b}_1 = \left[\begin{array}{c}C(c-1) + a\\C(c-1) - a\\-Cc + b\\-Cc - b\end{array}\right]

evaluates approximately to [a,a,,]T[a, -a, -\infty, -\infty]^T when c=1c=1 and to [,,b,b]T[-\infty, -\infty, b, -b]^T when c=0c=0. Passing this through a ReLU removes the negative components, so that we end up with [a+,(a)+,0,0]T[a^+, (-a)^+, 0, 0]^T or [0,0,b+,(b)+]T[0, 0, b^+, (-b)^+]^T respectively. Finally we choose

W2=[I,I,I,I],\mathbf{W}_2 = [I, -I, I, -I],

so that

W2 ReLU(W1x+b1)={a+(a)++00=a if c=100+b+(b)+=b if c=0\mathbf{W}_2 \text{ ReLU}\left(\mathbf{W}_1 x + \mathbf{b}_1\right) = \left\{\begin{array}{rl}a^+ - (-a)^+ + 0 - 0= a&\text{ if } c=1\\0 - 0 + b^+ - (-b)^+ = b&\text{ if } c=0\end{array}\right.

exactly as desired.

Most transformer actually include a residual connection

y=x+W2 ReLU(W1x+b1),y = x + \mathbf{W}_2 \text{ ReLU}\left(\mathbf{W}_1 x + \mathbf{b}_1\right),

which we've omitted for simplicity. In this setting we might want the output y=[c,a,0]y = [c, a, 0] if c=1c=1 and y=[c,b,0]y = [c, b, 0] if c=0c=0. This can be obtained by selecting between 00 and bab - a in the second position using a straightforward modification of the above technique.

Extensions

We can allow the condition to be an arbitrarily positive or negative integer, casting it into 0/1 via a relu inequality layer.