Skip to content

Commit

Permalink
Merge pull request #111 from janhq/82-feat-add-custom-user-assistant-…
Browse files Browse the repository at this point in the history
…prompt-option-as-a-server-option-for-nitro

82 feat add custom user assistant prompt option as a server option for nitro
  • Loading branch information
tikikun authored Nov 9, 2023
2 parents 1f1564c + 2f141f4 commit a56f33d
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 15 deletions.
40 changes: 25 additions & 15 deletions controllers/llamaCPP.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ void llamaCPP::chatCompletion(

const auto &jsonBody = req->getJsonObject();
std::string formatted_output =
"Below is a conversation between an AI system named ASSISTANT and USER\n";
"Below is a conversation between an AI system named " + ai_prompt +
" and " + user_prompt + "\n";

json data;
json stopWords;
Expand All @@ -94,9 +95,19 @@ void llamaCPP::chatCompletion(

const Json::Value &messages = (*jsonBody)["messages"];
for (const auto &message : messages) {
std::string role = message["role"].asString();
std::string input_role = message["role"].asString();
std::string role;
if (input_role == "user") {
role = user_prompt;
} else if (input_role == "assistant") {
role = ai_prompt;
} else if (input_role == "system") {
role = system_prompt;
} else {
role = input_role;
}
std::string content = message["content"].asString();
formatted_output += role + ": " + content + "\n";
formatted_output += role + content + "\n";
}
formatted_output += "assistant:";

Expand All @@ -105,8 +116,7 @@ void llamaCPP::chatCompletion(
stopWords.push_back(stop_word.asString());
}
// specify default stop words
stopWords.push_back("user:");
stopWords.push_back("### USER:");
stopWords.push_back(user_prompt);
data["stop"] = stopWords;
}

Expand Down Expand Up @@ -202,19 +212,19 @@ void llamaCPP::loadModel(
LOG_INFO << "Drogon thread is:" << drogon_thread;
if (jsonBody) {
params.model = (*jsonBody)["llama_model_path"].asString();
params.n_gpu_layers = (*jsonBody)["ngl"].asInt();
params.n_ctx = (*jsonBody)["ctx_len"].asInt();
params.embedding = (*jsonBody)["embedding"].asBool();
params.n_gpu_layers = (*jsonBody).get("ngl", 100).asInt();
params.n_ctx = (*jsonBody).get("ctx_len", 2048).asInt();
params.embedding = (*jsonBody).get("embedding", true).asBool();
// Check if n_parallel exists in jsonBody, if not, set to drogon_thread
if ((*jsonBody).isMember("n_parallel")) {
params.n_parallel = (*jsonBody)["n_parallel"].asInt();
} else {
params.n_parallel = drogon_thread;
}

params.n_parallel = (*jsonBody).get("n_parallel", drogon_thread).asInt();

params.cont_batching = (*jsonBody)["cont_batching"].asBool();
// params.n_threads = (*jsonBody)["n_threads"].asInt();
// params.n_threads_batch = params.n_threads;

this->user_prompt = (*jsonBody).get("user_prompt", "USER: ").asString();
this->ai_prompt = (*jsonBody).get("ai_prompt", "ASSISTANT: ").asString();
this->system_prompt =
(*jsonBody).get("system_prompt", "ASSISTANT's RULE: ").asString();
}
#ifdef GGML_USE_CUBLAS
LOG_INFO << "Setting up GGML CUBLAS PARAMS";
Expand Down
3 changes: 3 additions & 0 deletions controllers/llamaCPP.h
Original file line number Diff line number Diff line change
Expand Up @@ -2142,5 +2142,8 @@ class llamaCPP : public drogon::HttpController<llamaCPP> {
size_t sent_count = 0;
size_t sent_token_probs_index = 0;
std::thread backgroundThread;
std::string user_prompt;
std::string ai_prompt;
std::string system_prompt;
};
}; // namespace inferences

0 comments on commit a56f33d

Please sign in to comment.