From db9e5708a98b7209cf4465a0391139cf8fca7674 Mon Sep 17 00:00:00 2001 From: Peng Guanwen Date: Tue, 30 Jul 2024 00:47:31 +0800 Subject: [PATCH] [Core] Reduce unnecessary compute when logprobs=None (#6532) --- tests/samplers/test_logprobs.py | 39 ++++++- vllm/model_executor/layers/sampler.py | 144 +++++++++++++++----------- vllm/outputs.py | 17 +-- vllm/sampling_params.py | 15 +-- 4 files changed, 135 insertions(+), 80 deletions(-) diff --git a/tests/samplers/test_logprobs.py b/tests/samplers/test_logprobs.py index f7bcd4c855799..c07c71e38233f 100644 --- a/tests/samplers/test_logprobs.py +++ b/tests/samplers/test_logprobs.py @@ -14,7 +14,7 @@ @pytest.mark.parametrize("dtype", ["float"]) # needed for comparing logprobs with HF @pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16, -1]) -@pytest.mark.parametrize("num_top_logprobs", [6]) # 32000 == vocab_size +@pytest.mark.parametrize("num_top_logprobs", [0, 6]) # 32000 == vocab_size @pytest.mark.parametrize("detokenize", [True, False]) def test_get_prompt_logprobs( hf_runner, @@ -63,7 +63,10 @@ def test_get_prompt_logprobs( assert result.outputs[0].logprobs is not None assert len(result.outputs[0].logprobs) == max_tokens for logprobs in result.outputs[0].logprobs: - assert len(logprobs) == num_top_logprobs + # If the output token is not included in the top X + # logprob, it can return 1 more data + assert (len(logprobs) == num_top_logprobs + or len(logprobs) == num_top_logprobs + 1) output_text = result.outputs[0].text output_string_from_most_likely_tokens_lst: List[str] = [] for top_logprobs in result.outputs[0].logprobs: @@ -135,3 +138,35 @@ def test_max_logprobs(): bad_sampling_params = SamplingParams(logprobs=2) with pytest.raises(ValueError): runner.generate(["Hello world"], sampling_params=bad_sampling_params) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16, -1]) +@pytest.mark.parametrize("detokenize", [True, False]) +def test_none_logprobs(vllm_runner, model, chunked_prefill_token_size: int, + detokenize: bool, example_prompts): + max_num_seqs = 256 + enable_chunked_prefill = False + max_num_batched_tokens = None + if chunked_prefill_token_size != -1: + enable_chunked_prefill = True + max_num_seqs = min(chunked_prefill_token_size, max_num_seqs) + max_num_batched_tokens = chunked_prefill_token_size + max_tokens = 5 + + with vllm_runner( + model, + enable_chunked_prefill=enable_chunked_prefill, + max_num_batched_tokens=max_num_batched_tokens, + max_num_seqs=max_num_seqs, + ) as vllm_model: + sampling_params_logprobs_none = SamplingParams(max_tokens=max_tokens, + logprobs=None, + temperature=0.0, + detokenize=detokenize) + results_logprobs_none = vllm_model.model.generate( + example_prompts, sampling_params=sampling_params_logprobs_none) + + for i in range(len(results_logprobs_none)): + assert results_logprobs_none[i].outputs[0].logprobs is None + assert results_logprobs_none[i].outputs[0].cumulative_logprob is None diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 121458f8156a1..60fa3fbb51be6 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -1,5 +1,6 @@ """A layer that samples the next tokens from the model's outputs.""" import itertools +from math import inf from typing import Dict, List, Optional, Tuple import torch @@ -774,8 +775,11 @@ def _get_logprobs( # The next token ids to get the logprob value from. next_token_ids: List[int] = [] # The largest requested number of logprobs. We find logprobs as many as the - # largest num logprobs in this API. - largest_num_logprobs = 1 + # largest num logprobs in this API. If every logprobs is None, it will be + # set to -1. + largest_num_logprobs = -1 + # If beam search is enabled. + use_beam_search = False # Select indices to compute logprob from, ranks of token ids, and the top # k token ids from logprobs. @@ -808,6 +812,8 @@ def _get_logprobs( largest_num_logprobs = max(largest_num_logprobs, sampling_params.logprobs) + use_beam_search = use_beam_search or sampling_params.use_beam_search + assert len(next_token_ids) == len(query_indices) if len(query_indices) == 0: @@ -815,35 +821,40 @@ def _get_logprobs( empty_prompt_logprob: Optional[PromptLogprobs] = None return [empty_prompt_logprob], [empty_sampled_logprob] - query_indices_gpu = torch.tensor(query_indices, device=logprobs.device) - next_token_ids_gpu = torch.tensor(next_token_ids, device=logprobs.device) - - # (num_selected_query_tokens, num_logprobs). Note that query_indices can - # contain duplicates if beam search is enabled. - selected_logprobs = logprobs[[ - query_indices_gpu, - next_token_ids_gpu, - ]] - ranks = _get_ranks( - logprobs[query_indices_gpu], - next_token_ids_gpu, - ) - assert selected_logprobs.shape[0] == ranks.shape[0] - - # Logprobs of topk tokens for a batch of sequence groups. - # (num_query_tokens_across_batch). - if largest_num_logprobs > 0: - top_logprobs, top_token_ids = torch.topk(logprobs, - largest_num_logprobs, - dim=-1) - else: - top_logprobs, top_token_ids = None, None + selected_logprobs, ranks = None, None + top_logprobs, top_token_ids = None, None + + # If largest_num_logprobs == -1, i.e. no logprobs are requested, we can + # skip the whole logprob calculation. + if largest_num_logprobs >= 0 or use_beam_search: + query_indices_gpu = torch.tensor(query_indices, device=logprobs.device) + next_token_ids_gpu = torch.tensor(next_token_ids, + device=logprobs.device) + + # (num_selected_query_tokens, num_logprobs). Note that query_indices can + # contain duplicates if beam search is enabled. + selected_logprobs = logprobs[[ + query_indices_gpu, + next_token_ids_gpu, + ]] + ranks = _get_ranks( + logprobs[query_indices_gpu], + next_token_ids_gpu, + ) + assert selected_logprobs.shape[0] == ranks.shape[0] + + # We need to compute top k only if there exists logprobs > 0. + if largest_num_logprobs > 0: + # Logprobs of topk tokens for a batch of sequence groups. + # (num_query_tokens_across_batch). + top_logprobs, top_token_ids = torch.topk(logprobs, + largest_num_logprobs, + dim=-1) + top_logprobs = top_logprobs.to('cpu') + top_token_ids = top_token_ids.to('cpu') - selected_logprobs = selected_logprobs.to('cpu') - ranks = ranks.to('cpu') - if top_logprobs is not None and top_token_ids is not None: - top_logprobs = top_logprobs.to('cpu') - top_token_ids = top_token_ids.to('cpu') + selected_logprobs = selected_logprobs.to('cpu') + ranks = ranks.to('cpu') # Find prompt/sample logprobs. prompt_logprobs_per_seq_group: List[Optional[PromptLogprobs]] = [] @@ -940,46 +951,53 @@ def _get_sampled_logprob_if_needed( ): """Compute the sample logprob if needed.""" seq_ids = seq_group.seq_ids - num_logprobs = seq_group.sampling_params.logprobs or 0 + num_logprobs = seq_group.sampling_params.logprobs + use_beam_search = seq_group.sampling_params.use_beam_search sampled_logprobs: SampleLogprobs = [] next_token_ids, parent_seq_ids = sample_result if seq_group.do_sample: assert len(next_token_ids) > 0 - # Pre-select items from tensor. tolist() is faster than repetitive - # `.item()` calls. - selected_logprob_items = selected_logprobs[ - selected_logprobs_idx:selected_logprobs_idx + - len(next_token_ids)].tolist() - rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx + - len(next_token_ids)].tolist() - for idx, (next_token_id, - parent_id) in enumerate(zip(next_token_ids, parent_seq_ids)): - # Get the logprob of a sampled token. - sampled_logprobs_dict = { - next_token_id: (selected_logprob_items[idx], rank_items[idx]) - } - # Get top K logprobs. - if num_logprobs > 0: - top_ids = top_token_ids[top_logprob_idx + - parent_id, :num_logprobs].tolist() - top_probs = top_logprobs[top_logprob_idx + - parent_id, :num_logprobs].tolist() - # Top K is already sorted by rank, so we can use 1 ~ - # num_logprobs + 1 for rank. - top_ranks = range(1, num_logprobs + 1) - sampled_logprobs_dict.update({ - top_id: (top_prob, rank) - for top_id, top_prob, rank in zip(top_ids, top_probs, - top_ranks) + if num_logprobs is None and not use_beam_search: + for next_token_id in next_token_ids: + # Use a dummy logprob + sampled_logprobs.append({next_token_id: Logprob(inf)}) + else: + # Pre-select items from tensor. tolist() is faster than repetitive + # `.item()` calls. + selected_logprob_items = selected_logprobs[ + selected_logprobs_idx:selected_logprobs_idx + + len(next_token_ids)].tolist() + rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx + + len(next_token_ids)].tolist() + for idx, (next_token_id, parent_id) in enumerate( + zip(next_token_ids, parent_seq_ids)): + # Get the logprob of a sampled token. + sampled_logprobs_dict = { + next_token_id: + (selected_logprob_items[idx], rank_items[idx]) + } + if num_logprobs is not None and num_logprobs > 0: + # Get top K logprobs. + top_ids = top_token_ids[top_logprob_idx + + parent_id, :num_logprobs].tolist() + top_probs = top_logprobs[ + top_logprob_idx + parent_id, :num_logprobs].tolist() + # Top K is already sorted by rank, so we can use 1 ~ + # num_logprobs + 1 for rank. + top_ranks = range(1, num_logprobs + 1) + sampled_logprobs_dict.update({ + top_id: (top_prob, rank) + for top_id, top_prob, rank in zip( + top_ids, top_probs, top_ranks) + }) + + sampled_logprobs.append({ + token_id: Logprob(*logprob_and_rank) + for token_id, logprob_and_rank in + sampled_logprobs_dict.items() }) - sampled_logprobs.append({ - token_id: Logprob(*logprob_and_rank) - for token_id, logprob_and_rank in - sampled_logprobs_dict.items() - }) - # NOTE: This part of code is not intuitive. `selected_logprobs` include # logprobs for the current step, which has len(next_token_ids) tokens # per sequence group. `logprobs` includes logprobs from the previous diff --git a/vllm/outputs.py b/vllm/outputs.py index 4cb7f06bdb8c7..b1cb1cd07fbb1 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -29,7 +29,7 @@ class CompletionOutput: index: int text: str token_ids: Tuple[int, ...] - cumulative_logprob: float + cumulative_logprob: Optional[float] logprobs: Optional[SampleLogprobs] finish_reason: Optional[str] = None stop_reason: Union[int, str, None] = None @@ -124,13 +124,14 @@ def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput": include_logprobs = seq_group.sampling_params.logprobs is not None text_buffer_length = seq_group.sampling_params.output_text_buffer_length outputs = [ - CompletionOutput(seqs.index(seq), - seq.get_output_text_to_return(text_buffer_length), - seq.get_output_token_ids(), - seq.get_cumulative_logprob(), - seq.output_logprobs if include_logprobs else None, - SequenceStatus.get_finished_reason(seq.status), - seq.stop_reason) for seq in top_n_seqs + CompletionOutput( + seqs.index(seq), + seq.get_output_text_to_return(text_buffer_length), + seq.get_output_token_ids(), + seq.get_cumulative_logprob() if include_logprobs else None, + seq.output_logprobs if include_logprobs else None, + SequenceStatus.get_finished_reason(seq.status), + seq.stop_reason) for seq in top_n_seqs ] # Every sequence in the sequence group should have the same prompt. diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 638c870c04371..2598325439ebf 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -92,11 +92,12 @@ class SamplingParams: min_tokens: Minimum number of tokens to generate per output sequence before EOS or stop_token_ids can be generated logprobs: Number of log probabilities to return per output token. - Note that the implementation follows the OpenAI API: The return - result includes the log probabilities on the `logprobs` most likely - tokens, as well the chosen tokens. The API will always return the - log probability of the sampled token, so there may be up to - `logprobs+1` elements in the response. + When set to None, no probability is returned. If set to a non-None + value, the result includes the log probabilities of the specified + number of most likely tokens, as well as the chosen tokens. + Note that the implementation follows the OpenAI API: The API will + always return the log probability of the sampled token, so there + may be up to `logprobs+1` elements in the response. prompt_logprobs: Number of log probabilities to return per prompt token. detokenize: Whether to detokenize the output. Defaults to True. skip_special_tokens: Whether to skip special tokens in the output. @@ -168,8 +169,8 @@ def __init__( self.ignore_eos = ignore_eos self.max_tokens = max_tokens self.min_tokens = min_tokens - self.logprobs = logprobs - self.prompt_logprobs = prompt_logprobs + self.logprobs = 1 if logprobs is True else logprobs + self.prompt_logprobs = 1 if prompt_logprobs is True else prompt_logprobs # NOTE: This parameter is only exposed at the engine level for now. # It is not exposed in the OpenAI API server, as the OpenAI API does # not support returning only a list of token IDs.