Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added medleydb dataset, test file, and include in download #131

Merged
merged 2 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions basic_pitch/data/datasets/guitarset.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,17 +136,19 @@ def create_input_data(
if seed:
random.seed(seed)

def determine_split() -> str:
partition = random.uniform(0, 1)
if partition < validation_bound:
def determine_split(index: int) -> str:
if index < len(track_ids) * validation_bound:
return "train"
if partition < test_bound:
elif index < len(track_ids) * test_bound:
return "validation"
return "test"
else:
return "test"

guitarset = mirdata.initialize("guitarset")
track_ids = guitarset.track_ids
random.shuffle(track_ids)

return [(track_id, determine_split()) for track_id in guitarset.track_ids]
return [(track_id, determine_split(i)) for i, track_id in enumerate(track_ids)]


def main(known_args: argparse.Namespace, pipeline_args: List[str]) -> None:
Expand Down
16 changes: 6 additions & 10 deletions basic_pitch/data/datasets/ikala.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,21 +138,17 @@ def process(self, element: List[str], *args: Tuple[Any, Any], **kwargs: Dict[str
def create_input_data(train_percent: float, seed: Optional[int] = None) -> List[Tuple[str, str]]:
assert train_percent < 1.0, "Don't over allocate the data!"

# Test percent is 1 - train - validation
validation_bound = train_percent

if seed:
random.seed(seed)

def determine_split() -> str:
partition = random.uniform(0, 1)
if partition < validation_bound:
return "train"
return "validation"

ikala = mirdata.initialize("ikala")
track_ids = ikala.track_ids
random.shuffle(track_ids)

def determine_split(index: int) -> str:
return "train" if index < len(track_ids) * train_percent else "validation"

return [(track_id, determine_split()) for track_id in ikala.track_ids]
return [(track_id, determine_split(i)) for i, track_id in enumerate(track_ids)]


def main(known_args: argparse.Namespace, pipeline_args: List[str]) -> None:
Expand Down
187 changes: 187 additions & 0 deletions basic_pitch/data/datasets/medleydb_pitch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
#!/usr/bin/env python
# encoding: utf-8
#
# Copyright 2024 Spotify AB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import logging
import os
import random
import time
from typing import Any, Dict, List, Optional, Tuple

import apache_beam as beam
import mirdata

from basic_pitch.data import commandline, pipeline


class MedleyDbPitchInvalidTracks(beam.DoFn):
def process(self, element: Tuple[str, str], *args: Tuple[Any, Any], **kwargs: Dict[str, Any]) -> Any:
track_id, split = element
yield beam.pvalue.TaggedOutput(split, track_id)


class MedleyDbPitchToTfExample(beam.DoFn):
DOWNLOAD_ATTRIBUTES = ["audio_path", "notes_pyin_path", "pitch_path"]

def __init__(self, source: str, download: bool) -> None:
self.source = source
self.download = download

def setup(self) -> None:
import apache_beam as beam
import mirdata

self.medleydb_pitch_remote = mirdata.initialize("medleydb_pitch", data_home=self.source)
self.filesystem = beam.io.filesystems.FileSystems() # TODO: replace with fsspec
if self.download:
self.medleydb_pitch_remote.download()

def process(self, element: List[str], *args: Tuple[Any, Any], **kwargs: Dict[str, Any]) -> List[Any]:
import tempfile

import numpy as np
import sox

from basic_pitch.constants import (
AUDIO_N_CHANNELS,
AUDIO_SAMPLE_RATE,
FREQ_BINS_CONTOURS,
FREQ_BINS_NOTES,
ANNOTATION_HOP,
N_FREQ_BINS_NOTES,
N_FREQ_BINS_CONTOURS,
)
from basic_pitch.dataset import tf_example_serialization

logging.info(f"Processing {element}")
batch = []

for track_id in element:
track_remote = self.medleydb_pitch_remote.track(track_id)

with tempfile.TemporaryDirectory() as local_tmp_dir:
medleydb_pitch_local = mirdata.initialize("medleydb_pitch", local_tmp_dir)
track_local = medleydb_pitch_local.track(track_id)

for attr in self.DOWNLOAD_ATTRIBUTES:
source = getattr(track_remote, attr)
dest = getattr(track_local, attr)
os.makedirs(os.path.dirname(dest), exist_ok=True)
with self.filesystem.open(source) as s, open(dest, "wb") as d:
d.write(s.read())

# will be in temp dir and get cleaned up
local_wav_path = "{}_tmp.wav".format(track_local.audio_path)
tfm = sox.Transformer()
tfm.rate(AUDIO_SAMPLE_RATE)
tfm.channels(AUDIO_N_CHANNELS)
tfm.build(track_local.audio_path, local_wav_path)

duration = sox.file_info.duration(local_wav_path)
time_scale = np.arange(0, duration + ANNOTATION_HOP, ANNOTATION_HOP)
n_time_frames = len(time_scale)

if track_local.notes_pyin is not None:
note_indices, note_values = track_local.notes_pyin.to_sparse_index(
time_scale, "s", FREQ_BINS_NOTES, "hz"
)
onset_indices, onset_values = track_local.notes_pyin.to_sparse_index(
time_scale, "s", FREQ_BINS_NOTES, "hz", onsets_only=True
)
note_shape = (n_time_frames, N_FREQ_BINS_NOTES)
# if there are no notes, return empty note indices
else:
note_shape = (0, 0)
note_indices = []
onset_indices = []
note_values = []
onset_values = []

contour_indices, contour_values = track_local.pitch.to_sparse_index(
time_scale, "s", FREQ_BINS_CONTOURS, "hz"
)

batch.append(
tf_example_serialization.to_transcription_tfexample(
track_id,
"medleydb_pitch",
local_wav_path,
note_indices,
note_values,
onset_indices,
onset_values,
contour_indices,
contour_values,
note_shape,
(n_time_frames, N_FREQ_BINS_CONTOURS),
)
)
return [batch]


def create_input_data(train_percent: float, seed: Optional[int] = None) -> List[Tuple[str, str]]:
assert train_percent < 1.0, "Don't over allocate the data!"

if seed:
random.seed(seed)

medleydb_pitch = mirdata.initialize("medleydb_pitch")
track_ids = medleydb_pitch.track_ids
random.shuffle(track_ids)

def determine_split(index: int) -> str:
return "train" if index < len(track_ids) * train_percent else "validation"

return [(track_id, determine_split(i)) for i, track_id in enumerate(track_ids)]


def main(known_args: argparse.Namespace, pipeline_args: List[str]) -> None:
time_created = int(time.time())
destination = commandline.resolve_destination(known_args, time_created)
input_data = create_input_data(known_args.train_percent, known_args.split_seed)

pipeline_options = {
"runner": known_args.runner,
"job_name": f"medleydb-pitch-tfrecords-{time_created}",
"machine_type": "e2-standard-4",
"num_workers": 25,
"disk_size_gb": 128,
"experiments": ["use_runner_v2"],
"save_main_session": True,
"sdk_container_image": known_args.sdk_container_image,
"job_endpoint": known_args.job_endpoint,
"environment_type": "DOCKER",
"environment_config": known_args.sdk_container_image,
}
pipeline.run(
pipeline_options,
pipeline_args,
input_data,
MedleyDbPitchToTfExample(known_args.source, download=True),
MedleyDbPitchInvalidTracks(),
destination,
known_args.batch_size,
)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
commandline.add_default(parser, os.path.basename(os.path.splitext(__file__)[0]))
commandline.add_split(parser)
known_args, pipeline_args = parser.parse_known_args()

main(known_args, pipeline_args)
3 changes: 2 additions & 1 deletion basic_pitch/data/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@
from basic_pitch.data import commandline
from basic_pitch.data.datasets.guitarset import main as guitarset_main
from basic_pitch.data.datasets.ikala import main as ikala_main
from basic_pitch.data.datasets.medleydb_pitch import main as medleydb_pitch_main

logger = logging.getLogger()
logger.setLevel(logging.INFO)

DATASET_DICT = {"guitarset": guitarset_main, "ikala": ikala_main}
DATASET_DICT = {"guitarset": guitarset_main, "ikala": ikala_main, "medleydb_pitch": medleydb_pitch_main}


def main() -> None:
Expand Down
68 changes: 68 additions & 0 deletions tests/data/test_medleydb_pitch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#!/usr/bin/env python
# encoding: utf-8
#
# Copyright 2024 Spotify AB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import apache_beam as beam
import itertools
import os

from apache_beam.testing.test_pipeline import TestPipeline

from basic_pitch.data.datasets.medleydb_pitch import (
MedleyDbPitchInvalidTracks,
create_input_data,
)


# TODO: Create test_medleydb_pitch_to_tf_example


def test_medleydb_pitch_invalid_tracks(tmpdir: str) -> None:
split_labels = ["train", "validation"]
input_data = [(str(i), split) for i, split in enumerate(split_labels)]
with TestPipeline() as p:
splits = (
p
| "Create PCollection" >> beam.Create(input_data)
| "Tag it" >> beam.ParDo(MedleyDbPitchInvalidTracks()).with_outputs(*split_labels)
)

for split in split_labels:
(
getattr(splits, split)
| f"Write {split} to text"
>> beam.io.WriteToText(os.path.join(tmpdir, f"output_{split}.txt"), shard_name_template="")
)

for i, split in enumerate(split_labels):
with open(os.path.join(tmpdir, f"output_{split}.txt"), "r") as fp:
assert fp.read().strip() == str(i)


def test_medleydb_create_input_data() -> None:
data = create_input_data(train_percent=0.5)
data.sort(key=lambda el: el[1]) # sort by split
tolerance = 0.01
for _, group in itertools.groupby(data, lambda el: el[1]):
assert (0.5 - tolerance) * len(data) <= len(list(group)) <= (0.5 + tolerance) * len(data)


def test_create_input_data_overallocate() -> None:
try:
create_input_data(train_percent=1.1)
except AssertionError:
assert True
else:
assert False
Loading