TorchScript

AbstractTorchScriptNetwork

class AbstractTorchScriptNetwork(optimize=True, **kwargs)[source]

Bases: delira.models.abstract_network.AbstractNetwork, torch.jit.ScriptModule

Abstract Interface Class for TorchScript Networks. For more information have a look at https://pytorch.org/docs/stable/jit.html#torchscript

Warning

In addition to the here defined API, a forward function must be implemented and decorated with @torch.jit.script_method

_init_kwargs = {}
static closure(model, data_dict: dict, optimizers: dict, losses={}, metrics={}, fold=0, **kwargs)[source]

closure method to do a single backpropagation step

Parameters
  • model (AbstractTorchScriptNetwork) – trainable model

  • data_dict (dict) – dictionary containing the data

  • optimizers (dict) – dictionary of optimizers to optimize model’s parameters

  • losses (dict) – dict holding the losses to calculate errors (gradients from different losses will be accumulated)

  • metrics (dict) – dict holding the metrics to calculate

  • fold (int) – Current Fold in Crossvalidation (default: 0)

  • **kwargs – additional keyword arguments

Returns

  • dict – Metric values (with same keys as input dict metrics)

  • dict – Loss values (with same keys as input dict losses)

  • list – Arbitrary number of predictions as torch.Tensor

Raises

AssertionError – if optimizers or losses are empty or the optimizers are not specified

property init_kwargs

Returns all arguments registered as init kwargs

Returns

init kwargs

Return type

dict

static prepare_batch(batch: dict, input_device, output_device)[source]

Helper Function to prepare Network Inputs and Labels (convert them to correct type and shape and push them to correct devices)

Parameters
  • batch (dict) – dictionary containing all the data

  • input_device (torch.device) – device for network inputs

  • output_device (torch.device) – device for network outputs

Returns

dictionary containing data in correct type and shape and on correct device

Return type

dict