diff --git a/.github/workflows/tox.yml b/.github/workflows/tox.yml index b881f3e9..9f9efdc6 100644 --- a/.github/workflows/tox.yml +++ b/.github/workflows/tox.yml @@ -6,22 +6,26 @@ on: jobs: test: name: test ${{ matrix.py }} - ${{ matrix.os }} - runs-on: ${{ matrix.os }}-latest + runs-on: ${{ matrix.os }} strategy: fail-fast: false matrix: os: - - Ubuntu - - Windows - - MacOs + - ubuntu-latest + - windows-latest py: - "3.11" - "3.10" - "3.9" - "3.8" + include: + - os: macos-latest-xlarge + py: "3.10.11" + - os: macos-latest-xlarge + py: "3.11.8" steps: - name: Setup python for test ${{ matrix.py }} - uses: actions/setup-python@v3 + uses: actions/setup-python@v4 with: python-version: ${{ matrix.py }} - uses: actions/checkout@v3 @@ -29,7 +33,7 @@ jobs: run: sudo apt-get update && sudo apt-get install --no-install-recommends -y --fix-missing pkg-config libsndfile1 if: matrix.os == 'Ubuntu' - name: Install soundlibs MacOs - run: brew install libsndfile + run: brew install libsndfile llvm libomp if: matrix.os == 'MacOs' - name: Install soundlibs Windows run: choco install libsndfile @@ -41,8 +45,8 @@ jobs: # We will only check this on the minimum python version - name: Check formatting, lint and mypy run: tox -c tox.ini -e check-formatting,lint,mypy - if: matrix.py == '3.8' + if: matrix.py == '3.10' - name: Run test suite - run: tox -c tox.ini -e py,manifest + run: tox -c tox.ini -e py,manifest,full - name: Check that basic-pitch can be run as a commandline run: pip3 install -e . && basic-pitch --help diff --git a/MANIFEST.in b/MANIFEST.in index 16788a73..d2ed1c31 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,6 +1,5 @@ include *.txt tox.ini *.rst *.md LICENSE include catalog-info.yaml -recursive-include tests *.py *.wav +recursive-include tests *.py *.wav *.npz recursive-include basic_pitch *.py -recursive-include basic_pitch/saved_models/* -recursive-include basic_pitch *.index *.pb variables.data* +recursive-include basic_pitch/saved_models *.index *.pb variables.data* *.mlmodel *.json *.onnx *.tflite *.bin diff --git a/README.md b/README.md index 344977d3..00acb6e0 100644 --- a/README.md +++ b/README.md @@ -44,14 +44,31 @@ To update Basic Pitch to the latest version, add `--upgrade` to the above comman #### Compatible Environments: - MacOS, Windows and Ubuntu operating systems -- Python versions 3.7, 3.8, 3.9, 3.10 +- Python versions 3.7, 3.8, 3.9, 3.10, 3.11 - **For Mac M1 hardware, we currently only support python version 3.10. Otherwise, we suggest using a virtual machine.** +### Model Runtime + +Basic Pitch comes with the original TensorFlow model and the TensorFlow model converted to [CoreML](https://developer.apple.com/documentation/coreml), [TensorFlowLite](https://www.tensorflow.org/lite), and [ONNX](https://onnx.ai/). By default, Basic Pitch will _not_ install TensorFlow as a dependency *unless you are using Python>=3.11*. Instead, by default, CoreML will be installed on MacOS, TensorFlowLite will be installed on Linux and ONNX will be installed on Windows. If you want to install TensorFlow along with the default model inference runtime, you can install TensorFlow via `pip install basic-pitch[tf]`. + ## Usage ### Model Prediction +### Model Runtime + +By default, Basic Pitch will attempt to load a model in the following order: + +1. TensorFlow +2. CoreML +3. TensorFlowLite +4. ONNX + +Additionally, the module variable ICASSP_2022_MODEL_PATH will default to the first available version in the list. + +We will explain how to override this priority list below. Because all other model serializations were converted from TensorFlow, we recommend using TensorFlow when possible. N.B. Basic Pitch does not install TensorFlow by default to save the user time when installing and running Basic Pitch. + #### Command Line Tool This library offers a command line tool interface. A basic prediction command will generate and save a MIDI file transcription of audio at the `` to the ``: @@ -73,9 +90,11 @@ basic-pitch `: -- `--sonify-midi` to additionally save a `.wav` audio rendering of the MIDI file -- `--save-model-outputs` to additionally save raw model outputs as an NPZ file -- `--save-note-events` to additionally save the predicted note events as a CSV file +- `--sonify-midi` to additionally save a `.wav` audio rendering of the MIDI file. +- `--save-model-outputs` to additionally save raw model outputs as an NPZ file. +- `--save-note-events` to additionally save the predicted note events as a CSV file. + +If you want to use a non-default model type (e.g., use CoreML instead of TF), use the `--model-serialization` argument. The CLI will change the loaded model to the type you prefer. To discover more parameter control, run: ```bash @@ -100,6 +119,8 @@ model_output, midi_data, note_events = predict() - `midi_data` is the transcribed MIDI data derived from the `model_output` - `note_events` is a list of note events derived from the `model_output` +Note: As mentioned previously, ICASSP_2022_MODEL_PATH will default to the runtime first supported in the list TensorFlow, CoreML, TensorFlowLite, ONNX. + **predict() in a loop** To run prediction within a loop, you'll want to load the model yourself and provide `predict()` with the loaded model object itself to be used for repeated prediction calls, in order to avoid redundant and sluggish model loading. @@ -107,10 +128,10 @@ To run prediction within a loop, you'll want to load the model yourself and prov ```python import tensorflow as tf -from basic_pitch.inference import predict +from basic_pitch.inference import predict, Model from basic_pitch import ICASSP_2022_MODEL_PATH -basic_pitch_model = tf.saved_model.load(str(ICASSP_2022_MODEL_PATH)) +basic_pitch_model = Model(ICASSP_2022_MODEL_PATH)) for x in range(): ... diff --git a/basic_pitch/__init__.py b/basic_pitch/__init__.py index bde0d4a8..b80b2144 100644 --- a/basic_pitch/__init__.py +++ b/basic_pitch/__init__.py @@ -15,13 +15,81 @@ # See the License for the specific language governing permissions and # limitations under the License. +import enum +import logging import pathlib -__author__ = "Spotify" -__version__ = "0.3.0" -__email__ = "basic-pitch@spotify.com" -__demowebsite__ = "https://basicpitch.io" -__description__ = "Basic Pitch, a lightweight yet powerful audio-to-MIDI converter with pitch bend detection." -__url__ = "https://github.com/spotify/basic-pitch" -ICASSP_2022_MODEL_PATH = pathlib.Path(__file__).parent / "saved_models/icassp_2022/nmp" +try: + import coremltools + + CT_PRESENT = True +except ImportError: + CT_PRESENT = False + logging.warning( + "Coremltools is not installed. " + "If you plan to use a CoreML Saved Model, " + "reinstall basic-pitch with `pip install 'basic-pitch[coreml]'`" + ) + +try: + import tflite_runtime.interpreter + + TFLITE_PRESENT = True +except ImportError: + TFLITE_PRESENT = False + logging.warning( + "tflite-runtime is not installed. " + "If you plan to use a TFLite Model, " + "reinstall basic-pitch with `pip install 'basic-pitch tflite-runtime'` or " + "`pip install 'basic-pitch[tf]'" + ) + +try: + import onnxruntime + + ONNX_PRESENT = True +except ImportError: + ONNX_PRESENT = False + logging.warning( + "onnxruntime is not installed. " + "If you plan to use an ONNX Model, " + "reinstall basic-pitch with `pip install 'basic-pitch[onnx]'`" + ) + + +try: + import tensorflow + + TF_PRESENT = True +except ImportError: + TF_PRESENT = False + logging.warning( + "Tensorflow is not installed. " + "If you plan to use a TF Saved Model, " + "reinstall basic-pitch with `pip install 'basic-pitch[tf]'`" + ) + + +class FilenameSuffix(enum.Enum): + tf = "nmp" + coreml = "nmp.mlpackage" + tflite = "nmp.tflite" + onnx = "nmp.onnx" + + +if TF_PRESENT: + _default_model_type = FilenameSuffix.tf +elif CT_PRESENT: + _default_model_type = FilenameSuffix.coreml +elif TFLITE_PRESENT: + _default_model_type = FilenameSuffix.tflite +elif ONNX_PRESENT: + _default_model_type = FilenameSuffix.onnx + + +def build_icassp_2022_model_path(suffix: FilenameSuffix) -> pathlib.Path: + return pathlib.Path(__file__).parent / "saved_models/icassp_2022" / suffix.value + + +ICASSP_2022_MODEL_PATH = build_icassp_2022_model_path(_default_model_type) diff --git a/basic_pitch/inference.py b/basic_pitch/inference.py index be521a15..7d79cf81 100644 --- a/basic_pitch/inference.py +++ b/basic_pitch/inference.py @@ -18,12 +18,37 @@ import csv import enum import json +import logging import os import pathlib -from typing import Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union, cast + + +from basic_pitch import CT_PRESENT, ONNX_PRESENT, TF_PRESENT, TFLITE_PRESENT + +try: + import tensorflow as tf +except ImportError: + pass + +try: + import coremltools as ct +except ImportError: + pass + +try: + import tflite_runtime.interpreter as tflite +except ImportError: + if TF_PRESENT: + import tensorflow.lite as tflite + +try: + import onnxruntime as ort +except ImportError: + pass -from tensorflow import Tensor, signal, keras, saved_model import numpy as np +import numpy.typing as npt import librosa import pretty_midi @@ -33,16 +58,132 @@ ANNOTATIONS_FPS, FFT_HOP, ) -from basic_pitch import ICASSP_2022_MODEL_PATH, note_creation as infer from basic_pitch.commandline_printing import ( generating_file_message, no_tf_warnings, file_saved_confirmation, failed_to_save, ) +import basic_pitch.note_creation as infer + + +class Model: + class MODEL_TYPES(enum.Enum): + TENSORFLOW = enum.auto() + COREML = enum.auto() + TFLITE = enum.auto() + ONNX = enum.auto() + + def __init__(self, model_path: Union[pathlib.Path, str]): + present = [] + if TF_PRESENT: + present.append("TensorFlow") + try: + self.model_type = Model.MODEL_TYPES.TENSORFLOW + self.model = tf.saved_model.load(str(model_path)) + return + except Exception as e: + if os.path.isdir(model_path) and {"saved_model.pb", "variables"} & set(os.listdir(model_path)): + logging.warning( + "Could not load TensorFlow saved model %s even " + "though it looks like a saved model file with error %s. " + "Are you sure it's a TensorFlow saved model?", + model_path, + e.__repr__(), + ) + + if CT_PRESENT: + present.append("CoreML") + try: + self.model_type = Model.MODEL_TYPES.COREML + self.model = ct.models.MLModel(str(model_path)) + return + except Exception as e: + if str(model_path).endswith(".mlpackage"): + logging.warning( + "Could not load CoreML file %s even " + "though it looks like a CoreML file with error %s. " + "Are you sure it's a CoreML file?", + model_path, + e.__repr__(), + ) + + if TFLITE_PRESENT or TF_PRESENT: + present.append("TensorFlowLite") + try: + self.model_type = Model.MODEL_TYPES.TFLITE + self.interpreter = tflite.Interpreter(str(model_path)) + self.model = self.interpreter.get_signature_runner() + return + except Exception as e: + if str(model_path).endswith(".tflite"): + logging.warning( + "Could not load TensorFlowLite file %s even " + "though it looks like a TFLite file with error %s. " + "Are you sure it's a TFLite file?", + model_path, + e.__repr__(), + ) + + if ONNX_PRESENT: + present.append("ONNX") + try: + self.model_type = Model.MODEL_TYPES.ONNX + self.model = ort.InferenceSession(str(model_path), providers=["CPUExecutionProvider"]) + return + except Exception as e: + if str(model_path).endswith(".onnx"): + logging.warning( + "Could not load ONNX file %s even " + "though it looks like a ONNX file with error %s. " + "Are you sure it's a ONNX file?", + model_path, + e.__repr__(), + ) + + raise ValueError( + f"File {model_path} cannot be loaded into either " + "TensorFlow, CoreML, TFLite or ONNX. " + "Please check if it is a supported and valid serialized model " + "and that one of these packages are installed. On this system, " + f"{present} is installed." + ) - -def window_audio_file(audio_original: Tensor, hop_size: int) -> Tuple[Tensor, List[Dict[str, int]]]: + def predict(self, x: npt.NDArray[np.float32]) -> Dict[str, npt.NDArray[np.float32]]: + if self.model_type == Model.MODEL_TYPES.TENSORFLOW: + return {k: v.numpy() for k, v in cast(tf.keras.Model, self.model(x)).items()} + elif self.model_type == Model.MODEL_TYPES.COREML: + print(f"isfinite: {np.all(np.isfinite(x))}", flush=True) + print(f"shape: {x.shape}", flush=True) + print(f"dtype: {x.dtype}", flush=True) + result = cast(ct.models.MLModel, self.model).predict({"input_2": x}) + return { + "note": result["Identity_1"], + "onset": result["Identity_2"], + "contour": result["Identity"], + } + elif self.model_type == Model.MODEL_TYPES.TFLITE: + return self.model(input_2=x) # type: ignore + elif self.model_type == Model.MODEL_TYPES.ONNX: + return { + k: v + for k, v in zip( + ["note", "onset", "contour"], + cast(ort.InferenceSession, self.model).run( + [ + "StatefulPartitionedCall:1", + "StatefulPartitionedCall:2", + "StatefulPartitionedCall:0", + ], + {"serving_default_input_2:0": x}, + ), + ) + } + + +def window_audio_file( + audio_original: npt.NDArray[np.float32], hop_size: int +) -> Iterable[Tuple[npt.NDArray[np.float32], Dict[str, float]]]: """ Pad appropriately an audio file, and return as windowed signal, with window length = AUDIO_N_SAMPLES @@ -53,25 +194,24 @@ def window_audio_file(audio_original: Tensor, hop_size: int) -> Tuple[Tensor, Li window_times: list of {'start':.., 'end':...} objects (times in seconds) """ - from tensorflow import expand_dims # imporing this here so the module loads faster - - audio_windowed = expand_dims( - signal.frame(audio_original, AUDIO_N_SAMPLES, hop_size, pad_end=True, pad_value=0), - axis=-1, - ) - window_times = [ - { + for i in range(0, audio_original.shape[0], hop_size): + window = audio_original[i : i + AUDIO_N_SAMPLES] + if len(window) < AUDIO_N_SAMPLES: + window = np.pad( + window, + pad_width=[[0, AUDIO_N_SAMPLES - len(window)]], + ) + t_start = float(i) / AUDIO_SAMPLE_RATE + window_time = { "start": t_start, "end": t_start + (AUDIO_N_SAMPLES / AUDIO_SAMPLE_RATE), } - for t_start in np.arange(audio_windowed.shape[0]) * hop_size / AUDIO_SAMPLE_RATE - ] - return audio_windowed, window_times + yield np.expand_dims(window, axis=-1), window_time def get_audio_input( audio_path: Union[pathlib.Path, str], overlap_len: int, hop_size: int -) -> Tuple[Tensor, List[Dict[str, int]], int]: +) -> Iterable[Tuple[npt.NDArray[np.float32], Dict[str, float], int]]: """ Read wave file (as mono), pad appropriately, and return as windowed signal, with window length = AUDIO_N_SAMPLES @@ -90,11 +230,15 @@ def get_audio_input( original_length = audio_original.shape[0] audio_original = np.concatenate([np.zeros((int(overlap_len / 2),), dtype=np.float32), audio_original]) - audio_windowed, window_times = window_audio_file(audio_original, hop_size) - return audio_windowed, window_times, original_length + for window, window_time in window_audio_file(audio_original, hop_size): + yield np.expand_dims(window, axis=0), window_time, original_length -def unwrap_output(output: Tensor, audio_original_length: int, n_overlapping_frames: int) -> np.array: +def unwrap_output( + output: npt.NDArray[np.float32], + audio_original_length: int, + n_overlapping_frames: int, +) -> np.array: """Unwrap batched model predictions to a single matrix. Args: @@ -105,45 +249,53 @@ def unwrap_output(output: Tensor, audio_original_length: int, n_overlapping_fram Returns: array (n_times, n_freqs) """ - raw_output = output.numpy() - if len(raw_output.shape) != 3: + if len(output.shape) != 3: return None n_olap = int(0.5 * n_overlapping_frames) if n_olap > 0: # remove half of the overlapping frames from beginning and end - raw_output = raw_output[:, n_olap:-n_olap, :] + output = output[:, n_olap:-n_olap, :] - output_shape = raw_output.shape + output_shape = output.shape n_output_frames_original = int(np.floor(audio_original_length * (ANNOTATIONS_FPS / AUDIO_SAMPLE_RATE))) - unwrapped_output = raw_output.reshape(output_shape[0] * output_shape[1], output_shape[2]) + unwrapped_output = output.reshape(output_shape[0] * output_shape[1], output_shape[2]) return unwrapped_output[:n_output_frames_original, :] # trim to original audio length def run_inference( audio_path: Union[pathlib.Path, str], - model: keras.Model, + model_or_model_path: Union[Model, pathlib.Path, str], debug_file: Optional[pathlib.Path] = None, ) -> Dict[str, np.array]: """Run the model on the input audio path. Args: audio_path: The audio to run inference on. - model: A loaded keras model to run inference with. + model_or_model_path: A loaded Model or path to a serialized model to load. debug_file: An optional path to output debug data to. Useful for testing/verification. Returns: A dictionary with the notes, onsets and contours from model inference. """ + if isinstance(model_or_model_path, Model): + model = model_or_model_path + else: + model = Model(model_or_model_path) + # overlap 30 frames n_overlapping_frames = 30 overlap_len = n_overlapping_frames * FFT_HOP hop_size = AUDIO_N_SAMPLES - overlap_len - audio_windowed, _, audio_original_length = get_audio_input(audio_path, overlap_len, hop_size) + output: Dict[str, Any] = {"note": [], "onset": [], "contour": []} + for audio_windowed, _, audio_original_length in get_audio_input(audio_path, overlap_len, hop_size): + for k, v in model.predict(audio_windowed).items(): + output[k].append(v) - output = model(audio_windowed) - unwrapped_output = {k: unwrap_output(output[k], audio_original_length, n_overlapping_frames) for k in output} + unwrapped_output = { + k: unwrap_output(np.concatenate(output[k]), audio_original_length, n_overlapping_frames) for k in output + } if debug_file: with open(debug_file, "w") as f: @@ -261,7 +413,7 @@ def save_note_events( def predict( audio_path: Union[pathlib.Path, str], - model_or_model_path: Union[keras.Model, pathlib.Path, str] = ICASSP_2022_MODEL_PATH, + model_or_model_path: Union[Model, pathlib.Path, str], onset_threshold: float = 0.5, frame_threshold: float = 0.3, minimum_note_length: float = 127.70, @@ -271,12 +423,16 @@ def predict( melodia_trick: bool = True, debug_file: Optional[pathlib.Path] = None, midi_tempo: float = 120, -) -> Tuple[Dict[str, np.array], pretty_midi.PrettyMIDI, List[Tuple[float, float, int, float, Optional[List[int]]]],]: +) -> Tuple[ + Dict[str, np.array], + pretty_midi.PrettyMIDI, + List[Tuple[float, float, int, float, Optional[List[int]]]], +]: """Run a single prediction. Args: audio_path: File path for the audio to run inference on. - model_or_model_path: Path to load the Keras saved model from. Can be local or on GCS. + model_or_model_path: A loaded Model or path to a serialized model to load. onset_threshold: Minimum energy required for an onset to be considered present. frame_threshold: Minimum energy requirement for a frame to be considered present. minimum_note_length: The minimum allowed note length in milliseconds. @@ -290,17 +446,9 @@ def predict( """ with no_tf_warnings(): - # It's convenient to be able to pass in a keras saved model so if - # someone wants to place this function in a loop, - # the model doesn't have to be reloaded every function call - if isinstance(model_or_model_path, (pathlib.Path, str)): - model = saved_model.load(str(model_or_model_path)) - else: - model = model_or_model_path - print(f"Predicting MIDI for {audio_path}...") - model_output = run_inference(audio_path, model, debug_file) + model_output = run_inference(audio_path, model_or_model_path, debug_file) min_note_len = int(np.round(minimum_note_length / 1000 * (AUDIO_SAMPLE_RATE / FFT_HOP))) midi_data, note_events = infer.model_output_to_notes( model_output, @@ -348,7 +496,7 @@ def predict_and_save( sonify_midi: bool, save_model_outputs: bool, save_notes: bool, - model_path: Union[pathlib.Path, str] = ICASSP_2022_MODEL_PATH, + model_or_model_path: Union[Model, str, pathlib.Path], onset_threshold: float = 0.5, frame_threshold: float = 0.3, minimum_note_length: float = 127.70, @@ -369,7 +517,7 @@ def predict_and_save( sonify_midi: Whether or not to render audio from the MIDI and output it to a file. save_model_outputs: True to save contours, onsets and notes from the model prediction. save_notes: True to save note events. - model_path: Path to load the Keras saved model from. Can be local or on GCS. + model_or_model_path: A loaded Model or path to a serialized model to load. onset_threshold: Minimum energy required for an onset to be considered present. frame_threshold: Minimum energy requirement for a frame to be considered present. minimum_note_length: The minimum allowed note length in milliseconds. @@ -380,14 +528,12 @@ def predict_and_save( debug_file: An optional path to output debug data to. Useful for testing/verification. sonification_samplerate: Sample rate for rendering audio from MIDI. """ - model = saved_model.load(str(model_path)) - for audio_path in audio_path_list: print("") try: model_output, midi_data, note_events = predict( pathlib.Path(audio_path), - model, + model_or_model_path, onset_threshold, frame_threshold, minimum_note_length, diff --git a/basic_pitch/predict.py b/basic_pitch/predict.py index 76bb8862..d4b40df1 100644 --- a/basic_pitch/predict.py +++ b/basic_pitch/predict.py @@ -20,7 +20,12 @@ import pathlib import traceback -from basic_pitch import ICASSP_2022_MODEL_PATH +from basic_pitch import ( + ICASSP_2022_MODEL_PATH, + FilenameSuffix, + build_icassp_2022_model_path, +) +from basic_pitch.inference import Model os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" @@ -30,12 +35,25 @@ def main() -> None: """Handle command line arguments. Entrypoint for this script.""" parser = argparse.ArgumentParser(description="Predict midi from audio.") parser.add_argument("output_dir", type=str, help="directory to save outputs") - parser.add_argument("audio_paths", type=str, nargs="+", help="Space separated paths to the input audio files.") parser.add_argument( - "--model_path", + "audio_paths", + type=str, + nargs="+", + help="Space separated paths to the input audio files.", + ) + parser.add_argument( + "--model-path", type=str, default=ICASSP_2022_MODEL_PATH, - help="path to the saved model directory. Defaults to a ICASSP 2022 model", + help="path to the saved model directory. Defaults to a ICASSP 2022 model. " + "The preferred model is determined by the first library available in " + "[tensorflow, coreml, tensorflow-lite, onnx]", + ) + parser.add_argument( + "--model-serialization", + type=str, + choices=["tf", "coreml", "tflite", "onnx"], + help="If used, --model-path is ignored and instead the model serialization type" "specified is used.", ) parser.add_argument( "--save-midi", @@ -106,8 +124,17 @@ def main() -> None: default=120, help="The tempo for the midi file.", ) - parser.add_argument("--debug-file", default=None, help="Optional file for debug output for inference.") - parser.add_argument("--no-melodia", default=False, action="store_true", help="Skip the melodia trick.") + parser.add_argument( + "--debug-file", + default=None, + help="Optional file for debug output for inference.", + ) + parser.add_argument( + "--no-melodia", + default=False, + action="store_true", + help="Skip the melodia trick.", + ) args = parser.parse_args() print("") @@ -119,7 +146,11 @@ def main() -> None: # tensorflow is very slow to import # this import is here so that the help messages print faster print("Importing Tensorflow (this may take a few seconds)...") - from basic_pitch.inference import predict_and_save, verify_output_dir, verify_input_path + from basic_pitch.inference import ( + predict_and_save, + verify_output_dir, + verify_input_path, + ) output_dir = pathlib.Path(args.output_dir) verify_output_dir(output_dir) @@ -128,6 +159,11 @@ def main() -> None: for audio_path in audio_path_list: verify_input_path(audio_path) + if args.model_serialization: + model = Model(build_icassp_2022_model_path(FilenameSuffix[args.model_serialization])) + else: + model = Model(args.model_path) + try: predict_and_save( audio_path_list, @@ -136,7 +172,7 @@ def main() -> None: args.sonify_midi, args.save_model_outputs, args.save_note_events, - pathlib.Path(args.model_path), + model, args.onset_threshold, args.frame_threshold, args.minimum_note_length, diff --git a/basic_pitch/saved_models/icassp_2022/nmp.mlpackage/Data/com.apple.CoreML/model.mlmodel b/basic_pitch/saved_models/icassp_2022/nmp.mlpackage/Data/com.apple.CoreML/model.mlmodel new file mode 100644 index 00000000..4fa2f93d Binary files /dev/null and b/basic_pitch/saved_models/icassp_2022/nmp.mlpackage/Data/com.apple.CoreML/model.mlmodel differ diff --git a/basic_pitch/saved_models/icassp_2022/nmp.mlpackage/Data/com.apple.CoreML/weights/weight.bin b/basic_pitch/saved_models/icassp_2022/nmp.mlpackage/Data/com.apple.CoreML/weights/weight.bin new file mode 100644 index 00000000..883fcaf0 Binary files /dev/null and b/basic_pitch/saved_models/icassp_2022/nmp.mlpackage/Data/com.apple.CoreML/weights/weight.bin differ diff --git a/basic_pitch/saved_models/icassp_2022/nmp.mlpackage/Manifest.json b/basic_pitch/saved_models/icassp_2022/nmp.mlpackage/Manifest.json new file mode 100644 index 00000000..37bfe14d --- /dev/null +++ b/basic_pitch/saved_models/icassp_2022/nmp.mlpackage/Manifest.json @@ -0,0 +1,18 @@ +{ + "fileFormatVersion": "1.0.0", + "itemInfoEntries": { + "7EB69391-53BD-4281-BFA1-831180579ABC": { + "author": "com.apple.CoreML", + "description": "CoreML Model Specification", + "name": "model.mlmodel", + "path": "com.apple.CoreML/model.mlmodel" + }, + "C784FB56-1BA8-4A0E-AC2A-F0388FE1A2DB": { + "author": "com.apple.CoreML", + "description": "CoreML Model Weights", + "name": "weights", + "path": "com.apple.CoreML/weights" + } + }, + "rootModelIdentifier": "7EB69391-53BD-4281-BFA1-831180579ABC" +} diff --git a/basic_pitch/saved_models/icassp_2022/nmp.onnx b/basic_pitch/saved_models/icassp_2022/nmp.onnx new file mode 100644 index 00000000..c30e5f94 Binary files /dev/null and b/basic_pitch/saved_models/icassp_2022/nmp.onnx differ diff --git a/basic_pitch/saved_models/icassp_2022/nmp.tflite b/basic_pitch/saved_models/icassp_2022/nmp.tflite new file mode 100644 index 00000000..85a41bef Binary files /dev/null and b/basic_pitch/saved_models/icassp_2022/nmp.tflite differ diff --git a/pyproject.toml b/pyproject.toml index aa8effae..24698370 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,15 +17,19 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", ] dependencies = [ + "coremltools; platform_system == 'Darwin' and python_version < '3.11'", "librosa>=0.8.0", "mir_eval>=0.6", "numpy>=1.18", + "onnxruntime; platform_system == 'Windows' and python_version < '3.11'", "pretty_midi>=0.2.9", - "resampy>=0.2.2", + "resampy>=0.2.2,<0.4.3", + "scikit-learn", "scipy>=1.4.1", + "tensorflow>=2.4.1,<2.16; platform_system != 'Darwin' and python_version >= '3.11'", + "tensorflow-macos>=2.4.1; platform_system == 'Darwin' and python_version >= '3.11'", + "tflite-runtime; platform_system == 'Linux' and python_version < '3.11'", "typing_extensions", - "tensorflow>=2.4.1; platform_system != 'Darwin'", - "tensorflow-macos>=2.4.1; platform_system == 'Darwin'", ] [metadata] @@ -51,9 +55,15 @@ test = [ "pytest>=6.1.1", "pytest-mock", ] +tf = [ + "tensorflow>=2.4.1,<2.16; platform_system != 'Darwin'", + "tensorflow-macos>=2.4.1,<2.16; platform_system == 'Darwin' and python_version > '3.7'", +] +coreml = ["coremltools"] +onnx = ["onnxruntime"] docs = ["mkdocs>=1.0.4"] dev = [ - "basic_pitch[test,docs]", + "basic_pitch[test,tf,coreml,onnx,docs]", "mypy", "tox", ] @@ -63,6 +73,7 @@ universal = true [build-system] requires = [ - "setuptools", + "setuptools>=40.8.0", "wheel", -] \ No newline at end of file + "cython", +] diff --git a/tests/resources/vocadito_10/model_output.npz b/tests/resources/vocadito_10/model_output.npz new file mode 100644 index 00000000..9b2a4045 Binary files /dev/null and b/tests/resources/vocadito_10/model_output.npz differ diff --git a/tests/resources/vocadito_10/note_events.npz b/tests/resources/vocadito_10/note_events.npz new file mode 100644 index 00000000..a5e7d51b Binary files /dev/null and b/tests/resources/vocadito_10/note_events.npz differ diff --git a/tests/test_inference.py b/tests/test_inference.py index 23a477e5..89665855 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -15,107 +15,171 @@ # See the License for the specific language governing permissions and # limitations under the License. +import faulthandler import os import pathlib import tempfile -import unittest +from typing import Dict, List import librosa import pretty_midi import numpy as np +import numpy.typing as npt from basic_pitch import ICASSP_2022_MODEL_PATH, inference -from basic_pitch.constants import ANNOTATIONS_N_SEMITONES +from basic_pitch.constants import ( + AUDIO_SAMPLE_RATE, + AUDIO_N_SAMPLES, + ANNOTATIONS_N_SEMITONES, + FFT_HOP, +) RESOURCES_PATH = pathlib.Path(__file__).parent / "resources" +faulthandler.enable() -class TestPredict(unittest.TestCase): - def test_predict(self) -> None: - test_audio_path = RESOURCES_PATH / "vocadito_10.wav" - model_output, midi_data, note_events = inference.predict( + +def test_predict() -> None: + test_audio_path = RESOURCES_PATH / "vocadito_10.wav" + model_output, midi_data, note_events = inference.predict( + test_audio_path, + inference.Model(ICASSP_2022_MODEL_PATH), + ) + assert set(model_output.keys()) == set(["note", "onset", "contour"]) + assert model_output["note"].shape == model_output["onset"].shape + assert isinstance(midi_data, pretty_midi.PrettyMIDI) + lowest_supported_midi = 21 + note_pitch_min = [n[2] >= lowest_supported_midi for n in note_events] + note_pitch_max = [n[2] <= lowest_supported_midi + ANNOTATIONS_N_SEMITONES for n in note_events] + assert all(note_pitch_min) + assert all(note_pitch_max) + assert isinstance(note_events, list) + + expected_model_output = np.load("tests/resources/vocadito_10/model_output.npz", allow_pickle=True)["arr_0"].item() + for k in expected_model_output.keys(): + np.testing.assert_allclose(expected_model_output[k], model_output[k], atol=1e-4, rtol=0) + + expected_note_events = np.load("tests/resources/vocadito_10/note_events.npz", allow_pickle=True)["arr_0"] + + assert len(expected_note_events) == len(note_events) + for expected, calculated in zip(expected_note_events, note_events): + for i in range(len(expected)): + np.testing.assert_allclose(expected[i], calculated[i], atol=1e-4, rtol=0) + + +def test_predict_with_saves() -> None: + test_audio_path = RESOURCES_PATH / "vocadito_10.wav" + with tempfile.TemporaryDirectory() as tmpdir: + inference.predict_and_save( + [test_audio_path], + tmpdir, + True, + True, + True, + True, + model_or_model_path=ICASSP_2022_MODEL_PATH, + ) + expected_midi_path = tmpdir / pathlib.Path("vocadito_10_basic_pitch.mid") + expected_csv_path = tmpdir / pathlib.Path("vocadito_10_basic_pitch.csv") + expected_npz_path = tmpdir / pathlib.Path("vocadito_10_basic_pitch.npz") + expected_sonif_path = tmpdir / pathlib.Path("vocadito_10_basic_pitch.wav") + + for output_path in [ + expected_midi_path, + expected_csv_path, + expected_npz_path, + expected_sonif_path, + ]: + assert os.path.exists(output_path) + + +def test_predict_onset_threshold() -> None: + test_audio_path = RESOURCES_PATH / "vocadito_10.wav" + for onset_threshold in [0, 0.3, 0.8, 1]: + inference.predict( test_audio_path, ICASSP_2022_MODEL_PATH, + onset_threshold=onset_threshold, ) - assert set(model_output.keys()) == set(["note", "onset", "contour"]) - assert model_output["note"].shape == model_output["onset"].shape - assert isinstance(midi_data, pretty_midi.PrettyMIDI) - lowest_supported_midi = 21 - note_pitch_min = [n[2] >= lowest_supported_midi for n in note_events] - note_pitch_max = [n[2] <= lowest_supported_midi + ANNOTATIONS_N_SEMITONES for n in note_events] - assert all(note_pitch_min) - assert all(note_pitch_max) - assert isinstance(note_events, list) - - def test_predict_with_saves(self) -> None: - test_audio_path = RESOURCES_PATH / "vocadito_10.wav" - with tempfile.TemporaryDirectory() as tmpdir: - inference.predict_and_save( - [test_audio_path], - tmpdir, - True, - True, - True, - True, - ) - expected_midi_path = tmpdir / pathlib.Path("vocadito_10_basic_pitch.mid") - expected_csv_path = tmpdir / pathlib.Path("vocadito_10_basic_pitch.csv") - expected_npz_path = tmpdir / pathlib.Path("vocadito_10_basic_pitch.npz") - expected_sonif_path = tmpdir / pathlib.Path("vocadito_10_basic_pitch.wav") - - for output_path in [expected_midi_path, expected_csv_path, expected_npz_path, expected_sonif_path]: - assert os.path.exists(output_path) - - def test_predict_onset_threshold(self) -> None: - test_audio_path = RESOURCES_PATH / "vocadito_10.wav" - for onset_threshold in [0, 0.3, 0.8, 1]: - inference.predict( - test_audio_path, - ICASSP_2022_MODEL_PATH, - onset_threshold=onset_threshold, - ) - - def test_predict_frame_threshold(self) -> None: - test_audio_path = RESOURCES_PATH / "vocadito_10.wav" - for frame_threshold in [0, 0.3, 0.8, 1]: - inference.predict( - test_audio_path, - ICASSP_2022_MODEL_PATH, - frame_threshold=frame_threshold, - ) - - def test_predict_min_note_length(self) -> None: - test_audio_path = RESOURCES_PATH / "vocadito_10.wav" - for minimum_note_length in [10, 100, 1000]: - _, _, note_events = inference.predict( - test_audio_path, - ICASSP_2022_MODEL_PATH, - minimum_note_length=minimum_note_length, - ) - min_len_s = minimum_note_length / 1000.0 - note_lengths = [n[1] - n[0] >= min_len_s for n in note_events] - assert all(note_lengths) - - def test_predict_min_freq(self) -> None: - test_audio_path = RESOURCES_PATH / "vocadito_10.wav" - for minimum_frequency in [40, 80, 200, 2000]: - _, _, note_events = inference.predict( - test_audio_path, - ICASSP_2022_MODEL_PATH, - minimum_frequency=minimum_frequency, - ) - min_freq_midi = np.round(librosa.hz_to_midi(minimum_frequency)) - note_pitch = [n[2] >= min_freq_midi for n in note_events] - assert all(note_pitch) - - def test_predict_max_freq(self) -> None: - test_audio_path = RESOURCES_PATH / "vocadito_10.wav" - for maximum_frequency in [40, 80, 200, 2000]: - _, _, note_events = inference.predict( - test_audio_path, - ICASSP_2022_MODEL_PATH, - maximum_frequency=maximum_frequency, - ) - max_freq_midi = np.round(librosa.hz_to_midi(maximum_frequency)) - note_pitch = [n[2] <= max_freq_midi for n in note_events] - assert all(note_pitch) + + +def test_predict_frame_threshold() -> None: + test_audio_path = RESOURCES_PATH / "vocadito_10.wav" + for frame_threshold in [0, 0.3, 0.8, 1]: + inference.predict( + test_audio_path, + ICASSP_2022_MODEL_PATH, + frame_threshold=frame_threshold, + ) + + +def test_predict_min_note_length() -> None: + test_audio_path = RESOURCES_PATH / "vocadito_10.wav" + for minimum_note_length in [10, 100, 1000]: + _, _, note_events = inference.predict( + test_audio_path, + ICASSP_2022_MODEL_PATH, + minimum_note_length=minimum_note_length, + ) + min_len_s = minimum_note_length / 1000.0 + note_lengths = [n[1] - n[0] >= min_len_s for n in note_events] + assert all(note_lengths) + + +def test_predict_min_freq() -> None: + test_audio_path = RESOURCES_PATH / "vocadito_10.wav" + for minimum_frequency in [40, 80, 200, 2000]: + _, _, note_events = inference.predict( + test_audio_path, + ICASSP_2022_MODEL_PATH, + minimum_frequency=minimum_frequency, + ) + min_freq_midi = np.round(librosa.hz_to_midi(minimum_frequency)) + note_pitch = [n[2] >= min_freq_midi for n in note_events] + assert all(note_pitch) + + +def test_predict_max_freq() -> None: + test_audio_path = RESOURCES_PATH / "vocadito_10.wav" + for maximum_frequency in [40, 80, 200, 2000]: + _, _, note_events = inference.predict( + test_audio_path, + ICASSP_2022_MODEL_PATH, + maximum_frequency=maximum_frequency, + ) + max_freq_midi = np.round(librosa.hz_to_midi(maximum_frequency)) + note_pitch = [n[2] <= max_freq_midi for n in note_events] + assert all(note_pitch) + + +def test_window_audio_file() -> None: + test_audio_path = RESOURCES_PATH / "vocadito_10.wav" + audio, _ = librosa.load(str(test_audio_path), sr=AUDIO_SAMPLE_RATE, mono=True) + audio_windowed, window_times = zip(*inference.window_audio_file(audio, AUDIO_N_SAMPLES - 30 * FFT_HOP)) + assert len(audio_windowed) == 6 + assert len(window_times) == 6 + for time in window_times: + assert time["start"] <= time["end"] + np.testing.assert_equal(audio[:AUDIO_N_SAMPLES], np.squeeze(audio_windowed[0])) + + +def test_get_audio_input() -> None: + test_audio_path = RESOURCES_PATH / "vocadito_10.wav" + audio, _ = librosa.load(str(test_audio_path), sr=AUDIO_SAMPLE_RATE, mono=True) + overlap_len = 30 * FFT_HOP + audio = np.concatenate([np.zeros((overlap_len // 2,), dtype=np.float32), audio]) + audio_windowed: List[npt.NDArray[np.float32]] = [] + window_times: List[Dict[str, float]] = [] + for audio_window, window_time, original_length in inference.get_audio_input( + test_audio_path, overlap_len, AUDIO_N_SAMPLES - overlap_len + ): + audio_windowed.append(audio_window) + window_times.append(window_time) + audio_windowed = np.array(audio_windowed) + assert len(audio_windowed) == 6 + assert len(window_times) == 6 + for time in window_times: + assert time["start"] <= time["end"] + np.testing.assert_equal(audio[:AUDIO_N_SAMPLES], np.squeeze(audio_windowed[0])) + + assert original_length == 200607 diff --git a/tox.ini b/tox.ini index b53a493e..cbb2c008 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = py38,py39,py310,py311,manifest,check-formatting,lint,mypy +envlist = py38,py39,py310,py311,full,manifest,check-formatting,lint,mypy skipsdist = True usedevelop = True requires = @@ -7,7 +7,18 @@ requires = wheel pip>=24 +# Tests with the minimal package (either coreml/onnx/tflite) [testenv] +deps = -e .[test] +commands = + pytest tests --ignore tests/test_nn.py {posargs} +setenv = + SOURCE = {toxinidir}/basic_pitch + TEST_SOURCE = {toxinidir}/tests + LINE_LENGTH = "120" + +# Tests with TF +[testenv:full] deps = -e .[dev] commands = pytest tests {posargs} @@ -22,27 +33,27 @@ skip_install = true commands = check-manifest --ignore 'tests/*' [testenv:check-formatting] -basepython = python3.8 -deps = black==23.1.0 +basepython = python3.10 +deps = black skip_install = true commands = black {env:SOURCE} tests --line-length {env:LINE_LENGTH} --diff --check [testenv:format] -basepython = python3.8 +basepython = python3.10 deps = black skip_install = true commands = black {env:SOURCE} tests --line-length {env:LINE_LENGTH} [testenv:lint] -basepython = python3.8 +basepython = python3.10 deps = flake8 skip_install = true commands = flake8 [testenv:mypy] -basepython = python3.8 +basepython = python3.10 deps = mypy commands = mypy basic_pitch tests --strict --ignore-missing-imports --allow-subclassing-any @@ -52,6 +63,7 @@ show-source = true max-line-length = 120 exclude = .venv,.tox,.git,dist,doc,*.egg,build ignore = E203,E731,W503,E231 +per-file-ignores = __init__.py:F401 [pytest] addopts = -v