Source code for delira.training.pytorch_trainer

import os
import logging
import numpy as np
from tqdm.auto import tqdm
from collections import OrderedDict
from batchgenerators.dataloading import MultiThreadedAugmenter
from .callbacks import AbstractCallback
from .abstract_trainer import AbstractNetworkTrainer

from delira import get_backends

logger = logging.getLogger(__name__)

if "TORCH" in get_backends():
    import torch
    from .train_utils import pytorch_batch_to_numpy
    from .train_utils import create_optims_default_pytorch as create_optims_default
    from ..io.torch import load_checkpoint, save_checkpoint

[docs] class PyTorchNetworkTrainer(AbstractNetworkTrainer): """ Train and Validate a Network See Also -------- :class:`AbstractNetwork` """ def __init__(self, network, save_path, criterions: dict, optimizer_cls, optimizer_params={}, metrics={}, lr_scheduler_cls=None, lr_scheduler_params={}, gpu_ids=[], save_freq=1, optim_fn=create_optims_default, fold=0, callbacks=[], start_epoch=1, mixed_precision=False, mixed_precision_kwargs={"enable_caching": True, "verbose": False, "allow_banned": False}, **kwargs): """ Parameters ---------- network : :class:`AbstractPyTorchNetwork` the network to train save_path : str path to save networks to criterions : dict dictionary containing the training criterions optimizer_cls : subclass of torch.optim.Optimizer optimizer class implementing the optimization algorithm of choice optimizer_params : dict keyword arguments passed to optimizer during construction metrics : dict dictionary containing the validation metrics lr_scheduler_cls : Any learning rate schedule class: must implement step() method lr_scheduler_params : dict keyword arguments passed to lr scheduler during construction gpu_ids : list list containing ids of GPUs to use; if empty: use cpu instead save_freq : int integer specifying how often to save the current model's state. State is saved every state_freq epochs optim_fn : function creates a dictionary containing all necessary optimizers fold : int current cross validation fold (0 per default) callbacks : list initial callbacks to register start_epoch : int epoch to start training at mixed_precision : bool whether to use mixed precision or not (False per default) mixed_precision_kwargs : dict additional keyword arguments for mixed precision **kwargs : additional keyword arguments """ super().__init__(fold, callbacks) self.save_path = save_path if os.path.isdir(save_path): logger.warning( "Save Path already exists. Saved Models may be overwritten") else: os.makedirs(save_path) self.criterions = criterions self.metrics = metrics self.save_freq = save_freq # Whether or not to stop the training # Used for early stopping self.stop_training = False self.start_epoch = start_epoch self._setup(network, optim_fn, optimizer_cls, optimizer_params, lr_scheduler_cls, lr_scheduler_params, gpu_ids, mixed_precision, mixed_precision_kwargs) for key, val in kwargs.items(): setattr(self, key, val)
[docs] def _setup(self, network, optim_fn, optimizer_cls, optimizer_params, lr_scheduler_cls, lr_scheduler_params, gpu_ids, mixed_precision, mixed_precision_kwargs): """ Defines the Trainers Setup Parameters ---------- network : :class:`AbstractPyTorchNetwork` the network to train optim_fn : function creates a dictionary containing all necessary optimizers optimizer_cls : subclass of torch.optim.Optimizer optimizer class implementing the optimization algorithm of choice optimizer_params : dict lr_scheduler_cls : Any learning rate schedule class: must implement step() method lr_scheduler_params : dict keyword arguments passed to lr scheduler during construction gpu_ids : list list containing ids of GPUs to use; if empty: use cpu instead mixed_precision : bool whether to use mixed precision or not (False per default) mixed_precision_kwargs : dict additional keyword arguments for mixed precision """ try: from apex import amp self._amp_handle = amp.init(mixed_precision, *mixed_precision_kwargs) wrap_fn = self._amp_handle.wrap_optimizer except ImportError: if mixed_precision: logger.warning("Apex was not found found, trying to continue \ in full precision instead") from ..utils.context_managers import DefaultOptimWrapperTorch wrap_fn = DefaultOptimWrapperTorch # wrap optimizers by half_precision_optimizer via apex if necessary self.optimizers = {k: wrap_fn( v, num_loss=len(self.criterions)) for k, v in optim_fn(network, optimizer_cls, **optimizer_params).items()} # schedulers if lr_scheduler_cls is not None: for key, optim in self.optimizers.items(): if not issubclass(lr_scheduler_cls, AbstractCallback): logger.warning("lr_scheduler_cls is not a callback.") # access actual optimizer by calling wrapped optimizer from wrapper self.register_callback(lr_scheduler_cls(optim._optimizer, **lr_scheduler_params)) # store network in self.module to load previous state # (will be overwritten later) self.module = network # Load latest epoch file if available if os.path.isdir(self.save_path): # check all files in directory starting with "checkpoint" and not # ending with "_best.pth" files = [x for x in os.listdir(self.save_path) if os.path.isfile(os.path.join(self.save_path, x)) and x.startswith("checkpoint") and not x.endswith("_best.pth")] # if list is not empty: load previous state if files: latest_epoch = max([int(x.rsplit("_", 1)[-1].rsplit(".", 1)[0]) for x in files]) latest_state_path = os.path.join(self.save_path, "checkpoint_epoch_%d.pth" % latest_epoch) logger.info("Attempting to load state from previous \ training from %s" % latest_state_path) try: self.update_state(latest_state_path) except KeyError: logger.warn("Previous State could not be loaded, \ although it exists.Training will be \ restarted") # asssign closure and prepare batch from network self.closure_fn = network.closure self._prepare_batch = network.prepare_batch if gpu_ids and torch.cuda.is_available(): self.use_gpu = True if (len(gpu_ids) > 1) and (torch.cuda.device_count() > 1): # use GPU 0 as default input GPU self.input_device = torch.device("cuda:%d" % gpu_ids[0]) # Train on multiple GPUs and use GPU 0 as output device self.module = torch.nn.DataParallel(self.module.to( self.input_device), device_ids=gpu_ids, output_device=gpu_ids[1]) # use GPU 1 as default output GPU for balanced GPU usage self.output_device = torch.device("cuda:%d" % gpu_ids[1]) else: # use the only available GPU as input device self.input_device = torch.device("cuda:%d" % gpu_ids[0]) self.module = self.module.to(self.input_device) # use GPU 0 as output device as output device self.output_device = torch.device("cuda:%d" % gpu_ids[0]) else: self.use_gpu = False self.input_device = torch.device("cpu") self.output_device = torch.device("cpu") self.module = self.module.to(self.input_device)
[docs] def train(self, num_epochs, datamgr_train, datamgr_valid=None, val_score_key=None, val_score_mode='highest'): """ train network Parameters ---------- num_epochs : int number of epochs to train datamgr_train : BaseDataManager Data Manager to create Batch Generator for training datamgr_valid : BaseDataManager Data Manager to create Batch Generator for validation val_score_key : str Key of validation metric; must be key in self.metrics val_score_mode : str String to specify whether a higher or lower validation score is optimal; must be in ['highest', 'lowest'] Returns ------- :class:`AbstractPyTorchNetwork` Best model (if `val_score_key` is not a valid key the model of the last epoch will be returned) """ self._at_training_begin() self.module.train() if val_score_mode == 'highest': best_val_score = 0 elif val_score_mode == 'lowest': best_val_score = float('inf') else: best_val_score = None curr_val_score = best_val_score self.save_state(os.path.join(self.save_path, "checkpoint_epoch_0.pth"), self.start_epoch) metrics_val = {} for epoch in range(self.start_epoch, num_epochs+1): self._at_epoch_begin(metrics_val, val_score_key, epoch, num_epochs) batch_gen_train = datamgr_train.get_batchgen(seed=epoch) self._train_single_epoch(batch_gen_train, epoch) if datamgr_valid: # validate with batchsize 1 and 1 augmentation processs to # avoid dropping of last elements orig_num_aug_processes = datamgr_valid.n_process_augmentation orig_batch_size = datamgr_valid.batch_size datamgr_valid.batch_size = 1 datamgr_valid.n_process_augmentation = 1 pred_val, labels_val, metrics_val = self.predict( datamgr_valid.get_batchgen(), batch_size=orig_batch_size) # reset old values datamgr_valid.batch_size = orig_batch_size datamgr_valid.n_process_augmentation = orig_num_aug_processes # ToDO: Move decision, if current model is best to callback if val_score_key in metrics_val.keys(): curr_val_score = metrics_val[val_score_key] is_best = self._is_better_val_scores(best_val_score, curr_val_score, val_score_mode) else: logger.warning( "Validation score key not in metric dict. " "Logging metrics but can't decide which model is best") is_best = False if is_best: best_val_score = curr_val_score tqdm.write( 'Best val score = %2.3f' % best_val_score.item()) else: is_best = False else: is_best = False labels_val, pred_val, metrics_val = {}, {}, {} self._at_epoch_end(metrics_val, val_score_key, epoch, is_best) # stop training (might be caused by early stopping) if self.stop_training: break return self._at_training_end()
[docs] def _at_training_begin(self, *args, **kwargs): """ Defines behaviour at beginning of training Parameters ---------- *args : positional arguments **kwargs : keyword arguments """ pass
[docs] def _at_training_end(self): """ Defines Behaviour at end of training: Loads best model if available Returns ------- :class:`AbstractPyTorchNetwork` best network """ if os.path.isfile(os.path.join(self.save_path, 'checkpoint_best.pth')): # load best model and return it self.update_state(os.path.join(self.save_path, 'checkpoint_best.pth') ) return self.module
[docs] def _at_epoch_begin(self, metrics_val, val_score_key, epoch, num_epochs, **kwargs): """ Defines behaviour at beginning of each epoch: Executes all callbacks's `at_epoch_begin` method Parameters ---------- metrics_val : dict validation metrics val_score_key : str validation score key epoch : int current epoch num_epochs : int total number of epochs **kwargs : keyword arguments """ # execute all callbacks for cb in self._callbacks: self._update_state(cb.at_epoch_begin(self, val_metrics=metrics_val, val_score_key=val_score_key, curr_epoch=epoch))
[docs] def _at_epoch_end(self, metrics_val, val_score_key, epoch, is_best, **kwargs): """ Defines behaviour at beginning of each epoch: Executes all callbacks's `at_epoch_end` method and saves current state if necessary Parameters ---------- metrics_val : dict validation metrics val_score_key : str validation score key epoch : int current epoch num_epochs : int total number of epochs is_best : bool whether current model is best one so far **kwargs : keyword arguments """ for cb in self._callbacks: self._update_state(cb.at_epoch_end(self, val_metrics=metrics_val, val_score_key=val_score_key, curr_epoch=epoch)) if epoch % self.save_freq == 0: self.save_state(os.path.join(self.save_path, "checkpoint_epoch_%d.pth" % epoch), epoch) if is_best: self.save_state(os.path.join(self.save_path, "checkpoint_best.pth"), epoch)
[docs] def _train_single_epoch(self, batchgen: MultiThreadedAugmenter, epoch): """ Trains the network a single epoch Parameters ---------- batchgen : MultiThreadedAugmenter Generator yielding the training batches epoch : int current epoch """ self.module.train() n_batches = batchgen.generator.num_batches * batchgen.num_processes pbar = tqdm(enumerate(batchgen), unit=' batch', total=n_batches, desc='Epoch %d' % epoch) for batch_nr, batch in pbar: data_dict = self._prepare_batch(batch, self.input_device, self.output_device) _, _, _ = self.closure_fn(self.module, data_dict, optimizers=self.optimizers, criterions=self.criterions, metrics=self.metrics, fold=self.fold, batch_nr=batch_nr) batchgen._finish()
[docs] def predict(self, batchgen, batch_size=None): """ Returns predictions from network for batches from batchgen Parameters ---------- batchgen : MultiThreadedAugmenter Generator yielding the batches to predict batch_size : None or int if int: collect batches until batch_size is reached and forward them together Returns ------- np.ndarray predictions from batches list of np.ndarray labels from batches dict dictionary containing the mean validation metrics and the mean loss values """ self.module.eval() outputs_all, labels_all = [], [] metric_mean_vals = {} loss_mean_vals = {} n_batches = batchgen.generator.num_batches * batchgen.num_processes pbar = tqdm(enumerate(batchgen), unit=' sample', total=n_batches, desc='Test') orig_batch_size = batch_size batch_list = [] for i, batch in pbar: if not batch_list and (n_batches - i) < batch_size: batch_size = n_batches - i logger.debug("Set Batchsize down to %d to avoid cutting " "of the last batches" % batch_size) data_dict = self._prepare_batch(batch, self.input_device, self.output_device) # queue inputs and labels batch_list.append(data_dict) # if queue is full process queue: if batch_size is None or len(batch_list) >= batch_size: batch_dict = {} for batch in batch_list: for key, val in batch.items(): if key in batch_dict.keys(): batch_dict[key].append(val) else: batch_dict[key] = [val] for key, val_list in batch_dict.items(): batch_dict[key] = torch.cat(val_list) met_vals, loss_vals, preds = self.closure_fn( self.module, batch_dict, optimizers={}, criterions=self.criterions, metrics=self.metrics, fold=self.fold) for key, val in met_vals.items(): if key in metric_mean_vals.keys(): metric_mean_vals[key] += val.detach() else: metric_mean_vals[key] = val.detach() for key, val in loss_vals.items(): if key in loss_mean_vals.keys(): loss_mean_vals[key] += val.detach() else: loss_mean_vals[key] = val.detach() outputs_all.append( [pytorch_batch_to_numpy(tmp) for tmp in preds]) label_dict = {} for key, val in batch_dict.items(): if "data" not in key and "img" not in key: label_dict[key] = pytorch_batch_to_numpy(val) labels_all.append([label_dict[key] for key in sorted(label_dict.keys())]) batch_list = [] batchgen._finish() # transpose labels and outputs to have a list of lists of # labels of same type labels_all = zip(*labels_all) outputs_all = zip(*outputs_all) labels_all = [np.vstack(_labels) for _labels in labels_all] outputs_all = [np.vstack(_outputs) for _outputs in outputs_all] # metric_mean_vals contains sums of metrics so far. # Dividing by number of batches to get mean values # if virtual batchsize is given: calculate actual number of batches if batch_size is not None: div = np.ceil(n_batches / orig_batch_size) else: div = n_batches val_dict = {} for key, val in metric_mean_vals.items(): val_dict[key] = val / div for key, val in loss_mean_vals.items(): val_dict[key] = val / div return outputs_all, labels_all, val_dict
[docs] def save_state(self, file_name, epoch, **kwargs): """ saves the current state via :func:`delira.io.torch.save_checkpoint` Parameters ---------- file_name : str filename to save the state to epoch : int current epoch (will be saved for mapping back) *args : positional arguments **kwargs : keyword arguments """ save_checkpoint(file_name, self.module, self.optimizers, epoch=epoch, **kwargs)
[docs] @staticmethod def load_state(file_name, **kwargs): """ Loads the new state from file via :func:`delira.io.torch.load_checkpoint` Parameters ---------- file_name : str the file to load the state from **kwargs : keyword arguments Returns ------- dict new state """ return load_checkpoint(file_name, **kwargs)
[docs] def update_state(self, file_name, *args, **kwargs): """ Update internal state from a loaded state Parameters ---------- file_name : str file containing the new state to load *args : positional arguments **kwargs : keyword arguments Returns ------- :class:`AbstractNetworkTrainer` the trainer with a modified state """ self._update_state(self.load_state(file_name, *args, **kwargs))
[docs] def _update_state(self, new_state): """ Update the state from a given new state Parameters ---------- new_state : dict new state to update internal state from Returns ------- :class:`PyTorchNetworkTrainer` the trainer with a modified state """ # print(",".join(new_state.keys())) if "model" in new_state: self.module.load_state_dict(new_state.pop("model")) if "optimizer" in new_state and new_state["optimizer"]: optim_state = new_state.pop("optimizer") for key in self.optimizers.keys(): self.optimizers[key].load_state_dict( optim_state[key]) if "epoch" in new_state: self.start_epoch = new_state.pop("epoch") return super()._update_state(new_state)