Skip to content

Commit

Permalink
feat: add max_history_chat parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
sangjanai committed Jun 20, 2024
1 parent c5f73d0 commit 9a3d8d8
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 3 deletions.
17 changes: 14 additions & 3 deletions src/onnx_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ void OnnxEngine::LoadModel(
system_prompt_ =
json_body->get("system_prompt", "ASSISTANT's RULE: ").asString();
pre_prompt_ = json_body->get("pre_prompt", "").asString();
max_history_chat_ = json_body->get("max_history_chat", 2).asInt();
try {
std::cout << "Creating model..." << std::endl;
oga_model_ = OgaModel::Create(path.c_str());
Expand Down Expand Up @@ -181,17 +182,24 @@ void OnnxEngine::HandleChatCompletion(
auto is_stream = json_body->get("stream", false).asBool();

std::string formatted_output = pre_prompt_;

int history_max = max_history_chat_ * 2; // both user and assistant
int index = 0;
for (const auto& message : req.messages) {
std::string input_role = message["role"].asString();
std::string role;
if (input_role == "user") {
role = user_prompt_;
std::string content = message["content"].asString();
formatted_output += role + content;
if (index > static_cast<int>(req.messages.size()) - history_max) {
formatted_output += role + content;
}
} else if (input_role == "assistant") {
role = ai_prompt_;
std::string content = message["content"].asString();
formatted_output += role + content;
if (index > static_cast<int>(req.messages.size()) - history_max) {
formatted_output += role + content;
}
} else if (input_role == "system") {
role = system_prompt_;
std::string content = message["content"].asString();
Expand All @@ -200,7 +208,9 @@ void OnnxEngine::HandleChatCompletion(
role = input_role;
std::string content = message["content"].asString();
formatted_output += role + content;
LOG_WARN << "Should specify input_role";
}
index++;
}
formatted_output += ai_prompt_;

Expand Down Expand Up @@ -305,7 +315,8 @@ void OnnxEngine::HandleChatCompletion(
std::chrono::duration_cast<std::chrono::milliseconds>(end - start)
.count();
LOG_DEBUG << "Generated tokens per second: "
<< static_cast<double>(output_sequence_length) / duration_ms * 1000;
<< static_cast<double>(output_sequence_length) / duration_ms *
1000;

std::string to_send = out_string.p_;
auto resp_data = CreateFullReturnJson(GenerateRandomString(20), "_",
Expand Down
1 change: 1 addition & 0 deletions src/onnx_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class OnnxEngine : public EngineI {
std::string pre_prompt_;
std::string model_id_;
uint64_t start_time_;
int max_history_chat_;
std::unique_ptr<trantor::ConcurrentTaskQueue> q_;
};
} // namespace cortex_onnx

0 comments on commit 9a3d8d8

Please sign in to comment.