Skip to content

Commit

Permalink
Fix I-JEPA Example (#1747)
Browse files Browse the repository at this point in the history
* Add apply_masks and repeat_interleave_batch to models/utils.py
  • Loading branch information
vectorvp authored Dec 3, 2024
1 parent 1e7274a commit fac3dcb
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 50 deletions.
64 changes: 14 additions & 50 deletions lightly/models/modules/ijepa_timm.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
from __future__ import annotations

from functools import partial
from typing import Callable, List, Union
from typing import Callable

import torch
import torch.nn as nn
from timm.models.vision_transformer import Block
from torch import Tensor

from lightly.models import utils


class IJEPAPredictorTIMM(nn.Module):
# Type ignore because superclass has Any types.
class IJEPAPredictorTIMM(nn.Module): # type: ignore[misc]
"""Predictor for the I-JEPA model [0].
Experimental: Support for I-JEPA is experimental, there might be breaking changes
Expand Down Expand Up @@ -56,7 +60,7 @@ def __init__(
drop_path_rate: float = 0.0,
proj_drop_rate: float = 0.0,
attn_drop_rate: float = 0.0,
norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
norm_layer: Callable[..., nn.Module] = partial(nn.LayerNorm, eps=1e-6),
):
"""Initializes the IJEPAPredictorTIMM with the specified dimensions."""

Expand Down Expand Up @@ -95,10 +99,10 @@ def __init__(

def forward(
self,
x: torch.Tensor,
masks_x: Union[List[torch.Tensor], torch.Tensor],
masks: Union[List[torch.Tensor], torch.Tensor],
) -> torch.Tensor:
x: Tensor,
masks_x: list[Tensor] | Tensor,
masks: list[Tensor] | Tensor,
) -> Tensor:
"""Forward pass of the IJEPAPredictorTIMM.
Args:
Expand All @@ -124,12 +128,12 @@ def forward(
x = self.predictor_embed(x)
x_pos_embed = self.predictor_pos_embed.repeat(B, 1, 1)

x += self.apply_masks(x_pos_embed, masks_x)
x += utils.apply_masks(x_pos_embed, masks_x)
_, N_ctxt, _ = x.shape

pos_embs = self.predictor_pos_embed.repeat(B, 1, 1)
pos_embs = self.apply_masks(pos_embs, masks)
pos_embs = self.repeat_interleave_batch(pos_embs, B, repeat=len_masks_x)
pos_embs = utils.apply_masks(pos_embs, masks)
pos_embs = utils.repeat_interleave_batch(pos_embs, B, repeat=len_masks_x)
pred_tokens = self.mask_token.repeat(pos_embs.size(0), pos_embs.size(1), 1)

pred_tokens += pos_embs
Expand All @@ -144,43 +148,3 @@ def forward(
x = self.predictor_proj(x)

return x

def repeat_interleave_batch(
self, x: torch.Tensor, B: int, repeat: int
) -> torch.Tensor:
"""Repeat and interleave the input tensor."""
N = len(x) // B
x = torch.cat(
[
torch.cat([x[i * B : (i + 1) * B] for _ in range(repeat)], dim=0)
for i in range(N)
],
dim=0,
)
return x

def apply_masks(
self, x: torch.Tensor, masks: Union[torch.Tensor, List[torch.Tensor]]
) -> torch.Tensor:
"""Apply masks to the input tensor.
From https://github.com/facebookresearch/ijepa/blob/main/src/masks/utils.py
Args:
x:
tensor of shape [B (batch-size), N (num-patches), D (feature-dim)].
masks:
tensor or list of tensors containing indices of patches in [N] to keep.
Returns:
Tensor of shape [B, N', D] where N' is the number of patches to keep.
"""

if not isinstance(masks, list):
masks = [masks]

all_x = []
for m in masks:
mask_keep = m.unsqueeze(-1).repeat(1, 1, x.size(-1))
all_x += [torch.gather(x, dim=1, index=mask_keep)]
return torch.cat(all_x, dim=0)
54 changes: 54 additions & 0 deletions lightly/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1261,3 +1261,57 @@ def update_drop_path_rate(
else:
block.drop_path1 = Identity()
block.drop_path2 = Identity()


def repeat_interleave_batch(x: Tensor, B: int, repeat: int) -> Tensor:
"""Repeat and interleave the input tensor.
Args:
x:
Tensor with shape (B * N, ...) where B is the batch size and N the number of
batches.
B:
Batch size.
repeat:
Number of times to repeat each batch.
Returns:
Tensor with shape (B * repeat * N, ...) where each batch is repeated `repeat`
times.
"""
N = len(x) // B
x = torch.cat(
[
torch.cat([x[i * B : (i + 1) * B] for _ in range(repeat)], dim=0)
for i in range(N)
],
dim=0,
)
return x


def apply_masks(x: Tensor, masks: Tensor | list[Tensor]) -> Tensor:
"""Apply masks to the input tensor.
From https://github.com/facebookresearch/ijepa/blob/main/src/masks/utils.py
Args:
x:
Tensor of shape (B, N, D) where N is the number of patches.
masks:
Tensor or list of tensors containing indices of patches in
[0, N-1] to keep. Each tensor musth have shape (B, K) where K is the number
of patches to keep. All masks must have the same K.
Returns:
Tensor of shape (B * num_masks, K, D) where K is the number of patches to keep.
"""

if not isinstance(masks, list):
masks = [masks]

all_x = []
for m in masks:
mask_keep = m.unsqueeze(-1).repeat(1, 1, x.size(-1))
all_x += [torch.gather(x, dim=1, index=mask_keep)]
return torch.cat(all_x, dim=0)

0 comments on commit fac3dcb

Please sign in to comment.