From c67dc4e6de7d665dffea6f16181fc8420f00e62a Mon Sep 17 00:00:00 2001 From: James Date: Wed, 25 Dec 2024 23:42:50 +0700 Subject: [PATCH] update --- engine/common/assistant.h | 120 +++++++------- engine/common/dto/assistant_create_dto.h | 39 +++-- engine/common/dto/assistant_update_dto.h | 26 +-- engine/common/dto/base_dto.h | 5 +- engine/common/message_attachment.h | 13 +- .../common/repository/assistant_repository.h | 2 +- engine/common/thread.h | 12 +- engine/common/thread_tool_resources.h | 54 ------- engine/common/tool_resources.h | 100 ++++++++++++ engine/controllers/assistants.cc | 150 +++++++++++++++++- engine/controllers/assistants.h | 31 ++++ engine/main.cc | 6 +- .../repositories/assistant_fs_repository.cc | 2 +- engine/repositories/assistant_fs_repository.h | 4 +- engine/repositories/file_fs_repository.h | 2 +- engine/repositories/message_fs_repository.h | 2 +- engine/services/assistant_service.cc | 92 ++++++++++- engine/services/assistant_service.h | 6 +- engine/services/thread_service.cc | 4 +- engine/services/thread_service.h | 5 +- 20 files changed, 511 insertions(+), 164 deletions(-) delete mode 100644 engine/common/thread_tool_resources.h create mode 100644 engine/common/tool_resources.h diff --git a/engine/common/assistant.h b/engine/common/assistant.h index 03643f1ae..b9592e3e9 100644 --- a/engine/common/assistant.h +++ b/engine/common/assistant.h @@ -1,9 +1,11 @@ #pragma once #include +#include "common/assistant_code_interpreter_tool.h" +#include "common/assistant_file_search_tool.h" +#include "common/assistant_function_tool.h" #include "common/assistant_tool.h" -#include "common/message_attachment.h" -#include "common/thread_tool_resources.h" +#include "common/tool_resources.h" #include "common/variant_map.h" #include "utils/logging_utils.h" #include "utils/result.hpp" @@ -87,18 +89,19 @@ struct Assistant : JsonSerializable { Assistant& operator=(const Assistant&) = delete; Assistant(Assistant&& other) noexcept - : id(std::move(other.id)), - object(std::move(other.object)), - created_at(other.created_at), - name(std::move(other.name)), - description(std::move(other.description)), + : id{std::move(other.id)}, + object{std::move(other.object)}, + created_at{other.created_at}, + name{std::move(other.name)}, + description{std::move(other.description)}, model(std::move(other.model)), instructions(std::move(other.instructions)), tools(std::move(other.tools)), tool_resources(std::move(other.tool_resources)), metadata(std::move(other.metadata)), - temperature(std::move(other.temperature)), - top_p(std::move(other.top_p)) {} + temperature{std::move(other.temperature)}, + top_p{std::move(other.top_p)}, + response_format{std::move(other.response_format)} {} Assistant& operator=(Assistant&& other) noexcept { if (this != &other) { @@ -114,6 +117,7 @@ struct Assistant : JsonSerializable { metadata = std::move(other.metadata); temperature = std::move(other.temperature); top_p = std::move(other.top_p); + response_format = std::move(other.response_format); } return *this; } @@ -168,8 +172,7 @@ struct Assistant : JsonSerializable { * requires a list of file IDs, while the file_search tool requires a list * of vector store IDs. */ - std::optional> - tool_resources; + std::optional> tool_resources; /** * Set of 16 key-value pairs that can be attached to an object. This can be @@ -196,6 +199,8 @@ struct Assistant : JsonSerializable { */ std::optional top_p; + std::variant response_format; + cpp::result ToJson() override { try { Json::Value root; @@ -226,7 +231,7 @@ struct Assistant : JsonSerializable { Json::Value tool_resources_json; if (auto* code_interpreter = - std::get_if(&tool_resources.value())) { + std::get_if(&tool_resources.value())) { if (auto result = code_interpreter->ToJson(); result.has_value()) { tool_resources_json["code_interpreter"] = result.value(); } else { @@ -234,7 +239,7 @@ struct Assistant : JsonSerializable { result.error()); } } else if (auto* file_search = - std::get_if(&tool_resources.value())) { + std::get_if(&tool_resources.value())) { if (auto result = file_search->ToJson(); result.has_value()) { tool_resources_json["file_search"] = result.value(); } else { @@ -312,60 +317,61 @@ struct Assistant : JsonSerializable { // Parse tools array if (json.isMember("tools") && json["tools"].isArray()) { - // TODO: namh implement - // for (const auto& tool_json : json["tools"]) { - // auto tool = AssistantTool::FromJson(tool_json); - // if (!tool.has_value()) { - // return cpp::fail("Failed to parse tool: " + tool.error()); - // } - // assistant.tools.push_back(std::move(tool.value())); - // } + auto tools_array = json["tools"]; + for (const auto& tool : tools_array) { + if (!tool.isMember("type") || !tool["type"].isString()) { + CTL_WRN("Tool missing type field or invalid type"); + continue; + } + + std::string tool_type = tool["type"].asString(); + if (tool_type == "file_search") { + auto result = AssistantFileSearchTool::FromJson(tool); + if (result.has_value()) { + assistant.tools.push_back( + std::make_unique( + std::move(result.value()))); + } else { + CTL_WRN("Failed to parse file_search tool: " + result.error()); + } + } else if (tool_type == "code_interpreter") { + auto result = AssistantCodeInterpreterTool::FromJson(tool); + if (result.has_value()) { + assistant.tools.push_back( + std::make_unique( + std::move(result.value()))); + } else { + CTL_WRN("Failed to parse code_interpreter tool: " + + result.error()); + } + } else if (tool_type == "function") { + auto result = AssistantFunctionTool::FromJson(tool); + if (result.has_value()) { + assistant.tools.push_back(std::make_unique( + std::move(result.value()))); + } else { + CTL_WRN("Failed to parse function tool: " + result.error()); + } + } else { + CTL_WRN("Unknown tool type: " + tool_type); + } + } } // Parse tool_resources if (json.isMember("tool_resources") && - json["tool_resources"].isObject()) { - // const auto& resources_json = json["tool_resources"]; - // - // if (resources_json.isMember("code_interpreter")) { - // auto code_interpreter = ThreadCodeInterpreter::FromJson( - // resources_json["code_interpreter"]); - // if (!code_interpreter.has_value()) { - // return cpp::fail("Failed to parse code_interpreter: " + - // code_interpreter.error()); - // } - // assistant.tool_resources = std::move(code_interpreter.value()); - // } else if (resources_json.isMember("file_search")) { - // auto file_search = - // ThreadFileSearch::FromJson(resources_json["file_search"]); - // if (!file_search.has_value()) { - // return cpp::fail("Failed to parse file_search: " + - // file_search.error()); - // } - // assistant.tool_resources = std::move(file_search.value()); - // } - } + json["tool_resources"].isObject()) {} // Parse metadata if (json.isMember("metadata") && json["metadata"].isObject()) { - const auto& metadata_json = json["metadata"]; - for (const auto& key : metadata_json.getMemberNames()) { - const auto& value = metadata_json[key]; - if (value.isBool()) { - assistant.metadata[key] = value.asBool(); - } else if (value.isUInt64()) { - assistant.metadata[key] = value.asUInt64(); - } else if (value.isDouble()) { - assistant.metadata[key] = value.asDouble(); - } else if (value.isString()) { - assistant.metadata[key] = value.asString(); - } else { - return cpp::fail("Invalid metadata value type for key: " + key); - } + auto res = Cortex::ConvertJsonValueToMap(json["metadata"]); + if (res.has_value()) { + assistant.metadata = res.value(); + } else { + CTL_WRN("Failed to convert metadata to map: " + res.error()); } } - // Parse optional numerical fields if (json.isMember("temperature") && json["temperature"].isDouble()) { assistant.temperature = json["temperature"].asFloat(); } diff --git a/engine/common/dto/assistant_create_dto.h b/engine/common/dto/assistant_create_dto.h index b30774003..b27e020b6 100644 --- a/engine/common/dto/assistant_create_dto.h +++ b/engine/common/dto/assistant_create_dto.h @@ -20,7 +20,7 @@ struct CreateAssistantDto : public BaseDto { : model{std::move(other.model)}, name{std::move(other.name)}, description{std::move(other.description)}, - instruction{std::move(other.instruction)}, + instructions{std::move(other.instructions)}, metadata{std::move(other.metadata)}, temperature{std::move(other.temperature)}, top_p{std::move(other.top_p)}, @@ -31,7 +31,7 @@ struct CreateAssistantDto : public BaseDto { model = std::move(other.model); name = std::move(other.name); description = std::move(other.description); - instruction = std::move(other.instruction); + instructions = std::move(other.instructions); metadata = std::move(other.metadata); temperature = std::move(other.temperature); top_p = std::move(other.top_p); @@ -46,7 +46,7 @@ struct CreateAssistantDto : public BaseDto { std::optional description; - std::optional instruction; + std::optional instructions; // namH: implement tools @@ -58,17 +58,26 @@ struct CreateAssistantDto : public BaseDto { std::optional top_p; - std::optional response_format; + std::optional> response_format; - bool Validate() const override { + cpp::result Validate() const override { if (model.empty()) { - return false; + return cpp::fail("Model is mandatory"); } - return true; + if (response_format.has_value()) { + const auto& variant_value = response_format.value(); + if (std::holds_alternative(variant_value)) { + if (std::get(variant_value) != "auto") { + return cpp::fail("Invalid response_format"); + } + } + } + + return {}; } - CreateAssistantDto FromJson(Json::Value&& root) override { + static CreateAssistantDto FromJson(Json::Value&& root) { if (root.empty()) { throw std::runtime_error("Json passed in FromJson can't be empty"); } @@ -80,8 +89,8 @@ struct CreateAssistantDto : public BaseDto { if (root.isMember("description")) { dto.description = std::move(root["description"].asString()); } - if (root.isMember("instruction")) { - dto.instruction = std::move(root["instruction"].asString()); + if (root.isMember("instructions")) { + dto.instructions = std::move(root["instructions"].asString()); } if (root["metadata"].isObject() && !root["metadata"].empty()) { auto res = Cortex::ConvertJsonValueToMap(root["metadata"]); @@ -98,7 +107,15 @@ struct CreateAssistantDto : public BaseDto { dto.top_p = root["top_p"].asFloat(); } if (root.isMember("response_format")) { - dto.response_format = std::move(root["response_format"].asString()); + const auto& response_format = root["response_format"]; + if (response_format.isString()) { + dto.response_format = response_format.asString(); + } else if (response_format.isObject()) { + dto.response_format = response_format; + } else { + throw std::runtime_error( + "response_format must be either a string or an object"); + } } return dto; } diff --git a/engine/common/dto/assistant_update_dto.h b/engine/common/dto/assistant_update_dto.h index baed82a26..620cc35a5 100644 --- a/engine/common/dto/assistant_update_dto.h +++ b/engine/common/dto/assistant_update_dto.h @@ -12,7 +12,7 @@ struct UpdateAssistantDto : public BaseDto { std::optional description; - std::optional instruction; + std::optional instructions; // namH: implement tools @@ -24,20 +24,20 @@ struct UpdateAssistantDto : public BaseDto { std::optional top_p; - std::optional response_format; + std::optional> response_format; - bool Validate() const override { + cpp::result Validate() const override { if (!model.has_value() && !name.has_value() && !description.has_value() && - !instruction.has_value() && !metadata.has_value() && + !instructions.has_value() && !metadata.has_value() && !temperature.has_value() && !top_p.has_value() && !response_format.has_value()) { - return false; + return cpp::fail("At least one field must be provided"); } - return true; + return {}; } - UpdateAssistantDto FromJson(Json::Value&& root) override { + static UpdateAssistantDto FromJson(Json::Value&& root) { if (root.empty()) { throw std::runtime_error("Json passed in FromJson can't be empty"); } @@ -50,7 +50,7 @@ struct UpdateAssistantDto : public BaseDto { dto.description = std::move(root["description"].asString()); } if (root.isMember("instruction")) { - dto.instruction = std::move(root["instruction"].asString()); + dto.instructions = std::move(root["instruction"].asString()); } if (root["metadata"].isObject() && !root["metadata"].empty()) { auto res = Cortex::ConvertJsonValueToMap(root["metadata"]); @@ -67,7 +67,15 @@ struct UpdateAssistantDto : public BaseDto { dto.top_p = root["top_p"].asFloat(); } if (root.isMember("response_format")) { - dto.response_format = std::move(root["response_format"].asString()); + const auto& response_format = root["response_format"]; + if (response_format.isString()) { + dto.response_format = response_format.asString(); + } else if (response_format.isObject()) { + dto.response_format = response_format; + } else { + throw std::runtime_error( + "response_format must be either a string or an object"); + } } return dto; }; diff --git a/engine/common/dto/base_dto.h b/engine/common/dto/base_dto.h index d21cdcb85..ed7460aa3 100644 --- a/engine/common/dto/base_dto.h +++ b/engine/common/dto/base_dto.h @@ -1,6 +1,7 @@ #pragma once #include +#include "utils/result.hpp" namespace dto { template @@ -10,8 +11,6 @@ struct BaseDto { /** * Validate itself. */ - virtual bool Validate() const = 0; - - virtual T FromJson(Json::Value&& root) = 0; + virtual cpp::result Validate() const = 0; }; } // namespace dto diff --git a/engine/common/message_attachment.h b/engine/common/message_attachment.h index d564b7609..6a0fb02e9 100644 --- a/engine/common/message_attachment.h +++ b/engine/common/message_attachment.h @@ -4,7 +4,6 @@ #include "common/json_serializable.h" namespace OpenAi { - // The tools to add this file to. struct Tool { std::string type; @@ -15,13 +14,17 @@ struct Tool { }; // The type of tool being defined: code_interpreter -struct CodeInterpreter : Tool { - CodeInterpreter() : Tool{"code_interpreter"} {} +struct MessageCodeInterpreter : Tool { + MessageCodeInterpreter() : Tool{"code_interpreter"} {} + + ~MessageCodeInterpreter() = default; }; // The type of tool being defined: file_search -struct FileSearch : Tool { - FileSearch() : Tool{"file_search"} {} +struct MessageFileSearch : Tool { + MessageFileSearch() : Tool{"file_search"} {} + + ~MessageFileSearch() = default; }; // A list of files attached to the message, and the tools they were added to. diff --git a/engine/common/repository/assistant_repository.h b/engine/common/repository/assistant_repository.h index d2e59ba91..d0ff1908d 100644 --- a/engine/common/repository/assistant_repository.h +++ b/engine/common/repository/assistant_repository.h @@ -12,7 +12,7 @@ class AssistantRepository { virtual cpp::result CreateAssistant( OpenAi::Assistant& assistant) = 0; - virtual cpp::result RetrieveAssisant( + virtual cpp::result RetrieveAssistant( const std::string assistant_id) const = 0; virtual cpp::result ModifyAssistant( diff --git a/engine/common/thread.h b/engine/common/thread.h index 2bd5d866b..dc57ba32d 100644 --- a/engine/common/thread.h +++ b/engine/common/thread.h @@ -4,7 +4,7 @@ #include #include #include "common/assistant.h" -#include "common/thread_tool_resources.h" +#include "common/tool_resources.h" #include "common/variant_map.h" #include "json_serializable.h" #include "utils/logging_utils.h" @@ -36,7 +36,7 @@ struct Thread : JsonSerializable { * of tool. For example, the code_interpreter tool requires a list of * file IDs, while the file_search tool requires a list of vector store IDs. */ - std::unique_ptr tool_resources; + std::unique_ptr tool_resources; /** * Set of 16 key-value pairs that can be attached to an object. @@ -65,7 +65,7 @@ struct Thread : JsonSerializable { const auto& tool_json = json["tool_resources"]; if (tool_json.isMember("code_interpreter")) { - auto code_interpreter = std::make_unique(); + auto code_interpreter = std::make_unique(); const auto& file_ids = tool_json["code_interpreter"]["file_ids"]; if (file_ids.isArray()) { for (const auto& file_id : file_ids) { @@ -74,7 +74,7 @@ struct Thread : JsonSerializable { } thread.tool_resources = std::move(code_interpreter); } else if (tool_json.isMember("file_search")) { - auto file_search = std::make_unique(); + auto file_search = std::make_unique(); const auto& store_ids = tool_json["file_search"]["vector_store_ids"]; if (store_ids.isArray()) { for (const auto& store_id : store_ids) { @@ -148,10 +148,10 @@ struct Thread : JsonSerializable { Json::Value tool_json; if (auto code_interpreter = - dynamic_cast(tool_resources.get())) { + dynamic_cast(tool_resources.get())) { tool_json["code_interpreter"] = tool_result.value(); } else if (auto file_search = - dynamic_cast(tool_resources.get())) { + dynamic_cast(tool_resources.get())) { tool_json["file_search"] = tool_result.value(); } json["tool_resources"] = tool_json; diff --git a/engine/common/thread_tool_resources.h b/engine/common/thread_tool_resources.h deleted file mode 100644 index 4332bc92f..000000000 --- a/engine/common/thread_tool_resources.h +++ /dev/null @@ -1,54 +0,0 @@ -#pragma once - -#include -#include -#include "common/json_serializable.h" - -namespace OpenAi { - -struct ThreadToolResources : JsonSerializable { - virtual ~ThreadToolResources() = default; - - virtual cpp::result ToJson() override = 0; -}; - -struct ThreadCodeInterpreter : ThreadToolResources { - ~ThreadCodeInterpreter() override = default; - - std::vector file_ids; - - cpp::result ToJson() override { - try { - Json::Value json; - Json::Value file_ids_json{Json::arrayValue}; - for (auto& file_id : file_ids) { - file_ids_json.append(file_id); - } - json["file_ids"] = file_ids_json; - return json; - } catch (const std::exception& e) { - return cpp::fail(std::string("ToJson failed: ") + e.what()); - } - } -}; - -struct ThreadFileSearch : ThreadToolResources { - ~ThreadFileSearch() override = default; - - std::vector vector_store_ids; - - cpp::result ToJson() override { - try { - Json::Value json; - Json::Value vector_store_ids_json{Json::arrayValue}; - for (auto& vector_store_id : vector_store_ids) { - vector_store_ids_json.append(vector_store_id); - } - json["vector_store_ids"] = vector_store_ids_json; - return json; - } catch (const std::exception& e) { - return cpp::fail(std::string("ToJson failed: ") + e.what()); - } - } -}; -} // namespace OpenAi diff --git a/engine/common/tool_resources.h b/engine/common/tool_resources.h new file mode 100644 index 000000000..c026ef60d --- /dev/null +++ b/engine/common/tool_resources.h @@ -0,0 +1,100 @@ +#pragma once + +#include +#include +#include "common/json_serializable.h" + +namespace OpenAi { + +struct ToolResources : JsonSerializable { + virtual ~ToolResources() = default; + + virtual cpp::result ToJson() override = 0; +}; + +struct CodeInterpreter : ToolResources { + CodeInterpreter() = default; + + ~CodeInterpreter() override = default; + + CodeInterpreter(const CodeInterpreter&) = delete; + + CodeInterpreter& operator=(const CodeInterpreter&) = delete; + + CodeInterpreter(CodeInterpreter&& other) noexcept + : file_ids(std::move(other.file_ids)) {} + + CodeInterpreter& operator=(CodeInterpreter&& other) noexcept { + if (this != &other) { + file_ids = std::move(other.file_ids); + } + return *this; + } + + std::vector file_ids; + + static cpp::result FromJson( + const Json::Value& json) { + CodeInterpreter code_interpreter; + if (json.isMember("file_ids")) { + for (const auto& file_id : json["file_ids"]) { + code_interpreter.file_ids.push_back(file_id.asString()); + } + } + return code_interpreter; + } + + cpp::result ToJson() override { + Json::Value json; + Json::Value file_ids_json{Json::arrayValue}; + for (auto& file_id : file_ids) { + file_ids_json.append(file_id); + } + json["file_ids"] = file_ids_json; + return json; + } +}; + +struct FileSearch : ToolResources { + FileSearch() = default; + + ~FileSearch() override = default; + + FileSearch(const FileSearch&) = delete; + + FileSearch& operator=(const FileSearch&) = delete; + + FileSearch(FileSearch&& other) noexcept + : vector_store_ids{std::move(other.vector_store_ids)} {} + + FileSearch& operator=(FileSearch&& other) noexcept { + if (this != &other) { + vector_store_ids = std::move(other.vector_store_ids); + } + return *this; + } + + std::vector vector_store_ids; + + static cpp::result FromJson( + const Json::Value& json) { + FileSearch file_search; + if (json.isMember("vector_store_ids")) { + for (const auto& vector_store_id : json["vector_store_ids"]) { + file_search.vector_store_ids.push_back(vector_store_id.asString()); + } + } + return file_search; + } + + cpp::result ToJson() override { + Json::Value json; + Json::Value vector_store_ids_json{Json::arrayValue}; + for (auto& vector_store_id : vector_store_ids) { + vector_store_ids_json.append(vector_store_id); + } + json["vector_store_ids"] = vector_store_ids_json; + return json; + } +}; +} // namespace OpenAi diff --git a/engine/controllers/assistants.cc b/engine/controllers/assistants.cc index ba61abf83..530e180a5 100644 --- a/engine/controllers/assistants.cc +++ b/engine/controllers/assistants.cc @@ -1,4 +1,6 @@ #include "assistants.h" +#include "common/api-dto/delete_success_response.h" +#include "common/dto/assistant_create_dto.h" #include "utils/cortex_utils.h" #include "utils/logging_utils.h" @@ -6,7 +8,12 @@ void Assistants::RetrieveAssistant( const HttpRequestPtr& req, std::function&& callback, const std::string& assistant_id) const { - CTL_INF("RetrieveAssistant: " + assistant_id); + const auto& headers = req->headers(); + auto it = headers.find(kOpenAiAssistantKeyV2); + if (it != headers.end() && it->second == kOpenAiAssistantValueV2) { + return RetrieveAssistantV2(req, std::move(callback), assistant_id); + } + auto res = assistant_service_->RetrieveAssistant(assistant_id); if (res.has_error()) { Json::Value ret; @@ -33,6 +40,78 @@ void Assistants::RetrieveAssistant( } } +void Assistants::RetrieveAssistantV2( + const HttpRequestPtr& req, + std::function&& callback, + const std::string& assistant_id) const { + auto res = assistant_service_->RetrieveAssistantV2(assistant_id); + + if (res.has_error()) { + Json::Value ret; + ret["message"] = res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + } else { + auto to_json_res = res->ToJson(); + if (to_json_res.has_error()) { + CTL_ERR("Failed to convert assistant to json: " + to_json_res.error()); + Json::Value ret; + ret["message"] = to_json_res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + } else { + // TODO: namh need to use the text response because it contains model config + auto resp = + cortex_utils::CreateCortexHttpJsonResponse(res->ToJson().value()); + resp->setStatusCode(k200OK); + callback(resp); + } + } +} + +void Assistants::CreateAssistantV2( + const HttpRequestPtr& req, + std::function&& callback) { + auto json_body = req->getJsonObject(); + if (json_body == nullptr) { + Json::Value ret; + ret["message"] = "Request body can't be empty"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + auto dto = dto::CreateAssistantDto::FromJson(std::move(*json_body)); + CTL_INF("CreateAssistantV2: " << dto.model); + auto validate_res = dto.Validate(); + if (validate_res.has_error()) { + Json::Value ret; + ret["message"] = validate_res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + auto res = assistant_service_->CreateAssistantV2(dto); + if (res.has_error()) { + Json::Value ret; + ret["message"] = res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + auto to_json_res = res->ToJson(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(to_json_res.value()); + resp->setStatusCode(k200OK); + callback(resp); +} + void Assistants::CreateAssistant( const HttpRequestPtr& req, std::function&& callback, @@ -88,10 +167,55 @@ void Assistants::CreateAssistant( callback(resp); } +void Assistants::ModifyAssistantV2( + const HttpRequestPtr& req, + std::function&& callback, + const std::string& assistant_id) { + auto json_body = req->getJsonObject(); + if (json_body == nullptr) { + Json::Value ret; + ret["message"] = "Request body can't be empty"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + auto dto = dto::UpdateAssistantDto::FromJson(std::move(*json_body)); + auto validate_res = dto.Validate(); + if (validate_res.has_error()) { + Json::Value ret; + ret["message"] = validate_res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + auto res = assistant_service_->ModifyAssistantV2(assistant_id, dto); + if (res.has_error()) { + Json::Value ret; + ret["message"] = res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + auto resp = cortex_utils::CreateCortexHttpJsonResponse(res->ToJson().value()); + resp->setStatusCode(k200OK); + callback(resp); +} + void Assistants::ModifyAssistant( const HttpRequestPtr& req, std::function&& callback, const std::string& assistant_id) { + const auto& headers = req->headers(); + auto it = headers.find(kOpenAiAssistantKeyV2); + if (it != headers.end() && it->second == kOpenAiAssistantValueV2) { + return ModifyAssistantV2(req, std::move(callback), assistant_id); + } auto json_body = req->getJsonObject(); if (json_body == nullptr) { Json::Value ret; @@ -177,3 +301,27 @@ void Assistants::ListAssistants( response->setStatusCode(k200OK); callback(response); } + +void Assistants::DeleteAssistant( + const HttpRequestPtr& req, + std::function&& callback, + const std::string& assistant_id) { + auto res = assistant_service_->DeleteAssistantV2(assistant_id); + if (res.has_error()) { + Json::Value ret; + ret["message"] = res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + api_response::DeleteSuccessResponse response; + response.id = assistant_id; + response.object = "assistant.deleted"; + response.deleted = true; + auto resp = + cortex_utils::CreateCortexHttpJsonResponse(response.ToJson().value()); + resp->setStatusCode(k200OK); + callback(resp); +} diff --git a/engine/controllers/assistants.h b/engine/controllers/assistants.h index 4c6902793..30111bb01 100644 --- a/engine/controllers/assistants.h +++ b/engine/controllers/assistants.h @@ -7,14 +7,28 @@ using namespace drogon; class Assistants : public drogon::HttpController { + constexpr static auto kOpenAiAssistantKeyV2 = "openai-beta"; + constexpr static auto kOpenAiAssistantValueV2 = "assistants=v2"; + public: METHOD_LIST_BEGIN + ADD_METHOD_TO( + Assistants::ListAssistants, + "/v1/" + "assistants?limit={limit}&order={order}&after={after}&before={before}", + Get); + + ADD_METHOD_TO(Assistants::DeleteAssistant, "/v1/assistants/{assistant_id}", + Options, Delete); + ADD_METHOD_TO(Assistants::RetrieveAssistant, "/v1/assistants/{assistant_id}", Get); ADD_METHOD_TO(Assistants::CreateAssistant, "/v1/assistants/{assistant_id}", Options, Post); + ADD_METHOD_TO(Assistants::CreateAssistantV2, "/v1/assistants", Options, Post); + ADD_METHOD_TO(Assistants::ModifyAssistant, "/v1/assistants/{assistant_id}", Options, Patch); @@ -34,14 +48,31 @@ class Assistants : public drogon::HttpController { std::function&& callback, const std::string& assistant_id) const; + void RetrieveAssistantV2( + const HttpRequestPtr& req, + std::function&& callback, + const std::string& assistant_id) const; + + void DeleteAssistant(const HttpRequestPtr& req, + std::function&& callback, + const std::string& assistant_id); + void CreateAssistant(const HttpRequestPtr& req, std::function&& callback, const std::string& assistant_id); + void CreateAssistantV2( + const HttpRequestPtr& req, + std::function&& callback); + void ModifyAssistant(const HttpRequestPtr& req, std::function&& callback, const std::string& assistant_id); + void ModifyAssistantV2(const HttpRequestPtr& req, + std::function&& callback, + const std::string& assistant_id); + private: std::shared_ptr assistant_service_; }; diff --git a/engine/main.cc b/engine/main.cc index ddf1eefd8..938392bf0 100644 --- a/engine/main.cc +++ b/engine/main.cc @@ -15,6 +15,7 @@ #include "controllers/threads.h" #include "database/database.h" #include "migrations/migration_manager.h" +#include "repositories/assistant_fs_repository.h" #include "repositories/file_fs_repository.h" #include "repositories/message_fs_repository.h" #include "repositories/thread_fs_repository.h" @@ -142,9 +143,12 @@ void RunServer(std::optional host, std::optional port, auto file_repo = std::make_shared(data_folder_path); auto msg_repo = std::make_shared(data_folder_path); auto thread_repo = std::make_shared(data_folder_path); + auto assistant_repo = + std::make_shared(data_folder_path); auto file_srv = std::make_shared(file_repo); - auto assistant_srv = std::make_shared(thread_repo); + auto assistant_srv = + std::make_shared(thread_repo, assistant_repo); auto thread_srv = std::make_shared(thread_repo); auto message_srv = std::make_shared(msg_repo); diff --git a/engine/repositories/assistant_fs_repository.cc b/engine/repositories/assistant_fs_repository.cc index b28af690e..f5103f5c0 100644 --- a/engine/repositories/assistant_fs_repository.cc +++ b/engine/repositories/assistant_fs_repository.cc @@ -71,7 +71,7 @@ AssistantFsRepository::ListAssistants(uint8_t limit, const std::string& order, } cpp::result -AssistantFsRepository::RetrieveAssisant(const std::string assistant_id) const { +AssistantFsRepository::RetrieveAssistant(const std::string assistant_id) const { std::shared_lock lock(GrabAssistantMutex(assistant_id)); return LoadAssistant(assistant_id); } diff --git a/engine/repositories/assistant_fs_repository.h b/engine/repositories/assistant_fs_repository.h index c5eaffd6a..f310bd54e 100644 --- a/engine/repositories/assistant_fs_repository.h +++ b/engine/repositories/assistant_fs_repository.h @@ -17,7 +17,7 @@ class AssistantFsRepository : public AssistantRepository { cpp::result CreateAssistant( OpenAi::Assistant& assistant) override; - cpp::result RetrieveAssisant( + cpp::result RetrieveAssistant( const std::string assistant_id) const override; cpp::result ModifyAssistant( @@ -26,7 +26,7 @@ class AssistantFsRepository : public AssistantRepository { cpp::result DeleteAssistant( const std::string& assitant_id) override; - explicit AssistantFsRepository(std::filesystem::path data_folder_path) + explicit AssistantFsRepository(const std::filesystem::path& data_folder_path) : data_folder_path_{data_folder_path} { CTL_INF("Constructing AssistantFsRepository.."); auto path = data_folder_path_ / kAssistantContainerFolderName; diff --git a/engine/repositories/file_fs_repository.h b/engine/repositories/file_fs_repository.h index 974e81fa4..77af60dfc 100644 --- a/engine/repositories/file_fs_repository.h +++ b/engine/repositories/file_fs_repository.h @@ -28,7 +28,7 @@ class FileFsRepository : public FileRepository { cpp::result DeleteFileLocal( const std::string& file_id) override; - explicit FileFsRepository(std::filesystem::path data_folder_path) + explicit FileFsRepository(const std::filesystem::path& data_folder_path) : data_folder_path_{data_folder_path} { CTL_INF("Constructing FileFsRepository.."); auto file_container_path = data_folder_path_ / kFileContainerFolderName; diff --git a/engine/repositories/message_fs_repository.h b/engine/repositories/message_fs_repository.h index 2146778bf..0ca6e89b3 100644 --- a/engine/repositories/message_fs_repository.h +++ b/engine/repositories/message_fs_repository.h @@ -32,7 +32,7 @@ class MessageFsRepository : public MessageRepository { const std::string& thread_id, std::optional> messages) override; - explicit MessageFsRepository(std::filesystem::path data_folder_path) + explicit MessageFsRepository(const std::filesystem::path& data_folder_path) : data_folder_path_{data_folder_path} { CTL_INF("Constructing MessageFsRepository.."); auto thread_container_path = data_folder_path_ / kThreadContainerFolderName; diff --git a/engine/services/assistant_service.cc b/engine/services/assistant_service.cc index fea713eb5..2938787c6 100644 --- a/engine/services/assistant_service.cc +++ b/engine/services/assistant_service.cc @@ -37,17 +37,101 @@ AssistantService::ListAssistants(uint8_t limit, const std::string& order, cpp::result AssistantService::CreateAssistantV2( const dto::CreateAssistantDto& create_dto) { - OpenAi::Assistant assistant; + OpenAi::Assistant assistant; + assistant.model = create_dto.model; + if (create_dto.name) { + assistant.name = *create_dto.name; + } + if (create_dto.description) { + assistant.description = *create_dto.description; + } + if (create_dto.instructions) { + assistant.instructions = *create_dto.instructions; + } + if (create_dto.metadata) { + assistant.metadata = *create_dto.metadata; + } + if (create_dto.temperature) { + assistant.temperature = *create_dto.temperature; + } + if (create_dto.top_p) { + assistant.top_p = *create_dto.top_p; + } + if (create_dto.response_format) { + assistant.response_format = *create_dto.response_format; + } return assistant_repository_->CreateAssistant(assistant); } cpp::result -AssistantService::RetrieveAssistantV2(const std::string& assistant_id) const {} +AssistantService::RetrieveAssistantV2(const std::string& assistant_id) const { + if (assistant_id.empty()) { + return cpp::failure("Assistant ID cannot be empty"); + } + + return assistant_repository_->RetrieveAssistant(assistant_id); +} cpp::result AssistantService::ModifyAssistantV2( const std::string& assistant_id, - const dto::UpdateAssistantDto& update_dto) {} + const dto::UpdateAssistantDto& update_dto) { + if (assistant_id.empty()) { + return cpp::fail("Assistant ID cannot be empty"); + } + + if (!update_dto.Validate()) { + return cpp::fail("Invalid update assistant dto"); + } + + // First retrieve the existing assistant + auto existing_assistant = + assistant_repository_->RetrieveAssistant(assistant_id); + if (existing_assistant.has_error()) { + return cpp::fail(existing_assistant.error()); + } + + OpenAi::Assistant updated_assistant; + + // Update fields if they are present in the DTO + if (update_dto.model) { + updated_assistant.model = *update_dto.model; + } + if (update_dto.name) { + updated_assistant.name = *update_dto.name; + } + if (update_dto.description) { + updated_assistant.description = *update_dto.description; + } + if (update_dto.instructions) { + updated_assistant.instructions = *update_dto.instructions; + } + if (update_dto.metadata) { + updated_assistant.metadata = *update_dto.metadata; + } + if (update_dto.temperature) { + updated_assistant.temperature = *update_dto.temperature; + } + if (update_dto.top_p) { + updated_assistant.top_p = *update_dto.top_p; + } + if (update_dto.response_format) { + updated_assistant.response_format = *update_dto.response_format; + } + + auto res = assistant_repository_->ModifyAssistant(updated_assistant); + if (res.has_error()) { + return cpp::fail(res.error()); + } + + return updated_assistant; +} cpp::result AssistantService::DeleteAssistantV2( - const std::string& assistant_id) {} + const std::string& assistant_id) { + if (assistant_id.empty()) { + return cpp::fail("Assistant ID cannot be empty"); + } + + return assistant_repository_->DeleteAssistant(assistant_id); +} diff --git a/engine/services/assistant_service.h b/engine/services/assistant_service.h index 72bbdcf79..ad31104ff 100644 --- a/engine/services/assistant_service.h +++ b/engine/services/assistant_service.h @@ -37,8 +37,10 @@ class AssistantService { const std::string& assistant_id); explicit AssistantService( - std::shared_ptr thread_repository) - : thread_repository_{thread_repository} {} + std::shared_ptr thread_repository, + std::shared_ptr assistant_repository) + : thread_repository_{thread_repository}, + assistant_repository_{assistant_repository} {} private: std::shared_ptr thread_repository_; diff --git a/engine/services/thread_service.cc b/engine/services/thread_service.cc index 25784c2ee..827c4ea83 100644 --- a/engine/services/thread_service.cc +++ b/engine/services/thread_service.cc @@ -3,7 +3,7 @@ #include "utils/ulid/ulid.hh" cpp::result ThreadService::CreateThread( - std::unique_ptr tool_resources, + std::unique_ptr tool_resources, std::optional metadata) { LOG_TRACE << "CreateThread"; @@ -48,7 +48,7 @@ cpp::result ThreadService::RetrieveThread( cpp::result ThreadService::ModifyThread( const std::string& thread_id, - std::unique_ptr tool_resources, + std::unique_ptr tool_resources, std::optional metadata) { LOG_TRACE << "ModifyThread " << thread_id; auto retrieve_res = RetrieveThread(thread_id); diff --git a/engine/services/thread_service.h b/engine/services/thread_service.h index 966b0ab01..7011f46f3 100644 --- a/engine/services/thread_service.h +++ b/engine/services/thread_service.h @@ -2,7 +2,6 @@ #include #include "common/repository/thread_repository.h" -#include "common/thread_tool_resources.h" #include "common/variant_map.h" #include "utils/result.hpp" @@ -12,7 +11,7 @@ class ThreadService { : thread_repository_{thread_repository} {} cpp::result CreateThread( - std::unique_ptr tool_resources, + std::unique_ptr tool_resources, std::optional metadata); cpp::result, std::string> ListThreads( @@ -24,7 +23,7 @@ class ThreadService { cpp::result ModifyThread( const std::string& thread_id, - std::unique_ptr tool_resources, + std::unique_ptr tool_resources, std::optional metadata); cpp::result DeleteThread(