diff --git a/build_cortex_onnx.bat b/build_cortex_onnx.bat index bd321d4..1af5516 100644 --- a/build_cortex_onnx.bat +++ b/build_cortex_onnx.bat @@ -1,5 +1,5 @@ cmake -S ./third-party -B ./build_deps/third-party cmake --build ./build_deps/third-party --config Release -j4 -cmake -S .\onnxruntime-genai\ -B .\onnxruntime-genai\build -DUSE_DML=ON -DUSE_CUDA=OFF -DORT_HOME="./build_deps/ort" -DENABLE_PYTHON=OFF +cmake -S .\onnxruntime-genai\ -B .\onnxruntime-genai\build -DUSE_DML=ON -DUSE_CUDA=OFF -DORT_HOME="./build_deps/ort" -DENABLE_PYTHON=OFF -DENABLE_TESTS=OFF -DENABLE_MODEL_BENCHMARK=OFF cmake --build .\onnxruntime-genai\build --config Release -j4 diff --git a/src/onnx_engine.cc b/src/onnx_engine.cc index c26eead..57fd28f 100644 --- a/src/onnx_engine.cc +++ b/src/onnx_engine.cc @@ -153,6 +153,9 @@ void OnnxEngine::LoadModel( model_loaded_ = true; start_time_ = std::chrono::system_clock::now().time_since_epoch() / std::chrono::milliseconds(1); + if (q_ == nullptr) { + q_ = std::make_unique(1, model_id_); + } } catch (const std::exception& e) { std::cout << e.what() << std::endl; oga_model_.reset(); @@ -202,97 +205,115 @@ void OnnxEngine::HandleChatCompletion( formatted_output += ai_prompt_; // LOG_DEBUG << formatted_output; - - try { - if (req.stream) { - auto sequences = OgaSequences::Create(); - tokenizer_->Encode(formatted_output.c_str(), *sequences); - - auto params = OgaGeneratorParams::Create(*oga_model_); - // TODO(sang) - params->SetSearchOption("max_length", req.max_tokens); - params->SetSearchOption("top_p", req.top_p); - params->SetSearchOption("temperature", req.temperature); - // params->SetSearchOption("repetition_penalty", 0.95); - params->SetInputSequences(*sequences); - - auto generator = OgaGenerator::Create(*oga_model_, *params); - while (!generator->IsDone()) { - generator->ComputeLogits(); - generator->GenerateNextToken(); - - const int32_t num_tokens = generator->GetSequenceCount(0); - int32_t new_token = generator->GetSequenceData(0)[num_tokens - 1]; - auto out_string = tokenizer_stream_->Decode(new_token); - std::cout << out_string; + // TODO(sang) + q_->runTaskInQueue([this, cb = std::move(callback), formatted_output, req] { + try { + if (req.stream) { + + auto sequences = OgaSequences::Create(); + tokenizer_->Encode(formatted_output.c_str(), *sequences); + + auto params = OgaGeneratorParams::Create(*oga_model_); + // TODO(sang) + params->SetSearchOption("max_length", req.max_tokens); + params->SetSearchOption("top_p", req.top_p); + params->SetSearchOption("temperature", req.temperature); + // params->SetSearchOption("repetition_penalty", 0.95); + params->SetInputSequences(*sequences); + + auto generator = OgaGenerator::Create(*oga_model_, *params); + while (!generator->IsDone() && model_loaded_) { + generator->ComputeLogits(); + generator->GenerateNextToken(); + + const int32_t num_tokens = generator->GetSequenceCount(0); + int32_t new_token = generator->GetSequenceData(0)[num_tokens - 1]; + auto out_string = tokenizer_stream_->Decode(new_token); + // std::cout << out_string; + const std::string str = + "data: " + + CreateReturnJson(GenerateRandomString(20), "_", out_string) + + "\n\n"; + Json::Value resp_data; + resp_data["data"] = str; + Json::Value status; + status["is_done"] = false; + status["has_error"] = false; + status["is_stream"] = true; + status["status_code"] = k200OK; + cb(std::move(status), std::move(resp_data)); + } + + if (!model_loaded_) { + LOG_WARN << "Model unloaded during inference"; + Json::Value respData; + respData["data"] = std::string(); + Json::Value status; + status["is_done"] = false; + status["has_error"] = true; + status["is_stream"] = true; + status["status_code"] = k200OK; + cb(std::move(status), std::move(respData)); + return; + } + + LOG_INFO << "End of result"; + Json::Value resp_data; const std::string str = "data: " + - CreateReturnJson(GenerateRandomString(20), "_", out_string) + - "\n\n"; - Json::Value resp_data; + CreateReturnJson(GenerateRandomString(20), "_", "", "stop") + + "\n\n" + "data: [DONE]" + "\n\n"; resp_data["data"] = str; Json::Value status; - status["is_done"] = false; + status["is_done"] = true; status["has_error"] = false; status["is_stream"] = true; status["status_code"] = k200OK; - callback(std::move(status), std::move(resp_data)); + cb(std::move(status), std::move(resp_data)); + + } else { + auto sequences = OgaSequences::Create(); + tokenizer_->Encode(formatted_output.c_str(), *sequences); + + auto params = OgaGeneratorParams::Create(*oga_model_); + params->SetSearchOption("max_length", req.max_tokens); + params->SetSearchOption("top_p", req.top_p); + params->SetSearchOption("temperature", req.temperature); + // params->SetSearchOption("repetition_penalty", req.frequency_penalty); + params->SetInputSequences(*sequences); + + auto output_sequences = oga_model_->Generate(*params); + const auto output_sequence_length = + output_sequences->SequenceCount(0) - sequences->SequenceCount(0); + const auto* output_sequence_data = + output_sequences->SequenceData(0) + sequences->SequenceCount(0); + auto out_string = + tokenizer_->Decode(output_sequence_data, output_sequence_length); + + // std::cout << "Output: " << std::endl << out_string << std::endl; + + std::string to_send = out_string.p_; + auto resp_data = CreateFullReturnJson(GenerateRandomString(20), "_", + to_send, "_", 0, 0); + Json::Value status; + status["is_done"] = true; + status["has_error"] = false; + status["is_stream"] = false; + status["status_code"] = k200OK; + cb(std::move(status), std::move(resp_data)); } - - LOG_INFO << "End of result"; - Json::Value resp_data; - const std::string str = - "data: " + CreateReturnJson("gdsf", "_", "", "stop") + "\n\n" + - "data: [DONE]" + "\n\n"; - resp_data["data"] = str; - Json::Value status; - status["is_done"] = true; - status["has_error"] = false; - status["is_stream"] = true; - status["status_code"] = k200OK; - callback(std::move(status), std::move(resp_data)); - } else { - auto sequences = OgaSequences::Create(); - tokenizer_->Encode(formatted_output.c_str(), *sequences); - - auto params = OgaGeneratorParams::Create(*oga_model_); - params->SetSearchOption("max_length", req.max_tokens); - params->SetSearchOption("top_p", req.top_p); - params->SetSearchOption("temperature", req.temperature); - // params->SetSearchOption("repetition_penalty", req.frequency_penalty); - params->SetInputSequences(*sequences); - - auto output_sequences = oga_model_->Generate(*params); - const auto output_sequence_length = - output_sequences->SequenceCount(0) - sequences->SequenceCount(0); - const auto* output_sequence_data = - output_sequences->SequenceData(0) + sequences->SequenceCount(0); - auto out_string = - tokenizer_->Decode(output_sequence_data, output_sequence_length); - - // std::cout << "Output: " << std::endl << out_string << std::endl; - - std::string to_send = out_string.p_; - auto resp_data = CreateFullReturnJson(GenerateRandomString(20), "_", - to_send, "_", 0, 0); + } catch (const std::exception& e) { + std::cout << e.what() << std::endl; + Json::Value json_resp; + json_resp["message"] = "Error during inference"; Json::Value status; - status["is_done"] = true; - status["has_error"] = false; + status["is_done"] = false; + status["has_error"] = true; status["is_stream"] = false; - status["status_code"] = k200OK; - callback(std::move(status), std::move(resp_data)); + status["status_code"] = k500InternalServerError; + cb(std::move(status), std::move(json_resp)); } - } catch (const std::exception& e) { - std::cout << e.what() << std::endl; - Json::Value json_resp; - json_resp["message"] = "Error during inference"; - Json::Value status; - status["is_done"] = false; - status["has_error"] = true; - status["is_stream"] = false; - status["status_code"] = k500InternalServerError; - callback(std::move(status), std::move(json_resp)); - } + }); } void OnnxEngine::HandleEmbedding( diff --git a/src/onnx_engine.h b/src/onnx_engine.h index e52dd07..d54f113 100644 --- a/src/onnx_engine.h +++ b/src/onnx_engine.h @@ -6,6 +6,7 @@ #include "json/value.h" #include "ort_genai.h" #include "ort_genai_c.h" +#include "trantor/utils/ConcurrentTaskQueue.h" namespace cortex_onnx { class OnnxEngine : public EngineI { @@ -48,5 +49,6 @@ class OnnxEngine : public EngineI { std::string pre_prompt_; std::string model_id_; uint64_t start_time_; + std::unique_ptr q_; }; } // namespace cortex_onnx \ No newline at end of file