-
Notifications
You must be signed in to change notification settings - Fork 287
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Igor lig 4447 w mse benchmark #1474
base: master
Are you sure you want to change the base?
Changes from 11 commits
c003c32
d84b0e5
753cd5c
6d53a0a
dc99515
8b19156
827c8af
e8b9d7e
f5a913c
d4457dc
a798dcb
0fbde50
7ad8205
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
import math | ||
from typing import List, Tuple | ||
|
||
import torch | ||
from pytorch_lightning import LightningModule | ||
from torch import Tensor | ||
from torch.nn import Identity | ||
from torchvision.models import resnet50 | ||
|
||
from lightly.loss.wmse_loss import WMSELoss | ||
from lightly.models.modules import WMSEProjectionHead | ||
from lightly.models.utils import get_weight_decay_parameters | ||
from lightly.transforms import WMSETransform | ||
from lightly.utils.benchmarking import OnlineLinearClassifier | ||
from lightly.utils.lars import LARS | ||
from lightly.utils.scheduler import CosineWarmupScheduler | ||
|
||
|
||
class WMSE(LightningModule): | ||
def __init__(self, batch_size_per_device: int, num_classes: int) -> None: | ||
super().__init__() | ||
self.save_hyperparameters() | ||
self.batch_size_per_device = batch_size_per_device | ||
|
||
resnet = resnet50() | ||
resnet.fc = Identity() # Ignore classification head | ||
self.backbone = resnet | ||
|
||
# we use a projection head with output dimension 128 | ||
# and w_size of 256 to support a batch size of 512 | ||
self.projection_head = WMSEProjectionHead(output_dim=128) | ||
|
||
self.criterion_WMSE4loss = WMSELoss( | ||
w_size=256, embedding_dim=128, num_samples=4 | ||
) | ||
|
||
self.online_classifier = OnlineLinearClassifier(num_classes=num_classes) | ||
|
||
def forward(self, x: Tensor) -> Tensor: | ||
return self.backbone(x) | ||
|
||
def training_step( | ||
self, batch: Tuple[List[Tensor], Tensor, List[str]], batch_idx: int | ||
) -> Tensor: | ||
views, targets = batch[0], batch[1] | ||
features = self.forward(torch.cat(views)).flatten(start_dim=1) | ||
z = self.projection_head(features) | ||
loss = self.criterion_WMSE4loss(z) | ||
self.log( | ||
"train_loss", loss, prog_bar=True, sync_dist=True, batch_size=len(targets) | ||
) | ||
|
||
cls_loss, cls_log = self.online_classifier.training_step( | ||
(features.detach(), targets.repeat(len(views))), batch_idx | ||
) | ||
self.log_dict(cls_log, sync_dist=True, batch_size=len(targets)) | ||
return loss + cls_loss | ||
|
||
def validation_step( | ||
self, batch: Tuple[Tensor, Tensor, List[str]], batch_idx: int | ||
) -> Tensor: | ||
images, targets = batch[0], batch[1] | ||
features = self.forward(images).flatten(start_dim=1) | ||
cls_loss, cls_log = self.online_classifier.validation_step( | ||
(features.detach(), targets), batch_idx | ||
) | ||
self.log_dict(cls_log, prog_bar=True, sync_dist=True, batch_size=len(targets)) | ||
return cls_loss | ||
|
||
def configure_optimizers(self): | ||
# Don't use weight decay for batch norm, bias parameters, and classification | ||
# head to improve performance. | ||
params, params_no_weight_decay = get_weight_decay_parameters( | ||
[self.backbone, self.projection_head] | ||
) | ||
optimizer = LARS( | ||
[ | ||
{"name": "wmse", "params": params}, | ||
{ | ||
"name": "wmse_no_weight_decay", | ||
"params": params_no_weight_decay, | ||
"weight_decay": 0.0, | ||
}, | ||
{ | ||
"name": "online_classifier", | ||
"params": self.online_classifier.parameters(), | ||
"weight_decay": 0.0, | ||
}, | ||
], | ||
lr=0.1 * math.sqrt(self.batch_size_per_device * self.trainer.world_size), | ||
momentum=0.9, | ||
weight_decay=1e-6, | ||
) | ||
scheduler = { | ||
"scheduler": CosineWarmupScheduler( | ||
optimizer=optimizer, | ||
warmup_epochs=int( | ||
self.trainer.estimated_stepping_batches | ||
/ self.trainer.max_epochs | ||
* 10 | ||
), | ||
max_epochs=int(self.trainer.estimated_stepping_batches), | ||
), | ||
"interval": "step", | ||
} | ||
return [optimizer], [scheduler] | ||
|
||
|
||
transform = WMSETransform() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,8 +3,10 @@ | |
from typing import Callable | ||
|
||
import torch | ||
import torch.distributed as dist | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from lightly.utils.dist import gather | ||
|
||
|
||
def norm_mse_loss(x0: torch.Tensor, x1: torch.Tensor) -> torch.Tensor: | ||
|
@@ -59,10 +61,18 @@ | |
|
||
f_cov_shrinked = (1 - self.eps) * f_cov + self.eps * eye | ||
|
||
# get type of f_cov_shrinked and temporary convert to full precision | ||
# to support chelosky decomposition (only supports fp32) | ||
f_cov_shrinked_type = f_cov_shrinked.dtype | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks super duper hacky. Why is it necessary? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, as written in the comment. The original code is not using half precision. |
||
f_cov_shrinked = f_cov_shrinked.to(torch.float32) | ||
|
||
inv_sqrt = torch.linalg.solve_triangular( | ||
torch.linalg.cholesky(f_cov_shrinked), eye, upper=False | ||
) | ||
|
||
# convert back to original type | ||
inv_sqrt = inv_sqrt.to(f_cov_shrinked_type) | ||
|
||
inv_sqrt = inv_sqrt.contiguous().view( | ||
self.num_features, self.num_features, 1, 1 | ||
) | ||
|
@@ -117,6 +127,7 @@ | |
w_size: int = 256, | ||
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = norm_mse_loss, | ||
num_samples: int = 2, | ||
gather_distributed: bool = False, | ||
): | ||
"""Parameters as described in [0] | ||
|
||
|
@@ -137,6 +148,9 @@ | |
Loss function to use for the whitening. | ||
num_samples: | ||
Number of samples generated by the transforms for each image. | ||
gather_distributed: | ||
If True then the cross-correlation matrices from all gpus are | ||
gathered and summed before the loss calculation. | ||
|
||
|
||
""" | ||
|
@@ -147,15 +161,22 @@ | |
eps=eps, | ||
track_running_stats=track_running_stats, | ||
) | ||
if gather_distributed and not dist.is_available(): | ||
raise ValueError( | ||
"gather_distributed is True but torch.distributed is not available. " | ||
"Please set gather_distributed=False or install a torch version with " | ||
"distributed support." | ||
) | ||
if embedding_dim * 2 > w_size: | ||
raise ValueError( | ||
"w_size should be at least twice the size of embedding_dim to avoid instabiliy" | ||
f"w_size is {w_size} but it should be at least twice the size of embedding_dim which is {embedding_dim} to avoid instabiliy" | ||
) | ||
self.w_iter = w_iter | ||
self.w_size = w_size | ||
self.loss_f = loss_fn | ||
self.num_samples = num_samples | ||
self.num_pairs = num_samples * (num_samples - 1) // 2 | ||
self.gather_distributed = gather_distributed | ||
|
||
def forward(self, input: torch.Tensor) -> torch.Tensor: | ||
"""Calculates the W-MSE loss. | ||
|
@@ -173,13 +194,18 @@ | |
ValueError: | ||
If the batch size is smaller than w_size. | ||
""" | ||
|
||
if input.shape[0] % self.num_samples != 0: | ||
raise RuntimeError("input batch size must be divisible by num_samples") | ||
raise RuntimeError( | ||
f"input batch size is {input.shape[0]} but must be divisible by num_samples which is {self.num_samples}" | ||
) | ||
|
||
bs = input.shape[0] // self.num_samples | ||
|
||
if bs < self.w_size: | ||
raise ValueError("batch size must be greater than or equal to w_size") | ||
raise ValueError( | ||
f"batch size is {bs} but must be greater than or equal to w_size which is {self.w_size}" | ||
) | ||
loss = torch.tensor(0.0, device=input.device, requires_grad=True) | ||
|
||
for _ in range(self.w_iter): | ||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -699,6 +699,27 @@ | |||||
) | ||||||
|
||||||
|
||||||
class WMSEProjectionHead(SimCLRProjectionHead): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think we need this. We should be able to use SimCLR instead. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @guarin, we should make sure things are consistent. I'm not sure what we agreed on. AFAIK, The same goes for the transforms. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Aren't the default values different? In any case, I prefer if all components of the |
||||||
"""Projection head used for W-MSE. | ||||||
|
||||||
Uses the same projection head as SimCLR.[0] | ||||||
|
||||||
[0]: 2021, W-MSE, https://arxiv.org/pdf/2007.06346.pdf | ||||||
""" | ||||||
|
||||||
def __init__( | ||||||
self, | ||||||
input_dim: int = 2048, | ||||||
hidden_dim: int = 2048, | ||||||
output_dim: int = 128, | ||||||
num_layers: int = 2, | ||||||
batch_norm: bool = True, | ||||||
): | ||||||
super(WMSEProjectionHead, self).__init__( | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
In general the class should not be passed to the super method. |
||||||
input_dim, hidden_dim, output_dim, num_layers, batch_norm | ||||||
) | ||||||
|
||||||
|
||||||
class VICRegProjectionHead(ProjectionHead): | ||||||
"""Projection head used for VICReg. | ||||||
|
||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The denominator is missing here.