Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Igor lig 4447 w mse benchmark #1474

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
2 changes: 2 additions & 0 deletions benchmarks/imagenet/resnet50/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import tico
import torch
import vicreg
import wmse
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import (
DeviceStatsMonitor,
Expand Down Expand Up @@ -64,6 +65,7 @@
"swav": {"model": swav.SwAV, "transform": swav.transform},
"tico": {"model": tico.TiCo, "transform": tico.transform},
"vicreg": {"model": vicreg.VICReg, "transform": vicreg.transform},
"wmse": {"model": wmse.WMSE, "transform": wmse.transform},
}


Expand Down
109 changes: 109 additions & 0 deletions benchmarks/imagenet/resnet50/wmse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import math
from typing import List, Tuple

import torch
from pytorch_lightning import LightningModule
from torch import Tensor
from torch.nn import Identity
from torchvision.models import resnet50

from lightly.loss.wmse_loss import WMSELoss
from lightly.models.modules import WMSEProjectionHead
from lightly.models.utils import get_weight_decay_parameters
from lightly.transforms import WMSETransform
from lightly.utils.benchmarking import OnlineLinearClassifier
from lightly.utils.lars import LARS
from lightly.utils.scheduler import CosineWarmupScheduler


class WMSE(LightningModule):
def __init__(self, batch_size_per_device: int, num_classes: int) -> None:
super().__init__()
self.save_hyperparameters()
self.batch_size_per_device = batch_size_per_device

resnet = resnet50()
resnet.fc = Identity() # Ignore classification head
self.backbone = resnet

# we use a projection head with output dimension 64
# and w_size of 128 to support a batch size of 256
self.projection_head = WMSEProjectionHead(output_dim=64)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the output dimension is wrong here. From the paper:

Finally, we use an embedding size
of 64 for CIFAR-10 and CIFAR-100, and an embedding of
size of 128 for STL-10 and Tiny ImageNet. For ImageNet-
100 we use a configuration similar to the Tiny ImageNet
experiments, and 240 epochs of training. Finally, in the
ImageNet experiments (Tab. 3), we use the implementation
and the hyperparameter configuration of (Chen et al., 2020b)
(same number of layers in the projection head, etc.) based
on their open-source implementation2, the only difference
being the learning rate and the loss function (respectively,
0.075 and the contrastive loss in (Chen et al., 2020b) vs. 0.1
and Eq. 6 in W-MSE 4

So they're using a SimCLR2 projection head.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And most likely the embedding dim is the same as the one for SimCLR2.


self.criterion_WMSE4loss = WMSELoss(
w_size=128, embedding_dim=64, num_samples=4, gather_distributed=True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For ImageNet they probably use w_size=256:

For CIFAR-10
and CIFAR-100, the slicing sub-batch size is 128, for Tiny
ImageNet and STL-10, it is 256

)

self.online_classifier = OnlineLinearClassifier(num_classes=num_classes)

def forward(self, x: Tensor) -> Tensor:
return self.backbone(x)

def training_step(
self, batch: Tuple[List[Tensor], Tensor, List[str]], batch_idx: int
) -> Tensor:
views, targets = batch[0], batch[1]
features = self.forward(torch.cat(views)).flatten(start_dim=1)
z = self.projection_head(features)
loss = self.criterion_WMSE4loss(z)
self.log(
"train_loss", loss, prog_bar=True, sync_dist=True, batch_size=len(targets)
)

cls_loss, cls_log = self.online_classifier.training_step(
(features.detach(), targets.repeat(len(views))), batch_idx
)
self.log_dict(cls_log, sync_dist=True, batch_size=len(targets))
return loss + cls_loss

def validation_step(
self, batch: Tuple[Tensor, Tensor, List[str]], batch_idx: int
) -> Tensor:
images, targets = batch[0], batch[1]
features = self.forward(images).flatten(start_dim=1)
cls_loss, cls_log = self.online_classifier.validation_step(
(features.detach(), targets), batch_idx
)
self.log_dict(cls_log, prog_bar=True, sync_dist=True, batch_size=len(targets))
return cls_loss

def configure_optimizers(self):
# Don't use weight decay for batch norm, bias parameters, and classification
# head to improve performance.
params, params_no_weight_decay = get_weight_decay_parameters(
[self.backbone, self.projection_head]
)
optimizer = LARS(
[
{"name": "wmse", "params": params},
{
"name": "wmse_no_weight_decay",
"params": params_no_weight_decay,
"weight_decay": 0.0,
},
{
"name": "online_classifier",
"params": self.online_classifier.parameters(),
"weight_decay": 0.0,
},
],
lr=0.1 * math.sqrt(self.batch_size_per_device * self.trainer.world_size),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The denominator is missing here.

momentum=0.9,
weight_decay=1e-6,
)
scheduler = {
"scheduler": CosineWarmupScheduler(
optimizer=optimizer,
warmup_epochs=int(
self.trainer.estimated_stepping_batches
/ self.trainer.max_epochs
* 10
),
max_epochs=int(self.trainer.estimated_stepping_batches),
),
"interval": "step",
}
return [optimizer], [scheduler]


transform = WMSETransform()
37 changes: 34 additions & 3 deletions lightly/loss/wmse_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
from typing import Callable

import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from lightly.utils.dist import gather


def norm_mse_loss(x0: torch.Tensor, x1: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -59,10 +61,18 @@

f_cov_shrinked = (1 - self.eps) * f_cov + self.eps * eye

# get type of f_cov_shrinked and temporary convert to full precision
# to support chelosky decomposition
f_cov_shrinked_type = f_cov_shrinked.dtype
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks super duper hacky. Why is it necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, as written in the comment. The original code is not using half precision.

f_cov_shrinked = f_cov_shrinked.to(torch.float32)

inv_sqrt = torch.linalg.solve_triangular(
torch.linalg.cholesky(f_cov_shrinked), eye, upper=False
)

# convert back to original type
inv_sqrt = inv_sqrt.to(f_cov_shrinked_type)

inv_sqrt = inv_sqrt.contiguous().view(
self.num_features, self.num_features, 1, 1
)
Expand Down Expand Up @@ -117,6 +127,7 @@
w_size: int = 256,
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = norm_mse_loss,
num_samples: int = 2,
gather_distributed: bool = False,
):
"""Parameters as described in [0]

Expand All @@ -137,6 +148,9 @@
Loss function to use for the whitening.
num_samples:
Number of samples generated by the transforms for each image.
gather_distributed:
If True then the cross-correlation matrices from all gpus are
gathered and summed before the loss calculation.


"""
Expand All @@ -147,15 +161,22 @@
eps=eps,
track_running_stats=track_running_stats,
)
if gather_distributed and not dist.is_available():
raise ValueError(

Check warning on line 165 in lightly/loss/wmse_loss.py

View check run for this annotation

Codecov / codecov/patch

lightly/loss/wmse_loss.py#L165

Added line #L165 was not covered by tests
"gather_distributed is True but torch.distributed is not available. "
"Please set gather_distributed=False or install a torch version with "
"distributed support."
)
if embedding_dim * 2 > w_size:
raise ValueError(
"w_size should be at least twice the size of embedding_dim to avoid instabiliy"
f"w_size is {w_size} but it should be at least twice the size of embedding_dim which is {embedding_dim} to avoid instabiliy"
)
self.w_iter = w_iter
self.w_size = w_size
self.loss_f = loss_fn
self.num_samples = num_samples
self.num_pairs = num_samples * (num_samples - 1) // 2
self.gather_distributed = gather_distributed

def forward(self, input: torch.Tensor) -> torch.Tensor:
"""Calculates the W-MSE loss.
Expand All @@ -173,13 +194,23 @@
ValueError:
If the batch size is smaller than w_size.
"""
# gather all batches
if self.gather_distributed and dist.is_initialized():
world_size = dist.get_world_size()
if world_size > 1:
input = torch.cat(gather(input), dim=0)

Check warning on line 201 in lightly/loss/wmse_loss.py

View check run for this annotation

Codecov / codecov/patch

lightly/loss/wmse_loss.py#L199-L201

Added lines #L199 - L201 were not covered by tests
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure this is correct? Intuitively I think there could be problems because now every device computes the exact same loss, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed it but I will add it again. That seems the most easy and proper way to support multi-GPU training. I'll make sure we divide the loss by the number of devices to make runs more comparable between different multi-gpu setups.


if input.shape[0] % self.num_samples != 0:
raise RuntimeError("input batch size must be divisible by num_samples")
raise RuntimeError(
f"input batch size is {input.shape[0]} but must be divisible by num_samples which is {self.num_samples}"
)

bs = input.shape[0] // self.num_samples

if bs < self.w_size:
raise ValueError("batch size must be greater than or equal to w_size")
raise ValueError(

Check warning on line 211 in lightly/loss/wmse_loss.py

View check run for this annotation

Codecov / codecov/patch

lightly/loss/wmse_loss.py#L211

Added line #L211 was not covered by tests
f"batch size is {bs} but must be greater than or equal to w_size which is {self.w_size}"
)
loss = torch.tensor(0.0, device=input.device, requires_grad=True)

for _ in range(self.w_iter):
Expand Down
1 change: 1 addition & 0 deletions lightly/models/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
SMoGPrototypes,
SwaVProjectionHead,
SwaVPrototypes,
WMSEProjectionHead,
)
from lightly.models.modules.nn_memory_bank import NNMemoryBankModule

Expand Down
21 changes: 21 additions & 0 deletions lightly/models/modules/heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,27 @@
)


class WMSEProjectionHead(SimCLRProjectionHead):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need this. We should be able to use SimCLR instead.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@guarin, we should make sure things are consistent. I'm not sure what we agreed on. AFAIK, The same goes for the transforms.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aren't the default values different?

In any case, I prefer if all components of the WMSE model are called WMSESomething. Mixing components from different models is always confusing and it makes the components harder to discover in the code. If two models have the same head then we can just subclass from the first model and update the docstring.

"""Projection head used for W-MSE.

Uses the same projection head as SimCLR.[0]

[0]: 2021, W-MSE, https://arxiv.org/pdf/2007.06346.pdf
"""

def __init__(
self,
input_dim: int = 2048,
hidden_dim: int = 2048,
output_dim: int = 128,
num_layers: int = 2,
batch_norm: bool = True,
):
super(WMSEProjectionHead, self).__init__(

Check warning on line 718 in lightly/models/modules/heads.py

View check run for this annotation

Codecov / codecov/patch

lightly/models/modules/heads.py#L718

Added line #L718 was not covered by tests
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
super(WMSEProjectionHead, self).__init__(
super().__init__(

In general the class should not be passed to the super method.

input_dim, hidden_dim, output_dim, num_layers, batch_norm
)


class VICRegProjectionHead(ProjectionHead):
"""Projection head used for VICReg.

Expand Down
Loading