Skip to content

Commit

Permalink
working numpy example
Browse files Browse the repository at this point in the history
  • Loading branch information
Blaizzy committed May 3, 2024
1 parent a96c5fa commit ed9a948
Showing 1 changed file with 33 additions and 15 deletions.
48 changes: 33 additions & 15 deletions mlx_vlm/models/idefics2/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import mlx.core as mx
import mlx.nn as nn
import numpy as np
import torch


@dataclass
Expand Down Expand Up @@ -165,32 +164,51 @@ def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array:
W // self.patch_size,
)
N = max_nb_patches_h * max_nb_patches_w
boundaries = np.arange(1 / self.num_patches, 1.0, 1 / self.num_patches)
sequence = np.zeros((max_nb_patches_h * max_nb_patches_w))

boundaries = torch.arange(1 / self.num_patches, 1.0, 1 / self.num_patches)
position_ids = torch.full(
size=(B, max_nb_patches_h * max_nb_patches_w), fill_value=0
)
# Step 3: Use broadcasting to expand this row to B rows
position_ids = np.zeros_like(mask, dtype=int)

def bucketize(values, boundaries):
idx = (
np.digitize(values, boundaries, right=True) - 1
) # adjust indices to match 'right=True'
idx[idx == -1] = 0 # Handle any -1 indices that may appear
return idx

for batch_idx, p_attn_mask in enumerate(np.array(mask)):
nb_patches_h = p_attn_mask[:, 0].sum()
nb_patches_w = p_attn_mask[0].sum()

fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
fractional_coords_h = np.linspace(0, 1 - 1e-6, nb_patches_h)
fractional_coords_w = np.linspace(0, 1 - 1e-6, nb_patches_w)

bucket_coords_h = torch.bucketize(
fractional_coords_h, boundaries, right=True
)
bucket_coords_w = torch.bucketize(
fractional_coords_w, boundaries, right=True
)
bucket_coords_h = bucketize(fractional_coords_h, boundaries)
bucket_coords_w = bucketize(fractional_coords_w, boundaries)

pos_ids = (
bucket_coords_h[:, None] * self.num_patches + bucket_coords_w
).flatten()
position_ids[batch_idx][p_attn_mask.reshape(-1)] = pos_ids

print(position_ids)
flat_indices = np.flatnonzero(
p_attn_mask
) # Get flat indices where p_attn_mask is non-zero

# Ensure pos_ids has sufficient length
if len(pos_ids) < len(flat_indices):
raise ValueError(
"Not enough pos_ids generated: {} needed, but only {} generated.".format(
len(flat_indices), len(pos_ids)
)
)

# Apply position ids to the positions indicated by p_attn_mask
position_ids[batch_idx].flat[flat_indices] = pos_ids[
: len(flat_indices) + 1
]

position_ids = position_ids.reshape(B, N)

embeddings = patch_embeddings
embeddings += self.position_embedding(mx.array(position_ids))
Expand Down

0 comments on commit ed9a948

Please sign in to comment.