Models¶
delira
comes with it’s own model-structure tree - with
AbstractNetwork
at it’s root - and integrates
PyTorch Models (AbstractPyTorchNetwork
) deeply into the model
structure.
Tensorflow Integration is planned.
AbstractNetwork¶
-
class
AbstractNetwork
(type)[source]¶ Bases:
object
Abstract class all networks should be derived from
-
_init_kwargs
= {}¶
-
static
closure
(model, data_dict: dict, optimizers: dict, criterions={}, metrics={}, fold=0, **kwargs)[source]¶ Function which handles prediction from batch, logging, loss calculation and optimizer step :param model: model to forward data through :type model:
AbstractNetwork
:param data_dict: dictionary containing the data :type data_dict: dict :param optimizers: dictionary containing all optimizers to perform parameter update :type optimizers: dict :param criterions: Functions or classes to calculate criterions :type criterions: dict :param metrics: Functions or classes to calculate other metrics :type metrics: dict :param fold: Current Fold in Crossvalidation (default: 0) :type fold: int :param kwargs: additional keyword arguments :type kwargs: dictReturns: - dict – Metric values (with same keys as input dict metrics)
- dict – Loss values (with same keys as input dict criterions)
- list – Arbitrary number of predictions
Raises: NotImplementedError
– If not overwritten by subclass
-
static
prepare_batch
(batch: dict, input_device, output_device)[source]¶ Converts a numpy batch of data and labels to suitable datatype and pushes them to correct devices
Parameters: - batch (dict) – dictionary containing the batch (must have keys ‘data’ and ‘label’
- input_device – device for network inputs
- output_device – device for network outputs
Returns: dictionary containing all necessary data in right format and type and on the correct device
Return type: Raises: NotImplementedError
– If not overwritten by subclass
-
AbstractPyTorchNetwork¶
-
class
AbstractPyTorchNetwork
(type)[source]¶ Bases:
delira.models.abstract_network.AbstractNetwork
,sphinx.ext.autodoc.importer._MockObject
Abstract Class for PyTorch Networks
See also
torch.nn.Module
AbstractNetwork
-
_init_kwargs
= {}¶
-
static
closure
(model, data_dict: dict, optimizers: dict, criterions={}, metrics={}, fold=0, **kwargs)¶ Function which handles prediction from batch, logging, loss calculation and optimizer step :param model: model to forward data through :type model:
AbstractNetwork
:param data_dict: dictionary containing the data :type data_dict: dict :param optimizers: dictionary containing all optimizers to perform parameter update :type optimizers: dict :param criterions: Functions or classes to calculate criterions :type criterions: dict :param metrics: Functions or classes to calculate other metrics :type metrics: dict :param fold: Current Fold in Crossvalidation (default: 0) :type fold: int :param kwargs: additional keyword arguments :type kwargs: dictReturns: - dict – Metric values (with same keys as input dict metrics)
- dict – Loss values (with same keys as input dict criterions)
- list – Arbitrary number of predictions
Raises: NotImplementedError
– If not overwritten by subclass
-
forward
(*inputs)[source]¶ Forward inputs through module (defines module behavior) :param inputs: inputs of arbitrary type and number :type inputs: list
Returns: result: module results of arbitrary type and number Return type: Any
-
static
prepare_batch
(batch: dict, input_device, output_device)[source]¶ Helper Function to prepare Network Inputs and Labels (convert them to correct type and shape and push them to correct devices)
Parameters: - batch (dict) – dictionary containing all the data
- input_device (torch.device) – device for network inputs
- output_device (torch.device) – device for network outputs
Returns: dictionary containing data in correct type and shape and on correct device
Return type:
-
AbstractTfNetwork¶
-
class
AbstractTfNetwork
(sess=<sphinx.ext.autodoc.importer._MockObject object>, **kwargs)[source]¶ Bases:
delira.models.abstract_network.AbstractNetwork
Abstract Class for Tf Networks
See also
-
_add_losses
(losses: dict)[source]¶ Add losses to the model graph
Parameters: losses (dict) – dictionary containing losses.
-
_add_optims
(optims: dict)[source]¶ Add optimizers to the model graph
Parameters: optims (dict) – dictionary containing losses.
-
_init_kwargs
= {}¶
-
static
closure
(model, data_dict: dict, optimizers: dict, criterions={}, metrics={}, fold=0, **kwargs)¶ Function which handles prediction from batch, logging, loss calculation and optimizer step :param model: model to forward data through :type model:
AbstractNetwork
:param data_dict: dictionary containing the data :type data_dict: dict :param optimizers: dictionary containing all optimizers to perform parameter update :type optimizers: dict :param criterions: Functions or classes to calculate criterions :type criterions: dict :param metrics: Functions or classes to calculate other metrics :type metrics: dict :param fold: Current Fold in Crossvalidation (default: 0) :type fold: int :param kwargs: additional keyword arguments :type kwargs: dictReturns: - dict – Metric values (with same keys as input dict metrics)
- dict – Loss values (with same keys as input dict criterions)
- list – Arbitrary number of predictions
Raises: NotImplementedError
– If not overwritten by subclass
-
static
prepare_batch
(batch: dict, input_device, output_device)¶ Converts a numpy batch of data and labels to suitable datatype and pushes them to correct devices
Parameters: - batch (dict) – dictionary containing the batch (must have keys ‘data’ and ‘label’
- input_device – device for network inputs
- output_device – device for network outputs
Returns: dictionary containing all necessary data in right format and type and on the correct device
Return type: Raises: NotImplementedError
– If not overwritten by subclass
-