diff --git a/basic_pitch/__init__.py b/basic_pitch/__init__.py index d3b38b40..b80b2144 100644 --- a/basic_pitch/__init__.py +++ b/basic_pitch/__init__.py @@ -19,17 +19,6 @@ import logging import pathlib -try: - import tensorflow - - TF_PRESENT = True -except ImportError: - TF_PRESENT = False - logging.warning( - "Tensorflow is not installed. " - "If you plan to use a TF Saved Model, " - "reinstall basic-pitch with `pip install 'basic-pitch[tf]'`" - ) try: import coremltools @@ -69,6 +58,19 @@ ) +try: + import tensorflow + + TF_PRESENT = True +except ImportError: + TF_PRESENT = False + logging.warning( + "Tensorflow is not installed. " + "If you plan to use a TF Saved Model, " + "reinstall basic-pitch with `pip install 'basic-pitch[tf]'`" + ) + + class FilenameSuffix(enum.Enum): tf = "nmp" coreml = "nmp.mlpackage" diff --git a/basic_pitch/inference.py b/basic_pitch/inference.py index f2900c25..a89a4e15 100644 --- a/basic_pitch/inference.py +++ b/basic_pitch/inference.py @@ -24,6 +24,8 @@ from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union, cast +from basic_pitch import CT_PRESENT, ONNX_PRESENT, TF_PRESENT, TFLITE_PRESENT + try: import tensorflow as tf except ImportError: @@ -37,7 +39,8 @@ try: import tflite_runtime.interpreter as tflite except ImportError: - pass + if TF_PRESENT: + import tensorflow.lite as tflite try: import onnxruntime as ort @@ -55,7 +58,6 @@ ANNOTATIONS_FPS, FFT_HOP, ) -from basic_pitch import CT_PRESENT, ONNX_PRESENT, TF_PRESENT, TFLITE_PRESENT from basic_pitch.commandline_printing import ( generating_file_message, no_tf_warnings, @@ -106,7 +108,7 @@ def __init__(self, model_path: Union[pathlib.Path, str]): e.__repr__(), ) - if TFLITE_PRESENT: + if TFLITE_PRESENT or TF_PRESENT: present.append("TensorFlowLite") try: self.model_type = Model.MODEL_TYPES.TFLITE @@ -153,11 +155,22 @@ def predict(self, x: npt.NDArray[np.float32]) -> Dict[str, npt.NDArray[np.float3 elif self.model_type == Model.MODEL_TYPES.COREML: return cast(ct.models.MLModel, self.model.predict({"input": x.tolist()})) # type: ignore elif self.model_type == Model.MODEL_TYPES.TFLITE: - return cast(tflite.SignatureRunner, self.model)(x) # type: ignore + return self.model(input_2=x) # type: ignore elif self.model_type == Model.MODEL_TYPES.ONNX: - return cast(ort.InferenceSession, self.model).run( # type: ignore - ["note", "onset", "contour"], {"input": x} - ) + return { + k: v + for k, v in zip( + ["note", "onset", "contour"], + cast(ort.InferenceSession, self.model).run( + [ + "StatefulPartitionedCall:1", + "StatefulPartitionedCall:2", + "StatefulPartitionedCall:0", + ], + {"serving_default_input_2:0": x}, + ), + ) + } def window_audio_file(