Learner
Learner
is a basic class that provides useful functionalities:
- Tweaking/customization of the training loop using a system of callbacks through Exceptions
- Loading/saving model
- Fit the model
- Get model summary
- Learning rate finder
The training loop consists of a minimal set of instructions; looping through the data we:
- Compute the output of the model from the input
- Calculate a loss between this output and the desired target
- Compute the gradients of this loss with respect to model parameters
- Update the parameters accordingly
- Zero all the gradients
Any tweak of this training loop is defined in a Callback. A callback can implement actions on the following events:
before_fit
: called before doing anything, ideal for initial setupbefore_epoch
: called at the beginning of each epoch, useful for any behavior you need to reset at each epochbefore_train
: called at the beginning of the training part of an epochbefore_batch
: called at the beginning of each batch, just after drawing said batch. It can be used to do any setup necessary for the batch (like hyper-parameter scheduling) or to change the input/target before it goes in the model (change of the input with techniques like mixup for instance)after_pred
: called after computing the output of the model on the batch. It can be used to change that output before it's fed to the loss functionafter_loss
: called after the loss has been computed, but before the backward pass. It can be used to add any penalty to the loss (AR or TAR in RNN training for instance)after_cancel_backward
: reached immediately after CancelBackwardExceptionafter_backward
: called after the backward pass, but before updating the parameters. It can be used to do any change to the gradients before any updates (gradient clipping for instance)after_cancel_step
: reached immediately after CancelStepExceptionafter_step
: called after the step and before gradients are zeroedafter_cancel_batch
: reached immediately after CancelBatchException before proceeding toafter_batch
after_batch
: called at the end of a batch, for any clean-up before the next oneafter_cancel_train
: reached immediately after CancelTrainException before proceeding toafter_train
after_train
: called at the end of the training phase of an epochbefore_validate
: called at the beginning of the validation phase of an epoch, useful for any setup needed specifically for validationafter_cancel_validate
: reached immediately after CancelValidateException before proceeding toafter_validate
after_validate
: called at the end of the validation phase of an epochafter_cancel_epoch
: reached immediately after CancelEpochException before proceeding toafter_epoch
after_epoch
: called at the end of an epoch, for any clean-up before the next oneafter_cancel_fit
: reached immediately after CancelFitException before proceeding toafter_fit
after_fit
: called at the end of training, for any final clean-up
Learner
Learner is a basic class that handles training loop of pytorch model and utilize a systems of callbacks that makes training loop very customizable and extensible. You just need to provide a list of callbacks and callback functions.
Attributes:
Name | Type | Description |
---|---|---|
model |
Module
|
Pytorch's model. |
dls |
DataLoaders
|
Train and valid data loaders. |
n_inp |
int
|
Number of inputs to the model. |
loss_func |
Callable[[tensor, tensor], tensor]
|
Loss function. |
opt_func |
Optimizer
|
Optimizer function/class. |
lr |
float
|
Learning rate. |
splitter |
Callable[[Module], Iterable[Parameter]]
|
Function to split model's parameters into groups. |
path |
Path
|
Path to save all artifacts. |
model_dir_path |
Path
|
Model directory path. |
callbacks |
Iterable[Callable] | None, default=None
|
Iterable of callbacks of type |
logger |
Any
|
Logger to log metrics. Default is |
callbacks |
list[Callback]
|
List of all the used callbacks by |
__init__(model, dls, n_inp=1, loss_func=F.mse_loss, opt_func=opt.SGD, lr=0.01, splitter=params_getter, path='.', model_dir='models', callbacks=None)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
Module
|
Pytorch's model. |
required |
dls |
DataLoaders
|
Train and valid data loaders. |
required |
n_inp |
int
|
Number of inputs to the model. |
1
|
loss_func |
Callable[[tensor, tensor], tensor]
|
Loss function. |
F.mse_loss
|
opt_func |
Optimizer
|
Optimizer function/class. |
opt.SGD
|
lr |
float
|
Learning rate. |
`1e-2`
|
splitter |
Callable[[Module], Iterable[Parameter]]
|
Function to split model's parameters into groups, default all parameters belong to 1 group. |
`params_getter`
|
path |
str
|
Path to save all artifacts. |
"."
|
model_dir |
str
|
Model directory name relative to |
"models"
|
callbacks |
Iterable[Callable] | None
|
Iterable of callbacks of type |
None
|
fit(n_epochs=1, run_train=True, run_valid=True, callbacks=None, lr=None, reset_opt=False)
Fit the model for n_epochs
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
n_epochs |
int
|
Number epochs to train the model. |
1
|
run_train |
bool
|
Whether to run training passes. |
True
|
run_valid |
bool
|
Whether to run validation passes. |
True
|
callbacks |
Iterable[Callback] | None
|
Callbacks to add to the existing callbacks. The added
callbacks will be removed before |
None
|
lr |
float | None
|
Learning rate. If None, |
None
|
reset_opt |
bool
|
Whether to reset the optimizer. |
False
|
load_model(path=None, with_opt=False, with_epoch=False, with_loss=False)
Load the model and optionally the optimizer, epoch, and the loss.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
str | Path | None
|
Model's file path. If None, use |
None
|
with_opt |
bool
|
Whether to load the optimizer state. |
False
|
with_epoch |
bool
|
Whether to load the current epoch number. |
False
|
with_loss |
bool
|
Whether to load the current loss. |
False
|
lr_find(start_lr=1e-07, gamma=1.3, num_iter=100, stop_div=True, max_mult=4)
Try different learning rates using exponential schedule to help pick the best learning rate following Cyclical Learning Rates for Training Neural Networks. When done, plot learning rate vs loss.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
start_lr |
float
|
Start learning rate. |
1e-7
|
gamma |
float
|
Multiplicative factor of learning rate decay. |
1.3
|
num_iter |
int
|
Number of iterations to run the training. |
100
|
stop_div |
(bool, default)
|
Whether to stop training if the loss diverges. |
True
|
max_mult |
int
|
Divergence threshold. If loss >= max_mult * minimum loss, stop training. |
4
|
save_model(path=None, with_opt=False, with_epoch=False, with_loss=False, pickle_protocol=pickle.HIGHEST_PROTOCOL)
Save the model and optionally the optimizer, epoch, and the loss. Useful for checkpointing.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
str | Path | None
|
File path to save the model. If None, use
|
None
|
with_opt |
bool
|
Whether to save the optimizer state. |
False
|
with_epoch |
bool
|
Whether to save the current epoch number. |
False
|
with_loss |
bool
|
Whether to save the current loss. |
False
|
pickle_protocol |
int
|
Protocol used by pickler when saving the checkpoint. |
pickle.HIGHEST_PROTOCOL
|
show_batch(sample_sz=1, callbacks=None, **kwargs)
Show sample_sz
batch of input. The input would be what the
model would see when making predictions. Therefore, all
transformations and other augmentation will be applied to the
input.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
sample_sz |
int
|
Number of input samples to show. |
1
|
callbacks |
Iterable[Callback] | None
|
Callbacks to add to the existing callbacks. The added
callbacks will be removed before |
None
|
Raises:
Type | Description |
---|---|
NotImplementedError
|
Different types of |
summary(verbose=2, **kwargs)
Use torchinfo
package to print out the model summary.
params_getter(model)
Get all parameters of model
recursively.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
Module
|
Model. |
required |
Yields:
Type | Description |
---|---|
Parameter
|
Module parameter. |