import abc
import chainer
import numpy as np
from delira.models.abstract_network import AbstractNetwork
# Use this Mixin Class to set __call__ to None, because there is an
# internal check inside chainer.Link.__call__ for other __call__ methods
# of parent classes to be not None. If this would be the case,
# this function would be executed instead of our forward
class ChainerMixin(AbstractNetwork):
__call__ = None
[docs]class AbstractChainerNetwork(chainer.Chain, ChainerMixin):
"""
Abstract Class for Chainer Networks
"""
def __init__(self, **kwargs):
"""
Parameters
----------
**kwargs :
keyword arguments of arbitrary number and type
(will be registered as ``init_kwargs``)
"""
chainer.Chain.__init__(self)
AbstractNetwork.__init__(self, **kwargs)
[docs] @abc.abstractmethod
def forward(self, *args, **kwargs) -> dict:
"""
Feeds Arguments through the network
Parameters
----------
*args :
positional arguments of arbitrary number and type
**kwargs :
keyword arguments of arbitrary number and type
Returns
-------
dict
dictionary containing all computation results
"""
raise NotImplementedError
def __call__(self, *args, **kwargs) -> dict:
"""
Makes instances of this class callable.
Calls the ``forward`` method.
Parameters
----------
*args :
positional arguments of arbitrary number and type
**kwargs :
keyword arguments of arbitrary number and type
Returns
-------
dict
dictionary containing all computation results
"""
return chainer.Chain.__call__(self, *args, **kwargs)
[docs] @staticmethod
def prepare_batch(batch: dict, input_device, output_device):
"""
Helper Function to prepare Network Inputs and Labels (convert them
to correct type and shape and push them to correct devices)
Parameters
----------
batch : dict
dictionary containing all the data
input_device : chainer.backend.Device or string
device for network inputs
output_device : torch.device
device for network outputs
Returns
-------
dict
dictionary containing data in correct type and shape and on
correct device
"""
new_batch = {k: chainer.as_variable(v.astype(np.float32))
for k, v in batch.items()}
for k, v in new_batch.items():
if k == "data":
device = input_device
else:
device = output_device
# makes modification inplace!
v.to_device(device)
return new_batch
[docs] @staticmethod
def closure(model, data_dict: dict, optimizers: dict, losses={},
metrics={}, fold=0, **kwargs):
"""
default closure method to do a single training step;
Could be overwritten for more advanced models
Parameters
----------
model : :class:`AbstractChainerNetwork`
trainable model
data_dict : dict
dictionary containing the data
optimizers : dict
dictionary of optimizers to optimize model's parameters;
ignored here, just passed for compatibility reasons
losses : dict
dict holding the losses to calculate errors;
ignored here, just passed for compatibility reasons
metrics : dict
dict holding the metrics to calculate
fold : int
Current Fold in Crossvalidation (default: 0)
**kwargs:
additional keyword arguments
Returns
-------
dict
Metric values (with same keys as input dict metrics)
dict
Loss values (with same keys as input dict losses; will always
be empty here)
dict
dictionary containing all predictions
"""
assert (optimizers and losses) or not optimizers, \
"Criterion dict cannot be emtpy, if optimizers are passed"
loss_vals = {}
metric_vals = {}
total_loss = 0
inputs = data_dict["data"]
preds = model(inputs)
if data_dict:
for key, crit_fn in losses.items():
_loss_val = crit_fn(preds["pred"], data_dict["label"])
loss_vals[key] = _loss_val.item()
total_loss += _loss_val
with chainer.using_config("train", False):
for key, metric_fn in metrics.items():
metric_vals[key] = metric_fn(
preds["pred"], data_dict["label"]).item()
if optimizers:
model.cleargrads()
total_loss.backward()
optimizers['default'].update()
else:
# add prefix "val" in validation mode
eval_loss_vals, eval_metrics_vals = {}, {}
for key in loss_vals.keys():
eval_loss_vals["val_" + str(key)] = loss_vals[key]
for key in metric_vals:
eval_metrics_vals["val_" + str(key)] = metric_vals[key]
loss_vals = eval_loss_vals
metric_vals = eval_metrics_vals
return metric_vals, loss_vals, {k: v.unchain()
for k, v in preds.items()}