From b5eb73ed996913b8935965bb6d0c413c7075ba36 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 17 Apr 2024 19:17:10 +0200 Subject: [PATCH 1/5] rewrite --- litgpt/__main__.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/litgpt/__main__.py b/litgpt/__main__.py index 59d53ac904..2d81041a12 100644 --- a/litgpt/__main__.py +++ b/litgpt/__main__.py @@ -1,4 +1,5 @@ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. +import sys from typing import TYPE_CHECKING, Any @@ -39,6 +40,16 @@ def _new_parser(**kwargs: Any) -> "ArgumentParser": return parser +def _rewrite_argv_for_default_subcommand(parser_data: dict, command: str, subcommand: str) -> None: + if ( + len(sys.argv) > 2 + and sys.argv[1] == command + and sys.argv[2] not in parser_data[command].keys() + and not any(h in sys.argv for h in ["-h", "--help"]) + ): + sys.argv.insert(2, subcommand) + + def main() -> None: parser_data = { "download": {"help": "Download weights or tokenizer data from the Hugging Face Hub.", "fn": download_fn}, @@ -87,6 +98,8 @@ def main() -> None: set_docstring_parse_options(attribute_docstrings=True) set_config_read_mode(urls_enabled=True) + _rewrite_argv_for_default_subcommand(parser_data, "finetune", "lora") + root_parser = _new_parser(prog="litgpt") # register level 1 subcommands and level 2 subsubcommands. If there are more levels in the future we would want to From e7ef0c600cfbe3e06b9a033669df39e8705befb9 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 17 Apr 2024 19:30:20 +0200 Subject: [PATCH 2/5] update --- litgpt/__main__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/litgpt/__main__.py b/litgpt/__main__.py index 2d81041a12..9b73e8b422 100644 --- a/litgpt/__main__.py +++ b/litgpt/__main__.py @@ -41,6 +41,7 @@ def _new_parser(**kwargs: Any) -> "ArgumentParser": def _rewrite_argv_for_default_subcommand(parser_data: dict, command: str, subcommand: str) -> None: + """Rewrites the `sys.argv` such that `litgpt command` defaults to `litgpt command subcommand`.""" if ( len(sys.argv) > 2 and sys.argv[1] == command From 278b05ac4fc1d8dd6edac447319ca5401243fccb Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 17 Apr 2024 19:51:35 +0200 Subject: [PATCH 3/5] mini test --- litgpt/__main__.py | 7 +------ tests/test_cli.py | 10 ++++++++++ 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/litgpt/__main__.py b/litgpt/__main__.py index 9b73e8b422..06a865f75e 100644 --- a/litgpt/__main__.py +++ b/litgpt/__main__.py @@ -42,12 +42,7 @@ def _new_parser(**kwargs: Any) -> "ArgumentParser": def _rewrite_argv_for_default_subcommand(parser_data: dict, command: str, subcommand: str) -> None: """Rewrites the `sys.argv` such that `litgpt command` defaults to `litgpt command subcommand`.""" - if ( - len(sys.argv) > 2 - and sys.argv[1] == command - and sys.argv[2] not in parser_data[command].keys() - and not any(h in sys.argv for h in ["-h", "--help"]) - ): + if len(sys.argv) > 2 and sys.argv[1] == command and sys.argv[2] not in parser_data[command].keys(): sys.argv.insert(2, subcommand) diff --git a/tests/test_cli.py b/tests/test_cli.py index 2c994fcf96..042a06f458 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -61,3 +61,13 @@ def test_cli(): Optional[int], default: 3000000000000)""" in out ) + + +def test_rewrite_finetune_command(): + out1 = StringIO() + with pytest.raises(SystemExit), redirect_stdout(out1), mock.patch("sys.argv", ["litgpt", "fineune", "-h"]): + main() + out2 = StringIO() + with pytest.raises(SystemExit), redirect_stdout(out2), mock.patch("sys.argv", ["litgpt", "fineune", "lora", "-h"]): + main() + assert out1.getvalue() == out2.getvalue() From 692e006eee91e5f44fb4e2f4bc3cf30704c212ae Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 17 Apr 2024 20:54:15 +0200 Subject: [PATCH 4/5] update test --- tests/test_cli.py | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/tests/test_cli.py b/tests/test_cli.py index 042a06f458..a45a1ca092 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -23,18 +23,7 @@ def test_cli(): chat Chat with a model.""" in out ) - assert ("""evaluate Evaluate a model with the LM Evaluation Harness.""") in out - - out = StringIO() - with pytest.raises(SystemExit), redirect_stdout(out), mock.patch("sys.argv", ["litgpt", "finetune", "-h"]): - main() - out = out.getvalue() - assert ( - """Available subcommands: - lora Finetune a model with LoRA. - full Finetune a model.""" - in out - ) + assert """evaluate Evaluate a model with the LM Evaluation Harness.""" in out out = StringIO() with pytest.raises(SystemExit), redirect_stdout(out), mock.patch("sys.argv", ["litgpt", "finetune", "lora", "-h"]): From 0372779904f9f3d74bb3b55b15ac1d60c4349e1e Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 17 Apr 2024 23:35:05 +0200 Subject: [PATCH 5/5] update commands in readme --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index f6e426b9fa..3694d0cd4a 100644 --- a/README.md +++ b/README.md @@ -162,7 +162,7 @@ litgpt download --repo_id microsoft/phi-2 # 2) Finetune the model curl -L https://huggingface.co/datasets/medalpaca/medical_meadow_health_advice/raw/main/medical_meadow_health_advice.json -o my_custom_dataset.json -litgpt finetune lora \ +litgpt finetune \ --checkpoint_dir checkpoints/microsoft/phi-2 \ --data JSON \ --data.json_path my_custom_dataset.json \ @@ -267,7 +267,7 @@ Browse all training recipes [here](config_hub). ### Example ```bash -litgpt finetune lora \ +litgpt finetune \ --config https://raw.githubusercontent.com/Lightning-AI/litgpt/main/config_hub/finetune/llama-2-7b/lora.yaml ``` @@ -422,7 +422,7 @@ seed: 1337 Override any parameter in the CLI: ```bash -litgpt finetune lora \ +litgpt finetune \ --config https://raw.githubusercontent.com/Lightning-AI/litgpt/main/config_hub/finetune/llama-2-7b/lora.yaml \ --lora_r 4 ```