Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Faster / more memory efficient Qwen VL #114

Merged
merged 1 commit into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 6 additions & 101 deletions mlx_vlm/models/qwen2_vl/language.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,87 +48,6 @@ def from_dict(cls, params):
)


class Qwen2RotaryEmbedding:
def __init__(self, dim, max_position_embeddings=2048, base=10000):
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base

inv_freq = 1.0 / (
self.base ** (mx.arange(0, self.dim, 2).astype(mx.float32) / self.dim)
)
self.inv_freq = inv_freq

# Build the cos and sin cache
self._set_cos_sin_cache(seq_len=max_position_embeddings)

def _set_cos_sin_cache(self, seq_len):
self.max_seq_len_cached = seq_len
t = mx.arange(self.max_seq_len_cached).astype(mx.float32)

freqs = mx.outer(t, self.inv_freq)

# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = mx.concatenate((freqs, freqs), axis=-1)
self.cos_cached = mx.cos(emb)
self.sin_cached = mx.sin(emb)

def __call__(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]

if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len=seq_len)

return (
self.cos_cached[:seq_len].astype(x.dtype),
self.sin_cached[:seq_len].astype(x.dtype),
)


def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return mx.concatenate([-x2, x1], axis=-1)


def apply_multimodal_rotary_pos_emb(q, k, cos, sin, position_ids, mrope_section):
"""
Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors.

Args:
q (mx.array): The query tensor.
k (mx.array): The key tensor.
cos (mx.array): The cosine part of the rotary embedding.
sin (mx.array): The sine part of the rotary embedding.
mrope_section (List[int]): Multimodal rope section for channel dimension of temporal, height and width.
unsqueeze_dim (int, optional): Dimension to unsqueeze. Defaults to 1.

Returns:
tuple(mx.array): The rotated query and key tensors.
"""

mrope_section = np.cumsum(mrope_section * 2)[:-1].tolist()

cos = cos[position_ids]
sin = sin[position_ids]

cos = mx.concatenate(
[m[i % 3] for i, m in enumerate(mx.split(cos, mrope_section, axis=-1))], axis=-1
)[
:, None, :, :
] # unsqueeze dim 1
sin = mx.concatenate(
[m[i % 3] for i, m in enumerate(mx.split(sin, mrope_section, axis=-1))], axis=-1
)[:, None, :, :]

# Apply rotary embedding
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)

return q_embed, k_embed


class Attention(nn.Module):
def __init__(self, args: TextConfig):
super().__init__()
Expand All @@ -146,12 +65,10 @@ def __init__(self, args: TextConfig):
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)

self.rope_scaling = args.rope_scaling

self.rotary_emb = Qwen2RotaryEmbedding(
self.rotary_emb = nn.RoPE(
head_dim,
max_position_embeddings=args.max_position_embeddings,
base=args.rope_theta,
traditional=args.rope_traditional,
)

def __call__(
Expand All @@ -173,25 +90,13 @@ def __call__(
0, 2, 1, 3
)

kv_seq_len = keys.shape[-2]
if cache is not None:
kv_seq_len += cache.offset + 1
position_ids = mx.arange(cache.offset, cache.offset + L)
else:
position_ids = mx.arange(0, L)

position_ids = mx.expand_dims(position_ids, axis=0)
position_ids = mx.tile(position_ids, (3, 1, 1))

cos, sin = self.rotary_emb(values, kv_seq_len)
offset = cache.offset if cache else 0

if mask is not None:
mask = mask[None, None, :, :]
mask = mask[:, :, :, : keys.shape[-2]]
mask = mask[..., : keys.shape[-2]]

queries, keys = apply_multimodal_rotary_pos_emb(
queries, keys, cos, sin, position_ids, self.rope_scaling["mrope_section"]
)
queries = self.rotary_emb(queries, offset=offset)
keys = self.rotary_emb(keys, offset=offset)

if cache is not None:
keys, values = cache.update_and_fetch(keys, values)
Expand Down
4 changes: 4 additions & 0 deletions mlx_vlm/models/qwen2_vl/qwen2_vl.py
Blaizzy marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ def __call__(
):
image_grid_thw = kwargs.pop("image_grid_thw", None)
image_grid_thw = mx.array(image_grid_thw)

dtype = self.vision_tower.patch_embed.proj.weight.dtype
pixel_values = pixel_values.astype(dtype)

input_embddings = self.get_input_embeddings(
input_ids, pixel_values, image_grid_thw
)
Expand Down
35 changes: 1 addition & 34 deletions mlx_vlm/models/qwen2_vl/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,38 +53,6 @@ def check_array_shape(arr):
return False


class Qwen2RotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.inv_freq = 1.0 / (
self.base
** (mx.arange(0, self.dim, 2, dtype=mx.int64).astype(mx.float32) / self.dim)
)
self._set_cos_sin_cache(seq_len=max_position_embeddings, dtype=mx.float32)

def _set_cos_sin_cache(self, seq_len, dtype):
self.max_seq_len_cached = seq_len
t = mx.arange(self.max_seq_len_cached, dtype=mx.int64).astype(
self.inv_freq.dtype
)
freqs = mx.outer(t, self.inv_freq)
emb = mx.concatenate((freqs, freqs), axis=-1)
self.cos_cached = mx.cos(emb).astype(dtype)
self.sin_cached = mx.sin(emb).astype(dtype)

def __call__(self, x, seq_len=None):
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len=seq_len, dtype=x.dtype)

return (
self.cos_cached[:seq_len].astype(x.dtype),
self.sin_cached[:seq_len].astype(x.dtype),
)


def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
Expand Down Expand Up @@ -200,8 +168,7 @@ def __call__(

q = apply_rotary_pos_emb_vision(mx.expand_dims(q, 0), rotary_pos_emb)[0]
k = apply_rotary_pos_emb_vision(mx.expand_dims(k, 0), rotary_pos_emb)[0]

attention_mask = mx.ones((1, seq_length, seq_length))
attention_mask = mx.ones((1, seq_length, seq_length), dtype=x.dtype)

for i in range(1, len(cu_seqlens)):
start = int(cu_seqlens[i - 1])
Expand Down