From 658dd83b208608d0f646b4167c6b3af1705031ae Mon Sep 17 00:00:00 2001 From: Vladislav Tumko <56307628+vectorvp@users.noreply.github.com> Date: Sat, 30 Nov 2024 06:31:37 +0700 Subject: [PATCH] refactor: update typings --- lightly/models/modules/ijepa_timm.py | 11 ++++++----- lightly/models/utils.py | 6 +++--- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/lightly/models/modules/ijepa_timm.py b/lightly/models/modules/ijepa_timm.py index af254c639..c726851ee 100644 --- a/lightly/models/modules/ijepa_timm.py +++ b/lightly/models/modules/ijepa_timm.py @@ -6,6 +6,7 @@ import torch import torch.nn as nn from timm.models.vision_transformer import Block +from torch import Tensor from lightly.models import utils @@ -58,7 +59,7 @@ def __init__( drop_path_rate: float = 0.0, proj_drop_rate: float = 0.0, attn_drop_rate: float = 0.0, - norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), + norm_layer: Callable[..., nn.Module] = partial(nn.LayerNorm, eps=1e-6), ): """Initializes the IJEPAPredictorTIMM with the specified dimensions.""" @@ -97,10 +98,10 @@ def __init__( def forward( self, - x: torch.Tensor, - masks_x: list[torch.Tensor] | torch.Tensor, - masks: list[torch.Tensor] | torch.Tensor, - ) -> torch.Tensor: + x: Tensor, + masks_x: list[Tensor] | Tensor, + masks: list[Tensor] | Tensor, + ) -> Tensor: """Forward pass of the IJEPAPredictorTIMM. Args: diff --git a/lightly/models/utils.py b/lightly/models/utils.py index 3a4744ff7..3c65a6fce 100644 --- a/lightly/models/utils.py +++ b/lightly/models/utils.py @@ -1263,7 +1263,7 @@ def update_drop_path_rate( block.drop_path2 = Identity() -def repeat_interleave_batch(x: torch.Tensor, B: int, repeat: int) -> torch.Tensor: +def repeat_interleave_batch(x: Tensor, B: int, repeat: int) -> Tensor: """Repeat and interleave the input tensor.""" N = len(x) // B x = torch.cat( @@ -1277,8 +1277,8 @@ def repeat_interleave_batch(x: torch.Tensor, B: int, repeat: int) -> torch.Tenso def apply_masks( - x: torch.Tensor, masks: torch.Tensor | list[torch.Tensor] -) -> torch.Tensor: + x: Tensor, masks: Tensor | list[Tensor] +) -> Tensor: """Apply masks to the input tensor. From https://github.com/facebookresearch/ijepa/blob/main/src/masks/utils.py