Skip to content

Training

Training-related callbacks for customizing training and validation loops.

This module provides a comprehensive collection of callbacks that enhance and customize the training process. These callbacks handle various aspects of training including device management, progress tracking, metrics computation, data augmentation, and learning rate optimization.

The callbacks are designed to be easily composable and can be used together to create sophisticated training pipelines. Each callback has a specific execution order to ensure proper sequencing of operations.

Classes:

Examples:

Basic training with device management and progress tracking:

>>> from cmn_ai.callbacks.training import DeviceCallback, ProgressCallback
>>> from cmn_ai.callbacks.training import TrainEvalCallback, Recorder
>>>
>>> # Create callbacks
>>> callbacks = [
...     DeviceCallback(device='cuda'),  # Move to GPU
...     TrainEvalCallback(),            # Track progress
...     ProgressCallback(plot=True),    # Show progress bars
...     Recorder('lr', 'momentum')      # Record learning rate and momentum
... ]
>>>
>>> # Add to learner
>>> learner.add_cbs(callbacks)

Learning rate finding:

>>> from cmn_ai.callbacks.training import LRFinder
>>>
>>> # Find optimal learning rate
>>> lr_finder = LRFinder(gamma=1.3, num_iter=100, stop_div=True)
>>> learner.add_cb(lr_finder)
>>> learner.fit(1)  # Run for 1 epoch
>>> lr_finder.recorder.plot()  # Plot lr vs loss

Data augmentation with mixup:

>>> from cmn_ai.callbacks.training import Mixup
>>>
>>> # Add mixup augmentation
>>> mixup = Mixup(alpha=0.4)
>>> learner.add_cb(mixup)
>>> learner.fit(10)

Metrics tracking:

>>> from torcheval.metrics import MulticlassAccuracy
>>> from cmn_ai.callbacks.training import MetricsCallback
>>>
>>> # Track accuracy during training
>>> accuracy = MulticlassAccuracy()
>>> metrics_cb = MetricsCallback(accuracy=accuracy)
>>> learner.add_cb(metrics_cb)
>>> learner.fit(5)

Debugging with single batch:

>>> from cmn_ai.callbacks.training import SingleBatchCallback
>>>
>>> # Run only one batch for debugging
>>> debug_cb = SingleBatchCallback()
>>> learner.add_cb(debug_cb)
>>> learner.fit(1)  # Will stop after first batch

Raises:

  • CancelFitException

    Stop training and move to after_fit.

  • CancelEpochException

    Stop current epoch and move to after_epoch.

  • CancelTrainException

    Stop training current epoch and move to after_train.

  • CancelValidException

    Stop validation phase and move after_validate.

  • CancelBatchException

    Stop current batch and move to after_batch.

  • CancelStepException

    Skip stepping the optimizer.

  • CancelBackwardException

    Skip the backward pass and move to after_backward.

Notes
  • Callbacks are executed in order based on their order attribute
  • Lower order numbers execute earlier
  • Some callbacks modify the training loop behavior significantly
  • Always test callbacks individually before combining them
  • The Recorder callback is essential for post-training analysis
  • LRFinder should be used before full training to find optimal learning rate

BatchTransform

Bases: Callback

Apply transformations to entire batches before processing.

This callback applies a transformation function to the entire batch before it's processed by the model. The transformation runs on the device where the data is located.

Attributes:

  • order (int) –

    Callback execution order (DeviceCallback.order + 1).

  • tfm (Callback) –

    Transformation function to apply.

  • on_train (bool) –

    Whether to apply transformation during training.

  • on_valid (bool) –

    Whether to apply transformation during validation.

__init__(tfm, on_train=True, on_valid=True)

Initialize BatchTransform.

Parameters:

  • tfm (Callback) –

    Transformation function to apply on the batch.

  • on_train (bool, default: True ) –

    Whether to apply the transformation during training.

  • on_valid (bool, default: True ) –

    Whether to apply the transformation during validation.

before_batch()

Apply transformation to batch if conditions are met.

BatchTransformX

Bases: Callback

Apply transformations to input features (X) only.

This callback applies a transformation function specifically to the input features (X) of the batch, leaving the targets (Y) unchanged.

Attributes:

  • order (int) –

    Callback execution order (DeviceCallback.order + 1).

  • tfm (Callback) –

    Transformation function to apply.

  • on_train (bool) –

    Whether to apply transformation during training.

  • on_valid (bool) –

    Whether to apply transformation during validation.

__init__(tfm, on_train=True, on_valid=True)

Initialize BatchTransformX.

Parameters:

  • tfm (Callback) –

    Transformation function to apply on the input features.

  • on_train (bool, default: True ) –

    Whether to apply the transformation during training.

  • on_valid (bool, default: True ) –

    Whether to apply the transformation during validation.

before_batch()

Apply transformation to input features if conditions are met.

DeviceCallback

Bases: Callback

Move batch and model to specified device.

This callback ensures that both the model and input data are moved to the specified device (CPU/GPU) before training begins.

Attributes:

  • device (str | device) –

    Device to copy batch and model to.

__init__(device=DEFAULT_DEVICE)

Initialize DeviceCallback.

Parameters:

  • device (str | device, default: DEFAULT_DEVICE ) –

    Device to copy batch and model to.

before_batch()

Move batch data to specified device before each batch.

before_fit()

Move model to specified device before training starts.

LRFinder

Bases: Callback

Find optimal learning rate using exponential schedule.

This callback implements the learning rate finder technique from "Cyclical Learning Rates for Training Neural Networks". It tries different learning rates using an exponential schedule to help determine the best learning rate for training.

Attributes:

  • gamma (int) –

    Multiplicative factor for learning rate increase.

  • num_iter (int) –

    Number of iterations to run the training.

  • stop_div (bool) –

    Whether to stop training if loss diverges.

  • max_mult (int) –

    Divergence threshold multiplier.

  • scheduler (ExponentialLR) –

    Learning rate scheduler.

  • best_loss (float) –

    Best loss encountered during training.

__init__(gamma=1.3, num_iter=100, stop_div=True, max_mult=4)

Initialize LRFinder.

Parameters:

  • gamma (int, default: 1.3 ) –

    Multiplicative factor for learning rate increase.

  • num_iter (int, default: 100 ) –

    Number of iterations to run the training.

  • stop_div (bool, default: True ) –

    Whether to stop training if loss diverges (loss > 4 * best_loss).

  • max_mult (int, default: 4 ) –

    Divergence threshold. If loss >= max_mult * minimum loss, stop training.

after_batch()

Update best loss and check for divergence or completion.

after_fit()

Restore model state and plot learning rate vs loss.

before_fit()

Set up learning rate scheduler and save initial model state.

before_validate()

Skip validation during learning rate finding.

MetricsCallback

Bases: Callback

Compute and log metrics during training and validation.

This callback computes various metrics after each training/validation epoch and logs them using the learner's logger. Metrics must implement reset and compute methods. It's recommended to use metrics from the torcheval package or inherit from its Metrics base class.

Attributes:

  • metrics (dict[str, Any]) –

    Dictionary of named metrics to compute.

  • all_metrics (dict[str, Any]) –

    All metrics including loss metric.

  • stats (list[str]) –

    Current statistics to log.

  • start_time (float) –

    Start time for timing calculations.

__init__(*metrics, **named_metrics)

Initialize MetricsCallback.

Parameters:

  • *metrics (Any, default: () ) –

    Positional metrics to add.

  • **named_metrics (Any, default: {} ) –

    Named metrics to add.

after_batch()

Update metrics with batch predictions and targets.

after_train()

Compute and log metrics after training epoch.

Returns:

  • str

    Logged statistics string.

after_validate()

Compute and log metrics after validation epoch.

Returns:

  • str

    Logged statistics string.

before_fit()

Log metric names as header before training starts.

Returns:

  • str

    Header string with metric names.

before_train()

Reset metrics before training phase.

before_validate()

Reset metrics before validation phase.

Mixup

Bases: Callback

Implement mixup data augmentation technique.

This callback implements the mixup technique where instead of feeding raw data to the model, it uses linear combinations of inputs using alpha from a beta distribution. The labels are also linear combinations of the original labels. Based on the paper "mixup: BEYOND EMPIRICAL RISK MINIMIZATION".

Attributes:

  • order (int) –

    Callback execution order (90).

  • alpha (float) –

    Concentration parameter for Beta distribution.

  • distrib (Beta) –

    Beta distribution for sampling mixup weights.

  • old_loss_func (Callable) –

    Original loss function before mixup.

  • λ (Tensor) –

    Mixup weight for current batch.

  • yb1 (List[Tensor]) –

    Shuffled targets for mixup.

__init__(alpha=0.4)

Initialize Mixup.

Parameters:

  • alpha (float, default: 0.4 ) –

    Concentration parameter for Beta distribution.

after_fit()

Restore original loss function after training ends.

before_batch()

Apply mixup transformation to batch inputs and targets.

The mixup process involves: 1. Drawing samples from a beta distribution for each image 2. Taking the maximum of λ and 1-λ to avoid identical combinations 3. Shuffling the batch for combination 4. Creating linear combinations of inputs and targets

before_fit()

Store original loss function before training starts.

loss_func(pred, yb)

Compute mixup loss combining original and shuffled targets.

Parameters:

  • pred (Tensor) –

    Model predictions.

  • yb (Tensor) –

    Original targets.

Returns:

  • Tensor

    Mixup loss value.

ModelResetter

Bases: Callback

Reset model parameters at various training stages.

This callback is particularly useful for NLP models that need to reset hidden states. It assumes the model has a reset method that knows which parameters to reset and how.

after_fit()

Reset model after training ends.

before_train()

Reset model before training phase.

before_validate()

Reset model before validation phase.

ProgressCallback

Bases: Callback

Track training progress with progress bars and live loss plotting.

This callback provides visual feedback during training by displaying progress bars and optionally plotting training and validation losses in real-time.

Attributes:

  • order (int) –

    Callback execution order (-20).

  • plot (bool) –

    Whether to plot train/valid losses during training.

  • train_losses (List[float]) –

    List of training losses for plotting.

  • valid_losses (List[float]) –

    List of validation losses for plotting.

  • mbar (master_bar) –

    Master progress bar for epochs.

  • pb (progress_bar) –

    Progress bar for batches.

__init__(plot=True)

Initialize ProgressCallback.

Parameters:

  • plot (bool, default: True ) –

    Whether to plot train/valid losses during training.

after_batch()

Update progress bar and optionally plot losses after each batch.

after_epoch()

Update validation loss plot after each epoch.

after_fit()

Clean up progress bar after training ends.

before_fit()

Initialize progress tracking and create progress bars.

before_train()

Set up progress bar before training phase.

before_validate()

Set up progress bar before validation phase.

set_pb()

Create and configure progress bar for current phase.

Recorder

Bases: Callback

Record training metrics and optimizer parameters for later analysis.

This callback keeps track of losses and optimizer parameters (like learning rates) throughout training, enabling post-training analysis and visualization.

Attributes:

  • order (int) –

    Callback execution order (50).

  • params (List[str]) –

    List of parameter names to track.

  • params_records (Dict[str, List[List[float]]]) –

    Recorded parameter values for each parameter group.

  • losses (List[float]) –

    Recorded training losses.

__init__(*params)

Initialize Recorder.

Parameters:

  • *params (tuple[str, ...], default: () ) –

    Parameter names to track (e.g., 'lr', 'momentum').

after_batch()

Record parameters and loss after each training batch.

before_fit()

Initialize recording structures before training starts.

plot(pgid=-1, skip_last=0)

Plot loss vs learning rate (log-scale).

Parameters:

  • pgid (int, default: -1 ) –

    Parameter group index to plot.

  • skip_last (int, default: 0 ) –

    Number of last losses to skip in plotting.

plot_loss(skip_last=0)

Plot training losses.

Parameters:

  • skip_last (int, default: 0 ) –

    Number of last losses to skip in plotting.

plot_params(params='lr', pgid=-1, figsize=(8, 6))

Plot parameter values across training iterations.

Parameters:

  • params (str | Iterable[str], default: "lr" ) –

    Parameter name(s) to plot.

  • pgid (int, default: -1 ) –

    Parameter group index to plot.

  • figsize (tuple[int, int], default: (8, 6) ) –

    Figure size for the plot.

SingleBatchCallback

Bases: Callback

Run only one training/validation batch and stop.

This callback is useful for debugging or when you want to check parameters after processing just one batch. It raises CancelFitException after the first batch to stop training.

Attributes:

  • order (int) –

    Callback execution order (1).

after_batch()

Stop training after the first batch.

SingleBatchForwardCallback

Bases: Callback

Run one batch and stop after forward pass.

This callback runs one training/validation batch and stops after computing the loss (after forward pass) by raising CancelFitException. Useful for debugging or checking parameters after one forward pass.

Attributes:

  • order (int) –

    Callback execution order (1).

after_loss()

Stop training after computing loss for the first batch.

TrainEvalCallback

Bases: Callback

Track training progress and manage training/evaluation modes.

This callback tracks the number of iterations, percentage of training completed, and sets the appropriate training or evaluation mode for the model.

Attributes:

  • order (int) –

    Callback execution order (-10).

after_batch()

Update iteration counter and training percentage after each batch.

before_fit()

Initialize training counters before training starts.

before_train()

Set model to training mode and update training percentage.

before_validate()

Set model to evaluation mode before validation.