Skip to content

Commit

Permalink
chore: desired name - model pull API
Browse files Browse the repository at this point in the history
  • Loading branch information
louis-jan committed Nov 1, 2024
1 parent 52a2f69 commit 4ba2b49
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 12 deletions.
11 changes: 9 additions & 2 deletions engine/controllers/models.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,19 @@ void Models::PullModel(const HttpRequestPtr& req,
desired_model_id = id;
}

std::optional<std::string> desired_model_name = std::nullopt;
auto name_value = (*(req->getJsonObject())).get("name", "").asString();

if (!name_value.empty()) {
desired_model_name = name_value;
}

auto handle_model_input =
[&, model_handle]() -> cpp::result<DownloadTask, std::string> {
CTL_INF("Handle model input, model handle: " + model_handle);
if (string_utils::StartsWith(model_handle, "https")) {
return model_service_->HandleDownloadUrlAsync(model_handle,
desired_model_id);
return model_service_->HandleDownloadUrlAsync(
model_handle, desired_model_id, desired_model_name);
} else if (model_handle.find(":") != std::string::npos) {
auto model_and_branch = string_utils::SplitBy(model_handle, ":");
return model_service_->DownloadModelFromCortexsoAsync(
Expand Down
19 changes: 19 additions & 0 deletions engine/e2e-test/test_api_model_pull_direct_url.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,22 @@ async def test_model_pull_with_direct_url_should_be_success(self):
get_model_response.json()["model"]
== "TheBloke:TinyLlama-1.1B-Chat-v0.3-GGUF:tinyllama-1.1b-chat-v0.3.Q2_K.gguf"
)

@pytest.mark.asyncio
async def test_model_pull_with_direct_url_should_have_desired_name(self):
myobj = {
"model": "https://huggingface.co/afrideva/zephyr-smol_llama-100m-sft-full-GGUF/blob/main/zephyr-smol_llama-100m-sft-full.q2_k.gguf",
"name": "smol_llama_100m"
}
response = requests.post("http://localhost:3928/models/pull", json=myobj)
assert response.status_code == 200
await wait_for_websocket_download_success_event(timeout=None)
get_model_response = requests.get(
"http://127.0.0.1:3928/models/afrideva:zephyr-smol_llama-100m-sft-full-GGUF:zephyr-smol_llama-100m-sft-full.q2_k.gguf"
)
assert get_model_response.status_code == 200
print(get_model_response.json()["name"])
assert (
get_model_response.json()["name"]
== "smol_llama_100m"
)
22 changes: 14 additions & 8 deletions engine/services/model_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

namespace {
void ParseGguf(const DownloadItem& ggufDownloadItem,
std::optional<std::string> author) {
std::optional<std::string> author,
std::optional<std::string> name) {
namespace fs = std::filesystem;
namespace fmu = file_manager_utils;
config::GGUFHandler gguf_handler;
Expand All @@ -32,6 +33,8 @@ void ParseGguf(const DownloadItem& ggufDownloadItem,
fmu::ToRelativeCortexDataPath(fs::path(ggufDownloadItem.localPath));
model_config.files = {file_rel_path.string()};
model_config.model = ggufDownloadItem.id;
model_config.name =
name.has_value() ? name.value() : gguf_handler.GetModelConfig().name;
yaml_handler.UpdateModelConfig(model_config);

auto yaml_path{ggufDownloadItem.localPath};
Expand Down Expand Up @@ -223,7 +226,8 @@ std::optional<config::ModelConfig> ModelService::GetDownloadedModel(
}

cpp::result<DownloadTask, std::string> ModelService::HandleDownloadUrlAsync(
const std::string& url, std::optional<std::string> temp_model_id) {
const std::string& url, std::optional<std::string> temp_model_id,
std::optional<std::string> temp_name) {
auto url_obj = url_parser::FromUrlString(url);

if (url_obj.host == kHuggingFaceHost) {
Expand Down Expand Up @@ -279,9 +283,9 @@ cpp::result<DownloadTask, std::string> ModelService::HandleDownloadUrlAsync(
.localPath = local_path,
}}}};

auto on_finished = [author](const DownloadTask& finishedTask) {
auto on_finished = [author, temp_name](const DownloadTask& finishedTask) {
auto gguf_download_item = finishedTask.items[0];
ParseGguf(gguf_download_item, author);
ParseGguf(gguf_download_item, author, temp_name);
};

downloadTask.id = unique_model_id;
Expand Down Expand Up @@ -346,7 +350,7 @@ cpp::result<std::string, std::string> ModelService::HandleUrl(

auto on_finished = [author](const DownloadTask& finishedTask) {
auto gguf_download_item = finishedTask.items[0];
ParseGguf(gguf_download_item, author);
ParseGguf(gguf_download_item, author, std::nullopt);
};

auto result = download_service_->AddDownloadTask(downloadTask, on_finished);
Expand Down Expand Up @@ -770,7 +774,7 @@ cpp::result<ModelPullInfo, std::string> ModelService::GetModelPullInfo(
auto author{url_obj.pathParams[0]};
auto model_id{url_obj.pathParams[1]};
auto file_name{url_obj.pathParams.back()};
if (author == "cortexso") {
if (author == "cortexso") {
return ModelPullInfo{.id = model_id + ":" + url_obj.pathParams[3],
.downloaded_models = {},
.available_models = {},
Expand All @@ -787,8 +791,10 @@ cpp::result<ModelPullInfo, std::string> ModelService::GetModelPullInfo(
if (parsed.size() != 2) {
return cpp::fail("Invalid model handle: " + input);
}
return ModelPullInfo{
.id = input, .downloaded_models = {}, .available_models = {}, .download_url = input};
return ModelPullInfo{.id = input,
.downloaded_models = {},
.available_models = {},
.download_url = input};
}

if (input.find("/") != std::string::npos) {
Expand Down
5 changes: 3 additions & 2 deletions engine/services/model_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class ModelService {
std::shared_ptr<DownloadService> download_service,
std::shared_ptr<services::InferenceService> inference_service)
: download_service_{download_service},
inference_svc_(inference_service) {};
inference_svc_(inference_service){};

/**
* Return model id if download successfully
Expand Down Expand Up @@ -81,7 +81,8 @@ class ModelService {
cpp::result<std::string, std::string> HandleUrl(const std::string& url);

cpp::result<DownloadTask, std::string> HandleDownloadUrlAsync(
const std::string& url, std::optional<std::string> temp_model_id);
const std::string& url, std::optional<std::string> temp_model_id,
std::optional<std::string> temp_name);

private:
/**
Expand Down

0 comments on commit 4ba2b49

Please sign in to comment.