Skip to content

Hooks

Hooks for inspecting neural network activations and gradients during training.

This module provides utilities for registering hooks on PyTorch modules to inspect what is happening during forward and backward passes. This is useful for computing statistics of activations and gradients, debugging training issues, and monitoring model behavior.

Hooks are very useful to inspect what is happening during the forward and backward passes such as computing stats of the activations and gradients.

Functions:

  • compute_stats : Callable

    Compute means, std, and histogram of module activations/gradients.

  • get_hist : Callable

    Return matrix-ready for plotting heatmap of activations/gradients.

  • get_min : Callable

    Compute percentage of activations/gradients around zero from histogram.

Classes:

  • Hook : object

    Register either a forward or a backward hook on a single module.

  • Hooks : object

    Register hooks on multiple modules with context manager support.

  • HooksCallback : Callback

    Base class to run hooks on modules as a callback.

  • ActivationStats : HooksCallback

    Plot means, std, histogram, and dead activations of all modules.

Examples:

>>> # Register a hook on a single module
>>> hook = Hook(model.layer1, compute_stats)
>>> # Use as context manager
>>> with Hook(model.layer1, compute_stats):
...     output = model(input)
>>> # Register hooks on multiple modules
>>> hooks = Hooks(model.children(), compute_stats)
>>> hooks.remove()  # Clean up
>>> # Use as callback during training
>>> stats = ActivationStats(model, is_forward=True)
>>> learner.add_callback(stats)
>>> stats.plot_stats()  # Plot activation statistics

ActivationStats

Bases: HooksCallback

Plot activation/gradient statistics for all modules.

This class automatically computes and can plot various statistics of module activations (or gradients if is_forward=False), including means, standard deviations, histograms, and dead activation percentages.

Attributes:

  • bins (int) –

    Number of histogram bins.

  • bins_range (list | tuple) –

    Range for histogram bins.

Examples:

>>> # Monitor activation statistics during training
>>> stats = ActivationStats(model, is_forward=True)
>>> learner.add_callback(stats)
>>>
>>> # After training, plot the statistics
>>> stats.plot_stats()
>>> stats.plot_hist()
>>> stats.dead_chart([0, 5])  # Show dead activations in bins 0-4

__init__(modules=None, is_forward=True, bins=40, bins_range=(0, 10))

Initialize the activation statistics callback.

Parameters:

  • modules (Module | Iterable[Module] | None, default: None ) –

    Modules to register the hook on. If None, uses all model children.

  • is_forward (bool, default: True ) –

    Whether to monitor activations (True) or gradients (False).

  • bins (int, default: 40 ) –

    Number of histogram bins.

  • bins_range (list | tuple, default: (0, 10) ) –

    Lower and upper bounds for histogram bins.

dead_chart(bins_range, figsize=(11, 5))

Plot a line chart of the "dead" activations percentage.

Shows the percentage of activations/gradients that fall within the specified range around zero over time, which can help identify when neurons become inactive.

Parameters:

  • bins_range (list | tuple) –

    Range of bins around zero to consider as "dead" activations/gradients. Should be a slice-like object (e.g., [0, 5] for bins 0-4).

  • figsize (tuple, default: (11, 5) ) –

    Width and height of the figure in inches.

plot_hist(figsize=(11, 5))

Plot histogram of activations/gradients as a heatmap.

Creates a heatmap visualization where each row represents a histogram bin and each column represents a timestep during training.

Parameters:

  • figsize (tuple, default: (11, 5) ) –

    Width and height of the figure in inches.

plot_stats(figsize=(10, 4))

Plot means and standard deviations of activations/gradients.

Creates two subplots showing the mean and standard deviation of activations/gradients for each layer over time.

Parameters:

  • figsize (tuple, default: (10, 4) ) –

    Width and height of the figure in inches.

Hook

Register either a forward or a backward hook on a single module.

This class provides a convenient way to register hooks on PyTorch modules and automatically handle cleanup. It can be used as a context manager for automatic hook removal.

Attributes:

  • is_forward (bool) –

    Whether the hook is a forward or backward hook.

  • hook (RemovableHandle) –

    The registered hook handle for cleanup.

Examples:

>>> hook = Hook(model.layer1, compute_stats)
>>> # ... use the hook
>>> hook.remove()
>>> # Use as context manager
>>> with Hook(model.layer1, compute_stats):
...     output = model(input)

__del__()

Destructor to ensure hook removal.

__enter__()

Enter the context manager.

Returns:

  • Hook

    Self reference for context manager usage.

__exit__(*args)

Exit the context manager and remove the hook.

__init__(module, func, is_forward=True, **kwargs)

Initialize the hook.

Parameters:

  • module (Module) –

    The module to register the hook on.

  • func (Callable) –

    The hook function to be registered. Should accept (hook, module, input, output) for forward hooks or (hook, module, grad_input, grad_output) for backward hooks.

  • is_forward (bool, default: True ) –

    Whether to register func as a forward or backward hook.

  • **kwargs

    Additional keyword arguments to pass to the hook function.

remove()

Remove the registered hook.

This method removes the hook from the module and should be called to prevent memory leaks.

Hooks

Register hooks on multiple modules with convenient management.

This class provides a container for multiple hooks with convenient iteration, indexing, and cleanup methods. It can be used as a context manager for automatic cleanup of all hooks.

Attributes:

  • hooks (list[Hook]) –

    List of registered hooks.

Examples:

>>> hooks = Hooks(model.children(), compute_stats)
>>> for hook in hooks:
...     print(hook.stats)
>>> hooks.remove()
>>> # Use as context manager
>>> with Hooks(model.children(), compute_stats):
...     output = model(input)

__del__()

Destructor to ensure all hooks are removed.

__enter__()

Enter the context manager.

Returns:

  • Hooks

    Self reference for context manager usage.

__exit__(*args)

Exit the context manager and remove all hooks.

__getitem__(idx)

Get a hook by index.

Parameters:

  • idx (int) –

    Index of the hook to retrieve.

Returns:

  • Hook

    The hook at the specified index.

__init__(modules, func, is_forward=True, **kwargs)

Initialize hooks for multiple modules.

Parameters:

  • modules (Iterable[Module]) –

    Iterable of modules to register hooks on.

  • func (Callable) –

    The hook function to be registered on each module.

  • is_forward (bool, default: True ) –

    Whether to register func as a forward or backward hook.

  • **kwargs

    Additional keyword arguments to pass to the hook function.

__iter__()

Iterate over all hooks.

Returns:

  • Iterator[Hook]

    Iterator over all registered hooks.

__len__()

Get the number of hooks.

Returns:

  • int

    Number of registered hooks.

remove()

Remove all registered hooks.

This method removes all hooks from their respective modules and should be called to prevent memory leaks.

HooksCallback

Bases: Callback

Base class to run hooks on modules as a callback.

This class provides a convenient way to register and manage hooks during training using the callback system. It automatically handles hook registration before training and cleanup after training.

Attributes:

  • hookfunc (Callable) –

    The hook function to be registered on modules.

  • on_train (bool) –

    Whether to run hooks during training.

  • on_valid (bool) –

    Whether to run hooks during validation.

  • modules (list[Module]) –

    List of modules to register hooks on.

  • is_forward (bool) –

    Whether to register forward or backward hooks.

  • hooks (Hooks) –

    The hooks object managing all registered hooks.

Examples:

>>> # Create a custom hook callback
>>> class MyHookCallback(HooksCallback):
...     def __init__(self):
...         super().__init__(compute_stats, on_train=True, on_valid=False)
>>>
>>> callback = MyHookCallback()
>>> learner.add_callback(callback)

__init__(hookfunc, on_train=True, on_valid=False, modules=None, is_forward=True)

Initialize the hooks callback.

Parameters:

  • hookfunc (Callable) –

    The hook function to be registered on modules.

  • on_train (bool, default: True ) –

    Whether to run the hook on modules during training.

  • on_valid (bool, default: False ) –

    Whether to run the hook on modules during validation.

  • modules (Module | Iterable[Module] | None, default: None ) –

    Modules to register the hook on. If None, uses all model children.

  • is_forward (bool, default: True ) –

    Whether to register forward or backward hooks.

__iter__()

Iterate over all registered hooks.

Returns:

  • Iterator[Hook]

    Iterator over all registered hooks.

__len__()

Get the number of registered hooks.

Returns:

  • int

    Number of registered hooks.

after_fit()

Remove all hooks after training ends.

before_fit()

Register hooks before training begins.

If no modules are specified, registers hooks on all model children.

compute_stats(hook, module, inp, outp, bins=40, bins_range=(0, 10))

Compute the means, std, and histogram of module activations/gradients.

This function is designed to be used as a hook function. It computes statistics of the module's output (activations for forward hooks, gradients for backward hooks) and stores them in the hook object.

Parameters:

  • hook (Hook) –

    The registered hook object where stats will be stored.

  • module (Module) –

    The module that the hook is registered on.

  • inp (Tensor) –

    Input to the module (for forward hooks) or gradient input (for backward hooks).

  • outp (Tensor) –

    Output from the module (for forward hooks) or gradient output (for backward hooks).

  • bins (int, default: 40 ) –

    Number of histogram bins.

  • bins_range (list | tuple, default: (0, 10) ) –

    Lower and upper bounds for the histogram bins.

Notes

The computed statistics are stored in hook.stats as a tuple of three lists:

  • hook.stats[0]: List of mean values
  • hook.stats[1]: List of standard deviation values
  • hook.stats[2]: List of histogram tensors

get_hist(hook)

Return matrix-ready for plotting heatmap of activations/gradients.

Parameters:

  • hook (Hook) –

    Hook object containing histogram statistics.

Returns:

  • Tensor

    Matrix of histogram data ready for plotting as a heatmap. Shape is (bins, timesteps) with log1p applied for better visualization.

get_min(hook, bins_range)

Compute the percentage of activations/gradients around zero.

This function calculates what percentage of the activations or gradients fall within the specified range around zero, which can be useful for identifying "dead" neurons or gradients.

Parameters:

  • hook (Hook) –

    Hook object containing histogram statistics.

  • bins_range (list | tuple) –

    Range of bins around zero to consider as "dead" activations/gradients. Should be a slice-like object (e.g., [0, 5] for bins 0-4).

Returns:

  • Tensor

    Percentage of activations/gradients around zero for each timestep. Values range from 0 to 1.

Examples:

>>> # Get percentage of activations in bins 0-4 (around zero)
>>> dead_percentage = get_min(hook, [0, 5])
>>> print(f"Dead activations: {dead_percentage.mean():.2%}")