From 8541cfb094fb4e5f47ecd7883904602fe70f1ddb Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sun, 5 May 2024 15:39:45 +0200 Subject: [PATCH] Add chat UI with gradio --- README.md | 6 ++ mlx_vlm/chat_ui.py | 142 +++++++++++++++++++++++++++++++++++++++++++++ mlx_vlm/utils.py | 8 ++- requirements.txt | 1 + 4 files changed, 155 insertions(+), 2 deletions(-) create mode 100644 mlx_vlm/chat_ui.py diff --git a/README.md b/README.md index 942312f..c7997b7 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,12 @@ pip install mlx-vlm **Inference**: +**CLI** ```sh python -m mlx_vlm.generate --model qnguyen3/nanoLLaVA --max-tokens 100 --temp 0.0 +``` + +**Chat UI with Gradio** +```sh +python -m mlx_vlm.chat_ui --model qnguyen3/nanoLLaVA ``` \ No newline at end of file diff --git a/mlx_vlm/chat_ui.py b/mlx_vlm/chat_ui.py new file mode 100644 index 0000000..aa3c5a1 --- /dev/null +++ b/mlx_vlm/chat_ui.py @@ -0,0 +1,142 @@ +import argparse +from typing import Optional + +import gradio as gr +import mlx.core as mx + +from mlx_vlm import load + +from .prompt_utils import get_message_json +from .utils import ( + generate_step, + load, + load_config, + load_image_processor, + prepare_inputs, + sample, +) + + +def parse_arguments(): + parser = argparse.ArgumentParser( + description="Generate text from an image using a model." + ) + parser.add_argument( + "--model", + type=str, + default="qnguyen3/nanoLLaVA", + help="The path to the local model directory or Hugging Face repo.", + ) + return parser.parse_args() + + +args = parse_arguments() +config = load_config(args.model) +model, processor = load(args.model, {"trust_remote_code": True}) +image_processor = load_image_processor(args.model) + + +def generate( + model, + processor, + image: str, + prompt: str, + image_processor=None, + temp: float = 0.0, + max_tokens: int = 100, + repetition_penalty: Optional[float] = None, + repetition_context_size: Optional[int] = None, + top_p: float = 1.0, +): + + if image_processor is not None: + tokenizer = processor + else: + tokenizer = processor.tokenizer + + input_ids, pixel_values = prepare_inputs(image_processor, processor, image, prompt) + logits, cache = model(input_ids, pixel_values) + logits = logits[:, -1, :] + y, _ = sample(logits, temp, top_p) + + detokenizer = processor.detokenizer + detokenizer.reset() + + detokenizer.add_token(y.item()) + + for (token, _), n in zip( + generate_step( + model.language_model, + logits, + cache, + temp, + repetition_penalty, + repetition_context_size, + top_p, + ), + range(max_tokens), + ): + token = token.item() + + if token == tokenizer.eos_token_id: + break + + detokenizer.add_token(token) + detokenizer.finalize() + yield detokenizer.last_segment + + +def chat(message, history, temperature, max_tokens): + + chat = [] + for item in history: + chat.append(get_message_json(config["model_type"], item[0])) + if item[1] is not None: + chat.append({"role": "assistant", "content": item[1]}) + + if message["files"]: + chat.append(get_message_json(config["model_type"], message["text"])) + + messages = processor.apply_chat_template( + chat, + tokenize=False, + add_generation_prompt=True, + ) + response = "" + for chunk in generate( + model, + processor, + message["files"][0], + messages, + image_processor, + temperature, + max_tokens, + ): + response += chunk + yield response + + +demo = gr.ChatInterface( + fn=chat, + title="MLX-VLM Chat UI", + additional_inputs_accordion=gr.Accordion( + label="⚙️ Parameters", open=False, render=False + ), + additional_inputs=[ + gr.Slider( + minimum=0, maximum=1, step=0.1, value=0.9, label="Temperature", render=False + ), + gr.Slider( + minimum=128, + maximum=4096, + step=1, + value=200, + label="Max new tokens", + render=False, + ), + ], + description=f"Now Running {args.model}", + multimodal=True, +) + +demo.launch(inbrowser=True) diff --git a/mlx_vlm/utils.py b/mlx_vlm/utils.py index 1890e1b..b2feaf0 100644 --- a/mlx_vlm/utils.py +++ b/mlx_vlm/utils.py @@ -128,7 +128,7 @@ def load_model(model_path: Path, lazy: bool = False) -> nn.Module: model_id= "" model = AutoModelForCausalLM.from_pretrained(model_id) -processor = AutoProcessor.from_pretrained(model_id) +processor = AutoProcessor.from_pretrained(model_id) model.save_pretrained("") processor.save_pretrained("") @@ -233,7 +233,11 @@ def load( return model, processor -def load_config(model_path: Path) -> dict: +def load_config(model_path: Union[str, Path]) -> dict: + + if isinstance(model_path, str): + model_path = get_model_path(model_path) + try: with open(model_path / "config.json", "r") as f: config = json.load(f) diff --git a/requirements.txt b/requirements.txt index c9116a1..259a456 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,5 +4,6 @@ numpy transformers torch huggingface_hub +gradio Pillow requests