import abc
import logging
import tensorflow as tf
import numpy as np

from delira.models.abstract_network import AbstractNetwork

[docs]class AbstractTfGraphNetwork(AbstractNetwork, metaclass=abc.ABCMeta): """ Abstract Class for Tf Networks See Also -------- :class:`AbstractNetwork` """ @abc.abstractmethod def __init__(self, sess=tf.Session, **kwargs): """ Parameters ---------- **kwargs : keyword arguments (are passed to :class:`AbstractNetwork`'s ` __init__ to register them as init kwargs """ AbstractNetwork.__init__(self, **kwargs) self._sess = sess() self.inputs = {} self.outputs_train = {} self.outputs_eval = {} self._losses = None self._optims = None = True def __call__(self, *args, **kwargs): """ Wrapper for calling in eval setting Parameters ---------- *args : positional arguments (passed to ``) **kwargs: keyword arguments (passed to ``) Returns ------- Any result: module results of arbitrary type and number """ = False return*args, **kwargs)
[docs] def run(self, *args, **kwargs): """ Evaluates `self.outputs_train` or `self.outputs_eval` based on `` Parameters ---------- *args : currently unused, exist for compatibility reasons **kwargs : kwargs used to feed as ``self.inputs``. Same keys as for ``self.inputs`` must be used Returns ------- dict sames keys as outputs_train or outputs_eval, containing evaluated expressions as values """ _feed_dict = {} for feed_key, feed_value in kwargs.items(): assert feed_key in self.inputs.keys(), \ "{} not found in self.inputs".format(feed_key) _feed_dict[self.inputs[feed_key]] = feed_value if return, feed_dict=_feed_dict) return, feed_dict=_feed_dict)
[docs] def _add_losses(self, losses: dict): """ Adds losses to model that are to be used by optimizers or during evaluation. Can be overwritten for more advanced loss behavior Parameters ---------- losses : dict dictionary containing all losses. Individual losses are averaged """ if self._losses is not None and losses: logging.warning('Change of losses is not yet supported') raise NotImplementedError() elif self._losses is not None and not losses: pass else: self._losses = {} for name, _loss in losses.items(): self._losses[name] = _loss(self.inputs["label"], self.outputs_train["pred"]) total_loss = tf.reduce_mean(list(self._losses.values()), axis=0) self._losses['total'] = total_loss self.outputs_train["losses"] = self._losses self.outputs_eval["losses"] = self._losses
[docs] def _add_optims(self, optims: dict): """ Adds optims to model that are to be used by optimizers or during training. Can be overwritten for more advanced optimizers Parameters ---------- optim: dict dictionary containing all optimizers, optimizers should be of Type[tf.train.Optimizer] """ if self._optims is not None and optims: logging.warning('Change of optims is not yet supported') elif self._optims is not None and not optims: pass else: self._optims = optims['default'] grads = self._optims.compute_gradients(self._losses['total']) step = self._optims.apply_gradients(grads) self.outputs_train["default_step"] = step
[docs] @staticmethod def prepare_batch(batch: dict, input_device, output_device): """ Helper Function to prepare Network Inputs and Labels (convert them to correct type and shape and push them to correct devices) Parameters ---------- batch : dict dictionary containing all the data input_device : Any device for module inputs (will be ignored here; just given for compatibility) output_device : Any device for module outputs (will be ignored here; just given for compatibility) Returns ------- dict dictionary containing data in correct type and shape and on correct device """ return {k: v.astype(np.float32) for k, v in batch.items()}
[docs] @staticmethod def closure(model, data_dict: dict, optimizers: dict, losses={}, metrics={}, fold=0, **kwargs): """ default closure method to do a single training step; Could be overwritten for more advanced models Parameters ---------- model : :class:`SkLearnEstimator` trainable model data_dict : dict dictionary containing the data optimizers : dict dictionary of optimizers to optimize model's parameters; ignored here, just passed for compatibility reasons losses : dict dict holding the losses to calculate errors; ignored here, just passed for compatibility reasons metrics : dict dict holding the metrics to calculate fold : int Current Fold in Crossvalidation (default: 0) **kwargs: additional keyword arguments Returns ------- dict Metric values (with same keys as input dict metrics) dict Loss values (with same keys as input dict losses; will always be empty here) dict dictionary containing all predictions """ loss_vals = {} metric_vals = {} inputs = data_dict['data'] outputs =, label=data_dict['label']) preds = outputs['pred'] losses = outputs['losses'] for key, loss_val in losses.items(): loss_vals[key] = loss_val for key, metric_fn in metrics.items(): metric_vals[key] = metric_fn( preds, data_dict["label"]) if not # add prefix "val" in validation mode eval_loss_vals, eval_metrics_vals = {}, {} for key in loss_vals.keys(): eval_loss_vals["val_" + str(key)] = loss_vals[key] for key in metric_vals: eval_metrics_vals["val_" + str(key)] = metric_vals[key] loss_vals = eval_loss_vals metric_vals = eval_metrics_vals return metric_vals, loss_vals, outputs