diff --git a/benchmarks/imagenet/resnet50/vicreg.py b/benchmarks/imagenet/resnet50/vicreg.py index 2cd0d4243..0da61677e 100644 --- a/benchmarks/imagenet/resnet50/vicreg.py +++ b/benchmarks/imagenet/resnet50/vicreg.py @@ -24,7 +24,6 @@ def __init__(self, batch_size_per_device: int, num_classes: int) -> None: resnet = resnet50() resnet.fc = Identity() # Ignore classification head self.backbone = resnet - # VICReg uses Barlow Twins projection head. self.projection_head = VICRegProjectionHead(num_layers=2) self.criterion = VICRegLoss()