Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: models get command #1035

Merged
merged 3 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion engine/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -563,4 +563,5 @@ build
build-deps
.DS_Store

uploads/**
uploads/**
CMakePresets.json
129 changes: 129 additions & 0 deletions engine/commands/model_get_cmd.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
#include "model_get_cmd.h"
nguyenhoangthuan99 marked this conversation as resolved.
Show resolved Hide resolved
#include <filesystem>
#include <iostream>
#include <vector>
#include "config/yaml_config.h"
#include "trantor/utils/Logger.h"
#include "utils/cortex_utils.h"

namespace commands {
ModelGetCmd::ModelGetCmd(std::string modelHandle)
nguyenhoangthuan99 marked this conversation as resolved.
Show resolved Hide resolved
: modelHandle_(std::move(modelHandle)) {}

void ModelGetCmd::Exec() {
if (std::filesystem::exists(cortex_utils::models_folder) &&
std::filesystem::is_directory(cortex_utils::models_folder)) {
bool found_model = false;
// Iterate through directory
for (const auto& entry :
std::filesystem::directory_iterator(cortex_utils::models_folder)) {
if (entry.is_regular_file() && entry.path().stem() == modelHandle_ && entry.path().extension() == ".yaml") {
try {
config::YamlHandler handler;
handler.ModelConfigFromFile(entry.path().string());
const auto& model_config = handler.GetModelConfig();
std::cout << "ModelConfig Details:\n";
std::cout << "-------------------\n";

// Print non-null strings
if (!model_config.id.empty())
std::cout << "id: " << model_config.id << "\n";
if (!model_config.name.empty())
std::cout << "name: " << model_config.name << "\n";
if (!model_config.model.empty())
std::cout << "model: " << model_config.model << "\n";
if (!model_config.version.empty())
std::cout << "version: " << model_config.version << "\n";

// Print non-empty vectors
if (!model_config.stop.empty()) {
std::cout << "stop: [";
for (size_t i = 0; i < model_config.stop.size(); ++i) {
std::cout << model_config.stop[i];
if (i < model_config.stop.size() - 1)
std::cout << ", ";
}
std::cout << "]\n";
}
// Print valid numbers
if (!std::isnan(static_cast<double>(model_config.top_p)))
std::cout << "top_p: " << model_config.top_p << "\n";
if (!std::isnan(static_cast<double>(model_config.temperature)))
std::cout << "temperature: " << model_config.temperature << "\n";
if (!std::isnan(static_cast<double>(model_config.frequency_penalty)))
std::cout << "frequency_penalty: " << model_config.frequency_penalty
<< "\n";
if (!std::isnan(static_cast<double>(model_config.presence_penalty)))
std::cout << "presence_penalty: " << model_config.presence_penalty
<< "\n";
if (!std::isnan(static_cast<double>(model_config.max_tokens)))
std::cout << "max_tokens: " << model_config.max_tokens << "\n";
if (!std::isnan(static_cast<double>(model_config.stream)))
std::cout << "stream: " << std::boolalpha << model_config.stream << "\n";
if (!std::isnan(static_cast<double>(model_config.ngl)))
std::cout << "ngl: " << model_config.ngl << "\n";
if (!std::isnan(static_cast<double>(model_config.ctx_len)))
std::cout << "ctx_len: " << model_config.ctx_len << "\n";

// Print non-null strings
if (!model_config.engine.empty())
std::cout << "engine: " << model_config.engine << "\n";
if (!model_config.prompt_template.empty())
std::cout << "prompt_template: " << model_config.prompt_template << "\n";
if (!model_config.system_template.empty())
std::cout << "system_template: " << model_config.system_template << "\n";
if (!model_config.user_template.empty())
std::cout << "user_template: " << model_config.user_template << "\n";
if (!model_config.ai_template.empty())
std::cout << "ai_template: " << model_config.ai_template << "\n";
if (!model_config.os.empty())
std::cout << "os: " << model_config.os << "\n";
if (!model_config.gpu_arch.empty())
std::cout << "gpu_arch: " << model_config.gpu_arch << "\n";
if (!model_config.quantization_method.empty())
std::cout << "quantization_method: " << model_config.quantization_method
<< "\n";
if (!model_config.precision.empty())
std::cout << "precision: " << model_config.precision << "\n";

if (!std::isnan(static_cast<double>(model_config.tp)))
std::cout << "tp: " << model_config.tp << "\n";

// Print non-null strings
if (!model_config.trtllm_version.empty())
std::cout << "trtllm_version: " << model_config.trtllm_version << "\n";
if (!std::isnan(static_cast<double>(model_config.text_model)))
std::cout << "text_model: " << std::boolalpha << model_config.text_model
<< "\n";

// Print non-empty vectors
if (!model_config.files.empty()) {
std::cout << "files: [";
for (size_t i = 0; i < model_config.files.size(); ++i) {
std::cout << model_config.files[i];
if (i < model_config.files.size() - 1)
std::cout << ", ";
}
std::cout << "]\n";
}

// Print valid size_t number
if (model_config.created != 0)
std::cout << "created: " << model_config.created << "\n";

if (!model_config.object.empty())
std::cout << "object: " << model_config.object << "\n";
if (!model_config.owned_by.empty())
std::cout << "owned_by: " << model_config.owned_by << "\n";

found_model = true;
break;
} catch (const std::exception& e) {
LOG_ERROR << "Error reading yaml file '" << entry.path().string()
<< "': " << e.what();
}
}
}
}
}
}; // namespace commands
15 changes: 15 additions & 0 deletions engine/commands/model_get_cmd.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#pragma once

#include <string>
#include <cmath> // For std::isnan
namespace commands {

class ModelGetCmd {
public:
ModelGetCmd(std::string modelHandle);
void Exec();

private:
std::string modelHandle_;
};
} // namespace commands
9 changes: 9 additions & 0 deletions engine/controllers/command_line_parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include "commands/engine_init_cmd.h"
#include "commands/model_pull_cmd.h"
#include "commands/model_list_cmd.h"
#include "commands/model_get_cmd.h"
#include "commands/start_model_cmd.h"
#include "commands/stop_model_cmd.h"
#include "commands/stop_server_cmd.h"
Expand Down Expand Up @@ -49,6 +50,14 @@ bool CommandLineParser::SetupCommand(int argc, char** argv) {
command.Exec();
});

auto get_models_cmd =
models_cmd->add_subcommand("get", "Get info of {model_id} locally");
get_models_cmd->add_option("model_id", model_id, "");
get_models_cmd->callback([&model_id](){
commands::ModelGetCmd command(model_id);
command.Exec();
});

auto model_pull_cmd =
app_.add_subcommand("pull",
"Download a model from a registry. Working with "
Expand Down
66 changes: 66 additions & 0 deletions engine/controllers/models.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,4 +101,70 @@ void Models::ListModel(
auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret);
resp->setStatusCode(k200OK);
callback(resp);
}

void Models::GetModel(
const HttpRequestPtr& req,
std::function<void(const HttpResponsePtr&)>&& callback) const {
if (!http_util::HasFieldInReq(req, callback, "modelId")) {
return;
}
auto model_handle = (*(req->getJsonObject())).get("modelId", "").asString();
LOG_DEBUG << "GetModel, Model handle: " << model_handle;
Json::Value ret;
ret["object"] = "list";
Json::Value data(Json::arrayValue);
if (std::filesystem::exists(cortex_utils::models_folder) &&
std::filesystem::is_directory(cortex_utils::models_folder)) {
// Iterate through directory
for (const auto& entry :
std::filesystem::directory_iterator(cortex_utils::models_folder)) {
if (entry.is_regular_file() && entry.path().extension() == ".yaml" &&
entry.path().stem() == model_handle) {
try {
config::YamlHandler handler;
handler.ModelConfigFromFile(entry.path().string());
auto const& model_config = handler.GetModelConfig();
Json::Value obj;
obj["name"] = model_config.name;
obj["model"] = model_config.model;
obj["version"] = model_config.version;
Json::Value stop_array(Json::arrayValue);
for (const std::string& stop : model_config.stop)
stop_array.append(stop);
obj["stop"] = stop_array;
obj["top_p"] = model_config.top_p;
obj["temperature"] = model_config.temperature;
obj["presence_penalty"] = model_config.presence_penalty;
obj["max_tokens"] = model_config.max_tokens;
obj["stream"] = model_config.stream;
obj["ngl"] = model_config.ngl;
obj["ctx_len"] = model_config.ctx_len;
obj["engine"] = model_config.engine;
obj["prompt_template"] = model_config.prompt_template;

Json::Value files_array(Json::arrayValue);
for (const std::string& file : model_config.files)
files_array.append(file);
obj["files"] = files_array;
obj["id"] = model_config.id;
obj["created"] = static_cast<uint32_t>(model_config.created);
obj["object"] = model_config.object;
obj["owned_by"] = model_config.owned_by;
if (model_config.engine == "cortex.tensorrt-llm") {
obj["trtllm_version"] = model_config.trtllm_version;
}
data.append(std::move(obj));
} catch (const std::exception& e) {
LOG_ERROR << "Error reading yaml file '" << entry.path().string()
<< "': " << e.what();
}
}
}
}
ret["data"] = data;
ret["result"] = "OK";
auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret);
resp->setStatusCode(k200OK);
callback(resp);
}
3 changes: 3 additions & 0 deletions engine/controllers/models.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,13 @@ class Models : public drogon::HttpController<Models> {
METHOD_LIST_BEGIN
METHOD_ADD(Models::PullModel, "/pull", Post);
METHOD_ADD(Models::ListModel, "/list", Get);
METHOD_ADD(Models::GetModel, "/get", Post);
METHOD_LIST_END

void PullModel(const HttpRequestPtr& req,
std::function<void(const HttpResponsePtr&)>&& callback) const;
void ListModel(const HttpRequestPtr& req,
std::function<void(const HttpResponsePtr&)>&& callback) const;
void GetModel(const HttpRequestPtr& req,
std::function<void(const HttpResponsePtr&)>&& callback) const;
};
Loading