diff --git a/benchmarks/imagenet/resnet50/README.md b/benchmarks/imagenet/resnet50/README.md index 9f1a32c89..a5173a056 100644 --- a/benchmarks/imagenet/resnet50/README.md +++ b/benchmarks/imagenet/resnet50/README.md @@ -19,7 +19,7 @@ See [simclr.py](simclr.py). Instead, we try to provide building blocks and examples to make it as easy as possible to build on top of existing SSL methods. -You can find benchmark resuls in our [docs](https://docs.lightly.ai/self-supervised-learning/getting_started/benchmarks.html). +You can find benchmark results in our [docs](https://docs.lightly.ai/self-supervised-learning/getting_started/benchmarks.html). ## Run Benchmark diff --git a/lightly/__init__.py b/lightly/__init__.py index e90552f27..1f856ba0d 100644 --- a/lightly/__init__.py +++ b/lightly/__init__.py @@ -75,7 +75,7 @@ # All Rights Reserved __name__ = "lightly" -__version__ = "1.5.13" +__version__ = "1.5.14" import os diff --git a/lightly/transforms/__init__.py b/lightly/transforms/__init__.py index 02258f35d..422b64cac 100644 --- a/lightly/transforms/__init__.py +++ b/lightly/transforms/__init__.py @@ -47,6 +47,10 @@ TiCoView1Transform, TiCoView2Transform, ) +from lightly.transforms.torchvision_v2_compatibility import ( + ToTensor, + torchvision_transforms, +) from lightly.transforms.vicreg_transform import VICRegTransform, VICRegViewTransform from lightly.transforms.vicregl_transform import VICRegLTransform, VICRegLViewTransform from lightly.transforms.wmse_transform import WMSETransform diff --git a/lightly/transforms/torchvision_v2_compatibility.py b/lightly/transforms/torchvision_v2_compatibility.py new file mode 100644 index 000000000..edcc81399 --- /dev/null +++ b/lightly/transforms/torchvision_v2_compatibility.py @@ -0,0 +1,49 @@ +# +# Copyright (c) Lightly AG and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# +from typing import Union + +import torch +from PIL.Image import Image +from torch import Tensor +from torchvision.transforms import ToTensor as ToTensorV1 + +try: + from torchvision.transforms import v2 as torchvision_transforms + + _TRANSFORMS_V2 = True + +except ImportError: + from torchvision import transforms as torchvision_transforms + + _TRANSFORMS_V2 = False + + +def ToTensor() -> Union[torchvision_transforms.Compose, ToTensorV1]: + """Convert a PIL Image to a tensor with value normalization, similar to [0]. + + This implementation is required since `torchvision.transforms.v2.ToTensor` is + deprecated and will be removed in the future (see [1]). + + Input to this transform: + PIL Image (H x W x C) of uint8 type in range [0,255] + + Output of this transform: + torch.Tensor (C x H x W) of type torch.float32 in range [0.0, 1.0] + + - [0] https://pytorch.org/vision/main/generated/torchvision.transforms.ToTensor.html + - [1] https://pytorch.org/vision/0.20/generated/torchvision.transforms.v2.ToTensor.html?highlight=totensor#torchvision.transforms.v2.ToTensor + """ + T = torchvision_transforms + if _TRANSFORMS_V2 and hasattr(T, "ToImage") and hasattr(T, "ToDtype"): + # v2.transforms.ToTensor is deprecated and will be removed in the future. + # This is the new recommended way to convert a PIL Image to a tensor since + # torchvision v0.16. + # See also https://github.com/pytorch/vision/blame/33e47d88265b2d57c2644aad1425be4fccd64605/torchvision/transforms/v2/_deprecated.py#L19 + return T.Compose([T.ToImage(), T.ToDtype(dtype=torch.float32, scale=True)]) + else: + return T.ToTensor() diff --git a/tests/transforms/test_torchvision_v2compatibility.py b/tests/transforms/test_torchvision_v2compatibility.py new file mode 100644 index 000000000..f126ef581 --- /dev/null +++ b/tests/transforms/test_torchvision_v2compatibility.py @@ -0,0 +1,16 @@ +import numpy as np +import torch +from PIL import Image +from torch import Tensor + +from lightly.transforms import ToTensor + + +def test_ToTensor() -> None: + img_np = np.random.randint(0, 255, (20, 30, 3), dtype=np.uint8) + img_pil = Image.fromarray(img_np) + img_tens = ToTensor()(img_pil) + assert isinstance(img_tens, Tensor) + assert img_tens.shape == (3, 20, 30) + assert img_tens.dtype == torch.float32 + assert img_tens.max() <= 1.0 and img_tens.min() >= 0.0