Learner
Learner module for training PyTorch models with callback system.
This module provides a Learner
class that handles the training loop of PyTorch models
using a customizable callback system. The training loop consists of a minimal set of
instructions that can be extended and customized through callbacks.
The basic training loop iterates through data and:
- Computes the output of the model from the input
- Calculates a loss between this output and the desired target
- Computes the gradients of this loss with respect to model parameters
- Updates the parameters accordingly
- Zeros all the gradients
Any customization of this training loop is defined in a Callback
object.
A callback can implement actions on the following events:
before_fit
: Called before doing anything, ideal for initial setupbefore_epoch
: Called at the beginning of each epoch, useful for any behavior you need to reset at each epochbefore_train
: Called at the beginning of the training part of an epochbefore_batch
: Called at the beginning of each batch, just after drawing said batch. It can be used to do any setup necessary for the batch (like hyper-parameter scheduling) or to change the input/target before it goes in the model (change of the input with techniques like mixup)after_pred
: Called after computing the output of the model on the batch. It can be used to change that output before it's fed to the loss functionafter_loss
: Called after the loss has been computed, but before the backward pass. It can be used to add any penalty to the loss (AR or TAR in RNN training for instance)after_cancel_backward
: Reached immediately afterCancelBackwardException
after_backward
: Called after the backward pass, but before updating the parameters. It can be used to do any change to the gradients before any updates (gradient clipping for instance)after_cancel_step
: Reached immediately afterCancelStepException
after_step
: Called after the step and before gradients are zeroedafter_cancel_batch
: Reached immediately afterCancelBatchException
before proceeding toafter_batch
after_batch
: Called at the end of a batch, for any clean-up before the next oneafter_cancel_train
: Reached immediately afterCancelTrainException
before proceeding toafter_train
after_train
: Called at the end of the training phase of an epochbefore_validate
: Called at the beginning of the validation phase of an epoch, useful for any setup needed specifically for validationafter_cancel_validate
: Reached immediately afterCancelValidateException
before proceeding toafter_validate
after_validate
: Called at the end of the validation phase of an epochafter_cancel_epoch
: Reached immediately afterCancelEpochException
before proceeding toafter_epoch
after_epoch
: Called at the end of an epoch, for any clean-up before the next oneafter_cancel_fit
: Reached immediately afterCancelFitException
before proceeding toafter_fit
after_fit
: Called at the end of training, for any final clean-up
Classes:
-
Learner
–Main class for training PyTorch models with callback system.
Functions:
-
params_getter
–Get all parameters of a model recursively.
Examples:
Basic usage:
>>> from cmn_ai.learner import Learner
>>> from cmn_ai.utils.data import DataLoaders
>>> import torch.nn as nn
>>>
>>> # Create a simple model and data loaders
>>> model = nn.Linear(10, 1)
>>> dls = DataLoaders(train_dl, valid_dl)
>>>
>>> # Create learner
>>> learner = Learner(model, dls, loss_func=nn.MSELoss())
>>>
>>> # Train the model
>>> learner.fit(n_epochs=10)
Learning rate finding:
>>> learner.lr_find(start_lr=1e-6, num_iter=200)
Model checkpointing:
>>> learner.save_model("checkpoint.pt", with_opt=True, with_epoch=True)
>>> learner.load_model("checkpoint.pt", with_opt=True)
Notes
The TrainEvalCallback
is automatically added to all learners and doesn't need
to be provided manually. This callback handles the basic training and validation
loop management.
See Also
cmn_ai.callbacks.core.Callback : Base callback class cmn_ai.callbacks.training.TrainEvalCallback : Default training callback cmn_ai.utils.data.DataLoaders : Data loader container
Learner
A learner class that handles training loop of PyTorch models.
This class provides a customizable training loop using a callback system. It handles model training, validation, saving/loading, and learning rate finding. The training process can be customized through various callback events.
Attributes:
-
model
(Module
) –The PyTorch model to train.
-
dls
(DataLoaders
) –Training and validation data loaders.
-
n_inp
(int
) –Number of inputs to the model.
-
loss_func
(Callable[[tensor, tensor], tensor]
) –Loss function that takes predictions and targets.
-
opt_func
(Optimizer
) –Optimizer class (not instance).
-
lr
(float
) –Learning rate for training.
-
splitter
(Callable[[Module], Iterable[Parameter]]
) –Function to split model's parameters into groups.
-
path
(Path
) –Base path for saving artifacts.
-
model_dir_path
(Path
) –Directory path for saving models.
-
callbacks
(list[Callback]
) –List of all callbacks used by the learner.
-
logger
(Any
) –Logger for metrics. Default is
print
but typically modified by callbacks such asProgressCallback
.
Examples:
>>> from cmn_ai.learner import Learner
>>> from cmn_ai.utils.data import DataLoaders
>>> import torch.nn as nn
>>>
>>> model = nn.Linear(10, 1)
>>> dls = DataLoaders(train_dl, valid_dl)
>>> learner = Learner(model, dls, loss_func=nn.MSELoss())
>>> learner.fit(n_epochs=5)
training
property
writable
Get the training mode of the model.
Returns:
-
bool
–True if the model is in training mode, False otherwise.
__init__(model, dls, n_inp=1, loss_func=F.mse_loss, opt_func=opt.SGD, lr=0.01, splitter=params_getter, path='.', model_dir='models', callbacks=None)
Initialize the Learner.
Parameters:
-
model
(Module
) –The PyTorch model to train.
-
dls
(DataLoaders
) –Training and validation data loaders.
-
n_inp
(int
, default:1
) –Number of inputs to the model.
-
loss_func
(Callable[[tensor, tensor], tensor]
, default:F.mse_loss
) –Loss function that takes predictions and targets.
-
opt_func
(Optimizer
, default:opt.SGD
) –Optimizer class (not instance).
-
lr
(float
, default:1e-2
) –Learning rate for training.
-
splitter
(Callable[[Module], Iterable[Parameter]]
, default:params_getter
) –Function to split model's parameters into groups.
-
path
(str | Path
, default:"."
) –Base path for saving artifacts.
-
model_dir
(str
, default:"models"
) –Model directory name relative to
path
. -
callbacks
(Iterable[Callback] | None
, default:None
) –Initial callbacks to add to the learner.
Notes
The TrainEvalCallback
is automatically added to the callbacks list
and doesn't need to be provided manually.
fit(n_epochs=1, run_train=True, run_valid=True, callbacks=None, lr=None, reset_opt=False)
Fit the model for a specified number of epochs.
Parameters:
-
n_epochs
(int
, default:1
) –Number of epochs to train the model.
-
run_train
(bool
, default:True
) –Whether to run training passes.
-
run_valid
(bool
, default:True
) –Whether to run validation passes.
-
callbacks
(Iterable[Callback] | None
, default:None
) –Additional callbacks to add temporarily for this fit call. These callbacks will be removed after training completes.
-
lr
(float | None
, default:None
) –Learning rate to use. If None, uses the learner's default lr.
-
reset_opt
(bool
, default:False
) –Whether to reset the optimizer.
Examples:
>>> learner.fit(n_epochs=10, lr=0.001)
>>> learner.fit(n_epochs=5, run_valid=False)
load_model(path=None, with_opt=False, with_epoch=False, with_loss=False)
Load the model and optionally optimizer state, epoch, and loss.
Parameters:
-
path
(str | Path | None
, default:None
) –Model's file path. If None, uses
learner.model_dir_path
/model. -
with_opt
(bool
, default:False
) –Whether to load the optimizer state.
-
with_epoch
(bool
, default:False
) –Whether to load the current epoch number.
-
with_loss
(bool
, default:False
) –Whether to load the current loss value.
Examples:
>>> learner.load_model("checkpoint.pt", with_opt=True)
>>> learner.load_model() # Uses default path
lr_find(start_lr=1e-07, gamma=1.3, num_iter=100, stop_div=True, max_mult=4)
Find optimal learning rate using exponential schedule.
This method implements the learning rate finder described in Cyclical Learning Rates for Training Neural Networks. It tries different learning rates using an exponential schedule and plots learning rate vs loss to help identify the optimal learning rate.
Parameters:
-
start_lr
(float
, default:1e-7
) –Starting learning rate for the search.
-
gamma
(float
, default:1.3
) –Multiplicative factor for learning rate increase.
-
num_iter
(int
, default:100
) –Number of iterations to run the learning rate search.
-
stop_div
(bool
, default:True
) –Whether to stop training if the loss diverges.
-
max_mult
(int
, default:4
) –Divergence threshold. If loss >= max_mult * minimum loss, training stops.
Examples:
>>> learner.lr_find(start_lr=1e-6, num_iter=200)
save_model(path=None, with_opt=False, with_epoch=False, with_loss=False, pickle_protocol=pickle.HIGHEST_PROTOCOL)
Save the model and optionally optimizer state, epoch, and loss.
This method is useful for checkpointing during training. It saves the model state dict and optionally includes optimizer state, current epoch, and current loss.
Parameters:
-
path
(str | Path | None
, default:None
) –File path to save the model. If None, uses
learner.model_dir_path
/model. -
with_opt
(bool
, default:False
) –Whether to save the optimizer state.
-
with_epoch
(bool
, default:False
) –Whether to save the current epoch number.
-
with_loss
(bool
, default:False
) –Whether to save the current loss value.
-
pickle_protocol
(int
, default:pickle.HIGHEST_PROTOCOL
) –Protocol used by pickler when saving the checkpoint.
Examples:
>>> learner.save_model("checkpoint.pt", with_opt=True, with_epoch=True)
>>> learner.save_model() # Uses default path
show_batch(sample_sz=1, callbacks=None, **kwargs)
Show a sample batch of input data.
This method displays what the model would see when making predictions, including all transformations and augmentations applied to the input.
Parameters:
-
sample_sz
(int
, default:1
) –Number of input samples to show.
-
callbacks
(Iterable[Callback] | None
, default:None
) –Additional callbacks to add temporarily for this operation. These callbacks will be removed after the operation completes.
-
**kwargs
(Any
, default:{}
) –Additional arguments passed to the show_batch implementation.
Raises:
-
NotImplementedError
–Different types of
Learner
's must implement their own version depending on the type of input data. For example,VisionLearner
's would show images.
Examples:
>>> learner.show_batch(sample_sz=3)
summary(verbose=2, **kwargs)
Generate and display model summary using torchinfo.
Parameters:
-
verbose
(int
, default:2
) –Verbosity level for the summary output.
-
**kwargs
(Any
, default:{}
) –Additional arguments passed to torchinfo.summary.
Returns:
-
Any
–The summary object returned by torchinfo.
Examples:
>>> learner.summary(verbose=1)
>>> learner.summary(col_names=["input_size", "output_size"])
params_getter(model)
Get all parameters of a model recursively.
Parameters:
-
model
(Module
) –The PyTorch model to extract parameters from.
Yields:
-
Parameter
–Each parameter of the model.
Examples:
>>> model = nn.Linear(10, 1)
>>> params = list(params_getter(model))
>>> len(params)
2