NetworkTrainer¶
- The network trainer implements the actual training routine and can be subclassed
for special routines.
Subclassing your trainer also means you have to subclass your experiment (to use the trainer).
AbstractNetworkTrainer¶
-
class
AbstractNetworkTrainer
(fold=0, callbacks=[])[source]¶ Bases:
object
Defines an abstract API for Network Trainers
See also
-
_at_epoch_begin
(*args, **kwargs)[source]¶ Defines the behaviour at beginnig of each epoch
- Parameters
*args – positional arguments
**kwargs – keyword arguments
- Raises
NotImplementedError – If not overwritten by subclass
-
_at_epoch_end
(*args, **kwargs)[source]¶ Defines the behaviour at the end of each epoch
- Parameters
*args – positional arguments
**kwargs – keyword arguments
- Raises
NotImplementedError – If not overwritten by subclass
-
_at_training_begin
(*args, **kwargs)[source]¶ Defines the behaviour at beginnig of the training
- Parameters
*args – positional arguments
**kwargs – keyword arguments
- Raises
NotImplementedError – If not overwritten by subclass
-
_at_training_end
(*args, **kwargs)[source]¶ Defines the behaviour at the end of the training
- Parameters
*args – positional arguments
**kwargs – keyword arguments
- Raises
NotImplementedError – If not overwritten by subclass
-
static
_is_better_val_scores
(old_val_score, new_val_score, mode='highest')[source]¶ Check whether the new val score is better than the old one with respect to the optimization goal
-
_setup
(*args, **kwargs)[source]¶ Defines the actual Trainer Setup
- Parameters
*args – positional arguments
**kwargs – keyword arguments
- Raises
NotImplementedError – If not overwritten by subclass
-
_train_single_epoch
(batchgen: batchgenerators.dataloading.MultiThreadedAugmenter, epoch)[source]¶ Defines a routine to train a single epoch
- Parameters
batchgen (MultiThreadedAugmenter) – generator holding the batches
epoch (int) – current epoch
- Raises
NotImplementedError – If not overwritten by subclass
-
_update_state
(new_state)[source]¶ Update the state from a given new state
- Parameters
new_state (dict) – new state to update internal state from
- Returns
the trainer with a modified state
- Return type
-
predict
(batchgen, batchsize=None)[source]¶ Defines a rotine to predict data obtained from a batchgenerator
- Parameters
batchgen (MultiThreadedAugmenter) – Generator Holding the Batches
batchsize (Artificial batchsize (sampling will be done with batchsize) – 1 and sampled data will be stacked to match the artificial batchsize)(default: None)
- Raises
NotImplementedError – If not overwritten by subclass
-
register_callback
(callback: delira.training.callbacks.abstract_callback.AbstractCallback)[source]¶ Register Callback to Trainer
- Parameters
callback (
AbstractCallback
) – the callback to register- Raises
AssertionError – callback is not an instance of
AbstractCallback
and has not both methods [‘at_epoch_begin’, ‘at_epoch_end’]
-
save_state
(file_name, *args, **kwargs)[source]¶ saves the current state
- Parameters
file_name (str) – filename to save the state to
*args – positional arguments
**kwargs – keyword arguments
-
train
(num_epochs, datamgr_train, datamgr_valid=None, val_score_key=None, val_score_mode='highest')[source]¶ Defines a routine to train a specified number of epochs
- Parameters
num_epochs (int) – number of epochs to train
datamgr_train (DataManager) – the datamanager holding the train data
datamgr_valid (DataManager) – the datamanager holding the validation data (default: None)
val_score_key (str) – the key specifying which metric to use for validation (default: None)
val_score_mode (str) – key specifying what kind of validation score is best
- Raises
NotImplementedError – If not overwritten by subclass
-
PyTorchNetworkTrainer¶
-
class
PyTorchNetworkTrainer
(network, save_path, criterions: dict, optimizer_cls, optimizer_params={}, metrics={}, lr_scheduler_cls=None, lr_scheduler_params={}, gpu_ids=[], save_freq=1, optim_fn=<function create_optims_default_pytorch>, fold=0, callbacks=[], start_epoch=1, mixed_precision=False, mixed_precision_kwargs={'allow_banned': False, 'enable_caching': True, 'verbose': False}, **kwargs)[source]¶ Bases:
delira.training.abstract_trainer.AbstractNetworkTrainer
Train and Validate a Network
See also
AbstractNetwork
-
_at_epoch_begin
(metrics_val, val_score_key, epoch, num_epochs, **kwargs)[source]¶ Defines behaviour at beginning of each epoch: Executes all callbacks’s at_epoch_begin method
-
_at_epoch_end
(metrics_val, val_score_key, epoch, is_best, **kwargs)[source]¶ Defines behaviour at beginning of each epoch: Executes all callbacks’s at_epoch_end method and saves current state if necessary
-
_at_training_begin
(*args, **kwargs)[source]¶ Defines behaviour at beginning of training
- Parameters
*args – positional arguments
**kwargs – keyword arguments
-
_at_training_end
()[source]¶ Defines Behaviour at end of training: Loads best model if available
- Returns
best network
- Return type
AbstractPyTorchNetwork
-
static
_is_better_val_scores
(old_val_score, new_val_score, mode='highest')¶ Check whether the new val score is better than the old one with respect to the optimization goal
-
_setup
(network, optim_fn, optimizer_cls, optimizer_params, lr_scheduler_cls, lr_scheduler_params, gpu_ids, mixed_precision, mixed_precision_kwargs)[source]¶ Defines the Trainers Setup
- Parameters
network (
AbstractPyTorchNetwork
) – the network to trainoptim_fn (function) – creates a dictionary containing all necessary optimizers
optimizer_cls (subclass of torch.optim.Optimizer) – optimizer class implementing the optimization algorithm of choice
optimizer_params (dict) –
lr_scheduler_cls (Any) – learning rate schedule class: must implement step() method
lr_scheduler_params (dict) – keyword arguments passed to lr scheduler during construction
gpu_ids (list) – list containing ids of GPUs to use; if empty: use cpu instead
mixed_precision (bool) – whether to use mixed precision or not (False per default)
mixed_precision_kwargs (dict) – additional keyword arguments for mixed precision
-
_train_single_epoch
(batchgen: batchgenerators.dataloading.MultiThreadedAugmenter, epoch)[source]¶ Trains the network a single epoch
- Parameters
batchgen (MultiThreadedAugmenter) – Generator yielding the training batches
epoch (int) – current epoch
-
_update_state
(new_state)[source]¶ Update the state from a given new state
- Parameters
new_state (dict) – new state to update internal state from
- Returns
the trainer with a modified state
- Return type
-
static
load_state
(file_name, **kwargs)[source]¶ Loads the new state from file via
delira.io.torch.load_checkpoint()
-
predict
(batchgen, batch_size=None)[source]¶ Returns predictions from network for batches from batchgen
- Parameters
- Returns
np.ndarray – predictions from batches
list of np.ndarray – labels from batches
dict – dictionary containing the mean validation metrics and the mean loss values
-
register_callback
(callback: delira.training.callbacks.abstract_callback.AbstractCallback)¶ Register Callback to Trainer
- Parameters
callback (
AbstractCallback
) – the callback to register- Raises
AssertionError – callback is not an instance of
AbstractCallback
and has not both methods [‘at_epoch_begin’, ‘at_epoch_end’]
-
save_state
(file_name, epoch, **kwargs)[source]¶ saves the current state via
delira.io.torch.save_checkpoint()
-
train
(num_epochs, datamgr_train, datamgr_valid=None, val_score_key=None, val_score_mode='highest')[source]¶ train network
- Parameters
num_epochs (int) – number of epochs to train
datamgr_train (BaseDataManager) – Data Manager to create Batch Generator for training
datamgr_valid (BaseDataManager) – Data Manager to create Batch Generator for validation
val_score_key (str) – Key of validation metric; must be key in self.metrics
val_score_mode (str) – String to specify whether a higher or lower validation score is optimal; must be in [‘highest’, ‘lowest’]
- Returns
Best model (if val_score_key is not a valid key the model of the last epoch will be returned)
- Return type
AbstractPyTorchNetwork
-
TfNetworkTrainer¶
-
class
TfNetworkTrainer
(network, save_path, losses: dict, optimizer_cls, optimizer_params={}, metrics={}, lr_scheduler_cls=None, lr_scheduler_params={}, gpu_ids=[], save_freq=1, optim_fn=<function create_optims_default_tf>, fold=0, callbacks=[], start_epoch=1, **kwargs)[source]¶ Bases:
delira.training.abstract_trainer.AbstractNetworkTrainer
Train and Validate a Network
See also
AbstractNetwork
-
_at_epoch_begin
(metrics_val, val_score_key, epoch, num_epochs, **kwargs)[source]¶ Defines behaviour at beginning of each epoch: Executes all callbacks’s at_epoch_begin method
-
_at_epoch_end
(metrics_val, val_score_key, epoch, is_best, **kwargs)[source]¶ Defines behaviour at beginning of each epoch: Executes all callbacks’s at_epoch_end method and saves current state if necessary
-
_at_training_begin
(*args, **kwargs)[source]¶ Defines behaviour at beginning of training
- Parameters
*args – positional arguments
**kwargs – keyword arguments
-
_at_training_end
()[source]¶ Defines Behaviour at end of training: Loads best model if available
- Returns
best network
- Return type
AbstractTfNetwork
-
static
_is_better_val_scores
(old_val_score, new_val_score, mode='highest')¶ Check whether the new val score is better than the old one with respect to the optimization goal
-
_setup
(network, optim_fn, optimizer_cls, optimizer_params, lr_scheduler_cls, lr_scheduler_params, gpu_ids)[source]¶ Defines the Trainers Setup
- Parameters
network (instance of :class: AbstractTfNetwork) – the network to train
optim_fn (function) – creates a dictionary containing all necessary optimizers
optimizer_cls (subclass of tf.train.Optimizer) – optimizer class implementing the optimization algorithm of choice
optimizer_params (dict) –
lr_scheduler_cls (Any) – learning rate schedule class: must implement step() method
lr_scheduler_params (dict) – keyword arguments passed to lr scheduler during construction
gpu_ids (list) – list containing ids of GPUs to use; if empty: use cpu instead
-
_train_single_epoch
(batchgen: batchgenerators.dataloading.MultiThreadedAugmenter, epoch)[source]¶ Trains the network a single epoch
- Parameters
batchgen (MultiThreadedAugmenter) – Generator yielding the training batches
epoch (int) – current epoch
-
_update_state
(new_state)¶ Update the state from a given new state
- Parameters
new_state (dict) – new state to update internal state from
- Returns
the trainer with a modified state
- Return type
-
load_state
(file_name)[source]¶ Loads the new state from file via
delira.io.tf.load_checkpoint()
- Parameters
file_name (str) – the file to load the state from
-
predict
(batchgen, batch_size=None)[source]¶ Returns predictions from network for batches from batchgen
- Parameters
- Returns
np.ndarray – predictions from batches
list of np.ndarray – labels from batches
dict – dictionary containing the mean validation metrics and the mean loss values
-
register_callback
(callback: delira.training.callbacks.abstract_callback.AbstractCallback)¶ Register Callback to Trainer
- Parameters
callback (
AbstractCallback
) – the callback to register- Raises
AssertionError – callback is not an instance of
AbstractCallback
and has not both methods [‘at_epoch_begin’, ‘at_epoch_end’]
-
save_state
(file_name)[source]¶ saves the current state via
delira.io.tf.save_checkpoint()
- Parameters
file_name (str) – filename to save the state to
-
train
(num_epochs, datamgr_train, datamgr_valid=None, val_score_key=None, val_score_mode='highest')[source]¶ train network
- Parameters
num_epochs (int) – number of epochs to train
datamgr_train (BaseDataManager) – Data Manager to create Batch Generator for training
datamgr_valid (BaseDataManager) – Data Manager to create Batch Generator for validation
val_score_key (str) – Key of validation metric; must be key in self.metrics
val_score_mode (str) – String to specify whether a higher or lower validation score is optimal; must be in [‘highest’, ‘lowest’]
- Returns
Best model (if val_score_key is not a valid key the model of the last epoch will be returned)
- Return type
AbstractTfNetwork
-