Skip to content

Commit

Permalink
Cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
IgorSusmelj committed Sep 16, 2023
1 parent 41798a1 commit efc559c
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 22 deletions.
20 changes: 1 addition & 19 deletions benchmarks/imagenet/resnet50/tico.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,25 +65,7 @@ def training_step(
teacher_features, teacher_projections = self.forward_teacher(views[0])
student_projections = self.forward_student(views[1])

transformative_invariance_loss, covariance_contrast_loss = self.criterion(
teacher_projections, student_projections
)
loss = transformative_invariance_loss, covariance_contrast_loss
self.log(
"trans_loss",
transformative_invariance_loss,
prog_bar=True,
sync_dist=True,
batch_size=len(targets),
)

self.log(
"cov_loss",
covariance_contrast_loss,
prog_bar=True,
sync_dist=True,
batch_size=len(targets),
)
loss = self.criterion(teacher_projections, student_projections)

self.log(
"train_loss", loss, prog_bar=True, sync_dist=True, batch_size=len(targets)
Expand Down
6 changes: 3 additions & 3 deletions lightly/loss/tico_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,13 @@ def forward(
# compute loss
C = self.beta * self.C + (1 - self.beta) * B

transformative_invariance_loss = -(z_a * z_b).sum(dim=1).mean()
transformative_invariance_loss = 1.0 - (z_a * z_b).sum(dim=1).mean()
covariance_contrast_loss = self.rho * (torch.mm(z_a, C) * z_a).sum(dim=1).mean()

# loss = transformative_invariance_loss + covariance_contrast_loss
loss = transformative_invariance_loss + covariance_contrast_loss

# update covariance matrix
if update_covariance_matrix:
self.C = C.detach()

return transformative_invariance_loss, covariance_contrast_loss
return loss

0 comments on commit efc559c

Please sign in to comment.