From 4ba2b4992e0c0867d63b13928446e29799ade6dc Mon Sep 17 00:00:00 2001 From: Louis Le Date: Fri, 1 Nov 2024 10:32:30 +0700 Subject: [PATCH] chore: desired name - model pull API --- engine/controllers/models.cc | 11 ++++++++-- .../test_api_model_pull_direct_url.py | 19 ++++++++++++++++ engine/services/model_service.cc | 22 ++++++++++++------- engine/services/model_service.h | 5 +++-- 4 files changed, 45 insertions(+), 12 deletions(-) diff --git a/engine/controllers/models.cc b/engine/controllers/models.cc index 826471487..939f63f31 100644 --- a/engine/controllers/models.cc +++ b/engine/controllers/models.cc @@ -33,12 +33,19 @@ void Models::PullModel(const HttpRequestPtr& req, desired_model_id = id; } + std::optional desired_model_name = std::nullopt; + auto name_value = (*(req->getJsonObject())).get("name", "").asString(); + + if (!name_value.empty()) { + desired_model_name = name_value; + } + auto handle_model_input = [&, model_handle]() -> cpp::result { CTL_INF("Handle model input, model handle: " + model_handle); if (string_utils::StartsWith(model_handle, "https")) { - return model_service_->HandleDownloadUrlAsync(model_handle, - desired_model_id); + return model_service_->HandleDownloadUrlAsync( + model_handle, desired_model_id, desired_model_name); } else if (model_handle.find(":") != std::string::npos) { auto model_and_branch = string_utils::SplitBy(model_handle, ":"); return model_service_->DownloadModelFromCortexsoAsync( diff --git a/engine/e2e-test/test_api_model_pull_direct_url.py b/engine/e2e-test/test_api_model_pull_direct_url.py index e93ca2ddd..27969216a 100644 --- a/engine/e2e-test/test_api_model_pull_direct_url.py +++ b/engine/e2e-test/test_api_model_pull_direct_url.py @@ -53,3 +53,22 @@ async def test_model_pull_with_direct_url_should_be_success(self): get_model_response.json()["model"] == "TheBloke:TinyLlama-1.1B-Chat-v0.3-GGUF:tinyllama-1.1b-chat-v0.3.Q2_K.gguf" ) + + @pytest.mark.asyncio + async def test_model_pull_with_direct_url_should_have_desired_name(self): + myobj = { + "model": "https://huggingface.co/afrideva/zephyr-smol_llama-100m-sft-full-GGUF/blob/main/zephyr-smol_llama-100m-sft-full.q2_k.gguf", + "name": "smol_llama_100m" + } + response = requests.post("http://localhost:3928/models/pull", json=myobj) + assert response.status_code == 200 + await wait_for_websocket_download_success_event(timeout=None) + get_model_response = requests.get( + "http://127.0.0.1:3928/models/afrideva:zephyr-smol_llama-100m-sft-full-GGUF:zephyr-smol_llama-100m-sft-full.q2_k.gguf" + ) + assert get_model_response.status_code == 200 + print(get_model_response.json()["name"]) + assert ( + get_model_response.json()["name"] + == "smol_llama_100m" + ) diff --git a/engine/services/model_service.cc b/engine/services/model_service.cc index b49df3420..4967b1dd9 100644 --- a/engine/services/model_service.cc +++ b/engine/services/model_service.cc @@ -17,7 +17,8 @@ namespace { void ParseGguf(const DownloadItem& ggufDownloadItem, - std::optional author) { + std::optional author, + std::optional name) { namespace fs = std::filesystem; namespace fmu = file_manager_utils; config::GGUFHandler gguf_handler; @@ -32,6 +33,8 @@ void ParseGguf(const DownloadItem& ggufDownloadItem, fmu::ToRelativeCortexDataPath(fs::path(ggufDownloadItem.localPath)); model_config.files = {file_rel_path.string()}; model_config.model = ggufDownloadItem.id; + model_config.name = + name.has_value() ? name.value() : gguf_handler.GetModelConfig().name; yaml_handler.UpdateModelConfig(model_config); auto yaml_path{ggufDownloadItem.localPath}; @@ -223,7 +226,8 @@ std::optional ModelService::GetDownloadedModel( } cpp::result ModelService::HandleDownloadUrlAsync( - const std::string& url, std::optional temp_model_id) { + const std::string& url, std::optional temp_model_id, + std::optional temp_name) { auto url_obj = url_parser::FromUrlString(url); if (url_obj.host == kHuggingFaceHost) { @@ -279,9 +283,9 @@ cpp::result ModelService::HandleDownloadUrlAsync( .localPath = local_path, }}}}; - auto on_finished = [author](const DownloadTask& finishedTask) { + auto on_finished = [author, temp_name](const DownloadTask& finishedTask) { auto gguf_download_item = finishedTask.items[0]; - ParseGguf(gguf_download_item, author); + ParseGguf(gguf_download_item, author, temp_name); }; downloadTask.id = unique_model_id; @@ -346,7 +350,7 @@ cpp::result ModelService::HandleUrl( auto on_finished = [author](const DownloadTask& finishedTask) { auto gguf_download_item = finishedTask.items[0]; - ParseGguf(gguf_download_item, author); + ParseGguf(gguf_download_item, author, std::nullopt); }; auto result = download_service_->AddDownloadTask(downloadTask, on_finished); @@ -770,7 +774,7 @@ cpp::result ModelService::GetModelPullInfo( auto author{url_obj.pathParams[0]}; auto model_id{url_obj.pathParams[1]}; auto file_name{url_obj.pathParams.back()}; - if (author == "cortexso") { + if (author == "cortexso") { return ModelPullInfo{.id = model_id + ":" + url_obj.pathParams[3], .downloaded_models = {}, .available_models = {}, @@ -787,8 +791,10 @@ cpp::result ModelService::GetModelPullInfo( if (parsed.size() != 2) { return cpp::fail("Invalid model handle: " + input); } - return ModelPullInfo{ - .id = input, .downloaded_models = {}, .available_models = {}, .download_url = input}; + return ModelPullInfo{.id = input, + .downloaded_models = {}, + .available_models = {}, + .download_url = input}; } if (input.find("/") != std::string::npos) { diff --git a/engine/services/model_service.h b/engine/services/model_service.h index 495685982..c1600e2a6 100644 --- a/engine/services/model_service.h +++ b/engine/services/model_service.h @@ -39,7 +39,7 @@ class ModelService { std::shared_ptr download_service, std::shared_ptr inference_service) : download_service_{download_service}, - inference_svc_(inference_service) {}; + inference_svc_(inference_service){}; /** * Return model id if download successfully @@ -81,7 +81,8 @@ class ModelService { cpp::result HandleUrl(const std::string& url); cpp::result HandleDownloadUrlAsync( - const std::string& url, std::optional temp_model_id); + const std::string& url, std::optional temp_model_id, + std::optional temp_name); private: /**