Skip to content

Commit

Permalink
fix: use different interface for remote engine
Browse files Browse the repository at this point in the history
  • Loading branch information
vansangpfiev committed Dec 4, 2024
1 parent 90694c4 commit a7e4659
Show file tree
Hide file tree
Showing 8 changed files with 140 additions and 95 deletions.
10 changes: 5 additions & 5 deletions docs/static/openapi/cortex.json
Original file line number Diff line number Diff line change
Expand Up @@ -515,8 +515,8 @@
"/v1/models/add": {
"post": {
"operationId": "ModelsController_addModel",
"summary": "Add a model",
"description": "Add a new model configuration to the system.",
"summary": "Add a remote model",
"description": "Add a new remote model configuration to the system.",
"requestBody": {
"required": true,
"content": {
Expand Down Expand Up @@ -1509,17 +1509,17 @@
},
"type": {
"type": "string",
"description": "The type of connection",
"description": "The type of connection, remote or local",
"example": "remote"
},
"url": {
"type": "string",
"description": "The URL for the API endpoint",
"description": "The URL for the API endpoint for remote engine",
"example": "https://api.openai.com"
},
"api_key": {
"type": "string",
"description": "The API key for authentication",
"description": "The API key for authentication for remote engine",
"example": ""
},
"metadata": {
Expand Down
2 changes: 1 addition & 1 deletion engine/controllers/models.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class Models : public drogon::HttpController<Models, false> {
ADD_METHOD_TO(Models::StopModel, "/v1/models/stop", Options, Post);
ADD_METHOD_TO(Models::GetModelStatus, "/v1/models/status/{1}", Get);
ADD_METHOD_TO(Models::AddRemoteModel, "/v1/models/add", Options, Post);
ADD_METHOD_TO(Models::GetRemoteModels, "/v1/remote/{1}", Get);
ADD_METHOD_TO(Models::GetRemoteModels, "/v1/models/remote/{1}", Get);
METHOD_LIST_END

explicit Models(std::shared_ptr<ModelService> model_service,
Expand Down
37 changes: 37 additions & 0 deletions engine/cortex-common/remote_enginei.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#pragma once

#pragma once

#include <functional>
#include <memory>

#include "json/value.h"
#include "trantor/utils/Logger.h"
class RemoteEngineI {
public:
virtual ~RemoteEngineI() {}

virtual void HandleChatCompletion(
std::shared_ptr<Json::Value> json_body,
std::function<void(Json::Value&&, Json::Value&&)>&& callback) = 0;
virtual void HandleEmbedding(
std::shared_ptr<Json::Value> json_body,
std::function<void(Json::Value&&, Json::Value&&)>&& callback) = 0;
virtual void LoadModel(
std::shared_ptr<Json::Value> json_body,
std::function<void(Json::Value&&, Json::Value&&)>&& callback) = 0;
virtual void UnloadModel(
std::shared_ptr<Json::Value> json_body,
std::function<void(Json::Value&&, Json::Value&&)>&& callback) = 0;
virtual void GetModelStatus(
std::shared_ptr<Json::Value> json_body,
std::function<void(Json::Value&&, Json::Value&&)>&& callback) = 0;

// Get list of running models
virtual void GetModels(
std::shared_ptr<Json::Value> jsonBody,
std::function<void(Json::Value&&, Json::Value&&)>&& callback) = 0;

// Get available remote models
virtual Json::Value GetRemoteModels() = 0;
};
44 changes: 0 additions & 44 deletions engine/extensions/remote-engine/remote_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -263,9 +263,7 @@ CurlResponse RemoteEngine::MakeGetModelsRequest() {
std::string full_url = metadata_["get_models_url"].asString();

struct curl_slist* headers = nullptr;

headers = curl_slist_append(headers, api_key_template_.c_str());

headers = curl_slist_append(headers, "Content-Type: application/json");

curl_easy_setopt(curl, CURLOPT_URL, full_url.c_str());
Expand Down Expand Up @@ -304,7 +302,6 @@ CurlResponse RemoteEngine::MakeChatCompletionRequest(

struct curl_slist* headers = nullptr;
if (!config.api_key.empty()) {

headers = curl_slist_append(headers, api_key_template_.c_str());
}

Expand Down Expand Up @@ -707,50 +704,9 @@ void RemoteEngine::HandleEmbedding(
callback(Json::Value(), Json::Value());
}

bool RemoteEngine::IsSupported(const std::string& f) {
if (f == "HandleChatCompletion" || f == "LoadModel" || f == "UnloadModel" ||
f == "GetModelStatus" || f == "GetModels" || f == "SetFileLogger" ||
f == "SetLogLevel") {
return true;
}
return false;
}

bool RemoteEngine::SetFileLogger(int max_log_lines,
const std::string& log_path) {
if (!async_file_logger_) {
async_file_logger_ = std::make_unique<trantor::FileLogger>();
}

async_file_logger_->setFileName(log_path);
async_file_logger_->setMaxLines(max_log_lines); // Keep last 100000 lines
async_file_logger_->startLogging();
trantor::Logger::setOutputFunction(
[&](const char* msg, const uint64_t len) {
if (async_file_logger_)
async_file_logger_->output_(msg, len);
},
[&]() {
if (async_file_logger_)
async_file_logger_->flush();
});
freopen(log_path.c_str(), "w", stderr);
freopen(log_path.c_str(), "w", stdout);
return true;
}

void RemoteEngine::SetLogLevel(trantor::Logger::LogLevel log_level) {
trantor::Logger::setLogLevel(log_level);
}

Json::Value RemoteEngine::GetRemoteModels() {
CTL_WRN("Not implemented yet!");
return {};
}

extern "C" {
EngineI* get_engine() {
return new RemoteEngine();
}
}
} // namespace remote_engine
8 changes: 3 additions & 5 deletions engine/extensions/remote-engine/remote_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include <shared_mutex>
#include <string>
#include <unordered_map>
#include "cortex-common/EngineI.h"
#include "cortex-common/remote_enginei.h"
#include "extensions/remote-engine/template_renderer.h"
#include "utils/engine_constants.h"
#include "utils/file_logger.h"
Expand All @@ -31,7 +31,7 @@ struct CurlResponse {
std::string error_message;
};

class RemoteEngine : public EngineI {
class RemoteEngine : public RemoteEngineI {
protected:
// Model configuration
struct ModelConfig {
Expand Down Expand Up @@ -95,9 +95,7 @@ class RemoteEngine : public EngineI {
void HandleEmbedding(
std::shared_ptr<Json::Value> json_body,
std::function<void(Json::Value&&, Json::Value&&)>&& callback) override;
bool IsSupported(const std::string& feature) override;
bool SetFileLogger(int max_log_lines, const std::string& log_path) override;
void SetLogLevel(trantor::Logger::LogLevel logLevel) override;

Json::Value GetRemoteModels() override;
};

Expand Down
39 changes: 21 additions & 18 deletions engine/services/engine_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -694,21 +694,6 @@ cpp::result<void, std::string> EngineService::LoadEngine(
engines_[engine_name].engine = new remote_engine::AnthropicEngine();
}

auto& en = std::get<EngineI*>(engines_[ne].engine);
auto config = file_manager_utils::GetCortexConfig();
if (en->IsSupported("SetFileLogger")) {
en->SetFileLogger(config.maxLogLines,
(std::filesystem::path(config.logFolderPath) /
std::filesystem::path(config.logLlamaCppPath))
.string());
} else {
CTL_WRN("Method SetFileLogger is not supported yet");
}
if (en->IsSupported("SetLogLevel")) {
en->SetLogLevel(trantor::Logger::logLevel());
} else {
CTL_WRN("Method SetLogLevel is not supported yet");
}
CTL_INF("Loaded engine: " << engine_name);
return {};
}
Expand Down Expand Up @@ -883,8 +868,11 @@ cpp::result<void, std::string> EngineService::UnloadEngine(
if (!IsEngineLoaded(ne)) {
return cpp::fail("Engine " + ne + " is not loaded yet!");
}
EngineI* e = std::get<EngineI*>(engines_[ne].engine);
delete e;
if (std::holds_alternative<EngineI*>(engines_[ne].engine)) {
delete std::get<EngineI*>(engines_[ne].engine);
} else {
delete std::get<RemoteEngineI*>(engines_[ne].engine);
}

#if defined(_WIN32)
if (!RemoveDllDirectory(engines_[ne].cookie)) {
Expand Down Expand Up @@ -1100,7 +1088,22 @@ cpp::result<Json::Value, std::string> EngineService::GetRemoteModels(
return cpp::fail(r.error());
}

auto& e = std::get<EngineI*>(engines_[engine_name].engine);
if (!IsEngineLoaded(engine_name)) {
auto exist_engine = GetEngineByNameAndVariant(engine_name);
if (exist_engine.has_error()) {
return cpp::fail("Remote engine '" + engine_name + "' is not installed");
}

if (engine_name == kOpenAiEngine) {
engines_[engine_name].engine = new remote_engine::OpenAiEngine();
} else {
engines_[engine_name].engine = new remote_engine::AnthropicEngine();
}

CTL_INF("Loaded engine: " << engine_name);
}

auto& e = std::get<RemoteEngineI*>(engines_[engine_name].engine);
auto res = e->GetRemoteModels();
if (!res["error"].isNull()) {
return cpp::fail(res["error"].asString());
Expand Down
3 changes: 2 additions & 1 deletion engine/services/engine_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "common/engine_servicei.h"
#include "cortex-common/EngineI.h"
#include "cortex-common/cortexpythoni.h"
#include "cortex-common/remote_enginei.h"
#include "database/engines.h"
#include "extensions/remote-engine/remote_engine.h"
#include "services/download_service.h"
Expand All @@ -37,7 +38,7 @@ struct EngineUpdateResult {
}
};

using EngineV = std::variant<EngineI*, CortexPythonEngineI*>;
using EngineV = std::variant<EngineI*, CortexPythonEngineI*, RemoteEngineI*>;

class EngineService : public EngineServiceI {
private:
Expand Down
92 changes: 71 additions & 21 deletions engine/services/inference_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,26 @@ cpp::result<void, InferResult> InferenceService::HandleChatCompletion(
return cpp::fail(std::make_pair(stt, res));
}

auto engine = std::get<EngineI*>(engine_result.value());
engine->HandleChatCompletion(
json_body, [q, tool_choice](Json::Value status, Json::Value res) {
if (!tool_choice.isNull()) {
res["tool_choice"] = tool_choice;
}
q->push(std::make_pair(status, res));
});
if (std::holds_alternative<EngineI*>(engine_result.value())) {
std::get<EngineI*>(engine_result.value())
->HandleChatCompletion(
json_body, [q, tool_choice](Json::Value status, Json::Value res) {
if (!tool_choice.isNull()) {
res["tool_choice"] = tool_choice;
}
q->push(std::make_pair(status, res));
});
} else {
std::get<RemoteEngineI*>(engine_result.value())
->HandleChatCompletion(
json_body, [q, tool_choice](Json::Value status, Json::Value res) {
if (!tool_choice.isNull()) {
res["tool_choice"] = tool_choice;
}
q->push(std::make_pair(status, res));
});
}

return {};
}

Expand All @@ -53,10 +65,18 @@ cpp::result<void, InferResult> InferenceService::HandleEmbedding(
LOG_WARN << "Engine is not loaded yet";
return cpp::fail(std::make_pair(stt, res));
}
auto engine = std::get<EngineI*>(engine_result.value());
engine->HandleEmbedding(json_body, [q](Json::Value status, Json::Value res) {
q->push(std::make_pair(status, res));
});

if (std::holds_alternative<EngineI*>(engine_result.value())) {
std::get<EngineI*>(engine_result.value())
->HandleEmbedding(json_body, [q](Json::Value status, Json::Value res) {
q->push(std::make_pair(status, res));
});
} else {
std::get<RemoteEngineI*>(engine_result.value())
->HandleEmbedding(json_body, [q](Json::Value status, Json::Value res) {
q->push(std::make_pair(status, res));
});
}
return {};
}

Expand All @@ -83,11 +103,20 @@ InferResult InferenceService::LoadModel(

// might need mutex here
auto engine_result = engine_service_->GetLoadedEngine(engine_type);
auto engine = std::get<EngineI*>(engine_result.value());
engine->LoadModel(json_body, [&stt, &r](Json::Value status, Json::Value res) {
stt = status;
r = res;
});

if (std::holds_alternative<EngineI*>(engine_result.value())) {
std::get<EngineI*>(engine_result.value())
->LoadModel(json_body, [&stt, &r](Json::Value status, Json::Value res) {
stt = status;
r = res;
});
} else {
std::get<RemoteEngineI*>(engine_result.value())
->LoadModel(json_body, [&stt, &r](Json::Value status, Json::Value res) {
stt = status;
r = res;
});
}
return std::make_pair(stt, r);
}

Expand All @@ -110,12 +139,22 @@ InferResult InferenceService::UnloadModel(const std::string& engine_name,
json_body["model"] = model_id;

LOG_TRACE << "Start unload model";
auto engine = std::get<EngineI*>(engine_result.value());
engine->UnloadModel(std::make_shared<Json::Value>(json_body),
if (std::holds_alternative<EngineI*>(engine_result.value())) {
std::get<EngineI*>(engine_result.value())
->UnloadModel(std::make_shared<Json::Value>(json_body),
[&r, &stt](Json::Value status, Json::Value res) {
stt = status;
r = res;
});
} else {
std::get<RemoteEngineI*>(engine_result.value())
->UnloadModel(std::make_shared<Json::Value>(json_body),
[&r, &stt](Json::Value status, Json::Value res) {
stt = status;
r = res;
});
}

return std::make_pair(stt, r);
}

Expand All @@ -141,12 +180,23 @@ InferResult InferenceService::GetModelStatus(
}

LOG_TRACE << "Start to get model status";
auto engine = std::get<EngineI*>(engine_result.value());
engine->GetModelStatus(json_body,

if (std::holds_alternative<EngineI*>(engine_result.value())) {
std::get<EngineI*>(engine_result.value())
->GetModelStatus(json_body,
[&stt, &r](Json::Value status, Json::Value res) {
stt = status;
r = res;
});
} else {
std::get<RemoteEngineI*>(engine_result.value())
->GetModelStatus(json_body,
[&stt, &r](Json::Value status, Json::Value res) {
stt = status;
r = res;
});
}

return std::make_pair(stt, r);
}

Expand Down

0 comments on commit a7e4659

Please sign in to comment.