Skip to content

Commit

Permalink
chore: return model size after pulled (#1626)
Browse files Browse the repository at this point in the history
* chore: return model size after pulled

* chore: remove double cast checking
  • Loading branch information
louis-jan authored Nov 4, 2024
1 parent 8961a0d commit 5684fe6
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 3 deletions.
5 changes: 5 additions & 0 deletions docs/static/openapi/cortex.json
Original file line number Diff line number Diff line change
Expand Up @@ -3434,6 +3434,11 @@
"description": "To enable mmap, default is true",
"example": true
},
"size": {
"type": "number",
"description": "The model file size in bytes",
"example": 1073741824
},
"engine": {
"type": "string",
"description": "The engine to use.",
Expand Down
5 changes: 5 additions & 0 deletions engine/config/model_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ struct ModelConfig {
bool ignore_eos = false;
int n_probs = 0;
int min_keep = 0;
uint64_t size = 0;
std::string grammar;

void FromJson(const Json::Value& json) {
Expand All @@ -70,6 +71,8 @@ struct ModelConfig {
// model = json["model"].asString();
if (json.isMember("version"))
version = json["version"].asString();
if (json.isMember("size"))
size = json["size"].asUInt64();

if (json.isMember("stop") && json["stop"].isArray()) {
stop.clear();
Expand Down Expand Up @@ -176,6 +179,7 @@ struct ModelConfig {
obj["name"] = name;
obj["model"] = model;
obj["version"] = version;
obj["size"] = size;

Json::Value stop_array(Json::arrayValue);
for (const auto& s : stop) {
Expand Down Expand Up @@ -269,6 +273,7 @@ struct ModelConfig {
oss << format_utils::print_comment("END REQUIRED");
oss << format_utils::print_comment("BEGIN OPTIONAL");

oss << format_utils::print_float("size", size);
oss << format_utils::print_bool("stream", stream);
oss << format_utils::print_float("top_p", top_p);
oss << format_utils::print_float("temperature", temperature);
Expand Down
5 changes: 5 additions & 0 deletions engine/config/yaml_config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ void YamlHandler::ModelConfigFromYaml() {
tmp.model = yaml_node_["model"].as<std::string>();
if (yaml_node_["version"])
tmp.version = yaml_node_["version"].as<std::string>();
if (yaml_node_["size"])
tmp.size = yaml_node_["size"].as<uint64_t>();
if (yaml_node_["engine"])
tmp.engine = yaml_node_["engine"].as<std::string>();
if (yaml_node_["prompt_template"]) {
Expand Down Expand Up @@ -266,6 +268,8 @@ void YamlHandler::UpdateModelConfig(ModelConfig new_model_config) {
if (!model_config_.grammar.empty())
yaml_node_["grammar"] = model_config_.grammar;

yaml_node_["size"] = model_config_.size;

yaml_node_["created"] = std::time(nullptr);
} catch (const std::exception& e) {
std::cerr << "Error when update model config : " << e.what() << std::endl;
Expand Down Expand Up @@ -318,6 +322,7 @@ void YamlHandler::WriteYamlFile(const std::string& file_path) const {
outFile << "# END REQUIRED\n";
outFile << "\n";
outFile << "# BEGIN OPTIONAL\n";
outFile << format_utils::writeKeyValue("size", yaml_node_["size"]);
outFile << format_utils::writeKeyValue("stream", yaml_node_["stream"],
"Default true?");
outFile << format_utils::writeKeyValue("top_p", yaml_node_["top_p"],
Expand Down
18 changes: 15 additions & 3 deletions engine/services/model_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
namespace {
void ParseGguf(const DownloadItem& ggufDownloadItem,
std::optional<std::string> author,
std::optional<std::string> name) {
std::optional<std::string> name,
std::optional<std::uint64_t> size) {
namespace fs = std::filesystem;
namespace fmu = file_manager_utils;
config::GGUFHandler gguf_handler;
Expand All @@ -35,6 +36,7 @@ void ParseGguf(const DownloadItem& ggufDownloadItem,
model_config.model = ggufDownloadItem.id;
model_config.name =
name.has_value() ? name.value() : gguf_handler.GetModelConfig().name;
model_config.size = size.value_or(0);
yaml_handler.UpdateModelConfig(model_config);

auto yaml_path{ggufDownloadItem.localPath};
Expand Down Expand Up @@ -284,8 +286,13 @@ cpp::result<DownloadTask, std::string> ModelService::HandleDownloadUrlAsync(
}}}};

auto on_finished = [author, temp_name](const DownloadTask& finishedTask) {
// Sum downloadedBytes from all items
uint64_t model_size = 0;
for (const auto& item : finishedTask.items) {
model_size = model_size + item.bytes.value_or(0);
}
auto gguf_download_item = finishedTask.items[0];
ParseGguf(gguf_download_item, author, temp_name);
ParseGguf(gguf_download_item, author, temp_name, model_size);
};

downloadTask.id = unique_model_id;
Expand Down Expand Up @@ -349,8 +356,13 @@ cpp::result<std::string, std::string> ModelService::HandleUrl(
}}}};

auto on_finished = [author](const DownloadTask& finishedTask) {
// Sum downloadedBytes from all items
uint64_t model_size = 0;
for (const auto& item : finishedTask.items) {
model_size = model_size + item.bytes.value_or(0);
}
auto gguf_download_item = finishedTask.items[0];
ParseGguf(gguf_download_item, author, std::nullopt);
ParseGguf(gguf_download_item, author, std::nullopt, model_size);
};

auto result = download_service_->AddDownloadTask(downloadTask, on_finished);
Expand Down

0 comments on commit 5684fe6

Please sign in to comment.