Source code for delira.logging.deprecated.imgsave_handler

import logging
import torch
from skimage import io as sio
import numpy as np
import os

from delira.utils.decorators import make_deprecated

from ..trixi_handler import TrixiHandler


[docs]@make_deprecated(TrixiHandler) class ImgSaveHandler(logging.Handler): """ Logging Handler which saves images to dir .. deprecated:: 0.1 :class:`ImgSaveHandler` will be removed in next release and is deprecated in favor of ``trixi.logging`` Modules .. warning:: :class:`ImgSaveHandler` will be removed in next release See Also -------- :class:`TrixiHandler` """ def __init__(self, save_dir_train, save_dir_val=None, save_freq_train=1, save_freq_val=1, level=logging.NOTSET): """ Parameters ---------- save_dir_train : str path to which the training images should be saved (must not yet be existent) save_dir_val : str (default:None) path to which the training images should be saved (must not yet be existent) save_freq_train : int (default: 1) frequency with which images are saved during training save_freq_val : int (default: 1) frequency with which images are saved during validation level: int (default: logging.NOTSET) logging level Raises ------ DeprecationWarning First Time a class instance is created """ super().__init__(level) self._save_dir_train = save_dir_train self._save_dir_val = save_dir_val self._curr_index_train = 0 self._curr_index_val = 0 def _set_save_freq(name, save_freq): assert save_freq > 0 assert isinstance(save_freq, int) setattr(self, name, save_freq) _set_save_freq("save_freq_train", save_freq_train) _set_save_freq("save_freq_val", save_freq_val) self.curr_batch_train = 0 self.curr_batch_val = 0 os.makedirs(save_dir_train, exist_ok=True) if save_dir_val: os.makedirs(save_dir_val, exist_ok=True)
[docs] def emit(self, record): """ Logging record message Parameters ---------- record : LogRecord values to log Returns ------- None if `record.msg` is not a dict """ save_imgs = False if not isinstance(record.msg, dict): try: img = self._to_image(record.msg) if self.curr_batch_train % self.save_freq_train == 0: self._save_image_batch(img, "image_%05d" % self._curr_index_train) self._curr_index_train += 1 except Exception as e: pass return images = record.msg.get('images', {}) scores = record.msg.get('scores', {}) image_dict = {} is_train = not any([name.startswith("val_") for name in scores.keys()]) if isinstance(images, list): if is_train: self.curr_batch_train += 1 if self.curr_batch_train % self.save_freq_train == 0: save_imgs = True else: self.curr_batch_val += 1 if self.curr_batch_val % self.save_freq_val == 0: save_imgs = True if save_imgs: for img in images: if is_train: curr_index = self._curr_index_train else: curr_index = self._curr_index_val image_dict['image_%05d' % curr_index] = img elif isinstance(images, dict): if (images): if is_train: self.curr_batch_train += 1 if self.curr_batch_train % self.save_freq_train == 0: save_imgs = True else: self.curr_batch_val += 1 if self.curr_batch_val % self.save_freq_val == 0: save_imgs = True if save_imgs: for key, img in images.items(): if is_train: curr_index = self._curr_index_train else: curr_index = self._curr_index_val new_key = key.replace("val_", "") image_dict[new_key + '_%05d' % curr_index] = img if save_imgs: if is_train: self._curr_index_train += 1 else: self._curr_index_val += 1 for prefix, batch in image_dict.items(): self._save_image_batch(batch, prefix, is_train)
[docs] def _save_image_batch(self, batch, prefix, is_train=True): """ Saving image batch to save_dir Parameters ---------- batch: iterable batch of images prefix: str file-prefix """ save_dir = self._save_dir_train if is_train else self._save_dir_val if isinstance(batch, torch.Tensor): batch_elements = [tmp for tmp in batch.split(1)] else: batch_elements = list(batch) for idx, img in enumerate(batch_elements): sio.imsave(os.path.join(save_dir, prefix + "_%d.png" % idx), self._to_image(img))
@staticmethod def _to_image(tensor): """ convert image to numpy array Parameters ---------- tensor: entity which is convertible to numpy array image tensor Returns ------- np.ndarray converted tensor """ if isinstance(tensor, torch.Tensor): img = tensor[0].cpu().numpy() else: img = np.asarray(tensor) img = img.astype(np.float32) if img.shape[0] == 1: img = np.tile(img, (3, 1, 1)) img -= img.min() if img.max(): img /= img.max() return img.transpose(1, 2, 0)