diff --git a/docs/static/openapi/cortex.json b/docs/static/openapi/cortex.json index 206ee381d..9cdd5c7b4 100644 --- a/docs/static/openapi/cortex.json +++ b/docs/static/openapi/cortex.json @@ -512,6 +512,73 @@ } } }, + "/v1/models/add": { + "post": { + "operationId": "ModelsController_addModel", + "summary": "Add a remote model", + "description": "Add a new remote model configuration to the system.", + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/AddModelRequest" + } + } + } + }, + "responses": { + "200": { + "description": "Successful response", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "message": { + "type": "string" + }, + "model": { + "type": "object", + "properties": { + "model": { + "type": "string" + }, + "engine": { + "type": "string" + }, + "version": { + "type": "string" + } + } + } + } + }, + "example": { + "message": "Model added successfully!", + "model": { + "model": "claude-3-5-sonnet-20241022", + "engine": "anthropic", + "version": "2023-06-01" + } + } + } + } + }, + "400": { + "description": "Bad request", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/SimpleErrorResponse" + } + } + } + } + }, + "tags": ["Pulling Models"] + } + }, "/v1/models": { "get": { "operationId": "ModelsController_findAll", @@ -1417,7 +1484,7 @@ "required": true, "schema": { "type": "string", - "enum": ["llama-cpp", "onnxruntime", "tensorrt-llm"], + "enum": ["llama-cpp", "onnxruntime", "tensorrt-llm", "openai", "anthropic"], "default": "llama-cpp" }, "description": "The type of engine" @@ -1439,6 +1506,31 @@ "type": "string", "description": "The variant of the engine to install (optional)", "example": "mac-arm64" + }, + "type": { + "type": "string", + "description": "The type of connection, remote or local", + "example": "remote" + }, + "url": { + "type": "string", + "description": "The URL for the API endpoint for remote engine", + "example": "https://api.openai.com" + }, + "api_key": { + "type": "string", + "description": "The API key for authentication for remote engine", + "example": "" + }, + "metadata": { + "type": "object", + "properties": { + "get_models_url": { + "type": "string", + "description": "The URL to get models", + "example": "https://api.openai.com/v1/models" + } + } } } } @@ -1475,7 +1567,7 @@ "required": true, "schema": { "type": "string", - "enum": ["llama-cpp", "onnxruntime", "tensorrt-llm"], + "enum": ["llama-cpp", "onnxruntime", "tensorrt-llm", "openai", "anthropic"], "default": "llama-cpp" }, "description": "The type of engine" @@ -1690,7 +1782,7 @@ "required": true, "schema": { "type": "string", - "enum": ["llama-cpp", "onnxruntime", "tensorrt-llm"], + "enum": ["llama-cpp", "onnxruntime", "tensorrt-llm", "openai", "anthropic"], "default": "llama-cpp" }, "description": "The name of the engine to update" @@ -3636,6 +3728,109 @@ } } }, + "AddModelRequest": { + "type": "object", + "required": ["model", "engine", "version", "inference_params", "TransformReq", "TransformResp", "metadata"], + "properties": { + "model": { + "type": "string", + "description": "The identifier of the model." + }, + "api_key_template": { + "type": "string", + "description": "Template for the API key header." + }, + "engine": { + "type": "string", + "description": "The engine used for the model." + }, + "version": { + "type": "string", + "description": "The version of the model." + }, + "inference_params": { + "type": "object", + "properties": { + "temperature": { + "type": "number" + }, + "top_p": { + "type": "number" + }, + "frequency_penalty": { + "type": "number" + }, + "presence_penalty": { + "type": "number" + }, + "max_tokens": { + "type": "integer" + }, + "stream": { + "type": "boolean" + } + } + }, + "TransformReq": { + "type": "object", + "properties": { + "get_models": { + "type": "object" + }, + "chat_completions": { + "type": "object", + "properties": { + "url": { + "type": "string" + }, + "template": { + "type": "string" + } + } + }, + "embeddings": { + "type": "object" + } + } + }, + "TransformResp": { + "type": "object", + "properties": { + "chat_completions": { + "type": "object", + "properties": { + "template": { + "type": "string" + } + } + }, + "embeddings": { + "type": "object" + } + } + }, + "metadata": { + "type": "object", + "properties": { + "author": { + "type": "string" + }, + "description": { + "type": "string" + }, + "end_point": { + "type": "string" + }, + "logo": { + "type": "string" + }, + "api_key_url": { + "type": "string" + } + } + } + } + }, "CreateModelDto": { "type": "object", "properties": { @@ -4305,6 +4500,37 @@ "type": "integer", "description": "Number of GPU layers.", "example": 33 + }, + "api_key_template": { + "type": "string", + "description": "Template for the API key header." + }, + "version": { + "type": "string", + "description": "The version of the model." + }, + "inference_params": { + "type": "object", + "properties": { + "temperature": { + "type": "number" + }, + "top_p": { + "type": "number" + }, + "frequency_penalty": { + "type": "number" + }, + "presence_penalty": { + "type": "number" + }, + "max_tokens": { + "type": "integer" + }, + "stream": { + "type": "boolean" + } + } } } }, diff --git a/engine/CMakeLists.txt b/engine/CMakeLists.txt index eae09d439..7cac3421c 100644 --- a/engine/CMakeLists.txt +++ b/engine/CMakeLists.txt @@ -142,6 +142,10 @@ file(APPEND "${CMAKE_CURRENT_BINARY_DIR}/cortex_openapi.h" add_executable(${TARGET_NAME} main.cc ${CMAKE_CURRENT_SOURCE_DIR}/utils/cpuid/cpu_info.cc ${CMAKE_CURRENT_SOURCE_DIR}/utils/file_logger.cc + ${CMAKE_CURRENT_SOURCE_DIR}/extensions/remote-engine/remote_engine.cc + ${CMAKE_CURRENT_SOURCE_DIR}/extensions/remote-engine/openai_engine.cc + ${CMAKE_CURRENT_SOURCE_DIR}/extensions/remote-engine/anthropic_engine.cc + ${CMAKE_CURRENT_SOURCE_DIR}/extensions/remote-engine/template_renderer.cc ) target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_BINARY_DIR}) @@ -171,17 +175,17 @@ endif() aux_source_directory(controllers CTL_SRC) aux_source_directory(repositories REPO_SRC) aux_source_directory(services SERVICES_SRC) -aux_source_directory(common COMMON_SRC) aux_source_directory(models MODEL_SRC) aux_source_directory(cortex-common CORTEX_COMMON) aux_source_directory(config CONFIG_SRC) aux_source_directory(database DB_SRC) +aux_source_directory(extensions EX_SRC) aux_source_directory(migrations MIGR_SRC) aux_source_directory(utils UTILS_SRC) target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} ) -target_sources(${TARGET_NAME} PRIVATE ${UTILS_SRC} ${CONFIG_SRC} ${CTL_SRC} ${COMMON_SRC} ${SERVICES_SRC} ${DB_SRC} ${MIGR_SRC} ${REPO_SRC}) +target_sources(${TARGET_NAME} PRIVATE ${UTILS_SRC} ${CONFIG_SRC} ${CTL_SRC} ${COMMON_SRC} ${SERVICES_SRC} ${DB_SRC} ${EX_SRC} ${MIGR_SRC} ${REPO_SRC}) set_target_properties(${TARGET_NAME} PROPERTIES RUNTIME_OUTPUT_DIRECTORY_DEBUG ${CMAKE_BINARY_DIR} diff --git a/engine/cli/CMakeLists.txt b/engine/cli/CMakeLists.txt index 42d00ebd5..51382dc13 100644 --- a/engine/cli/CMakeLists.txt +++ b/engine/cli/CMakeLists.txt @@ -82,6 +82,10 @@ add_executable(${TARGET_NAME} main.cc ${CMAKE_CURRENT_SOURCE_DIR}/../services/model_service.cc ${CMAKE_CURRENT_SOURCE_DIR}/../services/inference_service.cc ${CMAKE_CURRENT_SOURCE_DIR}/../services/hardware_service.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../extensions/remote-engine/remote_engine.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../extensions/remote-engine/openai_engine.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../extensions/remote-engine/anthropic_engine.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../extensions/remote-engine/template_renderer.cc ${CMAKE_CURRENT_SOURCE_DIR}/utils/easywsclient.cc ${CMAKE_CURRENT_SOURCE_DIR}/utils/download_progress.cc ${CMAKE_CURRENT_SOURCE_DIR}/../utils/config_yaml_utils.cc @@ -121,11 +125,12 @@ aux_source_directory(../cortex-common CORTEX_COMMON) aux_source_directory(../config CONFIG_SRC) aux_source_directory(commands COMMANDS_SRC) aux_source_directory(../database DB_SRC) +aux_source_directory(../extensions EX_SRC) target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/.. ) target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) -target_sources(${TARGET_NAME} PRIVATE ${COMMANDS_SRC} ${CONFIG_SRC} ${COMMON_SRC} ${DB_SRC}) +target_sources(${TARGET_NAME} PRIVATE ${COMMANDS_SRC} ${CONFIG_SRC} ${COMMON_SRC} ${DB_SRC} ${EX_SRC}) set_target_properties(${TARGET_NAME} PROPERTIES RUNTIME_OUTPUT_DIRECTORY_DEBUG ${CMAKE_BINARY_DIR} diff --git a/engine/common/engine_servicei.h b/engine/common/engine_servicei.h index bd4f099ab..85fa87d76 100644 --- a/engine/common/engine_servicei.h +++ b/engine/common/engine_servicei.h @@ -3,8 +3,8 @@ #include #include #include +#include "database/engines.h" #include "utils/result.hpp" - // TODO: namh think of the other name struct DefaultEngineVariant { std::string engine; @@ -54,4 +54,8 @@ class EngineServiceI { virtual cpp::result UnloadEngine( const std::string& engine_name) = 0; + virtual cpp::result + GetEngineByNameAndVariant( + const std::string& engine_name, + const std::optional variant = std::nullopt) = 0; }; diff --git a/engine/config/model_config.h b/engine/config/model_config.h index 7d4076ee5..701547873 100644 --- a/engine/config/model_config.h +++ b/engine/config/model_config.h @@ -1,13 +1,194 @@ #pragma once #include +#include +#include +#include +#include #include #include +#include #include #include #include "utils/format_utils.h" +#include "utils/remote_models_utils.h" +#include "yaml-cpp/yaml.h" namespace config { + +namespace { +const std::string kOpenAITransformReqTemplate = + R"({ {% set first = true %} {% for key, value in input_request %} {% if key == \"messages\" or key == \"model\" or key == \"temperature\" or key == \"store\" or key == \"max_tokens\" or key == \"stream\" or key == \"presence_penalty\" or key == \"metadata\" or key == \"frequency_penalty\" or key == \"tools\" or key == \"tool_choice\" or key == \"logprobs\" or key == \"top_logprobs\" or key == \"logit_bias\" or key == \"n\" or key == \"modalities\" or key == \"prediction\" or key == \"response_format\" or key == \"service_tier\" or key == \"seed\" or key == \"stop\" or key == \"stream_options\" or key == \"top_p\" or key == \"parallel_tool_calls\" or key == \"user\" %} {% if not first %},{% endif %} \"{{ key }}\": {{ tojson(value) }} {% set first = false %} {% endif %} {% endfor %} })"; +const std::string kOpenAITransformRespTemplate = + R"({ {%- set first = true -%} {%- for key, value in input_request -%} {%- if key == \"id\" or key == \"choices\" or key == \"created\" or key == \"model\" or key == \"service_tier\" or key == \"system_fingerprint\" or key == \"object\" or key == \"usage\" -%} {%- if not first -%},{%- endif -%} \"{{ key }}\": {{ tojson(value) }} {%- set first = false -%} {%- endif -%} {%- endfor -%} })"; +const std::string kAnthropicTransformReqTemplate = + R"({ {% set first = true %} {% for key, value in input_request %} {% if key == \"system\" or key == \"messages\" or key == \"model\" or key == \"temperature\" or key == \"store\" or key == \"max_tokens\" or key == \"stream\" or key == \"presence_penalty\" or key == \"metadata\" or key == \"frequency_penalty\" or key == \"tools\" or key == \"tool_choice\" or key == \"logprobs\" or key == \"top_logprobs\" or key == \"logit_bias\" or key == \"n\" or key == \"modalities\" or key == \"prediction\" or key == \"response_format\" or key == \"service_tier\" or key == \"seed\" or key == \"stop\" or key == \"stream_options\" or key == \"top_p\" or key == \"parallel_tool_calls\" or key == \"user\" %} {% if not first %},{% endif %} \"{{ key }}\": {{ tojson(value) }} {% set first = false %} {% endif %} {% endfor %} })"; +const std::string kAnthropicTransformRespTemplate = R"({ + "id": "{{ input_request.id }}", + "created": null, + "object": "chat.completion", + "model": "{{ input_request.model }}", + "choices": [ + { + "index": 0, + "message": { + "role": "{{ input_request.role }}", + "content": "{% if input_request.content and input_request.content.0.type == "text" %} {{input_request.content.0.text}} {% endif %}", + "refusal": null + }, + "logprobs": null, + "finish_reason": "{{ input_request.stop_reason }}" + } + ], + "usage": { + "prompt_tokens": {{ input_request.usage.input_tokens }}, + "completion_tokens": {{ input_request.usage.output_tokens }}, + "total_tokens": {{ input_request.usage.input_tokens + input_request.usage.output_tokens }}, + "prompt_tokens_details": { + "cached_tokens": 0 + }, + "completion_tokens_details": { + "reasoning_tokens": 0, + "accepted_prediction_tokens": 0, + "rejected_prediction_tokens": 0 + } + }, + "system_fingerprint": "fp_6b68a8204b" + })"; +} // namespace + +struct RemoteModelConfig { + std::string model; + std::string api_key_template; + std::string engine; + std::string version; + std::size_t created; + std::string object = "model"; + std::string owned_by = ""; + Json::Value inference_params; + Json::Value TransformReq; + Json::Value TransformResp; + Json::Value metadata; + void LoadFromJson(const Json::Value& json) { + if (!json.isObject()) { + throw std::runtime_error("Input JSON must be an object"); + } + + // Load basic string fields + model = json.get("model", model).asString(); + api_key_template = + json.get("api_key_template", api_key_template).asString(); + engine = json.get("engine", engine).asString(); + version = json.get("version", version).asString(); + created = + json.get("created", static_cast(created)).asUInt64(); + object = json.get("object", object).asString(); + owned_by = json.get("owned_by", owned_by).asString(); + + // Load JSON object fields directly + inference_params = json.get("inference_params", inference_params); + TransformReq = json.get("TransformReq", TransformReq); + // Use default template if it is empty, currently we only support 2 remote engines + auto is_anthropic = [](const std::string& model) { + return model.find("claude") != std::string::npos; + }; + if (TransformReq["chat_completions"]["template"].isNull()) { + if (is_anthropic(model)) { + TransformReq["chat_completions"]["template"] = + kAnthropicTransformReqTemplate; + } else { + TransformReq["chat_completions"]["template"] = + kOpenAITransformReqTemplate; + } + } + TransformResp = json.get("TransformResp", TransformResp); + if (TransformResp["chat_completions"]["template"].isNull()) { + if (is_anthropic(model)) { + TransformResp["chat_completions"]["template"] = + kAnthropicTransformRespTemplate; + } else { + TransformResp["chat_completions"]["template"] = + kOpenAITransformRespTemplate; + } + } + metadata = json.get("metadata", metadata); + } + + Json::Value ToJson() const { + Json::Value json; + + // Add basic string fields + json["model"] = model; + json["api_key_template"] = api_key_template; + json["engine"] = engine; + json["version"] = version; + json["created"] = static_cast(created); + json["object"] = object; + json["owned_by"] = owned_by; + + // Add JSON object fields directly + json["inference_params"] = inference_params; + json["TransformReq"] = TransformReq; + json["TransformResp"] = TransformResp; + json["metadata"] = metadata; + + return json; + }; + + void SaveToYamlFile(const std::string& filepath) const { + YAML::Node root; + + // Convert basic fields + root["model"] = model; + root["api_key_template"] = api_key_template; + root["engine"] = engine; + root["version"] = version; + root["object"] = object; + root["owned_by"] = owned_by; + root["created"] = std::time(nullptr); + + // Convert Json::Value to YAML::Node using utility function + root["inference_params"] = + remote_models_utils::jsonToYaml(inference_params); + root["TransformReq"] = remote_models_utils::jsonToYaml(TransformReq); + root["TransformResp"] = remote_models_utils::jsonToYaml(TransformResp); + root["metadata"] = remote_models_utils::jsonToYaml(metadata); + + // Save to file + std::ofstream fout(filepath); + if (!fout.is_open()) { + throw std::runtime_error("Failed to open file for writing: " + filepath); + } + fout << root; + } + + void LoadFromYamlFile(const std::string& filepath) { + YAML::Node root; + try { + root = YAML::LoadFile(filepath); + } catch (const YAML::Exception& e) { + throw std::runtime_error("Failed to parse YAML file: " + + std::string(e.what())); + } + + // Load basic fields + model = root["model"].as(""); + api_key_template = root["api_key_template"].as(""); + engine = root["engine"].as(""); + version = root["version"] ? root["version"].as() : ""; + created = root["created"] ? root["created"].as() : 0; + object = root["object"] ? root["object"].as() : "model"; + owned_by = root["owned_by"] ? root["owned_by"].as() : ""; + + // Load complex fields using utility function + inference_params = + remote_models_utils::yamlToJson(root["inference_params"]); + TransformReq = remote_models_utils::yamlToJson(root["TransformReq"]); + TransformResp = remote_models_utils::yamlToJson(root["TransformResp"]); + metadata = remote_models_utils::yamlToJson(root["metadata"]); + } +}; + struct ModelConfig { std::string name; std::string model; diff --git a/engine/controllers/engines.cc b/engine/controllers/engines.cc index 9e110bd66..3d3c0c037 100644 --- a/engine/controllers/engines.cc +++ b/engine/controllers/engines.cc @@ -3,9 +3,9 @@ #include "utils/archive_utils.h" #include "utils/cortex_utils.h" #include "utils/engine_constants.h" +#include "utils/http_util.h" #include "utils/logging_utils.h" #include "utils/string_utils.h" - namespace { // Need to change this after we rename repositories std::string NormalizeEngine(const std::string& engine) { @@ -38,6 +38,18 @@ void Engines::ListEngine( } ret[engine] = variants; } + // Add remote engine + auto remote_engines = engine_service_->GetEngines(); + if (remote_engines.has_value()) { + for (auto engine : remote_engines.value()) { + if (engine.type == "remote") { + auto engine_json = engine.ToJson(); + Json::Value list_engine(Json::arrayValue); + list_engine.append(engine_json); + ret[engine.engine_name] = list_engine; + } + } + } auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); resp->setStatusCode(k200OK); @@ -162,6 +174,86 @@ void Engines::InstallEngine( norm_version = version; } + if ((req->getJsonObject()) && + (*(req->getJsonObject())).get("type", "").asString() == "remote") { + auto type = (*(req->getJsonObject())).get("type", "").asString(); + auto api_key = (*(req->getJsonObject())).get("api_key", "").asString(); + auto url = (*(req->getJsonObject())).get("url", "").asString(); + auto variant = norm_variant.value_or("all-platforms"); + auto status = (*(req->getJsonObject())).get("status", "Default").asString(); + std::string metadata; + if ((*(req->getJsonObject())).isMember("metadata") && + (*(req->getJsonObject()))["metadata"].isObject()) { + metadata = (*(req->getJsonObject())) + .get("metadata", Json::Value(Json::objectValue)) + .toStyledString(); + } else if ((*(req->getJsonObject())).isMember("metadata") && + !(*(req->getJsonObject()))["metadata"].isObject()) { + Json::Value res; + res["message"] = "metadata must be object"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(res); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + auto get_models_url = (*(req->getJsonObject())) + .get("metadata", Json::Value(Json::objectValue)) + .get("get_models_url", "") + .asString(); + + if (engine.empty() || type.empty() || url.empty()) { + Json::Value res; + res["message"] = "Engine name, type, url are required"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(res); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + auto exist_engine = engine_service_->GetEngineByNameAndVariant(engine); + // only allow 1 variant 1 version of a remote engine name + if (exist_engine.has_value()) { + Json::Value res; + if (get_models_url.empty()) { + res["warning"] = + "'get_models_url' not found in metadata, You'll not able to search " + "remote models with this engine"; + } + res["message"] = "Engine '" + engine + "' already exists"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(res); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + auto result = engine_service_->UpsertEngine( + engine, type, api_key, url, norm_version, variant, status, metadata); + if (result.has_error()) { + Json::Value res; + if (get_models_url.empty()) { + res["warning"] = + "'get_models_url' not found in metadata, You'll not able to search " + "remote models with this engine"; + } + res["message"] = result.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(res); + resp->setStatusCode(k400BadRequest); + callback(resp); + } else { + Json::Value res; + if (get_models_url.empty()) { + res["warning"] = + "'get_models_url' not found in metadata, You'll not able to search " + "remote models with this engine"; + } + res["message"] = "Remote Engine install successfully!"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(res); + resp->setStatusCode(k200OK); + callback(resp); + } + return; + } + auto result = engine_service_->InstallEngineAsync(engine, norm_version, norm_variant); if (result.has_error()) { @@ -169,12 +261,14 @@ void Engines::InstallEngine( res["message"] = result.error(); auto resp = cortex_utils::CreateCortexHttpJsonResponse(res); resp->setStatusCode(k400BadRequest); + CTL_INF("Error: " << result.error()); callback(resp); } else { Json::Value res; res["message"] = "Engine starts installing!"; auto resp = cortex_utils::CreateCortexHttpJsonResponse(res); resp->setStatusCode(k200OK); + CTL_INF("Engine starts installing!"); callback(resp); } } diff --git a/engine/controllers/models.cc b/engine/controllers/models.cc index 2760663d0..de14886da 100644 --- a/engine/controllers/models.cc +++ b/engine/controllers/models.cc @@ -7,6 +7,7 @@ #include "models.h" #include "trantor/utils/Logger.h" #include "utils/cortex_utils.h" +#include "utils/engine_constants.h" #include "utils/file_manager_utils.h" #include "utils/http_util.h" #include "utils/logging_utils.h" @@ -176,15 +177,29 @@ void Models::ListModel( fs::path(model_entry.path_to_model_yaml)) .string()); auto model_config = yaml_handler.GetModelConfig(); - Json::Value obj = model_config.ToJson(); - obj["id"] = model_entry.model; - obj["model"] = model_entry.model; - auto es = model_service_->GetEstimation(model_entry.model); - if (es.has_value()) { - obj["recommendation"] = hardware::ToJson(es.value()); + + if (!remote_engine::IsRemoteEngine(model_config.engine)) { + Json::Value obj = model_config.ToJson(); + obj["id"] = model_entry.model; + obj["model"] = model_entry.model; + obj["model"] = model_entry.model; + auto es = model_service_->GetEstimation(model_entry.model); + if (es.has_value()) { + obj["recommendation"] = hardware::ToJson(es.value()); + } + data.append(std::move(obj)); + yaml_handler.Reset(); + } else { + config::RemoteModelConfig remote_model_config; + remote_model_config.LoadFromYamlFile( + fmu::ToAbsoluteCortexDataPath( + fs::path(model_entry.path_to_model_yaml)) + .string()); + Json::Value obj = remote_model_config.ToJson(); + obj["id"] = model_entry.model; + obj["model"] = model_entry.model; + data.append(std::move(obj)); } - data.append(std::move(obj)); - yaml_handler.Reset(); } catch (const std::exception& e) { LOG_ERROR << "Failed to load yaml file for model: " << model_entry.path_to_model_yaml << ", error: " << e.what(); @@ -232,16 +247,34 @@ void Models::GetModel(const HttpRequestPtr& req, callback(resp); return; } + yaml_handler.ModelConfigFromFile( fmu::ToAbsoluteCortexDataPath( fs::path(model_entry.value().path_to_model_yaml)) .string()); auto model_config = yaml_handler.GetModelConfig(); + if (model_config.engine == kOnnxEngine || + model_config.engine == kLlamaEngine || + model_config.engine == kTrtLlmEngine) { + auto ret = model_config.ToJsonString(); + auto resp = cortex_utils::CreateCortexHttpTextAsJsonResponse(ret); + resp->setStatusCode(drogon::k200OK); + callback(resp); + } else { + config::RemoteModelConfig remote_model_config; + remote_model_config.LoadFromYamlFile( + fmu::ToAbsoluteCortexDataPath( + fs::path(model_entry.value().path_to_model_yaml)) + .string()); + ret = remote_model_config.ToJson(); + ret["id"] = remote_model_config.model; + ret["object"] = "model"; + ret["result"] = "OK"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k200OK); + callback(resp); + } - auto ret = model_config.ToJsonString(); - auto resp = cortex_utils::CreateCortexHttpTextAsJsonResponse(ret); - resp->setStatusCode(drogon::k200OK); - callback(resp); } catch (const std::exception& e) { std::string message = "Fail to get model information with ID '" + model_id + "': " + e.what(); @@ -289,11 +322,23 @@ void Models::UpdateModel(const HttpRequestPtr& req, fs::path(model_entry.value().path_to_model_yaml)); yaml_handler.ModelConfigFromFile(yaml_fp.string()); config::ModelConfig model_config = yaml_handler.GetModelConfig(); - model_config.FromJson(json_body); - yaml_handler.UpdateModelConfig(model_config); - yaml_handler.WriteYamlFile(yaml_fp.string()); - std::string message = "Successfully update model ID '" + model_id + - "': " + json_body.toStyledString(); + std::string message; + if (model_config.engine == kOnnxEngine || + model_config.engine == kLlamaEngine || + model_config.engine == kTrtLlmEngine) { + model_config.FromJson(json_body); + yaml_handler.UpdateModelConfig(model_config); + yaml_handler.WriteYamlFile(yaml_fp.string()); + message = "Successfully update model ID '" + model_id + + "': " + json_body.toStyledString(); + } else { + config::RemoteModelConfig remote_model_config; + remote_model_config.LoadFromYamlFile(yaml_fp.string()); + remote_model_config.LoadFromJson(json_body); + remote_model_config.SaveToYamlFile(yaml_fp.string()); + message = "Successfully update model ID '" + model_id + + "': " + json_body.toStyledString(); + } LOG_INFO << message; Json::Value ret; ret["result"] = "Updated successfully!"; @@ -344,8 +389,10 @@ void Models::ImportModel( // Use relative path for model_yaml_path. In case of import, we use absolute path for model auto yaml_rel_path = fmu::ToRelativeCortexDataPath(fs::path(model_yaml_path)); - cortex::db::ModelEntry model_entry{modelHandle, "local", "imported", - yaml_rel_path.string(), modelHandle}; + cortex::db::ModelEntry model_entry{ + modelHandle, "", "", yaml_rel_path.string(), + modelHandle, "local", "imported", cortex::db::ModelStatus::Downloaded, + ""}; std::filesystem::create_directories( std::filesystem::path(model_yaml_path).parent_path()); @@ -558,3 +605,122 @@ void Models::GetModelStatus( callback(resp); } } + +void Models::GetRemoteModels( + const HttpRequestPtr& req, + std::function&& callback, + const std::string& engine_id) { + if (!remote_engine::IsRemoteEngine(engine_id)) { + Json::Value ret; + ret["message"] = "Not a remote engine: " + engine_id; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(drogon::k400BadRequest); + callback(resp); + return; + } + + auto result = engine_service_->GetRemoteModels(engine_id); + + if (result.has_error()) { + Json::Value ret; + ret["message"] = result.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(drogon::k400BadRequest); + callback(resp); + } else { + auto resp = cortex_utils::CreateCortexHttpJsonResponse(result.value()); + resp->setStatusCode(k200OK); + callback(resp); + } +} + +void Models::AddRemoteModel( + const HttpRequestPtr& req, + std::function&& callback) const { + namespace fs = std::filesystem; + namespace fmu = file_manager_utils; + if (!http_util::HasFieldInReq(req, callback, "model") || + !http_util::HasFieldInReq(req, callback, "engine")) { + return; + } + + auto model_handle = (*(req->getJsonObject())).get("model", "").asString(); + auto engine_name = (*(req->getJsonObject())).get("engine", "").asString(); + /* To do: uncomment when remote engine is ready + + auto engine_validate = engine_service_->IsEngineReady(engine_name); + if (engine_validate.has_error()) { + Json::Value ret; + ret["message"] = engine_validate.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(drogon::k400BadRequest); + callback(resp); + return; + } + if (!engine_validate.value()) { + Json::Value ret; + ret["message"] = "Engine is not ready! Please install first!"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(drogon::k400BadRequest); + callback(resp); + return; + } + */ + config::RemoteModelConfig model_config; + model_config.LoadFromJson(*(req->getJsonObject())); + cortex::db::Models modellist_utils_obj; + std::string model_yaml_path = (file_manager_utils::GetModelsContainerPath() / + std::filesystem::path("remote") / + std::filesystem::path(model_handle + ".yml")) + .string(); + try { + // Use relative path for model_yaml_path. In case of import, we use absolute path for model + auto yaml_rel_path = + fmu::ToRelativeCortexDataPath(fs::path(model_yaml_path)); + // TODO: remove hardcode "openai" when engine is finish + cortex::db::ModelEntry model_entry{ + model_handle, "", "", yaml_rel_path.string(), + model_handle, "remote", "imported", cortex::db::ModelStatus::Remote, + "openai"}; + std::filesystem::create_directories( + std::filesystem::path(model_yaml_path).parent_path()); + if (modellist_utils_obj.AddModelEntry(model_entry).value()) { + model_config.SaveToYamlFile(model_yaml_path); + std::string success_message = "Model is imported successfully!"; + LOG_INFO << success_message; + Json::Value ret; + ret["result"] = "OK"; + ret["modelHandle"] = model_handle; + ret["message"] = success_message; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k200OK); + callback(resp); + + } else { + std::string error_message = "Fail to import model, model_id '" + + model_handle + "' already exists!"; + LOG_ERROR << error_message; + Json::Value ret; + ret["result"] = "Import failed!"; + ret["modelHandle"] = model_handle; + ret["message"] = error_message; + + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + } + } catch (const std::exception& e) { + std::string error_message = + "Error while adding Remote model with model_id '" + model_handle + + "': " + e.what(); + LOG_ERROR << error_message; + Json::Value ret; + ret["result"] = "Add failed!"; + ret["modelHandle"] = model_handle; + ret["message"] = error_message; + + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + } +} \ No newline at end of file diff --git a/engine/controllers/models.h b/engine/controllers/models.h index da6caf024..b2b288adc 100644 --- a/engine/controllers/models.h +++ b/engine/controllers/models.h @@ -21,6 +21,8 @@ class Models : public drogon::HttpController { METHOD_ADD(Models::StartModel, "/start", Options, Post); METHOD_ADD(Models::StopModel, "/stop", Options, Post); METHOD_ADD(Models::GetModelStatus, "/status/{1}", Get); + METHOD_ADD(Models::AddRemoteModel, "/add", Options, Post); + METHOD_ADD(Models::GetRemoteModels, "/remote/{1}", Get); ADD_METHOD_TO(Models::PullModel, "/v1/models/pull", Options, Post); ADD_METHOD_TO(Models::AbortPullModel, "/v1/models/pull", Options, Delete); @@ -32,6 +34,8 @@ class Models : public drogon::HttpController { ADD_METHOD_TO(Models::StartModel, "/v1/models/start", Options, Post); ADD_METHOD_TO(Models::StopModel, "/v1/models/stop", Options, Post); ADD_METHOD_TO(Models::GetModelStatus, "/v1/models/status/{1}", Get); + ADD_METHOD_TO(Models::AddRemoteModel, "/v1/models/add", Options, Post); + ADD_METHOD_TO(Models::GetRemoteModels, "/v1/models/remote/{1}", Get); METHOD_LIST_END explicit Models(std::shared_ptr model_service, @@ -56,6 +60,9 @@ class Models : public drogon::HttpController { void ImportModel( const HttpRequestPtr& req, std::function&& callback) const; + void AddRemoteModel( + const HttpRequestPtr& req, + std::function&& callback) const; void DeleteModel(const HttpRequestPtr& req, std::function&& callback, const std::string& model_id); @@ -73,6 +80,10 @@ class Models : public drogon::HttpController { std::function&& callback, const std::string& model_id); + void GetRemoteModels(const HttpRequestPtr& req, + std::function&& callback, + const std::string& engine_id); + private: std::shared_ptr model_service_; std::shared_ptr engine_service_; diff --git a/engine/cortex-common/EngineI.h b/engine/cortex-common/EngineI.h index 95ce605de..51e19c124 100644 --- a/engine/cortex-common/EngineI.h +++ b/engine/cortex-common/EngineI.h @@ -37,4 +37,6 @@ class EngineI { virtual bool SetFileLogger(int max_log_lines, const std::string& log_path) = 0; virtual void SetLogLevel(trantor::Logger::LogLevel logLevel) = 0; + + virtual Json::Value GetRemoteModels() = 0; }; diff --git a/engine/cortex-common/remote_enginei.h b/engine/cortex-common/remote_enginei.h new file mode 100644 index 000000000..81ffbf5cd --- /dev/null +++ b/engine/cortex-common/remote_enginei.h @@ -0,0 +1,37 @@ +#pragma once + +#pragma once + +#include +#include + +#include "json/value.h" +#include "trantor/utils/Logger.h" +class RemoteEngineI { + public: + virtual ~RemoteEngineI() {} + + virtual void HandleChatCompletion( + std::shared_ptr json_body, + std::function&& callback) = 0; + virtual void HandleEmbedding( + std::shared_ptr json_body, + std::function&& callback) = 0; + virtual void LoadModel( + std::shared_ptr json_body, + std::function&& callback) = 0; + virtual void UnloadModel( + std::shared_ptr json_body, + std::function&& callback) = 0; + virtual void GetModelStatus( + std::shared_ptr json_body, + std::function&& callback) = 0; + + // Get list of running models + virtual void GetModels( + std::shared_ptr jsonBody, + std::function&& callback) = 0; + + // Get available remote models + virtual Json::Value GetRemoteModels() = 0; +}; diff --git a/engine/database/engines.cc b/engine/database/engines.cc new file mode 100644 index 000000000..a4d13ef79 --- /dev/null +++ b/engine/database/engines.cc @@ -0,0 +1,173 @@ +#include "engines.h" +#include +#include "database.h" + +namespace cortex::db { + +void CreateTable(SQLite::Database& db) {} + +Engines::Engines() : db_(cortex::db::Database::GetInstance().db()) { + CreateTable(db_); +} + +Engines::Engines(SQLite::Database& db) : db_(db) { + CreateTable(db_); +} + +Engines::~Engines() {} + +std::optional Engines::UpsertEngine( + const std::string& engine_name, const std::string& type, + const std::string& api_key, const std::string& url, + const std::string& version, const std::string& variant, + const std::string& status, const std::string& metadata) { + try { + SQLite::Statement query( + db_, + "INSERT INTO engines (engine_name, type, api_key, url, version, " + "variant, status, metadata) " + "VALUES (?, ?, ?, ?, ?, ?, ?, ?) " + "ON CONFLICT(engine_name, variant) DO UPDATE SET " + "type = excluded.type, " + "api_key = excluded.api_key, " + "url = excluded.url, " + "version = excluded.version, " + "status = excluded.status, " + "metadata = excluded.metadata, " + "date_updated = CURRENT_TIMESTAMP " + "RETURNING id, engine_name, type, api_key, url, version, variant, " + "status, metadata, date_created, date_updated;"); + + query.bind(1, engine_name); + query.bind(2, type); + query.bind(3, api_key); + query.bind(4, url); + query.bind(5, version); + query.bind(6, variant); + query.bind(7, status); + query.bind(8, metadata); + + if (query.executeStep()) { + return EngineEntry{ + query.getColumn(0).getInt(), query.getColumn(1).getString(), + query.getColumn(2).getString(), query.getColumn(3).getString(), + query.getColumn(4).getString(), query.getColumn(5).getString(), + query.getColumn(6).getString(), query.getColumn(7).getString(), + query.getColumn(8).getString(), query.getColumn(9).getString(), + query.getColumn(10).getString()}; + } else { + return std::nullopt; + } + } catch (const std::exception& e) { + return std::nullopt; + } +} + +std::optional> Engines::GetEngines() const { + try { + SQLite::Statement query( + db_, + "SELECT id, engine_name, type, api_key, url, version, variant, status, " + "metadata, date_created, date_updated " + "FROM engines " + "WHERE status = 'Default' " + "ORDER BY date_updated DESC"); + + std::vector engines; + while (query.executeStep()) { + engines.push_back(EngineEntry{ + query.getColumn(0).getInt(), query.getColumn(1).getString(), + query.getColumn(2).getString(), query.getColumn(3).getString(), + query.getColumn(4).getString(), query.getColumn(5).getString(), + query.getColumn(6).getString(), query.getColumn(7).getString(), + query.getColumn(8).getString(), query.getColumn(9).getString(), + query.getColumn(10).getString()}); + } + + return engines; + } catch (const std::exception& e) { + return std::nullopt; + } +} + +std::optional Engines::GetEngineById(int id) const { + try { + SQLite::Statement query( + db_, + "SELECT id, engine_name, type, api_key, url, version, variant, status, " + "metadata, date_created, date_updated " + "FROM engines " + "WHERE id = ? AND status = 'Default' " + "ORDER BY date_updated DESC LIMIT 1"); + + query.bind(1, id); + + if (query.executeStep()) { + return EngineEntry{ + query.getColumn(0).getInt(), query.getColumn(1).getString(), + query.getColumn(2).getString(), query.getColumn(3).getString(), + query.getColumn(4).getString(), query.getColumn(5).getString(), + query.getColumn(6).getString(), query.getColumn(7).getString(), + query.getColumn(8).getString(), query.getColumn(9).getString(), + query.getColumn(10).getString()}; + } else { + return std::nullopt; + } + } catch (const std::exception& e) { + return std::nullopt; + } +} + +std::optional Engines::GetEngineByNameAndVariant( + const std::string& engine_name, + const std::optional variant) const { + try { + std::string queryStr = + "SELECT id, engine_name, type, api_key, url, version, variant, status, " + "metadata, date_created, date_updated " + "FROM engines " + "WHERE engine_name = ? AND status = 'Default' "; + + if (variant) { + queryStr += "AND variant = ? "; + } + + queryStr += "ORDER BY date_updated DESC LIMIT 1"; + + SQLite::Statement query(db_, queryStr); + + query.bind(1, engine_name); + + if (variant) { + query.bind(2, variant.value()); + } + + if (query.executeStep()) { + return EngineEntry{ + query.getColumn(0).getInt(), query.getColumn(1).getString(), + query.getColumn(2).getString(), query.getColumn(3).getString(), + query.getColumn(4).getString(), query.getColumn(5).getString(), + query.getColumn(6).getString(), query.getColumn(7).getString(), + query.getColumn(8).getString(), query.getColumn(9).getString(), + query.getColumn(10).getString()}; + } else { + return std::nullopt; + } + } catch (const std::exception& e) { + return std::nullopt; + } +} + +std::optional Engines::DeleteEngineById(int id) { + try { + SQLite::Statement query(db_, "DELETE FROM engines WHERE id = ?"); + + query.bind(1, id); + query.exec(); + return std::nullopt; + } catch (const std::exception& e) { + return std::string("Failed to delete engine: ") + e.what(); + } +} + +} // namespace cortex::db \ No newline at end of file diff --git a/engine/database/engines.h b/engine/database/engines.h new file mode 100644 index 000000000..7429d0fa2 --- /dev/null +++ b/engine/database/engines.h @@ -0,0 +1,88 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace cortex::db { + +struct EngineEntry { + int id; + std::string engine_name; + std::string type; + std::string api_key; + std::string url; + std::string version; + std::string variant; + std::string status; + std::string metadata; + std::string date_created; + std::string date_updated; + Json::Value ToJson() const { + Json::Value root; + Json::Reader reader; + + // Convert basic fields + root["id"] = id; + root["engine_name"] = engine_name; + root["type"] = type; + root["api_key"] = api_key; + root["url"] = url; + root["version"] = version; + root["variant"] = variant; + root["status"] = status; + root["date_created"] = date_created; + root["date_updated"] = date_updated; + + // Parse metadata string into JSON object + Json::Value metadataJson; + if (!metadata.empty()) { + bool success = reader.parse(metadata, metadataJson, + false); // false = don't collect comments + if (success) { + root["metadata"] = metadataJson; + } else { + root["metadata"] = Json::Value::null; + } + } else { + root["metadata"] = Json::Value(Json::objectValue); // empty object + } + + return root; + } +}; + +class Engines { + private: + SQLite::Database& db_; + + bool IsUnique(const std::vector& entries, + const std::string& model_id, + const std::string& model_alias) const; + + std::optional> LoadModelListNoLock() const; + + public: + Engines(); + Engines(SQLite::Database& db); + ~Engines(); + + std::optional UpsertEngine( + const std::string& engine_name, const std::string& type, + const std::string& api_key, const std::string& url, + const std::string& version, const std::string& variant, + const std::string& status, const std::string& metadata); + + std::optional> GetEngines() const; + std::optional GetEngineById(int id) const; + std::optional GetEngineByNameAndVariant( + const std::string& engine_name, + const std::optional variant = std::nullopt) const; + + std::optional DeleteEngineById(int id); +}; + +} // namespace cortex::db \ No newline at end of file diff --git a/engine/database/models.cc b/engine/database/models.cc index 3e81fbab2..fb2128396 100644 --- a/engine/database/models.cc +++ b/engine/database/models.cc @@ -9,9 +9,32 @@ namespace cortex::db { Models::Models() : db_(cortex::db::Database::GetInstance().db()) {} +Models::~Models() {} + +std::string Models::StatusToString(ModelStatus status) const { + switch (status) { + case ModelStatus::Remote: + return "remote"; + case ModelStatus::Downloaded: + return "downloaded"; + case ModelStatus::Undownloaded: + return "undownloaded"; + } + return "unknown"; +} + Models::Models(SQLite::Database& db) : db_(db) {} -Models::~Models() {} +ModelStatus Models::StringToStatus(const std::string& status_str) const { + if (status_str == "remote") { + return ModelStatus::Remote; + } else if (status_str == "downloaded" || status_str.empty()) { + return ModelStatus::Downloaded; + } else if (status_str == "undownloaded") { + return ModelStatus::Undownloaded; + } + throw std::invalid_argument("Invalid status string"); +} cpp::result, std::string> Models::LoadModelList() const { @@ -41,7 +64,8 @@ cpp::result, std::string> Models::LoadModelListNoLock() std::vector entries; SQLite::Statement query(db_, "SELECT model_id, author_repo_id, branch_name, " - "path_to_model_yaml, model_alias FROM models"); + "path_to_model_yaml, model_alias, model_format, " + "model_source, status, engine FROM models"); while (query.executeStep()) { ModelEntry entry; @@ -50,6 +74,10 @@ cpp::result, std::string> Models::LoadModelListNoLock() entry.branch_name = query.getColumn(2).getString(); entry.path_to_model_yaml = query.getColumn(3).getString(); entry.model_alias = query.getColumn(4).getString(); + entry.model_format = query.getColumn(5).getString(); + entry.model_source = query.getColumn(6).getString(); + entry.status = StringToStatus(query.getColumn(7).getString()); + entry.engine = query.getColumn(8).getString(); entries.push_back(entry); } return entries; @@ -124,7 +152,8 @@ cpp::result Models::GetModelInfo( try { SQLite::Statement query(db_, "SELECT model_id, author_repo_id, branch_name, " - "path_to_model_yaml, model_alias FROM models " + "path_to_model_yaml, model_alias, model_format, " + "model_source, status, engine FROM models " "WHERE model_id = ? OR model_alias = ?"); query.bind(1, identifier); @@ -136,6 +165,10 @@ cpp::result Models::GetModelInfo( entry.branch_name = query.getColumn(2).getString(); entry.path_to_model_yaml = query.getColumn(3).getString(); entry.model_alias = query.getColumn(4).getString(); + entry.model_format = query.getColumn(5).getString(); + entry.model_source = query.getColumn(6).getString(); + entry.status = StringToStatus(query.getColumn(7).getString()); + entry.engine = query.getColumn(8).getString(); return entry; } else { return cpp::fail("Model not found: " + identifier); @@ -151,6 +184,10 @@ void Models::PrintModelInfo(const ModelEntry& entry) const { LOG_INFO << "Branch Name: " << entry.branch_name; LOG_INFO << "Path to model.yaml: " << entry.path_to_model_yaml; LOG_INFO << "Model Alias: " << entry.model_alias; + LOG_INFO << "Model Format: " << entry.model_format; + LOG_INFO << "Model Source: " << entry.model_source; + LOG_INFO << "Status: " << StatusToString(entry.status); + LOG_INFO << "Engine: " << entry.engine; } cpp::result Models::AddModelEntry(ModelEntry new_entry, @@ -171,14 +208,18 @@ cpp::result Models::AddModelEntry(ModelEntry new_entry, SQLite::Statement insert( db_, - "INSERT INTO models (model_id, author_repo_id, " - "branch_name, path_to_model_yaml, model_alias) VALUES (?, ?, " - "?, ?, ?)"); + "INSERT INTO models (model_id, author_repo_id, branch_name, " + "path_to_model_yaml, model_alias, model_format, model_source, " + "status, engine) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)"); insert.bind(1, new_entry.model); insert.bind(2, new_entry.author_repo_id); insert.bind(3, new_entry.branch_name); insert.bind(4, new_entry.path_to_model_yaml); insert.bind(5, new_entry.model_alias); + insert.bind(6, new_entry.model_format); + insert.bind(7, new_entry.model_source); + insert.bind(8, StatusToString(new_entry.status)); + insert.bind(9, new_entry.engine); insert.exec(); return true; @@ -196,16 +237,20 @@ cpp::result Models::UpdateModelEntry( return cpp::fail("Model not found: " + identifier); } try { - SQLite::Statement upd(db_, - "UPDATE models " - "SET author_repo_id = ?, branch_name = ?, " - "path_to_model_yaml = ? " - "WHERE model_id = ? OR model_alias = ?"); + SQLite::Statement upd( + db_, + "UPDATE models SET author_repo_id = ?, branch_name = ?, " + "path_to_model_yaml = ?, model_format = ?, model_source = ?, status = " + "?, engine = ? WHERE model_id = ? OR model_alias = ?"); upd.bind(1, updated_entry.author_repo_id); upd.bind(2, updated_entry.branch_name); upd.bind(3, updated_entry.path_to_model_yaml); - upd.bind(4, identifier); - upd.bind(5, identifier); + upd.bind(4, updated_entry.model_format); + upd.bind(5, updated_entry.model_source); + upd.bind(6, StatusToString(updated_entry.status)); + upd.bind(7, updated_entry.engine); + upd.bind(8, identifier); + upd.bind(9, identifier); return upd.exec() == 1; } catch (const std::exception& e) { return cpp::fail(e.what()); @@ -293,4 +338,5 @@ bool Models::HasModel(const std::string& identifier) const { return false; } } -} // namespace cortex::db + +} // namespace cortex::db \ No newline at end of file diff --git a/engine/database/models.h b/engine/database/models.h index 197996ab8..dd6e2a5a1 100644 --- a/engine/database/models.h +++ b/engine/database/models.h @@ -7,12 +7,23 @@ #include "utils/result.hpp" namespace cortex::db { + +enum class ModelStatus { + Remote, + Downloaded, + Undownloaded +}; + struct ModelEntry { - std::string model; + std::string model; std::string author_repo_id; std::string branch_name; std::string path_to_model_yaml; std::string model_alias; + std::string model_format; + std::string model_source; + ModelStatus status; + std::string engine; }; class Models { @@ -26,6 +37,9 @@ class Models { cpp::result, std::string> LoadModelListNoLock() const; + std::string StatusToString(ModelStatus status) const; + ModelStatus StringToStatus(const std::string& status_str) const; + public: cpp::result, std::string> LoadModelList() const; Models(); @@ -49,4 +63,5 @@ class Models { const std::string& identifier) const; bool HasModel(const std::string& identifier) const; }; -} // namespace cortex::db + +} // namespace cortex::db \ No newline at end of file diff --git a/engine/extensions/remote-engine/anthropic_engine.cc b/engine/extensions/remote-engine/anthropic_engine.cc new file mode 100644 index 000000000..847cba566 --- /dev/null +++ b/engine/extensions/remote-engine/anthropic_engine.cc @@ -0,0 +1,62 @@ +#include "anthropic_engine.h" +#include +#include +#include "utils/logging_utils.h" + +namespace remote_engine { +namespace { +constexpr const std::array kAnthropicModels = { + "claude-3-5-sonnet-20241022", "claude-3-5-haiku-20241022", + "claude-3-opus-20240229", "claude-3-sonnet-20240229", + "claude-3-haiku-20240307"}; +} +void AnthropicEngine::GetModels( + std::shared_ptr json_body, + std::function&& callback) { + Json::Value json_resp; + Json::Value model_array(Json::arrayValue); + { + std::shared_lock l(models_mtx_); + for (const auto& [m, _] : models_) { + Json::Value val; + val["id"] = m; + val["engine"] = "anthropic"; + val["start_time"] = "_"; + val["model_size"] = "_"; + val["vram"] = "_"; + val["ram"] = "_"; + val["object"] = "model"; + model_array.append(val); + } + } + + json_resp["object"] = "list"; + json_resp["data"] = model_array; + + Json::Value status; + status["is_done"] = true; + status["has_error"] = false; + status["is_stream"] = false; + status["status_code"] = 200; + callback(std::move(status), std::move(json_resp)); + CTL_INF("Running models responded"); +} + +Json::Value AnthropicEngine::GetRemoteModels() { + Json::Value json_resp; + Json::Value model_array(Json::arrayValue); + for (const auto& m : kAnthropicModels) { + Json::Value val; + val["id"] = std::string(m); + val["engine"] = "anthropic"; + val["created"] = "_"; + val["object"] = "model"; + model_array.append(val); + } + + json_resp["object"] = "list"; + json_resp["data"] = model_array; + CTL_INF("Remote models responded"); + return json_resp; +} +} // namespace remote_engine \ No newline at end of file diff --git a/engine/extensions/remote-engine/anthropic_engine.h b/engine/extensions/remote-engine/anthropic_engine.h new file mode 100644 index 000000000..bcd3dfaf7 --- /dev/null +++ b/engine/extensions/remote-engine/anthropic_engine.h @@ -0,0 +1,13 @@ +#pragma once +#include "remote_engine.h" + +namespace remote_engine { + class AnthropicEngine: public RemoteEngine { +public: + void GetModels( + std::shared_ptr json_body, + std::function&& callback) override; + + Json::Value GetRemoteModels() override; + }; +} \ No newline at end of file diff --git a/engine/extensions/remote-engine/openai_engine.cc b/engine/extensions/remote-engine/openai_engine.cc new file mode 100644 index 000000000..7c7d70385 --- /dev/null +++ b/engine/extensions/remote-engine/openai_engine.cc @@ -0,0 +1,54 @@ +#include "openai_engine.h" +#include "utils/logging_utils.h" + +namespace remote_engine { + +void OpenAiEngine::GetModels( + std::shared_ptr json_body, + std::function&& callback) { + Json::Value json_resp; + Json::Value model_array(Json::arrayValue); + { + std::shared_lock l(models_mtx_); + for (const auto& [m, _] : models_) { + Json::Value val; + val["id"] = m; + val["engine"] = "openai"; + val["start_time"] = "_"; + val["model_size"] = "_"; + val["vram"] = "_"; + val["ram"] = "_"; + val["object"] = "model"; + model_array.append(val); + } + } + + json_resp["object"] = "list"; + json_resp["data"] = model_array; + + Json::Value status; + status["is_done"] = true; + status["has_error"] = false; + status["is_stream"] = false; + status["status_code"] = 200; + callback(std::move(status), std::move(json_resp)); + CTL_INF("Running models responded"); +} + +Json::Value OpenAiEngine::GetRemoteModels() { + auto response = MakeGetModelsRequest(); + if (response.error) { + Json::Value error; + error["error"] = response.error_message; + return error; + } + Json::Value response_json; + Json::Reader reader; + if (!reader.parse(response.body, response_json)) { + Json::Value error; + error["error"] = "Failed to parse response"; + return error; + } + return response_json; +} +} // namespace remote_engine \ No newline at end of file diff --git a/engine/extensions/remote-engine/openai_engine.h b/engine/extensions/remote-engine/openai_engine.h new file mode 100644 index 000000000..61dc68f0c --- /dev/null +++ b/engine/extensions/remote-engine/openai_engine.h @@ -0,0 +1,14 @@ +#pragma once + +#include "remote_engine.h" + +namespace remote_engine { +class OpenAiEngine : public RemoteEngine { + public: + void GetModels( + std::shared_ptr json_body, + std::function&& callback) override; + + Json::Value GetRemoteModels() override; +}; +} // namespace remote_engine \ No newline at end of file diff --git a/engine/extensions/remote-engine/remote_engine.cc b/engine/extensions/remote-engine/remote_engine.cc new file mode 100644 index 000000000..04effb457 --- /dev/null +++ b/engine/extensions/remote-engine/remote_engine.cc @@ -0,0 +1,712 @@ +#include "remote_engine.h" +#include +#include +#include +#include +#include "utils/json_helper.h" +#include "utils/logging_utils.h" +namespace remote_engine { +namespace { +constexpr const int k200OK = 200; +constexpr const int k400BadRequest = 400; +constexpr const int k409Conflict = 409; +constexpr const int k500InternalServerError = 500; +constexpr const int kFileLoggerOption = 0; +bool is_anthropic(const std::string& model) { + return model.find("claude") != std::string::npos; +} + +struct AnthropicChunk { + std::string type; + std::string id; + int index; + std::string msg; + std::string model; + std::string stop_reason; + bool should_ignore = false; + + AnthropicChunk(const std::string& str) { + if (str.size() > 6) { + std::string s = str.substr(6); + try { + auto root = json_helper::ParseJsonString(s); + type = root["type"].asString(); + if (type == "message_start") { + id = root["message"]["id"].asString(); + model = root["message"]["model"].asString(); + } else if (type == "content_block_delta") { + index = root["index"].asInt(); + if (root["delta"]["type"].asString() == "text_delta") { + msg = root["delta"]["text"].asString(); + } + } else if (type == "message_delta") { + stop_reason = root["delta"]["stop_reason"].asString(); + } else { + // ignore other messages + should_ignore = true; + } + } catch (const std::exception& e) { + should_ignore = true; + CTL_WRN("JSON parse error: " << e.what()); + } + } else { + should_ignore = true; + } + } + + std::string ToOpenAiFormatString() { + Json::Value root; + root["id"] = id; + root["object"] = "chat.completion.chunk"; + root["created"] = Json::Value(); + root["model"] = model; + root["system_fingerprint"] = "fp_e76890f0c3"; + Json::Value choices(Json::arrayValue); + Json::Value choice; + Json::Value content; + choice["index"] = 0; + content["content"] = msg; + if (type == "message_start") { + content["role"] = "assistant"; + content["refusal"] = Json::Value(); + } + choice["delta"] = content; + choice["finish_reason"] = stop_reason.empty() ? Json::Value() : stop_reason; + choices.append(choice); + root["choices"] = choices; + return "data: " + json_helper::DumpJsonString(root); + } +}; + +} // namespace + +size_t StreamWriteCallback(char* ptr, size_t size, size_t nmemb, + void* userdata) { + auto* context = static_cast(userdata); + std::string chunk(ptr, size * nmemb); + + context->buffer += chunk; + + // Process complete lines + size_t pos; + while ((pos = context->buffer.find('\n')) != std::string::npos) { + std::string line = context->buffer.substr(0, pos); + context->buffer = context->buffer.substr(pos + 1); + CTL_TRC(line); + + // Skip empty lines + if (line.empty() || line == "\r" || + line.find("event:") != std::string::npos) + continue; + + // Remove "data: " prefix if present + // if (line.substr(0, 6) == "data: ") + // { + // line = line.substr(6); + // } + + // Skip [DONE] message + // std::cout << line << std::endl; + if (line == "data: [DONE]" || + line.find("message_stop") != std::string::npos) { + Json::Value status; + status["is_done"] = true; + status["has_error"] = false; + status["is_stream"] = true; + status["status_code"] = 200; + (*context->callback)(std::move(status), Json::Value()); + break; + } + + // Parse the JSON + Json::Value chunk_json; + if (is_anthropic(context->model)) { + AnthropicChunk ac(line); + if (ac.should_ignore) + continue; + ac.model = context->model; + if (ac.type == "message_start") { + context->id = ac.id; + } else { + ac.id = context->id; + } + chunk_json["data"] = ac.ToOpenAiFormatString() + "\n\n"; + } else { + chunk_json["data"] = line + "\n\n"; + } + Json::Reader reader; + + Json::Value status; + status["is_done"] = false; + status["has_error"] = false; + status["is_stream"] = true; + status["status_code"] = 200; + (*context->callback)(std::move(status), std::move(chunk_json)); + } + + return size * nmemb; +} + +CurlResponse RemoteEngine::MakeStreamingChatCompletionRequest( + const ModelConfig& config, const std::string& body, + const std::function& callback) { + + CURL* curl = curl_easy_init(); + CurlResponse response; + + if (!curl) { + response.error = true; + response.error_message = "Failed to initialize CURL"; + return response; + } + + std::string full_url = + config.transform_req["chat_completions"]["url"].as(); + + struct curl_slist* headers = nullptr; + if (!config.api_key.empty()) { + headers = curl_slist_append(headers, api_key_template_.c_str()); + } + + if (is_anthropic(config.model)) { + std::string v = "anthropic-version: " + config.version; + headers = curl_slist_append(headers, v.c_str()); + } + + headers = curl_slist_append(headers, "Content-Type: application/json"); + headers = curl_slist_append(headers, "Accept: text/event-stream"); + headers = curl_slist_append(headers, "Cache-Control: no-cache"); + headers = curl_slist_append(headers, "Connection: keep-alive"); + + StreamContext context{ + std::make_shared>( + callback), + "", "", config.model}; + + curl_easy_setopt(curl, CURLOPT_URL, full_url.c_str()); + curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers); + curl_easy_setopt(curl, CURLOPT_POST, 1L); + curl_easy_setopt(curl, CURLOPT_POSTFIELDS, body.c_str()); + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, StreamWriteCallback); + curl_easy_setopt(curl, CURLOPT_WRITEDATA, &context); + curl_easy_setopt(curl, CURLOPT_TRANSFER_ENCODING, 1L); + + CURLcode res = curl_easy_perform(curl); + + if (res != CURLE_OK) { + response.error = true; + response.error_message = curl_easy_strerror(res); + + Json::Value status; + status["is_done"] = true; + status["has_error"] = true; + status["is_stream"] = true; + status["status_code"] = 500; + + Json::Value error; + error["error"] = response.error_message; + callback(std::move(status), std::move(error)); + } + + curl_slist_free_all(headers); + curl_easy_cleanup(curl); + return response; +} + +std::string ReplaceApiKeyPlaceholder(const std::string& templateStr, + const std::string& apiKey) { + const std::string placeholder = "{{api_key}}"; + std::string result = templateStr; + size_t pos = result.find(placeholder); + + if (pos != std::string::npos) { + result.replace(pos, placeholder.length(), apiKey); + } + + return result; +} + +static size_t WriteCallback(char* ptr, size_t size, size_t nmemb, + std::string* data) { + data->append(ptr, size * nmemb); + return size * nmemb; +} + +RemoteEngine::RemoteEngine() { + curl_global_init(CURL_GLOBAL_ALL); +} + +RemoteEngine::~RemoteEngine() { + curl_global_cleanup(); +} + +RemoteEngine::ModelConfig* RemoteEngine::GetModelConfig( + const std::string& model) { + std::shared_lock lock(models_mtx_); + auto it = models_.find(model); + if (it != models_.end()) { + return &it->second; + } + return nullptr; +} + +CurlResponse RemoteEngine::MakeGetModelsRequest() { + CURL* curl = curl_easy_init(); + CurlResponse response; + + if (!curl) { + response.error = true; + response.error_message = "Failed to initialize CURL"; + return response; + } + + std::string full_url = metadata_["get_models_url"].asString(); + + struct curl_slist* headers = nullptr; + headers = curl_slist_append(headers, api_key_template_.c_str()); + headers = curl_slist_append(headers, "Content-Type: application/json"); + + curl_easy_setopt(curl, CURLOPT_URL, full_url.c_str()); + curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers); + + std::string response_string; + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, WriteCallback); + curl_easy_setopt(curl, CURLOPT_WRITEDATA, &response_string); + + CURLcode res = curl_easy_perform(curl); + if (res != CURLE_OK) { + response.error = true; + response.error_message = curl_easy_strerror(res); + } else { + response.body = response_string; + } + + curl_slist_free_all(headers); + curl_easy_cleanup(curl); + return response; +} + +CurlResponse RemoteEngine::MakeChatCompletionRequest( + const ModelConfig& config, const std::string& body, + const std::string& method) { + CURL* curl = curl_easy_init(); + CurlResponse response; + + if (!curl) { + response.error = true; + response.error_message = "Failed to initialize CURL"; + return response; + } + std::string full_url = + config.transform_req["chat_completions"]["url"].as(); + + struct curl_slist* headers = nullptr; + if (!config.api_key.empty()) { + headers = curl_slist_append(headers, api_key_template_.c_str()); + } + + if (is_anthropic(config.model)) { + std::string v = "anthropic-version: " + config.version; + headers = curl_slist_append(headers, v.c_str()); + } + headers = curl_slist_append(headers, "Content-Type: application/json"); + + curl_easy_setopt(curl, CURLOPT_URL, full_url.c_str()); + curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers); + + if (method == "POST") { + curl_easy_setopt(curl, CURLOPT_POSTFIELDS, body.c_str()); + } + + std::string response_string; + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, WriteCallback); + curl_easy_setopt(curl, CURLOPT_WRITEDATA, &response_string); + + CURLcode res = curl_easy_perform(curl); + if (res != CURLE_OK) { + response.error = true; + response.error_message = curl_easy_strerror(res); + } else { + response.body = response_string; + } + + curl_slist_free_all(headers); + curl_easy_cleanup(curl); + return response; +} + +bool RemoteEngine::LoadModelConfig(const std::string& model, + const std::string& yaml_path, + const std::string& api_key) { + try { + YAML::Node config = YAML::LoadFile(yaml_path); + + ModelConfig model_config; + model_config.model = model; + if (is_anthropic(model)) { + if (!config["version"]) { + CTL_ERR("Missing version for model: " << model); + return false; + } + model_config.version = config["version"].as(); + } + + // Required fields + if (!config["api_key_template"]) { + LOG_ERROR << "Missing required fields in config for model " << model; + return false; + } + + model_config.api_key = api_key; + // model_config.url = ; + // Optional fields + if (config["api_key_template"]) { + api_key_template_ = ReplaceApiKeyPlaceholder( + config["api_key_template"].as(), api_key); + } + if (config["TransformReq"]) { + model_config.transform_req = config["TransformReq"]; + } else { + LOG_WARN << "Missing TransformReq in config for model " << model; + } + if (config["TransformResp"]) { + model_config.transform_resp = config["TransformResp"]; + } else { + LOG_WARN << "Missing TransformResp in config for model " << model; + } + + model_config.is_loaded = true; + + // Thread-safe update of models map + { + std::unique_lock lock(models_mtx_); + models_[model] = std::move(model_config); + } + CTL_DBG("LoadModelConfig successfully: " << model << ", " << yaml_path); + + return true; + } catch (const YAML::Exception& e) { + LOG_ERROR << "Failed to load config for model " << model << ": " + << e.what(); + return false; + } +} + +void RemoteEngine::GetModels( + std::shared_ptr json_body, + std::function&& callback) { + CTL_WRN("Not implemented yet!"); +} + +void RemoteEngine::LoadModel( + std::shared_ptr json_body, + std::function&& callback) { + if (!json_body->isMember("model") || !json_body->isMember("model_path") || + !json_body->isMember("api_key")) { + Json::Value error; + error["error"] = "Missing required fields: model or model_path"; + Json::Value status; + status["is_done"] = true; + status["has_error"] = true; + status["is_stream"] = false; + status["status_code"] = k400BadRequest; + callback(std::move(status), std::move(error)); + return; + } + + const std::string& model = (*json_body)["model"].asString(); + const std::string& model_path = (*json_body)["model_path"].asString(); + const std::string& api_key = (*json_body)["api_key"].asString(); + + if (!LoadModelConfig(model, model_path, api_key)) { + Json::Value error; + error["error"] = "Failed to load model configuration"; + Json::Value status; + status["is_done"] = true; + status["has_error"] = true; + status["is_stream"] = false; + status["status_code"] = k500InternalServerError; + callback(std::move(status), std::move(error)); + return; + } + if (json_body->isMember("metadata")) { + metadata_ = (*json_body)["metadata"]; + } + + Json::Value response; + response["status"] = "Model loaded successfully"; + Json::Value status; + status["is_done"] = true; + status["has_error"] = false; + status["is_stream"] = false; + status["status_code"] = k200OK; + callback(std::move(status), std::move(response)); + CTL_INF("Model loaded successfully: " << model); +} + +void RemoteEngine::UnloadModel( + std::shared_ptr json_body, + std::function&& callback) { + if (!json_body->isMember("model")) { + Json::Value error; + error["error"] = "Missing required field: model"; + Json::Value status; + status["is_done"] = true; + status["has_error"] = true; + status["is_stream"] = false; + status["status_code"] = k400BadRequest; + callback(std::move(status), std::move(error)); + return; + } + + const std::string& model = (*json_body)["model"].asString(); + + { + std::unique_lock lock(models_mtx_); + models_.erase(model); + } + + Json::Value response; + response["status"] = "Model unloaded successfully"; + Json::Value status; + status["is_done"] = true; + status["has_error"] = false; + status["is_stream"] = false; + status["status_code"] = k200OK; + callback(std::move(status), std::move(response)); +} + +void RemoteEngine::HandleChatCompletion( + std::shared_ptr json_body, + std::function&& callback) { + if (!json_body->isMember("model")) { + Json::Value status; + status["is_done"] = true; + status["has_error"] = true; + status["is_stream"] = false; + status["status_code"] = k400BadRequest; + Json::Value error; + error["error"] = "Missing required fields: model"; + callback(std::move(status), std::move(error)); + return; + } + + const std::string& model = (*json_body)["model"].asString(); + auto* model_config = GetModelConfig(model); + + if (!model_config || !model_config->is_loaded) { + Json::Value status; + status["is_done"] = true; + status["has_error"] = true; + status["is_stream"] = false; + status["status_code"] = k400BadRequest; + Json::Value error; + error["error"] = "Model not found or not loaded: " + model; + callback(std::move(status), std::move(error)); + return; + } + bool is_stream = + json_body->isMember("stream") && (*json_body)["stream"].asBool(); + Json::FastWriter writer; + // Transform request + std::string result; + try { + // Check if required YAML nodes exist + if (!model_config->transform_req["chat_completions"]) { + throw std::runtime_error( + "Missing 'chat_completions' node in transform_req"); + } + if (!model_config->transform_req["chat_completions"]["template"]) { + throw std::runtime_error("Missing 'template' node in chat_completions"); + } + + // Validate JSON body + if (!json_body || json_body->isNull()) { + throw std::runtime_error("Invalid or null JSON body"); + } + + // Get template string with error check + std::string template_str; + try { + template_str = model_config->transform_req["chat_completions"]["template"] + .as(); + } catch (const YAML::BadConversion& e) { + throw std::runtime_error("Failed to convert template node to string: " + + std::string(e.what())); + } + + // Parse system for anthropic + if (is_anthropic(model)) { + bool has_system = false; + Json::Value msgs(Json::arrayValue); + for (auto& kv : (*json_body)["messages"]) { + if (kv["role"].asString() == "system") { + (*json_body)["system"] = kv["content"].asString(); + has_system = true; + } else { + msgs.append(kv); + } + } + if (has_system) { + (*json_body)["messages"] = msgs; + } + } + + // Render with error handling + try { + result = renderer_.Render(template_str, *json_body); + } catch (const std::exception& e) { + throw std::runtime_error("Template rendering error: " + + std::string(e.what())); + } + } catch (const std::exception& e) { + // Log error and potentially rethrow or handle accordingly + LOG_WARN << "Error in TransformRequest: " << e.what(); + LOG_WARN << "Using original request body"; + result = (*json_body).toStyledString(); + } + + if (is_stream) { + MakeStreamingChatCompletionRequest(*model_config, result, callback); + } else { + + auto response = MakeChatCompletionRequest(*model_config, result); + + if (response.error) { + Json::Value status; + status["is_done"] = true; + status["has_error"] = true; + status["is_stream"] = false; + status["status_code"] = k400BadRequest; + Json::Value error; + error["error"] = response.error_message; + callback(std::move(status), std::move(error)); + return; + } + + Json::Value response_json; + Json::Reader reader; + if (!reader.parse(response.body, response_json)) { + Json::Value status; + status["is_done"] = true; + status["has_error"] = true; + status["is_stream"] = false; + status["status_code"] = k500InternalServerError; + Json::Value error; + error["error"] = "Failed to parse response"; + callback(std::move(status), std::move(error)); + return; + } + + // Transform Response + std::string response_str; + try { + // Check if required YAML nodes exist + if (!model_config->transform_resp["chat_completions"]) { + throw std::runtime_error( + "Missing 'chat_completions' node in transform_resp"); + } + if (!model_config->transform_resp["chat_completions"]["template"]) { + throw std::runtime_error("Missing 'template' node in chat_completions"); + } + + // Validate JSON body + if (!response_json || response_json.isNull()) { + throw std::runtime_error("Invalid or null JSON body"); + } + + // Get template string with error check + std::string template_str; + try { + template_str = + model_config->transform_resp["chat_completions"]["template"] + .as(); + } catch (const YAML::BadConversion& e) { + throw std::runtime_error("Failed to convert template node to string: " + + std::string(e.what())); + } + + // Render with error handling + try { + response_str = renderer_.Render(template_str, response_json); + } catch (const std::exception& e) { + throw std::runtime_error("Template rendering error: " + + std::string(e.what())); + } + } catch (const std::exception& e) { + // Log error and potentially rethrow or handle accordingly + LOG_WARN << "Error in TransformRequest: " << e.what(); + LOG_WARN << "Using original request body"; + response_str = response_json.toStyledString(); + } + + Json::Reader reader_final; + Json::Value response_json_final; + if (!reader_final.parse(response_str, response_json_final)) { + Json::Value status; + status["is_done"] = true; + status["has_error"] = true; + status["is_stream"] = false; + status["status_code"] = k500InternalServerError; + Json::Value error; + error["error"] = "Failed to parse response"; + callback(std::move(status), std::move(error)); + return; + } + + Json::Value status; + status["is_done"] = true; + status["has_error"] = false; + status["is_stream"] = false; + status["status_code"] = k200OK; + + callback(std::move(status), std::move(response_json_final)); + } +} + +void RemoteEngine::GetModelStatus( + std::shared_ptr json_body, + std::function&& callback) { + if (!json_body->isMember("model")) { + Json::Value error; + error["error"] = "Missing required field: model"; + callback(Json::Value(), std::move(error)); + return; + } + + const std::string& model = (*json_body)["model"].asString(); + auto* model_config = GetModelConfig(model); + + if (!model_config) { + Json::Value error; + error["error"] = "Model not found: " + model; + callback(Json::Value(), std::move(error)); + return; + } + + Json::Value response; + response["model"] = model; + response["model_loaded"] = model_config->is_loaded; + response["model_data"] = model_config->url; + + Json::Value status; + status["is_done"] = true; + status["has_error"] = false; + status["is_stream"] = false; + status["status_code"] = k200OK; + callback(std::move(status), std::move(response)); +} + +// Implement remaining virtual functions +void RemoteEngine::HandleEmbedding( + std::shared_ptr, + std::function&& callback) { + callback(Json::Value(), Json::Value()); +} + +Json::Value RemoteEngine::GetRemoteModels() { + CTL_WRN("Not implemented yet!"); + return {}; +} + +} // namespace remote_engine \ No newline at end of file diff --git a/engine/extensions/remote-engine/remote_engine.h b/engine/extensions/remote-engine/remote_engine.h new file mode 100644 index 000000000..8ce6fa652 --- /dev/null +++ b/engine/extensions/remote-engine/remote_engine.h @@ -0,0 +1,102 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include "cortex-common/remote_enginei.h" +#include "extensions/remote-engine/template_renderer.h" +#include "utils/engine_constants.h" +#include "utils/file_logger.h" +// Helper for CURL response + +namespace remote_engine { +inline bool IsRemoteEngine(std::string_view e) { + return e == kAnthropicEngine || e == kOpenAiEngine; +} + +struct StreamContext { + std::shared_ptr> callback; + std::string buffer; + // Cache value for Anthropic + std::string id; + std::string model; +}; +struct CurlResponse { + std::string body; + bool error{false}; + std::string error_message; +}; + +class RemoteEngine : public RemoteEngineI { + protected: + // Model configuration + struct ModelConfig { + std::string model; + std::string version; + std::string api_key; + std::string url; + YAML::Node transform_req; + YAML::Node transform_resp; + bool is_loaded{false}; + }; + + // Thread-safe model config storage + mutable std::shared_mutex models_mtx_; + std::unordered_map models_; + TemplateRenderer renderer_; + Json::Value metadata_; + std::string api_key_template_; + std::unique_ptr async_file_logger_; + + // Helper functions + CurlResponse MakeChatCompletionRequest(const ModelConfig& config, + const std::string& body, + const std::string& method = "POST"); + CurlResponse MakeStreamingChatCompletionRequest( + const ModelConfig& config, const std::string& body, + const std::function& callback); + CurlResponse MakeGetModelsRequest(); + + // Internal model management + bool LoadModelConfig(const std::string& model, const std::string& yaml_path, + const std::string& api_key); + ModelConfig* GetModelConfig(const std::string& model); + + public: + RemoteEngine(); + virtual ~RemoteEngine(); + + // Main interface implementations + void GetModels( + std::shared_ptr json_body, + std::function&& callback) override; + + void HandleChatCompletion( + std::shared_ptr json_body, + std::function&& callback) override; + + void LoadModel( + std::shared_ptr json_body, + std::function&& callback) override; + + void UnloadModel( + std::shared_ptr json_body, + std::function&& callback) override; + + void GetModelStatus( + std::shared_ptr json_body, + std::function&& callback) override; + + // Other required virtual functions + void HandleEmbedding( + std::shared_ptr json_body, + std::function&& callback) override; + + Json::Value GetRemoteModels() override; +}; + +} // namespace remote_engine \ No newline at end of file diff --git a/engine/extensions/remote-engine/template_renderer.cc b/engine/extensions/remote-engine/template_renderer.cc new file mode 100644 index 000000000..15514d17c --- /dev/null +++ b/engine/extensions/remote-engine/template_renderer.cc @@ -0,0 +1,136 @@ +#if defined(_WIN32) || defined(_WIN64) +#define NOMINMAX +#undef min +#undef max +#endif +#include "template_renderer.h" +#include +#include +#include "utils/logging_utils.h" +namespace remote_engine { +TemplateRenderer::TemplateRenderer() { + // Configure Inja environment + env_.set_trim_blocks(true); + env_.set_lstrip_blocks(true); + + // Add tojson function for all value types + env_.add_callback("tojson", 1, [](inja::Arguments& args) { + if (args.empty()) { + return nlohmann::json(nullptr); + } + const auto& value = *args[0]; + + if (value.is_string()) { + return nlohmann::json(std::string("\"") + value.get() + + "\""); + } + return value; + }); +} + +std::string TemplateRenderer::Render(const std::string& tmpl, + const Json::Value& data) { + try { + // Convert Json::Value to nlohmann::json + auto json_data = ConvertJsonValue(data); + + // Create the input data structure expected by the template + nlohmann::json template_data; + template_data["input_request"] = json_data; + + // Debug output + LOG_DEBUG << "Template: " << tmpl; + LOG_DEBUG << "Data: " << template_data.dump(2); + + // Render template + std::string result = env_.render(tmpl, template_data); + + // Clean up any potential double quotes in JSON strings + result = std::regex_replace(result, std::regex("\\\"\\\""), "\""); + + LOG_DEBUG << "Result: " << result; + + // Validate JSON + auto parsed = nlohmann::json::parse(result); + + return result; + } catch (const std::exception& e) { + LOG_ERROR << "Template rendering failed: " << e.what(); + LOG_ERROR << "Template: " << tmpl; + throw std::runtime_error(std::string("Template rendering failed: ") + + e.what()); + } +} + +nlohmann::json TemplateRenderer::ConvertJsonValue(const Json::Value& input) { + if (input.isNull()) { + return nullptr; + } else if (input.isBool()) { + return input.asBool(); + } else if (input.isInt()) { + return input.asInt(); + } else if (input.isUInt()) { + return input.asUInt(); + } else if (input.isDouble()) { + return input.asDouble(); + } else if (input.isString()) { + return input.asString(); + } else if (input.isArray()) { + nlohmann::json arr = nlohmann::json::array(); + for (const auto& element : input) { + arr.push_back(ConvertJsonValue(element)); + } + return arr; + } else if (input.isObject()) { + nlohmann::json obj = nlohmann::json::object(); + for (const auto& key : input.getMemberNames()) { + obj[key] = ConvertJsonValue(input[key]); + } + return obj; + } + return nullptr; +} + +Json::Value TemplateRenderer::ConvertNlohmannJson(const nlohmann::json& input) { + if (input.is_null()) { + return Json::Value(); + } else if (input.is_boolean()) { + return Json::Value(input.get()); + } else if (input.is_number_integer()) { + return Json::Value(input.get()); + } else if (input.is_number_unsigned()) { + return Json::Value(input.get()); + } else if (input.is_number_float()) { + return Json::Value(input.get()); + } else if (input.is_string()) { + return Json::Value(input.get()); + } else if (input.is_array()) { + Json::Value arr(Json::arrayValue); + for (const auto& element : input) { + arr.append(ConvertNlohmannJson(element)); + } + return arr; + } else if (input.is_object()) { + Json::Value obj(Json::objectValue); + for (auto it = input.begin(); it != input.end(); ++it) { + obj[it.key()] = ConvertNlohmannJson(it.value()); + } + return obj; + } + return Json::Value(); +} + +std::string TemplateRenderer::RenderFile(const std::string& template_path, + const Json::Value& data) { + try { + // Convert Json::Value to nlohmann::json + auto json_data = ConvertJsonValue(data); + + // Load and render template + return env_.render_file(template_path, json_data); + } catch (const std::exception& e) { + throw std::runtime_error(std::string("Template file rendering failed: ") + + e.what()); + } +} +} // namespace remote_engine \ No newline at end of file diff --git a/engine/extensions/remote-engine/template_renderer.h b/engine/extensions/remote-engine/template_renderer.h new file mode 100644 index 000000000..f59e7cc93 --- /dev/null +++ b/engine/extensions/remote-engine/template_renderer.h @@ -0,0 +1,40 @@ +#pragma once + +#include + +#include +#include "json/json.h" +#include "trantor/utils/Logger.h" +// clang-format off +#if defined(_WIN32) || defined(_WIN64) +#define NOMINMAX +#undef min +#undef max +#endif +#include +#include +// clang-format on +namespace remote_engine { +class TemplateRenderer { + public: + TemplateRenderer(); + ~TemplateRenderer() = default; + + // Convert Json::Value to nlohmann::json + static nlohmann::json ConvertJsonValue(const Json::Value& input); + + // Convert nlohmann::json to Json::Value + static Json::Value ConvertNlohmannJson(const nlohmann::json& input); + + // Render template with data + std::string Render(const std::string& tmpl, const Json::Value& data); + + // Load template from file and render + std::string RenderFile(const std::string& template_path, + const Json::Value& data); + + private: + inja::Environment env_; +}; + +} // namespace remote_engine \ No newline at end of file diff --git a/engine/migrations/db_helper.h b/engine/migrations/db_helper.h new file mode 100644 index 000000000..0990426bf --- /dev/null +++ b/engine/migrations/db_helper.h @@ -0,0 +1,26 @@ +#pragma once +#include + +namespace cortex::mgr { +#include +#include +#include +#include + +inline bool ColumnExists(SQLite::Database& db, const std::string& table_name, const std::string& column_name) { + try { + SQLite::Statement query(db, "SELECT " + column_name + " FROM " + table_name + " LIMIT 0"); + return true; + } catch (std::exception&) { + return false; + } +} + +inline void AddColumnIfNotExists(SQLite::Database& db, const std::string& table_name, + const std::string& column_name, const std::string& column_type) { + if (!ColumnExists(db, table_name, column_name)) { + std::string sql = "ALTER TABLE " + table_name + " ADD COLUMN " + column_name + " " + column_type; + db.exec(sql); + } +} +} \ No newline at end of file diff --git a/engine/migrations/migration_helper.cc b/engine/migrations/migration_helper.cc index 42cc8d453..b02435cd2 100644 --- a/engine/migrations/migration_helper.cc +++ b/engine/migrations/migration_helper.cc @@ -7,7 +7,6 @@ cpp::result MigrationHelper::BackupDatabase( try { SQLite::Database src_db(src_db_path, SQLite::OPEN_READONLY); sqlite3* backup_db; - if (sqlite3_open(backup_db_path.c_str(), &backup_db) != SQLITE_OK) { throw std::runtime_error("Failed to open backup database"); } diff --git a/engine/migrations/migration_manager.cc b/engine/migrations/migration_manager.cc index 2c2b6ddfd..0e2e41e4e 100644 --- a/engine/migrations/migration_manager.cc +++ b/engine/migrations/migration_manager.cc @@ -5,6 +5,8 @@ #include "utils/file_manager_utils.h" #include "utils/scope_exit.h" #include "utils/widechar_conv.h" +#include "v0/migration.h" +#include "v1/migration.h" namespace cortex::migr { @@ -140,6 +142,9 @@ cpp::result MigrationManager::DoUpFolderStructure( case 0: return v0::MigrateFolderStructureUp(); break; + case 1: + return v1::MigrateFolderStructureUp(); + break; default: return true; @@ -151,6 +156,9 @@ cpp::result MigrationManager::DoDownFolderStructure( case 0: return v0::MigrateFolderStructureDown(); break; + case 1: + return v1::MigrateFolderStructureDown(); + break; default: return true; @@ -184,6 +192,9 @@ cpp::result MigrationManager::DoUpDB(int version) { case 0: return v0::MigrateDBUp(db_); break; + case 1: + return v1::MigrateDBUp(db_); + break; default: return true; @@ -195,6 +206,9 @@ cpp::result MigrationManager::DoDownDB(int version) { case 0: return v0::MigrateDBDown(db_); break; + case 1: + return v1::MigrateDBDown(db_); + break; default: return true; diff --git a/engine/migrations/schema_version.h b/engine/migrations/schema_version.h index 7cfccf27a..1e64110e3 100644 --- a/engine/migrations/schema_version.h +++ b/engine/migrations/schema_version.h @@ -1,4 +1,4 @@ #pragma once //Track the current schema version -#define SCHEMA_VERSION 0 \ No newline at end of file +#define SCHEMA_VERSION 1 \ No newline at end of file diff --git a/engine/migrations/v1/migration.h b/engine/migrations/v1/migration.h new file mode 100644 index 000000000..f9a8038e3 --- /dev/null +++ b/engine/migrations/v1/migration.h @@ -0,0 +1,165 @@ +#pragma once +#include +#include +#include +#include "migrations/db_helper.h" +#include "utils/file_manager_utils.h" +#include "utils/logging_utils.h" +#include "utils/result.hpp" + +namespace cortex::migr::v1 { +// Data folder +namespace fmu = file_manager_utils; + +// cortexcpp +// |__ models +// | |__ cortex.so +// | |__ tinyllama +// | |__ gguf +// |__ engines +// | |__ cortex.llamacpp +// | |__ deps +// | |__ windows-amd64-avx +// |__ logs +// +inline cpp::result MigrateFolderStructureUp() { + if (!std::filesystem::exists(fmu::GetCortexDataPath() / "models")) { + std::filesystem::create_directory(fmu::GetCortexDataPath() / "models"); + } + + if (!std::filesystem::exists(fmu::GetCortexDataPath() / "engines")) { + std::filesystem::create_directory(fmu::GetCortexDataPath() / "engines"); + } + + if (!std::filesystem::exists(fmu::GetCortexDataPath() / "logs")) { + std::filesystem::create_directory(fmu::GetCortexDataPath() / "logs"); + } + + return true; +} + +inline cpp::result MigrateFolderStructureDown() { + // CTL_INF("Folder structure already up to date!"); + return true; +} + +// Database +inline cpp::result MigrateDBUp(SQLite::Database& db) { + try { + db.exec( + "CREATE TABLE IF NOT EXISTS schema_version ( version INTEGER PRIMARY " + "KEY);"); + + // models + { + // Check if the table exists + SQLite::Statement query(db, + "SELECT name FROM sqlite_master WHERE " + "type='table' AND name='models'"); + auto table_exists = query.executeStep(); + + if (table_exists) { + // Alter existing table + cortex::mgr::AddColumnIfNotExists(db, "models", "model_format", "TEXT"); + cortex::mgr::AddColumnIfNotExists(db, "models", "model_source", "TEXT"); + cortex::mgr::AddColumnIfNotExists(db, "models", "status", "TEXT"); + cortex::mgr::AddColumnIfNotExists(db, "models", "engine", "TEXT"); + } else { + // Create new table + db.exec( + "CREATE TABLE models (" + "model_id TEXT PRIMARY KEY," + "author_repo_id TEXT," + "branch_name TEXT," + "path_to_model_yaml TEXT," + "model_alias TEXT," + "model_format TEXT," + "model_source TEXT," + "status TEXT," + "engine TEXT" + ")"); + } + } + + db.exec( + "CREATE TABLE IF NOT EXISTS hardware (" + "uuid TEXT PRIMARY KEY, " + "type TEXT NOT NULL, " + "hardware_id INTEGER NOT NULL, " + "software_id INTEGER NOT NULL, " + "activated INTEGER NOT NULL CHECK (activated IN (0, 1)));"); + + // engines + db.exec( + "CREATE TABLE IF NOT EXISTS engines (" + "id INTEGER PRIMARY KEY AUTOINCREMENT," + "engine_name TEXT," + "type TEXT," + "api_key TEXT," + "url TEXT," + "version TEXT," + "variant TEXT," + "status TEXT," + "metadata TEXT," + "date_created TEXT DEFAULT CURRENT_TIMESTAMP," + "date_updated TEXT DEFAULT CURRENT_TIMESTAMP," + "UNIQUE(engine_name, variant));"); + + // CTL_INF("Database migration up completed successfully."); + return true; + } catch (const std::exception& e) { + CTL_WRN("Migration up failed: " << e.what()); + return cpp::fail(e.what()); + } +}; + +inline cpp::result MigrateDBDown(SQLite::Database& db) { + try { + // models + { + SQLite::Statement query(db, + "SELECT name FROM sqlite_master WHERE " + "type='table' AND name='models'"); + auto table_exists = query.executeStep(); + if (table_exists) { + // Create a new table with the old schema + db.exec( + "CREATE TABLE models_old (" + "model_id TEXT PRIMARY KEY," + "author_repo_id TEXT," + "branch_name TEXT," + "path_to_model_yaml TEXT," + "model_alias TEXT" + ")"); + + // Copy data from the current table to the new table + db.exec( + "INSERT INTO models_old (model_id, author_repo_id, branch_name, " + "path_to_model_yaml, model_alias) " + "SELECT model_id, author_repo_id, branch_name, path_to_model_yaml, " + "model_alias FROM models"); + + // Drop the current table + db.exec("DROP TABLE models"); + + // Rename the new table to the original name + db.exec("ALTER TABLE models_old RENAME TO models"); + } + } + + // hardware + { + // Do nothing + } + + // engines + db.exec("DROP TABLE IF EXISTS engines;"); + // CTL_INF("Migration down completed successfully."); + return true; + } catch (const std::exception& e) { + CTL_WRN("Migration down failed: " << e.what()); + return cpp::fail(e.what()); + } +} + +}; // namespace cortex::migr::v1 diff --git a/engine/services/engine_service.cc b/engine/services/engine_service.cc index c52e32ef0..c91fd0dd0 100644 --- a/engine/services/engine_service.cc +++ b/engine/services/engine_service.cc @@ -2,7 +2,11 @@ #include #include #include +#include #include "algorithm" +#include "database/engines.h" +#include "extensions/remote-engine/anthropic_engine.h" +#include "extensions/remote-engine/openai_engine.h" #include "utils/archive_utils.h" #include "utils/engine_constants.h" #include "utils/engine_matcher_utils.h" @@ -13,7 +17,6 @@ #include "utils/semantic_version_utils.h" #include "utils/system_info_utils.h" #include "utils/url_parser.h" - namespace { std::string GetSuitableCudaVersion(const std::string& engine, const std::string& cuda_driver_version) { @@ -179,6 +182,18 @@ cpp::result EngineService::UninstallEngineVariant( const std::string& engine, const std::optional version, const std::optional variant) { auto ne = NormalizeEngine(engine); + // TODO: handle uninstall remote engine + // only delete a remote engine if no model are using it + auto exist_engine = GetEngineByNameAndVariant(engine); + if (exist_engine.has_value() && exist_engine.value().type == "remote") { + auto result = DeleteEngine(exist_engine.value().id); + if (!result.empty()) { // This mean no error when delete model + CTL_ERR("Failed to delete engine: " << result); + return cpp::fail(result); + } + return cpp::result(true); + } + if (IsEngineLoaded(ne)) { CTL_INF("Engine " << ne << " is already loaded, unloading it"); auto unload_res = UnloadEngine(ne); @@ -226,21 +241,19 @@ cpp::result EngineService::UninstallEngineVariant( cpp::result EngineService::DownloadEngine( const std::string& engine, const std::string& version, const std::optional variant_name) { + auto normalized_version = version == "latest" ? "latest" : string_utils::RemoveSubstring(version, "v"); - auto res = GetEngineVariants(engine, version); if (res.has_error()) { return cpp::fail("Failed to fetch engine releases: " + res.error()); } - if (res.value().empty()) { return cpp::fail("No release found for " + version); } std::optional selected_variant = std::nullopt; - if (variant_name.has_value()) { auto latest_version_semantic = normalized_version == "latest" ? res.value()[0].version @@ -269,9 +282,10 @@ cpp::result EngineService::DownloadEngine( } } - if (selected_variant == std::nullopt) { + if (!selected_variant) { return cpp::fail("Failed to find a suitable variant for " + engine); } + if (IsEngineLoaded(engine)) { CTL_INF("Engine " << engine << " is already loaded, unloading it"); auto unload_res = UnloadEngine(engine); @@ -282,17 +296,17 @@ cpp::result EngineService::DownloadEngine( CTL_INF("Engine " << engine << " unloaded successfully"); } } - auto normalize_version = "v" + selected_variant->version; + auto normalize_version = "v" + selected_variant->version; auto variant_folder_name = engine_matcher_utils::GetVariantFromNameAndVersion( selected_variant->name, engine, selected_variant->version); - auto variant_folder_path = file_manager_utils::GetEnginesContainerPath() / engine / variant_folder_name.value() / normalize_version; - auto variant_path = variant_folder_path / selected_variant->name; + std::filesystem::create_directories(variant_folder_path); + CTL_INF("variant_folder_path: " + variant_folder_path.string()); auto on_finished = [this, engine, selected_variant, variant_folder_path, normalize_version](const DownloadTask& finishedTask) { @@ -301,14 +315,15 @@ cpp::result EngineService::DownloadEngine( CTL_INF("Version: " + normalize_version); auto extract_path = finishedTask.items[0].localPath.parent_path(); - archive_utils::ExtractArchive(finishedTask.items[0].localPath.string(), extract_path.string(), true); auto variant = engine_matcher_utils::GetVariantFromNameAndVersion( selected_variant->name, engine, normalize_version); + CTL_INF("Extracted variant: " + variant.value()); // set as default + auto res = SetDefaultEngineVariant(engine, normalize_version, variant.value()); if (res.has_error()) { @@ -316,10 +331,21 @@ cpp::result EngineService::DownloadEngine( } else { CTL_INF("Set default engine variant: " << res.value().variant); } - - // remove other engines - auto engine_directories = file_manager_utils::GetEnginesContainerPath() / - engine / selected_variant->name; + auto create_res = + EngineService::UpsertEngine(engine, // engine_name + "local", // todo - luke + "", // todo - luke + "", // todo - luke + normalize_version, variant.value(), + "Default", // todo - luke + "" // todo - luke + ); + + if (create_res.has_value()) { + CTL_ERR("Failed to create engine entry: " << create_res->engine_name); + } else { + CTL_INF("Engine entry created successfully"); + } for (const auto& entry : std::filesystem::directory_iterator( variant_folder_path.parent_path())) { @@ -333,7 +359,6 @@ cpp::result EngineService::DownloadEngine( } } - // remove the downloaded file try { std::filesystem::remove(finishedTask.items[0].localPath); } catch (const std::exception& e) { @@ -342,18 +367,18 @@ cpp::result EngineService::DownloadEngine( CTL_INF("Finished!"); }; - auto downloadTask{ + auto downloadTask = DownloadTask{.id = engine, .type = DownloadType::Engine, .items = {DownloadItem{ .id = engine, .downloadUrl = selected_variant->browser_download_url, .localPath = variant_path, - }}}}; + }}}; auto add_task_result = download_service_->AddTask(downloadTask, on_finished); - if (res.has_error()) { - return cpp::fail(res.error()); + if (add_task_result.has_error()) { + return cpp::fail(add_task_result.error()); } return {}; } @@ -656,6 +681,25 @@ cpp::result EngineService::LoadEngine( return {}; } + // Check for remote engine + if (remote_engine::IsRemoteEngine(engine_name)) { + auto exist_engine = GetEngineByNameAndVariant(engine_name); + if (exist_engine.has_error()) { + return cpp::fail("Remote engine '" + engine_name + "' is not installed"); + } + + if (engine_name == kOpenAiEngine) { + engines_[engine_name].engine = new remote_engine::OpenAiEngine(); + } else { + engines_[engine_name].engine = new remote_engine::AnthropicEngine(); + } + + CTL_INF("Loaded engine: " << engine_name); + return {}; + } + + // End hard code + CTL_INF("Loading engine: " << ne); auto selected_engine_variant = GetDefaultEngineVariant(ne); @@ -824,8 +868,11 @@ cpp::result EngineService::UnloadEngine( if (!IsEngineLoaded(ne)) { return cpp::fail("Engine " + ne + " is not loaded yet!"); } - EngineI* e = std::get(engines_[ne].engine); - delete e; + if (std::holds_alternative(engines_[ne].engine)) { + delete std::get(engines_[ne].engine); + } else { + delete std::get(engines_[ne].engine); + } #if defined(_WIN32) if (!RemoveDllDirectory(engines_[ne].cookie)) { @@ -867,9 +914,20 @@ EngineService::GetLatestEngineVersion(const std::string& engine) const { } cpp::result EngineService::IsEngineReady( - const std::string& engine) const { + const std::string& engine) { auto ne = NormalizeEngine(engine); + // Check for remote engine + if (remote_engine::IsRemoteEngine(engine)) { + auto exist_engine = GetEngineByNameAndVariant(engine); + if (exist_engine.has_error()) { + return cpp::fail("Remote engine '" + engine + "' is not installed"); + } + return true; + } + + // End hard code + auto os = hw_inf_.sys_inf->os; if (os == kMacOs && (ne == kOnnxRepo || ne == kTrtLlmRepo)) { return cpp::fail("Engine " + engine + " is not supported on macOS"); @@ -955,3 +1013,101 @@ cpp::result EngineService::UpdateEngine( .from = default_variant->version, .to = latest_version->tag_name}; } + +cpp::result, std::string> +EngineService::GetEngines() { + cortex::db::Engines engines; + auto get_res = engines.GetEngines(); + + if (!get_res.has_value()) { + return cpp::fail("Failed to get engine entries"); + } + + return get_res.value(); +} + +cpp::result EngineService::GetEngineById( + int id) { + cortex::db::Engines engines; + auto get_res = engines.GetEngineById(id); + + if (!get_res.has_value()) { + return cpp::fail("Engine with ID " + std::to_string(id) + " not found"); + } + + return get_res.value(); +} + +cpp::result +EngineService::GetEngineByNameAndVariant( + const std::string& engine_name, const std::optional variant) { + + cortex::db::Engines engines; + auto get_res = engines.GetEngineByNameAndVariant(engine_name, variant); + + if (!get_res.has_value()) { + if (variant.has_value()) { + return cpp::fail("Variant " + variant.value() + " not found for engine " + + engine_name); + } else { + return cpp::fail("Engine " + engine_name + " not found"); + } + } + + return get_res.value(); +} + +cpp::result EngineService::UpsertEngine( + const std::string& engine_name, const std::string& type, + const std::string& api_key, const std::string& url, + const std::string& version, const std::string& variant, + const std::string& status, const std::string& metadata) { + cortex::db::Engines engines; + auto upsert_res = engines.UpsertEngine(engine_name, type, api_key, url, + version, variant, status, metadata); + if (upsert_res.has_value()) { + return upsert_res.value(); + } else { + return cpp::fail("Failed to upsert engine entry"); + } +} + +std::string EngineService::DeleteEngine(int id) { + cortex::db::Engines engines; + auto delete_res = engines.DeleteEngineById(id); + if (delete_res.has_value()) { + return delete_res.value(); + } else { + return ""; + } +} + +cpp::result EngineService::GetRemoteModels( + const std::string& engine_name) { + if (auto r = IsEngineReady(engine_name); r.has_error()) { + return cpp::fail(r.error()); + } + + if (!IsEngineLoaded(engine_name)) { + auto exist_engine = GetEngineByNameAndVariant(engine_name); + if (exist_engine.has_error()) { + return cpp::fail("Remote engine '" + engine_name + "' is not installed"); + } + + if (engine_name == kOpenAiEngine) { + engines_[engine_name].engine = new remote_engine::OpenAiEngine(); + } else { + engines_[engine_name].engine = new remote_engine::AnthropicEngine(); + } + + CTL_INF("Loaded engine: " << engine_name); + } + + auto& e = std::get(engines_[engine_name].engine); + auto res = e->GetRemoteModels(); + if (!res["error"].isNull()) { + return cpp::fail(res["error"].asString()); + } else { + return res; + } +} \ No newline at end of file diff --git a/engine/services/engine_service.h b/engine/services/engine_service.h index 47d7c272f..8c8bfbbe6 100644 --- a/engine/services/engine_service.h +++ b/engine/services/engine_service.h @@ -2,12 +2,18 @@ #include #include +#include #include #include +#include #include + #include "common/engine_servicei.h" #include "cortex-common/EngineI.h" #include "cortex-common/cortexpythoni.h" +#include "cortex-common/remote_enginei.h" +#include "database/engines.h" +#include "extensions/remote-engine/remote_engine.h" #include "services/download_service.h" #include "utils/cpuid/cpu_info.h" #include "utils/dylib.h" @@ -32,11 +38,7 @@ struct EngineUpdateResult { } }; -namespace system_info_utils { -struct SystemInfo; -} - -using EngineV = std::variant; +using EngineV = std::variant; class EngineService : public EngineServiceI { private: @@ -54,6 +56,14 @@ class EngineService : public EngineServiceI { std::mutex engines_mutex_; std::unordered_map engines_{}; + std::shared_ptr download_service_; + + struct HardwareInfo { + std::unique_ptr sys_inf; + cortex::cpuid::CpuInfo cpu_inf; + std::string cuda_driver_version; + }; + HardwareInfo hw_inf_; public: const std::vector kSupportEngines = { @@ -70,7 +80,7 @@ class EngineService : public EngineServiceI { /** * Check if an engines is ready (have at least one variant installed) */ - cpp::result IsEngineReady(const std::string& engine) const; + cpp::result IsEngineReady(const std::string& engine); /** * Handling install engine variant. @@ -110,7 +120,6 @@ class EngineService : public EngineServiceI { std::vector GetLoadedEngines(); cpp::result LoadEngine(const std::string& engine_name); - cpp::result UnloadEngine(const std::string& engine_name); cpp::result @@ -123,6 +132,25 @@ class EngineService : public EngineServiceI { cpp::result UpdateEngine( const std::string& engine); + cpp::result, std::string> GetEngines(); + + cpp::result GetEngineById(int id); + + cpp::result GetEngineByNameAndVariant( + const std::string& engine_name, + const std::optional variant = std::nullopt); + + cpp::result UpsertEngine( + const std::string& engine_name, const std::string& type, + const std::string& api_key, const std::string& url, + const std::string& version, const std::string& variant, + const std::string& status, const std::string& metadata); + + std::string DeleteEngine(int id); + + cpp::result GetRemoteModels( + const std::string& engine_name); + private: cpp::result DownloadEngine( const std::string& engine, const std::string& version = "latest", @@ -137,13 +165,4 @@ class EngineService : public EngineServiceI { cpp::result IsEngineVariantReady( const std::string& engine, const std::string& version, const std::string& variant); - - std::shared_ptr download_service_; - - struct HardwareInfo { - std::unique_ptr sys_inf; - cortex::cpuid::CpuInfo cpu_inf; - std::string cuda_driver_version; - }; - HardwareInfo hw_inf_; -}; +}; \ No newline at end of file diff --git a/engine/services/inference_service.cc b/engine/services/inference_service.cc index 46309823d..ace7e675f 100644 --- a/engine/services/inference_service.cc +++ b/engine/services/inference_service.cc @@ -24,14 +24,26 @@ cpp::result InferenceService::HandleChatCompletion( return cpp::fail(std::make_pair(stt, res)); } - auto engine = std::get(engine_result.value()); - engine->HandleChatCompletion( - json_body, [q, tool_choice](Json::Value status, Json::Value res) { - if (!tool_choice.isNull()) { - res["tool_choice"] = tool_choice; - } - q->push(std::make_pair(status, res)); - }); + if (std::holds_alternative(engine_result.value())) { + std::get(engine_result.value()) + ->HandleChatCompletion( + json_body, [q, tool_choice](Json::Value status, Json::Value res) { + if (!tool_choice.isNull()) { + res["tool_choice"] = tool_choice; + } + q->push(std::make_pair(status, res)); + }); + } else { + std::get(engine_result.value()) + ->HandleChatCompletion( + json_body, [q, tool_choice](Json::Value status, Json::Value res) { + if (!tool_choice.isNull()) { + res["tool_choice"] = tool_choice; + } + q->push(std::make_pair(status, res)); + }); + } + return {}; } @@ -53,10 +65,18 @@ cpp::result InferenceService::HandleEmbedding( LOG_WARN << "Engine is not loaded yet"; return cpp::fail(std::make_pair(stt, res)); } - auto engine = std::get(engine_result.value()); - engine->HandleEmbedding(json_body, [q](Json::Value status, Json::Value res) { - q->push(std::make_pair(status, res)); - }); + + if (std::holds_alternative(engine_result.value())) { + std::get(engine_result.value()) + ->HandleEmbedding(json_body, [q](Json::Value status, Json::Value res) { + q->push(std::make_pair(status, res)); + }); + } else { + std::get(engine_result.value()) + ->HandleEmbedding(json_body, [q](Json::Value status, Json::Value res) { + q->push(std::make_pair(status, res)); + }); + } return {}; } @@ -83,11 +103,20 @@ InferResult InferenceService::LoadModel( // might need mutex here auto engine_result = engine_service_->GetLoadedEngine(engine_type); - auto engine = std::get(engine_result.value()); - engine->LoadModel(json_body, [&stt, &r](Json::Value status, Json::Value res) { - stt = status; - r = res; - }); + + if (std::holds_alternative(engine_result.value())) { + std::get(engine_result.value()) + ->LoadModel(json_body, [&stt, &r](Json::Value status, Json::Value res) { + stt = status; + r = res; + }); + } else { + std::get(engine_result.value()) + ->LoadModel(json_body, [&stt, &r](Json::Value status, Json::Value res) { + stt = status; + r = res; + }); + } return std::make_pair(stt, r); } @@ -110,12 +139,22 @@ InferResult InferenceService::UnloadModel(const std::string& engine_name, json_body["model"] = model_id; LOG_TRACE << "Start unload model"; - auto engine = std::get(engine_result.value()); - engine->UnloadModel(std::make_shared(json_body), + if (std::holds_alternative(engine_result.value())) { + std::get(engine_result.value()) + ->UnloadModel(std::make_shared(json_body), + [&r, &stt](Json::Value status, Json::Value res) { + stt = status; + r = res; + }); + } else { + std::get(engine_result.value()) + ->UnloadModel(std::make_shared(json_body), [&r, &stt](Json::Value status, Json::Value res) { stt = status; r = res; }); + } + return std::make_pair(stt, r); } @@ -141,12 +180,23 @@ InferResult InferenceService::GetModelStatus( } LOG_TRACE << "Start to get model status"; - auto engine = std::get(engine_result.value()); - engine->GetModelStatus(json_body, + + if (std::holds_alternative(engine_result.value())) { + std::get(engine_result.value()) + ->GetModelStatus(json_body, + [&stt, &r](Json::Value status, Json::Value res) { + stt = status; + r = res; + }); + } else { + std::get(engine_result.value()) + ->GetModelStatus(json_body, [&stt, &r](Json::Value status, Json::Value res) { stt = status; r = res; }); + } + return std::make_pair(stt, r); } diff --git a/engine/services/inference_service.h b/engine/services/inference_service.h index 7c09156ff..94097132a 100644 --- a/engine/services/inference_service.h +++ b/engine/services/inference_service.h @@ -5,7 +5,7 @@ #include #include "services/engine_service.h" #include "utils/result.hpp" - +#include "extensions/remote-engine/remote_engine.h" namespace services { // Status and result using InferResult = std::pair; diff --git a/engine/services/model_service.cc b/engine/services/model_service.cc index 3cfff5cb2..d81a9b649 100644 --- a/engine/services/model_service.cc +++ b/engine/services/model_service.cc @@ -64,11 +64,13 @@ void ParseGguf(const DownloadItem& ggufDownloadItem, auto author_id = author.has_value() ? author.value() : "cortexso"; cortex::db::Models modellist_utils_obj; - cortex::db::ModelEntry model_entry{.model = ggufDownloadItem.id, - .author_repo_id = author_id, - .branch_name = branch, - .path_to_model_yaml = rel.string(), - .model_alias = ggufDownloadItem.id}; + cortex::db::ModelEntry model_entry{ + .model = ggufDownloadItem.id, + .author_repo_id = author_id, + .branch_name = branch, + .path_to_model_yaml = rel.string(), + .model_alias = ggufDownloadItem.id, + .status = cortex::db::ModelStatus::Downloaded}; auto result = modellist_utils_obj.AddModelEntry(model_entry, true); if (result.has_error()) { CTL_WRN("Error adding model to modellist: " + result.error()); @@ -718,6 +720,49 @@ cpp::result ModelService::StartModel( .string()); auto mc = yaml_handler.GetModelConfig(); + // Running remote model + if (remote_engine::IsRemoteEngine(mc.engine)) { + + config::RemoteModelConfig remote_mc; + remote_mc.LoadFromYamlFile( + fmu::ToAbsoluteCortexDataPath( + fs::path(model_entry.value().path_to_model_yaml)) + .string()); + auto remote_engine_entry = + engine_svc_->GetEngineByNameAndVariant(mc.engine); + if (remote_engine_entry.has_error()) { + CTL_WRN("Remote engine error: " + model_entry.error()); + return cpp::fail(remote_engine_entry.error()); + } + auto remote_engine_json = remote_engine_entry.value().ToJson(); + json_data = remote_mc.ToJson(); + + json_data["api_key"] = std::move(remote_engine_json["api_key"]); + json_data["model_path"] = + fmu::ToAbsoluteCortexDataPath( + fs::path(model_entry.value().path_to_model_yaml)) + .string(); + json_data["metadata"] = std::move(remote_engine_json["metadata"]); + + auto ir = + inference_svc_->LoadModel(std::make_shared(json_data)); + auto status = std::get<0>(ir)["status_code"].asInt(); + auto data = std::get<1>(ir); + if (status == drogon::k200OK) { + return StartModelResult{.success = true, .warning = ""}; + } else if (status == drogon::k409Conflict) { + CTL_INF("Model '" + model_handle + "' is already loaded"); + return StartModelResult{.success = true, .warning = ""}; + } else { + // only report to user the error + CTL_ERR("Model failed to start with status code: " << status); + return cpp::fail("Model failed to start: " + + data["message"].asString()); + } + } + + // end hard code + json_data = mc.ToJson(); if (mc.files.size() > 0) { #if defined(_WIN32) diff --git a/engine/test/components/test_models_db.cc b/engine/test/components/test_models_db.cc index 8c3ebbe00..ab0ea9f70 100644 --- a/engine/test/components/test_models_db.cc +++ b/engine/test/components/test_models_db.cc @@ -6,6 +6,7 @@ namespace cortex::db { namespace { constexpr const auto kTestDb = "./test.db"; } + class ModelsTestSuite : public ::testing::Test { public: ModelsTestSuite() @@ -14,12 +15,17 @@ class ModelsTestSuite : public ::testing::Test { void SetUp() { try { db_.exec( - "CREATE TABLE IF NOT EXISTS models (" + "CREATE TABLE models (" "model_id TEXT PRIMARY KEY," "author_repo_id TEXT," "branch_name TEXT," "path_to_model_yaml TEXT," - "model_alias TEXT);"); + "model_alias TEXT," + "model_format TEXT," + "model_source TEXT," + "status TEXT," + "engine TEXT" + ")"); } catch (const std::exception& e) {} } @@ -33,20 +39,27 @@ class ModelsTestSuite : public ::testing::Test { SQLite::Database db_; cortex::db::Models model_list_; - const cortex::db::ModelEntry kTestModel{"test_model_id", "test_author", - "main", "/path/to/model.yaml", - "test_alias"}; + const cortex::db::ModelEntry kTestModel{ + "test_model_id", "test_author", + "main", "/path/to/model.yaml", + "test_alias", "test_format", + "test_source", cortex::db::ModelStatus::Downloaded, + "test_engine"}; }; TEST_F(ModelsTestSuite, TestAddModelEntry) { EXPECT_TRUE(model_list_.AddModelEntry(kTestModel).value()); auto retrieved_model = model_list_.GetModelInfo(kTestModel.model); - EXPECT_TRUE(retrieved_model); + EXPECT_TRUE(retrieved_model.has_value()); EXPECT_EQ(retrieved_model.value().model, kTestModel.model); EXPECT_EQ(retrieved_model.value().author_repo_id, kTestModel.author_repo_id); + EXPECT_EQ(retrieved_model.value().model_format, kTestModel.model_format); + EXPECT_EQ(retrieved_model.value().model_source, kTestModel.model_source); + EXPECT_EQ(retrieved_model.value().status, kTestModel.status); + EXPECT_EQ(retrieved_model.value().engine, kTestModel.engine); - // // Clean up + // Clean up EXPECT_TRUE(model_list_.DeleteModelEntry(kTestModel.model).value()); } @@ -54,7 +67,7 @@ TEST_F(ModelsTestSuite, TestGetModelInfo) { EXPECT_TRUE(model_list_.AddModelEntry(kTestModel).value()); auto model_by_id = model_list_.GetModelInfo(kTestModel.model); - EXPECT_TRUE(model_by_id); + EXPECT_TRUE(model_by_id.has_value()); EXPECT_EQ(model_by_id.value().model, kTestModel.model); auto model_by_alias = model_list_.GetModelInfo("test_alias"); @@ -71,14 +84,14 @@ TEST_F(ModelsTestSuite, TestUpdateModelEntry) { EXPECT_TRUE(model_list_.AddModelEntry(kTestModel).value()); cortex::db::ModelEntry updated_model = kTestModel; + updated_model.status = cortex::db::ModelStatus::Downloaded; EXPECT_TRUE( model_list_.UpdateModelEntry(kTestModel.model, updated_model).value()); auto retrieved_model = model_list_.GetModelInfo(kTestModel.model); - EXPECT_TRUE(retrieved_model); - EXPECT_TRUE( - model_list_.UpdateModelEntry(kTestModel.model, updated_model).value()); + EXPECT_TRUE(retrieved_model.has_value()); + EXPECT_EQ(retrieved_model.value().status, updated_model.status); // Clean up EXPECT_TRUE(model_list_.DeleteModelEntry(kTestModel.model).value()); @@ -117,7 +130,7 @@ TEST_F(ModelsTestSuite, TestPersistence) { // Create a new ModelListUtils instance to test if it loads from file cortex::db::Models new_model_list(db_); auto retrieved_model = new_model_list.GetModelInfo(kTestModel.model); - EXPECT_TRUE(retrieved_model); + EXPECT_TRUE(retrieved_model.has_value()); EXPECT_EQ(retrieved_model.value().model, kTestModel.model); EXPECT_EQ(retrieved_model.value().author_repo_id, kTestModel.author_repo_id); EXPECT_TRUE(model_list_.DeleteModelEntry(kTestModel.model).value()); @@ -136,7 +149,7 @@ TEST_F(ModelsTestSuite, TestUpdateModelAlias) { EXPECT_TRUE( model_list_.UpdateModelAlias(kTestModel.model, kNewTestAlias).value()); auto updated_model = model_list_.GetModelInfo(kNewTestAlias); - EXPECT_TRUE(updated_model); + EXPECT_TRUE(updated_model.has_value()); EXPECT_EQ(updated_model.value().model_alias, kNewTestAlias); EXPECT_EQ(updated_model.value().model, kTestModel.model); @@ -174,4 +187,5 @@ TEST_F(ModelsTestSuite, TestHasModel) { // Clean up EXPECT_TRUE(model_list_.DeleteModelEntry(kTestModel.model).value()); } -} // namespace cortex::db + +} // namespace cortex::db \ No newline at end of file diff --git a/engine/utils/engine_constants.h b/engine/utils/engine_constants.h index 5dab49936..020109fd8 100644 --- a/engine/utils/engine_constants.h +++ b/engine/utils/engine_constants.h @@ -3,6 +3,8 @@ constexpr const auto kOnnxEngine = "onnxruntime"; constexpr const auto kLlamaEngine = "llama-cpp"; constexpr const auto kTrtLlmEngine = "tensorrt-llm"; +constexpr const auto kOpenAiEngine = "openai"; +constexpr const auto kAnthropicEngine = "anthropic"; constexpr const auto kOnnxRepo = "cortex.onnx"; constexpr const auto kLlamaRepo = "cortex.llamacpp"; diff --git a/engine/utils/logging_utils.h b/engine/utils/logging_utils.h index d2c04a7e8..7d4cf35f1 100644 --- a/engine/utils/logging_utils.h +++ b/engine/utils/logging_utils.h @@ -9,6 +9,8 @@ inline bool log_verbose = false; inline bool is_server = false; // Only use trantor log +#define CTL_TRC(msg) LOG_TRACE << msg; + #define CTL_DBG(msg) LOG_DEBUG << msg; #define CTL_INF(msg) LOG_INFO << msg; diff --git a/engine/utils/remote_models_utils.h b/engine/utils/remote_models_utils.h new file mode 100644 index 000000000..7b7906f2c --- /dev/null +++ b/engine/utils/remote_models_utils.h @@ -0,0 +1,132 @@ +#pragma once + +#include +#include +#include + +namespace remote_models_utils { +constexpr char chat_completion_request_template[] = + "{ {% set first = true %} {% for key, value in input_request %} {% if key " + "== \"messages\" or key == \"model\" or key == \"temperature\" or key == " + "\"store\" or key == \"max_tokens\" or key == \"stream\" or key == " + "\"presence_penalty\" or key == \"metadata\" or key == " + "\"frequency_penalty\" or key == \"tools\" or key == \"tool_choice\" or " + "key == \"logprobs\" or key == \"top_logprobs\" or key == \"logit_bias\" " + "or key == \"n\" or key == \"modalities\" or key == \"prediction\" or key " + "== \"response_format\" or key == \"service_tier\" or key == \"seed\" or " + "key == \"stop\" or key == \"stream_options\" or key == \"top_p\" or key " + "== \"parallel_tool_calls\" or key == \"user\" %} {% if not first %},{% " + "endif %} \"{{ key }}\": {{ tojson(value) }} {% set first = false %} {% " + "endif %} {% endfor %} }"; + +constexpr char chat_completion_response_template[] = + "{ {% set first = true %} {% for key, value in input_request %} {% if key " + "== \"messages\" or key == \"model\" or key == \"temperature\" or key == " + "\"store\" or key == \"max_tokens\" or key == \"stream\" or key == " + "\"presence_penalty\" or key == \"metadata\" or key == " + "\"frequency_penalty\" or key == \"tools\" or key == \"tool_choice\" or " + "key == \"logprobs\" or key == \"top_logprobs\" or key == \"logit_bias\" " + "or key == \"n\" or key == \"modalities\" or key == \"prediction\" or key " + "== \"response_format\" or key == \"service_tier\" or key == \"seed\" or " + "key == \"stop\" or key == \"stream_options\" or key == \"top_p\" or key " + "== \"parallel_tool_calls\" or key == \"user\" %} {% if not first %},{% " + "endif %} \"{{ key }}\": {{ tojson(value) }} {% set first = false %} {% " + "endif %} {% endfor %} }"; + +constexpr char chat_completion_url[] = + "https://api.openai.com/v1/chat/completions"; + +inline Json::Value yamlToJson(const YAML::Node& node) { + Json::Value result; + + switch (node.Type()) { + case YAML::NodeType::Null: + return Json::Value(); + case YAML::NodeType::Scalar: { + // For scalar types, we'll first try to parse as string + std::string str_val = node.as(); + + // Try to parse as boolean + if (str_val == "true" || str_val == "True" || str_val == "TRUE") + return Json::Value(true); + if (str_val == "false" || str_val == "False" || str_val == "FALSE") + return Json::Value(false); + + // Try to parse as number + try { + // Check if it's an integer + size_t pos; + long long int_val = std::stoll(str_val, &pos); + if (pos == str_val.length()) { + return Json::Value(static_cast(int_val)); + } + + // Check if it's a float + double float_val = std::stod(str_val, &pos); + if (pos == str_val.length()) { + return Json::Value(float_val); + } + } catch (...) { + // If parsing as number fails, use as string + } + + // Default to string if no other type matches + return Json::Value(str_val); + } + case YAML::NodeType::Sequence: { + result = Json::Value(Json::arrayValue); + for (const auto& elem : node) { + result.append(yamlToJson(elem)); + } + return result; + } + case YAML::NodeType::Map: { + result = Json::Value(Json::objectValue); + for (const auto& it : node) { + std::string key = it.first.as(); + result[key] = yamlToJson(it.second); + } + return result; + } + default: + return Json::Value(); + } +} + +inline YAML::Node jsonToYaml(const Json::Value& json) { + YAML::Node result; + + switch (json.type()) { + case Json::nullValue: + result = YAML::Node(YAML::NodeType::Null); + break; + case Json::intValue: + result = json.asInt64(); + break; + case Json::uintValue: + result = json.asUInt64(); + break; + case Json::realValue: + result = json.asDouble(); + break; + case Json::stringValue: + result = json.asString(); + break; + case Json::booleanValue: + result = json.asBool(); + break; + case Json::arrayValue: + result = YAML::Node(YAML::NodeType::Sequence); + for (const auto& elem : json) + result.push_back(jsonToYaml(elem)); + break; + case Json::objectValue: + result = YAML::Node(YAML::NodeType::Map); + for (const auto& key : json.getMemberNames()) + result[key] = jsonToYaml(json[key]); + break; + } + return result; +} + +} // namespace remote_models_utils \ No newline at end of file diff --git a/engine/utils/result.hpp b/engine/utils/result.hpp index 96243f72e..7f7356b84 100644 --- a/engine/utils/result.hpp +++ b/engine/utils/result.hpp @@ -34,7 +34,6 @@ #include // std::size_t #include // std::enable_if, std::is_constructible, etc -#include // placement-new #include // std::address_of #include // std::reference_wrapper, std::invoke #include // std::in_place_t, std::forward diff --git a/engine/vcpkg.json b/engine/vcpkg.json index 36fa322a3..962d06ffd 100644 --- a/engine/vcpkg.json +++ b/engine/vcpkg.json @@ -13,6 +13,7 @@ "sqlitecpp", "trantor", "indicators", + "inja", "lfreist-hwinfo" ] }