Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
chuangua authored Jun 30, 2022
1 parent e9491dc commit d058289
Showing 1 changed file with 252 additions and 0 deletions.
252 changes: 252 additions & 0 deletions src/loss_functions/losses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,252 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

class Hill(nn.Module):
r""" Hill as described in the paper "Robust Loss Design for Multi-Label Learning with Missing Labels "
.. math::
Loss = y \times (1-p_{m})^\gamma\log(p_{m}) + (1-y) \times -(\lambda-p){p}^2
where : math:`\lambda-p` is the weighting term to down-weight the loss for possibly false negatives,
: math:`m` is a margin parameter,
: math:`\gamma` is a commonly used value same as Focal loss.
.. note::
Sigmoid will be done in loss.
Args:
lambda (float): Specifies the down-weight term. Default: 1.5. (We did not change the value of lambda in our experiment.)
margin (float): Margin value. Default: 1 . (Margin value is recommended in [0.5,1.0], and different margins have little effect on the result.)
gamma (float): Commonly used value same as Focal loss. Default: 2
"""

def __init__(self, lamb: float = 1.5, margin: float = 1.0, gamma: float = 2.0, reduction: str = 'sum') -> None:
super(Hill, self).__init__()
self.lamb = lamb
self.margin = margin
self.gamma = gamma
self.reduction = reduction

def forward(self, logits, targets):
"""
call function as forward
Args:
logits : The predicted logits before sigmoid with shape of :math:`(N, C)`
targets : Multi-label binarized vector with shape of :math:`(N, C)`
Returns:
torch.Tensor: loss
"""

# Calculating predicted probability
logits_margin = logits - self.margin
pred_pos = torch.sigmoid(logits_margin)
pred_neg = torch.sigmoid(logits)

# Focal margin for postive loss
pt = (1 - pred_pos) * targets + (1 - targets)
focal_weight = pt ** self.gamma

# Hill loss calculation
los_pos = targets * torch.log(pred_pos)
los_neg = (1-targets) * -(self.lamb - pred_neg) * pred_neg ** 2

loss = -(los_pos + los_neg)
loss *= focal_weight

if self.reduction == 'mean':
return loss.mean()
elif self.reduction == 'sum':
return loss.sum()
else:
return loss


class SPLC(nn.Module):
r""" SPLC loss as described in the paper "Simple Loss Design for Multi-Label Learning with Missing Labels "
.. math::
&L_{SPLC}^+ = loss^+(p)
&L_{SPLC}^- = \mathbb{I}(p\leq \tau)loss^-(p) + (1-\mathbb{I}(p\leq \tau))loss^+(p)
where :math:'\tau' is a threshold to identify missing label
:math:`$\mathbb{I}(\cdot)\in\{0,1\}$` is the indicator function,
:math: $loss^+(\cdot), loss^-(\cdot)$ refer to loss functions for positives and negatives, respectively.
.. note::
SPLC can be combinded with various multi-label loss functions.
SPLC performs best combined with Focal margin loss in our paper. Code of SPLC with Focal margin loss is released here.
Since the first epoch can recall few missing labels with high precision, SPLC can be used ater the first epoch.
Sigmoid will be done in loss.
Args:
tau (float): threshold value. Default: 0.6
change_epoch (int): which epoch to combine SPLC. Default: 1
margin (float): Margin value. Default: 1
gamma (float): Hard mining value. Default: 2
reduction (string, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
``'mean'``: the sum of the output will be divided by the number of
elements in the output, ``'sum'``: the output will be summed. Default: ``'sum'``
"""

def __init__(self,
tau: float = 0.7,
change_epoch: int = 1,
margin: float = 1.0,
gamma: float = 2.0,
reduction: str = 'sum') -> None:
super(SPLC, self).__init__()
self.tau = tau
self.change_epoch = change_epoch
self.margin = margin
self.gamma = gamma
self.reduction = reduction

def forward(self, logits: torch.Tensor, targets: torch.LongTensor,
epoch) -> torch.Tensor:
"""
call function as forward
Args:
logits : The predicted logits before sigmoid with shape of :math:`(N, C)`
targets : Multi-label binarized vector with shape of :math:`(N, C)`
epoch : The epoch of current training.
Returns:
torch.Tensor: loss
"""

# Subtract margin for positive logits
logits = torch.where(targets == 1, logits-self.margin, logits)

# SPLC missing label correction
if epoch >= self.change_epoch:
targets = torch.where(
torch.sigmoid(logits) > self.tau,
torch.tensor(1).cuda(), targets)

pred = torch.sigmoid(logits)

#Focal margin for postive loss
pt = (1 - pred) * targets + pred * (1 - targets)
focal_weight = pt**self.gamma

los_pos = targets * F.logsigmoid(logits)
los_neg = (1 - targets) * F.logsigmoid(-logits)

loss = -(los_pos + los_neg)
loss *= focal_weight

if self.reduction == 'mean':
return loss.mean()
elif self.reduction == 'sum':
return loss.sum()
else:
return loss


class AsymmetricLoss(nn.Module):
def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=True):
super(AsymmetricLoss, self).__init__()

self.gamma_neg = gamma_neg
self.gamma_pos = gamma_pos
self.clip = clip
self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
self.eps = eps

def forward(self, x, y):
""""
Parameters
----------
x: input logits
y: targets (multi-label binarized vector)
"""

# Calculating Probabilities
x_sigmoid = torch.sigmoid(x)
xs_pos = x_sigmoid
xs_neg = 1 - x_sigmoid

# Asymmetric Clipping
if self.clip is not None and self.clip > 0:
xs_neg = (xs_neg + self.clip).clamp(max=1)

# Basic CE calculation
los_pos = y * torch.log(xs_pos.clamp(min=self.eps))
los_neg = (1 - y) * torch.log(xs_neg.clamp(min=self.eps))
loss = los_pos + los_neg

# Asymmetric Focusing
if self.gamma_neg > 0 or self.gamma_pos > 0:
if self.disable_torch_grad_focal_loss:
torch.set_grad_enabled(False)
pt0 = xs_pos * y
pt1 = xs_neg * (1 - y) # pt = p if t > 0 else 1-p
pt = pt0 + pt1
one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y)
one_sided_w = torch.pow(1 - pt, one_sided_gamma)
if self.disable_torch_grad_focal_loss:
torch.set_grad_enabled(True)
loss *= one_sided_w

return -loss.sum()


class AsymmetricLossOptimized(nn.Module):
''' Notice - optimized version, minimizes memory allocation and gpu uploading,
favors inplace operations'''

def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=False):
super(AsymmetricLossOptimized, self).__init__()

self.gamma_neg = gamma_neg
self.gamma_pos = gamma_pos
self.clip = clip
self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
self.eps = eps

# prevent memory allocation and gpu uploading every iteration, and encourages inplace operations
self.targets = self.anti_targets = self.xs_pos = self.xs_neg = self.asymmetric_w = self.loss = None

def forward(self, x, y):
""""
Parameters
----------
x: input logits
y: targets (multi-label binarized vector)
"""

self.targets = y
self.anti_targets = 1 - y

# Calculating Probabilities
self.xs_pos = torch.sigmoid(x)
self.xs_neg = 1.0 - self.xs_pos

# Asymmetric Clipping
if self.clip is not None and self.clip > 0:
self.xs_neg.add_(self.clip).clamp_(max=1)

# Basic CE calculation
self.loss = self.targets * torch.log(self.xs_pos.clamp(min=self.eps))
self.loss.add_(self.anti_targets * torch.log(self.xs_neg.clamp(min=self.eps)))

# Asymmetric Focusing
if self.gamma_neg > 0 or self.gamma_pos > 0:
if self.disable_torch_grad_focal_loss:
torch.set_grad_enabled(False)
self.xs_pos = self.xs_pos * self.targets
self.xs_neg = self.xs_neg * self.anti_targets
self.asymmetric_w = torch.pow(1 - self.xs_pos - self.xs_neg,
self.gamma_pos * self.targets + self.gamma_neg * self.anti_targets)
if self.disable_torch_grad_focal_loss:
torch.set_grad_enabled(True)
self.loss *= self.asymmetric_w

return -self.loss.sum()

0 comments on commit d058289

Please sign in to comment.