Skip to content

Commit

Permalink
chore: add model name as a parameter support during import via API (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
louis-jan authored Nov 1, 2024
1 parent 166cdb5 commit f5fbad6
Show file tree
Hide file tree
Showing 6 changed files with 160 additions and 19 deletions.
86 changes: 86 additions & 0 deletions docs/static/openapi/cortex.json
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,46 @@
"tags": ["Models"]
}
},
"/v1/models/import": {
"post": {
"operationId": "ModelsController_importModel",
"summary": "Import model",
"description": "Imports a model from a specified path.",
"requestBody": {
"required": true,
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ImportModelRequest"
},
"example": {
"model": "model-id",
"modelPath": "/path/to/gguf",
"name": "model display name"
}
}
}
},
"responses": {
"200": {
"description": "Model is imported successfully!",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ImportModelResponse"
},
"example": {
"message": "Model is imported successfully!",
"modelHandle": "model-id",
"result": "OK"
}
}
}
}
},
"tags": ["Models"]
}
},
"/v1/threads": {
"post": {
"operationId": "ThreadsController_create",
Expand Down Expand Up @@ -1660,6 +1700,15 @@
"value": "my-custom-model-id"
}
]
},
"name": {
"type": "string",
"description": "The name which will be used to overwrite the model name.",
"examples": [
{
"value": "my-custom-model-name"
}
]
}
}
},
Expand Down Expand Up @@ -1803,6 +1852,43 @@
}
}
},
"ImportModelRequest": {
"type": "object",
"properties": {
"model": {
"type": "string",
"description": "The unique identifier of the model."
},
"modelPath": {
"type": "string",
"description": "The file path to the model."
},
"name": {
"type": "string",
"description": "The display name of the model."
}
},
"required": ["model", "modelPath"]
},
"ImportModelResponse": {
"type": "object",
"properties": {
"message": {
"type": "string",
"description": "Success message."
},
"modelHandle": {
"type": "string",
"description": "The unique identifier of the imported model."
},
"result": {
"type": "string",
"description": "Result status.",
"example": "OK"
}
},
"required": ["message", "modelHandle", "result"]
},
"CommonResponseDto": {
"type": "object",
"properties": {
Expand Down
13 changes: 11 additions & 2 deletions engine/controllers/models.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,19 @@ void Models::PullModel(const HttpRequestPtr& req,
desired_model_id = id;
}

std::optional<std::string> desired_model_name = std::nullopt;
auto name_value = (*(req->getJsonObject())).get("name", "").asString();

if (!name_value.empty()) {
desired_model_name = name_value;
}

auto handle_model_input =
[&, model_handle]() -> cpp::result<DownloadTask, std::string> {
CTL_INF("Handle model input, model handle: " + model_handle);
if (string_utils::StartsWith(model_handle, "https")) {
return model_service_->HandleDownloadUrlAsync(model_handle,
desired_model_id);
return model_service_->HandleDownloadUrlAsync(
model_handle, desired_model_id, desired_model_name);
} else if (model_handle.find(":") != std::string::npos) {
auto model_and_branch = string_utils::SplitBy(model_handle, ":");
return model_service_->DownloadModelFromCortexsoAsync(
Expand Down Expand Up @@ -312,6 +319,7 @@ void Models::ImportModel(
}
auto modelHandle = (*(req->getJsonObject())).get("model", "").asString();
auto modelPath = (*(req->getJsonObject())).get("modelPath", "").asString();
auto modelName = (*(req->getJsonObject())).get("name", "").asString();
config::GGUFHandler gguf_handler;
config::YamlHandler yaml_handler;
cortex::db::Models modellist_utils_obj;
Expand All @@ -333,6 +341,7 @@ void Models::ImportModel(
config::ModelConfig model_config = gguf_handler.GetModelConfig();
model_config.files.push_back(modelPath);
model_config.model = modelHandle;
model_config.name = modelName.empty() ? model_config.name : modelName;
yaml_handler.UpdateModelConfig(model_config);

if (modellist_utils_obj.AddModelEntry(model_entry).value()) {
Expand Down
24 changes: 22 additions & 2 deletions engine/e2e-test/test_api_model_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,25 @@ def setup_and_teardown(self):
def test_model_import_should_be_success(self):
body_json = {'model': 'tinyllama:gguf',
'modelPath': '/path/to/local/gguf'}
response = requests.post("http://localhost:3928/models/import", json = body_json)
assert response.status_code == 200
response = requests.post("http://localhost:3928/models/import", json=body_json)
assert response.status_code == 200

@pytest.mark.skipif(True, reason="Expensive test. Only test when you have local gguf file.")
def test_model_import_with_name_should_be_success(self):
body_json = {'model': 'tinyllama:gguf',
'modelPath': '/path/to/local/gguf',
'name': 'test_model'}
response = requests.post("http://localhost:3928/models/import", json=body_json)
assert response.status_code == 200

def test_model_import_with_invalid_path_should_fail(self):
body_json = {'model': 'tinyllama:gguf',
'modelPath': '/invalid/path/to/gguf'}
response = requests.post("http://localhost:3928/models/import", json=body_json)
assert response.status_code == 400

def test_model_import_with_missing_model_should_fail(self):
body_json = {'modelPath': '/path/to/local/gguf'}
response = requests.post("http://localhost:3928/models/import", json=body_json)
print(response)
assert response.status_code == 409
29 changes: 24 additions & 5 deletions engine/e2e-test/test_api_model_pull_direct_url.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def setup_and_teardown(self):
[
"models",
"delete",
"TheBloke:TinyLlama-1.1B-Chat-v0.3-GGUF:tinyllama-1.1b-chat-v0.3.Q2_K.gguf",
"afrideva:zephyr-smol_llama-100m-sft-full-GGUF:zephyr-smol_llama-100m-sft-full.q2_k.gguf",
],
)
yield
Expand All @@ -32,24 +32,43 @@ def setup_and_teardown(self):
[
"models",
"delete",
"TheBloke:TinyLlama-1.1B-Chat-v0.3-GGUF:tinyllama-1.1b-chat-v0.3.Q2_K.gguf",
"afrideva:zephyr-smol_llama-100m-sft-full-GGUF:zephyr-smol_llama-100m-sft-full.q2_k.gguf",
],
)
stop_server()

@pytest.mark.asyncio
async def test_model_pull_with_direct_url_should_be_success(self):
myobj = {
"model": "https://huggingface.co/TheBloke/TinyLlama-1.1B-Chat-v0.3-GGUF/blob/main/tinyllama-1.1b-chat-v0.3.Q2_K.gguf"
"model": "https://huggingface.co/afrideva/zephyr-smol_llama-100m-sft-full-GGUF/blob/main/zephyr-smol_llama-100m-sft-full.q2_k.gguf"
}
response = requests.post("http://localhost:3928/models/pull", json=myobj)
assert response.status_code == 200
await wait_for_websocket_download_success_event(timeout=None)
get_model_response = requests.get(
"http://127.0.0.1:3928/models/TheBloke:TinyLlama-1.1B-Chat-v0.3-GGUF:tinyllama-1.1b-chat-v0.3.Q2_K.gguf"
"http://127.0.0.1:3928/models/afrideva:zephyr-smol_llama-100m-sft-full-GGUF:zephyr-smol_llama-100m-sft-full.q2_k.gguf"
)
assert get_model_response.status_code == 200
assert (
get_model_response.json()["model"]
== "TheBloke:TinyLlama-1.1B-Chat-v0.3-GGUF:tinyllama-1.1b-chat-v0.3.Q2_K.gguf"
== "afrideva:zephyr-smol_llama-100m-sft-full-GGUF:zephyr-smol_llama-100m-sft-full.q2_k.gguf"
)

@pytest.mark.asyncio
async def test_model_pull_with_direct_url_should_have_desired_name(self):
myobj = {
"model": "https://huggingface.co/afrideva/zephyr-smol_llama-100m-sft-full-GGUF/blob/main/zephyr-smol_llama-100m-sft-full.q2_k.gguf",
"name": "smol_llama_100m"
}
response = requests.post("http://localhost:3928/models/pull", json=myobj)
assert response.status_code == 200
await wait_for_websocket_download_success_event(timeout=None)
get_model_response = requests.get(
"http://127.0.0.1:3928/models/afrideva:zephyr-smol_llama-100m-sft-full-GGUF:zephyr-smol_llama-100m-sft-full.q2_k.gguf"
)
assert get_model_response.status_code == 200
print(get_model_response.json()["name"])
assert (
get_model_response.json()["name"]
== "smol_llama_100m"
)
22 changes: 14 additions & 8 deletions engine/services/model_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

namespace {
void ParseGguf(const DownloadItem& ggufDownloadItem,
std::optional<std::string> author) {
std::optional<std::string> author,
std::optional<std::string> name) {
namespace fs = std::filesystem;
namespace fmu = file_manager_utils;
config::GGUFHandler gguf_handler;
Expand All @@ -32,6 +33,8 @@ void ParseGguf(const DownloadItem& ggufDownloadItem,
fmu::ToRelativeCortexDataPath(fs::path(ggufDownloadItem.localPath));
model_config.files = {file_rel_path.string()};
model_config.model = ggufDownloadItem.id;
model_config.name =
name.has_value() ? name.value() : gguf_handler.GetModelConfig().name;
yaml_handler.UpdateModelConfig(model_config);

auto yaml_path{ggufDownloadItem.localPath};
Expand Down Expand Up @@ -223,7 +226,8 @@ std::optional<config::ModelConfig> ModelService::GetDownloadedModel(
}

cpp::result<DownloadTask, std::string> ModelService::HandleDownloadUrlAsync(
const std::string& url, std::optional<std::string> temp_model_id) {
const std::string& url, std::optional<std::string> temp_model_id,
std::optional<std::string> temp_name) {
auto url_obj = url_parser::FromUrlString(url);

if (url_obj.host == kHuggingFaceHost) {
Expand Down Expand Up @@ -279,9 +283,9 @@ cpp::result<DownloadTask, std::string> ModelService::HandleDownloadUrlAsync(
.localPath = local_path,
}}}};

auto on_finished = [author](const DownloadTask& finishedTask) {
auto on_finished = [author, temp_name](const DownloadTask& finishedTask) {
auto gguf_download_item = finishedTask.items[0];
ParseGguf(gguf_download_item, author);
ParseGguf(gguf_download_item, author, temp_name);
};

downloadTask.id = unique_model_id;
Expand Down Expand Up @@ -346,7 +350,7 @@ cpp::result<std::string, std::string> ModelService::HandleUrl(

auto on_finished = [author](const DownloadTask& finishedTask) {
auto gguf_download_item = finishedTask.items[0];
ParseGguf(gguf_download_item, author);
ParseGguf(gguf_download_item, author, std::nullopt);
};

auto result = download_service_->AddDownloadTask(downloadTask, on_finished);
Expand Down Expand Up @@ -770,7 +774,7 @@ cpp::result<ModelPullInfo, std::string> ModelService::GetModelPullInfo(
auto author{url_obj.pathParams[0]};
auto model_id{url_obj.pathParams[1]};
auto file_name{url_obj.pathParams.back()};
if (author == "cortexso") {
if (author == "cortexso") {
return ModelPullInfo{.id = model_id + ":" + url_obj.pathParams[3],
.downloaded_models = {},
.available_models = {},
Expand All @@ -787,8 +791,10 @@ cpp::result<ModelPullInfo, std::string> ModelService::GetModelPullInfo(
if (parsed.size() != 2) {
return cpp::fail("Invalid model handle: " + input);
}
return ModelPullInfo{
.id = input, .downloaded_models = {}, .available_models = {}, .download_url = input};
return ModelPullInfo{.id = input,
.downloaded_models = {},
.available_models = {},
.download_url = input};
}

if (input.find("/") != std::string::npos) {
Expand Down
5 changes: 3 additions & 2 deletions engine/services/model_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class ModelService {
std::shared_ptr<DownloadService> download_service,
std::shared_ptr<services::InferenceService> inference_service)
: download_service_{download_service},
inference_svc_(inference_service) {};
inference_svc_(inference_service){};

/**
* Return model id if download successfully
Expand Down Expand Up @@ -81,7 +81,8 @@ class ModelService {
cpp::result<std::string, std::string> HandleUrl(const std::string& url);

cpp::result<DownloadTask, std::string> HandleDownloadUrlAsync(
const std::string& url, std::optional<std::string> temp_model_id);
const std::string& url, std::optional<std::string> temp_model_id,
std::optional<std::string> temp_name);

private:
/**
Expand Down

0 comments on commit f5fbad6

Please sign in to comment.