Experiments¶
Experiments are the outermost class to control your training, it wraps your NetworkTrainer and provides utilities for cross-validation.
AbstractExperiment¶
-
class
AbstractExperiment
(n_epochs, *args, **kwargs)[source]¶ Bases:
sphinx.ext.autodoc.importer._MockObject
Abstract Class Representing a single Experiment (must be subclassed for each Backend)
See also
-
kfold
(num_epochs: int, data: List[delira.data_loading.data_manager.BaseDataManager], num_splits=None, shuffle=False, random_seed=None, **kwargs)[source]¶ Runs K-Fold Crossvalidation
Parameters: - num_epochs (int) – number of epochs to train the model
- data (list of BaseDataManager) – list of datamanagers (will be split for crossvalidation)
- num_splits (None or int) – number of splits for kfold if None: len(data) splits will be validated
- shuffle (bool) – whether or not to shuffle indices for kfold
- random_seed (None or int) – random seed used to seed the kfold (if shuffle is true), pytorch and numpy
- **kwargs – additional keyword arguments (completely passed to self.run())
-
static
load
(file_name)[source]¶ Loads whole experiment
Parameters: file_name (str) – file_name to load the experiment from Raises: NotImplementedError
– if not overwritten in subclass
-
run
(train_data: Union[delira.data_loading.data_manager.BaseDataManager, delira.data_loading.data_manager.ConcatDataManager], val_data: Union[delira.data_loading.data_manager.BaseDataManager, delira.data_loading.data_manager.ConcatDataManager, None] = None, params: Optional[delira.training.parameters.Parameters] = None, **kwargs)[source]¶ trains single model
Parameters: - train_data (
BaseDataManager
orConcatDataManager
) – data manager containing the training data - val_data (
BaseDataManager
orConcatDataManager
) – data manager containing the validation data - parameters (
Parameters
, optional) – Class containing all parameters (defaults to None). If not specified, the parameters fall back to the ones given during class initialization
Raises: NotImplementedError
– If not overwritten in subclass- train_data (
-
save
()[source]¶ Saves the Whole experiments
Raises: NotImplementedError
– If not overwritten in subclass
-
setup
(*args, **kwargs)[source]¶ Abstract Method to setup a
AbstractNetworkTrainer
Raises: NotImplementedError
– if not overwritten by subclass
-
PyTorchExperiment¶
-
class
PyTorchExperiment
(params: delira.training.parameters.Parameters, model_cls: delira.models.abstract_network.AbstractPyTorchNetwork, name=None, save_path=None, val_score_key=None, optim_builder=<function create_optims_default_pytorch>, checkpoint_freq=1, trainer_cls=<class 'delira.training.pytorch_trainer.PyTorchNetworkTrainer'>, **kwargs)[source]¶ Bases:
delira.training.experiment.AbstractExperiment
Single Experiment for PyTorch Backend
See also
-
kfold
(num_epochs: int, data: List[delira.data_loading.data_manager.BaseDataManager], num_splits=None, shuffle=False, random_seed=None, **kwargs)¶ Runs K-Fold Crossvalidation
Parameters: - num_epochs (int) – number of epochs to train the model
- data (list of BaseDataManager) – list of datamanagers (will be split for crossvalidation)
- num_splits (None or int) – number of splits for kfold if None: len(data) splits will be validated
- shuffle (bool) – whether or not to shuffle indices for kfold
- random_seed (None or int) – random seed used to seed the kfold (if shuffle is true), pytorch and numpy
- **kwargs – additional keyword arguments (completely passed to self.run())
-
static
load
(file_name)[source]¶ Loads whole experiment
Parameters: file_name (str) – file_name to load the experiment from
-
run
(train_data: Union[delira.data_loading.data_manager.BaseDataManager, delira.data_loading.data_manager.ConcatDataManager], val_data: Union[delira.data_loading.data_manager.BaseDataManager, delira.data_loading.data_manager.ConcatDataManager, None], params: Optional[delira.training.parameters.Parameters] = None, **kwargs)[source]¶ trains single model
Parameters: - train_data (BaseDataManager or ConcatDataManager) – holds the trainset
- val_data (BaseDataManager or ConcatDataManager or None) – holds the validation set (if None: Model will not be validated)
- params (
Parameters
) – the parameters to construct a model and network trainer - **kwargs – holds additional keyword arguments (which are completly passed to the trainers init)
Returns: trainer of trained network
Return type: Raises: ValueError
– Class has no Attributeparams
and no parameters were given as function argument
-
setup
(params: delira.training.parameters.Parameters, **kwargs)[source]¶ Perform setup of Network Trainer
Parameters: - params (
Parameters
) – the parameters to construct a model and network trainer - **kwargs – keyword arguments
- params (
-