From c003c3224736c4ad0bb9463a3bd9214b5d18d1e4 Mon Sep 17 00:00:00 2001 From: IgorSusmelj Date: Thu, 11 Jan 2024 21:10:01 +0100 Subject: [PATCH 01/13] Add projection head --- lightly/models/modules/__init__.py | 1 + lightly/models/modules/heads.py | 21 +++++++++++++++++++++ 2 files changed, 22 insertions(+) diff --git a/lightly/models/modules/__init__.py b/lightly/models/modules/__init__.py index 3d8fe3d7c..3b3db65be 100644 --- a/lightly/models/modules/__init__.py +++ b/lightly/models/modules/__init__.py @@ -25,6 +25,7 @@ SMoGPrototypes, SwaVProjectionHead, SwaVPrototypes, + WMSEProjectionHead, ) from lightly.models.modules.nn_memory_bank import NNMemoryBankModule diff --git a/lightly/models/modules/heads.py b/lightly/models/modules/heads.py index 541b49a29..9bd3abf29 100644 --- a/lightly/models/modules/heads.py +++ b/lightly/models/modules/heads.py @@ -699,6 +699,27 @@ def __init__( ) +class WMSEProjectionHead(SimCLRProjectionHead): + """Projection head used for W-MSE. + + Uses the same projection head as SimCLR.[0] + + [0]: 2021, W-MSE, https://arxiv.org/pdf/2007.06346.pdf + """ + + def __init__( + self, + input_dim: int = 2048, + hidden_dim: int = 2048, + output_dim: int = 128, + num_layers: int = 2, + batch_norm: bool = True, + ): + super(SimCLRProjectionHead).__init__( + input_dim, hidden_dim, output_dim, num_layers, batch_norm + ) + + class VICRegProjectionHead(ProjectionHead): """Projection head used for VICReg. From d84b0e5fd14d9ea3ab370856932f313a626e5ed1 Mon Sep 17 00:00:00 2001 From: IgorSusmelj Date: Thu, 11 Jan 2024 21:10:26 +0100 Subject: [PATCH 02/13] Add ImageNet benchmark code for wmse --- benchmarks/imagenet/resnet50/wmse.py | 104 +++++++++++++++++++++++++++ 1 file changed, 104 insertions(+) create mode 100644 benchmarks/imagenet/resnet50/wmse.py diff --git a/benchmarks/imagenet/resnet50/wmse.py b/benchmarks/imagenet/resnet50/wmse.py new file mode 100644 index 000000000..cc6f1700e --- /dev/null +++ b/benchmarks/imagenet/resnet50/wmse.py @@ -0,0 +1,104 @@ +import math +from typing import List, Tuple + +import torch +from pytorch_lightning import LightningModule +from torch import Tensor +from torch.nn import Identity +from torchvision.models import resnet50 + +from lightly.loss.wmse_loss import WMSELoss +from lightly.models.modules import WMSEProjectionHead +from lightly.models.utils import get_weight_decay_parameters +from lightly.transforms import WMSETransform +from lightly.utils.benchmarking import OnlineLinearClassifier +from lightly.utils.lars import LARS +from lightly.utils.scheduler import CosineWarmupScheduler + + +class WMSE(LightningModule): + def __init__(self, batch_size_per_device: int, num_classes: int) -> None: + super().__init__() + self.save_hyperparameters() + self.batch_size_per_device = batch_size_per_device + + resnet = resnet50() + resnet.fc = Identity() # Ignore classification head + self.backbone = resnet + self.projection_head = WMSEProjectionHead() + self.criterion_WMSE4loss = WMSELoss(num_samples=4) + + self.online_classifier = OnlineLinearClassifier(num_classes=num_classes) + + def forward(self, x: Tensor) -> Tensor: + return self.backbone(x) + + def training_step( + self, batch: Tuple[List[Tensor], Tensor, List[str]], batch_idx: int + ) -> Tensor: + views, targets = batch[0], batch[1] + features = self.forward(torch.cat(views)).flatten(start_dim=1) + z = self.projection_head(features) + z0, z1 = z.chunk(len(views)) + loss = self.criterion_WMSE4loss(z0, z1) + self.log( + "train_loss", loss, prog_bar=True, sync_dist=True, batch_size=len(targets) + ) + + cls_loss, cls_log = self.online_classifier.training_step( + (features.detach(), targets.repeat(len(views))), batch_idx + ) + self.log_dict(cls_log, sync_dist=True, batch_size=len(targets)) + return loss + cls_loss + + def validation_step( + self, batch: Tuple[Tensor, Tensor, List[str]], batch_idx: int + ) -> Tensor: + images, targets = batch[0], batch[1] + features = self.forward(images).flatten(start_dim=1) + cls_loss, cls_log = self.online_classifier.validation_step( + (features.detach(), targets), batch_idx + ) + self.log_dict(cls_log, prog_bar=True, sync_dist=True, batch_size=len(targets)) + return cls_loss + + def configure_optimizers(self): + # Don't use weight decay for batch norm, bias parameters, and classification + # head to improve performance. + params, params_no_weight_decay = get_weight_decay_parameters( + [self.backbone, self.projection_head] + ) + optimizer = LARS( + [ + {"name": "wmse", "params": params}, + { + "name": "wmse_no_weight_decay", + "params": params_no_weight_decay, + "weight_decay": 0.0, + }, + { + "name": "online_classifier", + "params": self.online_classifier.parameters(), + "weight_decay": 0.0, + }, + ], + lr=0.1 * math.sqrt(self.batch_size_per_device * self.trainer.world_size), + momentum=0.9, + weight_decay=1e-6, + ) + scheduler = { + "scheduler": CosineWarmupScheduler( + optimizer=optimizer, + warmup_epochs=int( + self.trainer.estimated_stepping_batches + / self.trainer.max_epochs + * 10 + ), + max_epochs=int(self.trainer.estimated_stepping_batches), + ), + "interval": "step", + } + return [optimizer], [scheduler] + + +transform = WMSETransform() From 753cd5ca90695d0e26803ecaa7fe3c1f39bfe916 Mon Sep 17 00:00:00 2001 From: IgorSusmelj Date: Thu, 11 Jan 2024 21:17:50 +0100 Subject: [PATCH 03/13] Add wmse model to main script --- benchmarks/imagenet/resnet50/main.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/benchmarks/imagenet/resnet50/main.py b/benchmarks/imagenet/resnet50/main.py index 6230a8667..ebce13240 100644 --- a/benchmarks/imagenet/resnet50/main.py +++ b/benchmarks/imagenet/resnet50/main.py @@ -17,6 +17,7 @@ import tico import torch import vicreg +import wmse from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.callbacks import ( DeviceStatsMonitor, @@ -64,6 +65,7 @@ "swav": {"model": swav.SwAV, "transform": swav.transform}, "tico": {"model": tico.TiCo, "transform": tico.transform}, "vicreg": {"model": vicreg.VICReg, "transform": vicreg.transform}, + "wmse": {"model": wmse.WMSE, "transform": wmse.transform}, } From 6d53a0a703079c5416e3853bd3e90eb7115d69a2 Mon Sep 17 00:00:00 2001 From: IgorSusmelj Date: Thu, 11 Jan 2024 21:21:31 +0100 Subject: [PATCH 04/13] Fix head --- lightly/models/modules/heads.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightly/models/modules/heads.py b/lightly/models/modules/heads.py index 9bd3abf29..a7cc69aac 100644 --- a/lightly/models/modules/heads.py +++ b/lightly/models/modules/heads.py @@ -715,7 +715,7 @@ def __init__( num_layers: int = 2, batch_norm: bool = True, ): - super(SimCLRProjectionHead).__init__( + super(WMSEProjectionHead, self).__init__( input_dim, hidden_dim, output_dim, num_layers, batch_norm ) From dc9951596973ff51f59a51dc90f5d51c8ac78305 Mon Sep 17 00:00:00 2001 From: IgorSusmelj Date: Thu, 11 Jan 2024 22:51:44 +0100 Subject: [PATCH 05/13] Add better error messages. Support mixed precision and distributed. --- lightly/loss/wmse_loss.py | 37 ++++++++++++++++++++++++++++++++++--- 1 file changed, 34 insertions(+), 3 deletions(-) diff --git a/lightly/loss/wmse_loss.py b/lightly/loss/wmse_loss.py index 4c3fe1f8b..15c2e7472 100644 --- a/lightly/loss/wmse_loss.py +++ b/lightly/loss/wmse_loss.py @@ -3,8 +3,10 @@ from typing import Callable import torch +import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F +from lightly.utils.dist import gather def norm_mse_loss(x0: torch.Tensor, x1: torch.Tensor) -> torch.Tensor: @@ -59,10 +61,18 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: f_cov_shrinked = (1 - self.eps) * f_cov + self.eps * eye + # get type of f_cov_shrinked and temporary convert to full precision + # to support chelosky decomposition + f_cov_shrinked_type = f_cov_shrinked.dtype + f_cov_shrinked = f_cov_shrinked.to(torch.float32) + inv_sqrt = torch.linalg.solve_triangular( torch.linalg.cholesky(f_cov_shrinked), eye, upper=False ) + # convert back to original type + inv_sqrt = inv_sqrt.to(f_cov_shrinked_type) + inv_sqrt = inv_sqrt.contiguous().view( self.num_features, self.num_features, 1, 1 ) @@ -117,6 +127,7 @@ def __init__( w_size: int = 256, loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = norm_mse_loss, num_samples: int = 2, + gather_distributed: bool = False, ): """Parameters as described in [0] @@ -137,6 +148,9 @@ def __init__( Loss function to use for the whitening. num_samples: Number of samples generated by the transforms for each image. + gather_distributed: + If True then the cross-correlation matrices from all gpus are + gathered and summed before the loss calculation. """ @@ -147,15 +161,22 @@ def __init__( eps=eps, track_running_stats=track_running_stats, ) + if gather_distributed and not dist.is_available(): + raise ValueError( + "gather_distributed is True but torch.distributed is not available. " + "Please set gather_distributed=False or install a torch version with " + "distributed support." + ) if embedding_dim * 2 > w_size: raise ValueError( - "w_size should be at least twice the size of embedding_dim to avoid instabiliy" + f"w_size is {w_size} but it should be at least twice the size of embedding_dim which is {embedding_dim} to avoid instabiliy" ) self.w_iter = w_iter self.w_size = w_size self.loss_f = loss_fn self.num_samples = num_samples self.num_pairs = num_samples * (num_samples - 1) // 2 + self.gather_distributed = gather_distributed def forward(self, input: torch.Tensor) -> torch.Tensor: """Calculates the W-MSE loss. @@ -173,13 +194,23 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: ValueError: If the batch size is smaller than w_size. """ + # gather all batches + if self.gather_distributed and dist.is_initialized(): + world_size = dist.get_world_size() + if world_size > 1: + input = torch.cat(gather(input), dim=0) + if input.shape[0] % self.num_samples != 0: - raise RuntimeError("input batch size must be divisible by num_samples") + raise RuntimeError( + f"input batch size is {input.shape[0]} but must be divisible by num_samples which is {self.num_samples}" + ) bs = input.shape[0] // self.num_samples if bs < self.w_size: - raise ValueError("batch size must be greater than or equal to w_size") + raise ValueError( + f"batch size is {bs} but must be greater than or equal to w_size which is {self.w_size}" + ) loss = torch.tensor(0.0, device=input.device, requires_grad=True) for _ in range(self.w_iter): From 8b19156257b63028ff268e3f4d95fe2e744368c2 Mon Sep 17 00:00:00 2001 From: IgorSusmelj Date: Thu, 11 Jan 2024 22:53:16 +0100 Subject: [PATCH 06/13] Update code --- benchmarks/imagenet/resnet50/wmse.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/benchmarks/imagenet/resnet50/wmse.py b/benchmarks/imagenet/resnet50/wmse.py index cc6f1700e..d632367e7 100644 --- a/benchmarks/imagenet/resnet50/wmse.py +++ b/benchmarks/imagenet/resnet50/wmse.py @@ -25,8 +25,14 @@ def __init__(self, batch_size_per_device: int, num_classes: int) -> None: resnet = resnet50() resnet.fc = Identity() # Ignore classification head self.backbone = resnet - self.projection_head = WMSEProjectionHead() - self.criterion_WMSE4loss = WMSELoss(num_samples=4) + + # we use a projection head with output dimension 64 + # and w_size of 128 to support a batch size of 256 + self.projection_head = WMSEProjectionHead(output_dim=64) + + self.criterion_WMSE4loss = WMSELoss( + w_size=128, embedding_dim=64, num_samples=4, gather_distributed=True + ) self.online_classifier = OnlineLinearClassifier(num_classes=num_classes) @@ -39,8 +45,7 @@ def training_step( views, targets = batch[0], batch[1] features = self.forward(torch.cat(views)).flatten(start_dim=1) z = self.projection_head(features) - z0, z1 = z.chunk(len(views)) - loss = self.criterion_WMSE4loss(z0, z1) + loss = self.criterion_WMSE4loss(z) self.log( "train_loss", loss, prog_bar=True, sync_dist=True, batch_size=len(targets) ) From 827c8af40710f8c796a2c0e1804cb41326efd05b Mon Sep 17 00:00:00 2001 From: IgorSusmelj Date: Mon, 15 Jan 2024 08:30:22 +0100 Subject: [PATCH 07/13] Update comment to make clear why we have to use fp32 --- lightly/loss/wmse_loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightly/loss/wmse_loss.py b/lightly/loss/wmse_loss.py index 15c2e7472..3328e0800 100644 --- a/lightly/loss/wmse_loss.py +++ b/lightly/loss/wmse_loss.py @@ -62,7 +62,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: f_cov_shrinked = (1 - self.eps) * f_cov + self.eps * eye # get type of f_cov_shrinked and temporary convert to full precision - # to support chelosky decomposition + # to support chelosky decomposition (only supports fp32) f_cov_shrinked_type = f_cov_shrinked.dtype f_cov_shrinked = f_cov_shrinked.to(torch.float32) From e8b9d7e13152150bdb70710a88e260cfb16ef560 Mon Sep 17 00:00:00 2001 From: IgorSusmelj Date: Mon, 15 Jan 2024 14:09:29 +0100 Subject: [PATCH 08/13] Use parmeters from paper --- benchmarks/imagenet/resnet50/wmse.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/benchmarks/imagenet/resnet50/wmse.py b/benchmarks/imagenet/resnet50/wmse.py index d632367e7..a511e1547 100644 --- a/benchmarks/imagenet/resnet50/wmse.py +++ b/benchmarks/imagenet/resnet50/wmse.py @@ -26,12 +26,12 @@ def __init__(self, batch_size_per_device: int, num_classes: int) -> None: resnet.fc = Identity() # Ignore classification head self.backbone = resnet - # we use a projection head with output dimension 64 - # and w_size of 128 to support a batch size of 256 - self.projection_head = WMSEProjectionHead(output_dim=64) + # we use a projection head with output dimension 128 + # and w_size of 256 to support a batch size of 512 + self.projection_head = WMSEProjectionHead(output_dim=128) self.criterion_WMSE4loss = WMSELoss( - w_size=128, embedding_dim=64, num_samples=4, gather_distributed=True + w_size=256, embedding_dim=128, num_samples=4 ) self.online_classifier = OnlineLinearClassifier(num_classes=num_classes) From f5a913c38f08e034c8b01fc1c40394764dc9e009 Mon Sep 17 00:00:00 2001 From: IgorSusmelj Date: Mon, 15 Jan 2024 14:09:56 +0100 Subject: [PATCH 09/13] Remove distributed gathering --- lightly/loss/wmse_loss.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/lightly/loss/wmse_loss.py b/lightly/loss/wmse_loss.py index 3328e0800..7366d9a84 100644 --- a/lightly/loss/wmse_loss.py +++ b/lightly/loss/wmse_loss.py @@ -194,11 +194,6 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: ValueError: If the batch size is smaller than w_size. """ - # gather all batches - if self.gather_distributed and dist.is_initialized(): - world_size = dist.get_world_size() - if world_size > 1: - input = torch.cat(gather(input), dim=0) if input.shape[0] % self.num_samples != 0: raise RuntimeError( From d4457dc9e367dd45973456f840e30e7682ee1b3e Mon Sep 17 00:00:00 2001 From: IgorSusmelj Date: Fri, 19 Jan 2024 08:48:19 +0100 Subject: [PATCH 10/13] Add W-MSE results --- docs/source/getting_started/benchmarks.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/getting_started/benchmarks.rst b/docs/source/getting_started/benchmarks.rst index 6edf54152..9c41f6784 100644 --- a/docs/source/getting_started/benchmarks.rst +++ b/docs/source/getting_started/benchmarks.rst @@ -38,6 +38,7 @@ Evaluation settings are based on the following papers: "SwAV", "Res50", "256", "100", "67.2", "88.1", "75.4", "92.7", "49.5", "78.6", "`link `_", "`link `_" "TiCo", "Res50", "256", "100", "49.7", "74.4", "72.7", "90.9", "26.6", "53.6", "-", "`link `_" "VICReg", "Res50", "256", "100", "63.0", "85.4", "73.7", "91.9", "46.3", "75.2", "`link `_", "`link `_" + "W-MSE", "Res50", "512", "100", "54.6", "78.8", "73.6", "91.5", "31.2", "60.4", "-", "`link `_" *\*We use square root learning rate scaling instead of linear scaling as it yields better results for smaller batch sizes. See Appendix B.1 in the SimCLR paper.* From a798dcbc607f4941606c24f63e088dfe7185f432 Mon Sep 17 00:00:00 2001 From: IgorSusmelj Date: Fri, 19 Jan 2024 08:51:20 +0100 Subject: [PATCH 11/13] Add W-MSE results --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 59afc4ce1..23daeb4a5 100644 --- a/README.md +++ b/README.md @@ -283,6 +283,7 @@ See the [benchmarking scripts](./benchmarks/imagenet/resnet50/) for details. | SwAV | Res50 | 256 | 100 | 67.2 | 75.4 | 49.5 | [link](https://tensorboard.dev/experiment/Ipx4Oxl5Qkqm5Sl5kWyKKg) | [link](https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_swav_2023-05-25_08-29-14/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt) | | TiCo | Res50 | 256 | 100 | 49.7 | 72.7 | 26.6 | - | [link](https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_tico_2024-01-07_18-40-57/pretrain/version_0/checkpoints/epoch%3D99-step%3D250200.ckpt) | | VICReg | Res50 | 256 | 100 | 63.0 | 73.7 | 46.3 | [link](https://tensorboard.dev/experiment/qH5uywJbTJSzgCEfxc7yUw) | [link](https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_vicreg_2023-09-11_10-53-08/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt) | +| W-MSE | Res50 | 512 | 100 | 54.6 | 73.6 | 31.2 | - | [link](https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_wmse_2024-01-15_22-02-33/pretrain/version_0/checkpoints/epoch%3D99-step%3D250200.ckpt) | _\*We use square root learning rate scaling instead of linear scaling as it yields better results for smaller batch sizes. See Appendix B.1 in the [SimCLR paper](https://arxiv.org/abs/2002.05709)._ From 0fbde505af72dba7af23bf376cf0d3540785a65f Mon Sep 17 00:00:00 2001 From: IgorSusmelj Date: Sun, 21 Jan 2024 15:41:15 +0100 Subject: [PATCH 12/13] Add distributed gather again --- lightly/loss/wmse_loss.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/lightly/loss/wmse_loss.py b/lightly/loss/wmse_loss.py index 7366d9a84..c594e98f1 100644 --- a/lightly/loss/wmse_loss.py +++ b/lightly/loss/wmse_loss.py @@ -70,7 +70,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: torch.linalg.cholesky(f_cov_shrinked), eye, upper=False ) - # convert back to original type + # convert back to original type (e.g. fp16) inv_sqrt = inv_sqrt.to(f_cov_shrinked_type) inv_sqrt = inv_sqrt.contiguous().view( @@ -195,6 +195,12 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: If the batch size is smaller than w_size. """ + # gather all batches + if self.gather_distributed and dist.is_initialized(): + world_size = dist.get_world_size() + if world_size > 1: + input = torch.cat(gather(input), dim=0) + if input.shape[0] % self.num_samples != 0: raise RuntimeError( f"input batch size is {input.shape[0]} but must be divisible by num_samples which is {self.num_samples}" From 7ad82054c97e255b599cbea9737d73712ba60243 Mon Sep 17 00:00:00 2001 From: IgorSusmelj Date: Sun, 21 Jan 2024 16:04:04 +0100 Subject: [PATCH 13/13] Update code to match ImageNet experiments --- benchmarks/imagenet/resnet50/wmse.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/benchmarks/imagenet/resnet50/wmse.py b/benchmarks/imagenet/resnet50/wmse.py index a511e1547..0c78afcf6 100644 --- a/benchmarks/imagenet/resnet50/wmse.py +++ b/benchmarks/imagenet/resnet50/wmse.py @@ -26,13 +26,11 @@ def __init__(self, batch_size_per_device: int, num_classes: int) -> None: resnet.fc = Identity() # Ignore classification head self.backbone = resnet - # we use a projection head with output dimension 128 - # and w_size of 256 to support a batch size of 512 - self.projection_head = WMSEProjectionHead(output_dim=128) + # we use a projection head with 3 layers for ImageNet + self.projection_head = WMSEProjectionHead(num_layers=3) - self.criterion_WMSE4loss = WMSELoss( - w_size=256, embedding_dim=128, num_samples=4 - ) + # we use 4 samples per image for ImageNet + self.criterion_WMSE4loss = WMSELoss(num_samples=4, gather_distributed=True) self.online_classifier = OnlineLinearClassifier(num_classes=num_classes) @@ -106,4 +104,5 @@ def configure_optimizers(self): return [optimizer], [scheduler] -transform = WMSETransform() +# we use 4 samples per image for ImageNet +transform = WMSETransform(num_samples=4)