sparse mixture of experts: Nonlinear Function
Created: February 13, 2023
Modified: February 13, 2023

sparse mixture of experts

This page is from my personal notes, and has not been specifically reviewed for public consumption. It might be incomplete, wrong, outdated, or stupid. Caveat lector.

References:

A sparse-gated mixture of experts layer has 'expert networks' E1,,EnE_1, \ldots, E_n and a 'gating network' GG. The output of the layer is

y=iG(x)iEi(x).y = \sum_i G(x)_i E_i(x).

The gating network in Shazeer et al. (2017) is a softmax applied to noisy logits that have been explicitly sparsified by a 'KeepTopK' operation. This is trained by simple backpropagation. Although sparsification destroys gradients, this is allegedly not a big deal in practice since gradients do pass through to the top-K experts. Although only a few experts will fire for any particular input, an additional 'load-balancing' loss term is added to encourage all experts to have equal gating values on average (in effect this encourages exploration of less-commonly-used experts); this avoids a 'rich-get-richer' effect where the most commonly selected experts receive more training and so are even more likely to be selected in the future while other experts go unused.

Parallelization: it is tricky to parallelize a mixture-of-experts model since different batch elements will need different experts. The approach of Shazeer et al. (2017) is essentially to reshuffle the inputs after the gating network, sending each batch element to the device(s) that 'host' the appropriate expert(s). Each expert lives on a specific device, and all inputs that require that expert are routed to that device so that the expert can be applied to a batch of inputs.

The switch transformer (Fedus, Zoph, Shazeer, 2022) adopts this approach to transformer language models, where the 'experts' are the feedforward layers of the transformer. While previous work conjectured that it's necessary to use K2K \ge 2 experts in order to get a useful training signal for the gating network, the switch transformer finds that K=1K=1 actually works well.

The Generalist Language Model (GLaM) uses a similar architecture to the switch transformer, but with K=2K=2 experts activated per example (and trains a decoder-only language model rather than an encoder-decoder sequence-to-sequence model). Every other feedforward layer is a mixture of experts. The model is scaled to 64 experts, in which each feedforward pass has a 'base dense size' of 64B parameters (half of which are mixture-of-two experts layers, so 96B parameters are activated). There are a bunch of details about how this is distributed and balanced, which seems to be a lot of the effort since these models are more complex to scale than dense models. The claim is that this achieves better performance than GPT-3 with 1/3 the training energy consumption.

Expert choice routing (Zhou et al., 2022) tries to do a better job of keeping all experts at full utilization. Instead of choosing the top kk experts for each token --- which can result in some experts getting more tokens than others --- it chooses the top kk tokens for each expert. This guarantees that the experts are balanced by construction. It's weird, however, because it means that the number (and identity) of experts used by a given token is not fixed; it depends on the other tokens in the batch.