diff --git a/serving/processor/serving/model_instance.cc b/serving/processor/serving/model_instance.cc index d4a10a97ec6..fec822f72be 100644 --- a/serving/processor/serving/model_instance.cc +++ b/serving/processor/serving/model_instance.cc @@ -364,10 +364,10 @@ Status LocalSessionInstance::Warmup( int left_try_count = WARMUP_COUNT; while (left_try_count > 0) { if (warmup_session) { - s = warmup_session->LocalPredict( + s = warmup_session->Warmup( call.request, call.response); } else { - s = session_mgr_->LocalPredict( + s = session_mgr_->Warmup( call.request, call.response); } if (!s.ok()) return s; @@ -563,11 +563,11 @@ Status RemoteSessionInstance::Warmup( int left_try_count = WARMUP_COUNT; while (left_try_count > 0) { if (warmup_session) { - s = warmup_session->LocalPredict( - call.request, call.response); + s = warmup_session->Warmup( + call.request, call.response, false); } else { - s = session_mgr_->LocalPredict( - call.request, call.response); + s = session_mgr_->Warmup( + call.request, call.response, false); } if (!s.ok()) return s; diff --git a/serving/processor/serving/model_session.cc b/serving/processor/serving/model_session.cc index 24400c733dc..534c0e094d2 100644 --- a/serving/processor/serving/model_session.cc +++ b/serving/processor/serving/model_session.cc @@ -262,6 +262,16 @@ int ModelSession::GetServingSessionId() { } Status ModelSession::Predict(Request& req, Response& resp) { + return InternalPredict(req, resp, GetServingSessionId()); +} + +Status ModelSession::Predict(Request& req, Response& resp, + int sess_id) { + return InternalPredict(req, resp, sess_id); +} + +Status ModelSession::InternalPredict(Request& req, Response& resp, + int sess_id) { if (is_local_) { return Status(error::Code::INTERNAL, "Local sparse storage, please use LocalPredict."); @@ -278,17 +288,31 @@ Status ModelSession::Predict(Request& req, Response& resp) { // TODO: which session selected to run on, add some policy here status = session_group_->Run(run_options, req.inputs, req.output_tensor_names, {}, &resp.outputs, - &run_metadata, GetServingSessionId()); + &run_metadata, sess_id); Tracer::GetTracer()->GenTimeline(run_metadata); } else { status = session_group_->Run(req.inputs, req.output_tensor_names, - {}, &resp.outputs, GetServingSessionId()); + {}, &resp.outputs, sess_id); } --counter_; return status; } -Status ModelSession::LocalPredict(Request& req, Response& resp) { +Status ModelSession::LocalPredict(Request& req, + Response& resp) { + return InternalLocalPredict(req, resp, + GetServingSessionId()); +} + +Status ModelSession::LocalPredict(Request& req, + Response& resp, + int sess_id) { + return InternalLocalPredict(req, resp, sess_id); +} + +Status ModelSession::InternalLocalPredict(Request& req, + Response& resp, + int sess_id) { if (!is_local_) { return Status(error::Code::INTERNAL, "Remote sparse storage, please use Predict."); @@ -302,16 +326,31 @@ Status ModelSession::LocalPredict(Request& req, Response& resp) { // TODO: which session selected to run on, add some policy here status = session_group_->Run(run_options, req.inputs, req.output_tensor_names, {}, &resp.outputs, - &run_metadata, GetServingSessionId()); + &run_metadata, sess_id); Tracer::GetTracer()->GenTimeline(run_metadata); } else { status = session_group_->Run(req.inputs, req.output_tensor_names, - {}, &resp.outputs, GetServingSessionId()); + {}, &resp.outputs, sess_id); } --counter_; return status; } +Status ModelSession::Warmup(Request& req, Response& resp, bool local) { + int N = session_group_->GetSessionNum(); + for (int i = 0; i < N; ++i) { + Status s; + if (local) { + s = LocalPredict(req, resp, i); + } else { + s = Predict(req, resp, i); + } + if (!s.ok()) return s; + } + + return Status::OK(); +} + Status ModelSessionMgr::Predict(Request& req, Response& resp) { return serving_session_->Predict(req, resp); } @@ -320,6 +359,10 @@ Status ModelSessionMgr::LocalPredict(Request& req, Response& resp) { return serving_session_->LocalPredict(req, resp); } +Status ModelSessionMgr::Warmup(Request& req, Response& resp, bool local) { + return serving_session_->Warmup(req, resp, local); +} + Status ModelSessionMgr::CreateModelSession( const Version& version, const char* ckpt_name, IFeatureStoreMgr* sparse_storage, bool is_incr_ckpt, diff --git a/serving/processor/serving/model_session.h b/serving/processor/serving/model_session.h index 394582657ca..a54952318ba 100644 --- a/serving/processor/serving/model_session.h +++ b/serving/processor/serving/model_session.h @@ -33,10 +33,13 @@ struct ModelSession { virtual ~ModelSession(); Status Predict(Request& req, Response& resp); + Status Predict(Request& req, Response& resp, int sess_id); Status LocalPredict(Request& req, Response& resp); + Status LocalPredict(Request& req, Response& resp, int sess_id); Version GetVersion() {return version_;} void UpdateVersion(const Version& v) { version_ = v; } Session* GetSession(); + Status Warmup(Request& req, Response& resp, bool local=true); SessionGroup* session_group_ = nullptr; SelectSessionPolicy select_session_policy_ = @@ -54,6 +57,8 @@ struct ModelSession { private: int GetServingSessionId(); + Status InternalPredict(Request& req, Response& resp, int sess_id); + Status InternalLocalPredict(Request& req, Response& resp, int sess_id); }; class ModelSessionMgr { @@ -64,6 +69,7 @@ class ModelSessionMgr { Status Predict(Request& req, Response& resp); Status LocalPredict(Request& req, Response& resp); + Status Warmup(Request& req, Response& resp, bool local=true); Status CreateModelSession( const Version& version,