from .abstract_callback import AbstractCallback
from delira import get_backends
if 'TORCH' in get_backends():
from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingLR, \
ExponentialLR, LambdaLR, MultiStepLR, StepLR
[docs] class DefaultPyTorchSchedulerCallback(AbstractCallback):
Implements a Callback, which `at_epoch_end` function is suitable for most
def __init__(self, *args, **kwargs):
*args :
Arbitrary Positional Arguments
**kwargs :
Arbitrary Keyword Arguments
self.scheduler = None
[docs] def at_epoch_end(self, trainer, **kwargs):
Executes a single scheduling step
trainer : :class:`PyTorchNetworkTrainer`
the trainer class, which can be changed
**kwargs :
additional keyword arguments
modified trainer
self.scheduler.step(epoch=kwargs.get("curr_epoch", None))
return {}
[docs] class ReduceLROnPlateauCallback(DefaultPyTorchSchedulerCallback):
Wraps PyTorch's `ReduceLROnPlateau` Scheduler as Callback
def __init__(self, optimizer, mode='min', factor=0.1, patience=10,
verbose=False, threshold=1e-4, threshold_mode='rel',
cooldown=0, min_lr=0, eps=1e-8):
optimizer : Optimizer
Wrapped optimizer.
mode : str
One of `min`, `max`. In `min` mode, lr will
be reduced when the quantity monitored has stopped
decreasing; in `max` mode it will be reduced when the
quantity monitored has stopped increasing. Default: 'min'.
factor : float
Factor by which the learning rate will be
reduced. new_lr = lr * factor. Default: 0.1.
patience : int
Number of epochs with no improvement after
which learning rate will be reduced. For example, if
`patience = 2`, then we will ignore the first 2 epochs
with no improvement, and will only decrease the LR after the
3rd epoch if the loss still hasn't improved then.
Default: 10.
verbose : bool
If ``True``, prints a message to stdout for
each update. Default: ``False``.
threshold : float
Threshold for measuring the new optimum,
to only focus on significant changes. Default: 1e-4.
threshold_mode : string
One of `rel`, `abs`. In `rel` mode,
dynamic_threshold = best * ( 1 + threshold ) in 'max'
mode or best * ( 1 - threshold ) in `min` mode.
In `abs` mode, dynamic_threshold = best + threshold in
`max` mode or best - threshold in `min` mode. Default: 'rel'.
cooldown : int
Number of epochs to wait before resuming
normal operation after lr has been reduced. Default: 0.
min_lr : float or list
A scalar or a list of scalars. A
lower bound on the learning rate of all param groups
or each group respectively. Default: 0.
eps : float
Minimal decay applied to lr. If the difference
between new and old lr is smaller than eps, the update is
ignored. Default: 1e-8
self.scheduler = ReduceLROnPlateau(optimizer, mode, factor, patience,
verbose, threshold, threshold_mode,
cooldown, min_lr, eps)
[docs] def at_epoch_end(self, trainer,
Executes a single scheduling step
trainer : :class:`PyTorchNetworkTrainer`
the trainer class, which can be changed
kwargs :
additional keyword arguments
modified trainer
metrics = kwargs.get("val_metrics", {}).get(kwargs.get("val_score_key",
return {}
[docs] class CosineAnnealingLRCallback(DefaultPyTorchSchedulerCallback):
Wraps PyTorch's `CosineAnnealingLR` Scheduler as callback
def __init__(self, optimizer, T_max, eta_min=0, last_epoch=-1):
optimizer : optimizer
Wrapped optimizer.
T_max : int
Maximum number of iterations.
eta_min : float
Minimum learning rate. Default: 0.
last_epoch : int
The index of last epoch. Default: -1.
self.scheduler = CosineAnnealingLR(optimizer, T_max, eta_min,
[docs] class ExponentialLRCallback(DefaultPyTorchSchedulerCallback):
Wraps PyTorch's `ExponentialLR` Scheduler as callback
def __init__(self, optimizer, gamma, last_epoch=-1):
optimizer : Optimizer
Wrapped optimizer.
gamma : float
Multiplicative factor of learning rate decay.
last_epoch : int
The index of last epoch. Default: -1.
self.scheduler = ExponentialLR(optimizer, gamma, last_epoch)
[docs] class LambdaLRCallback(DefaultPyTorchSchedulerCallback):
Wraps PyTorch's `LambdaLR` Scheduler as callback
def __init__(self, optimizer, lr_lambda, last_epoch=-1):
optimizer : Optimizer
Wrapped optimizer.
lr_lambda : function or list
A function which computes a multiplicative
factor given an integer parameter epoch, or a list of such
functions, one for each group in optimizer.param_groups.
last_epoch : int
The index of last epoch. Default: -1.
self.scheduler = LambdaLR(optimizer, lr_lambda, last_epoch)
[docs] class MultiStepLRCallback(DefaultPyTorchSchedulerCallback):
Wraps PyTorch's `MultiStepLR` Scheduler as callback
def __init__(self, optimizer, milestones, gamma=0.1, last_epoch=-1):
optimizer : Optimizer
Wrapped optimizer.
milestones : list
List of epoch indices. Must be increasing.
gamma : float
Multiplicative factor of learning rate decay.
Default: 0.1.
last_epoch : int
The index of last epoch. Default: -1.
self.scheduler = MultiStepLR(
optimizer, milestones, gamma, last_epoch)
[docs] class StepLRCallback(DefaultPyTorchSchedulerCallback):
Wraps PyTorch's `StepLR` Scheduler as callback
def __init__(self, optimizer, step_size, gamma=0.1, last_epoch=-1):
optimizer : Optimizer
Wrapped optimizer.
step_size : int
Period of learning rate decay.
gamma :float
Multiplicative factor of learning rate decay.
Default: 0.1.
last_epoch : int
The index of last epoch. Default: -1
self.scheduler = StepLR(optimizer, step_size, gamma, last_epoch)