Source code for delira.models.abstract_network

import abc
import logging

file_logger = logging.getLogger(__name__)


[docs]class AbstractNetwork(object): """ Abstract class all networks should be derived from """ _init_kwargs = {} @abc.abstractmethod def __init__(self, **kwargs): """ Init function to register init kwargs (should be called from all subclasses) Parameters ---------- **kwargs keyword arguments (will be registered to `self.init_kwargs`) """ super().__init__() for key, val in kwargs.items(): self._init_kwargs[key] = val @abc.abstractmethod def __call__(self, *args, **kwargs): """ AbstractMethod to specify that each model should be able to be called for predictions Parameters ---------- *args : Positional arguments **kwargs : Keyword Arguments Raises ------ NotImplementedError if not overwritten by subclass """ raise NotImplementedError()
[docs] @staticmethod @abc.abstractmethod def closure(model, data_dict: dict, optimizers: dict, losses=None, metrics=None, fold=0, **kwargs): """ Function which handles prediction from batch, logging, loss calculation and optimizer step Parameters ---------- model : :class:`AbstractNetwork` model to forward data through data_dict : dict dictionary containing the data optimizers : dict dictionary containing all optimizers to perform parameter update losses : dict Functions or classes to calculate losses metrics : dict Functions or classes to calculate other metrics fold : int Current Fold in Crossvalidation (default: 0) kwargs : dict additional keyword arguments Returns ------- dict Metric values (with same keys as input dict metrics) dict Loss values (with same keys as input dict losses) dict Arbitrary number of predictions Raises ------ NotImplementedError If not overwritten by subclass """ if losses is None: losses = {} if metrics is None: metrics = {} raise NotImplementedError()
[docs] @staticmethod def prepare_batch(batch: dict, input_device, output_device): """ Converts a numpy batch of data and labels to suitable datatype and pushes them to correct devices Parameters ---------- batch : dict dictionary containing the batch (must have keys 'data' and 'label' input_device : device for network inputs output_device : device for network outputs Returns ------- dict dictionary containing all necessary data in right format and type and on the correct device Raises ------ NotImplementedError If not overwritten by subclass """ raise NotImplementedError()
@property def init_kwargs(self): """ Returns all arguments registered as init kwargs Returns ------- dict init kwargs """ return self._init_kwargs