diff --git a/tests/entrypoints/test_llm_generate.py b/tests/entrypoints/test_llm_generate.py new file mode 100644 index 0000000000000..5e8b7ca4d9977 --- /dev/null +++ b/tests/entrypoints/test_llm_generate.py @@ -0,0 +1,41 @@ +import pytest + +from vllm import LLM, SamplingParams + + +def test_multiple_sampling_params(): + + llm = LLM(model="facebook/opt-125m", + max_num_batched_tokens=4096, + tensor_parallel_size=1) + + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + sampling_params = [ + SamplingParams(temperature=0.01, top_p=0.95), + SamplingParams(temperature=0.3, top_p=0.95), + SamplingParams(temperature=0.7, top_p=0.95), + SamplingParams(temperature=0.99, top_p=0.95), + ] + + # Multiple SamplingParams should be matched with each prompt + outputs = llm.generate(prompts, sampling_params=sampling_params) + assert len(prompts) == len(outputs) + + # Exception raised, if the size of params does not match the size of prompts + with pytest.raises(ValueError): + outputs = llm.generate(prompts, sampling_params=sampling_params[:3]) + + # Single SamplingParams should be applied to every prompt + single_sampling_params = SamplingParams(temperature=0.3, top_p=0.95) + outputs = llm.generate(prompts, sampling_params=single_sampling_params) + assert len(prompts) == len(outputs) + + # sampling_params is None, default params should be applied + outputs = llm.generate(prompts, sampling_params=None) + assert len(prompts) == len(outputs) \ No newline at end of file diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 961de5d5063fa..f745dbd736d17 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -127,7 +127,8 @@ def set_tokenizer( def generate( self, prompts: Optional[Union[str, List[str]]] = None, - sampling_params: Optional[SamplingParams] = None, + sampling_params: Optional[Union[SamplingParams, + List[SamplingParams]]] = None, prompt_token_ids: Optional[List[List[int]]] = None, use_tqdm: bool = True, lora_request: Optional[LoRARequest] = None, @@ -142,7 +143,10 @@ def generate( Args: prompts: A list of prompts to generate completions for. sampling_params: The sampling parameters for text generation. If - None, we use the default sampling parameters. + None, we use the default sampling parameters. + When it is a single value, it is applied to every prompt. + When it is a list, the list must have the same length as the + prompts and it is paired one by one with the prompt. prompt_token_ids: A list of token IDs for the prompts. If None, we use the tokenizer to convert the prompts to token IDs. use_tqdm: Whether to use tqdm to display the progress bar. @@ -163,27 +167,33 @@ def generate( and len(prompts) != len(prompt_token_ids)): raise ValueError("The lengths of prompts and prompt_token_ids " "must be the same.") + + if prompts is not None: + num_requests = len(prompts) + else: + assert prompt_token_ids is not None + num_requests = len(prompt_token_ids) + if sampling_params is None: # Use default sampling params. sampling_params = SamplingParams() + elif isinstance(sampling_params, + list) and len(sampling_params) != num_requests: + raise ValueError("The lengths of prompts and sampling_params " + "must be the same.") if multi_modal_data: multi_modal_data.data = multi_modal_data.data.to(torch.float16) # Add requests to the engine. - if prompts is not None: - num_requests = len(prompts) - else: - assert prompt_token_ids is not None - num_requests = len(prompt_token_ids) - for i in range(num_requests): prompt = prompts[i] if prompts is not None else None token_ids = None if prompt_token_ids is None else prompt_token_ids[ i] self._add_request( prompt, - sampling_params, + sampling_params[i] + if isinstance(sampling_params, list) else sampling_params, token_ids, lora_request=lora_request, # Get ith image while maintaining the batch dim. @@ -232,4 +242,4 @@ def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]: # This is necessary because some requests may be finished earlier than # its previous requests. outputs = sorted(outputs, key=lambda x: int(x.request_id)) - return outputs + return outputs \ No newline at end of file