Skip to content

Commit

Permalink
[misc] partial prefix & random input generation benchmark (vllm-proje…
Browse files Browse the repository at this point in the history
…ct#9929)

Signed-off-by: rickyx <[email protected]>
  • Loading branch information
rickyyx authored Nov 18, 2024
1 parent 2298e69 commit 90a6c75
Showing 1 changed file with 91 additions and 25 deletions.
116 changes: 91 additions & 25 deletions benchmarks/benchmark_prefix_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,30 @@ def test_prefix(llm=None, sampling_params=None, prompts=None):
print(f"cost time {end_time - start_time}")


def sample_requests(
@dataclasses.dataclass
class Request:
prompt: str
prompt_len: int
output_len: int


def sample_tokens(tokenizer: PreTrainedTokenizerBase, length: int) -> str:
vocab = tokenizer.get_vocab()
# Remove the special tokens.
vocab = {
k: v
for k, v in vocab.items() if k not in tokenizer.all_special_ids
}
return random.choices(list(vocab.values()), k=length)


def sample_requests_from_dataset(
dataset_path: str,
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
input_length_range: Tuple[int, int],
fixed_output_len: Optional[int],
) -> List[Tuple[str, int, int]]:
) -> List[Request]:
if fixed_output_len is not None and fixed_output_len < 4:
raise ValueError("output_len too small")

Expand All @@ -77,68 +94,108 @@ def sample_requests(
random.shuffle(dataset)

min_len, max_len = input_length_range
assert min_len >= 0 and max_len >= min_len, "input_length_range too small"

# Filter out sequences that are too long or too short
filtered_dataset: List[Tuple[str, int, int]] = []
filtered_requests: List[Request] = []

for i in range(len(dataset)):
if len(filtered_dataset) == num_requests:
if len(filtered_requests) == num_requests:
break

# Tokenize the prompts and completions.
prompt = dataset[i][0]
prompt_token_ids = tokenizer(prompt).input_ids
prompt_token_ids = tokenizer(dataset[i][0]).input_ids
prompt = tokenizer.decode(prompt_token_ids)
completion = dataset[i][1]
completion_token_ids = tokenizer(completion).input_ids
prompt_len = len(prompt_token_ids)
output_len = len(completion_token_ids
) if fixed_output_len is None else fixed_output_len
if prompt_len < 4 or output_len < 4:
# Prune too short sequences.
continue
output_len = (len(completion_token_ids)
if fixed_output_len is None else fixed_output_len)
if min_len <= prompt_len <= max_len:
filtered_dataset.append((prompt, prompt_len, output_len))
filtered_requests.append(Request(prompt, prompt_len, output_len))

return filtered_requests


def sample_requests_from_random(
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
input_length_range: Tuple[int, int],
fixed_output_len: Optional[int],
prefix_len: int,
) -> List[Request]:

return filtered_dataset
requests = []
prefix_token_ids = sample_tokens(tokenizer, prefix_len)
min_len, max_len = input_length_range

for i in range(num_requests):
unique_part_token_ids = sample_tokens(
tokenizer,
random.randint(min_len - prefix_len, max_len - prefix_len))
prompt_token_ids = prefix_token_ids + unique_part_token_ids
prompt = tokenizer.decode(prompt_token_ids)
prompt_len = len(prompt_token_ids)
assert (min_len <= prompt_len <= max_len
), f"prompt_len {prompt_len} out of range {min_len}:{max_len}"
requests.append(Request(prompt, prompt_len, fixed_output_len))
return requests


def repeat_and_sort_requests(requests: List[Tuple[str, int, int]],
def repeat_and_sort_requests(requests: List[Request],
repeat_count: int,
sort: bool = False) -> List[str]:
repeated_requests = requests * repeat_count
if sort:
repeated_requests.sort(key=lambda x: x[1])
else:
random.shuffle(repeated_requests)
return [req[0] for req in repeated_requests]
return [req.prompt for req in repeated_requests]


def main(args):
tokenizer = get_tokenizer(args.model, trust_remote_code=True)
input_length_range = tuple(map(int, args.input_length_range.split(':')))
random.seed(args.seed)
if args.dataset_path is not None:
print(f"Start to sample {args.num_prompts} prompts"
if args.prefix_len > 0:
raise ValueError("prefix-len is not supported when "
"dataset-path is provided.")
print(f"Start to sample {args.num_prompts} prompts "
f"from {args.dataset_path}")
filtered_datasets = sample_requests(
filtered_requests = sample_requests_from_dataset(
dataset_path=args.dataset_path,
num_requests=args.num_prompts,
tokenizer=tokenizer,
input_length_range=input_length_range,
fixed_output_len=args.output_len,
)
else:
prompt_len = len(tokenizer(PROMPT).input_ids)
filtered_datasets = [(PROMPT, prompt_len, args.output_len)
] * args.num_prompts
print(f"Start to sample {args.num_prompts} prompts from random")
filtered_requests = sample_requests_from_random(
num_requests=args.num_prompts,
tokenizer=tokenizer,
input_length_range=input_length_range,
fixed_output_len=args.output_len,
prefix_len=args.prefix_len,
)

# Print some helpful stats of the requests.
print(f"Sampled {len(filtered_requests)} requests.")
prompt_lens = [req.prompt_len for req in filtered_requests]
print(f"Average input length: {sum(prompt_lens) / len(prompt_lens)}")
print(f"P50 input length: {sorted(prompt_lens)[len(prompt_lens) // 2]}")
print(f"Min Prompt Length: {min(prompt_lens)}")
print(f"Max Prompt Length: {max(prompt_lens)}")

engine_args = EngineArgs.from_cli_args(args)

llm = LLM(**dataclasses.asdict(engine_args))

sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len)

print("Testing filtered datasets")
prompts = repeat_and_sort_requests(filtered_datasets,
print("Testing filtered requests")
prompts = repeat_and_sort_requests(filtered_requests,
repeat_count=args.repeat_count,
sort=args.sort)

Expand All @@ -161,20 +218,29 @@ def main(args):
parser.add_argument('--output-len', type=int, default=10)
parser.add_argument('--num-prompts',
type=int,
default=1,
required=True,
help="Number of the prompts sampled from dataset")
parser.add_argument('--repeat-count',
type=int,
default=100,
default=1,
help='Number of times to repeat each prompt')
parser.add_argument('--sort',
action='store_true',
help='Sort prompts by input length')
parser.add_argument('--input-length-range',
type=str,
default='128:256',
required=True,
help='Range of input lengths for sampling prompts,'
'specified as "min:max" (e.g., "128:256").')
parser.add_argument(
"--prefix-len",
type=int,
default=0,
help="Specifies the length of a common prefix to be "
"added to the input prompt. The input-length-range will "
"subtract this length when filtering prompts. Only used "
"when dataset-path is not provided.",
)

parser = EngineArgs.add_cli_args(parser)
args = parser.parse_args()
Expand Down

0 comments on commit 90a6c75

Please sign in to comment.