Skip to content

Commit

Permalink
feat: rendering chat_template
Browse files Browse the repository at this point in the history
  • Loading branch information
namchuai committed Dec 19, 2024
1 parent 5414e02 commit 65f9790
Show file tree
Hide file tree
Showing 15 changed files with 4,321 additions and 136 deletions.
5 changes: 2 additions & 3 deletions engine/cli/commands/chat_completion_cmd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,8 @@ void ChatCompletionCmd::Exec(const std::string& host, int port,
json_data["model"] = model_handle;
json_data["stream"] = true;

std::string json_payload = json_data.toStyledString();

curl_easy_setopt(curl, CURLOPT_POSTFIELDS, json_payload.c_str());
curl_easy_setopt(curl, CURLOPT_POSTFIELDS,
json_data.toStyledString().c_str());

std::string ai_chat;
StreamingCallback callback;
Expand Down
29 changes: 29 additions & 0 deletions engine/common/model_metadata.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#pragma once

#include "common/tokenizer.h"
#include <sstream>

struct ModelMetadata {
uint32_t version;
uint64_t tensor_count;
uint64_t metadata_kv_count;
std::unique_ptr<Tokenizer> tokenizer;

std::string ToString() const {
std::ostringstream ss;
ss << "ModelMetadata {\n"
<< "version: " << version << "\n"
<< "tensor_count: " << tensor_count << "\n"
<< "metadata_kv_count: " << metadata_kv_count << "\n"
<< "tokenizer: ";

if (tokenizer) {
ss << "\n" << tokenizer->ToString();
} else {
ss << "null";
}

ss << "\n}";
return ss.str();
}
};
68 changes: 68 additions & 0 deletions engine/common/tokenizer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#pragma once

#include <sstream>
#include <string>

struct Tokenizer {
std::string eos_token = "";
bool add_eos_token = true;

std::string bos_token = "";
bool add_bos_token = true;

std::string unknown_token = "";
std::string padding_token = "";

std::string chat_template = "";

// Helper function for common fields
std::string BaseToString() const {
std::ostringstream ss;
ss << "eos_token: \"" << eos_token << "\"\n"
<< "add_eos_token: " << (add_eos_token ? "true" : "false") << "\n"
<< "bos_token: \"" << bos_token << "\"\n"
<< "add_bos_token: " << (add_bos_token ? "true" : "false") << "\n"
<< "unknown_token: \"" << unknown_token << "\"\n"
<< "padding_token: \"" << padding_token << "\"\n"
<< "chat_template: \"" << chat_template << "\"";
return ss.str();
}

virtual ~Tokenizer() = default;

virtual std::string ToString() = 0;
};

struct GgufTokenizer : public Tokenizer {
std::string pre = "";

~GgufTokenizer() override = default;

std::string ToString() override {
std::ostringstream ss;
ss << "GgufTokenizer {\n";
// Add base class members
ss << BaseToString() << "\n";
// Add derived class members
ss << "pre: \"" << pre << "\"\n";
ss << "}";
return ss.str();
}
};

struct SafeTensorTokenizer : public Tokenizer {
bool add_prefix_space = true;

~SafeTensorTokenizer() = default;

std::string ToString() override {
std::ostringstream ss;
ss << "SafeTensorTokenizer {\n";
// Add base class members
ss << BaseToString() << "\n";
// Add derived class members
ss << "add_prefix_space: " << (add_prefix_space ? "true" : "false") << "\n";
ss << "}";
return ss.str();
}
};
36 changes: 36 additions & 0 deletions engine/controllers/engines.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "utils/archive_utils.h"
#include "utils/cortex_utils.h"
#include "utils/engine_constants.h"
#include "utils/jinja_utils.h"
#include "utils/logging_utils.h"
#include "utils/string_utils.h"

Expand All @@ -20,6 +21,41 @@ std::string NormalizeEngine(const std::string& engine) {
};
} // namespace

void Engines::TestJinja(
const HttpRequestPtr& req,
std::function<void(const HttpResponsePtr&)>&& callback) {
auto body = req->getJsonObject();
if (body == nullptr) {
Json::Value ret;
ret["message"] = "Body can't be empty";
auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret);
resp->setStatusCode(k400BadRequest);
callback(resp);
return;
}

auto jinja = body->get("jinja", "").asString();
auto data = body->get("data", {});
auto bos_token = data.get("bos_token", "").asString();
auto eos_token = data.get("eos_token", "").asString();

auto rendered_data = jinja::RenderTemplate(jinja, data, bos_token, eos_token);

if (rendered_data.has_error()) {
Json::Value ret;
ret["message"] = rendered_data.error();
auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret);
resp->setStatusCode(k400BadRequest);
callback(resp);
return;
}
// TODO: namh recheck all the api using this. because we have an issue with Germany locale before.
auto resp = HttpResponse::newHttpResponse();
resp->setBody(rendered_data.value());
resp->setContentTypeCode(drogon::CT_TEXT_PLAIN);
callback(resp);
}

void Engines::ListEngine(
const HttpRequestPtr& req,
std::function<void(const HttpResponsePtr&)>&& callback) const {
Expand Down
5 changes: 5 additions & 0 deletions engine/controllers/engines.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ class Engines : public drogon::HttpController<Engines, false> {
public:
METHOD_LIST_BEGIN

ADD_METHOD_TO(Engines::TestJinja, "/v1/jinja", Options, Post);

// install engine
METHOD_ADD(Engines::InstallEngine, "/{1}/install", Options, Post);
ADD_METHOD_TO(Engines::InstallEngine, "/v1/engines/{1}/install", Options,
Expand Down Expand Up @@ -110,6 +112,9 @@ class Engines : public drogon::HttpController<Engines, false> {
std::function<void(const HttpResponsePtr&)>&& callback,
const std::string& engine) const;

void TestJinja(const HttpRequestPtr& req,
std::function<void(const HttpResponsePtr&)>&& callback);

void LoadEngine(const HttpRequestPtr& req,
std::function<void(const HttpResponsePtr&)>&& callback,
const std::string& engine);
Expand Down
9 changes: 8 additions & 1 deletion engine/controllers/server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
#include "trantor/utils/Logger.h"
#include "utils/cortex_utils.h"
#include "utils/function_calling/common.h"
#include "utils/http_util.h"

using namespace inferences;

Expand All @@ -27,6 +26,14 @@ void server::ChatCompletion(
std::function<void(const HttpResponsePtr&)>&& callback) {
LOG_DEBUG << "Start chat completion";
auto json_body = req->getJsonObject();
if (json_body == nullptr) {
Json::Value ret;
ret["message"] = "Body can't be empty";
auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret);
resp->setStatusCode(k400BadRequest);
callback(resp);
return;
}
bool is_stream = (*json_body).get("stream", false).asBool();
auto model_id = (*json_body).get("model", "invalid_model").asString();
auto engine_type = [this, &json_body]() -> std::string {
Expand Down
6 changes: 0 additions & 6 deletions engine/services/engine_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
#include <mutex>
#include <optional>
#include <string>
#include <string_view>
#include <unordered_map>
#include <vector>

Expand All @@ -17,7 +16,6 @@
#include "utils/cpuid/cpu_info.h"
#include "utils/dylib.h"
#include "utils/dylib_path_manager.h"
#include "utils/engine_constants.h"
#include "utils/github_release_utils.h"
#include "utils/result.hpp"
#include "utils/system_info_utils.h"
Expand Down Expand Up @@ -48,10 +46,6 @@ class EngineService : public EngineServiceI {
struct EngineInfo {
std::unique_ptr<cortex_cpp::dylib> dl;
EngineV engine;
#if defined(_WIN32)
DLL_DIRECTORY_COOKIE cookie;
DLL_DIRECTORY_COOKIE cuda_cookie;
#endif
};

std::mutex engines_mutex_;
Expand Down
40 changes: 39 additions & 1 deletion engine/services/inference_service.cc
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
#include "inference_service.h"
#include <drogon/HttpTypes.h>
#include "utils/engine_constants.h"
#include "utils/file_manager_utils.h"
#include "utils/function_calling/common.h"
#include "utils/gguf_metadata_reader.h"
#include "utils/jinja_utils.h"

namespace services {
cpp::result<void, InferResult> InferenceService::HandleChatCompletion(
Expand All @@ -24,6 +27,41 @@ cpp::result<void, InferResult> InferenceService::HandleChatCompletion(
return cpp::fail(std::make_pair(stt, res));
}

{
if (json_body->isMember("files") && !(*json_body)["files"].empty()) {
auto file = (*json_body)["files"][0].asString();
auto model_metadata_res = cortex_utils::ReadGgufMetadata(
file_manager_utils::ToAbsoluteCortexDataPath(
std::filesystem::path(file)));
if (model_metadata_res.has_value()) {
auto metadata = model_metadata_res.value().get();
if (!metadata->tokenizer->chat_template.empty()) {
auto messages = (*json_body)["messages"];
Json::Value messages_jsoncpp(Json::arrayValue);
for (auto message : messages) {
messages_jsoncpp.append(message);
}

Json::Value tools(Json::arrayValue);
Json::Value template_data_json;
template_data_json["messages"] = messages_jsoncpp;
// template_data_json["tools"] = tools;

auto prompt_result = jinja::RenderTemplate(
metadata->tokenizer->chat_template, template_data_json,
metadata->tokenizer->bos_token, metadata->tokenizer->eos_token);
if (prompt_result.has_value()) {
(*json_body)["prompt"] = prompt_result.value();
} else {
CTL_ERR("Failed to render prompt: " + prompt_result.error());
}
}
}
}
}

CTL_INF("Prompt is: " + json_body->get("prompt", "").asString());

auto cb = [q, tool_choice](Json::Value status, Json::Value res) {
if (!tool_choice.isNull()) {
res["tool_choice"] = tool_choice;
Expand Down Expand Up @@ -297,4 +335,4 @@ bool InferenceService::HasFieldInReq(std::shared_ptr<Json::Value> json_body,
}
return true;
}
} // namespace services
} // namespace services
3 changes: 2 additions & 1 deletion engine/services/inference_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
#include <queue>
#include "services/engine_service.h"
#include "utils/result.hpp"
#include "extensions/remote-engine/remote_engine.h"

namespace services {

// Status and result
using InferResult = std::pair<Json::Value, Json::Value>;

Expand Down
1 change: 0 additions & 1 deletion engine/services/model_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
#include "database/models.h"
#include "hardware_service.h"
#include "utils/cli_selection_utils.h"
#include "utils/cortex_utils.h"
#include "utils/engine_constants.h"
#include "utils/file_manager_utils.h"
#include "utils/huggingface_utils.h"
Expand Down
Loading

0 comments on commit 65f9790

Please sign in to comment.