Experiments

Experiments are the outermost class to control your training, it wraps your NetworkTrainer and provides utilities for cross-validation.

AbstractExperiment

class AbstractExperiment(n_epochs, *args, **kwargs)[source]

Bases: trixi.experiment.Experiment

Abstract Class Representing a single Experiment (must be subclassed for each Backend)

kfold(num_epochs: int, data: delira.data_loading.data_manager.BaseDataManager, num_splits=None, shuffle=False, random_seed=None, train_kwargs={}, test_kwargs={}, **kwargs)[source]

Runs K-Fold Crossvalidation The supported scenario is:

  • passing a single datamanager: the data within the single manager

will be split and multiple datamanagers will be created holding the subsets.

Parameters
  • num_epochs (int) – number of epochs to train the model

  • data (single BaseDataManager) – single datamanager (will be split for crossvalidation)

  • num_splits (None or int) – number of splits for kfold if None: 10 splits will be validated per default

  • shuffle (bool) – whether or not to shuffle indices for kfold

  • random_seed (None or int) – random seed used to seed the kfold (if shuffle is true), pytorch and numpy

  • train_kwargs (dict) – keyword arguments to specify training behavior

  • test_kwargs (dict) – keyword arguments to specify testing behavior

  • **kwargs – additional keyword arguments (completely passed to self.run())

See also

:method:`BaseDataManager.update_state_from_dict`

train_kwargs and test_kwargs

static load(file_name)[source]

Loads whole experiment

Parameters

file_name (str) – file_name to load the experiment from

Raises

NotImplementedError – if not overwritten in subclass

run(train_data: delira.data_loading.data_manager.BaseDataManager, val_data: Optional[delira.data_loading.data_manager.BaseDataManager] = None, params: Optional[delira.training.parameters.Parameters] = None, **kwargs)[source]

trains single model

Parameters
  • train_data (BaseDataManager) – data manager containing the training data

  • val_data (BaseDataManager) – data manager containing the validation data

  • parameters (Parameters, optional) – Class containing all parameters (defaults to None). If not specified, the parameters fall back to the ones given during class initialization

Raises

NotImplementedError – If not overwritten in subclass

save()[source]

Saves the Whole experiments

Raises

NotImplementedError – If not overwritten in subclass

setup(*args, **kwargs)[source]

Abstract Method to setup a AbstractNetworkTrainer

Raises

NotImplementedError – if not overwritten by subclass

stratified_kfold(num_epochs: int, data: delira.data_loading.data_manager.BaseDataManager, num_splits=None, shuffle=False, random_seed=None, label_key='label', train_kwargs={}, test_kwargs={}, **kwargs)[source]

Runs stratified K-Fold Crossvalidation The supported supported scenario is:

  • passing a single datamanager: the data within the single manager will be split and multiple datamanagers will be created holding the subsets.

Parameters
  • num_epochs (int) – number of epochs to train the model

  • data (BaseDataManager) – single datamanager (will be split for crossvalidation)

  • num_splits (None or int) – number of splits for kfold if None: 10 splits will be validated

  • shuffle (bool) – whether or not to shuffle indices for kfold

  • random_seed (None or int) – random seed used to seed the kfold (if shuffle is true), pytorch and numpy

  • label_key (str (default: "label")) – the key to extract the label for stratification from each data sample

  • train_kwargs (dict) – keyword arguments to specify training behavior

  • test_kwargs (dict) – keyword arguments to specify testing behavior

  • **kwargs – additional keyword arguments (completely passed to self.run())

See also

:method:`BaseDataManager.update_state_from_dict`

train_kwargs and test_kwargs

stratified_kfold_predict(num_epochs: int, data: delira.data_loading.data_manager.BaseDataManager, split_val=0.2, num_splits=None, shuffle=False, random_seed=None, label_key='label', train_kwargs={}, test_kwargs={}, **kwargs)[source]

Runs stratified K-Fold Crossvalidation The supported supported scenario is:

  • passing a single datamanager: the data within the single manager will be split and multiple datamanagers will be created holding the subsets.

Parameters
  • num_epochs (int) – number of epochs to train the model

  • data (BaseDataManager) – single datamanager (will be split for crossvalidation)

  • num_splits (None or int) – number of splits for kfold if None: 10 splits will be validated

  • shuffle (bool) – whether or not to shuffle indices for kfold

  • random_seed (None or int) – random seed used to seed the kfold (if shuffle is true), pytorch and numpy

  • label_key (str (default: "label")) – the key to extract the label for stratification from each data sample

  • train_kwargs (dict) – keyword arguments to specify training behavior

  • test_kwargs (dict) – keyword arguments to specify testing behavior

  • **kwargs – additional keyword arguments (completely passed to self.run())

See also

:method:`BaseDataManager.update_state_from_dict`

train_kwargs and test_kwargs

test(params: delira.training.parameters.Parameters, network: delira.models.abstract_network.AbstractNetwork, datamgr_test: delira.data_loading.data_manager.BaseDataManager, trainer_cls=<class 'delira.training.abstract_trainer.AbstractNetworkTrainer'>, **kwargs)[source]

Executes prediction for all items in datamgr_test with network

Parameters
  • params (Parameters) – the parameters to construct a model

  • network (:class:'AbstractNetwork') – the network to train

  • datamgr_test (:class:'BaseDataManager') – holds the test data

  • trainer_cls – class defining the actual trainer, defaults to AbstractNetworkTrainer, which should be suitable for most cases, but can easily be overwritten and exchanged if necessary

  • **kwargs – holds additional keyword arguments (which are completly passed to the trainers init)

Returns

  • np.ndarray – predictions from batches

  • list of np.ndarray – labels from batches

  • dict – dictionary containing the mean validation metrics and the mean loss values

PyTorchExperiment

class PyTorchExperiment(params: delira.training.parameters.Parameters, model_cls: delira.models.abstract_network.AbstractPyTorchNetwork, name=None, save_path=None, val_score_key=None, optim_builder=<function create_optims_default_pytorch>, checkpoint_freq=1, trainer_cls=<class 'delira.training.pytorch_trainer.PyTorchNetworkTrainer'>, **kwargs)[source]

Bases: delira.training.experiment.AbstractExperiment

Single Experiment for PyTorch Backend

kfold(num_epochs: int, data: Union[List[delira.data_loading.data_manager.BaseDataManager], delira.data_loading.data_manager.BaseDataManager], num_splits=None, shuffle=False, random_seed=None, train_kwargs={}, test_kwargs={}, **kwargs)[source]

Runs K-Fold Crossvalidation The supported scenario is:

  • passing a single datamanager: the data within the single manager

will be split and multiple datamanagers will be created holding the subsets.

Parameters
  • num_epochs (int) – number of epochs to train the model

  • data (single BaseDataManager) – single datamanager (will be split for crossvalidation)

  • num_splits (None or int) – number of splits for kfold if None: 10 splits will be validated per default

  • shuffle (bool) – whether or not to shuffle indices for kfold

  • random_seed (None or int) – random seed used to seed the kfold (if shuffle is true), pytorch and numpy

  • train_kwargs (dict) – keyword arguments to specify training behavior

  • test_kwargs (dict) – keyword arguments to specify testing behavior

  • **kwargs – additional keyword arguments (completely passed to self.run())

See also

:method:`BaseDataManager.update_state_from_dict`

train_kwargs and test_kwargs

static load(file_name)[source]

Loads whole experiment

Parameters

file_name (str) – file_name to load the experiment from

run(train_data: delira.data_loading.data_manager.BaseDataManager, val_data: Optional[delira.data_loading.data_manager.BaseDataManager], params: Optional[delira.training.parameters.Parameters] = None, **kwargs)[source]

trains single model

Parameters
  • train_data (BaseDataManager) – holds the trainset

  • val_data (BaseDataManager or None) – holds the validation set (if None: Model will not be validated)

  • params (Parameters) – the parameters to construct a model and network trainer

  • **kwargs – holds additional keyword arguments (which are completly passed to the trainers init)

Returns

trainer of trained network

Return type

AbstractNetworkTrainer

Raises

ValueError – Class has no Attribute params and no parameters were given as function argument

save()[source]

Saves the Whole experiments

setup(params: delira.training.parameters.Parameters, **kwargs)[source]

Perform setup of Network Trainer

Parameters
  • params (Parameters) – the parameters to construct a model and network trainer

  • **kwargs – keyword arguments

stratified_kfold(num_epochs: int, data: delira.data_loading.data_manager.BaseDataManager, num_splits=None, shuffle=False, random_seed=None, label_key='label', train_kwargs={}, test_kwargs={}, **kwargs)[source]

Runs stratified K-Fold Crossvalidation The supported supported scenario is:

  • passing a single datamanager: the data within the single manager will be split and multiple datamanagers will be created holding the subsets.

Parameters
  • num_epochs (int) – number of epochs to train the model

  • data (BaseDataManager) – single datamanager (will be split for crossvalidation)

  • num_splits (None or int) – number of splits for kfold if None: 10 splits will be validated

  • shuffle (bool) – whether or not to shuffle indices for kfold

  • random_seed (None or int) – random seed used to seed the kfold (if shuffle is true), pytorch and numpy

  • label_key (str (default: "label")) – the key to extract the label for stratification from each data sample

  • train_kwargs (dict) – keyword arguments to specify training behavior

  • test_kwargs (dict) – keyword arguments to specify testing behavior

  • **kwargs – additional keyword arguments (completely passed to self.run())

See also

:method:`BaseDataManager.update_state_from_dict`

train_kwargs and test_kwargs

stratified_kfold_predict(num_epochs: int, data: delira.data_loading.data_manager.BaseDataManager, split_val=0.2, num_splits=None, shuffle=False, random_seed=None, label_key='label', train_kwargs={}, test_kwargs={}, **kwargs)

Runs stratified K-Fold Crossvalidation The supported supported scenario is:

  • passing a single datamanager: the data within the single manager will be split and multiple datamanagers will be created holding the subsets.

Parameters
  • num_epochs (int) – number of epochs to train the model

  • data (BaseDataManager) – single datamanager (will be split for crossvalidation)

  • num_splits (None or int) – number of splits for kfold if None: 10 splits will be validated

  • shuffle (bool) – whether or not to shuffle indices for kfold

  • random_seed (None or int) – random seed used to seed the kfold (if shuffle is true), pytorch and numpy

  • label_key (str (default: "label")) – the key to extract the label for stratification from each data sample

  • train_kwargs (dict) – keyword arguments to specify training behavior

  • test_kwargs (dict) – keyword arguments to specify testing behavior

  • **kwargs – additional keyword arguments (completely passed to self.run())

See also

:method:`BaseDataManager.update_state_from_dict`

train_kwargs and test_kwargs

test(params: delira.training.parameters.Parameters, network: delira.models.abstract_network.AbstractPyTorchNetwork, datamgr_test: delira.data_loading.data_manager.BaseDataManager, **kwargs)[source]

Executes prediction for all items in datamgr_test with network

Parameters
  • params (Parameters) – the parameters to construct a model

  • network (:class:'AbstractPyTorchNetwork') – the network to train

  • datamgr_test (:class:'BaseDataManager') – holds the test data

  • **kwargs – holds additional keyword arguments (which are completly passed to the trainers init)

Returns

  • np.ndarray – predictions from batches

  • list of np.ndarray – labels from batches

  • dict – dictionary containing the mean validation metrics and the mean loss values

TfExperiment

class TfExperiment(params: Union[delira.training.parameters.Parameters, str], model_cls: delira.models.abstract_network.AbstractTfNetwork, name=None, save_path=None, val_score_key=None, optim_builder=<function create_optims_default_tf>, checkpoint_freq=1, trainer_cls=<class 'delira.training.tf_trainer.TfNetworkTrainer'>, **kwargs)[source]

Bases: delira.training.experiment.AbstractExperiment

Single Experiment for Tf Backend

kfold(num_epochs: int, data: Union[List[delira.data_loading.data_manager.BaseDataManager], delira.data_loading.data_manager.BaseDataManager], num_splits=None, shuffle=False, random_seed=None, train_kwargs={}, test_kwargs={}, **kwargs)[source]

Runs K-Fold Crossvalidation The supported scenario is:

  • passing a single datamanager: the data within the single manager

will be split and multiple datamanagers will be created holding the subsets.

Parameters
  • num_epochs (int) – number of epochs to train the model

  • data (single BaseDataManager) – single datamanager (will be split for crossvalidation)

  • num_splits (None or int) – number of splits for kfold if None: 10 splits will be validated per default

  • shuffle (bool) – whether or not to shuffle indices for kfold

  • random_seed (None or int) – random seed used to seed the kfold (if shuffle is true), pytorch and numpy

  • train_kwargs (dict) – keyword arguments to specify training behavior

  • test_kwargs (dict) – keyword arguments to specify testing behavior

  • **kwargs – additional keyword arguments (completely passed to self.run())

See also

:method:`BaseDataManager.update_state_from_dict`

train_kwargs and test_kwargs

static load(file_name)[source]

Loads whole experiment

Parameters

file_name (str) – file_name to load the experiment from

run(train_data: delira.data_loading.data_manager.BaseDataManager, val_data: Optional[delira.data_loading.data_manager.BaseDataManager], params: Optional[delira.training.parameters.Parameters] = None, **kwargs)[source]

trains single model

Parameters
  • train_data (BaseDataManager) – holds the trainset

  • val_data (BaseDataManager or None) – holds the validation set (if None: Model will not be validated)

  • params (Parameters) – the parameters to construct a model and network trainer

  • **kwargs – holds additional keyword arguments (which are completly passed to the trainers init)

Returns

trainer of trained network

Return type

AbstractNetworkTrainer

Raises

ValueError – Class has no Attribute params and no parameters were given as function argument

save()[source]

Saves the Whole experiments

setup(params: delira.training.parameters.Parameters, **kwargs)[source]

Perform setup of Network Trainer

Parameters
  • params (Parameters) – the parameters to construct a model and network trainer

  • **kwargs – keyword arguments

stratified_kfold(num_epochs: int, data: delira.data_loading.data_manager.BaseDataManager, num_splits=None, shuffle=False, random_seed=None, label_key='label', train_kwargs={}, test_kwargs={}, **kwargs)[source]

Runs stratified K-Fold Crossvalidation The supported supported scenario is:

  • passing a single datamanager: the data within the single manager will be split and multiple datamanagers will be created holding the subsets.

Parameters
  • num_epochs (int) – number of epochs to train the model

  • data (BaseDataManager) – single datamanager (will be split for crossvalidation)

  • num_splits (None or int) – number of splits for kfold if None: 10 splits will be validated

  • shuffle (bool) – whether or not to shuffle indices for kfold

  • random_seed (None or int) – random seed used to seed the kfold (if shuffle is true), pytorch and numpy

  • label_key (str (default: "label")) – the key to extract the label for stratification from each data sample

  • train_kwargs (dict) – keyword arguments to specify training behavior

  • test_kwargs (dict) – keyword arguments to specify testing behavior

  • **kwargs – additional keyword arguments (completely passed to self.run())

See also

:method:`BaseDataManager.update_state_from_dict`

train_kwargs and test_kwargs

stratified_kfold_predict(num_epochs: int, data: delira.data_loading.data_manager.BaseDataManager, split_val=0.2, num_splits=None, shuffle=False, random_seed=None, label_key='label', train_kwargs={}, test_kwargs={}, **kwargs)[source]

Runs stratified K-Fold Crossvalidation The supported supported scenario is:

  • passing a single datamanager: the data within the single manager will be split and multiple datamanagers will be created holding the subsets.

Parameters
  • num_epochs (int) – number of epochs to train the model

  • data (BaseDataManager) – single datamanager (will be split for crossvalidation)

  • num_splits (None or int) – number of splits for kfold if None: 10 splits will be validated

  • shuffle (bool) – whether or not to shuffle indices for kfold

  • random_seed (None or int) – random seed used to seed the kfold (if shuffle is true), pytorch and numpy

  • label_key (str (default: "label")) – the key to extract the label for stratification from each data sample

  • train_kwargs (dict) – keyword arguments to specify training behavior

  • test_kwargs (dict) – keyword arguments to specify testing behavior

  • **kwargs – additional keyword arguments (completely passed to self.run())

See also

:method:`BaseDataManager.update_state_from_dict`

train_kwargs and test_kwargs

test(params: delira.training.parameters.Parameters, network: delira.models.abstract_network.AbstractNetwork, datamgr_test: delira.data_loading.data_manager.BaseDataManager, **kwargs)[source]

Executes prediction for all items in datamgr_test with network

Parameters
  • params (Parameters) – the parameters to construct a model

  • network (:class:'AbstractPyTorchNetwork') – the network to train

  • datamgr_test (:class:'BaseDataManager') – holds the test data

  • **kwargs – holds additional keyword arguments (which are completly passed to the trainers init)

Returns

  • np.ndarray – predictions from batches

  • list of np.ndarray – labels from batches

  • dict – dictionary containing the mean validation metrics and the mean loss values