NetworkTrainer¶
- The network trainer implements the actual training routine and can be subclassed
- for special routines.
Subclassing your trainer also means you have to subclass your experiment (to use the trainer).
AbstractNetworkTrainer¶
-
class
AbstractNetworkTrainer
(fold=0, callbacks=[])[source]¶ Bases:
object
Defines an abstract API for Network Trainers
See also
-
_at_epoch_begin
(*args, **kwargs)[source]¶ Defines the behaviour at beginnig of each epoch
Parameters: - *args – positional arguments
- **kwargs – keyword arguments
Raises: NotImplementedError
– If not overwritten by subclass
-
_at_epoch_end
(*args, **kwargs)[source]¶ Defines the behaviour at the end of each epoch
Parameters: - *args – positional arguments
- **kwargs – keyword arguments
Raises: NotImplementedError
– If not overwritten by subclass
-
_at_training_begin
(*args, **kwargs)[source]¶ Defines the behaviour at beginnig of the training
Parameters: - *args – positional arguments
- **kwargs – keyword arguments
Raises: NotImplementedError
– If not overwritten by subclass
-
_at_training_end
(*args, **kwargs)[source]¶ Defines the behaviour at the end of the training
Parameters: - *args – positional arguments
- **kwargs – keyword arguments
Raises: NotImplementedError
– If not overwritten by subclass
-
static
_is_better_val_scores
(old_val_score, new_val_score, mode='highest')[source]¶ Check whether the new val score is better than the old one with respect to the optimization goal
Parameters: - old_val_score – old validation score
- new_val_score – new validation score
- mode (str) – String to specify whether a higher or lower validation score is optimal; must be in [‘highest’, ‘lowest’]
Returns: True if new score is better, False otherwise
Return type:
-
_setup
(*args, **kwargs)[source]¶ Defines the actual Trainer Setup
Parameters: - *args – positional arguments
- **kwargs – keyword arguments
Raises: NotImplementedError
– If not overwritten by subclass
-
_train_single_epoch
(batchgen: <sphinx.ext.autodoc.importer._MockObject object at 0x7fc9644f1b38>, epoch)[source]¶ Defines a routine to train a single epoch
Parameters: - batchgen (MultiThreadedAugmenter) – generator holding the batches
- epoch (int) – current epoch
Raises: NotImplementedError
– If not overwritten by subclass
-
_update_state
(new_state)[source]¶ Update the state from a given new state
Parameters: new_state (dict) – new state to update internal state from Returns: the trainer with a modified state Return type: AbstractNetworkTrainer
-
static
load_state
(file_name, *args, **kwargs)[source]¶ Loads the new state from file
Parameters: - file_name (str) – the file to load the state from
- *args – positional arguments
- **kwargs (keyword arguments) –
Returns: new state
Return type:
-
predict
(batchgen, batchsize=None)[source]¶ Defines a rotine to predict data obtained from a batchgenerator
Parameters: - batchgen (MultiThreadedAugmenter) – Generator Holding the Batches
- batchsize (Artificial batchsize (sampling will be done with batchsize) – 1 and sampled data will be stacked to match the artificial batchsize)(default: None)
Raises: NotImplementedError
– If not overwritten by subclass
-
register_callback
(callback: delira.training.callbacks.abstract_callback.AbstractCallback)[source]¶ Register Callback to Trainer
Parameters: callback ( AbstractCallback
) – the callback to registerRaises: AssertionError
– callback is not an instance ofAbstractCallback
and has not both methods [‘at_epoch_begin’, ‘at_epoch_end’]
-
save_state
(file_name, *args, **kwargs)[source]¶ saves the current state
Parameters: - file_name (str) – filename to save the state to
- *args – positional arguments
- **kwargs – keyword arguments
-
train
(num_epochs, datamgr_train, datamgr_valid=None, val_score_key=None, val_score_mode='highest')[source]¶ Defines a routine to train a specified number of epochs
Parameters: - num_epochs (int) – number of epochs to train
- datamgr_train (DataManager) – the datamanager holding the train data
- datamgr_valid (DataManager) – the datamanager holding the validation data (default: None)
- val_score_key (str) – the key specifying which metric to use for validation (default: None)
- val_score_mode (str) – key specifying what kind of validation score is best
Raises: NotImplementedError
– If not overwritten by subclass
-
PyTorchNetworkTrainer¶
-
class
PyTorchNetworkTrainer
(network, save_path, criterions: dict, optimizer_cls, optimizer_params={}, metrics={}, lr_scheduler_cls=None, lr_scheduler_params={}, gpu_ids=[], save_freq=1, optim_fn=<function create_optims_default_pytorch>, fold=0, callbacks=[], start_epoch=1, mixed_precision=False, mixed_precision_kwargs={'allow_banned': False, 'enable_caching': True, 'verbose': False}, **kwargs)[source]¶ Bases:
delira.training.abstract_trainer.AbstractNetworkTrainer
Train and Validate a Network
See also
AbstractNetwork
-
_at_epoch_begin
(metrics_val, val_score_key, epoch, num_epochs, **kwargs)[source]¶ Defines behaviour at beginning of each epoch: Executes all callbacks’s at_epoch_begin method
Parameters:
-
_at_epoch_end
(metrics_val, val_score_key, epoch, is_best, **kwargs)[source]¶ Defines behaviour at beginning of each epoch: Executes all callbacks’s at_epoch_end method and saves current state if necessary
Parameters:
-
_at_training_begin
(*args, **kwargs)[source]¶ Defines behaviour at beginning of training
Parameters: - *args – positional arguments
- **kwargs – keyword arguments
-
_at_training_end
()[source]¶ Defines Behaviour at end of training: Loads best model if available
Returns: best network Return type: AbstractPyTorchNetwork
-
static
_is_better_val_scores
(old_val_score, new_val_score, mode='highest')¶ Check whether the new val score is better than the old one with respect to the optimization goal
Parameters: - old_val_score – old validation score
- new_val_score – new validation score
- mode (str) – String to specify whether a higher or lower validation score is optimal; must be in [‘highest’, ‘lowest’]
Returns: True if new score is better, False otherwise
Return type:
-
_setup
(network, optim_fn, optimizer_cls, optimizer_params, lr_scheduler_cls, lr_scheduler_params, gpu_ids, mixed_precision, mixed_precision_kwargs)[source]¶ Defines the Trainers Setup
Parameters: - network (
AbstractPyTorchNetwork
) – the network to train - optim_fn (function) – creates a dictionary containing all necessary optimizers
- optimizer_cls (subclass of torch.optim.Optimizer) – optimizer class implementing the optimization algorithm of choice
- optimizer_params (dict) –
- lr_scheduler_cls (Any) – learning rate schedule class: must implement step() method
- lr_scheduler_params (dict) – keyword arguments passed to lr scheduler during construction
- gpu_ids (list) – list containing ids of GPUs to use; if empty: use cpu instead
- mixed_precision (bool) – whether to use mixed precision or not (False per default)
- mixed_precision_kwargs (dict) – additional keyword arguments for mixed precision
- network (
-
_train_single_epoch
(batchgen: <sphinx.ext.autodoc.importer._MockObject object at 0x7fc9647a5940>, epoch)[source]¶ Trains the network a single epoch
Parameters: - batchgen (MultiThreadedAugmenter) – Generator yielding the training batches
- epoch (int) – current epoch
-
_update_state
(new_state)[source]¶ Update the state from a given new state
Parameters: new_state (dict) – new state to update internal state from Returns: the trainer with a modified state Return type: PyTorchNetworkTrainer
-
static
load_state
(file_name, **kwargs)[source]¶ Loads the new state from file via
delira.io.torch.load_checkpoint()
Parameters: - file_name (str) – the file to load the state from
- **kwargs (keyword arguments) –
Returns: new state
Return type:
-
predict
(batchgen, batch_size=None)[source]¶ Returns predictions from network for batches from batchgen
Parameters: Returns: - np.ndarray – predictions from batches
- list of np.ndarray – labels from batches
- dict – dictionary containing the mean validation metrics and the mean loss values
-
register_callback
(callback: delira.training.callbacks.abstract_callback.AbstractCallback)¶ Register Callback to Trainer
Parameters: callback ( AbstractCallback
) – the callback to registerRaises: AssertionError
– callback is not an instance ofAbstractCallback
and has not both methods [‘at_epoch_begin’, ‘at_epoch_end’]
-
save_state
(file_name, epoch, **kwargs)[source]¶ saves the current state via
delira.io.torch.save_checkpoint()
Parameters:
-
train
(num_epochs, datamgr_train, datamgr_valid=None, val_score_key=None, val_score_mode='highest')[source]¶ train network
Parameters: - num_epochs (int) – number of epochs to train
- datamgr_train (BaseDataManager) – Data Manager to create Batch Generator for training
- datamgr_valid (BaseDataManager) – Data Manager to create Batch Generator for validation
- val_score_key (str) – Key of validation metric; must be key in self.metrics
- val_score_mode (str) – String to specify whether a higher or lower validation score is optimal; must be in [‘highest’, ‘lowest’]
Returns: Best model (if val_score_key is not a valid key the model of the last epoch will be returned)
Return type: AbstractPyTorchNetwork
-
TfNetworkTrainer¶
-
class
TfNetworkTrainer
(network, save_path, losses: dict, optimizer_cls, optimizer_params={}, metrics={}, lr_scheduler_cls=None, lr_scheduler_params={}, gpu_ids=[], save_freq=1, optim_fn=<function create_optims_default_tf>, fold=0, callbacks=[], start_epoch=1, **kwargs)[source]¶ Bases:
delira.training.abstract_trainer.AbstractNetworkTrainer
Train and Validate a Network
See also
AbstractNetwork
-
_at_epoch_begin
(metrics_val, val_score_key, epoch, num_epochs, **kwargs)[source]¶ Defines behaviour at beginning of each epoch: Executes all callbacks’s at_epoch_begin method
Parameters:
-
_at_epoch_end
(metrics_val, val_score_key, epoch, is_best, **kwargs)[source]¶ Defines behaviour at beginning of each epoch: Executes all callbacks’s at_epoch_end method and saves current state if necessary
Parameters:
-
_at_training_begin
(*args, **kwargs)[source]¶ Defines behaviour at beginning of training
Parameters: - *args – positional arguments
- **kwargs – keyword arguments
-
_at_training_end
()[source]¶ Defines Behaviour at end of training: Loads best model if available
Returns: best network Return type: AbstractTfNetwork
-
static
_is_better_val_scores
(old_val_score, new_val_score, mode='highest')¶ Check whether the new val score is better than the old one with respect to the optimization goal
Parameters: - old_val_score – old validation score
- new_val_score – new validation score
- mode (str) – String to specify whether a higher or lower validation score is optimal; must be in [‘highest’, ‘lowest’]
Returns: True if new score is better, False otherwise
Return type:
-
_setup
(network, optim_fn, optimizer_cls, optimizer_params, lr_scheduler_cls, lr_scheduler_params, gpu_ids)[source]¶ Defines the Trainers Setup
Parameters: - network (instance of :class: AbstractTfNetwork) – the network to train
- optim_fn (function) – creates a dictionary containing all necessary optimizers
- optimizer_cls (subclass of tf.train.Optimizer) – optimizer class implementing the optimization algorithm of choice
- optimizer_params (dict) –
- lr_scheduler_cls (Any) – learning rate schedule class: must implement step() method
- lr_scheduler_params (dict) – keyword arguments passed to lr scheduler during construction
- gpu_ids (list) – list containing ids of GPUs to use; if empty: use cpu instead
-
_train_single_epoch
(batchgen: <sphinx.ext.autodoc.importer._MockObject object at 0x7fc9648a5f60>, epoch)[source]¶ Trains the network a single epoch
Parameters: - batchgen (MultiThreadedAugmenter) – Generator yielding the training batches
- epoch (int) – current epoch
-
_update_state
(new_state)¶ Update the state from a given new state
Parameters: new_state (dict) – new state to update internal state from Returns: the trainer with a modified state Return type: AbstractNetworkTrainer
-
load_state
(file_name)[source]¶ Loads the new state from file via
delira.io.tf.load_checkpoint()
Parameters: file_name (str) – the file to load the state from
-
predict
(batchgen, batch_size=None)[source]¶ Returns predictions from network for batches from batchgen
Parameters: Returns: - np.ndarray – predictions from batches
- list of np.ndarray – labels from batches
- dict – dictionary containing the mean validation metrics and the mean loss values
-
register_callback
(callback: delira.training.callbacks.abstract_callback.AbstractCallback)¶ Register Callback to Trainer
Parameters: callback ( AbstractCallback
) – the callback to registerRaises: AssertionError
– callback is not an instance ofAbstractCallback
and has not both methods [‘at_epoch_begin’, ‘at_epoch_end’]
-
save_state
(file_name)[source]¶ saves the current state via
delira.io.tf.save_checkpoint()
Parameters: file_name (str) – filename to save the state to
-
train
(num_epochs, datamgr_train, datamgr_valid=None, val_score_key=None, val_score_mode='highest')[source]¶ train network
Parameters: - num_epochs (int) – number of epochs to train
- datamgr_train (BaseDataManager) – Data Manager to create Batch Generator for training
- datamgr_valid (BaseDataManager) – Data Manager to create Batch Generator for validation
- val_score_key (str) – Key of validation metric; must be key in self.metrics
- val_score_mode (str) – String to specify whether a higher or lower validation score is optimal; must be in [‘highest’, ‘lowest’]
Returns: Best model (if val_score_key is not a valid key the model of the last epoch will be returned)
Return type: AbstractTfNetwork
-