-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
252 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,252 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
class Hill(nn.Module): | ||
r""" Hill as described in the paper "Robust Loss Design for Multi-Label Learning with Missing Labels " | ||
.. math:: | ||
Loss = y \times (1-p_{m})^\gamma\log(p_{m}) + (1-y) \times -(\lambda-p){p}^2 | ||
where : math:`\lambda-p` is the weighting term to down-weight the loss for possibly false negatives, | ||
: math:`m` is a margin parameter, | ||
: math:`\gamma` is a commonly used value same as Focal loss. | ||
.. note:: | ||
Sigmoid will be done in loss. | ||
Args: | ||
lambda (float): Specifies the down-weight term. Default: 1.5. (We did not change the value of lambda in our experiment.) | ||
margin (float): Margin value. Default: 1 . (Margin value is recommended in [0.5,1.0], and different margins have little effect on the result.) | ||
gamma (float): Commonly used value same as Focal loss. Default: 2 | ||
""" | ||
|
||
def __init__(self, lamb: float = 1.5, margin: float = 1.0, gamma: float = 2.0, reduction: str = 'sum') -> None: | ||
super(Hill, self).__init__() | ||
self.lamb = lamb | ||
self.margin = margin | ||
self.gamma = gamma | ||
self.reduction = reduction | ||
|
||
def forward(self, logits, targets): | ||
""" | ||
call function as forward | ||
Args: | ||
logits : The predicted logits before sigmoid with shape of :math:`(N, C)` | ||
targets : Multi-label binarized vector with shape of :math:`(N, C)` | ||
Returns: | ||
torch.Tensor: loss | ||
""" | ||
|
||
# Calculating predicted probability | ||
logits_margin = logits - self.margin | ||
pred_pos = torch.sigmoid(logits_margin) | ||
pred_neg = torch.sigmoid(logits) | ||
|
||
# Focal margin for postive loss | ||
pt = (1 - pred_pos) * targets + (1 - targets) | ||
focal_weight = pt ** self.gamma | ||
|
||
# Hill loss calculation | ||
los_pos = targets * torch.log(pred_pos) | ||
los_neg = (1-targets) * -(self.lamb - pred_neg) * pred_neg ** 2 | ||
|
||
loss = -(los_pos + los_neg) | ||
loss *= focal_weight | ||
|
||
if self.reduction == 'mean': | ||
return loss.mean() | ||
elif self.reduction == 'sum': | ||
return loss.sum() | ||
else: | ||
return loss | ||
|
||
|
||
class SPLC(nn.Module): | ||
r""" SPLC loss as described in the paper "Simple Loss Design for Multi-Label Learning with Missing Labels " | ||
.. math:: | ||
&L_{SPLC}^+ = loss^+(p) | ||
&L_{SPLC}^- = \mathbb{I}(p\leq \tau)loss^-(p) + (1-\mathbb{I}(p\leq \tau))loss^+(p) | ||
where :math:'\tau' is a threshold to identify missing label | ||
:math:`$\mathbb{I}(\cdot)\in\{0,1\}$` is the indicator function, | ||
:math: $loss^+(\cdot), loss^-(\cdot)$ refer to loss functions for positives and negatives, respectively. | ||
.. note:: | ||
SPLC can be combinded with various multi-label loss functions. | ||
SPLC performs best combined with Focal margin loss in our paper. Code of SPLC with Focal margin loss is released here. | ||
Since the first epoch can recall few missing labels with high precision, SPLC can be used ater the first epoch. | ||
Sigmoid will be done in loss. | ||
Args: | ||
tau (float): threshold value. Default: 0.6 | ||
change_epoch (int): which epoch to combine SPLC. Default: 1 | ||
margin (float): Margin value. Default: 1 | ||
gamma (float): Hard mining value. Default: 2 | ||
reduction (string, optional): Specifies the reduction to apply to the output: | ||
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, | ||
``'mean'``: the sum of the output will be divided by the number of | ||
elements in the output, ``'sum'``: the output will be summed. Default: ``'sum'`` | ||
""" | ||
|
||
def __init__(self, | ||
tau: float = 0.7, | ||
change_epoch: int = 1, | ||
margin: float = 1.0, | ||
gamma: float = 2.0, | ||
reduction: str = 'sum') -> None: | ||
super(SPLC, self).__init__() | ||
self.tau = tau | ||
self.change_epoch = change_epoch | ||
self.margin = margin | ||
self.gamma = gamma | ||
self.reduction = reduction | ||
|
||
def forward(self, logits: torch.Tensor, targets: torch.LongTensor, | ||
epoch) -> torch.Tensor: | ||
""" | ||
call function as forward | ||
Args: | ||
logits : The predicted logits before sigmoid with shape of :math:`(N, C)` | ||
targets : Multi-label binarized vector with shape of :math:`(N, C)` | ||
epoch : The epoch of current training. | ||
Returns: | ||
torch.Tensor: loss | ||
""" | ||
|
||
# Subtract margin for positive logits | ||
logits = torch.where(targets == 1, logits-self.margin, logits) | ||
|
||
# SPLC missing label correction | ||
if epoch >= self.change_epoch: | ||
targets = torch.where( | ||
torch.sigmoid(logits) > self.tau, | ||
torch.tensor(1).cuda(), targets) | ||
|
||
pred = torch.sigmoid(logits) | ||
|
||
#Focal margin for postive loss | ||
pt = (1 - pred) * targets + pred * (1 - targets) | ||
focal_weight = pt**self.gamma | ||
|
||
los_pos = targets * F.logsigmoid(logits) | ||
los_neg = (1 - targets) * F.logsigmoid(-logits) | ||
|
||
loss = -(los_pos + los_neg) | ||
loss *= focal_weight | ||
|
||
if self.reduction == 'mean': | ||
return loss.mean() | ||
elif self.reduction == 'sum': | ||
return loss.sum() | ||
else: | ||
return loss | ||
|
||
|
||
class AsymmetricLoss(nn.Module): | ||
def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=True): | ||
super(AsymmetricLoss, self).__init__() | ||
|
||
self.gamma_neg = gamma_neg | ||
self.gamma_pos = gamma_pos | ||
self.clip = clip | ||
self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss | ||
self.eps = eps | ||
|
||
def forward(self, x, y): | ||
"""" | ||
Parameters | ||
---------- | ||
x: input logits | ||
y: targets (multi-label binarized vector) | ||
""" | ||
|
||
# Calculating Probabilities | ||
x_sigmoid = torch.sigmoid(x) | ||
xs_pos = x_sigmoid | ||
xs_neg = 1 - x_sigmoid | ||
|
||
# Asymmetric Clipping | ||
if self.clip is not None and self.clip > 0: | ||
xs_neg = (xs_neg + self.clip).clamp(max=1) | ||
|
||
# Basic CE calculation | ||
los_pos = y * torch.log(xs_pos.clamp(min=self.eps)) | ||
los_neg = (1 - y) * torch.log(xs_neg.clamp(min=self.eps)) | ||
loss = los_pos + los_neg | ||
|
||
# Asymmetric Focusing | ||
if self.gamma_neg > 0 or self.gamma_pos > 0: | ||
if self.disable_torch_grad_focal_loss: | ||
torch.set_grad_enabled(False) | ||
pt0 = xs_pos * y | ||
pt1 = xs_neg * (1 - y) # pt = p if t > 0 else 1-p | ||
pt = pt0 + pt1 | ||
one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y) | ||
one_sided_w = torch.pow(1 - pt, one_sided_gamma) | ||
if self.disable_torch_grad_focal_loss: | ||
torch.set_grad_enabled(True) | ||
loss *= one_sided_w | ||
|
||
return -loss.sum() | ||
|
||
|
||
class AsymmetricLossOptimized(nn.Module): | ||
''' Notice - optimized version, minimizes memory allocation and gpu uploading, | ||
favors inplace operations''' | ||
|
||
def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=False): | ||
super(AsymmetricLossOptimized, self).__init__() | ||
|
||
self.gamma_neg = gamma_neg | ||
self.gamma_pos = gamma_pos | ||
self.clip = clip | ||
self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss | ||
self.eps = eps | ||
|
||
# prevent memory allocation and gpu uploading every iteration, and encourages inplace operations | ||
self.targets = self.anti_targets = self.xs_pos = self.xs_neg = self.asymmetric_w = self.loss = None | ||
|
||
def forward(self, x, y): | ||
"""" | ||
Parameters | ||
---------- | ||
x: input logits | ||
y: targets (multi-label binarized vector) | ||
""" | ||
|
||
self.targets = y | ||
self.anti_targets = 1 - y | ||
|
||
# Calculating Probabilities | ||
self.xs_pos = torch.sigmoid(x) | ||
self.xs_neg = 1.0 - self.xs_pos | ||
|
||
# Asymmetric Clipping | ||
if self.clip is not None and self.clip > 0: | ||
self.xs_neg.add_(self.clip).clamp_(max=1) | ||
|
||
# Basic CE calculation | ||
self.loss = self.targets * torch.log(self.xs_pos.clamp(min=self.eps)) | ||
self.loss.add_(self.anti_targets * torch.log(self.xs_neg.clamp(min=self.eps))) | ||
|
||
# Asymmetric Focusing | ||
if self.gamma_neg > 0 or self.gamma_pos > 0: | ||
if self.disable_torch_grad_focal_loss: | ||
torch.set_grad_enabled(False) | ||
self.xs_pos = self.xs_pos * self.targets | ||
self.xs_neg = self.xs_neg * self.anti_targets | ||
self.asymmetric_w = torch.pow(1 - self.xs_pos - self.xs_neg, | ||
self.gamma_pos * self.targets + self.gamma_neg * self.anti_targets) | ||
if self.disable_torch_grad_focal_loss: | ||
torch.set_grad_enabled(True) | ||
self.loss *= self.asymmetric_w | ||
|
||
return -self.loss.sum() |