PyTorch¶
AbstractPyTorchNetwork¶
-
class
AbstractPyTorchNetwork
(**kwargs)[source]¶ Bases:
delira.models.abstract_network.AbstractNetwork
,torch.nn.Module
Abstract Class for PyTorch Networks
See also
None
,AbstractNetwork
-
_init_kwargs
= {}¶
-
static
closure
(model, data_dict: dict, optimizers: dict, losses={}, metrics={}, fold=0, **kwargs)[source]¶ closure method to do a single backpropagation step
- Parameters
model (
AbstractPyTorchNetwork
) – trainable modeldata_dict (dict) – dictionary containing the data
optimizers (dict) – dictionary of optimizers to optimize model’s parameters
losses (dict) – dict holding the losses to calculate errors (gradients from different losses will be accumulated)
metrics (dict) – dict holding the metrics to calculate
fold (int) – Current Fold in Crossvalidation (default: 0)
**kwargs – additional keyword arguments
- Returns
dict – Metric values (with same keys as input dict metrics)
dict – Loss values (with same keys as input dict losses)
list – Arbitrary number of predictions as torch.Tensor
- Raises
AssertionError – if optimizers or losses are empty or the optimizers are not specified
-
abstract
forward
(*inputs)[source]¶ Forward inputs through module (defines module behavior) :param inputs: inputs of arbitrary type and number :type inputs: list
- Returns
result: module results of arbitrary type and number
- Return type
Any
-
property
init_kwargs
¶ Returns all arguments registered as init kwargs
- Returns
init kwargs
- Return type
-
DataParallelPyTorchNetwork¶
-
class
DataParallelPyTorchNetwork
(module: delira.models.backends.torch.abstract_network.AbstractPyTorchNetwork, device_ids=None, output_device=None, dim=0)[source]¶ Bases:
delira.models.backends.torch.abstract_network.AbstractPyTorchNetwork
,torch.nn.DataParallel
A Wrapper around a
AbstractPyTorchNetwork
instance to implement parallel training by splitting the batches-
_init_kwargs
= {}¶
-
property
closure
¶ closure method to do a single backpropagation step
- Parameters
model (
AbstractPyTorchNetwork
) – trainable modeldata_dict (dict) – dictionary containing the data
optimizers (dict) – dictionary of optimizers to optimize model’s parameters
losses (dict) – dict holding the losses to calculate errors (gradients from different losses will be accumulated)
metrics (dict) – dict holding the metrics to calculate
fold (int) – Current Fold in Crossvalidation (default: 0)
**kwargs – additional keyword arguments
- Returns
dict – Metric values (with same keys as input dict metrics)
dict – Loss values (with same keys as input dict losses)
list – Arbitrary number of predictions as torch.Tensor
- Raises
AssertionError – if optimizers or losses are empty or the optimizers are not specified
-
forward
(*args, **kwargs)[source]¶ Scatters the inputs (both positional and keyword arguments) across all devices, feeds them through model replicas and re-builds batches on output device
- Parameters
*args – positional arguments of arbitrary number and type
**kwargs – keyword arguments of arbitrary number and type
- Returns
combined output from all scattered models
- Return type
Any
-
property
init_kwargs
¶ Returns all arguments registered as init kwargs
- Returns
init kwargs
- Return type
-
property
prepare_batch
¶ Helper Function to prepare Network Inputs and Labels (convert them to correct type and shape and push them to correct devices)
-
scale_loss¶
-
scale_loss
(loss, optimizers, loss_id=0, model=None, delay_unscale=False, **kwargs)[source]¶ Context Manager which automatically switches between loss scaling via apex.amp (if apex is available) and no loss scaling
- Parameters
loss (
torch.Tensor
) – a pytorch tensor containing the loss valueoptimizers (list) – a list of
torch.optim.Optimizer
containing all optimizers, which are holding paraneters affected by the backpropagation of the current loss valueloss_id (int) – When used in conjunction with the
num_losses
argument toamp.initialize
, enables Amp to use a different loss scale per loss.loss_id
must be an integer between 0 andnum_losses
that tells Amp which loss is being used for the current backward pass. Ifloss_id
is left unspecified, Amp will use the default global loss scaler for this backward pass.model (
AbstractPyTorchNetwork
or None) – Currently unused, reserved to enable future optimizations.delay_unscale (bool) –
delay_unscale
is never necessary, and the default value ofFalse
is strongly recommended. IfTrue
, Amp will not unscale the gradients or perform model->master gradient copies on context manager exit.delay_unscale=True
is a minor ninja performance optimization and can result in weird gotchas (especially with multiple models/optimizers/losses), so only use it if you know what you’re doing.**kwargs – additional keyword arguments; currently unused, but provided for the case amp decides to extend the functionality here
- Yields
torch.Tensor
– the new loss value (scaled if apex.amp is available and was configured to do so, unscaled in all other cases)