Skip to content

Commit

Permalink
Brings in flexargparse and rebases
Browse files Browse the repository at this point in the history
Signed-off-by: Rafael Vasquez <[email protected]>
  • Loading branch information
rafvasq committed Jul 15, 2024
1 parent 0bacb8c commit fe882a8
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 23 deletions.
15 changes: 10 additions & 5 deletions benchmarks/benchmark_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
On the server side, run one of the following commands:
vLLM OpenAI API server
vllm serve <your_model>
vllm serve <your_model> \
--swap-space 16 \
--disable-log-requests
Expand All @@ -17,7 +17,7 @@
--dataset-path <path to dataset> \
--request-rate <request_rate> \ # By default <request_rate> is inf
--num-prompts <num_prompts> # By default <num_prompts> is 1000
when using tgi backend, add
--endpoint /generate_stream
to the end of the command above.
Expand All @@ -44,6 +44,11 @@
except ImportError:
from backend_request_func import get_tokenizer

try:
from vllm.utils import FlexibleArgumentParser
except ImportError:
from argparse import ArgumentParser as FlexibleArgumentParser


@dataclass
class BenchmarkMetrics:
Expand Down Expand Up @@ -72,7 +77,6 @@ def sample_sharegpt_requests(
) -> List[Tuple[str, int, int]]:
if fixed_output_len is not None and fixed_output_len < 4:
raise ValueError("output_len too small")

# Load the dataset.
with open(dataset_path) as f:
dataset = json.load(f)
Expand Down Expand Up @@ -191,6 +195,7 @@ async def get_request(
if request_rate == float("inf"):
# If the request rate is infinity, then we don't need to wait.
continue

# Sample the request interval from the exponential distribution.
interval = np.random.exponential(1.0 / request_rate)
# The next request will be sent after the interval.
Expand All @@ -214,7 +219,7 @@ def calculate_metrics(
# We use the tokenizer to count the number of output tokens for all
# serving backends instead of looking at len(outputs[i].itl) since
# multiple output tokens may be bundled together
# Note: this may inflate the output token count slightly
# Note : this may inflate the output token count slightly
output_len = len(
tokenizer(outputs[i].generated_text,
add_special_tokens=False).input_ids)
Expand Down Expand Up @@ -511,7 +516,7 @@ def main(args: argparse.Namespace):


if __name__ == "__main__":
parser = argparse.ArgumentParser(
parser = FlexibleArgumentParser(
description="Benchmark the online serving throughput.")
parser.add_argument(
"--backend",
Expand Down
11 changes: 7 additions & 4 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import sys
import time
import warnings
from argparse import ArgumentParser
from contextlib import contextmanager
from typing import List

Expand All @@ -14,7 +13,7 @@
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.utils import get_open_port
from vllm.utils import FlexibleArgumentParser, get_open_port

# Path to root of repository so that utilities can be imported by ray workers
VLLM_PATH = os.path.abspath(os.path.join(__file__, os.pardir, os.pardir))
Expand All @@ -32,7 +31,10 @@ def __init__(self, cli_args: List[str], *, wait_url: str,
env = os.environ.copy()
env["PYTHONUNBUFFERED"] = "1"
self.proc = subprocess.Popen(
["vllm", "serve", *cli_args],
[
sys.executable, "-m", "vllm.entrypoints.openai.api_server",
*cli_args
],
env=env,
stdout=sys.stdout,
stderr=sys.stderr,
Expand Down Expand Up @@ -72,7 +74,8 @@ def __init__(self, cli_args: List[str], *, auto_port: bool = True) -> None:

cli_args = cli_args + ["--port", str(get_open_port())]

parser = ArgumentParser(description="vLLM's remote OpenAI server.")
parser = FlexibleArgumentParser(
description="vLLM's remote OpenAI server.")
parser = make_arg_parser(parser)
args = parser.parse_args(cli_args)
self.host = str(args.host or 'localhost')
Expand Down
7 changes: 3 additions & 4 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import argparse
import asyncio
import importlib
import inspect
Expand Down Expand Up @@ -31,6 +30,7 @@
from vllm.logger import init_logger
from vllm.tgis_utils.args import add_tgis_args, postprocess_tgis_args
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser
from vllm.version import __version__ as VLLM_VERSION

TIMEOUT_KEEP_ALIVE = 5 # seconds
Expand Down Expand Up @@ -89,7 +89,7 @@ async def health() -> Response:

@router.get("/v1/models")
async def show_available_models():
models = await openai_serving_chat.show_available_models()
models = await openai_serving_completion.show_available_models()
return JSONResponse(content=models.model_dump())


Expand Down Expand Up @@ -263,8 +263,7 @@ def run_server(args, llm_engine=None):
if __name__ == "__main__":
# NOTE(simon):
# This section should be in sync with vllm/scripts.py for CLI entrypoints.

parser = argparse.ArgumentParser(
parser = FlexibleArgumentParser(
description="vLLM OpenAI-Compatible RESTful API server.")
parser = make_arg_parser(parser)
parser = add_tgis_args(parser)
Expand Down
8 changes: 4 additions & 4 deletions vllm/entrypoints/openai/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
from vllm.entrypoints.openai.serving_engine import LoRAModulePath
from vllm.utils import FlexibleArgumentParser


class LoRAParserAction(argparse.Action):
Expand All @@ -22,8 +23,7 @@ def __call__(self, parser, namespace, values, option_string=None):
setattr(namespace, self.dest, lora_list)


def make_arg_parser(
parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
parser.add_argument("--host",
type=nullable_str,
default=None,
Expand Down Expand Up @@ -114,7 +114,7 @@ def make_arg_parser(
return parser


def create_parser_for_docs() -> argparse.ArgumentParser:
parser_for_docs = argparse.ArgumentParser(
def create_parser_for_docs() -> FlexibleArgumentParser:
parser_for_docs = FlexibleArgumentParser(
prog="-m vllm.entrypoints.openai.api_server")
return make_arg_parser(parser_for_docs)
9 changes: 3 additions & 6 deletions vllm/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from vllm.entrypoints.openai.api_server import run_server
from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.utils import FlexibleArgumentParser


def registrer_signal_handlers():
Expand All @@ -23,10 +24,6 @@ def signal_handler(sig, frame):

def serve(args: argparse.Namespace) -> None:
# EngineArgs expects the model name to be passed as --model.
if args.model is not None and args.model == args.model_tag:
raise ValueError(
"The --model argument is not supported for the serve command. "
"Use positional argument [model_tag] instead.")
args.model = args.model_tag

run_server(args)
Expand Down Expand Up @@ -245,7 +242,7 @@ def convert_to_fast_tokenizer(


def _add_query_options(
parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
parser.add_argument(
"--url",
type=str,
Expand All @@ -269,7 +266,7 @@ def _add_query_options(


def main():
parser = argparse.ArgumentParser(description="vLLM CLI")
parser = FlexibleArgumentParser(description="vLLM CLI")
subparsers = parser.add_subparsers(required=True)

serve_parser = subparsers.add_parser(
Expand Down
25 changes: 25 additions & 0 deletions vllm/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import argparse
import asyncio
import datetime
import enum
Expand Down Expand Up @@ -775,3 +776,27 @@ def wrapper(*args, **kwargs) -> Any:

wrapper.has_run = False # type: ignore[attr-defined]
return wrapper


class FlexibleArgumentParser(argparse.ArgumentParser):
"""ArgumentParser that allows both underscore and dash in names."""

def parse_args(self, args=None, namespace=None):
if args is None:
args = sys.argv[1:]

# Convert underscores to dashes and vice versa in argument names
processed_args = []
for arg in args:
if arg.startswith('--'):
if '=' in arg:
key, value = arg.split('=', 1)
key = '--' + key[len('--'):].replace('_', '-')
processed_args.append(f'{key}={value}')
else:
processed_args.append('--' +
arg[len('--'):].replace('_', '-'))
else:
processed_args.append(arg)

return super().parse_args(processed_args, namespace)

0 comments on commit fe882a8

Please sign in to comment.