General
Gradient accumulation: to avoid out-of-memory errors (OOM) while training on GPUs, break larger batches into smaller batches and accumulate gradients by back-propagating after every batch to gather gradients before running optimization step. It should give identical results as if we train with larger batches unless we have layers that kind of depending on the batch size in the forward pass such as BatchNorm.
- For example, in pytorch, instead of calling
optimizer.step()
for every batch, we call it every few batches.
- For example, in pytorch, instead of calling
When using mixed precision, it is better to increase batch size to utilize the GPU since every batch will occupy much smaller memory. As a result, there will be less updates for each epoch because the model will see less number of batches in every epoch. This means that 1) we need to increase the number of epochs and 2) increase the learning rate.
Mixup: linear combination of 2 random examples from training dataset using lambda parameter that is drawn from
Beta
distribution (α
,α
). The output vector will also be linear combination of two examples labels. This will force model to be more robust and learn linear combination of examples instead of memorizing them. As a result, model becomes less sensitive to corrupted labels and noise in the training data. This method can also be applied to tabular data. However, because it is harder for the model to learn to differentiate between the two examples and the weight of each, we need to train for much longer to get good results.
NLP
- With LLM, generally the more compute the better the results. We can define compute, roughly, as the number of parameters x number of tokens. Therefore, we can make the model bigger and keep the number of tokens fixed OR keep the model size the same and increase the number of tokens which means that we have to train for longer. There is a trade-off that depends on the task.
- Improving logloss for LLMs is correlated with improved performance on downstream tasks
- Even though loss scales smoothly with compute, individual downstream tasks may scale in an emergent fashion. This means some tasks’ loss may be flat, others maybe inversely scaled, etc.