Skip to content

Commit

Permalink
Merge pull request #419 from janhq/417-feat-refactor-some-parts-of-th…
Browse files Browse the repository at this point in the history
…e-code

417 feat refactor some parts of the code
  • Loading branch information
tikikun authored Feb 5, 2024
2 parents 8f6d281 + 96deb0e commit a305d60
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 87 deletions.
170 changes: 97 additions & 73 deletions controllers/llamaCPP.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,13 @@
using namespace inferences;
using json = nlohmann::json;

/**
* There is a need to save state of current ongoing inference status of a
* handler, this struct is to solve that issue
*
* @param inst Pointer to the llamaCPP instance this inference task is
* associated with.
*/
struct inferenceState {
bool is_stopped = false;
bool is_streaming = false;
Expand All @@ -15,15 +22,20 @@ struct inferenceState {
inferenceState(llamaCPP *inst) : instance(inst) {}
};

/**
* This function is to create the smart pointer to inferenceState, hence the
* inferenceState will be persisting even tho the lambda in streaming might go
* out of scope and the handler already moved on
*/
std::shared_ptr<inferenceState> create_inference_state(llamaCPP *instance) {
return std::make_shared<inferenceState>(instance);
}

// --------------------------------------------

// Function to check if the model is loaded
void check_model_loaded(
llama_server_context &llama, const HttpRequestPtr &req,
/**
* Check if model already loaded if not return message to user
* @param callback the function to return message to user
*/
void llamaCPP::checkModelLoaded(
std::function<void(const HttpResponsePtr &)> &callback) {
if (!llama.model_loaded_external) {
Json::Value jsonResp;
Expand Down Expand Up @@ -136,7 +148,7 @@ void llamaCPP::warmupModel() {
return;
}

void llamaCPP::chatCompletionPrelight(
void llamaCPP::handlePrelight(
const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback) {
auto resp = drogon::HttpResponse::newHttpResponse();
Expand All @@ -151,10 +163,17 @@ void llamaCPP::chatCompletion(
const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback) {

const auto &jsonBody = req->getJsonObject();
// Check if model is loaded
check_model_loaded(llama, req, callback);
checkModelLoaded(callback);

chatCompletionImpl(jsonBody, callback);
}

void llamaCPP::chatCompletionImpl(
std::shared_ptr<Json::Value> jsonBody,
std::function<void(const HttpResponsePtr &)> &callback) {

const auto &jsonBody = req->getJsonObject();
std::string formatted_output = pre_prompt;

json data;
Expand Down Expand Up @@ -402,17 +421,23 @@ void llamaCPP::chatCompletion(
}
}
}

void llamaCPP::embedding(
const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback) {
check_model_loaded(llama, req, callback);
checkModelLoaded(callback);
const auto &jsonBody = req->getJsonObject();

auto state = create_inference_state(this);
embeddingImpl(jsonBody, callback);
return;
}

const auto &jsonBody = req->getJsonObject();
void llamaCPP::embeddingImpl(
std::shared_ptr<Json::Value> jsonBody,
std::function<void(const HttpResponsePtr &)> &callback) {

Json::Value responseData(Json::arrayValue);

auto state = create_inference_state(this);
if (jsonBody->isMember("input")) {
// If single queue is busy, we will wait if not we will just go ahead and
// process and make it busy, and yet i'm aware not DRY, i have the same
Expand Down Expand Up @@ -464,7 +489,6 @@ void llamaCPP::embedding(
resp->setBody(Json::writeString(Json::StreamWriterBuilder(), root));
resp->setContentTypeString("application/json");
callback(resp);
return;
}

void llamaCPP::unloadModel(
Expand Down Expand Up @@ -501,31 +525,61 @@ void llamaCPP::modelStatus(
callback(resp);
return;
}
void llamaCPP::loadModel(
const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback) {

bool llamaCPP::loadModelImpl(const Json::Value &jsonBody) {
if (llama.model_loaded_external) {
LOG_INFO << "model loaded";
Json::Value jsonResp;
jsonResp["message"] = "Model already loaded";
auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp);
resp->setStatusCode(drogon::k409Conflict);
callback(resp);
return;
}

gpt_params params;
const auto &jsonBody = req->getJsonObject();
if (!loadModelImpl(jsonBody)) {
// Error occurred during model loading
Json::Value jsonResp;
jsonResp["message"] = "Failed to load model";
auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp);
resp->setStatusCode(drogon::k500InternalServerError);
callback(resp);
} else {
// Model loaded successfully
Json::Value jsonResp;
jsonResp["message"] = "Model loaded successfully";
auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp);
callback(resp);
}
}

bool llamaCPP::loadModelImpl(std::shared_ptr<Json::Value> jsonBody) {

gpt_params params;
// By default will setting based on number of handlers
if (jsonBody) {
if (!jsonBody["mmproj"].isNull()) {
if (!jsonBody->operator[]("mmproj").isNull()) {
LOG_INFO << "MMPROJ FILE detected, multi-model enabled!";
params.mmproj = jsonBody["mmproj"].asString();
params.mmproj = jsonBody->operator[]("mmproj").asString();
}
if (!jsonBody["grp_attn_n"].isNull()) {
if (!jsonBody->operator[]("grp_attn_n").isNull()) {

params.grp_attn_n = jsonBody["grp_attn_n"].asInt();
params.grp_attn_n = jsonBody->operator[]("grp_attn_n").asInt();
}
if (!jsonBody["grp_attn_w"].isNull()) {
if (!jsonBody->operator[]("grp_attn_w").isNull()) {

params.grp_attn_w = jsonBody["grp_attn_w"].asInt();
params.grp_attn_w = jsonBody->operator[]("grp_attn_w").asInt();
}
if (!jsonBody["mlock"].isNull()) {
params.use_mlock = jsonBody["mlock"].asBool();
if (!jsonBody->operator[]("mlock").isNull()) {
params.use_mlock = jsonBody->operator[]("mlock").asBool();
}

if (!jsonBody["grammar_file"].isNull()) {
std::string grammar_file = jsonBody["grammar_file"].asString();
if (!jsonBody->operator[]("grammar_file").isNull()) {
std::string grammar_file =
jsonBody->operator[]("grammar_file").asString();
std::ifstream file(grammar_file);
if (!file) {
LOG_ERROR << "Grammar file not found";
Expand All @@ -536,30 +590,31 @@ bool llamaCPP::loadModelImpl(const Json::Value &jsonBody) {
}
};

params.model = jsonBody["llama_model_path"].asString();
params.n_gpu_layers = jsonBody.get("ngl", 100).asInt();
params.n_ctx = jsonBody.get("ctx_len", 2048).asInt();
params.embedding = jsonBody.get("embedding", true).asBool();
params.model = jsonBody->operator[]("llama_model_path").asString();
params.n_gpu_layers = jsonBody->get("ngl", 100).asInt();
params.n_ctx = jsonBody->get("ctx_len", 2048).asInt();
params.embedding = jsonBody->get("embedding", true).asBool();
// Check if n_parallel exists in jsonBody, if not, set to drogon_thread
params.n_batch = jsonBody.get("n_batch", 512).asInt();
params.n_parallel = jsonBody.get("n_parallel", 1).asInt();
params.n_batch = jsonBody->get("n_batch", 512).asInt();
params.n_parallel = jsonBody->get("n_parallel", 1).asInt();
params.n_threads =
jsonBody.get("cpu_threads", std::thread::hardware_concurrency())
jsonBody->get("cpu_threads", std::thread::hardware_concurrency())
.asInt();
params.cont_batching = jsonBody.get("cont_batching", false).asBool();
params.cont_batching = jsonBody->get("cont_batching", false).asBool();
this->clean_cache_threshold =
jsonBody.get("clean_cache_threshold", 5).asInt();
this->caching_enabled = jsonBody.get("caching_enabled", false).asBool();
this->user_prompt = jsonBody.get("user_prompt", "USER: ").asString();
this->ai_prompt = jsonBody.get("ai_prompt", "ASSISTANT: ").asString();
jsonBody->get("clean_cache_threshold", 5).asInt();
this->caching_enabled = jsonBody->get("caching_enabled", false).asBool();
this->user_prompt = jsonBody->get("user_prompt", "USER: ").asString();
this->ai_prompt = jsonBody->get("ai_prompt", "ASSISTANT: ").asString();
this->system_prompt =
jsonBody.get("system_prompt", "ASSISTANT's RULE: ").asString();
this->pre_prompt = jsonBody.get("pre_prompt", "").asString();
this->repeat_last_n = jsonBody.get("repeat_last_n", 32).asInt();
jsonBody->get("system_prompt", "ASSISTANT's RULE: ").asString();
this->pre_prompt = jsonBody->get("pre_prompt", "").asString();
this->repeat_last_n = jsonBody->get("repeat_last_n", 32).asInt();

if (!jsonBody["llama_log_folder"].isNull()) {
if (!jsonBody->operator[]("llama_log_folder").isNull()) {
log_enable();
std::string llama_log_folder = jsonBody["llama_log_folder"].asString();
std::string llama_log_folder =
jsonBody->operator[]("llama_log_folder").asString();
log_set_target(llama_log_folder + "llama.log");
} // Set folder for llama log
}
Expand Down Expand Up @@ -597,37 +652,6 @@ bool llamaCPP::loadModelImpl(const Json::Value &jsonBody) {
return true;
}

void llamaCPP::loadModel(
const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback) {

if (llama.model_loaded_external) {
LOG_INFO << "model loaded";
Json::Value jsonResp;
jsonResp["message"] = "Model already loaded";
auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp);
resp->setStatusCode(drogon::k409Conflict);
callback(resp);
return;
}

const auto &jsonBody = req->getJsonObject();
if (!loadModelImpl(*jsonBody)) {
// Error occurred during model loading
Json::Value jsonResp;
jsonResp["message"] = "Failed to load model";
auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp);
resp->setStatusCode(drogon::k500InternalServerError);
callback(resp);
} else {
// Model loaded successfully
Json::Value jsonResp;
jsonResp["message"] = "Model loaded successfully";
auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp);
callback(resp);
}
}

void llamaCPP::backgroundTask() {
while (llama.model_loaded_external) {
// model_loaded =
Expand Down
29 changes: 15 additions & 14 deletions controllers/llamaCPP.h
Original file line number Diff line number Diff line change
Expand Up @@ -2530,36 +2530,26 @@ class llamaCPP : public drogon::HttpController<llamaCPP> {

// Openai compatible path
ADD_METHOD_TO(llamaCPP::chatCompletion, "/v1/chat/completions", Post);
ADD_METHOD_TO(llamaCPP::chatCompletionPrelight, "/v1/chat/completions",
Options);
ADD_METHOD_TO(llamaCPP::handlePrelight, "/v1/chat/completions", Options);

ADD_METHOD_TO(llamaCPP::embedding, "/v1/embeddings", Post);
ADD_METHOD_TO(llamaCPP::handlePrelight, "/v1/embeddings", Options);

// PATH_ADD("/llama/chat_completion", Post);
METHOD_LIST_END
void chatCompletion(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
void chatCompletionPrelight(
const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
void handlePrelight(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
void embedding(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
void loadModel(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
void unloadModel(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);

void modelStatus(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);

bool loadModelImpl(const Json::Value &jsonBody);

void warmupModel();

void backgroundTask();

void stopBackgroundTask();

private:
llama_server_context llama;
// std::atomic<bool> model_loaded = false;
Expand All @@ -2577,5 +2567,16 @@ class llamaCPP : public drogon::HttpController<llamaCPP> {
std::atomic<bool> single_queue_is_busy; // This value only used under the
// condition n_parallel is 1
std::string grammar_file_content;

bool loadModelImpl(std::shared_ptr<Json::Value> jsonBody);
void
chatCompletionImpl(std::shared_ptr<Json::Value> jsonBody,
std::function<void(const HttpResponsePtr &)> &callback);
void embeddingImpl(std::shared_ptr<Json::Value> jsonBody,
std::function<void(const HttpResponsePtr &)> &callback);
void checkModelLoaded(std::function<void(const HttpResponsePtr &)> &callback);
void warmupModel();
void backgroundTask();
void stopBackgroundTask();
};
}; // namespace inferences

0 comments on commit a305d60

Please sign in to comment.