Skip to content

Learner

Learner is a basic class that provides useful functionalities:

  • Tweaking/customization of the training loop using a system of callbacks through Exceptions
  • Loading/saving model
  • Fit the model
  • Get model summary
  • Learning rate finder

The training loop consists of a minimal set of instructions; looping through the data we:

  • Compute the output of the model from the input
  • Calculate a loss between this output and the desired target
  • Compute the gradients of this loss with respect to model parameters
  • Update the parameters accordingly
  • Zero all the gradients

Any tweak of this training loop is defined in a Callback. A callback can implement actions on the following events:

  • before_fit: called before doing anything, ideal for initial setup
  • before_epoch: called at the beginning of each epoch, useful for any behavior you need to reset at each epoch
  • before_train: called at the beginning of the training part of an epoch
  • before_batch: called at the beginning of each batch, just after drawing said batch. It can be used to do any setup necessary for the batch (like hyper-parameter scheduling) or to change the input/target before it goes in the model (change of the input with techniques like mixup for instance)
  • after_pred: called after computing the output of the model on the batch. It can be used to change that output before it's fed to the loss function
  • after_loss: called after the loss has been computed, but before the backward pass. It can be used to add any penalty to the loss (AR or TAR in RNN training for instance)
  • after_cancel_backward: reached immediately after CancelBackwardException
  • after_backward: called after the backward pass, but before updating the parameters. It can be used to do any change to the gradients before any updates (gradient clipping for instance)
  • after_cancel_step: reached immediately after CancelStepException
  • after_step: called after the step and before gradients are zeroed
  • after_cancel_batch: reached immediately after CancelBatchException before proceeding to after_batch
  • after_batch: called at the end of a batch, for any clean-up before the next one
  • after_cancel_train: reached immediately after CancelTrainException before proceeding to after_train
  • after_train: called at the end of the training phase of an epoch
  • before_validate: called at the beginning of the validation phase of an epoch, useful for any setup needed specifically for validation
  • after_cancel_validate: reached immediately after CancelValidateException before proceeding to after_validate
  • after_validate: called at the end of the validation phase of an epoch
  • after_cancel_epoch: reached immediately after CancelEpochException before proceeding to after_epoch
  • after_epoch: called at the end of an epoch, for any clean-up before the next one
  • after_cancel_fit: reached immediately after CancelFitException before proceeding to after_fit
  • after_fit: called at the end of training, for any final clean-up

Learner

Learner is a basic class that handles training loop of pytorch model and utilize a systems of callbacks that makes training loop very customizable and extensible. You just need to provide a list of callbacks and callback functions.

Attributes:

Name Type Description
model Module

Pytorch's model.

dls DataLoaders

Train and valid data loaders.

n_inp int

Number of inputs to the model.

loss_func Callable[[tensor, tensor], tensor]

Loss function.

opt_func Optimizer

Optimizer function/class.

lr float

Learning rate.

splitter Callable[[Module], Iterable[Parameter]]

Function to split model's parameters into groups.

path Path

Path to save all artifacts.

model_dir_path Path

Model directory path.

callbacks Iterable[Callable] | None, default=None

Iterable of callbacks of type Callback.

logger Any

Logger to log metrics. Default is print but is typically modified by callbacks such as ProgressCallback.

callbacks list[Callback]

List of all the used callbacks by learner. TrainEvalCallback is added by learner, so no need to add it.

__init__(model, dls, n_inp=1, loss_func=F.mse_loss, opt_func=opt.SGD, lr=0.01, splitter=params_getter, path='.', model_dir='models', callbacks=None)

Parameters:

Name Type Description Default
model Module

Pytorch's model.

required
dls DataLoaders

Train and valid data loaders.

required
n_inp int

Number of inputs to the model.

1
loss_func Callable[[tensor, tensor], tensor]

Loss function.

F.mse_loss
opt_func Optimizer

Optimizer function/class.

opt.SGD
lr float

Learning rate.

`1e-2`
splitter Callable[[Module], Iterable[Parameter]]

Function to split model's parameters into groups, default all parameters belong to 1 group.

`params_getter`
path str

Path to save all artifacts.

"."
model_dir str

Model directory name relative to path.

"models"
callbacks Iterable[Callable] | None

Iterable of callbacks of type Callback.

None

fit(n_epochs=1, run_train=True, run_valid=True, callbacks=None, lr=None, reset_opt=False)

Fit the model for n_epochs.

Parameters:

Name Type Description Default
n_epochs int

Number epochs to train the model.

1
run_train bool

Whether to run training passes.

True
run_valid bool

Whether to run validation passes.

True
callbacks Iterable[Callback] | None

Callbacks to add to the existing callbacks. The added callbacks will be removed before fit returns.

None
lr float | None

Learning rate. If None, lr passed to Learner will be used.

None
reset_opt bool

Whether to reset the optimizer.

False

load_model(path=None, with_opt=False, with_epoch=False, with_loss=False)

Load the model and optionally the optimizer, epoch, and the loss.

Parameters:

Name Type Description Default
path str | Path | None

Model's file path. If None, use learner.model_dir_path/model.

None
with_opt bool

Whether to load the optimizer state.

False
with_epoch bool

Whether to load the current epoch number.

False
with_loss bool

Whether to load the current loss.

False

lr_find(start_lr=1e-07, gamma=1.3, num_iter=100, stop_div=True, max_mult=4)

Try different learning rates using exponential schedule to help pick the best learning rate following Cyclical Learning Rates for Training Neural Networks. When done, plot learning rate vs loss.

Parameters:

Name Type Description Default
start_lr float

Start learning rate.

1e-7
gamma float

Multiplicative factor of learning rate decay.

1.3
num_iter int

Number of iterations to run the training.

100
stop_div (bool, default)

Whether to stop training if the loss diverges.

True
max_mult int

Divergence threshold. If loss >= max_mult * minimum loss, stop training.

4

save_model(path=None, with_opt=False, with_epoch=False, with_loss=False, pickle_protocol=pickle.HIGHEST_PROTOCOL)

Save the model and optionally the optimizer, epoch, and the loss. Useful for checkpointing.

Parameters:

Name Type Description Default
path str | Path | None

File path to save the model. If None, use learner.model_dir_path/model.

None
with_opt bool

Whether to save the optimizer state.

False
with_epoch bool

Whether to save the current epoch number.

False
with_loss bool

Whether to save the current loss.

False
pickle_protocol int

Protocol used by pickler when saving the checkpoint.

pickle.HIGHEST_PROTOCOL

show_batch(sample_sz=1, callbacks=None, **kwargs)

Show sample_sz batch of input. The input would be what the model would see when making predictions. Therefore, all transformations and other augmentation will be applied to the input.

Parameters:

Name Type Description Default
sample_sz int

Number of input samples to show.

1
callbacks Iterable[Callback] | None

Callbacks to add to the existing callbacks. The added callbacks will be removed before show_batch returns.

None

Raises:

Type Description
NotImplementedError

Different types of Learner's must implement their own version depending on the type of input data. For example, VisionLearner's would show images.

summary(verbose=2, **kwargs)

Use torchinfo package to print out the model summary.

params_getter(model)

Get all parameters of model recursively.

Parameters:

Name Type Description Default
model Module

Model.

required

Yields:

Type Description
Parameter

Module parameter.