diff --git a/src/onnx_engine.cc b/src/onnx_engine.cc index 54a73d1..6450ce8 100644 --- a/src/onnx_engine.cc +++ b/src/onnx_engine.cc @@ -133,6 +133,7 @@ void OnnxEngine::LoadModel( system_prompt_ = json_body->get("system_prompt", "ASSISTANT's RULE: ").asString(); pre_prompt_ = json_body->get("pre_prompt", "").asString(); + max_history_chat_ = json_body->get("max_history_chat", 2).asInt(); try { std::cout << "Creating model..." << std::endl; oga_model_ = OgaModel::Create(path.c_str()); @@ -181,17 +182,24 @@ void OnnxEngine::HandleChatCompletion( auto is_stream = json_body->get("stream", false).asBool(); std::string formatted_output = pre_prompt_; + + int history_max = max_history_chat_ * 2; // both user and assistant + int index = 0; for (const auto& message : req.messages) { std::string input_role = message["role"].asString(); std::string role; if (input_role == "user") { role = user_prompt_; std::string content = message["content"].asString(); - formatted_output += role + content; + if (index > static_cast(req.messages.size()) - history_max) { + formatted_output += role + content; + } } else if (input_role == "assistant") { role = ai_prompt_; std::string content = message["content"].asString(); - formatted_output += role + content; + if (index > static_cast(req.messages.size()) - history_max) { + formatted_output += role + content; + } } else if (input_role == "system") { role = system_prompt_; std::string content = message["content"].asString(); @@ -200,7 +208,9 @@ void OnnxEngine::HandleChatCompletion( role = input_role; std::string content = message["content"].asString(); formatted_output += role + content; + LOG_WARN << "Should specify input_role"; } + index++; } formatted_output += ai_prompt_; @@ -305,7 +315,8 @@ void OnnxEngine::HandleChatCompletion( std::chrono::duration_cast(end - start) .count(); LOG_DEBUG << "Generated tokens per second: " - << static_cast(output_sequence_length) / duration_ms * 1000; + << static_cast(output_sequence_length) / duration_ms * + 1000; std::string to_send = out_string.p_; auto resp_data = CreateFullReturnJson(GenerateRandomString(20), "_", diff --git a/src/onnx_engine.h b/src/onnx_engine.h index d54f113..ffad22b 100644 --- a/src/onnx_engine.h +++ b/src/onnx_engine.h @@ -49,6 +49,7 @@ class OnnxEngine : public EngineI { std::string pre_prompt_; std::string model_id_; uint64_t start_time_; + int max_history_chat_; std::unique_ptr q_; }; } // namespace cortex_onnx \ No newline at end of file