diff --git a/docs/static/openapi/cortex.json b/docs/static/openapi/cortex.json index f6f7b7145..96ce082e1 100644 --- a/docs/static/openapi/cortex.json +++ b/docs/static/openapi/cortex.json @@ -515,8 +515,8 @@ "/v1/models/add": { "post": { "operationId": "ModelsController_addModel", - "summary": "Add a model", - "description": "Add a new model configuration to the system.", + "summary": "Add a remote model", + "description": "Add a new remote model configuration to the system.", "requestBody": { "required": true, "content": { @@ -1509,17 +1509,17 @@ }, "type": { "type": "string", - "description": "The type of connection", + "description": "The type of connection, remote or local", "example": "remote" }, "url": { "type": "string", - "description": "The URL for the API endpoint", + "description": "The URL for the API endpoint for remote engine", "example": "https://api.openai.com" }, "api_key": { "type": "string", - "description": "The API key for authentication", + "description": "The API key for authentication for remote engine", "example": "" }, "metadata": { diff --git a/engine/controllers/models.h b/engine/controllers/models.h index 3227c0999..b2b288adc 100644 --- a/engine/controllers/models.h +++ b/engine/controllers/models.h @@ -35,7 +35,7 @@ class Models : public drogon::HttpController { ADD_METHOD_TO(Models::StopModel, "/v1/models/stop", Options, Post); ADD_METHOD_TO(Models::GetModelStatus, "/v1/models/status/{1}", Get); ADD_METHOD_TO(Models::AddRemoteModel, "/v1/models/add", Options, Post); - ADD_METHOD_TO(Models::GetRemoteModels, "/v1/remote/{1}", Get); + ADD_METHOD_TO(Models::GetRemoteModels, "/v1/models/remote/{1}", Get); METHOD_LIST_END explicit Models(std::shared_ptr model_service, diff --git a/engine/cortex-common/remote_enginei.h b/engine/cortex-common/remote_enginei.h new file mode 100644 index 000000000..81ffbf5cd --- /dev/null +++ b/engine/cortex-common/remote_enginei.h @@ -0,0 +1,37 @@ +#pragma once + +#pragma once + +#include +#include + +#include "json/value.h" +#include "trantor/utils/Logger.h" +class RemoteEngineI { + public: + virtual ~RemoteEngineI() {} + + virtual void HandleChatCompletion( + std::shared_ptr json_body, + std::function&& callback) = 0; + virtual void HandleEmbedding( + std::shared_ptr json_body, + std::function&& callback) = 0; + virtual void LoadModel( + std::shared_ptr json_body, + std::function&& callback) = 0; + virtual void UnloadModel( + std::shared_ptr json_body, + std::function&& callback) = 0; + virtual void GetModelStatus( + std::shared_ptr json_body, + std::function&& callback) = 0; + + // Get list of running models + virtual void GetModels( + std::shared_ptr jsonBody, + std::function&& callback) = 0; + + // Get available remote models + virtual Json::Value GetRemoteModels() = 0; +}; diff --git a/engine/extensions/remote-engine/remote_engine.cc b/engine/extensions/remote-engine/remote_engine.cc index d9aea2f41..04effb457 100644 --- a/engine/extensions/remote-engine/remote_engine.cc +++ b/engine/extensions/remote-engine/remote_engine.cc @@ -263,9 +263,7 @@ CurlResponse RemoteEngine::MakeGetModelsRequest() { std::string full_url = metadata_["get_models_url"].asString(); struct curl_slist* headers = nullptr; - headers = curl_slist_append(headers, api_key_template_.c_str()); - headers = curl_slist_append(headers, "Content-Type: application/json"); curl_easy_setopt(curl, CURLOPT_URL, full_url.c_str()); @@ -304,7 +302,6 @@ CurlResponse RemoteEngine::MakeChatCompletionRequest( struct curl_slist* headers = nullptr; if (!config.api_key.empty()) { - headers = curl_slist_append(headers, api_key_template_.c_str()); } @@ -707,50 +704,9 @@ void RemoteEngine::HandleEmbedding( callback(Json::Value(), Json::Value()); } -bool RemoteEngine::IsSupported(const std::string& f) { - if (f == "HandleChatCompletion" || f == "LoadModel" || f == "UnloadModel" || - f == "GetModelStatus" || f == "GetModels" || f == "SetFileLogger" || - f == "SetLogLevel") { - return true; - } - return false; -} - -bool RemoteEngine::SetFileLogger(int max_log_lines, - const std::string& log_path) { - if (!async_file_logger_) { - async_file_logger_ = std::make_unique(); - } - - async_file_logger_->setFileName(log_path); - async_file_logger_->setMaxLines(max_log_lines); // Keep last 100000 lines - async_file_logger_->startLogging(); - trantor::Logger::setOutputFunction( - [&](const char* msg, const uint64_t len) { - if (async_file_logger_) - async_file_logger_->output_(msg, len); - }, - [&]() { - if (async_file_logger_) - async_file_logger_->flush(); - }); - freopen(log_path.c_str(), "w", stderr); - freopen(log_path.c_str(), "w", stdout); - return true; -} - -void RemoteEngine::SetLogLevel(trantor::Logger::LogLevel log_level) { - trantor::Logger::setLogLevel(log_level); -} - Json::Value RemoteEngine::GetRemoteModels() { CTL_WRN("Not implemented yet!"); return {}; } -extern "C" { -EngineI* get_engine() { - return new RemoteEngine(); -} -} } // namespace remote_engine \ No newline at end of file diff --git a/engine/extensions/remote-engine/remote_engine.h b/engine/extensions/remote-engine/remote_engine.h index 153ec6408..8ce6fa652 100644 --- a/engine/extensions/remote-engine/remote_engine.h +++ b/engine/extensions/remote-engine/remote_engine.h @@ -7,7 +7,7 @@ #include #include #include -#include "cortex-common/EngineI.h" +#include "cortex-common/remote_enginei.h" #include "extensions/remote-engine/template_renderer.h" #include "utils/engine_constants.h" #include "utils/file_logger.h" @@ -31,7 +31,7 @@ struct CurlResponse { std::string error_message; }; -class RemoteEngine : public EngineI { +class RemoteEngine : public RemoteEngineI { protected: // Model configuration struct ModelConfig { @@ -95,9 +95,7 @@ class RemoteEngine : public EngineI { void HandleEmbedding( std::shared_ptr json_body, std::function&& callback) override; - bool IsSupported(const std::string& feature) override; - bool SetFileLogger(int max_log_lines, const std::string& log_path) override; - void SetLogLevel(trantor::Logger::LogLevel logLevel) override; + Json::Value GetRemoteModels() override; }; diff --git a/engine/services/engine_service.cc b/engine/services/engine_service.cc index 4634a0254..c91fd0dd0 100644 --- a/engine/services/engine_service.cc +++ b/engine/services/engine_service.cc @@ -694,21 +694,6 @@ cpp::result EngineService::LoadEngine( engines_[engine_name].engine = new remote_engine::AnthropicEngine(); } - auto& en = std::get(engines_[ne].engine); - auto config = file_manager_utils::GetCortexConfig(); - if (en->IsSupported("SetFileLogger")) { - en->SetFileLogger(config.maxLogLines, - (std::filesystem::path(config.logFolderPath) / - std::filesystem::path(config.logLlamaCppPath)) - .string()); - } else { - CTL_WRN("Method SetFileLogger is not supported yet"); - } - if (en->IsSupported("SetLogLevel")) { - en->SetLogLevel(trantor::Logger::logLevel()); - } else { - CTL_WRN("Method SetLogLevel is not supported yet"); - } CTL_INF("Loaded engine: " << engine_name); return {}; } @@ -883,8 +868,11 @@ cpp::result EngineService::UnloadEngine( if (!IsEngineLoaded(ne)) { return cpp::fail("Engine " + ne + " is not loaded yet!"); } - EngineI* e = std::get(engines_[ne].engine); - delete e; + if (std::holds_alternative(engines_[ne].engine)) { + delete std::get(engines_[ne].engine); + } else { + delete std::get(engines_[ne].engine); + } #if defined(_WIN32) if (!RemoveDllDirectory(engines_[ne].cookie)) { @@ -1100,7 +1088,22 @@ cpp::result EngineService::GetRemoteModels( return cpp::fail(r.error()); } - auto& e = std::get(engines_[engine_name].engine); + if (!IsEngineLoaded(engine_name)) { + auto exist_engine = GetEngineByNameAndVariant(engine_name); + if (exist_engine.has_error()) { + return cpp::fail("Remote engine '" + engine_name + "' is not installed"); + } + + if (engine_name == kOpenAiEngine) { + engines_[engine_name].engine = new remote_engine::OpenAiEngine(); + } else { + engines_[engine_name].engine = new remote_engine::AnthropicEngine(); + } + + CTL_INF("Loaded engine: " << engine_name); + } + + auto& e = std::get(engines_[engine_name].engine); auto res = e->GetRemoteModels(); if (!res["error"].isNull()) { return cpp::fail(res["error"].asString()); diff --git a/engine/services/engine_service.h b/engine/services/engine_service.h index 692f7d5f5..8c8bfbbe6 100644 --- a/engine/services/engine_service.h +++ b/engine/services/engine_service.h @@ -11,6 +11,7 @@ #include "common/engine_servicei.h" #include "cortex-common/EngineI.h" #include "cortex-common/cortexpythoni.h" +#include "cortex-common/remote_enginei.h" #include "database/engines.h" #include "extensions/remote-engine/remote_engine.h" #include "services/download_service.h" @@ -37,7 +38,7 @@ struct EngineUpdateResult { } }; -using EngineV = std::variant; +using EngineV = std::variant; class EngineService : public EngineServiceI { private: diff --git a/engine/services/inference_service.cc b/engine/services/inference_service.cc index 46309823d..ace7e675f 100644 --- a/engine/services/inference_service.cc +++ b/engine/services/inference_service.cc @@ -24,14 +24,26 @@ cpp::result InferenceService::HandleChatCompletion( return cpp::fail(std::make_pair(stt, res)); } - auto engine = std::get(engine_result.value()); - engine->HandleChatCompletion( - json_body, [q, tool_choice](Json::Value status, Json::Value res) { - if (!tool_choice.isNull()) { - res["tool_choice"] = tool_choice; - } - q->push(std::make_pair(status, res)); - }); + if (std::holds_alternative(engine_result.value())) { + std::get(engine_result.value()) + ->HandleChatCompletion( + json_body, [q, tool_choice](Json::Value status, Json::Value res) { + if (!tool_choice.isNull()) { + res["tool_choice"] = tool_choice; + } + q->push(std::make_pair(status, res)); + }); + } else { + std::get(engine_result.value()) + ->HandleChatCompletion( + json_body, [q, tool_choice](Json::Value status, Json::Value res) { + if (!tool_choice.isNull()) { + res["tool_choice"] = tool_choice; + } + q->push(std::make_pair(status, res)); + }); + } + return {}; } @@ -53,10 +65,18 @@ cpp::result InferenceService::HandleEmbedding( LOG_WARN << "Engine is not loaded yet"; return cpp::fail(std::make_pair(stt, res)); } - auto engine = std::get(engine_result.value()); - engine->HandleEmbedding(json_body, [q](Json::Value status, Json::Value res) { - q->push(std::make_pair(status, res)); - }); + + if (std::holds_alternative(engine_result.value())) { + std::get(engine_result.value()) + ->HandleEmbedding(json_body, [q](Json::Value status, Json::Value res) { + q->push(std::make_pair(status, res)); + }); + } else { + std::get(engine_result.value()) + ->HandleEmbedding(json_body, [q](Json::Value status, Json::Value res) { + q->push(std::make_pair(status, res)); + }); + } return {}; } @@ -83,11 +103,20 @@ InferResult InferenceService::LoadModel( // might need mutex here auto engine_result = engine_service_->GetLoadedEngine(engine_type); - auto engine = std::get(engine_result.value()); - engine->LoadModel(json_body, [&stt, &r](Json::Value status, Json::Value res) { - stt = status; - r = res; - }); + + if (std::holds_alternative(engine_result.value())) { + std::get(engine_result.value()) + ->LoadModel(json_body, [&stt, &r](Json::Value status, Json::Value res) { + stt = status; + r = res; + }); + } else { + std::get(engine_result.value()) + ->LoadModel(json_body, [&stt, &r](Json::Value status, Json::Value res) { + stt = status; + r = res; + }); + } return std::make_pair(stt, r); } @@ -110,12 +139,22 @@ InferResult InferenceService::UnloadModel(const std::string& engine_name, json_body["model"] = model_id; LOG_TRACE << "Start unload model"; - auto engine = std::get(engine_result.value()); - engine->UnloadModel(std::make_shared(json_body), + if (std::holds_alternative(engine_result.value())) { + std::get(engine_result.value()) + ->UnloadModel(std::make_shared(json_body), + [&r, &stt](Json::Value status, Json::Value res) { + stt = status; + r = res; + }); + } else { + std::get(engine_result.value()) + ->UnloadModel(std::make_shared(json_body), [&r, &stt](Json::Value status, Json::Value res) { stt = status; r = res; }); + } + return std::make_pair(stt, r); } @@ -141,12 +180,23 @@ InferResult InferenceService::GetModelStatus( } LOG_TRACE << "Start to get model status"; - auto engine = std::get(engine_result.value()); - engine->GetModelStatus(json_body, + + if (std::holds_alternative(engine_result.value())) { + std::get(engine_result.value()) + ->GetModelStatus(json_body, + [&stt, &r](Json::Value status, Json::Value res) { + stt = status; + r = res; + }); + } else { + std::get(engine_result.value()) + ->GetModelStatus(json_body, [&stt, &r](Json::Value status, Json::Value res) { stt = status; r = res; }); + } + return std::make_pair(stt, r); }