Source code for delira.models.classification.classification_network

import logging
file_logger = logging.getLogger(__name__)

from delira import get_backends

if "TORCH" in get_backends():
    import torch
    from torchvision import models as t_models
    from delira.models.abstract_network import AbstractPyTorchNetwork

[docs] class ClassificationNetworkBasePyTorch(AbstractPyTorchNetwork): """ Implements basic classification with ResNet18 References ---------- See Also -------- :class:`AbstractPyTorchNetwork` """ def __init__(self, in_channels: int, n_outputs: int, **kwargs): """ Parameters ---------- in_channels : int number of input_channels n_outputs : int number of outputs (usually same as number of classes) """ # register params by passing them as kwargs to parent class __init__ super().__init__(in_channels=in_channels, n_outputs=n_outputs, **kwargs) self.module = self._build_model(in_channels, n_outputs, **kwargs) for key, value in kwargs.items(): setattr(self, key, value)
[docs] def forward(self, input_batch: torch.Tensor): """ Forward input_batch through network Parameters ---------- input_batch : torch.Tensor batch to forward through network Returns ------- torch.Tensor Classification Result """ return self.module(input_batch)
[docs] @staticmethod def closure(model: AbstractPyTorchNetwork, data_dict: dict, optimizers: dict, criterions={}, metrics={}, fold=0, **kwargs): """ closure method to do a single backpropagation step Parameters ---------- model : :class:`ClassificationNetworkBasePyTorch` trainable model data_dict : dict dictionary containing the data optimizers : dict dictionary of optimizers to optimize model's parameters criterions : dict dict holding the criterions to calculate errors (gradients from different criterions will be accumulated) metrics : dict dict holding the metrics to calculate fold : int Current Fold in Crossvalidation (default: 0) **kwargs: additional keyword arguments Returns ------- dict Metric values (with same keys as input dict metrics) dict Loss values (with same keys as input dict criterions) list Arbitrary number of predictions as torch.Tensor Raises ------ AssertionError if optimizers or criterions are empty or the optimizers are not specified """ assert (optimizers and criterions) or not optimizers, \ "Criterion dict cannot be emtpy, if optimizers are passed" loss_vals = {} metric_vals = {} total_loss = 0 # choose suitable context manager: if optimizers: context_man = torch.enable_grad else: context_man = torch.no_grad with context_man(): inputs = data_dict.pop("data") preds = model(inputs) if data_dict: for key, crit_fn in criterions.items(): _loss_val = crit_fn(preds, *data_dict.values()) loss_vals[key] = _loss_val.detach() total_loss += _loss_val with torch.no_grad(): for key, metric_fn in metrics.items(): metric_vals[key] = metric_fn( preds, *data_dict.values()) if optimizers: optimizers['default'].zero_grad() # perform loss scaling via apex if half precision is enabled with optimizers["default"].scale_loss(total_loss) as scaled_loss: scaled_loss.backward() optimizers['default'].step() else: # add prefix "val" in validation mode eval_loss_vals, eval_metrics_vals = {}, {} for key in loss_vals.keys(): eval_loss_vals["val_" + str(key)] = loss_vals[key] for key in metric_vals: eval_metrics_vals["val_" + str(key)] = metric_vals[key] loss_vals = eval_loss_vals metric_vals = eval_metrics_vals for key, val in {**metric_vals, **loss_vals}.items():{"value": {"value": val.item(), "name": key, "env_appendix": "_%02d" % fold }}){'image_grid': {"images": inputs, "name": "input_images", "env_appendix": "_%02d" % fold}}) return metric_vals, loss_vals, [preds]
[docs] @staticmethod def _build_model(in_channels: int, n_outputs: int, **kwargs): """ builds actual model (resnet 18) Parameters ---------- in_channels : int number of input channels n_outputs : int number of outputs (usually same as number of classes) **kwargs : dict additional keyword arguments Returns ------- torch.nn.Module created model """ _model = t_models.resnet18( pretrained=False, num_classes=n_outputs, **kwargs) _model.conv1 = torch.nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False) return _model
[docs] @staticmethod def prepare_batch(batch: dict, input_device, output_device): """ Helper Function to prepare Network Inputs and Labels (convert them to correct type and shape and push them to correct devices) Parameters ---------- batch : dict dictionary containing all the data input_device : torch.device device for network inputs output_device : torch.device device for network outputs Returns ------- dict dictionary containing data in correct type and shape and on correct device """ return_dict = {"data": torch.from_numpy(batch.pop("data")).to( input_device).to(torch.float)} for key, vals in batch.items(): return_dict[key] = torch.from_numpy(vals).to( output_device).squeeze(-1).to(torch.long) return return_dict