-
Notifications
You must be signed in to change notification settings - Fork 286
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add torchvision.transforms v1/v2 independent ToTensor() implementation (
#1718) * add ToTensor() implementation compatible with v2 * make version independent ToTensor available * add test for version indep. ToTensor() implementation
- Loading branch information
Showing
4 changed files
with
69 additions
and
23 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
# | ||
# Copyright (c) Lightly AG and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
# | ||
from typing import Union | ||
|
||
import torch | ||
from PIL.Image import Image | ||
from torch import Tensor | ||
from torchvision.transforms import ToTensor as ToTensorV1 | ||
|
||
try: | ||
from torchvision.transforms import v2 as torchvision_transforms | ||
|
||
_TRANSFORMS_V2 = True | ||
|
||
except ImportError: | ||
from torchvision import transforms as torchvision_transforms | ||
|
||
_TRANSFORMS_V2 = False | ||
|
||
|
||
def ToTensor() -> Union[torchvision_transforms.Compose, ToTensorV1]: | ||
"""Convert a PIL Image to a tensor with value normalization, similar to [0]. | ||
This implementation is required since `torchvision.transforms.v2.ToTensor` is | ||
deprecated and will be removed in the future (see [1]). | ||
Input to this transform: | ||
PIL Image (H x W x C) of uint8 type in range [0,255] | ||
Output of this transform: | ||
torch.Tensor (C x H x W) of type torch.float32 in range [0.0, 1.0] | ||
- [0] https://pytorch.org/vision/main/generated/torchvision.transforms.ToTensor.html | ||
- [1] https://pytorch.org/vision/0.20/generated/torchvision.transforms.v2.ToTensor.html?highlight=totensor#torchvision.transforms.v2.ToTensor | ||
""" | ||
T = torchvision_transforms | ||
if _TRANSFORMS_V2 and hasattr(T, "ToImage") and hasattr(T, "ToDtype"): | ||
# v2.transforms.ToTensor is deprecated and will be removed in the future. | ||
# This is the new recommended way to convert a PIL Image to a tensor since | ||
# torchvision v0.16. | ||
# See also https://github.com/pytorch/vision/blame/33e47d88265b2d57c2644aad1425be4fccd64605/torchvision/transforms/v2/_deprecated.py#L19 | ||
return T.Compose([T.ToImage(), T.ToDtype(dtype=torch.float32, scale=True)]) | ||
else: | ||
return T.ToTensor() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
import numpy as np | ||
import torch | ||
from PIL import Image | ||
from torch import Tensor | ||
|
||
from lightly.transforms import ToTensor | ||
|
||
|
||
def test_ToTensor() -> None: | ||
img_np = np.random.randint(0, 255, (20, 30, 3), dtype=np.uint8) | ||
img_pil = Image.fromarray(img_np) | ||
img_tens = ToTensor()(img_pil) | ||
assert isinstance(img_tens, Tensor) | ||
assert img_tens.shape == (3, 20, 30) | ||
assert img_tens.dtype == torch.float32 | ||
assert img_tens.max() <= 1.0 and img_tens.min() >= 0.0 |