Skip to content

Commit

Permalink
working torch prototype
Browse files Browse the repository at this point in the history
  • Loading branch information
Blaizzy committed May 3, 2024
1 parent 34a3dd1 commit a96c5fa
Showing 1 changed file with 33 additions and 17 deletions.
50 changes: 33 additions & 17 deletions mlx_vlm/models/idefics2/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import mlx.core as mx
import mlx.nn as nn
import numpy as np
import torch


@dataclass
Expand Down Expand Up @@ -151,8 +152,8 @@ def __init__(self, config: VisionConfig):
stride=self.patch_size,
)

self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches
self.num_patches = self.image_size // self.patch_size
self.num_positions = self.num_patches**2
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)

def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array:
Expand All @@ -163,10 +164,33 @@ def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array:
H // self.patch_size,
W // self.patch_size,
)
sequence = np.arange(max_nb_patches_h * max_nb_patches_w)
N = max_nb_patches_h * max_nb_patches_w

# Tile the sequence to repeat it B times, each time as a new row
position_ids = np.tile(sequence, (B, 1))
boundaries = torch.arange(1 / self.num_patches, 1.0, 1 / self.num_patches)
position_ids = torch.full(
size=(B, max_nb_patches_h * max_nb_patches_w), fill_value=0
)

for batch_idx, p_attn_mask in enumerate(np.array(mask)):
nb_patches_h = p_attn_mask[:, 0].sum()
nb_patches_w = p_attn_mask[0].sum()

fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)

bucket_coords_h = torch.bucketize(
fractional_coords_h, boundaries, right=True
)
bucket_coords_w = torch.bucketize(
fractional_coords_w, boundaries, right=True
)

pos_ids = (
bucket_coords_h[:, None] * self.num_patches + bucket_coords_w
).flatten()
position_ids[batch_idx][p_attn_mask.reshape(-1)] = pos_ids

print(position_ids)

embeddings = patch_embeddings
embeddings += self.position_embedding(mx.array(position_ids))
Expand Down Expand Up @@ -199,18 +223,13 @@ def __call__(
B,
L // patch_size,
D // patch_size,
)
),
dtype=mx.bool_,
)

x = self.embeddings(x, mask=None)
x = self.embeddings(x, mask=patch_attention_mask)

encoder_states = (x,) if output_hidden_states else None
patch_size = self.config.patch_size

mask = None
if x.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
mask = mask.astype(x.dtype)

for layers in self.encoder.layers:
x = layers(x, mask=None)
Expand All @@ -224,10 +243,7 @@ def __call__(
def sanitize(self, weights):
sanitized_weights = {}
for k, v in weights.items():
if "position_ids" in k:
# Remove unused position_ids
continue
elif "patch_embedding.weight" in k:
if "patch_embedding.weight" in k:
# PyTorch conv2d weight tensors have shape:
# [out_channels, in_channels, kH, KW]
# MLX conv2d expects the weight be of shape:
Expand Down

0 comments on commit a96c5fa

Please sign in to comment.