From f9f280a3dcdf2acba437a53a0c4e8130296c7277 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Tue, 7 May 2024 16:29:13 +0200 Subject: [PATCH] Remove per-file CLIs --- tests/test_chat.py | 9 ++------- tests/test_evaluate.py | 12 ++---------- tests/test_generate.py | 9 ++------- tests/test_generate_adapter.py | 9 ++------- tests/test_generate_sequentially.py | 9 ++------- tests/test_generate_tp.py | 9 ++------- 6 files changed, 12 insertions(+), 45 deletions(-) diff --git a/tests/test_chat.py b/tests/test_chat.py index 13b432897d..ac6ef5760f 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -122,13 +122,8 @@ def test_main(mocked_input, stop_iteration, fake_checkpoint_dir, monkeypatch, te assert re.match("Now chatting with Llama 3.*>> .*Reply: foo bar baz", out.getvalue(), re.DOTALL) -@pytest.mark.parametrize("mode", ["file", "entrypoint"]) -def test_cli(mode): - if mode == "file": - cli_path = Path(__file__).parent.parent / "litgpt/chat/base.py" - args = [sys.executable, cli_path, "-h"] - else: - args = ["litgpt", "chat", "-h"] +def test_cli(): + args = ["litgpt", "chat", "-h"] output = subprocess.check_output(args) output = str(output.decode()) assert "Starts a conversation" in output diff --git a/tests/test_evaluate.py b/tests/test_evaluate.py index 12f8a68f9c..3b6c56ae1b 100644 --- a/tests/test_evaluate.py +++ b/tests/test_evaluate.py @@ -1,14 +1,11 @@ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import subprocess -import sys from contextlib import redirect_stdout from dataclasses import asdict from io import StringIO -from pathlib import Path from unittest import mock -import pytest import torch import yaml @@ -43,13 +40,8 @@ def test_evaluate_script(tmp_path): assert "Loading checkpoint shards" not in stdout -@pytest.mark.parametrize("mode", ["file", "entrypoint"]) -def test_cli(mode): - if mode == "file": - cli_path = Path(__file__).parent.parent / "litgpt/eval/evaluate.py" - args = [sys.executable, cli_path, "-h"] - else: - args = ["litgpt", "evaluate", "-h"] +def test_cli(): + args = ["litgpt", "evaluate", "-h"] output = subprocess.check_output(args) output = str(output.decode()) assert "run the LM Evaluation Harness" in output diff --git a/tests/test_generate.py b/tests/test_generate.py index 7cd0dca9db..430cd6ada0 100644 --- a/tests/test_generate.py +++ b/tests/test_generate.py @@ -83,13 +83,8 @@ def test_main(fake_checkpoint_dir, monkeypatch, tensor_like): assert "'padded_vocab_size': 512, 'n_layer': 2, 'n_head': 4" in err.getvalue() -@pytest.mark.parametrize("mode", ["file", "entrypoint"]) -def test_cli(mode): - if mode == "file": - cli_path = Path(__file__).parent.parent / "litgpt/generate/base.py" - args = [sys.executable, cli_path, "-h"] - else: - args = ["litgpt", "generate", "base", "-h"] +def test_cli(): + args = ["litgpt", "generate", "base", "-h"] output = subprocess.check_output(args) output = str(output.decode()) assert "Generates text samples" in output diff --git a/tests/test_generate_adapter.py b/tests/test_generate_adapter.py index e977b0a93b..0df5fb67d2 100644 --- a/tests/test_generate_adapter.py +++ b/tests/test_generate_adapter.py @@ -48,13 +48,8 @@ def test_main(fake_checkpoint_dir, monkeypatch, version, tensor_like): @pytest.mark.parametrize("version", ("", "_v2")) -@pytest.mark.parametrize("mode", ["file", "entrypoint"]) -def test_cli(version, mode): - if mode == "file": - cli_path = Path(__file__).parent.parent / f"litgpt/generate/adapter{version}.py" - args = [sys.executable, cli_path, "-h"] - else: - args = ["litgpt", "generate", f"adapter{version}", "-h"] +def test_cli(version): + args = ["litgpt", "generate", f"adapter{version}", "-h"] output = subprocess.check_output(args) output = str(output.decode()) assert "Generates a response" in output diff --git a/tests/test_generate_sequentially.py b/tests/test_generate_sequentially.py index b0bed4797e..5f4d9c7203 100644 --- a/tests/test_generate_sequentially.py +++ b/tests/test_generate_sequentially.py @@ -294,13 +294,8 @@ def test_base_with_sequentially(tmp_path): assert base_stdout == sequential_stdout -@pytest.mark.parametrize("mode", ["file", "entrypoint"]) -def test_cli(mode): - if mode == "file": - cli_path = Path(__file__).parent.parent / "litgpt/generate/sequentially.py" - args = [sys.executable, cli_path, "-h"] - else: - args = ["litgpt", "generate", "sequentially", "-h"] +def test_cli(): + args = ["litgpt", "generate", "sequentially", "-h"] output = subprocess.check_output(args) output = str(output.decode()) assert "Generates text samples" in output diff --git a/tests/test_generate_tp.py b/tests/test_generate_tp.py index 039dd0ea4b..634b81c974 100644 --- a/tests/test_generate_tp.py +++ b/tests/test_generate_tp.py @@ -130,13 +130,8 @@ def test_tp(tmp_path): assert tp_stdout.startswith("What food do llamas eat?") -@pytest.mark.parametrize("mode", ["file", "entrypoint"]) -def test_cli(mode): - if mode == "file": - cli_path = Path(__file__).parent.parent / "litgpt/generate/tp.py" - args = [sys.executable, cli_path, "-h"] - else: - args = ["litgpt", "generate", "tp", "-h"] +def test_cli(): + args = ["litgpt", "generate", "tp", "-h"] output = subprocess.check_output(args) output = str(output.decode()) assert "Generates text samples" in output