Skip to content

Commit

Permalink
feat: models get command (#1035)
Browse files Browse the repository at this point in the history
  • Loading branch information
nguyenhoangthuan99 authored Aug 29, 2024
1 parent bbc3e31 commit ba6816f
Show file tree
Hide file tree
Showing 6 changed files with 231 additions and 1 deletion.
3 changes: 2 additions & 1 deletion engine/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -563,4 +563,5 @@ build
build-deps
.DS_Store

uploads/**
uploads/**
CMakePresets.json
135 changes: 135 additions & 0 deletions engine/commands/model_get_cmd.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
#include "model_get_cmd.h"
#include <filesystem>
#include <iostream>
#include <vector>
#include "config/yaml_config.h"
#include "trantor/utils/Logger.h"
#include "utils/cortex_utils.h"

namespace commands {
ModelGetCmd::ModelGetCmd(std::string model_handle)
: model_handle_(std::move(model_handle)) {}

void ModelGetCmd::Exec() {
if (std::filesystem::exists(cortex_utils::models_folder) &&
std::filesystem::is_directory(cortex_utils::models_folder)) {
bool found_model = false;
// Iterate through directory
for (const auto& entry :
std::filesystem::directory_iterator(cortex_utils::models_folder)) {
if (entry.is_regular_file() && entry.path().stem() == model_handle_ &&
entry.path().extension() == ".yaml") {
try {
config::YamlHandler handler;
handler.ModelConfigFromFile(entry.path().string());
const auto& model_config = handler.GetModelConfig();
std::cout << "ModelConfig Details:\n";
std::cout << "-------------------\n";

// Print non-null strings
if (!model_config.id.empty())
std::cout << "id: " << model_config.id << "\n";
if (!model_config.name.empty())
std::cout << "name: " << model_config.name << "\n";
if (!model_config.model.empty())
std::cout << "model: " << model_config.model << "\n";
if (!model_config.version.empty())
std::cout << "version: " << model_config.version << "\n";

// Print non-empty vectors
if (!model_config.stop.empty()) {
std::cout << "stop: [";
for (size_t i = 0; i < model_config.stop.size(); ++i) {
std::cout << model_config.stop[i];
if (i < model_config.stop.size() - 1)
std::cout << ", ";
}
std::cout << "]\n";
}
// Print valid numbers
if (!std::isnan(static_cast<double>(model_config.top_p)))
std::cout << "top_p: " << model_config.top_p << "\n";
if (!std::isnan(static_cast<double>(model_config.temperature)))
std::cout << "temperature: " << model_config.temperature << "\n";
if (!std::isnan(static_cast<double>(model_config.frequency_penalty)))
std::cout << "frequency_penalty: " << model_config.frequency_penalty
<< "\n";
if (!std::isnan(static_cast<double>(model_config.presence_penalty)))
std::cout << "presence_penalty: " << model_config.presence_penalty
<< "\n";
if (!std::isnan(static_cast<double>(model_config.max_tokens)))
std::cout << "max_tokens: " << model_config.max_tokens << "\n";
if (!std::isnan(static_cast<double>(model_config.stream)))
std::cout << "stream: " << std::boolalpha << model_config.stream
<< "\n";
if (!std::isnan(static_cast<double>(model_config.ngl)))
std::cout << "ngl: " << model_config.ngl << "\n";
if (!std::isnan(static_cast<double>(model_config.ctx_len)))
std::cout << "ctx_len: " << model_config.ctx_len << "\n";

// Print non-null strings
if (!model_config.engine.empty())
std::cout << "engine: " << model_config.engine << "\n";
if (!model_config.prompt_template.empty())
std::cout << "prompt_template: " << model_config.prompt_template
<< "\n";
if (!model_config.system_template.empty())
std::cout << "system_template: " << model_config.system_template
<< "\n";
if (!model_config.user_template.empty())
std::cout << "user_template: " << model_config.user_template
<< "\n";
if (!model_config.ai_template.empty())
std::cout << "ai_template: " << model_config.ai_template << "\n";
if (!model_config.os.empty())
std::cout << "os: " << model_config.os << "\n";
if (!model_config.gpu_arch.empty())
std::cout << "gpu_arch: " << model_config.gpu_arch << "\n";
if (!model_config.quantization_method.empty())
std::cout << "quantization_method: "
<< model_config.quantization_method << "\n";
if (!model_config.precision.empty())
std::cout << "precision: " << model_config.precision << "\n";

if (!std::isnan(static_cast<double>(model_config.tp)))
std::cout << "tp: " << model_config.tp << "\n";

// Print non-null strings
if (!model_config.trtllm_version.empty())
std::cout << "trtllm_version: " << model_config.trtllm_version
<< "\n";
if (!std::isnan(static_cast<double>(model_config.text_model)))
std::cout << "text_model: " << std::boolalpha
<< model_config.text_model << "\n";

// Print non-empty vectors
if (!model_config.files.empty()) {
std::cout << "files: [";
for (size_t i = 0; i < model_config.files.size(); ++i) {
std::cout << model_config.files[i];
if (i < model_config.files.size() - 1)
std::cout << ", ";
}
std::cout << "]\n";
}

// Print valid size_t number
if (model_config.created != 0)
std::cout << "created: " << model_config.created << "\n";

if (!model_config.object.empty())
std::cout << "object: " << model_config.object << "\n";
if (!model_config.owned_by.empty())
std::cout << "owned_by: " << model_config.owned_by << "\n";

found_model = true;
break;
} catch (const std::exception& e) {
LOG_ERROR << "Error reading yaml file '" << entry.path().string()
<< "': " << e.what();
}
}
}
}
}
}; // namespace commands
15 changes: 15 additions & 0 deletions engine/commands/model_get_cmd.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#pragma once

#include <cmath> // For std::isnan
#include <string>
namespace commands {

class ModelGetCmd {
public:
ModelGetCmd(std::string model_handle);
void Exec();

private:
std::string model_handle_;
};
} // namespace commands
10 changes: 10 additions & 0 deletions engine/controllers/command_line_parser.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#include "command_line_parser.h"
#include "commands/engine_init_cmd.h"
#include "commands/model_list_cmd.h"
#include "commands/model_get_cmd.h"

#include "commands/model_pull_cmd.h"
#include "commands/start_model_cmd.h"
#include "commands/stop_model_cmd.h"
Expand Down Expand Up @@ -51,6 +53,14 @@ bool CommandLineParser::SetupCommand(int argc, char** argv) {
command.Exec();
});

auto get_models_cmd =
models_cmd->add_subcommand("get", "Get info of {model_id} locally");
get_models_cmd->add_option("model_id", model_id, "");
get_models_cmd->callback([&model_id](){
commands::ModelGetCmd command(model_id);
command.Exec();
});

auto model_pull_cmd =
app_.add_subcommand("pull",
"Download a model from a registry. Working with "
Expand Down
66 changes: 66 additions & 0 deletions engine/controllers/models.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,4 +101,70 @@ void Models::ListModel(
auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret);
resp->setStatusCode(k200OK);
callback(resp);
}

void Models::GetModel(
const HttpRequestPtr& req,
std::function<void(const HttpResponsePtr&)>&& callback) const {
if (!http_util::HasFieldInReq(req, callback, "modelId")) {
return;
}
auto model_handle = (*(req->getJsonObject())).get("modelId", "").asString();
LOG_DEBUG << "GetModel, Model handle: " << model_handle;
Json::Value ret;
ret["object"] = "list";
Json::Value data(Json::arrayValue);
if (std::filesystem::exists(cortex_utils::models_folder) &&
std::filesystem::is_directory(cortex_utils::models_folder)) {
// Iterate through directory
for (const auto& entry :
std::filesystem::directory_iterator(cortex_utils::models_folder)) {
if (entry.is_regular_file() && entry.path().extension() == ".yaml" &&
entry.path().stem() == model_handle) {
try {
config::YamlHandler handler;
handler.ModelConfigFromFile(entry.path().string());
auto const& model_config = handler.GetModelConfig();
Json::Value obj;
obj["name"] = model_config.name;
obj["model"] = model_config.model;
obj["version"] = model_config.version;
Json::Value stop_array(Json::arrayValue);
for (const std::string& stop : model_config.stop)
stop_array.append(stop);
obj["stop"] = stop_array;
obj["top_p"] = model_config.top_p;
obj["temperature"] = model_config.temperature;
obj["presence_penalty"] = model_config.presence_penalty;
obj["max_tokens"] = model_config.max_tokens;
obj["stream"] = model_config.stream;
obj["ngl"] = model_config.ngl;
obj["ctx_len"] = model_config.ctx_len;
obj["engine"] = model_config.engine;
obj["prompt_template"] = model_config.prompt_template;

Json::Value files_array(Json::arrayValue);
for (const std::string& file : model_config.files)
files_array.append(file);
obj["files"] = files_array;
obj["id"] = model_config.id;
obj["created"] = static_cast<uint32_t>(model_config.created);
obj["object"] = model_config.object;
obj["owned_by"] = model_config.owned_by;
if (model_config.engine == "cortex.tensorrt-llm") {
obj["trtllm_version"] = model_config.trtllm_version;
}
data.append(std::move(obj));
} catch (const std::exception& e) {
LOG_ERROR << "Error reading yaml file '" << entry.path().string()
<< "': " << e.what();
}
}
}
}
ret["data"] = data;
ret["result"] = "OK";
auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret);
resp->setStatusCode(k200OK);
callback(resp);
}
3 changes: 3 additions & 0 deletions engine/controllers/models.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,13 @@ class Models : public drogon::HttpController<Models> {
METHOD_LIST_BEGIN
METHOD_ADD(Models::PullModel, "/pull", Post);
METHOD_ADD(Models::ListModel, "/list", Get);
METHOD_ADD(Models::GetModel, "/get", Post);
METHOD_LIST_END

void PullModel(const HttpRequestPtr& req,
std::function<void(const HttpResponsePtr&)>&& callback) const;
void ListModel(const HttpRequestPtr& req,
std::function<void(const HttpResponsePtr&)>&& callback) const;
void GetModel(const HttpRequestPtr& req,
std::function<void(const HttpResponsePtr&)>&& callback) const;
};

0 comments on commit ba6816f

Please sign in to comment.