diff --git a/lightly/loss/tico_loss.py b/lightly/loss/tico_loss.py index 171a8b7aa..8e690c436 100644 --- a/lightly/loss/tico_loss.py +++ b/lightly/loss/tico_loss.py @@ -109,11 +109,11 @@ def forward( # compute loss C = self.beta * self.C + (1 - self.beta) * B - loss = ( - 1 - - (z_a * z_b).sum(dim=1).mean() - + self.rho * (torch.mm(z_a, C) * z_a).sum(dim=1).mean() - ) + + transformative_invariance_loss = -(z_a * z_b).sum(dim=1).mean() + covariance_contrast_loss = self.rho * (torch.mm(z_a, C) * z_a).sum(dim=1).mean() + + loss = transformative_invariance_loss + covariance_contrast_loss # update covariance matrix if update_covariance_matrix: