diff --git a/python/test/test_api.py b/python/test/test_api.py index df7a464f3..0f4a22a94 100644 --- a/python/test/test_api.py +++ b/python/test/test_api.py @@ -460,10 +460,10 @@ def test_basic_inference(self): output_memory_type="cpu", raise_on_error=True, ): - for name, input_ in inputs.items(): - output_ = response.outputs[name.replace("input", "output")] - output_ = numpy.from_dlpack(output_) - numpy.testing.assert_array_equal(input_, output_) + for input_name, input_value in inputs.items(): + output_value = response.outputs[input_name.replace("input", "output")] + output_value = numpy.from_dlpack(output_value) + numpy.testing.assert_array_equal(input_value, output_value) # test normal bool inputs = {"bool_input": [[True, False, False, True]]} @@ -473,11 +473,11 @@ def test_basic_inference(self): output_memory_type="cpu", raise_on_error=True, ): - for name, input_ in inputs.items(): - output_ = numpy.from_dlpack( - response.outputs[name.replace("input", "output")] + for input_name, input_value in inputs.items(): + output_value = numpy.from_dlpack( + response.outputs[input_name.replace("input", "output")] ) - numpy.testing.assert_array_equal(input_, output_) + numpy.testing.assert_array_equal(input_value, output_value) def test_parameters(self): server = tritonserver.Server(self._server_options).start(wait_until_ready=True) diff --git a/python/tritonserver/_c/tritonserver_pybind.cc b/python/tritonserver/_c/tritonserver_pybind.cc index b3993c705..127bb15b8 100644 --- a/python/tritonserver/_c/tritonserver_pybind.cc +++ b/python/tritonserver/_c/tritonserver_pybind.cc @@ -217,6 +217,14 @@ class PyParameter : public PyWrapper { { } + PyParameter(const char* name, double val) + : PyWrapper( + TRITONSERVER_ParameterNew( + name, TRITONSERVER_PARAMETER_DOUBLE, &val), + true) + { + } + PyParameter(const char* name, bool val) : PyWrapper( TRITONSERVER_ParameterNew(name, TRITONSERVER_PARAMETER_BOOL, &val), @@ -1770,6 +1778,7 @@ PYBIND11_MODULE(triton_bindings, m) py::enum_(m, "TRITONSERVER_ParameterType") .value("STRING", TRITONSERVER_PARAMETER_STRING) .value("INT", TRITONSERVER_PARAMETER_INT) + .value("DOUBLE", TRITONSERVER_PARAMETER_DOUBLE) .value("BOOL", TRITONSERVER_PARAMETER_BOOL) .value("BYTES", TRITONSERVER_PARAMETER_BYTES); // helper functions diff --git a/src/tritonserver.cc b/src/tritonserver.cc index 4871aa3ba..eae83ef2f 100644 --- a/src/tritonserver.cc +++ b/src/tritonserver.cc @@ -633,6 +633,8 @@ TRITONSERVER_ParameterTypeString(TRITONSERVER_ParameterType paramtype) return "INT"; case TRITONSERVER_PARAMETER_BOOL: return "BOOL"; + case TRITONSERVER_PARAMETER_DOUBLE: + return "DOUBLE"; case TRITONSERVER_PARAMETER_BYTES: return "BYTES"; default: @@ -656,6 +658,10 @@ TRITONSERVER_ParameterNew( lparam.reset(new tc::InferenceParameter( name, *reinterpret_cast(value))); break; + case TRITONSERVER_PARAMETER_DOUBLE: + lparam.reset(new tc::InferenceParameter( + name, *reinterpret_cast(value))); + break; case TRITONSERVER_PARAMETER_BOOL: lparam.reset(new tc::InferenceParameter( name, *reinterpret_cast(value)));