From f56c64a791a0d9c47d530d257e775ca1bd0629a9 Mon Sep 17 00:00:00 2001 From: philippmwirth Date: Fri, 15 Sep 2023 14:11:35 +0000 Subject: [PATCH] Use VICRegProjectionHead in examples --- examples/pytorch/vicreg.py | 9 +++++++-- examples/pytorch_lightning/vicreg.py | 9 +++++++-- examples/pytorch_lightning_distributed/vicreg.py | 9 +++++++-- 3 files changed, 21 insertions(+), 6 deletions(-) diff --git a/examples/pytorch/vicreg.py b/examples/pytorch/vicreg.py index a3221846c..39968cb08 100644 --- a/examples/pytorch/vicreg.py +++ b/examples/pytorch/vicreg.py @@ -7,7 +7,7 @@ ## The projection head is the same as the Barlow Twins one from lightly.loss.vicreg_loss import VICRegLoss -from lightly.models.modules import BarlowTwinsProjectionHead +from lightly.models.modules.heads import VICRegProjectionHead from lightly.transforms.vicreg_transform import VICRegTransform @@ -15,7 +15,12 @@ class VICReg(nn.Module): def __init__(self, backbone): super().__init__() self.backbone = backbone - self.projection_head = BarlowTwinsProjectionHead(512, 2048, 2048) + self.projection_head = VICRegProjectionHead( + input_dim=512, + hidden_dim=2048, + output_dim=2048, + num_layers=2, + ) def forward(self, x): x = self.backbone(x).flatten(start_dim=1) diff --git a/examples/pytorch_lightning/vicreg.py b/examples/pytorch_lightning/vicreg.py index 5fa5c89fd..2770309e7 100644 --- a/examples/pytorch_lightning/vicreg.py +++ b/examples/pytorch_lightning/vicreg.py @@ -10,7 +10,7 @@ from lightly.loss.vicreg_loss import VICRegLoss ## The projection head is the same as the Barlow Twins one -from lightly.models.modules import BarlowTwinsProjectionHead +from lightly.models.modules.heads import VICRegProjectionHead from lightly.transforms.vicreg_transform import VICRegTransform @@ -19,7 +19,12 @@ def __init__(self): super().__init__() resnet = torchvision.models.resnet18() self.backbone = nn.Sequential(*list(resnet.children())[:-1]) - self.projection_head = BarlowTwinsProjectionHead(512, 2048, 2048) + self.projection_head = VICRegProjectionHead( + input_dim=512, + hidden_dim=2048, + output_dim=2048, + num_layers=2, + ) self.criterion = VICRegLoss() def forward(self, x): diff --git a/examples/pytorch_lightning_distributed/vicreg.py b/examples/pytorch_lightning_distributed/vicreg.py index a30c726e5..bc8aa38ad 100644 --- a/examples/pytorch_lightning_distributed/vicreg.py +++ b/examples/pytorch_lightning_distributed/vicreg.py @@ -10,7 +10,7 @@ from lightly.loss import VICRegLoss ## The projection head is the same as the Barlow Twins one -from lightly.models.modules import BarlowTwinsProjectionHead +from lightly.models.modules.heads import VICRegProjectionHead from lightly.transforms.vicreg_transform import VICRegTransform @@ -19,7 +19,12 @@ def __init__(self): super().__init__() resnet = torchvision.models.resnet18() self.backbone = nn.Sequential(*list(resnet.children())[:-1]) - self.projection_head = BarlowTwinsProjectionHead(512, 2048, 2048) + self.projection_head = VICRegProjectionHead( + input_dim=512, + hidden_dim=2048, + output_dim=2048, + num_layers=2, + ) # enable gather_distributed to gather features from all gpus # before calculating the loss