Skip to content

Commit

Permalink
Merge pull request #143 from janhq/141-feat-non-stream-chat-completion
Browse files Browse the repository at this point in the history
feat: make non stream completion possible to be fully compatible with…
  • Loading branch information
tikikun authored Nov 15, 2023
2 parents e0cef1e + 2be0d28 commit 514fd2e
Showing 1 changed file with 108 additions and 44 deletions.
152 changes: 108 additions & 44 deletions controllers/llamaCPP.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
#include <regex>
#include <string>
#include <thread>
#include <trantor/utils/Logger.h>

using namespace inferences;
using json = nlohmann::json;
Expand All @@ -28,6 +27,45 @@ std::shared_ptr<State> createState(int task_id, llamaCPP *instance) {

// --------------------------------------------

std::string create_full_return_json(const std::string &id,
const std::string &model,
const std::string &content,
const std::string &system_fingerprint,
int prompt_tokens, int completion_tokens,
Json::Value finish_reason = Json::Value()) {

Json::Value root;

root["id"] = id;
root["model"] = model;
root["created"] = static_cast<int>(std::time(nullptr));
root["object"] = "chat.completion";
root["system_fingerprint"] = system_fingerprint;

Json::Value choicesArray(Json::arrayValue);
Json::Value choice;

choice["index"] = 0;
Json::Value message;
message["role"] = "assistant";
message["content"] = content;
choice["message"] = message;
choice["finish_reason"] = finish_reason;

choicesArray.append(choice);
root["choices"] = choicesArray;

Json::Value usage;
usage["prompt_tokens"] = prompt_tokens;
usage["completion_tokens"] = completion_tokens;
usage["total_tokens"] = prompt_tokens + completion_tokens;
root["usage"] = usage;

Json::StreamWriterBuilder writer;
writer["indentation"] = ""; // Compact output
return Json::writeString(writer, root);
}

std::string create_return_json(const std::string &id, const std::string &model,
const std::string &content,
Json::Value finish_reason = Json::Value()) {
Expand Down Expand Up @@ -82,9 +120,9 @@ void llamaCPP::chatCompletion(
json data;
json stopWords;
// To set default value
data["stream"] = true;

if (jsonBody) {
data["stream"] = (*jsonBody).get("stream", false).asBool();
data["n_predict"] = (*jsonBody).get("max_tokens", 500).asInt();
data["top_p"] = (*jsonBody).get("top_p", 0.95).asFloat();
data["temperature"] = (*jsonBody).get("temperature", 0.8).asFloat();
Expand Down Expand Up @@ -119,62 +157,87 @@ void llamaCPP::chatCompletion(
data["stop"] = stopWords;
}

bool is_streamed = data["stream"];

const int task_id = llama.request_completion(data, false, false);
LOG_INFO << "Resolved request for task_id:" << task_id;

auto state = createState(task_id, this);
if (is_streamed) {
auto state = createState(task_id, this);

auto chunked_content_provider =
[state](char *pBuffer, std::size_t nBuffSize) -> std::size_t {
if (!pBuffer) {
LOG_INFO << "Connection closed or buffer is null. Reset context";
state->instance->llama.request_cancel(state->task_id);
return 0;
}
if (state->isStopped) {
return 0;
}

task_result result = state->instance->llama.next_result(state->task_id);
if (!result.error) {
const std::string to_send = result.result_json["content"];
const std::string str =
"data: " +
create_return_json(nitro_utils::generate_random_string(20), "_",
to_send) +
"\n\n";

std::size_t nRead = std::min(str.size(), nBuffSize);
memcpy(pBuffer, str.data(), nRead);
auto chunked_content_provider =
[state](char *pBuffer, std::size_t nBuffSize) -> std::size_t {
if (!pBuffer) {
LOG_INFO << "Connection closed or buffer is null. Reset context";
state->instance->llama.request_cancel(state->task_id);
return 0;
}
if (state->isStopped) {
return 0;
}

if (result.stop) {
task_result result = state->instance->llama.next_result(state->task_id);
if (!result.error) {
const std::string to_send = result.result_json["content"];
const std::string str =
"data: " +
create_return_json(nitro_utils::generate_random_string(20), "_", "",
"stop") +
"\n\n" + "data: [DONE]" + "\n\n";
create_return_json(nitro_utils::generate_random_string(20), "_",
to_send) +
"\n\n";

LOG_VERBOSE("data stream", {{"to_send", str}});
std::size_t nRead = std::min(str.size(), nBuffSize);
memcpy(pBuffer, str.data(), nRead);
LOG_INFO << "reached result stop";
state->isStopped = true;
state->instance->llama.request_cancel(state->task_id);

if (result.stop) {
const std::string str =
"data: " +
create_return_json(nitro_utils::generate_random_string(20), "_",
"", "stop") +
"\n\n" + "data: [DONE]" + "\n\n";

LOG_VERBOSE("data stream", {{"to_send", str}});
std::size_t nRead = std::min(str.size(), nBuffSize);
memcpy(pBuffer, str.data(), nRead);
LOG_INFO << "reached result stop";
state->isStopped = true;
state->instance->llama.request_cancel(state->task_id);
return nRead;
}
return nRead;
} else {
return 0;
}
return nRead;
} else {
return 0;
}
return 0;
};
auto resp = nitro_utils::nitroStreamResponse(chunked_content_provider,
"chat_completions.txt");
callback(resp);
};
auto resp = nitro_utils::nitroStreamResponse(chunked_content_provider,
"chat_completions.txt");
callback(resp);

return;
return;
} else {
Json::Value respData;
auto resp = nitro_utils::nitroHttpResponse();
respData["testing"] = "thunghiem value moi";
if (!json_value(data, "stream", false)) {
std::string completion_text;
task_result result = llama.next_result(task_id);
if (!result.error && result.stop) {
int prompt_tokens = result.result_json["tokens_evaluated"];
int predicted_tokens = result.result_json["tokens_predicted"];
std::string full_return =
create_full_return_json(nitro_utils::generate_random_string(20),
"_", result.result_json["content"], "_",
prompt_tokens, predicted_tokens);
resp->setBody(full_return);
} else {
resp->setBody("internal error during inference");
return;
}
callback(resp);
return;
}
}
}

void llamaCPP::embedding(
const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback) {
Expand Down Expand Up @@ -262,7 +325,8 @@ void llamaCPP::loadModel(
this->pre_prompt =
(*jsonBody)
.get("pre_prompt",
"A chat between a curious user and an artificial intelligence "
"A chat between a curious user and an artificial "
"intelligence "
"assistant. The assistant follows the given rules no matter "
"what.\\n")
.asString();
Expand Down

0 comments on commit 514fd2e

Please sign in to comment.