Source code for delira.training.metrics

from sklearn.metrics import roc_auc_score
from sklearn.metrics import accuracy_score
from ..utils.decorators import make_deprecated

import trixi

from delira import get_backends

if "TORCH" in get_backends():
    import torch

    from .train_utils import pytorch_tensor_to_numpy, float_to_pytorch_tensor 
[docs] @make_deprecated(trixi) class AurocMetricPyTorch(torch.nn.Module): """ Metric to Calculate AuROC .. deprecated:: 0.1 :class:`AurocMetricPyTorch` will be removed in next release and is deprecated in favor of ``trixi.logging`` Modules .. warning:: :class:`AurocMetricPyTorch` will be removed in next release """ def __init__(self): super().__init__()
[docs] def forward(self, outputs: torch.Tensor, targets: torch.Tensor): """ Actual AuROC calculation Parameters ---------- outputs : torch.Tensor predictions from network targets : torch.Tensor training targets Returns ------- torch.Tensor auroc value """ if outputs.dim() == 2: outputs = torch.argmax(outputs, dim=1) score = roc_auc_score(pytorch_tensor_to_numpy(targets), pytorch_tensor_to_numpy(outputs)) return float_to_pytorch_tensor(score)
[docs] @make_deprecated(trixi) class AccuracyMetricPyTorch(torch.nn.Module): """ Metric to Calculate Accuracy .. deprecated:: 0.1 :class:`AccuracyMetricPyTorch` will be removed in next release and is deprecated in favor of ``trixi.logging`` Modules .. warning:: class:`AccuracyMetricPyTorch` will be removed in next release """ def __init__(self, normalize=True, sample_weight=None): """ Parameters ---------- normalize : bool, optional (default=True) If ``False``, return the number of correctly classified samples. Otherwise, return the fraction of correctly classified samples. sample_weight : array-like of shape = [n_samples], optional Sample weights. """ super().__init__() self.normalize = normalize self.sample_weight = sample_weight
[docs] def forward(self, outputs: torch.Tensor, targets: torch.Tensor): """ Actual accuracy calcuation Parameters ---------- outputs : torch.Tensor predictions from network targets : torch.Tensor training targets Returns ------- torch.Tensor accuracy value """ outputs = outputs > 0.5 if outputs.dim() == 2: outputs = torch.argmax(outputs, dim=1) score = accuracy_score(pytorch_tensor_to_numpy(targets), pytorch_tensor_to_numpy(outputs), self.normalize, self.sample_weight) return float_to_pytorch_tensor(score)