diff --git a/controllers/llamaCPP.cc b/controllers/llamaCPP.cc index 5e9c1a9b4..c03e11949 100644 --- a/controllers/llamaCPP.cc +++ b/controllers/llamaCPP.cc @@ -1,7 +1,5 @@ #include "llamaCPP.h" -#include - #include "llama.h" #include "log.h" #include "utils/nitro_utils.h" @@ -9,12 +7,6 @@ using namespace inferences; using json = nlohmann::json; -/** - * Queue to handle the inference task, this is to ensure that the inference - * task is handled in a sequential manner - */ -static trantor::SerialTaskQueue queue("worker"); - /** * The state of the inference task */ @@ -32,7 +24,6 @@ enum InferenceStatus { * associated with. */ struct inferenceState { - bool is_stopped = false; int task_id; InferenceStatus inferenceStatus = PENDING; llamaCPP *instance; @@ -150,7 +141,7 @@ std::string create_return_json(const std::string &id, const std::string &model, return Json::writeString(writer, root); } -llamaCPP::llamaCPP() { +llamaCPP::llamaCPP(): queue(new trantor::ConcurrentTaskQueue(llama.params.n_parallel, "llamaCPP")) { // Some default values for now below log_disable(); // Disable the log to file feature, reduce bloat for // target @@ -341,18 +332,17 @@ void llamaCPP::inferenceImpl( if(state->inferenceStatus == PENDING) { state->inferenceStatus = RUNNING; + } else if (state->inferenceStatus == FINISHED) { + return 0; } if (!pBuffer) { LOG_INFO << "Connection closed or buffer is null. Reset context"; state->instance->llama.request_cancel(state->task_id); - state->instance->single_queue_is_busy = false; - return 0; - } - if (state->is_stopped) { - state->instance->single_queue_is_busy = false; + state->inferenceStatus = FINISHED; return 0; } + task_result result = state->instance->llama.next_result(state->task_id); if (!result.error) { @@ -377,31 +367,27 @@ void llamaCPP::inferenceImpl( std::size_t nRead = std::min(str.size(), nBuffSize); memcpy(pBuffer, str.data(), nRead); LOG_INFO << "reached result stop"; - state->is_stopped = true; state->instance->llama.request_cancel(state->task_id); - state->instance->single_queue_is_busy = false; + state->inferenceStatus = FINISHED; } // Make sure nBufferSize is not zero // Otherwise it stop streaming if(!nRead) { - state->instance->single_queue_is_busy = false; + state->inferenceStatus = FINISHED; } return nRead; } - state->instance->single_queue_is_busy = false; + state->inferenceStatus = FINISHED; return 0; }; - - // Run task in serial queue - queue.runTaskInQueue([callback, state, data, + // Queued task + state->instance->queue->runTaskInQueue([callback, state, data, chunked_content_provider]() { state->task_id = state->instance->llama.request_completion(data, false, false, -1); - state->instance->single_queue_is_busy = true; - // Start streaming response auto resp = nitro_utils::nitroStreamResponse(chunked_content_provider, "chat_completions.txt"); @@ -410,7 +396,7 @@ void llamaCPP::inferenceImpl( int retries = 0; // Since this is an async task, we will wait for the task to be completed - while (state->instance->single_queue_is_busy && retries < 10) { + while (state->inferenceStatus != FINISHED && retries < 10) { // Should wait chunked_content_provider lambda to be called within 3s if(state->inferenceStatus == PENDING) { retries += 1; @@ -418,8 +404,6 @@ void llamaCPP::inferenceImpl( LOG_INFO << "Wait for task to be released:" << state->task_id; std::this_thread::sleep_for(std::chrono::milliseconds(300)); } - - state->inferenceStatus = FINISHED; }); return; } else { @@ -466,59 +450,51 @@ void llamaCPP::embeddingImpl( std::shared_ptr jsonBody, std::function &callback) { - Json::Value responseData(Json::arrayValue); + // Queue embedding task auto state = create_inference_state(this); - if (jsonBody->isMember("input")) { - // If single queue is busy, we will wait if not we will just go ahead and - // process and make it busy, and yet i'm aware not DRY, i have the same - // stuff on chatcompletion as well - if (state->instance->llama.params.n_parallel == 1) { - while (state->instance->single_queue_is_busy) { - LOG_INFO << "Waiting for task to be released status:" - << state->instance->single_queue_is_busy; - std::this_thread::sleep_for( - std::chrono::milliseconds(500)); // Waiting in 500 miliseconds step - } - } - const Json::Value &input = (*jsonBody)["input"]; - if (input.isString()) { - // Process the single string input - state->task_id = llama.request_completion( - {{"prompt", input.asString()}, {"n_predict", 0}}, false, true, -1); - state->instance->single_queue_is_busy = true; - task_result result = llama.next_result(state->task_id); - std::vector embedding_result = result.result_json["embedding"]; - responseData.append(create_embedding_payload(embedding_result, 0)); - } else if (input.isArray()) { - // Process each element in the array input - for (const auto &elem : input) { - if (elem.isString()) { - const int task_id = llama.request_completion( - {{"prompt", elem.asString()}, {"n_predict", 0}}, false, true, -1); - task_result result = llama.next_result(task_id); - std::vector embedding_result = result.result_json["embedding"]; - responseData.append(create_embedding_payload(embedding_result, 0)); + + state->instance->queue->runTaskInQueue([this, state, jsonBody, callback]() { + Json::Value responseData(Json::arrayValue); + + if (jsonBody->isMember("input")) { + const Json::Value &input = (*jsonBody)["input"]; + if (input.isString()) { + // Process the single string input + state->task_id = llama.request_completion( + {{"prompt", input.asString()}, {"n_predict", 0}}, false, true, -1); + task_result result = llama.next_result(state->task_id); + std::vector embedding_result = result.result_json["embedding"]; + responseData.append(create_embedding_payload(embedding_result, 0)); + } else if (input.isArray()) { + // Process each element in the array input + for (const auto &elem : input) { + if (elem.isString()) { + const int task_id = llama.request_completion( + {{"prompt", elem.asString()}, {"n_predict", 0}}, false, true, + -1); + task_result result = llama.next_result(task_id); + std::vector embedding_result = + result.result_json["embedding"]; + responseData.append(create_embedding_payload(embedding_result, 0)); + } } } } - } - - // We already got result of the embedding so no longer busy - state->instance->single_queue_is_busy = false; - auto resp = nitro_utils::nitroHttpResponse(); - Json::Value root; - root["data"] = responseData; - root["model"] = "_"; - root["object"] = "list"; - Json::Value usage; - usage["prompt_tokens"] = 0; - usage["total_tokens"] = 0; - root["usage"] = usage; - - resp->setBody(Json::writeString(Json::StreamWriterBuilder(), root)); - resp->setContentTypeString("application/json"); - callback(resp); + auto resp = nitro_utils::nitroHttpResponse(); + Json::Value root; + root["data"] = responseData; + root["model"] = "_"; + root["object"] = "list"; + Json::Value usage; + usage["prompt_tokens"] = 0; + usage["total_tokens"] = 0; + root["usage"] = usage; + + resp->setBody(Json::writeString(Json::StreamWriterBuilder(), root)); + resp->setContentTypeString("application/json"); + callback(resp); + }); } void llamaCPP::unloadModel( @@ -539,6 +515,7 @@ void llamaCPP::unloadModel( callback(resp); return; } + void llamaCPP::modelStatus( const HttpRequestPtr &req, std::function &&callback) { @@ -555,6 +532,7 @@ void llamaCPP::modelStatus( callback(resp); return; } + void llamaCPP::loadModel( const HttpRequestPtr &req, std::function &&callback) { @@ -674,6 +652,12 @@ bool llamaCPP::loadModelImpl(std::shared_ptr jsonBody) { } llama.initialize(); + if (queue != nullptr) { + delete queue; + } + + queue = new trantor::ConcurrentTaskQueue(llama.params.n_parallel, "llamaCPP"); + llama.model_loaded_external = true; LOG_INFO << "Started background task here!"; diff --git a/controllers/llamaCPP.h b/controllers/llamaCPP.h index 5f2be54b3..ad1889be0 100644 --- a/controllers/llamaCPP.h +++ b/controllers/llamaCPP.h @@ -26,6 +26,7 @@ #include "common/base.h" #include "utils/json.hpp" +#include // auto generated files (update with ./deps.sh) @@ -2562,10 +2563,13 @@ class llamaCPP : public drogon::HttpController, public ChatProvider { bool caching_enabled; std::atomic no_of_chats = 0; int clean_cache_threshold; - std::atomic single_queue_is_busy; // This value only used under the - // condition n_parallel is 1 std::string grammar_file_content; + /** + * Queue to handle the inference tasks + */ + trantor::ConcurrentTaskQueue *queue; + bool loadModelImpl(std::shared_ptr jsonBody); void inferenceImpl(std::shared_ptr jsonBody, std::function &callback);