Hooks
Hooks are very useful to inspect what is happening during the forward and backward passes such as computing stats of the activations and gradients.
The module contains the following classes:
Hook
: Registers forward or backward hook for a single moduleHooks
: Registers forward or backward hook for multiple modulesHooksCallback
: Use callbacks to register and manage hooksActivationStats
: Computes means/stds for either activation or gradients and plot the computed stats.
ActivationStats
Bases: HooksCallback
Plot the means, std, histogram, and dead activations of all modules'
activations if is_forward
else gradients.
__init__(modules=None, is_forward=True, bins=40, bins_range=(0, 10))
Parameters:
Name | Type | Description | Default |
---|---|---|---|
modules |
Module | Iterable[Module] | None
|
Modules to register the hook on. Default to all modules. |
None
|
is_forward |
bool
|
Whether to register |
True
|
bins |
int
|
Number of histogram bins. |
40
|
bins_range |
Iterable
|
Lower/Upper end of the histogram's bins range. |
(0, 10)
|
dead_chart(bins_range, figsize=(11, 5))
Plot a line chart of the "dead" activations percentage from all activations.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
bins_range |
list | tuple
|
Bins range around zero. Bins that are considered dead activations/gradients. |
required |
figsize |
tuple
|
Width, height of the figure. |
(11, 5)
|
plot_hist(figsize=(11, 5))
Plot histogram of activations/gradients as a heatmap.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
figsize |
(tuple, de)
|
Width, height of the heatmap figure. |
(11, 5)
|
plot_stats(figsize=(10, 4))
Plot means of standard deviation of activations/gradients for each layer that has a registered hook.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
figsize |
tuple
|
Width, height of the figure. |
(10, 4)
|
Hook
Register either a forward or a backward hook on the module.
__init__(module, func, is_forward=True, **kwargs)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
module |
Module
|
The module to register the hook on. |
required |
func |
Callable
|
The hook to be registered. |
required |
is_forward |
bool
|
Whether to register |
True
|
Hooks
Register hooks on all modules.
__init__(modules, func, is_forward=True, **kwargs)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
modules |
Iterable[Module]
|
The module to register the hook on. |
required |
func |
Callable
|
The hook to be registered. |
required |
is_forward |
bool
|
Whether to register |
True
|
HooksCallback
Bases: Callback
Base class to run hooks on modules as a callback.
__init__(hookfunc, on_train=True, on_valid=False, modules=None, is_forward=True)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
hookfunc |
Callable
|
The hook to be registered. |
required |
on_train |
bool
|
Whether to run the hook on modules during training. |
True
|
on_valid |
bool
|
Whether to run the hook on modules during validation. |
False
|
modules |
Module | Iterable[Module] | None
|
Modules to register the hook on. Default to all modules. |
None
|
is_forward |
bool
|
Whether to register |
True
|
compute_stats(hook, module, inp, outp, bins=40, bins_range=(0, 10))
Compute the means, std, and histogram of module
activations/gradients
and the hook
stats.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
hook |
Hook
|
Registered hook on the provided module. |
required |
module |
Module
|
Module to compute the stats on. |
required |
inp |
Tensor
|
Input of the module. |
required |
outp |
Tensor
|
Output of the module. |
required |
bins |
int
|
Number of histogram bins. |
40
|
bins_range |
Iterable
|
lower/upper end of the histogram's bins range. |
(0, 10)
|
get_hist(hook)
Return matrix-ready for plotting heatmap of activations/gradients.
get_min(hook, bins_range)
Compute the percentage of activations/gradients around zero from hook's histogram matrix.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
hook |
Hook
|
Hook that has the stats of the activations |
required |
bins_range |
list | tuple
|
Bins range around zero. |
required |
Returns:
Type | Description |
---|---|
Tensor
|
Percentage of the activations around zero. |