Skip to content

Commit

Permalink
Handle request re-use
Browse files Browse the repository at this point in the history
  • Loading branch information
Tabrizian committed Sep 6, 2023
1 parent 5f9ad31 commit 1541211
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 44 deletions.
14 changes: 12 additions & 2 deletions src/infer_request.h
Original file line number Diff line number Diff line change
Expand Up @@ -680,7 +680,17 @@ class InferenceRequest {
secondary_stats_aggregator_ = secondary_stats_aggregator;
}

void Cancel() { response_factory_->Cancel(); }
Status Cancel()
{
if (state_ != InferenceRequest::State::INITIALIZED) {
response_factory_->Cancel();
} else {
return Status(
Status::Code::INTERNAL,
"Request cannot be cancelled before it has started executing.");
}
return Status::Success;
}

bool IsCancelled() { return response_factory_->IsCancelled(); }

Expand Down Expand Up @@ -799,7 +809,7 @@ class InferenceRequest {
std::shared_ptr<SequenceStates> sequence_states_;

// The state of the request.
InferenceRequest::State state_;
std::atomic<InferenceRequest::State> state_;
// Whether this is a null request used for direct sequence batch padding or
// not.
bool null_request_;
Expand Down
181 changes: 140 additions & 41 deletions src/test/request_cancellation_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

#include <future>
#include <thread>

#include "gtest/gtest.h"
Expand All @@ -40,6 +41,60 @@
} while (false)


TRITONSERVER_Error*
ResponseAlloc(
TRITONSERVER_ResponseAllocator* allocator, const char* tensor_name,
size_t byte_size, TRITONSERVER_MemoryType preferred_memory_type,
int64_t preferred_memory_type_id, void* userp, void** buffer,
void** buffer_userp, TRITONSERVER_MemoryType* actual_memory_type,
int64_t* actual_memory_type_id)
{
*actual_memory_type = TRITONSERVER_MEMORY_CPU;
*actual_memory_type_id = preferred_memory_type_id;

if (byte_size == 0) {
*buffer = nullptr;
*buffer_userp = nullptr;
} else {
void* allocated_ptr = nullptr;
allocated_ptr = malloc(byte_size);

if (allocated_ptr != nullptr) {
*buffer = allocated_ptr;
*buffer_userp = new std::string(tensor_name);
}
}
return nullptr; // Success
}

TRITONSERVER_Error*
ResponseRelease(
TRITONSERVER_ResponseAllocator* allocator, void* buffer, void* buffer_userp,
size_t byte_size, TRITONSERVER_MemoryType memory_type,
int64_t memory_type_id)
{
return nullptr; // Success
}

void
InferRequestComplete(
TRITONSERVER_InferenceRequest* request, const uint32_t flags, void* userp)
{
}

void
InferResponseComplete(
TRITONSERVER_InferenceResponse* response, const uint32_t flags, void* userp)
{
if (response != nullptr) {
// Notify that the completion.
std::promise<TRITONSERVER_InferenceResponse*>* p =
reinterpret_cast<std::promise<TRITONSERVER_InferenceResponse*>*>(userp);
p->set_value(response);
delete p;
}
}

class RequestCancellationTest : public ::testing::Test {
protected:
static void SetUpTestSuite()
Expand All @@ -51,6 +106,10 @@ class RequestCancellationTest : public ::testing::Test {
server_options, "./models"));
FAIL_TEST_IF_ERR(TRITONSERVER_ServerOptionsSetBackendDirectory(
server_options, "/opt/tritonserver/backends"));
FAIL_TEST_IF_ERR(TRITONSERVER_ServerOptionsSetRepoAgentDirectory(
server_options, "/opt/tritonserver/repoagents"));
FAIL_TEST_IF_ERR(
TRITONSERVER_ServerOptionsSetStrictModelConfig(server_options, true));

FAIL_TEST_IF_ERR(TRITONSERVER_ServerNew(&server_, server_options));
FAIL_TEST_IF_ERR(TRITONSERVER_ServerOptionsDelete(server_options));
Expand Down Expand Up @@ -80,82 +139,123 @@ class RequestCancellationTest : public ::testing::Test {

std::this_thread::sleep_for(std::chrono::milliseconds(500));
}

// Create allocator with common callback
FAIL_TEST_IF_ERR(TRITONSERVER_ResponseAllocatorNew(
&allocator_, ResponseAlloc, ResponseRelease, nullptr /* start_fn */));

FAIL_TEST_IF_ERR(TRITONSERVER_InferenceRequestNew(
&irequest_, server_, "model", -1 /* model_version */));

FAIL_TEST_IF_ERR(TRITONSERVER_InferenceRequestSetReleaseCallback(
irequest_, InferRequestComplete, nullptr /* request_release_userp */));

std::vector<int64_t> input0_shape({1, 1000});
FAIL_TEST_IF_ERR(TRITONSERVER_InferenceRequestAddInput(
irequest_, "INPUT0", TRITONSERVER_TYPE_INT32, &input0_shape[0],
input0_shape.size()));
FAIL_TEST_IF_ERR(TRITONSERVER_InferenceRequestAppendInputData(
irequest_, "INPUT0", &input0_data_[0], input0_data_.size(),
TRITONSERVER_MEMORY_CPU, 0));
}

void TearDown() override
{
FAIL_TEST_IF_ERR(TRITONSERVER_ResponseAllocatorDelete(allocator_));
FAIL_TEST_IF_ERR(TRITONSERVER_InferenceRequestDelete(irequest_));
}

static TRITONSERVER_Server* server_;
TRITONSERVER_ResponseAllocator* allocator_ = nullptr;
static std::vector<int32_t> input0_data_;
TRITONSERVER_InferenceRequest* irequest_ = nullptr;
};

TRITONSERVER_Server* RequestCancellationTest::server_ = nullptr;

void
ReleaseCallback(
struct TRITONSERVER_InferenceRequest* request, const uint32_t flags,
void* userp)
{
// no-op
}
std::vector<int32_t> RequestCancellationTest::input0_data_(16, 1);

TEST_F(RequestCancellationTest, Cancellation)
{
TRITONSERVER_InferenceRequest* request;
FAIL_TEST_IF_ERR(
TRITONSERVER_InferenceRequestNew(&request, server_, "model", 1));
auto p = new std::promise<TRITONSERVER_InferenceResponse*>();
std::future<TRITONSERVER_InferenceResponse*> future = p->get_future();

FAIL_TEST_IF_ERR(TRITONSERVER_InferenceRequestSetResponseCallback(
request, nullptr, nullptr, nullptr, nullptr));
irequest_, allocator_, nullptr /* response_allocator_userp */,
InferResponseComplete, reinterpret_cast<void*>(p)));

TRITONBACKEND_Request* backend_request =
reinterpret_cast<TRITONBACKEND_Request*>(request);
reinterpret_cast<TRITONBACKEND_Request*>(irequest_);
TRITONBACKEND_ResponseFactory* response_factory;
FAIL_TEST_IF_ERR(
TRITONBACKEND_ResponseFactoryNew(&response_factory, backend_request));

bool is_cancelled = true;
FAIL_TEST_IF_ERR(
TRITONSERVER_InferenceRequestIsCancelled(request, &is_cancelled));
ASSERT_FALSE(is_cancelled);

is_cancelled = true;
FAIL_TEST_IF_ERR(
TRITONBACKEND_RequestIsCancelled(backend_request, &is_cancelled));
TRITONSERVER_InferenceRequestIsCancelled(irequest_, &is_cancelled));
ASSERT_FALSE(is_cancelled);

is_cancelled = true;
FAIL_TEST_IF_ERR(TRITONBACKEND_ResponseFactoryIsCancelled(
response_factory, &is_cancelled));
ASSERT_FALSE(is_cancelled);

is_cancelled = false;
FAIL_TEST_IF_ERR(TRITONSERVER_InferenceRequestCancel(request));
TRITONSERVER_Error* error = TRITONSERVER_InferenceRequestCancel(irequest_);
ASSERT_TRUE(error != nullptr);
ASSERT_TRUE(
std::string(TRITONSERVER_ErrorMessage(error)) ==
"Request cannot be cancelled before it has started executing.");

FAIL_TEST_IF_ERR(
TRITONBACKEND_RequestIsCancelled(backend_request, &is_cancelled));
ASSERT_TRUE(is_cancelled);
TRITONSERVER_ServerInferAsync(server_, irequest_, nullptr /* trace */));
FAIL_TEST_IF_ERR(TRITONSERVER_InferenceRequestCancel(irequest_));

is_cancelled = false;
FAIL_TEST_IF_ERR(
TRITONSERVER_InferenceRequestIsCancelled(request, &is_cancelled));
TRITONSERVER_InferenceRequestIsCancelled(irequest_, &is_cancelled));
ASSERT_TRUE(is_cancelled);

is_cancelled = false;
FAIL_TEST_IF_ERR(TRITONBACKEND_ResponseFactoryIsCancelled(
response_factory, &is_cancelled));
ASSERT_TRUE(is_cancelled);

TRITONSERVER_InferenceResponse* response = future.get();
FAIL_TEST_IF_ERR(TRITONSERVER_InferenceResponseDelete(response));
FAIL_TEST_IF_ERR(TRITONBACKEND_ResponseFactoryDelete(response_factory));
FAIL_TEST_IF_ERR(TRITONSERVER_InferenceRequestDelete(request));

// TODO: Enable after https://github.com/triton-inference-server/core/pull/251
// is merged. Currently, it fails with "Invalid request state transition from
// EXECUTING to STARTED".
// p = new std::promise<TRITONSERVER_InferenceResponse*>();
// future = p->get_future();

// FAIL_TEST_IF_ERR(
// TRITONSERVER_InferenceRequestSetResponseCallback(
// irequest_, allocator_, nullptr /* response_allocator_userp */,
// InferResponseComplete, reinterpret_cast<void*>(p)));
// // Sending another request and the request should not be cancelled.
// FAIL_TEST_IF_ERR(TRITONSERVER_ServerInferAsync(server_, irequest_, nullptr
// /* trace */));

// is_cancelled = true;
// FAIL_TEST_IF_ERR(
// TRITONSERVER_InferenceRequestIsCancelled(irequest_, &is_cancelled));
// ASSERT_FALSE(is_cancelled);

// is_cancelled = false;
// FAIL_TEST_IF_ERR(
// TRITONBACKEND_ResponseFactoryIsCancelled(response_factory,
// &is_cancelled));
// ASSERT_FALSE(is_cancelled);
// future.get();
}

TEST_F(RequestCancellationTest, CancellationAfterRelease)
{
TRITONSERVER_InferenceRequest* request;
FAIL_TEST_IF_ERR(
TRITONSERVER_InferenceRequestNew(&request, server_, "model", 1));
auto p = new std::promise<TRITONSERVER_InferenceResponse*>();
std::future<TRITONSERVER_InferenceResponse*> future = p->get_future();

FAIL_TEST_IF_ERR(TRITONSERVER_InferenceRequestSetResponseCallback(
request, nullptr, nullptr, nullptr, nullptr));
FAIL_TEST_IF_ERR(TRITONSERVER_InferenceRequestSetReleaseCallback(
request, ReleaseCallback, nullptr));
irequest_, allocator_, nullptr /* response_allocator_userp */,
InferResponseComplete, reinterpret_cast<void*>(p)));

TRITONBACKEND_Request* backend_request =
reinterpret_cast<TRITONBACKEND_Request*>(request);
reinterpret_cast<TRITONBACKEND_Request*>(irequest_);
TRITONBACKEND_ResponseFactory* response_factory;
FAIL_TEST_IF_ERR(
TRITONBACKEND_ResponseFactoryNew(&response_factory, backend_request));
Expand All @@ -164,7 +264,7 @@ TEST_F(RequestCancellationTest, CancellationAfterRelease)

bool is_cancelled = true;
FAIL_TEST_IF_ERR(
TRITONSERVER_InferenceRequestIsCancelled(request, &is_cancelled));
TRITONSERVER_InferenceRequestIsCancelled(irequest_, &is_cancelled));
ASSERT_FALSE(is_cancelled);

is_cancelled = true;
Expand All @@ -173,11 +273,11 @@ TEST_F(RequestCancellationTest, CancellationAfterRelease)
ASSERT_FALSE(is_cancelled);

is_cancelled = false;
FAIL_TEST_IF_ERR(TRITONSERVER_InferenceRequestCancel(request));
FAIL_TEST_IF_ERR(TRITONSERVER_InferenceRequestCancel(irequest_));

is_cancelled = false;
FAIL_TEST_IF_ERR(
TRITONSERVER_InferenceRequestIsCancelled(request, &is_cancelled));
TRITONSERVER_InferenceRequestIsCancelled(irequest_, &is_cancelled));
ASSERT_TRUE(is_cancelled);

is_cancelled = false;
Expand All @@ -186,7 +286,6 @@ TEST_F(RequestCancellationTest, CancellationAfterRelease)
ASSERT_TRUE(is_cancelled);

FAIL_TEST_IF_ERR(TRITONBACKEND_ResponseFactoryDelete(response_factory));
FAIL_TEST_IF_ERR(TRITONSERVER_InferenceRequestDelete(request));
}

int
Expand Down
2 changes: 1 addition & 1 deletion src/tritonserver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1653,7 +1653,7 @@ TRITONSERVER_InferenceRequestCancel(
{
tc::InferenceRequest* lrequest =
reinterpret_cast<tc::InferenceRequest*>(inference_request);
lrequest->Cancel();
RETURN_IF_STATUS_ERROR(lrequest->Cancel());
return nullptr; // Success
}

Expand Down

0 comments on commit 1541211

Please sign in to comment.