From 14288cada9772d20164dc9f7838ce1b2c21eeb83 Mon Sep 17 00:00:00 2001 From: Benedikt Terhechte Date: Thu, 28 Nov 2024 15:16:17 +0100 Subject: [PATCH] Fix Bugs in chat UI (#96) * fix bugs * changes --------- Co-authored-by: Prince Canuma --- mlx_vlm/chat_ui.py | 12 ++++++------ mlx_vlm/utils.py | 14 +++++++++++++- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/mlx_vlm/chat_ui.py b/mlx_vlm/chat_ui.py index 8c29165..210597f 100644 --- a/mlx_vlm/chat_ui.py +++ b/mlx_vlm/chat_ui.py @@ -29,20 +29,20 @@ def parse_arguments(): def chat(message, history, temperature, max_tokens): chat = [] - if len(message.files) >= 1: - chat.append(message.text) + if len(message["files"]) >= 1: + chat.append({"role": "user", "content": message["text"]}) else: raise gr.Error("Please upload an image. Text only chat is not supported.") - files = message.files[-1].path + file = message["files"][-1] if model.config.model_type != "paligemma": - messages = apply_chat_template(processor, config, chat) + prompt = apply_chat_template(processor, config, chat) else: - messages = message.text + prompt = message.text response = "" for chunk in stream_generate( - model, processor, files, messages, image_processor, max_tokens, temp=temperature + model, processor, file, prompt, image_processor, max_tokens, temp=temperature ): response += chunk yield response diff --git a/mlx_vlm/utils.py b/mlx_vlm/utils.py index 2248e33..9cf0d3d 100644 --- a/mlx_vlm/utils.py +++ b/mlx_vlm/utils.py @@ -1102,7 +1102,19 @@ def stream_generate( image_processor, processor, image, prompt, image_token_index, resize_shape ) input_ids, pixel_values, mask = inputs[:3] - kwargs = {k: v for k, v in zip(["image_grid_thw", "image_sizes"], inputs[3:])} + kwargs = { + k: v + for k, v in zip( + [ + "image_grid_thw", + "image_sizes", + "aspect_ratio_ids", + "aspect_ratio_mask", + "cross_attention_mask", + ], + inputs[3:], + ) + } detokenizer = processor.detokenizer