from ..utils import now
from ..data_loading import BaseDataManager, ConcatDataManager
from .. import __version__ as delira_version
from .parameters import Parameters
from trixi.experiment import Experiment as TrixiExperiment
import os
import logging
import yaml
import typing
import numpy as np
import pickle
from abc import abstractmethod
from sklearn.model_selection import KFold
from sklearn.model_selection import train_test_split
from datetime import datetime
from inspect import signature
from functools import partial
logger = logging.getLogger(__name__)
NOT_IMPLEMENTED_KEYS = []
[docs]class AbstractExperiment(TrixiExperiment):
"""
Abstract Class Representing a single Experiment (must be subclassed for
each Backend)
See Also
--------
:class:`PyTorchExperiment`
"""
@abstractmethod
def __init__(self, n_epochs, *args, **kwargs):
"""
Parameters
----------
n_epochs : int
number of epochs to train
*args :
positional arguments
**kwargs :
keyword arguments
"""
super().__init__(n_epochs)
self._run = 0
[docs] @abstractmethod
def setup(self, *args, **kwargs):
"""
Abstract Method to setup a :class:`AbstractNetworkTrainer`
Raises
------
NotImplementedError
if not overwritten by subclass
"""
raise NotImplementedError()
[docs] @abstractmethod
def run(self, train_data: typing.Union[BaseDataManager, ConcatDataManager],
val_data: typing.Optional[typing.Union[BaseDataManager,
ConcatDataManager]] = None,
params: typing.Optional[Parameters] = None,
**kwargs):
"""
trains single model
Parameters
----------
train_data : :class:`BaseDataManager` or :class:`ConcatDataManager`
data manager containing the training data
val_data : :class:`BaseDataManager` or :class:`ConcatDataManager`
data manager containing the validation data
parameters : :class:`Parameters`, optional
Class containing all parameters (defaults to None).
If not specified, the parameters fall back to the ones given during
class initialization
Raises
------
NotImplementedError
If not overwritten in subclass
"""
raise NotImplementedError()
[docs] def kfold(self, num_epochs: int, data: typing.List[BaseDataManager],
num_splits=None, shuffle=False, random_seed=None, **kwargs):
"""
Runs K-Fold Crossvalidation
Parameters
----------
num_epochs : int
number of epochs to train the model
data : list of BaseDataManager
list of datamanagers (will be split for crossvalidation)
num_splits : None or int
number of splits for kfold
if None: len(data) splits will be validated
shuffle : bool
whether or not to shuffle indices for kfold
random_seed : None or int
random seed used to seed the kfold (if shuffle is true),
pytorch and numpy
**kwargs :
additional keyword arguments (completely passed to self.run())
"""
if num_splits is None:
num_splits = len(data)
fold = KFold(n_splits=num_splits, shuffle=shuffle,
random_state=random_seed)
if random_seed is not None:
torch.manual_seed(random_seed)
np.random.seed(random_seed)
for idx, (train_idxs, test_idxs) in enumerate(fold.split(data)):
self.run(ConcatDataManager(
[data[_idx] for _idx in train_idxs]),
ConcatDataManager([data[_idx] for _idx in test_idxs]),
num_epochs=num_epochs,
fold=idx,
**kwargs)
def __str__(self):
"""
Converts :class:`AbstractExperiment` to string representation
Returns
-------
str
representation of class
"""
s = "Experiment:\n"
for k, v in vars(self).items():
s += "\t{} = {}\n".format(k, v)
return s
def __call__(self, *args, **kwargs):
"""
Call :meth:`AbstractExperiment.run`
Parameters
----------
*args :
positional arguments
**kwargs :
keyword arguments
Returns
-------
:class:`AbstractNetworkTrainer`
trainer of trained network
"""
return self.run(*args, **kwargs)
[docs] @abstractmethod
def save(self):
"""
Saves the Whole experiments
Raises
------
NotImplementedError
If not overwritten in subclass
"""
raise NotImplementedError()
[docs] @staticmethod
@abstractmethod
def load(file_name):
"""
Loads whole experiment
Parameters
----------
file_name : str
file_name to load the experiment from
Raises
------
NotImplementedError
if not overwritten in subclass
"""
raise NotImplementedError()
try:
import torch
from .train_utils import create_optims_default_pytorch
from .pytorch_trainer import PyTorchNetworkTrainer as PTNetworkTrainer
from ..models import AbstractPyTorchNetwork
[docs] class PyTorchExperiment(AbstractExperiment):
"""
Single Experiment for PyTorch Backend
See Also
--------
:class:`AbstractExperiment`
"""
def __init__(self,
params: Parameters,
model_cls: AbstractPyTorchNetwork,
name=None,
save_path=None,
val_score_key=None,
optim_builder=create_optims_default_pytorch,
checkpoint_freq=1,
trainer_cls=PTNetworkTrainer,
**kwargs
):
"""
Parameters
----------
params : :class:`Parameters`
the training and model parameters
model_cls :
the class to instantiate models
name : str
the experiment's name,
default: None -> "UnnamedExperiment"
save_path : str
the path to save the experiment to
(a date-time signature will be appended),
default: None -> Use current working dir
val_score_key : str or None
key to access the metric to monitor for model
selection and callbacks (often starts with "val_")
optim_builder : function
function returning a dictionary of optimizers
defaults to :function:`create_optims_default_pytorch`
checkpoint_freq : int
save checkpoint after each n epochs
(if set to 1, checkpoints will be saved after each epoch,
if set to 2, checkpoints will be saved after each
2 epochs etc.)
trainer_cls :
class defining the actual trainer,
defaults to :class:`PyTorchNetworkTrainer`,
which should be suitable for most cases,
but can easily be overwritten and exchanged if necessary
**kwargs :
additional keyword arguments
"""
if isinstance(params, str):
with open(params, "rb") as f:
params = pickle.load(f)
n_epochs = params.nested_get("num_epochs")
AbstractExperiment.__init__(self, n_epochs)
if name is None:
name = "UnnamedExperiment"
self.name = name
if save_path is None:
save_path = os.path.abspath(".")
self.save_path = os.path.join(save_path, name,
str(datetime.now().strftime(
"%y-%m-%d_%H-%M-%S")))
if os.path.isdir(self.save_path):
logger.warning("Save Path %s already exists")
os.makedirs(self.save_path, exist_ok=True)
self.trainer_cls = trainer_cls
if val_score_key is None and params.nested_get("metrics"):
val_score_key = sorted(params.nested_get("metrics").keys())[0]
self.val_score_key = val_score_key
self.params = params
self.model_cls = model_cls
self.kwargs = kwargs
self._optim_builder = optim_builder
self.checkpoint_freq = checkpoint_freq
self._run = 0
# log HyperParameters
logger.info({"text": {"text":
str(params) + "\n\tmodel_class = %s"
% model_cls.__class__.__name__}})
[docs] def setup(self, params: Parameters, **kwargs):
"""
Perform setup of Network Trainer
Parameters
----------
params : :class:`Parameters`
the parameters to construct a model and network trainer
**kwargs :
keyword arguments
"""
model_params = params.permute_training_on_top().model
model_kwargs = {}
for key in signature(self.model_cls.__init__).parameters.keys():
if key in ["self", "args", "kwargs"]:
continue
try:
model_kwargs[key] = model_params.nested_get(key)
except KeyError:
pass
model = self.model_cls(**model_kwargs)
training_params = params.permute_training_on_top().training
criterions = training_params.nested_get("criterions")
optimizer_cls = training_params.nested_get("optimizer_cls")
optimizer_params = training_params.nested_get("optimizer_params")
metrics = training_params.nested_get("metrics")
lr_scheduler_cls = training_params.nested_get("lr_sched_cls")
lr_scheduler_params = training_params.nested_get("lr_sched_params")
return self.trainer_cls(
network=model,
save_path=os.path.join(
self.save_path,
"checkpoints",
"run_%02d" % self._run),
criterions=criterions,
optimizer_cls=optimizer_cls,
optimizer_params=optimizer_params,
metrics=metrics,
lr_scheduler_cls=lr_scheduler_cls,
lr_scheduler_params=lr_scheduler_params,
optim_fn=self._optim_builder,
save_freq=self.checkpoint_freq,
**self.kwargs,
**kwargs
)
[docs] def run(self,
train_data: typing.Union[BaseDataManager, ConcatDataManager],
val_data: typing.Union[BaseDataManager, ConcatDataManager, None],
params: typing.Optional[Parameters] = None,
**kwargs):
"""
trains single model
Parameters
----------
train_data : BaseDataManager or ConcatDataManager
holds the trainset
val_data : BaseDataManager or ConcatDataManager or None
holds the validation set (if None: Model will not be validated)
params : :class:`Parameters`
the parameters to construct a model and network trainer
**kwargs :
holds additional keyword arguments
(which are completly passed to the trainers init)
Returns
-------
:class:`AbstractNetworkTrainer`
trainer of trained network
Raises
------
ValueError
Class has no Attribute ``params`` and no parameters were given as
function argument
"""
if params is None:
if hasattr(self, "params"):
params = self.params
else:
raise ValueError("No parameters given")
else:
self.params = params
training_params = params.permute_training_on_top().training
trainer = self.setup(params, **kwargs)
self._run += 1
num_epochs = kwargs.get("num_epochs", training_params.nested_get(
"num_epochs"))
return trainer.train(num_epochs, train_data, val_data,
self.val_score_key,
self.kwargs.get("val_score_mode",
"lowest")
)
[docs] def save(self):
"""
Saves the Whole experiments
"""
with open(os.path.join(self.save_path, "experiment.delira.pkl"),
"wb") as f:
pickle.dump(self, f)
self.params.save(os.path.join(self.save_path, "parameters"))
[docs] @staticmethod
def load(file_name):
"""
Loads whole experiment
Parameters
----------
file_name : str
file_name to load the experiment from
"""
with open(file_name, "rb") as f:
return pickle.load(f)
def __getstate__(self):
return vars(self)
def __setstate__(self, state):
vars(self).update(state)
except ImportError as e:
raise e