diff --git a/lightly/models/modules/ijepa_timm.py b/lightly/models/modules/ijepa_timm.py index c726851ee..372acc6c7 100644 --- a/lightly/models/modules/ijepa_timm.py +++ b/lightly/models/modules/ijepa_timm.py @@ -11,7 +11,8 @@ from lightly.models import utils -class IJEPAPredictorTIMM(nn.Module): +# Type ignore because superclass has Any types. +class IJEPAPredictorTIMM(nn.Module): # type: ignore[misc] """Predictor for the I-JEPA model [0]. Experimental: Support for I-JEPA is experimental, there might be breaking changes @@ -127,12 +128,12 @@ def forward( x = self.predictor_embed(x) x_pos_embed = self.predictor_pos_embed.repeat(B, 1, 1) - x += self.apply_masks(x_pos_embed, masks_x) + x += utils.apply_masks(x_pos_embed, masks_x) _, N_ctxt, _ = x.shape pos_embs = self.predictor_pos_embed.repeat(B, 1, 1) - pos_embs = self.apply_masks(pos_embs, masks) - pos_embs = self.repeat_interleave_batch(pos_embs, B, repeat=len_masks_x) + pos_embs = utils.apply_masks(pos_embs, masks) + pos_embs = utils.repeat_interleave_batch(pos_embs, B, repeat=len_masks_x) pred_tokens = self.mask_token.repeat(pos_embs.size(0), pos_embs.size(1), 1) pred_tokens += pos_embs