Skip to content

Commit

Permalink
Add torchvision.transforms v1/v2 independent ToTensor() implementation (
Browse files Browse the repository at this point in the history
#1718)

* add ToTensor() implementation compatible with v2

* make version independent ToTensor available

* add test for version indep. ToTensor() implementation
  • Loading branch information
liopeer authored Nov 6, 2024
1 parent c22ee83 commit 01fdac4
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 23 deletions.
5 changes: 4 additions & 1 deletion lightly/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,10 @@
TiCoView1Transform,
TiCoView2Transform,
)
from lightly.transforms.torchvision_transforms import torchvision_transforms
from lightly.transforms.torchvision_v2_compatibility import (
ToTensor,
torchvision_transforms,
)
from lightly.transforms.vicreg_transform import VICRegTransform, VICRegViewTransform
from lightly.transforms.vicregl_transform import VICRegLTransform, VICRegLViewTransform
from lightly.transforms.wmse_transform import WMSETransform
22 changes: 0 additions & 22 deletions lightly/transforms/torchvision_transforms.py

This file was deleted.

49 changes: 49 additions & 0 deletions lightly/transforms/torchvision_v2_compatibility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#
# Copyright (c) Lightly AG and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
from typing import Union

import torch
from PIL.Image import Image
from torch import Tensor
from torchvision.transforms import ToTensor as ToTensorV1

try:
from torchvision.transforms import v2 as torchvision_transforms

_TRANSFORMS_V2 = True

except ImportError:
from torchvision import transforms as torchvision_transforms

_TRANSFORMS_V2 = False


def ToTensor() -> Union[torchvision_transforms.Compose, ToTensorV1]:
"""Convert a PIL Image to a tensor with value normalization, similar to [0].
This implementation is required since `torchvision.transforms.v2.ToTensor` is
deprecated and will be removed in the future (see [1]).
Input to this transform:
PIL Image (H x W x C) of uint8 type in range [0,255]
Output of this transform:
torch.Tensor (C x H x W) of type torch.float32 in range [0.0, 1.0]
- [0] https://pytorch.org/vision/main/generated/torchvision.transforms.ToTensor.html
- [1] https://pytorch.org/vision/0.20/generated/torchvision.transforms.v2.ToTensor.html?highlight=totensor#torchvision.transforms.v2.ToTensor
"""
T = torchvision_transforms
if _TRANSFORMS_V2 and hasattr(T, "ToImage") and hasattr(T, "ToDtype"):
# v2.transforms.ToTensor is deprecated and will be removed in the future.
# This is the new recommended way to convert a PIL Image to a tensor since
# torchvision v0.16.
# See also https://github.com/pytorch/vision/blame/33e47d88265b2d57c2644aad1425be4fccd64605/torchvision/transforms/v2/_deprecated.py#L19
return T.Compose([T.ToImage(), T.ToDtype(dtype=torch.float32, scale=True)])
else:
return T.ToTensor()
16 changes: 16 additions & 0 deletions tests/transforms/test_torchvision_v2compatibility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import numpy as np
import torch
from PIL import Image
from torch import Tensor

from lightly.transforms import ToTensor


def test_ToTensor() -> None:
img_np = np.random.randint(0, 255, (20, 30, 3), dtype=np.uint8)
img_pil = Image.fromarray(img_np)
img_tens = ToTensor()(img_pil)
assert isinstance(img_tens, Tensor)
assert img_tens.shape == (3, 20, 30)
assert img_tens.dtype == torch.float32
assert img_tens.max() <= 1.0 and img_tens.min() >= 0.0

0 comments on commit 01fdac4

Please sign in to comment.