Skip to content

Commit

Permalink
Add more protection around backend response APIs to avoid crashing se…
Browse files Browse the repository at this point in the history
…rver (#367)
  • Loading branch information
rmccorm4 authored Jun 6, 2024
1 parent 7752666 commit 190d8e3
Showing 1 changed file with 30 additions and 0 deletions.
30 changes: 30 additions & 0 deletions src/backend_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1549,6 +1549,11 @@ TRITONAPI_DECLSPEC TRITONSERVER_Error*
TRITONBACKEND_ResponseSetStringParameter(
TRITONBACKEND_Response* response, const char* name, const char* value)
{
if (!response) {
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INVALID_ARG, "response was nullptr");
}

InferenceResponse* tr = reinterpret_cast<InferenceResponse*>(response);
RETURN_TRITONSERVER_ERROR_IF_ERROR(tr->AddParameter(name, value));
return nullptr; // success
Expand All @@ -1558,6 +1563,11 @@ TRITONAPI_DECLSPEC TRITONSERVER_Error*
TRITONBACKEND_ResponseSetIntParameter(
TRITONBACKEND_Response* response, const char* name, const int64_t value)
{
if (!response) {
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INVALID_ARG, "response was nullptr");
}

InferenceResponse* tr = reinterpret_cast<InferenceResponse*>(response);
RETURN_TRITONSERVER_ERROR_IF_ERROR(tr->AddParameter(name, value));
return nullptr; // success
Expand All @@ -1567,6 +1577,11 @@ TRITONAPI_DECLSPEC TRITONSERVER_Error*
TRITONBACKEND_ResponseSetBoolParameter(
TRITONBACKEND_Response* response, const char* name, const bool value)
{
if (!response) {
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INVALID_ARG, "response was nullptr");
}

InferenceResponse* tr = reinterpret_cast<InferenceResponse*>(response);
RETURN_TRITONSERVER_ERROR_IF_ERROR(tr->AddParameter(name, value));
return nullptr; // success
Expand All @@ -1576,6 +1591,11 @@ TRITONAPI_DECLSPEC TRITONSERVER_Error*
TRITONBACKEND_ResponseSetDoubleParameter(
TRITONBACKEND_Response* response, const char* name, const double value)
{
if (!response) {
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INVALID_ARG, "response was nullptr");
}

InferenceResponse* tr = reinterpret_cast<InferenceResponse*>(response);
RETURN_TRITONSERVER_ERROR_IF_ERROR(tr->AddParameter(name, value));
return nullptr; // success
Expand All @@ -1587,6 +1607,11 @@ TRITONBACKEND_ResponseOutput(
const char* name, const TRITONSERVER_DataType datatype,
const int64_t* shape, const uint32_t dims_count)
{
if (!response) {
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INVALID_ARG, "response was nullptr");
}

*output = nullptr;
InferenceResponse* tr = reinterpret_cast<InferenceResponse*>(response);
std::vector<int64_t> lshape(shape, shape + dims_count);
Expand All @@ -1602,6 +1627,11 @@ TRITONBACKEND_ResponseSend(
TRITONBACKEND_Response* response, const uint32_t send_flags,
TRITONSERVER_Error* error)
{
if (!response) {
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INVALID_ARG, "response was nullptr");
}

InferenceResponse* tr = reinterpret_cast<InferenceResponse*>(response);

std::unique_ptr<InferenceResponse> utr(tr);
Expand Down

0 comments on commit 190d8e3

Please sign in to comment.