Training
Training-related callbacks for customizing training and validation loops.
This module provides a comprehensive collection of callbacks that enhance and customize the training process. These callbacks handle various aspects of training including device management, progress tracking, metrics computation, data augmentation, and learning rate optimization.
The callbacks are designed to be easily composable and can be used together to create sophisticated training pipelines. Each callback has a specific execution order to ensure proper sequencing of operations.
Classes:
-
DeviceCallback
–Move model and data to specified device (CPU/GPU).
-
TrainEvalCallback
–Track training progress and manage training/evaluation modes.
-
ProgressCallback
–Display progress bars and live loss plots during training.
-
Recorder
–Record training metrics and optimizer parameters for analysis.
-
ModelResetter
–Reset model parameters at various training stages.
-
LRFinder
–Find optimal learning rate using exponential scheduling.
-
BatchTransform
–Apply transformations to entire batches.
-
BatchTransformX
–Apply transformations to input features only.
-
SingleBatchCallback
–Run only one batch for debugging purposes.
-
SingleBatchForwardCallback
–Run one batch and stop after forward pass.
-
MetricsCallback
–Compute and log various metrics during training.
-
Mixup
–Implement mixup data augmentation technique.
Examples:
Basic training with device management and progress tracking:
>>> from cmn_ai.callbacks.training import DeviceCallback, ProgressCallback
>>> from cmn_ai.callbacks.training import TrainEvalCallback, Recorder
>>>
>>> # Create callbacks
>>> callbacks = [
... DeviceCallback(device='cuda'), # Move to GPU
... TrainEvalCallback(), # Track progress
... ProgressCallback(plot=True), # Show progress bars
... Recorder('lr', 'momentum') # Record learning rate and momentum
... ]
>>>
>>> # Add to learner
>>> learner.add_cbs(callbacks)
Learning rate finding:
>>> from cmn_ai.callbacks.training import LRFinder
>>>
>>> # Find optimal learning rate
>>> lr_finder = LRFinder(gamma=1.3, num_iter=100, stop_div=True)
>>> learner.add_cb(lr_finder)
>>> learner.fit(1) # Run for 1 epoch
>>> lr_finder.recorder.plot() # Plot lr vs loss
Data augmentation with mixup:
>>> from cmn_ai.callbacks.training import Mixup
>>>
>>> # Add mixup augmentation
>>> mixup = Mixup(alpha=0.4)
>>> learner.add_cb(mixup)
>>> learner.fit(10)
Metrics tracking:
>>> from torcheval.metrics import MulticlassAccuracy
>>> from cmn_ai.callbacks.training import MetricsCallback
>>>
>>> # Track accuracy during training
>>> accuracy = MulticlassAccuracy()
>>> metrics_cb = MetricsCallback(accuracy=accuracy)
>>> learner.add_cb(metrics_cb)
>>> learner.fit(5)
Debugging with single batch:
>>> from cmn_ai.callbacks.training import SingleBatchCallback
>>>
>>> # Run only one batch for debugging
>>> debug_cb = SingleBatchCallback()
>>> learner.add_cb(debug_cb)
>>> learner.fit(1) # Will stop after first batch
Raises:
-
CancelFitException
–Stop training and move to after_fit.
-
CancelEpochException
–Stop current epoch and move to after_epoch.
-
CancelTrainException
–Stop training current epoch and move to after_train.
-
CancelValidException
–Stop validation phase and move after_validate.
-
CancelBatchException
–Stop current batch and move to after_batch.
-
CancelStepException
–Skip stepping the optimizer.
-
CancelBackwardException
–Skip the backward pass and move to after_backward.
Notes
- Callbacks are executed in order based on their
order
attribute - Lower order numbers execute earlier
- Some callbacks modify the training loop behavior significantly
- Always test callbacks individually before combining them
- The
Recorder
callback is essential for post-training analysis LRFinder
should be used before full training to find optimal learning rate
BatchTransform
Bases: Callback
Apply transformations to entire batches before processing.
This callback applies a transformation function to the entire batch before it's processed by the model. The transformation runs on the device where the data is located.
Attributes:
-
order
(int
) –Callback execution order (DeviceCallback.order + 1).
-
tfm
(Callback
) –Transformation function to apply.
-
on_train
(bool
) –Whether to apply transformation during training.
-
on_valid
(bool
) –Whether to apply transformation during validation.
__init__(tfm, on_train=True, on_valid=True)
Initialize BatchTransform.
Parameters:
-
tfm
(Callback
) –Transformation function to apply on the batch.
-
on_train
(bool
, default:True
) –Whether to apply the transformation during training.
-
on_valid
(bool
, default:True
) –Whether to apply the transformation during validation.
before_batch()
Apply transformation to batch if conditions are met.
BatchTransformX
Bases: Callback
Apply transformations to input features (X) only.
This callback applies a transformation function specifically to the input features (X) of the batch, leaving the targets (Y) unchanged.
Attributes:
-
order
(int
) –Callback execution order (DeviceCallback.order + 1).
-
tfm
(Callback
) –Transformation function to apply.
-
on_train
(bool
) –Whether to apply transformation during training.
-
on_valid
(bool
) –Whether to apply transformation during validation.
__init__(tfm, on_train=True, on_valid=True)
Initialize BatchTransformX.
Parameters:
-
tfm
(Callback
) –Transformation function to apply on the input features.
-
on_train
(bool
, default:True
) –Whether to apply the transformation during training.
-
on_valid
(bool
, default:True
) –Whether to apply the transformation during validation.
before_batch()
Apply transformation to input features if conditions are met.
DeviceCallback
Bases: Callback
Move batch and model to specified device.
This callback ensures that both the model and input data are moved to the specified device (CPU/GPU) before training begins.
Attributes:
-
device
(str | device
) –Device to copy batch and model to.
__init__(device=DEFAULT_DEVICE)
Initialize DeviceCallback.
Parameters:
-
device
(str | device
, default:DEFAULT_DEVICE
) –Device to copy batch and model to.
before_batch()
Move batch data to specified device before each batch.
before_fit()
Move model to specified device before training starts.
LRFinder
Bases: Callback
Find optimal learning rate using exponential schedule.
This callback implements the learning rate finder technique from "Cyclical Learning Rates for Training Neural Networks". It tries different learning rates using an exponential schedule to help determine the best learning rate for training.
Attributes:
-
gamma
(int
) –Multiplicative factor for learning rate increase.
-
num_iter
(int
) –Number of iterations to run the training.
-
stop_div
(bool
) –Whether to stop training if loss diverges.
-
max_mult
(int
) –Divergence threshold multiplier.
-
scheduler
(ExponentialLR
) –Learning rate scheduler.
-
best_loss
(float
) –Best loss encountered during training.
__init__(gamma=1.3, num_iter=100, stop_div=True, max_mult=4)
Initialize LRFinder.
Parameters:
-
gamma
(int
, default:1.3
) –Multiplicative factor for learning rate increase.
-
num_iter
(int
, default:100
) –Number of iterations to run the training.
-
stop_div
(bool
, default:True
) –Whether to stop training if loss diverges (loss > 4 * best_loss).
-
max_mult
(int
, default:4
) –Divergence threshold. If loss >= max_mult * minimum loss, stop training.
after_batch()
Update best loss and check for divergence or completion.
after_fit()
Restore model state and plot learning rate vs loss.
before_fit()
Set up learning rate scheduler and save initial model state.
before_validate()
Skip validation during learning rate finding.
MetricsCallback
Bases: Callback
Compute and log metrics during training and validation.
This callback computes various metrics after each training/validation
epoch and logs them using the learner's logger. Metrics must implement
reset
and compute
methods. It's recommended to use metrics from
the torcheval
package or inherit from its Metrics base class.
Attributes:
-
metrics
(dict[str, Any]
) –Dictionary of named metrics to compute.
-
all_metrics
(dict[str, Any]
) –All metrics including loss metric.
-
stats
(list[str]
) –Current statistics to log.
-
start_time
(float
) –Start time for timing calculations.
__init__(*metrics, **named_metrics)
Initialize MetricsCallback.
Parameters:
-
*metrics
(Any
, default:()
) –Positional metrics to add.
-
**named_metrics
(Any
, default:{}
) –Named metrics to add.
after_batch()
Update metrics with batch predictions and targets.
after_train()
Compute and log metrics after training epoch.
Returns:
-
str
–Logged statistics string.
after_validate()
Compute and log metrics after validation epoch.
Returns:
-
str
–Logged statistics string.
before_fit()
Log metric names as header before training starts.
Returns:
-
str
–Header string with metric names.
before_train()
Reset metrics before training phase.
before_validate()
Reset metrics before validation phase.
Mixup
Bases: Callback
Implement mixup data augmentation technique.
This callback implements the mixup technique where instead of feeding raw data to the model, it uses linear combinations of inputs using alpha from a beta distribution. The labels are also linear combinations of the original labels. Based on the paper "mixup: BEYOND EMPIRICAL RISK MINIMIZATION".
Attributes:
-
order
(int
) –Callback execution order (90).
-
alpha
(float
) –Concentration parameter for Beta distribution.
-
distrib
(Beta
) –Beta distribution for sampling mixup weights.
-
old_loss_func
(Callable
) –Original loss function before mixup.
-
λ
(Tensor
) –Mixup weight for current batch.
-
yb1
(List[Tensor]
) –Shuffled targets for mixup.
__init__(alpha=0.4)
Initialize Mixup.
Parameters:
-
alpha
(float
, default:0.4
) –Concentration parameter for Beta distribution.
after_fit()
Restore original loss function after training ends.
before_batch()
Apply mixup transformation to batch inputs and targets.
The mixup process involves: 1. Drawing samples from a beta distribution for each image 2. Taking the maximum of λ and 1-λ to avoid identical combinations 3. Shuffling the batch for combination 4. Creating linear combinations of inputs and targets
before_fit()
Store original loss function before training starts.
loss_func(pred, yb)
Compute mixup loss combining original and shuffled targets.
Parameters:
-
pred
(Tensor
) –Model predictions.
-
yb
(Tensor
) –Original targets.
Returns:
-
Tensor
–Mixup loss value.
ModelResetter
Bases: Callback
Reset model parameters at various training stages.
This callback is particularly useful for NLP models that need to reset
hidden states. It assumes the model has a reset
method that knows
which parameters to reset and how.
after_fit()
Reset model after training ends.
before_train()
Reset model before training phase.
before_validate()
Reset model before validation phase.
ProgressCallback
Bases: Callback
Track training progress with progress bars and live loss plotting.
This callback provides visual feedback during training by displaying progress bars and optionally plotting training and validation losses in real-time.
Attributes:
-
order
(int
) –Callback execution order (-20).
-
plot
(bool
) –Whether to plot train/valid losses during training.
-
train_losses
(List[float]
) –List of training losses for plotting.
-
valid_losses
(List[float]
) –List of validation losses for plotting.
-
mbar
(master_bar
) –Master progress bar for epochs.
-
pb
(progress_bar
) –Progress bar for batches.
__init__(plot=True)
Initialize ProgressCallback.
Parameters:
-
plot
(bool
, default:True
) –Whether to plot train/valid losses during training.
after_batch()
Update progress bar and optionally plot losses after each batch.
after_epoch()
Update validation loss plot after each epoch.
after_fit()
Clean up progress bar after training ends.
before_fit()
Initialize progress tracking and create progress bars.
before_train()
Set up progress bar before training phase.
before_validate()
Set up progress bar before validation phase.
set_pb()
Create and configure progress bar for current phase.
Recorder
Bases: Callback
Record training metrics and optimizer parameters for later analysis.
This callback keeps track of losses and optimizer parameters (like learning rates) throughout training, enabling post-training analysis and visualization.
Attributes:
-
order
(int
) –Callback execution order (50).
-
params
(List[str]
) –List of parameter names to track.
-
params_records
(Dict[str, List[List[float]]]
) –Recorded parameter values for each parameter group.
-
losses
(List[float]
) –Recorded training losses.
__init__(*params)
Initialize Recorder.
Parameters:
-
*params
(tuple[str, ...]
, default:()
) –Parameter names to track (e.g., 'lr', 'momentum').
after_batch()
Record parameters and loss after each training batch.
before_fit()
Initialize recording structures before training starts.
plot(pgid=-1, skip_last=0)
Plot loss vs learning rate (log-scale).
Parameters:
-
pgid
(int
, default:-1
) –Parameter group index to plot.
-
skip_last
(int
, default:0
) –Number of last losses to skip in plotting.
plot_loss(skip_last=0)
Plot training losses.
Parameters:
-
skip_last
(int
, default:0
) –Number of last losses to skip in plotting.
plot_params(params='lr', pgid=-1, figsize=(8, 6))
Plot parameter values across training iterations.
Parameters:
-
params
(str | Iterable[str]
, default:"lr"
) –Parameter name(s) to plot.
-
pgid
(int
, default:-1
) –Parameter group index to plot.
-
figsize
(tuple[int, int]
, default:(8, 6)
) –Figure size for the plot.
SingleBatchCallback
Bases: Callback
Run only one training/validation batch and stop.
This callback is useful for debugging or when you want to check parameters after processing just one batch. It raises CancelFitException after the first batch to stop training.
Attributes:
-
order
(int
) –Callback execution order (1).
after_batch()
Stop training after the first batch.
SingleBatchForwardCallback
Bases: Callback
Run one batch and stop after forward pass.
This callback runs one training/validation batch and stops after computing the loss (after forward pass) by raising CancelFitException. Useful for debugging or checking parameters after one forward pass.
Attributes:
-
order
(int
) –Callback execution order (1).
after_loss()
Stop training after computing loss for the first batch.
TrainEvalCallback
Bases: Callback
Track training progress and manage training/evaluation modes.
This callback tracks the number of iterations, percentage of training completed, and sets the appropriate training or evaluation mode for the model.
Attributes:
-
order
(int
) –Callback execution order (-10).
after_batch()
Update iteration counter and training percentage after each batch.
before_fit()
Initialize training counters before training starts.
before_train()
Set model to training mode and update training percentage.
before_validate()
Set model to evaluation mode before validation.