Source code for delira.training.base_trainer

from abc import abstractmethod
import logging
import pickle
import typing

from ..data_loading.data_manager import Augmenter

from .predictor import Predictor
from .callbacks import AbstractCallback
from ..models import AbstractNetwork

import numpy as np
import os
from tqdm import tqdm
from delira.logging import TrixiHandler

logger = logging.getLogger(__name__)


[docs]class BaseNetworkTrainer(Predictor): """ Defines a Base API and basic functions for Network Trainers See Also -------- :class:`PyTorchNetworkTrainer` :class:`TfNetworkTrainer` """ __KEYS_TO_GUARD = ["use_gpu", "input_device", "output_device", "_callbacks"] def __init__(self, network: AbstractNetwork, save_path: str, losses: dict, optimizer_cls: type, optimizer_params: dict, train_metrics: dict, val_metrics: dict, lr_scheduler_cls: type, lr_scheduler_params: dict, gpu_ids: typing.List[int], save_freq: int, optim_fn, key_mapping: dict, logging_type: str, logging_kwargs: dict, fold: int, callbacks: typing.List[AbstractCallback], start_epoch=1, metric_keys=None, convert_batch_to_npy_fn=lambda x: x, val_freq=1, **kwargs ): """ Parameters ---------- network : :class:`AbstractTfNetwork` the network to train save_path : str path to save networks to losses : dict dictionary containing the training losses optimizer_cls : subclass of tf.train.Optimizer optimizer class implementing the optimization algorithm of choice optimizer_params : dict keyword arguments passed to optimizer during construction train_metrics : dict, optional metrics, which will be evaluated during train phase (should work on numpy arrays) val_metrics : dict, optional metrics, which will be evaluated during test phase (should work on numpy arrays) lr_scheduler_cls : Any learning rate schedule class: must implement step() method lr_scheduler_params : dict keyword arguments passed to lr scheduler during construction gpu_ids : list list containing ids of GPUs to use; if empty: use cpu instead save_freq : int integer specifying how often to save the current model's state. State is saved every state_freq epochs optim_fn : function creates a dictionary containing all necessary optimizers key_mapping : dict a dictionary containing the mapping from the ``data_dict`` to the actual model's inputs. E.g. if a model accepts one input named 'x' and the data_dict contains one entry named 'data' this argument would have to be ``{'x': 'data'}`` logging_type : str or callable the type of logging. If string: it must be one of ["visdom", "tensorboardx"] If callable: it must be a logging handler class logging_kwargs : dict dictionary containing all logging keyword arguments fold : int current cross validation fold (0 per default) callbacks : list initial callbacks to register start_epoch : int epoch to start training at metric_keys : dict the batch_dict keys to use for each metric to calculate. Should contain a value for each key in ``metrics``. If no values are given for a key, per default ``pred`` and ``label`` will be used for metric calculation convert_batch_to_npy_fn : type, optional function converting a batch-tensor to numpy, per default this is the identity function val_freq : int validation frequency specifying how often to validate the trained model (a value of 1 denotes validating every epoch, a value of 2 denotes validating every second epoch etc.); defaults to 1 **kwargs : Additional keyword arguments """ # explicity not call self._setup here to reuse the __init__ of # abstract class. self._setup has to be called in subclass # check argument types assert isinstance(network, AbstractNetwork) assert isinstance(save_path, str) assert isinstance(losses, dict) assert isinstance(optimizer_params, dict) assert isinstance(train_metrics, dict) assert isinstance(val_metrics, dict) assert isinstance(lr_scheduler_params, dict) assert isinstance(gpu_ids, list) if os.path.isdir(save_path): logger.warning( "Save Path already exists. Saved Models may be overwritten") else: os.makedirs(save_path) self._callbacks = [] self._fold = fold self.start_epoch = start_epoch self.save_path = save_path self.losses = losses self.train_metrics = train_metrics self.val_metrics = val_metrics self.stop_training = False self.save_freq = save_freq self.metric_keys = metric_keys for cbck in callbacks: self.register_callback(cbck) self._reinitialize_logging(logging_type, logging_kwargs) self._tqdm_desc = "Validate" self.val_freq = val_freq
[docs] def _setup(self, network, lr_scheduler_cls, lr_scheduler_params, gpu_ids, key_mapping, convert_batch_to_npy_fn, prepare_batch_fn): super()._setup(network, key_mapping, convert_batch_to_npy_fn, prepare_batch_fn) self.closure_fn = network.closure # optimizers must exist before calling _setup() if lr_scheduler_cls is not None: for key, optim in self.optimizers.items(): if not issubclass(lr_scheduler_cls, AbstractCallback): logger.warning("lr_scheduler_cls is not a callback.") self.register_callback(lr_scheduler_cls(optim, **lr_scheduler_params)) if gpu_ids: self.use_gpu = True else: self.use_gpu = False
[docs] def _at_training_begin(self, *args, **kwargs): """ Defines the behaviour at beginnig of the training Parameters ---------- *args : positional arguments **kwargs : keyword arguments Raises ------ NotImplementedError If not overwritten by subclass """ self.save_state(os.path.join(self.save_path, "checkpoint_epoch_0"))
[docs] def _at_training_end(self, *args, **kwargs): """ Defines the behaviour at the end of the training Parameters ---------- *args : positional arguments **kwargs : keyword arguments Raises ------ NotImplementedError If not overwritten by subclass """ return self.module
[docs] def _at_epoch_begin(self, metrics_val, val_score_key, epoch, num_epochs, **kwargs): """ Defines behaviour at beginning of each epoch: Executes all callbacks's `at_epoch_begin` method Parameters ---------- metrics_val : dict validation metrics val_score_key : str validation score key epoch : int current epoch num_epochs : int total number of epochs **kwargs : keyword arguments """ # execute all callbacks for cb in self._callbacks: self._update_state(cb.at_epoch_begin(self, val_metrics=metrics_val, val_score_key=val_score_key, curr_epoch=epoch))
[docs] def _at_epoch_end(self, metrics_val, val_score_key, epoch, is_best, **kwargs): """ Defines behaviour at beginning of each epoch: Executes all callbacks's `at_epoch_end` method and saves current state if necessary Parameters ---------- metrics_val : dict validation metrics val_score_key : str validation score key epoch : int current epoch num_epochs : int total number of epochs **kwargs : keyword arguments """ for cb in self._callbacks: self._update_state(cb.at_epoch_end(self, val_metrics=metrics_val, val_score_key=val_score_key, curr_epoch=epoch)) if epoch % self.save_freq == 0: self.save_state(os.path.join(self.save_path, "checkpoint_epoch_%d" % epoch)) if is_best: self.save_state(os.path.join(self.save_path, "checkpoint_best"))
[docs] def _train_single_epoch(self, batchgen: Augmenter, epoch, verbose=False): """ Trains the network a single epoch Parameters ---------- batchgen : :class:`Augmenter` Generator yielding the training batches epoch : int current epoch """ metrics, losses = [], [] n_batches = batchgen.num_batches if verbose: iterable = tqdm( enumerate(batchgen), unit=' batch', total=n_batches, desc='Epoch %d' % epoch) else: iterable = enumerate(batchgen) for batch_nr, batch in iterable: data_dict = self._prepare_batch(batch) _metrics, _losses, _ = self.closure_fn(self.module, data_dict, optimizers=self.optimizers, losses=self.losses, metrics=self.train_metrics, fold=self.fold, batch_nr=batch_nr) metrics.append(_metrics) losses.append(_losses) batchgen._finish() total_losses, total_metrics = {}, {} for _metrics in metrics: for key, val in _metrics.items(): if key in total_metrics: total_metrics[key].append(val) else: total_metrics[key] = [val] for _losses in losses: for key, val in _losses.items(): if key in total_losses: total_losses[key].append(val) else: total_losses[key] = [val] return total_metrics, total_losses
[docs] def train(self, num_epochs, datamgr_train, datamgr_valid=None, val_score_key=None, val_score_mode='highest', reduce_mode='mean', verbose=True): """ Defines a routine to train a specified number of epochs Parameters ---------- num_epochs : int number of epochs to train datamgr_train : DataManager the datamanager holding the train data datamgr_valid : DataManager the datamanager holding the validation data (default: None) val_score_key : str the key specifying which metric to use for validation (default: None) val_score_mode : str key specifying what kind of validation score is best reduce_mode : str 'mean','sum','first_only' verbose : bool whether to show progress bars or not Raises ------ NotImplementedError If not overwritten by subclass """ self._at_training_begin() if val_score_mode == 'highest': best_val_score = 0 elif val_score_mode == 'lowest': best_val_score = float('inf') else: best_val_score = None is_best = False new_val_score = best_val_score if reduce_mode == 'mean': def reduce_fn(batch): return np.mean(batch) elif reduce_mode == 'sum': def reduce_fn(batch): return np.sum(batch) elif reduce_mode == 'first_only': def reduce_fn(batch): return batch[0] elif reduce_mode == 'last_only': def reduce_fn(batch): return batch[-1] else: raise ValueError("No valid reduce mode given") metrics_val = {} val_metric_fns = {} for k, v in self.val_metrics.items(): if not k.startswith("val_"): k = "val_" + k val_metric_fns[k] = v if self.metric_keys is None: val_metric_keys = None else: val_metric_keys = {} for k, v in self.metric_keys.items(): if not k.startswith("val_"): k = "val_" + k val_metric_keys[k] = v for epoch in range(self.start_epoch, num_epochs + 1): self._at_epoch_begin(metrics_val, val_score_key, epoch, num_epochs) batch_gen_train = datamgr_train.get_batchgen(seed=epoch) # train single network epoch train_metrics, train_losses = self._train_single_epoch( batch_gen_train, epoch, verbose=verbose) total_metrics = { **train_metrics, **train_losses} # validate network if datamgr_valid is not None and (epoch % self.val_freq == 0): # next must be called here because self.predict_data_mgr # returns a generator (of size 1) and we want to get the first # (and only) item val_metrics = next( self.predict_data_mgr_cache_metrics_only( datamgr_valid, datamgr_valid.batch_size, metrics=val_metric_fns, metric_keys=val_metric_keys, verbose=verbose)) total_metrics.update(val_metrics) for k, v in total_metrics.items(): total_metrics[k] = reduce_fn(v) # check if metric became better if val_score_key is not None: if val_score_key not in total_metrics: if "val_" + val_score_key not in total_metrics: logger.warning( "val_score_key '%s' not a valid key for \ validation metrics ") new_val_score = best_val_score else: new_val_score = total_metrics["val_" + val_score_key] val_score_key = "val_" + val_score_key else: new_val_score = total_metrics.get(val_score_key) if new_val_score != best_val_score: is_best = self._is_better_val_scores( best_val_score, new_val_score, val_score_mode) # set best_val_score to new_val_score if is_best if is_best: best_val_score = new_val_score if is_best and verbose: logging.info("New Best Value at Epoch %03d : %03.3f" % (epoch, best_val_score)) # log metrics and loss values for key, val in total_metrics.items(): logging.info({"value": {"value": val, "name": key }}) self._at_epoch_end(total_metrics, val_score_key, epoch, is_best) is_best = False # stop training (might be caused by early stopping) if self.stop_training: break return self._at_training_end()
@property def fold(self): """ Get current fold Returns ------- int current fold """ return self._fold @fold.setter def fold(self, fold): """ Set the current fold Parameters ---------- fold : int new fold Raises ------ ValueError if `fold` is not covertable to :obj:`int` """ try: self._fold = int(fold) except ValueError as e: logger.error(e) raise e
[docs] def register_callback(self, callback: AbstractCallback): """ Register Callback to Trainer Parameters ---------- callback : :class:`AbstractCallback` the callback to register Raises ------ AssertionError `callback` is not an instance of :class:`AbstractCallback` and has not both methods ['at_epoch_begin', 'at_epoch_end'] """ assertion_str = "Given callback is not valid; Must be instance of " \ "AbstractCallback or provide functions " \ "'at_epoch_begin' and 'at_epoch_end'" instance_check = isinstance(callback, AbstractCallback) attr_check_begin = hasattr(callback, "at_epoch_begin") attr_check_end = hasattr(callback, "at_epoch_end") attr_check_both = attr_check_begin and attr_check_end assert instance_check or attr_check_both, assertion_str self._callbacks.append(callback)
[docs] def save_state(self, file_name, *args, **kwargs): """ saves the current state Parameters ---------- file_name : str filename to save the state to *args : positional arguments **kwargs : keyword arguments """ with open(file_name, "wb") as f: pickle.dump(vars(self), f, *args, **kwargs)
[docs] @staticmethod def load_state(file_name, *args, **kwargs): """ Loads the new state from file Parameters ---------- file_name : str the file to load the state from *args : positional arguments **kwargs : keyword arguments Returns ------- dict new state """ with open(file_name, "rb") as f: new_state = pickle.load(f, *args, **kwargs) return new_state
[docs] def _update_state(self, new_state): """ Update the state from a given new state Parameters ---------- new_state : dict new state to update internal state from Returns ------- :class:`BaseNetworkTrainer` the trainer with a modified state """ for key, val in new_state.items(): if (key.startswith("__") and key.endswith("__")): continue try: setattr(self, key, val) except PermissionError: logger.error("Trying to overwrite attribute %s of " "NetworkTrainer, which is not allowed!" % key) return self
[docs] def update_state(self, file_name, *args, **kwargs): """ Update internal state from a loaded state Parameters ---------- file_name : str file containing the new state to load *args : positional arguments **kwargs : keyword arguments Returns ------- :class:`BaseNetworkTrainer` the trainer with a modified state """ self._update_state(self.load_state(file_name, *args, **kwargs))
[docs] @staticmethod def _is_better_val_scores(old_val_score, new_val_score, mode='highest'): """ Check whether the new val score is better than the old one with respect to the optimization goal Parameters ---------- old_val_score : old validation score new_val_score : new validation score mode: str String to specify whether a higher or lower validation score is optimal; must be in ['highest', 'lowest'] Returns ------- bool True if new score is better, False otherwise """ assert mode in ['highest', 'lowest'], "Invalid Comparison Mode" if mode == 'highest': return new_val_score > old_val_score elif mode == 'lowest': return new_val_score < old_val_score
[docs] def _reinitialize_logging(self, logging_type, logging_kwargs: dict): from ..logging import TensorboardXLoggingHandler, VisdomLoggingHandler if isinstance(logging_type, str): if logging_type.lower() == "visdom": logging_cls = VisdomLoggingHandler elif logging_type.lower() == "tensorboardx": logging_cls = TensorboardXLoggingHandler else: raise ValueError("Invalid Logging Type") else: logging_cls = logging_type if logging_cls == VisdomLoggingHandler: _logging_kwargs = {"exp_name": "main", "level": 0} elif logging_cls == TensorboardXLoggingHandler: _logging_kwargs = {"log_dir": self.save_path, "level": 0} _logging_kwargs.update(logging_kwargs) if "exp_name" in _logging_kwargs.keys(): _logging_kwargs["exp_name"] = _logging_kwargs["exp_name"] + \ "_%02d" % self.fold # remove prior Trixihandlers and reinitialize it with given logging # type # This facilitates visualization of multiple splits/fold inside one # tensorboard-instance by means of # different tf.Summary.FileWriters() root_logger = logging.getLogger() new_handlers = [] for handler in root_logger.handlers: if isinstance(handler, TrixiHandler): handler.close() else: new_handlers.append(handler) root_logger.handlers = [] new_handlers.append( logging_cls(**_logging_kwargs) ) logging.basicConfig(level=logging.INFO, handlers=new_handlers)
[docs] @staticmethod def _search_for_prev_state(path, extensions=[]): """ Helper function to search in a given path for previous epoch states (indicated by extensions) Parameters ---------- path : str the path to search in extensions : list list of strings containing valid file extensions for checkpoint files Returns ------- str the file containing the latest checkpoint (if available) None if no latst checkpoint was found int the latest epoch (1 if no checkpoint was found) """ files = [] for file in os.listdir(path): for ext in extensions: if not ext.startswith("."): ext = "." + ext if not file.endswith(ext): continue if not file.startswith("checkpoint"): continue if file.endswith("_best" + ext): continue files.append(file) break if files: latest_epoch = max([ int(x.rsplit("_", 1)[-1].rsplit(".", 1)[0]) for x in files]) latest_state_path = [x for x in files if x.startswith("checkpoint_%d" % latest_epoch)][0] return latest_state_path, latest_epoch return None, 1