Skip to content

Commit

Permalink
Squash 10235
Browse files Browse the repository at this point in the history
Signed-off-by: Jefferson Fialho <[email protected]>
  • Loading branch information
fialhocoelho committed Dec 12, 2024
1 parent e22a57a commit 9c7773c
Show file tree
Hide file tree
Showing 6 changed files with 654 additions and 132 deletions.
30 changes: 10 additions & 20 deletions tests/basic_correctness/test_chunked_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
Run `pytest tests/models/test_chunked_prefill.py`.
"""
import os
from contextlib import nullcontext

import pytest

Expand Down Expand Up @@ -232,7 +231,6 @@ def test_with_prefix_caching(

max_num_batched_tokens = max_num_seqs = chunk_size
outputs = {} # type: ignore
check_result = True
for enable in (True, False):
with vllm_runner(
model,
Expand All @@ -244,25 +242,17 @@ def test_with_prefix_caching(
enforce_eager=enforce_eager,
max_num_seqs=max_num_seqs,
) as vllm_model:
# It should fail when prefix caching is enable and chunk
# size is not a multiple of block size (16).
should_fail = chunk_size % 16 != 0 and enable
check_result &= not should_fail
outputs[enable] = []
# Send the request one-by-one to ensure the cache is populated.
with pytest.raises(ValueError) if should_fail else nullcontext():
for prompt in full_prompts:
outputs[enable] += vllm_model.generate_greedy([prompt],
max_tokens)

# Check results only if we did not expect a failure.
if check_result:
check_outputs_equal(
outputs_0_lst=outputs[False],
outputs_1_lst=outputs[True],
name_0="w/o prefix caching",
name_1="with prefix caching",
)
for prompt in full_prompts:
outputs[enable] += vllm_model.generate_greedy([prompt],
max_tokens)

check_outputs_equal(
outputs_0_lst=outputs[False],
outputs_1_lst=outputs[True],
name_0="w/o prefix caching",
name_1="with prefix caching",
)


@pytest.mark.parametrize("model", ["facebook/opt-125m"])
Expand Down
298 changes: 296 additions & 2 deletions tests/core/test_chunked_prefill_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@

from vllm.config import CacheConfig, SchedulerConfig
from vllm.core.scheduler import Scheduler
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.sampling_params import SamplingParams
from vllm.sequence import Logprob, SequenceGroup

from .utils import create_dummy_prompt
Expand All @@ -14,7 +17,7 @@ def get_sequence_groups(scheduler_output):
return [s.seq_group for s in scheduler_output.scheduled_seq_groups]


def append_new_token(seq_group, token_id: int):
def append_new_token(seq_group: SequenceGroup, token_id: int):
for seq in seq_group.get_seqs():
seq.append_token_id(token_id, {token_id: Logprob(token_id)})

Expand Down Expand Up @@ -121,6 +124,214 @@ def test_chunk():
assert out.num_batched_tokens == 57


def test_concurrent_chunking():
"""Verify prefills are chunked properly when
--max-num-partial-prefills is > 1"""
block_size = 4
max_seqs = 60
max_model_len = 2000
max_num_batched_tokens = 64
scheduler_config = SchedulerConfig(
"generate",
max_num_batched_tokens,
max_seqs,
max_model_len,
enable_chunked_prefill=True,
max_num_partial_prefills=2, # Up to 2 partial prefills at a time
)
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 32
cache_config.num_gpu_blocks = 32
scheduler = Scheduler(scheduler_config, cache_config, None)
running: List[SequenceGroup] = []

# Add seq groups to scheduler.
for i in range(2):
_, seq_group = create_dummy_prompt(str(i),
prompt_length=60,
block_size=block_size)
scheduler.add_seq_group(seq_group)
running.append(seq_group)

# Verify both requests are chunked with half of max_num_batched_tokens each
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
assert set(get_sequence_groups(out)) == set(running)
assert seq_group_meta[0].token_chunk_size == 32
assert seq_group_meta[1].token_chunk_size == 32
assert out.num_prefill_groups == 2
assert out.num_batched_tokens == 64

# After one iteration, both should have 60 - 32 = 28 tokens left to prefill
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
assert set(get_sequence_groups(out)) == set(running)
assert seq_group_meta[0].token_chunk_size == 28
assert seq_group_meta[1].token_chunk_size == 28
assert out.num_prefill_groups == 2
assert out.num_batched_tokens == 56


def test_concurrent_chunking_large_requests():
"""Verify large prefill requests are run one at a time"""
block_size = 4
max_seqs = 60
max_model_len = 2000
max_num_batched_tokens = 64
scheduler_config = SchedulerConfig(
"generate",
max_num_batched_tokens,
max_seqs,
max_model_len,
enable_chunked_prefill=True,
max_num_partial_prefills=2, # Up to 2 partial prefills at a time
)
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 3200 # large KV cache size for large requests
cache_config.num_gpu_blocks = 3200
scheduler = Scheduler(scheduler_config, cache_config, None)

# Add seq groups to scheduler.
for i in range(2):
_, seq_group = create_dummy_prompt(
str(i),
prompt_length=1200, # Very large prompt
block_size=block_size)
scheduler.add_seq_group(seq_group)

# Verify only a single request is chunked, and it gets all 64 tokens
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
assert len(get_sequence_groups(out)) == 1
assert seq_group_meta[0].token_chunk_size == 64
assert out.num_prefill_groups == 1
assert out.num_batched_tokens == 64


def test_short_prompts_jump_long_prompts_in_queue():
"""Verify large prefill requests are punted behind smaller ones if
another large prefill request is already running"""
block_size = 4
max_seqs = 60
max_model_len = 2000
max_num_batched_tokens = 64
scheduler_config = SchedulerConfig(
"generate",
max_num_batched_tokens,
max_seqs,
max_model_len,
enable_chunked_prefill=True,
max_num_partial_prefills=2, # Up to 2 partial prefills at a time
)
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 3200 # large KV cache size for large requests
cache_config.num_gpu_blocks = 3200
scheduler = Scheduler(scheduler_config, cache_config, None)
running: List[SequenceGroup] = []

# Add 2 large seq groups to scheduler.
for i in range(2):
_, seq_group = create_dummy_prompt(
str(i),
prompt_length=1200, # Very large prompt
block_size=block_size)
scheduler.add_seq_group(seq_group)
running.append(seq_group)
assert seq_group.is_prefill()

# Add 2 small seq groups behind them
for i in range(2):
_, seq_group = create_dummy_prompt(
str(i + 2),
prompt_length=40, # Very small prompt
block_size=block_size)
scheduler.add_seq_group(seq_group)
running.append(seq_group)
assert seq_group.is_prefill()

# Verify one large req and 1 small req chunked
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
assert seq_group_meta[0].token_chunk_size == 32 # large req gets 32 tokens
assert seq_group_meta[1].token_chunk_size == 32 # small req gets 32 tokens

# all 4 are prefilling
assert running[0].is_prefill()
assert running[1].is_prefill()
assert running[2].is_prefill()
assert running[3].is_prefill()

assert out.num_prefill_groups == 2
assert out.num_batched_tokens == 64

# in the second iteration,
# the first small request had only 8 tokens left
# so it went to decode
# The other small req is scheduled
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
# the new small req got 64 - (32+8) tokens
assert (seq_group_meta[0].token_chunk_size == 24)
assert seq_group_meta[1].token_chunk_size == 32 # large req still got 32
# the other small request had only 8 tokens left
assert seq_group_meta[2].token_chunk_size == 8 # 40-32

# notice the small request got to decode now
# this is because of max_num_partial_prefills logic
assert running[0].is_prefill()
assert running[1].is_prefill()
assert not running[2].is_prefill()
assert running[3].is_prefill()

assert out.num_prefill_groups == 3
assert out.num_batched_tokens == 64
# the small seq group has a new token appended.
append_new_token(running[2], 1)

# in the third iteration,
# the first small request has entered decode
# and other small req had 16 tokens left
# so it went to decode
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
assert seq_group_meta[0].token_chunk_size == 32 # large still got 32
# small req prefilled 40-24=16 tokens
assert (seq_group_meta[1].token_chunk_size == 16)
assert seq_group_meta[2].token_chunk_size == 1 # decode
assert out.num_prefill_groups == 2
assert out.num_batched_tokens == 49 # (32+16+1 decode)

# both small requests have now reached decode
assert running[0].is_prefill()
assert running[1].is_prefill()
assert not running[2].is_prefill()
assert not running[3].is_prefill()

# the small seq group has a new token appended.
append_new_token(running[2], 1)

# in the fourth iteration, both small requests are decoding
# so large request gets all the budget
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
# large req gets 63 tokens (minus 1 for decode)
assert seq_group_meta[0].token_chunk_size == 63
assert seq_group_meta[1].token_chunk_size == 1 # decode
assert out.num_prefill_groups == 1
assert out.num_batched_tokens == 64

assert running[0].is_prefill()
assert running[1].is_prefill()
assert not running[2].is_prefill()
assert not running[3].is_prefill()

# both the small seq groups have a new token appended
append_new_token(running[2], 1)
append_new_token(running[3], 1)

# in the fifth iteration, large request gets all the budget
# while both small requests are decoding
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
assert seq_group_meta[0].token_chunk_size == 62
assert seq_group_meta[1].token_chunk_size == 1 # decode
assert seq_group_meta[2].token_chunk_size == 1 # decode
assert out.num_prefill_groups == 1
assert out.num_batched_tokens == 64


def test_complex():
block_size = 4
max_seqs = 60
Expand Down Expand Up @@ -506,7 +717,7 @@ def test_chunked_prefill_max_seqs():
assert not running[1].is_prefill()


def test_perfix_caching():
def test_prefix_caching():
"""Verify allocating full blocks when prefix caching is enabled."""
block_size = 4
max_seqs = 10
Expand Down Expand Up @@ -546,3 +757,86 @@ def test_perfix_caching():
assert seq_group_meta[1].token_chunk_size == 12
assert out.num_prefill_groups == 2
assert out.num_batched_tokens == 62


def test_prefix_caching_with_concurrent_partial_prefills():
"""Verify allocating full blocks when prefix caching is enabled with
--max-num-partial-prefills > 1."""
block_size = 4
max_seqs = 10
max_model_len = 8000
max_num_batched_tokens = 60 # With two slots, each slot will get 30 tokens
scheduler_config = SchedulerConfig("generate",
max_num_batched_tokens,
max_seqs,
max_model_len,
enable_chunked_prefill=True,
max_num_partial_prefills=2)
cache_config = CacheConfig(block_size,
1.0,
1,
"auto",
enable_prefix_caching=True)
cache_config.num_cpu_blocks = 0
cache_config.num_gpu_blocks = 32
scheduler = Scheduler(scheduler_config, cache_config, None)
running: List[SequenceGroup] = []

# Add seq groups to scheduler.
for i in range(2):
_, seq_group = create_dummy_prompt(str(i),
block_size=block_size,
prompt_length=50)
scheduler.add_seq_group(seq_group)
running.append(seq_group)

seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
assert set(get_sequence_groups(out)) == set(running)
# To partially prefill both sequences, both can chunk up to 30 tokens
# But the next lowest multiple of the block size (4) is 28
assert seq_group_meta[0].token_chunk_size == 28
assert seq_group_meta[1].token_chunk_size == 28
assert out.num_prefill_groups == 2
assert out.num_batched_tokens == 56

# On the next iteration, both sequences should finish prefill
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
assert set(get_sequence_groups(out)) == set(running)
# Both sequences have 50 - 28 = 22 tokens left to prefill.
# This is not a multiple of the block size, but we don't care since we don't
# cache the final partial block of prefix sequences
assert seq_group_meta[0].token_chunk_size == 22
assert seq_group_meta[1].token_chunk_size == 22
assert out.num_prefill_groups == 2
assert out.num_batched_tokens == 44


@pytest.mark.parametrize("model", ["facebook/opt-125m"])
@pytest.mark.parametrize("max_num_partial_prefills", [2, 4, 8])
def test_chunked_prefill_with_actual_engine(model: str,
max_num_partial_prefills: int):
"""Make sure the model can actually sample with concurrent
partial prefills
"""

prompt = "hello" * 40

engine_args = EngineArgs(
model=model,
max_num_partial_prefills=max_num_partial_prefills,
max_num_batched_tokens=40,
max_num_seqs=8,
enable_chunked_prefill=True,
gpu_memory_utilization=0.8,
)

engine = LLMEngine.from_engine_args(engine_args)
sampling_params = SamplingParams(temperature=0)

for req_num in range(max_num_partial_prefills):
engine.add_request(f"{req_num}", prompt, sampling_params)
# first step
request_outputs = engine.step()
# means all are prefilling
assert len(request_outputs) == 0
assert len(engine.scheduler[0].running) == max_num_partial_prefills
Loading

0 comments on commit 9c7773c

Please sign in to comment.