NetworkTrainer

The network trainer implements the actual training routine and can be subclassed

for special routines.

Subclassing your trainer also means you have to subclass your experiment (to use the trainer).

AbstractNetworkTrainer

class AbstractNetworkTrainer(fold=0, callbacks=[])[source]

Bases: object

Defines an abstract API for Network Trainers

_at_epoch_begin(*args, **kwargs)[source]

Defines the behaviour at beginnig of each epoch

Parameters
  • *args – positional arguments

  • **kwargs – keyword arguments

Raises

NotImplementedError – If not overwritten by subclass

_at_epoch_end(*args, **kwargs)[source]

Defines the behaviour at the end of each epoch

Parameters
  • *args – positional arguments

  • **kwargs – keyword arguments

Raises

NotImplementedError – If not overwritten by subclass

_at_training_begin(*args, **kwargs)[source]

Defines the behaviour at beginnig of the training

Parameters
  • *args – positional arguments

  • **kwargs – keyword arguments

Raises

NotImplementedError – If not overwritten by subclass

_at_training_end(*args, **kwargs)[source]

Defines the behaviour at the end of the training

Parameters
  • *args – positional arguments

  • **kwargs – keyword arguments

Raises

NotImplementedError – If not overwritten by subclass

static _is_better_val_scores(old_val_score, new_val_score, mode='highest')[source]

Check whether the new val score is better than the old one with respect to the optimization goal

Parameters
  • old_val_score – old validation score

  • new_val_score – new validation score

  • mode (str) – String to specify whether a higher or lower validation score is optimal; must be in [‘highest’, ‘lowest’]

Returns

True if new score is better, False otherwise

Return type

bool

_setup(*args, **kwargs)[source]

Defines the actual Trainer Setup

Parameters
  • *args – positional arguments

  • **kwargs – keyword arguments

Raises

NotImplementedError – If not overwritten by subclass

_train_single_epoch(batchgen: batchgenerators.dataloading.MultiThreadedAugmenter, epoch)[source]

Defines a routine to train a single epoch

Parameters
  • batchgen (MultiThreadedAugmenter) – generator holding the batches

  • epoch (int) – current epoch

Raises

NotImplementedError – If not overwritten by subclass

_update_state(new_state)[source]

Update the state from a given new state

Parameters

new_state (dict) – new state to update internal state from

Returns

the trainer with a modified state

Return type

AbstractNetworkTrainer

fold

Get current fold

Returns

current fold

Return type

int

static load_state(file_name, *args, **kwargs)[source]

Loads the new state from file

Parameters
  • file_name (str) – the file to load the state from

  • *args – positional arguments

  • **kwargs (keyword arguments) –

Returns

new state

Return type

dict

predict(batchgen, batchsize=None)[source]

Defines a rotine to predict data obtained from a batchgenerator

Parameters
  • batchgen (MultiThreadedAugmenter) – Generator Holding the Batches

  • batchsize (Artificial batchsize (sampling will be done with batchsize) – 1 and sampled data will be stacked to match the artificial batchsize)(default: None)

Raises

NotImplementedError – If not overwritten by subclass

register_callback(callback: delira.training.callbacks.abstract_callback.AbstractCallback)[source]

Register Callback to Trainer

Parameters

callback (AbstractCallback) – the callback to register

Raises

AssertionErrorcallback is not an instance of AbstractCallback and has not both methods [‘at_epoch_begin’, ‘at_epoch_end’]

save_state(file_name, *args, **kwargs)[source]

saves the current state

Parameters
  • file_name (str) – filename to save the state to

  • *args – positional arguments

  • **kwargs – keyword arguments

train(num_epochs, datamgr_train, datamgr_valid=None, val_score_key=None, val_score_mode='highest')[source]

Defines a routine to train a specified number of epochs

Parameters
  • num_epochs (int) – number of epochs to train

  • datamgr_train (DataManager) – the datamanager holding the train data

  • datamgr_valid (DataManager) – the datamanager holding the validation data (default: None)

  • val_score_key (str) – the key specifying which metric to use for validation (default: None)

  • val_score_mode (str) – key specifying what kind of validation score is best

Raises

NotImplementedError – If not overwritten by subclass

update_state(file_name, *args, **kwargs)[source]

Update internal state from a loaded state

Parameters
  • file_name (str) – file containing the new state to load

  • *args – positional arguments

  • **kwargs – keyword arguments

Returns

the trainer with a modified state

Return type

AbstractNetworkTrainer

PyTorchNetworkTrainer

class PyTorchNetworkTrainer(network, save_path, criterions: dict, optimizer_cls, optimizer_params={}, metrics={}, lr_scheduler_cls=None, lr_scheduler_params={}, gpu_ids=[], save_freq=1, optim_fn=<function create_optims_default_pytorch>, fold=0, callbacks=[], start_epoch=1, mixed_precision=False, mixed_precision_kwargs={'allow_banned': False, 'enable_caching': True, 'verbose': False}, **kwargs)[source]

Bases: delira.training.abstract_trainer.AbstractNetworkTrainer

Train and Validate a Network

See also

AbstractNetwork

_at_epoch_begin(metrics_val, val_score_key, epoch, num_epochs, **kwargs)[source]

Defines behaviour at beginning of each epoch: Executes all callbacks’s at_epoch_begin method

Parameters
  • metrics_val (dict) – validation metrics

  • val_score_key (str) – validation score key

  • epoch (int) – current epoch

  • num_epochs (int) – total number of epochs

  • **kwargs – keyword arguments

_at_epoch_end(metrics_val, val_score_key, epoch, is_best, **kwargs)[source]

Defines behaviour at beginning of each epoch: Executes all callbacks’s at_epoch_end method and saves current state if necessary

Parameters
  • metrics_val (dict) – validation metrics

  • val_score_key (str) – validation score key

  • epoch (int) – current epoch

  • num_epochs (int) – total number of epochs

  • is_best (bool) – whether current model is best one so far

  • **kwargs – keyword arguments

_at_training_begin(*args, **kwargs)[source]

Defines behaviour at beginning of training

Parameters
  • *args – positional arguments

  • **kwargs – keyword arguments

_at_training_end()[source]

Defines Behaviour at end of training: Loads best model if available

Returns

best network

Return type

AbstractPyTorchNetwork

static _is_better_val_scores(old_val_score, new_val_score, mode='highest')

Check whether the new val score is better than the old one with respect to the optimization goal

Parameters
  • old_val_score – old validation score

  • new_val_score – new validation score

  • mode (str) – String to specify whether a higher or lower validation score is optimal; must be in [‘highest’, ‘lowest’]

Returns

True if new score is better, False otherwise

Return type

bool

_setup(network, optim_fn, optimizer_cls, optimizer_params, lr_scheduler_cls, lr_scheduler_params, gpu_ids, mixed_precision, mixed_precision_kwargs)[source]

Defines the Trainers Setup

Parameters
  • network (AbstractPyTorchNetwork) – the network to train

  • optim_fn (function) – creates a dictionary containing all necessary optimizers

  • optimizer_cls (subclass of torch.optim.Optimizer) – optimizer class implementing the optimization algorithm of choice

  • optimizer_params (dict) –

  • lr_scheduler_cls (Any) – learning rate schedule class: must implement step() method

  • lr_scheduler_params (dict) – keyword arguments passed to lr scheduler during construction

  • gpu_ids (list) – list containing ids of GPUs to use; if empty: use cpu instead

  • mixed_precision (bool) – whether to use mixed precision or not (False per default)

  • mixed_precision_kwargs (dict) – additional keyword arguments for mixed precision

_train_single_epoch(batchgen: batchgenerators.dataloading.MultiThreadedAugmenter, epoch)[source]

Trains the network a single epoch

Parameters
  • batchgen (MultiThreadedAugmenter) – Generator yielding the training batches

  • epoch (int) – current epoch

_update_state(new_state)[source]

Update the state from a given new state

Parameters

new_state (dict) – new state to update internal state from

Returns

the trainer with a modified state

Return type

PyTorchNetworkTrainer

fold

Get current fold

Returns

current fold

Return type

int

static load_state(file_name, **kwargs)[source]

Loads the new state from file via delira.io.torch.load_checkpoint()

Parameters
  • file_name (str) – the file to load the state from

  • **kwargs (keyword arguments) –

Returns

new state

Return type

dict

predict(batchgen, batch_size=None)[source]

Returns predictions from network for batches from batchgen

Parameters
  • batchgen (MultiThreadedAugmenter) – Generator yielding the batches to predict

  • batch_size (None or int) – if int: collect batches until batch_size is reached and forward them together

Returns

  • np.ndarray – predictions from batches

  • list of np.ndarray – labels from batches

  • dict – dictionary containing the mean validation metrics and the mean loss values

register_callback(callback: delira.training.callbacks.abstract_callback.AbstractCallback)

Register Callback to Trainer

Parameters

callback (AbstractCallback) – the callback to register

Raises

AssertionErrorcallback is not an instance of AbstractCallback and has not both methods [‘at_epoch_begin’, ‘at_epoch_end’]

save_state(file_name, epoch, **kwargs)[source]

saves the current state via delira.io.torch.save_checkpoint()

Parameters
  • file_name (str) – filename to save the state to

  • epoch (int) – current epoch (will be saved for mapping back)

  • *args – positional arguments

  • **kwargs – keyword arguments

train(num_epochs, datamgr_train, datamgr_valid=None, val_score_key=None, val_score_mode='highest')[source]

train network

Parameters
  • num_epochs (int) – number of epochs to train

  • datamgr_train (BaseDataManager) – Data Manager to create Batch Generator for training

  • datamgr_valid (BaseDataManager) – Data Manager to create Batch Generator for validation

  • val_score_key (str) – Key of validation metric; must be key in self.metrics

  • val_score_mode (str) – String to specify whether a higher or lower validation score is optimal; must be in [‘highest’, ‘lowest’]

Returns

Best model (if val_score_key is not a valid key the model of the last epoch will be returned)

Return type

AbstractPyTorchNetwork

update_state(file_name, *args, **kwargs)[source]

Update internal state from a loaded state

Parameters
  • file_name (str) – file containing the new state to load

  • *args – positional arguments

  • **kwargs – keyword arguments

Returns

the trainer with a modified state

Return type

AbstractNetworkTrainer

TfNetworkTrainer

class TfNetworkTrainer(network, save_path, losses: dict, optimizer_cls, optimizer_params={}, metrics={}, lr_scheduler_cls=None, lr_scheduler_params={}, gpu_ids=[], save_freq=1, optim_fn=<function create_optims_default_tf>, fold=0, callbacks=[], start_epoch=1, **kwargs)[source]

Bases: delira.training.abstract_trainer.AbstractNetworkTrainer

Train and Validate a Network

See also

AbstractNetwork

_at_epoch_begin(metrics_val, val_score_key, epoch, num_epochs, **kwargs)[source]

Defines behaviour at beginning of each epoch: Executes all callbacks’s at_epoch_begin method

Parameters
  • metrics_val (dict) – validation metrics

  • val_score_key (str) – validation score key

  • epoch (int) – current epoch

  • num_epochs (int) – total number of epochs

  • **kwargs – keyword arguments

_at_epoch_end(metrics_val, val_score_key, epoch, is_best, **kwargs)[source]

Defines behaviour at beginning of each epoch: Executes all callbacks’s at_epoch_end method and saves current state if necessary

Parameters
  • metrics_val (dict) – validation metrics

  • val_score_key (str) – validation score key

  • epoch (int) – current epoch

  • num_epochs (int) – total number of epochs

  • **kwargs – keyword arguments

_at_training_begin(*args, **kwargs)[source]

Defines behaviour at beginning of training

Parameters
  • *args – positional arguments

  • **kwargs – keyword arguments

_at_training_end()[source]

Defines Behaviour at end of training: Loads best model if available

Returns

best network

Return type

AbstractTfNetwork

static _is_better_val_scores(old_val_score, new_val_score, mode='highest')

Check whether the new val score is better than the old one with respect to the optimization goal

Parameters
  • old_val_score – old validation score

  • new_val_score – new validation score

  • mode (str) – String to specify whether a higher or lower validation score is optimal; must be in [‘highest’, ‘lowest’]

Returns

True if new score is better, False otherwise

Return type

bool

_setup(network, optim_fn, optimizer_cls, optimizer_params, lr_scheduler_cls, lr_scheduler_params, gpu_ids)[source]

Defines the Trainers Setup

Parameters
  • network (instance of :class: AbstractTfNetwork) – the network to train

  • optim_fn (function) – creates a dictionary containing all necessary optimizers

  • optimizer_cls (subclass of tf.train.Optimizer) – optimizer class implementing the optimization algorithm of choice

  • optimizer_params (dict) –

  • lr_scheduler_cls (Any) – learning rate schedule class: must implement step() method

  • lr_scheduler_params (dict) – keyword arguments passed to lr scheduler during construction

  • gpu_ids (list) – list containing ids of GPUs to use; if empty: use cpu instead

_train_single_epoch(batchgen: batchgenerators.dataloading.MultiThreadedAugmenter, epoch)[source]

Trains the network a single epoch

Parameters
  • batchgen (MultiThreadedAugmenter) – Generator yielding the training batches

  • epoch (int) – current epoch

_update_state(new_state)

Update the state from a given new state

Parameters

new_state (dict) – new state to update internal state from

Returns

the trainer with a modified state

Return type

AbstractNetworkTrainer

fold

Get current fold

Returns

current fold

Return type

int

load_state(file_name)[source]

Loads the new state from file via delira.io.tf.load_checkpoint()

Parameters

file_name (str) – the file to load the state from

predict(batchgen, batch_size=None)[source]

Returns predictions from network for batches from batchgen

Parameters
  • batchgen (MultiThreadedAugmenter) – Generator yielding the batches to predict

  • batch_size (None or int) – if int: collect batches until batch_size is reached and forward them together

Returns

  • np.ndarray – predictions from batches

  • list of np.ndarray – labels from batches

  • dict – dictionary containing the mean validation metrics and the mean loss values

register_callback(callback: delira.training.callbacks.abstract_callback.AbstractCallback)

Register Callback to Trainer

Parameters

callback (AbstractCallback) – the callback to register

Raises

AssertionErrorcallback is not an instance of AbstractCallback and has not both methods [‘at_epoch_begin’, ‘at_epoch_end’]

save_state(file_name)[source]

saves the current state via delira.io.tf.save_checkpoint()

Parameters

file_name (str) – filename to save the state to

train(num_epochs, datamgr_train, datamgr_valid=None, val_score_key=None, val_score_mode='highest')[source]

train network

Parameters
  • num_epochs (int) – number of epochs to train

  • datamgr_train (BaseDataManager) – Data Manager to create Batch Generator for training

  • datamgr_valid (BaseDataManager) – Data Manager to create Batch Generator for validation

  • val_score_key (str) – Key of validation metric; must be key in self.metrics

  • val_score_mode (str) – String to specify whether a higher or lower validation score is optimal; must be in [‘highest’, ‘lowest’]

Returns

Best model (if val_score_key is not a valid key the model of the last epoch will be returned)

Return type

AbstractTfNetwork

update_state(file_name, *args, **kwargs)

Update internal state from a loaded state

Parameters
  • file_name (str) – file containing the new state to load

  • *args – positional arguments

  • **kwargs – keyword arguments

Returns

the trainer with a modified state

Return type

AbstractNetworkTrainer