Skip to content

Commit

Permalink
feat: add chat_template
Browse files Browse the repository at this point in the history
  • Loading branch information
namchuai committed Dec 18, 2024
1 parent cf47200 commit 88127a7
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 70 deletions.
2 changes: 2 additions & 0 deletions src/chat_completion_request.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ struct ChatCompletionRequest {
Json::Value stop = Json::Value(Json::arrayValue);
Json::Value messages = Json::Value(Json::arrayValue);
std::string model_id;
std::string prompt;

int seed = -1;
float dynatemp_range = 0.0f;
Expand Down Expand Up @@ -125,6 +126,7 @@ inline ChatCompletionRequest fromJson(std::shared_ptr<Json::Value> jsonBody) {
completion.presence_penalty =
(*jsonBody).get("presence_penalty", 0).asFloat();
completion.messages = (*jsonBody)["messages"];
completion.prompt = jsonBody->get("prompt", "").asString();
completion.stop = (*jsonBody)["stop"];
completion.model_id = (*jsonBody).get("model", {}).asString();

Expand Down
148 changes: 78 additions & 70 deletions src/llama_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -751,7 +751,7 @@ void LlamaEngine::HandleInferenceImpl(
callback(std::move(status), std::move(jsonResp));
return;
}
std::string formatted_output = si.pre_prompt;
auto formatted_output = si.pre_prompt;
int request_id = ++no_of_requests_;
LOG_INFO << "Request " << request_id << ", " << "model "
<< completion.model_id << ": "
Expand Down Expand Up @@ -830,87 +830,95 @@ void LlamaEngine::HandleInferenceImpl(
return "";
};

for (const auto& message : messages) {
std::string input_role = message["role"].asString();
std::string role;
if (input_role == "user") {
role = si.user_prompt;
} else if (input_role == "assistant") {
role = si.ai_prompt;
} else if (input_role == "system") {
role = si.system_prompt;
} else {
role = input_role;
}
if (!completion.prompt.empty()) {
// If prompt is provided, use it as the prompt
formatted_output = completion.prompt;
} else {
for (const auto& message : messages) {
std::string input_role = message["role"].asString();
std::string role;
if (input_role == "user") {
role = si.user_prompt;
} else if (input_role == "assistant") {
role = si.ai_prompt;
} else if (input_role == "system") {
role = si.system_prompt;
} else {
role = input_role;
}

if (auto content = get_message(message["content"]); !content.empty()) {
formatted_output += role + content;
if (auto content = get_message(message["content"]); !content.empty()) {
formatted_output += role + content;
}
}
formatted_output += si.ai_prompt;
}
formatted_output += si.ai_prompt;
} else {
data["image_data"] = json::array();
for (const auto& message : messages) {
std::string input_role = message["role"].asString();
std::string role;
if (input_role == "user") {
formatted_output += role;
for (auto content_piece : message["content"]) {
role = si.user_prompt;
if (!completion.prompt.empty()) {
formatted_output = completion.prompt;
} else {
for (const auto& message : messages) {
std::string input_role = message["role"].asString();
std::string role;
if (input_role == "user") {
formatted_output += role;
for (auto content_piece : message["content"]) {
role = si.user_prompt;

json content_piece_image_data;
content_piece_image_data["data"] = "";

auto content_piece_type = content_piece["type"].asString();
if (content_piece_type == "text") {
auto text = content_piece["text"].asString();
formatted_output += text;
} else if (content_piece_type == "image_url") {
auto image_url = content_piece["image_url"]["url"].asString();
std::string base64_image_data;
if (image_url.find("http") != std::string::npos) {
LOG_INFO << "Request " << request_id << ": "
<< "Remote image detected but not supported yet";
} else if (image_url.find("data:image") != std::string::npos) {
LOG_INFO << "Request " << request_id << ": "
<< "Base64 image detected";
base64_image_data = llama_utils::extractBase64(image_url);
// LOG_INFO << "Request " << request_id << ": " << base64_image_data;
} else {
LOG_INFO << "Request " << request_id << ": "
<< "Local image detected";
llama_utils::processLocalImage(
image_url, [&](const std::string& base64Image) {
base64_image_data = base64Image;
});
// LOG_INFO << "Request " << request_id << ": " << base64_image_data;
}
content_piece_image_data["data"] = base64_image_data;

json content_piece_image_data;
content_piece_image_data["data"] = "";

auto content_piece_type = content_piece["type"].asString();
if (content_piece_type == "text") {
auto text = content_piece["text"].asString();
formatted_output += text;
} else if (content_piece_type == "image_url") {
auto image_url = content_piece["image_url"]["url"].asString();
std::string base64_image_data;
if (image_url.find("http") != std::string::npos) {
LOG_INFO << "Request " << request_id << ": "
<< "Remote image detected but not supported yet";
} else if (image_url.find("data:image") != std::string::npos) {
LOG_INFO << "Request " << request_id << ": "
<< "Base64 image detected";
base64_image_data = llama_utils::extractBase64(image_url);
// LOG_INFO << "Request " << request_id << ": " << base64_image_data;
} else {
LOG_INFO << "Request " << request_id << ": "
<< "Local image detected";
llama_utils::processLocalImage(
image_url, [&](const std::string& base64Image) {
base64_image_data = base64Image;
});
// LOG_INFO << "Request " << request_id << ": " << base64_image_data;
formatted_output += "[img-" + std::to_string(no_images) + "]";
content_piece_image_data["id"] = no_images;
data["image_data"].push_back(content_piece_image_data);
no_images++;
}
content_piece_image_data["data"] = base64_image_data;

formatted_output += "[img-" + std::to_string(no_images) + "]";
content_piece_image_data["id"] = no_images;
data["image_data"].push_back(content_piece_image_data);
no_images++;
}
}

} else if (input_role == "assistant") {
role = si.ai_prompt;
std::string content = message["content"].asString();
formatted_output += role + content;
} else if (input_role == "system") {
role = si.system_prompt;
std::string content = message["content"].asString();
formatted_output = role + content + formatted_output;
} else if (input_role == "assistant") {
role = si.ai_prompt;
std::string content = message["content"].asString();
formatted_output += role + content;
} else if (input_role == "system") {
role = si.system_prompt;
std::string content = message["content"].asString();
formatted_output = role + content + formatted_output;

} else {
role = input_role;
std::string content = message["content"].asString();
formatted_output += role + content;
} else {
role = input_role;
std::string content = message["content"].asString();
formatted_output += role + content;
}
}
formatted_output += si.ai_prompt;
}
formatted_output += si.ai_prompt;
// LOG_INFO << "Request " << request_id << ": " << formatted_output;
}

data["prompt"] = formatted_output;
Expand Down

0 comments on commit 88127a7

Please sign in to comment.