Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Implement SSL-EY #1443

Open
wants to merge 26 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ and PyTorch Lightning distributed examples for all models to kickstart your proj
- SimMIM, 2021 [paper](https://arxiv.org/abs/2111.09886) [docs](https://docs.lightly.ai/self-supervised-learning/examples/simmim.html)
- SimSiam, 2021 [paper](https://arxiv.org/abs/2011.10566) [docs](https://docs.lightly.ai/self-supervised-learning/examples/simsiam.html)
- SMoG, 2022 [paper](https://arxiv.org/abs/2207.06167) [docs](https://docs.lightly.ai/self-supervised-learning/examples/smog.html)
- SSL-EY, 2023 [paper](https://arxiv.org/abs/2310.01012) [docs](https://docs.lightly.ai/self-supervised-learning/examples/ssley.html)
- SwaV, 2020 [paper](https://arxiv.org/abs/2006.09882) [docs](https://docs.lightly.ai/self-supervised-learning/examples/swav.html)
- TiCo, 2022 [paper](https://arxiv.org/abs/2206.10698) [docs](https://docs.lightly.ai/self-supervised-learning/examples/tico.html)
- VICReg, 2022 [paper](https://arxiv.org/abs/2105.04906) [docs](https://docs.lightly.ai/self-supervised-learning/examples/vicreg.html)
Expand Down Expand Up @@ -287,6 +288,7 @@ See the [benchmarking scripts](./benchmarks/imagenet/resnet50/) for details.
| SimCLR* | Res50 | 256 | 100 | 63.2 | 73.9 | 44.8 | [link](https://tensorboard.dev/experiment/Ugol97adQdezgcVibDYMMA) | [link](https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_simclr_2023-06-22_09-11-13/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt) |
| SimCLR* + DCL | Res50 | 256 | 100 | 65.1 | 73.5 | 49.6 | [link](https://tensorboard.dev/experiment/k4ZonZ77QzmBkc0lXswQlg/) | [link](https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_dcl_2023-07-04_16-51-40/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt) |
| SimCLR* + DCLW | Res50 | 256 | 100 | 64.5 | 73.2 | 48.5 | [link](https://tensorboard.dev/experiment/TrALnpwFQ4OkZV3uvaX7wQ/) | [link](https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_dclw_2023-07-07_14-57-13/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt) |
| SSL-EY | Res50 | 256 | 100 | TODO | TODO | TODO | [link](TODO)| [link](TODO) |
| SwAV | Res50 | 256 | 100 | 67.2 | 75.4 | 49.5 | [link](https://tensorboard.dev/experiment/Ipx4Oxl5Qkqm5Sl5kWyKKg) | [link](https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_swav_2023-05-25_08-29-14/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt) |
| VICReg | Res50 | 256 | 100 | 63.0 | 73.7 | 46.3 | [link](https://tensorboard.dev/experiment/qH5uywJbTJSzgCEfxc7yUw) | [link](https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_vicreg_2023-09-11_10-53-08/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt) |

Expand Down
2 changes: 2 additions & 0 deletions benchmarks/imagenet/resnet50/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import linear_eval
import mocov2
import simclr
import ssley
import swav
import torch
import vicreg
Expand Down Expand Up @@ -60,6 +61,7 @@
"dino": {"model": dino.DINO, "transform": dino.transform},
"mocov2": {"model": mocov2.MoCoV2, "transform": mocov2.transform},
"simclr": {"model": simclr.SimCLR, "transform": simclr.transform},
"ssley": {"model": ssley.SSLEY, "transform": ssley.transform},
"swav": {"model": swav.SwAV, "transform": swav.transform},
"vicreg": {"model": vicreg.VICReg, "transform": vicreg.transform},
}
Expand Down
127 changes: 127 additions & 0 deletions benchmarks/imagenet/resnet50/ssley.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
from typing import List, Tuple

import torch
from pytorch_lightning import LightningModule
from torch import Tensor
from torch.nn import Identity
from torchvision.models import resnet50

from lightly.loss.ssley_loss import SSLEYLoss
from lightly.models.modules.heads import SSLEYProjectionHead
from lightly.models.utils import get_weight_decay_parameters
from lightly.transforms import SSLEYTransform
from lightly.utils.benchmarking import OnlineLinearClassifier
from lightly.utils.lars import LARS
from lightly.utils.scheduler import CosineWarmupScheduler


class SSLEY(LightningModule):
def __init__(self, batch_size_per_device: int, num_classes: int) -> None:
super().__init__()
self.save_hyperparameters()
self.batch_size_per_device = batch_size_per_device

resnet = resnet50()
resnet.fc = Identity() # Ignore classification head
self.backbone = resnet
self.projection_head = SSLEYProjectionHead()
self.criterion = SSLEYLoss()

self.online_classifier = OnlineLinearClassifier(num_classes=num_classes)

def forward(self, x: Tensor) -> Tensor:
return self.backbone(x)

def training_step(
self, batch: Tuple[List[Tensor], Tensor, List[str]], batch_idx: int
) -> Tensor:
views, targets = batch[0], batch[1]
features = self.forward(torch.cat(views)).flatten(start_dim=1)
z = self.projection_head(features)
z_a, z_b = z.chunk(len(views))
loss = self.criterion(z_a=z_a, z_b=z_b)
self.log(
"train_loss", loss, prog_bar=True, sync_dist=True, batch_size=len(targets)
)

# Online linear evaluation.
cls_loss, cls_log = self.online_classifier.training_step(
(features.detach(), targets.repeat(len(views))), batch_idx
)

self.log_dict(cls_log, sync_dist=True, batch_size=len(targets))
return loss + cls_loss

def validation_step(
self, batch: Tuple[Tensor, Tensor, List[str]], batch_idx: int
) -> Tensor:
images, targets = batch[0], batch[1]
features = self.forward(images).flatten(start_dim=1)
cls_loss, cls_log = self.online_classifier.validation_step(
(features.detach(), targets), batch_idx
)
self.log_dict(cls_log, prog_bar=True, sync_dist=True, batch_size=len(targets))
return cls_loss

def configure_optimizers(self):
# Don't use weight decay for batch norm, bias parameters, and classification
# head to improve performance.
params, params_no_weight_decay = get_weight_decay_parameters(
[self.backbone, self.projection_head]
)
global_batch_size = self.batch_size_per_device * self.trainer.world_size
base_lr = _get_base_learning_rate(global_batch_size=global_batch_size)
optimizer = LARS(
[
{"name": "ssley", "params": params},
{
"name": "ssley_no_weight_decay",
"params": params_no_weight_decay,
"weight_decay": 0.0,
},
{
"name": "online_classifier",
"params": self.online_classifier.parameters(),
"weight_decay": 0.0,
},
],
# Linear learning rate scaling with a base learning rate of 0.2.
# See https://arxiv.org/pdf/2105.04906.pdf for details.
lr=base_lr * global_batch_size / 256,
momentum=0.9,
weight_decay=1e-6,
)
scheduler = {
"scheduler": CosineWarmupScheduler(
optimizer=optimizer,
warmup_epochs=(
self.trainer.estimated_stepping_batches
/ self.trainer.max_epochs
* 10
),
max_epochs=self.trainer.estimated_stepping_batches,
end_value=0.01, # Scale base learning rate from 0.2 to 0.002.
),
"interval": "step",
}
return [optimizer], [scheduler]


# SSLEY transform
transform = SSLEYTransform()


def _get_base_learning_rate(global_batch_size: int) -> float:
"""Returns the base learning rate for training 100 epochs with a given batch size.

This follows section C.4 in https://arxiv.org/pdf/2105.04906.pdf.

"""
if global_batch_size == 128:
return 0.8
elif global_batch_size == 256:
return 0.5
elif global_batch_size == 512:
return 0.4
else:
return 0.3
1 change: 1 addition & 0 deletions docs/source/examples/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ for PyTorch and PyTorch Lightning to give you a headstart when implementing your
simmim.rst
simsiam.rst
smog.rst
ssley.rst
swav.rst
tico.rst
vicreg.rst
Expand Down
48 changes: 48 additions & 0 deletions docs/source/examples/ssley.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
.. _ssley:

SSL-EY
=======

SSL-EY is a method that explicitly
avoids the collapse problem with a simple regularization term on the variance of the embeddings along each dimension individually. It inherits the model structure from
`Barlow Twins, 2022 <https://arxiv.org/abs/2103.03230>`_ changing the loss. Doing so allows the stabilization of the training and leads to performance improvements.

Reference:
`Efficient Algorithms for the CCA Family: Unconstrained Objectives with Unbiased Gradients, 2023 <https://arxiv.org/abs/2310.01012>`_


.. tabs::
.. tab:: PyTorch

This example can be run from the command line with::

python lightly/examples/pytorch/ssley.py

.. literalinclude:: ../../../examples/pytorch/ssley.py

.. tab:: Lightning

This example can be run from the command line with::

python lightly/examples/pytorch_lightning/ssley.py

.. literalinclude:: ../../../examples/pytorch_lightning/ssley.py

.. tab:: Lightning Distributed

This example runs on multiple gpus using Distributed Data Parallel (DDP)
training with Pytorch Lightning. At least one GPU must be available on
the system. The example can be run from the command line with::

python lightly/examples/pytorch_lightning_distributed/ssley.py

The model differs in the following ways from the non-distributed
implementation:

- Distributed Data Parallel is enabled
- Distributed Sampling is used in the dataloader

Distributed Sampling makes sure that each distributed process sees only
a subset of the data.

.. literalinclude:: ../../../examples/pytorch_lightning_distributed/ssley.py
41 changes: 41 additions & 0 deletions docs/source/getting_started/benchmarks/imagenette_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
| SimCLR | 256 | 200 | 0.835 | 49.7 Min | 3.7 GByte |
| SimMIM (ViT-B32) | 256 | 200 | 0.315 | 115.5 Min | 9.7 GByte |
| SimSiam | 256 | 200 | 0.752 | 58.2 Min | 3.9 GByte |
| SSL-EY | 256 | 200 | TO-DO | TO-DO | TO-DO GByte |
| SwaV | 256 | 200 | 0.861 | 73.3 Min | 6.4 GByte |
| SwaVQueue | 256 | 200 | 0.827 | 72.6 Min | 6.4 GByte |
| SMoG | 256 | 200 | 0.663 | 58.7 Min | 2.6 GByte |
Expand All @@ -50,6 +51,7 @@
| SimCLR | 256 | 800 | 0.889 | 193.5 Min | 3.7 GByte |
| SimMIM (ViT-B32) | 256 | 800 | 0.343 | 446.5 Min | 9.7 GByte |
| SimSiam | 256 | 800 | 0.872 | 206.4 Min | 3.9 GByte |
| SSL-EY | 256 | 800 | TO-DO | TO-DO | TO-DO GByte |
| SwaV | 256 | 800 | 0.902 | 283.2 Min | 6.4 GByte |
| SwaVQueue | 256 | 800 | 0.890 | 282.7 Min | 6.4 GByte |
| SMoG | 256 | 800 | 0.788 | 232.1 Min | 2.6 GByte |
Expand Down Expand Up @@ -81,6 +83,7 @@
NegativeCosineSimilarity,
NTXentLoss,
PMSNLoss,
SSLEYLoss,
SwaVLoss,
TiCoLoss,
VICRegLLoss,
Expand Down Expand Up @@ -266,6 +269,7 @@ def create_dataset_train_ssl(model):
SimCLRModel: simclr_transform,
SimMIMModel: simmim_transform,
SimSiamModel: simsiam_transform,
SSL_EYModel: vicreg_transform,
SwaVModel: swav_transform,
SwaVQueueModel: swav_transform,
SMoGModel: smog_transform,
Expand Down Expand Up @@ -1166,6 +1170,42 @@ def configure_optimizers(self):
return [optim], [cosine_scheduler]


class SSLEYModel(BenchmarkModule):
def __init__(self, dataloader_kNN, num_classes):
super().__init__(dataloader_kNN, num_classes)
# create a ResNet backbone and remove the classification head
resnet = torchvision.models.resnet18()
self.backbone = nn.Sequential(*list(resnet.children())[:-1])
self.projection_head = heads.BarlowTwinsProjectionHead(512, 2048, 2048)
self.criterion = SSLEYLoss()
self.warmup_epochs = 40 if max_epochs >= 800 else 20

def forward(self, x):
x = self.backbone(x).flatten(start_dim=1)
z = self.projection_head(x)
return z

def training_step(self, batch, batch_index):
(x0, x1), _, _ = batch
z0 = self.forward(x0)
z1 = self.forward(x1)
loss = self.criterion(z0, z1)
return loss

def configure_optimizers(self):
# Training diverges without LARS
optim = LARS(
self.parameters(),
lr=0.3 * lr_factor,
weight_decay=1e-4,
momentum=0.9,
)
cosine_scheduler = scheduler.CosineWarmupScheduler(
optim, self.warmup_epochs, max_epochs
)
return [optim], [cosine_scheduler]


class VICRegModel(BenchmarkModule):
def __init__(self, dataloader_kNN, num_classes):
super().__init__(dataloader_kNN, num_classes)
Expand Down Expand Up @@ -1411,6 +1451,7 @@ def configure_optimizers(self):
SimCLRModel,
# SimMIMModel, # disabled by default because SimMIM uses larger images with size 224
SimSiamModel,
SSLEYModel,
SwaVModel,
SwaVQueueModel,
SMoGModel,
Expand Down
3 changes: 3 additions & 0 deletions docs/source/lightly.loss.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ lightly.loss
.. autoclass:: lightly.loss.regularizer.co2.CO2Regularizer
:members:

.. autoclass:: lightly.loss.ssley_loss.SSLEYLoss
:members:

.. autoclass:: lightly.loss.swav_loss.SwaVLoss
:members:

Expand Down
4 changes: 4 additions & 0 deletions docs/source/lightly.transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ lightly.transforms
:members:
:special-members: __call__

.. automodule:: lightly.transforms.ssley_transform
:members:
:special-members: __call__

.. automodule:: lightly.transforms.swav_transform
:members:
:special-members: __call__
Expand Down
66 changes: 66 additions & 0 deletions examples/pytorch/ssley.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import torch
import torchvision
from torch import nn

from lightly.loss.ssley_loss import SSLEYLoss
from lightly.models.modules.heads import SSLEYProjectionHead
from lightly.transforms import SSLEYTransform


class SSLEY(nn.Module):
def __init__(self, backbone):
super().__init__()
self.backbone = backbone
self.projection_head = SSLEYProjectionHead(
input_dim=512,
hidden_dim=2048,
output_dim=2048,
num_layers=2,
)

def forward(self, x):
x = self.backbone(x).flatten(start_dim=1)
z = self.projection_head(x)
return z


resnet = torchvision.models.resnet18()
backbone = nn.Sequential(*list(resnet.children())[:-1])
model = SSLEY(backbone)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

transform = SSLEYTransform(input_size=32)
dataset = torchvision.datasets.CIFAR10(
"datasets/cifar10", download=True, transform=transform
)
# or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder", transform=transform)

dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=256,
shuffle=True,
drop_last=True,
num_workers=8,
)
criterion = SSLEYLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.06)

print("Starting Training")
for epoch in range(10):
total_loss = 0
for batch in dataloader:
x0, x1 = batch[0]
x0 = x0.to(device)
x1 = x1.to(device)
z0 = model(x0)
z1 = model(x1)
loss = criterion(z0, z1)
total_loss += loss.detach()
loss.backward()
optimizer.step()
optimizer.zero_grad()
avg_loss = total_loss / len(dataloader)
print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}")
Loading
Loading