Skip to content

Commit

Permalink
Merge branch 'main' into stream-generate-fix
Browse files Browse the repository at this point in the history
  • Loading branch information
andimarafioti authored Nov 28, 2024
2 parents 63d0156 + 14288ca commit 04b1168
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 7 deletions.
12 changes: 6 additions & 6 deletions mlx_vlm/chat_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,20 +29,20 @@ def parse_arguments():

def chat(message, history, temperature, max_tokens):
chat = []
if len(message.files) >= 1:
chat.append(message.text)
if len(message["files"]) >= 1:
chat.append({"role": "user", "content": message["text"]})
else:
raise gr.Error("Please upload an image. Text only chat is not supported.")

files = message.files[-1].path
file = message["files"][-1]
if model.config.model_type != "paligemma":
messages = apply_chat_template(processor, config, chat)
prompt = apply_chat_template(processor, config, chat)
else:
messages = message.text
prompt = message.text

response = ""
for chunk in stream_generate(
model, processor, files, messages, image_processor, max_tokens, temp=temperature
model, processor, file, prompt, image_processor, max_tokens, temp=temperature
):
response += chunk
yield response
Expand Down
14 changes: 13 additions & 1 deletion mlx_vlm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1105,7 +1105,19 @@ def stream_generate(
image_processor, processor, image, prompt, image_token_index, resize_shape
)
input_ids, pixel_values, mask = inputs[:3]
kwargs = {k: v for k, v in zip(["image_grid_thw", "image_sizes"], inputs[3:])}
kwargs = {
k: v
for k, v in zip(
[
"image_grid_thw",
"image_sizes",
"aspect_ratio_ids",
"aspect_ratio_mask",
"cross_attention_mask",
],
inputs[3:],
)
}

detokenizer = processor.detokenizer

Expand Down

0 comments on commit 04b1168

Please sign in to comment.