Skip to content

Commit

Permalink
add prompt utils and code formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
Blaizzy committed May 3, 2024
1 parent ed9a948 commit 473692e
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 9 deletions.
12 changes: 8 additions & 4 deletions mlx_vlm/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import mlx.core as mx

from .prompt_utils import get_message_json
from .utils import generate, get_model_path, load, load_config, load_image_processor

MODEL_TYPE = ""
Expand Down Expand Up @@ -50,9 +51,10 @@ def parse_arguments():

def get_model_and_processors(model_path):
model_path = get_model_path(model_path)
config = load_config(model_path)
model, processor = load(model_path, {"trust_remote_code": True})
image_processor = load_image_processor(model_path)
return model, processor, image_processor
return model, processor, image_processor, config


def sample(logits, temperature=0.0):
Expand All @@ -64,22 +66,24 @@ def sample(logits, temperature=0.0):

def main():
args = parse_arguments()
model, processor, image_processor = get_model_and_processors(args.model)
model, processor, image_processor, config = get_model_and_processors(args.model)

prompt = codecs.decode(args.prompt, "unicode_escape")

if "chat_template" in processor.__dict__.keys():
prompt = processor.apply_chat_template(
[{"role": "user", "content": f"<image>\n{prompt}"}],
[get_message_json(config["model_type"], prompt)],
tokenize=False,
add_generation_prompt=True,
)

elif "tokenizer" in processor.__dict__.keys():
prompt = processor.tokenizer.apply_chat_template(
[{"role": "user", "content": f"<image>\n{prompt}"}],
[get_message_json(config["model_type"], prompt)],
tokenize=False,
add_generation_prompt=True,
)

else:
ValueError(
"Error: processor does not have 'chat_template' or 'tokenizer' attribute."
Expand Down
1 change: 0 additions & 1 deletion mlx_vlm/models/idefics2/idefics2.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,6 @@ def from_pretrained(path_or_hf_repo: str):
weights.update(mx.load(wf))

weights = model.sanitize(weights=weights)

weights = VisionModel(model_config.vision_config).sanitize(weights=weights)
weights = LanguageModel(model_config.text_config).sanitize(weights=weights)
model.load_weights(list(weights.items()))
Expand Down
1 change: 0 additions & 1 deletion mlx_vlm/models/idefics2/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,6 @@ def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array:
boundaries = np.arange(1 / self.num_patches, 1.0, 1 / self.num_patches)
sequence = np.zeros((max_nb_patches_h * max_nb_patches_w))

# Step 3: Use broadcasting to expand this row to B rows
position_ids = np.zeros_like(mask, dtype=int)

def bucketize(values, boundaries):
Expand Down
22 changes: 22 additions & 0 deletions mlx_vlm/prompt_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
def get_message_json(model_name, prompt):
"""
Get the appropriate JSON message based on the specified model.
Args:
model_name (str): The model for which to generate the message. Options: 'Idefics 2', 'nanollava', 'llava'.
prompt (str): The text prompt to be included in the message.
Returns:
dict: A dictionary representing the JSON message for the specified model.
"""
if model_name == "idefics2":
message = {
"role": "user",
"content": [{"type": "image"}, {"type": "text", "text": prompt}],
}
elif model_name in ["llava-qwen2", "llava"]:
message = {"role": "user", "content": f"<image>\n{prompt}"}
else:
raise ValueError(f"Unsupported model: {model_name}")

return message
8 changes: 5 additions & 3 deletions mlx_vlm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,11 @@ def load_model(model_path: Path, lazy: bool = False) -> nn.Module:
config["vision_config"]
)
model_config.text_config = model_class.TextConfig.from_dict(config["text_config"])
model_config.perceiver_config = model_class.PerceiverConfig.from_dict(
config["perceiver_config"]
)

if hasattr(model_config, "perceiver_config"):
model_config.perceiver_config = model_class.PerceiverConfig.from_dict(
config["perceiver_config"]
)
model = model_class.Model(model_config)

if hasattr(model, "sanitize"):
Expand Down

0 comments on commit 473692e

Please sign in to comment.