Skip to content

Commit

Permalink
[Bug][CLI] Allow users to disable prefix caching explicitly (vllm-pro…
Browse files Browse the repository at this point in the history
…ject#10724)

Signed-off-by: rickyx <[email protected]>
  • Loading branch information
rickyyx authored Nov 28, 2024
1 parent 278be67 commit d9b4b3f
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 3 deletions.
19 changes: 19 additions & 0 deletions tests/engine/test_arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,25 @@ def test_compilation_config():
assert args.compilation_config.level == 3


def test_prefix_cache_default():
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
args = parser.parse_args([])

engine_args = EngineArgs.from_cli_args(args=args)
assert (not engine_args.enable_prefix_caching
), "prefix caching defaults to off."

# with flag to turn it on.
args = parser.parse_args(["--enable-prefix-caching"])
engine_args = EngineArgs.from_cli_args(args=args)
assert engine_args.enable_prefix_caching

# with disable flag to turn it off.
args = parser.parse_args(["--no-enable-prefix-caching"])
engine_args = EngineArgs.from_cli_args(args=args)
assert not engine_args.enable_prefix_caching


def test_valid_pooling_config():
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
args = parser.parse_args([
Expand Down
19 changes: 19 additions & 0 deletions tests/v1/engine/test_engine_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from vllm.config import VllmConfig
from vllm.engine.arg_utils import EngineArgs
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser

if not envs.VLLM_USE_V1:
pytest.skip(
Expand All @@ -12,6 +13,24 @@
)


def test_prefix_caching_from_cli():
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
args = parser.parse_args([])
engine_args = EngineArgs.from_cli_args(args=args)
assert (engine_args.enable_prefix_caching
), "V1 turns on prefix caching by default."

# Turn it off possible with flag.
args = parser.parse_args(["--no-enable-prefix-caching"])
engine_args = EngineArgs.from_cli_args(args=args)
assert not engine_args.enable_prefix_caching

# Turn it on with flag.
args = parser.parse_args(["--enable-prefix-caching"])
engine_args = EngineArgs.from_cli_args(args=args)
assert engine_args.enable_prefix_caching


def test_defaults():
engine_args = EngineArgs(model="facebook/opt-125m")

Expand Down
10 changes: 7 additions & 3 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,9 +416,13 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
'tokens. This is ignored on neuron devices and '
'set to max-model-len')

parser.add_argument('--enable-prefix-caching',
action='store_true',
help='Enables automatic prefix caching.')
parser.add_argument(
"--enable-prefix-caching",
action=argparse.BooleanOptionalAction,
default=EngineArgs.enable_prefix_caching,
help="Enables automatic prefix caching. "
"Use --no-enable-prefix-caching to disable explicitly.",
)
parser.add_argument('--disable-sliding-window',
action='store_true',
help='Disables sliding window, '
Expand Down

0 comments on commit d9b4b3f

Please sign in to comment.