Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds batched inference with left-padding #886

Closed
wants to merge 24 commits into from

Conversation

FlimFlamm
Copy link

@FlimFlamm FlimFlamm commented Jan 17, 2024

Adds a left-padding batched inference strategy by modifying generate/base.py and model.py

  • The only API change was to rename prompt to prompts in generate.py's main() function; it's still compatible with a string, and now with a list of strings
  • Under the right conditions (eg: batch size 64 and max_tokens=200 (something that wont overflow my 4090 at that batch size), and I can get just about 2000 tokens/second on mistral 7b. On stablem 3B i can get 6000 tokens per second (with batch size in the hundreds)
  • I added some ultra-light formatting and color to the decoding and printing of the generated resuts. Color/light is lit! (shown in vid)
2024-01-17.19-35-56.mp4

EDIT: currently triage'ing the test fails

Constructive feedback is very welcome. If something about this commit would adversely affect other parts of the repo that I have overlooked, I'll do my best to address it.

@FlimFlamm
Copy link
Author

I think the tests might need some tweaking (otherwise i might have broken them with a few to many pushes :D)

Will leave this as is for review, and will be back ASAP to carry out any fixes that might be desirable or necessary.

@carmocca
Copy link
Contributor

Hi @FlimFlamm! Thanks for working on this.

I had this partially implemented but never pushed it and I might have lost it because I cannot find it in my stashes 💀.

I'll sit on this for a bit and perhaps merge what you have with what I had. This will need tests and some performance benchmarking before landing.

@FlimFlamm
Copy link
Author

FlimFlamm commented Jan 19, 2024

Hi @FlimFlamm! Thanks for working on this.

I had this partially implemented but never pushed it and I might have lost it because I cannot find it in my stashes 💀.

I'll sit on this for a bit and perhaps merge what you have with what I had. This will need tests and some performance benchmarking before landing.

Happy to be of help!

For a relatively small change, it definitely affects computation in a lot of scripts (anything that touches generate), which includes things like generate/lora.py, generate/adapter.py, etc, so this is definitely one for careful review. Here are some notes/considerations I have so far in hindsight:

  • It's important to ensure that the ROPE application doesn't need to be modified for the left-padding case; left padding changes the way the rope applies by default (in the case of left padding in a given sequence, the position 1 of the ROPE wont be on the BOS token of the sequence). Outputs look good despite this, but they might be getting harmed by the difference, and some models might be more brittle in this situation.

  • A second important consideration is whether or not any extra masking needs to occur. (masking the left padding tokens). My implementation pads with 0's; originally I was masking them out inside of model.py's forward pass (modifying the mask that is loaded from the cache according to current left padding), but it didn't seem to result in any difference in outputs, so i removed it in the interest of not modifying model.py. Possibly the models I was testing with effectively ignore 0 tokens due to training/config dynamics, but other models might have issues. (since not all models have an explicit padding token, I'm not sure how this applies to all cases all cases )

  • Right padding might be a better alternative assuming the ROPE issue with left padding leads to complications for model.py. I originally tried implementing right padding but decided to do left padding because it seemed to require less alterations, but this might not be true if there's a ROPE issue.

  • I did some testing with generate/lora/adapter/sequentially, and the same left padding logic seems to work without issue (just requires the modifications found in generate-main(). Was able to do multi-gpu batched inference with mixtral!

I'm gonna tinker and test some more (hopefully to see if right padding an be more easily cinched in in case that turns out to be important for model performance)

@FlimFlamm
Copy link
Author

FlimFlamm commented Jan 21, 2024

Pushed some additions and changes that seemed sensical or cleaner. Made a simple padding function for utils (can do left and right padding), and set up the mask cache to optionally take a padding mask (can be passed when the kv cache is being set, or directly to build_mask_cache).

Also set up the same logic in sequentially.py for testing (seems to work great).

Finally I also added optional attention masking for the forward pass of model.py's GPT class (which isn't required, but seems like it would be useful for anyone using special masking)

NOTE: the masking strategy that bakes a batches left/right padding into the mask cache results in the mask cache being increased by a factor of batch_size (since we need unique padding inside each sequence's mask), but by doing so we dont have to do any tensor work during generation. In theory if the max sequence length explodes, this strategy loses its edge (because it quadratically scales the auto-regressive mask itself), which might make the batch_size factor start to hurt.

@WilliamGazeley
Copy link

Thanks for working on this @FlimFlamm, I was working on this functionality on my fork as well but the kv cache issue is a tricky one.

I cloned your repo and tried to run generation on stablelm and TinyLlama, but both produced outputs that were jibberish. I didn't make any changes to your code, any idea what could be going on?

@FlimFlamm
Copy link
Author

Thanks for working on this @FlimFlamm, I was working on this functionality on my fork as well but the kv cache issue is a tricky one.

I cloned your repo and tried to run generation on stablelm and TinyLlama, but both produced outputs that were jibberish. I didn't make any changes to your code, any idea what could be going on?

Can I ask exactly what method or CLI arg you used to test? Will try to reproduce and see if i can find the issue.

@WilliamGazeley
Copy link

WilliamGazeley commented Jan 25, 2024

I just did the following:

python scripts/download.py --repo_id 'TinyLlama/TinyLlama-1.1B-Chat-v1.0' --from_safetensors 1
python scripts/convert_hf_checkpoint.py --checkpoint_dir 'checkpoints/TinyLlama/TinyLlama-1.1B-Chat-v1.0'
python generate/base.py --checkpoint_dir 'checkpoints/TinyLlama/TinyLlama-1.1B-Chat-v1.0'

I set prompts = ["what food do llamas eat?"] and I get outputs that keep repeating words.

@FlimFlamm
Copy link
Author

FlimFlamm commented Jan 25, 2024

I just did the following:

python scripts/download.py --repo_id 'TinyLlama/TinyLlama-1.1B-Chat-v1.0' --from_safetensors 1
python scripts/convert_hf_checkpoint.py --checkpoint_dir 'checkpoints/TinyLlama/TinyLlama-1.1B-Chat-v1.0'
python generate/base.py --checkpoint_dir 'checkpoints/TinyLlama/TinyLlama-1.1B-Chat-v1.0'

I set prompts = ["what food do llamas eat?"] and I get outputs that keep repeating words.

Awesome, thanks for the details.

Editing...

@FlimFlamm
Copy link
Author

FlimFlamm commented Jan 26, 2024

I just did the following:

python scripts/download.py --repo_id 'TinyLlama/TinyLlama-1.1B-Chat-v1.0' --from_safetensors 1
python scripts/convert_hf_checkpoint.py --checkpoint_dir 'checkpoints/TinyLlama/TinyLlama-1.1B-Chat-v1.0'
python generate/base.py --checkpoint_dir 'checkpoints/TinyLlama/TinyLlama-1.1B-Chat-v1.0'

I set prompts = ["what food do llamas eat?"] and I get outputs that keep repeating words.

So I found the problem; I was building my mask incorrectly. recent push should have the replacement build_mask_cache() function. Only other necessary change was to use right padding instead of left padding, because having corrected the mask i started re-encountering the NaN issue described here pytorch/pytorch#103749

The exact cause of the problem are cases where an entire line of the causal attention mask is "False", which screws with the dot product attention. The fix is apparently common in a lot of repos. Ours would be something like:

    def scaled_dot_product_attention(
        self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        scale = 1.0 / math.sqrt(self.config.head_size)
        y = torch.nn.functional.scaled_dot_product_attention(
            q, k, v, attn_mask=(1.0 - mask.to(dtype=q.dtype)) * -10000.0, scale=scale, is_causal=mask is None
        )
        return y.transpose(1, 2)

Instead of doing that, I just switched base.py to right padding, but the more i test, the more it looks like the above is a correct way to address the all False problem.

Another legitimate fix for this particular model seems to just be using the padding token that is assigned, and not using a padding mask at all. Whether or not the model itself defines a padding token might be an indicator that no extra masking is required for the padding...

Do let me know if the last push makes batched inference work on your end!

@FlimFlamm
Copy link
Author

FlimFlamm commented Jan 26, 2024

Applying the large-negative number fix seems to have done the trick; left and right padding now are both equivalent in terms of output for the tinyllama chat modell.

@WilliamGazeley
Copy link

WilliamGazeley commented Jan 26, 2024

Running benchmarks on TinyLlama using the original generate code vs your batch implemetation yields (almost) identical scores now. Well done!

Update
Spoke too soon.. when batching 10 prompts at a time, the score is close to the original code, however batching 25 results in a 10% drop in hellaswag perf.

@FlimFlamm
Copy link
Author

FlimFlamm commented Jan 26, 2024

Running benchmarks on TinyLlama using the original generate code vs your batch implemetation yields (almost) identical scores now. Well done!

Update Spoke too soon.. when batching 10 prompts at a time, the score is close to the original code, however batching 25 results in a 10% drop in hellaswag perf.

Very interesting. A few questions/requests that might help me replicate/track this issue down:

  1. Are all your prompts in a given batch the same? (If so, then i can narrow down some potential causes)

    1.1) If not, try turning off the padding mask (just dont pass it into set_kv) (EoS is 0 according to the generation config).

    1.2) Try using left and right padding (change the parameter from "left" to "right" in base.py where pad_batched_tokens is called)

  2. Are you using a particular prompt formatting, like <|user|> and <|assistant|>? (afaik this is what tinyllama 1b chat is trained for)

  3. Do the results of prompts in a batch's index 0 change for you compared to non-batched inference or the old generate code? (i dont think they should, but am wondering if performance degrades uniformly for all sequences in a batch)

I also wonder if the large-negative-number fix might not be ideally implemented here; im only using negative 10k (most implementations used torch.finfo(dtype).min, but this still resulted in NaN's for me (although maybe i omitted an additional related change))...

Possibly this performance hit is a consequence of batched inference in and of itself? Can't find much about it but maybe?

@WilliamGazeley
Copy link

WilliamGazeley commented Jan 27, 2024

  1. Promts within a batch are different, but the prompts passed to both runs are the same (i.e. single and batched receive the same prompts)
    1.1) Turning off the padding mask destroys the benchmark performance
    1.2) Right padding also results in bad performance
    1.3) Right padding + no mask results in bad performance (expected, but being thorough)

  2. You're right, I'm not using the correct format, but this shouldn't matter because both single and batch generation are not using the format - scores should be the same

  3. The inputs for single (batch_size = 1) and batched (batch_size > 1) are identical, but oddly the outputs are not. Temperature is 0 and random was seeded.

This branch and the upstream are starting to diverge, I'm going to copy your changes into my fork that's up-to-date and continue to dig around.

What are you getting on your end, is doing 10 prompts in a batch the same as the same 10 prompts one at a time?

@FlimFlamm
Copy link
Author

@WilliamGazeley Thanks for the effort on this!

What are you getting on your end, is doing 10 prompts in a batch the same as the same 10 prompts one at a time?

Batched inference does give different outputs for each sequence, which I think is by design. The good news is that the first sequence in the batch is the same as our single unbatched case along with original generate/mask code.

You're right, I'm not using the correct format, but this shouldn't matter because both single and batch generation are not using
the format - scores should be the same

I agree, although assuming there is some small unavoidable performance loss in batched inference cases, I was thinking that an input being more out of its training distribution could amplify the performance degradation. (since at 1B this model is relatively brittle, perhaps that also magnifies the issues we're seeing re: performance)

I'll keep poking at it as well to see what I can come up with (will fire up hellaswag soon as i can top replicate your findings and start hunting from there).

@WilliamGazeley
Copy link

WilliamGazeley commented Jan 29, 2024

Playing around further, I noticed that there's a huge difference in outputs if you change between bf16 and 16-true. This is somewhat expected I guess, but the batched 16-true is closer to the single bf16 than the batched bf16 is to the single bf16 - this is only on my benchmark though.

Also, I think your implementation of scaled_dot_product_attention() breaks training scripts that do not pass input_pos when generating logits. I've updated the function to:

def scaled_dot_product_attention(
        self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        scale = 1.0 / math.sqrt(self.config.head_size)
        y = torch.nn.functional.scaled_dot_product_attention(
            q, 
            k, 
            v, 
            attn_mask=None if mask is None else (1.0 - mask.to(dtype=q.dtype)) * -10000.0, 
            dropout_p=0.0, 
            scale=scale, 
            is_causal=mask is None
        )
        return y.transpose(1, 2)

@Andrei-Aksionov
Copy link
Collaborator

Closing as it was added in #1702

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants