Skip to content

Commit

Permalink
Remove per-file CLIs
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca committed May 7, 2024
1 parent ca4a23d commit f9f280a
Show file tree
Hide file tree
Showing 6 changed files with 12 additions and 45 deletions.
9 changes: 2 additions & 7 deletions tests/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,13 +122,8 @@ def test_main(mocked_input, stop_iteration, fake_checkpoint_dir, monkeypatch, te
assert re.match("Now chatting with Llama 3.*>> .*Reply: foo bar baz", out.getvalue(), re.DOTALL)


@pytest.mark.parametrize("mode", ["file", "entrypoint"])
def test_cli(mode):
if mode == "file":
cli_path = Path(__file__).parent.parent / "litgpt/chat/base.py"
args = [sys.executable, cli_path, "-h"]
else:
args = ["litgpt", "chat", "-h"]
def test_cli():
args = ["litgpt", "chat", "-h"]
output = subprocess.check_output(args)
output = str(output.decode())
assert "Starts a conversation" in output
Expand Down
12 changes: 2 additions & 10 deletions tests/test_evaluate.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.

import subprocess
import sys
from contextlib import redirect_stdout
from dataclasses import asdict
from io import StringIO
from pathlib import Path
from unittest import mock

import pytest
import torch
import yaml

Expand Down Expand Up @@ -43,13 +40,8 @@ def test_evaluate_script(tmp_path):
assert "Loading checkpoint shards" not in stdout


@pytest.mark.parametrize("mode", ["file", "entrypoint"])
def test_cli(mode):
if mode == "file":
cli_path = Path(__file__).parent.parent / "litgpt/eval/evaluate.py"
args = [sys.executable, cli_path, "-h"]
else:
args = ["litgpt", "evaluate", "-h"]
def test_cli():
args = ["litgpt", "evaluate", "-h"]
output = subprocess.check_output(args)
output = str(output.decode())
assert "run the LM Evaluation Harness" in output
9 changes: 2 additions & 7 deletions tests/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,8 @@ def test_main(fake_checkpoint_dir, monkeypatch, tensor_like):
assert "'padded_vocab_size': 512, 'n_layer': 2, 'n_head': 4" in err.getvalue()


@pytest.mark.parametrize("mode", ["file", "entrypoint"])
def test_cli(mode):
if mode == "file":
cli_path = Path(__file__).parent.parent / "litgpt/generate/base.py"
args = [sys.executable, cli_path, "-h"]
else:
args = ["litgpt", "generate", "base", "-h"]
def test_cli():
args = ["litgpt", "generate", "base", "-h"]
output = subprocess.check_output(args)
output = str(output.decode())
assert "Generates text samples" in output
Expand Down
9 changes: 2 additions & 7 deletions tests/test_generate_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,8 @@ def test_main(fake_checkpoint_dir, monkeypatch, version, tensor_like):


@pytest.mark.parametrize("version", ("", "_v2"))
@pytest.mark.parametrize("mode", ["file", "entrypoint"])
def test_cli(version, mode):
if mode == "file":
cli_path = Path(__file__).parent.parent / f"litgpt/generate/adapter{version}.py"
args = [sys.executable, cli_path, "-h"]
else:
args = ["litgpt", "generate", f"adapter{version}", "-h"]
def test_cli(version):
args = ["litgpt", "generate", f"adapter{version}", "-h"]
output = subprocess.check_output(args)
output = str(output.decode())
assert "Generates a response" in output
9 changes: 2 additions & 7 deletions tests/test_generate_sequentially.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,13 +294,8 @@ def test_base_with_sequentially(tmp_path):
assert base_stdout == sequential_stdout


@pytest.mark.parametrize("mode", ["file", "entrypoint"])
def test_cli(mode):
if mode == "file":
cli_path = Path(__file__).parent.parent / "litgpt/generate/sequentially.py"
args = [sys.executable, cli_path, "-h"]
else:
args = ["litgpt", "generate", "sequentially", "-h"]
def test_cli():
args = ["litgpt", "generate", "sequentially", "-h"]
output = subprocess.check_output(args)
output = str(output.decode())
assert "Generates text samples" in output
9 changes: 2 additions & 7 deletions tests/test_generate_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,13 +130,8 @@ def test_tp(tmp_path):
assert tp_stdout.startswith("What food do llamas eat?")


@pytest.mark.parametrize("mode", ["file", "entrypoint"])
def test_cli(mode):
if mode == "file":
cli_path = Path(__file__).parent.parent / "litgpt/generate/tp.py"
args = [sys.executable, cli_path, "-h"]
else:
args = ["litgpt", "generate", "tp", "-h"]
def test_cli():
args = ["litgpt", "generate", "tp", "-h"]
output = subprocess.check_output(args)
output = str(output.decode())
assert "Generates text samples" in output

0 comments on commit f9f280a

Please sign in to comment.