Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite finetune command if subcommand is not provided #1313

Merged
merged 6 commits into from
Apr 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ litgpt download --repo_id microsoft/phi-2
# 2) Finetune the model
curl -L https://huggingface.co/datasets/medalpaca/medical_meadow_health_advice/raw/main/medical_meadow_health_advice.json -o my_custom_dataset.json

litgpt finetune lora \
litgpt finetune \
--checkpoint_dir checkpoints/microsoft/phi-2 \
--data JSON \
--data.json_path my_custom_dataset.json \
Expand Down Expand Up @@ -315,7 +315,7 @@ Browse all training recipes [here](config_hub).
### Example

```bash
litgpt finetune lora \
litgpt finetune \
--config https://raw.githubusercontent.com/Lightning-AI/litgpt/main/config_hub/finetune/llama-2-7b/lora.yaml
```

Expand Down Expand Up @@ -470,7 +470,7 @@ seed: 1337
Override any parameter in the CLI:

```bash
litgpt finetune lora \
litgpt finetune \
--config https://raw.githubusercontent.com/Lightning-AI/litgpt/main/config_hub/finetune/llama-2-7b/lora.yaml \
--lora_r 4
```
Expand Down
9 changes: 9 additions & 0 deletions litgpt/__main__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import sys

from typing import TYPE_CHECKING, Any

Expand Down Expand Up @@ -41,6 +42,12 @@ def _new_parser(**kwargs: Any) -> "ArgumentParser":
return parser


def _rewrite_argv_for_default_subcommand(parser_data: dict, command: str, subcommand: str) -> None:
"""Rewrites the `sys.argv` such that `litgpt command` defaults to `litgpt command subcommand`."""
if len(sys.argv) > 2 and sys.argv[1] == command and sys.argv[2] not in parser_data[command].keys():
sys.argv.insert(2, subcommand)


def main() -> None:
parser_data = {
"download": {"help": "Download weights or tokenizer data from the Hugging Face Hub.", "fn": download_fn},
Expand Down Expand Up @@ -90,6 +97,8 @@ def main() -> None:
set_docstring_parse_options(attribute_docstrings=True)
set_config_read_mode(urls_enabled=True)

_rewrite_argv_for_default_subcommand(parser_data, "finetune", "lora")

root_parser = _new_parser(prog="litgpt")

# register level 1 subcommands and level 2 subsubcommands. If there are more levels in the future we would want to
Expand Down
25 changes: 12 additions & 13 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,8 @@ def test_cli():
chat Chat with a model."""
in out
)
assert ("""evaluate Evaluate a model with the LM Evaluation Harness.""") in out
assert ("""serve Serve and deploy a model with LitServe.""") in out
out = StringIO()
with pytest.raises(SystemExit), redirect_stdout(out), mock.patch("sys.argv", ["litgpt", "finetune", "-h"]):
main()
out = out.getvalue()
assert (
"""Available subcommands:
lora Finetune a model with LoRA.
full Finetune a model."""
in out
)

assert """evaluate Evaluate a model with the LM Evaluation Harness.""" in out
assert """serve Serve and deploy a model with LitServe.""" in out
out = StringIO()
with pytest.raises(SystemExit), redirect_stdout(out), mock.patch("sys.argv", ["litgpt", "finetune", "lora", "-h"]):
main()
Expand All @@ -61,3 +50,13 @@ def test_cli():
Optional[int], default: 3000000000000)"""
in out
)


def test_rewrite_finetune_command():
out1 = StringIO()
with pytest.raises(SystemExit), redirect_stdout(out1), mock.patch("sys.argv", ["litgpt", "fineune", "-h"]):
main()
out2 = StringIO()
with pytest.raises(SystemExit), redirect_stdout(out2), mock.patch("sys.argv", ["litgpt", "fineune", "lora", "-h"]):
main()
assert out1.getvalue() == out2.getvalue()