diff --git a/README.md b/README.md index ddf733291e..78fdff2b6d 100644 --- a/README.md +++ b/README.md @@ -168,7 +168,7 @@ litgpt download --repo_id microsoft/phi-2 # 2) Finetune the model curl -L https://huggingface.co/datasets/medalpaca/medical_meadow_health_advice/raw/main/medical_meadow_health_advice.json -o my_custom_dataset.json -litgpt finetune lora \ +litgpt finetune \ --checkpoint_dir checkpoints/microsoft/phi-2 \ --data JSON \ --data.json_path my_custom_dataset.json \ @@ -315,7 +315,7 @@ Browse all training recipes [here](config_hub). ### Example ```bash -litgpt finetune lora \ +litgpt finetune \ --config https://raw.githubusercontent.com/Lightning-AI/litgpt/main/config_hub/finetune/llama-2-7b/lora.yaml ``` @@ -470,7 +470,7 @@ seed: 1337 Override any parameter in the CLI: ```bash -litgpt finetune lora \ +litgpt finetune \ --config https://raw.githubusercontent.com/Lightning-AI/litgpt/main/config_hub/finetune/llama-2-7b/lora.yaml \ --lora_r 4 ``` diff --git a/litgpt/__main__.py b/litgpt/__main__.py index e88c6212c8..821c1f5801 100644 --- a/litgpt/__main__.py +++ b/litgpt/__main__.py @@ -1,4 +1,5 @@ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. +import sys from typing import TYPE_CHECKING, Any @@ -41,6 +42,12 @@ def _new_parser(**kwargs: Any) -> "ArgumentParser": return parser +def _rewrite_argv_for_default_subcommand(parser_data: dict, command: str, subcommand: str) -> None: + """Rewrites the `sys.argv` such that `litgpt command` defaults to `litgpt command subcommand`.""" + if len(sys.argv) > 2 and sys.argv[1] == command and sys.argv[2] not in parser_data[command].keys(): + sys.argv.insert(2, subcommand) + + def main() -> None: parser_data = { "download": {"help": "Download weights or tokenizer data from the Hugging Face Hub.", "fn": download_fn}, @@ -90,6 +97,8 @@ def main() -> None: set_docstring_parse_options(attribute_docstrings=True) set_config_read_mode(urls_enabled=True) + _rewrite_argv_for_default_subcommand(parser_data, "finetune", "lora") + root_parser = _new_parser(prog="litgpt") # register level 1 subcommands and level 2 subsubcommands. If there are more levels in the future we would want to diff --git a/tests/test_cli.py b/tests/test_cli.py index 49a10a07ab..f95841ddc0 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -23,19 +23,8 @@ def test_cli(): chat Chat with a model.""" in out ) - assert ("""evaluate Evaluate a model with the LM Evaluation Harness.""") in out - assert ("""serve Serve and deploy a model with LitServe.""") in out - out = StringIO() - with pytest.raises(SystemExit), redirect_stdout(out), mock.patch("sys.argv", ["litgpt", "finetune", "-h"]): - main() - out = out.getvalue() - assert ( - """Available subcommands: - lora Finetune a model with LoRA. - full Finetune a model.""" - in out - ) - + assert """evaluate Evaluate a model with the LM Evaluation Harness.""" in out + assert """serve Serve and deploy a model with LitServe.""" in out out = StringIO() with pytest.raises(SystemExit), redirect_stdout(out), mock.patch("sys.argv", ["litgpt", "finetune", "lora", "-h"]): main() @@ -61,3 +50,13 @@ def test_cli(): Optional[int], default: 3000000000000)""" in out ) + + +def test_rewrite_finetune_command(): + out1 = StringIO() + with pytest.raises(SystemExit), redirect_stdout(out1), mock.patch("sys.argv", ["litgpt", "fineune", "-h"]): + main() + out2 = StringIO() + with pytest.raises(SystemExit), redirect_stdout(out2), mock.patch("sys.argv", ["litgpt", "fineune", "lora", "-h"]): + main() + assert out1.getvalue() == out2.getvalue()