Skip to content

Commit

Permalink
allow to check v2 transforms availability from anywhere (#1730)
Browse files Browse the repository at this point in the history
* allow access to v2 transforms availability from anywhere

* import unnecessary after exception

Co-authored-by: guarin <[email protected]>

* reformat

---------

Co-authored-by: guarin <[email protected]>
  • Loading branch information
liopeer and guarin authored Nov 12, 2024
1 parent c5a749f commit c37ecf4
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 3 deletions.
7 changes: 4 additions & 3 deletions lightly/transforms/torchvision_v2_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@
from torch import Tensor
from torchvision.transforms import ToTensor as ToTensorV1

try:
from lightly.utils import dependency as _dependency

if _dependency.torchvision_transforms_v2_available():
from torchvision.transforms import v2 as torchvision_transforms

_TRANSFORMS_V2 = True

except ImportError:
else:
from torchvision import transforms as torchvision_transforms

_TRANSFORMS_V2 = False
Expand Down
14 changes: 14 additions & 0 deletions lightly/utils/dependency.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,17 @@ def timm_vit_available() -> bool:
except ImportError:
return False
return True


@functools.lru_cache(maxsize=1)
def torchvision_transforms_v2_available() -> bool:
"""Checks if torchvision supports the transforms.v2 API.
Returns:
True if transforms.v2 are available, False otherwise
"""
try:
from torchvision.transforms import v2
except ImportError:
return False
return True

0 comments on commit c37ecf4

Please sign in to comment.