Skip to content

Commit

Permalink
Use VICRegProjectionHead in examples
Browse files Browse the repository at this point in the history
  • Loading branch information
philippmwirth committed Sep 15, 2023
1 parent eaa943b commit f56c64a
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 6 deletions.
9 changes: 7 additions & 2 deletions examples/pytorch/vicreg.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,20 @@

## The projection head is the same as the Barlow Twins one
from lightly.loss.vicreg_loss import VICRegLoss
from lightly.models.modules import BarlowTwinsProjectionHead
from lightly.models.modules.heads import VICRegProjectionHead
from lightly.transforms.vicreg_transform import VICRegTransform


class VICReg(nn.Module):
def __init__(self, backbone):
super().__init__()
self.backbone = backbone
self.projection_head = BarlowTwinsProjectionHead(512, 2048, 2048)
self.projection_head = VICRegProjectionHead(
input_dim=512,
hidden_dim=2048,
output_dim=2048,
num_layers=2,
)

def forward(self, x):
x = self.backbone(x).flatten(start_dim=1)
Expand Down
9 changes: 7 additions & 2 deletions examples/pytorch_lightning/vicreg.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from lightly.loss.vicreg_loss import VICRegLoss

## The projection head is the same as the Barlow Twins one
from lightly.models.modules import BarlowTwinsProjectionHead
from lightly.models.modules.heads import VICRegProjectionHead
from lightly.transforms.vicreg_transform import VICRegTransform


Expand All @@ -19,7 +19,12 @@ def __init__(self):
super().__init__()
resnet = torchvision.models.resnet18()
self.backbone = nn.Sequential(*list(resnet.children())[:-1])
self.projection_head = BarlowTwinsProjectionHead(512, 2048, 2048)
self.projection_head = VICRegProjectionHead(
input_dim=512,
hidden_dim=2048,
output_dim=2048,
num_layers=2,
)
self.criterion = VICRegLoss()

def forward(self, x):
Expand Down
9 changes: 7 additions & 2 deletions examples/pytorch_lightning_distributed/vicreg.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from lightly.loss import VICRegLoss

## The projection head is the same as the Barlow Twins one
from lightly.models.modules import BarlowTwinsProjectionHead
from lightly.models.modules.heads import VICRegProjectionHead
from lightly.transforms.vicreg_transform import VICRegTransform


Expand All @@ -19,7 +19,12 @@ def __init__(self):
super().__init__()
resnet = torchvision.models.resnet18()
self.backbone = nn.Sequential(*list(resnet.children())[:-1])
self.projection_head = BarlowTwinsProjectionHead(512, 2048, 2048)
self.projection_head = VICRegProjectionHead(
input_dim=512,
hidden_dim=2048,
output_dim=2048,
num_layers=2,
)

# enable gather_distributed to gather features from all gpus
# before calculating the loss
Expand Down

0 comments on commit f56c64a

Please sign in to comment.