Skip to content

Core

Core callback system for training loop management.

This module provides the foundational callback system that allows for custom behavior injection into machine learning training loops. It defines the base Callback class and a set of control flow exceptions that enable fine-grained control over the training process.

Classes:

  • Callback

    Base class for all callbacks, providing the interface for training loop integration and utility methods for callback management.

  • CancelFitException

    Exception to stop training and move to after_fit phase.

  • CancelEpochException

    Exception to stop current epoch and move to after_epoch phase.

  • CancelTrainException

    Exception to stop training phase and move to after_train phase.

  • CancelValidateException

    Exception to stop validation phase and move to after_validate phase.

  • CancelBatchException

    Exception to stop current batch and move to after_batch phase.

  • CancelStepException

    Exception to skip optimizer step and move to after_step phase.

  • CancelBackwardException

    Exception to skip backward pass and move to after_backward phase.

Notes

The callback system works by defining specific event names that correspond to different phases of the training loop. Callbacks can implement methods with these event names to be called at the appropriate times:

  • before_fit: Called before training begins
  • after_fit: Called after training completes
  • before_epoch: Called before each epoch
  • after_epoch: Called after each epoch
  • before_train: Called before training phase of each epoch
  • after_train: Called after training phase of each epoch
  • before_validate: Called before validation phase of each epoch
  • after_validate: Called after validation phase of each epoch
  • before_batch: Called before processing each batch
  • after_batch: Called after processing each batch
  • before_step: Called before optimizer step
  • after_step: Called after optimizer step
  • before_backward: Called before backward pass
  • after_backward: Called after backward pass

Examples:

Creating a custom callback:

>>> class MyCallback(Callback):
...     def before_epoch(self):
...         print(f"Starting epoch {self.epoch}")
...
...     def after_batch(self):
...         if self.loss > 0.5:
...             raise CancelEpochException("Loss too high")

Using control flow exceptions:

>>> class EarlyStoppingCallback(Callback):
...     def after_epoch(self):
...         if self.epoch > 10 and self.valid_loss > self.best_loss:
...             raise CancelFitException("Early stopping triggered")

Callback

Base class for all callbacks.

A callback is a mechanism to inject custom behavior into the training loop at specific points. Callbacks can be used for logging, early stopping, learning rate scheduling, and other custom functionality.

Attributes:

  • order (int) –

    The order in which callbacks should be executed. Lower numbers are executed first. Default is 0.

  • learner (Any) –

    Reference to the learner object, set by set_learner().

Notes

Subclasses should implement specific callback methods that correspond to training events (e.g., before_fit, after_epoch, etc.).

name property

Returns the name of the callback after removing the word callback and then convert it to snake (split words by underscores).

Returns:

  • str

    The callback name in snake_case format with 'Callback' suffix removed. For example, 'TestCallback' becomes 'test'.

__call__(event_nm)

Call the callback method corresponding to the given event name.

If the callback has a method with the same name as the event, it will be called. Otherwise, nothing happens.

Parameters:

  • event_nm (str) –

    The name of the event to handle (e.g., 'before_fit', 'after_epoch').

Returns:

  • Any

    The return value of the callback method, or None if the method doesn't exist.

__getattr__(k)

Allow access to learner attributes directly through the callback.

This would allow us to use self.obj instead of self.learner.obj when we know obj is in learner because it will only be called when getattribute returns AttributeError.

Parameters:

  • k (str) –

    The attribute name to access from the learner.

Returns:

  • Any

    The attribute value from the learner object.

Raises:

  • AttributeError

    If the attribute is not found in the learner object.

__init__()

Initialize the callback.

camel2snake(name) staticmethod

Convert camelCase name to snake_case by inserting underscores.

Inserts underscores between lowercase and uppercase letters. For example, TestCallback becomes test_callback.

Parameters:

  • name (str) –

    The camelCase string to convert.

Returns:

  • str

    The converted snake_case string.

Examples:

>>> Callback.camel2snake("TestCallback")
'test_callback'
>>> Callback.camel2snake("MyCustomCallback")
'my_custom_callback'

set_learner(learner)

Set the learner as an attribute so that callbacks can access learner's attributes without the need to pass learner for every single method in every callback.

Parameters:

  • learner (Any) –

    Learner that the callback will be called when some events happens. This object will be stored as self.learner.

CancelBackwardException

Bases: Exception

Exception raised to skip the backward pass and move to after_backward.

This exception is used to skip the backward pass computation and immediately proceed to the after_backward phase.

CancelBatchException

Bases: Exception

Exception raised to stop current batch and move to after_batch.

This exception is used to interrupt the current batch processing and immediately proceed to the after_batch phase.

CancelEpochException

Bases: Exception

Exception raised to stop current epoch and move to after_epoch.

This exception is used to interrupt the current epoch and immediately proceed to the after_epoch phase.

CancelFitException

Bases: Exception

Exception raised to stop training and move to after_fit.

This exception is used to interrupt the training process and immediately proceed to the after_fit phase of the training loop.

CancelStepException

Bases: Exception

Exception raised to skip stepping the optimizer and move to after_step.

This exception is used to skip the optimizer step and immediately proceed to the after_step phase.

CancelTrainException

Bases: Exception

Exception raised to stop training current epoch and move to after_train.

This exception is used to interrupt the training phase of the current epoch and immediately proceed to the after_train phase.

CancelValidateException

Bases: Exception

Exception raised to stop validation phase and move to after_validate.

This exception is used to interrupt the validation phase and immediately proceed to the after_validate phase.