Skip to content

Commit

Permalink
Update ValidateNonLinearFormatIO()
Browse files Browse the repository at this point in the history
  • Loading branch information
pskiran1 committed Jul 19, 2024
1 parent 9b43798 commit fb9bd83
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 36 deletions.
53 changes: 28 additions & 25 deletions src/model_config_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,34 @@ ValidateIOShape(
return Status::Success;
}

/// Validate that Non-linear format inputs or outputs are specified correctly
/// in a model configuration.
template <class ModelIO>
Status
ValidateNonLinearFormatIO(
const ModelIO& io, const std::string& platform, bool is_input)
{
if (!io.is_non_linear_format_io()) {
// Nothing to validate as the tensor is not non-linear format.
return Status::Success;
}

if (platform != kTensorRTPlanPlatform) {
return Status(
Status::Code::INVALID_ARG,
"Non-linear IO format is only supported for the TensorRT platform");
}

if (io.dims_size() != 3) {
std::string io_type = is_input ? "input" : "output";
return Status(
Status::Code::INVALID_ARG,
"Non-linear IO format " + io_type + " requires 3 dims");
}

return Status::Success;
}

} // namespace

Status
Expand Down Expand Up @@ -1712,31 +1740,6 @@ ValidateInstanceGroup(
return Status::Success;
}

Status
ValidateNonLinearFormatIO(
const inference::ModelInput& io, const std::string& platform, bool is_input)
{
if (!io.is_non_linear_format_io()) {
// Nothing to validate as the tensor is not non-linear format.
return Status::Success;
}

if (platform != kTensorRTPlanPlatform) {
return Status(
Status::Code::INVALID_ARG,
"Non-linear IO format is only supported for the TensorRT platform");
}

if (io.dims_size() != 3) {
std::string io_type = is_input ? "input" : "output";
return Status(
Status::Code::INVALID_ARG,
"Non-linear IO format " + io_type + " requires 3 dims");
}

return Status::Success;
}

Status
ValidateModelInput(
const inference::ModelInput& io, int32_t max_batch_size,
Expand Down
11 changes: 0 additions & 11 deletions src/model_config_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -172,17 +172,6 @@ Status ValidateInstanceGroup(
/// is not valid.
Status ValidateModelIOConfig(const inference::ModelConfig& config);

/// Validate that Non-linear format inputs or outputs are specified correctly
/// in a model configuration.
/// \param io The model input.
/// \param platform The platform name
/// \param is_input Specifies whether it is an input or an output.
/// \return The error status. A non-OK status indicates the configuration
/// is not valid.
Status ValidateNonLinearFormatIO(
const inference::ModelInput& io, const std::string& platform,
bool is_input);

/// Validate that input is specified correctly in a model
/// configuration.
/// \param io The model input.
Expand Down

0 comments on commit fb9bd83

Please sign in to comment.