NetworkTrainer

The network trainer implements the actual training routine and can be subclassed
for special routines.

Subclassing your trainer also means you have to subclass your experiment (to use the trainer).

AbstractNetworkTrainer

class AbstractNetworkTrainer(fold=0, callbacks=[])[source]

Bases: object

Defines an abstract API for Network Trainers

_at_epoch_begin(*args, **kwargs)[source]

Defines the behaviour at beginnig of each epoch

Parameters:
  • *args – positional arguments
  • **kwargs – keyword arguments
Raises:

NotImplementedError – If not overwritten by subclass

_at_epoch_end(*args, **kwargs)[source]

Defines the behaviour at the end of each epoch

Parameters:
  • *args – positional arguments
  • **kwargs – keyword arguments
Raises:

NotImplementedError – If not overwritten by subclass

_at_training_begin(*args, **kwargs)[source]

Defines the behaviour at beginnig of the training

Parameters:
  • *args – positional arguments
  • **kwargs – keyword arguments
Raises:

NotImplementedError – If not overwritten by subclass

_at_training_end(*args, **kwargs)[source]

Defines the behaviour at the end of the training

Parameters:
  • *args – positional arguments
  • **kwargs – keyword arguments
Raises:

NotImplementedError – If not overwritten by subclass

static _is_better_val_scores(old_val_score, new_val_score, mode='highest')[source]

Check whether the new val score is better than the old one with respect to the optimization goal

Parameters:
  • old_val_score – old validation score
  • new_val_score – new validation score
  • mode (str) – String to specify whether a higher or lower validation score is optimal; must be in [‘highest’, ‘lowest’]
Returns:

True if new score is better, False otherwise

Return type:

bool

_setup(*args, **kwargs)[source]

Defines the actual Trainer Setup

Parameters:
  • *args – positional arguments
  • **kwargs – keyword arguments
Raises:

NotImplementedError – If not overwritten by subclass

_train_single_epoch(batchgen: <sphinx.ext.autodoc.importer._MockObject object at 0x7fc9644f1b38>, epoch)[source]

Defines a routine to train a single epoch

Parameters:
  • batchgen (MultiThreadedAugmenter) – generator holding the batches
  • epoch (int) – current epoch
Raises:

NotImplementedError – If not overwritten by subclass

_update_state(new_state)[source]

Update the state from a given new state

Parameters:new_state (dict) – new state to update internal state from
Returns:the trainer with a modified state
Return type:AbstractNetworkTrainer
fold

Get current fold

Returns:current fold
Return type:int
static load_state(file_name, *args, **kwargs)[source]

Loads the new state from file

Parameters:
  • file_name (str) – the file to load the state from
  • *args – positional arguments
  • **kwargs (keyword arguments) –
Returns:

new state

Return type:

dict

predict(batchgen, batchsize=None)[source]

Defines a rotine to predict data obtained from a batchgenerator

Parameters:
  • batchgen (MultiThreadedAugmenter) – Generator Holding the Batches
  • batchsize (Artificial batchsize (sampling will be done with batchsize) – 1 and sampled data will be stacked to match the artificial batchsize)(default: None)
Raises:

NotImplementedError – If not overwritten by subclass

register_callback(callback: delira.training.callbacks.abstract_callback.AbstractCallback)[source]

Register Callback to Trainer

Parameters:callback (AbstractCallback) – the callback to register
Raises:AssertionErrorcallback is not an instance of AbstractCallback and has not both methods [‘at_epoch_begin’, ‘at_epoch_end’]
save_state(file_name, *args, **kwargs)[source]

saves the current state

Parameters:
  • file_name (str) – filename to save the state to
  • *args – positional arguments
  • **kwargs – keyword arguments
train(num_epochs, datamgr_train, datamgr_valid=None, val_score_key=None, val_score_mode='highest')[source]

Defines a routine to train a specified number of epochs

Parameters:
  • num_epochs (int) – number of epochs to train
  • datamgr_train (DataManager) – the datamanager holding the train data
  • datamgr_valid (DataManager) – the datamanager holding the validation data (default: None)
  • val_score_key (str) – the key specifying which metric to use for validation (default: None)
  • val_score_mode (str) – key specifying what kind of validation score is best
Raises:

NotImplementedError – If not overwritten by subclass

update_state(file_name, *args, **kwargs)[source]

Update internal state from a loaded state

Parameters:
  • file_name (str) – file containing the new state to load
  • *args – positional arguments
  • **kwargs – keyword arguments
Returns:

the trainer with a modified state

Return type:

AbstractNetworkTrainer

PyTorchNetworkTrainer

class PyTorchNetworkTrainer(network, save_path, criterions: dict, optimizer_cls, optimizer_params={}, metrics={}, lr_scheduler_cls=None, lr_scheduler_params={}, gpu_ids=[], save_freq=1, optim_fn=<function create_optims_default_pytorch>, fold=0, callbacks=[], start_epoch=1, mixed_precision=False, mixed_precision_kwargs={'allow_banned': False, 'enable_caching': True, 'verbose': False}, **kwargs)[source]

Bases: delira.training.abstract_trainer.AbstractNetworkTrainer

Train and Validate a Network

See also

AbstractNetwork

_at_epoch_begin(metrics_val, val_score_key, epoch, num_epochs, **kwargs)[source]

Defines behaviour at beginning of each epoch: Executes all callbacks’s at_epoch_begin method

Parameters:
  • metrics_val (dict) – validation metrics
  • val_score_key (str) – validation score key
  • epoch (int) – current epoch
  • num_epochs (int) – total number of epochs
  • **kwargs – keyword arguments
_at_epoch_end(metrics_val, val_score_key, epoch, is_best, **kwargs)[source]

Defines behaviour at beginning of each epoch: Executes all callbacks’s at_epoch_end method and saves current state if necessary

Parameters:
  • metrics_val (dict) – validation metrics
  • val_score_key (str) – validation score key
  • epoch (int) – current epoch
  • num_epochs (int) – total number of epochs
  • is_best (bool) – whether current model is best one so far
  • **kwargs – keyword arguments
_at_training_begin(*args, **kwargs)[source]

Defines behaviour at beginning of training

Parameters:
  • *args – positional arguments
  • **kwargs – keyword arguments
_at_training_end()[source]

Defines Behaviour at end of training: Loads best model if available

Returns:best network
Return type:AbstractPyTorchNetwork
static _is_better_val_scores(old_val_score, new_val_score, mode='highest')

Check whether the new val score is better than the old one with respect to the optimization goal

Parameters:
  • old_val_score – old validation score
  • new_val_score – new validation score
  • mode (str) – String to specify whether a higher or lower validation score is optimal; must be in [‘highest’, ‘lowest’]
Returns:

True if new score is better, False otherwise

Return type:

bool

_setup(network, optim_fn, optimizer_cls, optimizer_params, lr_scheduler_cls, lr_scheduler_params, gpu_ids, mixed_precision, mixed_precision_kwargs)[source]

Defines the Trainers Setup

Parameters:
  • network (AbstractPyTorchNetwork) – the network to train
  • optim_fn (function) – creates a dictionary containing all necessary optimizers
  • optimizer_cls (subclass of torch.optim.Optimizer) – optimizer class implementing the optimization algorithm of choice
  • optimizer_params (dict) –
  • lr_scheduler_cls (Any) – learning rate schedule class: must implement step() method
  • lr_scheduler_params (dict) – keyword arguments passed to lr scheduler during construction
  • gpu_ids (list) – list containing ids of GPUs to use; if empty: use cpu instead
  • mixed_precision (bool) – whether to use mixed precision or not (False per default)
  • mixed_precision_kwargs (dict) – additional keyword arguments for mixed precision
_train_single_epoch(batchgen: <sphinx.ext.autodoc.importer._MockObject object at 0x7fc9647a5940>, epoch)[source]

Trains the network a single epoch

Parameters:
  • batchgen (MultiThreadedAugmenter) – Generator yielding the training batches
  • epoch (int) – current epoch
_update_state(new_state)[source]

Update the state from a given new state

Parameters:new_state (dict) – new state to update internal state from
Returns:the trainer with a modified state
Return type:PyTorchNetworkTrainer
fold

Get current fold

Returns:current fold
Return type:int
static load_state(file_name, **kwargs)[source]

Loads the new state from file via delira.io.torch.load_checkpoint()

Parameters:
  • file_name (str) – the file to load the state from
  • **kwargs (keyword arguments) –
Returns:

new state

Return type:

dict

predict(batchgen, batch_size=None)[source]

Returns predictions from network for batches from batchgen

Parameters:
  • batchgen (MultiThreadedAugmenter) – Generator yielding the batches to predict
  • batch_size (None or int) – if int: collect batches until batch_size is reached and forward them together
Returns:

  • np.ndarray – predictions from batches
  • list of np.ndarray – labels from batches
  • dict – dictionary containing the mean validation metrics and the mean loss values

register_callback(callback: delira.training.callbacks.abstract_callback.AbstractCallback)

Register Callback to Trainer

Parameters:callback (AbstractCallback) – the callback to register
Raises:AssertionErrorcallback is not an instance of AbstractCallback and has not both methods [‘at_epoch_begin’, ‘at_epoch_end’]
save_state(file_name, epoch, **kwargs)[source]

saves the current state via delira.io.torch.save_checkpoint()

Parameters:
  • file_name (str) – filename to save the state to
  • epoch (int) – current epoch (will be saved for mapping back)
  • *args – positional arguments
  • **kwargs – keyword arguments
train(num_epochs, datamgr_train, datamgr_valid=None, val_score_key=None, val_score_mode='highest')[source]

train network

Parameters:
  • num_epochs (int) – number of epochs to train
  • datamgr_train (BaseDataManager) – Data Manager to create Batch Generator for training
  • datamgr_valid (BaseDataManager) – Data Manager to create Batch Generator for validation
  • val_score_key (str) – Key of validation metric; must be key in self.metrics
  • val_score_mode (str) – String to specify whether a higher or lower validation score is optimal; must be in [‘highest’, ‘lowest’]
Returns:

Best model (if val_score_key is not a valid key the model of the last epoch will be returned)

Return type:

AbstractPyTorchNetwork

update_state(file_name, *args, **kwargs)[source]

Update internal state from a loaded state

Parameters:
  • file_name (str) – file containing the new state to load
  • *args – positional arguments
  • **kwargs – keyword arguments
Returns:

the trainer with a modified state

Return type:

AbstractNetworkTrainer

TfNetworkTrainer

class TfNetworkTrainer(network, save_path, losses: dict, optimizer_cls, optimizer_params={}, metrics={}, lr_scheduler_cls=None, lr_scheduler_params={}, gpu_ids=[], save_freq=1, optim_fn=<function create_optims_default_tf>, fold=0, callbacks=[], start_epoch=1, **kwargs)[source]

Bases: delira.training.abstract_trainer.AbstractNetworkTrainer

Train and Validate a Network

See also

AbstractNetwork

_at_epoch_begin(metrics_val, val_score_key, epoch, num_epochs, **kwargs)[source]

Defines behaviour at beginning of each epoch: Executes all callbacks’s at_epoch_begin method

Parameters:
  • metrics_val (dict) – validation metrics
  • val_score_key (str) – validation score key
  • epoch (int) – current epoch
  • num_epochs (int) – total number of epochs
  • **kwargs – keyword arguments
_at_epoch_end(metrics_val, val_score_key, epoch, is_best, **kwargs)[source]

Defines behaviour at beginning of each epoch: Executes all callbacks’s at_epoch_end method and saves current state if necessary

Parameters:
  • metrics_val (dict) – validation metrics
  • val_score_key (str) – validation score key
  • epoch (int) – current epoch
  • num_epochs (int) – total number of epochs
  • **kwargs – keyword arguments
_at_training_begin(*args, **kwargs)[source]

Defines behaviour at beginning of training

Parameters:
  • *args – positional arguments
  • **kwargs – keyword arguments
_at_training_end()[source]

Defines Behaviour at end of training: Loads best model if available

Returns:best network
Return type:AbstractTfNetwork
static _is_better_val_scores(old_val_score, new_val_score, mode='highest')

Check whether the new val score is better than the old one with respect to the optimization goal

Parameters:
  • old_val_score – old validation score
  • new_val_score – new validation score
  • mode (str) – String to specify whether a higher or lower validation score is optimal; must be in [‘highest’, ‘lowest’]
Returns:

True if new score is better, False otherwise

Return type:

bool

_setup(network, optim_fn, optimizer_cls, optimizer_params, lr_scheduler_cls, lr_scheduler_params, gpu_ids)[source]

Defines the Trainers Setup

Parameters:
  • network (instance of :class: AbstractTfNetwork) – the network to train
  • optim_fn (function) – creates a dictionary containing all necessary optimizers
  • optimizer_cls (subclass of tf.train.Optimizer) – optimizer class implementing the optimization algorithm of choice
  • optimizer_params (dict) –
  • lr_scheduler_cls (Any) – learning rate schedule class: must implement step() method
  • lr_scheduler_params (dict) – keyword arguments passed to lr scheduler during construction
  • gpu_ids (list) – list containing ids of GPUs to use; if empty: use cpu instead
_train_single_epoch(batchgen: <sphinx.ext.autodoc.importer._MockObject object at 0x7fc9648a5f60>, epoch)[source]

Trains the network a single epoch

Parameters:
  • batchgen (MultiThreadedAugmenter) – Generator yielding the training batches
  • epoch (int) – current epoch
_update_state(new_state)

Update the state from a given new state

Parameters:new_state (dict) – new state to update internal state from
Returns:the trainer with a modified state
Return type:AbstractNetworkTrainer
fold

Get current fold

Returns:current fold
Return type:int
load_state(file_name)[source]

Loads the new state from file via delira.io.tf.load_checkpoint()

Parameters:file_name (str) – the file to load the state from
predict(batchgen, batch_size=None)[source]

Returns predictions from network for batches from batchgen

Parameters:
  • batchgen (MultiThreadedAugmenter) – Generator yielding the batches to predict
  • batch_size (None or int) – if int: collect batches until batch_size is reached and forward them together
Returns:

  • np.ndarray – predictions from batches
  • list of np.ndarray – labels from batches
  • dict – dictionary containing the mean validation metrics and the mean loss values

register_callback(callback: delira.training.callbacks.abstract_callback.AbstractCallback)

Register Callback to Trainer

Parameters:callback (AbstractCallback) – the callback to register
Raises:AssertionErrorcallback is not an instance of AbstractCallback and has not both methods [‘at_epoch_begin’, ‘at_epoch_end’]
save_state(file_name)[source]

saves the current state via delira.io.tf.save_checkpoint()

Parameters:file_name (str) – filename to save the state to
train(num_epochs, datamgr_train, datamgr_valid=None, val_score_key=None, val_score_mode='highest')[source]

train network

Parameters:
  • num_epochs (int) – number of epochs to train
  • datamgr_train (BaseDataManager) – Data Manager to create Batch Generator for training
  • datamgr_valid (BaseDataManager) – Data Manager to create Batch Generator for validation
  • val_score_key (str) – Key of validation metric; must be key in self.metrics
  • val_score_mode (str) – String to specify whether a higher or lower validation score is optimal; must be in [‘highest’, ‘lowest’]
Returns:

Best model (if val_score_key is not a valid key the model of the last epoch will be returned)

Return type:

AbstractTfNetwork

update_state(file_name, *args, **kwargs)

Update internal state from a loaded state

Parameters:
  • file_name (str) – file containing the new state to load
  • *args – positional arguments
  • **kwargs – keyword arguments
Returns:

the trainer with a modified state

Return type:

AbstractNetworkTrainer