Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca committed May 8, 2024
1 parent 61e0b2c commit 9dcf291
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions tests/test_generate_sequentially.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,9 +285,9 @@ def test_base_with_sequentially(tmp_path):
f"--checkpoint_dir={str(checkpoint_dir)}",
]
env = {"CUDA_VISIBLE_DEVICES": "0,1"}
base_stdout = subprocess.check_output([sys.executable, root / "litgpt/generate/base.py", *args], env=env).decode()
base_stdout = subprocess.check_output(["litgpt", "generate", "base", *args], env=env).decode()
sequential_stdout = subprocess.check_output(
[sys.executable, root / "litgpt/generate/sequentially.py", *args], env=env
["litgpt", "generate", "sequentially", *args], env=env
).decode()

assert base_stdout.startswith("What food do llamas eat?")
Expand Down
2 changes: 1 addition & 1 deletion tests/test_generate_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def test_tp(tmp_path):
f"--checkpoint_dir={str(checkpoint_dir)}",
]
env = {"CUDA_VISIBLE_DEVICES": "0,1"}
tp_stdout = subprocess.check_output([sys.executable, root / "litgpt/generate/tp.py", *args], env=env).decode()
tp_stdout = subprocess.check_output(["litgpt", "generate", "tp", *args], env=env).decode()

# there is some unaccounted randomness so cannot compare the output with that of `generate/base.py`
assert tp_stdout.startswith("What food do llamas eat?")
Expand Down

0 comments on commit 9dcf291

Please sign in to comment.