From 6e3710d7f447305e753551fa873f2c72921e41fe Mon Sep 17 00:00:00 2001 From: Benjie Genchel Date: Thu, 8 Aug 2024 20:06:24 -0400 Subject: [PATCH] added two tests for example_deserialization, and made some corresponding changes in the original file. --- .../data/tf_example_deserialization.py | 6 +- tests/data/test_tf_example_deserialization.py | 67 ++++++++++++++++--- 2 files changed, 61 insertions(+), 12 deletions(-) diff --git a/basic_pitch/data/tf_example_deserialization.py b/basic_pitch/data/tf_example_deserialization.py index dc60b54..54551e4 100644 --- a/basic_pitch/data/tf_example_deserialization.py +++ b/basic_pitch/data/tf_example_deserialization.py @@ -219,7 +219,7 @@ def transcription_file_generator( dataset_names: List[str], datasets_base_path: str, sample_weights: np.ndarray, -) -> Tuple[Callable[[], Iterator[str]], bool]: +) -> Tuple[Callable[[], Iterator[tf.Tensor]], bool]: """ dataset_names: list of dataset dataset_names """ @@ -235,7 +235,7 @@ def transcription_file_generator( return lambda: _validation_file_generator(file_dict), True -def _train_file_generator(x: Dict[str, List[str]], weights: np.ndarray) -> Iterator[str]: +def _train_file_generator(x: Dict[str, tf.data.Dataset], weights: np.ndarray) -> Iterator[tf.Tensor]: x = {k: list(v) for (k, v) in x.items()} keys = list(x.keys()) # shuffle each list @@ -248,7 +248,7 @@ def _train_file_generator(x: Dict[str, List[str]], weights: np.ndarray) -> Itera yield fpath -def _validation_file_generator(x: Dict[str, tf.data.Dataset]) -> Iterator[str]: +def _validation_file_generator(x: Dict[str, tf.data.Dataset]) -> Iterator[tf.Tensor]: x = {k: list(v) for (k, v) in x.items()} # loop until there are no more test files while any(x.values()): diff --git a/tests/data/test_tf_example_deserialization.py b/tests/data/test_tf_example_deserialization.py index f7bcf61..d0f15ec 100644 --- a/tests/data/test_tf_example_deserialization.py +++ b/tests/data/test_tf_example_deserialization.py @@ -16,24 +16,73 @@ # limitations under the License. import numpy as np +import os import pathlib import tensorflow as tf -from basic_pitch.data.tf_example_deserialization import sample_datasets, transcription_file_generator +from basic_pitch.data.tf_example_deserialization import transcription_dataset, transcription_file_generator -def test_prepare_dataset(): - pass +def create_empty_tfrecord(filepath: pathlib.Path) -> None: + assert filepath.suffix == ".tfrecord" + with tf.io.TFRecordWriter(str(filepath)) as writer: + writer.write("") -def test_sample_datasets(): - pass +# def test_prepare_dataset() -> None: +# pass -def test_transcription_file_generator(tmpdir: str): - print("FUCK YOU ") - file_gen, random_seed = transcription_file_generator("train", ["test2"], datasets_base_path=tmpdir, sample_weights=np.ndarray(1)) +# def test_sample_datasets() -> None: +# pass + + +# def test_transcription_dataset(tmp_path: pathlib.Path) -> None: +# dataset_path = tmp_path / "test_ds" / "splits" / "train" +# dataset_path.mkdir(parents=True) +# create_empty_tfrecord(dataset_path / "test.tfrecord") + +# file_gen, random_seed = transcription_file_generator( +# "train", ["test_ds"], datasets_base_path=str(tmp_path), sample_weights=np.array([1]) +# ) + +# transcription_dataset(file_generator=file_gen, n_samples_per_track=1, random_seed=random_seed) + + +def test_transcription_file_generator_train(tmp_path: pathlib.Path) -> None: + dataset_path = tmp_path / "test_ds" / "splits" / "train" + dataset_path.mkdir(parents=True) + create_empty_tfrecord(dataset_path / "test.tfrecord") + + file_gen, random_seed = transcription_file_generator( + "train", ["test_ds"], datasets_base_path=str(tmp_path), sample_weights=np.array([1]) + ) + assert random_seed is False - print(file_gen()) + generator = file_gen() + assert next(generator).numpy().decode("utf-8") == str(dataset_path / "test.tfrecord") + try: + next(generator) + except Exception as e: + assert isinstance(e, StopIteration) + + +def test_transcription_file_generator_valid(tmp_path: pathlib.Path) -> None: + dataset_path = tmp_path / "test_ds" / "splits" / "valid" + dataset_path.mkdir(parents=True) + create_empty_tfrecord(dataset_path / "test.tfrecord") + + file_gen, random_seed = transcription_file_generator( + "valid", ["test_ds"], datasets_base_path=str(tmp_path), sample_weights=np.array([1]) + ) + + assert random_seed is True + + generator = file_gen() + assert next(generator).numpy().decode("utf-8") == str(dataset_path / "test.tfrecord") + try: + next(generator) + except Exception as e: + assert isinstance(e, StopIteration)