Models

delira comes with it’s own model-structure tree - with AbstractNetwork at it’s root - and integrates PyTorch Models (AbstractPyTorchNetwork) deeply into the model structure. Tensorflow Integration is planned.

AbstractNetwork

class AbstractNetwork(type)[source]

Bases: object

Abstract class all networks should be derived from

_init_kwargs = {}
static closure(model, data_dict: dict, optimizers: dict, criterions={}, metrics={}, fold=0, **kwargs)[source]

Function which handles prediction from batch, logging, loss calculation and optimizer step :param model: model to forward data through :type model: AbstractNetwork :param data_dict: dictionary containing the data :type data_dict: dict :param optimizers: dictionary containing all optimizers to perform parameter update :type optimizers: dict :param criterions: Functions or classes to calculate criterions :type criterions: dict :param metrics: Functions or classes to calculate other metrics :type metrics: dict :param fold: Current Fold in Crossvalidation (default: 0) :type fold: int :param kwargs: additional keyword arguments :type kwargs: dict

Returns:
  • dict – Metric values (with same keys as input dict metrics)
  • dict – Loss values (with same keys as input dict criterions)
  • list – Arbitrary number of predictions
Raises:NotImplementedError – If not overwritten by subclass
init_kwargs

Returns all arguments registered as init kwargs

Returns:init kwargs
Return type:dict
static prepare_batch(batch: dict, input_device, output_device)[source]

Converts a numpy batch of data and labels to suitable datatype and pushes them to correct devices

Parameters:
  • batch (dict) – dictionary containing the batch (must have keys ‘data’ and ‘label’
  • input_device – device for network inputs
  • output_device – device for network outputs
Returns:

dictionary containing all necessary data in right format and type and on the correct device

Return type:

dict

Raises:

NotImplementedError – If not overwritten by subclass

AbstractPyTorchNetwork

class AbstractPyTorchNetwork(type)[source]

Bases: delira.models.abstract_network.AbstractNetwork, sphinx.ext.autodoc.importer._MockObject

Abstract Class for PyTorch Networks

See also

torch.nn.Module AbstractNetwork

_init_kwargs = {}
static closure(model, data_dict: dict, optimizers: dict, criterions={}, metrics={}, fold=0, **kwargs)

Function which handles prediction from batch, logging, loss calculation and optimizer step :param model: model to forward data through :type model: AbstractNetwork :param data_dict: dictionary containing the data :type data_dict: dict :param optimizers: dictionary containing all optimizers to perform parameter update :type optimizers: dict :param criterions: Functions or classes to calculate criterions :type criterions: dict :param metrics: Functions or classes to calculate other metrics :type metrics: dict :param fold: Current Fold in Crossvalidation (default: 0) :type fold: int :param kwargs: additional keyword arguments :type kwargs: dict

Returns:
  • dict – Metric values (with same keys as input dict metrics)
  • dict – Loss values (with same keys as input dict criterions)
  • list – Arbitrary number of predictions
Raises:NotImplementedError – If not overwritten by subclass
forward(*inputs)[source]

Forward inputs through module (defines module behavior) :param inputs: inputs of arbitrary type and number :type inputs: list

Returns:result: module results of arbitrary type and number
Return type:Any
init_kwargs

Returns all arguments registered as init kwargs

Returns:init kwargs
Return type:dict
static prepare_batch(batch: dict, input_device, output_device)[source]

Helper Function to prepare Network Inputs and Labels (convert them to correct type and shape and push them to correct devices)

Parameters:
  • batch (dict) – dictionary containing all the data
  • input_device (torch.device) – device for network inputs
  • output_device (torch.device) – device for network outputs
Returns:

dictionary containing data in correct type and shape and on correct device

Return type:

dict

AbstractTfNetwork

class AbstractTfNetwork(sess=<sphinx.ext.autodoc.importer._MockObject object>, **kwargs)[source]

Bases: delira.models.abstract_network.AbstractNetwork

Abstract Class for Tf Networks

See also

AbstractNetwork

_add_losses(losses: dict)[source]

Add losses to the model graph

Parameters:losses (dict) – dictionary containing losses.
_add_optims(optims: dict)[source]

Add optimizers to the model graph

Parameters:optims (dict) – dictionary containing losses.
_init_kwargs = {}
static closure(model, data_dict: dict, optimizers: dict, criterions={}, metrics={}, fold=0, **kwargs)

Function which handles prediction from batch, logging, loss calculation and optimizer step :param model: model to forward data through :type model: AbstractNetwork :param data_dict: dictionary containing the data :type data_dict: dict :param optimizers: dictionary containing all optimizers to perform parameter update :type optimizers: dict :param criterions: Functions or classes to calculate criterions :type criterions: dict :param metrics: Functions or classes to calculate other metrics :type metrics: dict :param fold: Current Fold in Crossvalidation (default: 0) :type fold: int :param kwargs: additional keyword arguments :type kwargs: dict

Returns:
  • dict – Metric values (with same keys as input dict metrics)
  • dict – Loss values (with same keys as input dict criterions)
  • list – Arbitrary number of predictions
Raises:NotImplementedError – If not overwritten by subclass
init_kwargs

Returns all arguments registered as init kwargs

Returns:init kwargs
Return type:dict
static prepare_batch(batch: dict, input_device, output_device)

Converts a numpy batch of data and labels to suitable datatype and pushes them to correct devices

Parameters:
  • batch (dict) – dictionary containing the batch (must have keys ‘data’ and ‘label’
  • input_device – device for network inputs
  • output_device – device for network outputs
Returns:

dictionary containing all necessary data in right format and type and on the correct device

Return type:

dict

Raises:

NotImplementedError – If not overwritten by subclass

run(*args)[source]

Evaluates self.outputs_train or self.outputs_eval based on self.training

Parameters:*args – arguments to feed as self.inputs. Must have same length as self.inputs
Returns:based on len(self.outputs*), returns either list or np.ndarray
Return type:np.ndarray or list