From b85ed1be95e5ba6c9c309681caed3bcad92a6e03 Mon Sep 17 00:00:00 2001 From: David Rubinstein Date: Sun, 1 Oct 2023 19:49:21 -0400 Subject: [PATCH] More informative error --- basic_pitch/inference.py | 44 ++++++++++++++++++++++++++++++++++------ 1 file changed, 38 insertions(+), 6 deletions(-) diff --git a/basic_pitch/inference.py b/basic_pitch/inference.py index 0fd7c76f..746f22d2 100644 --- a/basic_pitch/inference.py +++ b/basic_pitch/inference.py @@ -18,6 +18,7 @@ import csv import enum import json +import logging import os import pathlib from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union, cast @@ -72,44 +73,75 @@ class MODEL_TYPES(enum.Enum): ONNX = enum.auto() def __init__(self, model_path: Union[pathlib.Path, str]): + present = [] if TF_PRESENT: + present.append("TensorFlow") try: self.model_type = Model.MODEL_TYPES.TENSORFLOW self.model = tf.saved_model.load(model_path) return except Exception: - pass + if os.path.isdir(model_path) and {"saved_model.pb", "variables"} & set(os.path.listdir(model_path)): + logging.warning( + "Could not load TensorFlowLite file %s even " + "though it looks like a TFLite file with error %s. Are you sure it's a TFLite file", + model_path, + e.__repr__(), + ) if CT_PRESENT: + present.append("CoreML") try: self.model_type = Model.MODEL_TYPES.COREML self.model = ct.models.MLModel(str(model_path)) return except Exception: - pass + if str(model_path).endswith(".mlpackage"): + logging.warning( + "Could not load TensorFlowLite file %s even " + "though it looks like a TFLite file with error %s. Are you sure it's a TFLite file", + model_path, + e.__repr__(), + ) if TFLITE_PRESENT: + present.append("TensorFlowLite") try: self.model_type = Model.MODEL_TYPES.TFLITE self.interpreter = tflite.Interpreter(model_path=model_path) self.interpreter.allocate_tensors() self.model = self.interpreter.get_signature_runner() return - except Exception: - pass + except Exception as e: + if str(model_path).endswith(".tflite"): + logging.warning( + "Could not load TensorFlowLite file %s even " + "though it looks like a TFLite file with error %s. Are you sure it's a TFLite file", + model_path, + e.__repr__(), + ) if ONNX_PRESENT: + present.append("ONNX") try: self.model_type = Model.MODEL_TYPES.ONNX self.model = ort.InferenceSession(model_path) return except Exception: - pass + if str(model_path).endswith(".onnx"): + logging.warning( + "Could not load TensorFlowLite file %s even " + "though it looks like a TFLite file with error %s. Are you sure it's a TFLite file", + model_path, + e.__repr__(), + ) raise ValueError( f"File {model_path} cannot be loaded into either " "TensorFlow, CoreML, TFLite or ONNX. " - "Please check if it is a supported and valid serialized model" + "Please check if it is a supported and valid serialized model " + "and that one of these packages are installed. On this system, " + f"{present} is installed." ) def predict(self, x: npt.NDArray[np.float32]) -> Dict[str, npt.NDArray[np.float32]]: