Skip to content

Commit

Permalink
feat: rendering chat_template
Browse files Browse the repository at this point in the history
  • Loading branch information
namchuai committed Dec 20, 2024
1 parent 5414e02 commit 71f06d1
Show file tree
Hide file tree
Showing 19 changed files with 4,384 additions and 164 deletions.
5 changes: 2 additions & 3 deletions engine/cli/commands/chat_completion_cmd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,8 @@ void ChatCompletionCmd::Exec(const std::string& host, int port,
json_data["model"] = model_handle;
json_data["stream"] = true;

std::string json_payload = json_data.toStyledString();

curl_easy_setopt(curl, CURLOPT_POSTFIELDS, json_payload.c_str());
curl_easy_setopt(curl, CURLOPT_POSTFIELDS,
json_data.toStyledString().c_str());

std::string ai_chat;
StreamingCallback callback;
Expand Down
29 changes: 29 additions & 0 deletions engine/common/model_metadata.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#pragma once

#include "common/tokenizer.h"
#include <sstream>

struct ModelMetadata {
uint32_t version;
uint64_t tensor_count;
uint64_t metadata_kv_count;
std::unique_ptr<Tokenizer> tokenizer;

std::string ToString() const {
std::ostringstream ss;
ss << "ModelMetadata {\n"
<< "version: " << version << "\n"
<< "tensor_count: " << tensor_count << "\n"
<< "metadata_kv_count: " << metadata_kv_count << "\n"
<< "tokenizer: ";

if (tokenizer) {
ss << "\n" << tokenizer->ToString();
} else {
ss << "null";
}

ss << "\n}";
return ss.str();
}
};
72 changes: 72 additions & 0 deletions engine/common/tokenizer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
#pragma once

#include <sstream>
#include <string>

struct Tokenizer {
std::string eos_token = "";
bool add_eos_token = true;

std::string bos_token = "";
bool add_bos_token = true;

std::string unknown_token = "";
std::string padding_token = "";

std::string chat_template = "";

bool add_generation_prompt = true;

// Helper function for common fields
std::string BaseToString() const {
std::ostringstream ss;
ss << "eos_token: \"" << eos_token << "\"\n"
<< "add_eos_token: " << (add_eos_token ? "true" : "false") << "\n"
<< "bos_token: \"" << bos_token << "\"\n"
<< "add_bos_token: " << (add_bos_token ? "true" : "false") << "\n"
<< "unknown_token: \"" << unknown_token << "\"\n"
<< "padding_token: \"" << padding_token << "\"\n"
<< "chat_template: \"" << chat_template << "\"\n"
<< "add_generation_prompt: "
<< (add_generation_prompt ? "true" : "false") << "\"";
return ss.str();
}

virtual ~Tokenizer() = default;

virtual std::string ToString() = 0;
};

struct GgufTokenizer : public Tokenizer {
std::string pre = "";

~GgufTokenizer() override = default;

std::string ToString() override {
std::ostringstream ss;
ss << "GgufTokenizer {\n";
// Add base class members
ss << BaseToString() << "\n";
// Add derived class members
ss << "pre: \"" << pre << "\"\n";
ss << "}";
return ss.str();
}
};

struct SafeTensorTokenizer : public Tokenizer {
bool add_prefix_space = true;

~SafeTensorTokenizer() = default;

std::string ToString() override {
std::ostringstream ss;
ss << "SafeTensorTokenizer {\n";
// Add base class members
ss << BaseToString() << "\n";
// Add derived class members
ss << "add_prefix_space: " << (add_prefix_space ? "true" : "false") << "\n";
ss << "}";
return ss.str();
}
};
34 changes: 34 additions & 0 deletions engine/controllers/engines.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "utils/archive_utils.h"
#include "utils/cortex_utils.h"
#include "utils/engine_constants.h"
#include "utils/jinja_utils.h"
#include "utils/logging_utils.h"
#include "utils/string_utils.h"

Expand All @@ -20,6 +21,39 @@ std::string NormalizeEngine(const std::string& engine) {
};
} // namespace

void Engines::TestJinja(
const HttpRequestPtr& req,
std::function<void(const HttpResponsePtr&)>&& callback) {
auto body = req->getJsonObject();
if (body == nullptr) {
Json::Value ret;
ret["message"] = "Body can't be empty";
auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret);
resp->setStatusCode(k400BadRequest);
callback(resp);
return;
}

auto jinja = body->get("jinja", "").asString();
auto data = body->get("data", {});
auto bos_token = data.get("bos_token", "").asString();
auto eos_token = data.get("eos_token", "").asString();

auto rendered_data = jinja::RenderTemplate(jinja, data, bos_token, eos_token);

if (rendered_data.has_error()) {
Json::Value ret;
ret["message"] = rendered_data.error();
auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret);
resp->setStatusCode(k400BadRequest);
callback(resp);
return;
}

auto resp = cortex_utils::CreateTextPlainResponse(rendered_data.value());
callback(resp);
}

void Engines::ListEngine(
const HttpRequestPtr& req,
std::function<void(const HttpResponsePtr&)>&& callback) const {
Expand Down
5 changes: 5 additions & 0 deletions engine/controllers/engines.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ class Engines : public drogon::HttpController<Engines, false> {
public:
METHOD_LIST_BEGIN

ADD_METHOD_TO(Engines::TestJinja, "/v1/jinja", Options, Post);

// install engine
METHOD_ADD(Engines::InstallEngine, "/{1}/install", Options, Post);
ADD_METHOD_TO(Engines::InstallEngine, "/v1/engines/{1}/install", Options,
Expand Down Expand Up @@ -110,6 +112,9 @@ class Engines : public drogon::HttpController<Engines, false> {
std::function<void(const HttpResponsePtr&)>&& callback,
const std::string& engine) const;

void TestJinja(const HttpRequestPtr& req,
std::function<void(const HttpResponsePtr&)>&& callback);

void LoadEngine(const HttpRequestPtr& req,
std::function<void(const HttpResponsePtr&)>&& callback,
const std::string& engine);
Expand Down
17 changes: 5 additions & 12 deletions engine/controllers/files.cc
Original file line number Diff line number Diff line change
Expand Up @@ -216,10 +216,8 @@ void Files::RetrieveFileContent(
return;
}

auto [buffer, size] = std::move(res.value());
auto resp = HttpResponse::newHttpResponse();
resp->setBody(std::string(buffer.get(), size));
resp->setContentTypeCode(CT_APPLICATION_OCTET_STREAM);
auto resp =
cortex_utils::CreateCortexContentResponse(std::move(res.value()));
callback(resp);
} else {
if (!msg_res->rel_path.has_value()) {
Expand All @@ -243,10 +241,8 @@ void Files::RetrieveFileContent(
return;
}

auto [buffer, size] = std::move(content_res.value());
auto resp = HttpResponse::newHttpResponse();
resp->setBody(std::string(buffer.get(), size));
resp->setContentTypeCode(CT_APPLICATION_OCTET_STREAM);
auto resp = cortex_utils::CreateCortexContentResponse(
std::move(content_res.value()));
callback(resp);
}
}
Expand All @@ -261,9 +257,6 @@ void Files::RetrieveFileContent(
return;
}

auto [buffer, size] = std::move(res.value());
auto resp = HttpResponse::newHttpResponse();
resp->setBody(std::string(buffer.get(), size));
resp->setContentTypeCode(CT_APPLICATION_OCTET_STREAM);
auto resp = cortex_utils::CreateCortexContentResponse(std::move(res.value()));
callback(resp);
}
9 changes: 8 additions & 1 deletion engine/controllers/server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
#include "trantor/utils/Logger.h"
#include "utils/cortex_utils.h"
#include "utils/function_calling/common.h"
#include "utils/http_util.h"

using namespace inferences;

Expand All @@ -27,6 +26,14 @@ void server::ChatCompletion(
std::function<void(const HttpResponsePtr&)>&& callback) {
LOG_DEBUG << "Start chat completion";
auto json_body = req->getJsonObject();
if (json_body == nullptr) {
Json::Value ret;
ret["message"] = "Body can't be empty";
auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret);
resp->setStatusCode(k400BadRequest);
callback(resp);
return;
}
bool is_stream = (*json_body).get("stream", false).asBool();
auto model_id = (*json_body).get("model", "invalid_model").asString();
auto engine_type = [this, &json_body]() -> std::string {
Expand Down
1 change: 1 addition & 0 deletions engine/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ void RunServer(std::optional<std::string> host, std::optional<int> port,
auto model_src_svc = std::make_shared<services::ModelSourceService>();
auto model_service = std::make_shared<ModelService>(
download_service, inference_svc, engine_service);
inference_svc->SetModelService(model_service);

auto file_watcher_srv = std::make_shared<FileWatcherService>(
model_dir_path.string(), model_service);
Expand Down
20 changes: 8 additions & 12 deletions engine/services/engine_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
#include <mutex>
#include <optional>
#include <string>
#include <string_view>
#include <unordered_map>
#include <vector>

Expand All @@ -17,7 +16,6 @@
#include "utils/cpuid/cpu_info.h"
#include "utils/dylib.h"
#include "utils/dylib_path_manager.h"
#include "utils/engine_constants.h"
#include "utils/github_release_utils.h"
#include "utils/result.hpp"
#include "utils/system_info_utils.h"
Expand Down Expand Up @@ -48,10 +46,6 @@ class EngineService : public EngineServiceI {
struct EngineInfo {
std::unique_ptr<cortex_cpp::dylib> dl;
EngineV engine;
#if defined(_WIN32)
DLL_DIRECTORY_COOKIE cookie;
DLL_DIRECTORY_COOKIE cuda_cookie;
#endif
};

std::mutex engines_mutex_;
Expand Down Expand Up @@ -105,21 +99,23 @@ class EngineService : public EngineServiceI {

cpp::result<DefaultEngineVariant, std::string> SetDefaultEngineVariant(
const std::string& engine, const std::string& version,
const std::string& variant);
const std::string& variant) override;

cpp::result<DefaultEngineVariant, std::string> GetDefaultEngineVariant(
const std::string& engine);
const std::string& engine) override;

cpp::result<std::vector<EngineVariantResponse>, std::string>
GetInstalledEngineVariants(const std::string& engine) const;
GetInstalledEngineVariants(const std::string& engine) const override;

cpp::result<EngineV, std::string> GetLoadedEngine(
const std::string& engine_name);

std::vector<EngineV> GetLoadedEngines();

cpp::result<void, std::string> LoadEngine(const std::string& engine_name);
cpp::result<void, std::string> UnloadEngine(const std::string& engine_name);
cpp::result<void, std::string> LoadEngine(
const std::string& engine_name) override;
cpp::result<void, std::string> UnloadEngine(
const std::string& engine_name) override;

cpp::result<github_release_utils::GitHubRelease, std::string>
GetLatestEngineVersion(const std::string& engine) const;
Expand All @@ -137,7 +133,7 @@ class EngineService : public EngineServiceI {

cpp::result<cortex::db::EngineEntry, std::string> GetEngineByNameAndVariant(
const std::string& engine_name,
const std::optional<std::string> variant = std::nullopt);
const std::optional<std::string> variant = std::nullopt) override;

cpp::result<cortex::db::EngineEntry, std::string> UpsertEngine(
const std::string& engine_name, const std::string& type,
Expand Down
52 changes: 51 additions & 1 deletion engine/services/inference_service.cc
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
#include "inference_service.h"
#include <drogon/HttpTypes.h>
#include "utils/engine_constants.h"
#include "utils/file_manager_utils.h"
#include "utils/function_calling/common.h"
#include "utils/gguf_metadata_reader.h"
#include "utils/jinja_utils.h"

namespace services {
cpp::result<void, InferResult> InferenceService::HandleChatCompletion(
Expand All @@ -24,6 +27,53 @@ cpp::result<void, InferResult> InferenceService::HandleChatCompletion(
return cpp::fail(std::make_pair(stt, res));
}

{
// TODO: we can cache this one so we don't have to read the file every inference
auto model_id = json_body->get("model", "").asString();
if (!model_id.empty()) {
if (auto model_service = model_service_.lock()) {
auto model_config = model_service->GetDownloadedModel(model_id);
if (model_config.has_value() && !model_config->files.empty()) {
auto file = model_config->files[0];

auto model_metadata_res = cortex_utils::ReadGgufMetadata(
file_manager_utils::ToAbsoluteCortexDataPath(
std::filesystem::path(file)));
if (model_metadata_res.has_value()) {
auto metadata = model_metadata_res.value().get();
if (!metadata->tokenizer->chat_template.empty()) {
auto messages = (*json_body)["messages"];
Json::Value messages_jsoncpp(Json::arrayValue);
for (auto message : messages) {
messages_jsoncpp.append(message);
}

Json::Value tools(Json::arrayValue);
Json::Value template_data_json;
template_data_json["messages"] = messages_jsoncpp;
// template_data_json["tools"] = tools;

auto prompt_result = jinja::RenderTemplate(
metadata->tokenizer->chat_template, template_data_json,
metadata->tokenizer->bos_token,
metadata->tokenizer->eos_token,
metadata->tokenizer->add_generation_prompt);
if (prompt_result.has_value()) {
(*json_body)["prompt"] = prompt_result.value();
} else {
CTL_ERR("Failed to render prompt: " + prompt_result.error());
}
}
} else {
CTL_ERR("Failed to read metadata: " + model_metadata_res.error());
}
}
}
}
}

CTL_INF("Prompt is: " + json_body->get("prompt", "").asString());

auto cb = [q, tool_choice](Json::Value status, Json::Value res) {
if (!tool_choice.isNull()) {
res["tool_choice"] = tool_choice;
Expand Down Expand Up @@ -297,4 +347,4 @@ bool InferenceService::HasFieldInReq(std::shared_ptr<Json::Value> json_body,
}
return true;
}
} // namespace services
} // namespace services
Loading

0 comments on commit 71f06d1

Please sign in to comment.