Skip to content

Training

Almost all training's related callbacks that will tweak/customize the training/validation loop is in this module.

Raises:

Type Description
CancelFitException

Stop training and move to after_fit.

CancelEpochException

Stop current epoch and move to after_epoch.

CancelTrainException

Stop training current epoch and move to after_train.

CancelValidException

Stop validation phase and move after_validate.

CancelBatchException

Stop current batch and move to after_batch.

CancelStepException

Skip stepping the optimizer.

CancelBackwardException

Skip the backward pass and move to after_backward.

BatchTransform

Bases: Callback

Transform X as a batch using tfm callable before every batch. Apply transformation tfm on the batch as a whole.

__init__(tfm, on_train=True, on_valid=True)

Parameters:

Name Type Description Default
tfm Callback

Transformation to apply on the batch.

required
on_train bool

Whether to apply the transformation during training.

True
on_valid bool

Whether to apply the transformation during validation.

True

DeviceCallback

Bases: Callback

Move batch and model to device.

LRFinder

Bases: Callback

Try different learning rates using exponential schedule to help pick the best learning rate following Cyclical Learning Rates for Training Neural Networks. When done, plot learning rate vs loss.

__init__(gamma=1.3, num_iter=100, stop_div=True, max_mult=4)

Parameters:

Name Type Description Default
end_lr float

Last learning rate in the schedule.

10.0
num_iter int

Number of iterations to run the training.

100
stop_div (bool, default)

Whether to stop training if loss diverges (loss > 4 * best_loss).

True
max_mult int

Divergence threshold. If loss >= max_mult * minimum loss, stop training.

4

MetricsCallback

Bases: Callback

Compute/update given metrics and log it using learner defined logger after every train/validate epoch. Metrics have to implement reset and compute methods. Highly recommended to use metrics from torcheval package or inherit from its Metrics baseclass for custom metrics.

Mixup

Bases: Callback

Train the model with a mix of samples from each batch in the training data. Instead of feeding the model with raw data, we use linear combination of the input using alpha from beta distribution. This means that the labels would also be the linear combination of the labels and not the original labels. The implementation is largely based on this paper.

__init__(alpha=0.4)

Parameters:

Name Type Description Default
alpha float

Concentration for Beta distribution.

0.4

before_batch()

Steps taken before passing inputs to the model:

  • Draw from a beta distribution a sample of size of batch_size. Each image would have its own λ.
  • To avoid having two images combined together with the same λ, we take the max of λ and 1 - λ so even if they were combined together they would lead to a different image.
  • Shuffle the batch before computing the linear combination.
  • Change the batch input to be the linear combination of both the original image and the shuffled images.

The loss would be the linear combination of the loss of the original images with the original target and the shuffled image with the shuffled target weighted by λ.

ModelResetter

Bases: Callback

Reset model's parameters. This is very useful in the context of NLP since we always reset hidden state. The assumption here is that model has a reset method that knows which parameters to reset and how.

ProgressCallback

Bases: Callback

Track progress of training using progress bar as well as to plot losses (train and valid), which allows us to have live feedback of the model's performance while it is still training.

__init__(plot=True)

Parameters:

Name Type Description Default
plot bool

Whether to plot train/valid losses during training.

True

Recorder

Bases: Callback

Keep track of losses and params of the optimizer such as learning rates as training progress so we can plot them later.

plot(pgid=-1, skip_last=0)

Plot loss vs lr (log-scale) across all iterations of training.

plot_loss(skip_last=0)

Plot losses, optionally skip last skip_last losses.

plot_params(params='lr', pgid=-1, figsize=(8, 6))

Plot all params values across all iterations of training.

SingleBatchCallback

Bases: Callback

Run 1 training/validation batch and stop by raising CancelFitException. Useful for debugging or want to check few parameters after 1 batch.

TrainEvalCallback

Bases: Callback

Tracks the number of iterations, percentage of training, and set training and eval modes.