Schedule
Scheduling callbacks for hyperparameter management during training.
This module provides commonly used schedulers for hyperparameters such as learning rate.
It is very important to remember to apply hyperparameter updates (such as learning rate)
after the optimizer's update. This means that it should be either in before_batch
OR
after_batch
in our framework.
We have two choices to schedule hyperparameters:
- Use any subclass of
Scheduler
such asBatchScheduler
and pass any scheduler from PyTorch's schedulers (see https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate) - Use
ParamScheduler
and pass it any callable that takes the position and returns the hyperparameter value such asexp_sched
Functions:
-
annealer : Callable
–Decorator to create annealer functions.
-
no_sched : Callable
–Constant scheduler that returns the start value regardless of position.
-
lin_sched : Callable
–Linear scheduler that interpolates linearly between start and end values.
-
cos_sched : Callable
–Cosine scheduler that interpolates using cosine annealing.
-
exp_sched : Callable
–Exponential scheduler that interpolates exponentially between start and end values.
-
cos_1cycle_anneal : Callable
–Combine two cosine schedulers for 1-cycle learning rate scheduling.
-
combine_scheds : Callable
–Combine multiple schedulers, each running for a given percentage of training.
Classes:
-
ParamScheduler : Callback
–Callback to schedule hyperparameter values during training.
-
Scheduler : Callback
–Base scheduler to change hyperparameters using PyTorch schedulers.
-
BatchScheduler : Scheduler
–Scheduler that updates hyperparameters after every batch.
-
EpochScheduler : Scheduler
–Scheduler that updates hyperparameters after every epoch.
Examples:
>>> # Using ParamScheduler with exponential scheduling
>>> scheduler = ParamScheduler('lr', exp_sched(1e-3, 1e-5))
>>> learner.add_callback(scheduler)
>>> # Using BatchScheduler with PyTorch's OneCycleLR
>>> from torch.optim.lr_scheduler import OneCycleLR
>>> from functools import partial
>>> scheduler = BatchScheduler(partial(OneCycleLR, max_lr=1e-2, total_steps=1000))
>>> learner.add_callback(scheduler)
BatchScheduler
Bases: Scheduler
Scheduler that updates hyperparameters after every batch.
This scheduler applies the hyperparameter updates after each batch during training, which is useful for schedulers that need frequent updates like OneCycleLR.
Examples:
>>> from torch.optim.lr_scheduler import OneCycleLR
>>> from functools import partial
>>> scheduler = BatchScheduler(partial(OneCycleLR, max_lr=1e-2, total_steps=1000))
>>> learner.add_callback(scheduler)
after_batch()
Update hyperparameters after each batch during training.
Only updates values during training mode, not during validation.
EpochScheduler
Bases: Scheduler
Scheduler that updates hyperparameters after every epoch.
This scheduler applies the hyperparameter updates after each epoch during training, which is useful for schedulers that need less frequent updates like StepLR or ReduceLROnPlateau.
Examples:
>>> from torch.optim.lr_scheduler import StepLR
>>> from functools import partial
>>> scheduler = EpochScheduler(partial(StepLR, step_size=30, gamma=0.1))
>>> learner.add_callback(scheduler)
after_epoch()
Update hyperparameters after each epoch during training.
Only updates values during training mode, not during validation.
ParamScheduler
Bases: Callback
Callback to schedule hyperparameter values during training.
This class is used to schedule the values of hyperparameters (such as learning rate) during the training process. It can handle different schedulers for different parameter groups in the optimizer.
Attributes:
-
order
(int
) –The order in which this callback should be executed (default: 60).
-
pname
(str
) –The name of the hyperparameter to be scheduled.
-
sched_funcs
(list[Callable] | tuple[Callable]
) –List or tuple of scheduler functions for each parameter group.
Examples:
>>> # Schedule learning rate with exponential decay
>>> scheduler = ParamScheduler('lr', exp_sched(0.01, 0.001))
>>> learner.add_callback(scheduler)
>>> # Different schedulers for different parameter groups
>>> schedulers = [lin_sched(0.01, 0.001), cos_sched(0.005, 0.0005)]
>>> scheduler = ParamScheduler('lr', schedulers)
>>> learner.add_callback(scheduler)
__init__(pname, sched_funcs)
Initialize the parameter scheduler.
Parameters:
-
pname
(str
) –The name of the hyperparameter to be scheduled (e.g., 'lr' for learning rate).
-
sched_funcs
(list[Callable] | tuple[Callable]
) –A list or tuple of schedulers for each parameter group. Each scheduler should accept a single argument (position) and return a value for the hyperparameter. If a single scheduler is provided, it will be applied to all parameter groups.
before_batch()
Update hyperparameter values before each batch during training.
Only updates values during training mode, not during validation.
before_fit()
Prepare the scheduler before training begins.
If a single scheduler function is provided, it will be replicated for all parameter groups in the optimizer.
Scheduler
Bases: Callback
Base scheduler to change hyperparameters using PyTorch schedulers.
This class provides a base implementation for using PyTorch's built-in schedulers. PyTorch's schedulers take the optimizer as the first argument, so it's important to pass a scheduler that has all its arguments already passed except the optimizer.
Attributes:
-
scheduler
(Callable
) –The scheduler function to be used.
-
scheduler_object
(_LRScheduler
) –The instantiated scheduler object.
Notes
PyTorch's schedulers take optimizer as the first argument. Therefore, it is
important to pass the scheduler that has all its arguments already passed
except the optimizer. This will be done in Scheduler
's before_fit
method.
Examples:
>>> from torch.optim.lr_scheduler import OneCycleLR
>>> from functools import partial
>>> scheduler = Scheduler(partial(OneCycleLR, max_lr=1e-2, total_steps=1000))
>>> learner.add_callback(scheduler)
__init__(scheduler)
Initialize the scheduler.
Parameters:
-
scheduler
(Callable
) –The scheduler function to be used. Should accept an optimizer as its first argument and return a scheduler object.
before_fit()
Create the scheduler object before training begins.
Instantiates the scheduler with the current optimizer.
step()
Step the scheduler to update hyperparameter values.
This method calls the step method of the underlying PyTorch scheduler.
annealer(func)
Decorator to create annealer functions.
This decorator wraps a function to create a partial function that can be used as a scheduler. The wrapped function should accept start, end, and position parameters and return the scheduled value.
Parameters:
-
func
(Callable
) –The function to be wrapped. Should have signature func(start, end, pos).
Returns:
-
Callable
–A partial function that can be used as a scheduler.
Examples:
>>> @annealer
>>> def my_scheduler(start, end, pos):
... return start + (end - start) * pos
>>> scheduler = my_scheduler(0.1, 0.01)
>>> value = scheduler(0.5) # pos = 0.5
combine_scheds(pcts, scheds)
Combine multiple schedulers, each running for a given percentage of training.
Parameters:
-
pcts
(list[float]
) –List of percentages (should sum to 1.0) indicating how much of training each scheduler should run for.
-
scheds
(list[Callable]
) –List of scheduler functions corresponding to each percentage.
Returns:
-
Callable
–A combined scheduler function that switches between schedulers based on the current training position.
Raises:
-
AssertionError
–If the number of percentages doesn't match the number of schedulers, if percentages don't sum to 1.0, or if any percentage is negative.
Examples:
>>> scheduler1 = lin_sched(0.01, 0.005)
>>> scheduler2 = cos_sched(0.005, 0.001)
>>> combined = combine_scheds([0.6, 0.4], [scheduler1, scheduler2])
>>> # First 60% of training uses linear scheduler, last 40% uses cosine
cos_1cycle_anneal(start, high, end)
Combine two cosine schedulers for 1-cycle learning rate scheduling.
Creates a list of two cosine schedulers where the first scheduler goes from
start
to high
and the second scheduler goes from high
to end
.
Parameters:
-
start
(float
) –The starting value.
-
high
(float
) –The peak value in the middle of training.
-
end
(float
) –The ending value.
Returns:
-
list
–A list containing two cosine schedulers for the 1-cycle policy.
Examples:
>>> schedulers = cos_1cycle_anneal(0.001, 0.01, 0.0001)
>>> len(schedulers) # Returns 2 schedulers
2
cos_sched(start, end, pos)
Cosine scheduler that interpolates using cosine annealing.
Parameters:
-
start
(float
) –The starting value.
-
end
(float
) –The ending value.
-
pos
(float
) –Current position in training (0 to 1).
Returns:
-
float
–The cosine-interpolated value.
Examples:
>>> scheduler = cos_sched(0.01, 0.001)
>>> scheduler(0.0) # Returns start value
0.01
>>> scheduler(1.0) # Returns end value
0.001
exp_sched(start, end, pos)
Exponential scheduler that interpolates exponentially between start and end values.
Parameters:
-
start
(float
) –The starting value.
-
end
(float
) –The ending value.
-
pos
(float
) –Current position in training (0 to 1).
Returns:
-
float
–The exponentially interpolated value.
Examples:
>>> scheduler = exp_sched(0.01, 0.001)
>>> scheduler(0.0) # Returns start value
0.01
>>> scheduler(1.0) # Returns end value
0.001
lin_sched(start, end, pos)
Linear scheduler that interpolates linearly between start and end values.
Parameters:
-
start
(float
) –The starting value.
-
end
(float
) –The ending value.
-
pos
(float
) –Current position in training (0 to 1).
Returns:
-
float
–The linearly interpolated value.
Examples:
>>> scheduler = lin_sched(0.01, 0.001)
>>> scheduler(0.5) # Returns 0.0055 (halfway between 0.01 and 0.001)
0.0055
no_sched(start, end, pos)
Constant scheduler that returns the start value regardless of position.
Parameters:
-
start
(float
) –The constant value to return.
-
end
(float
) –Not used in this scheduler (kept for interface consistency).
-
pos
(float
) –Current position in training (0 to 1). Not used in this scheduler.
Returns:
-
float
–The start value (constant).
Examples:
>>> scheduler = no_sched(0.01, 0.001)
>>> scheduler(0.5) # Returns 0.01 regardless of position
0.01