Source code for delira.training.tf_trainer

import logging
import os

from batchgenerators.dataloading import MultiThreadedAugmenter

from .base_trainer import BaseNetworkTrainer
from .train_utils import convert_tf_tensor_to_npy
from .train_utils import create_optims_default_tf as create_optims_default
from .train_utils import initialize_uninitialized
from ..io import tf_load_checkpoint, tf_save_checkpoint

logger = logging.getLogger(__name__)


[docs]class TfNetworkTrainer(BaseNetworkTrainer): """ Train and Validate a Network See Also -------- :class:`AbstractNetwork` """ def __init__(self, network, save_path, key_mapping, losses: dict, optimizer_cls, optimizer_params=None, train_metrics=None, val_metrics=None, lr_scheduler_cls=None, lr_scheduler_params=None, gpu_ids=None, save_freq=1, optim_fn=create_optims_default, logging_type="tensorboardx", logging_kwargs=None, fold=0, callbacks=None, start_epoch=1, metric_keys=None, convert_batch_to_npy_fn=convert_tf_tensor_to_npy, val_freq=1, **kwargs ): """ Parameters ---------- network : :class:`AbstractTfNetwork` the network to train save_path : str path to save networks to key_mapping : dict a dictionary containing the mapping from the ``data_dict`` to the actual model's inputs. E.g. if a model accepts one input named 'x' and the data_dict contains one entry named 'data' this argument would have to be ``{'x': 'data'}`` losses : dict dictionary containing the training losses optimizer_cls : subclass of tf.train.Optimizer optimizer class implementing the optimization algorithm of choice optimizer_params : dict keyword arguments passed to optimizer during construction train_metrics : dict, optional metrics, which will be evaluated during train phase (should work on numpy arrays) val_metrics : dict, optional metrics, which will be evaluated during test phase (should work on numpy arrays) lr_scheduler_cls : Any learning rate schedule class: must implement step() method lr_scheduler_params : dict keyword arguments passed to lr scheduler during construction gpu_ids : list list containing ids of GPUs to use; if empty: use cpu instead save_freq : int integer specifying how often to save the current model's state. State is saved every state_freq epochs optim_fn : function creates a dictionary containing all necessary optimizers logging_type : str or callable the type of logging. If string: it must be one of ["visdom", "tensorboardx"] If callable: it must be a logging handler class logging_kwargs : dict dictionary containing all logging keyword arguments fold : int current cross validation fold (0 per default) callbacks : list initial callbacks to register start_epoch : int epoch to start training at metric_keys : dict dict specifying which batch_dict entry to use for which metric as target; default: None, which will result in key "label" for all metrics convert_batch_to_npy_fn : type, optional function converting a batch-tensor to numpy, per default this is the identity function val_freq : int validation frequency specifying how often to validate the trained model (a value of 1 denotes validating every epoch, a value of 2 denotes validating every second epoch etc.); defaults to 1 **kwargs : Additional keyword arguments """ if optimizer_params is None: optimizer_params = {} if train_metrics is None: train_metrics = {} if val_metrics is None: val_metrics = {} if lr_scheduler_params is None: lr_scheduler_params = {} if gpu_ids is None: gpu_ids = [] if logging_kwargs is None: logging_kwargs = {} if callbacks is None: callbacks = [] super().__init__( network, save_path, losses, optimizer_cls, optimizer_params, train_metrics, val_metrics, lr_scheduler_cls, lr_scheduler_params, gpu_ids, save_freq, optim_fn, key_mapping, logging_type, logging_kwargs, fold, callbacks, start_epoch, metric_keys, convert_batch_to_npy_fn, val_freq) self._setup(network, optim_fn, optimizer_cls, optimizer_params, lr_scheduler_cls, lr_scheduler_params, key_mapping, convert_batch_to_npy_fn, gpu_ids) for key, val in kwargs.items(): setattr(self, key, val)
[docs] def _setup(self, network, optim_fn, optimizer_cls, optimizer_params, lr_scheduler_cls, lr_scheduler_params, key_mapping, convert_batch_to_npy_fn, gpu_ids): """ Defines the Trainers Setup Parameters ---------- network : instance of :class: `AbstractTfNetwork` the network to train optim_fn : function creates a dictionary containing all necessary optimizers optimizer_cls : subclass of tf.train.Optimizer optimizer class implementing the optimization algorithm of choice optimizer_params : dict lr_scheduler_cls : Any learning rate schedule class: must implement step() method lr_scheduler_params : dict keyword arguments passed to lr scheduler during construction convert_batch_to_npy_fn : type, optional function converting a batch-tensor to numpy, per default this is the identity function gpu_ids : list list containing ids of GPUs to use; if empty: use cpu instead """ # TODO: implement multi-GPU and single GPU training with help of # keras multi-gpu model # note: might be bugged in combination with sess.run # https://github.com/tensorflow/tensorflow/issues/21788 """ if gpu_ids and tf.test.is_gpu_available(): assert len(gpu_ids) <= len(get_available_gpus()), "more GPUs specified than available" self.use_gpu = True if len(gpu_ids) > 1: logger.warning( "multi-GPU training not yet tested!") network.model = tf.keras.utils.multi_gpu_model( network.model, len(gpu_ids), cpu_merge=True, cpu_relocation=False) else: network.models = tf.keras.models.clone_model(network.model) else: self.use_gpu = False """ self.optimizers = optim_fn(optimizer_cls, **optimizer_params) super()._setup(network, lr_scheduler_cls, lr_scheduler_params, gpu_ids, key_mapping, convert_batch_to_npy_fn, lambda x: x) self.use_gpu = True self.module._add_losses(self.losses) self.module._add_optims(self.optimizers) # check for unitialized variables initialize_uninitialized(self.module._sess) # Load latest epoch file if available if os.path.isdir(self.save_path): latest_state_path, latest_epoch = self._search_for_prev_state( self.save_path, [".meta"]) if latest_state_path is not None: # if pth file does not exist, load pt file instead if not os.path.isfile(latest_state_path): latest_state_path = latest_state_path[:-1] logger.info("Attempting to load state from previous \ training from %s" % latest_state_path) self.update_state(latest_state_path) self.start_epoch = latest_epoch
[docs] def _at_training_end(self): """ Defines Behaviour at end of training: Loads best model if available Returns ------- :class:`AbstractTfNetwork` best network """ if os.path.isfile(os.path.join(self.save_path, 'checkpoint_best.meta')): # load best model and return it. Since the state is hidden in the # graph, we don't actually need to use # self._update_state. self.update_state(os.path.join(self.save_path, 'checkpoint_best') ) return self.module
[docs] def _train_single_epoch(self, batchgen: MultiThreadedAugmenter, epoch, verbose=False): """ Trains the network a single epoch Parameters ---------- batchgen : MultiThreadedAugmenter Generator yielding the training batches epoch : int current epoch """ self.module.training = True return super()._train_single_epoch(batchgen, epoch, verbose=verbose)
[docs] def predict_data_mgr(self, datamgr, batch_size=None, metrics=None, metric_keys=None, verbose=False, **kwargs): """ Defines a routine to predict data obtained from a batchgenerator Parameters ---------- datamgr : :class:`BaseDataManager` Manager producing a generator holding the batches batch_size : int Artificial batchsize (sampling will be done with batchsize 1 and sampled data will be stacked to match the artificial batchsize)(default: None) metrics : dict the metrics to calculate metric_keys : dict the ``batch_dict`` items to use for metric calculation verbose : bool whether to show a progress-bar or not, default: False **kwargs : additional keword arguments """ if metrics is None: metrics = {} self.module.training = False return super().predict_data_mgr(datamgr, batch_size, metrics, metric_keys, verbose=verbose)
[docs] def save_state(self, file_name, *args, **kwargs): """ saves the current state via :func:`delira.io.tf.save_checkpoint` Parameters ---------- file_name : str filename to save the state to """ tf_save_checkpoint(file_name, self.module)
[docs] def load_state(self, file_name, *args, **kwargs): """ Loads the new state from file via :func:`delira.io.tf.load_checkpoint` Parameters ---------- file_name : str the file to load the state from Returns ------- """ return tf_load_checkpoint(file_name, self.module)