Skip to content

Commit

Permalink
Merge pull request #433 from janhq/refactor/simplify-state-with-queue…
Browse files Browse the repository at this point in the history
…-system

refactor: simplify state with queued system
  • Loading branch information
tikikun authored Feb 14, 2024
2 parents 9c1d8b6 + 5a3432f commit fb7bc74
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 77 deletions.
134 changes: 59 additions & 75 deletions controllers/llamaCPP.cc
Original file line number Diff line number Diff line change
@@ -1,20 +1,12 @@
#include "llamaCPP.h"

#include <trantor/utils/SerialTaskQueue.h>

#include "llama.h"
#include "log.h"
#include "utils/nitro_utils.h"

using namespace inferences;
using json = nlohmann::json;

/**
* Queue to handle the inference task, this is to ensure that the inference
* task is handled in a sequential manner
*/
static trantor::SerialTaskQueue queue("worker");

/**
* The state of the inference task
*/
Expand All @@ -32,7 +24,6 @@ enum InferenceStatus {
* associated with.
*/
struct inferenceState {
bool is_stopped = false;
int task_id;
InferenceStatus inferenceStatus = PENDING;
llamaCPP *instance;
Expand Down Expand Up @@ -150,7 +141,7 @@ std::string create_return_json(const std::string &id, const std::string &model,
return Json::writeString(writer, root);
}

llamaCPP::llamaCPP() {
llamaCPP::llamaCPP(): queue(new trantor::ConcurrentTaskQueue(llama.params.n_parallel, "llamaCPP")) {
// Some default values for now below
log_disable(); // Disable the log to file feature, reduce bloat for
// target
Expand Down Expand Up @@ -341,18 +332,17 @@ void llamaCPP::inferenceImpl(

if(state->inferenceStatus == PENDING) {
state->inferenceStatus = RUNNING;
} else if (state->inferenceStatus == FINISHED) {
return 0;
}

if (!pBuffer) {
LOG_INFO << "Connection closed or buffer is null. Reset context";
state->instance->llama.request_cancel(state->task_id);
state->instance->single_queue_is_busy = false;
return 0;
}
if (state->is_stopped) {
state->instance->single_queue_is_busy = false;
state->inferenceStatus = FINISHED;
return 0;
}


task_result result = state->instance->llama.next_result(state->task_id);
if (!result.error) {
Expand All @@ -377,31 +367,27 @@ void llamaCPP::inferenceImpl(
std::size_t nRead = std::min(str.size(), nBuffSize);
memcpy(pBuffer, str.data(), nRead);
LOG_INFO << "reached result stop";
state->is_stopped = true;
state->instance->llama.request_cancel(state->task_id);
state->instance->single_queue_is_busy = false;
state->inferenceStatus = FINISHED;
}

// Make sure nBufferSize is not zero
// Otherwise it stop streaming
if(!nRead) {
state->instance->single_queue_is_busy = false;
state->inferenceStatus = FINISHED;
}

return nRead;
}
state->instance->single_queue_is_busy = false;
state->inferenceStatus = FINISHED;
return 0;
};

// Run task in serial queue
queue.runTaskInQueue([callback, state, data,
// Queued task
state->instance->queue->runTaskInQueue([callback, state, data,
chunked_content_provider]() {
state->task_id =
state->instance->llama.request_completion(data, false, false, -1);

state->instance->single_queue_is_busy = true;

// Start streaming response
auto resp = nitro_utils::nitroStreamResponse(chunked_content_provider,
"chat_completions.txt");
Expand All @@ -410,16 +396,14 @@ void llamaCPP::inferenceImpl(
int retries = 0;

// Since this is an async task, we will wait for the task to be completed
while (state->instance->single_queue_is_busy && retries < 10) {
while (state->inferenceStatus != FINISHED && retries < 10) {
// Should wait chunked_content_provider lambda to be called within 3s
if(state->inferenceStatus == PENDING) {
retries += 1;
}
LOG_INFO << "Wait for task to be released:" << state->task_id;
std::this_thread::sleep_for(std::chrono::milliseconds(300));
}

state->inferenceStatus = FINISHED;
});
return;
} else {
Expand Down Expand Up @@ -466,59 +450,51 @@ void llamaCPP::embeddingImpl(
std::shared_ptr<Json::Value> jsonBody,
std::function<void(const HttpResponsePtr &)> &callback) {

Json::Value responseData(Json::arrayValue);
// Queue embedding task
auto state = create_inference_state(this);
if (jsonBody->isMember("input")) {
// If single queue is busy, we will wait if not we will just go ahead and
// process and make it busy, and yet i'm aware not DRY, i have the same
// stuff on chatcompletion as well
if (state->instance->llama.params.n_parallel == 1) {
while (state->instance->single_queue_is_busy) {
LOG_INFO << "Waiting for task to be released status:"
<< state->instance->single_queue_is_busy;
std::this_thread::sleep_for(
std::chrono::milliseconds(500)); // Waiting in 500 miliseconds step
}
}
const Json::Value &input = (*jsonBody)["input"];
if (input.isString()) {
// Process the single string input
state->task_id = llama.request_completion(
{{"prompt", input.asString()}, {"n_predict", 0}}, false, true, -1);
state->instance->single_queue_is_busy = true;
task_result result = llama.next_result(state->task_id);
std::vector<float> embedding_result = result.result_json["embedding"];
responseData.append(create_embedding_payload(embedding_result, 0));
} else if (input.isArray()) {
// Process each element in the array input
for (const auto &elem : input) {
if (elem.isString()) {
const int task_id = llama.request_completion(
{{"prompt", elem.asString()}, {"n_predict", 0}}, false, true, -1);
task_result result = llama.next_result(task_id);
std::vector<float> embedding_result = result.result_json["embedding"];
responseData.append(create_embedding_payload(embedding_result, 0));

state->instance->queue->runTaskInQueue([this, state, jsonBody, callback]() {
Json::Value responseData(Json::arrayValue);

if (jsonBody->isMember("input")) {
const Json::Value &input = (*jsonBody)["input"];
if (input.isString()) {
// Process the single string input
state->task_id = llama.request_completion(
{{"prompt", input.asString()}, {"n_predict", 0}}, false, true, -1);
task_result result = llama.next_result(state->task_id);
std::vector<float> embedding_result = result.result_json["embedding"];
responseData.append(create_embedding_payload(embedding_result, 0));
} else if (input.isArray()) {
// Process each element in the array input
for (const auto &elem : input) {
if (elem.isString()) {
const int task_id = llama.request_completion(
{{"prompt", elem.asString()}, {"n_predict", 0}}, false, true,
-1);
task_result result = llama.next_result(task_id);
std::vector<float> embedding_result =
result.result_json["embedding"];
responseData.append(create_embedding_payload(embedding_result, 0));
}
}
}
}
}

// We already got result of the embedding so no longer busy
state->instance->single_queue_is_busy = false;

auto resp = nitro_utils::nitroHttpResponse();
Json::Value root;
root["data"] = responseData;
root["model"] = "_";
root["object"] = "list";
Json::Value usage;
usage["prompt_tokens"] = 0;
usage["total_tokens"] = 0;
root["usage"] = usage;

resp->setBody(Json::writeString(Json::StreamWriterBuilder(), root));
resp->setContentTypeString("application/json");
callback(resp);
auto resp = nitro_utils::nitroHttpResponse();
Json::Value root;
root["data"] = responseData;
root["model"] = "_";
root["object"] = "list";
Json::Value usage;
usage["prompt_tokens"] = 0;
usage["total_tokens"] = 0;
root["usage"] = usage;

resp->setBody(Json::writeString(Json::StreamWriterBuilder(), root));
resp->setContentTypeString("application/json");
callback(resp);
});
}

void llamaCPP::unloadModel(
Expand All @@ -539,6 +515,7 @@ void llamaCPP::unloadModel(
callback(resp);
return;
}

void llamaCPP::modelStatus(
const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback) {
Expand All @@ -555,6 +532,7 @@ void llamaCPP::modelStatus(
callback(resp);
return;
}

void llamaCPP::loadModel(
const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback) {
Expand Down Expand Up @@ -674,6 +652,12 @@ bool llamaCPP::loadModelImpl(std::shared_ptr<Json::Value> jsonBody) {
}
llama.initialize();

if (queue != nullptr) {
delete queue;
}

queue = new trantor::ConcurrentTaskQueue(llama.params.n_parallel, "llamaCPP");

llama.model_loaded_external = true;

LOG_INFO << "Started background task here!";
Expand Down
8 changes: 6 additions & 2 deletions controllers/llamaCPP.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

#include "common/base.h"
#include "utils/json.hpp"
#include <trantor/utils/ConcurrentTaskQueue.h>

// auto generated files (update with ./deps.sh)

Expand Down Expand Up @@ -2562,10 +2563,13 @@ class llamaCPP : public drogon::HttpController<llamaCPP>, public ChatProvider {
bool caching_enabled;
std::atomic<int> no_of_chats = 0;
int clean_cache_threshold;
std::atomic<bool> single_queue_is_busy; // This value only used under the
// condition n_parallel is 1
std::string grammar_file_content;

/**
* Queue to handle the inference tasks
*/
trantor::ConcurrentTaskQueue *queue;

bool loadModelImpl(std::shared_ptr<Json::Value> jsonBody);
void inferenceImpl(std::shared_ptr<Json::Value> jsonBody,
std::function<void(const HttpResponsePtr &)> &callback);
Expand Down

0 comments on commit fb7bc74

Please sign in to comment.