From 88127a7dad56e7f6b4fe6dd950077a1831cd3f9e Mon Sep 17 00:00:00 2001 From: James Date: Wed, 18 Dec 2024 15:29:51 +0700 Subject: [PATCH] feat: add chat_template --- src/chat_completion_request.h | 2 + src/llama_engine.cc | 148 ++++++++++++++++++---------------- 2 files changed, 80 insertions(+), 70 deletions(-) diff --git a/src/chat_completion_request.h b/src/chat_completion_request.h index c37ce9e..a56e49e 100644 --- a/src/chat_completion_request.h +++ b/src/chat_completion_request.h @@ -67,6 +67,7 @@ struct ChatCompletionRequest { Json::Value stop = Json::Value(Json::arrayValue); Json::Value messages = Json::Value(Json::arrayValue); std::string model_id; + std::string prompt; int seed = -1; float dynatemp_range = 0.0f; @@ -125,6 +126,7 @@ inline ChatCompletionRequest fromJson(std::shared_ptr jsonBody) { completion.presence_penalty = (*jsonBody).get("presence_penalty", 0).asFloat(); completion.messages = (*jsonBody)["messages"]; + completion.prompt = jsonBody->get("prompt", "").asString(); completion.stop = (*jsonBody)["stop"]; completion.model_id = (*jsonBody).get("model", {}).asString(); diff --git a/src/llama_engine.cc b/src/llama_engine.cc index 3b80bd0..5560645 100644 --- a/src/llama_engine.cc +++ b/src/llama_engine.cc @@ -751,7 +751,7 @@ void LlamaEngine::HandleInferenceImpl( callback(std::move(status), std::move(jsonResp)); return; } - std::string formatted_output = si.pre_prompt; + auto formatted_output = si.pre_prompt; int request_id = ++no_of_requests_; LOG_INFO << "Request " << request_id << ", " << "model " << completion.model_id << ": " @@ -830,87 +830,95 @@ void LlamaEngine::HandleInferenceImpl( return ""; }; - for (const auto& message : messages) { - std::string input_role = message["role"].asString(); - std::string role; - if (input_role == "user") { - role = si.user_prompt; - } else if (input_role == "assistant") { - role = si.ai_prompt; - } else if (input_role == "system") { - role = si.system_prompt; - } else { - role = input_role; - } + if (!completion.prompt.empty()) { + // If prompt is provided, use it as the prompt + formatted_output = completion.prompt; + } else { + for (const auto& message : messages) { + std::string input_role = message["role"].asString(); + std::string role; + if (input_role == "user") { + role = si.user_prompt; + } else if (input_role == "assistant") { + role = si.ai_prompt; + } else if (input_role == "system") { + role = si.system_prompt; + } else { + role = input_role; + } - if (auto content = get_message(message["content"]); !content.empty()) { - formatted_output += role + content; + if (auto content = get_message(message["content"]); !content.empty()) { + formatted_output += role + content; + } } + formatted_output += si.ai_prompt; } - formatted_output += si.ai_prompt; } else { data["image_data"] = json::array(); - for (const auto& message : messages) { - std::string input_role = message["role"].asString(); - std::string role; - if (input_role == "user") { - formatted_output += role; - for (auto content_piece : message["content"]) { - role = si.user_prompt; + if (!completion.prompt.empty()) { + formatted_output = completion.prompt; + } else { + for (const auto& message : messages) { + std::string input_role = message["role"].asString(); + std::string role; + if (input_role == "user") { + formatted_output += role; + for (auto content_piece : message["content"]) { + role = si.user_prompt; + + json content_piece_image_data; + content_piece_image_data["data"] = ""; + + auto content_piece_type = content_piece["type"].asString(); + if (content_piece_type == "text") { + auto text = content_piece["text"].asString(); + formatted_output += text; + } else if (content_piece_type == "image_url") { + auto image_url = content_piece["image_url"]["url"].asString(); + std::string base64_image_data; + if (image_url.find("http") != std::string::npos) { + LOG_INFO << "Request " << request_id << ": " + << "Remote image detected but not supported yet"; + } else if (image_url.find("data:image") != std::string::npos) { + LOG_INFO << "Request " << request_id << ": " + << "Base64 image detected"; + base64_image_data = llama_utils::extractBase64(image_url); + // LOG_INFO << "Request " << request_id << ": " << base64_image_data; + } else { + LOG_INFO << "Request " << request_id << ": " + << "Local image detected"; + llama_utils::processLocalImage( + image_url, [&](const std::string& base64Image) { + base64_image_data = base64Image; + }); + // LOG_INFO << "Request " << request_id << ": " << base64_image_data; + } + content_piece_image_data["data"] = base64_image_data; - json content_piece_image_data; - content_piece_image_data["data"] = ""; - - auto content_piece_type = content_piece["type"].asString(); - if (content_piece_type == "text") { - auto text = content_piece["text"].asString(); - formatted_output += text; - } else if (content_piece_type == "image_url") { - auto image_url = content_piece["image_url"]["url"].asString(); - std::string base64_image_data; - if (image_url.find("http") != std::string::npos) { - LOG_INFO << "Request " << request_id << ": " - << "Remote image detected but not supported yet"; - } else if (image_url.find("data:image") != std::string::npos) { - LOG_INFO << "Request " << request_id << ": " - << "Base64 image detected"; - base64_image_data = llama_utils::extractBase64(image_url); - // LOG_INFO << "Request " << request_id << ": " << base64_image_data; - } else { - LOG_INFO << "Request " << request_id << ": " - << "Local image detected"; - llama_utils::processLocalImage( - image_url, [&](const std::string& base64Image) { - base64_image_data = base64Image; - }); - // LOG_INFO << "Request " << request_id << ": " << base64_image_data; + formatted_output += "[img-" + std::to_string(no_images) + "]"; + content_piece_image_data["id"] = no_images; + data["image_data"].push_back(content_piece_image_data); + no_images++; } - content_piece_image_data["data"] = base64_image_data; - - formatted_output += "[img-" + std::to_string(no_images) + "]"; - content_piece_image_data["id"] = no_images; - data["image_data"].push_back(content_piece_image_data); - no_images++; } - } - } else if (input_role == "assistant") { - role = si.ai_prompt; - std::string content = message["content"].asString(); - formatted_output += role + content; - } else if (input_role == "system") { - role = si.system_prompt; - std::string content = message["content"].asString(); - formatted_output = role + content + formatted_output; + } else if (input_role == "assistant") { + role = si.ai_prompt; + std::string content = message["content"].asString(); + formatted_output += role + content; + } else if (input_role == "system") { + role = si.system_prompt; + std::string content = message["content"].asString(); + formatted_output = role + content + formatted_output; - } else { - role = input_role; - std::string content = message["content"].asString(); - formatted_output += role + content; + } else { + role = input_role; + std::string content = message["content"].asString(); + formatted_output += role + content; + } } + formatted_output += si.ai_prompt; } - formatted_output += si.ai_prompt; - // LOG_INFO << "Request " << request_id << ": " << formatted_output; } data["prompt"] = formatted_output;