NetworkTrainer

The network trainer implements the actual training routine and can be subclassed

for special routines.

BaseNetworkTrainer

class BaseNetworkTrainer(network: delira.models.abstract_network.AbstractNetwork, save_path: str, losses: dict, optimizer_cls: type, optimizer_params: dict, train_metrics: dict, val_metrics: dict, lr_scheduler_cls: type, lr_scheduler_params: dict, gpu_ids: List[int], save_freq: int, optim_fn, key_mapping: dict, logging_type: str, logging_kwargs: dict, fold: int, callbacks: List[delira.training.callbacks.abstract_callback.AbstractCallback], start_epoch=1, metric_keys=None, convert_batch_to_npy_fn=<function BaseNetworkTrainer.<lambda>>, val_freq=1, **kwargs)[source]

Bases: delira.training.predictor.Predictor

Defines a Base API and basic functions for Network Trainers

_BaseNetworkTrainer__KEYS_TO_GUARD = ['use_gpu', 'input_device', 'output_device', '_callbacks']
_Predictor__KEYS_TO_GUARD = []
static _Predictor__concatenate_dict_items(dict_like: dict)

Function to recursively concatenate dict-items

Parameters

dict_like (dict) – the (nested) dict, whoose items should be concatenated

static _Predictor__convert_dict(old_dict, new_dict)

Function to recursively convert dicts

Parameters
  • old_dict (dict) – the old nested dict

  • new_dict (dict) – the new nested dict

Returns

the updated new nested dict

Return type

dict

_at_epoch_begin(metrics_val, val_score_key, epoch, num_epochs, **kwargs)[source]

Defines behaviour at beginning of each epoch: Executes all callbacks’s at_epoch_begin method

Parameters
  • metrics_val (dict) – validation metrics

  • val_score_key (str) – validation score key

  • epoch (int) – current epoch

  • num_epochs (int) – total number of epochs

  • **kwargs – keyword arguments

_at_epoch_end(metrics_val, val_score_key, epoch, is_best, **kwargs)[source]

Defines behaviour at beginning of each epoch: Executes all callbacks’s at_epoch_end method and saves current state if necessary

Parameters
  • metrics_val (dict) – validation metrics

  • val_score_key (str) – validation score key

  • epoch (int) – current epoch

  • num_epochs (int) – total number of epochs

  • **kwargs – keyword arguments

_at_training_begin(*args, **kwargs)[source]

Defines the behaviour at beginnig of the training

Parameters
  • *args – positional arguments

  • **kwargs – keyword arguments

Raises

NotImplementedError – If not overwritten by subclass

_at_training_end(*args, **kwargs)[source]

Defines the behaviour at the end of the training

Parameters
  • *args – positional arguments

  • **kwargs – keyword arguments

Raises

NotImplementedError – If not overwritten by subclass

static _is_better_val_scores(old_val_score, new_val_score, mode='highest')[source]

Check whether the new val score is better than the old one with respect to the optimization goal

Parameters
  • old_val_score – old validation score

  • new_val_score – new validation score

  • mode (str) – String to specify whether a higher or lower validation score is optimal; must be in [‘highest’, ‘lowest’]

Returns

True if new score is better, False otherwise

Return type

bool

_reinitialize_logging(logging_type, logging_kwargs: dict)[source]
static _search_for_prev_state(path, extensions=None)[source]

Helper function to search in a given path for previous epoch states (indicated by extensions)

Parameters
  • path (str) – the path to search in

  • extensions (list) – list of strings containing valid file extensions for checkpoint files

Returns

  • str – the file containing the latest checkpoint (if available)

  • None – if no latst checkpoint was found

  • int – the latest epoch (1 if no checkpoint was found)

_setup(network, lr_scheduler_cls, lr_scheduler_params, gpu_ids, key_mapping, convert_batch_to_npy_fn, prepare_batch_fn)[source]
Parameters
  • network (AbstractNetwork) – the network to predict from

  • key_mapping (dict) – a dictionary containing the mapping from the data_dict to the actual model’s inputs. E.g. if a model accepts one input named ‘x’ and the data_dict contains one entry named ‘data’ this argument would have to be {'x': 'data'}

  • convert_batch_to_npy_fn (type) – a callable function to convert tensors in positional and keyword arguments to numpy

  • prepare_batch_fn (type) – function converting a batch-tensor to the framework specific tensor-type and pushing it to correct device, default: identity function

_train_single_epoch(batchgen: delira.data_loading.data_manager.Augmenter, epoch, verbose=False)[source]

Trains the network a single epoch

Parameters
  • batchgen (Augmenter) – Generator yielding the training batches

  • epoch (int) – current epoch

_update_state(new_state)[source]

Update the state from a given new state

Parameters

new_state (dict) – new state to update internal state from

Returns

the trainer with a modified state

Return type

BaseNetworkTrainer

static calc_metrics(batch: delira.utils.config.LookupConfig, metrics=None, metric_keys=None)

Compute metrics

Parameters
  • batch (LookupConfig) – dictionary containing the whole batch (including predictions)

  • metrics (dict) – dict with metrics

  • metric_keys (dict) – dict of tuples which contains hashables for specifying the items to use for calculating the respective metric. If not specified for a metric, the keys “pred” and “label” are used per default

Returns

dict with metric results

Return type

dict

property fold

Get current fold

Returns

current fold

Return type

int

static load_state(file_name, *args, **kwargs)[source]

Loads the new state from file

Parameters
  • file_name (str) – the file to load the state from

  • *args – positional arguments

  • **kwargs (keyword arguments) –

Returns

new state

Return type

dict

predict(data: dict, **kwargs)

Predict single batch Returns the predictions corresponding to the given data obtained by the model

Parameters
  • data (dict) – batch dictionary

  • **kwargs – keyword arguments(directly passed to prepare_batch)

Returns

predicted data

Return type

dict

predict_data_mgr(datamgr, batchsize=None, metrics=None, metric_keys=None, verbose=False, **kwargs)

Defines a routine to predict data obtained from a batchgenerator without explicitly caching anything

Parameters
  • datamgr (BaseDataManager) – Manager producing a generator holding the batches

  • batchsize (int) – Artificial batchsize (sampling will be done with batchsize 1 and sampled data will be stacked to match the artificial batchsize)(default: None)

  • metrics (dict) – the metrics to calculate

  • metric_keys (dict) – the batch_dict items to use for metric calculation

  • verbose (bool) – whether to show a progress-bar or not, default: False

  • kwargs – keyword arguments passed to prepare_batch_fn()

Yields
  • dict – a dictionary containing all predictions of the current batch

  • dict – a dictionary containing all metrics of the current batch

predict_data_mgr_cache(datamgr, batchsize=None, metrics=None, metric_keys=None, verbose=False, cache_preds=False, **kwargs)

Defines a routine to predict data obtained from a batchgenerator and caches all predictions and metrics (yields them in dicts)

Parameters
  • datamgr (BaseDataManager) – Manager producing a generator holding the batches

  • batchsize (int) – Artificial batchsize (sampling will be done with batchsize 1 and sampled data will be stacked to match the artificial batchsize)(default: None)

  • metrics (dict) – the metrics to calculate

  • metric_keys (dict) – the batch_dict items to use for metric calculation

  • verbose (bool) – whether to show a progress-bar or not, default: False

  • cache_preds (bool) – whether to also cache predictions

  • kwargs – keyword arguments passed to prepare_batch_fn()

Yields
  • dict – a dictionary containing all validation metrics (maybe empty)

  • dict – a dictionary containing all predictions; If cache_preds=True

Warning

Since this function caches all metrics and may additionally cache all predictions (based on the argument cache_preds), this may result in huge memory consumption. If you are running out of memory, please have a look at Predictor.predict_data_mgr_cache_metrics_only() or Predictor.predict_data_mgr() or consider setting cache_preds to False (if not done already)

predict_data_mgr_cache_all(datamgr, batchsize=None, metrics=None, metric_keys=None, verbose=False, **kwargs)

Defines a routine to predict data obtained from a batchgenerator and caches all predictions and metrics (yields them in dicts)

Parameters
  • datamgr (BaseDataManager) – Manager producing a generator holding the batches

  • batchsize (int) – Artificial batchsize (sampling will be done with batchsize 1 and sampled data will be stacked to match the artificial batchsize)(default: None)

  • metrics (dict) – the metrics to calculate

  • metric_keys (dict) – the batch_dict items to use for metric calculation

  • verbose (bool) – whether to show a progress-bar or not, default: False

  • kwargs – keyword arguments passed to prepare_batch_fn()

Yields
  • dict – a dictionary containing all predictions;

  • dict – a dictionary containing all validation metrics (maybe empty)

Warning

Since this function caches all predictions and metrics, this may result in huge memory consumption. If you are running out of memory, please have a look at Predictor.predict_data_mgr_cache_metrics_only() or Predictor.predict_data_mgr()

predict_data_mgr_cache_metrics_only(datamgr, batchsize=None, metrics=None, metric_keys=None, verbose=False, **kwargs)

Defines a routine to predict data obtained from a batchgenerator and caches the metrics

Parameters
  • datamgr (BaseDataManager) – Manager producing a generator holding the batches

  • batchsize (int) – Artificial batchsize (sampling will be done with batchsize 1 and sampled data will be stacked to match the artificial batchsize)(default: None)

  • metrics (dict) – the metrics to calculate

  • metric_keys (dict) – the batch_dict items to use for metric calculation

  • verbose (bool) – whether to show a progress-bar or not, default: False

  • kwargs – keyword arguments passed to prepare_batch_fn()

Yields

dict – a dictionary containing all validation metrics (maybe empty)

Notes

This function stores each prediction temporarily for metric calculation; This results in a (typically) way lower memory consumption than Predictor.predict_data_mgr_cache_all(), but still caches the metrics. If this is not desired, it is recommended to use Predictor.predict_data_mgr() and iterate over the generator as this only produces per-batch metrics and predictions and does not cache anything by default

register_callback(callback: delira.training.callbacks.abstract_callback.AbstractCallback)[source]

Register Callback to Trainer

Parameters

callback (AbstractCallback) – the callback to register

Raises

AssertionErrorcallback is not an instance of AbstractCallback and has not both methods [‘at_epoch_begin’, ‘at_epoch_end’]

save_state(file_name, *args, **kwargs)[source]

saves the current state

Parameters
  • file_name (str) – filename to save the state to

  • *args – positional arguments

  • **kwargs – keyword arguments

train(num_epochs, datamgr_train, datamgr_valid=None, val_score_key=None, val_score_mode='highest', reduce_mode='mean', verbose=True)[source]

Defines a routine to train a specified number of epochs

Parameters
  • num_epochs (int) – number of epochs to train

  • datamgr_train (DataManager) – the datamanager holding the train data

  • datamgr_valid (DataManager) – the datamanager holding the validation data (default: None)

  • val_score_key (str) – the key specifying which metric to use for validation (default: None)

  • val_score_mode (str) – key specifying what kind of validation score is best

  • reduce_mode (str) – ‘mean’,’sum’,’first_only’

  • verbose (bool) – whether to show progress bars or not

Raises

NotImplementedError – If not overwritten by subclass

update_state(file_name, *args, **kwargs)[source]

Update internal state from a loaded state

Parameters
  • file_name (str) – file containing the new state to load

  • *args – positional arguments

  • **kwargs – keyword arguments

Returns

the trainer with a modified state

Return type

BaseNetworkTrainer

PyTorchNetworkTrainer

class PyTorchNetworkTrainer(network: delira.models.abstract_network.AbstractPyTorchNetwork, save_path: str, key_mapping, losses=None, optimizer_cls=None, optimizer_params=None, train_metrics=None, val_metrics=None, lr_scheduler_cls=None, lr_scheduler_params=None, gpu_ids=None, save_freq=1, optim_fn=<function create_optims_default_pytorch>, logging_type='tensorboardx', logging_kwargs=None, fold=0, callbacks=None, start_epoch=1, metric_keys=None, convert_batch_to_npy_fn=<function convert_torch_tensor_to_npy>, mixed_precision=False, mixed_precision_kwargs=None, criterions=None, val_freq=1, **kwargs)[source]

Bases: delira.training.base_trainer.BaseNetworkTrainer

Train and Validate a Network

See also

AbstractNetwork

_BaseNetworkTrainer__KEYS_TO_GUARD = ['use_gpu', 'input_device', 'output_device', '_callbacks']
_Predictor__KEYS_TO_GUARD = []
static _Predictor__concatenate_dict_items(dict_like: dict)

Function to recursively concatenate dict-items

Parameters

dict_like (dict) – the (nested) dict, whoose items should be concatenated

static _Predictor__convert_dict(old_dict, new_dict)

Function to recursively convert dicts

Parameters
  • old_dict (dict) – the old nested dict

  • new_dict (dict) – the new nested dict

Returns

the updated new nested dict

Return type

dict

_at_epoch_begin(metrics_val, val_score_key, epoch, num_epochs, **kwargs)

Defines behaviour at beginning of each epoch: Executes all callbacks’s at_epoch_begin method

Parameters
  • metrics_val (dict) – validation metrics

  • val_score_key (str) – validation score key

  • epoch (int) – current epoch

  • num_epochs (int) – total number of epochs

  • **kwargs – keyword arguments

_at_epoch_end(metrics_val, val_score_key, epoch, is_best, **kwargs)[source]

Defines behaviour at beginning of each epoch: Executes all callbacks’s at_epoch_end method and saves current state if necessary

Parameters
  • metrics_val (dict) – validation metrics

  • val_score_key (str) – validation score key

  • epoch (int) – current epoch

  • num_epochs (int) – total number of epochs

  • is_best (bool) – whether current model is best one so far

  • **kwargs – keyword arguments

_at_training_begin(*args, **kwargs)[source]

Defines behaviour at beginning of training

Parameters
  • *args – positional arguments

  • **kwargs – keyword arguments

_at_training_end()[source]

Defines Behaviour at end of training: Loads best model if available

Returns

best network

Return type

AbstractPyTorchNetwork

static _is_better_val_scores(old_val_score, new_val_score, mode='highest')

Check whether the new val score is better than the old one with respect to the optimization goal

Parameters
  • old_val_score – old validation score

  • new_val_score – new validation score

  • mode (str) – String to specify whether a higher or lower validation score is optimal; must be in [‘highest’, ‘lowest’]

Returns

True if new score is better, False otherwise

Return type

bool

_reinitialize_logging(logging_type, logging_kwargs: dict)
static _search_for_prev_state(path, extensions=None)

Helper function to search in a given path for previous epoch states (indicated by extensions)

Parameters
  • path (str) – the path to search in

  • extensions (list) – list of strings containing valid file extensions for checkpoint files

Returns

  • str – the file containing the latest checkpoint (if available)

  • None – if no latst checkpoint was found

  • int – the latest epoch (1 if no checkpoint was found)

_setup(network, optim_fn, optimizer_cls, optimizer_params, lr_scheduler_cls, lr_scheduler_params, gpu_ids, key_mapping, convert_batch_to_npy_fn, mixed_precision, mixed_precision_kwargs)[source]

Defines the Trainers Setup

Parameters
  • network (AbstractPyTorchNetwork) – the network to train

  • optim_fn (function) – creates a dictionary containing all necessary optimizers

  • optimizer_cls (subclass of torch.optim.Optimizer) – optimizer class implementing the optimization algorithm of choice

  • optimizer_params (dict) –

  • lr_scheduler_cls (Any) – learning rate schedule class: must implement step() method

  • lr_scheduler_params (dict) – keyword arguments passed to lr scheduler during construction

  • gpu_ids (list) – list containing ids of GPUs to use; if empty: use cpu instead

  • convert_batch_to_npy_fn (type) – function converting a batch-tensor to numpy

  • mixed_precision (bool) – whether to use mixed precision or not (False per default)

  • mixed_precision_kwargs (dict) – additional keyword arguments for mixed precision

_train_single_epoch(batchgen: batchgenerators.dataloading.MultiThreadedAugmenter, epoch, verbose=False)[source]

Trains the network a single epoch

Parameters
  • batchgen (MultiThreadedAugmenter) – Generator yielding the training batches

  • epoch (int) – current epoch

_update_state(new_state)[source]

Update the state from a given new state

Parameters

new_state (dict) – new state to update internal state from

Returns

the trainer with a modified state

Return type

PyTorchNetworkTrainer

static calc_metrics(batch: delira.utils.config.LookupConfig, metrics=None, metric_keys=None)

Compute metrics

Parameters
  • batch (LookupConfig) – dictionary containing the whole batch (including predictions)

  • metrics (dict) – dict with metrics

  • metric_keys (dict) – dict of tuples which contains hashables for specifying the items to use for calculating the respective metric. If not specified for a metric, the keys “pred” and “label” are used per default

Returns

dict with metric results

Return type

dict

property fold

Get current fold

Returns

current fold

Return type

int

static load_state(file_name, **kwargs)[source]

Loads the new state from file via delira.io.torch.load_checkpoint()

Parameters
  • file_name (str) – the file to load the state from

  • **kwargs (keyword arguments) –

Returns

new state

Return type

dict

predict(data: dict, **kwargs)

Predict single batch Returns the predictions corresponding to the given data obtained by the model

Parameters
  • data (dict) – batch dictionary

  • **kwargs – keyword arguments(directly passed to prepare_batch)

Returns

predicted data

Return type

dict

predict_data_mgr(datamgr, batchsize=None, metrics=None, metric_keys=None, verbose=False, **kwargs)[source]

Defines a routine to predict data obtained from a batchgenerator

Parameters
  • datamgr (BaseDataManager) – Manager producing a generator holding the batches

  • batchsize (int) – Artificial batchsize (sampling will be done with batchsize 1 and sampled data will be stacked to match the artificial batchsize)(default: None)

  • metrics (dict) – the metrics to calculate

  • metric_keys (dict) – the batch_dict items to use for metric calculation

  • verbose (bool) – whether to show a progress-bar or not, default: False

  • **kwargs – additional keyword arguments

Returns

  • dict – predictions

  • dict – calculated metrics

predict_data_mgr_cache(datamgr, batchsize=None, metrics=None, metric_keys=None, verbose=False, cache_preds=False, **kwargs)

Defines a routine to predict data obtained from a batchgenerator and caches all predictions and metrics (yields them in dicts)

Parameters
  • datamgr (BaseDataManager) – Manager producing a generator holding the batches

  • batchsize (int) – Artificial batchsize (sampling will be done with batchsize 1 and sampled data will be stacked to match the artificial batchsize)(default: None)

  • metrics (dict) – the metrics to calculate

  • metric_keys (dict) – the batch_dict items to use for metric calculation

  • verbose (bool) – whether to show a progress-bar or not, default: False

  • cache_preds (bool) – whether to also cache predictions

  • kwargs – keyword arguments passed to prepare_batch_fn()

Yields
  • dict – a dictionary containing all validation metrics (maybe empty)

  • dict – a dictionary containing all predictions; If cache_preds=True

Warning

Since this function caches all metrics and may additionally cache all predictions (based on the argument cache_preds), this may result in huge memory consumption. If you are running out of memory, please have a look at Predictor.predict_data_mgr_cache_metrics_only() or Predictor.predict_data_mgr() or consider setting cache_preds to False (if not done already)

predict_data_mgr_cache_all(datamgr, batchsize=None, metrics=None, metric_keys=None, verbose=False, **kwargs)

Defines a routine to predict data obtained from a batchgenerator and caches all predictions and metrics (yields them in dicts)

Parameters
  • datamgr (BaseDataManager) – Manager producing a generator holding the batches

  • batchsize (int) – Artificial batchsize (sampling will be done with batchsize 1 and sampled data will be stacked to match the artificial batchsize)(default: None)

  • metrics (dict) – the metrics to calculate

  • metric_keys (dict) – the batch_dict items to use for metric calculation

  • verbose (bool) – whether to show a progress-bar or not, default: False

  • kwargs – keyword arguments passed to prepare_batch_fn()

Yields
  • dict – a dictionary containing all predictions;

  • dict – a dictionary containing all validation metrics (maybe empty)

Warning

Since this function caches all predictions and metrics, this may result in huge memory consumption. If you are running out of memory, please have a look at Predictor.predict_data_mgr_cache_metrics_only() or Predictor.predict_data_mgr()

predict_data_mgr_cache_metrics_only(datamgr, batchsize=None, metrics=None, metric_keys=None, verbose=False, **kwargs)

Defines a routine to predict data obtained from a batchgenerator and caches the metrics

Parameters
  • datamgr (BaseDataManager) – Manager producing a generator holding the batches

  • batchsize (int) – Artificial batchsize (sampling will be done with batchsize 1 and sampled data will be stacked to match the artificial batchsize)(default: None)

  • metrics (dict) – the metrics to calculate

  • metric_keys (dict) – the batch_dict items to use for metric calculation

  • verbose (bool) – whether to show a progress-bar or not, default: False

  • kwargs – keyword arguments passed to prepare_batch_fn()

Yields

dict – a dictionary containing all validation metrics (maybe empty)

Notes

This function stores each prediction temporarily for metric calculation; This results in a (typically) way lower memory consumption than Predictor.predict_data_mgr_cache_all(), but still caches the metrics. If this is not desired, it is recommended to use Predictor.predict_data_mgr() and iterate over the generator as this only produces per-batch metrics and predictions and does not cache anything by default

register_callback(callback: delira.training.callbacks.abstract_callback.AbstractCallback)

Register Callback to Trainer

Parameters

callback (AbstractCallback) – the callback to register

Raises

AssertionErrorcallback is not an instance of AbstractCallback and has not both methods [‘at_epoch_begin’, ‘at_epoch_end’]

save_state(file_name, epoch, **kwargs)[source]

saves the current state via delira.io.torch.save_checkpoint()

Parameters
  • file_name (str) – filename to save the state to

  • epoch (int) – current epoch (will be saved for mapping back)

  • **kwargs – keyword arguments

train(num_epochs, datamgr_train, datamgr_valid=None, val_score_key=None, val_score_mode='highest', reduce_mode='mean', verbose=True)

Defines a routine to train a specified number of epochs

Parameters
  • num_epochs (int) – number of epochs to train

  • datamgr_train (DataManager) – the datamanager holding the train data

  • datamgr_valid (DataManager) – the datamanager holding the validation data (default: None)

  • val_score_key (str) – the key specifying which metric to use for validation (default: None)

  • val_score_mode (str) – key specifying what kind of validation score is best

  • reduce_mode (str) – ‘mean’,’sum’,’first_only’

  • verbose (bool) – whether to show progress bars or not

Raises

NotImplementedError – If not overwritten by subclass

update_state(file_name, *args, **kwargs)[source]

Update internal state from a loaded state

Parameters
  • file_name (str) – file containing the new state to load

  • *args – positional arguments

  • **kwargs – keyword arguments

Returns

the trainer with a modified state

Return type

BaseNetworkTrainer

TfNetworkTrainer

class TfNetworkTrainer(network, save_path, key_mapping, losses: dict, optimizer_cls, optimizer_params=None, train_metrics=None, val_metrics=None, lr_scheduler_cls=None, lr_scheduler_params=None, gpu_ids=None, save_freq=1, optim_fn=<function create_optims_default_tf>, logging_type='tensorboardx', logging_kwargs=None, fold=0, callbacks=None, start_epoch=1, metric_keys=None, convert_batch_to_npy_fn=<function convert_tf_tensor_to_npy>, val_freq=1, **kwargs)[source]

Bases: delira.training.base_trainer.BaseNetworkTrainer

Train and Validate a Network

See also

AbstractNetwork

_BaseNetworkTrainer__KEYS_TO_GUARD = ['use_gpu', 'input_device', 'output_device', '_callbacks']
_Predictor__KEYS_TO_GUARD = []
static _Predictor__concatenate_dict_items(dict_like: dict)

Function to recursively concatenate dict-items

Parameters

dict_like (dict) – the (nested) dict, whoose items should be concatenated

static _Predictor__convert_dict(old_dict, new_dict)

Function to recursively convert dicts

Parameters
  • old_dict (dict) – the old nested dict

  • new_dict (dict) – the new nested dict

Returns

the updated new nested dict

Return type

dict

_at_epoch_begin(metrics_val, val_score_key, epoch, num_epochs, **kwargs)

Defines behaviour at beginning of each epoch: Executes all callbacks’s at_epoch_begin method

Parameters
  • metrics_val (dict) – validation metrics

  • val_score_key (str) – validation score key

  • epoch (int) – current epoch

  • num_epochs (int) – total number of epochs

  • **kwargs – keyword arguments

_at_epoch_end(metrics_val, val_score_key, epoch, is_best, **kwargs)

Defines behaviour at beginning of each epoch: Executes all callbacks’s at_epoch_end method and saves current state if necessary

Parameters
  • metrics_val (dict) – validation metrics

  • val_score_key (str) – validation score key

  • epoch (int) – current epoch

  • num_epochs (int) – total number of epochs

  • **kwargs – keyword arguments

_at_training_begin(*args, **kwargs)

Defines the behaviour at beginnig of the training

Parameters
  • *args – positional arguments

  • **kwargs – keyword arguments

Raises

NotImplementedError – If not overwritten by subclass

_at_training_end()[source]

Defines Behaviour at end of training: Loads best model if available

Returns

best network

Return type

AbstractTfNetwork

static _is_better_val_scores(old_val_score, new_val_score, mode='highest')

Check whether the new val score is better than the old one with respect to the optimization goal

Parameters
  • old_val_score – old validation score

  • new_val_score – new validation score

  • mode (str) – String to specify whether a higher or lower validation score is optimal; must be in [‘highest’, ‘lowest’]

Returns

True if new score is better, False otherwise

Return type

bool

_reinitialize_logging(logging_type, logging_kwargs: dict)
static _search_for_prev_state(path, extensions=None)

Helper function to search in a given path for previous epoch states (indicated by extensions)

Parameters
  • path (str) – the path to search in

  • extensions (list) – list of strings containing valid file extensions for checkpoint files

Returns

  • str – the file containing the latest checkpoint (if available)

  • None – if no latst checkpoint was found

  • int – the latest epoch (1 if no checkpoint was found)

_setup(network, optim_fn, optimizer_cls, optimizer_params, lr_scheduler_cls, lr_scheduler_params, key_mapping, convert_batch_to_npy_fn, gpu_ids)[source]

Defines the Trainers Setup

Parameters
  • network (instance of :class: AbstractTfNetwork) – the network to train

  • optim_fn (function) – creates a dictionary containing all necessary optimizers

  • optimizer_cls (subclass of tf.train.Optimizer) – optimizer class implementing the optimization algorithm of choice

  • optimizer_params (dict) –

  • lr_scheduler_cls (Any) – learning rate schedule class: must implement step() method

  • lr_scheduler_params (dict) – keyword arguments passed to lr scheduler during construction

  • convert_batch_to_npy_fn (type, optional) – function converting a batch-tensor to numpy, per default this is the identity function

  • gpu_ids (list) – list containing ids of GPUs to use; if empty: use cpu instead

_train_single_epoch(batchgen: batchgenerators.dataloading.MultiThreadedAugmenter, epoch, verbose=False)[source]

Trains the network a single epoch

Parameters
  • batchgen (MultiThreadedAugmenter) – Generator yielding the training batches

  • epoch (int) – current epoch

_update_state(new_state)

Update the state from a given new state

Parameters

new_state (dict) – new state to update internal state from

Returns

the trainer with a modified state

Return type

BaseNetworkTrainer

static calc_metrics(batch: delira.utils.config.LookupConfig, metrics=None, metric_keys=None)

Compute metrics

Parameters
  • batch (LookupConfig) – dictionary containing the whole batch (including predictions)

  • metrics (dict) – dict with metrics

  • metric_keys (dict) – dict of tuples which contains hashables for specifying the items to use for calculating the respective metric. If not specified for a metric, the keys “pred” and “label” are used per default

Returns

dict with metric results

Return type

dict

property fold

Get current fold

Returns

current fold

Return type

int

load_state(file_name, *args, **kwargs)[source]

Loads the new state from file via delira.io.tf.load_checkpoint()

Parameters

file_name (str) – the file to load the state from

predict(data: dict, **kwargs)

Predict single batch Returns the predictions corresponding to the given data obtained by the model

Parameters
  • data (dict) – batch dictionary

  • **kwargs – keyword arguments(directly passed to prepare_batch)

Returns

predicted data

Return type

dict

predict_data_mgr(datamgr, batch_size=None, metrics=None, metric_keys=None, verbose=False, **kwargs)[source]

Defines a routine to predict data obtained from a batchgenerator

Parameters
  • datamgr (BaseDataManager) – Manager producing a generator holding the batches

  • batch_size (int) – Artificial batchsize (sampling will be done with batchsize 1 and sampled data will be stacked to match the artificial batchsize)(default: None)

  • metrics (dict) – the metrics to calculate

  • metric_keys (dict) – the batch_dict items to use for metric calculation

  • verbose (bool) – whether to show a progress-bar or not, default: False

  • **kwargs – additional keword arguments

predict_data_mgr_cache(datamgr, batchsize=None, metrics=None, metric_keys=None, verbose=False, cache_preds=False, **kwargs)

Defines a routine to predict data obtained from a batchgenerator and caches all predictions and metrics (yields them in dicts)

Parameters
  • datamgr (BaseDataManager) – Manager producing a generator holding the batches

  • batchsize (int) – Artificial batchsize (sampling will be done with batchsize 1 and sampled data will be stacked to match the artificial batchsize)(default: None)

  • metrics (dict) – the metrics to calculate

  • metric_keys (dict) – the batch_dict items to use for metric calculation

  • verbose (bool) – whether to show a progress-bar or not, default: False

  • cache_preds (bool) – whether to also cache predictions

  • kwargs – keyword arguments passed to prepare_batch_fn()

Yields
  • dict – a dictionary containing all validation metrics (maybe empty)

  • dict – a dictionary containing all predictions; If cache_preds=True

Warning

Since this function caches all metrics and may additionally cache all predictions (based on the argument cache_preds), this may result in huge memory consumption. If you are running out of memory, please have a look at Predictor.predict_data_mgr_cache_metrics_only() or Predictor.predict_data_mgr() or consider setting cache_preds to False (if not done already)

predict_data_mgr_cache_all(datamgr, batchsize=None, metrics=None, metric_keys=None, verbose=False, **kwargs)

Defines a routine to predict data obtained from a batchgenerator and caches all predictions and metrics (yields them in dicts)

Parameters
  • datamgr (BaseDataManager) – Manager producing a generator holding the batches

  • batchsize (int) – Artificial batchsize (sampling will be done with batchsize 1 and sampled data will be stacked to match the artificial batchsize)(default: None)

  • metrics (dict) – the metrics to calculate

  • metric_keys (dict) – the batch_dict items to use for metric calculation

  • verbose (bool) – whether to show a progress-bar or not, default: False

  • kwargs – keyword arguments passed to prepare_batch_fn()

Yields
  • dict – a dictionary containing all predictions;

  • dict – a dictionary containing all validation metrics (maybe empty)

Warning

Since this function caches all predictions and metrics, this may result in huge memory consumption. If you are running out of memory, please have a look at Predictor.predict_data_mgr_cache_metrics_only() or Predictor.predict_data_mgr()

predict_data_mgr_cache_metrics_only(datamgr, batchsize=None, metrics=None, metric_keys=None, verbose=False, **kwargs)

Defines a routine to predict data obtained from a batchgenerator and caches the metrics

Parameters
  • datamgr (BaseDataManager) – Manager producing a generator holding the batches

  • batchsize (int) – Artificial batchsize (sampling will be done with batchsize 1 and sampled data will be stacked to match the artificial batchsize)(default: None)

  • metrics (dict) – the metrics to calculate

  • metric_keys (dict) – the batch_dict items to use for metric calculation

  • verbose (bool) – whether to show a progress-bar or not, default: False

  • kwargs – keyword arguments passed to prepare_batch_fn()

Yields

dict – a dictionary containing all validation metrics (maybe empty)

Notes

This function stores each prediction temporarily for metric calculation; This results in a (typically) way lower memory consumption than Predictor.predict_data_mgr_cache_all(), but still caches the metrics. If this is not desired, it is recommended to use Predictor.predict_data_mgr() and iterate over the generator as this only produces per-batch metrics and predictions and does not cache anything by default

register_callback(callback: delira.training.callbacks.abstract_callback.AbstractCallback)

Register Callback to Trainer

Parameters

callback (AbstractCallback) – the callback to register

Raises

AssertionErrorcallback is not an instance of AbstractCallback and has not both methods [‘at_epoch_begin’, ‘at_epoch_end’]

save_state(file_name, *args, **kwargs)[source]

saves the current state via delira.io.tf.save_checkpoint()

Parameters

file_name (str) – filename to save the state to

train(num_epochs, datamgr_train, datamgr_valid=None, val_score_key=None, val_score_mode='highest', reduce_mode='mean', verbose=True)

Defines a routine to train a specified number of epochs

Parameters
  • num_epochs (int) – number of epochs to train

  • datamgr_train (DataManager) – the datamanager holding the train data

  • datamgr_valid (DataManager) – the datamanager holding the validation data (default: None)

  • val_score_key (str) – the key specifying which metric to use for validation (default: None)

  • val_score_mode (str) – key specifying what kind of validation score is best

  • reduce_mode (str) – ‘mean’,’sum’,’first_only’

  • verbose (bool) – whether to show progress bars or not

Raises

NotImplementedError – If not overwritten by subclass

update_state(file_name, *args, **kwargs)

Update internal state from a loaded state

Parameters
  • file_name (str) – file containing the new state to load

  • *args – positional arguments

  • **kwargs – keyword arguments

Returns

the trainer with a modified state

Return type

BaseNetworkTrainer