Predictor¶
- The predictor implements the basic prediction and metric calculation routines
and can be subclassed for special routines.
It is also the baseclass of all the trainers which extend it’s functionality by training routines
Predictor¶
-
class
Predictor
(model, key_mapping: dict, convert_batch_to_npy_fn=<function convert_batch_to_numpy_identity>, prepare_batch_fn=<function Predictor.<lambda>>, **kwargs)[source]¶ Bases:
object
Defines an API for Predictions from a Network
See also
-
_Predictor__KEYS_TO_GUARD
= []¶
-
static
_Predictor__concatenate_dict_items
(dict_like: dict)¶ Function to recursively concatenate dict-items
- Parameters
dict_like (dict) – the (nested) dict, whoose items should be concatenated
-
static
_Predictor__convert_dict
(old_dict, new_dict)¶ Function to recursively convert dicts
-
_setup
(network, key_mapping, convert_batch_args_kwargs_to_npy_fn, prepare_batch_fn, **kwargs)[source]¶ - Parameters
network (
AbstractNetwork
) – the network to predict fromkey_mapping (dict) – a dictionary containing the mapping from the
data_dict
to the actual model’s inputs. E.g. if a model accepts one input named ‘x’ and the data_dict contains one entry named ‘data’ this argument would have to be{'x': 'data'}
convert_batch_to_npy_fn (type) – a callable function to convert tensors in positional and keyword arguments to numpy
prepare_batch_fn (type) – function converting a batch-tensor to the framework specific tensor-type and pushing it to correct device, default: identity function
-
static
calc_metrics
(batch: delira.utils.config.LookupConfig, metrics={}, metric_keys=None)[source]¶ Compute metrics
- Parameters
batch (LookupConfig) – dictionary containing the whole batch (including predictions)
metrics (dict) – dict with metrics
metric_keys (dict) – dict of tuples which contains hashables for specifying the items to use for calculating the respective metric. If not specified for a metric, the keys “pred” and “label” are used per default
- Returns
dict with metric results
- Return type
-
predict
(data: dict, **kwargs)[source]¶ Predict single batch Returns the predictions corresponding to the given data obtained by the model
-
predict_data_mgr
(datamgr, batchsize=None, metrics={}, metric_keys=None, verbose=False, **kwargs)[source]¶ Defines a routine to predict data obtained from a batchgenerator without explicitly caching anything
- Parameters
datamgr (
BaseDataManager
) – Manager producing a generator holding the batchesbatchsize (int) – Artificial batchsize (sampling will be done with batchsize 1 and sampled data will be stacked to match the artificial batchsize)(default: None)
metrics (dict) – the metrics to calculate
metric_keys (dict) – the
batch_dict
items to use for metric calculationverbose (bool) – whether to show a progress-bar or not, default: False
kwargs – keyword arguments passed to
prepare_batch_fn()
- Yields
dict – a dictionary containing all predictions of the current batch
dict – a dictionary containing all metrics of the current batch
-
predict_data_mgr_cache
(datamgr, batchsize=None, metrics={}, metric_keys=None, verbose=False, cache_preds=False, **kwargs)[source]¶ Defines a routine to predict data obtained from a batchgenerator and caches all predictions and metrics (yields them in dicts)
- Parameters
datamgr (
BaseDataManager
) – Manager producing a generator holding the batchesbatchsize (int) – Artificial batchsize (sampling will be done with batchsize 1 and sampled data will be stacked to match the artificial batchsize)(default: None)
metrics (dict) – the metrics to calculate
metric_keys (dict) – the
batch_dict
items to use for metric calculationverbose (bool) – whether to show a progress-bar or not, default: False
kwargs – keyword arguments passed to
prepare_batch_fn()
- Yields
dict – a dictionary containing all validation metrics (maybe empty)
dict – a dictionary containing all predictions; If
cache_preds=True
Warning
Since this function caches all metrics and may additionally cache all predictions (based on the argument
cache_preds
), this may result in huge memory consumption. If you are running out of memory, please have a look atPredictor.predict_data_mgr_cache_metrics_only()
orPredictor.predict_data_mgr()
or consider settingcache_preds
toFalse
(if not done already)
-
predict_data_mgr_cache_all
(datamgr, batchsize=None, metrics={}, metric_keys=None, verbose=False, **kwargs)[source]¶ Defines a routine to predict data obtained from a batchgenerator and caches all predictions and metrics (yields them in dicts)
- Parameters
datamgr (
BaseDataManager
) – Manager producing a generator holding the batchesbatchsize (int) – Artificial batchsize (sampling will be done with batchsize 1 and sampled data will be stacked to match the artificial batchsize)(default: None)
metrics (dict) – the metrics to calculate
metric_keys (dict) – the
batch_dict
items to use for metric calculationverbose (bool) – whether to show a progress-bar or not, default: False
kwargs – keyword arguments passed to
prepare_batch_fn()
- Yields
dict – a dictionary containing all predictions;
dict – a dictionary containing all validation metrics (maybe empty)
Warning
Since this function caches all predictions and metrics, this may result in huge memory consumption. If you are running out of memory, please have a look at
Predictor.predict_data_mgr_cache_metrics_only()
orPredictor.predict_data_mgr()
-
predict_data_mgr_cache_metrics_only
(datamgr, batchsize=None, metrics={}, metric_keys=None, verbose=False, **kwargs)[source]¶ Defines a routine to predict data obtained from a batchgenerator and caches the metrics
- Parameters
datamgr (
BaseDataManager
) – Manager producing a generator holding the batchesbatchsize (int) – Artificial batchsize (sampling will be done with batchsize 1 and sampled data will be stacked to match the artificial batchsize)(default: None)
metrics (dict) – the metrics to calculate
metric_keys (dict) – the
batch_dict
items to use for metric calculationverbose (bool) – whether to show a progress-bar or not, default: False
kwargs – keyword arguments passed to
prepare_batch_fn()
- Yields
dict – a dictionary containing all validation metrics (maybe empty)
Notes
This function stores each prediction temporarily for metric calculation; This results in a (typically) way lower memory consumption than
Predictor.predict_data_mgr_cache_all()
, but still caches the metrics. If this is not desired, it is recommended to usePredictor.predict_data_mgr()
and iterate over the generator as this only produces per-batch metrics and predictions and does not cache anything by default
-