from sklearn.metrics import accuracy_score, balanced_accuracy_score, \
f1_score, fbeta_score, hamming_loss, jaccard_similarity_score, log_loss, \
matthews_corrcoef, precision_score, recall_score, zero_one_loss, \
roc_auc_score
from sklearn.preprocessing import label_binarize
import numpy as np
[docs]class SklearnClassificationMetric(object):
def __init__(self, score_fn, gt_logits=False, pred_logits=True, **kwargs):
"""
Wraps an score function as a metric
Parameters
----------
score_fn : function
function which should be wrapped
gt_logits : bool
whether given ``y_true`` are logits or not
pred_logits : bool
whether given ``y_pred`` are logits or not
**kwargs:
variable number of keyword arguments passed to score_fn function
"""
self._score_fn = score_fn
self._gt_logits = gt_logits
self._pred_logits = pred_logits
self.kwargs = kwargs
def __call__(self, y_true, y_pred, **kwargs):
"""
Compute metric with score_fn
Parameters
----------
y_true: np.ndarray
ground truth data
y_pred: np.ndarray
predictions of network
kwargs:
variable number of keyword arguments passed to score_fn
Returns
-------
float
result from score function
"""
if self._gt_logits:
y_true = np.argmax(y_true, axis=-1)
if self._pred_logits:
y_pred = np.argmax(y_pred, axis=-1)
return self._score_fn(y_true=y_true, y_pred=y_pred,
**kwargs, **self.kwargs)
[docs]class SklearnAccuracyScore(SklearnClassificationMetric):
"""
Accuracy Metric
"""
def __init__(self, gt_logits=False, pred_logits=True, **kwargs):
super().__init__(accuracy_score, gt_logits, pred_logits, **kwargs)
[docs]class SklearnBalancedAccuracyScore(SklearnClassificationMetric):
"""
Balanced Accuracy Metric
"""
def __init__(self, gt_logits=False, pred_logits=True, **kwargs):
super().__init__(balanced_accuracy_score, gt_logits, pred_logits,
**kwargs)
[docs]class SklearnF1Score(SklearnClassificationMetric):
"""
F1 Score
"""
def __init__(self, gt_logits=False, pred_logits=True, **kwargs):
super().__init__(f1_score, gt_logits, pred_logits, **kwargs)
[docs]class SklearnFBetaScore(SklearnClassificationMetric):
"""
F-Beta Score (Generalized F1)
"""
def __init__(self, gt_logits=False, pred_logits=True, **kwargs):
super().__init__(fbeta_score, gt_logits, pred_logits, **kwargs)
[docs]class SklearnHammingLoss(SklearnClassificationMetric):
"""
Hamming Loss
"""
def __init__(self, gt_logits=False, pred_logits=True, **kwargs):
super().__init__(hamming_loss, gt_logits, pred_logits, **kwargs)
[docs]class SklearnJaccardSimilarityScore(SklearnClassificationMetric):
"""
Jaccard Similarity Score
"""
def __init__(self, gt_logits=False, pred_logits=True, **kwargs):
super().__init__(jaccard_similarity_score, gt_logits, pred_logits,
**kwargs)
[docs]class SklearnLogLoss(SklearnClassificationMetric):
"""
Log Loss (NLL)
"""
def __init__(self, gt_logits=False, pred_logits=True, **kwargs):
super().__init__(log_loss, gt_logits, pred_logits, **kwargs)
[docs]class SklearnMatthewsCorrCoeff(SklearnClassificationMetric):
"""
Matthews Correlation Coefficient
"""
def __init__(self, gt_logits=False, pred_logits=True, **kwargs):
super().__init__(matthews_corrcoef, gt_logits, pred_logits, **kwargs)
[docs]class SklearnPrecisionScore(SklearnClassificationMetric):
"""
Precision Score
"""
def __init__(self, gt_logits=False, pred_logits=True, **kwargs):
super().__init__(precision_score, gt_logits, pred_logits, **kwargs)
[docs]class SklearnRecallScore(SklearnClassificationMetric):
"""
Recall Score
"""
def __init__(self, gt_logits=False, pred_logits=True, **kwargs):
super().__init__(recall_score, gt_logits, pred_logits, **kwargs)
[docs]class SklearnZeroOneLoss(SklearnClassificationMetric):
"""
Zero One Loss
"""
def __init__(self, gt_logits=False, pred_logits=True, **kwargs):
super().__init__(zero_one_loss, gt_logits, pred_logits, **kwargs)
[docs]class AurocMetric(object):
def __init__(self, classes=(0, 1), **kwargs):
"""
Implements the auroc metric for binary and multi class classification
Parameters
----------
classes: array-like
uniquely holds the label for each class.
kwargs:
variable number of keyword arguments passed to roc_auc_score
Raises
------
ValueError
if not at least two classes are provided
"""
self.classes = classes
self.kwargs = kwargs
if len(self.classes) < 2:
raise ValueError("At least classes 2 must exist for "
"classification. Only classes {} were passed to "
"AurocMetric.".format(classes))
def __call__(self, y_true, y_pred, **kwargs):
"""
Compute auroc
Parameters
----------
y_true: np.ndarray
ground truth data with shape (N)
y_pred: np.ndarray
predictions of network in numpy format with shape (N, nclasses)
kwargs:
variable number of keyword arguments passed to roc_auc_score
Returns
-------
float
computes auc score
Raises
------
ValueError
if two classes are given and the predictions contain more than two
classes
"""
# binary classification
if len(self.classes) == 2:
# single output unit (e.g. sigmoid)
if len(y_pred.shape) == 1 or y_pred.shape[2] == 1:
return roc_auc_score(y_true, y_pred, **kwargs)
# output of two units (e.g. softmax)
elif y_pred.shape[2] == 2:
return roc_auc_score(y_true, y_pred[:, 1], **kwargs)
else:
raise ValueError("Can not compute auroc metric for binary "
"classes with {} predicted "
"classes.".format(y_pred.shape[2]))
# classification with multiple classes
if len(self.classes) > 2:
y_true_bin = label_binarize(y_true, self.classes)
return roc_auc_score(y_true_bin, y_pred, **kwargs, **self.kwargs)