Skip to content

Commit

Permalink
Merge pull request #133 from spotify/bgenchel/add-slakh
Browse files Browse the repository at this point in the history
Add Slakh
  • Loading branch information
drubinstein authored Aug 7, 2024
2 parents 6cfc090 + 6f3d2fd commit 7147db2
Show file tree
Hide file tree
Showing 13 changed files with 967 additions and 60 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/tox.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,13 @@ jobs:
python-version: ${{ matrix.py }}
- uses: actions/checkout@v3
- name: Install soundlibs Ubuntu
run: sudo apt-get update && sudo apt-get install --no-install-recommends -y --fix-missing pkg-config libsndfile1 sox
run: sudo apt-get update && sudo apt-get install --no-install-recommends -y --fix-missing pkg-config libsndfile1 sox ffmpeg
if: matrix.os == 'ubuntu-latest'
- name: Install soundlibs MacOs
run: brew install libsndfile llvm libomp sox
run: brew install libsndfile llvm libomp sox ffmpeg
if: matrix.os == 'macos-latest-xlarge'
- name: Install soundlibs Windows
run: choco install libsndfile sox.portable
run: choco install libsndfile sox.portable flac ffmpeg
if: matrix.os == 'windows-latest'
- name: Upgrade pip
run: python -m pip install -U pip
Expand Down
2 changes: 1 addition & 1 deletion MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
include *.txt tox.ini *.rst *.md LICENSE
include catalog-info.yaml
include Dockerfile .dockerignore
recursive-include tests *.py *.wav *.npz *.jams *.zip *.midi *.csv *.json
recursive-include tests *.py *.wav *.npz *.jams *.zip *.mid *.flac *.yaml *.json
recursive-include basic_pitch *.py *.md
recursive-include basic_pitch/saved_models *.index *.pb variables.data* *.mlmodel *.json *.onnx *.tflite *.bin
223 changes: 223 additions & 0 deletions basic_pitch/data/datasets/slakh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
#!/usr/bin/env python
# encoding: utf-8
#
# Copyright 2024 Spotify AB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import logging
import os
import time

from typing import List, Tuple, Any

import apache_beam as beam
import mirdata

from basic_pitch.data import commandline, pipeline


class SlakhFilterInvalidTracks(beam.DoFn):
DOWNLOAD_ATTRIBUTES = ["audio_path", "metadata_path", "midi_path"]

def __init__(self, source: str):
self.source = source

def setup(self) -> None:
import mirdata

self.slakh_remote = mirdata.initialize("slakh", data_home=self.source)
self.filesystem = beam.io.filesystems.FileSystems()

def process(self, element: Tuple[str, str]) -> Any:
import tempfile

import apache_beam as beam
import ffmpeg

from basic_pitch.constants import (
AUDIO_N_CHANNELS,
AUDIO_SAMPLE_RATE,
)

track_id, split = element
if split == "omitted":
return None

logging.info(f"Processing (track_id, split): ({track_id}, {split})")

track_remote = self.slakh_remote.track(track_id)

with tempfile.TemporaryDirectory() as local_tmp_dir:
slakh_local = mirdata.initialize("slakh", local_tmp_dir)
track_local = slakh_local.track(track_id)

for attr in self.DOWNLOAD_ATTRIBUTES:
source = getattr(track_remote, attr)
dest = getattr(track_local, attr)
if not dest:
return None
logging.info(f"Downloading {attr} from {source} to {dest}")
os.makedirs(os.path.dirname(dest), exist_ok=True)
with self.filesystem.open(source) as s, open(dest, "wb") as d:
d.write(s.read())

if track_local.is_drum:
return None

local_wav_path = "{}_tmp.wav".format(track_local.audio_path)
try:
ffmpeg.input(track_local.audio_path).output(
local_wav_path, ar=AUDIO_SAMPLE_RATE, ac=AUDIO_N_CHANNELS
).run()
except Exception as e:
logging.info(f"Could not process {local_wav_path}. Exception: {e}")
return None

# if there are no notes, skip this track
if track_local.notes is None or len(track_local.notes.intervals) == 0:
return None

yield beam.pvalue.TaggedOutput(split, track_id)


class SlakhToTfExample(beam.DoFn):
DOWNLOAD_ATTRIBUTES = ["audio_path", "metadata_path", "midi_path"]

def __init__(self, source: str, download: bool) -> None:
self.source = source
self.download = download

def setup(self) -> None:
import apache_beam as beam
import mirdata

self.slakh_remote = mirdata.initialize("slakh", data_home=self.source)
self.filesystem = beam.io.filesystems.FileSystems() # TODO: replace with fsspec
if self.download:
self.slakh_remote.download()

def process(self, element: List[str]) -> List[Any]:
import tempfile

import numpy as np
import ffmpeg

from basic_pitch.constants import (
AUDIO_N_CHANNELS,
AUDIO_SAMPLE_RATE,
FREQ_BINS_CONTOURS,
FREQ_BINS_NOTES,
ANNOTATION_HOP,
N_FREQ_BINS_NOTES,
N_FREQ_BINS_CONTOURS,
)
from basic_pitch.data import tf_example_serialization

logging.info(f"Processing {element}")
batch = []

for track_id in element:
track_remote = self.slakh_remote.track(track_id)

with tempfile.TemporaryDirectory() as local_tmp_dir:
slakh_local = mirdata.initialize("slakh", local_tmp_dir)
track_local = slakh_local.track(track_id)

for attr in self.DOWNLOAD_ATTRIBUTES:
source = getattr(track_remote, attr)
dest = getattr(track_local, attr)
logging.info(f"Downloading {attr} from {source} to {dest}")
os.makedirs(os.path.dirname(dest), exist_ok=True)
with self.filesystem.open(source) as s, open(dest, "wb") as d:
d.write(s.read())

local_wav_path = "{}_tmp.wav".format(track_local.audio_path)
ffmpeg.input(track_local.audio_path).output(
local_wav_path, ar=AUDIO_SAMPLE_RATE, ac=AUDIO_N_CHANNELS
).run()

duration = float(ffmpeg.probe(local_wav_path)["format"]["duration"])
time_scale = np.arange(0, duration + ANNOTATION_HOP, ANNOTATION_HOP)
n_time_frames = len(time_scale)

note_indices, note_values = track_local.notes.to_sparse_index(time_scale, "s", FREQ_BINS_NOTES, "hz")
onset_indices, onset_values = track_local.notes.to_sparse_index(
time_scale, "s", FREQ_BINS_NOTES, "hz", onsets_only=True
)
contour_indices, contour_values = track_local.multif0.to_sparse_index(
time_scale, "s", FREQ_BINS_CONTOURS, "hz"
)

batch.append(
tf_example_serialization.to_transcription_tfexample(
track_id,
"slakh",
local_wav_path,
note_indices,
note_values,
onset_indices,
onset_values,
contour_indices,
contour_values,
(n_time_frames, N_FREQ_BINS_NOTES),
(n_time_frames, N_FREQ_BINS_CONTOURS),
)
)

logging.info(f"Finished processing batch of length {len(batch)}")
return [batch]


def create_input_data() -> List[Tuple[str, str]]:
slakh = mirdata.initialize("slakh")
return [(track_id, track.data_split) for track_id, track in slakh.load_tracks().items()]


def main(known_args: argparse.Namespace, pipeline_args: List[str]) -> None:
time_created = int(time.time())
destination = commandline.resolve_destination(known_args, time_created)
input_data = create_input_data()

pipeline_options = {
"runner": known_args.runner,
"job_name": f"slakh-tfrecords-{time_created}",
"machine_type": "e2-standard-4",
"num_workers": 25,
"disk_size_gb": 128,
"experiments": ["use_runner_v2"],
"save_main_session": True,
"sdk_container_image": known_args.sdk_container_image,
"job_endpoint": known_args.job_endpoint,
"environment_type": "DOCKER",
"environment_config": known_args.sdk_container_image,
}
pipeline.run(
pipeline_options,
pipeline_args,
input_data,
SlakhToTfExample(known_args.source, download=True),
SlakhFilterInvalidTracks(known_args.source),
destination,
known_args.batch_size,
)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
commandline.add_default(parser, os.path.basename(os.path.splitext(__file__)[0]))
commandline.add_split(parser)
known_args, pipeline_args = parser.parse_known_args() # sys.argv)

main(known_args, pipeline_args)
3 changes: 2 additions & 1 deletion basic_pitch/data/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,17 @@
from basic_pitch.data.datasets.ikala import main as ikala_main
from basic_pitch.data.datasets.maestro import main as maestro_main
from basic_pitch.data.datasets.medleydb_pitch import main as medleydb_pitch_main
from basic_pitch.data.datasets.slakh import main as slakh_main

logger = logging.getLogger()
logger.setLevel(logging.INFO)


DATASET_DICT = {
"guitarset": guitarset_main,
"ikala": ikala_main,
"maestro": maestro_main,
"medleydb_pitch": medleydb_pitch_main,
"slakh": slakh_main,
}


Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ data = [
"apache_beam",
"mirdata",
"smart_open",
"sox"
"sox",
"ffmpeg-python"
]
test = [
"basic_pitch[data]",
Expand Down
59 changes: 6 additions & 53 deletions tests/data/test_maestro.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import numpy as np
import os
import pathlib
import wave

from mido import MidiFile, MidiTrack, Message
from typing import List

import apache_beam as beam
Expand All @@ -33,6 +29,8 @@
)
from basic_pitch.data.pipeline import WriteBatchToTfRecord

from utils import create_mock_wav, create_mock_midi

RESOURCES_PATH = pathlib.Path(__file__).parent.parent / "resources"
MAESTRO_TEST_DATA_PATH = RESOURCES_PATH / "data" / "maestro"

Expand All @@ -42,57 +40,13 @@
GT_15M_TRACK_ID = "2004/MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_05_Track05_wav"


def create_mock_wav(output_fpath: str, duration_min: int) -> None:
duration_seconds = duration_min * 60
sample_rate = 44100
n_channels = 2 # Stereo
sampwidth = 2 # 2 bytes per sample (16-bit audio)

# Generate a silent audio data array
num_samples = duration_seconds * sample_rate
audio_data = np.zeros((num_samples, n_channels), dtype=np.int16)

# Create the WAV file
with wave.open(str(output_fpath), "w") as wav_file:
wav_file.setnchannels(n_channels)
wav_file.setsampwidth(sampwidth)
wav_file.setframerate(sample_rate)
wav_file.writeframes(audio_data.tobytes())

logging.info(f"Mock {duration_min}-minute WAV file '{output_fpath}' created successfully.")


def create_mock_midi(output_fpath: str) -> None:
# Create a new MIDI file with one track
mid = MidiFile()
track = MidiTrack()
mid.tracks.append(track)

# Define a sequence of notes (time, type, note, velocity)
notes = [
(0, "note_on", 60, 64), # C4
(500, "note_off", 60, 64),
(0, "note_on", 62, 64), # D4
(500, "note_off", 62, 64),
]

# Add the notes to the track
for time, type, note, velocity in notes:
track.append(Message(type, note=note, velocity=velocity, time=time))

# Save the MIDI file
mid.save(output_fpath)

logging.info(f"Mock MIDI file '{output_fpath}' created successfully.")


def test_maestro_to_tf_example(tmp_path: pathlib.Path) -> None:
mock_maestro_home = tmp_path / "maestro"
mock_maestro_ext = mock_maestro_home / "2004"
mock_maestro_ext.mkdir(parents=True, exist_ok=True)

create_mock_wav(str(mock_maestro_ext / (TRAIN_TRACK_ID.split("/")[1] + ".wav")), 3)
create_mock_midi(str(mock_maestro_ext / (TRAIN_TRACK_ID.split("/")[1] + ".midi")))
create_mock_wav(mock_maestro_ext / f"{TRAIN_TRACK_ID.split('/')[1]}.wav", 3)
create_mock_midi(mock_maestro_ext / f"{TRAIN_TRACK_ID.split('/')[1]}.midi")

output_dir = tmp_path / "outputs"
output_dir.mkdir(parents=True, exist_ok=True)
Expand Down Expand Up @@ -122,7 +76,7 @@ def test_maestro_invalid_tracks(tmp_path: pathlib.Path) -> None:
input_data = [(TRAIN_TRACK_ID, "train"), (VALID_TRACK_ID, "validation"), (TEST_TRACK_ID, "test")]

for track_id, _ in input_data:
create_mock_wav(str(mock_maestro_ext / (track_id.split("/")[1] + ".wav")), 3)
create_mock_wav(mock_maestro_ext / f"{track_id.split('/')[1]}.wav", 3)

split_labels = set([e[1] for e in input_data])
with TestPipeline() as p:
Expand Down Expand Up @@ -154,8 +108,7 @@ def test_maestro_invalid_tracks_over_15_min(tmp_path: pathlib.Path) -> None:
mock_maestro_ext = mock_maestro_home / "2004"
mock_maestro_ext.mkdir(parents=True, exist_ok=True)

mock_fpath = mock_maestro_ext / (GT_15M_TRACK_ID.split("/")[1] + ".wav")
create_mock_wav(str(mock_fpath), 16)
create_mock_wav(mock_maestro_ext / f"{GT_15M_TRACK_ID.split('/')[1]}.wav", 16)

input_data = [(GT_15M_TRACK_ID, "train")]
split_labels = set([e[1] for e in input_data])
Expand Down
Loading

0 comments on commit 7147db2

Please sign in to comment.