Models¶
delira
comes with it’s own model-structure tree - with
AbstractNetwork
at it’s root - and integrates
PyTorch Models (AbstractPyTorchNetwork
) deeply into the model
structure.
Tensorflow Integration is planned.
AbstractNetwork¶
-
class
AbstractNetwork
(type)[source]¶ Bases:
object
Abstract class all networks should be derived from
-
_init_kwargs
= {}¶
-
static
closure
(model, data_dict: dict, optimizers: dict, criterions={}, metrics={}, fold=0, **kwargs)[source]¶ Function which handles prediction from batch, logging, loss calculation and optimizer step :param model: model to forward data through :type model:
AbstractNetwork
:param data_dict: dictionary containing the data :type data_dict: dict :param optimizers: dictionary containing all optimizers to perform parameter update :type optimizers: dict :param criterions: Functions or classes to calculate criterions :type criterions: dict :param metrics: Functions or classes to calculate other metrics :type metrics: dict :param fold: Current Fold in Crossvalidation (default: 0) :type fold: int :param kwargs: additional keyword arguments :type kwargs: dict- Returns
dict – Metric values (with same keys as input dict metrics)
dict – Loss values (with same keys as input dict criterions)
list – Arbitrary number of predictions
- Raises
NotImplementedError – If not overwritten by subclass
-
static
prepare_batch
(batch: dict, input_device, output_device)[source]¶ Converts a numpy batch of data and labels to suitable datatype and pushes them to correct devices
- Parameters
batch (dict) – dictionary containing the batch (must have keys ‘data’ and ‘label’
input_device – device for network inputs
output_device – device for network outputs
- Returns
dictionary containing all necessary data in right format and type and on the correct device
- Return type
- Raises
NotImplementedError – If not overwritten by subclass
-
AbstractPyTorchNetwork¶
-
class
AbstractPyTorchNetwork
(type)[source]¶ Bases:
delira.models.abstract_network.AbstractNetwork
,torch.nn.Module
Abstract Class for PyTorch Networks
See also
None
,AbstractNetwork
-
_init_kwargs
= {}¶
-
static
closure
(model, data_dict: dict, optimizers: dict, criterions={}, metrics={}, fold=0, **kwargs)¶ Function which handles prediction from batch, logging, loss calculation and optimizer step :param model: model to forward data through :type model:
AbstractNetwork
:param data_dict: dictionary containing the data :type data_dict: dict :param optimizers: dictionary containing all optimizers to perform parameter update :type optimizers: dict :param criterions: Functions or classes to calculate criterions :type criterions: dict :param metrics: Functions or classes to calculate other metrics :type metrics: dict :param fold: Current Fold in Crossvalidation (default: 0) :type fold: int :param kwargs: additional keyword arguments :type kwargs: dict- Returns
dict – Metric values (with same keys as input dict metrics)
dict – Loss values (with same keys as input dict criterions)
list – Arbitrary number of predictions
- Raises
NotImplementedError – If not overwritten by subclass
-
forward
(*inputs)[source]¶ Forward inputs through module (defines module behavior) :param inputs: inputs of arbitrary type and number :type inputs: list
- Returns
result: module results of arbitrary type and number
- Return type
Any
-
AbstractTfNetwork¶
-
class
AbstractTfNetwork
(sess=tensorflow.Session, **kwargs)[source]¶ Bases:
delira.models.abstract_network.AbstractNetwork
Abstract Class for Tf Networks
See also
-
_add_losses
(losses: dict)[source]¶ Add losses to the model graph
- Parameters
losses (dict) – dictionary containing losses.
-
_add_optims
(optims: dict)[source]¶ Add optimizers to the model graph
- Parameters
optims (dict) – dictionary containing losses.
-
_init_kwargs
= {}¶
-
static
closure
(model, data_dict: dict, optimizers: dict, criterions={}, metrics={}, fold=0, **kwargs)¶ Function which handles prediction from batch, logging, loss calculation and optimizer step :param model: model to forward data through :type model:
AbstractNetwork
:param data_dict: dictionary containing the data :type data_dict: dict :param optimizers: dictionary containing all optimizers to perform parameter update :type optimizers: dict :param criterions: Functions or classes to calculate criterions :type criterions: dict :param metrics: Functions or classes to calculate other metrics :type metrics: dict :param fold: Current Fold in Crossvalidation (default: 0) :type fold: int :param kwargs: additional keyword arguments :type kwargs: dict- Returns
dict – Metric values (with same keys as input dict metrics)
dict – Loss values (with same keys as input dict criterions)
list – Arbitrary number of predictions
- Raises
NotImplementedError – If not overwritten by subclass
-
static
prepare_batch
(batch: dict, input_device, output_device)¶ Converts a numpy batch of data and labels to suitable datatype and pushes them to correct devices
- Parameters
batch (dict) – dictionary containing the batch (must have keys ‘data’ and ‘label’
input_device – device for network inputs
output_device – device for network outputs
- Returns
dictionary containing all necessary data in right format and type and on the correct device
- Return type
- Raises
NotImplementedError – If not overwritten by subclass
-