From dab2d6df537b93a75269e1d0f0ef8a9924e09f22 Mon Sep 17 00:00:00 2001 From: James Date: Fri, 20 Dec 2024 00:15:12 +0700 Subject: [PATCH] update --- engine/cli/commands/chat_completion_cmd.cc | 3 +- engine/common/model_metadata.h | 29 ++ engine/common/model_tokenizer.h | 33 -- engine/common/tokenizer.h | 68 ++++ engine/config/chat_template_renderer.h | 20 +- engine/utils/gguf_metadata_reader.h | 420 +++++++++++++++++++++ 6 files changed, 528 insertions(+), 45 deletions(-) create mode 100644 engine/common/model_metadata.h delete mode 100644 engine/common/model_tokenizer.h create mode 100644 engine/common/tokenizer.h create mode 100644 engine/utils/gguf_metadata_reader.h diff --git a/engine/cli/commands/chat_completion_cmd.cc b/engine/cli/commands/chat_completion_cmd.cc index bb2c55879..9841fdf15 100644 --- a/engine/cli/commands/chat_completion_cmd.cc +++ b/engine/cli/commands/chat_completion_cmd.cc @@ -37,7 +37,7 @@ size_t WriteCallback(char* ptr, size_t size, size_t nmemb, void* userdata) { } try { - std::string content = + auto content = json_helper::ParseJsonString(chunk)["choices"][0]["delta"]["content"] .asString(); std::cout << content << std::flush; @@ -51,7 +51,6 @@ size_t WriteCallback(char* ptr, size_t size, size_t nmemb, void* userdata) { return data_length; } - } // namespace void ChatCompletionCmd::Exec(const std::string& host, int port, diff --git a/engine/common/model_metadata.h b/engine/common/model_metadata.h new file mode 100644 index 000000000..739a0af3d --- /dev/null +++ b/engine/common/model_metadata.h @@ -0,0 +1,29 @@ +#pragma once + +#include "common/tokenizer.h" +#include + +struct ModelMetadata { + uint32_t version; + uint64_t tensor_count; + uint64_t metadata_kv_count; + std::unique_ptr tokenizer; + + std::string ToString() const { + std::ostringstream ss; + ss << "ModelMetadata {\n" + << "version: " << version << "\n" + << "tensor_count: " << tensor_count << "\n" + << "metadata_kv_count: " << metadata_kv_count << "\n" + << "tokenizer: "; + + if (tokenizer) { + ss << "\n" << tokenizer->ToString(); + } else { + ss << "null"; + } + + ss << "\n}"; + return ss.str(); + } +}; diff --git a/engine/common/model_tokenizer.h b/engine/common/model_tokenizer.h deleted file mode 100644 index b8d3622cb..000000000 --- a/engine/common/model_tokenizer.h +++ /dev/null @@ -1,33 +0,0 @@ -#pragma once - -#include "common/json_serializable.h" - -struct ModelTokenizer : JsonSerializable { - std::string model; - - std::string pre; - - std::vector tokens; - - std::vector token_type; - - std::vector merges; - - // TODO: clean this up - size_t eos_token_id; - - size_t padding_token_id; - - size_t bos_token_id; - - bool add_bos_token; - - std::string chat_template; - - cpp::result ToJson() override { - Json::Value root; - // TODO: namh handle this - - return root; - } -}; diff --git a/engine/common/tokenizer.h b/engine/common/tokenizer.h new file mode 100644 index 000000000..36ccaa694 --- /dev/null +++ b/engine/common/tokenizer.h @@ -0,0 +1,68 @@ +#pragma once + +#include +#include + +struct Tokenizer { + std::string eos_token = ""; + bool add_eos_token = true; + + std::string bos_token = ""; + bool add_bos_token = true; + + std::string unknown_token = ""; + std::string padding_token = ""; + + std::string chat_template = ""; + + // Helper function for common fields + std::string BaseToString() const { + std::ostringstream ss; + ss << "eos_token: \"" << eos_token << "\"\n" + << "add_eos_token: " << (add_eos_token ? "true" : "false") << "\n" + << "bos_token: \"" << bos_token << "\"\n" + << "add_bos_token: " << (add_bos_token ? "true" : "false") << "\n" + << "unknown_token: \"" << unknown_token << "\"\n" + << "padding_token: \"" << padding_token << "\"\n" + << "chat_template: \"" << chat_template << "\""; + return ss.str(); + } + + virtual ~Tokenizer() = default; + + virtual std::string ToString() = 0; +}; + +struct GgufTokenizer : public Tokenizer { + std::string pre = ""; + + ~GgufTokenizer() override = default; + + std::string ToString() override { + std::ostringstream ss; + ss << "GgufTokenizer {\n"; + // Add base class members + ss << BaseToString() << "\n"; + // Add derived class members + ss << "pre: \"" << pre << "\"\n"; + ss << "}"; + return ss.str(); + } +}; + +struct SafeTensorTokenizer : public Tokenizer { + bool add_prefix_space = true; + + ~SafeTensorTokenizer() = default; + + std::string ToString() override { + std::ostringstream ss; + ss << "SafeTensorTokenizer {\n"; + // Add base class members + ss << BaseToString() << "\n"; + // Add derived class members + ss << "add_prefix_space: " << (add_prefix_space ? "true" : "false") << "\n"; + ss << "}"; + return ss.str(); + } +}; diff --git a/engine/config/chat_template_renderer.h b/engine/config/chat_template_renderer.h index 881186a9d..1d208518b 100644 --- a/engine/config/chat_template_renderer.h +++ b/engine/config/chat_template_renderer.h @@ -276,15 +276,15 @@ static int32_t llama_chat_apply_template_internal( } } else if (tmpl == "llama3" || (tmpl_contains("<|start_header_id|>") && tmpl_contains("<|end_header_id|>"))) { - // Llama 3 - for (auto message : chat) { - std::string role(message->role); - ss << "<|start_header_id|>" << role << "<|end_header_id|>\n\n" - << trim(message->content) << "<|eot_id|>"; - } - if (add_ass) { - ss << "<|start_header_id|>assistant<|end_header_id|>\n\n"; - } + // // Llama 3 + // for (auto message : chat) { + // std::string role(message->role); + // ss << "<|start_header_id|>" << role << "<|end_header_id|>\n\n" + // << trim(message->content) << "<|eot_id|>"; + // } + // if (add_ass) { + // ss << "<|start_header_id|>assistant<|end_header_id|>\n\n"; + // } } else if (tmpl == "chatglm3" || tmpl_contains("[gMASK]sop")) { // chatglm3-6b ss << "[gMASK]" << "sop"; @@ -426,4 +426,4 @@ std::string llama_chat_apply_template(const std::string& tmpl, std::string formatted_chat(buf.data(), res); return formatted_chat; } -} // namespace config \ No newline at end of file +} // namespace config diff --git a/engine/utils/gguf_metadata_reader.h b/engine/utils/gguf_metadata_reader.h new file mode 100644 index 000000000..0dba0c1bb --- /dev/null +++ b/engine/utils/gguf_metadata_reader.h @@ -0,0 +1,420 @@ +#pragma once + +#include +#include +#include +#include +#include "common/model_metadata.h" +#include "utils/logging_utils.h" +#include "utils/result.hpp" + +/** + * Parsing the GGUF metadata. + * + * Reference: https://github.com/ggerganov/ggml/blob/master/docs/gguf.md + */ +namespace cortex_utils { +namespace { +// present in the first 4 bytes of a GGUF file +constexpr uint32_t GGUF_MAGIC_NUMBER = 1179993927; + +constexpr static auto GGUF_VERSION_LENGTH = 4; +constexpr static auto TENSOR_COUNT_LENGTH = 8; +constexpr static auto METADATA_KV_COUNT = 8; + +constexpr static auto TOKEN_LIST_KEY = "tokenizer.ggml.tokens"; +constexpr static auto BOS_ID_KEY = "tokenizer.ggml.bos_token_id"; +constexpr static auto EOS_ID_KEY = "tokenizer.ggml.eos_token_id"; +constexpr static auto UNK_ID_KEY = "tokenizer.ggml.unknown_token_id"; +constexpr static auto PADDING_ID_KEY = "tokenizer.ggml.padding_token_id"; + +constexpr static auto CHAT_TEMPLATE_ID_KEY = "tokenizer.chat_template"; +constexpr static auto ADD_BOS_TOKEN_KEY = "tokenizer.ggml.add_bos_token"; +constexpr static auto ADD_EOS_TOKEN_KEY = "tokenizer.ggml.add_eos_token"; +const std::vector kSpecialTokenIds{BOS_ID_KEY, EOS_ID_KEY, + UNK_ID_KEY, PADDING_ID_KEY}; + +struct MetadataArrayElement; + +// clang-format off +using MetadataValue = std::variant< + uint8_t, int8_t, + uint16_t, int16_t, + uint32_t, int32_t, + uint64_t, int64_t, + float, double, + bool, std::string, + std::vector +>; + +// clang-format on + +struct MetadataArrayElement { + MetadataValue value; + + // Add constructors for different types + MetadataArrayElement(uint8_t v) : value(v) {} + MetadataArrayElement(int8_t v) : value(v) {} + MetadataArrayElement(uint16_t v) : value(v) {} + MetadataArrayElement(int16_t v) : value(v) {} + MetadataArrayElement(uint32_t v) : value(v) {} + MetadataArrayElement(int32_t v) : value(v) {} + MetadataArrayElement(uint64_t v) : value(v) {} + MetadataArrayElement(int64_t v) : value(v) {} + MetadataArrayElement(float v) : value(v) {} + MetadataArrayElement(double v) : value(v) {} + MetadataArrayElement(bool v) : value(v) {} + MetadataArrayElement(const std::string& v) : value(v) {} + MetadataArrayElement(std::string&& v) : value(std::move(v)) {} + + MetadataArrayElement(MetadataValue&& v) : value(std::move(v)) {} +}; + +struct MetadataValueResult { + size_t bytes_read; + MetadataValue value; + + template + MetadataValueResult(size_t br, T&& val) + : bytes_read(br), value(std::forward(val)) {} +}; + +std::pair ReadString(std::ifstream& file) { + uint64_t length; + file.read(reinterpret_cast(&length), sizeof(uint64_t)); + + if (!file) { + throw std::runtime_error("Failed to read string length"); + } + + if (length > 1024 * 1024 * 1024) { + throw std::runtime_error("String length too large: " + + std::to_string(length)); + } + + std::string value(length, '\0'); + file.read(value.data(), length); + + if (!file) { + throw std::runtime_error("Failed to read string content of length " + + std::to_string(length)); + } + + return {8 + length, value}; +} + +inline MetadataValueResult ReadMetadataValue(uint32_t value_type, + std::ifstream& file, + const std::string& key) { + switch (value_type) { + case 0: { // uint8 + uint8_t value; + file.read(reinterpret_cast(&value), sizeof(value)); + return {sizeof(uint8_t), value}; + } + case 1: { // int8 + int8_t value; + file.read(reinterpret_cast(&value), sizeof(value)); + return {sizeof(int8_t), value}; + } + case 2: { // uint16 + uint16_t value; + file.read(reinterpret_cast(&value), sizeof(value)); + return {sizeof(uint16_t), value}; + } + case 3: { // int16 + int16_t value; + file.read(reinterpret_cast(&value), sizeof(value)); + return {sizeof(int16_t), value}; + } + case 4: { // uint32 + uint32_t value; + file.read(reinterpret_cast(&value), sizeof(value)); + return {sizeof(uint32_t), value}; + } + case 5: { // int32 + int32_t value; + file.read(reinterpret_cast(&value), sizeof(value)); + return {sizeof(int32_t), value}; + } + case 6: { // float32 + float value; + file.read(reinterpret_cast(&value), sizeof(value)); + return {sizeof(float), value}; + } + case 7: { // bool + bool value; + file.read(reinterpret_cast(&value), sizeof(value)); + return {sizeof(bool), value}; + } + case 8: { // string + auto [length, value] = ReadString(file); + return {length, value}; + } + case 9: { // array + uint32_t array_type; + file.read(reinterpret_cast(&array_type), sizeof(uint32_t)); + + uint64_t array_length; + file.read(reinterpret_cast(&array_length), sizeof(uint64_t)); + + size_t bytes_read = 12; // 4 for type + 8 for length + + std::vector array_values_string; + std::vector array_values_float; + + for (uint64_t i = 0; i < array_length; ++i) { + auto result = ReadMetadataValue(array_type, file, + key + "[" + std::to_string(i) + "]"); + bytes_read += result.bytes_read; + + if (array_type == 8) { + array_values_string.push_back(std::get(result.value)); + } else { + float float_value; + switch (result.value.index()) { + case 0: + float_value = static_cast(std::get(result.value)); + break; + case 1: + float_value = static_cast(std::get(result.value)); + break; + case 2: + float_value = + static_cast(std::get(result.value)); + break; + case 3: + float_value = static_cast(std::get(result.value)); + break; + case 4: + float_value = + static_cast(std::get(result.value)); + break; + case 5: + float_value = static_cast(std::get(result.value)); + break; + case 6: + float_value = + static_cast(std::get(result.value)); + break; + case 7: + float_value = static_cast(std::get(result.value)); + break; + case 8: + float_value = std::get(result.value); + break; + case 9: + float_value = static_cast(std::get(result.value)); + break; + case 10: + float_value = static_cast(std::get(result.value)); + break; + default: + throw std::runtime_error( + "Unexpected type in array element conversion"); + } + array_values_float.push_back(float_value); + } + } + + if (!array_values_string.empty()) { + std::vector result; + result.reserve(array_values_string.size()); + for (const auto& str : array_values_string) { + result.emplace_back(str); + } + return {bytes_read, std::move(result)}; + } else { + std::vector result; + result.reserve(array_values_float.size()); + for (float val : array_values_float) { + result.emplace_back(val); + } + return {bytes_read, std::move(result)}; + } + } + + case 10: { // uint64 + uint64_t value; + file.read(reinterpret_cast(&value), sizeof(value)); + return {sizeof(uint64_t), value}; + } + case 11: { // int64 + int64_t value; + file.read(reinterpret_cast(&value), sizeof(value)); + return {sizeof(int64_t), value}; + } + case 12: { // float64/double + double value; + file.read(reinterpret_cast(&value), sizeof(value)); + return {sizeof(double), value}; + } + default: + throw std::runtime_error("Unknown value type: " + + std::to_string(value_type) + " for key: " + key); + } +} + +void PrintMetadataValue(const std::string& key, const MetadataValue& value) { + std::ostringstream oss; + oss << "Key: " << key << " = "; + + switch (value.index()) { + case 0: // uint8_t + oss << "uint8: " << static_cast(std::get(value)); + break; + case 1: // int8_t + oss << "int8: " << static_cast(std::get(value)); + break; + case 2: // uint16_t + oss << "uint16: " << std::get(value); + break; + case 3: // int16_t + oss << "int16: " << std::get(value); + break; + case 4: // uint32_t + oss << "uint32: " << std::get(value); + break; + case 5: // int32_t + oss << "int32: " << std::get(value); + break; + case 6: // uint64_t + oss << "uint64: " << std::get(value); + break; + case 7: // int64_t + oss << "int64: " << std::get(value); + break; + case 8: // float + oss << "float: " << std::get(value); + break; + case 9: // double + oss << "double: " << std::get(value); + break; + case 10: // bool + oss << "bool: " << (std::get(value) ? "true" : "false"); + break; + case 11: // string + oss << "string: " << std::get(value); + break; + case 12: { // vector + const auto& arr = std::get>(value); + oss << "array[" << arr.size() << "]: {"; + for (size_t i = 0; i < arr.size(); ++i) { + if (i > 0) + oss << ", "; + std::ostringstream key_oss; + key_oss << key << "[" << i << "]"; + PrintMetadataValue(key_oss.str(), arr[i].value); + } + oss << "}"; + break; + } + } + + CTL_INF(oss.str()); +} +} // namespace + +inline cpp::result, std::string> +ReadGgufMetadata(const std::filesystem::path& path) { + if (!std::filesystem::exists(path)) { + return cpp::fail("Gguf file does not exist at " + path.string()); + } + + std::ifstream file(path, std::ios::binary); + if (!file) { + return cpp::fail("Failed to open file: " + path.string()); + } + + uint32_t magic_number; + file.read(reinterpret_cast(&magic_number), sizeof(magic_number)); + if (magic_number != GGUF_MAGIC_NUMBER) { + return cpp::fail("Invalid GGUF file: incorrect magic number"); + } + + auto metadata_ptr = std::make_unique(); + + uint32_t version; + file.read(reinterpret_cast(&version), GGUF_VERSION_LENGTH); + metadata_ptr->version = version; + + uint64_t tensor_count; + file.read(reinterpret_cast(&tensor_count), TENSOR_COUNT_LENGTH); + metadata_ptr->tensor_count = tensor_count; + + uint64_t metadata_kv_count; + file.read(reinterpret_cast(&metadata_kv_count), METADATA_KV_COUNT); + metadata_ptr->metadata_kv_count = metadata_kv_count; + + std::unordered_map kv; + for (uint64_t i = 0; i < metadata_kv_count; ++i) { + auto [key_byte_length, key] = ReadString(file); + + char value_type_bytes[4]; + file.read(value_type_bytes, 4); + uint32_t value_type = + static_cast(static_cast(value_type_bytes[0])) | + (static_cast(static_cast(value_type_bytes[1])) + << 8) | + (static_cast(static_cast(value_type_bytes[2])) + << 16) | + (static_cast(static_cast(value_type_bytes[3])) + << 24); + + try { + auto result = ReadMetadataValue(value_type, file, key); + kv.emplace(key, result); + } catch (const std::exception& e) { + CTL_ERR("Error reading metadata value for key '" + key + + "': " + e.what()); + return cpp::fail("Error reading metadata value for key '" + key + "'"); + } + } + + { + metadata_ptr->tokenizer = std::make_unique(); + // initialize tokenizer + if (auto it = kv.find(CHAT_TEMPLATE_ID_KEY); it != kv.end()) { + metadata_ptr->tokenizer->chat_template = + std::get(it->second.value); + } + + for (const auto& key : kSpecialTokenIds) { + if (auto it = kv.find(key); it != kv.end()) { + auto id = std::get(it->second.value); + if (auto token_it = kv.find(TOKEN_LIST_KEY); token_it != kv.end()) { + auto& tokens = std::get>( + token_it->second.value); + + if (key == BOS_ID_KEY) { + metadata_ptr->tokenizer->bos_token = + std::get(tokens[id].value); + } else if (key == EOS_ID_KEY) { + metadata_ptr->tokenizer->eos_token = + std::get(tokens[id].value); + } else if (key == UNK_ID_KEY) { + metadata_ptr->tokenizer->unknown_token = + std::get(tokens[id].value); + } else if (key == PADDING_ID_KEY) { + metadata_ptr->tokenizer->padding_token = + std::get(tokens[id].value); + } else { + CTL_ERR("Unknown special token key: " + key); + } + } + } + } + + if (auto it = kv.find(ADD_BOS_TOKEN_KEY); it != kv.end()) { + metadata_ptr->tokenizer->add_bos_token = std::get(it->second.value); + } + + if (auto it = kv.find(ADD_EOS_TOKEN_KEY); it != kv.end()) { + metadata_ptr->tokenizer->add_eos_token = std::get(it->second.value); + } + } + + CTL_INF("Parsed GGUF metadata successfully: " + metadata_ptr->ToString()); + return metadata_ptr; +} +} // namespace cortex_utils