Skip to content

Commit

Permalink
Dont convert numpy array to list
Browse files Browse the repository at this point in the history
  • Loading branch information
David Rubinstein committed Oct 2, 2023
1 parent f85a617 commit 101d78a
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion basic_pitch/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def predict(self, x: npt.NDArray[np.float32]) -> Dict[str, npt.NDArray[np.float3
if self.model_type == Model.MODEL_TYPES.TENSORFLOW:
return {k: v.numpy() for k, v in cast(tf.keras.Model, self.model(x)).items()}
elif self.model_type == Model.MODEL_TYPES.COREML:
return cast(ct.models.MLModel, self.model.predict({"input": x.tolist()})) # type: ignore
return cast(ct.models.MLModel, self.model.predict({"input": x})) # type: ignore
elif self.model_type == Model.MODEL_TYPES.TFLITE:
return self.model(input_2=x) # type: ignore
elif self.model_type == Model.MODEL_TYPES.ONNX:
Expand Down

0 comments on commit 101d78a

Please sign in to comment.