From e3622369d5065b5d65d272c14db372eef4fe49f4 Mon Sep 17 00:00:00 2001 From: Jefferson Fialho Date: Thu, 19 Dec 2024 12:09:17 -0300 Subject: [PATCH] Squash 10235 Signed-off-by: Jefferson Fialho --- .../basic_correctness/test_chunked_prefill.py | 30 +- tests/core/test_chunked_prefill_scheduler.py | 298 ++++++++++++- vllm/config.py | 15 + vllm/core/scheduler.py | 400 +++++++++++++----- vllm/engine/arg_utils.py | 34 +- vllm/model_executor/layers/sampler.py | 8 +- 6 files changed, 654 insertions(+), 131 deletions(-) diff --git a/tests/basic_correctness/test_chunked_prefill.py b/tests/basic_correctness/test_chunked_prefill.py index 469d18a4dd7af..5f90c52481793 100644 --- a/tests/basic_correctness/test_chunked_prefill.py +++ b/tests/basic_correctness/test_chunked_prefill.py @@ -7,7 +7,6 @@ Run `pytest tests/models/test_chunked_prefill.py`. """ import os -from contextlib import nullcontext import pytest @@ -232,7 +231,6 @@ def test_with_prefix_caching( max_num_batched_tokens = max_num_seqs = chunk_size outputs = {} # type: ignore - check_result = True for enable in (True, False): with vllm_runner( model, @@ -244,25 +242,17 @@ def test_with_prefix_caching( enforce_eager=enforce_eager, max_num_seqs=max_num_seqs, ) as vllm_model: - # It should fail when prefix caching is enable and chunk - # size is not a multiple of block size (16). - should_fail = chunk_size % 16 != 0 and enable - check_result &= not should_fail outputs[enable] = [] - # Send the request one-by-one to ensure the cache is populated. - with pytest.raises(ValueError) if should_fail else nullcontext(): - for prompt in full_prompts: - outputs[enable] += vllm_model.generate_greedy([prompt], - max_tokens) - - # Check results only if we did not expect a failure. - if check_result: - check_outputs_equal( - outputs_0_lst=outputs[False], - outputs_1_lst=outputs[True], - name_0="w/o prefix caching", - name_1="with prefix caching", - ) + for prompt in full_prompts: + outputs[enable] += vllm_model.generate_greedy([prompt], + max_tokens) + + check_outputs_equal( + outputs_0_lst=outputs[False], + outputs_1_lst=outputs[True], + name_0="w/o prefix caching", + name_1="with prefix caching", + ) @pytest.mark.parametrize("model", ["facebook/opt-125m"]) diff --git a/tests/core/test_chunked_prefill_scheduler.py b/tests/core/test_chunked_prefill_scheduler.py index eaaf004df38b2..71a203ec8db2a 100644 --- a/tests/core/test_chunked_prefill_scheduler.py +++ b/tests/core/test_chunked_prefill_scheduler.py @@ -5,6 +5,9 @@ from vllm.config import CacheConfig, SchedulerConfig from vllm.core.scheduler import Scheduler +from vllm.engine.arg_utils import EngineArgs +from vllm.engine.llm_engine import LLMEngine +from vllm.sampling_params import SamplingParams from vllm.sequence import Logprob, SequenceGroup from .utils import create_dummy_prompt @@ -14,7 +17,7 @@ def get_sequence_groups(scheduler_output): return [s.seq_group for s in scheduler_output.scheduled_seq_groups] -def append_new_token(seq_group, token_id: int): +def append_new_token(seq_group: SequenceGroup, token_id: int): for seq in seq_group.get_seqs(): seq.append_token_id(token_id, {token_id: Logprob(token_id)}) @@ -121,6 +124,214 @@ def test_chunk(): assert out.num_batched_tokens == 57 +def test_concurrent_chunking(): + """Verify prefills are chunked properly when + --max-num-partial-prefills is > 1""" + block_size = 4 + max_seqs = 60 + max_model_len = 2000 + max_num_batched_tokens = 64 + scheduler_config = SchedulerConfig( + "generate", + max_num_batched_tokens, + max_seqs, + max_model_len, + enable_chunked_prefill=True, + max_num_partial_prefills=2, # Up to 2 partial prefills at a time + ) + cache_config = CacheConfig(block_size, 1.0, 1, "auto") + cache_config.num_cpu_blocks = 32 + cache_config.num_gpu_blocks = 32 + scheduler = Scheduler(scheduler_config, cache_config, None) + running: List[SequenceGroup] = [] + + # Add seq groups to scheduler. + for i in range(2): + _, seq_group = create_dummy_prompt(str(i), + prompt_length=60, + block_size=block_size) + scheduler.add_seq_group(seq_group) + running.append(seq_group) + + # Verify both requests are chunked with half of max_num_batched_tokens each + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + assert set(get_sequence_groups(out)) == set(running) + assert seq_group_meta[0].token_chunk_size == 32 + assert seq_group_meta[1].token_chunk_size == 32 + assert out.num_prefill_groups == 2 + assert out.num_batched_tokens == 64 + + # After one iteration, both should have 60 - 32 = 28 tokens left to prefill + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + assert set(get_sequence_groups(out)) == set(running) + assert seq_group_meta[0].token_chunk_size == 28 + assert seq_group_meta[1].token_chunk_size == 28 + assert out.num_prefill_groups == 2 + assert out.num_batched_tokens == 56 + + +def test_concurrent_chunking_large_requests(): + """Verify large prefill requests are run one at a time""" + block_size = 4 + max_seqs = 60 + max_model_len = 2000 + max_num_batched_tokens = 64 + scheduler_config = SchedulerConfig( + "generate", + max_num_batched_tokens, + max_seqs, + max_model_len, + enable_chunked_prefill=True, + max_num_partial_prefills=2, # Up to 2 partial prefills at a time + ) + cache_config = CacheConfig(block_size, 1.0, 1, "auto") + cache_config.num_cpu_blocks = 3200 # large KV cache size for large requests + cache_config.num_gpu_blocks = 3200 + scheduler = Scheduler(scheduler_config, cache_config, None) + + # Add seq groups to scheduler. + for i in range(2): + _, seq_group = create_dummy_prompt( + str(i), + prompt_length=1200, # Very large prompt + block_size=block_size) + scheduler.add_seq_group(seq_group) + + # Verify only a single request is chunked, and it gets all 64 tokens + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + assert len(get_sequence_groups(out)) == 1 + assert seq_group_meta[0].token_chunk_size == 64 + assert out.num_prefill_groups == 1 + assert out.num_batched_tokens == 64 + + +def test_short_prompts_jump_long_prompts_in_queue(): + """Verify large prefill requests are punted behind smaller ones if + another large prefill request is already running""" + block_size = 4 + max_seqs = 60 + max_model_len = 2000 + max_num_batched_tokens = 64 + scheduler_config = SchedulerConfig( + "generate", + max_num_batched_tokens, + max_seqs, + max_model_len, + enable_chunked_prefill=True, + max_num_partial_prefills=2, # Up to 2 partial prefills at a time + ) + cache_config = CacheConfig(block_size, 1.0, 1, "auto") + cache_config.num_cpu_blocks = 3200 # large KV cache size for large requests + cache_config.num_gpu_blocks = 3200 + scheduler = Scheduler(scheduler_config, cache_config, None) + running: List[SequenceGroup] = [] + + # Add 2 large seq groups to scheduler. + for i in range(2): + _, seq_group = create_dummy_prompt( + str(i), + prompt_length=1200, # Very large prompt + block_size=block_size) + scheduler.add_seq_group(seq_group) + running.append(seq_group) + assert seq_group.is_prefill() + + # Add 2 small seq groups behind them + for i in range(2): + _, seq_group = create_dummy_prompt( + str(i + 2), + prompt_length=40, # Very small prompt + block_size=block_size) + scheduler.add_seq_group(seq_group) + running.append(seq_group) + assert seq_group.is_prefill() + + # Verify one large req and 1 small req chunked + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + assert seq_group_meta[0].token_chunk_size == 32 # large req gets 32 tokens + assert seq_group_meta[1].token_chunk_size == 32 # small req gets 32 tokens + + # all 4 are prefilling + assert running[0].is_prefill() + assert running[1].is_prefill() + assert running[2].is_prefill() + assert running[3].is_prefill() + + assert out.num_prefill_groups == 2 + assert out.num_batched_tokens == 64 + + # in the second iteration, + # the first small request had only 8 tokens left + # so it went to decode + # The other small req is scheduled + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + # the new small req got 64 - (32+8) tokens + assert (seq_group_meta[0].token_chunk_size == 24) + assert seq_group_meta[1].token_chunk_size == 32 # large req still got 32 + # the other small request had only 8 tokens left + assert seq_group_meta[2].token_chunk_size == 8 # 40-32 + + # notice the small request got to decode now + # this is because of max_num_partial_prefills logic + assert running[0].is_prefill() + assert running[1].is_prefill() + assert not running[2].is_prefill() + assert running[3].is_prefill() + + assert out.num_prefill_groups == 3 + assert out.num_batched_tokens == 64 + # the small seq group has a new token appended. + append_new_token(running[2], 1) + + # in the third iteration, + # the first small request has entered decode + # and other small req had 16 tokens left + # so it went to decode + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + assert seq_group_meta[0].token_chunk_size == 32 # large still got 32 + # small req prefilled 40-24=16 tokens + assert (seq_group_meta[1].token_chunk_size == 16) + assert seq_group_meta[2].token_chunk_size == 1 # decode + assert out.num_prefill_groups == 2 + assert out.num_batched_tokens == 49 # (32+16+1 decode) + + # both small requests have now reached decode + assert running[0].is_prefill() + assert running[1].is_prefill() + assert not running[2].is_prefill() + assert not running[3].is_prefill() + + # the small seq group has a new token appended. + append_new_token(running[2], 1) + + # in the fourth iteration, both small requests are decoding + # so large request gets all the budget + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + # large req gets 63 tokens (minus 1 for decode) + assert seq_group_meta[0].token_chunk_size == 63 + assert seq_group_meta[1].token_chunk_size == 1 # decode + assert out.num_prefill_groups == 1 + assert out.num_batched_tokens == 64 + + assert running[0].is_prefill() + assert running[1].is_prefill() + assert not running[2].is_prefill() + assert not running[3].is_prefill() + + # both the small seq groups have a new token appended + append_new_token(running[2], 1) + append_new_token(running[3], 1) + + # in the fifth iteration, large request gets all the budget + # while both small requests are decoding + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + assert seq_group_meta[0].token_chunk_size == 62 + assert seq_group_meta[1].token_chunk_size == 1 # decode + assert seq_group_meta[2].token_chunk_size == 1 # decode + assert out.num_prefill_groups == 1 + assert out.num_batched_tokens == 64 + + def test_complex(): block_size = 4 max_seqs = 60 @@ -506,7 +717,7 @@ def test_chunked_prefill_max_seqs(): assert not running[1].is_prefill() -def test_perfix_caching(): +def test_prefix_caching(): """Verify allocating full blocks when prefix caching is enabled.""" block_size = 4 max_seqs = 10 @@ -546,3 +757,86 @@ def test_perfix_caching(): assert seq_group_meta[1].token_chunk_size == 12 assert out.num_prefill_groups == 2 assert out.num_batched_tokens == 62 + + +def test_prefix_caching_with_concurrent_partial_prefills(): + """Verify allocating full blocks when prefix caching is enabled with + --max-num-partial-prefills > 1.""" + block_size = 4 + max_seqs = 10 + max_model_len = 8000 + max_num_batched_tokens = 60 # With two slots, each slot will get 30 tokens + scheduler_config = SchedulerConfig("generate", + max_num_batched_tokens, + max_seqs, + max_model_len, + enable_chunked_prefill=True, + max_num_partial_prefills=2) + cache_config = CacheConfig(block_size, + 1.0, + 1, + "auto", + enable_prefix_caching=True) + cache_config.num_cpu_blocks = 0 + cache_config.num_gpu_blocks = 32 + scheduler = Scheduler(scheduler_config, cache_config, None) + running: List[SequenceGroup] = [] + + # Add seq groups to scheduler. + for i in range(2): + _, seq_group = create_dummy_prompt(str(i), + block_size=block_size, + prompt_length=50) + scheduler.add_seq_group(seq_group) + running.append(seq_group) + + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + assert set(get_sequence_groups(out)) == set(running) + # To partially prefill both sequences, both can chunk up to 30 tokens + # But the next lowest multiple of the block size (4) is 28 + assert seq_group_meta[0].token_chunk_size == 28 + assert seq_group_meta[1].token_chunk_size == 28 + assert out.num_prefill_groups == 2 + assert out.num_batched_tokens == 56 + + # On the next iteration, both sequences should finish prefill + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + assert set(get_sequence_groups(out)) == set(running) + # Both sequences have 50 - 28 = 22 tokens left to prefill. + # This is not a multiple of the block size, but we don't care since we don't + # cache the final partial block of prefix sequences + assert seq_group_meta[0].token_chunk_size == 22 + assert seq_group_meta[1].token_chunk_size == 22 + assert out.num_prefill_groups == 2 + assert out.num_batched_tokens == 44 + + +@pytest.mark.parametrize("model", ["facebook/opt-125m"]) +@pytest.mark.parametrize("max_num_partial_prefills", [2, 4, 8]) +def test_chunked_prefill_with_actual_engine(model: str, + max_num_partial_prefills: int): + """Make sure the model can actually sample with concurrent + partial prefills + """ + + prompt = "hello" * 40 + + engine_args = EngineArgs( + model=model, + max_num_partial_prefills=max_num_partial_prefills, + max_num_batched_tokens=40, + max_num_seqs=8, + enable_chunked_prefill=True, + gpu_memory_utilization=0.8, + ) + + engine = LLMEngine.from_engine_args(engine_args) + sampling_params = SamplingParams(temperature=0) + + for req_num in range(max_num_partial_prefills): + engine.add_request(f"{req_num}", prompt, sampling_params) + # first step + request_outputs = engine.step() + # means all are prefilling + assert len(request_outputs) == 0 + assert len(engine.scheduler[0].running) == max_num_partial_prefills diff --git a/vllm/config.py b/vllm/config.py index 0e886e18fcd6d..10adc38d70c22 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1322,6 +1322,21 @@ class SchedulerConfig: # Maximum length of a sequence (including prompt and generated text). max_model_len: int = 8192 + # Maximum number of sequences that can be partially prefilled concurrently + max_num_partial_prefills: int = 1 + + # Maximum number of "very long prompt" sequences that can be prefilled + # concurrently (long is defined by long_prefill_threshold) + max_long_partial_prefills: int = 1 + + # Set a percentage of the context length that determines which + # sequences are considered "long" + long_prefill_threshold: float = 0.04 + + # calculate context length that determines which sequences are + # considered "long" + long_prefill_token_threshold = int(max_model_len * long_prefill_threshold) + # The number of slots to allocate per sequence per # step, beyond the known token ids. This is used in speculative # decoding to store KV activations of tokens which may or may not be diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index c3bc6becf0995..02220153f6cc6 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -15,7 +15,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import (Sequence, SequenceData, SequenceGroup, SequenceGroupMetadata, SequenceGroupMetadataDelta, - SequenceStatus) + SequenceStage, SequenceStatus) from vllm.utils import Device, PyObjectCache logger = init_logger(__name__) @@ -37,6 +37,7 @@ class PreemptionMode(enum.Enum): recompute them when the sequences are resumed, treating the sequences as new prompts. """ + SWAP = enum.auto() RECOMPUTE = enum.auto() @@ -52,6 +53,7 @@ class SchedulingBudget: happen if we only have chunked prefill scheduling, we can remove this feature from the API when chunked prefill is enabled by default. """ + token_budget: int max_num_seqs: int _request_ids_num_batched_tokens: Set[str] = field(default_factory=set) @@ -130,6 +132,7 @@ class ScheduledSequenceGroup: @dataclass class SchedulerOutputs: """The scheduling decision made from a scheduler.""" + # Scheduled sequence groups. scheduled_seq_groups: GenericSequence[ScheduledSequenceGroup] # Number of prefill groups scheduled. @@ -203,6 +206,7 @@ class SchedulerRunningOutputs: Could contain prefill (prefill that's chunked) or decodes. If there's not enough memory, it can be preempted (for recompute) or swapped out. """ + # Selected sequences that are running and in a decoding phase. decode_seq_groups: List[ScheduledSequenceGroup] # Selected sequences that are running and in a prefill phase. @@ -244,6 +248,7 @@ class SchedulerSwappedInOutputs: Could contain prefill (prefill that's chunked) or decodes. """ + # Selected sequences that are going to be swapped in and is in a # decoding phase. decode_seq_groups: List[ScheduledSequenceGroup] @@ -278,6 +283,7 @@ class SchedulerPrefillOutputs: Could contain a fresh prefill requests or preempted requests that need to be recomputed from scratch. """ + # Selected sequences for prefill. seq_groups: List[ScheduledSequenceGroup] # Ignored sequence groups. @@ -294,23 +300,27 @@ def create_empty(cls) -> "SchedulerPrefillOutputs": def seq_group_metadata_builder(): - return SequenceGroupMetadata(request_id="", - is_prompt=False, - seq_data={}, - sampling_params=None, - block_tables={}) + return SequenceGroupMetadata( + request_id="", + is_prompt=False, + seq_data={}, + sampling_params=None, + block_tables={}, + ) def scheduler_running_outputs_builder(): - return SchedulerRunningOutputs(decode_seq_groups=[], - prefill_seq_groups=[], - preempted=[], - swapped_out=[], - blocks_to_swap_out=[], - blocks_to_copy=[], - num_lookahead_slots=0, - prefill_seq_groups_list=[], - decode_seq_groups_list=[]) + return SchedulerRunningOutputs( + decode_seq_groups=[], + prefill_seq_groups=[], + preempted=[], + swapped_out=[], + blocks_to_swap_out=[], + blocks_to_copy=[], + num_lookahead_slots=0, + prefill_seq_groups_list=[], + decode_seq_groups_list=[], + ) def scheduled_seq_group_builder(): @@ -319,6 +329,99 @@ def scheduled_seq_group_builder(): # return ScheduledSequenceGroup(seq_group=None, token_chunk_size=0) +@dataclass +class PartialPrefillMetadata: + """Holds information about the partial prefills that are currently running + during a single iteration of the Scheduler. + When chunked prefill is enabled, we allow a certain number of seqs to be + partially prefilled during each iteration. Having multiple partial prefills + in flight allows us to minimize TTFT and avoid decode starvation in cases + where a single sequence group with a very large prompt blocks the queue for + too many iterations. + The number of long prefill requests is limited so that smaller + requests may jump the queue in front of them and get to the decode + phase faster. + """ + + # A minimum bound on the total number of prefills running during this + # scheduling step + partial_prefills: int + + # The number of long prefill requests currently running + long_partial_prefills: int + + scheduler_config: SchedulerConfig + + def cannot_schedule(self, seq_group: SequenceGroup) -> bool: + """When concurrent partial prefills are enabled, + we limit the number of long requests and only accept + shorter requests from the queue while running them + concurrently""" + return (seq_group.first_seq.get_num_new_tokens() > + self.scheduler_config.long_prefill_token_threshold + and self.long_partial_prefills >= + self.scheduler_config.max_long_partial_prefills + and self.scheduler_config.max_num_partial_prefills > 1) + + def increment_partial_prefills(self, seq_group: SequenceGroup) -> None: + # When a new prefill is scheduled, we need to know if it is a + # long request + if (seq_group.first_seq.get_num_new_tokens() > + self.scheduler_config.long_prefill_token_threshold): + self.long_partial_prefills += 1 + + @classmethod + def from_queues( + cls, + running: Deque[SequenceGroup], + waiting: Deque[SequenceGroup], + scheduler_config: SchedulerConfig, + ) -> "PartialPrefillMetadata": + """Create a PartialPrefillMetadata object from the current state of + the scheduler's queues. + This accounts for the currently running prefill requests, and peeks into + the waiting queue to see if there are more prefills to potentially be + scheduled during this iteration.""" + partial_prefills = 0 + long_partial_prefills = 0 + + waiting_partial_prefills = 0 + waiting_long_prefills = 0 + + for sg in running: + # TODO: Check if this stage is correctly updated before scheduling + if sg.first_seq.data.stage == SequenceStage.PREFILL: + partial_prefills += 1 + if (sg.first_seq.get_num_new_tokens() > + scheduler_config.long_prefill_token_threshold): + long_partial_prefills += 1 + + for sg in waiting: + # Don't bother looping through the rest of the queue if we know + # there are already at + # least max_partial_prefills requests to fill + if (partial_prefills + waiting_partial_prefills >= + scheduler_config.max_num_partial_prefills): + break + + # Don't count long requests from the waiting queue if we aren't + # going to schedule them anyway + if (sg.first_seq.get_num_new_tokens() > + scheduler_config.long_prefill_token_threshold): + if (long_partial_prefills + waiting_long_prefills >= + scheduler_config.max_long_partial_prefills): + continue + waiting_long_prefills += 1 + waiting_partial_prefills += 1 + + return PartialPrefillMetadata( + partial_prefills=min(partial_prefills + waiting_partial_prefills, + scheduler_config.max_num_partial_prefills), + long_partial_prefills=long_partial_prefills, + scheduler_config=scheduler_config, + ) + + class Scheduler: def __init__( @@ -358,7 +461,8 @@ def __init__( num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=num_cpu_blocks, sliding_window=self.cache_config.sliding_window, - enable_caching=self.cache_config.enable_prefix_caching) + enable_caching=self.cache_config.enable_prefix_caching, + ) # Sequence groups in the WAITING state. # Contain new prefill or preempted requests. @@ -419,6 +523,18 @@ def __init__( # for processing and deallocation by the free_finished_seq_groups() self._async_stopped: List[SequenceGroup] = [] + # List with the chunk sizes to hand out to each sequence depending + # on how many partial prefills are running. This is slightly faster than + # running an integer division every time a prefill is scheduled. + # This splits the budget evenly among all prefills. + self.partial_prefill_budget_lookup_list = [0] * ( + self.scheduler_config.max_num_partial_prefills + 1) + self.partial_prefill_budget_lookup_list[0] = ( + scheduler_config.max_num_batched_tokens) + for i in range(1, self.scheduler_config.max_num_partial_prefills + 1): + self.partial_prefill_budget_lookup_list[i] = ( + scheduler_config.max_num_batched_tokens // i) + @property def next_cache_id(self): return (self.cache_id + 1) % self.num_cache_iters @@ -498,8 +614,8 @@ def _free_seq_group_cross_attn_blocks( self.block_manager.free_cross(seq_group) def has_unfinished_seqs(self) -> bool: - return len(self.waiting) != 0 or len(self.running) != 0 or len( - self.swapped) != 0 + return (len(self.waiting) != 0 or len(self.running) != 0 + or len(self.swapped) != 0) def get_prefix_cache_hit_rate(self, device: Device) -> float: return self.block_manager.get_prefix_cache_hit_rate(device) @@ -518,6 +634,7 @@ def _schedule_running( budget: SchedulingBudget, curr_loras: Optional[Set[int]], enable_chunking: bool = False, + partial_prefill_metadata: Optional[PartialPrefillMetadata] = None, ) -> SchedulerRunningOutputs: """Schedule sequence groups that are running. @@ -532,12 +649,14 @@ def _schedule_running( chunked number of tokens are scheduled if `budget.num_batched_tokens` has not enough capacity to schedule all tokens. - + partial_prefill_metadata: information about the partial prefills + that are currently running + Returns: SchedulerRunningOutputs. """ - ret: SchedulerRunningOutputs = \ - self._scheduler_running_outputs_cache[self.cache_id].get_object() + ret: SchedulerRunningOutputs = self._scheduler_running_outputs_cache[ + self.cache_id].get_object() ret.blocks_to_swap_out.clear() ret.blocks_to_copy.clear() ret.decode_seq_groups.clear() @@ -572,10 +691,14 @@ def _schedule_running( # 2. If a sequence is running with non-chunked prefill, then # there it's a decoding sequence, and the cached tokens info is # irrelevant. - num_uncached_new_tokens, _ = ( + num_uncached_new_tokens, _ = \ self._get_num_new_uncached_and_cached_tokens( - seq_group, SequenceStatus.RUNNING, enable_chunking, - budget)) + seq_group, + SequenceStatus.RUNNING, + enable_chunking, + budget, + partial_prefill_metadata, + ) num_running_tokens = num_uncached_new_tokens if num_running_tokens == 0: @@ -588,8 +711,8 @@ def _schedule_running( # to process the final tokens. The check below avoids this extra # decode run when the model max len is reached, in order to avoid # a memory overflow. - if self.use_async_output_proc and seq_group.seqs[0].get_len( - ) > self.scheduler_config.max_model_len: + if (self.use_async_output_proc and seq_group.seqs[0].get_len() > + self.scheduler_config.max_model_len): self._async_stopped.append(seq_group) continue @@ -648,8 +771,9 @@ def _schedule_running( self._append_slots(seq_group, blocks_to_copy, enable_chunking) is_prefill = seq_group.is_prefill() - scheduled_seq_group: ScheduledSequenceGroup = \ - self._scheduled_seq_group_cache[self.cache_id].get_object() + scheduled_seq_group: ScheduledSequenceGroup = ( + self._scheduled_seq_group_cache[ + self.cache_id].get_object()) scheduled_seq_group.seq_group = seq_group if is_prefill: scheduled_seq_group.token_chunk_size = num_running_tokens @@ -726,7 +850,8 @@ def _schedule_swapped( logger.warning( "Failing the request %s because there's not enough kv " "cache blocks to run the entire sequence.", - seq_group.request_id) + seq_group.request_id, + ) for seq in seq_group.get_seqs(): seq.status = SequenceStatus.FINISHED_IGNORED infeasible_seq_groups.append(seq_group) @@ -796,16 +921,17 @@ def _schedule_swapped( ) def _get_prompt_limit(self, seq_group: SequenceGroup) -> int: - if self.scheduler_config.chunked_prefill_enabled and \ - not self.scheduler_config.is_multi_step: + if (self.scheduler_config.chunked_prefill_enabled + and not self.scheduler_config.is_multi_step): prompt_limit = self.scheduler_config.max_model_len else: - prompt_limit = min(self.scheduler_config.max_model_len, - self.scheduler_config.max_num_batched_tokens) + prompt_limit = min( + self.scheduler_config.max_model_len, + self.scheduler_config.max_num_batched_tokens, + ) # Model is fine tuned with long context. Return the fine tuned max_len. - if (seq_group.lora_request - and seq_group.lora_request.long_lora_max_len): + if seq_group.lora_request and seq_group.lora_request.long_lora_max_len: assert prompt_limit <= seq_group.lora_request.long_lora_max_len return seq_group.lora_request.long_lora_max_len else: @@ -813,7 +939,7 @@ def _get_prompt_limit(self, seq_group: SequenceGroup) -> int: def _get_priority(self, seq_group: SequenceGroup) -> Tuple[Optional[int], float]: - """ Get the priority of the sequence group. + """Get the priority of the sequence group. Highest preference to user-defined priority, followed by arrival time. Args: seq_group: The sequence group input. @@ -846,14 +972,14 @@ def _schedule_priority_preemption( if waiting_queue: seq_group = waiting_queue.popleft() num_new_seqs = seq_group.get_max_num_running_seqs() - num_new_tokens_uncached, _ = ( + num_new_tokens_uncached, _ = \ self._get_num_new_uncached_and_cached_tokens( - seq_group, SequenceStatus.WAITING, False, budget)) + seq_group, SequenceStatus.WAITING, False, budget) - #Only preempt if priority inversion exists + # Only preempt if priority inversion exists while running_queue and self._get_priority( running_queue[-1]) > self._get_priority(seq_group): - #Only preempt if waiting sequence cannot be allocated + # Only preempt if waiting sequence cannot be allocated can_allocate = self.block_manager.can_allocate(seq_group) if (num_new_tokens_uncached > 0 and can_allocate == AllocStatus.OK @@ -863,7 +989,7 @@ def _schedule_priority_preemption( )): break - #Adjust budget to remove the victim sequence group + # Adjust budget to remove the victim sequence group vseq_group = running_queue.pop() num_running_tokens_uncached, _ = ( self._get_num_new_uncached_and_cached_tokens( @@ -874,11 +1000,11 @@ def _schedule_priority_preemption( budget.subtract_num_seqs(vseq_group.request_id, num_running_seqs) - #Preempt out the victim sequence group + # Preempt out the victim sequence group self._preempt(vseq_group, blocks_to_swap_out) waiting_queue.appendleft(vseq_group) force_preemption_count += 1 - #Put the sequence back into the waiting queue + # Put the sequence back into the waiting queue waiting_queue.appendleft(seq_group) waiting_queue = deque(sorted(waiting_queue, key=self._get_priority)) @@ -892,6 +1018,7 @@ def _schedule_prefills( budget: SchedulingBudget, curr_loras: Optional[Set[int]], enable_chunking: bool = False, + partial_prefill_metadata: Optional[PartialPrefillMetadata] = None, ) -> SchedulerPrefillOutputs: """Schedule sequence groups that are in prefill stage. @@ -912,10 +1039,20 @@ def _schedule_prefills( chunked number of tokens are scheduled if `budget.num_batched_tokens` has not enough capacity to schedule all tokens. + partial_prefill_metadata: information about the partial prefills + that are currently running Returns: SchedulerPrefillOutputs. """ + if budget.remaining_token_budget() == 0: + # Do nothing: Can't add any more prefill anyway + return SchedulerPrefillOutputs( + seq_groups=[], + ignored_seq_groups=[], + num_lookahead_slots=self._get_num_lookahead_slots( + is_prefill=True, enable_chunking=enable_chunking), + ) ignored_seq_groups: List[SequenceGroup] = [] seq_groups: List[ScheduledSequenceGroup] = [] @@ -929,10 +1066,19 @@ def _schedule_prefills( assert len(waiting_seqs) == 1, ( "Waiting sequence group should have only one prompt " "sequence.") + if (partial_prefill_metadata is not None + and partial_prefill_metadata.cannot_schedule(seq_group)): + leftover_waiting_sequences.appendleft(seq_group) + waiting_queue.popleft() + continue num_new_tokens_uncached, num_new_tokens_cached = ( self._get_num_new_uncached_and_cached_tokens( - seq_group, SequenceStatus.WAITING, enable_chunking, - budget)) + seq_group, + SequenceStatus.WAITING, + enable_chunking, + budget, + partial_prefill_metadata=partial_prefill_metadata, + )) num_new_tokens = num_new_tokens_uncached + num_new_tokens_cached if not enable_chunking: @@ -943,7 +1089,10 @@ def _schedule_prefills( if num_new_tokens > prompt_limit: logger.warning( "Input prompt (%d tokens) is too long" - " and exceeds limit of %d", num_new_tokens, prompt_limit) + " and exceeds limit of %d", + num_new_tokens, + prompt_limit, + ) for seq in waiting_seqs: seq.status = SequenceStatus.FINISHED_IGNORED ignored_seq_groups.append(seq_group) @@ -964,7 +1113,9 @@ def _schedule_prefills( logger.warning( "Input prompt (%d tokens) + lookahead slots (%d) is " "too long and exceeds the capacity of block_manager", - num_new_tokens, num_lookahead_slots) + num_new_tokens, + num_lookahead_slots, + ) for seq in waiting_seqs: seq.status = SequenceStatus.FINISHED_IGNORED ignored_seq_groups.append(seq_group) @@ -1005,6 +1156,9 @@ def _schedule_prefills( waiting_queue.popleft() self._allocate_and_set_running(seq_group) + if partial_prefill_metadata is not None: + partial_prefill_metadata.increment_partial_prefills(seq_group) + if enable_chunking and self.scheduler_config.is_multi_step: blocks_to_copy: List[Tuple[int, int]] = [] # init_multi_step_from_lookahead_slots happens in append_slots @@ -1020,7 +1174,8 @@ def _schedule_prefills( num_scheduler_steps=self.scheduler_config. num_scheduler_steps, is_multi_step=self.scheduler_config.is_multi_step, - enable_chunking=enable_chunking) + enable_chunking=enable_chunking, + ) seq_groups.append( ScheduledSequenceGroup(seq_group=seq_group, @@ -1041,11 +1196,12 @@ def _schedule_prefills( seq_groups=seq_groups, ignored_seq_groups=ignored_seq_groups, num_lookahead_slots=self._get_num_lookahead_slots( - is_prefill=True, enable_chunking=enable_chunking)) + is_prefill=True, enable_chunking=enable_chunking), + ) def _schedule_default(self) -> SchedulerOutputs: """Schedule queued requests. - + The current policy is designed to optimize the throughput. First, it batches as many prefill requests as possible. And it schedules decodes. If there's a pressure on GPU memory, decode requests can @@ -1061,9 +1217,9 @@ def _schedule_default(self) -> SchedulerOutputs: for seq_group in self.running: budget.add_num_seqs(seq_group.request_id, seq_group.get_max_num_running_seqs()) - curr_loras = set( + curr_loras = (set( seq_group.lora_int_id for seq_group in self.running - if seq_group.lora_int_id > 0) if self.lora_enabled else None + if seq_group.lora_int_id > 0) if self.lora_enabled else None) prefills = SchedulerPrefillOutputs.create_empty() running_scheduled = SchedulerRunningOutputs.create_empty() @@ -1089,12 +1245,13 @@ def _schedule_default(self) -> SchedulerOutputs: # If any sequence group is preempted, do not swap in any sequence # group. because it means there's no slot for new running requests. - if len(running_scheduled.preempted) + len( - running_scheduled.swapped_out) == 0: - swapped_in = self._schedule_swapped(budget, curr_loras) + if (len(running_scheduled.preempted) + + len(running_scheduled.swapped_out) == 0): + swapped_in = \ + self._schedule_swapped(budget, curr_loras) - assert (budget.num_batched_tokens <= - self.scheduler_config.max_num_batched_tokens) + assert budget.num_batched_tokens <= \ + self.scheduler_config.max_num_batched_tokens assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs # Update waiting requests. @@ -1111,8 +1268,8 @@ def _schedule_default(self) -> SchedulerOutputs: # Update swapped requests. self.swapped.extend(running_scheduled.swapped_out) - preempted = (len(running_scheduled.preempted) + - len(running_scheduled.swapped_out)) + preempted = len(running_scheduled.preempted) + len( + running_scheduled.swapped_out) # There should be no prefill from running queue because this policy # doesn't allow chunked prefills. @@ -1150,7 +1307,7 @@ def _schedule_default(self) -> SchedulerOutputs: def _schedule_chunked_prefill(self) -> SchedulerOutputs: """Schedule queued requests. - + Chunked prefill allows to chunk prefill requests, batch them together with decode requests. This policy 1. schedule as many decoding requests as possible. 2. schedule chunked prefill requests that are not @@ -1171,10 +1328,20 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs: prefills = SchedulerPrefillOutputs.create_empty() swapped_in = SchedulerSwappedInOutputs.create_empty() + # Create partial prefill metadata + partial_prefill_metadata = PartialPrefillMetadata.from_queues( + running=self.running, + waiting=self.waiting, + scheduler_config=self.scheduler_config, + ) + # Decoding should be always scheduled first by fcfs. - running_scheduled = self._schedule_running(budget, - curr_loras, - enable_chunking=True) + running_scheduled = self._schedule_running( + budget, + curr_loras, + enable_chunking=True, + partial_prefill_metadata=partial_prefill_metadata, + ) # Schedule swapped out requests. # If preemption happens, it means we don't have space for swap-in. @@ -1182,12 +1349,15 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs: running_scheduled.swapped_out) == 0: swapped_in = self._schedule_swapped(budget, curr_loras) - prefills = self._schedule_prefills(budget, - curr_loras, - enable_chunking=True) + prefills = self._schedule_prefills( + budget, + curr_loras, + enable_chunking=True, + partial_prefill_metadata=partial_prefill_metadata, + ) - assert (budget.num_batched_tokens <= - self.scheduler_config.max_num_batched_tokens) + assert budget.num_batched_tokens <= \ + self.scheduler_config.max_num_batched_tokens assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs # Update waiting requests. @@ -1221,7 +1391,7 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs: # If all prompts, then we set num_lookahead_slots to 0 # this allows us to go through the `no_spec` path in # `spec_decode_worker.py` - all_prefills = (len(scheduled_seq_groups) == num_prefill_groups) + all_prefills = len(scheduled_seq_groups) == num_prefill_groups num_lookahead_slots = (0 if (all_prefills and not self.scheduler_config.is_multi_step) @@ -1355,8 +1525,8 @@ def schedule( # NOTE: We use get_len instead of get_prompt_len because when # a sequence is preempted, prefill includes previous generated # output tokens. - if (token_chunk_size + num_computed_tokens < - seqs[0].data.get_len()): + if token_chunk_size + num_computed_tokens < seqs[ + 0].data.get_len(): do_sample = False # It assumes the scheduled_seq_groups is ordered by @@ -1381,10 +1551,12 @@ def schedule( # between engine and worker. # the subsequent comms can still use delta, but # `multi_modal_data` will be None. - multi_modal_data=seq_group.multi_modal_data - if scheduler_outputs.num_prefill_groups > 0 else None, - multi_modal_placeholders=seq_group.multi_modal_placeholders - if scheduler_outputs.num_prefill_groups > 0 else None, + multi_modal_data=(seq_group.multi_modal_data if + scheduler_outputs.num_prefill_groups > 0 + else None), + multi_modal_placeholders=( + seq_group.multi_modal_placeholders + if scheduler_outputs.num_prefill_groups > 0 else None), mm_processor_kwargs=seq_group.mm_processor_kwargs, prompt_adapter_request=seq_group.prompt_adapter_request, ) @@ -1490,10 +1662,12 @@ def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None: for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): seq.status = SequenceStatus.RUNNING - def _append_slots(self, - seq_group: SequenceGroup, - blocks_to_copy: List[Tuple[int, int]], - enable_chunking: bool = False) -> None: + def _append_slots( + self, + seq_group: SequenceGroup, + blocks_to_copy: List[Tuple[int, int]], + enable_chunking: bool = False, + ) -> None: """Appends new slots to the sequences in the given sequence group. Args: @@ -1514,7 +1688,8 @@ def _append_slots(self, num_lookahead_slots, num_scheduler_steps=self.scheduler_config.num_scheduler_steps, is_multi_step=self.scheduler_config.is_multi_step, - enable_chunking=enable_chunking) + enable_chunking=enable_chunking, + ) seq_status: Optional[SequenceStatus] = SequenceStatus.RUNNING if self.scheduler_config.is_multi_step and enable_chunking: @@ -1557,8 +1732,11 @@ def _preempt(self, seq_group: SequenceGroup, "not enough KV cache space. This can affect the end-to-end " "performance. Increase gpu_memory_utilization or " "tensor_parallel_size to provide more KV cache memory. " - "total_num_cumulative_preemption=%d", seq_group.request_id, - preemption_mode, self.num_cumulative_preemption + 1) + "total_num_cumulative_preemption=%d", + seq_group.request_id, + preemption_mode, + self.num_cumulative_preemption + 1, + ) self.num_cumulative_preemption += 1 if preemption_mode == PreemptionMode.RECOMPUTE: @@ -1621,10 +1799,9 @@ def _passed_delay(self, now: float) -> bool: if self.scheduler_config.delay_factor > 0 and self.waiting: earliest_arrival_time = min( [e.metrics.arrival_time for e in self.waiting]) - passed_delay = ( - (now - earliest_arrival_time) > - (self.scheduler_config.delay_factor * self.last_prompt_latency) - or not self.running) + passed_delay = (now - earliest_arrival_time) > ( + self.scheduler_config.delay_factor * + self.last_prompt_latency) or not self.running else: passed_delay = True return passed_delay @@ -1664,6 +1841,7 @@ def _get_num_new_uncached_and_cached_tokens( status: SequenceStatus, enable_chunking: bool, budget: SchedulingBudget, + partial_prefill_metadata: Optional[PartialPrefillMetadata] = None, ) -> Tuple[int, int]: """ Returns the number of new uncached and cached tokens to schedule for a @@ -1687,6 +1865,8 @@ def _get_num_new_uncached_and_cached_tokens( to schedule. enable_chunking: Whether to chunk the number of tokens to compute. budget: The budget to chunk the number of tokens to compute. + partial_prefill_metadata: information about the partial prefills + that are currently running Returns: @@ -1764,6 +1944,8 @@ def _get_num_new_uncached_and_cached_tokens( budget, self._get_prompt_limit(seq_group), num_uncached_new_tokens, + self.partial_prefill_budget_lookup_list, + partial_prefill_metadata, ) return num_uncached_new_tokens, num_cached_new_tokens @@ -1775,6 +1957,8 @@ def _chunk_new_tokens_to_schedule( budget: SchedulingBudget, prompt_limit: int, num_new_tokens: int, + partial_prefill_budget_lookup_list: List[int], + partial_prefill_metadata: Optional[PartialPrefillMetadata] = None, ) -> int: """ Chunks the number of new tokens to schedule based on the budget when @@ -1807,29 +1991,31 @@ def _chunk_new_tokens_to_schedule( # the sequence. return num_new_tokens - return (0 if num_new_tokens > remaining_token_budget else - num_new_tokens) + return 0 if num_new_tokens > \ + remaining_token_budget else num_new_tokens - if cache_config.enable_prefix_caching: - # Adjust the remaining token budget to be divisible by the block - # size when prefix caching is enabled. + # Get the number of tokens to allocate to this prefill slot + prefill_slot_budget = (remaining_token_budget + if partial_prefill_metadata is None else + partial_prefill_budget_lookup_list[ + partial_prefill_metadata.partial_prefills]) - # When prefix caching is enabled, we always allocate - # the number of new tokens that is dividable by the block - # size to avoid partial block matching. + if cache_config.enable_prefix_caching: + # When prefix caching is enabled and we're partially prefilling + # a sequence, we always allocate a number of new tokens that is + # divisible by the block size to avoid partial block matching. block_size = cache_config.block_size - remainder = budget.token_budget % block_size - if remainder != 0: - raise ValueError("When enabling chunked prefill and " - "prefix caching, max_num_batched_tokens " - "(chunk size) must be dividable by " - "block size, but got chunk_size " - f"({budget.token_budget}) % block_size " - f"({block_size}) = {remainder}") - # Round down to block size. - remaining_token_budget = (remaining_token_budget // block_size * - block_size) - - num_new_tokens = min(num_new_tokens, remaining_token_budget) + # Don't exceed either the total budget or slot budget. + # Take min of those and get the next lowest multiple of the + # block size: + remaining_token_budget = ( + min(remaining_token_budget, prefill_slot_budget) // + block_size) * block_size + # NB: In the case where num_new_tokens < budget, we are + # finishing prefill for this sequence, so we do not need to + # allocate a full block. + + num_new_tokens = min(num_new_tokens, remaining_token_budget, + prefill_slot_budget) return num_new_tokens diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 912a8b2f54adb..e34923d92426a 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -120,6 +120,9 @@ class EngineArgs: cpu_offload_gb: float = 0 # GiB gpu_memory_utilization: float = 0.90 max_num_batched_tokens: Optional[int] = None + max_num_partial_prefills: Optional[int] = 1 + max_long_partial_prefills: Optional[int] = 1 + long_prefill_threshold: Optional[float] = 0.04 max_num_seqs: Optional[int] = None max_logprobs: int = 20 # Default value for OpenAI Chat Completions API disable_log_stats: bool = False @@ -507,6 +510,31 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=EngineArgs.max_num_batched_tokens, help='Maximum number of batched tokens per ' 'iteration.') + parser.add_argument( + "--max-num-partial-prefills", + type=int, + default=EngineArgs.max_num_partial_prefills, + help="For chunked prefill, the max number of concurrent \ + partial prefills." + "Defaults to 1", + ) + parser.add_argument( + "--max-long-partial-prefills", + type=int, + default=EngineArgs.max_long_partial_prefills, + help="For chunked prefill, the max number of long concurrent " + "partial prefills. The length is determined by the " + "long-prefill-threshold argument. " + "Defaults to 1", + ) + parser.add_argument( + "--long-prefill-threshold", + type=float, + default=EngineArgs.long_prefill_threshold, + help="For chunked prefill, a request is considered long " + "if the prompt is longer than the " + "max_model_length * long_prefill_threshold. Defaults to 0.04%", + ) parser.add_argument('--max-num-seqs', type=int, default=EngineArgs.max_num_seqs, @@ -1188,7 +1216,11 @@ def create_engine_config(self, multi_step_stream_outputs=self.multi_step_stream_outputs, send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER and parallel_config.use_ray), - policy=self.scheduling_policy) + policy=self.scheduling_policy, + max_num_partial_prefills=self.max_num_partial_prefills, + max_long_partial_prefills=self.max_long_partial_prefills, + long_prefill_threshold=self.long_prefill_threshold, + ) lora_config = LoRAConfig( bias_enabled=self.enable_lora_bias, max_lora_rank=self.max_lora_rank, diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index c10efefea5471..8792bd42d54d2 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -995,7 +995,9 @@ def get_logprobs( if len(query_indices) == 0: empty_sampled_logprob: SampleLogprobs = [] empty_prompt_logprob: Optional[PromptLogprobs] = None - return [empty_prompt_logprob], [empty_sampled_logprob] + num_seq_groups = len(sampling_metadata.seq_groups) + return [empty_prompt_logprob + ] * num_seq_groups, [empty_sampled_logprob] * num_seq_groups selected_logprobs, ranks = None, None top_logprobs, top_token_ids = None, None @@ -1262,6 +1264,10 @@ def _build_sampler_output( assert sample_logprobs is not None assert not isinstance(maybe_deferred_sample_results, SampleResultArgsType) + assert len(sampling_metadata.seq_groups) \ + == len(maybe_deferred_sample_results) \ + == len(prompt_logprobs) \ + == len(sample_logprobs) deferred_sample_results_args = None for (seq_group, sample_result, group_prompt_logprobs,