Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity
- Thesis: The paper introduces Switch Transformers, a simplified and improved variant of Mixture of Experts (MoE) that routes each token to only a single expert (k=1) rather than multiple experts as in traditional MoE approaches. This architectural simplification enables efficient scaling to trillion-parameter models while maintaining computational efficiency, as the model can increase parameter count independently of FLOPs per token. The authors demonstrate that this sparse architecture not only simplifies implementation and reduces communication costs but also achieves superior performance compared to both dense models and traditional MoE approaches, establishing sparse models as a practical and effective path for scaling neural language models.
- Methods:
- Single Expert Routing: Route each token to only the top-1 expert based on router probability (instead of top-k)
- Selective Precision: Use float32 for router computations, bfloat16 elsewhere for stability
- Initialization: Reduce initialization scale by 10x (from 1.0 to 0.1) for training stability
- Expert Dropout: Apply higher dropout (0.4) in expert layers during fine-tuning
- Capacity Factor: Fixed expert capacity with overflow handling via residual connections
- Load Balancing: Auxiliary loss to encourage uniform expert utilization
- Contribution:
- Achieved 7x pre-training speedup over T5-Base with same computational budget
- Scaled models to 1.6 trillion parameters (Switch-C)
- Demonstrated 4x speedup over T5-XXL for trillion parameter models
- Showed consistent improvements across all 101 languages in multilingual training
- Successfully distilled large sparse models into dense ones preserving ~30% of quality gains
- Provided comprehensive scaling analysis from 2 to 2048 experts
- Importance:
- Makes sparse expert models practical and accessible (works even with 2 experts)
- Proves sparse models superior to dense on speed-accuracy Pareto frontier
- Enables training of extremely large models without proportional compute increase
- Simplifies MoE implementation while improving stability and performance
- Opens path for efficient scaling beyond traditional dense model limitations
- Takeaways:
- Simpler is better: k=1 routing outperforms complex k>1 routing schemes
- Sparse models consistently beat dense models per step and wall-clock time
- Expert models scale well: doubling experts improves quality with minimal overhead
- Small scale benefits: Even 2-expert models show improvements over dense baselines
- Distillation viable: Can compress 99% while retaining 28% of performance gains
- Improvements:
- Stability: Largest models (Switch-XXL) still exhibit training instabilities
- Fine-tuning Gap: Pre-training gains don’t fully translate to downstream tasks (especially reasoning)
- Hardware Optimization: Better co-design of hardware and sparse architectures needed
- Heterogeneous Experts: Current design uses identical experts; varied expert sizes could help
- Beyond FFN: Preliminary success with attention layers needs more exploration
- Questions:
- How do we do the routing in code in terms of which device and how do we send the tokens to the assigned expert which resides on a different device
- Notes:
- MoEs were not adopted due to training instabilities, hard to finetune, large commmunication costs, and additional systems’ complexities
- Main contribution:
- Sample efficiency
- Speed up
- Typically in ML/DL, inputs are processed by the same parameters of the model. With MoE, different inputs will be processed by different parameters according to routing mechanism. This way, we can keep the computational cost fixed while increasing the model’s capacity.
- Language modeling benefits hugely from increased capacity of the model as it is complicated and we have plenty of data to train on. Therefore, increasing the capacity of the model always lead to better results according the scaling laws [[chinchilla]] of the LLM
- Paper added increasing model size while keeping computational budget (FLOPs) the same for each example as another axis to the scaling laws: dataset size, model size, computational budget
- Experts are spread on different devices
- Switch Transformer routes each token to ONLY 1 expert. Therefore, MoE layer is called switch layer. This leads to less commmunication, less computation at the router since we’re only multiplying output of 1 expert with its router’s probability, and reduce batch size for each expert (which means we can afford having much more experts where each expert has less capacity)
- Each expert has predetermined capacity of the max tokens it can receive in a batch. \(Expert\ Capacity = Capacity\ Factor * \frac{Total\ Tokens}{Number\ of\ Experts}\) High capacity leads unbalanced experts, which also leads to wasted memory (since each expert is allocated memory to the max capacity and some experts won’t use it) and wasted computation (since CUDA kernels pad the computation for increased parallelism)
- If an expert receives more tokens than its capacity, these tokens will be dropped (not go through any expert) and pass through to the other layer.
- Capacity factor > 1 is used to buffer more tokens and avoid dropping tokens
- We add load-balancing auxiliary loss to the model’s loss (cross entropy) to make sure all experts are balanced and receive similar number of tokens. This is done by
- \(loss = \alpha \cdot N \cdot \sum_{i=1}^{N} f_i \cdot P_i\)
- where \(f_i\) is the fraction of tokens dispatched to expert \(i\): \(f_i = \frac{1}{T} \sum_{x \in \mathcal{B}} {1}\{\arg\max p(x) = i\}\)
- and \(P_i\) is the fraction of the router probability allocated for expert \(i\): \(P_i = \frac{1}{T} \sum_{x \in \mathcal{B}} p_i(x)\)
- For training instability, we cast logits float32 before computing the softmax then cast back to bfloat16 when router dispatch tokens and combine tensors
- MoEs are known to suffer from overfitting during fine-tuning on downstream tasks since number of parameters are much larger than corresponding dense models and downstream tasks typically have much smaller datasets. The paper suggests to have much more aggressive dropout inside the experts during fine-tuning
- As number of experts increase -> loss decreases while FLOPs are still the same. This means we can keep increasing the number of experts to improve the loss while keeping the same FLOP budget (it is possible here since we only route a token for 1 expert)
- Switch Transformer is also sample efficient where the same loss can be achieved with much less training time (or number of tokens). In other words, MoE is much faster than dense models to train. Finally, bigger MoEs are more sample efficient than smaller MoEs. Increasing the size of a dense model will still lead to speed up gap favoring sparse models (MoEs) even though the gap gets smaller
#nlp #llm