-
Notifications
You must be signed in to change notification settings - Fork 0
/
loss.py
27 lines (19 loc) · 1 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from typing import Optional
import torch
import torch.nn as nn
from torch import Tensor
class NSBCELoss(nn.BCELoss):
def __init__(self, weight: Optional[Tensor] = None, size_average=None, reduce=None, reduction: str = 'mean') -> None:
super().__init__(weight, size_average, reduce, reduction)
def forward(self, input: Tensor, target: Tensor, negative_mask: Tensor) -> Tensor:
# make loss masking adding negative_mask to target and find nonzero indices
loss_targets = (target.add(negative_mask)).nonzero(as_tuple=True)
# compute loss only for nonzero indices
return nn.functional.binary_cross_entropy(input[loss_targets], target[loss_targets], weight=self.weight, reduction=self.reduction)
class BPRLoss(nn.Module):
def __init__(self):
super().__init__()
self.logsigmoid = nn.LogSigmoid()
def forward(self, positive_preds, negative_preds):
difference = positive_preds - negative_preds
return torch.mean(-self.logsigmoid(difference))