From c229a52e30556c161676c59748ac3d44fac46bc9 Mon Sep 17 00:00:00 2001 From: tikikun Date: Wed, 15 Nov 2023 16:31:12 +0700 Subject: [PATCH 1/2] feat: make non stream completion possible to be fully compatible with openaiapi --- controllers/llamaCPP.cc | 152 ++++++++++++++++++++++++++++------------ 1 file changed, 109 insertions(+), 43 deletions(-) diff --git a/controllers/llamaCPP.cc b/controllers/llamaCPP.cc index 63710b466..b7a6fc3f2 100644 --- a/controllers/llamaCPP.cc +++ b/controllers/llamaCPP.cc @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include @@ -28,6 +29,45 @@ std::shared_ptr createState(int task_id, llamaCPP *instance) { // -------------------------------------------- +std::string create_full_return_json(const std::string &id, + const std::string &model, + const std::string &content, + const std::string &system_fingerprint, + int prompt_tokens, int completion_tokens, + Json::Value finish_reason = Json::Value()) { + + Json::Value root; + + root["id"] = id; + root["model"] = model; + root["created"] = static_cast(std::time(nullptr)); + root["object"] = "chat.completion"; + root["system_fingerprint"] = system_fingerprint; + + Json::Value choicesArray(Json::arrayValue); + Json::Value choice; + + choice["index"] = 0; + Json::Value message; + message["role"] = "assistant"; + message["content"] = content; + choice["message"] = message; + choice["finish_reason"] = finish_reason; + + choicesArray.append(choice); + root["choices"] = choicesArray; + + Json::Value usage; + usage["prompt_tokens"] = prompt_tokens; + usage["completion_tokens"] = completion_tokens; + usage["total_tokens"] = prompt_tokens + completion_tokens; + root["usage"] = usage; + + Json::StreamWriterBuilder writer; + writer["indentation"] = ""; // Compact output + return Json::writeString(writer, root); +} + std::string create_return_json(const std::string &id, const std::string &model, const std::string &content, Json::Value finish_reason = Json::Value()) { @@ -82,9 +122,9 @@ void llamaCPP::chatCompletion( json data; json stopWords; // To set default value - data["stream"] = true; if (jsonBody) { + data["stream"] = (*jsonBody).get("stream", false).asBool(); data["n_predict"] = (*jsonBody).get("max_tokens", 500).asInt(); data["top_p"] = (*jsonBody).get("top_p", 0.95).asFloat(); data["temperature"] = (*jsonBody).get("temperature", 0.8).asFloat(); @@ -119,62 +159,87 @@ void llamaCPP::chatCompletion( data["stop"] = stopWords; } + bool is_streamed = data["stream"]; + const int task_id = llama.request_completion(data, false, false); LOG_INFO << "Resolved request for task_id:" << task_id; - auto state = createState(task_id, this); + if (is_streamed) { + auto state = createState(task_id, this); - auto chunked_content_provider = - [state](char *pBuffer, std::size_t nBuffSize) -> std::size_t { - if (!pBuffer) { - LOG_INFO << "Connection closed or buffer is null. Reset context"; - state->instance->llama.request_cancel(state->task_id); - return 0; - } - if (state->isStopped) { - return 0; - } - - task_result result = state->instance->llama.next_result(state->task_id); - if (!result.error) { - const std::string to_send = result.result_json["content"]; - const std::string str = - "data: " + - create_return_json(nitro_utils::generate_random_string(20), "_", - to_send) + - "\n\n"; - - std::size_t nRead = std::min(str.size(), nBuffSize); - memcpy(pBuffer, str.data(), nRead); + auto chunked_content_provider = + [state](char *pBuffer, std::size_t nBuffSize) -> std::size_t { + if (!pBuffer) { + LOG_INFO << "Connection closed or buffer is null. Reset context"; + state->instance->llama.request_cancel(state->task_id); + return 0; + } + if (state->isStopped) { + return 0; + } - if (result.stop) { + task_result result = state->instance->llama.next_result(state->task_id); + if (!result.error) { + const std::string to_send = result.result_json["content"]; const std::string str = "data: " + - create_return_json(nitro_utils::generate_random_string(20), "_", "", - "stop") + - "\n\n" + "data: [DONE]" + "\n\n"; + create_return_json(nitro_utils::generate_random_string(20), "_", + to_send) + + "\n\n"; - LOG_VERBOSE("data stream", {{"to_send", str}}); std::size_t nRead = std::min(str.size(), nBuffSize); memcpy(pBuffer, str.data(), nRead); - LOG_INFO << "reached result stop"; - state->isStopped = true; - state->instance->llama.request_cancel(state->task_id); + + if (result.stop) { + const std::string str = + "data: " + + create_return_json(nitro_utils::generate_random_string(20), "_", + "", "stop") + + "\n\n" + "data: [DONE]" + "\n\n"; + + LOG_VERBOSE("data stream", {{"to_send", str}}); + std::size_t nRead = std::min(str.size(), nBuffSize); + memcpy(pBuffer, str.data(), nRead); + LOG_INFO << "reached result stop"; + state->isStopped = true; + state->instance->llama.request_cancel(state->task_id); + return nRead; + } return nRead; + } else { + return 0; } - return nRead; - } else { return 0; - } - return 0; - }; - auto resp = nitro_utils::nitroStreamResponse(chunked_content_provider, - "chat_completions.txt"); - callback(resp); + }; + auto resp = nitro_utils::nitroStreamResponse(chunked_content_provider, + "chat_completions.txt"); + callback(resp); - return; + return; + } else { + Json::Value respData; + auto resp = nitro_utils::nitroHttpResponse(); + respData["testing"] = "thunghiem value moi"; + if (!json_value(data, "stream", false)) { + std::string completion_text; + task_result result = llama.next_result(task_id); + if (!result.error && result.stop) { + int prompt_tokens = result.result_json["tokens_evaluated"]; + int predicted_tokens = result.result_json["tokens_predicted"]; + std::string full_return = + create_full_return_json(nitro_utils::generate_random_string(20), + "_", result.result_json["content"], "_", + prompt_tokens, predicted_tokens); + resp->setBody(full_return); + } else { + resp->setBody("internal error during inference"); + return; + } + callback(resp); + return; + } + } } - void llamaCPP::embedding( const HttpRequestPtr &req, std::function &&callback) { @@ -262,7 +327,8 @@ void llamaCPP::loadModel( this->pre_prompt = (*jsonBody) .get("pre_prompt", - "A chat between a curious user and an artificial intelligence " + "A chat between a curious user and an artificial " + "intelligence " "assistant. The assistant follows the given rules no matter " "what.\\n") .asString(); From 2be0d28ab497e3bbf88ed81d01710d28d72af8c3 Mon Sep 17 00:00:00 2001 From: tikikun Date: Wed, 15 Nov 2023 16:34:16 +0700 Subject: [PATCH 2/2] feat: clean up some redundant include --- controllers/llamaCPP.cc | 2 -- 1 file changed, 2 deletions(-) diff --git a/controllers/llamaCPP.cc b/controllers/llamaCPP.cc index b7a6fc3f2..73e55bd01 100644 --- a/controllers/llamaCPP.cc +++ b/controllers/llamaCPP.cc @@ -5,11 +5,9 @@ #include #include #include -#include #include #include #include -#include using namespace inferences; using json = nlohmann::json;