import logging
import os
import pickle
import typing
import numpy as np
from tqdm import tqdm
from delira.logging import TrixiHandler
from .callbacks import AbstractCallback
from .predictor import Predictor
from ..data_loading.data_manager import Augmenter
from ..models import AbstractNetwork
logger = logging.getLogger(__name__)
[docs]class BaseNetworkTrainer(Predictor):
"""
Defines a Base API and basic functions for Network Trainers
See Also
--------
:class:`PyTorchNetworkTrainer`
:class:`TfNetworkTrainer`
"""
__KEYS_TO_GUARD = ["use_gpu",
"input_device",
"output_device",
"_callbacks"]
def __init__(self,
network: AbstractNetwork,
save_path: str,
losses: dict,
optimizer_cls: type,
optimizer_params: dict,
train_metrics: dict,
val_metrics: dict,
lr_scheduler_cls: type,
lr_scheduler_params: dict,
gpu_ids: typing.List[int],
save_freq: int,
optim_fn,
key_mapping: dict,
logging_type: str,
logging_kwargs: dict,
fold: int,
callbacks: typing.List[AbstractCallback],
start_epoch=1,
metric_keys=None,
convert_batch_to_npy_fn=lambda x: x,
val_freq=1,
**kwargs
):
"""
Parameters
----------
network : :class:`AbstractTfNetwork`
the network to train
save_path : str
path to save networks to
losses : dict
dictionary containing the training losses
optimizer_cls : subclass of tf.train.Optimizer
optimizer class implementing the optimization algorithm of choice
optimizer_params : dict
keyword arguments passed to optimizer during construction
train_metrics : dict, optional
metrics, which will be evaluated during train phase
(should work on numpy arrays)
val_metrics : dict, optional
metrics, which will be evaluated during test phase
(should work on numpy arrays)
lr_scheduler_cls : Any
learning rate schedule class: must implement step() method
lr_scheduler_params : dict
keyword arguments passed to lr scheduler during construction
gpu_ids : list
list containing ids of GPUs to use; if empty: use cpu instead
save_freq : int
integer specifying how often to save the current model's state.
State is saved every state_freq epochs
optim_fn : function
creates a dictionary containing all necessary optimizers
key_mapping : dict
a dictionary containing the mapping from the ``data_dict`` to
the actual model's inputs.
E.g. if a model accepts one input named 'x' and the data_dict
contains one entry named 'data' this argument would have to
be ``{'x': 'data'}``
logging_type : str or callable
the type of logging. If string: it must be one of
["visdom", "tensorboardx"]
If callable: it must be a logging handler class
logging_kwargs : dict
dictionary containing all logging keyword arguments
fold : int
current cross validation fold (0 per default)
callbacks : list
initial callbacks to register
start_epoch : int
epoch to start training at
metric_keys : dict
the batch_dict keys to use for each metric to calculate.
Should contain a value for each key in ``metrics``.
If no values are given for a key, per default ``pred`` and
``label`` will be used for metric calculation
convert_batch_to_npy_fn : type, optional
function converting a batch-tensor to numpy, per default this is
the identity function
val_freq : int
validation frequency specifying how often to validate the trained
model (a value of 1 denotes validating every epoch,
a value of 2 denotes validating every second epoch etc.);
defaults to 1
**kwargs :
Additional keyword arguments
"""
# explicity not call self._setup here to reuse the __init__ of
# abstract class. self._setup has to be called in subclass
# check argument types
assert isinstance(network, AbstractNetwork)
assert isinstance(save_path, str)
assert isinstance(losses, dict)
assert isinstance(optimizer_params, dict)
assert isinstance(train_metrics, dict)
assert isinstance(val_metrics, dict)
assert isinstance(lr_scheduler_params, dict)
assert isinstance(gpu_ids, list)
if os.path.isdir(save_path):
logger.warning(
"Save Path already exists. Saved Models may be overwritten")
else:
os.makedirs(save_path)
self._callbacks = []
self._fold = fold
self.start_epoch = start_epoch
self.save_path = save_path
self.losses = losses
self.train_metrics = train_metrics
self.val_metrics = val_metrics
self.stop_training = False
self.save_freq = save_freq
self.metric_keys = metric_keys
for cbck in callbacks:
self.register_callback(cbck)
self._reinitialize_logging(logging_type, logging_kwargs)
self._tqdm_desc = "Validate"
self.val_freq = val_freq
[docs] def _setup(self, network, lr_scheduler_cls, lr_scheduler_params, gpu_ids,
key_mapping, convert_batch_to_npy_fn, prepare_batch_fn):
super()._setup(network, key_mapping, convert_batch_to_npy_fn,
prepare_batch_fn)
self.closure_fn = network.closure
# optimizers must exist before calling _setup()
if lr_scheduler_cls is not None:
for key, optim in self.optimizers.items():
if not issubclass(lr_scheduler_cls, AbstractCallback):
logger.warning("lr_scheduler_cls is not a callback.")
self.register_callback(lr_scheduler_cls(optim,
**lr_scheduler_params))
if gpu_ids:
self.use_gpu = True
else:
self.use_gpu = False
[docs] def _at_training_begin(self, *args, **kwargs):
"""
Defines the behaviour at beginnig of the training
Parameters
----------
*args :
positional arguments
**kwargs :
keyword arguments
Raises
------
NotImplementedError
If not overwritten by subclass
"""
self.save_state(os.path.join(self.save_path, "checkpoint_epoch_0"))
[docs] def _at_training_end(self, *args, **kwargs):
"""
Defines the behaviour at the end of the training
Parameters
----------
*args :
positional arguments
**kwargs :
keyword arguments
Raises
------
NotImplementedError
If not overwritten by subclass
"""
return self.module
[docs] def _at_epoch_begin(self, metrics_val, val_score_key, epoch, num_epochs,
**kwargs):
"""
Defines behaviour at beginning of each epoch: Executes all callbacks's
`at_epoch_begin` method
Parameters
----------
metrics_val : dict
validation metrics
val_score_key : str
validation score key
epoch : int
current epoch
num_epochs : int
total number of epochs
**kwargs :
keyword arguments
"""
# execute all callbacks
for cb in self._callbacks:
self._update_state(cb.at_epoch_begin(self, val_metrics=metrics_val,
val_score_key=val_score_key,
curr_epoch=epoch))
[docs] def _at_epoch_end(self, metrics_val, val_score_key, epoch, is_best,
**kwargs):
"""
Defines behaviour at beginning of each epoch: Executes all callbacks's
`at_epoch_end` method and saves current state if necessary
Parameters
----------
metrics_val : dict
validation metrics
val_score_key : str
validation score key
epoch : int
current epoch
num_epochs : int
total number of epochs
**kwargs :
keyword arguments
"""
for cb in self._callbacks:
self._update_state(cb.at_epoch_end(self, val_metrics=metrics_val,
val_score_key=val_score_key,
curr_epoch=epoch))
if epoch % self.save_freq == 0:
self.save_state(os.path.join(self.save_path,
"checkpoint_epoch_%d" % epoch))
if is_best:
self.save_state(os.path.join(self.save_path,
"checkpoint_best"))
[docs] def _train_single_epoch(self, batchgen: Augmenter, epoch,
verbose=False):
"""
Trains the network a single epoch
Parameters
----------
batchgen : :class:`Augmenter`
Generator yielding the training batches
epoch : int
current epoch
"""
metrics, losses = [], []
n_batches = batchgen.num_batches
if verbose:
iterable = tqdm(
enumerate(batchgen),
unit=' batch',
total=n_batches,
desc='Epoch %d' %
epoch)
else:
iterable = enumerate(batchgen)
for batch_nr, batch in iterable:
data_dict = self._prepare_batch(batch)
_metrics, _losses, _ = self.closure_fn(self.module, data_dict,
optimizers=self.optimizers,
losses=self.losses,
metrics=self.train_metrics,
fold=self.fold,
batch_nr=batch_nr)
metrics.append(_metrics)
losses.append(_losses)
batchgen._finish()
total_losses, total_metrics = {}, {}
for _metrics in metrics:
for key, val in _metrics.items():
if key in total_metrics:
total_metrics[key].append(val)
else:
total_metrics[key] = [val]
for _losses in losses:
for key, val in _losses.items():
if key in total_losses:
total_losses[key].append(val)
else:
total_losses[key] = [val]
return total_metrics, total_losses
[docs] def train(self, num_epochs, datamgr_train, datamgr_valid=None,
val_score_key=None, val_score_mode='highest', reduce_mode='mean',
verbose=True):
"""
Defines a routine to train a specified number of epochs
Parameters
----------
num_epochs : int
number of epochs to train
datamgr_train : DataManager
the datamanager holding the train data
datamgr_valid : DataManager
the datamanager holding the validation data (default: None)
val_score_key : str
the key specifying which metric to use for validation
(default: None)
val_score_mode : str
key specifying what kind of validation score is best
reduce_mode : str
'mean','sum','first_only'
verbose : bool
whether to show progress bars or not
Raises
------
NotImplementedError
If not overwritten by subclass
"""
self._at_training_begin()
if val_score_mode == 'highest':
best_val_score = 0
elif val_score_mode == 'lowest':
best_val_score = float('inf')
else:
best_val_score = None
is_best = False
new_val_score = best_val_score
if reduce_mode == 'mean':
def reduce_fn(batch):
return np.mean(batch)
elif reduce_mode == 'sum':
def reduce_fn(batch):
return np.sum(batch)
elif reduce_mode == 'first_only':
def reduce_fn(batch):
return batch[0]
elif reduce_mode == 'last_only':
def reduce_fn(batch):
return batch[-1]
else:
raise ValueError("No valid reduce mode given")
metrics_val = {}
val_metric_fns = {}
for k, v in self.val_metrics.items():
if not k.startswith("val_"):
k = "val_" + k
val_metric_fns[k] = v
if self.metric_keys is None:
val_metric_keys = None
else:
val_metric_keys = {}
for k, v in self.metric_keys.items():
if not k.startswith("val_"):
k = "val_" + k
val_metric_keys[k] = v
for epoch in range(self.start_epoch, num_epochs + 1):
self._at_epoch_begin(metrics_val, val_score_key, epoch,
num_epochs)
batch_gen_train = datamgr_train.get_batchgen(seed=epoch)
# train single network epoch
train_metrics, train_losses = self._train_single_epoch(
batch_gen_train, epoch, verbose=verbose)
total_metrics = {
**train_metrics,
**train_losses}
# validate network
if datamgr_valid is not None and (epoch % self.val_freq == 0):
# next must be called here because self.predict_data_mgr
# returns a generator (of size 1) and we want to get the first
# (and only) item
val_metrics = next(
self.predict_data_mgr_cache_metrics_only(
datamgr_valid, datamgr_valid.batch_size,
metrics=val_metric_fns, metric_keys=val_metric_keys,
verbose=verbose))
total_metrics.update(val_metrics)
for k, v in total_metrics.items():
total_metrics[k] = reduce_fn(v)
# check if metric became better
if val_score_key is not None:
if val_score_key not in total_metrics:
if "val_" + val_score_key not in total_metrics:
logger.warning(
"val_score_key '%s' not a valid key for \
validation metrics" % str(val_score_key))
new_val_score = best_val_score
else:
new_val_score = total_metrics["val_" + val_score_key]
val_score_key = "val_" + val_score_key
else:
new_val_score = total_metrics.get(val_score_key)
if new_val_score != best_val_score:
is_best = self._is_better_val_scores(
best_val_score, new_val_score, val_score_mode)
# set best_val_score to new_val_score if is_best
if is_best:
best_val_score = new_val_score
if is_best and verbose:
logging.info("New Best Value at Epoch %03d : %03.3f" %
(epoch, best_val_score))
# log metrics and loss values
for key, val in total_metrics.items():
logging.info({"value": {"value": val, "name": key
}})
self._at_epoch_end(total_metrics, val_score_key, epoch, is_best)
is_best = False
# stop training (might be caused by early stopping)
if self.stop_training:
break
return self._at_training_end()
@property
def fold(self):
"""
Get current fold
Returns
-------
int
current fold
"""
return self._fold
@fold.setter
def fold(self, fold):
"""
Set the current fold
Parameters
----------
fold : int
new fold
Raises
------
ValueError
if `fold` is not covertable to :obj:`int`
"""
try:
self._fold = int(fold)
except ValueError as e:
logger.error(e)
raise e
[docs] def register_callback(self, callback: AbstractCallback):
"""
Register Callback to Trainer
Parameters
----------
callback : :class:`AbstractCallback`
the callback to register
Raises
------
AssertionError
`callback` is not an instance of :class:`AbstractCallback` and has
not both methods ['at_epoch_begin', 'at_epoch_end']
"""
assertion_str = "Given callback is not valid; Must be instance of " \
"AbstractCallback or provide functions " \
"'at_epoch_begin' and 'at_epoch_end'"
instance_check = isinstance(callback, AbstractCallback)
attr_check_begin = hasattr(callback, "at_epoch_begin")
attr_check_end = hasattr(callback, "at_epoch_end")
attr_check_both = attr_check_begin and attr_check_end
assert instance_check or attr_check_both, assertion_str
self._callbacks.append(callback)
[docs] def save_state(self, file_name, *args, **kwargs):
"""
saves the current state
Parameters
----------
file_name : str
filename to save the state to
*args :
positional arguments
**kwargs :
keyword arguments
"""
with open(file_name, "wb") as f:
pickle.dump(vars(self), f, *args, **kwargs)
[docs] @staticmethod
def load_state(file_name, *args, **kwargs):
"""
Loads the new state from file
Parameters
----------
file_name : str
the file to load the state from
*args :
positional arguments
**kwargs : keyword arguments
Returns
-------
dict
new state
"""
with open(file_name, "rb") as f:
new_state = pickle.load(f, *args, **kwargs)
return new_state
[docs] def _update_state(self, new_state):
"""
Update the state from a given new state
Parameters
----------
new_state : dict
new state to update internal state from
Returns
-------
:class:`BaseNetworkTrainer`
the trainer with a modified state
"""
for key, val in new_state.items():
if key.startswith("__") and key.endswith("__"):
continue
try:
setattr(self, key, val)
except PermissionError:
logger.error("Trying to overwrite attribute %s of "
"NetworkTrainer, which is not allowed!" % key)
return self
[docs] def update_state(self, file_name, *args, **kwargs):
"""
Update internal state from a loaded state
Parameters
----------
file_name : str
file containing the new state to load
*args :
positional arguments
**kwargs :
keyword arguments
Returns
-------
:class:`BaseNetworkTrainer`
the trainer with a modified state
"""
self._update_state(self.load_state(file_name, *args, **kwargs))
[docs] @staticmethod
def _is_better_val_scores(old_val_score, new_val_score,
mode='highest'):
"""
Check whether the new val score is better than the old one
with respect to the optimization goal
Parameters
----------
old_val_score :
old validation score
new_val_score :
new validation score
mode: str
String to specify whether a higher or lower validation score is
optimal; must be in ['highest', 'lowest']
Returns
-------
bool
True if new score is better, False otherwise
"""
assert mode in ['highest', 'lowest'], "Invalid Comparison Mode"
if mode == 'highest':
return new_val_score > old_val_score
elif mode == 'lowest':
return new_val_score < old_val_score
[docs] def _reinitialize_logging(self, logging_type, logging_kwargs: dict):
from ..logging import TensorboardXLoggingHandler, VisdomLoggingHandler
if isinstance(logging_type, str):
if logging_type.lower() == "visdom":
logging_cls = VisdomLoggingHandler
elif logging_type.lower() == "tensorboardx":
logging_cls = TensorboardXLoggingHandler
else:
raise ValueError("Invalid Logging Type")
else:
logging_cls = logging_type
if logging_cls == VisdomLoggingHandler:
_logging_kwargs = {"exp_name": "main",
"level": 0}
elif logging_cls == TensorboardXLoggingHandler:
_logging_kwargs = {"log_dir": self.save_path,
"level": 0}
_logging_kwargs.update(logging_kwargs)
if "exp_name" in _logging_kwargs.keys():
_logging_kwargs["exp_name"] = _logging_kwargs["exp_name"] + \
"_%02d" % self.fold
# remove prior Trixihandlers and reinitialize it with given logging
# type
# This facilitates visualization of multiple splits/fold inside one
# tensorboard-instance by means of
# different tf.Summary.FileWriters()
root_logger = logging.getLogger()
new_handlers = []
for handler in root_logger.handlers:
if isinstance(handler, TrixiHandler):
handler.close()
else:
new_handlers.append(handler)
root_logger.handlers = []
new_handlers.append(
logging_cls(**_logging_kwargs)
)
logging.basicConfig(level=logging.INFO,
handlers=new_handlers)
[docs] @staticmethod
def _search_for_prev_state(path, extensions=None):
"""
Helper function to search in a given path for previous epoch states
(indicated by extensions)
Parameters
----------
path : str
the path to search in
extensions : list
list of strings containing valid file extensions for checkpoint
files
Returns
-------
str
the file containing the latest checkpoint (if available)
None
if no latst checkpoint was found
int
the latest epoch (1 if no checkpoint was found)
"""
if extensions is None:
extensions = []
files = []
for file in os.listdir(path):
for ext in extensions:
if not ext.startswith("."):
ext = "." + ext
if not file.endswith(ext):
continue
if not file.startswith("checkpoint"):
continue
if file.endswith("_best" + ext):
continue
files.append(file)
break
if files:
latest_epoch = max([
int(x.rsplit("_", 1)[-1].rsplit(".", 1)[0])
for x in files])
latest_state_path = [x for x in files
if x.startswith("checkpoint_%d"
% latest_epoch)][0]
return latest_state_path, latest_epoch
return None, 1