import contextlib
from delira import get_backends
from delira.utils.decorators import make_deprecated
if "TORCH" in get_backends():
import torch
[docs] class DefaultOptimWrapperTorch(object):
"""
Class wrapping a ``torch`` optimizer to mirror the behavior of ``apex``
without depending on it
"""
@make_deprecated(
"'delira.models.model_utils.scale_loss' combined with "
"new apex.amp API (https://github.com/NVIDIA/apex)")
def __init__(self, optimizer: torch.optim.Optimizer, *args, **kwargs):
"""
Parameters
----------
optimizer : torch.optim.Optimizer
the actual optimizer to wrap
*args :
additional positional arguments (unused)
**kwargs :
additional keyword arguments (unused)
"""
self._optimizer = optimizer
[docs] @contextlib.contextmanager
def scale_loss(self, loss):
"""
Function which scales the loss in ``apex`` and yields the unscaled
loss here to mirror the API
Parameters
----------
loss : torch.Tensor
the unscaled loss
"""
yield loss
return
[docs] def step(self, closure=None):
"""
Wraps the step method of the optimizer and calls the original step
method
Parameters
----------
closure : callable
A closure that reevaluates the model and returns the loss.
Optional for most optimizers.
"""
return self._optimizer.step(closure=closure)
# Forward any attribute lookups
def __getattr__(self, attr):
return getattr(self._optimizer, attr)
# Forward all torch.optim.Optimizer methods
def __getstate__(self):
return self._optimizer.__getstate__()
def __setstate__(self, *args, **kwargs):
return self._optimizer.__setstate__(*args, **kwargs)
def __repr__(self):
return self._optimizer.__repr__()
[docs] def state_dict(self):
return self._optimizer.state_dict()
[docs] def load_state_dict(self, state_dict):
return self._optimizer.load_state_dict(state_dict)
[docs] def zero_grad(self):
return self._optimizer.zero_grad()
[docs] def add_param_group(self, param_group):
return self._optimizer.add_param_group(param_group)
from delira import get_current_debug_mode, set_debug_mode
[docs] class DebugMode(object):
"""
Context Manager to set a specific debug mode for the code inside the
defined context (and reverting to previous mode afterwards)
"""
def __init__(self, mode):
"""
Parameters
----------
mode : bool
the debug mode; if ``True`` disables all multiprocessing
"""
self._mode = mode
[docs] def _switch_to_new_mode(self):
"""
helper function to switch to the new debug mode
(and saving the previous one in ``self._mode``)
"""
prev_mode = get_current_debug_mode()
set_debug_mode(self._mode)
self._mode = prev_mode
def __enter__(self):
"""
Sets the specified debug mode on entering the context
"""
self._switch_to_new_mode()
def __exit__(self, *args, **kwargs):
"""
Resets the previous debug mode on exiting the context
Parameters
----------
*args :
arbitrary positional arguments
(ignored here, just needed for compatibility with other context
managers)
**kwargs :
arbitrary keyword arguments
(ignored here, just needed for compatibility with other context
managers)
"""
self._switch_to_new_mode()
[docs] class DebugEnabled(DebugMode):
"""
Context Manager to enable the debug mode for the wrapped context
"""
def __init__(self):
super().__init__(True)
[docs] class DebugDisabled(DebugMode):
"""
Context Manager to disable the debug mode for the wrapped context
"""
def __init__(self):
super().__init__(False)