Source code for delira.training.losses

from delira import get_backends

if "TORCH" in get_backends():
    import torch
    import torch.nn.functional as F

[docs] class BCEFocalLossPyTorch(torch.nn.Module): """ Focal loss for binary case without(!) logit """ def __init__(self, alpha=None, gamma=2, reduction='elementwise_mean'): """ Implements Focal Loss for binary class case Parameters ---------- alpha : float alpha has to be in range [0,1], assigns class weight gamma : float focusing parameter reduction : str Specifies the reduction to apply to the output: ‘none’ | ‘elementwise_mean’ | ‘sum’. ‘none’: no reduction will be applied, ‘elementwise_mean’: the sum of the output will be divided by the number of elements in the output, ‘sum’: the output will be summed (further information about parameters above can be found in pytorch documentation) Returns ------- torch.Tensor loss value """ super().__init__() self.alpha = alpha self.gamma = gamma self.reduction = reduction
[docs] def forward(self, p, t): bce_loss = F.binary_cross_entropy(p, t, reduction='none') if self.alpha is not None: # create weights for alpha alpha_weight = torch.ones(t.shape, device=p.device) * \ self.alpha alpha_weight = torch.where(torch.eq(t, 1.), alpha_weight, 1 - alpha_weight) else: alpha_weight = torch.Tensor([1]).to(p.device) # create weights for focal loss focal_weight = 1 - torch.where(torch.eq(t, 1.), p, 1 - p) focal_weight.pow_(self.gamma) focal_weight.to(p.device) # compute loss focal_loss = focal_weight * alpha_weight * bce_loss if self.reduction == 'elementwise_mean': return torch.mean(focal_loss) if self.reduction == 'none': return focal_loss if self.reduction == 'sum': return torch.sum(focal_loss) raise AttributeError('Reduction parameter unknown.')
[docs] class BCEFocalLossLogitPyTorch(torch.nn.Module): """ Focal loss for binary case WITH logit """ def __init__(self, alpha=None, gamma=2, reduction='elementwise_mean'): """ Implements Focal Loss for binary class case Parameters ---------- alpha : float alpha has to be in range [0,1], assigns class weight gamma : float focusing parameter reduction : str Specifies the reduction to apply to the output: ‘none’ | ‘elementwise_mean’ | ‘sum’. ‘none’: no reduction will be applied, ‘elementwise_mean’: the sum of the output will be divided by the number of elements in the output, ‘sum’: the output will be summed (further information about parameters above can be found in pytorch documentation) Returns ------- torch.Tensor loss value """ super().__init__() self.alpha = alpha self.gamma = gamma self.reduction = reduction
[docs] def forward(self, p, t): bce_loss = F.binary_cross_entropy_with_logits( p, t, reduction='none') p = torch.sigmoid(p) if self.alpha is not None: # create weights for alpha alpha_weight = torch.ones(t.shape, device=p.device) * \ self.alpha alpha_weight = torch.where(torch.eq(t, 1.), alpha_weight, 1 - alpha_weight) else: alpha_weight = torch.Tensor([1]).to(p.device) # create weights for focal loss focal_weight = 1 - torch.where(torch.eq(t, 1.), p, 1 - p) focal_weight.pow_(self.gamma) focal_weight.to(p.device) # compute loss focal_loss = focal_weight * alpha_weight * bce_loss if self.reduction == 'elementwise_mean': return torch.mean(focal_loss) if self.reduction == 'none': return focal_loss if self.reduction == 'sum': return torch.sum(focal_loss) raise AttributeError('Reduction parameter unknown.')