From 5684fe635e3cbdc5e705e072cb7ef1f622e17982 Mon Sep 17 00:00:00 2001 From: Louis Date: Mon, 4 Nov 2024 16:44:27 +0700 Subject: [PATCH] chore: return model size after pulled (#1626) * chore: return model size after pulled * chore: remove double cast checking --- docs/static/openapi/cortex.json | 5 +++++ engine/config/model_config.h | 5 +++++ engine/config/yaml_config.cc | 5 +++++ engine/services/model_service.cc | 18 +++++++++++++++--- 4 files changed, 30 insertions(+), 3 deletions(-) diff --git a/docs/static/openapi/cortex.json b/docs/static/openapi/cortex.json index f6120a4ad..1378ade12 100644 --- a/docs/static/openapi/cortex.json +++ b/docs/static/openapi/cortex.json @@ -3434,6 +3434,11 @@ "description": "To enable mmap, default is true", "example": true }, + "size": { + "type": "number", + "description": "The model file size in bytes", + "example": 1073741824 + }, "engine": { "type": "string", "description": "The engine to use.", diff --git a/engine/config/model_config.h b/engine/config/model_config.h index bc3a7ec25..044fd8dd3 100644 --- a/engine/config/model_config.h +++ b/engine/config/model_config.h @@ -58,6 +58,7 @@ struct ModelConfig { bool ignore_eos = false; int n_probs = 0; int min_keep = 0; + uint64_t size = 0; std::string grammar; void FromJson(const Json::Value& json) { @@ -70,6 +71,8 @@ struct ModelConfig { // model = json["model"].asString(); if (json.isMember("version")) version = json["version"].asString(); + if (json.isMember("size")) + size = json["size"].asUInt64(); if (json.isMember("stop") && json["stop"].isArray()) { stop.clear(); @@ -176,6 +179,7 @@ struct ModelConfig { obj["name"] = name; obj["model"] = model; obj["version"] = version; + obj["size"] = size; Json::Value stop_array(Json::arrayValue); for (const auto& s : stop) { @@ -269,6 +273,7 @@ struct ModelConfig { oss << format_utils::print_comment("END REQUIRED"); oss << format_utils::print_comment("BEGIN OPTIONAL"); + oss << format_utils::print_float("size", size); oss << format_utils::print_bool("stream", stream); oss << format_utils::print_float("top_p", top_p); oss << format_utils::print_float("temperature", temperature); diff --git a/engine/config/yaml_config.cc b/engine/config/yaml_config.cc index 99f8103d8..e4932c9c3 100644 --- a/engine/config/yaml_config.cc +++ b/engine/config/yaml_config.cc @@ -75,6 +75,8 @@ void YamlHandler::ModelConfigFromYaml() { tmp.model = yaml_node_["model"].as(); if (yaml_node_["version"]) tmp.version = yaml_node_["version"].as(); + if (yaml_node_["size"]) + tmp.size = yaml_node_["size"].as(); if (yaml_node_["engine"]) tmp.engine = yaml_node_["engine"].as(); if (yaml_node_["prompt_template"]) { @@ -266,6 +268,8 @@ void YamlHandler::UpdateModelConfig(ModelConfig new_model_config) { if (!model_config_.grammar.empty()) yaml_node_["grammar"] = model_config_.grammar; + yaml_node_["size"] = model_config_.size; + yaml_node_["created"] = std::time(nullptr); } catch (const std::exception& e) { std::cerr << "Error when update model config : " << e.what() << std::endl; @@ -318,6 +322,7 @@ void YamlHandler::WriteYamlFile(const std::string& file_path) const { outFile << "# END REQUIRED\n"; outFile << "\n"; outFile << "# BEGIN OPTIONAL\n"; + outFile << format_utils::writeKeyValue("size", yaml_node_["size"]); outFile << format_utils::writeKeyValue("stream", yaml_node_["stream"], "Default true?"); outFile << format_utils::writeKeyValue("top_p", yaml_node_["top_p"], diff --git a/engine/services/model_service.cc b/engine/services/model_service.cc index 4967b1dd9..d9656073e 100644 --- a/engine/services/model_service.cc +++ b/engine/services/model_service.cc @@ -18,7 +18,8 @@ namespace { void ParseGguf(const DownloadItem& ggufDownloadItem, std::optional author, - std::optional name) { + std::optional name, + std::optional size) { namespace fs = std::filesystem; namespace fmu = file_manager_utils; config::GGUFHandler gguf_handler; @@ -35,6 +36,7 @@ void ParseGguf(const DownloadItem& ggufDownloadItem, model_config.model = ggufDownloadItem.id; model_config.name = name.has_value() ? name.value() : gguf_handler.GetModelConfig().name; + model_config.size = size.value_or(0); yaml_handler.UpdateModelConfig(model_config); auto yaml_path{ggufDownloadItem.localPath}; @@ -284,8 +286,13 @@ cpp::result ModelService::HandleDownloadUrlAsync( }}}}; auto on_finished = [author, temp_name](const DownloadTask& finishedTask) { + // Sum downloadedBytes from all items + uint64_t model_size = 0; + for (const auto& item : finishedTask.items) { + model_size = model_size + item.bytes.value_or(0); + } auto gguf_download_item = finishedTask.items[0]; - ParseGguf(gguf_download_item, author, temp_name); + ParseGguf(gguf_download_item, author, temp_name, model_size); }; downloadTask.id = unique_model_id; @@ -349,8 +356,13 @@ cpp::result ModelService::HandleUrl( }}}}; auto on_finished = [author](const DownloadTask& finishedTask) { + // Sum downloadedBytes from all items + uint64_t model_size = 0; + for (const auto& item : finishedTask.items) { + model_size = model_size + item.bytes.value_or(0); + } auto gguf_download_item = finishedTask.items[0]; - ParseGguf(gguf_download_item, author, std::nullopt); + ParseGguf(gguf_download_item, author, std::nullopt, model_size); }; auto result = download_service_->AddDownloadTask(downloadTask, on_finished);