Skip to content

Commit

Permalink
Merge pull request #1430 from janhq/j/update-pull-command
Browse files Browse the repository at this point in the history
feat: uplift pull and run cmd
  • Loading branch information
namchuai authored Oct 4, 2024
2 parents 8bb9523 + 75a98b9 commit c3a97f7
Show file tree
Hide file tree
Showing 10 changed files with 169 additions and 27 deletions.
2 changes: 1 addition & 1 deletion engine/commands/model_pull_cmd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@ void ModelPullCmd::Exec(const std::string& input) {
auto result = model_service_.DownloadModel(input);
if (result.has_error()) {
CLI_LOG(result.error());
}
}
}
}; // namespace commands
21 changes: 14 additions & 7 deletions engine/commands/run_cmd.cc
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
#include "run_cmd.h"
#include "chat_completion_cmd.h"
#include "config/yaml_config.h"
#include "cortex_upd_cmd.h"
#include "database/models.h"
#include "model_start_cmd.h"
#include "model_status_cmd.h"
#include "server_start_cmd.h"
#include "utils/cli_selection_utils.h"
#include "utils/logging_utils.h"

#include "cortex_upd_cmd.h"

namespace commands {

namespace {
Expand All @@ -33,14 +33,21 @@ void RunCmd::Exec(bool chat_flag) {

// Download model if it does not exist
{
if (!modellist_handler.HasModel(model_handle_)) {
auto related_models_ids = modellist_handler.FindRelatedModel(model_handle_);
if (related_models_ids.has_error() || related_models_ids.value().empty()) {
auto result = model_service_.DownloadModel(model_handle_);
if (result.has_error()) {
CTL_ERR("Error: " << result.error());
return;
}
model_id = result.value();
CTL_INF("model_id: " << model_id.value());
} else if (related_models_ids.value().size() == 1) {
model_id = related_models_ids.value().front();
} else { // multiple models with nearly same name found
auto selection = cli_selection_utils::PrintSelection(
related_models_ids.value(), "Local Models: (press enter to select)");
if (!selection.has_value()) {
return;
}
model_id = selection.value();
CLI_LOG("Selected: " << selection.value());
}
}

Expand Down
2 changes: 1 addition & 1 deletion engine/controllers/models.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ void Models::PullModel(const HttpRequestPtr& req,
CTL_INF("Handle model input, model handle: " + model_handle);
if (string_utils::StartsWith(model_handle, "https")) {
return model_service_.HandleUrl(model_handle, true);
} else if (model_handle.find(":") == std::string::npos) {
} else if (model_handle.find(":") != std::string::npos) {
auto model_and_branch = string_utils::SplitBy(model_handle, ":");
return model_service_.DownloadModelFromCortexso(
model_and_branch[0], model_and_branch[1], true);
Expand Down
23 changes: 23 additions & 0 deletions engine/database/models.cc
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,29 @@ cpp::result<bool, std::string> Models::DeleteModelEntry(
}
}

cpp::result<std::vector<std::string>, std::string> Models::FindRelatedModel(
const std::string& identifier) const {
// TODO (namh): add check for alias as well
try {
std::vector<std::string> related_models;
SQLite::Statement query(
db_,
"SELECT model_id FROM models WHERE model_id LIKE ? OR model_id LIKE ? "
"OR model_id LIKE ? OR model_id LIKE ?");
query.bind(1, identifier + ":%");
query.bind(2, "%:" + identifier);
query.bind(3, "%:" + identifier + ":%");
query.bind(4, identifier);

while (query.executeStep()) {
related_models.push_back(query.getColumn(0).getString());
}
return related_models;
} catch (const std::exception& e) {
return cpp::fail(e.what());
}
}

bool Models::HasModel(const std::string& identifier) const {
try {
SQLite::Statement query(
Expand Down
18 changes: 11 additions & 7 deletions engine/database/models.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class Models {
const std::string& model_id,
const std::string& model_alias) const;

cpp::result<std::vector<ModelEntry>, std::string> LoadModelListNoLock() const;
cpp::result<std::vector<ModelEntry>, std::string> LoadModelListNoLock() const;

public:
static const std::string kModelListPath;
Expand All @@ -35,15 +35,19 @@ class Models {
std::string GenerateShortenedAlias(
const std::string& model_id,
const std::vector<ModelEntry>& entries) const;
cpp::result<ModelEntry, std::string> GetModelInfo(const std::string& identifier) const;
cpp::result<ModelEntry, std::string> GetModelInfo(
const std::string& identifier) const;
void PrintModelInfo(const ModelEntry& entry) const;
cpp::result<bool, std::string> AddModelEntry(ModelEntry new_entry,
bool use_short_alias = false);
cpp::result<bool, std::string> UpdateModelEntry(const std::string& identifier,
const ModelEntry& updated_entry);
cpp::result<bool, std::string> DeleteModelEntry(const std::string& identifier);
cpp::result<bool, std::string> UpdateModelAlias(const std::string& model_id,
const std::string& model_alias);
cpp::result<bool, std::string> UpdateModelEntry(
const std::string& identifier, const ModelEntry& updated_entry);
cpp::result<bool, std::string> DeleteModelEntry(
const std::string& identifier);
cpp::result<bool, std::string> UpdateModelAlias(
const std::string& model_id, const std::string& model_alias);
cpp::result<std::vector<std::string>, std::string> FindRelatedModel(
const std::string& identifier) const;
bool HasModel(const std::string& identifier) const;
};
} // namespace cortex::db
41 changes: 35 additions & 6 deletions engine/services/model_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -151,17 +151,46 @@ cpp::result<std::string, std::string> ModelService::HandleCortexsoModel(
return cpp::fail(branches.error());
}

std::vector<std::string> options{};
auto default_model_branch = huggingface_utils::GetDefaultBranch(modelName);

cortex::db::Models modellist_handler;
auto downloaded_model_ids =
modellist_handler.FindRelatedModel(modelName).value_or(
std::vector<std::string>{});

std::vector<std::string> avai_download_opts{};
for (const auto& branch : branches.value()) {
if (branch.second.name != "main") {
options.emplace_back(branch.second.name);
if (branch.second.name == "main") { // main branch only have metadata. skip
continue;
}
auto model_id = modelName + ":" + branch.second.name;
if (std::find(downloaded_model_ids.begin(), downloaded_model_ids.end(),
model_id) !=
downloaded_model_ids.end()) { // if downloaded, we skip it
continue;
}
avai_download_opts.emplace_back(model_id);
}
if (options.empty()) {

if (avai_download_opts.empty()) {
// TODO: only with pull, we return
return cpp::fail("No variant available");
}
auto selection = cli_selection_utils::PrintSelection(options);
return DownloadModelFromCortexso(modelName, selection.value());
std::optional<std::string> normalized_def_branch = std::nullopt;
if (default_model_branch.has_value()) {
normalized_def_branch = modelName + ":" + default_model_branch.value();
}
string_utils::SortStrings(downloaded_model_ids);
string_utils::SortStrings(avai_download_opts);
auto selection = cli_selection_utils::PrintModelSelection(
downloaded_model_ids, avai_download_opts, normalized_def_branch);
if (!selection.has_value()) {
return cpp::fail("Invalid selection");
}

CLI_LOG("Selected: " << selection.value());
auto branch_name = selection.value().substr(modelName.size() + 1);
return DownloadModelFromCortexso(modelName, branch_name, false);
}

std::optional<config::ModelConfig> ModelService::GetDownloadedModel(
Expand Down
56 changes: 53 additions & 3 deletions engine/utils/cli_selection_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,67 @@
#include <optional>
#include <string>
#include <vector>
#include "utils/logging_utils.h"

namespace cli_selection_utils {
inline void PrintMenu(const std::vector<std::string>& options) {
auto index{1};
const std::string indent = std::string(4, ' ');
inline void PrintMenu(
const std::vector<std::string>& options,
const std::optional<std::string> default_option = std::nullopt,
const int start_index = 1) {
auto index{start_index};
for (const auto& option : options) {
std::cout << index << ". " << option << "\n";
bool is_default = false;
if (default_option.has_value() && option == default_option.value()) {
is_default = true;
}
std::string selection{std::to_string(index) + ". " + option +
(is_default ? " (default)" : "") + "\n"};
std::cout << indent << selection;
index++;
}
std::endl(std::cout);
}

inline std::optional<std::string> PrintModelSelection(
const std::vector<std::string>& downloaded,
const std::vector<std::string>& availables,
const std::optional<std::string> default_selection = std::nullopt) {

std::string selection{""};
if (!downloaded.empty()) {
std::cout << "Downloaded models:\n";
for (const auto& option : downloaded) {
std::cout << indent << option << "\n";
}
std::endl(std::cout);
}

if (!availables.empty()) {
std::cout << "Available to download:\n";
PrintMenu(availables, default_selection, 1);
}

std::cout << "Select a model (" << 1 << "-" << availables.size() << "): ";
std::getline(std::cin, selection);

// if selection is empty and default selection is inside availables, return default_selection
if (selection.empty()) {
if (default_selection.has_value() &&
std::find(availables.begin(), availables.end(),
default_selection.value()) != availables.end()) {
return default_selection;
}
return std::nullopt;
}

if (std::stoi(selection) > availables.size() || std::stoi(selection) < 1) {
return std::nullopt;
}

return availables[std::stoi(selection) - 1];
}

inline std::optional<std::string> PrintSelection(
const std::vector<std::string>& options,
const std::string& title = "Select an option") {
Expand Down
3 changes: 1 addition & 2 deletions engine/utils/curl_utils.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#include <curl/curl.h>
#include <nlohmann/json.hpp>
#include <string>
#include "utils/logging_utils.h"
#include "utils/result.hpp"
#include "yaml-cpp/yaml.h"

Expand Down Expand Up @@ -74,4 +73,4 @@ inline cpp::result<nlohmann::json, std::string> SimpleGetJson(
" parsing error: " + std::string(e.what()));
}
}
} // namespace curl_utils
} // namespace curl_utils
26 changes: 26 additions & 0 deletions engine/utils/huggingface_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,15 @@ GetHuggingFaceModelRepoInfo(const std::string& author,
return model_repo_info;
}

inline std::string GetMetadataUrl(const std::string& model_id) {
auto url_obj = url_parser::Url{
.protocol = "https",
.host = kHuggingfaceHost,
.pathParams = {"cortexso", model_id, "resolve", "main", "metadata.yml"}};

return url_obj.ToFullPath();
}

inline std::string GetDownloadableUrl(const std::string& author,
const std::string& modelName,
const std::string& fileName,
Expand All @@ -151,4 +160,21 @@ inline std::string GetDownloadableUrl(const std::string& author,
};
return url_parser::FromUrl(url_obj);
}

inline std::optional<std::string> GetDefaultBranch(
const std::string& model_name) {
auto default_model_branch =
curl_utils::ReadRemoteYaml(GetMetadataUrl(model_name));

if (default_model_branch.has_error()) {
return std::nullopt;
}

auto metadata = default_model_branch.value();
auto default_branch = metadata["default"];
if (default_branch.IsDefined()) {
return default_branch.as<std::string>();
}
return std::nullopt;
}
} // namespace huggingface_utils
4 changes: 4 additions & 0 deletions engine/utils/string_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ inline bool StartsWith(const std::string& str, const std::string& prefix) {
return str.rfind(prefix, 0) == 0;
}

inline void SortStrings(std::vector<std::string>& strings) {
std::sort(strings.begin(), strings.end());
}

inline bool EndsWith(const std::string& str, const std::string& suffix) {
if (str.length() >= suffix.length()) {
return (0 == str.compare(str.length() - suffix.length(), suffix.length(),
Expand Down

0 comments on commit c3a97f7

Please sign in to comment.