diff --git a/src/ensemble_scheduler/ensemble_scheduler.cc b/src/ensemble_scheduler/ensemble_scheduler.cc index da5806d33..a16044062 100644 --- a/src/ensemble_scheduler/ensemble_scheduler.cc +++ b/src/ensemble_scheduler/ensemble_scheduler.cc @@ -646,8 +646,9 @@ EnsembleContext::ConsumeResponse(const std::unique_ptr& completed_step) if (response != nullptr) { RETURN_IF_TRITONSERVER_ERROR(TRITONSERVER_InferenceResponseError(response)); uint32_t count; + bool parameter_override = false; InferenceRequest::SequenceId correlation_id = step_ptr->correlation_id_; - uint32_t flags = step_ptr->flags_; + uint32_t flags = 0; RETURN_IF_TRITONSERVER_ERROR( TRITONSERVER_InferenceResponseParameterCount(response, &count)); for (uint32_t idx = 0; idx < count; idx++) { @@ -661,10 +662,12 @@ EnsembleContext::ConsumeResponse(const std::unique_ptr& completed_step) case TRITONSERVER_PARAMETER_INT: correlation_id = InferenceRequest::SequenceId( *reinterpret_cast(vvalue)); + parameter_override = true; break; case TRITONSERVER_PARAMETER_STRING: correlation_id = InferenceRequest::SequenceId( std::string(*reinterpret_cast(vvalue))); + parameter_override = true; break; default: RETURN_IF_TRITONSERVER_ERROR(TRITONSERVER_ErrorNew( @@ -683,6 +686,7 @@ EnsembleContext::ConsumeResponse(const std::unique_ptr& completed_step) if (*reinterpret_cast(vvalue)) { flags |= TRITONSERVER_REQUEST_FLAG_SEQUENCE_START; } + parameter_override = true; } } else if (!strcmp(name, "sequence_end")) { if (type != TRITONSERVER_PARAMETER_BOOL) { @@ -694,6 +698,7 @@ EnsembleContext::ConsumeResponse(const std::unique_ptr& completed_step) if (*reinterpret_cast(vvalue)) { flags |= TRITONSERVER_REQUEST_FLAG_SEQUENCE_END; } + parameter_override = true; } } } @@ -740,9 +745,16 @@ EnsembleContext::ConsumeResponse(const std::unique_ptr& completed_step) } auto& tensor_data = tensor_data_[it->second]; - step_ptr->updated_tensors_.emplace( - it->second, - tensor_data.AddTensor(std::move(tensor), correlation_id, flags)); + if (parameter_override) { + step_ptr->updated_tensors_.emplace( + it->second, + tensor_data.AddTensor(std::move(tensor), correlation_id, flags)); + } else { + step_ptr->updated_tensors_.emplace( + it->second, tensor_data.AddTensor( + std::move(tensor), step_ptr->correlation_id_, + step_ptr->flags_)); + } output_to_tensor.erase(it); } else {