Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix state transitions for re-running requests #251

Merged
merged 6 commits into from
Sep 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 37 additions & 42 deletions src/infer_request.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,86 +106,81 @@ InferenceRequest::InferenceRequest(
: needs_normalization_(true), model_raw_(model),
requested_model_version_(requested_model_version), flags_(0),
correlation_id_(0), batch_size_(0), timeout_us_(0), collect_stats_(true),
state_(InferenceRequest::State::INITIALIZED), null_request_(false),
decrement_pending_count_(false)
state_(InferenceRequest::State::INITIALIZED), null_request_(false)
{
SetPriority(0);
}

InferenceRequest::~InferenceRequest()
{
// If request has been enqueued but hasn't started executing by destruction
// time, an error occurred and the pending request count will need to be
// decremented.
DecrementPendingRequestCount();
}


Status
InferenceRequest::SetState(InferenceRequest::State new_state)
{
LOG_VERBOSE(1) << LogRequest() << "Setting state from " << state_ << " to "
<< new_state;
// No-op if this is already the current state, or if this is a null request.
if (new_state == state_ || null_request_) {
return Status::Success;
}

// Allow RELEASED state transition from any state for now.
// Not all requests will follow linear transition, such as null requests
// used for padding batches, and ensemble requests.
if (new_state == InferenceRequest::State::RELEASED) {
state_ = new_state;
return Status::Success;
}

// Generate error when called rather than copying it into every case below.
const auto generate_error = [&]() {
std::stringstream ss;
ss << LogRequest() << "Invalid request state transition from " << state_
<< " to " << new_state;
return Status(Status::Code::INVALID_ARG, ss.str());
return Status(Status::Code::INTERNAL, ss.str());
};

// Define state transitions
switch (state_) {
case InferenceRequest::State::INITIALIZED: {
if (new_state != InferenceRequest::State::STARTED) {
if (new_state == InferenceRequest::State::PENDING) {
IncrementPendingRequestCount();
} else if (new_state == InferenceRequest::State::RELEASED) {
// No-op when moving from initialized to released, just releasing early.
} else {
return generate_error();
}
state_ = new_state;
IncrementPendingRequestCount();
break;
}
case InferenceRequest::State::STARTED: {
if (new_state != InferenceRequest::State::EXECUTING) {
case InferenceRequest::State::PENDING: {
// Request may move from pending to either execution when scheduled to
// backend, or released early due to some error.
if (new_state == InferenceRequest::State::EXECUTING ||
new_state == InferenceRequest::State::RELEASED) {
DecrementPendingRequestCount();
} else {
// Unexpected state transition
return generate_error();
}
state_ = new_state;
DecrementPendingRequestCount();
break;
}
case InferenceRequest::State::EXECUTING: {
if (new_state != InferenceRequest::State::RELEASED) {
return generate_error();
}
state_ = new_state;
break;
}
case InferenceRequest::State::RELEASED: {
// No state transition currently supported after release.
return generate_error();
if (new_state != InferenceRequest::State::INITIALIZED) {
// Only transition currently supported after release is to start over
// again, such as re-using request objects for multiple inferences.
return generate_error();
}
nnshah1 marked this conversation as resolved.
Show resolved Hide resolved
break;
}
}
state_ = new_state;
return Status::Success;
}

void
InferenceRequest::IncrementPendingRequestCount()
{
#ifdef TRITON_ENABLE_METRICS
// Pending request count should always be 0 or 1 per-request. If a request
// increments the count, it should not be incremented again until decremented.
auto reporter = model_raw_->MetricReporter();
if (reporter) {
reporter->IncrementGauge(kPendingRequestMetric, 1);
decrement_pending_count_ = true;
}
#endif // TRITON_ENABLE_METRICS
}
Expand All @@ -194,13 +189,11 @@ void
InferenceRequest::DecrementPendingRequestCount()
{
#ifdef TRITON_ENABLE_METRICS
// Only decrement if count has been incremented, and not already decremented.
if (decrement_pending_count_) {
auto reporter = model_raw_->MetricReporter();
if (reporter) {
reporter->DecrementGauge(kPendingRequestMetric, 1);
}
decrement_pending_count_ = false;
// Pending request count should always be 0 or 1 per-request. A request should
// not decrement the count unless it has already been incremented.
auto reporter = model_raw_->MetricReporter();
if (reporter) {
reporter->DecrementGauge(kPendingRequestMetric, 1);
}
#endif // TRITON_ENABLE_METRICS
}
Expand Down Expand Up @@ -376,7 +369,7 @@ InferenceRequest::OutputBufferProperties(
Status
InferenceRequest::Run(std::unique_ptr<InferenceRequest>& request)
{
RETURN_IF_ERROR(request->SetState(InferenceRequest::State::STARTED));
RETURN_IF_ERROR(request->SetState(InferenceRequest::State::PENDING));
return request->model_raw_->Enqueue(request);
}

Expand Down Expand Up @@ -849,8 +842,10 @@ InferenceRequest::PrepareForInference()
request_start_ns_ = 0;
#endif // TRITON_ENABLE_STATS

LOG_VERBOSE(1) << LogRequest() << "prepared: " << *this;
// Help enforce that PrepareForInference() is called prior to Run().
RETURN_IF_ERROR(SetState(InferenceRequest::State::INITIALIZED));

LOG_VERBOSE(1) << LogRequest() << "prepared: " << *this;
return Status::Success;
}

Expand Down Expand Up @@ -1580,8 +1575,8 @@ operator<<(std::ostream& out, const InferenceRequest::State& state)
out << "INITIALIZED";
break;
}
case InferenceRequest::State::STARTED: {
out << "STARTED";
case InferenceRequest::State::PENDING: {
out << "PENDING";
break;
}
case InferenceRequest::State::EXECUTING: {
Expand Down
6 changes: 1 addition & 5 deletions src/infer_request.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class InferenceRequest {
INITIALIZED,

// The request has been enqueued, but is not yet executing.
STARTED,
PENDING,

// The request has been picked up by a backend model instance for execution,
// but hasn't been released yet.
Expand Down Expand Up @@ -291,7 +291,6 @@ class InferenceRequest {
const int64_t requested_model_version);

InferenceRequest(Model* model, const int64_t requested_model_version);
~InferenceRequest();

const std::string& ModelName() const;
int64_t RequestedModelVersion() const { return requested_model_version_; }
Expand Down Expand Up @@ -799,9 +798,6 @@ class InferenceRequest {
// Whether this is a null request used for direct sequence batch padding or
// not.
bool null_request_;
// Catch-all to correctly decrement pending count if needed on destruction
// if request doesn't follow normal execution path (error, unused, ensembles)
bool decrement_pending_count_;
};

std::ostream& operator<<(std::ostream& out, const InferenceRequest& request);
Expand Down
2 changes: 0 additions & 2 deletions src/test/response_cache_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,6 @@ InferenceRequest::InferenceRequest(
response_factory_.reset(new InferenceResponseFactory());
}

InferenceRequest::~InferenceRequest() {}

InferenceRequest::Input::Input(
const std::string& name, const inference::DataType datatype,
const int64_t* shape, const uint64_t dim_count)
Expand Down