NetworkTrainer¶
- The network trainer implements the actual training routine and can be subclassed
for special routines.
BaseNetworkTrainer¶
-
class
BaseNetworkTrainer
(network: delira.models.abstract_network.AbstractNetwork, save_path: str, losses: dict, optimizer_cls: type, optimizer_params: dict, train_metrics: dict, val_metrics: dict, lr_scheduler_cls: type, lr_scheduler_params: dict, gpu_ids: List[int], save_freq: int, optim_fn, key_mapping: dict, logging_type: str, logging_kwargs: dict, fold: int, callbacks: List[delira.training.callbacks.abstract_callback.AbstractCallback], start_epoch=1, metric_keys=None, convert_batch_to_npy_fn=<function BaseNetworkTrainer.<lambda>>, val_freq=1, **kwargs)[source]¶ Bases:
delira.training.predictor.Predictor
Defines a Base API and basic functions for Network Trainers
See also
-
_BaseNetworkTrainer__KEYS_TO_GUARD
= ['use_gpu', 'input_device', 'output_device', '_callbacks']¶
-
_Predictor__KEYS_TO_GUARD
= []¶
-
static
_Predictor__concatenate_dict_items
(dict_like: dict)¶ Function to recursively concatenate dict-items
- Parameters
dict_like (dict) – the (nested) dict, whoose items should be concatenated
-
static
_Predictor__convert_dict
(old_dict, new_dict)¶ Function to recursively convert dicts
-
_at_epoch_begin
(metrics_val, val_score_key, epoch, num_epochs, **kwargs)[source]¶ Defines behaviour at beginning of each epoch: Executes all callbacks’s at_epoch_begin method
-
_at_epoch_end
(metrics_val, val_score_key, epoch, is_best, **kwargs)[source]¶ Defines behaviour at beginning of each epoch: Executes all callbacks’s at_epoch_end method and saves current state if necessary
-
_at_training_begin
(*args, **kwargs)[source]¶ Defines the behaviour at beginnig of the training
- Parameters
*args – positional arguments
**kwargs – keyword arguments
- Raises
NotImplementedError – If not overwritten by subclass
-
_at_training_end
(*args, **kwargs)[source]¶ Defines the behaviour at the end of the training
- Parameters
*args – positional arguments
**kwargs – keyword arguments
- Raises
NotImplementedError – If not overwritten by subclass
-
static
_is_better_val_scores
(old_val_score, new_val_score, mode='highest')[source]¶ Check whether the new val score is better than the old one with respect to the optimization goal
-
static
_search_for_prev_state
(path, extensions=None)[source]¶ Helper function to search in a given path for previous epoch states (indicated by extensions)
- Parameters
- Returns
str – the file containing the latest checkpoint (if available)
None – if no latst checkpoint was found
int – the latest epoch (1 if no checkpoint was found)
-
_setup
(network, lr_scheduler_cls, lr_scheduler_params, gpu_ids, key_mapping, convert_batch_to_npy_fn, prepare_batch_fn)[source]¶ - Parameters
network (
AbstractNetwork
) – the network to predict fromkey_mapping (dict) – a dictionary containing the mapping from the
data_dict
to the actual model’s inputs. E.g. if a model accepts one input named ‘x’ and the data_dict contains one entry named ‘data’ this argument would have to be{'x': 'data'}
convert_batch_to_npy_fn (type) – a callable function to convert tensors in positional and keyword arguments to numpy
prepare_batch_fn (type) – function converting a batch-tensor to the framework specific tensor-type and pushing it to correct device, default: identity function
-
_train_single_epoch
(batchgen: delira.data_loading.data_manager.Augmenter, epoch, verbose=False)[source]¶ Trains the network a single epoch
- Parameters
batchgen (
Augmenter
) – Generator yielding the training batchesepoch (int) – current epoch
-
_update_state
(new_state)[source]¶ Update the state from a given new state
- Parameters
new_state (dict) – new state to update internal state from
- Returns
the trainer with a modified state
- Return type
-
static
calc_metrics
(batch: delira.utils.config.LookupConfig, metrics=None, metric_keys=None)¶ Compute metrics
- Parameters
batch (LookupConfig) – dictionary containing the whole batch (including predictions)
metrics (dict) – dict with metrics
metric_keys (dict) – dict of tuples which contains hashables for specifying the items to use for calculating the respective metric. If not specified for a metric, the keys “pred” and “label” are used per default
- Returns
dict with metric results
- Return type
-
predict
(data: dict, **kwargs)¶ Predict single batch Returns the predictions corresponding to the given data obtained by the model
-
predict_data_mgr
(datamgr, batchsize=None, metrics=None, metric_keys=None, verbose=False, **kwargs)¶ Defines a routine to predict data obtained from a batchgenerator without explicitly caching anything
- Parameters
datamgr (
BaseDataManager
) – Manager producing a generator holding the batchesbatchsize (int) – Artificial batchsize (sampling will be done with batchsize 1 and sampled data will be stacked to match the artificial batchsize)(default: None)
metrics (dict) – the metrics to calculate
metric_keys (dict) – the
batch_dict
items to use for metric calculationverbose (bool) – whether to show a progress-bar or not, default: False
kwargs – keyword arguments passed to
prepare_batch_fn()
- Yields
dict – a dictionary containing all predictions of the current batch
dict – a dictionary containing all metrics of the current batch
-
predict_data_mgr_cache
(datamgr, batchsize=None, metrics=None, metric_keys=None, verbose=False, cache_preds=False, **kwargs)¶ Defines a routine to predict data obtained from a batchgenerator and caches all predictions and metrics (yields them in dicts)
- Parameters
datamgr (
BaseDataManager
) – Manager producing a generator holding the batchesbatchsize (int) – Artificial batchsize (sampling will be done with batchsize 1 and sampled data will be stacked to match the artificial batchsize)(default: None)
metrics (dict) – the metrics to calculate
metric_keys (dict) – the
batch_dict
items to use for metric calculationverbose (bool) – whether to show a progress-bar or not, default: False
cache_preds (bool) – whether to also cache predictions
kwargs – keyword arguments passed to
prepare_batch_fn()
- Yields
dict – a dictionary containing all validation metrics (maybe empty)
dict – a dictionary containing all predictions; If
cache_preds=True
Warning
Since this function caches all metrics and may additionally cache all predictions (based on the argument
cache_preds
), this may result in huge memory consumption. If you are running out of memory, please have a look atPredictor.predict_data_mgr_cache_metrics_only()
orPredictor.predict_data_mgr()
or consider settingcache_preds
toFalse
(if not done already)
-
predict_data_mgr_cache_all
(datamgr, batchsize=None, metrics=None, metric_keys=None, verbose=False, **kwargs)¶ Defines a routine to predict data obtained from a batchgenerator and caches all predictions and metrics (yields them in dicts)
- Parameters
datamgr (
BaseDataManager
) – Manager producing a generator holding the batchesbatchsize (int) – Artificial batchsize (sampling will be done with batchsize 1 and sampled data will be stacked to match the artificial batchsize)(default: None)
metrics (dict) – the metrics to calculate
metric_keys (dict) – the
batch_dict
items to use for metric calculationverbose (bool) – whether to show a progress-bar or not, default: False
kwargs – keyword arguments passed to
prepare_batch_fn()
- Yields
dict – a dictionary containing all predictions;
dict – a dictionary containing all validation metrics (maybe empty)
Warning
Since this function caches all predictions and metrics, this may result in huge memory consumption. If you are running out of memory, please have a look at
Predictor.predict_data_mgr_cache_metrics_only()
orPredictor.predict_data_mgr()
-
predict_data_mgr_cache_metrics_only
(datamgr, batchsize=None, metrics=None, metric_keys=None, verbose=False, **kwargs)¶ Defines a routine to predict data obtained from a batchgenerator and caches the metrics
- Parameters
datamgr (
BaseDataManager
) – Manager producing a generator holding the batchesbatchsize (int) – Artificial batchsize (sampling will be done with batchsize 1 and sampled data will be stacked to match the artificial batchsize)(default: None)
metrics (dict) – the metrics to calculate
metric_keys (dict) – the
batch_dict
items to use for metric calculationverbose (bool) – whether to show a progress-bar or not, default: False
kwargs – keyword arguments passed to
prepare_batch_fn()
- Yields
dict – a dictionary containing all validation metrics (maybe empty)
Notes
This function stores each prediction temporarily for metric calculation; This results in a (typically) way lower memory consumption than
Predictor.predict_data_mgr_cache_all()
, but still caches the metrics. If this is not desired, it is recommended to usePredictor.predict_data_mgr()
and iterate over the generator as this only produces per-batch metrics and predictions and does not cache anything by default
-
register_callback
(callback: delira.training.callbacks.abstract_callback.AbstractCallback)[source]¶ Register Callback to Trainer
- Parameters
callback (
AbstractCallback
) – the callback to register- Raises
AssertionError – callback is not an instance of
AbstractCallback
and has not both methods [‘at_epoch_begin’, ‘at_epoch_end’]
-
save_state
(file_name, *args, **kwargs)[source]¶ saves the current state
- Parameters
file_name (str) – filename to save the state to
*args – positional arguments
**kwargs – keyword arguments
-
train
(num_epochs, datamgr_train, datamgr_valid=None, val_score_key=None, val_score_mode='highest', reduce_mode='mean', verbose=True)[source]¶ Defines a routine to train a specified number of epochs
- Parameters
num_epochs (int) – number of epochs to train
datamgr_train (DataManager) – the datamanager holding the train data
datamgr_valid (DataManager) – the datamanager holding the validation data (default: None)
val_score_key (str) – the key specifying which metric to use for validation (default: None)
val_score_mode (str) – key specifying what kind of validation score is best
reduce_mode (str) – ‘mean’,’sum’,’first_only’
verbose (bool) – whether to show progress bars or not
- Raises
NotImplementedError – If not overwritten by subclass
-
PyTorchNetworkTrainer¶
-
class
PyTorchNetworkTrainer
(network: delira.models.abstract_network.AbstractPyTorchNetwork, save_path: str, key_mapping, losses=None, optimizer_cls=None, optimizer_params=None, train_metrics=None, val_metrics=None, lr_scheduler_cls=None, lr_scheduler_params=None, gpu_ids=None, save_freq=1, optim_fn=<function create_optims_default_pytorch>, logging_type='tensorboardx', logging_kwargs=None, fold=0, callbacks=None, start_epoch=1, metric_keys=None, convert_batch_to_npy_fn=<function convert_torch_tensor_to_npy>, mixed_precision=False, mixed_precision_kwargs=None, criterions=None, val_freq=1, **kwargs)[source]¶ Bases:
delira.training.base_trainer.BaseNetworkTrainer
Train and Validate a Network
See also
AbstractNetwork
-
_BaseNetworkTrainer__KEYS_TO_GUARD
= ['use_gpu', 'input_device', 'output_device', '_callbacks']¶
-
_Predictor__KEYS_TO_GUARD
= []¶
-
static
_Predictor__concatenate_dict_items
(dict_like: dict)¶ Function to recursively concatenate dict-items
- Parameters
dict_like (dict) – the (nested) dict, whoose items should be concatenated
-
static
_Predictor__convert_dict
(old_dict, new_dict)¶ Function to recursively convert dicts
-
_at_epoch_begin
(metrics_val, val_score_key, epoch, num_epochs, **kwargs)¶ Defines behaviour at beginning of each epoch: Executes all callbacks’s at_epoch_begin method
-
_at_epoch_end
(metrics_val, val_score_key, epoch, is_best, **kwargs)[source]¶ Defines behaviour at beginning of each epoch: Executes all callbacks’s at_epoch_end method and saves current state if necessary
-
_at_training_begin
(*args, **kwargs)[source]¶ Defines behaviour at beginning of training
- Parameters
*args – positional arguments
**kwargs – keyword arguments
-
_at_training_end
()[source]¶ Defines Behaviour at end of training: Loads best model if available
- Returns
best network
- Return type
AbstractPyTorchNetwork
-
static
_is_better_val_scores
(old_val_score, new_val_score, mode='highest')¶ Check whether the new val score is better than the old one with respect to the optimization goal
-
_reinitialize_logging
(logging_type, logging_kwargs: dict)¶
-
static
_search_for_prev_state
(path, extensions=None)¶ Helper function to search in a given path for previous epoch states (indicated by extensions)
- Parameters
- Returns
str – the file containing the latest checkpoint (if available)
None – if no latst checkpoint was found
int – the latest epoch (1 if no checkpoint was found)
-
_setup
(network, optim_fn, optimizer_cls, optimizer_params, lr_scheduler_cls, lr_scheduler_params, gpu_ids, key_mapping, convert_batch_to_npy_fn, mixed_precision, mixed_precision_kwargs)[source]¶ Defines the Trainers Setup
- Parameters
network (
AbstractPyTorchNetwork
) – the network to trainoptim_fn (function) – creates a dictionary containing all necessary optimizers
optimizer_cls (subclass of torch.optim.Optimizer) – optimizer class implementing the optimization algorithm of choice
optimizer_params (dict) –
lr_scheduler_cls (Any) – learning rate schedule class: must implement step() method
lr_scheduler_params (dict) – keyword arguments passed to lr scheduler during construction
gpu_ids (list) – list containing ids of GPUs to use; if empty: use cpu instead
convert_batch_to_npy_fn (type) – function converting a batch-tensor to numpy
mixed_precision (bool) – whether to use mixed precision or not (False per default)
mixed_precision_kwargs (dict) – additional keyword arguments for mixed precision
-
_train_single_epoch
(batchgen: batchgenerators.dataloading.MultiThreadedAugmenter, epoch, verbose=False)[source]¶ Trains the network a single epoch
- Parameters
batchgen (MultiThreadedAugmenter) – Generator yielding the training batches
epoch (int) – current epoch
-
_update_state
(new_state)[source]¶ Update the state from a given new state
- Parameters
new_state (dict) – new state to update internal state from
- Returns
the trainer with a modified state
- Return type
-
static
calc_metrics
(batch: delira.utils.config.LookupConfig, metrics=None, metric_keys=None)¶ Compute metrics
- Parameters
batch (LookupConfig) – dictionary containing the whole batch (including predictions)
metrics (dict) – dict with metrics
metric_keys (dict) – dict of tuples which contains hashables for specifying the items to use for calculating the respective metric. If not specified for a metric, the keys “pred” and “label” are used per default
- Returns
dict with metric results
- Return type
-
static
load_state
(file_name, **kwargs)[source]¶ Loads the new state from file via
delira.io.torch.load_checkpoint()
-
predict
(data: dict, **kwargs)¶ Predict single batch Returns the predictions corresponding to the given data obtained by the model
-
predict_data_mgr
(datamgr, batchsize=None, metrics=None, metric_keys=None, verbose=False, **kwargs)[source]¶ Defines a routine to predict data obtained from a batchgenerator
- Parameters
datamgr (
BaseDataManager
) – Manager producing a generator holding the batchesbatchsize (int) – Artificial batchsize (sampling will be done with batchsize 1 and sampled data will be stacked to match the artificial batchsize)(default: None)
metrics (dict) – the metrics to calculate
metric_keys (dict) – the
batch_dict
items to use for metric calculationverbose (bool) – whether to show a progress-bar or not, default: False
**kwargs – additional keyword arguments
- Returns
dict – predictions
dict – calculated metrics
-
predict_data_mgr_cache
(datamgr, batchsize=None, metrics=None, metric_keys=None, verbose=False, cache_preds=False, **kwargs)¶ Defines a routine to predict data obtained from a batchgenerator and caches all predictions and metrics (yields them in dicts)
- Parameters
datamgr (
BaseDataManager
) – Manager producing a generator holding the batchesbatchsize (int) – Artificial batchsize (sampling will be done with batchsize 1 and sampled data will be stacked to match the artificial batchsize)(default: None)
metrics (dict) – the metrics to calculate
metric_keys (dict) – the
batch_dict
items to use for metric calculationverbose (bool) – whether to show a progress-bar or not, default: False
cache_preds (bool) – whether to also cache predictions
kwargs – keyword arguments passed to
prepare_batch_fn()
- Yields
dict – a dictionary containing all validation metrics (maybe empty)
dict – a dictionary containing all predictions; If
cache_preds=True
Warning
Since this function caches all metrics and may additionally cache all predictions (based on the argument
cache_preds
), this may result in huge memory consumption. If you are running out of memory, please have a look atPredictor.predict_data_mgr_cache_metrics_only()
orPredictor.predict_data_mgr()
or consider settingcache_preds
toFalse
(if not done already)
-
predict_data_mgr_cache_all
(datamgr, batchsize=None, metrics=None, metric_keys=None, verbose=False, **kwargs)¶ Defines a routine to predict data obtained from a batchgenerator and caches all predictions and metrics (yields them in dicts)
- Parameters
datamgr (
BaseDataManager
) – Manager producing a generator holding the batchesbatchsize (int) – Artificial batchsize (sampling will be done with batchsize 1 and sampled data will be stacked to match the artificial batchsize)(default: None)
metrics (dict) – the metrics to calculate
metric_keys (dict) – the
batch_dict
items to use for metric calculationverbose (bool) – whether to show a progress-bar or not, default: False
kwargs – keyword arguments passed to
prepare_batch_fn()
- Yields
dict – a dictionary containing all predictions;
dict – a dictionary containing all validation metrics (maybe empty)
Warning
Since this function caches all predictions and metrics, this may result in huge memory consumption. If you are running out of memory, please have a look at
Predictor.predict_data_mgr_cache_metrics_only()
orPredictor.predict_data_mgr()
-
predict_data_mgr_cache_metrics_only
(datamgr, batchsize=None, metrics=None, metric_keys=None, verbose=False, **kwargs)¶ Defines a routine to predict data obtained from a batchgenerator and caches the metrics
- Parameters
datamgr (
BaseDataManager
) – Manager producing a generator holding the batchesbatchsize (int) – Artificial batchsize (sampling will be done with batchsize 1 and sampled data will be stacked to match the artificial batchsize)(default: None)
metrics (dict) – the metrics to calculate
metric_keys (dict) – the
batch_dict
items to use for metric calculationverbose (bool) – whether to show a progress-bar or not, default: False
kwargs – keyword arguments passed to
prepare_batch_fn()
- Yields
dict – a dictionary containing all validation metrics (maybe empty)
Notes
This function stores each prediction temporarily for metric calculation; This results in a (typically) way lower memory consumption than
Predictor.predict_data_mgr_cache_all()
, but still caches the metrics. If this is not desired, it is recommended to usePredictor.predict_data_mgr()
and iterate over the generator as this only produces per-batch metrics and predictions and does not cache anything by default
-
register_callback
(callback: delira.training.callbacks.abstract_callback.AbstractCallback)¶ Register Callback to Trainer
- Parameters
callback (
AbstractCallback
) – the callback to register- Raises
AssertionError – callback is not an instance of
AbstractCallback
and has not both methods [‘at_epoch_begin’, ‘at_epoch_end’]
-
save_state
(file_name, epoch, **kwargs)[source]¶ saves the current state via
delira.io.torch.save_checkpoint()
-
train
(num_epochs, datamgr_train, datamgr_valid=None, val_score_key=None, val_score_mode='highest', reduce_mode='mean', verbose=True)¶ Defines a routine to train a specified number of epochs
- Parameters
num_epochs (int) – number of epochs to train
datamgr_train (DataManager) – the datamanager holding the train data
datamgr_valid (DataManager) – the datamanager holding the validation data (default: None)
val_score_key (str) – the key specifying which metric to use for validation (default: None)
val_score_mode (str) – key specifying what kind of validation score is best
reduce_mode (str) – ‘mean’,’sum’,’first_only’
verbose (bool) – whether to show progress bars or not
- Raises
NotImplementedError – If not overwritten by subclass
-
TfNetworkTrainer¶
-
class
TfNetworkTrainer
(network, save_path, key_mapping, losses: dict, optimizer_cls, optimizer_params=None, train_metrics=None, val_metrics=None, lr_scheduler_cls=None, lr_scheduler_params=None, gpu_ids=None, save_freq=1, optim_fn=<function create_optims_default_tf>, logging_type='tensorboardx', logging_kwargs=None, fold=0, callbacks=None, start_epoch=1, metric_keys=None, convert_batch_to_npy_fn=<function convert_tf_tensor_to_npy>, val_freq=1, **kwargs)[source]¶ Bases:
delira.training.base_trainer.BaseNetworkTrainer
Train and Validate a Network
See also
AbstractNetwork
-
_BaseNetworkTrainer__KEYS_TO_GUARD
= ['use_gpu', 'input_device', 'output_device', '_callbacks']¶
-
_Predictor__KEYS_TO_GUARD
= []¶
-
static
_Predictor__concatenate_dict_items
(dict_like: dict)¶ Function to recursively concatenate dict-items
- Parameters
dict_like (dict) – the (nested) dict, whoose items should be concatenated
-
static
_Predictor__convert_dict
(old_dict, new_dict)¶ Function to recursively convert dicts
-
_at_epoch_begin
(metrics_val, val_score_key, epoch, num_epochs, **kwargs)¶ Defines behaviour at beginning of each epoch: Executes all callbacks’s at_epoch_begin method
-
_at_epoch_end
(metrics_val, val_score_key, epoch, is_best, **kwargs)¶ Defines behaviour at beginning of each epoch: Executes all callbacks’s at_epoch_end method and saves current state if necessary
-
_at_training_begin
(*args, **kwargs)¶ Defines the behaviour at beginnig of the training
- Parameters
*args – positional arguments
**kwargs – keyword arguments
- Raises
NotImplementedError – If not overwritten by subclass
-
_at_training_end
()[source]¶ Defines Behaviour at end of training: Loads best model if available
- Returns
best network
- Return type
AbstractTfNetwork
-
static
_is_better_val_scores
(old_val_score, new_val_score, mode='highest')¶ Check whether the new val score is better than the old one with respect to the optimization goal
-
_reinitialize_logging
(logging_type, logging_kwargs: dict)¶
-
static
_search_for_prev_state
(path, extensions=None)¶ Helper function to search in a given path for previous epoch states (indicated by extensions)
- Parameters
- Returns
str – the file containing the latest checkpoint (if available)
None – if no latst checkpoint was found
int – the latest epoch (1 if no checkpoint was found)
-
_setup
(network, optim_fn, optimizer_cls, optimizer_params, lr_scheduler_cls, lr_scheduler_params, key_mapping, convert_batch_to_npy_fn, gpu_ids)[source]¶ Defines the Trainers Setup
- Parameters
network (instance of :class: AbstractTfNetwork) – the network to train
optim_fn (function) – creates a dictionary containing all necessary optimizers
optimizer_cls (subclass of tf.train.Optimizer) – optimizer class implementing the optimization algorithm of choice
optimizer_params (dict) –
lr_scheduler_cls (Any) – learning rate schedule class: must implement step() method
lr_scheduler_params (dict) – keyword arguments passed to lr scheduler during construction
convert_batch_to_npy_fn (type, optional) – function converting a batch-tensor to numpy, per default this is the identity function
gpu_ids (list) – list containing ids of GPUs to use; if empty: use cpu instead
-
_train_single_epoch
(batchgen: batchgenerators.dataloading.MultiThreadedAugmenter, epoch, verbose=False)[source]¶ Trains the network a single epoch
- Parameters
batchgen (MultiThreadedAugmenter) – Generator yielding the training batches
epoch (int) – current epoch
-
_update_state
(new_state)¶ Update the state from a given new state
- Parameters
new_state (dict) – new state to update internal state from
- Returns
the trainer with a modified state
- Return type
-
static
calc_metrics
(batch: delira.utils.config.LookupConfig, metrics=None, metric_keys=None)¶ Compute metrics
- Parameters
batch (LookupConfig) – dictionary containing the whole batch (including predictions)
metrics (dict) – dict with metrics
metric_keys (dict) – dict of tuples which contains hashables for specifying the items to use for calculating the respective metric. If not specified for a metric, the keys “pred” and “label” are used per default
- Returns
dict with metric results
- Return type
-
load_state
(file_name, *args, **kwargs)[source]¶ Loads the new state from file via
delira.io.tf.load_checkpoint()
- Parameters
file_name (str) – the file to load the state from
-
predict
(data: dict, **kwargs)¶ Predict single batch Returns the predictions corresponding to the given data obtained by the model
-
predict_data_mgr
(datamgr, batch_size=None, metrics=None, metric_keys=None, verbose=False, **kwargs)[source]¶ Defines a routine to predict data obtained from a batchgenerator
- Parameters
datamgr (
BaseDataManager
) – Manager producing a generator holding the batchesbatch_size (int) – Artificial batchsize (sampling will be done with batchsize 1 and sampled data will be stacked to match the artificial batchsize)(default: None)
metrics (dict) – the metrics to calculate
metric_keys (dict) – the
batch_dict
items to use for metric calculationverbose (bool) – whether to show a progress-bar or not, default: False
**kwargs – additional keword arguments
-
predict_data_mgr_cache
(datamgr, batchsize=None, metrics=None, metric_keys=None, verbose=False, cache_preds=False, **kwargs)¶ Defines a routine to predict data obtained from a batchgenerator and caches all predictions and metrics (yields them in dicts)
- Parameters
datamgr (
BaseDataManager
) – Manager producing a generator holding the batchesbatchsize (int) – Artificial batchsize (sampling will be done with batchsize 1 and sampled data will be stacked to match the artificial batchsize)(default: None)
metrics (dict) – the metrics to calculate
metric_keys (dict) – the
batch_dict
items to use for metric calculationverbose (bool) – whether to show a progress-bar or not, default: False
cache_preds (bool) – whether to also cache predictions
kwargs – keyword arguments passed to
prepare_batch_fn()
- Yields
dict – a dictionary containing all validation metrics (maybe empty)
dict – a dictionary containing all predictions; If
cache_preds=True
Warning
Since this function caches all metrics and may additionally cache all predictions (based on the argument
cache_preds
), this may result in huge memory consumption. If you are running out of memory, please have a look atPredictor.predict_data_mgr_cache_metrics_only()
orPredictor.predict_data_mgr()
or consider settingcache_preds
toFalse
(if not done already)
-
predict_data_mgr_cache_all
(datamgr, batchsize=None, metrics=None, metric_keys=None, verbose=False, **kwargs)¶ Defines a routine to predict data obtained from a batchgenerator and caches all predictions and metrics (yields them in dicts)
- Parameters
datamgr (
BaseDataManager
) – Manager producing a generator holding the batchesbatchsize (int) – Artificial batchsize (sampling will be done with batchsize 1 and sampled data will be stacked to match the artificial batchsize)(default: None)
metrics (dict) – the metrics to calculate
metric_keys (dict) – the
batch_dict
items to use for metric calculationverbose (bool) – whether to show a progress-bar or not, default: False
kwargs – keyword arguments passed to
prepare_batch_fn()
- Yields
dict – a dictionary containing all predictions;
dict – a dictionary containing all validation metrics (maybe empty)
Warning
Since this function caches all predictions and metrics, this may result in huge memory consumption. If you are running out of memory, please have a look at
Predictor.predict_data_mgr_cache_metrics_only()
orPredictor.predict_data_mgr()
-
predict_data_mgr_cache_metrics_only
(datamgr, batchsize=None, metrics=None, metric_keys=None, verbose=False, **kwargs)¶ Defines a routine to predict data obtained from a batchgenerator and caches the metrics
- Parameters
datamgr (
BaseDataManager
) – Manager producing a generator holding the batchesbatchsize (int) – Artificial batchsize (sampling will be done with batchsize 1 and sampled data will be stacked to match the artificial batchsize)(default: None)
metrics (dict) – the metrics to calculate
metric_keys (dict) – the
batch_dict
items to use for metric calculationverbose (bool) – whether to show a progress-bar or not, default: False
kwargs – keyword arguments passed to
prepare_batch_fn()
- Yields
dict – a dictionary containing all validation metrics (maybe empty)
Notes
This function stores each prediction temporarily for metric calculation; This results in a (typically) way lower memory consumption than
Predictor.predict_data_mgr_cache_all()
, but still caches the metrics. If this is not desired, it is recommended to usePredictor.predict_data_mgr()
and iterate over the generator as this only produces per-batch metrics and predictions and does not cache anything by default
-
register_callback
(callback: delira.training.callbacks.abstract_callback.AbstractCallback)¶ Register Callback to Trainer
- Parameters
callback (
AbstractCallback
) – the callback to register- Raises
AssertionError – callback is not an instance of
AbstractCallback
and has not both methods [‘at_epoch_begin’, ‘at_epoch_end’]
-
save_state
(file_name, *args, **kwargs)[source]¶ saves the current state via
delira.io.tf.save_checkpoint()
- Parameters
file_name (str) – filename to save the state to
-
train
(num_epochs, datamgr_train, datamgr_valid=None, val_score_key=None, val_score_mode='highest', reduce_mode='mean', verbose=True)¶ Defines a routine to train a specified number of epochs
- Parameters
num_epochs (int) – number of epochs to train
datamgr_train (DataManager) – the datamanager holding the train data
datamgr_valid (DataManager) – the datamanager holding the validation data (default: None)
val_score_key (str) – the key specifying which metric to use for validation (default: None)
val_score_mode (str) – key specifying what kind of validation score is best
reduce_mode (str) – ‘mean’,’sum’,’first_only’
verbose (bool) – whether to show progress bars or not
- Raises
NotImplementedError – If not overwritten by subclass
-