From 338c2abdcf754f6574315393f40659cfbad99198 Mon Sep 17 00:00:00 2001 From: Sai Kiran Polisetty Date: Mon, 26 Aug 2024 14:08:08 +0000 Subject: [PATCH] Add shm reference support --- include/triton/core/tritonserver.h | 24 ++++++++++++++++++++++++ src/infer_request.h | 18 ++++++++++++++++++ src/tritonserver.cc | 26 ++++++++++++++++++++++++++ src/tritonserver_stub.cc | 8 ++++++++ 4 files changed, 76 insertions(+) diff --git a/include/triton/core/tritonserver.h b/include/triton/core/tritonserver.h index d9701e890..dc40d1ac3 100644 --- a/include/triton/core/tritonserver.h +++ b/include/triton/core/tritonserver.h @@ -31,6 +31,9 @@ #include #include +#include +#include + #ifdef __cplusplus extern "C" { #endif @@ -1484,6 +1487,27 @@ TRITONSERVER_InferenceRequestSetDoubleParameter( struct TRITONSERVER_InferenceRequest* request, const char* key, const double value); +/// Add shm region name to the request. +/// +/// \param request The request. +/// \param region_name The name of the shm region. +/// \param is_added Returns true if region_name added, false otherwise. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC struct TRITONSERVER_Error* +TRITONSERVER_InferenceRequestAddRefShmRegion( + struct TRITONSERVER_InferenceRequest* request, const char* name, + bool* is_added); + +/// Get shm region names referred by request. +/// +/// \param request The request. +/// \param ref_shm_regions Returns set of shm region names. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC struct TRITONSERVER_Error* +TRITONSERVER_InferenceRequestGetRefShmRegions( + TRITONSERVER_InferenceRequest* request, + const std::set** input_ref_shm_regions); + /// TRITONSERVER_InferenceResponse /// /// Object representing an inference response. The inference response diff --git a/src/infer_request.h b/src/infer_request.h index c180d438b..dd249044a 100644 --- a/src/infer_request.h +++ b/src/infer_request.h @@ -760,6 +760,21 @@ class InferenceRequest { return is_cancelled; } + // Add the shm_region name to the ref_shm_regions_ + // If the name is added successfully, set is_added to true + // if the name already exists, set is_added to false + Status AddRefShmRegion(const std::string& region_name, bool* is_added) + { + auto it = ref_shm_regions_.insert(region_name); + *is_added = it.second; + return Status::Success; + } + + const std::set& GetRefShmRegions() const + { + return ref_shm_regions_; + } + private: DISALLOW_COPY_AND_ASSIGN(InferenceRequest); friend std::ostream& operator<<( @@ -885,6 +900,9 @@ class InferenceRequest { // not. bool null_request_; + // Set of shared memory region names used by InferenceRequest + std::set ref_shm_regions_; + // Response factory arguments const ResponseAllocator* response_allocator_; void* response_userp_; diff --git a/src/tritonserver.cc b/src/tritonserver.cc index 67343a730..3fe02924f 100644 --- a/src/tritonserver.cc +++ b/src/tritonserver.cc @@ -2133,6 +2133,32 @@ TRITONSERVER_InferenceRequestSetDoubleParameter( return nullptr; // success } +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_InferenceRequestAddRefShmRegion( + TRITONSERVER_InferenceRequest* request, const char* name, bool* is_added) +{ + tc::InferenceRequest* tr = reinterpret_cast(request); + RETURN_IF_STATUS_ERROR(tr->AddRefShmRegion(name, is_added)); + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_InferenceRequestGetRefShmRegions( + TRITONSERVER_InferenceRequest* request, + const std::set** ref_shm_regions) +{ + if (ref_shm_regions == nullptr || request == nullptr) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, "Received nullptr"); + } + + tc::InferenceRequest* tr = reinterpret_cast(request); + const std::set& regions = tr->GetRefShmRegions(); + *ref_shm_regions = ®ions; + + return nullptr; // Success +} + // // TRITONSERVER_InferenceResponse // diff --git a/src/tritonserver_stub.cc b/src/tritonserver_stub.cc index 0f7b18e71..7d4cf47f0 100644 --- a/src/tritonserver_stub.cc +++ b/src/tritonserver_stub.cc @@ -755,6 +755,14 @@ TRITONSERVER_InferenceRequestSetDoubleParameter() { } TRITONAPI_DECLSPEC void +TRITONSERVER_InferenceRequestAddRefShmRegion() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_InferenceRequestGetRefShmRegions() +{ +} +TRITONAPI_DECLSPEC void TRITONSERVER_InferenceRequestSetIntParameter() { }