diff --git a/docs/static/openapi/cortex.json b/docs/static/openapi/cortex.json index 0f715456d..8b3acb0e2 100644 --- a/docs/static/openapi/cortex.json +++ b/docs/static/openapi/cortex.json @@ -554,6 +554,46 @@ "tags": ["Models"] } }, + "/v1/models/import": { + "post": { + "operationId": "ModelsController_importModel", + "summary": "Import model", + "description": "Imports a model from a specified path.", + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ImportModelRequest" + }, + "example": { + "model": "model-id", + "modelPath": "/path/to/gguf", + "name": "model display name" + } + } + } + }, + "responses": { + "200": { + "description": "Model is imported successfully!", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ImportModelResponse" + }, + "example": { + "message": "Model is imported successfully!", + "modelHandle": "model-id", + "result": "OK" + } + } + } + } + }, + "tags": ["Models"] + } + }, "/v1/threads": { "post": { "operationId": "ThreadsController_create", @@ -1660,6 +1700,15 @@ "value": "my-custom-model-id" } ] + }, + "name": { + "type": "string", + "description": "The name which will be used to overwrite the model name.", + "examples": [ + { + "value": "my-custom-model-name" + } + ] } } }, @@ -1803,6 +1852,43 @@ } } }, + "ImportModelRequest": { + "type": "object", + "properties": { + "model": { + "type": "string", + "description": "The unique identifier of the model." + }, + "modelPath": { + "type": "string", + "description": "The file path to the model." + }, + "name": { + "type": "string", + "description": "The display name of the model." + } + }, + "required": ["model", "modelPath"] + }, + "ImportModelResponse": { + "type": "object", + "properties": { + "message": { + "type": "string", + "description": "Success message." + }, + "modelHandle": { + "type": "string", + "description": "The unique identifier of the imported model." + }, + "result": { + "type": "string", + "description": "Result status.", + "example": "OK" + } + }, + "required": ["message", "modelHandle", "result"] + }, "CommonResponseDto": { "type": "object", "properties": { diff --git a/engine/controllers/models.cc b/engine/controllers/models.cc index 602c81ab6..939f63f31 100644 --- a/engine/controllers/models.cc +++ b/engine/controllers/models.cc @@ -33,12 +33,19 @@ void Models::PullModel(const HttpRequestPtr& req, desired_model_id = id; } + std::optional desired_model_name = std::nullopt; + auto name_value = (*(req->getJsonObject())).get("name", "").asString(); + + if (!name_value.empty()) { + desired_model_name = name_value; + } + auto handle_model_input = [&, model_handle]() -> cpp::result { CTL_INF("Handle model input, model handle: " + model_handle); if (string_utils::StartsWith(model_handle, "https")) { - return model_service_->HandleDownloadUrlAsync(model_handle, - desired_model_id); + return model_service_->HandleDownloadUrlAsync( + model_handle, desired_model_id, desired_model_name); } else if (model_handle.find(":") != std::string::npos) { auto model_and_branch = string_utils::SplitBy(model_handle, ":"); return model_service_->DownloadModelFromCortexsoAsync( @@ -312,6 +319,7 @@ void Models::ImportModel( } auto modelHandle = (*(req->getJsonObject())).get("model", "").asString(); auto modelPath = (*(req->getJsonObject())).get("modelPath", "").asString(); + auto modelName = (*(req->getJsonObject())).get("name", "").asString(); config::GGUFHandler gguf_handler; config::YamlHandler yaml_handler; cortex::db::Models modellist_utils_obj; @@ -333,6 +341,7 @@ void Models::ImportModel( config::ModelConfig model_config = gguf_handler.GetModelConfig(); model_config.files.push_back(modelPath); model_config.model = modelHandle; + model_config.name = modelName.empty() ? model_config.name : modelName; yaml_handler.UpdateModelConfig(model_config); if (modellist_utils_obj.AddModelEntry(model_entry).value()) { diff --git a/engine/e2e-test/test_api_model_import.py b/engine/e2e-test/test_api_model_import.py index 8dd34ea7a..3f8a82a0d 100644 --- a/engine/e2e-test/test_api_model_import.py +++ b/engine/e2e-test/test_api_model_import.py @@ -18,5 +18,25 @@ def setup_and_teardown(self): def test_model_import_should_be_success(self): body_json = {'model': 'tinyllama:gguf', 'modelPath': '/path/to/local/gguf'} - response = requests.post("http://localhost:3928/models/import", json = body_json) - assert response.status_code == 200 \ No newline at end of file + response = requests.post("http://localhost:3928/models/import", json=body_json) + assert response.status_code == 200 + + @pytest.mark.skipif(True, reason="Expensive test. Only test when you have local gguf file.") + def test_model_import_with_name_should_be_success(self): + body_json = {'model': 'tinyllama:gguf', + 'modelPath': '/path/to/local/gguf', + 'name': 'test_model'} + response = requests.post("http://localhost:3928/models/import", json=body_json) + assert response.status_code == 200 + + def test_model_import_with_invalid_path_should_fail(self): + body_json = {'model': 'tinyllama:gguf', + 'modelPath': '/invalid/path/to/gguf'} + response = requests.post("http://localhost:3928/models/import", json=body_json) + assert response.status_code == 400 + + def test_model_import_with_missing_model_should_fail(self): + body_json = {'modelPath': '/path/to/local/gguf'} + response = requests.post("http://localhost:3928/models/import", json=body_json) + print(response) + assert response.status_code == 409 \ No newline at end of file diff --git a/engine/e2e-test/test_api_model_pull_direct_url.py b/engine/e2e-test/test_api_model_pull_direct_url.py index e93ca2ddd..aa15fbfba 100644 --- a/engine/e2e-test/test_api_model_pull_direct_url.py +++ b/engine/e2e-test/test_api_model_pull_direct_url.py @@ -21,7 +21,7 @@ def setup_and_teardown(self): [ "models", "delete", - "TheBloke:TinyLlama-1.1B-Chat-v0.3-GGUF:tinyllama-1.1b-chat-v0.3.Q2_K.gguf", + "afrideva:zephyr-smol_llama-100m-sft-full-GGUF:zephyr-smol_llama-100m-sft-full.q2_k.gguf", ], ) yield @@ -32,7 +32,7 @@ def setup_and_teardown(self): [ "models", "delete", - "TheBloke:TinyLlama-1.1B-Chat-v0.3-GGUF:tinyllama-1.1b-chat-v0.3.Q2_K.gguf", + "afrideva:zephyr-smol_llama-100m-sft-full-GGUF:zephyr-smol_llama-100m-sft-full.q2_k.gguf", ], ) stop_server() @@ -40,16 +40,35 @@ def setup_and_teardown(self): @pytest.mark.asyncio async def test_model_pull_with_direct_url_should_be_success(self): myobj = { - "model": "https://huggingface.co/TheBloke/TinyLlama-1.1B-Chat-v0.3-GGUF/blob/main/tinyllama-1.1b-chat-v0.3.Q2_K.gguf" + "model": "https://huggingface.co/afrideva/zephyr-smol_llama-100m-sft-full-GGUF/blob/main/zephyr-smol_llama-100m-sft-full.q2_k.gguf" } response = requests.post("http://localhost:3928/models/pull", json=myobj) assert response.status_code == 200 await wait_for_websocket_download_success_event(timeout=None) get_model_response = requests.get( - "http://127.0.0.1:3928/models/TheBloke:TinyLlama-1.1B-Chat-v0.3-GGUF:tinyllama-1.1b-chat-v0.3.Q2_K.gguf" + "http://127.0.0.1:3928/models/afrideva:zephyr-smol_llama-100m-sft-full-GGUF:zephyr-smol_llama-100m-sft-full.q2_k.gguf" ) assert get_model_response.status_code == 200 assert ( get_model_response.json()["model"] - == "TheBloke:TinyLlama-1.1B-Chat-v0.3-GGUF:tinyllama-1.1b-chat-v0.3.Q2_K.gguf" + == "afrideva:zephyr-smol_llama-100m-sft-full-GGUF:zephyr-smol_llama-100m-sft-full.q2_k.gguf" + ) + + @pytest.mark.asyncio + async def test_model_pull_with_direct_url_should_have_desired_name(self): + myobj = { + "model": "https://huggingface.co/afrideva/zephyr-smol_llama-100m-sft-full-GGUF/blob/main/zephyr-smol_llama-100m-sft-full.q2_k.gguf", + "name": "smol_llama_100m" + } + response = requests.post("http://localhost:3928/models/pull", json=myobj) + assert response.status_code == 200 + await wait_for_websocket_download_success_event(timeout=None) + get_model_response = requests.get( + "http://127.0.0.1:3928/models/afrideva:zephyr-smol_llama-100m-sft-full-GGUF:zephyr-smol_llama-100m-sft-full.q2_k.gguf" + ) + assert get_model_response.status_code == 200 + print(get_model_response.json()["name"]) + assert ( + get_model_response.json()["name"] + == "smol_llama_100m" ) diff --git a/engine/services/model_service.cc b/engine/services/model_service.cc index b49df3420..4967b1dd9 100644 --- a/engine/services/model_service.cc +++ b/engine/services/model_service.cc @@ -17,7 +17,8 @@ namespace { void ParseGguf(const DownloadItem& ggufDownloadItem, - std::optional author) { + std::optional author, + std::optional name) { namespace fs = std::filesystem; namespace fmu = file_manager_utils; config::GGUFHandler gguf_handler; @@ -32,6 +33,8 @@ void ParseGguf(const DownloadItem& ggufDownloadItem, fmu::ToRelativeCortexDataPath(fs::path(ggufDownloadItem.localPath)); model_config.files = {file_rel_path.string()}; model_config.model = ggufDownloadItem.id; + model_config.name = + name.has_value() ? name.value() : gguf_handler.GetModelConfig().name; yaml_handler.UpdateModelConfig(model_config); auto yaml_path{ggufDownloadItem.localPath}; @@ -223,7 +226,8 @@ std::optional ModelService::GetDownloadedModel( } cpp::result ModelService::HandleDownloadUrlAsync( - const std::string& url, std::optional temp_model_id) { + const std::string& url, std::optional temp_model_id, + std::optional temp_name) { auto url_obj = url_parser::FromUrlString(url); if (url_obj.host == kHuggingFaceHost) { @@ -279,9 +283,9 @@ cpp::result ModelService::HandleDownloadUrlAsync( .localPath = local_path, }}}}; - auto on_finished = [author](const DownloadTask& finishedTask) { + auto on_finished = [author, temp_name](const DownloadTask& finishedTask) { auto gguf_download_item = finishedTask.items[0]; - ParseGguf(gguf_download_item, author); + ParseGguf(gguf_download_item, author, temp_name); }; downloadTask.id = unique_model_id; @@ -346,7 +350,7 @@ cpp::result ModelService::HandleUrl( auto on_finished = [author](const DownloadTask& finishedTask) { auto gguf_download_item = finishedTask.items[0]; - ParseGguf(gguf_download_item, author); + ParseGguf(gguf_download_item, author, std::nullopt); }; auto result = download_service_->AddDownloadTask(downloadTask, on_finished); @@ -770,7 +774,7 @@ cpp::result ModelService::GetModelPullInfo( auto author{url_obj.pathParams[0]}; auto model_id{url_obj.pathParams[1]}; auto file_name{url_obj.pathParams.back()}; - if (author == "cortexso") { + if (author == "cortexso") { return ModelPullInfo{.id = model_id + ":" + url_obj.pathParams[3], .downloaded_models = {}, .available_models = {}, @@ -787,8 +791,10 @@ cpp::result ModelService::GetModelPullInfo( if (parsed.size() != 2) { return cpp::fail("Invalid model handle: " + input); } - return ModelPullInfo{ - .id = input, .downloaded_models = {}, .available_models = {}, .download_url = input}; + return ModelPullInfo{.id = input, + .downloaded_models = {}, + .available_models = {}, + .download_url = input}; } if (input.find("/") != std::string::npos) { diff --git a/engine/services/model_service.h b/engine/services/model_service.h index 495685982..c1600e2a6 100644 --- a/engine/services/model_service.h +++ b/engine/services/model_service.h @@ -39,7 +39,7 @@ class ModelService { std::shared_ptr download_service, std::shared_ptr inference_service) : download_service_{download_service}, - inference_svc_(inference_service) {}; + inference_svc_(inference_service){}; /** * Return model id if download successfully @@ -81,7 +81,8 @@ class ModelService { cpp::result HandleUrl(const std::string& url); cpp::result HandleDownloadUrlAsync( - const std::string& url, std::optional temp_model_id); + const std::string& url, std::optional temp_model_id, + std::optional temp_name); private: /**