Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix request re-use when cancelling a request #253

Merged
merged 6 commits into from
Sep 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/backend_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1022,7 +1022,8 @@ TRITONBACKEND_RequestIsCancelled(
TRITONBACKEND_Request* request, bool* is_cancelled)
{
InferenceRequest* tr = reinterpret_cast<InferenceRequest*>(request);
*is_cancelled = tr->IsCancelled();

RETURN_TRITONSERVER_ERROR_IF_ERROR(tr->IsCancelled(is_cancelled));
return nullptr;
}

Expand Down
6 changes: 6 additions & 0 deletions src/backend_model_instance.cc
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,12 @@ TritonModelInstance::WarmUp()
request->SetResponseCallback(
&warmup_allocator, nullptr, WarmupResponseComplete,
&response_complete[i]);

// For warmup requests we need to manually set ResponseFactory
// since they modify the callback after PrepareForInference has
// been called.
request->SetResponseFactory();

// Capture timestamp before run to avoid incorrect accumulation from
// sequential warmup runs
#ifdef TRITON_ENABLE_STATS
Expand Down
3 changes: 2 additions & 1 deletion src/infer_request.cc
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,7 @@ InferenceRequest::CopyAsNull(const InferenceRequest& from)
lrequest->SetResponseCallback(
&null_allocator, nullptr, NullResponseComplete, nullptr);
lrequest->SetReleaseCallback(NullRequestComplete, nullptr);
lrequest->SetResponseFactory();

// Must normalize inputs here...
for (auto& pr : lrequest->original_inputs_) {
Expand Down Expand Up @@ -828,7 +829,7 @@ InferenceRequest::PrepareForInference()
// inference execution.
inputs_.clear();
override_inputs_.clear();
ResetCancel();
SetResponseFactory();

// Renormalize if anything has changed in the inference request in a
// way that could impact renormalization.
Expand Down
49 changes: 41 additions & 8 deletions src/infer_request.h
Original file line number Diff line number Diff line change
Expand Up @@ -497,16 +497,17 @@ class InferenceRequest {
return Status::Success;
}

// Initialize the response factory that is to be used with any
// responses produced for this request.
// Initialize the response factory arguments that are going to be used with
// any responses produced for this request.
Status SetResponseCallback(
const ResponseAllocator* allocator, void* alloc_userp,
TRITONSERVER_InferenceResponseCompleteFn_t response_fn,
void* response_userp)
{
response_factory_.reset(new InferenceResponseFactory(
model_shared_, id_, allocator, alloc_userp, response_fn, response_userp,
response_delegator_));
response_allocator_ = allocator;
rmccorm4 marked this conversation as resolved.
Show resolved Hide resolved
alloc_userp_ = alloc_userp;
response_callback_ = response_fn;
response_userp_ = response_userp;
return Status::Success;
}

Expand Down Expand Up @@ -550,6 +551,13 @@ class InferenceRequest {

Status LoadInputStates();

void SetResponseFactory()
{
response_factory_.reset(new InferenceResponseFactory(
model_shared_, id_, response_allocator_, alloc_userp_,
response_callback_, response_userp_, response_delegator_));
}

const std::shared_ptr<SequenceStates>& GetSequenceStates() const
{
return sequence_states_;
Expand Down Expand Up @@ -680,10 +688,29 @@ class InferenceRequest {
secondary_stats_aggregator_ = secondary_stats_aggregator;
}

void Cancel() { response_factory_->Cancel(); }
void ResetCancel() { response_factory_->ResetCancel(); }
Status Cancel()
{
if (!response_factory_) {
return Status(
Status::Code::INTERNAL,
"It is not possible to cancel an inference request before calling "
"TRITONSERVER_InferAsync.");
}
response_factory_->Cancel();
rmccorm4 marked this conversation as resolved.
Show resolved Hide resolved
return Status::Success;
}

bool IsCancelled() { return response_factory_->IsCancelled(); }
Status IsCancelled(bool* is_cancelled)
{
if (!response_factory_) {
return Status(
Status::Code::INTERNAL,
"It is not possible to query cancellation status before calling "
"TRITONSERVER_InferAsync.");
}
*is_cancelled = response_factory_->IsCancelled();
return Status::Success;
}

#endif // TRITON_ENABLE_STATS

Expand Down Expand Up @@ -804,6 +831,12 @@ class InferenceRequest {
// Whether this is a null request used for direct sequence batch padding or
// not.
bool null_request_;

// Response factory arguments
const ResponseAllocator* response_allocator_;
void* response_userp_;
void* alloc_userp_;
TRITONSERVER_InferenceResponseCompleteFn_t response_callback_;
};

std::ostream& operator<<(std::ostream& out, const InferenceRequest& request);
Expand Down
1 change: 0 additions & 1 deletion src/infer_response.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ class InferenceResponseFactory {
}

void Cancel() { is_cancelled_ = true; }
void ResetCancel() { is_cancelled_ = false; }

bool IsCancelled() { return is_cancelled_; }

Expand Down
8 changes: 8 additions & 0 deletions src/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,14 @@ set_target_properties(
INSTALL_RPATH ""
)

target_include_directories(
request_cancellation_test
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/..
${CMAKE_CURRENT_SOURCE_DIR}/../../include
${GTEST_INCLUDE_DIRS}
)

target_link_libraries(
request_cancellation_test
PRIVATE
Expand Down
25 changes: 13 additions & 12 deletions src/test/request_cancellation_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -188,22 +188,16 @@ TEST_F(RequestCancellationTest, Cancellation)

TRITONBACKEND_Request* backend_request =
reinterpret_cast<TRITONBACKEND_Request*>(irequest_);
TRITONBACKEND_ResponseFactory* response_factory;
FAIL_TEST_IF_ERR(
TRITONBACKEND_ResponseFactoryNew(&response_factory, backend_request));

bool is_cancelled = true;
FAIL_TEST_IF_ERR(
TRITONSERVER_InferenceRequestIsCancelled(irequest_, &is_cancelled));
ASSERT_FALSE(is_cancelled);

TRITONSERVER_ServerInferAsync(server_, irequest_, nullptr /* trace */));
FAIL_TEST_IF_ERR(TRITONSERVER_InferenceRequestCancel(irequest_));

TRITONBACKEND_ResponseFactory* response_factory;
FAIL_TEST_IF_ERR(
TRITONSERVER_ServerInferAsync(server_, irequest_, nullptr /* trace */));
FAIL_TEST_IF_ERR(TRITONSERVER_InferenceRequestCancel(irequest_));
TRITONBACKEND_ResponseFactoryNew(&response_factory, backend_request));

is_cancelled = false;
bool is_cancelled = false;
FAIL_TEST_IF_ERR(
TRITONSERVER_InferenceRequestIsCancelled(irequest_, &is_cancelled));
ASSERT_TRUE(is_cancelled);
Expand All @@ -228,13 +222,13 @@ TEST_F(RequestCancellationTest, Cancellation)
FAIL_TEST_IF_ERR(TRITONSERVER_InferenceRequestSetResponseCallback(
irequest_, allocator_, nullptr /* response_allocator_userp */,
InferResponseComplete, reinterpret_cast<void*>(p)));
FAIL_TEST_IF_ERR(
TRITONBACKEND_ResponseFactoryNew(&response_factory, backend_request));

// Sending another request and the request should not be cancelled.
FAIL_TEST_IF_ERR(TRITONSERVER_ServerInferAsync(
server_, irequest_, nullptr
/* trace */));
FAIL_TEST_IF_ERR(
TRITONBACKEND_ResponseFactoryNew(&response_factory, backend_request));

is_cancelled = true;
FAIL_TEST_IF_ERR(
Expand All @@ -260,6 +254,10 @@ TEST_F(RequestCancellationTest, CancellationAfterRelease)
irequest_, allocator_, nullptr /* response_allocator_userp */,
InferResponseComplete, reinterpret_cast<void*>(p)));

FAIL_TEST_IF_ERR(TRITONSERVER_ServerInferAsync(
server_, irequest_, nullptr
/* trace */));

TRITONBACKEND_Request* backend_request =
reinterpret_cast<TRITONBACKEND_Request*>(irequest_);
TRITONBACKEND_ResponseFactory* response_factory;
Expand All @@ -286,6 +284,9 @@ TEST_F(RequestCancellationTest, CancellationAfterRelease)
TRITONSERVER_InferenceRequestIsCancelled(irequest_, &is_cancelled));
ASSERT_TRUE(is_cancelled);

TRITONSERVER_InferenceResponse* response = future.get();
FAIL_TEST_IF_ERR(TRITONSERVER_InferenceResponseDelete(response));

is_cancelled = false;
FAIL_TEST_IF_ERR(TRITONBACKEND_ResponseFactoryIsCancelled(
response_factory, &is_cancelled));
Expand Down
4 changes: 2 additions & 2 deletions src/tritonserver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1643,7 +1643,7 @@ TRITONSERVER_InferenceRequestIsCancelled(
{
tc::InferenceRequest* lrequest =
reinterpret_cast<tc::InferenceRequest*>(inference_request);
*is_cancelled = lrequest->IsCancelled();
RETURN_IF_STATUS_ERROR(lrequest->IsCancelled(is_cancelled));
return nullptr; // Success
}

Expand All @@ -1653,7 +1653,7 @@ TRITONSERVER_InferenceRequestCancel(
{
tc::InferenceRequest* lrequest =
reinterpret_cast<tc::InferenceRequest*>(inference_request);
lrequest->Cancel();
RETURN_IF_STATUS_ERROR(lrequest->Cancel());
return nullptr; // Success
}

Expand Down