Skip to content

Commit

Permalink
Merge pull request #1695 from janhq/j/add-proxy-support
Browse files Browse the repository at this point in the history
feat: add proxy support
  • Loading branch information
namchuai authored Nov 18, 2024
2 parents 6892823 + b95b857 commit 0ffe3d4
Show file tree
Hide file tree
Showing 15 changed files with 729 additions and 735 deletions.
929 changes: 283 additions & 646 deletions docs/static/openapi/cortex.json

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions engine/cli/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ add_executable(${TARGET_NAME} main.cc
${CMAKE_CURRENT_SOURCE_DIR}/../utils/cpuid/cpu_info.cc
${CMAKE_CURRENT_SOURCE_DIR}/../utils/file_logger.cc
${CMAKE_CURRENT_SOURCE_DIR}/command_line_parser.cc
${CMAKE_CURRENT_SOURCE_DIR}/../services/config_service.cc
${CMAKE_CURRENT_SOURCE_DIR}/../services/download_service.cc
${CMAKE_CURRENT_SOURCE_DIR}/../services/engine_service.cc
${CMAKE_CURRENT_SOURCE_DIR}/../services/model_service.cc
Expand Down
34 changes: 14 additions & 20 deletions engine/cli/command_line_parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -326,14 +326,8 @@ void CommandLineParser::SetupModelCommands() {
void CommandLineParser::SetupConfigsCommands() {
auto config_cmd =
app_.add_subcommand("config", "Subcommands for managing configurations");
config_cmd->usage(
"Usage:\n" + commands::GetCortexBinary() +
" config status for listing all API server configuration.\n" +
commands::GetCortexBinary() +
" config --cors [on/off] to toggle CORS.\n" +
commands::GetCortexBinary() +
" config --allowed_origins [comma separated origin] to set a list of "
"allowed origin");
config_cmd->usage("Usage:\n" + commands::GetCortexBinary() +
" config [option] [value]");
config_cmd->group(kConfigGroup);
auto config_status_cmd =
config_cmd->add_subcommand("status", "Print all configurations");
Expand All @@ -344,18 +338,18 @@ void CommandLineParser::SetupConfigsCommands() {
std::stoi(cml_data_.config.apiServerPort));
});

// TODO: this can be improved
std::vector<std::string> avai_opts{"cors", "allowed_origins"};
std::unordered_map<std::string, std::string> description{
{"cors", "[on/off] Toggling CORS."},
{"allowed_origins",
"Allowed origins for CORS. Comma separated. E.g. "
"http://localhost,https://cortex.so"}};
for (const auto& opt : avai_opts) {
std::string option = "--" + opt;
config_cmd->add_option(option, config_update_opts_[opt], description[opt])
->expected(0, 1)
->default_str("*");
for (const auto& [key, opt] : CONFIGURATIONS) {
std::string option = "--" + opt.name;
auto option_cmd =
config_cmd->add_option(option, config_update_opts_[opt.name], opt.desc)
->group(opt.group)
->default_str(opt.default_value);

if (opt.allow_empty) {
option_cmd->expected(0, 1);
} else {
option_cmd->expected(1);
}
}

config_cmd->callback([this, config_cmd] {
Expand Down
30 changes: 21 additions & 9 deletions engine/cli/commands/config_upd_cmd.cc
Original file line number Diff line number Diff line change
@@ -1,35 +1,39 @@
#include "config_upd_cmd.h"
#include "commands/server_start_cmd.h"
#include "common/api_server_configuration.h"
#include "utils/curl_utils.h"
#include "utils/logging_utils.h"
#include "utils/string_utils.h"
#include "utils/url_parser.h"

namespace {
const std::vector<std::string> config_keys{"cors", "allowed_origins"};

inline Json::Value NormalizeJson(
const std::unordered_map<std::string, std::string> options) {
Json::Value root;
for (const auto& [key, value] : options) {
if (std::find(config_keys.begin(), config_keys.end(), key) ==
config_keys.end()) {
if (CONFIGURATIONS.find(key) == CONFIGURATIONS.end()) {
continue;
}
auto config = CONFIGURATIONS.at(key);

if (key == "cors") {
if (config.accept_value == "[on|off]") {
if (string_utils::EqualsIgnoreCase("on", value)) {
root["cors"] = true;
root[key] = true;
} else if (string_utils::EqualsIgnoreCase("off", value)) {
root["cors"] = false;
root[key] = false;
}
} else if (key == "allowed_origins") {
} else if (config.accept_value == "comma separated") {
auto origins = string_utils::SplitBy(value, ",");
Json::Value origin_array(Json::arrayValue);
for (const auto& origin : origins) {
origin_array.append(origin);
}
root[key] = origin_array;
} else if (config.accept_value == "string") {
root[key] = value;
} else {
CTL_ERR("Not support configuration type: " << config.accept_value
<< " for config key: " << key);
}
}

Expand All @@ -50,13 +54,21 @@ void commands::ConfigUpdCmd::Exec(
}
}

auto non_null_opts = std::unordered_map<std::string, std::string>();
for (const auto& [key, value] : options) {
if (value.empty()) {
continue;
}
non_null_opts[key] = value;
}

auto url = url_parser::Url{
.protocol = "http",
.host = host + ":" + std::to_string(port),
.pathParams = {"v1", "configs"},
};

auto json = NormalizeJson(options);
auto json = NormalizeJson(non_null_opts);
if (json.empty()) {
CLI_LOG_ERROR("Invalid configuration options provided");
return;
Expand Down
200 changes: 197 additions & 3 deletions engine/common/api_server_configuration.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,144 @@
#include <unordered_map>
#include <vector>

// current only support basic auth
enum class ProxyAuthMethod {
Basic,
Digest,
DigestIe,
Bearer,
Negotiate,
Ntlm,
NtlmWb,
Any,
AnySafe,
AuthOnly,
AwsSigV4
};

struct ApiConfigurationMetadata {
std::string name;
std::string desc;
std::string group;
std::string accept_value;
std::string default_value;

bool allow_empty = false;
};

static const std::unordered_map<std::string, ApiConfigurationMetadata>
CONFIGURATIONS = {
{"cors",
ApiConfigurationMetadata{
.name = "cors",
.desc = "Cross-Origin Resource Sharing configuration.",
.group = "CORS",
.accept_value = "[on|off]",
.default_value = "on"}},
{"allowed_origins",
ApiConfigurationMetadata{
.name = "allowed_origins",
.desc = "Allowed origins for CORS. Comma separated. E.g. "
"http://localhost,https://cortex.so",
.group = "CORS",
.accept_value = "comma separated",
.default_value = "*",
.allow_empty = true}},
{"proxy_url", ApiConfigurationMetadata{.name = "proxy_url",
.desc = "Proxy URL",
.group = "Proxy",
.accept_value = "string",
.default_value = ""}},
{"proxy_username", ApiConfigurationMetadata{.name = "proxy_username",
.desc = "Proxy Username",
.group = "Proxy",
.accept_value = "string",
.default_value = ""}},
{"proxy_password", ApiConfigurationMetadata{.name = "proxy_password",
.desc = "Proxy Password",
.group = "Proxy",
.accept_value = "string",
.default_value = ""}},
{"verify_proxy_ssl",
ApiConfigurationMetadata{.name = "verify_proxy_ssl",
.desc = "Verify SSL for proxy",
.group = "Proxy",
.accept_value = "[on|off]",
.default_value = "on"}},
{"verify_proxy_host_ssl",
ApiConfigurationMetadata{.name = "verify_proxy_host_ssl",
.desc = "Verify SSL for proxy",
.group = "Proxy",
.accept_value = "[on|off]",
.default_value = "on"}},
{"no_proxy", ApiConfigurationMetadata{.name = "no_proxy",
.desc = "No proxy for hosts",
.group = "Proxy",
.accept_value = "string",
.default_value = ""}},
{"verify_peer_ssl", ApiConfigurationMetadata{.name = "verify_peer_ssl",
.desc = "Verify peer SSL",
.group = "Proxy",
.accept_value = "[on|off]",
.default_value = "on"}},
{"verify_host_ssl", ApiConfigurationMetadata{.name = "verify_host_ssl",
.desc = "Verify host SSL",
.group = "Proxy",
.accept_value = "[on|off]",
.default_value = "on"}},
};

class ApiServerConfiguration {
public:
ApiServerConfiguration(bool cors = true,
std::vector<std::string> allowed_origins = {})
: cors{cors}, allowed_origins{allowed_origins} {}
ApiServerConfiguration(
bool cors = true, std::vector<std::string> allowed_origins = {},
bool verify_proxy_ssl = true, bool verify_proxy_host_ssl = true,
const std::string& proxy_url = "", const std::string& proxy_username = "",
const std::string& proxy_password = "", const std::string& no_proxy = "",
bool verify_peer_ssl = true, bool verify_host_ssl = true)
: cors{cors},
allowed_origins{allowed_origins},
verify_proxy_ssl{verify_proxy_ssl},
verify_proxy_host_ssl{verify_proxy_host_ssl},
proxy_url{proxy_url},
proxy_username{proxy_username},
proxy_password{proxy_password},
no_proxy{no_proxy},
verify_peer_ssl{verify_peer_ssl},
verify_host_ssl{verify_host_ssl} {}

// cors
bool cors{true};
std::vector<std::string> allowed_origins;

// proxy
bool verify_proxy_ssl{true};
bool verify_proxy_host_ssl{true};
ProxyAuthMethod proxy_auth_method{ProxyAuthMethod::Basic};
std::string proxy_url{""};
std::string proxy_username{""};
std::string proxy_password{""};
std::string no_proxy{""};

bool verify_peer_ssl{true};
bool verify_host_ssl{true};

Json::Value ToJson() const {
Json::Value root;
root["cors"] = cors;
root["allowed_origins"] = Json::Value(Json::arrayValue);
for (const auto& origin : allowed_origins) {
root["allowed_origins"].append(origin);
}
root["verify_proxy_ssl"] = verify_proxy_ssl;
root["verify_proxy_host_ssl"] = verify_proxy_host_ssl;
root["proxy_url"] = proxy_url;
root["proxy_username"] = proxy_username;
root["proxy_password"] = proxy_password;
root["no_proxy"] = no_proxy;
root["verify_peer_ssl"] = verify_peer_ssl;
root["verify_host_ssl"] = verify_host_ssl;

return root;
}

Expand All @@ -31,6 +153,78 @@ class ApiServerConfiguration {
const std::unordered_map<std::string,
std::function<bool(const Json::Value&)>>
field_updater{
{"verify_peer_ssl",
[this](const Json::Value& value) -> bool {
if (!value.isBool()) {
return false;
}
verify_peer_ssl = value.asBool();
return true;
}},

{"verify_host_ssl",
[this](const Json::Value& value) -> bool {
if (!value.isBool()) {
return false;
}
verify_host_ssl = value.asBool();
return true;
}},

{"verify_proxy_host_ssl",
[this](const Json::Value& value) -> bool {
if (!value.isBool()) {
return false;
}
verify_proxy_host_ssl = value.asBool();
return true;
}},

{"verify_proxy_ssl",
[this](const Json::Value& value) -> bool {
if (!value.isBool()) {
return false;
}
verify_proxy_ssl = value.asBool();
return true;
}},

{"no_proxy",
[this](const Json::Value& value) -> bool {
if (!value.isString()) {
return false;
}
no_proxy = value.asString();
return true;
}},

{"proxy_url",
[this](const Json::Value& value) -> bool {
if (!value.isString()) {
return false;
}
proxy_url = value.asString();
return true;
}},

{"proxy_username",
[this](const Json::Value& value) -> bool {
if (!value.isString()) {
return false;
}
proxy_username = value.asString();
return true;
}},

{"proxy_password",
[this](const Json::Value& value) -> bool {
if (!value.isString()) {
return false;
}
proxy_password = value.asString();
return true;
}},

{"cors",
[this](const Json::Value& value) -> bool {
if (!value.isBool()) {
Expand Down
7 changes: 5 additions & 2 deletions engine/config/yaml_config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "utils/file_manager_utils.h"
#include "utils/format_utils.h"
#include "yaml_config.h"

namespace config {
// Method to read YAML file
void YamlHandler::Reset() {
Expand Down Expand Up @@ -44,6 +45,7 @@ void YamlHandler::ReadYamlFile(const std::string& file_path) {
throw;
}
}

void YamlHandler::SplitPromptTemplate(ModelConfig& mc) {
if (mc.prompt_template.size() > 0) {
auto& pt = mc.prompt_template;
Expand Down Expand Up @@ -220,7 +222,7 @@ void YamlHandler::UpdateModelConfig(ModelConfig new_model_config) {
yaml_node_["ngl"] = model_config_.ngl;
if (!std::isnan(static_cast<double>(model_config_.ctx_len)))
yaml_node_["ctx_len"] = model_config_.ctx_len;
if (!std::isnan(static_cast<double>(model_config_.n_parallel)))
if (!std::isnan(static_cast<double>(model_config_.n_parallel)))
yaml_node_["n_parallel"] = model_config_.n_parallel;
if (!std::isnan(static_cast<double>(model_config_.tp)))
yaml_node_["tp"] = model_config_.tp;
Expand Down Expand Up @@ -377,7 +379,8 @@ void YamlHandler::WriteYamlFile(const std::string& file_path) const {
outFile << format_utils::writeKeyValue(
"ctx_len", yaml_node_["ctx_len"],
"llama.context_length | 0 or undefined = loaded from model");
outFile << format_utils::writeKeyValue("n_parallel", yaml_node_["n_parallel"]);
outFile << format_utils::writeKeyValue("n_parallel",
yaml_node_["n_parallel"]);
outFile << format_utils::writeKeyValue("ngl", yaml_node_["ngl"],
"Undefined = loaded from model");
outFile << "# END OPTIONAL\n";
Expand Down
Loading

0 comments on commit 0ffe3d4

Please sign in to comment.