Source code for delira.training.experiment


import typing
import logging
import pickle
import os
from datetime import datetime
from functools import partial

import numpy as np
from sklearn.model_selection import KFold, StratifiedKFold, \
    StratifiedShuffleSplit, ShuffleSplit

from delira import get_backends

from ..data_loading import BaseDataManager
from ..models import AbstractNetwork

from .parameters import Parameters
from .base_trainer import BaseNetworkTrainer
from .predictor import Predictor

logger = logging.getLogger(__name__)


[docs]class BaseExperiment(object): """ Baseclass for Experiments. Implements: * Setup-Behavior for Models, Trainers and Predictors (depending on train and test case) * The K-Fold logic (including stratified and random splitting) * Argument Handling """ def __init__(self, params: typing.Union[str, Parameters], model_cls: AbstractNetwork, n_epochs=None, name=None, save_path=None, key_mapping=None, val_score_key=None, optim_builder=None, checkpoint_freq=1, trainer_cls=BaseNetworkTrainer, predictor_cls=Predictor, **kwargs): """ Parameters ---------- params : :class:`Parameters` or str the training parameters, if string is passed, it is treated as a path to a pickle file, where the parameters are loaded from model_cls : Subclass of :class:`AbstractNetwork` the class implementing the model to train n_epochs : int or None the number of epochs to train, if None: can be specified later during actual training name : str or None the Experiment's name save_path : str or None the path to save the results and checkpoints to. if None: Current working directory will be used key_mapping : dict mapping between data_dict and model inputs (necessary for prediction with :class:`Predictor`-API) val_score_key : str or None key defining which metric to use for validation (determining best model and scheduling lr); if None: No validation-based operations will be done (model might still get validated, but validation metrics can only be logged and not used further) optim_builder : function Function returning a dict of backend-specific optimizers checkpoint_freq : int frequency of saving checkpoints (1 denotes saving every epoch, 2 denotes saving every second epoch etc.); default: 1 trainer_cls : subclass of :class:`BaseNetworkTrainer` the trainer class to use for training the model predictor_cls : subclass of :class:`Predictor` the predictor class to use for testing the model **kwargs : additional keyword arguments """ # params could also be a file containing a pickled instance of # parameters if isinstance(params, str): with open(params, "rb") as f: params = pickle.load(f) if n_epochs is None: n_epochs = params.nested_get("n_epochs", params.nested_get("num_epochs")) self.n_epochs = n_epochs if name is None: name = "UnnamedExperiment" self.name = name if save_path is None: save_path = os.path.abspath(".") self.save_path = os.path.join(save_path, name, str(datetime.now().strftime( "%y-%m-%d_%H-%M-%S"))) if os.path.isdir(self.save_path): logger.warning("Save Path %s already exists") os.makedirs(self.save_path, exist_ok=True) self.trainer_cls = trainer_cls self.predictor_cls = predictor_cls if val_score_key is None: if params.nested_get("val_metrics", False): val_score_key = sorted( params.nested_get("val_metrics").keys())[0] self.val_score_key = val_score_key assert key_mapping is not None self.key_mapping = key_mapping self.params = params self.model_cls = model_cls self._optim_builder = optim_builder self.checkpoint_freq = checkpoint_freq self._run = 0 self.kwargs = kwargs
[docs] def setup(self, params, training=True, **kwargs): """ Defines the setup behavior (model, trainer etc.) for training and testing case Parameters ---------- params : :class:`Parameters` the parameters to use for setup training : bool whether to setup for training case or for testing case **kwargs : additional keyword arguments Returns ------- :class:`BaseNetworkTrainer` the created trainer (if ``training=True``) :class:`Predictor` the created predictor (if ``training=False``) See Also -------- * :meth:`BaseExperiment._setup_training` for training setup * :meth:`BaseExperiment._setup_test` for test setup """ if training: return self._setup_training(params, **kwargs) return self._setup_test(params, **kwargs)
[docs] def _setup_training(self, params, **kwargs): """ Handles the setup for training case Parameters ---------- params : :class:`Parameters` the parameters containing the model and training kwargs **kwargs : additional keyword arguments Returns ------- :class:`BaseNetworkTrainer` the created trainer """ model_params = params.permute_training_on_top().model model_kwargs = {**model_params.fixed, **model_params.variable} model = self.model_cls(**model_kwargs) training_params = params.permute_training_on_top().training losses = training_params.nested_get("losses") optimizer_cls = training_params.nested_get("optimizer_cls") optimizer_params = training_params.nested_get("optimizer_params") train_metrics = training_params.nested_get("train_metrics", {}) lr_scheduler_cls = training_params.nested_get("lr_sched_cls", None) lr_scheduler_params = training_params.nested_get("lr_sched_params", {}) val_metrics = training_params.nested_get("val_metrics", {}) # necessary for resuming training from a given path save_path = kwargs.pop("save_path", os.path.join( self.save_path, "checkpoints", "run_%02d" % self._run)) return self.trainer_cls( network=model, save_path=save_path, losses=losses, key_mapping=self.key_mapping, optimizer_cls=optimizer_cls, optimizer_params=optimizer_params, train_metrics=train_metrics, val_metrics=val_metrics, lr_scheduler_cls=lr_scheduler_cls, lr_scheduler_params=lr_scheduler_params, optim_fn=self._optim_builder, save_freq=self.checkpoint_freq, **kwargs )
[docs] def _setup_test(self, params, model, convert_batch_to_npy_fn, prepare_batch_fn, **kwargs): """ Parameters ---------- params : :class:`Parameters` the parameters containing the model and training kwargs (ignored here, just passed for subclassing and unified API) model : :class:`AbstractNetwork` the model to test convert_batch_to_npy_fn : function function to convert a batch of tensors to numpy prepare_batch_fn : function function to convert a batch-dict to a format accepted by the model. This conversion typically includes dtype-conversion, reshaping, wrapping to backend-specific tensors and pushing to correct devices **kwargs : additional keyword arguments Returns ------- :class:`Predictor` the created predictor """ predictor = self.predictor_cls( model=model, key_mapping=self.key_mapping, convert_batch_to_npy_fn=convert_batch_to_npy_fn, prepare_batch_fn=prepare_batch_fn, **kwargs) return predictor
[docs] def run(self, train_data: BaseDataManager, val_data: BaseDataManager = None, params: Parameters = None, **kwargs): """ Setup and run training Parameters ---------- train_data : :class:`BaseDataManager` the data to use for training val_data : :class:`BaseDataManager` or None the data to use for validation (no validation is done if passing None); default: None params : :class:`Parameters` or None the parameters to use for training and model instantiation (will be merged with ``self.params``) **kwargs : additional keyword arguments Returns ------- :class:`AbstractNetwork` The trained network returned by the trainer (usually best network) See Also -------- :class:`BaseNetworkTrainer` for training itself """ params = self._resolve_params(params) kwargs = self._resolve_kwargs(kwargs) params.permute_training_on_top() training_params = params.training trainer = self.setup(params, training=True, **kwargs) self._run += 1 num_epochs = kwargs.get("num_epochs", training_params.nested_get( "num_epochs", self.n_epochs)) if num_epochs is None: num_epochs = self.n_epochs return trainer.train(num_epochs, train_data, val_data, self.val_score_key, kwargs.get("val_score_mode", "lowest"))
[docs] def resume(self, save_path: str, train_data: BaseDataManager, val_data: BaseDataManager = None, params: Parameters = None, **kwargs): """ Resumes a previous training by passing an explicit ``save_path`` instead of generating a new one Parameters ---------- save_path : str path to previous training train_data : :class:`BaseDataManager` the data to use for training val_data : :class:`BaseDataManager` or None the data to use for validation (no validation is done if passing None); default: None params : :class:`Parameters` or None the parameters to use for training and model instantiation (will be merged with ``self.params``) **kwargs : additional keyword arguments Returns ------- :class:`AbstractNetwork` The trained network returned by the trainer (usually best network) See Also -------- :class:`BaseNetworkTrainer` for training itself """ return self.run( train_data=train_data, val_data=val_data, params=params, save_path=save_path, **kwargs)
[docs] def test(self, network, test_data: BaseDataManager, metrics: dict, metric_keys=None, verbose=False, prepare_batch=lambda x: x, convert_fn=lambda x: x, **kwargs): """ Setup and run testing on a given network Parameters ---------- network : :class:`AbstractNetwork` the (trained) network to test test_data : :class:`BaseDataManager` the data to use for testing metrics : dict the metrics to calculate metric_keys : dict of tuples the batch_dict keys to use for each metric to calculate. Should contain a value for each key in ``metrics``. If no values are given for a key, per default ``pred`` and ``label`` will be used for metric calculation verbose : bool verbosity of the test process prepare_batch : function function to convert a batch-dict to a format accepted by the model. This conversion typically includes dtype-conversion, reshaping, wrapping to backend-specific tensors and pushing to correct devices convert_fn : function function to convert a batch of tensors to numpy **kwargs : additional keyword arguments Returns ------- dict all predictions obtained by feeding the ``test_data`` through the ``network`` dict all metrics calculated upon the ``test_data`` and the obtained predictions """ kwargs = self._resolve_kwargs(kwargs) predictor = self.setup(None, training=False, model=network, convert_batch_to_npy_fn=convert_fn, prepare_batch_fn=prepare_batch, **kwargs) # return first item of generator return next(predictor.predict_data_mgr_cache_all(test_data, 1, metrics, metric_keys, verbose))
[docs] def kfold(self, data: BaseDataManager, metrics: dict, num_epochs=None, num_splits=None, shuffle=False, random_seed=None, split_type="random", val_split=0.2, label_key="label", train_kwargs: dict = None, metric_keys: dict = None, test_kwargs: dict = None, params=None, verbose=False, **kwargs): """ Performs a k-Fold cross-validation Parameters ---------- data : :class:`BaseDataManager` the data to use for training(, validation) and testing. Will be split based on ``split_type`` and ``val_split`` metrics : dict dictionary containing the metrics to evaluate during k-fold num_epochs : int or None number of epochs to train (if not given, will either be extracted from ``params``, ``self.parms`` or ``self.n_epochs``) num_splits : int or None the number of splits to extract from ``data``. If None: uses a default of 10 shuffle : bool whether to shuffle the data before splitting or not (implemented by index-shuffling rather than actual data-shuffling to retain potentially lazy-behavior of datasets) random_seed : None seed to seed numpy, the splitting functions and the used backend-framework split_type : str must be one of ['random', 'stratified'] if 'random': uses random data splitting if 'stratified': uses stratified data splitting. Stratification will be based on ``label_key`` val_split : float or None the fraction of the train data to use as validation set. If None: No validation will be done during training; only testing for each fold after the training is complete label_key : str the label to use for stratification. Will be ignored unless ``split_type`` is 'stratified'. Default: 'label' train_kwargs : dict or None kwargs to update the behavior of the :class:`BaseDataManager` containing the train data. If None: empty dict will be passed metric_keys : dict of tuples the batch_dict keys to use for each metric to calculate. Should contain a value for each key in ``metrics``. If no values are given for a key, per default ``pred`` and ``label`` will be used for metric calculation test_kwargs : dict or None kwargs to update the behavior of the :class:`BaseDataManager` containing the test and validation data. If None: empty dict will be passed params : :class:`Parameters`or None the training and model parameters (will be merged with ``self.params``) verbose : bool verbosity **kwargs : additional keyword arguments Returns ------- dict all predictions from all folds dict all metric values from all folds Raises ------ ValueError if ``split_type`` is neither 'random', nor 'stratified' See Also -------- * :class:`sklearn.model_selection.KFold` and :class:`sklearn.model_selection.ShuffleSplit` for random data-splitting * :class:`sklearn.model_selection.StratifiedKFold` and :class:`sklearn.model_selection.StratifiedShuffleSplit` for stratified data-splitting * :meth:`BaseDataManager.update_from_state_dict` for updating the data managers by kwargs * :meth:`BaseExperiment.run` for the training * :meth:`BaseExperiment.test` for the testing Notes ----- using stratified splits may be slow during split-calculation, since each item must be loaded once to obtain the labels necessary for stratification. """ # set number of splits if not specified if num_splits is None: num_splits = 10 logger.warning("num_splits not defined, using default value of \ 10 splits instead ") metrics_test, outputs = {}, {} split_idxs = list(range(len(data.dataset))) if train_kwargs is None: train_kwargs = {} if test_kwargs is None: test_kwargs = {} # switch between differnt kfold types if split_type == "random": split_cls = KFold val_split_cls = ShuffleSplit # split_labels are ignored for random splitting, set them to # split_idxs just ensures same length split_labels = split_idxs elif split_type == "stratified": split_cls = StratifiedKFold val_split_cls = StratifiedShuffleSplit # iterate over dataset to get labels for stratified splitting split_labels = [data.dataset[_idx][label_key] for _idx in split_idxs] else: raise ValueError("split_type must be one of " "['random', 'stratified'], but got: %s" % str(split_type)) fold = split_cls(n_splits=num_splits, shuffle=shuffle, random_state=random_seed) if random_seed is not None: np.random.seed(random_seed) # iterate over folds for idx, (train_idxs, test_idxs) in enumerate( fold.split(split_idxs, split_labels)): # extract data from single manager train_data = data.get_subset(train_idxs) test_data = data.get_subset(test_idxs) train_data.update_state_from_dict(train_kwargs) test_data.update_state_from_dict(test_kwargs) val_data = None if val_split is not None: if split_type == "random": # split_labels are ignored for random splitting, set them # to split_idxs just ensures same length train_labels = train_idxs elif split_type == "stratified": # iterate over dataset to get labels for stratified # splitting train_labels = [train_data.dataset[_idx][label_key] for _idx in train_idxs] else: raise ValueError("split_type must be one of " "['random', 'stratified'], but got: %s" % str(split_type)) _val_split = val_split_cls(n_splits=1, test_size=val_split, random_state=random_seed) for _train_idxs, _val_idxs in _val_split.split(train_idxs, train_labels): val_data = train_data.get_subset(_val_idxs) val_data.update_state_from_dict(test_kwargs) train_data = train_data.get_subset(_train_idxs) model = self.run(train_data=train_data, val_data=val_data, params=params, num_epochs=num_epochs, fold=idx, **kwargs) _outputs, _metrics_test = self.test(model, test_data, metrics=metrics, metric_keys=metric_keys, verbose=verbose) outputs[str(idx)] = _outputs metrics_test[str(idx)] = _metrics_test return outputs, metrics_test
def __str__(self): """ Converts :class:`BaseExperiment` to string representation Returns ------- str representation of class """ s = "Experiment:\n" for k, v in vars(self).items(): s += "\t{} = {}\n".format(k, v) return s def __call__(self, *args, **kwargs): """ Call :meth:`BaseExperiment.run` Parameters ---------- *args : positional arguments **kwargs : keyword arguments Returns ------- :class:`BaseNetworkTrainer` trainer of trained network """ return self.run(*args, **kwargs)
[docs] def save(self): """ Saves the Whole experiments """ with open(os.path.join(self.save_path, "experiment.delira.pkl"), "wb") as f: pickle.dump(self, f) self.params.save(os.path.join(self.save_path, "parameters"))
[docs] @staticmethod def load(file_name): """ Loads whole experiment Parameters ---------- file_name : str file_name to load the experiment from """ with open(file_name, "rb") as f: return pickle.load(f)
[docs] def _resolve_params(self, params: typing.Union[Parameters, None]): """ Merges the given params with ``self.params``. If the same argument is given in both params, the one from the currently given parameters is used here Parameters ---------- params : :class:`Parameters` or None the parameters to merge with ``self.params`` Returns ------- :class:`Parameters` the merged parameter instance """ if params is None: params = Parameters() if hasattr(self, "params") and isinstance(self.params, Parameters): _params = params params = self.params params.update(_params) return params
[docs] def _resolve_kwargs(self, kwargs: typing.Union[dict, None]): """ Merges given kwargs with ``self.kwargs`` If same argument is present in both kwargs, the one from the given kwargs will be used here Parameters ---------- kwargs : dict the given kwargs to merge with self.kwargs Returns ------- dict merged kwargs """ if kwargs is None: kwargs = {} if hasattr(self, "kwargs") and isinstance(self.kwargs, dict): _kwargs = kwargs kwargs = self.kwargs kwargs.update(_kwargs) return kwargs
def __getstate__(self): return vars(self) def __setstate__(self, state): vars(self).update(state)
if "TORCH" in get_backends(): from .train_utils import create_optims_default_pytorch, \ convert_torch_tensor_to_npy from .pytorch_trainer import PyTorchNetworkTrainer as PTNetworkTrainer from ..models import AbstractPyTorchNetwork import torch
[docs] class PyTorchExperiment(BaseExperiment): def __init__(self, params: typing.Union[str, Parameters], model_cls: AbstractPyTorchNetwork, n_epochs=None, name=None, save_path=None, key_mapping=None, val_score_key=None, optim_builder=create_optims_default_pytorch, checkpoint_freq=1, trainer_cls=PTNetworkTrainer, **kwargs): """ Parameters ---------- params : :class:`Parameters` or str the training parameters, if string is passed, it is treated as a path to a pickle file, where the parameters are loaded from model_cls : Subclass of :class:`AbstractPyTorchNetwork` the class implementing the model to train n_epochs : int or None the number of epochs to train, if None: can be specified later during actual training name : str or None the Experiment's name save_path : str or None the path to save the results and checkpoints to. if None: Current working directory will be used key_mapping : dict mapping between data_dict and model inputs (necessary for prediction with :class:`Predictor`-API), if no keymapping is given, a default key_mapping of {"x": "data"} will be used here val_score_key : str or None key defining which metric to use for validation (determining best model and scheduling lr); if None: No validation-based operations will be done (model might still get validated, but validation metrics can only be logged and not used further) optim_builder : function Function returning a dict of backend-specific optimizers. defaults to :func:`create_optims_default_pytorch` checkpoint_freq : int frequency of saving checkpoints (1 denotes saving every epoch, 2 denotes saving every second epoch etc.); default: 1 trainer_cls : subclass of :class:`PyTorchNetworkTrainer` the trainer class to use for training the model, defaults to :class:`PyTorchNetworkTrainer` **kwargs : additional keyword arguments """ if key_mapping is None: key_mapping = {"x": "data"} super().__init__(params=params, model_cls=model_cls, n_epochs=n_epochs, name=name, save_path=save_path, key_mapping=key_mapping, val_score_key=val_score_key, optim_builder=optim_builder, checkpoint_freq=checkpoint_freq, trainer_cls=trainer_cls, **kwargs)
[docs] def kfold(self, data: BaseDataManager, metrics: dict, num_epochs=None, num_splits=None, shuffle=False, random_seed=None, split_type="random", val_split=0.2, label_key="label", train_kwargs: dict = None, test_kwargs: dict = None, metric_keys: dict = None, params=None, verbose=False, **kwargs): """ Performs a k-Fold cross-validation Parameters ---------- data : :class:`BaseDataManager` the data to use for training(, validation) and testing. Will be split based on ``split_type`` and ``val_split`` metrics : dict dictionary containing the metrics to evaluate during k-fold num_epochs : int or None number of epochs to train (if not given, will either be extracted from ``params``, ``self.parms`` or ``self.n_epochs``) num_splits : int or None the number of splits to extract from ``data``. If None: uses a default of 10 shuffle : bool whether to shuffle the data before splitting or not (implemented by index-shuffling rather than actual data-shuffling to retain potentially lazy-behavior of datasets) random_seed : None seed to seed numpy, the splitting functions and the used backend-framework split_type : str must be one of ['random', 'stratified'] if 'random': uses random data splitting if 'stratified': uses stratified data splitting. Stratification will be based on ``label_key`` val_split : float or None the fraction of the train data to use as validation set. If None: No validation will be done during training; only testing for each fold after the training is complete label_key : str the label to use for stratification. Will be ignored unless ``split_type`` is 'stratified'. Default: 'label' train_kwargs : dict or None kwargs to update the behavior of the :class:`BaseDataManager` containing the train data. If None: empty dict will be passed metric_keys : dict of tuples the batch_dict keys to use for each metric to calculate. Should contain a value for each key in ``metrics``. If no values are given for a key, per default ``pred`` and ``label`` will be used for metric calculation test_kwargs : dict or None kwargs to update the behavior of the :class:`BaseDataManager` containing the test and validation data. If None: empty dict will be passed params : :class:`Parameters`or None the training and model parameters (will be merged with ``self.params``) verbose : bool verbosity **kwargs : additional keyword arguments Returns ------- dict all predictions from all folds dict all metric values from all folds Raises ------ ValueError if ``split_type`` is neither 'random', nor 'stratified' See Also -------- * :class:`sklearn.model_selection.KFold` and :class:`sklearn.model_selection.ShuffleSplit` for random data-splitting * :class:`sklearn.model_selection.StratifiedKFold` and :class:`sklearn.model_selection.StratifiedShuffleSplit` for stratified data-splitting * :meth:`BaseDataManager.update_from_state_dict` for updating the data managers by kwargs * :meth:`BaseExperiment.run` for the training * :meth:`BaseExperiment.test` for the testing Notes ----- using stratified splits may be slow during split-calculation, since each item must be loaded once to obtain the labels necessary for stratification. """ # seed torch backend if random_seed is not None: torch.manual_seed(random_seed) return super().kfold( data=data, metrics=metrics, num_epochs=num_epochs, num_splits=num_splits, shuffle=shuffle, random_seed=random_seed, split_type=split_type, val_split=val_split, label_key=label_key, train_kwargs=train_kwargs, test_kwargs=test_kwargs, metric_keys=metric_keys, params=params, verbose=verbose, **kwargs)
[docs] def test(self, network, test_data: BaseDataManager, metrics: dict, metric_keys=None, verbose=False, prepare_batch=None, convert_fn=None, **kwargs): """ Setup and run testing on a given network Parameters ---------- network : :class:`AbstractNetwork` the (trained) network to test test_data : :class:`BaseDataManager` the data to use for testing metrics : dict the metrics to calculate metric_keys : dict of tuples the batch_dict keys to use for each metric to calculate. Should contain a value for each key in ``metrics``. If no values are given for a key, per default ``pred`` and ``label`` will be used for metric calculation verbose : bool verbosity of the test process prepare_batch : function function to convert a batch-dict to a format accepted by the model. This conversion typically includes dtype-conversion, reshaping, wrapping to backend-specific tensors and pushing to correct devices. If not further specified uses the ``network``'s ``prepare_batch`` with CPU devices convert_fn : function function to convert a batch of tensors to numpy if not specified defaults to :func:`convert_torch_tensor_to_npy` **kwargs : additional keyword arguments Returns ------- dict all predictions obtained by feeding the ``test_data`` through the ``network`` dict all metrics calculated upon the ``test_data`` and the obtained predictions """ # use backend-specific and model-specific prepare_batch fn # (runs on same device as passed network per default) device = next(network.parameters()).device if prepare_batch is None: prepare_batch = partial(network.prepare_batch, input_device=device, output_device=device) # switch to backend-specific convert function if convert_fn is None: convert_fn = convert_torch_tensor_to_npy return super().test(network=network, test_data=test_data, metrics=metrics, metric_keys=metric_keys, verbose=verbose, prepare_batch=prepare_batch, convert_fn=convert_fn, **kwargs)
if "TF" in get_backends(): from .tf_trainer import TfNetworkTrainer from .train_utils import create_optims_default_tf, \ convert_tf_tensor_to_npy, initialize_uninitialized from ..models import AbstractTfNetwork from .parameters import Parameters import tensorflow as tf
[docs] class TfExperiment(BaseExperiment): def __init__(self, params: typing.Union[str, Parameters], model_cls: AbstractTfNetwork, n_epochs=None, name=None, save_path=None, key_mapping=None, val_score_key=None, optim_builder=create_optims_default_tf, checkpoint_freq=1, trainer_cls=TfNetworkTrainer, **kwargs): """ Parameters ---------- params : :class:`Parameters` or str the training parameters, if string is passed, it is treated as a path to a pickle file, where the parameters are loaded from model_cls : Subclass of :class:`AbstractTfNetwork` the class implementing the model to train n_epochs : int or None the number of epochs to train, if None: can be specified later during actual training name : str or None the Experiment's name save_path : str or None the path to save the results and checkpoints to. if None: Current working directory will be used key_mapping : dict mapping between data_dict and model inputs (necessary for prediction with :class:`Predictor`-API), if no keymapping is given, a default key_mapping of {"images": "data"} will be used here val_score_key : str or None key defining which metric to use for validation (determining best model and scheduling lr); if None: No validation-based operations will be done (model might still get validated, but validation metrics can only be logged and not used further) optim_builder : function Function returning a dict of backend-specific optimizers. defaults to :func:`create_optims_default_tf` checkpoint_freq : int frequency of saving checkpoints (1 denotes saving every epoch, 2 denotes saving every second epoch etc.); default: 1 trainer_cls : subclass of :class:`TfNetworkTrainer` the trainer class to use for training the model, defaults to :class:`TfNetworkTrainer` **kwargs : additional keyword arguments """ if key_mapping is None: key_mapping = {"images": "data"} super().__init__(params=params, model_cls=model_cls, n_epochs=n_epochs, name=name, save_path=save_path, key_mapping=key_mapping, val_score_key=val_score_key, optim_builder=optim_builder, checkpoint_freq=checkpoint_freq, trainer_cls=trainer_cls, **kwargs)
[docs] def setup(self, params, training=True, **kwargs): """ Defines the setup behavior (model, trainer etc.) for training and testing case Parameters ---------- params : :class:`Parameters` the parameters to use for setup training : bool whether to setup for training case or for testing case **kwargs : additional keyword arguments Returns ------- :class:`BaseNetworkTrainer` the created trainer (if ``training=True``) :class:`Predictor` the created predictor (if ``training=False``) See Also -------- * :meth:`BaseExperiment._setup_training` for training setup * :meth:`BaseExperiment._setup_test` for test setup """ tf.reset_default_graph() return super().setup(params=params, training=training, **kwargs)
[docs] def kfold(self, data: BaseDataManager, metrics: dict, num_epochs=None, num_splits=None, shuffle=False, random_seed=None, split_type="random", val_split=0.2, label_key="label", train_kwargs: dict = None, test_kwargs: dict = None, metric_keys: dict = None, params=None, verbose=False, **kwargs): """ Performs a k-Fold cross-validation Parameters ---------- data : :class:`BaseDataManager` the data to use for training(, validation) and testing. Will be split based on ``split_type`` and ``val_split`` metrics : dict dictionary containing the metrics to evaluate during k-fold num_epochs : int or None number of epochs to train (if not given, will either be extracted from ``params``, ``self.parms`` or ``self.n_epochs``) num_splits : int or None the number of splits to extract from ``data``. If None: uses a default of 10 shuffle : bool whether to shuffle the data before splitting or not (implemented by index-shuffling rather than actual data-shuffling to retain potentially lazy-behavior of datasets) random_seed : None seed to seed numpy, the splitting functions and the used backend-framework split_type : str must be one of ['random', 'stratified'] if 'random': uses random data splitting if 'stratified': uses stratified data splitting. Stratification will be based on ``label_key`` val_split : float or None the fraction of the train data to use as validation set. If None: No validation will be done during training; only testing for each fold after the training is complete label_key : str the label to use for stratification. Will be ignored unless ``split_type`` is 'stratified'. Default: 'label' train_kwargs : dict or None kwargs to update the behavior of the :class:`BaseDataManager` containing the train data. If None: empty dict will be passed metric_keys : dict of tuples the batch_dict keys to use for each metric to calculate. Should contain a value for each key in ``metrics``. If no values are given for a key, per default ``pred`` and ``label`` will be used for metric calculation test_kwargs : dict or None kwargs to update the behavior of the :class:`BaseDataManager` containing the test and validation data. If None: empty dict will be passed params : :class:`Parameters`or None the training and model parameters (will be merged with ``self.params``) verbose : bool verbosity **kwargs : additional keyword arguments Returns ------- dict all predictions from all folds dict all metric values from all folds Raises ------ ValueError if ``split_type`` is neither 'random', nor 'stratified' See Also -------- * :class:`sklearn.model_selection.KFold` and :class:`sklearn.model_selection.ShuffleSplit` for random data-splitting * :class:`sklearn.model_selection.StratifiedKFold` and :class:`sklearn.model_selection.StratifiedShuffleSplit` for stratified data-splitting * :meth:`BaseDataManager.update_from_state_dict` for updating the data managers by kwargs * :meth:`BaseExperiment.run` for the training * :meth:`BaseExperiment.test` for the testing Notes ----- using stratified splits may be slow during split-calculation, since each item must be loaded once to obtain the labels necessary for stratification. """ # seed tf backend if random_seed is not None: tf.set_random_seed(random_seed) return super().kfold( data=data, metrics=metrics, num_epochs=num_epochs, num_splits=num_splits, shuffle=shuffle, random_seed=random_seed, split_type=split_type, val_split=val_split, label_key=label_key, train_kwargs=train_kwargs, test_kwargs=test_kwargs, metric_keys=metric_keys, params=params, verbose=verbose, **kwargs)
[docs] def test(self, network, test_data: BaseDataManager, metrics: dict, metric_keys=None, verbose=False, prepare_batch=lambda x: x, convert_fn=None, **kwargs): """ Setup and run testing on a given network Parameters ---------- network : :class:`AbstractNetwork` the (trained) network to test test_data : :class:`BaseDataManager` the data to use for testing metrics : dict the metrics to calculate metric_keys : dict of tuples the batch_dict keys to use for each metric to calculate. Should contain a value for each key in ``metrics``. If no values are given for a key, per default ``pred`` and ``label`` will be used for metric calculation verbose : bool verbosity of the test process prepare_batch : function function to convert a batch-dict to a format accepted by the model. This conversion typically includes dtype-conversion, reshaping, wrapping to backend-specific tensors and pushing to correct devices. If not further specified uses the ``network``'s ``prepare_batch`` with CPU devices convert_fn : function function to convert a batch of tensors to numpy if not specified defaults to :func:`convert_torch_tensor_to_npy` **kwargs : additional keyword arguments Returns ------- dict all predictions obtained by feeding the ``test_data`` through the ``network`` dict all metrics calculated upon the ``test_data`` and the obtained predictions """ # specify convert_fn to correct backend function if convert_fn is None: convert_fn = convert_tf_tensor_to_npy initialize_uninitialized(network._sess) return super().test(network=network, test_data=test_data, metrics=metrics, metric_keys=metric_keys, verbose=verbose, prepare_batch=prepare_batch, convert_fn=convert_fn, **kwargs)