diff --git a/mlx_vlm/models/idefics2/vision.py b/mlx_vlm/models/idefics2/vision.py index cb90097..7a093fb 100644 --- a/mlx_vlm/models/idefics2/vision.py +++ b/mlx_vlm/models/idefics2/vision.py @@ -5,6 +5,7 @@ import mlx.core as mx import mlx.nn as nn import numpy as np +import torch @dataclass @@ -151,8 +152,8 @@ def __init__(self, config: VisionConfig): stride=self.patch_size, ) - self.num_patches = (self.image_size // self.patch_size) ** 2 - self.num_positions = self.num_patches + self.num_patches = self.image_size // self.patch_size + self.num_positions = self.num_patches**2 self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array: @@ -163,10 +164,33 @@ def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array: H // self.patch_size, W // self.patch_size, ) - sequence = np.arange(max_nb_patches_h * max_nb_patches_w) + N = max_nb_patches_h * max_nb_patches_w - # Tile the sequence to repeat it B times, each time as a new row - position_ids = np.tile(sequence, (B, 1)) + boundaries = torch.arange(1 / self.num_patches, 1.0, 1 / self.num_patches) + position_ids = torch.full( + size=(B, max_nb_patches_h * max_nb_patches_w), fill_value=0 + ) + + for batch_idx, p_attn_mask in enumerate(np.array(mask)): + nb_patches_h = p_attn_mask[:, 0].sum() + nb_patches_w = p_attn_mask[0].sum() + + fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h) + fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w) + + bucket_coords_h = torch.bucketize( + fractional_coords_h, boundaries, right=True + ) + bucket_coords_w = torch.bucketize( + fractional_coords_w, boundaries, right=True + ) + + pos_ids = ( + bucket_coords_h[:, None] * self.num_patches + bucket_coords_w + ).flatten() + position_ids[batch_idx][p_attn_mask.reshape(-1)] = pos_ids + + print(position_ids) embeddings = patch_embeddings embeddings += self.position_embedding(mx.array(position_ids)) @@ -199,18 +223,13 @@ def __call__( B, L // patch_size, D // patch_size, - ) + ), + dtype=mx.bool_, ) - x = self.embeddings(x, mask=None) + x = self.embeddings(x, mask=patch_attention_mask) encoder_states = (x,) if output_hidden_states else None - patch_size = self.config.patch_size - - mask = None - if x.shape[1] > 1: - mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1]) - mask = mask.astype(x.dtype) for layers in self.encoder.layers: x = layers(x, mask=None) @@ -224,10 +243,7 @@ def __call__( def sanitize(self, weights): sanitized_weights = {} for k, v in weights.items(): - if "position_ids" in k: - # Remove unused position_ids - continue - elif "patch_embedding.weight" in k: + if "patch_embedding.weight" in k: # PyTorch conv2d weight tensors have shape: # [out_channels, in_channels, kH, KW] # MLX conv2d expects the weight be of shape: