-
-
Notifications
You must be signed in to change notification settings - Fork 283
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to implement paged attention in HF format? #616
Comments
To use the paged mode (flash-attn only), you first need a cache initialized with a batch size of 1 and a length which is some multiple of the page size. The page size is always 256 with the current version of flash-attn. Essentially this cache won't have a shape, just a total capacity. PagedParams is constructed like so: params = ExLlamaV2Attention.PagedParams(
batch_size = batch_size,
block_index = block_index,
cache_seqlens = cache_seqlens,
max_cache_seqlen = cache_seqlens.max().item(),
page_size = 256,
q_len = q_len,
)
So say you have three sequences that are currently 10, 1025 and 320 tokens long, respectively, and you want room in the cache for each to grow by 500 tokens. You're forwarding a single token. That could look like:
So when the forward pass writes the keys/values for position 10, it only touches page 0 in the cache. At the same time it will write position 512+1025, which goes to page 6, etc. It's the Now, there's some choices you could make about how to get to the above point in the first place. You could do one with a shape of Or you just do each sequence in the element as a bsz 1 forward pass. This is what the dynamic generator does and it simplifies things a lot, especially for continuous batching. I.e.:
There's a bunch of fun details about paged attention, such as the fact that the page indices don't need to be contiguous. Also they don't need to be unique, as long as you're not updating the same page twice in a forward pass. The dynamic generator uses both of those details for deduplication and continuous batching, respectively. If you wanted to not have a predefined length max_new_tokens you could allocate pages dynamically during inference. There's nothing that prevents you from adding page 13 after page 1 in the first sequence, or growing the It does of course require some bookkeeping in your generator, and I'm not sure how well that plays together with HF and pipelines and whatnot. |
Okay, I kind of get the concept. I think I want to forward each sequence as a bsz 1 forward pass. Does this means we have to do for-looping each sequence for one big batch forward pass? What about the cache instance? should I make one for each sequence or just make one for all? But, how the cache know which sequence is forwarded with it? |
You use one cache for everything, and it's the One way to go about it would be to start by tokenizing all the prompts in a batch, then constructing the block index based on how many pages each sequence is going to need, including both the prompt and the completion:
Then you run the three individual forward passes to prefill:
It doesn't matter if the block index has extra padding on the right, since it's indexed from the left. And then for each token you pass |
I understand. But, I have another doubt. What about the input mask and position offset? For input mask might be solved because the masking process is done inside flash attention. But, what about position offset? |
You wouldn't use masking or position offsets in paged mode, only a list of sequence lengths, and then the flash-attn kernel handles the rest. This allows all sequences to start at position zero (as long as that corresponds to a page boundary in the cache, as determined by |
So, I just create exllamav2 in HF format and it works well in batch. My code is in #606. Now, I got new problem. Bigger batch means bigger memory usage and mostly is for padding especially if there is different size in token sequence. Could you explain to me how exllamav2 paged attention works in code? I check the code in exllamav2/model.py,
PagedParams
is used but I don't know what to fill into the parameter.The text was updated successfully, but these errors were encountered: