Training
Almost all training's related callbacks that will tweak/customize the training/validation loop is in this module.
Raises:
Type | Description |
---|---|
CancelFitException
|
Stop training and move to after_fit. |
CancelEpochException
|
Stop current epoch and move to after_epoch. |
CancelTrainException
|
Stop training current epoch and move to after_train. |
CancelValidException
|
Stop validation phase and move after_validate. |
CancelBatchException
|
Stop current batch and move to after_batch. |
CancelStepException
|
Skip stepping the optimizer. |
CancelBackwardException
|
Skip the backward pass and move to after_backward. |
BatchTransform
Bases: Callback
Transform batch using tfm
callable before every batch.
Apply transformation tfm
on the batch as a whole.
__init__(tfm, on_train=True, on_valid=True)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
tfm |
Callback
|
Transformation to apply on the batch. |
required |
on_train |
bool
|
Whether to apply the transformation during training. |
True
|
on_valid |
bool
|
Whether to apply the transformation during validation. |
True
|
BatchTransformX
Bases: Callback
Transform X using tfm
callable before every batch.
Apply transformation tfm
on the batch as a whole.
__init__(tfm, on_train=True, on_valid=True)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
tfm |
Callback
|
Transformation to apply on the batch. |
required |
on_train |
bool
|
Whether to apply the transformation during training. |
True
|
on_valid |
bool
|
Whether to apply the transformation during validation. |
True
|
DeviceCallback
Bases: Callback
Move batch and model to device
.
__init__(device=default_device)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
device |
str | device
|
Device to copy batch to. |
default_device
|
LRFinder
Bases: Callback
Try different learning rates using exponential schedule to help pick the best learning rate following Cyclical Learning Rates for Training Neural Networks. When done, plot learning rate vs loss.
__init__(gamma=1.3, num_iter=100, stop_div=True, max_mult=4)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
end_lr |
float
|
Last learning rate in the schedule. |
10.0
|
num_iter |
int
|
Number of iterations to run the training. |
100
|
stop_div |
(bool, default)
|
Whether to stop training if loss diverges (loss > 4 * best_loss). |
True
|
max_mult |
int
|
Divergence threshold. If loss >= max_mult * minimum loss, stop training. |
4
|
MetricsCallback
Bases: Callback
Compute/update given metrics and log it using learner
defined logger
after every train
/validate
epoch. Metrics have to implement reset
and compute
methods. Highly recommended to use metrics from
torcheval
package or inherit from its Metrics baseclass for custom
metrics.
Mixup
Bases: Callback
Train the model with a mix of samples from each batch in the training
data. Instead of feeding the model with raw data, we use linear
combination of the input using alpha
from beta distribution. This
means that the labels would also be the linear combination of the
labels and not the original labels. The implementation is largely
based on this paper.
__init__(alpha=0.4)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
alpha |
float
|
Concentration for Beta distribution. |
0.4
|
before_batch()
Steps taken before passing inputs to the model:
- Draw from a beta distribution a sample of size of
batch_size
. Each image would have its ownλ
. - To avoid having two images combined together with the same
λ
, we take the max ofλ
and1 - λ
so even if they were combined together they would lead to a different image. - Shuffle the batch before computing the linear combination.
- Change the batch input to be the linear combination of both the original image and the shuffled images.
The loss would be the linear combination of the loss of the
original images with the original target and the shuffled image
with the shuffled target weighted by λ
.
ModelResetter
Bases: Callback
Reset model's parameters. This is very useful in the context of NLP
since we always reset hidden state. The assumption here is that
model
has a reset
method that knows which parameters to reset
and how.
ProgressCallback
Bases: Callback
Track progress of training using progress bar as well as to plot losses (train and valid), which allows us to have live feedback of the model's performance while it is still training.
__init__(plot=True)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
plot |
bool
|
Whether to plot train/valid losses during training. |
True
|
Recorder
Bases: Callback
Keep track of losses and params
of the optimizer such as
learning rates as training progress so we can plot them later.
plot(pgid=-1, skip_last=0)
Plot loss vs lr (log-scale) across all iterations of training.
plot_loss(skip_last=0)
Plot losses, optionally skip last skip_last
losses.
plot_params(params='lr', pgid=-1, figsize=(8, 6))
Plot all params
values across all iterations of training.
SingleBatchCallback
Bases: Callback
Run 1 training/validation batch and stop by raising
CancelFitException
. Useful for debugging or want to check few
parameters after 1 batch.
SingleBatchForwardCallback
Bases: Callback
Run 1 training/validation batch and stop after forward pass (after
computing loss) by raising CancelFitException
. Useful for debugging
or want to check few parameters after 1 batch.
TrainEvalCallback
Bases: Callback
Tracks the number of iterations, percentage of training, and set training and eval modes.