Skip to content

Commit

Permalink
Add examples (vanilla PyTorch, PyTorch Lightning, and PyTorch Lightni…
Browse files Browse the repository at this point in the history
…ng Distributed) (#1480)

* Add MMCR examples
* Add MMCR docs page
* Add MMCRTransform to docs
* Add MMCRLoss to docs
  • Loading branch information
johnsutor authored Jan 24, 2024
1 parent 87be5a1 commit b7f4ae2
Show file tree
Hide file tree
Showing 7 changed files with 324 additions and 0 deletions.
48 changes: 48 additions & 0 deletions docs/source/examples/mmcr.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
.. _mmcr:

MMCR
====

Example implementation of the MMCR architecture.

Reference:
`Learning Efficient Coding of Natural Images with Maximum Manifold Capacity Representations, 2023 <https://arxiv.org/abs/2303.03307>`_


.. tabs::

.. tab:: PyTorch

This example can be run from the command line with::

python lightly/examples/pytorch/mmcr.py

.. literalinclude:: ../../../examples/pytorch/mmcr.py

.. tab:: Lightning

This example can be run from the command line with::

python lightly/examples/pytorch_lightning/mmcr.py

.. literalinclude:: ../../../examples/pytorch_lightning/mmcr.py

.. tab:: Lightning Distributed

This example runs on multiple gpus using Distributed Data Parallel (DDP)
training with Pytorch Lightning. At least one GPU must be available on
the system. The example can be run from the command line with::

python lightly/examples/pytorch_lightning_distributed/mmcr.py

The model differs in the following ways from the non-distributed
implementation:

- Distributed Data Parallel is enabled
- Synchronized Batch Norm is used in place of standard Batch Norm

Note that Synchronized Batch Norm is optional and the model can also be
trained without it. Without Synchronized Batch Norm the batch norm for
each GPU is only calculated based on the features on that specific GPU.

.. literalinclude:: ../../../examples/pytorch_lightning_distributed/mmcr.py
1 change: 1 addition & 0 deletions docs/source/examples/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ for PyTorch and PyTorch Lightning to give you a headstart when implementing your
dino.rst
fastsiam.rst
mae.rst
mmcr.rst
msn.rst
moco.rst
nnclr.rst
Expand Down
3 changes: 3 additions & 0 deletions docs/source/lightly.loss.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ lightly.loss
.. autoclass:: lightly.loss.memory_bank.MemoryBankModule
:members:

.. autoclass:: lightly.loss.mmcr_loss.MMCRLoss
:members:

.. autoclass:: lightly.loss.msn_loss.MSNLoss
:members:

Expand Down
4 changes: 4 additions & 0 deletions docs/source/lightly.transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ lightly.transforms
:members:
:special-members: __call__

.. automodule:: lightly.transforms.mmcr_transform
:members:
:special-members: __call__

.. automodule:: lightly.transforms.moco_transform
:members:
:special-members: __call__
Expand Down
92 changes: 92 additions & 0 deletions examples/pytorch/mmcr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Note: The model and training settings do not follow the reference settings
# from the paper. The settings are chosen such that the example can easily be
# run on a small dataset with a single GPU.

import copy

import torch
import torchvision
from torch import nn

from lightly.loss import MMCRLoss
from lightly.models.modules import SimCLRProjectionHead
from lightly.models.utils import deactivate_requires_grad, update_momentum
from lightly.transforms.mmcr_transform import MMCRTransform
from lightly.utils.scheduler import cosine_schedule


class MMCR(nn.Module):
def __init__(self, backbone):
super().__init__()

self.backbone = backbone
self.projection_head = SimCLRProjectionHead(512, 512, 128)

self.backbone_momentum = copy.deepcopy(self.backbone)
self.projection_head_momentum = copy.deepcopy(self.projection_head)

deactivate_requires_grad(self.backbone_momentum)
deactivate_requires_grad(self.projection_head_momentum)

def forward(self, x):
y = self.backbone(x).flatten(start_dim=1)
z = self.projection_head(y)
return z

def forward_momentum(self, x):
y = self.backbone_momentum(x).flatten(start_dim=1)
z = self.projection_head_momentum(y)
z = z.detach()
return z


resnet = torchvision.models.resnet18()
backbone = nn.Sequential(*list(resnet.children())[:-1])
model = MMCR(backbone)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

transform = MMCRTransform(k=8, input_size=32, gaussian_blur=0.0)
dataset = torchvision.datasets.CIFAR10(
"datasets/cifar10", download=True, transform=transform
)
# or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder", transform=transform)

dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=256,
shuffle=True,
drop_last=True,
num_workers=8,
)

criterion = MMCRLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.06)

epochs = 10

print("Starting Training")
for epoch in range(epochs):
total_loss = 0
momentum_val = cosine_schedule(epoch, epochs, 0.996, 1)
for batch in dataloader:
update_momentum(model.backbone, model.backbone_momentum, m=momentum_val)
update_momentum(
model.projection_head, model.projection_head_momentum, m=momentum_val
)
z_o = [model(x.to(device)) for x in batch[0]]
z_m = [model.forward_momentum(x.to(device)) for x in batch[0]]

# Switch dimensions to (batch_size, k, embedding_size)
z_o = torch.stack(z_o, dim=1)
z_m = torch.stack(z_m, dim=1)

loss = criterion(z_o, z_m)
total_loss += loss.detach()
loss.backward()
optimizer.step()
optimizer.zero_grad()
avg_loss = total_loss / len(dataloader)
print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}")
84 changes: 84 additions & 0 deletions examples/pytorch_lightning/mmcr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Note: The model and training settings do not follow the reference settings
# from the paper. The settings are chosen such that the example can easily be
# run on a small dataset with a single GPU.

import copy

import pytorch_lightning as pl
import torch
import torchvision
from torch import nn

from lightly.loss import MMCRLoss
from lightly.models.modules import SimCLRProjectionHead
from lightly.models.utils import deactivate_requires_grad, update_momentum
from lightly.transforms.mmcr_transform import MMCRTransform
from lightly.utils.scheduler import cosine_schedule


class MMCR(pl.LightningModule):
def __init__(self):
super().__init__()
resnet = torchvision.models.resnet18()
self.backbone = nn.Sequential(*list(resnet.children())[:-1])
self.projection_head = SimCLRProjectionHead(512, 512, 128)

self.backbone_momentum = copy.deepcopy(self.backbone)
self.projection_head_momentum = copy.deepcopy(self.projection_head)

deactivate_requires_grad(self.backbone_momentum)
deactivate_requires_grad(self.projection_head_momentum)

self.criterion = MMCRLoss()

def forward(self, x):
y = self.backbone(x).flatten(start_dim=1)
z = self.projection_head(y)
return z

def forward_momentum(self, x):
y = self.backbone_momentum(x).flatten(start_dim=1)
z = self.projection_head_momentum(y)
z = z.detach()
return z

def training_step(self, batch, batch_idx):
momentum = cosine_schedule(self.current_epoch, 10, 0.996, 1)
update_momentum(self.backbone, self.backbone_momentum, m=momentum)
update_momentum(self.projection_head, self.projection_head_momentum, m=momentum)
z_o = [model(x) for x in batch[0]]
z_m = [model.forward_momentum(x) for x in batch[0]]

# Switch dimensions to (batch_size, k, embedding_size)
z_o = torch.stack(z_o, dim=1)
z_m = torch.stack(z_m, dim=1)

loss = self.criterion(z_o, z_m)
return loss

def configure_optimizers(self):
return torch.optim.SGD(self.parameters(), lr=0.06)


model = MMCR()

# We disable resizing and gaussian blur for cifar10.
transform = MMCRTransform(k=8, input_size=32, gaussian_blur=0.0)
dataset = torchvision.datasets.CIFAR10(
"datasets/cifar10", download=True, transform=transform
)
# or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder", transform=transform)

dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=256,
shuffle=True,
drop_last=True,
num_workers=8,
)

accelerator = "gpu" if torch.cuda.is_available() else "cpu"

trainer = pl.Trainer(max_epochs=10, devices=1, accelerator=accelerator)
trainer.fit(model=model, train_dataloaders=dataloader)
92 changes: 92 additions & 0 deletions examples/pytorch_lightning_distributed/mmcr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Note: The model and training settings do not follow the reference settings
# from the paper. The settings are chosen such that the example can easily be
# run on a small dataset with a single GPU.

import copy

import pytorch_lightning as pl
import torch
import torchvision
from torch import nn

from lightly.loss import MMCRLoss
from lightly.models.modules import SimCLRProjectionHead
from lightly.models.utils import deactivate_requires_grad, update_momentum
from lightly.transforms.mmcr_transform import MMCRTransform
from lightly.utils.scheduler import cosine_schedule


class MMCR(pl.LightningModule):
def __init__(self):
super().__init__()
resnet = torchvision.models.resnet18()
self.backbone = nn.Sequential(*list(resnet.children())[:-1])
self.projection_head = SimCLRProjectionHead(512, 512, 128)

self.backbone_momentum = copy.deepcopy(self.backbone)
self.projection_head_momentum = copy.deepcopy(self.projection_head)

deactivate_requires_grad(self.backbone_momentum)
deactivate_requires_grad(self.projection_head_momentum)

self.criterion = MMCRLoss()

def forward(self, x):
y = self.backbone(x).flatten(start_dim=1)
z = self.projection_head(y)
return z

def forward_momentum(self, x):
y = self.backbone_momentum(x).flatten(start_dim=1)
z = self.projection_head_momentum(y)
z = z.detach()
return z

def training_step(self, batch, batch_idx):
momentum = cosine_schedule(self.current_epoch, 10, 0.996, 1)
update_momentum(self.backbone, self.backbone_momentum, m=momentum)
update_momentum(self.projection_head, self.projection_head_momentum, m=momentum)
z_o = [model(x) for x in batch[0]]
z_m = [model.forward_momentum(x) for x in batch[0]]

# Switch dimensions to (batch_size, k, embedding_size)
z_o = torch.stack(z_o, dim=1)
z_m = torch.stack(z_m, dim=1)

loss = self.criterion(z_o, z_m)
return loss

def configure_optimizers(self):
return torch.optim.SGD(self.parameters(), lr=0.06)


model = MMCR()

# We disable resizing and gaussian blur for cifar10.
transform = MMCRTransform(k=8, input_size=32, gaussian_blur=0.0)
dataset = torchvision.datasets.CIFAR10(
"datasets/cifar10", download=True, transform=transform
)
# or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder", transform=transform)

dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=256,
shuffle=True,
drop_last=True,
num_workers=8,
)

# Train with DDP and use Synchronized Batch Norm for a more accurate batch norm
# calculation. Distributed sampling is also enabled with replace_sampler_ddp=True.
if __name__ == "__main__":
trainer = pl.Trainer(
max_epochs=10,
devices="auto",
accelerator="gpu",
strategy="ddp",
sync_batchnorm=True,
use_distributed_sampler=True, # or replace_sampler_ddp=True for PyTorch Lightning <2.0
)
trainer.fit(model=model, train_dataloaders=dataloader)

0 comments on commit b7f4ae2

Please sign in to comment.