Skip to content

Commit

Permalink
fix: import modules from utils and typecheck ignore
Browse files Browse the repository at this point in the history
  • Loading branch information
vectorvp committed Nov 29, 2024
1 parent da8dd51 commit c45ece0
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions lightly/models/modules/ijepa_timm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from lightly.models import utils


class IJEPAPredictorTIMM(nn.Module):
# Type ignore because superclass has Any types.
class IJEPAPredictorTIMM(nn.Module): # type: ignore[misc]
"""Predictor for the I-JEPA model [0].
Experimental: Support for I-JEPA is experimental, there might be breaking changes
Expand Down Expand Up @@ -127,12 +128,12 @@ def forward(
x = self.predictor_embed(x)
x_pos_embed = self.predictor_pos_embed.repeat(B, 1, 1)

x += self.apply_masks(x_pos_embed, masks_x)
x += utils.apply_masks(x_pos_embed, masks_x)
_, N_ctxt, _ = x.shape

pos_embs = self.predictor_pos_embed.repeat(B, 1, 1)
pos_embs = self.apply_masks(pos_embs, masks)
pos_embs = self.repeat_interleave_batch(pos_embs, B, repeat=len_masks_x)
pos_embs = utils.apply_masks(pos_embs, masks)
pos_embs = utils.repeat_interleave_batch(pos_embs, B, repeat=len_masks_x)
pred_tokens = self.mask_token.repeat(pos_embs.size(0), pos_embs.size(1), 1)

pred_tokens += pos_embs
Expand Down

0 comments on commit c45ece0

Please sign in to comment.