Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
namchuai committed Dec 19, 2024
1 parent 79f2678 commit dab2d6d
Show file tree
Hide file tree
Showing 6 changed files with 528 additions and 45 deletions.
3 changes: 1 addition & 2 deletions engine/cli/commands/chat_completion_cmd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ size_t WriteCallback(char* ptr, size_t size, size_t nmemb, void* userdata) {
}

try {
std::string content =
auto content =
json_helper::ParseJsonString(chunk)["choices"][0]["delta"]["content"]
.asString();
std::cout << content << std::flush;
Expand All @@ -51,7 +51,6 @@ size_t WriteCallback(char* ptr, size_t size, size_t nmemb, void* userdata) {

return data_length;
}

} // namespace

void ChatCompletionCmd::Exec(const std::string& host, int port,
Expand Down
29 changes: 29 additions & 0 deletions engine/common/model_metadata.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#pragma once

#include "common/tokenizer.h"
#include <sstream>

struct ModelMetadata {
uint32_t version;
uint64_t tensor_count;
uint64_t metadata_kv_count;
std::unique_ptr<Tokenizer> tokenizer;

std::string ToString() const {
std::ostringstream ss;
ss << "ModelMetadata {\n"
<< "version: " << version << "\n"
<< "tensor_count: " << tensor_count << "\n"
<< "metadata_kv_count: " << metadata_kv_count << "\n"
<< "tokenizer: ";

if (tokenizer) {
ss << "\n" << tokenizer->ToString();
} else {
ss << "null";
}

ss << "\n}";
return ss.str();
}
};
33 changes: 0 additions & 33 deletions engine/common/model_tokenizer.h

This file was deleted.

68 changes: 68 additions & 0 deletions engine/common/tokenizer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#pragma once

#include <sstream>
#include <string>

struct Tokenizer {
std::string eos_token = "";
bool add_eos_token = true;

std::string bos_token = "";
bool add_bos_token = true;

std::string unknown_token = "";
std::string padding_token = "";

std::string chat_template = "";

// Helper function for common fields
std::string BaseToString() const {
std::ostringstream ss;
ss << "eos_token: \"" << eos_token << "\"\n"
<< "add_eos_token: " << (add_eos_token ? "true" : "false") << "\n"
<< "bos_token: \"" << bos_token << "\"\n"
<< "add_bos_token: " << (add_bos_token ? "true" : "false") << "\n"
<< "unknown_token: \"" << unknown_token << "\"\n"
<< "padding_token: \"" << padding_token << "\"\n"
<< "chat_template: \"" << chat_template << "\"";
return ss.str();
}

virtual ~Tokenizer() = default;

virtual std::string ToString() = 0;
};

struct GgufTokenizer : public Tokenizer {
std::string pre = "";

~GgufTokenizer() override = default;

std::string ToString() override {
std::ostringstream ss;
ss << "GgufTokenizer {\n";
// Add base class members
ss << BaseToString() << "\n";
// Add derived class members
ss << "pre: \"" << pre << "\"\n";
ss << "}";
return ss.str();
}
};

struct SafeTensorTokenizer : public Tokenizer {
bool add_prefix_space = true;

~SafeTensorTokenizer() = default;

std::string ToString() override {
std::ostringstream ss;
ss << "SafeTensorTokenizer {\n";
// Add base class members
ss << BaseToString() << "\n";
// Add derived class members
ss << "add_prefix_space: " << (add_prefix_space ? "true" : "false") << "\n";
ss << "}";
return ss.str();
}
};
20 changes: 10 additions & 10 deletions engine/config/chat_template_renderer.h
Original file line number Diff line number Diff line change
Expand Up @@ -276,15 +276,15 @@ static int32_t llama_chat_apply_template_internal(
}
} else if (tmpl == "llama3" || (tmpl_contains("<|start_header_id|>") &&
tmpl_contains("<|end_header_id|>"))) {
// Llama 3
for (auto message : chat) {
std::string role(message->role);
ss << "<|start_header_id|>" << role << "<|end_header_id|>\n\n"
<< trim(message->content) << "<|eot_id|>";
}
if (add_ass) {
ss << "<|start_header_id|>assistant<|end_header_id|>\n\n";
}
// // Llama 3
// for (auto message : chat) {
// std::string role(message->role);
// ss << "<|start_header_id|>" << role << "<|end_header_id|>\n\n"
// << trim(message->content) << "<|eot_id|>";
// }
// if (add_ass) {
// ss << "<|start_header_id|>assistant<|end_header_id|>\n\n";
// }
} else if (tmpl == "chatglm3" || tmpl_contains("[gMASK]sop")) {
// chatglm3-6b
ss << "[gMASK]" << "sop";
Expand Down Expand Up @@ -426,4 +426,4 @@ std::string llama_chat_apply_template(const std::string& tmpl,
std::string formatted_chat(buf.data(), res);
return formatted_chat;
}
} // namespace config
} // namespace config
Loading

0 comments on commit dab2d6d

Please sign in to comment.