TorchScript¶
AbstractTorchScriptNetwork¶
-
class
AbstractTorchScriptNetwork
(optimize=True, **kwargs)[source]¶ Bases:
delira.models.abstract_network.AbstractNetwork
,torch.jit.ScriptModule
Abstract Interface Class for TorchScript Networks. For more information have a look at https://pytorch.org/docs/stable/jit.html#torchscript
Warning
In addition to the here defined API, a forward function must be implemented and decorated with
@torch.jit.script_method
-
_init_kwargs
= {}¶
-
static
closure
(model, data_dict: dict, optimizers: dict, losses={}, metrics={}, fold=0, **kwargs)[source]¶ closure method to do a single backpropagation step
- Parameters
model (
AbstractTorchScriptNetwork
) – trainable modeldata_dict (dict) – dictionary containing the data
optimizers (dict) – dictionary of optimizers to optimize model’s parameters
losses (dict) – dict holding the losses to calculate errors (gradients from different losses will be accumulated)
metrics (dict) – dict holding the metrics to calculate
fold (int) – Current Fold in Crossvalidation (default: 0)
**kwargs – additional keyword arguments
- Returns
dict – Metric values (with same keys as input dict metrics)
dict – Loss values (with same keys as input dict losses)
list – Arbitrary number of predictions as torch.Tensor
- Raises
AssertionError – if optimizers or losses are empty or the optimizers are not specified
-
property
init_kwargs
¶ Returns all arguments registered as init kwargs
- Returns
init kwargs
- Return type
-