Source code for delira.training.parameters

from ..utils import LookupConfig
import pickle
import yaml

[docs]class Parameters(LookupConfig): """ Class Containing all variable and fixed parameters for training and model instantiation See Also -------- ``trixi.util.Config`` """ def __init__(self, fixed_params={"model": {}, "training": {}}, variable_params={"model": {}, "training": {}}): """ Parameters ---------- fixed_params : dict fixed parameters (are not variated using hyperparameter search) variable_params: dict variable parameters (can be variated by a hyperparameter search) """ super().__init__(fixed=fixed_params, variable=variable_params)
[docs] def permute_hierarchy(self): """ switches hierarchy Returns ------- Parameters the class with a permuted hierarchy Raises ------ AttributeError if no valid hierarchy is found """ if self.variability_on_top: fixed = self.pop("fixed") variable = self.pop("variable") model = { "fixed": fixed.pop("model"), "variable": variable.pop("model")} training = {"fixed": fixed.pop("training"), "variable": variable.pop("training")} self.model = model self.training = training elif self.training_on_top: model = self.pop("model") training = self.pop("training") fixed = { "model": model.pop("fixed"), "training": training.pop("fixed")} variable = { "model": model.pop("variable"), "training": training.pop("variable")} self.fixed = fixed self.variable = variable else: return AttributeError( "%s must either have keys ('model', 'training') or " "('fixed', 'variable')" % self.__class__.__name__) return self
[docs] def permute_training_on_top(self): """ permutes hierarchy in a way that the training-model hierarchy is on top Returns ------- Parameters Parameters with permuted hierarchy """ if self.training_on_top: return self else: return self.permute_hierarchy()
[docs] def permute_variability_on_top(self): """ permutes hierarchy in a way that the training-model hierarchy is on top Returns ------- Parameters Parameters with permuted hierarchy """ if self.variability_on_top: return self else: return self.permute_hierarchy()
@property def hierarchy(self): """ Returns the current hierarchy Returns ------- str current hierarchy """ if self.variability_on_top: hierarchy = "variability\n|\n->\ttraining" elif self.training_on_top: hierarchy = "training\n|\n->\tvariability" else: hierarchy = "no valid hierarchy" return hierarchy
[docs] def permute_to_hierarchy(self, hierarchy: str): """ Permute hierarchy to match the specified hierarchy Parameters ---------- hierarchy : str target hierarchy Raises ------ ValueError Specified hierarchy is invalid Returns ------- Parameters parameters with proper hierarchy """ if hierarchy == "variability\n|\n->\ttraining": return self.permute_training_on_top() elif hierarchy == "training\n|\n->\tvariability": return self.permute_variability_on_top() else: raise ValueError("Invalid Hierarchy: %s" % hierarchy)
@property def variability_on_top(self): """ Return whether the variability is on top Returns ------- bool whether variability is on top """ return hasattr(self, "fixed") and hasattr(self, "variable") @property def training_on_top(self): """ Return whether the training hierarchy is on top Returns ------- bool whether training is on top """ return hasattr(self, "model") and hasattr(self, "training")
[docs] def save(self, filepath: str): """ Saves class to given filepath (YAML + Pickle) Parameters ---------- filepath : str file to save data to """ if not (filepath.endswith(".yaml") or filepath.endswith(".yml")): filepath = filepath + ".yml" try: with open(filepath, "w") as f: yaml.dump(self.permute_variability_on_top(), f) except TypeError: pass finally: with open(filepath.replace(".yaml", "").replace(".yml", ""), "wb") as f: pickle.dump(self, f)
def __str__(self): """ String Representation of class Returns ------- str string representation """ s = "Parameters:\n" for k, v in vars(self).items(): try: s += "\t{} = {}\n".format(k, v) except TypeError: s += "\t{} = {}\n".format(k, v.__class__.__name__) return s