Skip to content

Hooks

Hooks are very useful to inspect what is happening during the forward and backward passes such as computing stats of the activations and gradients.

The module contains the following classes:

  • Hook: Registers forward or backward hook for a single module
  • Hooks: Registers forward or backward hook for multiple modules
  • HooksCallback: Use callbacks to register and manage hooks
  • ActivationStats: Computes means/stds for either activation or gradients and plot the computed stats.

ActivationStats

Bases: HooksCallback

Plot the means, std, histogram, and dead activations of all modules' activations if is_forward else gradients.

__init__(modules=None, is_forward=True, bins=40, bins_range=(0, 10))

Parameters:

Name Type Description Default
modules Module | Iterable[Module] | None

Modules to register the hook on. Default to all modules.

None
is_forward bool

Whether to register func as a forward or backward hook.

True
bins int

Number of histogram bins.

40
bins_range Iterable

Lower/Upper end of the histogram's bins range.

(0, 10)

dead_chart(bins_range, figsize=(11, 5))

Plot a line chart of the "dead" activations percentage from all activations.

Parameters:

Name Type Description Default
bins_range list | tuple

Bins range around zero. Bins that are considered dead activations/gradients.

required
figsize tuple

Width, height of the figure.

(11, 5)

plot_hist(figsize=(11, 5))

Plot histogram of activations/gradients as a heatmap.

Parameters:

Name Type Description Default
figsize (tuple, de)

Width, height of the heatmap figure.

(11, 5)

plot_stats(figsize=(10, 4))

Plot means of standard deviation of activations/gradients for each layer that has a registered hook.

Parameters:

Name Type Description Default
figsize tuple

Width, height of the figure.

(10, 4)

Hook

Register either a forward or a backward hook on the module.

__init__(module, func, is_forward=True, **kwargs)

Parameters:

Name Type Description Default
module Module

The module to register the hook on.

required
func Callable

The hook to be registered.

required
is_forward bool

Whether to register func as a forward or backward hook.

True

Hooks

Register hooks on all modules.

__init__(modules, func, is_forward=True, **kwargs)

Parameters:

Name Type Description Default
modules Iterable[Module]

The module to register the hook on.

required
func Callable

The hook to be registered.

required
is_forward bool

Whether to register func as a forward or backward hook.

True

HooksCallback

Bases: Callback

Base class to run hooks on modules as a callback.

__init__(hookfunc, on_train=True, on_valid=False, modules=None, is_forward=True)

Parameters:

Name Type Description Default
hookfunc Callable

The hook to be registered.

required
on_train bool

Whether to run the hook on modules during training.

True
on_valid bool

Whether to run the hook on modules during validation.

False
modules Module | Iterable[Module] | None

Modules to register the hook on. Default to all modules.

None
is_forward bool

Whether to register func as a forward or backward hook.

True

compute_stats(hook, module, inp, outp, bins=40, bins_range=(0, 10))

Compute the means, std, and histogram of module activations/gradients and the hook stats.

Parameters:

Name Type Description Default
hook Hook

Registered hook on the provided module.

required
module Module

Module to compute the stats on.

required
inp Tensor

Input of the module.

required
outp Tensor

Output of the module.

required
bins int

Number of histogram bins.

40
bins_range Iterable

lower/upper end of the histogram's bins range.

(0, 10)

get_hist(hook)

Return matrix-ready for plotting heatmap of activations/gradients.

get_min(hook, bins_range)

Compute the percentage of activations/gradients around zero from hook's histogram matrix.

Parameters:

Name Type Description Default
hook Hook

Hook that has the stats of the activations

required
bins_range list | tuple

Bins range around zero.

required

Returns:

Type Description
Tensor

Percentage of the activations around zero.