diff --git a/include/triton/core/tritonbackend.h b/include/triton/core/tritonbackend.h index de0fffb8d..4b5808d60 100644 --- a/include/triton/core/tritonbackend.h +++ b/include/triton/core/tritonbackend.h @@ -94,7 +94,7 @@ struct TRITONBACKEND_Batcher; /// } /// #define TRITONBACKEND_API_VERSION_MAJOR 1 -#define TRITONBACKEND_API_VERSION_MINOR 15 +#define TRITONBACKEND_API_VERSION_MINOR 16 /// Get the TRITONBACKEND API version supported by Triton. This value /// can be compared against the TRITONBACKEND_API_VERSION_MAJOR and @@ -375,6 +375,31 @@ TRITONBACKEND_DECLSPEC TRITONSERVER_Error* TRITONBACKEND_OutputBufferAttributes( TRITONBACKEND_DECLSPEC TRITONSERVER_Error* TRITONBACKEND_RequestId( TRITONBACKEND_Request* request, const char** id); +/// Query whether the request is cancelled or not. +/// +/// If possible the backend should terminate any processing and +/// send an error response with cancelled status. +/// +/// \param request The inference request. +/// \param is_cancelled Returns true if the request is cancelled otherwise it +/// would return false. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONBACKEND_DECLSPEC TRITONSERVER_Error* TRITONBACKEND_RequestIsCancelled( + TRITONBACKEND_Request* request, bool* is_cancelled); + +/// Query whether the response factory is cancelled or not. +/// +/// If possible the backend should terminate any processing and +/// send an error response with cancelled status. +/// +/// \param factory The response factory +/// \param is_cancelled Returns true if the request is cancelled otherwise it +/// would return false. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONBACKEND_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_ResponseFactoryIsCancelled( + TRITONBACKEND_ResponseFactory* factory, bool* is_cancelled); + /// Get the correlation ID of the request if it is an unsigned integer. /// Zero indicates that the request does not have a correlation ID. /// Returns failure if correlation ID for given request is not an unsigned diff --git a/include/triton/core/tritonserver.h b/include/triton/core/tritonserver.h index c037242bb..a77c00980 100644 --- a/include/triton/core/tritonserver.h +++ b/include/triton/core/tritonserver.h @@ -91,7 +91,7 @@ struct TRITONSERVER_MetricFamily; /// } /// #define TRITONSERVER_API_VERSION_MAJOR 1 -#define TRITONSERVER_API_VERSION_MINOR 24 +#define TRITONSERVER_API_VERSION_MINOR 25 /// Get the TRITONBACKEND API version supported by the Triton shared /// library. This value can be compared against the @@ -308,7 +308,8 @@ typedef enum TRITONSERVER_errorcode_enum { TRITONSERVER_ERROR_INVALID_ARG, TRITONSERVER_ERROR_UNAVAILABLE, TRITONSERVER_ERROR_UNSUPPORTED, - TRITONSERVER_ERROR_ALREADY_EXISTS + TRITONSERVER_ERROR_ALREADY_EXISTS, + TRITONSERVER_ERROR_CANCELLED } TRITONSERVER_Error_Code; /// Create a new error object. The caller takes ownership of the @@ -1091,6 +1092,34 @@ TRITONSERVER_InferenceRequestSetCorrelationIdString( struct TRITONSERVER_InferenceRequest* inference_request, const char* correlation_id); +/// Cancel an inference request. Requests are canceled on a best +/// effort basis and no guarantee is provided that cancelling a +/// request will result in early termination. Note that the +/// inference request cancellation status will be reset after +/// TRITONSERVER_InferAsync is run. This means that if you cancel +/// the request before calling TRITONSERVER_InferAsync +/// the request will not be cancelled. +/// +/// \param inference_request The request object. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC struct TRITONSERVER_Error* +TRITONSERVER_InferenceRequestCancel( + struct TRITONSERVER_InferenceRequest* inference_request); + +/// Query whether the request is cancelled or not. +/// +/// If possible the backend should terminate any processing and +/// send an error response with cancelled status. +/// +/// \param inference_request The request object. +/// \param is_cancelled Returns whether the inference request is cancelled or +/// not. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC struct TRITONSERVER_Error* +TRITONSERVER_InferenceRequestIsCancelled( + struct TRITONSERVER_InferenceRequest* inference_request, + bool* is_cancelled); + /// Deprecated. See TRITONSERVER_InferenceRequestPriorityUInt64 instead. /// /// Get the priority for a request. The default is 0 indicating that diff --git a/src/backend_model.cc b/src/backend_model.cc index 7a2fbfc7d..552e88a62 100644 --- a/src/backend_model.cc +++ b/src/backend_model.cc @@ -1017,6 +1017,16 @@ TRITONBACKEND_RequestFlags(TRITONBACKEND_Request* request, uint32_t* flags) return nullptr; // success } +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_RequestIsCancelled( + TRITONBACKEND_Request* request, bool* is_cancelled) +{ + InferenceRequest* tr = reinterpret_cast(request); + *is_cancelled = tr->IsCancelled(); + return nullptr; +} + + TRITONAPI_DECLSPEC TRITONSERVER_Error* TRITONBACKEND_RequestCorrelationIdString( TRITONBACKEND_Request* request, const char** id) @@ -1365,6 +1375,17 @@ TRITONBACKEND_ResponseFactorySendFlags( return nullptr; // success } +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_ResponseFactoryIsCancelled( + TRITONBACKEND_ResponseFactory* factory, bool* is_cancelled) +{ + std::shared_ptr* response_factory = + reinterpret_cast*>(factory); + *is_cancelled = (*response_factory)->IsCancelled(); + return nullptr; // success +} + + /// /// TRITONBACKEND_Response /// diff --git a/src/infer_request.cc b/src/infer_request.cc index 803335bca..d1a042346 100644 --- a/src/infer_request.cc +++ b/src/infer_request.cc @@ -106,75 +106,69 @@ InferenceRequest::InferenceRequest( : needs_normalization_(true), model_raw_(model), requested_model_version_(requested_model_version), flags_(0), correlation_id_(0), batch_size_(0), timeout_us_(0), collect_stats_(true), - state_(InferenceRequest::State::INITIALIZED), null_request_(false), - decrement_pending_count_(false) + state_(InferenceRequest::State::INITIALIZED), null_request_(false) { SetPriority(0); } -InferenceRequest::~InferenceRequest() -{ - // If request has been enqueued but hasn't started executing by destruction - // time, an error occurred and the pending request count will need to be - // decremented. - DecrementPendingRequestCount(); -} - - Status InferenceRequest::SetState(InferenceRequest::State new_state) { + LOG_VERBOSE(1) << LogRequest() << "Setting state from " << state_ << " to " + << new_state; // No-op if this is already the current state, or if this is a null request. if (new_state == state_ || null_request_) { return Status::Success; } - // Allow RELEASED state transition from any state for now. - // Not all requests will follow linear transition, such as null requests - // used for padding batches, and ensemble requests. - if (new_state == InferenceRequest::State::RELEASED) { - state_ = new_state; - return Status::Success; - } - // Generate error when called rather than copying it into every case below. const auto generate_error = [&]() { std::stringstream ss; ss << LogRequest() << "Invalid request state transition from " << state_ << " to " << new_state; - return Status(Status::Code::INVALID_ARG, ss.str()); + return Status(Status::Code::INTERNAL, ss.str()); }; // Define state transitions switch (state_) { case InferenceRequest::State::INITIALIZED: { - if (new_state != InferenceRequest::State::STARTED) { + if (new_state == InferenceRequest::State::PENDING) { + IncrementPendingRequestCount(); + } else if (new_state == InferenceRequest::State::RELEASED) { + // No-op when moving from initialized to released, just releasing early. + } else { return generate_error(); } - state_ = new_state; - IncrementPendingRequestCount(); break; } - case InferenceRequest::State::STARTED: { - if (new_state != InferenceRequest::State::EXECUTING) { + case InferenceRequest::State::PENDING: { + // Request may move from pending to either execution when scheduled to + // backend, or released early due to some error. + if (new_state == InferenceRequest::State::EXECUTING || + new_state == InferenceRequest::State::RELEASED) { + DecrementPendingRequestCount(); + } else { + // Unexpected state transition return generate_error(); } - state_ = new_state; - DecrementPendingRequestCount(); break; } case InferenceRequest::State::EXECUTING: { if (new_state != InferenceRequest::State::RELEASED) { return generate_error(); } - state_ = new_state; break; } case InferenceRequest::State::RELEASED: { - // No state transition currently supported after release. - return generate_error(); + if (new_state != InferenceRequest::State::INITIALIZED) { + // Only transition currently supported after release is to start over + // again, such as re-using request objects for multiple inferences. + return generate_error(); + } + break; } } + state_ = new_state; return Status::Success; } @@ -182,10 +176,11 @@ void InferenceRequest::IncrementPendingRequestCount() { #ifdef TRITON_ENABLE_METRICS + // Pending request count should always be 0 or 1 per-request. If a request + // increments the count, it should not be incremented again until decremented. auto reporter = model_raw_->MetricReporter(); if (reporter) { reporter->IncrementGauge(kPendingRequestMetric, 1); - decrement_pending_count_ = true; } #endif // TRITON_ENABLE_METRICS } @@ -194,13 +189,11 @@ void InferenceRequest::DecrementPendingRequestCount() { #ifdef TRITON_ENABLE_METRICS - // Only decrement if count has been incremented, and not already decremented. - if (decrement_pending_count_) { - auto reporter = model_raw_->MetricReporter(); - if (reporter) { - reporter->DecrementGauge(kPendingRequestMetric, 1); - } - decrement_pending_count_ = false; + // Pending request count should always be 0 or 1 per-request. A request should + // not decrement the count unless it has already been incremented. + auto reporter = model_raw_->MetricReporter(); + if (reporter) { + reporter->DecrementGauge(kPendingRequestMetric, 1); } #endif // TRITON_ENABLE_METRICS } @@ -376,7 +369,7 @@ InferenceRequest::OutputBufferProperties( Status InferenceRequest::Run(std::unique_ptr& request) { - RETURN_IF_ERROR(request->SetState(InferenceRequest::State::STARTED)); + RETURN_IF_ERROR(request->SetState(InferenceRequest::State::PENDING)); return request->model_raw_->Enqueue(request); } @@ -826,6 +819,7 @@ InferenceRequest::PrepareForInference() // inference execution. inputs_.clear(); override_inputs_.clear(); + ResetCancel(); // Renormalize if anything has changed in the inference request in a // way that could impact renormalization. @@ -849,8 +843,10 @@ InferenceRequest::PrepareForInference() request_start_ns_ = 0; #endif // TRITON_ENABLE_STATS - LOG_VERBOSE(1) << LogRequest() << "prepared: " << *this; + // Help enforce that PrepareForInference() is called prior to Run(). + RETURN_IF_ERROR(SetState(InferenceRequest::State::INITIALIZED)); + LOG_VERBOSE(1) << LogRequest() << "prepared: " << *this; return Status::Success; } @@ -1580,8 +1576,8 @@ operator<<(std::ostream& out, const InferenceRequest::State& state) out << "INITIALIZED"; break; } - case InferenceRequest::State::STARTED: { - out << "STARTED"; + case InferenceRequest::State::PENDING: { + out << "PENDING"; break; } case InferenceRequest::State::EXECUTING: { diff --git a/src/infer_request.h b/src/infer_request.h index 97c56ba27..56504abb9 100644 --- a/src/infer_request.h +++ b/src/infer_request.h @@ -63,7 +63,7 @@ class InferenceRequest { INITIALIZED, // The request has been enqueued, but is not yet executing. - STARTED, + PENDING, // The request has been picked up by a backend model instance for execution, // but hasn't been released yet. @@ -291,7 +291,6 @@ class InferenceRequest { const int64_t requested_model_version); InferenceRequest(Model* model, const int64_t requested_model_version); - ~InferenceRequest(); const std::string& ModelName() const; int64_t RequestedModelVersion() const { return requested_model_version_; } @@ -680,6 +679,11 @@ class InferenceRequest { secondary_stats_aggregator_ = secondary_stats_aggregator; } + void Cancel() { response_factory_->Cancel(); } + void ResetCancel() { response_factory_->ResetCancel(); } + + bool IsCancelled() { return response_factory_->IsCancelled(); } + #endif // TRITON_ENABLE_STATS private: @@ -795,13 +799,10 @@ class InferenceRequest { std::shared_ptr sequence_states_; // The state of the request. - InferenceRequest::State state_; + std::atomic state_; // Whether this is a null request used for direct sequence batch padding or // not. bool null_request_; - // Catch-all to correctly decrement pending count if needed on destruction - // if request doesn't follow normal execution path (error, unused, ensembles) - bool decrement_pending_count_; }; std::ostream& operator<<(std::ostream& out, const InferenceRequest& request); diff --git a/src/infer_response.h b/src/infer_response.h index 5632c0f84..2beb9a667 100644 --- a/src/infer_response.h +++ b/src/infer_response.h @@ -59,10 +59,16 @@ class InferenceResponseFactory { std::unique_ptr&&, const uint32_t)>& delegator) : model_(model), id_(id), allocator_(allocator), alloc_userp_(alloc_userp), response_fn_(response_fn), - response_userp_(response_userp), response_delegator_(delegator) + response_userp_(response_userp), response_delegator_(delegator), + is_cancelled_(false) { } + void Cancel() { is_cancelled_ = true; } + void ResetCancel() { is_cancelled_ = false; } + + bool IsCancelled() { return is_cancelled_; } + const ResponseAllocator* Allocator() { return allocator_; } void* AllocatorUserp() { return alloc_userp_; } @@ -118,6 +124,7 @@ class InferenceResponseFactory { std::function&&, const uint32_t)> response_delegator_; + std::atomic is_cancelled_; #ifdef TRITON_ENABLE_TRACING // Inference trace associated with this response. diff --git a/src/status.cc b/src/status.cc index 1640ee5ed..1344eeefc 100644 --- a/src/status.cc +++ b/src/status.cc @@ -48,7 +48,8 @@ TritonCodeToStatusCode(TRITONSERVER_Error_Code code) return Status::Code::UNSUPPORTED; case TRITONSERVER_ERROR_ALREADY_EXISTS: return Status::Code::ALREADY_EXISTS; - + case TRITONSERVER_ERROR_CANCELLED: + return Status::Code::CANCELLED; default: break; } @@ -74,7 +75,8 @@ StatusCodeToTritonCode(Status::Code status_code) return TRITONSERVER_ERROR_UNSUPPORTED; case Status::Code::ALREADY_EXISTS: return TRITONSERVER_ERROR_ALREADY_EXISTS; - + case Status::Code::CANCELLED: + return TRITONSERVER_ERROR_CANCELLED; default: break; } diff --git a/src/test/CMakeLists.txt b/src/test/CMakeLists.txt index 7fe27f5c0..1b5dbc3ed 100644 --- a/src/test/CMakeLists.txt +++ b/src/test/CMakeLists.txt @@ -534,3 +534,43 @@ install( TARGETS register_api_test RUNTIME DESTINATION bin ) + +# +# Request Cancellation Unittest +# +add_executable( + request_cancellation_test + request_cancellation_test.cc +) + +set_target_properties( + request_cancellation_test + PROPERTIES + SKIP_BUILD_RPATH TRUE + BUILD_WITH_INSTALL_RPATH TRUE + INSTALL_RPATH_USE_LINK_PATH FALSE + INSTALL_RPATH "" +) + +target_include_directories( + request_cancellation_test + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/.. + ${CMAKE_CURRENT_SOURCE_DIR}/../../include + ${GTEST_INCLUDE_DIRS} +) + +target_link_libraries( + request_cancellation_test + PRIVATE + triton-common-error # from repo-common + triton-common-logging # from repo-common + triton-core + GTest::gtest + GTest::gtest_main +) + +install( + TARGETS request_cancellation_test + RUNTIME DESTINATION bin +) diff --git a/src/test/request_cancellation_test.cc b/src/test/request_cancellation_test.cc new file mode 100644 index 000000000..8e2652d30 --- /dev/null +++ b/src/test/request_cancellation_test.cc @@ -0,0 +1,306 @@ +// Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include +#include +#include + +#include "gtest/gtest.h" +#include "triton/core/tritonbackend.h" +#include "triton/core/tritonserver.h" + + +#define FAIL_TEST_IF_ERR(X) \ + do { \ + std::shared_ptr err__((X), TRITONSERVER_ErrorDelete); \ + ASSERT_TRUE((err__ == nullptr)) \ + << TRITONSERVER_ErrorCodeString(err__.get()) << " - " \ + << TRITONSERVER_ErrorMessage(err__.get()); \ + } while (false) + + +TRITONSERVER_Error* +ResponseAlloc( + TRITONSERVER_ResponseAllocator* allocator, const char* tensor_name, + size_t byte_size, TRITONSERVER_MemoryType preferred_memory_type, + int64_t preferred_memory_type_id, void* userp, void** buffer, + void** buffer_userp, TRITONSERVER_MemoryType* actual_memory_type, + int64_t* actual_memory_type_id) +{ + *actual_memory_type = TRITONSERVER_MEMORY_CPU; + *actual_memory_type_id = preferred_memory_type_id; + + if (byte_size == 0) { + *buffer = nullptr; + *buffer_userp = nullptr; + } else { + void* allocated_ptr = nullptr; + allocated_ptr = malloc(byte_size); + + if (allocated_ptr != nullptr) { + *buffer = allocated_ptr; + *buffer_userp = new std::string(tensor_name); + } + } + return nullptr; // Success +} + +TRITONSERVER_Error* +ResponseRelease( + TRITONSERVER_ResponseAllocator* allocator, void* buffer, void* buffer_userp, + size_t byte_size, TRITONSERVER_MemoryType memory_type, + int64_t memory_type_id) +{ + return nullptr; // Success +} + +void +InferRequestComplete( + TRITONSERVER_InferenceRequest* request, const uint32_t flags, void* userp) +{ +} + +void +InferResponseComplete( + TRITONSERVER_InferenceResponse* response, const uint32_t flags, void* userp) +{ + if (response != nullptr) { + // Notify that the completion. + std::promise* p = + reinterpret_cast*>(userp); + p->set_value(response); + delete p; + } +} + +class RequestCancellationTest : public ::testing::Test { + protected: + static void SetUpTestSuite() + { + // Create the server... + TRITONSERVER_ServerOptions* server_options = nullptr; + FAIL_TEST_IF_ERR(TRITONSERVER_ServerOptionsNew(&server_options)); + FAIL_TEST_IF_ERR(TRITONSERVER_ServerOptionsSetModelRepositoryPath( + server_options, "./models")); + FAIL_TEST_IF_ERR(TRITONSERVER_ServerOptionsSetBackendDirectory( + server_options, "/opt/tritonserver/backends")); + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerOptionsSetLogVerbose(server_options, 1)); + FAIL_TEST_IF_ERR(TRITONSERVER_ServerOptionsSetRepoAgentDirectory( + server_options, "/opt/tritonserver/repoagents")); + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerOptionsSetStrictModelConfig(server_options, true)); + + FAIL_TEST_IF_ERR(TRITONSERVER_ServerNew(&server_, server_options)); + FAIL_TEST_IF_ERR(TRITONSERVER_ServerOptionsDelete(server_options)); + } + + static void TearDownTestSuite() + { + FAIL_TEST_IF_ERR(TRITONSERVER_ServerDelete(server_)); + } + + void SetUp() override + { + ASSERT_TRUE(server_ != nullptr) << "Server has not created"; + // Wait until the server is both live and ready. + size_t health_iters = 0; + while (true) { + bool live, ready; + FAIL_TEST_IF_ERR(TRITONSERVER_ServerIsLive(server_, &live)); + FAIL_TEST_IF_ERR(TRITONSERVER_ServerIsReady(server_, &ready)); + if (live && ready) { + break; + } + + if (++health_iters >= 10) { + FAIL() << "failed to find healthy inference server"; + } + + std::this_thread::sleep_for(std::chrono::milliseconds(500)); + } + + // Create allocator with common callback + FAIL_TEST_IF_ERR(TRITONSERVER_ResponseAllocatorNew( + &allocator_, ResponseAlloc, ResponseRelease, nullptr /* start_fn */)); + + FAIL_TEST_IF_ERR(TRITONSERVER_InferenceRequestNew( + &irequest_, server_, "model", -1 /* model_version */)); + + FAIL_TEST_IF_ERR(TRITONSERVER_InferenceRequestSetReleaseCallback( + irequest_, InferRequestComplete, nullptr /* request_release_userp */)); + + std::vector input0_shape({1, 1000}); + FAIL_TEST_IF_ERR(TRITONSERVER_InferenceRequestAddInput( + irequest_, "INPUT0", TRITONSERVER_TYPE_INT32, &input0_shape[0], + input0_shape.size())); + FAIL_TEST_IF_ERR(TRITONSERVER_InferenceRequestAppendInputData( + irequest_, "INPUT0", &input0_data_[0], input0_data_.size(), + TRITONSERVER_MEMORY_CPU, 0)); + } + + void TearDown() override + { + FAIL_TEST_IF_ERR(TRITONSERVER_ResponseAllocatorDelete(allocator_)); + FAIL_TEST_IF_ERR(TRITONSERVER_InferenceRequestDelete(irequest_)); + } + + static TRITONSERVER_Server* server_; + TRITONSERVER_ResponseAllocator* allocator_ = nullptr; + static std::vector input0_data_; + TRITONSERVER_InferenceRequest* irequest_ = nullptr; +}; + +TRITONSERVER_Server* RequestCancellationTest::server_ = nullptr; +std::vector RequestCancellationTest::input0_data_(16, 1); + +TEST_F(RequestCancellationTest, Cancellation) +{ + auto p = new std::promise(); + std::future future = p->get_future(); + + FAIL_TEST_IF_ERR(TRITONSERVER_InferenceRequestSetResponseCallback( + irequest_, allocator_, nullptr /* response_allocator_userp */, + InferResponseComplete, reinterpret_cast(p))); + + TRITONBACKEND_Request* backend_request = + reinterpret_cast(irequest_); + TRITONBACKEND_ResponseFactory* response_factory; + FAIL_TEST_IF_ERR( + TRITONBACKEND_ResponseFactoryNew(&response_factory, backend_request)); + + bool is_cancelled = true; + FAIL_TEST_IF_ERR( + TRITONSERVER_InferenceRequestIsCancelled(irequest_, &is_cancelled)); + ASSERT_FALSE(is_cancelled); + + FAIL_TEST_IF_ERR(TRITONSERVER_InferenceRequestCancel(irequest_)); + + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerInferAsync(server_, irequest_, nullptr /* trace */)); + FAIL_TEST_IF_ERR(TRITONSERVER_InferenceRequestCancel(irequest_)); + + is_cancelled = false; + FAIL_TEST_IF_ERR( + TRITONSERVER_InferenceRequestIsCancelled(irequest_, &is_cancelled)); + ASSERT_TRUE(is_cancelled); + + is_cancelled = false; + FAIL_TEST_IF_ERR(TRITONBACKEND_ResponseFactoryIsCancelled( + response_factory, &is_cancelled)); + ASSERT_TRUE(is_cancelled); + + TRITONSERVER_InferenceResponse* response = future.get(); + FAIL_TEST_IF_ERR(TRITONSERVER_InferenceResponseDelete(response)); + FAIL_TEST_IF_ERR(TRITONBACKEND_ResponseFactoryDelete(response_factory)); + + // FIXME: Looks like there is an issue with internal request state management. + // If the backend send responses before releasing the requests the state may + // not be set to "RELEASED" which is allowed for converting to "INITIALIZED". + std::this_thread::sleep_for(std::chrono::seconds(2)); + + p = new std::promise(); + future = p->get_future(); + + FAIL_TEST_IF_ERR(TRITONSERVER_InferenceRequestSetResponseCallback( + irequest_, allocator_, nullptr /* response_allocator_userp */, + InferResponseComplete, reinterpret_cast(p))); + FAIL_TEST_IF_ERR( + TRITONBACKEND_ResponseFactoryNew(&response_factory, backend_request)); + + // Sending another request and the request should not be cancelled. + FAIL_TEST_IF_ERR(TRITONSERVER_ServerInferAsync( + server_, irequest_, nullptr + /* trace */)); + + is_cancelled = true; + FAIL_TEST_IF_ERR( + TRITONSERVER_InferenceRequestIsCancelled(irequest_, &is_cancelled)); + ASSERT_FALSE(is_cancelled); + + is_cancelled = true; + FAIL_TEST_IF_ERR(TRITONBACKEND_ResponseFactoryIsCancelled( + response_factory, &is_cancelled)); + ASSERT_FALSE(is_cancelled); + + response = future.get(); + FAIL_TEST_IF_ERR(TRITONSERVER_InferenceResponseDelete(response)); + FAIL_TEST_IF_ERR(TRITONBACKEND_ResponseFactoryDelete(response_factory)); +} + +TEST_F(RequestCancellationTest, CancellationAfterRelease) +{ + auto p = new std::promise(); + std::future future = p->get_future(); + + FAIL_TEST_IF_ERR(TRITONSERVER_InferenceRequestSetResponseCallback( + irequest_, allocator_, nullptr /* response_allocator_userp */, + InferResponseComplete, reinterpret_cast(p))); + + TRITONBACKEND_Request* backend_request = + reinterpret_cast(irequest_); + TRITONBACKEND_ResponseFactory* response_factory; + FAIL_TEST_IF_ERR( + TRITONBACKEND_ResponseFactoryNew(&response_factory, backend_request)); + FAIL_TEST_IF_ERR(TRITONBACKEND_RequestRelease( + backend_request, TRITONSERVER_REQUEST_RELEASE_ALL)); + + bool is_cancelled = true; + FAIL_TEST_IF_ERR( + TRITONSERVER_InferenceRequestIsCancelled(irequest_, &is_cancelled)); + ASSERT_FALSE(is_cancelled); + + is_cancelled = true; + FAIL_TEST_IF_ERR(TRITONBACKEND_ResponseFactoryIsCancelled( + response_factory, &is_cancelled)); + ASSERT_FALSE(is_cancelled); + + is_cancelled = false; + FAIL_TEST_IF_ERR(TRITONSERVER_InferenceRequestCancel(irequest_)); + + is_cancelled = false; + FAIL_TEST_IF_ERR( + TRITONSERVER_InferenceRequestIsCancelled(irequest_, &is_cancelled)); + ASSERT_TRUE(is_cancelled); + + is_cancelled = false; + FAIL_TEST_IF_ERR(TRITONBACKEND_ResponseFactoryIsCancelled( + response_factory, &is_cancelled)); + ASSERT_TRUE(is_cancelled); + + FAIL_TEST_IF_ERR(TRITONBACKEND_ResponseFactoryDelete(response_factory)); +} + +int +main(int argc, char** argv) +{ +#ifdef TRITON_ENABLE_LOGGING + LOG_SET_VERBOSE(2); +#endif // TRITON_ENABLE_LOGGING + + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/src/test/response_cache_test.cc b/src/test/response_cache_test.cc index 2b40a04c5..dad7d0faf 100644 --- a/src/test/response_cache_test.cc +++ b/src/test/response_cache_test.cc @@ -66,8 +66,6 @@ InferenceRequest::InferenceRequest( response_factory_.reset(new InferenceResponseFactory()); } -InferenceRequest::~InferenceRequest() {} - InferenceRequest::Input::Input( const std::string& name, const inference::DataType datatype, const int64_t* shape, const uint64_t dim_count) diff --git a/src/tritonserver.cc b/src/tritonserver.cc index c9fc49fc4..998c7a90d 100644 --- a/src/tritonserver.cc +++ b/src/tritonserver.cc @@ -1637,6 +1637,27 @@ TRITONSERVER_InferenceRequestFlags( return nullptr; // Success } +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_InferenceRequestIsCancelled( + struct TRITONSERVER_InferenceRequest* inference_request, bool* is_cancelled) +{ + tc::InferenceRequest* lrequest = + reinterpret_cast(inference_request); + *is_cancelled = lrequest->IsCancelled(); + return nullptr; // Success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_InferenceRequestCancel( + struct TRITONSERVER_InferenceRequest* inference_request) +{ + tc::InferenceRequest* lrequest = + reinterpret_cast(inference_request); + lrequest->Cancel(); + return nullptr; // Success +} + + TRITONAPI_DECLSPEC TRITONSERVER_Error* TRITONSERVER_InferenceRequestSetFlags( TRITONSERVER_InferenceRequest* inference_request, uint32_t flags) diff --git a/src/tritonserver_stub.cc b/src/tritonserver_stub.cc index b0081a0a2..141b38fa6 100644 --- a/src/tritonserver_stub.cc +++ b/src/tritonserver_stub.cc @@ -118,6 +118,14 @@ TRITONSERVER_ResponseAllocatorDelete() { } TRITONAPI_DECLSPEC void +TRITONSERVER_InferenceRequestIsCancelled() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_InferenceRequestCancel() +{ +} +TRITONAPI_DECLSPEC void TRITONSERVER_MessageNewFromSerializedJson() { } @@ -651,6 +659,14 @@ TRITONBACKEND_RequestId() { } TRITONAPI_DECLSPEC void +TRITONBACKEND_ResponseFactoryIsCancelled() +{ +} +TRITONAPI_DECLSPEC void +TRITONBACKEND_RequestIsCancelled() +{ +} +TRITONAPI_DECLSPEC void TRITONBACKEND_RequestCorrelationId() { }