Skip to content

Commit

Permalink
Minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca committed Apr 3, 2024
1 parent 9b381c1 commit b53b688
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions tests/test_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
def test_evaluate_script(tmp_path, monkeypatch):
ours_config = Config.from_name("pythia-14m")
download_from_hub(repo_id="EleutherAI/pythia-14m", tokenizer_only=True, checkpoint_dir=tmp_path)
shutil.move(tmp_path / "EleutherAI" / "pythia-14m" / "tokenizer.json", tmp_path)
shutil.move(tmp_path / "EleutherAI" / "pythia-14m" / "tokenizer_config.json", tmp_path)
shutil.move(str(tmp_path / "EleutherAI" / "pythia-14m" / "tokenizer.json"), str(tmp_path))
shutil.move(str(tmp_path / "EleutherAI" / "pythia-14m" / "tokenizer_config.json"), str(tmp_path))
ours_model = GPT(ours_config)
checkpoint_path = tmp_path / "lit_model.pth"
torch.save(ours_model.state_dict(), checkpoint_path)
Expand All @@ -46,11 +46,11 @@ def test_evaluate_script(tmp_path, monkeypatch):
tasks="mathqa"
)
stdout = StringIO()
with redirect_stdout(stdout), mock.patch("sys.argv", [Path("eval") / "evaluate.py"]):
with redirect_stdout(stdout), mock.patch("sys.argv", ["eval/evaluate.py"]):
module.convert_and_evaluate(**fn_kwargs)
stdout_out = stdout.getvalue()
assert "mathqa" in stdout_out
assert "Metric" in stdout_out
stdout = stdout.getvalue()
assert "mathqa" in stdout
assert "Metric" in stdout


@pytest.mark.parametrize("mode", ["file", "entrypoint"])
Expand Down

0 comments on commit b53b688

Please sign in to comment.