Skip to content

Commit

Permalink
Add shm reference support
Browse files Browse the repository at this point in the history
  • Loading branch information
pskiran1 committed Aug 26, 2024
1 parent 13b6046 commit 338c2ab
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 0 deletions.
24 changes: 24 additions & 0 deletions include/triton/core/tritonserver.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@
#include <stddef.h>
#include <stdint.h>

#include <set>
#include <string>

#ifdef __cplusplus
extern "C" {
#endif
Expand Down Expand Up @@ -1484,6 +1487,27 @@ TRITONSERVER_InferenceRequestSetDoubleParameter(
struct TRITONSERVER_InferenceRequest* request, const char* key,
const double value);

/// Add shm region name to the request.
///
/// \param request The request.
/// \param region_name The name of the shm region.
/// \param is_added Returns true if region_name added, false otherwise.
/// \return a TRITONSERVER_Error indicating success or failure.
TRITONSERVER_DECLSPEC struct TRITONSERVER_Error*
TRITONSERVER_InferenceRequestAddRefShmRegion(
struct TRITONSERVER_InferenceRequest* request, const char* name,
bool* is_added);

/// Get shm region names referred by request.
///
/// \param request The request.
/// \param ref_shm_regions Returns set of shm region names.
/// \return a TRITONSERVER_Error indicating success or failure.
TRITONSERVER_DECLSPEC struct TRITONSERVER_Error*
TRITONSERVER_InferenceRequestGetRefShmRegions(
TRITONSERVER_InferenceRequest* request,
const std::set<std::string>** input_ref_shm_regions);

/// TRITONSERVER_InferenceResponse
///
/// Object representing an inference response. The inference response
Expand Down
18 changes: 18 additions & 0 deletions src/infer_request.h
Original file line number Diff line number Diff line change
Expand Up @@ -760,6 +760,21 @@ class InferenceRequest {
return is_cancelled;
}

// Add the shm_region name to the ref_shm_regions_
// If the name is added successfully, set is_added to true
// if the name already exists, set is_added to false
Status AddRefShmRegion(const std::string& region_name, bool* is_added)
{
auto it = ref_shm_regions_.insert(region_name);
*is_added = it.second;
return Status::Success;
}

const std::set<std::string>& GetRefShmRegions() const
{
return ref_shm_regions_;
}

private:
DISALLOW_COPY_AND_ASSIGN(InferenceRequest);
friend std::ostream& operator<<(
Expand Down Expand Up @@ -885,6 +900,9 @@ class InferenceRequest {
// not.
bool null_request_;

// Set of shared memory region names used by InferenceRequest
std::set<std::string> ref_shm_regions_;

// Response factory arguments
const ResponseAllocator* response_allocator_;
void* response_userp_;
Expand Down
26 changes: 26 additions & 0 deletions src/tritonserver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2133,6 +2133,32 @@ TRITONSERVER_InferenceRequestSetDoubleParameter(
return nullptr; // success
}

TRITONAPI_DECLSPEC TRITONSERVER_Error*
TRITONSERVER_InferenceRequestAddRefShmRegion(
TRITONSERVER_InferenceRequest* request, const char* name, bool* is_added)
{
tc::InferenceRequest* tr = reinterpret_cast<tc::InferenceRequest*>(request);
RETURN_IF_STATUS_ERROR(tr->AddRefShmRegion(name, is_added));
return nullptr; // success
}

TRITONAPI_DECLSPEC TRITONSERVER_Error*
TRITONSERVER_InferenceRequestGetRefShmRegions(
TRITONSERVER_InferenceRequest* request,
const std::set<std::string>** ref_shm_regions)
{
if (ref_shm_regions == nullptr || request == nullptr) {
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INTERNAL, "Received nullptr");
}

tc::InferenceRequest* tr = reinterpret_cast<tc::InferenceRequest*>(request);
const std::set<std::string>& regions = tr->GetRefShmRegions();
*ref_shm_regions = &regions;

return nullptr; // Success
}

//
// TRITONSERVER_InferenceResponse
//
Expand Down
8 changes: 8 additions & 0 deletions src/tritonserver_stub.cc
Original file line number Diff line number Diff line change
Expand Up @@ -755,6 +755,14 @@ TRITONSERVER_InferenceRequestSetDoubleParameter()
{
}
TRITONAPI_DECLSPEC void
TRITONSERVER_InferenceRequestAddRefShmRegion()
{
}
TRITONAPI_DECLSPEC void
TRITONSERVER_InferenceRequestGetRefShmRegions()
{
}
TRITONAPI_DECLSPEC void
TRITONSERVER_InferenceRequestSetIntParameter()
{
}
Expand Down

0 comments on commit 338c2ab

Please sign in to comment.