import abc
import logging
file_logger = logging.getLogger(__name__)
[docs]class AbstractNetwork(object):
"""
Abstract class all networks should be derived from
"""
_init_kwargs = {}
@abc.abstractmethod
def __init__(self, **kwargs):
"""
Init function to register init kwargs (should be called from all
subclasses)
Parameters
----------
**kwargs
keyword arguments (will be registered to `self.init_kwargs`)
"""
super().__init__()
for key, val in kwargs.items():
self._init_kwargs[key] = val
@abc.abstractmethod
def __call__(self, *args, **kwargs):
"""
AbstractMethod to specify that each model should be able to be called
for predictions
Parameters
----------
*args :
Positional arguments
**kwargs :
Keyword Arguments
Raises
------
NotImplementedError
if not overwritten by subclass
"""
raise NotImplementedError()
[docs] @staticmethod
@abc.abstractmethod
def closure(model, data_dict: dict, optimizers: dict, losses=None,
metrics=None, fold=0, **kwargs):
"""
Function which handles prediction from batch, logging, loss calculation
and optimizer step
Parameters
----------
model : :class:`AbstractNetwork`
model to forward data through
data_dict : dict
dictionary containing the data
optimizers : dict
dictionary containing all optimizers to perform parameter update
losses : dict
Functions or classes to calculate losses
metrics : dict
Functions or classes to calculate other metrics
fold : int
Current Fold in Crossvalidation (default: 0)
kwargs : dict
additional keyword arguments
Returns
-------
dict
Metric values (with same keys as input dict metrics)
dict
Loss values (with same keys as input dict losses)
dict
Arbitrary number of predictions
Raises
------
NotImplementedError
If not overwritten by subclass
"""
if losses is None:
losses = {}
if metrics is None:
metrics = {}
raise NotImplementedError()
[docs] @staticmethod
def prepare_batch(batch: dict, input_device, output_device):
"""
Converts a numpy batch of data and labels to suitable datatype and
pushes them to correct devices
Parameters
----------
batch : dict
dictionary containing the batch (must have keys 'data' and 'label'
input_device :
device for network inputs
output_device :
device for network outputs
Returns
-------
dict
dictionary containing all necessary data in right format and type
and on the correct device
Raises
------
NotImplementedError
If not overwritten by subclass
"""
raise NotImplementedError()
@property
def init_kwargs(self):
"""
Returns all arguments registered as init kwargs
Returns
-------
dict
init kwargs
"""
return self._init_kwargs