Predictor

The predictor implements the basic prediction and metric calculation routines

and can be subclassed for special routines.

It is also the baseclass of all the trainers which extend it’s functionality by training routines

Predictor

class Predictor(model, key_mapping: dict, convert_batch_to_npy_fn=<function convert_batch_to_numpy_identity>, prepare_batch_fn=<function Predictor.<lambda>>, **kwargs)[source]

Bases: object

Defines an API for Predictions from a Network

_Predictor__KEYS_TO_GUARD = []
static _Predictor__concatenate_dict_items(dict_like: dict)

Function to recursively concatenate dict-items

Parameters

dict_like (dict) – the (nested) dict, whoose items should be concatenated

static _Predictor__convert_dict(old_dict, new_dict)

Function to recursively convert dicts

Parameters
  • old_dict (dict) – the old nested dict

  • new_dict (dict) – the new nested dict

Returns

the updated new nested dict

Return type

dict

_setup(network, key_mapping, convert_batch_args_kwargs_to_npy_fn, prepare_batch_fn, **kwargs)[source]
Parameters
  • network (AbstractNetwork) – the network to predict from

  • key_mapping (dict) – a dictionary containing the mapping from the data_dict to the actual model’s inputs. E.g. if a model accepts one input named ‘x’ and the data_dict contains one entry named ‘data’ this argument would have to be {'x': 'data'}

  • convert_batch_to_npy_fn (type) – a callable function to convert tensors in positional and keyword arguments to numpy

  • prepare_batch_fn (type) – function converting a batch-tensor to the framework specific tensor-type and pushing it to correct device, default: identity function

static calc_metrics(batch: delira.utils.config.LookupConfig, metrics={}, metric_keys=None)[source]

Compute metrics

Parameters
  • batch (LookupConfig) – dictionary containing the whole batch (including predictions)

  • metrics (dict) – dict with metrics

  • metric_keys (dict) – dict of tuples which contains hashables for specifying the items to use for calculating the respective metric. If not specified for a metric, the keys “pred” and “label” are used per default

Returns

dict with metric results

Return type

dict

predict(data: dict, **kwargs)[source]

Predict single batch Returns the predictions corresponding to the given data obtained by the model

Parameters
  • data (dict) – batch dictionary

  • **kwargs – keyword arguments(directly passed to prepare_batch)

Returns

predicted data

Return type

dict

predict_data_mgr(datamgr, batchsize=None, metrics={}, metric_keys=None, verbose=False, **kwargs)[source]

Defines a routine to predict data obtained from a batchgenerator without explicitly caching anything

Parameters
  • datamgr (BaseDataManager) – Manager producing a generator holding the batches

  • batchsize (int) – Artificial batchsize (sampling will be done with batchsize 1 and sampled data will be stacked to match the artificial batchsize)(default: None)

  • metrics (dict) – the metrics to calculate

  • metric_keys (dict) – the batch_dict items to use for metric calculation

  • verbose (bool) – whether to show a progress-bar or not, default: False

  • kwargs – keyword arguments passed to prepare_batch_fn()

Yields
  • dict – a dictionary containing all predictions of the current batch

  • dict – a dictionary containing all metrics of the current batch

predict_data_mgr_cache(datamgr, batchsize=None, metrics={}, metric_keys=None, verbose=False, cache_preds=False, **kwargs)[source]

Defines a routine to predict data obtained from a batchgenerator and caches all predictions and metrics (yields them in dicts)

Parameters
  • datamgr (BaseDataManager) – Manager producing a generator holding the batches

  • batchsize (int) – Artificial batchsize (sampling will be done with batchsize 1 and sampled data will be stacked to match the artificial batchsize)(default: None)

  • metrics (dict) – the metrics to calculate

  • metric_keys (dict) – the batch_dict items to use for metric calculation

  • verbose (bool) – whether to show a progress-bar or not, default: False

  • kwargs – keyword arguments passed to prepare_batch_fn()

Yields
  • dict – a dictionary containing all validation metrics (maybe empty)

  • dict – a dictionary containing all predictions; If cache_preds=True

Warning

Since this function caches all metrics and may additionally cache all predictions (based on the argument cache_preds), this may result in huge memory consumption. If you are running out of memory, please have a look at Predictor.predict_data_mgr_cache_metrics_only() or Predictor.predict_data_mgr() or consider setting cache_preds to False (if not done already)

predict_data_mgr_cache_all(datamgr, batchsize=None, metrics={}, metric_keys=None, verbose=False, **kwargs)[source]

Defines a routine to predict data obtained from a batchgenerator and caches all predictions and metrics (yields them in dicts)

Parameters
  • datamgr (BaseDataManager) – Manager producing a generator holding the batches

  • batchsize (int) – Artificial batchsize (sampling will be done with batchsize 1 and sampled data will be stacked to match the artificial batchsize)(default: None)

  • metrics (dict) – the metrics to calculate

  • metric_keys (dict) – the batch_dict items to use for metric calculation

  • verbose (bool) – whether to show a progress-bar or not, default: False

  • kwargs – keyword arguments passed to prepare_batch_fn()

Yields
  • dict – a dictionary containing all predictions;

  • dict – a dictionary containing all validation metrics (maybe empty)

Warning

Since this function caches all predictions and metrics, this may result in huge memory consumption. If you are running out of memory, please have a look at Predictor.predict_data_mgr_cache_metrics_only() or Predictor.predict_data_mgr()

predict_data_mgr_cache_metrics_only(datamgr, batchsize=None, metrics={}, metric_keys=None, verbose=False, **kwargs)[source]

Defines a routine to predict data obtained from a batchgenerator and caches the metrics

Parameters
  • datamgr (BaseDataManager) – Manager producing a generator holding the batches

  • batchsize (int) – Artificial batchsize (sampling will be done with batchsize 1 and sampled data will be stacked to match the artificial batchsize)(default: None)

  • metrics (dict) – the metrics to calculate

  • metric_keys (dict) – the batch_dict items to use for metric calculation

  • verbose (bool) – whether to show a progress-bar or not, default: False

  • kwargs – keyword arguments passed to prepare_batch_fn()

Yields

dict – a dictionary containing all validation metrics (maybe empty)

Notes

This function stores each prediction temporarily for metric calculation; This results in a (typically) way lower memory consumption than Predictor.predict_data_mgr_cache_all(), but still caches the metrics. If this is not desired, it is recommended to use Predictor.predict_data_mgr() and iterate over the generator as this only produces per-batch metrics and predictions and does not cache anything by default