From 101d78ad82510c793a2d8fd1046b9910ff12ed74 Mon Sep 17 00:00:00 2001 From: David Rubinstein Date: Mon, 2 Oct 2023 17:03:04 -0400 Subject: [PATCH] Dont convert numpy array to list --- basic_pitch/inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/basic_pitch/inference.py b/basic_pitch/inference.py index a89a4e15..85e2f4e4 100644 --- a/basic_pitch/inference.py +++ b/basic_pitch/inference.py @@ -153,7 +153,7 @@ def predict(self, x: npt.NDArray[np.float32]) -> Dict[str, npt.NDArray[np.float3 if self.model_type == Model.MODEL_TYPES.TENSORFLOW: return {k: v.numpy() for k, v in cast(tf.keras.Model, self.model(x)).items()} elif self.model_type == Model.MODEL_TYPES.COREML: - return cast(ct.models.MLModel, self.model.predict({"input": x.tolist()})) # type: ignore + return cast(ct.models.MLModel, self.model.predict({"input": x})) # type: ignore elif self.model_type == Model.MODEL_TYPES.TFLITE: return self.model(input_2=x) # type: ignore elif self.model_type == Model.MODEL_TYPES.ONNX: