Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Fix various tests caused by cases byte size validation not handled properly #364

Merged
merged 10 commits into from
Jun 7, 2024
157 changes: 103 additions & 54 deletions src/infer_request.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1176,27 +1176,44 @@ InferenceRequest::Normalize()
// Note: Since we're using normalized input.ShapeWithBatchDim() here,
// make sure that all the normalization is before the check.
{
const size_t& byte_size = input.Data()->TotalByteSize();
const auto& data_type = input.DType();
const auto& input_dims = input.ShapeWithBatchDim();
int64_t expected_byte_size = INT_MAX;
// Because Triton expects STRING type to be in special format
// (prepend 4 bytes to specify string length), so need to add all the
// first 4 bytes for each element to find expected byte size
if (data_type == inference::DataType::TYPE_STRING) {
RETURN_IF_ERROR(
ValidateBytesInputs(input_id, input, &expected_byte_size));
} else {
expected_byte_size = triton::common::GetByteSize(data_type, input_dims);
}
if ((byte_size > INT_MAX) ||
(static_cast<int64_t>(byte_size) != expected_byte_size)) {
return Status(
Status::Code::INVALID_ARG,
LogRequest() + "input byte size mismatch for input '" + input_id +
"' for model '" + ModelName() + "'. Expected " +
std::to_string(expected_byte_size) + ", got " +
std::to_string(byte_size));

// FIXME: Skip byte size validation for TensorRT backend because it breaks
// shape-size assumption. See DLIS-6805 for proper fix for TRT backend
// reformat_free tensors.
bool skip_byte_size_check = false;
constexpr char trt_prefix[] = "tensorrt_";
const std::string& platform = model_raw_->Config().platform();
skip_byte_size_check |= (platform.rfind(trt_prefix) == 0);
rmccorm4 marked this conversation as resolved.
Show resolved Hide resolved

if (!skip_byte_size_check) {
TRITONSERVER_MemoryType input_memory_type;
// Because Triton expects STRING type to be in special format
// (prepend 4 bytes to specify string length), so need to add all the
// first 4 bytes for each element to find expected byte size
if (data_type == inference::DataType::TYPE_STRING) {
RETURN_IF_ERROR(
ValidateBytesInputs(input_id, input, &input_memory_type));
// FIXME: Temporarily skips byte size checks for GPU tensors. See
// DLIS-6820.
skip_byte_size_check |=
(input_memory_type == TRITONSERVER_MEMORY_GPU);
} else {
const auto& input_dims = input.ShapeWithBatchDim();
int64_t expected_byte_size = INT_MAX;
expected_byte_size =
triton::common::GetByteSize(data_type, input_dims);
const size_t& byte_size = input.Data()->TotalByteSize();
if ((byte_size > INT_MAX) ||
(static_cast<int64_t>(byte_size) != expected_byte_size)) {
return Status(
Status::Code::INVALID_ARG,
LogRequest() + "input byte size mismatch for input '" +
input_id + "' for model '" + ModelName() + "'. Expected " +
std::to_string(expected_byte_size) + ", got " +
std::to_string(byte_size));
}
}
}
}
}
Expand Down Expand Up @@ -1267,55 +1284,87 @@ InferenceRequest::ValidateRequestInputs()
Status
InferenceRequest::ValidateBytesInputs(
const std::string& input_id, const Input& input,
int64_t* const expected_byte_size) const
TRITONSERVER_MemoryType* buffer_memory_type) const
{
const auto& input_dims = input.ShapeWithBatchDim();

int64_t element_count = triton::common::GetElementCount(input_dims);
int64_t element_idx = 0;
*expected_byte_size = 0;
for (size_t i = 0; i < input.Data()->BufferCount(); ++i) {
size_t content_byte_size;
TRITONSERVER_MemoryType content_memory_type;
int64_t content_memory_id;
const char* content = input.Data()->BufferAt(
i, &content_byte_size, &content_memory_type, &content_memory_id);

while (content_byte_size >= sizeof(uint32_t)) {
if (element_idx >= element_count) {
return Status(
Status::Code::INVALID_ARG,
LogRequest() + "unexpected number of string elements " +
std::to_string(element_idx + 1) + " for inference input '" +
input_id + "', expecting " + std::to_string(element_count));
int64_t element_checked = 0;
size_t remaining_element_size = 0;

size_t buffer_next_idx = 0;
const size_t buffer_count = input.DataBufferCount();

const char* buffer = nullptr;
size_t remaining_buffer_size = 0;
int64_t buffer_memory_id;

// Validate elements until all buffers have been fully processed.
while (remaining_buffer_size || buffer_next_idx < buffer_count) {
// Get the next buffer if not currently processing one.
if (!remaining_buffer_size) {
// Reset remaining buffer size and pointers for next buffer.
RETURN_IF_ERROR(input.DataBuffer(
buffer_next_idx++, (const void**)(&buffer), &remaining_buffer_size,
buffer_memory_type, &buffer_memory_id));

if (*buffer_memory_type == TRITONSERVER_MEMORY_GPU) {
return Status::Success;
}
}

const uint32_t len = *(reinterpret_cast<const uint32_t*>(content));
content += sizeof(uint32_t);
content_byte_size -= sizeof(uint32_t);
*expected_byte_size += sizeof(uint32_t);

if (content_byte_size < len) {
constexpr size_t kElementSizeIndicator = sizeof(uint32_t);
// Get the next element if not currently processing one.
if (!remaining_element_size) {
// FIXME: Assume the string element's byte size indicator is not spread
// across buffer boundaries for simplicity.
if (remaining_buffer_size < kElementSizeIndicator) {
return Status(
Status::Code::INVALID_ARG,
LogRequest() + "incomplete string data for inference input '" +
input_id + "', expecting string of length " +
std::to_string(len) + " but only " +
std::to_string(content_byte_size) + " bytes available");
LogRequest() +
"element byte size indicator exceeds the end of the buffer.");
}

content += len;
content_byte_size -= len;
*expected_byte_size += len;
element_idx++;
// Start the next element and reset the remaining element size.
remaining_element_size = *(reinterpret_cast<const uint32_t*>(buffer));
element_checked++;

// Advance pointer and remainder by the indicator size.
buffer += kElementSizeIndicator;
remaining_buffer_size -= kElementSizeIndicator;
}

// If the remaining buffer fits it: consume the rest of the element, proceed
// to the next element.
if (remaining_buffer_size >= remaining_element_size) {
buffer += remaining_element_size;
remaining_buffer_size -= remaining_element_size;
remaining_element_size = 0;
}
// Otherwise the remaining element is larger: consume the rest of the
// buffer, proceed to the next buffer.
else {
remaining_element_size -= remaining_buffer_size;
remaining_buffer_size = 0;
}
}

if (element_idx != element_count) {
// Validate the number of processed buffers exactly match expectations.
if (buffer_next_idx != buffer_count) {
return Status(
Status::Code::INVALID_ARG,
LogRequest() + "expected " + std::to_string(buffer_count) +
" buffers for inference input '" + input_id + "', got " +
std::to_string(buffer_next_idx));
}

// Validate the number of processed elements exactly match expectations.
if (element_checked != element_count) {
return Status(
Status::Code::INVALID_ARG,
LogRequest() + "expected " + std::to_string(element_count) +
" strings for inference input '" + input_id + "', got " +
std::to_string(element_idx));
" string elements for inference input '" + input_id + "', got " +
std::to_string(element_checked));
}

return Status::Success;
Expand Down
2 changes: 1 addition & 1 deletion src/infer_request.h
Original file line number Diff line number Diff line change
Expand Up @@ -749,7 +749,7 @@ class InferenceRequest {

Status ValidateBytesInputs(
const std::string& input_id, const Input& input,
int64_t* const expected_byte_size) const;
TRITONSERVER_MemoryType* buffer_memory_type) const;

// Helpers for pending request metrics
void IncrementPendingRequestCount();
Expand Down
Loading