Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add torchvision.transforms v1/v2 independent ToTensor() implementation #1718

Merged
merged 7 commits into from
Nov 6, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion lightly/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,10 @@
TiCoView1Transform,
TiCoView2Transform,
)
from lightly.transforms.torchvision_transforms import torchvision_transforms
from lightly.transforms.torchvision_v2compatibility import (
ToTensor,
torchvision_transforms,
)
from lightly.transforms.vicreg_transform import VICRegTransform, VICRegViewTransform
from lightly.transforms.vicregl_transform import VICRegLTransform, VICRegLViewTransform
from lightly.transforms.wmse_transform import WMSETransform
22 changes: 0 additions & 22 deletions lightly/transforms/torchvision_transforms.py

This file was deleted.

49 changes: 49 additions & 0 deletions lightly/transforms/torchvision_v2compatibility.py
liopeer marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#
# Copyright (c) Lightly AG and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
from typing import Union

import torch
from PIL.Image import Image
from torch import Tensor
from torchvision.transforms import ToTensor as ToTensorV1

try:
from torchvision.transforms import v2 as torchvision_transforms

_TRANSFORMS_V2 = True

except ImportError:
from torchvision import transforms as torchvision_transforms

Check warning on line 21 in lightly/transforms/torchvision_v2compatibility.py

View check run for this annotation

Codecov / codecov/patch

lightly/transforms/torchvision_v2compatibility.py#L20-L21

Added lines #L20 - L21 were not covered by tests

_TRANSFORMS_V2 = False

Check warning on line 23 in lightly/transforms/torchvision_v2compatibility.py

View check run for this annotation

Codecov / codecov/patch

lightly/transforms/torchvision_v2compatibility.py#L23

Added line #L23 was not covered by tests


def ToTensor() -> Union[torchvision_transforms.Compose, ToTensorV1]:
"""Convert a PIL Image to a tensor with value normalization, similar to [0].

This implementation is required since `torchvision.transforms.v2.ToTensor` is
deprecated and will be removed in the future (see [1]).

Input to this transform:
PIL Image (H x W x C) of uint8 type in range [0,255]

Output of this transform:
torch.Tensor (C x H x W) of type torch.float32 in range [0.0, 1.0]

- [0] https://pytorch.org/vision/main/generated/torchvision.transforms.ToTensor.html
- [1] https://pytorch.org/vision/0.20/generated/torchvision.transforms.v2.ToTensor.html?highlight=totensor#torchvision.transforms.v2.ToTensor
"""
T = torchvision_transforms
if _TRANSFORMS_V2 and hasattr(T, "ToImage") and hasattr(T, "ToDtype"):
# v2.transforms.ToTensor is deprecated and will be removed in the future.
# This is the new recommended way to convert a PIL Image to a tensor since
# torchvision v0.16.
# See also https://github.com/pytorch/vision/blame/33e47d88265b2d57c2644aad1425be4fccd64605/torchvision/transforms/v2/_deprecated.py#L19
return T.Compose([T.ToImage(), T.ToDtype(dtype=torch.float32, scale=True)])
else:
return T.ToTensor()

Check warning on line 49 in lightly/transforms/torchvision_v2compatibility.py

View check run for this annotation

Codecov / codecov/patch

lightly/transforms/torchvision_v2compatibility.py#L49

Added line #L49 was not covered by tests
16 changes: 16 additions & 0 deletions tests/transforms/test_torchvision_v2compatibility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import numpy as np
import torch
from PIL import Image
from torch import Tensor

from lightly.transforms import ToTensor


def test_ToTensor() -> None:
img_np = np.random.randint(0, 255, (20, 30, 3), dtype=np.uint8)
img_pil = Image.fromarray(img_np)
img_tens = ToTensor()(img_pil)
assert isinstance(img_tens, Tensor)
assert img_tens.shape == (3, 20, 30)
assert img_tens.dtype == torch.float32
assert img_tens.max() <= 1.0 and img_tens.min() >= 0.0