Source code for delira.training.parameters

import pickle
from copy import deepcopy, copy

import yaml

from ..utils import LookupConfig


[docs]class Parameters(LookupConfig): """ Class Containing all variable and fixed parameters for training and model instantiation See Also -------- ``trixi.util.Config`` """ def __init__(self, fixed_params={"model": {}, "training": {}}, variable_params={"model": {}, "training": {}}): """ Parameters ---------- fixed_params : dict fixed parameters (are not variated using hyperparameter search) variable_params: dict variable parameters (can be variated by a hyperparameter search) """ super().__init__(fixed=fixed_params, variable=variable_params)
[docs] def permute_hierarchy(self): """ switches hierarchy Returns ------- Parameters the class with a permuted hierarchy Raises ------ AttributeError if no valid hierarchy is found """ if self.variability_on_top: fixed = self.pop("fixed") variable = self.pop("variable") model = { "fixed": fixed.pop("model"), "variable": variable.pop("model")} training = {"fixed": fixed.pop("training"), "variable": variable.pop("training")} self.model = model self.training = training elif self.training_on_top: model = self.pop("model") training = self.pop("training") fixed = { "model": model.pop("fixed"), "training": training.pop("fixed")} variable = { "model": model.pop("variable"), "training": training.pop("variable")} self.fixed = fixed self.variable = variable else: return AttributeError( "%s must either have keys ('model', 'training') or " "('fixed', 'variable')" % self.__class__.__name__) return self
[docs] def permute_training_on_top(self): """ permutes hierarchy in a way that the training-model hierarchy is on top Returns ------- Parameters Parameters with permuted hierarchy """ if self.training_on_top: return self else: return self.permute_hierarchy()
[docs] def permute_variability_on_top(self): """ permutes hierarchy in a way that the training-model hierarchy is on top Returns ------- Parameters Parameters with permuted hierarchy """ if self.variability_on_top: return self else: return self.permute_hierarchy()
@property def hierarchy(self): """ Returns the current hierarchy Returns ------- str current hierarchy """ if self.variability_on_top: hierarchy = "variability\n|\n->\ttraining" elif self.training_on_top: hierarchy = "training\n|\n->\tvariability" else: hierarchy = "no valid hierarchy" return hierarchy
[docs] def permute_to_hierarchy(self, hierarchy: str): """ Permute hierarchy to match the specified hierarchy Parameters ---------- hierarchy : str target hierarchy Raises ------ ValueError Specified hierarchy is invalid Returns ------- Parameters parameters with proper hierarchy """ if hierarchy == "variability\n|\n->\ttraining": return self.permute_training_on_top() elif hierarchy == "training\n|\n->\tvariability": return self.permute_variability_on_top() else: raise ValueError("Invalid Hierarchy: %s" % hierarchy)
@property def variability_on_top(self): """ Return whether the variability is on top Returns ------- bool whether variability is on top """ return hasattr(self, "fixed") and hasattr(self, "variable") @property def training_on_top(self): """ Return whether the training hierarchy is on top Returns ------- bool whether training is on top """ return hasattr(self, "model") and hasattr(self, "training")
[docs] def save(self, filepath: str): """ Saves class to given filepath (YAML + Pickle) Parameters ---------- filepath : str file to save data to """ if not (filepath.endswith(".yaml") or filepath.endswith(".yml")): filepath = filepath + ".yml" try: with open(filepath, "w") as f: yaml.dump(self.permute_variability_on_top(), f) except TypeError: pass finally: with open(filepath.replace(".yaml", "").replace(".yml", ""), "wb") as f: pickle.dump(self, f)
[docs] def update(self, dict_like, deep=False, ignore=None, allow_dict_overwrite=True): """Update entries in the Parameters Parameters ---------- dict_like : dict Update source deep : bool Make deep copies of all references in the source. ignore : Iterable Iterable of keys to ignore in update allow_dict_overwrite : bool Allow overwriting with dict. Regular dicts only update on the highest level while we recurse and merge Configs. This flag decides whether it is possible to overwrite a 'regular' value with a dict/Config at lower levels. See examples for an illustration of the difference Examples: --------- The following illustrates the update behaviour if :obj:allow_dict_overwrite is active. If it isn't, an AttributeError would be raised, originating from trying to update "string":: config1 = Config(config={ "lvl0": { "lvl1": "string", "something": "else" } }) config2 = Config(config={ "lvl0": { "lvl1": { "lvl2": "string" } } }) config1.update(config2, allow_dict_overwrite=True) >>>config1 { "lvl0": { "lvl1": { "lvl2": "string" }, "something": "else" } } """ empty = self.variability_on_top == self.training_on_top if not empty: variability_on_top = self.variability_on_top if variability_on_top: if isinstance(dict_like, Parameters): dict_like_variability_on_top = dict_like.variability_on_top dict_like = dict_like.permute_variability_on_top() else: if ("fixed" not in dict_like.keys() and "variable" not in dict_like.keys()): raise RuntimeError("Unsafe to Update from dict with " "another structre as current " "parameters") else: if isinstance(dict_like, Parameters): dict_like_variability_on_top = dict_like.variability_on_top dict_like = dict_like.permute_training_on_top() else: if ("model" not in dict_like.keys() and "training" not in dict_like.keys()): raise RuntimeError("Unsafe to Update from dict with " "another structre as current " "parameters") super().update(dict_like=dict_like, deep=deep, ignore=ignore, allow_dict_overwrite=allow_dict_overwrite) if not empty and isinstance(dict_like, Parameters): # restore original permutation of dict_like if variability_on_top and not dict_like_variability_on_top: # dict_like changed to variability_on_top dict_like.permute_training_on_top() elif not variability_on_top and dict_like_variability_on_top: # dict_like changed to training_on_top dict_like.permute_variability_on_top()
def __str__(self): """ String Representation of class Returns ------- str string representation """ s = "Parameters:\n" for k, v in vars(self).items(): try: s += "\t{} = {}\n".format(k, v) except TypeError: s += "\t{} = {}\n".format(k, v.__class__.__name__) return s def __copy__(self): """ Enables shallow copy Returns ------- :class:`Parameters` copied parameters """ var_top = self.variability_on_top _params = Parameters() _params.update(copy(dict(self.permute_variability_on_top()))) if var_top: return _params.permute_variability_on_top() else: # restore original perumation self.permute_training_on_top() return _params.permute_training_on_top() def __deepcopy__(self, memo): """ Enables deepcopy Returns ------- :class:`Parameters` deepcopied parameters """ var_top = self.variability_on_top _params = Parameters() _params.update(deepcopy(dict(self.permute_variability_on_top()), memo=memo)) if var_top: return _params.permute_variability_on_top() else: # restore original perumation self.permute_training_on_top() return _params.permute_training_on_top()