From d0582896317b727610a7b6e80b287fbb9101249f Mon Sep 17 00:00:00 2001 From: lisha_li <441625350@qq.com> Date: Thu, 30 Jun 2022 14:09:20 +0800 Subject: [PATCH] Add files via upload --- src/loss_functions/losses.py | 252 +++++++++++++++++++++++++++++++++++ 1 file changed, 252 insertions(+) create mode 100644 src/loss_functions/losses.py diff --git a/src/loss_functions/losses.py b/src/loss_functions/losses.py new file mode 100644 index 0000000..cdf9800 --- /dev/null +++ b/src/loss_functions/losses.py @@ -0,0 +1,252 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Hill(nn.Module): + r""" Hill as described in the paper "Robust Loss Design for Multi-Label Learning with Missing Labels " + + .. math:: + Loss = y \times (1-p_{m})^\gamma\log(p_{m}) + (1-y) \times -(\lambda-p){p}^2 + + where : math:`\lambda-p` is the weighting term to down-weight the loss for possibly false negatives, + : math:`m` is a margin parameter, + : math:`\gamma` is a commonly used value same as Focal loss. + + .. note:: + Sigmoid will be done in loss. + + Args: + lambda (float): Specifies the down-weight term. Default: 1.5. (We did not change the value of lambda in our experiment.) + margin (float): Margin value. Default: 1 . (Margin value is recommended in [0.5,1.0], and different margins have little effect on the result.) + gamma (float): Commonly used value same as Focal loss. Default: 2 + + """ + + def __init__(self, lamb: float = 1.5, margin: float = 1.0, gamma: float = 2.0, reduction: str = 'sum') -> None: + super(Hill, self).__init__() + self.lamb = lamb + self.margin = margin + self.gamma = gamma + self.reduction = reduction + + def forward(self, logits, targets): + """ + call function as forward + + Args: + logits : The predicted logits before sigmoid with shape of :math:`(N, C)` + targets : Multi-label binarized vector with shape of :math:`(N, C)` + + Returns: + torch.Tensor: loss + """ + + # Calculating predicted probability + logits_margin = logits - self.margin + pred_pos = torch.sigmoid(logits_margin) + pred_neg = torch.sigmoid(logits) + + # Focal margin for postive loss + pt = (1 - pred_pos) * targets + (1 - targets) + focal_weight = pt ** self.gamma + + # Hill loss calculation + los_pos = targets * torch.log(pred_pos) + los_neg = (1-targets) * -(self.lamb - pred_neg) * pred_neg ** 2 + + loss = -(los_pos + los_neg) + loss *= focal_weight + + if self.reduction == 'mean': + return loss.mean() + elif self.reduction == 'sum': + return loss.sum() + else: + return loss + + +class SPLC(nn.Module): + r""" SPLC loss as described in the paper "Simple Loss Design for Multi-Label Learning with Missing Labels " + + .. math:: + &L_{SPLC}^+ = loss^+(p) + &L_{SPLC}^- = \mathbb{I}(p\leq \tau)loss^-(p) + (1-\mathbb{I}(p\leq \tau))loss^+(p) + + where :math:'\tau' is a threshold to identify missing label + :math:`$\mathbb{I}(\cdot)\in\{0,1\}$` is the indicator function, + :math: $loss^+(\cdot), loss^-(\cdot)$ refer to loss functions for positives and negatives, respectively. + + .. note:: + SPLC can be combinded with various multi-label loss functions. + SPLC performs best combined with Focal margin loss in our paper. Code of SPLC with Focal margin loss is released here. + Since the first epoch can recall few missing labels with high precision, SPLC can be used ater the first epoch. + Sigmoid will be done in loss. + + Args: + tau (float): threshold value. Default: 0.6 + change_epoch (int): which epoch to combine SPLC. Default: 1 + margin (float): Margin value. Default: 1 + gamma (float): Hard mining value. Default: 2 + reduction (string, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Default: ``'sum'`` + + """ + + def __init__(self, + tau: float = 0.7, + change_epoch: int = 1, + margin: float = 1.0, + gamma: float = 2.0, + reduction: str = 'sum') -> None: + super(SPLC, self).__init__() + self.tau = tau + self.change_epoch = change_epoch + self.margin = margin + self.gamma = gamma + self.reduction = reduction + + def forward(self, logits: torch.Tensor, targets: torch.LongTensor, + epoch) -> torch.Tensor: + """ + call function as forward + + Args: + logits : The predicted logits before sigmoid with shape of :math:`(N, C)` + targets : Multi-label binarized vector with shape of :math:`(N, C)` + epoch : The epoch of current training. + + Returns: + torch.Tensor: loss + """ + + # Subtract margin for positive logits + logits = torch.where(targets == 1, logits-self.margin, logits) + + # SPLC missing label correction + if epoch >= self.change_epoch: + targets = torch.where( + torch.sigmoid(logits) > self.tau, + torch.tensor(1).cuda(), targets) + + pred = torch.sigmoid(logits) + + #Focal margin for postive loss + pt = (1 - pred) * targets + pred * (1 - targets) + focal_weight = pt**self.gamma + + los_pos = targets * F.logsigmoid(logits) + los_neg = (1 - targets) * F.logsigmoid(-logits) + + loss = -(los_pos + los_neg) + loss *= focal_weight + + if self.reduction == 'mean': + return loss.mean() + elif self.reduction == 'sum': + return loss.sum() + else: + return loss + + +class AsymmetricLoss(nn.Module): + def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=True): + super(AsymmetricLoss, self).__init__() + + self.gamma_neg = gamma_neg + self.gamma_pos = gamma_pos + self.clip = clip + self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss + self.eps = eps + + def forward(self, x, y): + """" + Parameters + ---------- + x: input logits + y: targets (multi-label binarized vector) + """ + + # Calculating Probabilities + x_sigmoid = torch.sigmoid(x) + xs_pos = x_sigmoid + xs_neg = 1 - x_sigmoid + + # Asymmetric Clipping + if self.clip is not None and self.clip > 0: + xs_neg = (xs_neg + self.clip).clamp(max=1) + + # Basic CE calculation + los_pos = y * torch.log(xs_pos.clamp(min=self.eps)) + los_neg = (1 - y) * torch.log(xs_neg.clamp(min=self.eps)) + loss = los_pos + los_neg + + # Asymmetric Focusing + if self.gamma_neg > 0 or self.gamma_pos > 0: + if self.disable_torch_grad_focal_loss: + torch.set_grad_enabled(False) + pt0 = xs_pos * y + pt1 = xs_neg * (1 - y) # pt = p if t > 0 else 1-p + pt = pt0 + pt1 + one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y) + one_sided_w = torch.pow(1 - pt, one_sided_gamma) + if self.disable_torch_grad_focal_loss: + torch.set_grad_enabled(True) + loss *= one_sided_w + + return -loss.sum() + + +class AsymmetricLossOptimized(nn.Module): + ''' Notice - optimized version, minimizes memory allocation and gpu uploading, + favors inplace operations''' + + def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=False): + super(AsymmetricLossOptimized, self).__init__() + + self.gamma_neg = gamma_neg + self.gamma_pos = gamma_pos + self.clip = clip + self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss + self.eps = eps + + # prevent memory allocation and gpu uploading every iteration, and encourages inplace operations + self.targets = self.anti_targets = self.xs_pos = self.xs_neg = self.asymmetric_w = self.loss = None + + def forward(self, x, y): + """" + Parameters + ---------- + x: input logits + y: targets (multi-label binarized vector) + """ + + self.targets = y + self.anti_targets = 1 - y + + # Calculating Probabilities + self.xs_pos = torch.sigmoid(x) + self.xs_neg = 1.0 - self.xs_pos + + # Asymmetric Clipping + if self.clip is not None and self.clip > 0: + self.xs_neg.add_(self.clip).clamp_(max=1) + + # Basic CE calculation + self.loss = self.targets * torch.log(self.xs_pos.clamp(min=self.eps)) + self.loss.add_(self.anti_targets * torch.log(self.xs_neg.clamp(min=self.eps))) + + # Asymmetric Focusing + if self.gamma_neg > 0 or self.gamma_pos > 0: + if self.disable_torch_grad_focal_loss: + torch.set_grad_enabled(False) + self.xs_pos = self.xs_pos * self.targets + self.xs_neg = self.xs_neg * self.anti_targets + self.asymmetric_w = torch.pow(1 - self.xs_pos - self.xs_neg, + self.gamma_pos * self.targets + self.gamma_neg * self.anti_targets) + if self.disable_torch_grad_focal_loss: + torch.set_grad_enabled(True) + self.loss *= self.asymmetric_w + + return -self.loss.sum()