Skip to content

Commit

Permalink
updated few other locations
Browse files Browse the repository at this point in the history
  • Loading branch information
nnshah1 committed Jun 4, 2024
1 parent 5296aff commit 0a3c470
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 8 deletions.
16 changes: 8 additions & 8 deletions python/test/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,10 +460,10 @@ def test_basic_inference(self):
output_memory_type="cpu",
raise_on_error=True,
):
for name, input_ in inputs.items():
output_ = response.outputs[name.replace("input", "output")]
output_ = numpy.from_dlpack(output_)
numpy.testing.assert_array_equal(input_, output_)
for input_name, input_value in inputs.items():
output_value = response.outputs[input_name.replace("input", "output")]
output_value = numpy.from_dlpack(output_value)
numpy.testing.assert_array_equal(input_value, output_value)

# test normal bool
inputs = {"bool_input": [[True, False, False, True]]}
Expand All @@ -473,11 +473,11 @@ def test_basic_inference(self):
output_memory_type="cpu",
raise_on_error=True,
):
for name, input_ in inputs.items():
output_ = numpy.from_dlpack(
response.outputs[name.replace("input", "output")]
for input_name, input_value in inputs.items():
output_value = numpy.from_dlpack(
response.outputs[input_name.replace("input", "output")]
)
numpy.testing.assert_array_equal(input_, output_)
numpy.testing.assert_array_equal(input_value, output_value)

def test_parameters(self):
server = tritonserver.Server(self._server_options).start(wait_until_ready=True)
Expand Down
9 changes: 9 additions & 0 deletions python/tritonserver/_c/tritonserver_pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,14 @@ class PyParameter : public PyWrapper<struct TRITONSERVER_Parameter> {
{
}

PyParameter(const char* name, double val)
: PyWrapper(
TRITONSERVER_ParameterNew(
name, TRITONSERVER_PARAMETER_DOUBLE, &val),
true)
{
}

PyParameter(const char* name, bool val)
: PyWrapper(
TRITONSERVER_ParameterNew(name, TRITONSERVER_PARAMETER_BOOL, &val),
Expand Down Expand Up @@ -1770,6 +1778,7 @@ PYBIND11_MODULE(triton_bindings, m)
py::enum_<TRITONSERVER_ParameterType>(m, "TRITONSERVER_ParameterType")
.value("STRING", TRITONSERVER_PARAMETER_STRING)
.value("INT", TRITONSERVER_PARAMETER_INT)
.value("DOUBLE", TRITONSERVER_PARAMETER_DOUBLE)
.value("BOOL", TRITONSERVER_PARAMETER_BOOL)
.value("BYTES", TRITONSERVER_PARAMETER_BYTES);
// helper functions
Expand Down
6 changes: 6 additions & 0 deletions src/tritonserver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,8 @@ TRITONSERVER_ParameterTypeString(TRITONSERVER_ParameterType paramtype)
return "INT";
case TRITONSERVER_PARAMETER_BOOL:
return "BOOL";
case TRITONSERVER_PARAMETER_DOUBLE:
return "DOUBLE";
case TRITONSERVER_PARAMETER_BYTES:
return "BYTES";
default:
Expand All @@ -656,6 +658,10 @@ TRITONSERVER_ParameterNew(
lparam.reset(new tc::InferenceParameter(
name, *reinterpret_cast<const int64_t*>(value)));
break;
case TRITONSERVER_PARAMETER_DOUBLE:
lparam.reset(new tc::InferenceParameter(
name, *reinterpret_cast<const double*>(value)));
break;
case TRITONSERVER_PARAMETER_BOOL:
lparam.reset(new tc::InferenceParameter(
name, *reinterpret_cast<const bool*>(value)));
Expand Down

0 comments on commit 0a3c470

Please sign in to comment.