From 9f5067a1a731540319f45632d8064b852bc930a8 Mon Sep 17 00:00:00 2001 From: Benjie Genchel Date: Wed, 31 Jul 2024 00:04:50 -0400 Subject: [PATCH] address pr comments - tmpdir -> tmp_path, cleaner code iteration, remove debugging prints. --- tests/data/test_maestro.py | 38 +++++++++++++++++++------------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/tests/data/test_maestro.py b/tests/data/test_maestro.py index 77004cb..27db976 100644 --- a/tests/data/test_maestro.py +++ b/tests/data/test_maestro.py @@ -86,15 +86,15 @@ def create_mock_midi(output_fpath: str) -> None: logging.info(f"Mock MIDI file '{output_fpath}' created successfully.") -def test_maestro_to_tf_example(tmpdir: str) -> None: - mock_maestro_home = pathlib.Path(tmpdir) / "maestro" +def test_maestro_to_tf_example(tmp_path: pathlib.Path) -> None: + mock_maestro_home = tmp_path / "maestro" mock_maestro_ext = mock_maestro_home / "2004" mock_maestro_ext.mkdir(parents=True, exist_ok=True) create_mock_wav(str(mock_maestro_ext / (TRAIN_TRACK_ID.split("/")[1] + ".wav")), 3) create_mock_midi(str(mock_maestro_ext / (TRAIN_TRACK_ID.split("/")[1] + ".midi"))) - output_dir = pathlib.Path(tmpdir) / "outputs" + output_dir = tmp_path / "outputs" output_dir.mkdir(parents=True, exist_ok=True) input_data: List[str] = [TRAIN_TRACK_ID] @@ -106,25 +106,24 @@ def test_maestro_to_tf_example(tmpdir: str) -> None: | "Write to tfrecord" >> beam.ParDo(WriteBatchToTfRecord(str(output_dir))) ) - assert len(os.listdir(str(output_dir))) == 1 - print("PASSED THIS POINT") - assert os.path.splitext(os.listdir(str(output_dir))[0])[-1] == ".tfrecord" - print("PASSED THIS OTHER POINT") + listdir = os.listdir(str(output_dir)) + assert len(listdir) == 1 + assert os.path.splitext(listdir[0])[-1] == ".tfrecord" with open(os.path.join(str(output_dir), os.listdir(str(output_dir))[0]), "rb") as fp: data = fp.read() assert len(data) != 0 -def test_maestro_invalid_tracks(tmpdir: str) -> None: - mock_maestro_home = pathlib.Path(tmpdir) / "maestro" +def test_maestro_invalid_tracks(tmp_path: pathlib.Path) -> None: + mock_maestro_home = tmp_path / "maestro" mock_maestro_ext = mock_maestro_home / "2004" mock_maestro_ext.mkdir(parents=True, exist_ok=True) - create_mock_wav(str(mock_maestro_ext / (TRAIN_TRACK_ID.split("/")[1] + ".wav")), 3) - create_mock_wav(str(mock_maestro_ext / (VALID_TRACK_ID.split("/")[1] + ".wav")), 3) - create_mock_wav(str(mock_maestro_ext / (TEST_TRACK_ID.split("/")[1] + ".wav")), 3) - input_data = [(TRAIN_TRACK_ID, "train"), (VALID_TRACK_ID, "validation"), (TEST_TRACK_ID, "test")] + + for track_id, _ in input_data: + create_mock_wav(str(mock_maestro_ext / (track_id.split("/")[1] + ".wav")), 3) + split_labels = set([e[1] for e in input_data]) with TestPipeline() as p: splits = ( @@ -137,23 +136,24 @@ def test_maestro_invalid_tracks(tmpdir: str) -> None: ( getattr(splits, split) | f"Write {split} to text" - >> beam.io.WriteToText(os.path.join(tmpdir, f"output_{split}.txt"), shard_name_template="") + >> beam.io.WriteToText(str(tmp_path / f"output_{split}.txt"), shard_name_template="") ) for track_id, split in input_data: - with open(os.path.join(tmpdir, f"output_{split}.txt"), "r") as fp: + with open(str(tmp_path / f"output_{split}.txt"), "r") as fp: assert fp.read().strip() == track_id -def test_maestro_invalid_tracks_over_15_min(tmpdir: str) -> None: +def test_maestro_invalid_tracks_over_15_min(tmp_path: pathlib.Path) -> None: """ The track id used here is a real track id in maestro, and it is part of the train split, but we mock the data so as not to store a large file in git, hence the variable name. """ - mock_maestro_home = pathlib.Path(tmpdir) / "maestro" + mock_maestro_home = tmp_path / "maestro" mock_maestro_ext = mock_maestro_home / "2004" mock_maestro_ext.mkdir(parents=True, exist_ok=True) + mock_fpath = mock_maestro_ext / (GT_15M_TRACK_ID.split("/")[1] + ".wav") create_mock_wav(str(mock_fpath), 16) @@ -170,11 +170,11 @@ def test_maestro_invalid_tracks_over_15_min(tmpdir: str) -> None: ( getattr(splits, split) | f"Write {split} to text" - >> beam.io.WriteToText(os.path.join(tmpdir, f"output_{split}.txt"), shard_name_template="") + >> beam.io.WriteToText(str(tmp_path / f"output_{split}.txt"), shard_name_template="") ) for _, split in input_data: - with open(os.path.join(tmpdir, f"output_{split}.txt"), "r") as fp: + with open(str(tmp_path / f"output_{split}.txt"), "r") as fp: assert fp.read().strip() == ""