From af7ea68f75834b688cb5edd4c3f5a1aac4faff18 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Wed, 3 Apr 2024 15:19:17 +0200 Subject: [PATCH] Minor fixes --- tests/test_evaluate.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/test_evaluate.py b/tests/test_evaluate.py index fc59676b85..94b90b4551 100644 --- a/tests/test_evaluate.py +++ b/tests/test_evaluate.py @@ -29,8 +29,8 @@ def test_evaluate_script(tmp_path, monkeypatch): ours_config = Config.from_name("pythia-14m") download_from_hub(repo_id="EleutherAI/pythia-14m", tokenizer_only=True, checkpoint_dir=tmp_path) - shutil.move(tmp_path / "EleutherAI" / "pythia-14m" / "tokenizer.json", tmp_path) - shutil.move(tmp_path / "EleutherAI" / "pythia-14m" / "tokenizer_config.json", tmp_path) + shutil.move(str(tmp_path / "EleutherAI" / "pythia-14m" / "tokenizer.json"), str(tmp_path)) + shutil.move(str(tmp_path / "EleutherAI" / "pythia-14m" / "tokenizer_config.json"), str(tmp_path)) ours_model = GPT(ours_config) checkpoint_path = tmp_path / "lit_model.pth" torch.save(ours_model.state_dict(), checkpoint_path) @@ -46,11 +46,11 @@ def test_evaluate_script(tmp_path, monkeypatch): tasks="mathqa" ) stdout = StringIO() - with redirect_stdout(stdout), mock.patch("sys.argv", [Path("eval") / "evaluate.py"]): + with redirect_stdout(stdout), mock.patch("sys.argv", ["eval/evaluate.py"]): module.convert_and_evaluate(**fn_kwargs) - stdout_out = stdout.getvalue() - assert "mathqa" in stdout_out - assert "Metric" in stdout_out + stdout = stdout.getvalue() + assert "mathqa" in stdout + assert "Metric" in stdout @pytest.mark.parametrize("mode", ["file", "entrypoint"])