From 253c19b50967f3330ba569ccf784c90791616d4f Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Sun, 28 Jul 2024 08:24:23 +0700 Subject: [PATCH] fix: better error handling (#20) Co-authored-by: vansangpfiev --- src/onnx_engine.cc | 47 ++++++++++++++++++++++++++++++++++++---------- src/onnx_engine.h | 1 + 2 files changed, 38 insertions(+), 10 deletions(-) diff --git a/src/onnx_engine.cc b/src/onnx_engine.cc index 6450ce8..e379657 100644 --- a/src/onnx_engine.cc +++ b/src/onnx_engine.cc @@ -127,7 +127,7 @@ OnnxEngine::OnnxEngine() { void OnnxEngine::LoadModel( std::shared_ptr json_body, std::function&& callback) { - auto path = json_body->get("model_path", "").asString(); + path_ = json_body->get("model_path", "").asString(); user_prompt_ = json_body->get("user_prompt", "USER: ").asString(); ai_prompt_ = json_body->get("ai_prompt", "ASSISTANT: ").asString(); system_prompt_ = @@ -136,7 +136,7 @@ void OnnxEngine::LoadModel( max_history_chat_ = json_body->get("max_history_chat", 2).asInt(); try { std::cout << "Creating model..." << std::endl; - oga_model_ = OgaModel::Create(path.c_str()); + oga_model_ = OgaModel::Create(path_.c_str()); std::cout << "Creating tokenizer..." << std::endl; tokenizer_ = OgaTokenizer::Create(*oga_model_); tokenizer_stream_ = OgaTokenizerStream::Create(*tokenizer_); @@ -149,7 +149,7 @@ void OnnxEngine::LoadModel( status["status_code"] = k200OK; callback(std::move(status), std::move(json_resp)); model_id_ = GetModelId(*json_body); - LOG_INFO << "Model loaded successfully: " << path + LOG_INFO << "Model loaded successfully: " << path_ << ", model_id: " << model_id_; model_loaded_ = true; start_time_ = std::chrono::system_clock::now().time_since_epoch() / @@ -158,7 +158,7 @@ void OnnxEngine::LoadModel( q_ = std::make_unique(1, model_id_); } } catch (const std::exception& e) { - std::cout << e.what() << std::endl; + std::cout << "Failed to load model: " << e.what() << std::endl; oga_model_.reset(); tokenizer_.reset(); tokenizer_stream_.reset(); @@ -182,8 +182,8 @@ void OnnxEngine::HandleChatCompletion( auto is_stream = json_body->get("stream", false).asBool(); std::string formatted_output = pre_prompt_; - - int history_max = max_history_chat_ * 2; // both user and assistant + + int history_max = max_history_chat_ * 2; // both user and assistant int index = 0; for (const auto& message : req.messages) { std::string input_role = message["role"].asString(); @@ -272,8 +272,32 @@ void OnnxEngine::HandleChatCompletion( auto duration_ms = std::chrono::duration_cast(end - start) .count(); - LOG_DEBUG << "Generated tokens per second: " - << generated_tokens / duration_ms * 1000; + std::cout << "Generated tokens per second: " + << generated_tokens / duration_ms * 1000 << std::endl; + if ((generated_tokens / duration_ms * 1000) < 1.0f) { + max_history_chat_ = std::max(1, max_history_chat_ / 2); + tokenizer_stream_.reset(); + tokenizer_.reset(); + oga_model_.reset(); + generator.reset(); + params.reset(); + sequences.reset(); + model_loaded_ = false; + LOG_WARN << "Something wrong happened, restart model and try again"; + LOG_INFO << "Creating model..."; + oga_model_ = OgaModel::Create(path_.c_str()); + LOG_INFO << "Creating tokenizer..."; + tokenizer_ = OgaTokenizer::Create(*oga_model_); + tokenizer_stream_ = OgaTokenizerStream::Create(*tokenizer_); + LOG_INFO << "Model loaded successfully: " << path_ + << ", model_id: " << model_id_; + model_loaded_ = true; + start_time_ = std::chrono::system_clock::now().time_since_epoch() / + std::chrono::milliseconds(1); + if (q_ == nullptr) { + q_ = std::make_unique(1, model_id_); + } + } LOG_INFO << "End of result"; Json::Value resp_data; @@ -326,10 +350,13 @@ void OnnxEngine::HandleChatCompletion( status["has_error"] = false; status["is_stream"] = false; status["status_code"] = k200OK; - cb(std::move(status), std::move(resp_data)); + cb(std::move(status), std::move(resp_data)); } } catch (const std::exception& e) { - std::cout << e.what() << std::endl; + tokenizer_stream_.reset(); + tokenizer_.reset(); + oga_model_.reset(); + std::cout << "Error during inference: " << e.what() << std::endl; Json::Value json_resp; json_resp["message"] = "Error during inference"; Json::Value status; diff --git a/src/onnx_engine.h b/src/onnx_engine.h index ffad22b..b316b9e 100644 --- a/src/onnx_engine.h +++ b/src/onnx_engine.h @@ -51,5 +51,6 @@ class OnnxEngine : public EngineI { uint64_t start_time_; int max_history_chat_; std::unique_ptr q_; + std::string path_; }; } // namespace cortex_onnx \ No newline at end of file