diff --git a/MANIFEST.in b/MANIFEST.in index b2858a4..574c238 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,6 +1,6 @@ include *.txt tox.ini *.rst *.md LICENSE include catalog-info.yaml include Dockerfile .dockerignore -recursive-include tests *.py *.wav *.npz *.jams *.zip +recursive-include tests *.py *.wav *.npz *.jams *.zip *.mid *.flac *.yaml recursive-include basic_pitch *.py *.md recursive-include basic_pitch/saved_models *.index *.pb variables.data* *.mlmodel *.json *.onnx *.tflite *.bin diff --git a/basic_pitch/data/datasets/slakh.py b/basic_pitch/data/datasets/slakh.py index 7fc1f11..6801cdb 100644 --- a/basic_pitch/data/datasets/slakh.py +++ b/basic_pitch/data/datasets/slakh.py @@ -20,7 +20,7 @@ import os import time -from typing import List, Tuple +from typing import List, Tuple, Any import apache_beam as beam import mirdata @@ -34,13 +34,13 @@ class SlakhFilterInvalidTracks(beam.DoFn): def __init__(self, source: str): self.source = source - def setup(self): + def setup(self) -> None: import mirdata self.slakh_remote = mirdata.initialize("slakh", data_home=self.source) self.filesystem = beam.io.filesystems.FileSystems() - def process(self, element: Tuple[str, str]): + def process(self, element: Tuple[str, str]) -> Any: import tempfile import apache_beam as beam @@ -100,9 +100,8 @@ def __init__(self, source: str, download: bool) -> None: self.source = source self.download = download - def setup(self): + def setup(self) -> None: import apache_beam as beam - import os import mirdata self.slakh_remote = mirdata.initialize("slakh", data_home=self.source) @@ -110,7 +109,7 @@ def setup(self): if self.download: self.slakh_remote.download() - def process(self, element: List[str]): + def process(self, element: List[str]) -> List[Any]: import tempfile import numpy as np @@ -188,7 +187,7 @@ def create_input_data() -> List[Tuple[str, str]]: return [(track_id, track.data_split) for track_id, track in slakh.load_tracks().items()] -def main(known_args, pipeline_args): +def main(known_args: argparse.Namespace, pipeline_args: List[str]) -> None: time_created = int(time.time()) destination = commandline.resolve_destination(known_args, time_created) input_data = create_input_data() diff --git a/tests/data/test_slakh.py b/tests/data/test_slakh.py index 93c4245..08f7c76 100644 --- a/tests/data/test_slakh.py +++ b/tests/data/test_slakh.py @@ -61,16 +61,14 @@ def test_slakh_to_tf_example(tmpdir: str) -> None: def test_slakh_invalid_tracks(tmpdir: str) -> None: split_labels = ["train", "validation", "test"] - input_data = [(TRAIN_PIANO_TRACK_ID, "train"), - (VALID_PIANO_TRACK_ID, "validation"), - (TEST_PIANO_TRACK_ID, "test")] + input_data = [(TRAIN_PIANO_TRACK_ID, "train"), (VALID_PIANO_TRACK_ID, "validation"), (TEST_PIANO_TRACK_ID, "test")] with TestPipeline() as p: splits = ( p | "Create PCollection" >> beam.Create(input_data) - | "Tag it" >> beam.ParDo( - SlakhFilterInvalidTracks(str(RESOURCES_PATH / "data" / "slakh"))).with_outputs(*split_labels) + | "Tag it" + >> beam.ParDo(SlakhFilterInvalidTracks(str(RESOURCES_PATH / "data" / "slakh"))).with_outputs(*split_labels) ) for split in split_labels: @@ -87,15 +85,14 @@ def test_slakh_invalid_tracks(tmpdir: str) -> None: def test_slakh_invalid_tracks_omitted(tmpdir: str) -> None: split_labels = ["train", "omitted"] - input_data = [(TRAIN_PIANO_TRACK_ID, "train"), - (OMITTED_PIANO_TRACK_ID, "omitted")] + input_data = [(TRAIN_PIANO_TRACK_ID, "train"), (OMITTED_PIANO_TRACK_ID, "omitted")] with TestPipeline() as p: splits = ( p | "Create PCollection" >> beam.Create(input_data) - | "Tag it" >> beam.ParDo( - SlakhFilterInvalidTracks(str(RESOURCES_PATH / "data" / "slakh"))).with_outputs(*split_labels) + | "Tag it" + >> beam.ParDo(SlakhFilterInvalidTracks(str(RESOURCES_PATH / "data" / "slakh"))).with_outputs(*split_labels) ) for split in split_labels: @@ -114,16 +111,14 @@ def test_slakh_invalid_tracks_omitted(tmpdir: str) -> None: def test_slakh_invalid_tracks_drums(tmpdir: str) -> None: split_labels = ["train", "validation", "test"] - input_data = [(TRAIN_DRUMS_TRACK_ID, "train"), - (VALID_DRUMS_TRACK_ID, "validation"), - (TEST_DRUMS_TRACK_ID, "test")] + input_data = [(TRAIN_DRUMS_TRACK_ID, "train"), (VALID_DRUMS_TRACK_ID, "validation"), (TEST_DRUMS_TRACK_ID, "test")] with TestPipeline() as p: splits = ( p | "Create PCollection" >> beam.Create(input_data) - | "Tag it" >> beam.ParDo( - SlakhFilterInvalidTracks(str(RESOURCES_PATH / "data" / "slakh"))).with_outputs(*split_labels) + | "Tag it" + >> beam.ParDo(SlakhFilterInvalidTracks(str(RESOURCES_PATH / "data" / "slakh"))).with_outputs(*split_labels) ) for split in split_labels: