Skip to content

Commit

Permalink
Remove redundant checks
Browse files Browse the repository at this point in the history
  • Loading branch information
yinggeh committed Jun 6, 2024
1 parent 272ef29 commit 0e3b63b
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 23 deletions.
36 changes: 14 additions & 22 deletions src/infer_request.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1176,10 +1176,7 @@ InferenceRequest::Normalize()
// Note: Since we're using normalized input.ShapeWithBatchDim() here,
// make sure that all the normalization is before the check.
{
const size_t& byte_size = input.Data()->TotalByteSize();
const auto& data_type = input.DType();
const auto& input_dims = input.ShapeWithBatchDim();
int64_t expected_byte_size = INT_MAX;

// FIXME: Skip byte size validation for TensorRT backend because it breaks
// shape-size assumption. See DLIS-6805 for proper fix for TRT backend
Expand All @@ -1195,27 +1192,27 @@ InferenceRequest::Normalize()
// (prepend 4 bytes to specify string length), so need to add all the
// first 4 bytes for each element to find expected byte size
if (data_type == inference::DataType::TYPE_STRING) {
RETURN_IF_ERROR(ValidateBytesInputs(
input_id, input, &expected_byte_size, &input_memory_type));
RETURN_IF_ERROR(
ValidateBytesInputs(input_id, input, &input_memory_type));
// FIXME: Temporarily skips byte size checks for GPU tensors. See
// DLIS-6820.
skip_byte_size_check |=
(input_memory_type == TRITONSERVER_MEMORY_GPU);
} else {
const auto& input_dims = input.ShapeWithBatchDim();
int64_t expected_byte_size = INT_MAX;
expected_byte_size =
triton::common::GetByteSize(data_type, input_dims);
}

bool byte_size_invalid =
(byte_size > INT_MAX) ||
(static_cast<int64_t>(byte_size) != expected_byte_size);
if (!skip_byte_size_check && byte_size_invalid) {
return Status(
Status::Code::INVALID_ARG,
LogRequest() + "input byte size mismatch for input '" + input_id +
"' for model '" + ModelName() + "'. Expected " +
std::to_string(expected_byte_size) + ", got " +
std::to_string(byte_size));
const size_t& byte_size = input.Data()->TotalByteSize();
if ((byte_size > INT_MAX) ||
(static_cast<int64_t>(byte_size) != expected_byte_size)) {
return Status(
Status::Code::INVALID_ARG,
LogRequest() + "input byte size mismatch for input '" +
input_id + "' for model '" + ModelName() + "'. Expected " +
std::to_string(expected_byte_size) + ", got " +
std::to_string(byte_size));
}
}
}
}
Expand Down Expand Up @@ -1287,10 +1284,8 @@ InferenceRequest::ValidateRequestInputs()
Status
InferenceRequest::ValidateBytesInputs(
const std::string& input_id, const Input& input,
int64_t* const expected_byte_size,
TRITONSERVER_MemoryType* buffer_memory_type) const
{
*expected_byte_size = 0;
const auto& input_dims = input.ShapeWithBatchDim();

int64_t element_count = triton::common::GetElementCount(input_dims);
Expand All @@ -1312,10 +1307,7 @@ InferenceRequest::ValidateBytesInputs(
RETURN_IF_ERROR(input.DataBuffer(
buffer_next_idx++, (const void**)(&buffer), &remaining_buffer_size,
buffer_memory_type, &buffer_memory_id));
*expected_byte_size += remaining_buffer_size;

// FIXME: Skip GPU buffers for now, return an expected_byte_size of -1 as
// a signal to skip validation.
if (*buffer_memory_type == TRITONSERVER_MEMORY_GPU) {
return Status::Success;
}
Expand Down
1 change: 0 additions & 1 deletion src/infer_request.h
Original file line number Diff line number Diff line change
Expand Up @@ -749,7 +749,6 @@ class InferenceRequest {

Status ValidateBytesInputs(
const std::string& input_id, const Input& input,
int64_t* const expected_byte_size,
TRITONSERVER_MemoryType* buffer_memory_type) const;

// Helpers for pending request metrics
Expand Down

0 comments on commit 0e3b63b

Please sign in to comment.