Skip to content

Commit

Permalink
Multitrack class (#304)
Browse files Browse the repository at this point in the history
* MultiTrack class

* check for unequal lengths and sample rates, add option to concatenate

* mix --> average

* bump version

* updated the multitrack api

* add tests for future datasets

* track_audio_attribute --> track_audio_property

Co-authored-by: Rachel Bittner <[email protected]>
  • Loading branch information
rabitt and Rachel Bittner authored Oct 28, 2020
1 parent 434775e commit ef952b6
Show file tree
Hide file tree
Showing 5 changed files with 538 additions and 90 deletions.
53 changes: 53 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,59 @@ class Track(track.Track):
# -- see the documentation for `jams_utils.jams_converter for all fields


# -- if the dataset contains multitracks, you can define a MultiTrack similar to a Track
# -- you can delete the block of code below if the dataset has no multitracks
class MultiTrack(track.MultiTrack):
"""Example multitrack class
Args:
mtrack_id (str): multitrack id
data_home (str): Local path where the dataset is stored.
If `None`, looks for the data in the default directory, `~/mir_datasets/Example`
Attributes:
mtrack_id (str): track id
tracks (dict): {track_id: Track}
track_audio_attribute (str): the name of the attribute of Track which
returns the audio to be mixed
# -- Add any of the dataset specific attributes here
"""
def __init__(self, mtrack_id, data_home):
self.mtrack_id = mtrack_id
self._data_home = data_home
# these three attributes below must have exactly these names
self.track_ids = [...] # define which track_ids should be part of the multitrack
self.tracks = {t: Track(t, self._data_home) for t in track_ids}
self.track_audio_property = "audio" # the property of Track which returns the relevant audio file for mixing

# -- optionally add any multitrack specific attributes here
self.mix_path = ... # this can be called whatever makes sense for the datasets
self.annotation_path = ...

# -- multitracks can optionally have mix-level cached properties and properties
@utils.cached_property
def annotation(self):
"""output type: description of output"""
return load_annotation(self.annotation_path)

@property
def audio(self):
"""(np.ndarray, float): DESCRIPTION audio signal, sample rate"""
return load_audio(self.audio_path)

# -- multitrack objects are themselves Tracks, and also need a to_jams method
# -- for any mixture-level annotations
def to_jams(self):
"""Jams: the track's data in jams format"""
return jams_utils.jams_converter(
audio_path=self.mix_path,
annotation_data=[(self.annotation, None)],
...
)
# -- see the documentation for `jams_utils.jams_converter for all fields


def load_audio(audio_path):
"""Load a Example audio file.
Expand Down
128 changes: 122 additions & 6 deletions mirdata/track.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
# -*- coding: utf-8 -*-
"""track object utility functions
"""


import types

import numpy as np

MAX_STR_LEN = 100


class Track(object):
def __repr__(self):
properties = [v for v in dir(self.__class__) if not v.startswith('_')]
properties = [v for v in dir(self.__class__) if not v.startswith("_")]
attributes = [
v for v in dir(self) if not v.startswith('_') and v not in properties
v for v in dir(self) if not v.startswith("_") and v not in properties
]

repr_str = "Track(\n"
Expand All @@ -21,7 +21,7 @@ def __repr__(self):
val = getattr(self, attr)
if isinstance(val, str):
if len(val) > MAX_STR_LEN:
val = '...{}'.format(val[-MAX_STR_LEN:])
val = "...{}".format(val[-MAX_STR_LEN:])
val = '"{}"'.format(val)
repr_str += " {}={},\n".format(attr, val)

Expand All @@ -33,11 +33,127 @@ def __repr__(self):
if val.__doc__ is None:
raise ValueError("{} has no documentation".format(prop))

val_type_str = val.__doc__.split(':')[0]
val_type_str = val.__doc__.split(":")[0]
repr_str += " {}: {},\n".format(prop, val_type_str)

repr_str += ")"
return repr_str

def to_jams(self):
raise NotImplementedError


class MultiTrack(Track):
"""MultiTrack class.
A multitrack class is a collection of track objects and their associated audio
that can be mixed together.
A multitrack is iteslf a Track, and can have its own associated audio (such as
a mastered mix), its own metadata and its own annotations.
"""

def _check_mixable(self):
if not hasattr(self, "tracks") or not hasattr(self, "track_audio_property"):
raise NotImplementedError(
"This MultiTrack has no tracks/track_audio_property. Cannot perform mixing"
)

def get_target(self, track_keys, weights=None, average=True, enforce_length=True):
"""Get target which is a linear mixture of tracks
Args:
track_keys (list): list of track keys to mix together
weights (list or None): list of positive scalars to be used in the average
average (bool): if True, computes a weighted average of the tracks
if False, computes a weighted sum of the tracks
enforce_length (bool): If True, raises ValueError if the tracks are
not the same length. If False, pads audio with zeros to match the length
of the longest track
Returns:
target (np.ndarray): target audio with shape (n_channels, n_samples)
Raises:
ValueError:
if sample rates of the tracks are not equal
if enforce_length=True and lengths are not equal
"""
self._check_mixable()
signals = []
lengths = []
sample_rates = []
for k in track_keys:
audio, sample_rate = getattr(self.tracks[k], self.track_audio_property)
# ensure all signals are shape (n_channels, n_samples)
if len(audio.shape) == 1:
audio = audio[np.newaxis, :]
signals.append(audio)
lengths.append(audio.shape[1])
sample_rates.append(sample_rate)

if len(set(sample_rates)) > 1:
raise ValueError(
"Sample rates for tracks {} are not equal: {}".format(
track_keys, sample_rates
)
)

max_length = np.max(lengths)
if any([l != max_length for l in lengths]):
if enforce_length:
raise ValueError(
"Track's {} audio are not the same length {}. Use enforce_length=False to pad with zeros.".format(
track_keys, lengths
)
)
else:
# pad signals to the max length
signals = [
np.pad(signal, ((0, 0), (0, max_length - signal.shape[1])))
for signal in signals
]

if weights is None:
weights = np.ones((len(track_keys),))

target = np.average(signals, axis=0, weights=weights)
if not average:
target *= np.sum(weights)

return target

def get_random_target(self, n_tracks=None, min_weight=0.3, max_weight=1.0):
"""Get a random target by combining a random selection of tracks with random weights
Args:
n_tracks (int or None): number of tracks to randomly mix. If None, uses all tracks
min_weight (float): minimum possible weight when mixing
max_weight (float): maximum possible weight when mixing
Returns:
target (np.ndarray): mixture audio with shape (n_samples, n_channels)
tracks (list): list of keys of included tracks
weights (list): list of weights used to mix tracks
"""
self._check_mixable()
tracks = list(self.tracks.keys())
if n_tracks is not None and n_tracks < len(tracks):
tracks = np.random.choice(tracks, n_tracks, replace=False)

weights = np.random.uniform(low=min_weight, high=max_weight, size=len(tracks))
target = self.get_target(tracks, weights=weights)
return target, tracks, weights

def get_mix(self):
"""Create a linear mixture given a subset of tracks.
Args:
track_keys (list): list of track keys to mix together
Returns:
target (np.ndarray): mixture audio with shape (n_samples, n_channels)
"""
self._check_mixable()
return self.get_target(list(self.tracks.keys()))
4 changes: 2 additions & 2 deletions mirdata/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
# -*- coding: utf-8 -*-
"""Version info"""

short_version = '0.2'
version = '0.2.0'
short_version = "0.2"
version = "0.2.1"
Loading

0 comments on commit ef952b6

Please sign in to comment.