From f473b0b2d78074d4ebb2e61540de470b62740ea1 Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Thu, 12 Dec 2024 11:29:11 +0700 Subject: [PATCH] feat: model sources (#1777) * feat: prioritize GPUs * fix: migrate db * fix: add priority * fix: db * fix: more * feat: model sources * feat: support delete API * feat: cli: support models sources add * feat: cli: model source delete * feat: cli: add model source list * feat: sync cortex.db * chore: cleanup * feat: add metadata for model * fix: migration * chore: unit tests: cleanup * fix: add metadata * fix: pull model * chore: unit tests: update * chore: add e2e tests for models sources * chore: add API docs * chore: rename --------- Co-authored-by: vansangpfiev --- docs/static/openapi/cortex.json | 99 ++++ engine/cli/command_line_parser.cc | 76 ++- engine/cli/command_line_parser.h | 2 + engine/cli/commands/model_list_cmd.cc | 78 +-- engine/cli/commands/model_list_cmd.h | 3 +- engine/cli/commands/model_source_add_cmd.cc | 38 ++ engine/cli/commands/model_source_add_cmd.h | 12 + engine/cli/commands/model_source_del_cmd.cc | 39 ++ engine/cli/commands/model_source_del_cmd.h | 12 + engine/cli/commands/model_source_list_cmd.cc | 56 +++ engine/cli/commands/model_source_list_cmd.h | 11 + engine/controllers/models.cc | 98 +++- engine/controllers/models.h | 25 +- engine/database/models.cc | 222 ++++----- engine/database/models.h | 22 +- engine/e2e-test/test_api_model.py | 15 +- engine/main.cc | 5 +- engine/services/model_service.cc | 107 ++-- engine/services/model_source_service.cc | 493 +++++++++++++++++++ engine/services/model_source_service.h | 53 ++ engine/test/components/test_models_db.cc | 70 +-- engine/utils/huggingface_utils.h | 2 + engine/utils/json_parser_utils.h | 2 +- 23 files changed, 1269 insertions(+), 271 deletions(-) create mode 100644 engine/cli/commands/model_source_add_cmd.cc create mode 100644 engine/cli/commands/model_source_add_cmd.h create mode 100644 engine/cli/commands/model_source_del_cmd.cc create mode 100644 engine/cli/commands/model_source_del_cmd.h create mode 100644 engine/cli/commands/model_source_list_cmd.cc create mode 100644 engine/cli/commands/model_source_list_cmd.h create mode 100644 engine/services/model_source_service.cc create mode 100644 engine/services/model_source_service.h diff --git a/docs/static/openapi/cortex.json b/docs/static/openapi/cortex.json index 9cdd5c7b4..2ff239ce2 100644 --- a/docs/static/openapi/cortex.json +++ b/docs/static/openapi/cortex.json @@ -807,6 +807,105 @@ "tags": ["Pulling Models"] } }, + "/v1/models/sources": { + "post": { + "summary": "Add a model source", + "description": "User can add a Huggingface Organization or Repository", + "requestBody": { + "required": false, + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "source": { + "type": "string", + "description": "The url of model source to add", + "example": "https://huggingface.co/cortexso/tinyllama" + } + } + } + } + } + }, + "responses": { + "200": { + "description": "Successful installation", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "message": { + "type": "string", + "example": "Added model source" + } + } + } + } + } + } + }, + "tags": ["Pulling Models"] + }, + "delete": { + "summary": "Remove a model source", + "description": "User can remove a Huggingface Organization or Repository", + "requestBody": { + "required": false, + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "source": { + "type": "string", + "description": "The url of model source to remove", + "example": "https://huggingface.co/cortexso/tinyllama" + } + } + } + } + } + }, + "responses": { + "200": { + "description": "Successful uninstallation", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "message": { + "type": "string", + "description": "Removed model source successfully!", + "example": "Removed model source successfully!" + } + } + } + } + } + }, + "400": { + "description": "Bad request", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "error": { + "type": "string", + "description": "Error message describing the issue with the request" + } + } + } + } + } + } + }, + "tags": ["Pulling Models"] + } + }, "/v1/threads": { "post": { "operationId": "ThreadsController_create", diff --git a/engine/cli/command_line_parser.cc b/engine/cli/command_line_parser.cc index 9d5d83ffc..624ccd3dd 100644 --- a/engine/cli/command_line_parser.cc +++ b/engine/cli/command_line_parser.cc @@ -20,6 +20,9 @@ #include "commands/model_import_cmd.h" #include "commands/model_list_cmd.h" #include "commands/model_pull_cmd.h" +#include "commands/model_source_add_cmd.h" +#include "commands/model_source_del_cmd.h" +#include "commands/model_source_list_cmd.h" #include "commands/model_start_cmd.h" #include "commands/model_stop_cmd.h" #include "commands/model_upd_cmd.h" @@ -253,6 +256,8 @@ void CommandLineParser::SetupModelCommands() { "Display cpu mode"); list_models_cmd->add_flag("--gpu_mode", cml_data_.display_gpu_mode, "Display gpu mode"); + list_models_cmd->add_flag("--available", cml_data_.display_available_model, + "Display available models to download"); list_models_cmd->group(kSubcommands); list_models_cmd->callback([this]() { if (std::exchange(executed_, true)) @@ -261,7 +266,8 @@ void CommandLineParser::SetupModelCommands() { cml_data_.config.apiServerHost, std::stoi(cml_data_.config.apiServerPort), cml_data_.filter, cml_data_.display_engine, cml_data_.display_version, - cml_data_.display_cpu_mode, cml_data_.display_gpu_mode); + cml_data_.display_cpu_mode, cml_data_.display_gpu_mode, + cml_data_.display_available_model); }); auto get_models_cmd = @@ -329,6 +335,74 @@ void CommandLineParser::SetupModelCommands() { std::stoi(cml_data_.config.apiServerPort), cml_data_.model_id, cml_data_.model_path); }); + + auto model_source_cmd = models_cmd->add_subcommand( + "sources", "Subcommands for managing model sources"); + model_source_cmd->usage("Usage:\n" + commands::GetCortexBinary() + + " models sources [options] [subcommand]"); + model_source_cmd->group(kSubcommands); + + model_source_cmd->callback([this, model_source_cmd] { + if (std::exchange(executed_, true)) + return; + if (model_source_cmd->get_subcommands().empty()) { + CLI_LOG(model_source_cmd->help()); + } + }); + + auto model_src_add_cmd = + model_source_cmd->add_subcommand("add", "Add a model source"); + model_src_add_cmd->usage("Usage:\n" + commands::GetCortexBinary() + + " models sources add [model_source]"); + model_src_add_cmd->group(kSubcommands); + model_src_add_cmd->add_option("source", cml_data_.model_src, ""); + model_src_add_cmd->callback([&]() { + if (std::exchange(executed_, true)) + return; + if (cml_data_.model_src.empty()) { + CLI_LOG("[model_source] is required\n"); + CLI_LOG(model_src_add_cmd->help()); + return; + }; + + commands::ModelSourceAddCmd().Exec( + cml_data_.config.apiServerHost, + std::stoi(cml_data_.config.apiServerPort), cml_data_.model_src); + }); + + auto model_src_del_cmd = + model_source_cmd->add_subcommand("remove", "Remove a model source"); + model_src_del_cmd->usage("Usage:\n" + commands::GetCortexBinary() + + " models sources remove [model_source]"); + model_src_del_cmd->group(kSubcommands); + model_src_del_cmd->add_option("source", cml_data_.model_src, ""); + model_src_del_cmd->callback([&]() { + if (std::exchange(executed_, true)) + return; + if (cml_data_.model_src.empty()) { + CLI_LOG("[model_source] is required\n"); + CLI_LOG(model_src_del_cmd->help()); + return; + }; + + commands::ModelSourceDelCmd().Exec( + cml_data_.config.apiServerHost, + std::stoi(cml_data_.config.apiServerPort), cml_data_.model_src); + }); + + auto model_src_list_cmd = + model_source_cmd->add_subcommand("list", "List all model sources"); + model_src_list_cmd->usage("Usage:\n" + commands::GetCortexBinary() + + " models sources list"); + model_src_list_cmd->group(kSubcommands); + model_src_list_cmd->callback([&]() { + if (std::exchange(executed_, true)) + return; + + commands::ModelSourceListCmd().Exec( + cml_data_.config.apiServerHost, + std::stoi(cml_data_.config.apiServerPort)); + }); } void CommandLineParser::SetupConfigsCommands() { diff --git a/engine/cli/command_line_parser.h b/engine/cli/command_line_parser.h index aec10dcb4..896c026d0 100644 --- a/engine/cli/command_line_parser.h +++ b/engine/cli/command_line_parser.h @@ -66,6 +66,7 @@ class CommandLineParser { bool display_version = false; bool display_cpu_mode = false; bool display_gpu_mode = false; + bool display_available_model = false; std::string filter = ""; std::string log_level = "INFO"; @@ -74,6 +75,7 @@ class CommandLineParser { int port; config_yaml_utils::CortexConfig config; std::unordered_map model_update_options; + std::string model_src; }; CmlData cml_data_; std::unordered_map config_update_opts_; diff --git a/engine/cli/commands/model_list_cmd.cc b/engine/cli/commands/model_list_cmd.cc index 7990563f3..96ff2885d 100644 --- a/engine/cli/commands/model_list_cmd.cc +++ b/engine/cli/commands/model_list_cmd.cc @@ -21,7 +21,7 @@ using Row_t = void ModelListCmd::Exec(const std::string& host, int port, const std::string& filter, bool display_engine, bool display_version, bool display_cpu_mode, - bool display_gpu_mode) { + bool display_gpu_mode, bool available) { // Start server if server is not started yet if (!commands::IsServerAlive(host, port)) { CLI_LOG("Starting server ..."); @@ -73,40 +73,62 @@ void ModelListCmd::Exec(const std::string& host, int port, continue; } - count += 1; + if (available) { + if (v["status"].asString() != "downloadable") { + continue; + } - std::vector row = {std::to_string(count), - v["model"].asString()}; - if (display_engine) { - row.push_back(v["engine"].asString()); - } - if (display_version) { - row.push_back(v["version"].asString()); - } + count += 1; - if (auto& r = v["recommendation"]; !r.isNull()) { - if (display_cpu_mode) { - if (!r["cpu_mode"].isNull()) { - row.push_back("RAM: " + r["cpu_mode"]["ram"].asString() + " MiB"); - } + std::vector row = {std::to_string(count), + v["model"].asString()}; + if (display_engine) { + row.push_back(v["engine"].asString()); + } + if (display_version) { + row.push_back(v["version"].asString()); + } + table.add_row({row.begin(), row.end()}); + } else { + if (v["status"].asString() == "downloadable") { + continue; + } + + count += 1; + + std::vector row = {std::to_string(count), + v["model"].asString()}; + if (display_engine) { + row.push_back(v["engine"].asString()); + } + if (display_version) { + row.push_back(v["version"].asString()); } - if (display_gpu_mode) { - if (!r["gpu_mode"].isNull()) { - std::string s; - s += "ngl: " + r["gpu_mode"][0]["ngl"].asString() + " - "; - s += "context: " + r["gpu_mode"][0]["context_length"].asString() + - " - "; - s += "RAM: " + r["gpu_mode"][0]["ram"].asString() + " MiB - "; - s += "VRAM: " + r["gpu_mode"][0]["vram"].asString() + " MiB - "; - s += "recommended ngl: " + - r["gpu_mode"][0]["recommend_ngl"].asString(); - row.push_back(s); + if (auto& r = v["recommendation"]; !r.isNull()) { + if (display_cpu_mode) { + if (!r["cpu_mode"].isNull()) { + row.push_back("RAM: " + r["cpu_mode"]["ram"].asString() + " MiB"); + } + } + + if (display_gpu_mode) { + if (!r["gpu_mode"].isNull()) { + std::string s; + s += "ngl: " + r["gpu_mode"][0]["ngl"].asString() + " - "; + s += "context: " + r["gpu_mode"][0]["context_length"].asString() + + " - "; + s += "RAM: " + r["gpu_mode"][0]["ram"].asString() + " MiB - "; + s += "VRAM: " + r["gpu_mode"][0]["vram"].asString() + " MiB - "; + s += "recommended ngl: " + + r["gpu_mode"][0]["recommend_ngl"].asString(); + row.push_back(s); + } } } - } - table.add_row({row.begin(), row.end()}); + table.add_row({row.begin(), row.end()}); + } } } diff --git a/engine/cli/commands/model_list_cmd.h b/engine/cli/commands/model_list_cmd.h index 791c1ecf6..85dd76de9 100644 --- a/engine/cli/commands/model_list_cmd.h +++ b/engine/cli/commands/model_list_cmd.h @@ -8,6 +8,7 @@ class ModelListCmd { public: void Exec(const std::string& host, int port, const std::string& filter, bool display_engine = false, bool display_version = false, - bool display_cpu_mode = false, bool display_gpu_mode = false); + bool display_cpu_mode = false, bool display_gpu_mode = false, + bool available = false); }; } // namespace commands diff --git a/engine/cli/commands/model_source_add_cmd.cc b/engine/cli/commands/model_source_add_cmd.cc new file mode 100644 index 000000000..2fadbe8ec --- /dev/null +++ b/engine/cli/commands/model_source_add_cmd.cc @@ -0,0 +1,38 @@ +#include "model_source_add_cmd.h" +#include "server_start_cmd.h" +#include "utils/json_helper.h" +#include "utils/logging_utils.h" +namespace commands { +bool ModelSourceAddCmd::Exec(const std::string& host, int port, const std::string& model_source) { + // Start server if server is not started yet + if (!commands::IsServerAlive(host, port)) { + CLI_LOG("Starting server ..."); + commands::ServerStartCmd ssc; + if (!ssc.Exec(host, port)) { + return false; + } + } + + auto url = url_parser::Url{ + .protocol = "http", + .host = host + ":" + std::to_string(port), + .pathParams = {"v1", "models", "sources"}, + }; + + Json::Value json_data; + json_data["source"] = model_source; + + auto data_str = json_data.toStyledString(); + auto res = curl_utils::SimplePostJson(url.ToFullPath(), data_str); + if (res.has_error()) { + auto root = json_helper::ParseJsonString(res.error()); + CLI_LOG(root["message"].asString()); + return false; + } + + CLI_LOG("Added model source: " << model_source); + return true; +} + + +}; // namespace commands diff --git a/engine/cli/commands/model_source_add_cmd.h b/engine/cli/commands/model_source_add_cmd.h new file mode 100644 index 000000000..6d3bcc6c0 --- /dev/null +++ b/engine/cli/commands/model_source_add_cmd.h @@ -0,0 +1,12 @@ +#pragma once + +#include +#include + +namespace commands { + +class ModelSourceAddCmd { + public: + bool Exec(const std::string& host, int port, const std::string& model_source); +}; +} // namespace commands diff --git a/engine/cli/commands/model_source_del_cmd.cc b/engine/cli/commands/model_source_del_cmd.cc new file mode 100644 index 000000000..c3c1694e7 --- /dev/null +++ b/engine/cli/commands/model_source_del_cmd.cc @@ -0,0 +1,39 @@ +#include "model_source_del_cmd.h" +#include "server_start_cmd.h" +#include "utils/json_helper.h" +#include "utils/logging_utils.h" + +namespace commands { +bool ModelSourceDelCmd::Exec(const std::string& host, int port, const std::string& model_source) { + // Start server if server is not started yet + if (!commands::IsServerAlive(host, port)) { + CLI_LOG("Starting server ..."); + commands::ServerStartCmd ssc; + if (!ssc.Exec(host, port)) { + return false; + } + } + + auto url = url_parser::Url{ + .protocol = "http", + .host = host + ":" + std::to_string(port), + .pathParams = {"v1", "models", "sources"}, + }; + + Json::Value json_data; + json_data["source"] = model_source; + + auto data_str = json_data.toStyledString(); + auto res = curl_utils::SimpleDeleteJson(url.ToFullPath(), data_str); + if (res.has_error()) { + auto root = json_helper::ParseJsonString(res.error()); + CLI_LOG(root["message"].asString()); + return false; + } + + CLI_LOG("Removed model source: " << model_source); + return true; +} + + +}; // namespace commands diff --git a/engine/cli/commands/model_source_del_cmd.h b/engine/cli/commands/model_source_del_cmd.h new file mode 100644 index 000000000..5015a609a --- /dev/null +++ b/engine/cli/commands/model_source_del_cmd.h @@ -0,0 +1,12 @@ +#pragma once + +#include +#include + +namespace commands { + +class ModelSourceDelCmd { + public: + bool Exec(const std::string& host, int port, const std::string& model_source); +}; +} // namespace commands diff --git a/engine/cli/commands/model_source_list_cmd.cc b/engine/cli/commands/model_source_list_cmd.cc new file mode 100644 index 000000000..ae69c5aef --- /dev/null +++ b/engine/cli/commands/model_source_list_cmd.cc @@ -0,0 +1,56 @@ +#include "model_source_list_cmd.h" +#include +#include +#include +#include +#include "server_start_cmd.h" +#include "utils/curl_utils.h" +#include "utils/json_helper.h" +#include "utils/logging_utils.h" +#include "utils/string_utils.h" +#include "utils/url_parser.h" +// clang-format off +#include +// clang-format on + +namespace commands { + +bool ModelSourceListCmd::Exec(const std::string& host, int port) { + // Start server if server is not started yet + if (!commands::IsServerAlive(host, port)) { + CLI_LOG("Starting server ..."); + commands::ServerStartCmd ssc; + if (!ssc.Exec(host, port)) { + return false; + } + } + + tabulate::Table table; + table.add_row({"#", "Model Source"}); + + auto url = url_parser::Url{ + .protocol = "http", + .host = host + ":" + std::to_string(port), + .pathParams = {"v1", "models", "sources"}, + }; + auto result = curl_utils::SimpleGetJson(url.ToFullPath()); + if (result.has_error()) { + CTL_ERR(result.error()); + return false; + } + table.format().font_color(tabulate::Color::green); + int count = 0; + + if (!result.value()["data"].isNull()) { + for (auto const& v : result.value()["data"]) { + auto model_source = v.asString(); + count += 1; + std::vector row = {std::to_string(count), model_source}; + table.add_row({row.begin(), row.end()}); + } + } + + std::cout << table << std::endl; + return true; +} +}; // namespace commands diff --git a/engine/cli/commands/model_source_list_cmd.h b/engine/cli/commands/model_source_list_cmd.h new file mode 100644 index 000000000..99116f592 --- /dev/null +++ b/engine/cli/commands/model_source_list_cmd.h @@ -0,0 +1,11 @@ +#pragma once + +#include + +namespace commands { + +class ModelSourceListCmd { + public: + bool Exec(const std::string& host, int port); +}; +} // namespace commands diff --git a/engine/controllers/models.cc b/engine/controllers/models.cc index 3f91da848..affa45d52 100644 --- a/engine/controllers/models.cc +++ b/engine/controllers/models.cc @@ -172,6 +172,28 @@ void Models::ListModel( if (list_entry) { for (const auto& model_entry : list_entry.value()) { try { + if (model_entry.status == cortex::db::ModelStatus::Downloadable) { + Json::Value obj; + obj["id"] = model_entry.model; + obj["model"] = model_entry.model; + auto status_to_string = [](cortex::db::ModelStatus status) { + switch (status) { + case cortex::db::ModelStatus::Remote: + return "remote"; + case cortex::db::ModelStatus::Downloaded: + return "downloaded"; + case cortex::db::ModelStatus::Downloadable: + return "downloadable"; + } + return "unknown"; + }; + obj["modelSource"] = model_entry.model_source; + obj["status"] = status_to_string(model_entry.status); + obj["engine"] = model_entry.engine; + obj["metadata"] = model_entry.metadata; + data.append(std::move(obj)); + continue; + } yaml_handler.ModelConfigFromFile( fmu::ToAbsoluteCortexDataPath( fs::path(model_entry.path_to_model_yaml)) @@ -182,7 +204,7 @@ void Models::ListModel( Json::Value obj = model_config.ToJson(); obj["id"] = model_entry.model; obj["model"] = model_entry.model; - obj["model"] = model_entry.model; + obj["status"] = "downloaded"; auto es = model_service_->GetEstimation(model_entry.model); if (es.has_value() && !!es.value()) { obj["recommendation"] = hardware::ToJson(*(es.value())); @@ -723,4 +745,78 @@ void Models::AddRemoteModel( resp->setStatusCode(k400BadRequest); callback(resp); } +} + +void Models::AddModelSource( + const HttpRequestPtr& req, + std::function&& callback) { + if (!http_util::HasFieldInReq(req, callback, "source")) { + return; + } + + auto model_source = (*(req->getJsonObject())).get("source", "").asString(); + auto res = model_src_svc_->AddModelSource(model_source); + if (res.has_error()) { + Json::Value ret; + ret["message"] = res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + } else { + auto const& info = res.value(); + Json::Value ret; + ret["message"] = "Model source is added successfully!"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k200OK); + callback(resp); + } +} + +void Models::DeleteModelSource( + const HttpRequestPtr& req, + std::function&& callback) { + if (!http_util::HasFieldInReq(req, callback, "source")) { + return; + } + + auto model_source = (*(req->getJsonObject())).get("source", "").asString(); + auto res = model_src_svc_->RemoveModelSource(model_source); + if (res.has_error()) { + Json::Value ret; + ret["message"] = res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + } else { + auto const& info = res.value(); + Json::Value ret; + ret["message"] = "Model source is deleted successfully!"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k200OK); + callback(resp); + } +} + +void Models::GetModelSources( + const HttpRequestPtr& req, + std::function&& callback) { + auto res = model_src_svc_->GetModelSources(); + if (res.has_error()) { + Json::Value ret; + ret["message"] = res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + } else { + auto const& info = res.value(); + Json::Value ret; + Json::Value data(Json::arrayValue); + for (auto const& i : info) { + data.append(i); + } + ret["data"] = data; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k200OK); + callback(resp); + } } \ No newline at end of file diff --git a/engine/controllers/models.h b/engine/controllers/models.h index b2b288adc..d3200f33a 100644 --- a/engine/controllers/models.h +++ b/engine/controllers/models.h @@ -4,6 +4,7 @@ #include #include "services/engine_service.h" #include "services/model_service.h" +#include "services/model_source_service.h" using namespace drogon; @@ -23,6 +24,9 @@ class Models : public drogon::HttpController { METHOD_ADD(Models::GetModelStatus, "/status/{1}", Get); METHOD_ADD(Models::AddRemoteModel, "/add", Options, Post); METHOD_ADD(Models::GetRemoteModels, "/remote/{1}", Get); + METHOD_ADD(Models::AddModelSource, "/sources", Post); + METHOD_ADD(Models::DeleteModelSource, "/sources", Delete); + METHOD_ADD(Models::GetModelSources, "/sources", Get); ADD_METHOD_TO(Models::PullModel, "/v1/models/pull", Options, Post); ADD_METHOD_TO(Models::AbortPullModel, "/v1/models/pull", Options, Delete); @@ -36,11 +40,17 @@ class Models : public drogon::HttpController { ADD_METHOD_TO(Models::GetModelStatus, "/v1/models/status/{1}", Get); ADD_METHOD_TO(Models::AddRemoteModel, "/v1/models/add", Options, Post); ADD_METHOD_TO(Models::GetRemoteModels, "/v1/models/remote/{1}", Get); + ADD_METHOD_TO(Models::AddModelSource, "/v1/models/sources", Post); + ADD_METHOD_TO(Models::DeleteModelSource, "/v1/models/sources", Delete); + ADD_METHOD_TO(Models::GetModelSources, "/v1/models/sources", Get); METHOD_LIST_END explicit Models(std::shared_ptr model_service, - std::shared_ptr engine_service) - : model_service_{model_service}, engine_service_{engine_service} {} + std::shared_ptr engine_service, + std::shared_ptr mss) + : model_service_{model_service}, + engine_service_{engine_service}, + model_src_svc_(mss) {} void PullModel(const HttpRequestPtr& req, std::function&& callback); @@ -84,7 +94,18 @@ class Models : public drogon::HttpController { std::function&& callback, const std::string& engine_id); + void AddModelSource(const HttpRequestPtr& req, + std::function&& callback); + + void DeleteModelSource( + const HttpRequestPtr& req, + std::function&& callback); + + void GetModelSources(const HttpRequestPtr& req, + std::function&& callback); + private: std::shared_ptr model_service_; std::shared_ptr engine_service_; + std::shared_ptr model_src_svc_; }; diff --git a/engine/database/models.cc b/engine/database/models.cc index 8c8be9eaf..67ff1a8c9 100644 --- a/engine/database/models.cc +++ b/engine/database/models.cc @@ -18,8 +18,8 @@ std::string Models::StatusToString(ModelStatus status) const { return "remote"; case ModelStatus::Downloaded: return "downloaded"; - case ModelStatus::Undownloaded: - return "undownloaded"; + case ModelStatus::Downloadable: + return "downloadable"; } return "unknown"; } @@ -31,8 +31,8 @@ ModelStatus Models::StringToStatus(const std::string& status_str) const { return ModelStatus::Remote; } else if (status_str == "downloaded" || status_str.empty()) { return ModelStatus::Downloaded; - } else if (status_str == "undownloaded") { - return ModelStatus::Undownloaded; + } else if (status_str == "downloadable") { + return ModelStatus::Downloadable; } throw std::invalid_argument("Invalid status string"); } @@ -50,23 +50,21 @@ cpp::result, std::string> Models::LoadModelList() } bool Models::IsUnique(const std::vector& entries, - const std::string& model_id, - const std::string& model_alias) const { + const std::string& model_id) const { return std::none_of( - entries.begin(), entries.end(), [&](const ModelEntry& entry) { - return entry.model == model_id || entry.model_alias == model_id || - entry.model == model_alias || entry.model_alias == model_alias; - }); + entries.begin(), entries.end(), + [&](const ModelEntry& entry) { return entry.model == model_id; }); } cpp::result, std::string> Models::LoadModelListNoLock() const { try { std::vector entries; - SQLite::Statement query(db_, - "SELECT model_id, author_repo_id, branch_name, " - "path_to_model_yaml, model_alias, model_format, " - "model_source, status, engine FROM models"); + SQLite::Statement query( + db_, + "SELECT model_id, author_repo_id, branch_name, " + "path_to_model_yaml, model_alias, model_format, " + "model_source, status, engine, metadata FROM models"); while (query.executeStep()) { ModelEntry entry; @@ -79,6 +77,7 @@ cpp::result, std::string> Models::LoadModelListNoLock() entry.model_source = query.getColumn(6).getString(); entry.status = StringToStatus(query.getColumn(7).getString()); entry.engine = query.getColumn(8).getString(); + entry.metadata = query.getColumn(9).getString(); entries.push_back(entry); } return entries; @@ -88,77 +87,17 @@ cpp::result, std::string> Models::LoadModelListNoLock() } } -std::string Models::GenerateShortenedAlias( - const std::string& model_id, const std::vector& entries) const { - std::vector parts; - std::istringstream iss(model_id); - std::string part; - while (std::getline(iss, part, ':')) { - parts.push_back(part); - } - - if (parts.empty()) { - return model_id; // Return original if no parts - } - - // Extract the filename without extension - std::string filename = parts.back(); - size_t last_dot_pos = filename.find_last_of('.'); - if (last_dot_pos != std::string::npos) { - filename = filename.substr(0, last_dot_pos); - } - - // Convert to lowercase - std::transform(filename.begin(), filename.end(), filename.begin(), - [](unsigned char c) { return std::tolower(c); }); - - // Generate alias candidates - std::vector candidates; - candidates.push_back(filename); - - if (parts.size() >= 2) { - candidates.push_back(parts[parts.size() - 2] + ":" + filename); - } - - if (parts.size() >= 3) { - candidates.push_back(parts[parts.size() - 3] + ":" + - parts[parts.size() - 2] + ":" + filename); - } - - if (parts.size() >= 4) { - candidates.push_back(parts[0] + ":" + parts[1] + ":" + - parts[parts.size() - 2] + ":" + filename); - } - - // Find the first unique candidate - for (const auto& candidate : candidates) { - if (IsUnique(entries, model_id, candidate)) { - return candidate; - } - } - - // If all candidates are taken, append a number to the last candidate - std::string base_candidate = candidates.back(); - int suffix = 1; - std::string unique_candidate = base_candidate; - while (!IsUnique(entries, model_id, unique_candidate)) { - unique_candidate = base_candidate + "-" + std::to_string(suffix++); - } - - return unique_candidate; -} - cpp::result Models::GetModelInfo( const std::string& identifier) const { try { - SQLite::Statement query(db_, - "SELECT model_id, author_repo_id, branch_name, " - "path_to_model_yaml, model_alias, model_format, " - "model_source, status, engine FROM models " - "WHERE model_id = ? OR model_alias = ?"); + SQLite::Statement query( + db_, + "SELECT model_id, author_repo_id, branch_name, " + "path_to_model_yaml, model_alias, model_format, " + "model_source, status, engine, metadata FROM models " + "WHERE model_id = ?"); query.bind(1, identifier); - query.bind(2, identifier); if (query.executeStep()) { ModelEntry entry; entry.model = query.getColumn(0).getString(); @@ -170,6 +109,7 @@ cpp::result Models::GetModelInfo( entry.model_source = query.getColumn(6).getString(); entry.status = StringToStatus(query.getColumn(7).getString()); entry.engine = query.getColumn(8).getString(); + entry.metadata = query.getColumn(9).getString(); return entry; } else { return cpp::fail("Model not found: " + identifier); @@ -189,10 +129,10 @@ void Models::PrintModelInfo(const ModelEntry& entry) const { LOG_INFO << "Model Source: " << entry.model_source; LOG_INFO << "Status: " << StatusToString(entry.status); LOG_INFO << "Engine: " << entry.engine; + LOG_INFO << "Metadata: " << entry.metadata; } -cpp::result Models::AddModelEntry(ModelEntry new_entry, - bool use_short_alias) { +cpp::result Models::AddModelEntry(ModelEntry new_entry) { try { db_.exec("BEGIN TRANSACTION;"); cortex::utils::ScopeExit se([this] { db_.exec("COMMIT;"); }); @@ -201,17 +141,13 @@ cpp::result Models::AddModelEntry(ModelEntry new_entry, CTL_WRN(model_list.error()); return cpp::fail(model_list.error()); } - if (IsUnique(model_list.value(), new_entry.model, new_entry.model_alias)) { - if (use_short_alias) { - new_entry.model_alias = - GenerateShortenedAlias(new_entry.model, model_list.value()); - } + if (IsUnique(model_list.value(), new_entry.model)) { SQLite::Statement insert( db_, "INSERT INTO models (model_id, author_repo_id, branch_name, " "path_to_model_yaml, model_alias, model_format, model_source, " - "status, engine) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)"); + "status, engine, metadata) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"); insert.bind(1, new_entry.model); insert.bind(2, new_entry.author_repo_id); insert.bind(3, new_entry.branch_name); @@ -221,6 +157,7 @@ cpp::result Models::AddModelEntry(ModelEntry new_entry, insert.bind(7, new_entry.model_source); insert.bind(8, StatusToString(new_entry.status)); insert.bind(9, new_entry.engine); + insert.bind(10, new_entry.metadata); insert.exec(); return true; @@ -242,7 +179,7 @@ cpp::result Models::UpdateModelEntry( db_, "UPDATE models SET author_repo_id = ?, branch_name = ?, " "path_to_model_yaml = ?, model_format = ?, model_source = ?, status = " - "?, engine = ? WHERE model_id = ? OR model_alias = ?"); + "?, engine = ?, metadata = ? WHERE model_id = ?"); upd.bind(1, updated_entry.author_repo_id); upd.bind(2, updated_entry.branch_name); upd.bind(3, updated_entry.path_to_model_yaml); @@ -250,7 +187,7 @@ cpp::result Models::UpdateModelEntry( upd.bind(5, updated_entry.model_source); upd.bind(6, StatusToString(updated_entry.status)); upd.bind(7, updated_entry.engine); - upd.bind(8, identifier); + upd.bind(8, updated_entry.metadata); upd.bind(9, identifier); return upd.exec() == 1; } catch (const std::exception& e) { @@ -258,36 +195,6 @@ cpp::result Models::UpdateModelEntry( } } -cpp::result Models::UpdateModelAlias( - const std::string& model_id, const std::string& new_model_alias) { - if (!HasModel(model_id)) { - return cpp::fail("Model not found: " + model_id); - } - try { - db_.exec("BEGIN TRANSACTION;"); - cortex::utils::ScopeExit se([this] { db_.exec("COMMIT;"); }); - auto model_list = LoadModelListNoLock(); - if (model_list.has_error()) { - CTL_WRN(model_list.error()); - return cpp::fail(model_list.error()); - } - // Check new_model_alias is unique - if (IsUnique(model_list.value(), new_model_alias, new_model_alias)) { - SQLite::Statement upd(db_, - "UPDATE models " - "SET model_alias = ? " - "WHERE model_id = ? OR model_alias = ?"); - upd.bind(1, new_model_alias); - upd.bind(2, model_id); - upd.bind(3, model_id); - return upd.exec() == 1; - } - return false; - } catch (const std::exception& e) { - return cpp::fail(e.what()); - } -} - cpp::result Models::DeleteModelEntry( const std::string& identifier) { try { @@ -296,10 +203,34 @@ cpp::result Models::DeleteModelEntry( return true; } - SQLite::Statement del( - db_, "DELETE from models WHERE model_id = ? OR model_alias = ?"); + SQLite::Statement del(db_, "DELETE from models WHERE model_id = ?"); del.bind(1, identifier); - del.bind(2, identifier); + return del.exec() == 1; + } catch (const std::exception& e) { + return cpp::fail(e.what()); + } +} + +cpp::result Models::DeleteModelEntryWithOrg( + const std::string& src) { + try { + SQLite::Statement del(db_, + "DELETE from models WHERE model_source LIKE ? AND " + "status = \"downloadable\""); + del.bind(1, src + "%"); + return del.exec() == 1; + } catch (const std::exception& e) { + return cpp::fail(e.what()); + } +} + +cpp::result Models::DeleteModelEntryWithRepo( + const std::string& src) { + try { + SQLite::Statement del(db_, + "DELETE from models WHERE model_source = ? AND " + "status = \"downloadable\""); + del.bind(1, src); return del.exec() == 1; } catch (const std::exception& e) { return cpp::fail(e.what()); @@ -310,8 +241,9 @@ cpp::result, std::string> Models::FindRelatedModel( const std::string& identifier) const { try { std::vector related_models; - SQLite::Statement query( - db_, "SELECT model_id FROM models WHERE model_id LIKE ?"); + SQLite::Statement query(db_, + "SELECT model_id FROM models WHERE model_id LIKE ? " + "AND status = \"downloaded\""); query.bind(1, "%" + identifier + "%"); while (query.executeStep()) { @@ -325,11 +257,9 @@ cpp::result, std::string> Models::FindRelatedModel( bool Models::HasModel(const std::string& identifier) const { try { - SQLite::Statement query( - db_, - "SELECT COUNT(*) FROM models WHERE model_id = ? OR model_alias = ?"); + SQLite::Statement query(db_, + "SELECT COUNT(*) FROM models WHERE model_id = ?"); query.bind(1, identifier); - query.bind(2, identifier); if (query.executeStep()) { return query.getColumn(0).getInt() > 0; } @@ -340,4 +270,38 @@ bool Models::HasModel(const std::string& identifier) const { } } +cpp::result, std::string> Models::GetModelSources() + const { + try { + std::vector sources; + SQLite::Statement query(db_, + "SELECT DISTINCT model_source FROM models WHERE " + "status = \"downloadable\""); + + while (query.executeStep()) { + sources.push_back(query.getColumn(0).getString()); + } + return sources; + } catch (const std::exception& e) { + return cpp::fail(e.what()); + } +} + +cpp::result, std::string> Models::GetModels( + const std::string& model_src) const { + try { + std::vector ids; + SQLite::Statement query(db_, + "SELECT model_id FROM models WHERE model_source = " + "? AND status = \"downloadable\""); + query.bind(1, model_src); + while (query.executeStep()) { + ids.push_back(query.getColumn(0).getString()); + } + return ids; + } catch (const std::exception& e) { + return cpp::fail(e.what()); + } +} + } // namespace cortex::db diff --git a/engine/database/models.h b/engine/database/models.h index 5c855cf1b..b0c4bc258 100644 --- a/engine/database/models.h +++ b/engine/database/models.h @@ -8,7 +8,8 @@ namespace cortex::db { -enum class ModelStatus { Remote, Downloaded, Undownloaded }; +enum class ModelStatus { Remote, Downloaded, Downloadable }; + struct ModelEntry { std::string model; @@ -20,6 +21,7 @@ struct ModelEntry { std::string model_source; ModelStatus status; std::string engine; + std::string metadata; }; class Models { @@ -28,8 +30,7 @@ class Models { SQLite::Database& db_; bool IsUnique(const std::vector& entries, - const std::string& model_id, - const std::string& model_alias) const; + const std::string& model_id) const; cpp::result, std::string> LoadModelListNoLock() const; @@ -41,23 +42,24 @@ class Models { Models(); Models(SQLite::Database& db); ~Models(); - std::string GenerateShortenedAlias( - const std::string& model_id, - const std::vector& entries) const; cpp::result GetModelInfo( const std::string& identifier) const; void PrintModelInfo(const ModelEntry& entry) const; - cpp::result AddModelEntry(ModelEntry new_entry, - bool use_short_alias = false); + cpp::result AddModelEntry(ModelEntry new_entry); cpp::result UpdateModelEntry( const std::string& identifier, const ModelEntry& updated_entry); cpp::result DeleteModelEntry( const std::string& identifier); - cpp::result UpdateModelAlias( - const std::string& model_id, const std::string& model_alias); + cpp::result DeleteModelEntryWithOrg( + const std::string& src); + cpp::result DeleteModelEntryWithRepo( + const std::string& src); cpp::result, std::string> FindRelatedModel( const std::string& identifier) const; bool HasModel(const std::string& identifier) const; + cpp::result, std::string> GetModelSources() const; + cpp::result, std::string> GetModels( + const std::string& model_src) const; }; } // namespace cortex::db diff --git a/engine/e2e-test/test_api_model.py b/engine/e2e-test/test_api_model.py index c2723d2ca..8f2e4b07a 100644 --- a/engine/e2e-test/test_api_model.py +++ b/engine/e2e-test/test_api_model.py @@ -129,4 +129,17 @@ async def test_models_start_stop_should_be_successful(self): # delete API print("Delete model") response = requests.delete("http://localhost:3928/v1/models/tinyllama:gguf") - assert response.status_code == 200 \ No newline at end of file + assert response.status_code == 200 + + def test_models_sources_api(self): + json_body = {"source": "https://huggingface.co/cortexso/tinyllama"} + response = requests.post( + "http://localhost:3928/v1/models/sources", json=json_body + ) + assert response.status_code == 200, f"status_code: {response.status_code}" + + json_body = {"source": "https://huggingface.co/cortexso/tinyllama"} + response = requests.delete( + "http://localhost:3928/v1/models/sources", json=json_body + ) + assert response.status_code == 200, f"status_code: {response.status_code}" \ No newline at end of file diff --git a/engine/main.cc b/engine/main.cc index 5222ac5c2..13583dc00 100644 --- a/engine/main.cc +++ b/engine/main.cc @@ -22,6 +22,7 @@ #include "services/file_watcher_service.h" #include "services/message_service.h" #include "services/model_service.h" +#include "services/model_source_service.h" #include "services/thread_service.h" #include "utils/archive_utils.h" #include "utils/cortex_utils.h" @@ -141,6 +142,7 @@ void RunServer(std::optional port, bool ignore_cout) { auto engine_service = std::make_shared(download_service); auto inference_svc = std::make_shared(engine_service); + auto model_src_svc = std::make_shared(); auto model_service = std::make_shared( download_service, inference_svc, engine_service); @@ -154,7 +156,8 @@ void RunServer(std::optional port, bool ignore_cout) { auto thread_ctl = std::make_shared(thread_srv, message_srv); auto message_ctl = std::make_shared(message_srv); auto engine_ctl = std::make_shared(engine_service); - auto model_ctl = std::make_shared(model_service, engine_service); + auto model_ctl = + std::make_shared(model_service, engine_service, model_src_svc); auto event_ctl = std::make_shared(event_queue_ptr); auto pm_ctl = std::make_shared(); auto hw_ctl = std::make_shared(engine_service, hw_service); diff --git a/engine/services/model_service.cc b/engine/services/model_service.cc index 7f79ddaf7..15fee15be 100644 --- a/engine/services/model_service.cc +++ b/engine/services/model_service.cc @@ -64,16 +64,30 @@ void ParseGguf(const DownloadItem& ggufDownloadItem, auto author_id = author.has_value() ? author.value() : "cortexso"; cortex::db::Models modellist_utils_obj; - cortex::db::ModelEntry model_entry{ - .model = ggufDownloadItem.id, - .author_repo_id = author_id, - .branch_name = branch, - .path_to_model_yaml = rel.string(), - .model_alias = ggufDownloadItem.id, - .status = cortex::db::ModelStatus::Downloaded}; - auto result = modellist_utils_obj.AddModelEntry(model_entry, true); - if (result.has_error()) { - CTL_WRN("Error adding model to modellist: " + result.error()); + if (!modellist_utils_obj.HasModel(ggufDownloadItem.id)) { + cortex::db::ModelEntry model_entry{ + .model = ggufDownloadItem.id, + .author_repo_id = author_id, + .branch_name = branch, + .path_to_model_yaml = rel.string(), + .model_alias = ggufDownloadItem.id, + .status = cortex::db::ModelStatus::Downloaded}; + auto result = modellist_utils_obj.AddModelEntry(model_entry); + + if (result.has_error()) { + CTL_ERR("Error adding model to modellist: " + result.error()); + } + } else { + if (auto m = modellist_utils_obj.GetModelInfo(ggufDownloadItem.id); + m.has_value()) { + auto upd_m = m.value(); + upd_m.status = cortex::db::ModelStatus::Downloaded; + if (auto r = + modellist_utils_obj.UpdateModelEntry(ggufDownloadItem.id, upd_m); + r.has_error()) { + CTL_ERR(r.error()); + } + } } } @@ -136,6 +150,9 @@ void ModelService::ForceIndexingModelList() { CTL_DBG("Database model size: " + std::to_string(list_entry.value().size())); for (const auto& model_entry : list_entry.value()) { + if (model_entry.status != cortex::db::ModelStatus::Downloaded) { + continue; + } try { yaml_handler.ModelConfigFromFile( fmu::ToAbsoluteCortexDataPath( @@ -301,7 +318,8 @@ cpp::result ModelService::HandleDownloadUrlAsync( } auto model_entry = modellist_handler.GetModelInfo(unique_model_id); - if (model_entry.has_value()) { + if (model_entry.has_value() && + model_entry->status == cortex::db::ModelStatus::Downloaded) { CLI_LOG("Model already downloaded: " << unique_model_id); return cpp::fail("Please delete the model before downloading again"); } @@ -491,7 +509,8 @@ ModelService::DownloadModelFromCortexsoAsync( } auto model_entry = modellist_handler.GetModelInfo(unique_model_id); - if (model_entry.has_value()) { + if (model_entry.has_value() && + model_entry->status == cortex::db::ModelStatus::Downloaded) { return cpp::fail("Please delete the model before downloading again"); } @@ -532,14 +551,32 @@ ModelService::DownloadModelFromCortexsoAsync( CTL_INF("path_to_model_yaml: " << rel.string()); cortex::db::Models modellist_utils_obj; - cortex::db::ModelEntry model_entry{.model = unique_model_id, - .author_repo_id = "cortexso", - .branch_name = branch, - .path_to_model_yaml = rel.string(), - .model_alias = unique_model_id}; - auto result = modellist_utils_obj.AddModelEntry(model_entry); - if (result.has_error()) { - CTL_ERR("Error adding model to modellist: " + result.error()); + if (!modellist_utils_obj.HasModel(unique_model_id)) { + cortex::db::ModelEntry model_entry{ + .model = unique_model_id, + .author_repo_id = "cortexso", + .branch_name = branch, + .path_to_model_yaml = rel.string(), + .model_alias = unique_model_id, + .status = cortex::db::ModelStatus::Downloaded}; + auto result = modellist_utils_obj.AddModelEntry(model_entry); + + if (result.has_error()) { + CTL_ERR("Error adding model to modellist: " + result.error()); + } + } else { + if (auto m = modellist_utils_obj.GetModelInfo(unique_model_id); + m.has_value()) { + auto upd_m = m.value(); + upd_m.status = cortex::db::ModelStatus::Downloaded; + if (auto r = + modellist_utils_obj.UpdateModelEntry(unique_model_id, upd_m); + r.has_error()) { + CTL_ERR(r.error()); + } + } else { + CTL_WRN("Could not get model entry with model id: " << unique_model_id); + } } }; @@ -585,14 +622,28 @@ cpp::result ModelService::DownloadModelFromCortexso( CTL_INF("path_to_model_yaml: " << rel.string()); cortex::db::Models modellist_utils_obj; - cortex::db::ModelEntry model_entry{.model = model_id, - .author_repo_id = "cortexso", - .branch_name = branch, - .path_to_model_yaml = rel.string(), - .model_alias = model_id}; - auto result = modellist_utils_obj.AddModelEntry(model_entry); - if (result.has_error()) { - CTL_ERR("Error adding model to modellist: " + result.error()); + if (!modellist_utils_obj.HasModel(model_id)) { + cortex::db::ModelEntry model_entry{ + .model = model_id, + .author_repo_id = "cortexso", + .branch_name = branch, + .path_to_model_yaml = rel.string(), + .model_alias = model_id, + .status = cortex::db::ModelStatus::Downloaded}; + auto result = modellist_utils_obj.AddModelEntry(model_entry); + + if (result.has_error()) { + CTL_ERR("Error adding model to modellist: " + result.error()); + } + } else { + if (auto m = modellist_utils_obj.GetModelInfo(model_id); m.has_value()) { + auto upd_m = m.value(); + upd_m.status = cortex::db::ModelStatus::Downloaded; + if (auto r = modellist_utils_obj.UpdateModelEntry(model_id, upd_m); + r.has_error()) { + CTL_ERR(r.error()); + } + } } }; diff --git a/engine/services/model_source_service.cc b/engine/services/model_source_service.cc new file mode 100644 index 000000000..a7d9d5e6e --- /dev/null +++ b/engine/services/model_source_service.cc @@ -0,0 +1,493 @@ +#include "model_source_service.h" +#include +#include +#include "database/models.h" +#include "json/json.h" +#include "utils/curl_utils.h" +#include "utils/huggingface_utils.h" +#include "utils/logging_utils.h" +#include "utils/string_utils.h" +#include "utils/url_parser.h" + +namespace services { +namespace hu = huggingface_utils; + +namespace { +struct ModelInfo { + std::string id; + int likes; + int trending_score; + bool is_private; + int downloads; + std::vector tags; + std::string created_at; + std::string model_id; +}; + +std::vector ParseJsonString(const std::string& json_str) { + std::vector models; + + // Parse the JSON string + Json::Value root; + Json::Reader reader; + bool parsing_successful = reader.parse(json_str, root); + + if (!parsing_successful) { + std::cerr << "Failed to parse JSON" << std::endl; + return models; + } + + // Iterate over the JSON array + for (const auto& model : root) { + ModelInfo info; + info.id = model["id"].asString(); + info.likes = model["likes"].asInt(); + info.trending_score = model["trendingScore"].asInt(); + info.is_private = model["private"].asBool(); + info.downloads = model["downloads"].asInt(); + + const Json::Value& tags = model["tags"]; + for (const auto& tag : tags) { + info.tags.push_back(tag.asString()); + } + + info.created_at = model["createdAt"].asString(); + info.model_id = model["modelId"].asString(); + models.push_back(info); + } + + return models; +} + +} // namespace + +ModelSourceService::ModelSourceService() { + sync_db_thread_ = std::thread(&ModelSourceService::SyncModelSource, this); + running_ = true; +} +ModelSourceService::~ModelSourceService() { + running_ = false; + if (sync_db_thread_.joinable()) { + sync_db_thread_.join(); + } + CTL_INF("Done cleanup thread"); +} + +cpp::result ModelSourceService::AddModelSource( + const std::string& model_source) { + auto res = url_parser::FromUrlString(model_source); + if (res.has_error()) { + return cpp::fail(res.error()); + } else { + auto& r = res.value(); + if (r.pathParams.empty() || r.pathParams.size() > 2) { + return cpp::fail("Invalid model source url: " + model_source); + } + + if (auto is_org = r.pathParams.size() == 1; is_org) { + auto& author = r.pathParams[0]; + if (author == "cortexso") { + return AddCortexsoOrg(model_source); + } else { + return AddHfOrg(model_source, author); + } + } else { // Repo + auto const& author = r.pathParams[0]; + auto const& model_name = r.pathParams[1]; + if (r.pathParams[0] == "cortexso") { + return AddCortexsoRepo(model_source, author, model_name); + } else { + return AddHfRepo(model_source, author, model_name); + } + } + } + return true; +} + +cpp::result ModelSourceService::RemoveModelSource( + const std::string& model_source) { + cortex::db::Models model_db; + auto srcs = model_db.GetModelSources(); + if (srcs.has_error()) { + return cpp::fail(srcs.error()); + } else { + auto& v = srcs.value(); + if (std::find(v.begin(), v.end(), model_source) == v.end()) { + return cpp::fail("Model source does not exist: " + model_source); + } + } + CTL_INF("Remove model source: " << model_source); + auto res = url_parser::FromUrlString(model_source); + if (res.has_error()) { + return cpp::fail(res.error()); + } else { + auto& r = res.value(); + if (r.pathParams.empty() || r.pathParams.size() > 2) { + return cpp::fail("Invalid model source url: " + model_source); + } + + if (r.pathParams.size() == 1) { + if (auto del_res = model_db.DeleteModelEntryWithOrg(model_source); + del_res.has_error()) { + CTL_INF(del_res.error()); + return cpp::fail(del_res.error()); + } + } else { + if (auto del_res = model_db.DeleteModelEntryWithRepo(model_source); + del_res.has_error()) { + CTL_INF(del_res.error()); + return cpp::fail(del_res.error()); + } + } + } + return true; +} + +cpp::result, std::string> +ModelSourceService::GetModelSources() { + cortex::db::Models model_db; + return model_db.GetModelSources(); +} + +cpp::result ModelSourceService::AddHfOrg( + const std::string& model_source, const std::string& author) { + auto res = curl_utils::SimpleGet("https://huggingface.co/api/models?author=" + + author); + if (res.has_value()) { + auto models = ParseJsonString(res.value()); + // Get models from db + cortex::db::Models model_db; + + auto model_list_before = + model_db.GetModels(model_source).value_or(std::vector{}); + std::unordered_set updated_model_list; + // Add new models + for (auto const& m : models) { + CTL_DBG(m.id); + auto author_model = string_utils::SplitBy(m.id, "/"); + if (author_model.size() == 2) { + auto const& author = author_model[0]; + auto const& model_name = author_model[1]; + auto add_res = AddRepoSiblings(model_source, author, model_name) + .value_or(std::unordered_set{}); + for (auto const& a : add_res) { + updated_model_list.insert(a); + } + } + } + + // Clean up + for (auto const& mid : model_list_before) { + if (updated_model_list.find(mid) == updated_model_list.end()) { + if (auto del_res = model_db.DeleteModelEntry(mid); + del_res.has_error()) { + CTL_INF(del_res.error()); + } + } + } + } else { + return cpp::fail(res.error()); + } + return true; +} + +cpp::result ModelSourceService::AddHfRepo( + const std::string& model_source, const std::string& author, + const std::string& model_name) { + // Get models from db + cortex::db::Models model_db; + + auto model_list_before = + model_db.GetModels(model_source).value_or(std::vector{}); + std::unordered_set updated_model_list; + auto add_res = AddRepoSiblings(model_source, author, model_name); + if (add_res.has_error()) { + return cpp::fail(add_res.error()); + } else { + updated_model_list = add_res.value(); + } + for (auto const& mid : model_list_before) { + if (updated_model_list.find(mid) == updated_model_list.end()) { + if (auto del_res = model_db.DeleteModelEntry(mid); del_res.has_error()) { + CTL_INF(del_res.error()); + } + } + } + return true; +} + +cpp::result, std::string> +ModelSourceService::AddRepoSiblings(const std::string& model_source, + const std::string& author, + const std::string& model_name) { + std::unordered_set res; + auto repo_info = hu::GetHuggingFaceModelRepoInfo(author, model_name); + if (repo_info.has_error()) { + return cpp::fail(repo_info.error()); + } + + if (!repo_info->gguf.has_value()) { + return cpp::fail( + "Not a GGUF model. Currently, only GGUF single file is " + "supported."); + } + + for (const auto& sibling : repo_info->siblings) { + if (string_utils::EndsWith(sibling.rfilename, ".gguf")) { + cortex::db::Models model_db; + std::string model_id = + author + ":" + model_name + ":" + sibling.rfilename; + cortex::db::ModelEntry e = { + .model = model_id, + .author_repo_id = author, + .branch_name = "main", + .path_to_model_yaml = "", + .model_alias = "", + .model_format = "hf-gguf", + .model_source = model_source, + .status = cortex::db::ModelStatus::Downloadable, + .engine = "llama-cpp", + .metadata = repo_info->metadata}; + if (!model_db.HasModel(model_id)) { + if (auto add_res = model_db.AddModelEntry(e); add_res.has_error()) { + CTL_INF(add_res.error()); + } + } else { + if (auto m = model_db.GetModelInfo(model_id); + m.has_value() && + m->status == cortex::db::ModelStatus::Downloadable) { + if (auto upd_res = model_db.UpdateModelEntry(model_id, e); + upd_res.has_error()) { + CTL_INF(upd_res.error()); + } + } + } + res.insert(model_id); + } + } + + return res; +} + +cpp::result ModelSourceService::AddCortexsoOrg( + const std::string& model_source) { + auto res = curl_utils::SimpleGet( + "https://huggingface.co/api/models?author=cortexso"); + if (res.has_value()) { + auto models = ParseJsonString(res.value()); + // Get models from db + cortex::db::Models model_db; + + auto model_list_before = + model_db.GetModels(model_source).value_or(std::vector{}); + std::unordered_set updated_model_list; + for (auto const& m : models) { + CTL_INF(m.id); + auto author_model = string_utils::SplitBy(m.id, "/"); + if (author_model.size() == 2) { + auto const& author = author_model[0]; + auto const& model_name = author_model[1]; + auto branches = huggingface_utils::GetModelRepositoryBranches( + "cortexso", model_name); + if (branches.has_error()) { + CTL_INF(branches.error()); + continue; + } + + auto repo_info = hu::GetHuggingFaceModelRepoInfo(author, model_name); + if (repo_info.has_error()) { + CTL_INF(repo_info.error()); + continue; + } + for (auto const& [branch, _] : branches.value()) { + CTL_INF(branch); + auto add_res = AddCortexsoRepoBranch(model_source, author, model_name, + branch, repo_info->metadata) + .value_or(std::unordered_set{}); + for (auto const& a : add_res) { + updated_model_list.insert(a); + } + } + } + } + // Clean up + for (auto const& mid : model_list_before) { + if (updated_model_list.find(mid) == updated_model_list.end()) { + if (auto del_res = model_db.DeleteModelEntry(mid); + del_res.has_error()) { + CTL_INF(del_res.error()); + } + } + } + } else { + return cpp::fail(res.error()); + } + + return true; +} + +cpp::result ModelSourceService::AddCortexsoRepo( + const std::string& model_source, const std::string& author, + const std::string& model_name) { + auto branches = + huggingface_utils::GetModelRepositoryBranches("cortexso", model_name); + if (branches.has_error()) { + return cpp::fail(branches.error()); + } + + auto repo_info = hu::GetHuggingFaceModelRepoInfo(author, model_name); + if (repo_info.has_error()) { + return cpp::fail(repo_info.error()); + } + // Get models from db + cortex::db::Models model_db; + + auto model_list_before = + model_db.GetModels(model_source).value_or(std::vector{}); + std::unordered_set updated_model_list; + + for (auto const& [branch, _] : branches.value()) { + CTL_INF(branch); + auto add_res = AddCortexsoRepoBranch(model_source, author, model_name, + branch, repo_info->metadata) + .value_or(std::unordered_set{}); + for (auto const& a : add_res) { + updated_model_list.insert(a); + } + } + + // Clean up + for (auto const& mid : model_list_before) { + if (updated_model_list.find(mid) == updated_model_list.end()) { + if (auto del_res = model_db.DeleteModelEntry(mid); del_res.has_error()) { + CTL_INF(del_res.error()); + } + } + } + return true; +} + +cpp::result, std::string> +ModelSourceService::AddCortexsoRepoBranch(const std::string& model_source, + const std::string& author, + const std::string& model_name, + const std::string& branch, + const std::string& metadata) { + std::unordered_set res; + + url_parser::Url url = { + .protocol = "https", + .host = kHuggingFaceHost, + .pathParams = {"api", "models", "cortexso", model_name, "tree", branch}, + }; + + auto result = curl_utils::SimpleGetJson(url.ToFullPath()); + if (result.has_error()) { + return cpp::fail("Model " + model_name + " not found"); + } + + bool has_gguf = false; + for (const auto& value : result.value()) { + auto path = value["path"].asString(); + if (path.find(".gguf") != std::string::npos) { + has_gguf = true; + } + } + if (!has_gguf) { + CTL_INF("Only support gguf file format! - branch: " << branch); + return {}; + } else { + cortex::db::Models model_db; + std::string model_id = model_name + ":" + branch; + cortex::db::ModelEntry e = {.model = model_id, + .author_repo_id = author, + .branch_name = branch, + .path_to_model_yaml = "", + .model_alias = "", + .model_format = "cortexso", + .model_source = model_source, + .status = cortex::db::ModelStatus::Downloadable, + .engine = "llama-cpp", + .metadata = metadata}; + if (!model_db.HasModel(model_id)) { + CTL_INF("Adding model to db: " << model_name << ":" << branch); + if (auto res = model_db.AddModelEntry(e); + res.has_error() || !res.value()) { + CTL_DBG("Cannot add model to db: " << model_id); + } + } else { + if (auto m = model_db.GetModelInfo(model_id); + m.has_value() && m->status == cortex::db::ModelStatus::Downloadable) { + if (auto upd_res = model_db.UpdateModelEntry(model_id, e); + upd_res.has_error()) { + CTL_INF(upd_res.error()); + } + } + } + res.insert(model_id); + } + return res; +} + +void ModelSourceService::SyncModelSource() { + // Do interval check for 10 minutes + constexpr const int kIntervalCheck = 10 * 60; + auto start_time = std::chrono::steady_clock::now(); + while (running_) { + std::this_thread::sleep_for(std::chrono::milliseconds(500)); + auto current_time = std::chrono::steady_clock::now(); + auto elapsed_time = std::chrono::duration_cast( + current_time - start_time) + .count(); + + if (elapsed_time > kIntervalCheck) { + CTL_DBG("Start to sync cortex.db"); + start_time = current_time; + + cortex::db::Models model_db; + auto res = model_db.GetModelSources(); + if (res.has_error()) { + CTL_INF(res.error()); + } else { + for (auto const& src : res.value()) { + CTL_DBG(src); + } + + std::unordered_set orgs; + std::vector repos; + for (auto const& src : res.value()) { + auto url_res = url_parser::FromUrlString(src); + if (url_res.has_value()) { + if (url_res->pathParams.size() == 1) { + orgs.insert(src); + } else if (url_res->pathParams.size() == 2) { + repos.push_back(src); + } + } + } + + // Get list to update + std::vector update_cand(orgs.begin(), orgs.end()); + auto get_org = [](const std::string& rp) { + return rp.substr(0, rp.find_last_of("/")); + }; + for (auto const& repo : repos) { + if (orgs.find(get_org(repo)) != orgs.end()) { + update_cand.push_back(repo); + } + } + + // Sync cortex.db with the upstream data + for (auto const& c : update_cand) { + if (auto res = AddModelSource(c); res.has_error()) { + CTL_INF(res.error();) + } + } + } + + CTL_DBG("Done sync cortex.db"); + } + } +} + +} // namespace services \ No newline at end of file diff --git a/engine/services/model_source_service.h b/engine/services/model_source_service.h new file mode 100644 index 000000000..aa0b37259 --- /dev/null +++ b/engine/services/model_source_service.h @@ -0,0 +1,53 @@ +#pragma once +#include +#include +#include +#include "utils/result.hpp" + +namespace services { +class ModelSourceService { + public: + explicit ModelSourceService(); + ~ModelSourceService(); + + cpp::result AddModelSource( + const std::string& model_source); + + cpp::result RemoveModelSource( + const std::string& model_source); + + cpp::result, std::string> GetModelSources(); + + private: + cpp::result AddHfOrg(const std::string& model_source, + const std::string& author); + + cpp::result AddHfRepo( + const std::string& model_source, const std::string& author, + const std::string& model_name); + + cpp::result, std::string> AddRepoSiblings( + const std::string& model_source, const std::string& author, + const std::string& model_name); + + cpp::result AddCortexsoOrg( + const std::string& model_source); + + cpp::result AddCortexsoRepo( + const std::string& model_source, const std::string& author, + const std::string& model_name); + + cpp::result, std::string> + AddCortexsoRepoBranch(const std::string& model_source, + const std::string& author, + const std::string& model_name, + const std::string& branch, + const std::string& metadata); + + void SyncModelSource(); + + private: + std::thread sync_db_thread_; + std::atomic running_; +}; +} // namespace services \ No newline at end of file diff --git a/engine/test/components/test_models_db.cc b/engine/test/components/test_models_db.cc index ab0ea9f70..06294aa8c 100644 --- a/engine/test/components/test_models_db.cc +++ b/engine/test/components/test_models_db.cc @@ -24,7 +24,8 @@ class ModelsTestSuite : public ::testing::Test { "model_format TEXT," "model_source TEXT," "status TEXT," - "engine TEXT" + "engine TEXT," + "metadata TEXT" ")"); } catch (const std::exception& e) {} } @@ -70,10 +71,6 @@ TEST_F(ModelsTestSuite, TestGetModelInfo) { EXPECT_TRUE(model_by_id.has_value()); EXPECT_EQ(model_by_id.value().model, kTestModel.model); - auto model_by_alias = model_list_.GetModelInfo("test_alias"); - EXPECT_TRUE(model_by_alias); - EXPECT_EQ(model_by_alias.value().model, kTestModel.model); - EXPECT_TRUE(model_list_.GetModelInfo("non_existent_model").has_error()); // Clean up @@ -104,26 +101,6 @@ TEST_F(ModelsTestSuite, TestDeleteModelEntry) { EXPECT_TRUE(model_list_.GetModelInfo(kTestModel.model).has_error()); } -TEST_F(ModelsTestSuite, TestGenerateShortenedAlias) { - EXPECT_TRUE(model_list_.AddModelEntry(kTestModel).value()); - auto models1 = model_list_.LoadModelList(); - auto alias = model_list_.GenerateShortenedAlias( - "huggingface.co:bartowski:llama3.1-7b-gguf:Model_ID_Xxx.gguf", - models1.value()); - EXPECT_EQ(alias, "model_id_xxx"); - EXPECT_TRUE(model_list_.UpdateModelAlias(kTestModel.model, alias).value()); - - // Test with existing entries to force longer alias - auto models2 = model_list_.LoadModelList(); - alias = model_list_.GenerateShortenedAlias( - "huggingface.co:bartowski:llama3.1-7b-gguf:Model_ID_Xxx.gguf", - models2.value()); - EXPECT_EQ(alias, "llama3.1-7b-gguf:model_id_xxx"); - - // Clean up - EXPECT_TRUE(model_list_.DeleteModelEntry(kTestModel.model).value()); -} - TEST_F(ModelsTestSuite, TestPersistence) { EXPECT_TRUE(model_list_.AddModelEntry(kTestModel).value()); @@ -136,53 +113,10 @@ TEST_F(ModelsTestSuite, TestPersistence) { EXPECT_TRUE(model_list_.DeleteModelEntry(kTestModel.model).value()); } -TEST_F(ModelsTestSuite, TestUpdateModelAlias) { - constexpr const auto kNewTestAlias = "new_test_alias"; - constexpr const auto kNonExistentModel = "non_existent_model"; - constexpr const auto kAnotherAlias = "another_alias"; - constexpr const auto kFinalTestAlias = "final_test_alias"; - constexpr const auto kAnotherModelId = "another_model_id"; - // Add the test model - ASSERT_TRUE(model_list_.AddModelEntry(kTestModel).value()); - - // Test successful update - EXPECT_TRUE( - model_list_.UpdateModelAlias(kTestModel.model, kNewTestAlias).value()); - auto updated_model = model_list_.GetModelInfo(kNewTestAlias); - EXPECT_TRUE(updated_model.has_value()); - EXPECT_EQ(updated_model.value().model_alias, kNewTestAlias); - EXPECT_EQ(updated_model.value().model, kTestModel.model); - - // Test update with non-existent model - EXPECT_TRUE(model_list_.UpdateModelAlias(kNonExistentModel, kAnotherAlias) - .has_error()); - - // Test update with non-unique alias - cortex::db::ModelEntry another_model = kTestModel; - another_model.model = kAnotherModelId; - another_model.model_alias = kAnotherAlias; - ASSERT_TRUE(model_list_.AddModelEntry(another_model).value()); - - EXPECT_FALSE( - model_list_.UpdateModelAlias(kTestModel.model, kAnotherAlias).value()); - - // Test update using model alias instead of model ID - EXPECT_TRUE(model_list_.UpdateModelAlias(kNewTestAlias, kFinalTestAlias)); - updated_model = model_list_.GetModelInfo(kFinalTestAlias); - EXPECT_TRUE(updated_model); - EXPECT_EQ(updated_model.value().model_alias, kFinalTestAlias); - EXPECT_EQ(updated_model.value().model, kTestModel.model); - - // Clean up - EXPECT_TRUE(model_list_.DeleteModelEntry(kTestModel.model).value()); - EXPECT_TRUE(model_list_.DeleteModelEntry(kAnotherModelId).value()); -} - TEST_F(ModelsTestSuite, TestHasModel) { EXPECT_TRUE(model_list_.AddModelEntry(kTestModel).value()); EXPECT_TRUE(model_list_.HasModel(kTestModel.model)); - EXPECT_TRUE(model_list_.HasModel("test_alias")); EXPECT_FALSE(model_list_.HasModel("non_existent_model")); // Clean up EXPECT_TRUE(model_list_.DeleteModelEntry(kTestModel.model).value()); diff --git a/engine/utils/huggingface_utils.h b/engine/utils/huggingface_utils.h index f2895c363..1d1040612 100644 --- a/engine/utils/huggingface_utils.h +++ b/engine/utils/huggingface_utils.h @@ -67,6 +67,7 @@ struct HuggingFaceModelRepoInfo { std::vector siblings; std::vector spaces; std::string createdAt; + std::string metadata; static cpp::result FromJson( const Json::Value& body) { @@ -104,6 +105,7 @@ struct HuggingFaceModelRepoInfo { .spaces = json_parser_utils::ParseJsonArray(body["spaces"]), .createdAt = body["createdAt"].asString(), + .metadata = body.toStyledString(), }; } diff --git a/engine/utils/json_parser_utils.h b/engine/utils/json_parser_utils.h index 3ebd2c546..b4ea1a7e1 100644 --- a/engine/utils/json_parser_utils.h +++ b/engine/utils/json_parser_utils.h @@ -10,7 +10,7 @@ template T jsonToValue(const Json::Value& value); template <> -std::string jsonToValue(const Json::Value& value) { +inline std::string jsonToValue(const Json::Value& value) { return value.asString(); }