Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Igor lig 4447 w mse benchmark #1474

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ See the [benchmarking scripts](./benchmarks/imagenet/resnet50/) for details.
| SwAV | Res50 | 256 | 100 | 67.2 | 75.4 | 49.5 | [link](https://tensorboard.dev/experiment/Ipx4Oxl5Qkqm5Sl5kWyKKg) | [link](https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_swav_2023-05-25_08-29-14/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt) |
| TiCo | Res50 | 256 | 100 | 49.7 | 72.7 | 26.6 | - | [link](https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_tico_2024-01-07_18-40-57/pretrain/version_0/checkpoints/epoch%3D99-step%3D250200.ckpt) |
| VICReg | Res50 | 256 | 100 | 63.0 | 73.7 | 46.3 | [link](https://tensorboard.dev/experiment/qH5uywJbTJSzgCEfxc7yUw) | [link](https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_vicreg_2023-09-11_10-53-08/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt) |
| W-MSE | Res50 | 512 | 100 | 54.6 | 73.6 | 31.2 | - | [link](https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_wmse_2024-01-15_22-02-33/pretrain/version_0/checkpoints/epoch%3D99-step%3D250200.ckpt) |

_\*We use square root learning rate scaling instead of linear scaling as it yields
better results for smaller batch sizes. See Appendix B.1 in the [SimCLR paper](https://arxiv.org/abs/2002.05709)._
Expand Down
2 changes: 2 additions & 0 deletions benchmarks/imagenet/resnet50/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import tico
import torch
import vicreg
import wmse
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import (
DeviceStatsMonitor,
Expand Down Expand Up @@ -64,6 +65,7 @@
"swav": {"model": swav.SwAV, "transform": swav.transform},
"tico": {"model": tico.TiCo, "transform": tico.transform},
"vicreg": {"model": vicreg.VICReg, "transform": vicreg.transform},
"wmse": {"model": wmse.WMSE, "transform": wmse.transform},
}


Expand Down
109 changes: 109 additions & 0 deletions benchmarks/imagenet/resnet50/wmse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import math
from typing import List, Tuple

import torch
from pytorch_lightning import LightningModule
from torch import Tensor
from torch.nn import Identity
from torchvision.models import resnet50

from lightly.loss.wmse_loss import WMSELoss
from lightly.models.modules import WMSEProjectionHead
from lightly.models.utils import get_weight_decay_parameters
from lightly.transforms import WMSETransform
from lightly.utils.benchmarking import OnlineLinearClassifier
from lightly.utils.lars import LARS
from lightly.utils.scheduler import CosineWarmupScheduler


class WMSE(LightningModule):
def __init__(self, batch_size_per_device: int, num_classes: int) -> None:
super().__init__()
self.save_hyperparameters()
self.batch_size_per_device = batch_size_per_device

resnet = resnet50()
resnet.fc = Identity() # Ignore classification head
self.backbone = resnet

# we use a projection head with output dimension 128
# and w_size of 256 to support a batch size of 512
self.projection_head = WMSEProjectionHead(output_dim=128)

self.criterion_WMSE4loss = WMSELoss(
w_size=256, embedding_dim=128, num_samples=4
)

self.online_classifier = OnlineLinearClassifier(num_classes=num_classes)

def forward(self, x: Tensor) -> Tensor:
return self.backbone(x)

def training_step(
self, batch: Tuple[List[Tensor], Tensor, List[str]], batch_idx: int
) -> Tensor:
views, targets = batch[0], batch[1]
features = self.forward(torch.cat(views)).flatten(start_dim=1)
z = self.projection_head(features)
loss = self.criterion_WMSE4loss(z)
self.log(
"train_loss", loss, prog_bar=True, sync_dist=True, batch_size=len(targets)
)

cls_loss, cls_log = self.online_classifier.training_step(
(features.detach(), targets.repeat(len(views))), batch_idx
)
self.log_dict(cls_log, sync_dist=True, batch_size=len(targets))
return loss + cls_loss

def validation_step(
self, batch: Tuple[Tensor, Tensor, List[str]], batch_idx: int
) -> Tensor:
images, targets = batch[0], batch[1]
features = self.forward(images).flatten(start_dim=1)
cls_loss, cls_log = self.online_classifier.validation_step(
(features.detach(), targets), batch_idx
)
self.log_dict(cls_log, prog_bar=True, sync_dist=True, batch_size=len(targets))
return cls_loss

def configure_optimizers(self):
# Don't use weight decay for batch norm, bias parameters, and classification
# head to improve performance.
params, params_no_weight_decay = get_weight_decay_parameters(
[self.backbone, self.projection_head]
)
optimizer = LARS(
[
{"name": "wmse", "params": params},
{
"name": "wmse_no_weight_decay",
"params": params_no_weight_decay,
"weight_decay": 0.0,
},
{
"name": "online_classifier",
"params": self.online_classifier.parameters(),
"weight_decay": 0.0,
},
],
lr=0.1 * math.sqrt(self.batch_size_per_device * self.trainer.world_size),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The denominator is missing here.

momentum=0.9,
weight_decay=1e-6,
)
scheduler = {
"scheduler": CosineWarmupScheduler(
optimizer=optimizer,
warmup_epochs=int(
self.trainer.estimated_stepping_batches
/ self.trainer.max_epochs
* 10
),
max_epochs=int(self.trainer.estimated_stepping_batches),
),
"interval": "step",
}
return [optimizer], [scheduler]


transform = WMSETransform()
1 change: 1 addition & 0 deletions docs/source/getting_started/benchmarks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ Evaluation settings are based on the following papers:
"SwAV", "Res50", "256", "100", "67.2", "88.1", "75.4", "92.7", "49.5", "78.6", "`link <https://tensorboard.dev/experiment/Ipx4Oxl5Qkqm5Sl5kWyKKg>`_", "`link <https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_swav_2023-05-25_08-29-14/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt>`_"
"TiCo", "Res50", "256", "100", "49.7", "74.4", "72.7", "90.9", "26.6", "53.6", "-", "`link <https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_tico_2024-01-07_18-40-57/pretrain/version_0/checkpoints/epoch%3D99-step%3D250200.ckpt>`_"
"VICReg", "Res50", "256", "100", "63.0", "85.4", "73.7", "91.9", "46.3", "75.2", "`link <https://tensorboard.dev/experiment/qH5uywJbTJSzgCEfxc7yUw>`_", "`link <https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_vicreg_2023-09-11_10-53-08/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt>`_"
"W-MSE", "Res50", "512", "100", "54.6", "78.8", "73.6", "91.5", "31.2", "60.4", "-", "`link <https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_wmse_2024-01-15_22-02-33/pretrain/version_0/checkpoints/epoch%3D99-step%3D250200.ckpt>`_"

*\*We use square root learning rate scaling instead of linear scaling as it yields better results for smaller batch sizes. See Appendix B.1 in the SimCLR paper.*

Expand Down
32 changes: 29 additions & 3 deletions lightly/loss/wmse_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
from typing import Callable

import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from lightly.utils.dist import gather


def norm_mse_loss(x0: torch.Tensor, x1: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -59,10 +61,18 @@

f_cov_shrinked = (1 - self.eps) * f_cov + self.eps * eye

# get type of f_cov_shrinked and temporary convert to full precision
# to support chelosky decomposition (only supports fp32)
f_cov_shrinked_type = f_cov_shrinked.dtype
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks super duper hacky. Why is it necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, as written in the comment. The original code is not using half precision.

f_cov_shrinked = f_cov_shrinked.to(torch.float32)

inv_sqrt = torch.linalg.solve_triangular(
torch.linalg.cholesky(f_cov_shrinked), eye, upper=False
)

# convert back to original type
inv_sqrt = inv_sqrt.to(f_cov_shrinked_type)

inv_sqrt = inv_sqrt.contiguous().view(
self.num_features, self.num_features, 1, 1
)
Expand Down Expand Up @@ -117,6 +127,7 @@
w_size: int = 256,
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = norm_mse_loss,
num_samples: int = 2,
gather_distributed: bool = False,
):
"""Parameters as described in [0]

Expand All @@ -137,6 +148,9 @@
Loss function to use for the whitening.
num_samples:
Number of samples generated by the transforms for each image.
gather_distributed:
If True then the cross-correlation matrices from all gpus are
gathered and summed before the loss calculation.


"""
Expand All @@ -147,15 +161,22 @@
eps=eps,
track_running_stats=track_running_stats,
)
if gather_distributed and not dist.is_available():
raise ValueError(

Check warning on line 165 in lightly/loss/wmse_loss.py

View check run for this annotation

Codecov / codecov/patch

lightly/loss/wmse_loss.py#L165

Added line #L165 was not covered by tests
"gather_distributed is True but torch.distributed is not available. "
"Please set gather_distributed=False or install a torch version with "
"distributed support."
)
if embedding_dim * 2 > w_size:
raise ValueError(
"w_size should be at least twice the size of embedding_dim to avoid instabiliy"
f"w_size is {w_size} but it should be at least twice the size of embedding_dim which is {embedding_dim} to avoid instabiliy"
)
self.w_iter = w_iter
self.w_size = w_size
self.loss_f = loss_fn
self.num_samples = num_samples
self.num_pairs = num_samples * (num_samples - 1) // 2
self.gather_distributed = gather_distributed

def forward(self, input: torch.Tensor) -> torch.Tensor:
"""Calculates the W-MSE loss.
Expand All @@ -173,13 +194,18 @@
ValueError:
If the batch size is smaller than w_size.
"""

if input.shape[0] % self.num_samples != 0:
raise RuntimeError("input batch size must be divisible by num_samples")
raise RuntimeError(
f"input batch size is {input.shape[0]} but must be divisible by num_samples which is {self.num_samples}"
)

bs = input.shape[0] // self.num_samples

if bs < self.w_size:
raise ValueError("batch size must be greater than or equal to w_size")
raise ValueError(

Check warning on line 206 in lightly/loss/wmse_loss.py

View check run for this annotation

Codecov / codecov/patch

lightly/loss/wmse_loss.py#L206

Added line #L206 was not covered by tests
f"batch size is {bs} but must be greater than or equal to w_size which is {self.w_size}"
)
loss = torch.tensor(0.0, device=input.device, requires_grad=True)

for _ in range(self.w_iter):
Expand Down
1 change: 1 addition & 0 deletions lightly/models/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
SMoGPrototypes,
SwaVProjectionHead,
SwaVPrototypes,
WMSEProjectionHead,
)
from lightly.models.modules.nn_memory_bank import NNMemoryBankModule

Expand Down
21 changes: 21 additions & 0 deletions lightly/models/modules/heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,27 @@
)


class WMSEProjectionHead(SimCLRProjectionHead):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need this. We should be able to use SimCLR instead.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@guarin, we should make sure things are consistent. I'm not sure what we agreed on. AFAIK, The same goes for the transforms.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aren't the default values different?

In any case, I prefer if all components of the WMSE model are called WMSESomething. Mixing components from different models is always confusing and it makes the components harder to discover in the code. If two models have the same head then we can just subclass from the first model and update the docstring.

"""Projection head used for W-MSE.

Uses the same projection head as SimCLR.[0]

[0]: 2021, W-MSE, https://arxiv.org/pdf/2007.06346.pdf
"""

def __init__(
self,
input_dim: int = 2048,
hidden_dim: int = 2048,
output_dim: int = 128,
num_layers: int = 2,
batch_norm: bool = True,
):
super(WMSEProjectionHead, self).__init__(

Check warning on line 718 in lightly/models/modules/heads.py

View check run for this annotation

Codecov / codecov/patch

lightly/models/modules/heads.py#L718

Added line #L718 was not covered by tests
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
super(WMSEProjectionHead, self).__init__(
super().__init__(

In general the class should not be passed to the super method.

input_dim, hidden_dim, output_dim, num_layers, batch_norm
)


class VICRegProjectionHead(ProjectionHead):
"""Projection head used for VICReg.

Expand Down
Loading