diff --git a/basic_pitch/inference.py b/basic_pitch/inference.py index 0fd7c76f..d401dcb0 100644 --- a/basic_pitch/inference.py +++ b/basic_pitch/inference.py @@ -18,6 +18,7 @@ import csv import enum import json +import logging import os import pathlib from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union, cast @@ -72,44 +73,78 @@ class MODEL_TYPES(enum.Enum): ONNX = enum.auto() def __init__(self, model_path: Union[pathlib.Path, str]): + present = [] if TF_PRESENT: + present.append("TensorFlow") try: self.model_type = Model.MODEL_TYPES.TENSORFLOW - self.model = tf.saved_model.load(model_path) + self.model = tf.saved_model.load(str(model_path)) return - except Exception: - pass + except Exception as e: + if os.path.isdir(model_path) and {"saved_model.pb", "variables"} & set(os.listdir(model_path)): + logging.warning( + "Could not load TensorFlow saved model %s even " + "though it looks like a saved model file with error %s. " + "Are you sure it's a TensorFlow saved model?", + model_path, + e.__repr__(), + ) if CT_PRESENT: + present.append("CoreML") try: self.model_type = Model.MODEL_TYPES.COREML self.model = ct.models.MLModel(str(model_path)) return - except Exception: - pass + except Exception as e: + if str(model_path).endswith(".mlpackage"): + logging.warning( + "Could not load CoreML file %s even " + "though it looks like a CoreML file with error %s. " + "Are you sure it's a TFLite file?", + model_path, + e.__repr__(), + ) if TFLITE_PRESENT: + present.append("TensorFlowLite") try: self.model_type = Model.MODEL_TYPES.TFLITE - self.interpreter = tflite.Interpreter(model_path=model_path) - self.interpreter.allocate_tensors() + self.interpreter = tflite.Interpreter(str(model_path)) self.model = self.interpreter.get_signature_runner() return - except Exception: - pass + except Exception as e: + if str(model_path).endswith(".tflite"): + logging.warning( + "Could not load TensorFlowLite file %s even " + "though it looks like a TFLite file with error %s. " + "Are you sure it's a TFLite file?", + model_path, + e.__repr__(), + ) if ONNX_PRESENT: + present.append("ONNX") try: self.model_type = Model.MODEL_TYPES.ONNX - self.model = ort.InferenceSession(model_path) + self.model = ort.InferenceSession(str(model_path), providers=['CPUExecutionProvider']) return - except Exception: - pass + except Exception as e: + if str(model_path).endswith(".onnx"): + logging.warning( + "Could not load ONNX file %s even " + "though it looks like a ONNX file with error %s. " + "Are you sure it's a ONNX file?", + model_path, + e.__repr__(), + ) raise ValueError( f"File {model_path} cannot be loaded into either " "TensorFlow, CoreML, TFLite or ONNX. " - "Please check if it is a supported and valid serialized model" + "Please check if it is a supported and valid serialized model " + "and that one of these packages are installed. On this system, " + f"{present} is installed." ) def predict(self, x: npt.NDArray[np.float32]) -> Dict[str, npt.NDArray[np.float32]]: