-
Notifications
You must be signed in to change notification settings - Fork 282
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #100 from spotify/drubinstein/pt
Add alternative model serializations for inference
- Loading branch information
Showing
16 changed files
with
558 additions
and
179 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,5 @@ | ||
include *.txt tox.ini *.rst *.md LICENSE | ||
include catalog-info.yaml | ||
recursive-include tests *.py *.wav | ||
recursive-include tests *.py *.wav *.npz | ||
recursive-include basic_pitch *.py | ||
recursive-include basic_pitch/saved_models/* | ||
recursive-include basic_pitch *.index *.pb variables.data* | ||
recursive-include basic_pitch/saved_models *.index *.pb variables.data* *.mlmodel *.json *.onnx *.tflite *.bin |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,13 +15,81 @@ | |
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import enum | ||
import logging | ||
import pathlib | ||
|
||
__author__ = "Spotify" | ||
__version__ = "0.3.0" | ||
__email__ = "[email protected]" | ||
__demowebsite__ = "https://basicpitch.io" | ||
__description__ = "Basic Pitch, a lightweight yet powerful audio-to-MIDI converter with pitch bend detection." | ||
__url__ = "https://github.com/spotify/basic-pitch" | ||
|
||
ICASSP_2022_MODEL_PATH = pathlib.Path(__file__).parent / "saved_models/icassp_2022/nmp" | ||
try: | ||
import coremltools | ||
|
||
CT_PRESENT = True | ||
except ImportError: | ||
CT_PRESENT = False | ||
logging.warning( | ||
"Coremltools is not installed. " | ||
"If you plan to use a CoreML Saved Model, " | ||
"reinstall basic-pitch with `pip install 'basic-pitch[coreml]'`" | ||
) | ||
|
||
try: | ||
import tflite_runtime.interpreter | ||
|
||
TFLITE_PRESENT = True | ||
except ImportError: | ||
TFLITE_PRESENT = False | ||
logging.warning( | ||
"tflite-runtime is not installed. " | ||
"If you plan to use a TFLite Model, " | ||
"reinstall basic-pitch with `pip install 'basic-pitch tflite-runtime'` or " | ||
"`pip install 'basic-pitch[tf]'" | ||
) | ||
|
||
try: | ||
import onnxruntime | ||
|
||
ONNX_PRESENT = True | ||
except ImportError: | ||
ONNX_PRESENT = False | ||
logging.warning( | ||
"onnxruntime is not installed. " | ||
"If you plan to use an ONNX Model, " | ||
"reinstall basic-pitch with `pip install 'basic-pitch[onnx]'`" | ||
) | ||
|
||
|
||
try: | ||
import tensorflow | ||
|
||
TF_PRESENT = True | ||
except ImportError: | ||
TF_PRESENT = False | ||
logging.warning( | ||
"Tensorflow is not installed. " | ||
"If you plan to use a TF Saved Model, " | ||
"reinstall basic-pitch with `pip install 'basic-pitch[tf]'`" | ||
) | ||
|
||
|
||
class FilenameSuffix(enum.Enum): | ||
tf = "nmp" | ||
coreml = "nmp.mlpackage" | ||
tflite = "nmp.tflite" | ||
onnx = "nmp.onnx" | ||
|
||
|
||
if TF_PRESENT: | ||
_default_model_type = FilenameSuffix.tf | ||
elif CT_PRESENT: | ||
_default_model_type = FilenameSuffix.coreml | ||
elif TFLITE_PRESENT: | ||
_default_model_type = FilenameSuffix.tflite | ||
elif ONNX_PRESENT: | ||
_default_model_type = FilenameSuffix.onnx | ||
|
||
|
||
def build_icassp_2022_model_path(suffix: FilenameSuffix) -> pathlib.Path: | ||
return pathlib.Path(__file__).parent / "saved_models/icassp_2022" / suffix.value | ||
|
||
|
||
ICASSP_2022_MODEL_PATH = build_icassp_2022_model_path(_default_model_type) |
Oops, something went wrong.