Source code for delira.logging.deprecated.visdom_image_handler

import logging
import torch
import numpy as np
from visdom import Visdom
from delira.utils import now
from delira.utils.decorators import make_deprecated


from ..trixi_handler import TrixiHandler


[docs]@make_deprecated(TrixiHandler) class VisdomImageHandler(logging.Handler): """ Logging Handler to show images and metric plots with visdom .. deprecated:: 0.1 :class:`VisdomImageHandler` will be removed in next release and is deprecated in favor of ``trixi.logging`` Modules .. warning:: :class:`VisdomImageHandler` will be removed in next release See Also -------- `Visdom` :class:`TrixiHandler` """ def __init__(self, port, prefix, log_freq_train, log_freq_val=1e10, level=logging.NOTSET, log_freq_img=1, **kwargs): """ Parameters ---------- port: int port of visdom-server prefix : str prefix of environment names log_freq_train : int Defines logging frequency for scores in train mode log_freq_val : int Defines logging frequency for scores in validation mode level : int (default: logging.NOTSET) logging level **kwargs: additional keyword arguments which are directly passed to visdom """ super().__init__(level=level) self.viz = Visdom(port=port, env=prefix, **kwargs) self.env_prefix = prefix self.log_freq_train = log_freq_train self.log_freq_val = log_freq_val self.curr_batch_train = 1 self.curr_batch_val = 1 self.curr_epoch_train = 1 self.curr_epoch_val = 1 self.metrics = {} self.val_metrics = {} self.plot_windows = {} self.image_windows = {} self.heatmap_windows = {} self.text_windows = {} self.bar_windows = {} self.curr_env_name = prefix self.curr_fold = None self.img_count = 0 self.log_freq_img = log_freq_img
[docs] def emit(self, record): """ shows images and metric plots in visdom Parameters ---------- record : LogRecord entities to log Returns ------- None * if no connection to `visdom` could be found * if `record.msg` is not a dict """ # messages that cant be send fill (GPU-)RAM so return if no connection if not self.viz.check_connection(): return if not isinstance(record.msg, dict): return scores = record.msg.get("scores", {}) images = record.msg.get("images", {}) heatmaps = record.msg.get("heatmaps", {}) scalars = record.msg.get("scalars", {}) bars = record.msg.get("bars", {}) fold = record.msg.get("fold", "") text = record.msg.get("text", {}) plots = record.msg.get("plots", {}) if fold != self.curr_fold: self.curr_batch_train = 1 self.curr_batch_val = 1 self.curr_epoch_train = 1 self.curr_epoch_val = 1 self.metrics = {} self.val_metrics = {} self.plot_windows = {} self.image_windows = {} self.heatmap_windows = {} self.text_windows = {} self.bar_windows = {} if not fold and isinstance(fold, str): fold_name = self.env_prefix else: fold_name = self.env_prefix + "_fold_%02d_%s" % (fold, now()) else: fold_name = self.curr_env_name self.curr_fold = fold self.curr_env_name = fold_name # Log losses and metrics for i, metric_name in enumerate(scores.keys()): # handle validation scores if metric_name.startswith("val_"): metric_name = metric_name.split("_", maxsplit=1)[-1] if metric_name not in self.val_metrics: self.val_metrics[metric_name] = self._to_scalar( scores["val_" + metric_name]) else: self.val_metrics[metric_name] += self._to_scalar( scores["val_" + metric_name]) # handle train scores else: if metric_name not in self.metrics: self.metrics[metric_name] = self._to_scalar( scores[metric_name]) else: self.metrics[metric_name] += self._to_scalar( scores[metric_name] ) # Draw images self.img_count += 1 if (self.img_count % self.log_freq_img) == 0: for image_name, tensor in images.items(): if image_name not in self.image_windows: self.image_windows[image_name] = self.viz.image( self._to_image(tensor), opts={'title': image_name}, env=fold_name) else: self.viz.image(self._to_image(tensor.data), win=self.image_windows[image_name], opts={'title': image_name}, env=fold_name) self.img_count = 0 # draw heatmaps for heatmap_name, tensor in heatmaps.items(): heatmap = tensor[0].cpu().numpy() if heatmap_name not in self.heatmap_windows: self.heatmap_windows[heatmap_name] = self.viz.heatmap( heatmap, opts=dict(title=heatmap_name, colormap='hot'), env=fold_name) else: self.viz.heatmap(heatmap, win=self.heatmap_windows[heatmap_name], opts=dict(title=heatmap_name, colormap='hot'), env=fold_name) # visualize scalars for scalar_name, scalar_val in scalars.items(): text_str = "<font face = 'Arial' size = '4'>%s</font>" % \ str(self._to_scalar(scalar_val)) if scalar_name not in self.text_windows: self.text_windows[scalar_name] = self.viz.text(text_str, env=fold_name) else: self.viz.text(text_str, win=self.text_windows[scalar_name], env=fold_name) # draw bars for bar_name, bar_vals in bars.items(): if bar_name not in self.bar_windows: self.bar_windows[bar_name] = self.viz.bar( bar_vals, opts={"title": bar_name}, env=fold_name) else: self.viz.bar(bar_vals, win=self.bar_windows[bar_name], opts={"title": bar_name}, env=fold_name) # visualize text for text_name, val_str in text.items(): text_str = "<font face = 'Arial' size = '4'>%s</font>" % val_str if text_name not in self.text_windows: self.text_windows[text_name] = self.viz.text(text_str, env=fold_name) else: self.viz.text(text_str, win=self.text_windows[text_name], env=fold_name) # visualize plots for plot_name, plot_vals in plots.items(): if isinstance(plot_vals, dict): x_vals = plot_vals["x"] y_vals = plot_vals["y"] xlabel = plot_vals.get("xlabel", "") ylabel = plot_vals.get("ylabel", "") else: x_vals = np.array(plot_vals[0]) y_vals = np.array(plot_vals[1]) xlabel = "" ylabel = "" if plot_name not in self.plot_windows: self.plot_windows[plot_name] = self.viz.line( X=x_vals, Y=y_vals, opts={'xlabel': xlabel, 'ylabel': ylabel, 'title': plot_name}, env=fold_name) else: self.viz.line(X=x_vals, Y=y_vals, win=self.plot_windows[plot_name], opts={'xlabel': xlabel, 'ylabel': ylabel, 'title': plot_name}, env=fold_name) # End of epoch # decide which dict to log # only one epoch type at same type possible # train epoch ended if (self.curr_batch_train % self.log_freq_train) == 0: score_dict = self.metrics curr_batch = self.curr_batch_train curr_epoch = self.curr_epoch_train name = "train" self.curr_epoch_train += 1 self.curr_batch_train = 1 self.metrics = {} # validation epoch ended elif (self.curr_batch_val % self.log_freq_val) == 0: score_dict = self.val_metrics curr_batch = self.curr_batch_val curr_epoch = self.curr_epoch_val name = "val" self.curr_epoch_val += 1 self.curr_batch_val = 1 self.val_metrics = {} # no epoch ended else: score_dict = {} curr_epoch = 1 curr_batch = 1 if score_dict: # Plot losses for metric_name, metric in score_dict.items(): if metric_name not in self.plot_windows: self.plot_windows[metric_name] = self.viz.line( X=np.array([curr_epoch]), Y=np.array([metric / curr_batch]), opts={'xlabel': 'iterations', 'ylabel': metric_name, 'title': metric_name}, name=name, env=fold_name) else: self.viz.line(X=np.array([curr_epoch]), Y=np.array([metric / curr_batch]), win=self.plot_windows[metric_name], update='append', name=name, env=fold_name) else: is_val = False is_train = False for key in scores.keys(): if key.startswith("val_"): is_val = True else: is_train = True if is_val: self.curr_batch_val +=1 if is_train: self.curr_batch_train += 1
@staticmethod def _to_scalar(val): """ convert scalar wrapped in tensor or numpy array to float Parameters ---------- val: torch.Tensor or numpy array value to be converted Returns ------- float converted value """ if isinstance(val, np.ndarray): return np.asscalar(val) elif isinstance(val, torch.Tensor): return val.item() else: return float(val) @staticmethod def _to_image(tensor: torch.Tensor): """ convert image to numpy array Parameters ---------- tensor: entity which is convertible to numpy array image tensor Returns ------- np.ndarray converted image """ img = tensor[0].cpu().numpy() if img.shape[0] == 1: img = np.tile(img, (3, 1, 1)) img -= img.min() if img.max(): img *= 255/img.max() return img.astype(np.uint8)