Skip to content

Commit

Permalink
Merge pull request #100 from spotify/drubinstein/pt
Browse files Browse the repository at this point in the history
Add alternative model serializations for inference
  • Loading branch information
drubinstein authored Mar 12, 2024
2 parents 1c1f862 + 638da87 commit b0d185b
Show file tree
Hide file tree
Showing 16 changed files with 558 additions and 179 deletions.
20 changes: 12 additions & 8 deletions .github/workflows/tox.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,30 +6,34 @@ on:
jobs:
test:
name: test ${{ matrix.py }} - ${{ matrix.os }}
runs-on: ${{ matrix.os }}-latest
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os:
- Ubuntu
- Windows
- MacOs
- ubuntu-latest
- windows-latest
py:
- "3.11"
- "3.10"
- "3.9"
- "3.8"
include:
- os: macos-latest-xlarge
py: "3.10.11"
- os: macos-latest-xlarge
py: "3.11.8"
steps:
- name: Setup python for test ${{ matrix.py }}
uses: actions/setup-python@v3
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.py }}
- uses: actions/checkout@v3
- name: Install soundlibs Ubuntu
run: sudo apt-get update && sudo apt-get install --no-install-recommends -y --fix-missing pkg-config libsndfile1
if: matrix.os == 'Ubuntu'
- name: Install soundlibs MacOs
run: brew install libsndfile
run: brew install libsndfile llvm libomp
if: matrix.os == 'MacOs'
- name: Install soundlibs Windows
run: choco install libsndfile
Expand All @@ -41,8 +45,8 @@ jobs:
# We will only check this on the minimum python version
- name: Check formatting, lint and mypy
run: tox -c tox.ini -e check-formatting,lint,mypy
if: matrix.py == '3.8'
if: matrix.py == '3.10'
- name: Run test suite
run: tox -c tox.ini -e py,manifest
run: tox -c tox.ini -e py,manifest,full
- name: Check that basic-pitch can be run as a commandline
run: pip3 install -e . && basic-pitch --help
5 changes: 2 additions & 3 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
include *.txt tox.ini *.rst *.md LICENSE
include catalog-info.yaml
recursive-include tests *.py *.wav
recursive-include tests *.py *.wav *.npz
recursive-include basic_pitch *.py
recursive-include basic_pitch/saved_models/*
recursive-include basic_pitch *.index *.pb variables.data*
recursive-include basic_pitch/saved_models *.index *.pb variables.data* *.mlmodel *.json *.onnx *.tflite *.bin
33 changes: 27 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,31 @@ To update Basic Pitch to the latest version, add `--upgrade` to the above comman

#### Compatible Environments:
- MacOS, Windows and Ubuntu operating systems
- Python versions 3.7, 3.8, 3.9, 3.10
- Python versions 3.7, 3.8, 3.9, 3.10, 3.11
- **For Mac M1 hardware, we currently only support python version 3.10. Otherwise, we suggest using a virtual machine.**


### Model Runtime

Basic Pitch comes with the original TensorFlow model and the TensorFlow model converted to [CoreML](https://developer.apple.com/documentation/coreml), [TensorFlowLite](https://www.tensorflow.org/lite), and [ONNX](https://onnx.ai/). By default, Basic Pitch will _not_ install TensorFlow as a dependency *unless you are using Python>=3.11*. Instead, by default, CoreML will be installed on MacOS, TensorFlowLite will be installed on Linux and ONNX will be installed on Windows. If you want to install TensorFlow along with the default model inference runtime, you can install TensorFlow via `pip install basic-pitch[tf]`.

## Usage

### Model Prediction

### Model Runtime

By default, Basic Pitch will attempt to load a model in the following order:

1. TensorFlow
2. CoreML
3. TensorFlowLite
4. ONNX

Additionally, the module variable ICASSP_2022_MODEL_PATH will default to the first available version in the list.

We will explain how to override this priority list below. Because all other model serializations were converted from TensorFlow, we recommend using TensorFlow when possible. N.B. Basic Pitch does not install TensorFlow by default to save the user time when installing and running Basic Pitch.

#### Command Line Tool

This library offers a command line tool interface. A basic prediction command will generate and save a MIDI file transcription of audio at the `<input-audio-path>` to the `<output-directory>`:
Expand All @@ -73,9 +90,11 @@ basic-pitch <output-directory> <input-audio-path-1> <input-audio-path-2> <input-

Optionally, you may append any of the following flags to your prediction command to save additional formats of the prediction output to the `<output-directory>`:

- `--sonify-midi` to additionally save a `.wav` audio rendering of the MIDI file
- `--save-model-outputs` to additionally save raw model outputs as an NPZ file
- `--save-note-events` to additionally save the predicted note events as a CSV file
- `--sonify-midi` to additionally save a `.wav` audio rendering of the MIDI file.
- `--save-model-outputs` to additionally save raw model outputs as an NPZ file.
- `--save-note-events` to additionally save the predicted note events as a CSV file.

If you want to use a non-default model type (e.g., use CoreML instead of TF), use the `--model-serialization` argument. The CLI will change the loaded model to the type you prefer.

To discover more parameter control, run:
```bash
Expand All @@ -100,17 +119,19 @@ model_output, midi_data, note_events = predict(<input-audio-path>)
- `midi_data` is the transcribed MIDI data derived from the `model_output`
- `note_events` is a list of note events derived from the `model_output`

Note: As mentioned previously, ICASSP_2022_MODEL_PATH will default to the runtime first supported in the list TensorFlow, CoreML, TensorFlowLite, ONNX.

**predict() in a loop**

To run prediction within a loop, you'll want to load the model yourself and provide `predict()` with the loaded model object itself to be used for repeated prediction calls, in order to avoid redundant and sluggish model loading.

```python
import tensorflow as tf

from basic_pitch.inference import predict
from basic_pitch.inference import predict, Model
from basic_pitch import ICASSP_2022_MODEL_PATH

basic_pitch_model = tf.saved_model.load(str(ICASSP_2022_MODEL_PATH))
basic_pitch_model = Model(ICASSP_2022_MODEL_PATH))

for x in range():
...
Expand Down
82 changes: 75 additions & 7 deletions basic_pitch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,81 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import enum
import logging
import pathlib

__author__ = "Spotify"
__version__ = "0.3.0"
__email__ = "[email protected]"
__demowebsite__ = "https://basicpitch.io"
__description__ = "Basic Pitch, a lightweight yet powerful audio-to-MIDI converter with pitch bend detection."
__url__ = "https://github.com/spotify/basic-pitch"

ICASSP_2022_MODEL_PATH = pathlib.Path(__file__).parent / "saved_models/icassp_2022/nmp"
try:
import coremltools

CT_PRESENT = True
except ImportError:
CT_PRESENT = False
logging.warning(
"Coremltools is not installed. "
"If you plan to use a CoreML Saved Model, "
"reinstall basic-pitch with `pip install 'basic-pitch[coreml]'`"
)

try:
import tflite_runtime.interpreter

TFLITE_PRESENT = True
except ImportError:
TFLITE_PRESENT = False
logging.warning(
"tflite-runtime is not installed. "
"If you plan to use a TFLite Model, "
"reinstall basic-pitch with `pip install 'basic-pitch tflite-runtime'` or "
"`pip install 'basic-pitch[tf]'"
)

try:
import onnxruntime

ONNX_PRESENT = True
except ImportError:
ONNX_PRESENT = False
logging.warning(
"onnxruntime is not installed. "
"If you plan to use an ONNX Model, "
"reinstall basic-pitch with `pip install 'basic-pitch[onnx]'`"
)


try:
import tensorflow

TF_PRESENT = True
except ImportError:
TF_PRESENT = False
logging.warning(
"Tensorflow is not installed. "
"If you plan to use a TF Saved Model, "
"reinstall basic-pitch with `pip install 'basic-pitch[tf]'`"
)


class FilenameSuffix(enum.Enum):
tf = "nmp"
coreml = "nmp.mlpackage"
tflite = "nmp.tflite"
onnx = "nmp.onnx"


if TF_PRESENT:
_default_model_type = FilenameSuffix.tf
elif CT_PRESENT:
_default_model_type = FilenameSuffix.coreml
elif TFLITE_PRESENT:
_default_model_type = FilenameSuffix.tflite
elif ONNX_PRESENT:
_default_model_type = FilenameSuffix.onnx


def build_icassp_2022_model_path(suffix: FilenameSuffix) -> pathlib.Path:
return pathlib.Path(__file__).parent / "saved_models/icassp_2022" / suffix.value


ICASSP_2022_MODEL_PATH = build_icassp_2022_model_path(_default_model_type)
Loading

0 comments on commit b0d185b

Please sign in to comment.