-
Notifications
You must be signed in to change notification settings - Fork 898
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
LlaVA in MLX #461
LlaVA in MLX #461
Conversation
Nice initiative. I am also attempting to port the llava to MLX, but I have not made any progress yet. I am happy to help in any way that I can, just clarify the issue you mentioned regarding reconciling the supposed number of attention heads (32). Could you provide more details on that? |
@mzbac Thank you! Appreciate any help, or if this gives you a way to start a few steps ahead on your version of the port. Will explain my observation below: According to the model config, the base LLM is Vicuna 7b 1.5, a fine-tuned version of Llama 2. From that information, I used the config file of that model to fill out what the base LLM of LlaVa should be: However, when I do so, I get the following error for loading in the language model weights: What this tells me is that the pre-trained weights have a shape of 4096 x 4096, whereas our implementation of Llama2 uses the following: Where 4096 (dim size) * 32 (num heads) = 131072, which leads me to believe there is something wrong either with how I saved the pre-trained weights, or some mismatch in how Torch saved the attention heads versus here. If I were to update the model config, set num_heads = 1, this size mismatch won't happen. It's possible this is then actually correct, but I haven't gotten anywhere in the forward pass & inference section of the port which would quickly prove what's working or not. |
Yeah, there are two different formats of Llama models. One is the original PyTorch format released by Meta, and the other one uses HF format. By quickly checking the model link you shared, it seems like you are using the HF format one. You would use the Llama model in mlx-lm to load the weights (it can directly load the hf format model without converting), otherwise, there may be some issues due to improper loading of config and default configuration mismatch with actual weights used by the model. Edit:
|
@mzbac that makes a ton of sense, thank you for the point it looks like the weights now load in properly! A ton of this code will inevitably need to be cleaned up before merging but next up in my mind is getting the processor and inference working, I'll share any updates I have on that when I can Edit: The vocab size difference is because LlaVA is an extension of that original Vicuna, but now with additional tokens. This is a bit confusing to me as well. According to transformers:
Which agrees with what you are saying, but also according to huggingface, the vocab size is 32064 (https://huggingface.co/llava-hf/llava-1.5-7b-hf/blob/main/config.json#L37) . |
@mzbac tracking here, looks like there's a decent chunk of tasks remaining. Not sure when I'll be able to get to them all, so for anyone reading this and curious (I think) these are the remaining steps we'd need to do: For the forward pass:
Once done, this should then be able to be converted to actual generation and not just a single ouput, i.e. port over model.generate. Haven't looked this code over yet so I can't say what sub-tasks remain. And then there's the general housekeeping. Deleting unused classes, finding speedups, etc. I'd imagine it's also valuable to port over the text and image processor like was done for CLIP. That's something else that'd be great to add in. |
Yeah, I am currently working on making the clip model compatible with HF format and updating the model to support output hidden states. I think I can complete that by the end of this week. If you're okay with it, you can leave that task to me. |
@mzbac I just looked at your CLIP PR and looks great! Once that is merged I can sync my fork with main, and update the forward pass for llava. I still need to do some more research on how the full forward pass is implemented to make sure nothing else like this "surprises" me, that may take some time, so if you (or anyone reading) also feels comfortable trying your hand at this too it's no problem. |
Yeah, I started looking at Llava's forward pass and got some understanding of how it works. However, I will set up some tests to ensure there are no surprises in our implementation. I will share it with you once I have the setup. |
@nkasmanoff, while I was trying to set up the test, I noticed that I had to refactor quite a bit of code in order to be able to run the test and load the model. After writing two tests for image features and merging image features with input ids into input embeddings, most of the LLava model was completed (the rest is just a normal forward pass for llama LLM). I ended up finishing the llava model forward pass. Here is the complete code (https://github.com/mzbac/mlx-examples/tree/llava/llava), let me know if you would like to sync it back to your branch. |
I'm really excited about the progress on this implementation!
Could you point me to the code for that? Maybe there is a workaround using MLX still. |
Yeah, here is the merged image feature and input embedding part: https://github.com/mzbac/mlx-examples/blob/llava/llava/llava.py#L94-L131. |
@mzbac looks very good to me! I haven't made any major changes on my end, so I think syncing your changes would be ideal. Happy to test those changes on my end too. |
self, | ||
inputs: mx.array, | ||
cache=None, | ||
inputs_embeds=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Most of the implementation is copied from mlx-lm's llama, with only updates made to the forward pass to allow for directly passing inputs_embeds for the initial prompt evaluation.
llava/llava.py
Outdated
def _merge_input_ids_with_image_features( | ||
self, image_features, inputs_embeds, input_ids | ||
): | ||
|
||
image_features = np.array(image_features) | ||
inputs_embeds = np.array(inputs_embeds) | ||
input_ids = np.array(input_ids) | ||
|
||
_, num_image_patches, embed_dim = image_features.shape | ||
batch_size, sequence_length = input_ids.shape | ||
|
||
special_image_token_mask = input_ids == self.config.image_token_index | ||
num_special_image_tokens = np.sum(special_image_token_mask, axis=-1) | ||
|
||
# if no special image tokens found, return a warning | ||
if np.all(num_special_image_tokens == 0): | ||
logging.warning( | ||
"No special image tokens found in the input. Please make sure to include <image> in your prompt." | ||
) | ||
|
||
# calculate the final sequence length. Will be the original sequence length + the # of image tokens to be inserted in. | ||
final_sequence_length = ( | ||
np.max(num_special_image_tokens) * (num_image_patches - 1) | ||
) + sequence_length | ||
|
||
non_image_indices = np.where( | ||
input_ids != self.config.image_token_index) | ||
|
||
new_token_positions = ( | ||
np.cumsum((special_image_token_mask * | ||
(num_image_patches - 1) + 1), axis=-1) | ||
- 1 | ||
) | ||
text_to_overwrite = new_token_positions[non_image_indices] | ||
|
||
final_embedding = np.zeros( | ||
(batch_size, final_sequence_length, | ||
embed_dim), dtype=inputs_embeds.dtype | ||
) | ||
|
||
final_embedding[non_image_indices[0], text_to_overwrite, :] = inputs_embeds[ | ||
non_image_indices | ||
] | ||
|
||
image_to_overwrite = np.all(final_embedding == 0, axis=-1) | ||
reshaped_image_features = image_features.reshape(-1, embed_dim) | ||
final_embedding[image_to_overwrite, :] = reshaped_image_features[ | ||
: np.sum(image_to_overwrite) | ||
] | ||
|
||
return mx.array(final_embedding) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@awni, this is the part where we have to use NumPy as a workaround because MLX doesn't support boolean indexing. Maybe you could give us some pointers on how we can implement this using MLX.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a couple questions about how this function is supposed to work:
- Shouldn't it be an error if the input text does not have the same number of
<image>
tokens as the image batch size? - Can the input text have a batch size other than 1? If so, how does this work with the image batch? Do all the input images need to have the same number of
<image>
tokens?
I also have a couple thoughts on how to change this to work better for MLX (and possibly be a lot simpler). Assuming the text has a batch size of 1:
- Get the positions of the
<image>
tags and split the text embeddings based on that - Split the image embeddings along the batch axis.
- Interleave the two sets and concatenate
In this case it should be fine from an efficiency standpoint since you typically have only one or two <image>
tags. And, I think it would be simpler. This would not be the most efficient approach if you had a lot of images and just a few embeddings per image.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1, you are right. I skipped the error checking for mismatched image tokens and image batches to simplify the implementation, but we could definitely add it. FYI: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava/modeling_llava.py#L333-L337
2. My understanding is that as long as the total number of images matches the total image tokens in a batch, it should work with batches. Given the final embedding is batched, however, I have not done any batch testing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Regarding 2, makes sense! Does the rest of the code work for batched text? E.g. the generation part properly handles different text examples? If it's just using our MLX LM llama implementation (which it looks like) then the answer is probably no. I don't know if that simplifies much though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct me if I am wrong, but I thought the Llama example should be able to handle batched text generation. The only difference is that the initial prompt will get the embedding with text and image from the code above and directly go to LLM's encode layer without going through the embedding/norm layer. Once that forward pass is completed, we get the next token logits and key-value cache, then the rest of process is just a standard forward pass for LLM.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@awni, I don't have permission to push the changes. Maybe either you or @nkasmanoff could update it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
😬 sorry. Yea happy to, or @nkasmanoff can give you temporary permissions to push to his fork.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mzbac just gave you permission, let me know if that worked!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@nkasmanoff @awni I have pushed the changes. Please note that the test.py will fail due to the fast_gelu to native gelu change, but if you build the latest mlx locally, the test will pass.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, we should put the version in the requirement in that case. Should work with 0.5
and up.
Also it looks like this is almost ready to go. I will review a bit more tonight, but from what I looked at, it looked quite nice already! |
@awni this is great to hear! It depends if you think this is essential to this PR, but some ideas for additional features which are missing now that may be useful later on.
My opinion is we don't need any for first example, but is there anything else to consider? |
Thanks for the list @nkasmanoff, also think we can leave most of it as follow up depending on which direction this example goes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🚀
Really great addition, thank you @mzbac and @nkasmanoff !
@nkasmanoff @mzbac Great works! Do you plan to support Llava-1.6-hf too ? |
@nxphi47 would like to add support for llava 1.6, however, could you create an issue for that? Comments on the closed PR may get lost. |
* add: llava mlx first draft * add: weights comparision * add forward pass skeleton * update: now imports weights correctly * delete base * latest * adding config * fix: use config * add mlx config * feat: add image processor for llava processor * wip * feat: llava working example * chore: refactor generate script * chore: clean up * add: warning to user if no <image> token despite using one * add: __call__ to LlavaModel * add: call to LlavaModel * update fp * clean up var names * update: native GeLU * Cleanup * update generate and readme * remove todo comment * rearrange tests * fix example code * nits in README * update readme * nit in readme * nits in README * chore(llava): refactor image embedding merging logic * min mlx version * nits in readmes * fix cli prompt, some nits * updates, slight simplify --------- Co-authored-by: anchen <[email protected]> Co-authored-by: Awni Hannun <[email protected]>
1.6 seems to need an image_newline model parameter. It doesn’t seem to be a straight drop in replacement for 1.5 |
Is v1.6 support abandoned, or are some of the newer mlx features helpful to getting it up and running? |
Thanks @awni! Llava 1.6 (Next) is coming to MLX VLM soon, Alongside a trainer. We currently support:
|
Hi, please see the following for what I’ve made so far for converting a Llava checkpoint, (https://huggingface.co/llava-hf/llava-1.5-7b-hf) into an MLX implementation.
I chose this model because it was easy to load its accompanying implementation on HF, and compare the architecture / model weights side by side.
There’s a few other choices I’ve made to make this easier on the first pass, such as combining all of the safe-tensors shards into a single PyTorch file, that way I could lean on some prior MLX examples implementations. Obviously this could be a bit restrictive as the LlaVA models get bigger and this file size > RAM, but that is one of the several issues I wanted to point out in this draft I can hopefully get some help in fixing.
There are several TODOs in the repo showing these issues. I think most urgently, something I am confused by is how to reconcile the supposed # of attention heads (32) with what I’m saving from the downloaded weights.
I tried to outline my work in the notebook ‘Local Llava’ (can delete in the official PR) which is where I’d suggest anyone able to help get started from.
The other major TODO I also have left is figuring out how to do the forward pass & model.generate that takes into account text and image inputs, but figure that can wait until confirmation the model is actually loaded correctly.
Improved documentation, README, and tests to come along afterwards.
Thank you in advance to anyone who can help!
Feel free to close if another implementation comes and looks further along :-)
P.S: Another nice to have for this would be the ability to use other multi modal variants that accept different base models such as Phi and Mistral.