Skip to content

Commit

Permalink
Merge branch 'master' into phase_shift_transform
Browse files Browse the repository at this point in the history
  • Loading branch information
guarin authored Nov 7, 2024
2 parents 1c40c24 + 3b7357e commit 1afd12d
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 2 deletions.
2 changes: 1 addition & 1 deletion benchmarks/imagenet/resnet50/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ See [simclr.py](simclr.py).
Instead, we try to provide building blocks and examples to make it as easy as possible to
build on top of existing SSL methods.

You can find benchmark resuls in our [docs](https://docs.lightly.ai/self-supervised-learning/getting_started/benchmarks.html).
You can find benchmark results in our [docs](https://docs.lightly.ai/self-supervised-learning/getting_started/benchmarks.html).

## Run Benchmark

Expand Down
2 changes: 1 addition & 1 deletion lightly/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
# All Rights Reserved

__name__ = "lightly"
__version__ = "1.5.13"
__version__ = "1.5.14"


import os
Expand Down
4 changes: 4 additions & 0 deletions lightly/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@
TiCoView1Transform,
TiCoView2Transform,
)
from lightly.transforms.torchvision_v2_compatibility import (
ToTensor,
torchvision_transforms,
)
from lightly.transforms.vicreg_transform import VICRegTransform, VICRegViewTransform
from lightly.transforms.vicregl_transform import VICRegLTransform, VICRegLViewTransform
from lightly.transforms.wmse_transform import WMSETransform
49 changes: 49 additions & 0 deletions lightly/transforms/torchvision_v2_compatibility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#
# Copyright (c) Lightly AG and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
from typing import Union

import torch
from PIL.Image import Image
from torch import Tensor
from torchvision.transforms import ToTensor as ToTensorV1

try:
from torchvision.transforms import v2 as torchvision_transforms

_TRANSFORMS_V2 = True

except ImportError:
from torchvision import transforms as torchvision_transforms

_TRANSFORMS_V2 = False


def ToTensor() -> Union[torchvision_transforms.Compose, ToTensorV1]:
"""Convert a PIL Image to a tensor with value normalization, similar to [0].
This implementation is required since `torchvision.transforms.v2.ToTensor` is
deprecated and will be removed in the future (see [1]).
Input to this transform:
PIL Image (H x W x C) of uint8 type in range [0,255]
Output of this transform:
torch.Tensor (C x H x W) of type torch.float32 in range [0.0, 1.0]
- [0] https://pytorch.org/vision/main/generated/torchvision.transforms.ToTensor.html
- [1] https://pytorch.org/vision/0.20/generated/torchvision.transforms.v2.ToTensor.html?highlight=totensor#torchvision.transforms.v2.ToTensor
"""
T = torchvision_transforms
if _TRANSFORMS_V2 and hasattr(T, "ToImage") and hasattr(T, "ToDtype"):
# v2.transforms.ToTensor is deprecated and will be removed in the future.
# This is the new recommended way to convert a PIL Image to a tensor since
# torchvision v0.16.
# See also https://github.com/pytorch/vision/blame/33e47d88265b2d57c2644aad1425be4fccd64605/torchvision/transforms/v2/_deprecated.py#L19
return T.Compose([T.ToImage(), T.ToDtype(dtype=torch.float32, scale=True)])
else:
return T.ToTensor()
16 changes: 16 additions & 0 deletions tests/transforms/test_torchvision_v2compatibility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import numpy as np
import torch
from PIL import Image
from torch import Tensor

from lightly.transforms import ToTensor


def test_ToTensor() -> None:
img_np = np.random.randint(0, 255, (20, 30, 3), dtype=np.uint8)
img_pil = Image.fromarray(img_np)
img_tens = ToTensor()(img_pil)
assert isinstance(img_tens, Tensor)
assert img_tens.shape == (3, 20, 30)
assert img_tens.dtype == torch.float32
assert img_tens.max() <= 1.0 and img_tens.min() >= 0.0

0 comments on commit 1afd12d

Please sign in to comment.