Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: use async queue for chat completion #8

Merged
merged 4 commits into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion build_cortex_onnx.bat
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
cmake -S ./third-party -B ./build_deps/third-party
cmake --build ./build_deps/third-party --config Release -j4

cmake -S .\onnxruntime-genai\ -B .\onnxruntime-genai\build -DUSE_DML=ON -DUSE_CUDA=OFF -DORT_HOME="./build_deps/ort" -DENABLE_PYTHON=OFF
cmake -S .\onnxruntime-genai\ -B .\onnxruntime-genai\build -DUSE_DML=ON -DUSE_CUDA=OFF -DORT_HOME="./build_deps/ort" -DENABLE_PYTHON=OFF -DENABLE_TESTS=OFF -DENABLE_MODEL_BENCHMARK=OFF
cmake --build .\onnxruntime-genai\build --config Release -j4
181 changes: 101 additions & 80 deletions src/onnx_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,9 @@ void OnnxEngine::LoadModel(
model_loaded_ = true;
start_time_ = std::chrono::system_clock::now().time_since_epoch() /
std::chrono::milliseconds(1);
if (q_ == nullptr) {
q_ = std::make_unique<trantor::ConcurrentTaskQueue>(1, model_id_);
}
} catch (const std::exception& e) {
std::cout << e.what() << std::endl;
oga_model_.reset();
Expand Down Expand Up @@ -202,97 +205,115 @@ void OnnxEngine::HandleChatCompletion(
formatted_output += ai_prompt_;

// LOG_DEBUG << formatted_output;

try {
if (req.stream) {
auto sequences = OgaSequences::Create();
tokenizer_->Encode(formatted_output.c_str(), *sequences);

auto params = OgaGeneratorParams::Create(*oga_model_);
// TODO(sang)
params->SetSearchOption("max_length", req.max_tokens);
params->SetSearchOption("top_p", req.top_p);
params->SetSearchOption("temperature", req.temperature);
// params->SetSearchOption("repetition_penalty", 0.95);
params->SetInputSequences(*sequences);

auto generator = OgaGenerator::Create(*oga_model_, *params);
while (!generator->IsDone()) {
generator->ComputeLogits();
generator->GenerateNextToken();

const int32_t num_tokens = generator->GetSequenceCount(0);
int32_t new_token = generator->GetSequenceData(0)[num_tokens - 1];
auto out_string = tokenizer_stream_->Decode(new_token);
std::cout << out_string;
// TODO(sang)
q_->runTaskInQueue([this, cb = std::move(callback), formatted_output, req] {
try {
if (req.stream) {

auto sequences = OgaSequences::Create();
tokenizer_->Encode(formatted_output.c_str(), *sequences);

auto params = OgaGeneratorParams::Create(*oga_model_);
// TODO(sang)
params->SetSearchOption("max_length", req.max_tokens);
params->SetSearchOption("top_p", req.top_p);
params->SetSearchOption("temperature", req.temperature);
// params->SetSearchOption("repetition_penalty", 0.95);
params->SetInputSequences(*sequences);

auto generator = OgaGenerator::Create(*oga_model_, *params);
while (!generator->IsDone() && model_loaded_) {
generator->ComputeLogits();
generator->GenerateNextToken();

const int32_t num_tokens = generator->GetSequenceCount(0);
int32_t new_token = generator->GetSequenceData(0)[num_tokens - 1];
auto out_string = tokenizer_stream_->Decode(new_token);
// std::cout << out_string;
const std::string str =
"data: " +
CreateReturnJson(GenerateRandomString(20), "_", out_string) +
"\n\n";
Json::Value resp_data;
resp_data["data"] = str;
Json::Value status;
status["is_done"] = false;
status["has_error"] = false;
status["is_stream"] = true;
status["status_code"] = k200OK;
cb(std::move(status), std::move(resp_data));
}

if (!model_loaded_) {
LOG_WARN << "Model unloaded during inference";
Json::Value respData;
respData["data"] = std::string();
Json::Value status;
status["is_done"] = false;
status["has_error"] = true;
status["is_stream"] = true;
status["status_code"] = k200OK;
cb(std::move(status), std::move(respData));
return;
}

LOG_INFO << "End of result";
Json::Value resp_data;
const std::string str =
"data: " +
CreateReturnJson(GenerateRandomString(20), "_", out_string) +
"\n\n";
Json::Value resp_data;
CreateReturnJson(GenerateRandomString(20), "_", "", "stop") +
"\n\n" + "data: [DONE]" + "\n\n";
resp_data["data"] = str;
Json::Value status;
status["is_done"] = false;
status["is_done"] = true;
status["has_error"] = false;
status["is_stream"] = true;
status["status_code"] = k200OK;
callback(std::move(status), std::move(resp_data));
cb(std::move(status), std::move(resp_data));

} else {
auto sequences = OgaSequences::Create();
tokenizer_->Encode(formatted_output.c_str(), *sequences);

auto params = OgaGeneratorParams::Create(*oga_model_);
params->SetSearchOption("max_length", req.max_tokens);
params->SetSearchOption("top_p", req.top_p);
params->SetSearchOption("temperature", req.temperature);
// params->SetSearchOption("repetition_penalty", req.frequency_penalty);
params->SetInputSequences(*sequences);

auto output_sequences = oga_model_->Generate(*params);
const auto output_sequence_length =
output_sequences->SequenceCount(0) - sequences->SequenceCount(0);
const auto* output_sequence_data =
output_sequences->SequenceData(0) + sequences->SequenceCount(0);
auto out_string =
tokenizer_->Decode(output_sequence_data, output_sequence_length);

// std::cout << "Output: " << std::endl << out_string << std::endl;

std::string to_send = out_string.p_;
auto resp_data = CreateFullReturnJson(GenerateRandomString(20), "_",
to_send, "_", 0, 0);
Json::Value status;
status["is_done"] = true;
status["has_error"] = false;
status["is_stream"] = false;
status["status_code"] = k200OK;
cb(std::move(status), std::move(resp_data));
}

LOG_INFO << "End of result";
Json::Value resp_data;
const std::string str =
"data: " + CreateReturnJson("gdsf", "_", "", "stop") + "\n\n" +
"data: [DONE]" + "\n\n";
resp_data["data"] = str;
Json::Value status;
status["is_done"] = true;
status["has_error"] = false;
status["is_stream"] = true;
status["status_code"] = k200OK;
callback(std::move(status), std::move(resp_data));
} else {
auto sequences = OgaSequences::Create();
tokenizer_->Encode(formatted_output.c_str(), *sequences);

auto params = OgaGeneratorParams::Create(*oga_model_);
params->SetSearchOption("max_length", req.max_tokens);
params->SetSearchOption("top_p", req.top_p);
params->SetSearchOption("temperature", req.temperature);
// params->SetSearchOption("repetition_penalty", req.frequency_penalty);
params->SetInputSequences(*sequences);

auto output_sequences = oga_model_->Generate(*params);
const auto output_sequence_length =
output_sequences->SequenceCount(0) - sequences->SequenceCount(0);
const auto* output_sequence_data =
output_sequences->SequenceData(0) + sequences->SequenceCount(0);
auto out_string =
tokenizer_->Decode(output_sequence_data, output_sequence_length);

// std::cout << "Output: " << std::endl << out_string << std::endl;

std::string to_send = out_string.p_;
auto resp_data = CreateFullReturnJson(GenerateRandomString(20), "_",
to_send, "_", 0, 0);
} catch (const std::exception& e) {
std::cout << e.what() << std::endl;
Json::Value json_resp;
json_resp["message"] = "Error during inference";
Json::Value status;
status["is_done"] = true;
status["has_error"] = false;
status["is_done"] = false;
status["has_error"] = true;
status["is_stream"] = false;
status["status_code"] = k200OK;
callback(std::move(status), std::move(resp_data));
status["status_code"] = k500InternalServerError;
cb(std::move(status), std::move(json_resp));
}
} catch (const std::exception& e) {
std::cout << e.what() << std::endl;
Json::Value json_resp;
json_resp["message"] = "Error during inference";
Json::Value status;
status["is_done"] = false;
status["has_error"] = true;
status["is_stream"] = false;
status["status_code"] = k500InternalServerError;
callback(std::move(status), std::move(json_resp));
}
});
}

void OnnxEngine::HandleEmbedding(
Expand Down
2 changes: 2 additions & 0 deletions src/onnx_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "json/value.h"
#include "ort_genai.h"
#include "ort_genai_c.h"
#include "trantor/utils/ConcurrentTaskQueue.h"

namespace cortex_onnx {
class OnnxEngine : public EngineI {
Expand Down Expand Up @@ -48,5 +49,6 @@ class OnnxEngine : public EngineI {
std::string pre_prompt_;
std::string model_id_;
uint64_t start_time_;
std::unique_ptr<trantor::ConcurrentTaskQueue> q_;
};
} // namespace cortex_onnx
Loading