Source code for

import logging
import copy

import numpy as np
from tqdm import tqdm

from ..data_loading import BaseDataManager
from .train_utils import convert_batch_to_numpy_identity
from ..utils.config import LookupConfig

logger = logging.getLogger(__name__)

[docs]class Predictor(object): """ Defines an API for Predictions from a Network See Also -------- :class:`PyTorchNetworkTrainer` """ # static variable to prevent certain attributers from overwriting __KEYS_TO_GUARD = [] def __init__( self, model, key_mapping: dict, convert_batch_to_npy_fn=convert_batch_to_numpy_identity, prepare_batch_fn=lambda x: x, **kwargs): """ Parameters ---------- model : :class:`AbstractNetwork` the model to predict from key_mapping : dict a dictionary containing the mapping from the ``data_dict`` to the actual model's inputs. E.g. if a model accepts one input named 'x' and the data_dict contains one entry named 'data' this argument would have to be ``{'x': 'data'}`` convert_batch_args_kwargs_to_npy_fn : type, optional a callable function to convert tensors in positional and keyword arguments to numpy; default: identity function prepare_batch_fn : type, optional function converting a batch-tensor to the framework specific tensor-type and pushing it to correct device, default: identity function **kwargs : additional keyword arguments """ self._setup(model, key_mapping, convert_batch_to_npy_fn, prepare_batch_fn, **kwargs) self._tqdm_desc = "Test"
[docs] def _setup(self, network, key_mapping, convert_batch_args_kwargs_to_npy_fn, prepare_batch_fn, **kwargs): """ Parameters ---------- network : :class:`AbstractNetwork` the network to predict from key_mapping : dict a dictionary containing the mapping from the ``data_dict`` to the actual model's inputs. E.g. if a model accepts one input named 'x' and the data_dict contains one entry named 'data' this argument would have to be ``{'x': 'data'}`` convert_batch_to_npy_fn : type a callable function to convert tensors in positional and keyword arguments to numpy prepare_batch_fn : type function converting a batch-tensor to the framework specific tensor-type and pushing it to correct device, default: identity function """ self.module = network self.key_mapping = key_mapping self._convert_to_npy_fn = convert_batch_args_kwargs_to_npy_fn self._prepare_batch = prepare_batch_fn
def __call__(self, data: dict, **kwargs): """ Method to call the class. Returns the predictions corresponding to the given data obtained by the model Parameters ---------- data : dict batch dictionary Returns ------- dict predicted data """ return self.predict(data, **kwargs)
[docs] def predict(self, data: dict, **kwargs): """ Predict single batch Returns the predictions corresponding to the given data obtained by the model Parameters ---------- data : dict batch dictionary **kwargs : keyword arguments(directly passed to ``prepare_batch``) Returns ------- dict predicted data """ data = self._prepare_batch(data, **kwargs) mapped_data = { k: data[v] for k, v in self.key_mapping.items()} pred = self.module( **mapped_data ) # converts positional arguments and keyword arguments, # but returns only keyword arguments, since positional # arguments are not given. return self._convert_to_npy_fn( **pred )[1]
[docs] def predict_data_mgr(self, datamgr, batchsize=None, metrics=None, metric_keys=None, verbose=False, **kwargs): """ Defines a routine to predict data obtained from a batchgenerator without explicitly caching anything Parameters ---------- datamgr : :class:`BaseDataManager` Manager producing a generator holding the batches batchsize : int Artificial batchsize (sampling will be done with batchsize 1 and sampled data will be stacked to match the artificial batchsize)(default: None) metrics : dict the metrics to calculate metric_keys : dict the ``batch_dict`` items to use for metric calculation verbose : bool whether to show a progress-bar or not, default: False kwargs : keyword arguments passed to :func:`prepare_batch_fn` Yields ------ dict a dictionary containing all predictions of the current batch dict a dictionary containing all metrics of the current batch """ if metrics is None: metrics = {} orig_num_aug_processes = datamgr.n_process_augmentation orig_batch_size = datamgr.batch_size if batchsize is None: batchsize = orig_batch_size datamgr.batch_size = 1 datamgr.n_process_augmentation = 1 batchgen = datamgr.get_batchgen() n_batches = batchgen.num_batches if verbose: iterable = tqdm(enumerate(batchgen), unit=' sample', total=n_batches, desc=self._tqdm_desc) else: iterable = enumerate(batchgen) batch_list = [] for i, batch in iterable: if not batch_list and (n_batches - i) < batchsize: batchsize = n_batches - i logger.debug("Set Batchsize down to %d to avoid cutting " "of the last batches" % batchsize) batch_list.append(batch) # if queue is full process queue: if batchsize is None or len(batch_list) >= batchsize: batch_dict = {} for _batch in batch_list: for key, val in _batch.items(): if key in batch_dict.keys(): batch_dict[key].append(val) else: batch_dict[key] = [val] for key, val_list in batch_dict.items(): batch_dict[key] = np.concatenate(val_list) preds = self.predict(copy.copy(batch_dict), **kwargs) # convert batchdict back to numpy (self.predict may convert it # to backend-specific tensor type) - no-op if already numpy batch_dict = self._convert_to_npy_fn(**batch_dict)[1] preds_batch = LookupConfig() preds_batch.update(batch_dict) preds_batch.update(preds) # calculate metrics for predicted batch _metric_vals = self.calc_metrics(preds_batch, metrics=metrics, metric_keys=metric_keys) yield preds, _metric_vals batch_list = [] batchgen._finish() datamgr.batch_size = orig_batch_size datamgr.n_process_augmentation = orig_num_aug_processes return
[docs] def predict_data_mgr_cache_metrics_only(self, datamgr, batchsize=None, metrics=None, metric_keys=None, verbose=False, **kwargs): """ Defines a routine to predict data obtained from a batchgenerator and caches the metrics Parameters ---------- datamgr : :class:`BaseDataManager` Manager producing a generator holding the batches batchsize : int Artificial batchsize (sampling will be done with batchsize 1 and sampled data will be stacked to match the artificial batchsize)(default: None) metrics : dict the metrics to calculate metric_keys : dict the ``batch_dict`` items to use for metric calculation verbose : bool whether to show a progress-bar or not, default: False kwargs : keyword arguments passed to :func:`prepare_batch_fn` Yields ------ dict a dictionary containing all validation metrics (maybe empty) Notes ----- This function stores each prediction temporarily for metric calculation; This results in a (typically) way lower memory consumption than :meth:`Predictor.predict_data_mgr_cache_all`, but still caches the metrics. If this is not desired, it is recommended to use :meth:`Predictor.predict_data_mgr` and iterate over the generator as this only produces per-batch metrics and predictions and does not cache anything by default """ if metrics is None: metrics = {} yield from self.predict_data_mgr_cache(datamgr=datamgr, batchsize=batchsize, metrics=metrics, metric_keys=metric_keys, verbose=verbose, cache_preds=False, **kwargs) return
[docs] def predict_data_mgr_cache_all(self, datamgr, batchsize=None, metrics=None, metric_keys=None, verbose=False, **kwargs): """ Defines a routine to predict data obtained from a batchgenerator and caches all predictions and metrics (yields them in dicts) Parameters ---------- datamgr : :class:`BaseDataManager` Manager producing a generator holding the batches batchsize : int Artificial batchsize (sampling will be done with batchsize 1 and sampled data will be stacked to match the artificial batchsize)(default: None) metrics : dict the metrics to calculate metric_keys : dict the ``batch_dict`` items to use for metric calculation verbose : bool whether to show a progress-bar or not, default: False kwargs : keyword arguments passed to :func:`prepare_batch_fn` Yields ------ dict a dictionary containing all predictions; dict a dictionary containing all validation metrics (maybe empty) Warnings -------- Since this function caches all predictions and metrics, this may result in huge memory consumption. If you are running out of memory, please have a look at :meth:`Predictor.predict_data_mgr_cache_metrics_only` or :meth:`Predictor.predict_data_mgr` """ if metrics is None: metrics = {} yield from self.predict_data_mgr_cache(datamgr=datamgr, batchsize=batchsize, metrics=metrics, metric_keys=metric_keys, verbose=verbose, cache_preds=True, **kwargs) return
[docs] def predict_data_mgr_cache(self, datamgr, batchsize=None, metrics=None, metric_keys=None, verbose=False, cache_preds=False, **kwargs): """ Defines a routine to predict data obtained from a batchgenerator and caches all predictions and metrics (yields them in dicts) Parameters ---------- datamgr : :class:`BaseDataManager` Manager producing a generator holding the batches batchsize : int Artificial batchsize (sampling will be done with batchsize 1 and sampled data will be stacked to match the artificial batchsize)(default: None) metrics : dict the metrics to calculate metric_keys : dict the ``batch_dict`` items to use for metric calculation verbose : bool whether to show a progress-bar or not, default: False cache_preds : bool whether to also cache predictions kwargs : keyword arguments passed to :func:`prepare_batch_fn` Yields ------ dict a dictionary containing all validation metrics (maybe empty) dict a dictionary containing all predictions; If ``cache_preds=True`` Warnings -------- Since this function caches all metrics and may additionally cache all predictions (based on the argument ``cache_preds``), this may result in huge memory consumption. If you are running out of memory, please have a look at :meth:`Predictor.predict_data_mgr_cache_metrics_only` or :meth:`Predictor.predict_data_mgr` or consider setting ``cache_preds`` to ``False`` (if not done already) """ if metrics is None: metrics = {} predictions_all, metric_vals = [], {k: [] for k in metrics.keys()} for preds, _metric_vals in self.predict_data_mgr( datamgr=datamgr, batchsize=batchsize, metrics=metrics, metric_keys=metric_keys, verbose=verbose, **kwargs): if cache_preds: predictions_all.append(preds) for k, v in _metric_vals.items(): metric_vals[k].append(v) if cache_preds: # convert predictions from list of dicts to dict of lists new_predictions_all = {} # recursively convert all nested dicts for preds in predictions_all: new_predictions_all = self.__convert_dict(preds, new_predictions_all) # concatenate lists to single arrays preds_all = self.__concatenate_dict_items(new_predictions_all) else: preds_all = {} for k, v in metric_vals.items(): metric_vals[k] = np.array(v) if cache_preds: yield preds_all, metric_vals else: yield metric_vals return
@staticmethod def __convert_dict(old_dict, new_dict): """ Function to recursively convert dicts Parameters ---------- old_dict : dict the old nested dict new_dict : dict the new nested dict Returns ------- dict the updated new nested dict """ for k, v in old_dict.items(): # apply same function again on item if item is dict if isinstance(v, dict): if k not in new_dict: new_dict[k] = {} new_dict[k] = Predictor.__convert_dict(v, new_dict[k]) else: # check if v is scalar and convert to npy-array if # necessary. # Otherwise concatenation might fail if np.isscalar(v): v = np.array(v) # check for zero-sized arrays and reshape if necessary. # Otherwise concatenation might fail if v.shape == (): v = v.reshape(1) if k in new_dict: new_dict[k].append(v) else: new_dict[k] = [v] return new_dict @staticmethod def __concatenate_dict_items(dict_like: dict): """ Function to recursively concatenate dict-items Parameters ---------- dict_like : dict the (nested) dict, whoose items should be concatenated Returns ------- """ for k, v in dict_like.items(): if isinstance(v, dict): v = Predictor.__concatenate_dict_items(v) else: v = np.concatenate(v) dict_like[k] = v return dict_like def __setattr__(self, key, value): """ Set attributes and guard specific attributes after they have been set once Parameters ---------- key : str the attributes name value : Any the value to set Raises ------ PermissionError If attribute which should be set is guarded """ # check if key has been set once if key in self.__KEYS_TO_GUARD and hasattr(self, key): raise PermissionError("%s should not be overwritten after " "it has been set once" % key) else: super().__setattr__(key, value)
[docs] @staticmethod def calc_metrics(batch: LookupConfig, metrics=None, metric_keys=None): """ Compute metrics Parameters ---------- batch: LookupConfig dictionary containing the whole batch (including predictions) metrics: dict dict with metrics metric_keys : dict dict of tuples which contains hashables for specifying the items to use for calculating the respective metric. If not specified for a metric, the keys "pred" and "label" are used per default Returns ------- dict dict with metric results """ if metrics is None: metrics = {} if metric_keys is None: metric_keys = {k: ("pred", "label") for k in metrics.keys()} return {key: metric_fn(*[batch.nested_get(k) for k in metric_keys[key]]) for key, metric_fn in metrics.items()}