Skip to content

Commit

Permalink
added files for downloading and converting datasets, and a guitarset …
Browse files Browse the repository at this point in the history
…to test the ecosystem. Tests are written but tox is currently not passing. Many files reformatted by black.
  • Loading branch information
bgenchel committed Mar 8, 2024
1 parent b6d1e0a commit 52603f7
Show file tree
Hide file tree
Showing 21 changed files with 80,213 additions and 169 deletions.
8 changes: 6 additions & 2 deletions basic_pitch/commandline_printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ def generating_file_message(output_type: str) -> None:
print(f"\n\n Creating {output_type.replace('_', ' ').lower()}...")


def file_saved_confirmation(output_type: str, save_path: Union[pathlib.Path, str]) -> None:
def file_saved_confirmation(
output_type: str, save_path: Union[pathlib.Path, str]
) -> None:
"""Print a confirmation that the file was saved succesfully
Args:
Expand All @@ -61,7 +63,9 @@ def failed_to_save(output_type: str, save_path: Union[pathlib.Path, str]) -> Non
save_path: The path to output file.
"""
print(f"\n🚨 Failed to save {output_type.replace('_', ' ').lower()} to {save_path} \n")
print(
f"\n🚨 Failed to save {output_type.replace('_', ' ').lower()} to {save_path} \n"
)


@contextmanager
Expand Down
12 changes: 9 additions & 3 deletions basic_pitch/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,17 @@
}


def _freq_bins(bins_per_semitone: int, base_frequency: float, n_semitones: int) -> np.array:
def _freq_bins(
bins_per_semitone: int, base_frequency: float, n_semitones: int
) -> np.array:
d = 2.0 ** (1.0 / (12 * bins_per_semitone))
bin_freqs = base_frequency * d ** np.arange(bins_per_semitone * n_semitones)
return bin_freqs


FREQ_BINS_NOTES = _freq_bins(NOTES_BINS_PER_SEMITONE, ANNOTATIONS_BASE_FREQUENCY, ANNOTATIONS_N_SEMITONES)
FREQ_BINS_CONTOURS = _freq_bins(CONTOURS_BINS_PER_SEMITONE, ANNOTATIONS_BASE_FREQUENCY, ANNOTATIONS_N_SEMITONES)
FREQ_BINS_NOTES = _freq_bins(
NOTES_BINS_PER_SEMITONE, ANNOTATIONS_BASE_FREQUENCY, ANNOTATIONS_N_SEMITONES
)
FREQ_BINS_CONTOURS = _freq_bins(
CONTOURS_BINS_PER_SEMITONE, ANNOTATIONS_BASE_FREQUENCY, ANNOTATIONS_N_SEMITONES
)
65 changes: 45 additions & 20 deletions basic_pitch/data/commandline.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
#!/usr/bin/env python
# encoding: utf-8
#
# Copyright 2022 Spotify AB
# Cos.pathyright 2022 Spotify AB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# You may obtain a cos.pathy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
Expand All @@ -17,41 +17,66 @@

import argparse
import os
import os.path as op

from typing import Optional

def add_default(parser: argparse.ArgumentParser, dataset_name: str):
parser.add_argument("source", nargs='?', default=op.join(op.expanduser('~'), 'mir_datasets'),
help="Source directory for mir data. Defaults to local mir_datasets folder.")
parser.add_argument("destination", nargs='?',
default=op.join(op.expanduser('~'), 'data', 'basic_pitch', dataset_name),
help="Output directory to write results to. Defaults to local ~/data/basic_pitch/{dataset}/")
parser.add_argument("--runner", choices=["DataflowRunner", "DirectRunner"], default="DirectRunner")
parser.add_argument("--timestamped", default=False, action="store_true",
help="If passed, the dataset will be put into a timestamp directory instead of 'splits'")
parser.add_argument("--batch-size", default=5, type=int, help="Number of examples per tfrecord")
parser.add_argument("--worker-harness-container-image", default="",
help="Container image to run dataset generation job with. \
Required due to non-python dependencies.")

def add_default(parser: argparse.ArgumentParser, dataset_name: str = "") -> None:
parser.add_argument(
"--source",
default=os.path.join(os.path.expanduser("~"), "mir_datasets", dataset_name),
help="Source directory for mir data. Defaults to local mir_datasets folder.",
)
parser.add_argument(
"--destination",
default=os.path.join(
os.path.expanduser("~"), "data", "basic_pitch", dataset_name
),
help="Output directory to write results to. Defaults to local ~/data/basic_pitch/{dataset}/",
)
parser.add_argument(
"--runner",
choices=["DataflowRunner", "DirectRunner"],
default="DirectRunner",
help="Whether to run the download and process locally or on GCP Dataflow",
)
parser.add_argument(
"--timestamped",
default=False,
action="store_true",
help="If passed, the dataset will be put into a timestamp directory instead of 'splits'",
)
parser.add_argument(
"--batch-size", default=5, type=int, help="Number of examples per tfrecord"
)
parser.add_argument(
"--worker-harness-container-image",
default="",
help="Container image to run dataset generation job with. \
Required due to non-python dependencies.",
)

def resolve_destination(namespace: argparse.Namespace, dataset: str, time_created: int) -> str:
return os.path.join(namespace.destination, str(time_created) if namespace.timestamped else "splits")

def resolve_destination(namespace: argparse.Namespace, time_created: int) -> str:
return os.path.join(
namespace.destination, str(time_created) if namespace.timestamped else "splits"
)


def add_split(
parser: argparse.ArgumentParser,
train_percent: float = 0.8,
validation_percent: float = 0.1,
split_seed: int = None,
split_seed: Optional[int] = None,
):
parser.add_argument(
"--train-percent",
type=float,
default=train_percent,
help="Percentage of tracks to mark as train",
)
parser.add_argument( "--validation-percent",
parser.add_argument(
"--validation-percent",
type=float,
default=validation_percent,
help="Percentage of tracks to mark as validation",
Expand Down
30 changes: 17 additions & 13 deletions basic_pitch/data/datasets/guitarset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import argparse
import logging
import os
import os.path as op
import random
import time
from typing import List, Tuple, Optional
Expand All @@ -28,11 +27,9 @@

from basic_pitch.data import commandline, pipeline

DIRNAME = "guitarset"


class GuitarSetInvalidTracks(beam.DoFn):
def process(self, element: Tuple[str, str]):
def process(self, element: Tuple[str, str], *args, **kwargs):
track_id, split = element
yield beam.pvalue.TaggedOutput(split, track_id)

Expand All @@ -41,17 +38,21 @@ class GuitarSetToTfExample(beam.DoFn):
DOWNLOAD_ATTRIBUTES = ["audio_mic_path", "jams_path"]

def __init__(self, source: str):
print(f"source_dir: {source}")
self.source = source

def setup(self):
import apache_beam as beam
import mirdata

self.guitarset_remote = mirdata.initialize("guitarset", data_home=os.path.join(self.source, DIRNAME))
self.guitarset_remote = mirdata.initialize("guitarset", data_home=self.source)
self.filesystem = beam.io.filesystems.FileSystems()
if (
type(self.filesystem.get_filesystem(self.source))
== beam.io.localfilesystem.LocalFileSystem
):
self.guitarset_remote.download()

def process(self, element: List[str]):
def process(self, element: List[str], *args, **kwargs):
import tempfile

import mirdata
Expand Down Expand Up @@ -80,10 +81,11 @@ def process(self, element: List[str]):

for attribute in self.DOWNLOAD_ATTRIBUTES:
source = getattr(track_remote, attribute)
print(f"source: {source}")
destination = getattr(track_local, attribute)
os.makedirs(os.path.dirname(destination), exist_ok=True)
with self.filesystem.open(source) as s, open(destination, "wb") as d:
with self.filesystem.open(source) as s, open(
destination, "wb"
) as d:
d.write(s.read())

local_wav_path = f"{track_local.audio_mic_path}_tmp.wav"
Expand Down Expand Up @@ -146,15 +148,17 @@ def determine_split() -> str:
return "test"

guitarset = mirdata.initialize("guitarset")
guitarset.download()
# guitarset.download()

return [(track_id, determine_split()) for track_id in guitarset.track_ids]


def main(known_args, pipeline_args):
time_created = int(time.time())
destination = commandline.resolve_destination(known_args, DIRNAME, time_created)
input_data = create_input_data(known_args.train_percent, known_args.validation_percent, known_args.split_seed)
destination = commandline.resolve_destination(known_args, time_created)
input_data = create_input_data(
known_args.train_percent, known_args.validation_percent, known_args.split_seed
)

pipeline_options = {
"runner": known_args.runner,
Expand All @@ -178,7 +182,7 @@ def main(known_args, pipeline_args):

if __name__ == "__main__":
parser = argparse.ArgumentParser()
commandline.add_default(parser, op.basename(op.splittext(__file__)[0]))
commandline.add_default(parser, os.path.basename(os.path.splitext(__file__)[0]))
commandline.add_split(parser)
known_args, pipeline_args = parser.parse_known_args()

Expand Down
24 changes: 20 additions & 4 deletions basic_pitch/data/download.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,41 @@
import argparse
import logging
import sys

from basic_pitch.data import commandline
from basic_pitch.data.datasets.guitarset import main as guitarset_main

logger = logging.getLogger()
logger.setLevel(logging.INFO)
handler = logging.StreamHandler(sys.stdout)
formatter = logging.Formatter("%(levelname)s:: %(message)s")
handler.setFormatter(formatter)
logger.addHandler(handler)

DATASET_DICT = {
'guitarset': guitarset_main,
"guitarset": guitarset_main,
}


def main():
dataset_parser = argparse.ArgumentParser()
dataset_parser.add_argument("dataset", choices=list(DATASET_DICT.keys()), help="The dataset to download / process.")
dataset_parser.add_argument(
"dataset",
choices=list(DATASET_DICT.keys()),
help="The dataset to download / process.",
)
args, remaining_args = dataset_parser.parse_known_args()
dataset = args.dataset
logger.info(f"Downloading and processing {dataset}")

print(f'got the arg: {dataset}')
cl_parser = argparse.ArgumentParser()
commandline.add_default(cl_parser, dataset)
commandline.add_split(cl_parser)
known_args, pipeline_args = cl_parser.parse_known_args(remaining_args)
for arg in vars(known_args):
logger.info(f"{arg} = {getattr(known_args, arg)}")
DATASET_DICT[dataset](known_args, pipeline_args)


if __name__ == '__main__':
if __name__ == "__main__":
main()
28 changes: 19 additions & 9 deletions basic_pitch/data/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import logging
import os
import uuid
from typing import Dict, List, Tuple
from typing import Dict, List, Tuple, Callable, Union

import apache_beam as beam
import tensorflow as tf
Expand All @@ -40,23 +40,28 @@ def __init__(self, destination):
self.destination = destination

def process(self, element):
if not isinstance(element, list):
element = [element]

logging.info(f"Writing to file batch of length {len(element)}")
# hopefully uuids are unique enough
with tf.io.TFRecordWriter(os.path.join(self.destination, f"{uuid.uuid4()}.tfrecord")) as writer:
with tf.io.TFRecordWriter(
os.path.join(self.destination, f"{uuid.uuid4()}.tfrecord")
) as writer:
for example in element:
writer.write(example.SerializeToString())


def transcription_dataset_writer(
pcoll,
p: beam.Pipeline,
input_data: List[Tuple[str, str]],
to_tf_example: beam.DoFn,
filter_invalid_tracks: beam.DoFn,
to_tf_example: Union[beam.DoFn, Callable],
filter_invalid_tracks: beam.PTransform,
destination: str,
batch_size: int,
):
valid_track_ids = (
pcoll
p
| "Create PCollection of track IDS" >> beam.Create(input_data)
| "Remove invalid track IDs"
>> beam.ParDo(filter_invalid_tracks).with_outputs(
Expand All @@ -73,9 +78,12 @@ def transcription_dataset_writer(
| f"Batch {split}" >> beam.ParDo(Batch(batch_size))
| f"Reshuffle {split}" >> beam.Reshuffle() # To prevent fuses
| f"Create tf.Example {split} batch" >> beam.ParDo(to_tf_example)
| f"Write {split} batch to tfrecord" >> beam.ParDo(WriteBatchToTfRecord(os.path.join(destination, split)))
| f"Write {split} batch to tfrecord"
>> beam.ParDo(WriteBatchToTfRecord(os.path.join(destination, split)))
)
getattr(valid_track_ids, split) | f"Write {split} index file" >> beam.io.textio.WriteToText(
getattr(
valid_track_ids, split
) | f"Write {split} index file" >> beam.io.textio.WriteToText(
os.path.join(destination, split, "index.csv"),
num_shards=1,
header="track_id",
Expand All @@ -92,4 +100,6 @@ def run(
batch_size: int,
):
with beam.Pipeline(options=PipelineOptions(**pipeline_options)) as p:
transcription_dataset_writer(p, input_data, to_tf_example, filter_invalid_tracks, destination, batch_size)
transcription_dataset_writer(
p, input_data, to_tf_example, filter_invalid_tracks, destination, batch_size
)
32 changes: 24 additions & 8 deletions basic_pitch/data/tf_example_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,30 @@ def _to_transcription_tfex(
"file_id": bytes_feature(bytes(file_id, "utf-8")),
"source": bytes_feature(bytes(source, "utf-8")),
"audio_wav": bytes_feature(encoded_wav),
"notes_indices": bytes_feature(tf.io.serialize_tensor(np.array(notes_indices, np.int64))),
"notes_values": bytes_feature(tf.io.serialize_tensor(np.array(notes_values, np.float32))),
"onsets_indices": bytes_feature(tf.io.serialize_tensor(np.array(onsets_indices, np.int64))),
"onsets_values": bytes_feature(tf.io.serialize_tensor(np.array(onsets_values, np.float32))),
"contours_indices": bytes_feature(tf.io.serialize_tensor(np.array(contours_indices, np.int64))),
"contours_values": bytes_feature(tf.io.serialize_tensor(np.array(contours_values, np.float32))),
"notes_onsets_shape": bytes_feature(tf.io.serialize_tensor(np.array(notes_onsets_shape, np.int64))),
"contours_shape": bytes_feature(tf.io.serialize_tensor(np.array(contours_shape, np.int64))),
"notes_indices": bytes_feature(
tf.io.serialize_tensor(np.array(notes_indices, np.int64))
),
"notes_values": bytes_feature(
tf.io.serialize_tensor(np.array(notes_values, np.float32))
),
"onsets_indices": bytes_feature(
tf.io.serialize_tensor(np.array(onsets_indices, np.int64))
),
"onsets_values": bytes_feature(
tf.io.serialize_tensor(np.array(onsets_values, np.float32))
),
"contours_indices": bytes_feature(
tf.io.serialize_tensor(np.array(contours_indices, np.int64))
),
"contours_values": bytes_feature(
tf.io.serialize_tensor(np.array(contours_values, np.float32))
),
"notes_onsets_shape": bytes_feature(
tf.io.serialize_tensor(np.array(notes_onsets_shape, np.int64))
),
"contours_shape": bytes_feature(
tf.io.serialize_tensor(np.array(contours_shape, np.int64))
),
}
)
)
Expand Down
Loading

0 comments on commit 52603f7

Please sign in to comment.