Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LlaVA in MLX #461

Merged
merged 35 commits into from
Mar 1, 2024
Merged

LlaVA in MLX #461

merged 35 commits into from
Mar 1, 2024

Conversation

nkasmanoff
Copy link
Contributor

Hi, please see the following for what I’ve made so far for converting a Llava checkpoint, (https://huggingface.co/llava-hf/llava-1.5-7b-hf) into an MLX implementation.

I chose this model because it was easy to load its accompanying implementation on HF, and compare the architecture / model weights side by side.

There’s a few other choices I’ve made to make this easier on the first pass, such as combining all of the safe-tensors shards into a single PyTorch file, that way I could lean on some prior MLX examples implementations. Obviously this could be a bit restrictive as the LlaVA models get bigger and this file size > RAM, but that is one of the several issues I wanted to point out in this draft I can hopefully get some help in fixing.

There are several TODOs in the repo showing these issues. I think most urgently, something I am confused by is how to reconcile the supposed # of attention heads (32) with what I’m saving from the downloaded weights.

I tried to outline my work in the notebook ‘Local Llava’ (can delete in the official PR) which is where I’d suggest anyone able to help get started from.

The other major TODO I also have left is figuring out how to do the forward pass & model.generate that takes into account text and image inputs, but figure that can wait until confirmation the model is actually loaded correctly.

Improved documentation, README, and tests to come along afterwards.

Thank you in advance to anyone who can help!

Feel free to close if another implementation comes and looks further along :-)

P.S: Another nice to have for this would be the ability to use other multi modal variants that accept different base models such as Phi and Mistral.

@mzbac
Copy link
Contributor

mzbac commented Feb 19, 2024

Nice initiative. I am also attempting to port the llava to MLX, but I have not made any progress yet. I am happy to help in any way that I can, just clarify the issue you mentioned regarding reconciling the supposed number of attention heads (32). Could you provide more details on that?

@nkasmanoff
Copy link
Contributor Author

@mzbac Thank you! Appreciate any help, or if this gives you a way to start a few steps ahead on your version of the port. Will explain my observation below:

According to the model config, the base LLM is Vicuna 7b 1.5, a fine-tuned version of Llama 2.

From that information, I used the config file of that model to fill out what the base LLM of LlaVa should be:
Pasted Graphic

However, when I do so, I get the following error for loading in the language model weights:

Screenshot 2024-02-20 at 8 00 28 AM

What this tells me is that the pre-trained weights have a shape of 4096 x 4096, whereas our implementation of Llama2 uses the following:

Pasted Graphic 1

Where 4096 (dim size) * 32 (num heads) = 131072, which leads me to believe there is something wrong either with how I saved the pre-trained weights, or some mismatch in how Torch saved the attention heads versus here.

If I were to update the model config, set num_heads = 1, this size mismatch won't happen. It's possible this is then actually correct, but I haven't gotten anywhere in the forward pass & inference section of the port which would quickly prove what's working or not.

@mzbac
Copy link
Contributor

mzbac commented Feb 20, 2024

@mzbac Thank you! Appreciate any help, or if this gives you a way to start a few steps ahead on your version of the port. Will explain my observation below:

According to the model config, the base LLM is Vicuna 7b 1.5, a fine-tuned version of Llama 2.

From that information, I used the config file of that model to fill out what the base LLM of LlaVa should be: Pasted Graphic

However, when I do so, I get the following error for loading in the language model weights:

Screenshot 2024-02-20 at 8 00 28 AM

What this tells me is that the pre-trained weights have a shape of 4096 x 4096, whereas our implementation of Llama2 uses the following:

Pasted Graphic 1

Where 4096 (dim size) * 32 (num heads) = 131072, which leads me to believe there is something wrong either with how I saved the pre-trained weights, or some mismatch in how Torch saved the attention heads versus here.

If I were to update the model config, set num_heads = 1, this size mismatch won't happen. It's possible this is then actually correct, but I haven't gotten anywhere in the forward pass & inference section of the port which would quickly prove what's working or not.

Yeah, there are two different formats of Llama models. One is the original PyTorch format released by Meta, and the other one uses HF format. By quickly checking the model link you shared, it seems like you are using the HF format one. You would use the Llama model in mlx-lm to load the weights (it can directly load the hf format model without converting), otherwise, there may be some issues due to improper loading of config and default configuration mismatch with actual weights used by the model.

Edit:

@nkasmanoff
Copy link
Contributor Author

nkasmanoff commented Feb 22, 2024

@mzbac that makes a ton of sense, thank you for the point it looks like the weights now load in properly! A ton of this code will inevitably need to be cleaned up before merging but next up in my mind is getting the processor and inference working, I'll share any updates I have on that when I can

Edit:

The vocab size difference is because LlaVA is an extension of that original Vicuna, but now with additional tokens. This is a bit confusing to me as well. According to transformers:

processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
processor.tokenizer.vocab_size
# 32000

Which agrees with what you are saying, but also according to huggingface, the vocab size is 32064 (https://huggingface.co/llava-hf/llava-1.5-7b-hf/blob/main/config.json#L37) .

@nkasmanoff
Copy link
Contributor Author

@mzbac tracking here, looks like there's a decent chunk of tasks remaining. Not sure when I'll be able to get to them all, so for anyone reading this and curious (I think) these are the remaining steps we'd need to do:

For the forward pass:

Once done, this should then be able to be converted to actual generation and not just a single ouput, i.e. port over model.generate. Haven't looked this code over yet so I can't say what sub-tasks remain.

And then there's the general housekeeping. Deleting unused classes, finding speedups, etc.

I'd imagine it's also valuable to port over the text and image processor like was done for CLIP. That's something else that'd be great to add in.

@mzbac
Copy link
Contributor

mzbac commented Feb 22, 2024

Yeah, I am currently working on making the clip model compatible with HF format and updating the model to support output hidden states. I think I can complete that by the end of this week. If you're okay with it, you can leave that task to me.
FYI: #472

@nkasmanoff
Copy link
Contributor Author

@mzbac I just looked at your CLIP PR and looks great! Once that is merged I can sync my fork with main, and update the forward pass for llava. I still need to do some more research on how the full forward pass is implemented to make sure nothing else like this "surprises" me, that may take some time, so if you (or anyone reading) also feels comfortable trying your hand at this too it's no problem.

@mzbac
Copy link
Contributor

mzbac commented Feb 23, 2024

Yeah, I started looking at Llava's forward pass and got some understanding of how it works. However, I will set up some tests to ensure there are no surprises in our implementation. I will share it with you once I have the setup.

@mzbac
Copy link
Contributor

mzbac commented Feb 24, 2024

@nkasmanoff, while I was trying to set up the test, I noticed that I had to refactor quite a bit of code in order to be able to run the test and load the model. After writing two tests for image features and merging image features with input ids into input embeddings, most of the LLava model was completed (the rest is just a normal forward pass for llama LLM). I ended up finishing the llava model forward pass. Here is the complete code (https://github.com/mzbac/mlx-examples/tree/llava/llava), let me know if you would like to sync it back to your branch.
Note: The crucial part of the Llava model is merging image features with input IDs, which I had to implement using NumPy due to MLX not supporting boolean indexing. However, this operation only occurs once before generation, so I hope the performance won't be too bad.

@awni
Copy link
Member

awni commented Feb 24, 2024

I'm really excited about the progress on this implementation!

Note: The crucial part of the Llava model is merging image features with input IDs, which I had to implement using NumPy due to MLX not supporting boolean indexing.

Could you point me to the code for that? Maybe there is a workaround using MLX still.

@mzbac
Copy link
Contributor

mzbac commented Feb 24, 2024

I'm really excited about the progress on this implementation!

Note: The crucial part of the Llava model is merging image features with input IDs, which I had to implement using NumPy due to MLX not supporting boolean indexing.

Could you point me to the code for that? Maybe there is a workaround using MLX still.

Yeah, here is the merged image feature and input embedding part: https://github.com/mzbac/mlx-examples/blob/llava/llava/llava.py#L94-L131.
PS: If you pull down my fork branch, you can run python generate.py to test the performance. I didn't notice much performance impact.

@nkasmanoff
Copy link
Contributor Author

@nkasmanoff, while I was trying to set up the test, I noticed that I had to refactor quite a bit of code in order to be able to run the test and load the model. After writing two tests for image features and merging image features with input ids into input embeddings, most of the LLava model was completed (the rest is just a normal forward pass for llama LLM). I ended up finishing the llava model forward pass. Here is the complete code (https://github.com/mzbac/mlx-examples/tree/llava/llava), let me know if you would like to sync it back to your branch. Note: The crucial part of the Llava model is merging image features with input IDs, which I had to implement using NumPy due to MLX not supporting boolean indexing. However, this operation only occurs once before generation, so I hope the performance won't be too bad.

@mzbac looks very good to me! I haven't made any major changes on my end, so I think syncing your changes would be ideal. Happy to test those changes on my end too.

@nkasmanoff
Copy link
Contributor Author

@awni I think it is! I saw @mzbac as the reviewer hence the tag, but please take a look too :-)

llava/README.md Show resolved Hide resolved
self,
inputs: mx.array,
cache=None,
inputs_embeds=None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most of the implementation is copied from mlx-lm's llama, with only updates made to the forward pass to allow for directly passing inputs_embeds for the initial prompt evaluation.

llava/llava.py Outdated
Comment on lines 100 to 150
def _merge_input_ids_with_image_features(
self, image_features, inputs_embeds, input_ids
):

image_features = np.array(image_features)
inputs_embeds = np.array(inputs_embeds)
input_ids = np.array(input_ids)

_, num_image_patches, embed_dim = image_features.shape
batch_size, sequence_length = input_ids.shape

special_image_token_mask = input_ids == self.config.image_token_index
num_special_image_tokens = np.sum(special_image_token_mask, axis=-1)

# if no special image tokens found, return a warning
if np.all(num_special_image_tokens == 0):
logging.warning(
"No special image tokens found in the input. Please make sure to include <image> in your prompt."
)

# calculate the final sequence length. Will be the original sequence length + the # of image tokens to be inserted in.
final_sequence_length = (
np.max(num_special_image_tokens) * (num_image_patches - 1)
) + sequence_length

non_image_indices = np.where(
input_ids != self.config.image_token_index)

new_token_positions = (
np.cumsum((special_image_token_mask *
(num_image_patches - 1) + 1), axis=-1)
- 1
)
text_to_overwrite = new_token_positions[non_image_indices]

final_embedding = np.zeros(
(batch_size, final_sequence_length,
embed_dim), dtype=inputs_embeds.dtype
)

final_embedding[non_image_indices[0], text_to_overwrite, :] = inputs_embeds[
non_image_indices
]

image_to_overwrite = np.all(final_embedding == 0, axis=-1)
reshaped_image_features = image_features.reshape(-1, embed_dim)
final_embedding[image_to_overwrite, :] = reshaped_image_features[
: np.sum(image_to_overwrite)
]

return mx.array(final_embedding)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@awni, this is the part where we have to use NumPy as a workaround because MLX doesn't support boolean indexing. Maybe you could give us some pointers on how we can implement this using MLX.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a couple questions about how this function is supposed to work:

  • Shouldn't it be an error if the input text does not have the same number of <image> tokens as the image batch size?
  • Can the input text have a batch size other than 1? If so, how does this work with the image batch? Do all the input images need to have the same number of <image> tokens?

I also have a couple thoughts on how to change this to work better for MLX (and possibly be a lot simpler). Assuming the text has a batch size of 1:

  1. Get the positions of the <image> tags and split the text embeddings based on that
  2. Split the image embeddings along the batch axis.
  3. Interleave the two sets and concatenate

In this case it should be fine from an efficiency standpoint since you typically have only one or two <image> tags. And, I think it would be simpler. This would not be the most efficient approach if you had a lot of images and just a few embeddings per image.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1, you are right. I skipped the error checking for mismatched image tokens and image batches to simplify the implementation, but we could definitely add it. FYI: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava/modeling_llava.py#L333-L337
2. My understanding is that as long as the total number of images matches the total image tokens in a batch, it should work with batches. Given the final embedding is batched, however, I have not done any batch testing.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding 2, makes sense! Does the rest of the code work for batched text? E.g. the generation part properly handles different text examples? If it's just using our MLX LM llama implementation (which it looks like) then the answer is probably no. I don't know if that simplifies much though.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct me if I am wrong, but I thought the Llama example should be able to handle batched text generation. The only difference is that the initial prompt will get the embedding with text and image from the code above and directly go to LLM's encode layer without going through the embedding/norm layer. Once that forward pass is completed, we get the next token logits and key-value cache, then the rest of process is just a standard forward pass for LLM.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@awni, I don't have permission to push the changes. Maybe either you or @nkasmanoff could update it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😬 sorry. Yea happy to, or @nkasmanoff can give you temporary permissions to push to his fork.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mzbac just gave you permission, let me know if that worked!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nkasmanoff @awni I have pushed the changes. Please note that the test.py will fail due to the fast_gelu to native gelu change, but if you build the latest mlx locally, the test will pass.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, we should put the version in the requirement in that case. Should work with 0.5 and up.

@awni
Copy link
Member

awni commented Mar 1, 2024

Also it looks like this is almost ready to go. I will review a bit more tonight, but from what I looked at, it looked quite nice already!

@nkasmanoff
Copy link
Contributor Author

Also it looks like this is almost ready to go. I will review a bit more tonight, but from what I looked at, it looked quite nice already!

@awni this is great to hear!

It depends if you think this is essential to this PR, but some ideas for additional features which are missing now that may be useful later on.

  • Adding a MLX native processor + tokenizer, rather than lean on transformers.
  • Make possible to use other variants of LlaVA, or VLMs.
  • Fine tuning (discussed previously) demo.

My opinion is we don't need any for first example, but is there anything else to consider?

@mzbac

@awni
Copy link
Member

awni commented Mar 1, 2024

Thanks for the list @nkasmanoff, also think we can leave most of it as follow up depending on which direction this example goes.

Copy link
Member

@awni awni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚀

Really great addition, thank you @mzbac and @nkasmanoff !

@awni awni merged commit a429263 into ml-explore:main Mar 1, 2024
2 checks passed
@nxphi47
Copy link

nxphi47 commented Mar 2, 2024

@nkasmanoff @mzbac Great works! Do you plan to support Llava-1.6-hf too ?
It does require this change from transformers: huggingface/transformers#29012

@mzbac
Copy link
Contributor

mzbac commented Mar 4, 2024

@nkasmanoff @mzbac Great works! Do you plan to support Llava-1.6-hf too ? It does require this change from transformers: huggingface/transformers#29012

@nxphi47 would like to add support for llava 1.6, however, could you create an issue for that? Comments on the closed PR may get lost.

devonthomas35 pushed a commit to devonthomas35/mlx-examples that referenced this pull request Mar 11, 2024
* add: llava mlx first draft

* add: weights comparision

* add forward pass skeleton

* update: now  imports weights correctly

* delete base

* latest

* adding config

* fix: use config

* add mlx config

* feat: add image processor for llava processor

* wip

* feat: llava working example

* chore: refactor generate script

* chore: clean up

* add: warning to user if no <image> token despite using one

* add: __call__ to LlavaModel

* add: call to LlavaModel

* update fp

* clean up var names

* update: native GeLU

* Cleanup

* update generate and readme

* remove todo comment

* rearrange tests

* fix example code

* nits in README

* update readme

* nit in readme

* nits in README

* chore(llava): refactor image embedding merging logic

* min mlx version

* nits in readmes

* fix cli prompt, some nits

* updates, slight simplify

---------

Co-authored-by: anchen <[email protected]>
Co-authored-by: Awni Hannun <[email protected]>
@jrp2014
Copy link

jrp2014 commented Mar 23, 2024

1.6 seems to need an image_newline model parameter. It doesn’t seem to be a straight drop in replacement for 1.5

@jrp2014
Copy link

jrp2014 commented Jun 14, 2024

Is v1.6 support abandoned, or are some of the newer mlx features helpful to getting it up and running?

@awni
Copy link
Member

awni commented Jun 15, 2024

I think you can run it in MLX VLM, CC @Blaizzy

@Blaizzy
Copy link
Contributor

Blaizzy commented Jun 15, 2024

Thanks @awni!

Llava 1.6 (Next) is coming to MLX VLM soon,

Alongside a trainer.

We currently support:

  • Llava 1.5
  • Deepseek VL
  • Idefics 2
  • Paligemma
  • NanoLlava
  • Phi-3-vision (almost done)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants