From f94527fd43e576ca096596f94ab2e7005ad82267 Mon Sep 17 00:00:00 2001 From: NamH Date: Fri, 27 Dec 2024 21:49:19 +0700 Subject: [PATCH] feat: add openai assistant (#1826) --- docs/static/openapi/cortex.json | 542 ++++++++++++++++-- engine/common/assistant.h | 271 ++++++++- .../common/assistant_code_interpreter_tool.h | 32 ++ engine/common/assistant_file_search_tool.h | 151 +++++ engine/common/assistant_function_tool.h | 130 +++++ engine/common/assistant_tool.h | 88 +-- engine/common/dto/assistant_create_dto.h | 211 +++++++ engine/common/dto/assistant_update_dto.h | 201 +++++++ engine/common/dto/base_dto.h | 16 + engine/common/message_attachment.h | 15 +- .../common/repository/assistant_repository.h | 25 + engine/common/thread.h | 12 +- engine/common/thread_tool_resources.h | 50 -- engine/common/tool_resources.h | 114 ++++ engine/controllers/assistants.cc | 185 +++++- engine/controllers/assistants.h | 39 ++ engine/main.cc | 6 +- .../repositories/assistant_fs_repository.cc | 214 +++++++ engine/repositories/assistant_fs_repository.h | 59 ++ engine/repositories/file_fs_repository.h | 2 +- engine/repositories/message_fs_repository.h | 2 +- engine/services/assistant_service.cc | 180 ++++++ engine/services/assistant_service.h | 32 +- engine/services/thread_service.cc | 4 +- engine/services/thread_service.h | 5 +- engine/test/components/test_assistant.cc | 194 +++++++ .../test_assistant_tool_code_interpreter.cc | 49 ++ .../test_assistant_tool_file_search.cc | 207 +++++++ .../test_assistant_tool_function.cc | 240 ++++++++ engine/test/components/test_tool_resources.cc | 212 +++++++ 30 files changed, 3289 insertions(+), 199 deletions(-) create mode 100644 engine/common/assistant_code_interpreter_tool.h create mode 100644 engine/common/assistant_file_search_tool.h create mode 100644 engine/common/assistant_function_tool.h create mode 100644 engine/common/dto/assistant_create_dto.h create mode 100644 engine/common/dto/assistant_update_dto.h create mode 100644 engine/common/dto/base_dto.h create mode 100644 engine/common/repository/assistant_repository.h delete mode 100644 engine/common/thread_tool_resources.h create mode 100644 engine/common/tool_resources.h create mode 100644 engine/repositories/assistant_fs_repository.cc create mode 100644 engine/repositories/assistant_fs_repository.h create mode 100644 engine/test/components/test_assistant.cc create mode 100644 engine/test/components/test_assistant_tool_code_interpreter.cc create mode 100644 engine/test/components/test_assistant_tool_file_search.cc create mode 100644 engine/test/components/test_assistant_tool_function.cc create mode 100644 engine/test/components/test_tool_resources.cc diff --git a/docs/static/openapi/cortex.json b/docs/static/openapi/cortex.json index 479e300ce..d006f0f2d 100644 --- a/docs/static/openapi/cortex.json +++ b/docs/static/openapi/cortex.json @@ -5,77 +5,470 @@ "post": { "operationId": "AssistantsController_create", "summary": "Create assistant", - "description": "Creates a new assistant.", - "parameters": [], + "description": "Creates a new assistant with the specified configuration.", "requestBody": { "required": true, "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/CreateAssistantDto" + "type": "object", + "properties": { + "model": { + "type": "string", + "description": "The model identifier to use for the assistant." + }, + "name": { + "type": "string", + "description": "The name of the assistant." + }, + "description": { + "type": "string", + "description": "The description of the assistant." + }, + "instructions": { + "type": "string", + "description": "Instructions for the assistant's behavior." + }, + "tools": { + "type": "array", + "description": "A list of tools enabled on the assistant. Maximum of 128 tools.", + "items": { + "type": "object", + "properties": { + "type": { + "type": "string", + "enum": [ + "code_interpreter", + "file_search", + "function" + ] + } + } + } + }, + "tool_resources": { + "type": "object", + "description": "Resources used by the assistant's tools.", + "properties": { + "code_interpreter": { + "type": "object" + }, + "file_search": { + "type": "object" + } + } + }, + "metadata": { + "type": "object", + "description": "Set of key-value pairs for the assistant.", + "additionalProperties": true + }, + "temperature": { + "type": "number", + "format": "float", + "description": "Temperature parameter for response generation." + }, + "top_p": { + "type": "number", + "format": "float", + "description": "Top p parameter for response generation." + }, + "response_format": { + "oneOf": [ + { + "type": "string", + "enum": ["auto"] + }, + { + "type": "object" + } + ] + } + }, + "required": ["model"] } } } }, "responses": { - "201": { - "description": "The assistant has been successfully created." + "200": { + "description": "Ok", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "id": { + "type": "string", + "description": "The unique identifier of the assistant." + }, + "object": { + "type": "string", + "enum": ["assistant"], + "description": "The object type, which is always 'assistant'." + }, + "created_at": { + "type": "integer", + "description": "Unix timestamp (in seconds) of when the assistant was created." + }, + "model": { + "type": "string", + "description": "The model identifier used by the assistant." + }, + "name": { + "type": "string", + "description": "The name of the assistant." + }, + "description": { + "type": "string", + "description": "The description of the assistant." + }, + "instructions": { + "type": "string", + "description": "Instructions for the assistant's behavior." + }, + "tools": { + "type": "array", + "description": "A list of tools enabled on the assistant.", + "items": { + "type": "object", + "properties": { + "type": { + "type": "string", + "enum": [ + "code_interpreter", + "file_search", + "function" + ] + } + } + } + }, + "tool_resources": { + "type": "object", + "description": "Resources used by the assistant's tools.", + "properties": { + "code_interpreter": { + "type": "object" + }, + "file_search": { + "type": "object" + } + } + }, + "metadata": { + "type": "object", + "description": "Set of key-value pairs that can be attached to the assistant.", + "additionalProperties": true + }, + "temperature": { + "type": "number", + "format": "float", + "description": "Temperature parameter for response generation." + }, + "top_p": { + "type": "number", + "format": "float", + "description": "Top p parameter for response generation." + }, + "response_format": { + "oneOf": [ + { + "type": "string", + "enum": ["auto"] + }, + { + "type": "object" + } + ] + } + }, + "required": [ + "id", + "object", + "created_at", + "model", + "metadata" + ] + } + } + } } }, "tags": ["Assistants"] }, - "get": { - "operationId": "AssistantsController_findAll", - "summary": "List assistants", - "description": "Returns a list of assistants.", + "patch": { + "operationId": "AssistantsController_update", + "summary": "Update assistant", + "description": "Updates an assistant. Requires at least one modifiable field.", "parameters": [ { - "name": "limit", - "required": false, - "in": "query", - "description": "A limit on the number of objects to be returned. Limit can range between 1 and 100, and the default is 20.", - "schema": { - "type": "number" - } - }, - { - "name": "order", - "required": false, - "in": "query", - "description": "Sort order by the created_at timestamp of the objects. asc for ascending order and desc for descending order.", - "schema": { - "type": "string" - } - }, - { - "name": "after", - "required": false, - "in": "query", - "description": "A cursor for use in pagination. after is an object ID that defines your place in the list. For instance, if you make a list request and receive 100 objects, ending with obj_foo, your subsequent call can include after=obj_foo in order to fetch the next page of the list.", + "name": "id", + "required": true, + "in": "path", + "description": "The unique identifier of the assistant.", "schema": { "type": "string" } }, { - "name": "before", - "required": false, - "in": "query", - "description": "A cursor for use in pagination. before is an object ID that defines your place in the list. For instance, if you make a list request and receive 100 objects, ending with obj_foo, your subsequent call can include before=obj_foo in order to fetch the previous page of the list.", + "name": "OpenAI-Beta", + "required": true, + "in": "header", + "description": "Beta feature header.", "schema": { - "type": "string" + "type": "string", + "enum": ["assistants=v2"] } } ], + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "model": { + "type": "string", + "description": "The model identifier to use for the assistant." + }, + "name": { + "type": "string", + "description": "The name of the assistant." + }, + "description": { + "type": "string", + "description": "The description of the assistant." + }, + "instructions": { + "type": "string", + "description": "Instructions for the assistant's behavior." + }, + "tools": { + "type": "array", + "description": "A list of tools enabled on the assistant. Maximum of 128 tools.", + "items": { + "type": "object", + "properties": { + "type": { + "type": "string", + "enum": [ + "code_interpreter", + "file_search", + "function" + ] + } + } + } + }, + "tool_resources": { + "type": "object", + "description": "Resources used by the assistant's tools.", + "properties": { + "code_interpreter": { + "type": "object" + }, + "file_search": { + "type": "object" + } + } + }, + "metadata": { + "type": "object", + "description": "Set of key-value pairs for the assistant.", + "additionalProperties": true + }, + "temperature": { + "type": "number", + "format": "float", + "description": "Temperature parameter for response generation." + }, + "top_p": { + "type": "number", + "format": "float", + "description": "Top p parameter for response generation." + }, + "response_format": { + "oneOf": [ + { + "type": "string", + "enum": ["auto"] + }, + { + "type": "object" + } + ] + } + }, + "minProperties": 1 + } + } + } + }, "responses": { "200": { "description": "Ok", "content": { "application/json": { "schema": { - "type": "array", - "items": { - "$ref": "#/components/schemas/AssistantEntity" - } + "type": "object", + "properties": { + "id": { + "type": "string", + "description": "The unique identifier of the assistant." + }, + "object": { + "type": "string", + "enum": ["assistant"], + "description": "The object type, which is always 'assistant'." + }, + "created_at": { + "type": "integer", + "description": "Unix timestamp (in seconds) of when the assistant was created." + }, + "model": { + "type": "string", + "description": "The model identifier used by the assistant." + }, + "name": { + "type": "string", + "description": "The name of the assistant." + }, + "description": { + "type": "string", + "description": "The description of the assistant." + }, + "instructions": { + "type": "string", + "description": "Instructions for the assistant's behavior." + }, + "tools": { + "type": "array", + "description": "A list of tools enabled on the assistant.", + "items": { + "type": "object", + "properties": { + "type": { + "type": "string", + "enum": [ + "code_interpreter", + "file_search", + "function" + ] + } + } + } + }, + "tool_resources": { + "type": "object", + "description": "Resources used by the assistant's tools.", + "properties": { + "code_interpreter": { + "type": "object" + }, + "file_search": { + "type": "object" + } + } + }, + "metadata": { + "type": "object", + "description": "Set of key-value pairs that can be attached to the assistant.", + "additionalProperties": true + }, + "temperature": { + "type": "number", + "format": "float", + "description": "Temperature parameter for response generation." + }, + "top_p": { + "type": "number", + "format": "float", + "description": "Top p parameter for response generation." + }, + "response_format": { + "oneOf": [ + { + "type": "string", + "enum": ["auto"] + }, + { + "type": "object" + } + ] + } + }, + "required": [ + "id", + "object", + "created_at", + "model", + "metadata" + ] + } + } + } + } + }, + "tags": ["Assistants"] + }, + "get": { + "operationId": "AssistantsController_list", + "summary": "List assistants", + "description": "Returns a list of assistants.", + "responses": { + "200": { + "description": "Ok", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "object": { + "type": "string", + "enum": ["list"], + "description": "The object type, which is always 'list' for a list response." + }, + "data": { + "type": "array", + "items": { + "type": "object", + "properties": { + "id": { + "type": "string", + "description": "The unique identifier of the assistant." + }, + "object": { + "type": "string", + "enum": ["assistant"], + "description": "The object type, which is always 'assistant'." + }, + "created_at": { + "type": "integer", + "description": "Unix timestamp (in seconds) of when the assistant was created." + }, + "model": { + "type": "string", + "description": "The model identifier used by the assistant." + }, + "metadata": { + "type": "object", + "description": "Set of key-value pairs that can be attached to the assistant.", + "additionalProperties": true + } + }, + "required": [ + "id", + "object", + "created_at", + "model", + "metadata" + ] + } + } + }, + "required": ["object", "data"] } } } @@ -88,7 +481,7 @@ "get": { "operationId": "AssistantsController_findOne", "summary": "Get assistant", - "description": "Retrieves a specific assistant defined by an assistant's `id`.", + "description": "Retrieves a specific assistant by ID.", "parameters": [ { "name": "id", @@ -98,6 +491,16 @@ "schema": { "type": "string" } + }, + { + "name": "OpenAI-Beta", + "required": true, + "in": "header", + "description": "Beta feature header.", + "schema": { + "type": "string", + "enum": ["assistants=v2"] + } } ], "responses": { @@ -106,7 +509,38 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/AssistantEntity" + "type": "object", + "properties": { + "id": { + "type": "string", + "description": "The unique identifier of the assistant." + }, + "object": { + "type": "string", + "enum": ["assistant"], + "description": "The object type, which is always 'assistant'." + }, + "created_at": { + "type": "integer", + "description": "Unix timestamp (in seconds) of when the assistant was created." + }, + "model": { + "type": "string", + "description": "The model identifier used by the assistant." + }, + "metadata": { + "type": "object", + "description": "Set of key-value pairs attached to the assistant.", + "additionalProperties": true + } + }, + "required": [ + "id", + "object", + "created_at", + "model", + "metadata" + ] } } } @@ -117,7 +551,7 @@ "delete": { "operationId": "AssistantsController_remove", "summary": "Delete assistant", - "description": "Deletes a specific assistant defined by an assistant's `id`.", + "description": "Deletes a specific assistant by ID.", "parameters": [ { "name": "id", @@ -131,11 +565,28 @@ ], "responses": { "200": { - "description": "The assistant has been successfully deleted.", + "description": "Ok", "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/DeleteAssistantResponseDto" + "type": "object", + "properties": { + "id": { + "type": "string", + "description": "The unique identifier of the deleted assistant." + }, + "object": { + "type": "string", + "enum": ["assistant.deleted"], + "description": "The object type for a deleted assistant." + }, + "deleted": { + "type": "boolean", + "enum": [true], + "description": "Indicates the assistant was successfully deleted." + } + }, + "required": ["id", "object", "deleted"] } } } @@ -3456,6 +3907,7 @@ "Files", "Hardware", "Events", + "Assistants", "Threads", "Messages", "Pulling Models", diff --git a/engine/common/assistant.h b/engine/common/assistant.h index e49147e9e..6210a0c2c 100644 --- a/engine/common/assistant.h +++ b/engine/common/assistant.h @@ -1,9 +1,13 @@ #pragma once #include +#include "common/assistant_code_interpreter_tool.h" +#include "common/assistant_file_search_tool.h" +#include "common/assistant_function_tool.h" #include "common/assistant_tool.h" -#include "common/thread_tool_resources.h" +#include "common/tool_resources.h" #include "common/variant_map.h" +#include "utils/logging_utils.h" #include "utils/result.hpp" namespace OpenAi { @@ -75,7 +79,49 @@ struct JanAssistant : JsonSerializable { } }; -struct Assistant { +struct Assistant : JsonSerializable { + Assistant() = default; + + ~Assistant() = default; + + Assistant(const Assistant&) = delete; + + Assistant& operator=(const Assistant&) = delete; + + Assistant(Assistant&& other) noexcept + : id{std::move(other.id)}, + object{std::move(other.object)}, + created_at{other.created_at}, + name{std::move(other.name)}, + description{std::move(other.description)}, + model(std::move(other.model)), + instructions(std::move(other.instructions)), + tools(std::move(other.tools)), + tool_resources(std::move(other.tool_resources)), + metadata(std::move(other.metadata)), + temperature{std::move(other.temperature)}, + top_p{std::move(other.top_p)}, + response_format{std::move(other.response_format)} {} + + Assistant& operator=(Assistant&& other) noexcept { + if (this != &other) { + id = std::move(other.id); + object = std::move(other.object); + created_at = other.created_at; + name = std::move(other.name); + description = std::move(other.description); + model = std::move(other.model); + instructions = std::move(other.instructions); + tools = std::move(other.tools); + tool_resources = std::move(other.tool_resources); + metadata = std::move(other.metadata); + temperature = std::move(other.temperature); + top_p = std::move(other.top_p); + response_format = std::move(other.response_format); + } + return *this; + } + /** * The identifier, which can be referenced in API endpoints. */ @@ -126,8 +172,7 @@ struct Assistant { * requires a list of file IDs, while the file_search tool requires a list * of vector store IDs. */ - std::optional> - tool_resources; + std::unique_ptr tool_resources; /** * Set of 16 key-value pairs that can be attached to an object. This can be @@ -153,5 +198,223 @@ struct Assistant { * We generally recommend altering this or temperature but not both. */ std::optional top_p; + + std::variant response_format; + + cpp::result ToJson() override { + try { + Json::Value root; + + root["id"] = std::move(id); + root["object"] = "assistant"; + root["created_at"] = created_at; + if (name.has_value()) { + root["name"] = name.value(); + } + if (description.has_value()) { + root["description"] = description.value(); + } + root["model"] = model; + if (instructions.has_value()) { + root["instructions"] = instructions.value(); + } + + Json::Value tools_jarr{Json::arrayValue}; + for (auto& tool_ptr : tools) { + if (auto it = tool_ptr->ToJson(); it.has_value()) { + tools_jarr.append(it.value()); + } else { + CTL_WRN("Failed to convert content to json: " + it.error()); + } + } + root["tools"] = tools_jarr; + if (tool_resources) { + Json::Value tool_resources_json{Json::objectValue}; + + if (auto* code_interpreter = + dynamic_cast(tool_resources.get())) { + auto result = code_interpreter->ToJson(); + if (result.has_value()) { + tool_resources_json["code_interpreter"] = result.value(); + } else { + CTL_WRN("Failed to convert code_interpreter to json: " + + result.error()); + } + } else if (auto* file_search = dynamic_cast( + tool_resources.get())) { + auto result = file_search->ToJson(); + if (result.has_value()) { + tool_resources_json["file_search"] = result.value(); + } else { + CTL_WRN("Failed to convert file_search to json: " + result.error()); + } + } + + // Only add tool_resources to root if we successfully serialized some resources + if (!tool_resources_json.empty()) { + root["tool_resources"] = tool_resources_json; + } + } + Json::Value metadata_json{Json::objectValue}; + for (const auto& [key, value] : metadata) { + if (std::holds_alternative(value)) { + metadata_json[key] = std::get(value); + } else if (std::holds_alternative(value)) { + metadata_json[key] = std::get(value); + } else if (std::holds_alternative(value)) { + metadata_json[key] = std::get(value); + } else { + metadata_json[key] = std::get(value); + } + } + root["metadata"] = metadata_json; + + if (temperature.has_value()) { + root["temperature"] = temperature.value(); + } + if (top_p.has_value()) { + root["top_p"] = top_p.value(); + } + return root; + } catch (const std::exception& e) { + return cpp::fail("ToJson failed: " + std::string(e.what())); + } + } + + static cpp::result FromJson(Json::Value&& json) { + try { + Assistant assistant; + + // Parse required fields + if (!json.isMember("id") || !json["id"].isString()) { + return cpp::fail("Missing or invalid 'id' field"); + } + assistant.id = json["id"].asString(); + + if (!json.isMember("object") || !json["object"].isString() || + json["object"].asString() != "assistant") { + return cpp::fail("Missing or invalid 'object' field"); + } + + if (!json.isMember("created_at") || !json["created_at"].isUInt64()) { + return cpp::fail("Missing or invalid 'created_at' field"); + } + assistant.created_at = json["created_at"].asUInt64(); + + if (!json.isMember("model") || !json["model"].isString()) { + return cpp::fail("Missing or invalid 'model' field"); + } + assistant.model = json["model"].asString(); + + // Parse optional fields + if (json.isMember("name") && json["name"].isString()) { + assistant.name = json["name"].asString(); + } + + if (json.isMember("description") && json["description"].isString()) { + assistant.description = json["description"].asString(); + } + + if (json.isMember("instructions") && json["instructions"].isString()) { + assistant.instructions = json["instructions"].asString(); + } + + // Parse tools array + if (json.isMember("tools") && json["tools"].isArray()) { + auto tools_array = json["tools"]; + for (const auto& tool : tools_array) { + if (!tool.isMember("type") || !tool["type"].isString()) { + CTL_WRN("Tool missing type field or invalid type"); + continue; + } + + std::string tool_type = tool["type"].asString(); + if (tool_type == "file_search") { + auto result = AssistantFileSearchTool::FromJson(tool); + if (result.has_value()) { + assistant.tools.push_back( + std::make_unique( + std::move(result.value()))); + } else { + CTL_WRN("Failed to parse file_search tool: " + result.error()); + } + } else if (tool_type == "code_interpreter") { + auto result = AssistantCodeInterpreterTool::FromJson(); + if (result.has_value()) { + assistant.tools.push_back( + std::make_unique( + std::move(result.value()))); + } else { + CTL_WRN("Failed to parse code_interpreter tool: " + + result.error()); + } + } else if (tool_type == "function") { + auto result = AssistantFunctionTool::FromJson(tool); + if (result.has_value()) { + assistant.tools.push_back(std::make_unique( + std::move(result.value()))); + } else { + CTL_WRN("Failed to parse function tool: " + result.error()); + } + } else { + CTL_WRN("Unknown tool type: " + tool_type); + } + } + } + + if (json.isMember("tool_resources") && + json["tool_resources"].isObject()) { + const auto& tool_resources_json = json["tool_resources"]; + + // Parse code interpreter resources + if (tool_resources_json.isMember("code_interpreter")) { + auto result = OpenAi::CodeInterpreter::FromJson( + tool_resources_json["code_interpreter"]); + if (result.has_value()) { + assistant.tool_resources = + std::make_unique( + std::move(result.value())); + } else { + CTL_WRN("Failed to parse code_interpreter resources: " + + result.error()); + } + } + + // Parse file search resources + if (tool_resources_json.isMember("file_search")) { + auto result = + OpenAi::FileSearch::FromJson(tool_resources_json["file_search"]); + if (result.has_value()) { + assistant.tool_resources = + std::make_unique(std::move(result.value())); + } else { + CTL_WRN("Failed to parse file_search resources: " + result.error()); + } + } + } + + // Parse metadata + if (json.isMember("metadata") && json["metadata"].isObject()) { + auto res = Cortex::ConvertJsonValueToMap(json["metadata"]); + if (res.has_value()) { + assistant.metadata = res.value(); + } else { + CTL_WRN("Failed to convert metadata to map: " + res.error()); + } + } + + if (json.isMember("temperature") && json["temperature"].isDouble()) { + assistant.temperature = json["temperature"].asFloat(); + } + + if (json.isMember("top_p") && json["top_p"].isDouble()) { + assistant.top_p = json["top_p"].asFloat(); + } + + return assistant; + } catch (const std::exception& e) { + return cpp::fail("FromJson failed: " + std::string(e.what())); + } + } }; } // namespace OpenAi diff --git a/engine/common/assistant_code_interpreter_tool.h b/engine/common/assistant_code_interpreter_tool.h new file mode 100644 index 000000000..43bfac47c --- /dev/null +++ b/engine/common/assistant_code_interpreter_tool.h @@ -0,0 +1,32 @@ +#pragma once + +#include "common/assistant_tool.h" + +namespace OpenAi { +struct AssistantCodeInterpreterTool : public AssistantTool { + AssistantCodeInterpreterTool() : AssistantTool("code_interpreter") {} + + AssistantCodeInterpreterTool(const AssistantCodeInterpreterTool&) = delete; + + AssistantCodeInterpreterTool& operator=(const AssistantCodeInterpreterTool&) = + delete; + + AssistantCodeInterpreterTool(AssistantCodeInterpreterTool&&) = default; + + AssistantCodeInterpreterTool& operator=(AssistantCodeInterpreterTool&&) = + default; + + ~AssistantCodeInterpreterTool() = default; + + static cpp::result FromJson() { + AssistantCodeInterpreterTool tool; + return std::move(tool); + } + + cpp::result ToJson() override { + Json::Value json; + json["type"] = type; + return json; + } +}; +} // namespace OpenAi diff --git a/engine/common/assistant_file_search_tool.h b/engine/common/assistant_file_search_tool.h new file mode 100644 index 000000000..2abaa7f6e --- /dev/null +++ b/engine/common/assistant_file_search_tool.h @@ -0,0 +1,151 @@ +#pragma once + +#include "common/assistant_tool.h" +#include "common/json_serializable.h" + +namespace OpenAi { +struct FileSearchRankingOption : public JsonSerializable { + /** + * The ranker to use for the file search. If not specified will use the auto ranker. + */ + std::string ranker; + + /** + * The score threshold for the file search. All values must be a + * floating point number between 0 and 1. + */ + float score_threshold; + + FileSearchRankingOption(float score_threshold, + const std::string& ranker = "auto") + : ranker{ranker}, score_threshold{score_threshold} {} + + FileSearchRankingOption(const FileSearchRankingOption&) = delete; + + FileSearchRankingOption& operator=(const FileSearchRankingOption&) = delete; + + FileSearchRankingOption(FileSearchRankingOption&&) = default; + + FileSearchRankingOption& operator=(FileSearchRankingOption&&) = default; + + ~FileSearchRankingOption() = default; + + static cpp::result FromJson( + const Json::Value& json) { + if (!json.isMember("score_threshold")) { + return cpp::fail("score_threshold must be provided"); + } + + FileSearchRankingOption option{ + json["score_threshold"].asFloat(), + std::move(json.get("ranker", "auto").asString())}; + return option; + } + + cpp::result ToJson() override { + Json::Value json; + json["ranker"] = ranker; + json["score_threshold"] = score_threshold; + return json; + } +}; + +/** + * Overrides for the file search tool. + */ +struct AssistantFileSearch : public JsonSerializable { + /** + * The maximum number of results the file search tool should output. + * The default is 20 for gpt-4* models and 5 for gpt-3.5-turbo. + * This number should be between 1 and 50 inclusive. + * + * Note that the file search tool may output fewer than max_num_results results. + * See the file search tool documentation for more information. + */ + int max_num_results; + + /** + * The ranking options for the file search. If not specified, + * the file search tool will use the auto ranker and a score_threshold of 0. + * + * See the file search tool documentation for more information. + */ + FileSearchRankingOption ranking_options; + + AssistantFileSearch(int max_num_results, + FileSearchRankingOption&& ranking_options) + : max_num_results{max_num_results}, + ranking_options{std::move(ranking_options)} {} + + AssistantFileSearch(const AssistantFileSearch&) = delete; + + AssistantFileSearch& operator=(const AssistantFileSearch&) = delete; + + AssistantFileSearch(AssistantFileSearch&&) = default; + + AssistantFileSearch& operator=(AssistantFileSearch&&) = default; + + ~AssistantFileSearch() = default; + + static cpp::result FromJson( + const Json::Value& json) { + try { + AssistantFileSearch search{ + json["max_num_results"].asInt(), + FileSearchRankingOption::FromJson(json["ranking_options"]).value()}; + return search; + } catch (const std::exception& e) { + return cpp::fail(std::string("FromJson failed: ") + e.what()); + } + } + + cpp::result ToJson() override { + Json::Value root; + root["max_num_results"] = max_num_results; + root["ranking_options"] = ranking_options.ToJson().value(); + return root; + } +}; + +struct AssistantFileSearchTool : public AssistantTool { + AssistantFileSearch file_search; + + AssistantFileSearchTool(AssistantFileSearch& file_search) + : AssistantTool("file_search"), file_search{std::move(file_search)} {} + + AssistantFileSearchTool(const AssistantFileSearchTool&) = delete; + + AssistantFileSearchTool& operator=(const AssistantFileSearchTool&) = delete; + + AssistantFileSearchTool(AssistantFileSearchTool&&) = default; + + AssistantFileSearchTool& operator=(AssistantFileSearchTool&&) = default; + + ~AssistantFileSearchTool() = default; + + static cpp::result FromJson( + const Json::Value& json) { + try { + AssistantFileSearch search{json["file_search"]["max_num_results"].asInt(), + FileSearchRankingOption::FromJson( + json["file_search"]["ranking_options"]) + .value()}; + AssistantFileSearchTool tool{search}; + return tool; + } catch (const std::exception& e) { + return cpp::fail(std::string("FromJson failed: ") + e.what()); + } + } + + cpp::result ToJson() override { + try { + Json::Value root; + root["type"] = type; + root["file_search"] = file_search.ToJson().value(); + return root; + } catch (const std::exception& e) { + return cpp::fail(std::string("ToJson failed: ") + e.what()); + } + } +}; +}; // namespace OpenAi diff --git a/engine/common/assistant_function_tool.h b/engine/common/assistant_function_tool.h new file mode 100644 index 000000000..7998cb8ff --- /dev/null +++ b/engine/common/assistant_function_tool.h @@ -0,0 +1,130 @@ +#pragma once + +#include +#include "common/assistant_tool.h" +#include "common/json_serializable.h" + +namespace OpenAi { +struct AssistantFunction : public JsonSerializable { + AssistantFunction(const std::string& description, const std::string& name, + const Json::Value& parameters, + const std::optional& strict) + : description{std::move(description)}, + name{std::move(name)}, + parameters{std::move(parameters)}, + strict{strict} {} + + AssistantFunction(const AssistantFunction&) = delete; + + AssistantFunction& operator=(const AssistantFunction&) = delete; + + AssistantFunction(AssistantFunction&&) = default; + + AssistantFunction& operator=(AssistantFunction&&) = default; + + ~AssistantFunction() = default; + + /** + * A description of what the function does, used by the model to choose + * when and how to call the function. + */ + std::string description; + + /** + * The name of the function to be called. Must be a-z, A-Z, 0-9, or contain + * underscores and dashes, with a maximum length of 64. + */ + std::string name; + + /** + * The parameters the functions accepts, described as a JSON Schema object. + * See the guide for examples, and the JSON Schema reference for documentation + * about the format. + * + * Omitting parameters defines a function with an empty parameter list. + */ + Json::Value parameters; + + /** + * Whether to enable strict schema adherence when generating the function call. + * If set to true, the model will follow the exact schema defined in the parameters + * field. Only a subset of JSON Schema is supported when strict is true. + * + * Learn more about Structured Outputs in the function calling guide. + */ + std::optional strict; + + static cpp::result FromJson( + const Json::Value& json) { + if (json.empty()) { + return cpp::fail("Function json can't be empty"); + } + + if (!json.isMember("name") || json.get("name", "").asString().empty()) { + return cpp::fail("Function name can't be empty"); + } + + if (!json.isMember("description")) { + return cpp::fail("Function description is mandatory"); + } + + if (!json.isMember("parameters")) { + return cpp::fail("Function parameters are mandatory"); + } + + std::optional is_strict = std::nullopt; + if (json.isMember("strict")) { + is_strict = json["strict"].asBool(); + } + AssistantFunction function{json["description"].asString(), + json["name"].asString(), json["parameters"], + is_strict}; + function.parameters = json["parameters"]; + return function; + } + + cpp::result ToJson() override { + Json::Value json; + json["description"] = description; + json["name"] = name; + if (strict.has_value()) { + json["strict"] = *strict; + } + json["parameters"] = parameters; + return json; + } +}; + +struct AssistantFunctionTool : public AssistantTool { + AssistantFunctionTool(AssistantFunction& function) + : AssistantTool("function"), function{std::move(function)} {} + + AssistantFunctionTool(const AssistantFunctionTool&) = delete; + + AssistantFunctionTool& operator=(const AssistantFunctionTool&) = delete; + + AssistantFunctionTool(AssistantFunctionTool&&) = default; + + AssistantFunctionTool& operator=(AssistantFunctionTool&&) = default; + + ~AssistantFunctionTool() = default; + + AssistantFunction function; + + static cpp::result FromJson( + const Json::Value& json) { + auto function_res = AssistantFunction::FromJson(json["function"]); + if (function_res.has_error()) { + return cpp::fail("Failed to parse function: " + function_res.error()); + } + return AssistantFunctionTool{function_res.value()}; + } + + cpp::result ToJson() override { + Json::Value root; + root["type"] = type; + root["function"] = function.ToJson().value(); + return root; + } +}; +}; // namespace OpenAi diff --git a/engine/common/assistant_tool.h b/engine/common/assistant_tool.h index 622721708..d02392392 100644 --- a/engine/common/assistant_tool.h +++ b/engine/common/assistant_tool.h @@ -1,91 +1,27 @@ #pragma once -#include #include +#include "common/json_serializable.h" namespace OpenAi { -struct AssistantTool { +struct AssistantTool : public JsonSerializable { std::string type; AssistantTool(const std::string& type) : type{type} {} - virtual ~AssistantTool() = default; -}; - -struct AssistantCodeInterpreterTool : public AssistantTool { - AssistantCodeInterpreterTool() : AssistantTool{"code_interpreter"} {} - - ~AssistantCodeInterpreterTool() = default; -}; - -struct AssistantFileSearchTool : public AssistantTool { - AssistantFileSearchTool() : AssistantTool("file_search") {} - - ~AssistantFileSearchTool() = default; + AssistantTool(const AssistantTool&) = delete; - /** - * The ranking options for the file search. If not specified, - * the file search tool will use the auto ranker and a score_threshold of 0. - * - * See the file search tool documentation for more information. - */ - struct RankingOption { - /** - * The ranker to use for the file search. If not specified will use the auto ranker. - */ - std::string ranker; + AssistantTool& operator=(const AssistantTool&) = delete; - /** - * The score threshold for the file search. All values must be a - * floating point number between 0 and 1. - */ - float score_threshold; - }; + AssistantTool(AssistantTool&& other) noexcept : type{std::move(other.type)} {} - /** - * Overrides for the file search tool. - */ - struct FileSearch { - /** - * The maximum number of results the file search tool should output. - * The default is 20 for gpt-4* models and 5 for gpt-3.5-turbo. - * This number should be between 1 and 50 inclusive. - * - * Note that the file search tool may output fewer than max_num_results results. - * See the file search tool documentation for more information. - */ - int max_num_result; - }; -}; - -struct AssistantFunctionTool : public AssistantTool { - AssistantFunctionTool() : AssistantTool("function") {} - - ~AssistantFunctionTool() = default; - - struct Function { - /** - * A description of what the function does, used by the model to choose - * when and how to call the function. - */ - std::string description; + AssistantTool& operator=(AssistantTool&& other) noexcept { + if (this != &other) { + type = std::move(other.type); + } + return *this; + } - /** - * The name of the function to be called. Must be a-z, A-Z, 0-9, or contain - * underscores and dashes, with a maximum length of 64. - */ - std::string name; - - // TODO: namh handle parameters - - /** - * Whether to enable strict schema adherence when generating the function call. - * If set to true, the model will follow the exact schema defined in the parameters - * field. Only a subset of JSON Schema is supported when strict is true. - * - * Learn more about Structured Outputs in the function calling guide. - */ - std::optional strict; - }; + virtual ~AssistantTool() = default; }; } // namespace OpenAi diff --git a/engine/common/dto/assistant_create_dto.h b/engine/common/dto/assistant_create_dto.h new file mode 100644 index 000000000..19d79b833 --- /dev/null +++ b/engine/common/dto/assistant_create_dto.h @@ -0,0 +1,211 @@ +#pragma once + +#include +#include +#include "common/assistant_code_interpreter_tool.h" +#include "common/assistant_file_search_tool.h" +#include "common/assistant_function_tool.h" +#include "common/assistant_tool.h" +#include "common/dto/base_dto.h" +#include "common/tool_resources.h" +#include "common/variant_map.h" +#include "utils/logging_utils.h" + +namespace dto { +struct CreateAssistantDto : public BaseDto { + CreateAssistantDto() = default; + + ~CreateAssistantDto() = default; + + CreateAssistantDto(const CreateAssistantDto&) = delete; + + CreateAssistantDto& operator=(const CreateAssistantDto&) = delete; + + CreateAssistantDto(CreateAssistantDto&& other) noexcept + : model{std::move(other.model)}, + name{std::move(other.name)}, + description{std::move(other.description)}, + instructions{std::move(other.instructions)}, + tools{std::move(other.tools)}, + tool_resources{std::move(other.tool_resources)}, + metadata{std::move(other.metadata)}, + temperature{std::move(other.temperature)}, + top_p{std::move(other.top_p)}, + response_format{std::move(other.response_format)} {} + + CreateAssistantDto& operator=(CreateAssistantDto&& other) noexcept { + if (this != &other) { + model = std::move(other.model); + name = std::move(other.name); + description = std::move(other.description); + instructions = std::move(other.instructions); + tools = std::move(other.tools); + tool_resources = std::move(other.tool_resources), + metadata = std::move(other.metadata); + temperature = std::move(other.temperature); + top_p = std::move(other.top_p); + response_format = std::move(other.response_format); + } + return *this; + } + + std::string model; + + std::optional name; + + std::optional description; + + std::optional instructions; + + /** + * A list of tool enabled on the assistant. There can be a maximum of + * 128 tools per assistant. Tools can be of types code_interpreter, + * file_search, or function. + */ + std::vector> tools; + + /** + * A set of resources that are used by the assistant's tools. The resources + * are specific to the type of tool. For example, the code_interpreter tool + * requires a list of file IDs, while the file_search tool requires a list + * of vector store IDs. + */ + std::unique_ptr tool_resources; + + std::optional metadata; + + std::optional temperature; + + std::optional top_p; + + std::optional> response_format; + + cpp::result Validate() const override { + if (model.empty()) { + return cpp::fail("Model is mandatory"); + } + + if (response_format.has_value()) { + const auto& variant_value = response_format.value(); + if (std::holds_alternative(variant_value)) { + if (std::get(variant_value) != "auto") { + return cpp::fail("Invalid response_format"); + } + } + } + + return {}; + } + + static CreateAssistantDto FromJson(Json::Value&& root) { + if (root.empty()) { + throw std::runtime_error("Json passed in FromJson can't be empty"); + } + CreateAssistantDto dto; + dto.model = std::move(root["model"].asString()); + if (root.isMember("name")) { + dto.name = std::move(root["name"].asString()); + } + if (root.isMember("description")) { + dto.description = std::move(root["description"].asString()); + } + if (root.isMember("instructions")) { + dto.instructions = std::move(root["instructions"].asString()); + } + if (root["metadata"].isObject() && !root["metadata"].empty()) { + auto res = Cortex::ConvertJsonValueToMap(root["metadata"]); + if (res.has_error()) { + CTL_WRN("Failed to convert metadata to map: " + res.error()); + } else { + dto.metadata = std::move(res.value()); + } + } + if (root.isMember("temperature")) { + dto.temperature = root["temperature"].asFloat(); + } + if (root.isMember("top_p")) { + dto.top_p = root["top_p"].asFloat(); + } + if (root.isMember("tools") && root["tools"].isArray()) { + auto tools_array = root["tools"]; + for (const auto& tool : tools_array) { + if (!tool.isMember("type") || !tool["type"].isString()) { + CTL_WRN("Tool missing type field or invalid type"); + continue; + } + + std::string tool_type = tool["type"].asString(); + if (tool_type == "file_search") { + auto result = OpenAi::AssistantFileSearchTool::FromJson(tool); + if (result.has_value()) { + dto.tools.push_back( + std::make_unique( + std::move(result.value()))); + } else { + CTL_WRN("Failed to parse file_search tool: " + result.error()); + } + } else if (tool_type == "code_interpreter") { + auto result = OpenAi::AssistantCodeInterpreterTool::FromJson(); + if (result.has_value()) { + dto.tools.push_back( + std::make_unique( + std::move(result.value()))); + } else { + CTL_WRN("Failed to parse code_interpreter tool: " + result.error()); + } + } else if (tool_type == "function") { + auto result = OpenAi::AssistantFunctionTool::FromJson(tool); + if (result.has_value()) { + dto.tools.push_back(std::make_unique( + std::move(result.value()))); + } else { + CTL_WRN("Failed to parse function tool: " + result.error()); + } + } else { + CTL_WRN("Unknown tool type: " + tool_type); + } + } + } + if (root.isMember("tool_resources") && root["tool_resources"].isObject()) { + const auto& tool_resources_json = root["tool_resources"]; + + // Parse code interpreter resources + if (tool_resources_json.isMember("code_interpreter")) { + auto result = OpenAi::CodeInterpreter::FromJson( + tool_resources_json["code_interpreter"]); + if (result.has_value()) { + dto.tool_resources = std::make_unique( + std::move(result.value())); + } else { + CTL_WRN("Failed to parse code_interpreter resources: " + + result.error()); + } + } + + // Parse file search resources + if (tool_resources_json.isMember("file_search")) { + auto result = + OpenAi::FileSearch::FromJson(tool_resources_json["file_search"]); + if (result.has_value()) { + dto.tool_resources = + std::make_unique(std::move(result.value())); + } else { + CTL_WRN("Failed to parse file_search resources: " + result.error()); + } + } + } + if (root.isMember("response_format")) { + const auto& response_format = root["response_format"]; + if (response_format.isString()) { + dto.response_format = response_format.asString(); + } else if (response_format.isObject()) { + dto.response_format = response_format; + } else { + throw std::runtime_error( + "response_format must be either a string or an object"); + } + } + return dto; + } +}; +} // namespace dto diff --git a/engine/common/dto/assistant_update_dto.h b/engine/common/dto/assistant_update_dto.h new file mode 100644 index 000000000..01e5844d7 --- /dev/null +++ b/engine/common/dto/assistant_update_dto.h @@ -0,0 +1,201 @@ +#pragma once + +#include "common/assistant_code_interpreter_tool.h" +#include "common/assistant_file_search_tool.h" +#include "common/assistant_function_tool.h" +#include "common/dto/base_dto.h" +#include "common/tool_resources.h" +#include "common/variant_map.h" +#include "utils/logging_utils.h" + +namespace dto { +struct UpdateAssistantDto : public BaseDto { + UpdateAssistantDto() = default; + + ~UpdateAssistantDto() = default; + + UpdateAssistantDto(const UpdateAssistantDto&) = delete; + + UpdateAssistantDto& operator=(const UpdateAssistantDto&) = delete; + + UpdateAssistantDto(UpdateAssistantDto&& other) noexcept + : model{std::move(other.model)}, + name{std::move(other.name)}, + description{std::move(other.description)}, + instructions{std::move(other.instructions)}, + tools{std::move(other.tools)}, + tool_resources{std::move(other.tool_resources)}, + metadata{std::move(other.metadata)}, + temperature{std::move(other.temperature)}, + top_p{std::move(other.top_p)}, + response_format{std::move(other.response_format)} {} + + UpdateAssistantDto& operator=(UpdateAssistantDto&& other) noexcept { + if (this != &other) { + model = std::move(other.model); + name = std::move(other.name); + description = std::move(other.description); + instructions = std::move(other.instructions); + tools = std::move(other.tools); + tool_resources = std::move(other.tool_resources), + metadata = std::move(other.metadata); + temperature = std::move(other.temperature); + top_p = std::move(other.top_p); + response_format = std::move(other.response_format); + } + return *this; + } + std::optional model; + + std::optional name; + + std::optional description; + + std::optional instructions; + + /** + * A list of tool enabled on the assistant. There can be a maximum of + * 128 tools per assistant. Tools can be of types code_interpreter, + * file_search, or function. + */ + std::vector> tools; + + /** + * A set of resources that are used by the assistant's tools. The resources + * are specific to the type of tool. For example, the code_interpreter tool + * requires a list of file IDs, while the file_search tool requires a list + * of vector store IDs. + */ + std::unique_ptr tool_resources; + + std::optional metadata; + + std::optional temperature; + + std::optional top_p; + + std::optional> response_format; + + cpp::result Validate() const override { + if (!model.has_value() && !name.has_value() && !description.has_value() && + !instructions.has_value() && !metadata.has_value() && + !temperature.has_value() && !top_p.has_value() && + !response_format.has_value()) { + return cpp::fail("At least one field must be provided"); + } + + return {}; + } + + static UpdateAssistantDto FromJson(Json::Value&& root) { + if (root.empty()) { + throw std::runtime_error("Json passed in FromJson can't be empty"); + } + UpdateAssistantDto dto; + dto.model = std::move(root["model"].asString()); + if (root.isMember("name")) { + dto.name = std::move(root["name"].asString()); + } + if (root.isMember("description")) { + dto.description = std::move(root["description"].asString()); + } + if (root.isMember("instruction")) { + dto.instructions = std::move(root["instruction"].asString()); + } + if (root["metadata"].isObject() && !root["metadata"].empty()) { + auto res = Cortex::ConvertJsonValueToMap(root["metadata"]); + if (res.has_error()) { + CTL_WRN("Failed to convert metadata to map: " + res.error()); + } else { + dto.metadata = std::move(res.value()); + } + } + if (root.isMember("temperature")) { + dto.temperature = root["temperature"].asFloat(); + } + if (root.isMember("top_p")) { + dto.top_p = root["top_p"].asFloat(); + } + if (root.isMember("tools") && root["tools"].isArray()) { + auto tools_array = root["tools"]; + for (const auto& tool : tools_array) { + if (!tool.isMember("type") || !tool["type"].isString()) { + CTL_WRN("Tool missing type field or invalid type"); + continue; + } + + std::string tool_type = tool["type"].asString(); + if (tool_type == "file_search") { + auto result = OpenAi::AssistantFileSearchTool::FromJson(tool); + if (result.has_value()) { + dto.tools.push_back( + std::make_unique( + std::move(result.value()))); + } else { + CTL_WRN("Failed to parse file_search tool: " + result.error()); + } + } else if (tool_type == "code_interpreter") { + auto result = OpenAi::AssistantCodeInterpreterTool::FromJson(); + if (result.has_value()) { + dto.tools.push_back( + std::make_unique( + std::move(result.value()))); + } else { + CTL_WRN("Failed to parse code_interpreter tool: " + result.error()); + } + } else if (tool_type == "function") { + auto result = OpenAi::AssistantFunctionTool::FromJson(tool); + if (result.has_value()) { + dto.tools.push_back(std::make_unique( + std::move(result.value()))); + } else { + CTL_WRN("Failed to parse function tool: " + result.error()); + } + } else { + CTL_WRN("Unknown tool type: " + tool_type); + } + } + } + if (root.isMember("tool_resources") && root["tool_resources"].isObject()) { + const auto& tool_resources_json = root["tool_resources"]; + + // Parse code interpreter resources + if (tool_resources_json.isMember("code_interpreter")) { + auto result = OpenAi::CodeInterpreter::FromJson( + tool_resources_json["code_interpreter"]); + if (result.has_value()) { + dto.tool_resources = std::make_unique( + std::move(result.value())); + } else { + CTL_WRN("Failed to parse code_interpreter resources: " + + result.error()); + } + } + + // Parse file search resources + if (tool_resources_json.isMember("file_search")) { + auto result = + OpenAi::FileSearch::FromJson(tool_resources_json["file_search"]); + if (result.has_value()) { + dto.tool_resources = + std::make_unique(std::move(result.value())); + } else { + CTL_WRN("Failed to parse file_search resources: " + result.error()); + } + } + } + if (root.isMember("response_format")) { + const auto& response_format = root["response_format"]; + if (response_format.isString()) { + dto.response_format = response_format.asString(); + } else if (response_format.isObject()) { + dto.response_format = response_format; + } else { + throw std::runtime_error( + "response_format must be either a string or an object"); + } + } + return dto; + }; +}; +} // namespace dto diff --git a/engine/common/dto/base_dto.h b/engine/common/dto/base_dto.h new file mode 100644 index 000000000..ed7460aa3 --- /dev/null +++ b/engine/common/dto/base_dto.h @@ -0,0 +1,16 @@ +#pragma once + +#include +#include "utils/result.hpp" + +namespace dto { +template +struct BaseDto { + virtual ~BaseDto() = default; + + /** + * Validate itself. + */ + virtual cpp::result Validate() const = 0; +}; +} // namespace dto diff --git a/engine/common/message_attachment.h b/engine/common/message_attachment.h index 767ec9bea..6a0fb02e9 100644 --- a/engine/common/message_attachment.h +++ b/engine/common/message_attachment.h @@ -4,22 +4,27 @@ #include "common/json_serializable.h" namespace OpenAi { - // The tools to add this file to. struct Tool { std::string type; Tool(const std::string& type) : type{type} {} + + virtual ~Tool() = default; }; // The type of tool being defined: code_interpreter -struct CodeInterpreter : Tool { - CodeInterpreter() : Tool{"code_interpreter"} {} +struct MessageCodeInterpreter : Tool { + MessageCodeInterpreter() : Tool{"code_interpreter"} {} + + ~MessageCodeInterpreter() = default; }; // The type of tool being defined: file_search -struct FileSearch : Tool { - FileSearch() : Tool{"file_search"} {} +struct MessageFileSearch : Tool { + MessageFileSearch() : Tool{"file_search"} {} + + ~MessageFileSearch() = default; }; // A list of files attached to the message, and the tools they were added to. diff --git a/engine/common/repository/assistant_repository.h b/engine/common/repository/assistant_repository.h new file mode 100644 index 000000000..d0ff1908d --- /dev/null +++ b/engine/common/repository/assistant_repository.h @@ -0,0 +1,25 @@ +#pragma once + +#include "common/assistant.h" +#include "utils/result.hpp" + +class AssistantRepository { + public: + virtual cpp::result, std::string> + ListAssistants(uint8_t limit, const std::string& order, + const std::string& after, const std::string& before) const = 0; + + virtual cpp::result CreateAssistant( + OpenAi::Assistant& assistant) = 0; + + virtual cpp::result RetrieveAssistant( + const std::string assistant_id) const = 0; + + virtual cpp::result ModifyAssistant( + OpenAi::Assistant& assistant) = 0; + + virtual cpp::result DeleteAssistant( + const std::string& assitant_id) = 0; + + virtual ~AssistantRepository() = default; +}; diff --git a/engine/common/thread.h b/engine/common/thread.h index 2bd5d866b..dc57ba32d 100644 --- a/engine/common/thread.h +++ b/engine/common/thread.h @@ -4,7 +4,7 @@ #include #include #include "common/assistant.h" -#include "common/thread_tool_resources.h" +#include "common/tool_resources.h" #include "common/variant_map.h" #include "json_serializable.h" #include "utils/logging_utils.h" @@ -36,7 +36,7 @@ struct Thread : JsonSerializable { * of tool. For example, the code_interpreter tool requires a list of * file IDs, while the file_search tool requires a list of vector store IDs. */ - std::unique_ptr tool_resources; + std::unique_ptr tool_resources; /** * Set of 16 key-value pairs that can be attached to an object. @@ -65,7 +65,7 @@ struct Thread : JsonSerializable { const auto& tool_json = json["tool_resources"]; if (tool_json.isMember("code_interpreter")) { - auto code_interpreter = std::make_unique(); + auto code_interpreter = std::make_unique(); const auto& file_ids = tool_json["code_interpreter"]["file_ids"]; if (file_ids.isArray()) { for (const auto& file_id : file_ids) { @@ -74,7 +74,7 @@ struct Thread : JsonSerializable { } thread.tool_resources = std::move(code_interpreter); } else if (tool_json.isMember("file_search")) { - auto file_search = std::make_unique(); + auto file_search = std::make_unique(); const auto& store_ids = tool_json["file_search"]["vector_store_ids"]; if (store_ids.isArray()) { for (const auto& store_id : store_ids) { @@ -148,10 +148,10 @@ struct Thread : JsonSerializable { Json::Value tool_json; if (auto code_interpreter = - dynamic_cast(tool_resources.get())) { + dynamic_cast(tool_resources.get())) { tool_json["code_interpreter"] = tool_result.value(); } else if (auto file_search = - dynamic_cast(tool_resources.get())) { + dynamic_cast(tool_resources.get())) { tool_json["file_search"] = tool_result.value(); } json["tool_resources"] = tool_json; diff --git a/engine/common/thread_tool_resources.h b/engine/common/thread_tool_resources.h deleted file mode 100644 index 3c22a4480..000000000 --- a/engine/common/thread_tool_resources.h +++ /dev/null @@ -1,50 +0,0 @@ -#pragma once - -#include -#include -#include "common/json_serializable.h" - -namespace OpenAi { - -struct ThreadToolResources : JsonSerializable { - ~ThreadToolResources() = default; - - virtual cpp::result ToJson() override = 0; -}; - -struct ThreadCodeInterpreter : ThreadToolResources { - std::vector file_ids; - - cpp::result ToJson() override { - try { - Json::Value json; - Json::Value file_ids_json{Json::arrayValue}; - for (auto& file_id : file_ids) { - file_ids_json.append(file_id); - } - json["file_ids"] = file_ids_json; - return json; - } catch (const std::exception& e) { - return cpp::fail(std::string("ToJson failed: ") + e.what()); - } - } -}; - -struct ThreadFileSearch : ThreadToolResources { - std::vector vector_store_ids; - - cpp::result ToJson() override { - try { - Json::Value json; - Json::Value vector_store_ids_json{Json::arrayValue}; - for (auto& vector_store_id : vector_store_ids) { - vector_store_ids_json.append(vector_store_id); - } - json["vector_store_ids"] = vector_store_ids_json; - return json; - } catch (const std::exception& e) { - return cpp::fail(std::string("ToJson failed: ") + e.what()); - } - } -}; -} // namespace OpenAi diff --git a/engine/common/tool_resources.h b/engine/common/tool_resources.h new file mode 100644 index 000000000..5aadb3f8b --- /dev/null +++ b/engine/common/tool_resources.h @@ -0,0 +1,114 @@ +#pragma once + +#include +#include +#include "common/json_serializable.h" + +namespace OpenAi { + +struct ToolResources : JsonSerializable { + ToolResources() = default; + + ToolResources(const ToolResources&) = delete; + + ToolResources& operator=(const ToolResources&) = delete; + + ToolResources(ToolResources&&) noexcept = default; + + ToolResources& operator=(ToolResources&&) noexcept = default; + + virtual ~ToolResources() = default; + + virtual cpp::result ToJson() override = 0; +}; + +struct CodeInterpreter : ToolResources { + CodeInterpreter() = default; + + ~CodeInterpreter() override = default; + + CodeInterpreter(const CodeInterpreter&) = delete; + + CodeInterpreter& operator=(const CodeInterpreter&) = delete; + + CodeInterpreter(CodeInterpreter&& other) noexcept + : ToolResources(std::move(other)), file_ids(std::move(other.file_ids)) {} + + CodeInterpreter& operator=(CodeInterpreter&& other) noexcept { + if (this != &other) { + ToolResources::operator=(std::move(other)); + file_ids = std::move(other.file_ids); + } + return *this; + } + + std::vector file_ids; + + static cpp::result FromJson( + const Json::Value& json) { + CodeInterpreter code_interpreter; + if (json.isMember("file_ids")) { + for (const auto& file_id : json["file_ids"]) { + code_interpreter.file_ids.push_back(file_id.asString()); + } + } + return code_interpreter; + } + + cpp::result ToJson() override { + Json::Value json; + Json::Value file_ids_json{Json::arrayValue}; + for (auto& file_id : file_ids) { + file_ids_json.append(file_id); + } + json["file_ids"] = file_ids_json; + return json; + } +}; + +struct FileSearch : ToolResources { + FileSearch() = default; + + ~FileSearch() override = default; + + FileSearch(const FileSearch&) = delete; + + FileSearch& operator=(const FileSearch&) = delete; + + FileSearch(FileSearch&& other) noexcept + : ToolResources(std::move(other)), + vector_store_ids{std::move(other.vector_store_ids)} {} + + FileSearch& operator=(FileSearch&& other) noexcept { + if (this != &other) { + ToolResources::operator=(std::move(other)); + + vector_store_ids = std::move(other.vector_store_ids); + } + return *this; + } + + std::vector vector_store_ids; + + static cpp::result FromJson( + const Json::Value& json) { + FileSearch file_search; + if (json.isMember("vector_store_ids")) { + for (const auto& vector_store_id : json["vector_store_ids"]) { + file_search.vector_store_ids.push_back(vector_store_id.asString()); + } + } + return file_search; + } + + cpp::result ToJson() override { + Json::Value json; + Json::Value vector_store_ids_json{Json::arrayValue}; + for (auto& vector_store_id : vector_store_ids) { + vector_store_ids_json.append(vector_store_id); + } + json["vector_store_ids"] = vector_store_ids_json; + return json; + } +}; +} // namespace OpenAi diff --git a/engine/controllers/assistants.cc b/engine/controllers/assistants.cc index 405d7ed3c..530e180a5 100644 --- a/engine/controllers/assistants.cc +++ b/engine/controllers/assistants.cc @@ -1,4 +1,6 @@ #include "assistants.h" +#include "common/api-dto/delete_success_response.h" +#include "common/dto/assistant_create_dto.h" #include "utils/cortex_utils.h" #include "utils/logging_utils.h" @@ -6,7 +8,12 @@ void Assistants::RetrieveAssistant( const HttpRequestPtr& req, std::function&& callback, const std::string& assistant_id) const { - CTL_INF("RetrieveAssistant: " + assistant_id); + const auto& headers = req->headers(); + auto it = headers.find(kOpenAiAssistantKeyV2); + if (it != headers.end() && it->second == kOpenAiAssistantValueV2) { + return RetrieveAssistantV2(req, std::move(callback), assistant_id); + } + auto res = assistant_service_->RetrieveAssistant(assistant_id); if (res.has_error()) { Json::Value ret; @@ -33,6 +40,78 @@ void Assistants::RetrieveAssistant( } } +void Assistants::RetrieveAssistantV2( + const HttpRequestPtr& req, + std::function&& callback, + const std::string& assistant_id) const { + auto res = assistant_service_->RetrieveAssistantV2(assistant_id); + + if (res.has_error()) { + Json::Value ret; + ret["message"] = res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + } else { + auto to_json_res = res->ToJson(); + if (to_json_res.has_error()) { + CTL_ERR("Failed to convert assistant to json: " + to_json_res.error()); + Json::Value ret; + ret["message"] = to_json_res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + } else { + // TODO: namh need to use the text response because it contains model config + auto resp = + cortex_utils::CreateCortexHttpJsonResponse(res->ToJson().value()); + resp->setStatusCode(k200OK); + callback(resp); + } + } +} + +void Assistants::CreateAssistantV2( + const HttpRequestPtr& req, + std::function&& callback) { + auto json_body = req->getJsonObject(); + if (json_body == nullptr) { + Json::Value ret; + ret["message"] = "Request body can't be empty"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + auto dto = dto::CreateAssistantDto::FromJson(std::move(*json_body)); + CTL_INF("CreateAssistantV2: " << dto.model); + auto validate_res = dto.Validate(); + if (validate_res.has_error()) { + Json::Value ret; + ret["message"] = validate_res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + auto res = assistant_service_->CreateAssistantV2(dto); + if (res.has_error()) { + Json::Value ret; + ret["message"] = res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + auto to_json_res = res->ToJson(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(to_json_res.value()); + resp->setStatusCode(k200OK); + callback(resp); +} + void Assistants::CreateAssistant( const HttpRequestPtr& req, std::function&& callback, @@ -88,10 +167,55 @@ void Assistants::CreateAssistant( callback(resp); } +void Assistants::ModifyAssistantV2( + const HttpRequestPtr& req, + std::function&& callback, + const std::string& assistant_id) { + auto json_body = req->getJsonObject(); + if (json_body == nullptr) { + Json::Value ret; + ret["message"] = "Request body can't be empty"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + auto dto = dto::UpdateAssistantDto::FromJson(std::move(*json_body)); + auto validate_res = dto.Validate(); + if (validate_res.has_error()) { + Json::Value ret; + ret["message"] = validate_res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + auto res = assistant_service_->ModifyAssistantV2(assistant_id, dto); + if (res.has_error()) { + Json::Value ret; + ret["message"] = res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + auto resp = cortex_utils::CreateCortexHttpJsonResponse(res->ToJson().value()); + resp->setStatusCode(k200OK); + callback(resp); +} + void Assistants::ModifyAssistant( const HttpRequestPtr& req, std::function&& callback, const std::string& assistant_id) { + const auto& headers = req->headers(); + auto it = headers.find(kOpenAiAssistantKeyV2); + if (it != headers.end() && it->second == kOpenAiAssistantValueV2) { + return ModifyAssistantV2(req, std::move(callback), assistant_id); + } auto json_body = req->getJsonObject(); if (json_body == nullptr) { Json::Value ret; @@ -142,3 +266,62 @@ void Assistants::ModifyAssistant( resp->setStatusCode(k200OK); callback(resp); } + +void Assistants::ListAssistants( + const HttpRequestPtr& req, + std::function&& callback, + std::optional limit, std::optional order, + std::optional after, std::optional before) const { + + auto res = assistant_service_->ListAssistants( + std::stoi(limit.value_or("20")), order.value_or("desc"), + after.value_or(""), before.value_or("")); + if (res.has_error()) { + Json::Value root; + root["message"] = res.error(); + auto response = cortex_utils::CreateCortexHttpJsonResponse(root); + response->setStatusCode(k400BadRequest); + callback(response); + return; + } + + Json::Value assistant_list(Json::arrayValue); + for (auto& msg : res.value()) { + if (auto it = msg.ToJson(); it.has_value()) { + assistant_list.append(it.value()); + } else { + CTL_WRN("Failed to convert message to json: " + it.error()); + } + } + + Json::Value root; + root["object"] = "list"; + root["data"] = assistant_list; + auto response = cortex_utils::CreateCortexHttpJsonResponse(root); + response->setStatusCode(k200OK); + callback(response); +} + +void Assistants::DeleteAssistant( + const HttpRequestPtr& req, + std::function&& callback, + const std::string& assistant_id) { + auto res = assistant_service_->DeleteAssistantV2(assistant_id); + if (res.has_error()) { + Json::Value ret; + ret["message"] = res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + api_response::DeleteSuccessResponse response; + response.id = assistant_id; + response.object = "assistant.deleted"; + response.deleted = true; + auto resp = + cortex_utils::CreateCortexHttpJsonResponse(response.ToJson().value()); + resp->setStatusCode(k200OK); + callback(resp); +} diff --git a/engine/controllers/assistants.h b/engine/controllers/assistants.h index 94ddd14b1..30111bb01 100644 --- a/engine/controllers/assistants.h +++ b/engine/controllers/assistants.h @@ -7,33 +7,72 @@ using namespace drogon; class Assistants : public drogon::HttpController { + constexpr static auto kOpenAiAssistantKeyV2 = "openai-beta"; + constexpr static auto kOpenAiAssistantValueV2 = "assistants=v2"; + public: METHOD_LIST_BEGIN + ADD_METHOD_TO( + Assistants::ListAssistants, + "/v1/" + "assistants?limit={limit}&order={order}&after={after}&before={before}", + Get); + + ADD_METHOD_TO(Assistants::DeleteAssistant, "/v1/assistants/{assistant_id}", + Options, Delete); + ADD_METHOD_TO(Assistants::RetrieveAssistant, "/v1/assistants/{assistant_id}", Get); ADD_METHOD_TO(Assistants::CreateAssistant, "/v1/assistants/{assistant_id}", Options, Post); + ADD_METHOD_TO(Assistants::CreateAssistantV2, "/v1/assistants", Options, Post); + ADD_METHOD_TO(Assistants::ModifyAssistant, "/v1/assistants/{assistant_id}", Options, Patch); + METHOD_LIST_END explicit Assistants(std::shared_ptr assistant_srv) : assistant_service_{assistant_srv} {}; + void ListAssistants(const HttpRequestPtr& req, + std::function&& callback, + std::optional limit, + std::optional order, + std::optional after, + std::optional before) const; + void RetrieveAssistant(const HttpRequestPtr& req, std::function&& callback, const std::string& assistant_id) const; + void RetrieveAssistantV2( + const HttpRequestPtr& req, + std::function&& callback, + const std::string& assistant_id) const; + + void DeleteAssistant(const HttpRequestPtr& req, + std::function&& callback, + const std::string& assistant_id); + void CreateAssistant(const HttpRequestPtr& req, std::function&& callback, const std::string& assistant_id); + void CreateAssistantV2( + const HttpRequestPtr& req, + std::function&& callback); + void ModifyAssistant(const HttpRequestPtr& req, std::function&& callback, const std::string& assistant_id); + void ModifyAssistantV2(const HttpRequestPtr& req, + std::function&& callback, + const std::string& assistant_id); + private: std::shared_ptr assistant_service_; }; diff --git a/engine/main.cc b/engine/main.cc index ddf1eefd8..938392bf0 100644 --- a/engine/main.cc +++ b/engine/main.cc @@ -15,6 +15,7 @@ #include "controllers/threads.h" #include "database/database.h" #include "migrations/migration_manager.h" +#include "repositories/assistant_fs_repository.h" #include "repositories/file_fs_repository.h" #include "repositories/message_fs_repository.h" #include "repositories/thread_fs_repository.h" @@ -142,9 +143,12 @@ void RunServer(std::optional host, std::optional port, auto file_repo = std::make_shared(data_folder_path); auto msg_repo = std::make_shared(data_folder_path); auto thread_repo = std::make_shared(data_folder_path); + auto assistant_repo = + std::make_shared(data_folder_path); auto file_srv = std::make_shared(file_repo); - auto assistant_srv = std::make_shared(thread_repo); + auto assistant_srv = + std::make_shared(thread_repo, assistant_repo); auto thread_srv = std::make_shared(thread_repo); auto message_srv = std::make_shared(msg_repo); diff --git a/engine/repositories/assistant_fs_repository.cc b/engine/repositories/assistant_fs_repository.cc new file mode 100644 index 000000000..87b4174fd --- /dev/null +++ b/engine/repositories/assistant_fs_repository.cc @@ -0,0 +1,214 @@ +#include "assistant_fs_repository.h" +#include +#include +#include +#include +#include "utils/result.hpp" + +cpp::result, std::string> +AssistantFsRepository::ListAssistants(uint8_t limit, const std::string& order, + const std::string& after, + const std::string& before) const { + std::vector assistants; + try { + auto assistant_container_path = + data_folder_path_ / kAssistantContainerFolderName; + std::vector all_assistants; + + for (const auto& entry : + std::filesystem::directory_iterator(assistant_container_path)) { + if (!entry.is_directory()) { + continue; + } + + auto assistant_file = entry.path() / kAssistantFileName; + if (!std::filesystem::exists(assistant_file)) { + continue; + } + + auto current_assistant_id = entry.path().filename().string(); + + if (!after.empty() && current_assistant_id <= after) { + continue; + } + + if (!before.empty() && current_assistant_id >= before) { + continue; + } + + std::shared_lock assistant_lock(GrabAssistantMutex(current_assistant_id)); + auto assistant_res = LoadAssistant(current_assistant_id); + if (assistant_res.has_value()) { + all_assistants.push_back(std::move(assistant_res.value())); + } + assistant_lock.unlock(); + } + + // sorting + if (order == "desc") { + std::sort(all_assistants.begin(), all_assistants.end(), + [](const OpenAi::Assistant& assistant1, + const OpenAi::Assistant& assistant2) { + return assistant1.created_at > assistant2.created_at; + }); + } else { + std::sort(all_assistants.begin(), all_assistants.end(), + [](const OpenAi::Assistant& assistant1, + const OpenAi::Assistant& assistant2) { + return assistant1.created_at < assistant2.created_at; + }); + } + + size_t assistant_count = + std::min(static_cast(limit), all_assistants.size()); + for (size_t i = 0; i < assistant_count; i++) { + assistants.push_back(std::move(all_assistants[i])); + } + + return assistants; + } catch (const std::exception& e) { + return cpp::fail("Failed to list assistants: " + std::string(e.what())); + } +} + +cpp::result +AssistantFsRepository::RetrieveAssistant(const std::string assistant_id) const { + std::shared_lock lock(GrabAssistantMutex(assistant_id)); + return LoadAssistant(assistant_id); +} + +cpp::result AssistantFsRepository::ModifyAssistant( + OpenAi::Assistant& assistant) { + { + std::unique_lock lock(GrabAssistantMutex(assistant.id)); + auto path = GetAssistantPath(assistant.id); + + if (!std::filesystem::exists(path)) { + lock.unlock(); + return cpp::fail("Assistant doesn't exist: " + assistant.id); + } + } + + return SaveAssistant(assistant); +} + +cpp::result AssistantFsRepository::DeleteAssistant( + const std::string& assitant_id) { + { + std::unique_lock assistant_lock(GrabAssistantMutex(assitant_id)); + auto path = GetAssistantPath(assitant_id); + if (!std::filesystem::exists(path)) { + return cpp::fail("Assistant doesn't exist: " + assitant_id); + } + try { + std::filesystem::remove_all(path); + } catch (const std::exception& e) { + return cpp::fail(""); + } + } + + std::unique_lock map_lock(map_mutex_); + assistant_mutexes_.erase(assitant_id); + return {}; +} + +cpp::result +AssistantFsRepository::CreateAssistant(OpenAi::Assistant& assistant) { + CTL_INF("CreateAssistant: " + assistant.id); + { + std::unique_lock lock(GrabAssistantMutex(assistant.id)); + auto path = GetAssistantPath(assistant.id); + + if (std::filesystem::exists(path)) { + return cpp::fail("Assistant already exists: " + assistant.id); + } + + std::filesystem::create_directories(path); + auto assistant_file_path = path / kAssistantFileName; + std::ofstream assistant_file(assistant_file_path); + assistant_file.close(); + + CTL_INF("CreateAssistant created new file: " + assistant.id); + auto save_result = SaveAssistant(assistant); + if (save_result.has_error()) { + lock.unlock(); + return cpp::fail("Failed to save assistant: " + save_result.error()); + } + } + return RetrieveAssistant(assistant.id); +} + +cpp::result AssistantFsRepository::SaveAssistant( + OpenAi::Assistant& assistant) { + auto path = GetAssistantPath(assistant.id) / kAssistantFileName; + if (!std::filesystem::exists(path)) { + std::filesystem::create_directories(path); + } + + std::ofstream file(path); + if (!file) { + return cpp::fail("Failed to open file: " + path.string()); + } + try { + file << assistant.ToJson()->toStyledString(); + file.flush(); + file.close(); + return {}; + } catch (const std::exception& e) { + file.close(); + return cpp::fail("Failed to save assistant: " + std::string(e.what())); + } +} + +std::filesystem::path AssistantFsRepository::GetAssistantPath( + const std::string& assistant_id) const { + auto container_folder_path = + data_folder_path_ / kAssistantContainerFolderName; + if (!std::filesystem::exists(container_folder_path)) { + std::filesystem::create_directories(container_folder_path); + } + + return data_folder_path_ / kAssistantContainerFolderName / assistant_id; +} + +std::shared_mutex& AssistantFsRepository::GrabAssistantMutex( + const std::string& assistant_id) const { + std::shared_lock map_lock(map_mutex_); + auto it = assistant_mutexes_.find(assistant_id); + if (it != assistant_mutexes_.end()) { + return *it->second; + } + + map_lock.unlock(); + std::unique_lock map_write_lock(map_mutex_); + return *assistant_mutexes_ + .try_emplace(assistant_id, std::make_unique()) + .first->second; +} + +cpp::result +AssistantFsRepository::LoadAssistant(const std::string& assistant_id) const { + auto path = GetAssistantPath(assistant_id) / kAssistantFileName; + if (!std::filesystem::exists(path)) { + return cpp::fail("Path does not exist: " + path.string()); + } + + try { + std::ifstream file(path); + if (!file.is_open()) { + return cpp::fail("Failed to open file: " + path.string()); + } + + Json::Value root; + Json::CharReaderBuilder builder; + JSONCPP_STRING errs; + + if (!parseFromStream(builder, file, &root, &errs)) { + return cpp::fail("Failed to parse JSON: " + errs); + } + + return OpenAi::Assistant::FromJson(std::move(root)); + } catch (const std::exception& e) { + return cpp::fail("Failed to load assistant: " + std::string(e.what())); + } +} diff --git a/engine/repositories/assistant_fs_repository.h b/engine/repositories/assistant_fs_repository.h new file mode 100644 index 000000000..f310bd54e --- /dev/null +++ b/engine/repositories/assistant_fs_repository.h @@ -0,0 +1,59 @@ +#pragma once + +#include +#include + +#include "common/repository/assistant_repository.h" + +class AssistantFsRepository : public AssistantRepository { + public: + constexpr static auto kAssistantContainerFolderName = "assistants"; + constexpr static auto kAssistantFileName = "assistant.json"; + + cpp::result, std::string> ListAssistants( + uint8_t limit, const std::string& order, const std::string& after, + const std::string& before) const override; + + cpp::result CreateAssistant( + OpenAi::Assistant& assistant) override; + + cpp::result RetrieveAssistant( + const std::string assistant_id) const override; + + cpp::result ModifyAssistant( + OpenAi::Assistant& assistant) override; + + cpp::result DeleteAssistant( + const std::string& assitant_id) override; + + explicit AssistantFsRepository(const std::filesystem::path& data_folder_path) + : data_folder_path_{data_folder_path} { + CTL_INF("Constructing AssistantFsRepository.."); + auto path = data_folder_path_ / kAssistantContainerFolderName; + + if (!std::filesystem::exists(path)) { + std::filesystem::create_directories(path); + } + } + + ~AssistantFsRepository() = default; + + private: + std::filesystem::path GetAssistantPath(const std::string& assistant_id) const; + + std::shared_mutex& GrabAssistantMutex(const std::string& assistant_id) const; + + cpp::result SaveAssistant(OpenAi::Assistant& assistant); + + cpp::result LoadAssistant( + const std::string& assistant_id) const; + + /** + * The path to the data folder. + */ + std::filesystem::path data_folder_path_; + + mutable std::shared_mutex map_mutex_; + mutable std::unordered_map> + assistant_mutexes_; +}; diff --git a/engine/repositories/file_fs_repository.h b/engine/repositories/file_fs_repository.h index 974e81fa4..77af60dfc 100644 --- a/engine/repositories/file_fs_repository.h +++ b/engine/repositories/file_fs_repository.h @@ -28,7 +28,7 @@ class FileFsRepository : public FileRepository { cpp::result DeleteFileLocal( const std::string& file_id) override; - explicit FileFsRepository(std::filesystem::path data_folder_path) + explicit FileFsRepository(const std::filesystem::path& data_folder_path) : data_folder_path_{data_folder_path} { CTL_INF("Constructing FileFsRepository.."); auto file_container_path = data_folder_path_ / kFileContainerFolderName; diff --git a/engine/repositories/message_fs_repository.h b/engine/repositories/message_fs_repository.h index 2146778bf..0ca6e89b3 100644 --- a/engine/repositories/message_fs_repository.h +++ b/engine/repositories/message_fs_repository.h @@ -32,7 +32,7 @@ class MessageFsRepository : public MessageRepository { const std::string& thread_id, std::optional> messages) override; - explicit MessageFsRepository(std::filesystem::path data_folder_path) + explicit MessageFsRepository(const std::filesystem::path& data_folder_path) : data_folder_path_{data_folder_path} { CTL_INF("Constructing MessageFsRepository.."); auto thread_container_path = data_folder_path_ / kThreadContainerFolderName; diff --git a/engine/services/assistant_service.cc b/engine/services/assistant_service.cc index e769bf23f..08a5a743f 100644 --- a/engine/services/assistant_service.cc +++ b/engine/services/assistant_service.cc @@ -1,5 +1,7 @@ #include "assistant_service.h" +#include #include "utils/logging_utils.h" +#include "utils/ulid_generator.h" cpp::result AssistantService::CreateAssistant(const std::string& thread_id, @@ -26,3 +28,181 @@ AssistantService::ModifyAssistant(const std::string& thread_id, CTL_INF("RetrieveAssistant: " + thread_id); return thread_repository_->ModifyAssistant(thread_id, assistant); } + +cpp::result, std::string> +AssistantService::ListAssistants(uint8_t limit, const std::string& order, + const std::string& after, + const std::string& before) const { + CTL_INF("List assistants invoked"); + return assistant_repository_->ListAssistants(limit, order, after, before); +} + +cpp::result AssistantService::CreateAssistantV2( + const dto::CreateAssistantDto& create_dto) { + + OpenAi::Assistant assistant; + assistant.id = "asst_" + ulid::GenerateUlid(); + assistant.model = create_dto.model; + if (create_dto.name) { + assistant.name = *create_dto.name; + } + if (create_dto.description) { + assistant.description = *create_dto.description; + } + if (create_dto.instructions) { + assistant.instructions = *create_dto.instructions; + } + if (create_dto.metadata) { + assistant.metadata = *create_dto.metadata; + } + if (create_dto.temperature) { + assistant.temperature = *create_dto.temperature; + } + if (create_dto.top_p) { + assistant.top_p = *create_dto.top_p; + } + for (auto& tool_ptr : create_dto.tools) { + // Create a new unique_ptr in assistant.tools that takes ownership + if (auto* function_tool = + dynamic_cast(tool_ptr.get())) { + assistant.tools.push_back(std::make_unique( + std::move(*function_tool))); + } else if (auto* code_tool = + dynamic_cast( + tool_ptr.get())) { + assistant.tools.push_back( + std::make_unique( + std::move(*code_tool))); + } else if (auto* search_tool = + dynamic_cast( + tool_ptr.get())) { + assistant.tools.push_back( + std::make_unique( + std::move(*search_tool))); + } + } + if (create_dto.tool_resources) { + if (auto* code_interpreter = dynamic_cast( + create_dto.tool_resources.get())) { + assistant.tool_resources = std::make_unique( + std::move(*code_interpreter)); + } else if (auto* file_search = dynamic_cast( + create_dto.tool_resources.get())) { + assistant.tool_resources = + std::make_unique(std::move(*file_search)); + } + } + if (create_dto.response_format) { + assistant.response_format = *create_dto.response_format; + } + auto seconds_since_epoch = + std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count(); + assistant.created_at = seconds_since_epoch; + return assistant_repository_->CreateAssistant(assistant); +} +cpp::result +AssistantService::RetrieveAssistantV2(const std::string& assistant_id) const { + if (assistant_id.empty()) { + return cpp::fail("Assistant ID cannot be empty"); + } + + return assistant_repository_->RetrieveAssistant(assistant_id); +} + +cpp::result AssistantService::ModifyAssistantV2( + const std::string& assistant_id, + const dto::UpdateAssistantDto& update_dto) { + if (assistant_id.empty()) { + return cpp::fail("Assistant ID cannot be empty"); + } + + if (!update_dto.Validate()) { + return cpp::fail("Invalid update assistant dto"); + } + + // First retrieve the existing assistant + auto existing_assistant = + assistant_repository_->RetrieveAssistant(assistant_id); + if (existing_assistant.has_error()) { + return cpp::fail(existing_assistant.error()); + } + + OpenAi::Assistant updated_assistant; + updated_assistant.id = assistant_id; + + // Update fields if they are present in the DTO + if (update_dto.model) { + updated_assistant.model = *update_dto.model; + } + if (update_dto.name) { + updated_assistant.name = *update_dto.name; + } + if (update_dto.description) { + updated_assistant.description = *update_dto.description; + } + if (update_dto.instructions) { + updated_assistant.instructions = *update_dto.instructions; + } + if (update_dto.metadata) { + updated_assistant.metadata = *update_dto.metadata; + } + if (update_dto.temperature) { + updated_assistant.temperature = *update_dto.temperature; + } + if (update_dto.top_p) { + updated_assistant.top_p = *update_dto.top_p; + } + for (auto& tool_ptr : update_dto.tools) { + if (auto* function_tool = + dynamic_cast(tool_ptr.get())) { + updated_assistant.tools.push_back( + std::make_unique( + std::move(*function_tool))); + } else if (auto* code_tool = + dynamic_cast( + tool_ptr.get())) { + updated_assistant.tools.push_back( + std::make_unique( + std::move(*code_tool))); + } else if (auto* search_tool = + dynamic_cast( + tool_ptr.get())) { + updated_assistant.tools.push_back( + std::make_unique( + std::move(*search_tool))); + } + } + if (update_dto.tool_resources) { + if (auto* code_interpreter = dynamic_cast( + update_dto.tool_resources.get())) { + updated_assistant.tool_resources = + std::make_unique( + std::move(*code_interpreter)); + } else if (auto* file_search = dynamic_cast( + update_dto.tool_resources.get())) { + updated_assistant.tool_resources = + std::make_unique(std::move(*file_search)); + } + } + if (update_dto.response_format) { + updated_assistant.response_format = *update_dto.response_format; + } + + auto res = assistant_repository_->ModifyAssistant(updated_assistant); + if (res.has_error()) { + return cpp::fail(res.error()); + } + + return updated_assistant; +} + +cpp::result AssistantService::DeleteAssistantV2( + const std::string& assistant_id) { + if (assistant_id.empty()) { + return cpp::fail("Assistant ID cannot be empty"); + } + + return assistant_repository_->DeleteAssistant(assistant_id); +} diff --git a/engine/services/assistant_service.h b/engine/services/assistant_service.h index e7f7414d1..ad31104ff 100644 --- a/engine/services/assistant_service.h +++ b/engine/services/assistant_service.h @@ -1,15 +1,14 @@ #pragma once #include "common/assistant.h" +#include "common/dto/assistant_create_dto.h" +#include "common/dto/assistant_update_dto.h" +#include "common/repository/assistant_repository.h" #include "repositories/thread_fs_repository.h" #include "utils/result.hpp" class AssistantService { public: - explicit AssistantService( - std::shared_ptr thread_repository) - : thread_repository_{thread_repository} {} - cpp::result CreateAssistant( const std::string& thread_id, const OpenAi::JanAssistant& assistant); @@ -19,6 +18,31 @@ class AssistantService { cpp::result ModifyAssistant( const std::string& thread_id, const OpenAi::JanAssistant& assistant); + // V2 + cpp::result CreateAssistantV2( + const dto::CreateAssistantDto& create_dto); + + cpp::result, std::string> ListAssistants( + uint8_t limit, const std::string& order, const std::string& after, + const std::string& before) const; + + cpp::result RetrieveAssistantV2( + const std::string& assistant_id) const; + + cpp::result ModifyAssistantV2( + const std::string& assistant_id, + const dto::UpdateAssistantDto& update_dto); + + cpp::result DeleteAssistantV2( + const std::string& assistant_id); + + explicit AssistantService( + std::shared_ptr thread_repository, + std::shared_ptr assistant_repository) + : thread_repository_{thread_repository}, + assistant_repository_{assistant_repository} {} + private: std::shared_ptr thread_repository_; + std::shared_ptr assistant_repository_; }; diff --git a/engine/services/thread_service.cc b/engine/services/thread_service.cc index 0ec0ac89d..9c5e7e857 100644 --- a/engine/services/thread_service.cc +++ b/engine/services/thread_service.cc @@ -4,7 +4,7 @@ #include "utils/ulid_generator.h" cpp::result ThreadService::CreateThread( - std::unique_ptr tool_resources, + std::unique_ptr tool_resources, std::optional metadata) { LOG_TRACE << "CreateThread"; @@ -46,7 +46,7 @@ cpp::result ThreadService::RetrieveThread( cpp::result ThreadService::ModifyThread( const std::string& thread_id, - std::unique_ptr tool_resources, + std::unique_ptr tool_resources, std::optional metadata) { LOG_TRACE << "ModifyThread " << thread_id; auto retrieve_res = RetrieveThread(thread_id); diff --git a/engine/services/thread_service.h b/engine/services/thread_service.h index 966b0ab01..7011f46f3 100644 --- a/engine/services/thread_service.h +++ b/engine/services/thread_service.h @@ -2,7 +2,6 @@ #include #include "common/repository/thread_repository.h" -#include "common/thread_tool_resources.h" #include "common/variant_map.h" #include "utils/result.hpp" @@ -12,7 +11,7 @@ class ThreadService { : thread_repository_{thread_repository} {} cpp::result CreateThread( - std::unique_ptr tool_resources, + std::unique_ptr tool_resources, std::optional metadata); cpp::result, std::string> ListThreads( @@ -24,7 +23,7 @@ class ThreadService { cpp::result ModifyThread( const std::string& thread_id, - std::unique_ptr tool_resources, + std::unique_ptr tool_resources, std::optional metadata); cpp::result DeleteThread( diff --git a/engine/test/components/test_assistant.cc b/engine/test/components/test_assistant.cc new file mode 100644 index 000000000..20ba08f34 --- /dev/null +++ b/engine/test/components/test_assistant.cc @@ -0,0 +1,194 @@ +#include +#include "common/assistant.h" + +namespace OpenAi { +namespace { + +class AssistantTest : public ::testing::Test { + protected: + void SetUp() override { + // Set up base assistant with minimal required fields + base_assistant.id = "asst_123"; + base_assistant.object = "assistant"; + base_assistant.created_at = 1702000000; + base_assistant.model = "gpt-4"; + } + + Assistant base_assistant; +}; + +TEST_F(AssistantTest, MinimalAssistantToJson) { + auto result = base_assistant.ToJson(); + ASSERT_TRUE(result.has_value()); + + Json::Value json = result.value(); + EXPECT_EQ(json["id"].asString(), "asst_123"); + EXPECT_EQ(json["object"].asString(), "assistant"); + EXPECT_EQ(json["created_at"].asUInt64(), 1702000000); + EXPECT_EQ(json["model"].asString(), "gpt-4"); +} + +TEST_F(AssistantTest, FullAssistantToJson) { + base_assistant.name = "Test Assistant"; + base_assistant.description = "Test Description"; + base_assistant.instructions = "Test Instructions"; + base_assistant.temperature = 0.7f; + base_assistant.top_p = 0.9f; + + // Add a code interpreter tool + auto code_tool = std::make_unique(); + base_assistant.tools.push_back(std::move(code_tool)); + + // Add metadata + base_assistant.metadata["key1"] = std::string("value1"); + base_assistant.metadata["key2"] = true; + base_assistant.metadata["key3"] = static_cast(42ULL); + + auto result = base_assistant.ToJson(); + ASSERT_TRUE(result.has_value()); + + Json::Value json = result.value(); + EXPECT_EQ(json["name"].asString(), "Test Assistant"); + EXPECT_EQ(json["description"].asString(), "Test Description"); + EXPECT_EQ(json["instructions"].asString(), "Test Instructions"); + EXPECT_FLOAT_EQ(json["temperature"].asFloat(), 0.7f); + EXPECT_FLOAT_EQ(json["top_p"].asFloat(), 0.9f); + + EXPECT_TRUE(json["tools"].isArray()); + EXPECT_EQ(json["tools"].size(), 1); + EXPECT_EQ(json["tools"][0]["type"].asString(), "code_interpreter"); + + EXPECT_TRUE(json["metadata"].isObject()); + EXPECT_EQ(json["metadata"]["key1"].asString(), "value1"); + EXPECT_EQ(json["metadata"]["key2"].asBool(), true); + EXPECT_EQ(json["metadata"]["key3"].asUInt64(), 42ULL); +} + +TEST_F(AssistantTest, FromJsonMinimal) { + Json::Value input; + input["id"] = "asst_123"; + input["object"] = "assistant"; + input["created_at"] = 1702000000; + input["model"] = "gpt-4"; + + auto result = Assistant::FromJson(std::move(input)); + ASSERT_TRUE(result.has_value()); + + const auto& assistant = result.value(); + EXPECT_EQ(assistant.id, "asst_123"); + EXPECT_EQ(assistant.object, "assistant"); + EXPECT_EQ(assistant.created_at, 1702000000); + EXPECT_EQ(assistant.model, "gpt-4"); +} + +TEST_F(AssistantTest, FromJsonComplete) { + Json::Value input; + input["id"] = "asst_123"; + input["object"] = "assistant"; + input["created_at"] = 1702000000; + input["model"] = "gpt-4"; + input["name"] = "Test Assistant"; + input["description"] = "Test Description"; + input["instructions"] = "Test Instructions"; + input["temperature"] = 0.7; + input["top_p"] = 0.9; + + // Add tools + Json::Value tools(Json::arrayValue); + Json::Value code_tool; + code_tool["type"] = "code_interpreter"; + tools.append(code_tool); + + Json::Value function_tool; + function_tool["type"] = "function"; + function_tool["function"] = Json::Value(Json::objectValue); + function_tool["function"]["name"] = "test_function"; + function_tool["function"]["description"] = "Test function"; + function_tool["function"]["parameters"] = Json::Value(Json::objectValue); + tools.append(function_tool); + input["tools"] = tools; + + // Add metadata + Json::Value metadata(Json::objectValue); + metadata["key1"] = "value1"; + metadata["key2"] = true; + metadata["key3"] = 42; + input["metadata"] = metadata; + + auto result = Assistant::FromJson(std::move(input)); + ASSERT_TRUE(result.has_value()); + + const auto& assistant = result.value(); + EXPECT_EQ(assistant.name.value(), "Test Assistant"); + EXPECT_EQ(assistant.description.value(), "Test Description"); + EXPECT_EQ(assistant.instructions.value(), "Test Instructions"); + EXPECT_FLOAT_EQ(assistant.temperature.value(), 0.7f); + EXPECT_FLOAT_EQ(assistant.top_p.value(), 0.9f); + + EXPECT_EQ(assistant.tools.size(), 2); + EXPECT_TRUE(dynamic_cast(assistant.tools[0].get()) != nullptr); + EXPECT_TRUE(dynamic_cast(assistant.tools[1].get()) != nullptr); + + EXPECT_EQ(assistant.metadata.size(), 3); + EXPECT_EQ(std::get(assistant.metadata.at("key1")), "value1"); + EXPECT_EQ(std::get(assistant.metadata.at("key2")), true); + EXPECT_EQ(std::get(assistant.metadata.at("key3")), 42ULL); +} + +TEST_F(AssistantTest, FromJsonInvalidInput) { + // Missing required field 'id' + { + Json::Value input; + input["object"] = "assistant"; + input["created_at"] = 1702000000; + input["model"] = "gpt-4"; + + auto result = Assistant::FromJson(std::move(input)); + EXPECT_FALSE(result.has_value()); + } + + // Invalid object type + { + Json::Value input; + input["id"] = "asst_123"; + input["object"] = "invalid"; + input["created_at"] = 1702000000; + input["model"] = "gpt-4"; + + auto result = Assistant::FromJson(std::move(input)); + EXPECT_FALSE(result.has_value()); + } + + // Invalid created_at type + { + Json::Value input; + input["id"] = "asst_123"; + input["object"] = "assistant"; + input["created_at"] = "invalid"; + input["model"] = "gpt-4"; + + auto result = Assistant::FromJson(std::move(input)); + EXPECT_FALSE(result.has_value()); + } +} + +TEST_F(AssistantTest, MoveConstructorAndAssignment) { + base_assistant.name = "Test Assistant"; + base_assistant.tools.push_back(std::make_unique()); + + // Test move constructor + Assistant moved_assistant(std::move(base_assistant)); + EXPECT_EQ(moved_assistant.id, "asst_123"); + EXPECT_EQ(moved_assistant.name.value(), "Test Assistant"); + EXPECT_EQ(moved_assistant.tools.size(), 1); + + // Test move assignment + Assistant another_assistant; + another_assistant = std::move(moved_assistant); + EXPECT_EQ(another_assistant.id, "asst_123"); + EXPECT_EQ(another_assistant.name.value(), "Test Assistant"); + EXPECT_EQ(another_assistant.tools.size(), 1); +} + +} // namespace +} // namespace OpenAi diff --git a/engine/test/components/test_assistant_tool_code_interpreter.cc b/engine/test/components/test_assistant_tool_code_interpreter.cc new file mode 100644 index 000000000..f32526504 --- /dev/null +++ b/engine/test/components/test_assistant_tool_code_interpreter.cc @@ -0,0 +1,49 @@ +#include +#include +#include "common/assistant_code_interpreter_tool.h" + +namespace OpenAi { +namespace { + +class AssistantCodeInterpreterToolTest : public ::testing::Test {}; + +TEST_F(AssistantCodeInterpreterToolTest, BasicConstruction) { + AssistantCodeInterpreterTool tool; + EXPECT_EQ(tool.type, "code_interpreter"); +} + +TEST_F(AssistantCodeInterpreterToolTest, MoveConstructor) { + AssistantCodeInterpreterTool original; + AssistantCodeInterpreterTool moved(std::move(original)); + EXPECT_EQ(moved.type, "code_interpreter"); +} + +TEST_F(AssistantCodeInterpreterToolTest, MoveAssignment) { + AssistantCodeInterpreterTool original; + AssistantCodeInterpreterTool target; + target = std::move(original); + EXPECT_EQ(target.type, "code_interpreter"); +} + +TEST_F(AssistantCodeInterpreterToolTest, FromJson) { + Json::Value json; // Empty JSON is fine for this tool + auto result = AssistantCodeInterpreterTool::FromJson(); + + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(result.value().type, "code_interpreter"); +} + +TEST_F(AssistantCodeInterpreterToolTest, ToJson) { + AssistantCodeInterpreterTool tool; + auto result = tool.ToJson(); + + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(result.value()["type"].asString(), "code_interpreter"); + + // Verify no extra fields + Json::Value::Members members = result.value().getMemberNames(); + EXPECT_EQ(members.size(), 1); // Only "type" field should be present + EXPECT_EQ(members[0], "type"); +} +} // namespace +} // namespace OpenAi diff --git a/engine/test/components/test_assistant_tool_file_search.cc b/engine/test/components/test_assistant_tool_file_search.cc new file mode 100644 index 000000000..25a2ffc05 --- /dev/null +++ b/engine/test/components/test_assistant_tool_file_search.cc @@ -0,0 +1,207 @@ +#include +#include +#include "common/assistant_file_search_tool.h" + +namespace OpenAi { +namespace { + +class AssistantFileSearchToolTest : public ::testing::Test {}; + +TEST_F(AssistantFileSearchToolTest, FileSearchRankingOptionBasicConstruction) { + const float threshold = 0.75f; + const std::string ranker = "test_ranker"; + FileSearchRankingOption option{threshold, ranker}; + + EXPECT_EQ(option.score_threshold, threshold); + EXPECT_EQ(option.ranker, ranker); +} + +TEST_F(AssistantFileSearchToolTest, FileSearchRankingOptionDefaultRanker) { + const float threshold = 0.5f; + FileSearchRankingOption option{threshold}; + + EXPECT_EQ(option.score_threshold, threshold); + EXPECT_EQ(option.ranker, "auto"); +} + +TEST_F(AssistantFileSearchToolTest, FileSearchRankingOptionFromValidJson) { + Json::Value json; + json["score_threshold"] = 0.8f; + json["ranker"] = "custom_ranker"; + + auto result = FileSearchRankingOption::FromJson(json); + ASSERT_TRUE(result.has_value()); + + EXPECT_EQ(result.value().score_threshold, 0.8f); + EXPECT_EQ(result.value().ranker, "custom_ranker"); +} + +TEST_F(AssistantFileSearchToolTest, FileSearchRankingOptionFromInvalidJson) { + Json::Value json; + auto result = FileSearchRankingOption::FromJson(json); + EXPECT_FALSE(result.has_value()); +} + +TEST_F(AssistantFileSearchToolTest, FileSearchRankingOptionToJson) { + FileSearchRankingOption option{0.9f, "special_ranker"}; + auto json_result = option.ToJson(); + + ASSERT_TRUE(json_result.has_value()); + Json::Value json = json_result.value(); + + EXPECT_EQ(json["score_threshold"].asFloat(), 0.9f); + EXPECT_EQ(json["ranker"].asString(), "special_ranker"); +} + +TEST_F(AssistantFileSearchToolTest, AssistantFileSearchBasicConstruction) { + FileSearchRankingOption ranking_option{0.7f, "test_ranker"}; + AssistantFileSearch search{10, std::move(ranking_option)}; + + EXPECT_EQ(search.max_num_results, 10); + EXPECT_EQ(search.ranking_options.score_threshold, 0.7f); + EXPECT_EQ(search.ranking_options.ranker, "test_ranker"); +} + +TEST_F(AssistantFileSearchToolTest, AssistantFileSearchFromValidJson) { + Json::Value json; + json["max_num_results"] = 15; + + Json::Value ranking_json; + ranking_json["score_threshold"] = 0.85f; + ranking_json["ranker"] = "custom_ranker"; + json["ranking_options"] = ranking_json; + + auto result = AssistantFileSearch::FromJson(json); + ASSERT_TRUE(result.has_value()); + + EXPECT_EQ(result.value().max_num_results, 15); + EXPECT_EQ(result.value().ranking_options.score_threshold, 0.85f); + EXPECT_EQ(result.value().ranking_options.ranker, "custom_ranker"); +} + +TEST_F(AssistantFileSearchToolTest, AssistantFileSearchFromInvalidJson) { + Json::Value json; + // Missing required fields + auto result = AssistantFileSearch::FromJson(json); + EXPECT_FALSE(result.has_value()); +} + +TEST_F(AssistantFileSearchToolTest, AssistantFileSearchToJson) { + FileSearchRankingOption ranking_option{0.95f, "advanced_ranker"}; + AssistantFileSearch search{20, std::move(ranking_option)}; + + auto json_result = search.ToJson(); + ASSERT_TRUE(json_result.has_value()); + + Json::Value json = json_result.value(); + EXPECT_EQ(json["max_num_results"].asInt(), 20); + EXPECT_EQ(json["ranking_options"]["score_threshold"].asFloat(), 0.95f); + EXPECT_EQ(json["ranking_options"]["ranker"].asString(), "advanced_ranker"); +} + +TEST_F(AssistantFileSearchToolTest, AssistantFileSearchToolConstruction) { + FileSearchRankingOption ranking_option{0.8f, "tool_ranker"}; + AssistantFileSearch search{25, std::move(ranking_option)}; + AssistantFileSearchTool tool{search}; + + EXPECT_EQ(tool.type, "file_search"); + EXPECT_EQ(tool.file_search.max_num_results, 25); + EXPECT_EQ(tool.file_search.ranking_options.score_threshold, 0.8f); + EXPECT_EQ(tool.file_search.ranking_options.ranker, "tool_ranker"); +} + +TEST_F(AssistantFileSearchToolTest, AssistantFileSearchToolFromValidJson) { + Json::Value json; + json["type"] = "file_search"; + + Json::Value file_search; + file_search["max_num_results"] = 30; + + Json::Value ranking_options; + ranking_options["score_threshold"] = 0.75f; + ranking_options["ranker"] = "json_ranker"; + file_search["ranking_options"] = ranking_options; + + json["file_search"] = file_search; + + auto result = AssistantFileSearchTool::FromJson(json); + ASSERT_TRUE(result.has_value()); + + EXPECT_EQ(result.value().type, "file_search"); + EXPECT_EQ(result.value().file_search.max_num_results, 30); + EXPECT_EQ(result.value().file_search.ranking_options.score_threshold, 0.75f); + EXPECT_EQ(result.value().file_search.ranking_options.ranker, "json_ranker"); +} + +TEST_F(AssistantFileSearchToolTest, AssistantFileSearchToolFromInvalidJson) { + Json::Value json; + // Missing required fields + auto result = AssistantFileSearchTool::FromJson(json); + EXPECT_FALSE(result.has_value()); +} + +TEST_F(AssistantFileSearchToolTest, AssistantFileSearchToolToJson) { + FileSearchRankingOption ranking_option{0.65f, "final_ranker"}; + AssistantFileSearch search{35, std::move(ranking_option)}; + AssistantFileSearchTool tool{search}; + + auto json_result = tool.ToJson(); + ASSERT_TRUE(json_result.has_value()); + + Json::Value json = json_result.value(); + EXPECT_EQ(json["type"].asString(), "file_search"); + EXPECT_EQ(json["file_search"]["max_num_results"].asInt(), 35); + EXPECT_EQ(json["file_search"]["ranking_options"]["score_threshold"].asFloat(), + 0.65f); + EXPECT_EQ(json["file_search"]["ranking_options"]["ranker"].asString(), + "final_ranker"); +} + +TEST_F(AssistantFileSearchToolTest, MoveConstructorsAndAssignments) { + // Test FileSearchRankingOption move operations + FileSearchRankingOption original_option{0.8f, "original_ranker"}; + FileSearchRankingOption moved_option{std::move(original_option)}; + EXPECT_EQ(moved_option.score_threshold, 0.8f); + EXPECT_EQ(moved_option.ranker, "original_ranker"); + + FileSearchRankingOption assign_target{0.5f}; + assign_target = std::move(moved_option); + EXPECT_EQ(assign_target.score_threshold, 0.8f); + EXPECT_EQ(assign_target.ranker, "original_ranker"); + + // Test AssistantFileSearch move operations + FileSearchRankingOption search_option{0.9f, "search_ranker"}; + AssistantFileSearch original_search{40, std::move(search_option)}; + AssistantFileSearch moved_search{std::move(original_search)}; + EXPECT_EQ(moved_search.max_num_results, 40); + EXPECT_EQ(moved_search.ranking_options.score_threshold, 0.9f); + + // Test AssistantFileSearchTool move operations + FileSearchRankingOption tool_option{0.7f, "tool_ranker"}; + AssistantFileSearch tool_search{45, std::move(tool_option)}; + AssistantFileSearchTool original_tool{tool_search}; + AssistantFileSearchTool moved_tool{std::move(original_tool)}; + EXPECT_EQ(moved_tool.type, "file_search"); + EXPECT_EQ(moved_tool.file_search.max_num_results, 45); +} + +TEST_F(AssistantFileSearchToolTest, EdgeCases) { + // Test boundary values for score_threshold + FileSearchRankingOption min_threshold{0.0f}; + EXPECT_EQ(min_threshold.score_threshold, 0.0f); + + FileSearchRankingOption max_threshold{1.0f}; + EXPECT_EQ(max_threshold.score_threshold, 1.0f); + + // Test boundary values for max_num_results + FileSearchRankingOption ranking_option{0.5f}; + AssistantFileSearch min_results{1, std::move(ranking_option)}; + EXPECT_EQ(min_results.max_num_results, 1); + + FileSearchRankingOption ranking_option2{0.5f}; + AssistantFileSearch max_results{50, std::move(ranking_option2)}; + EXPECT_EQ(max_results.max_num_results, 50); +} + +} // namespace +} // namespace OpenAi diff --git a/engine/test/components/test_assistant_tool_function.cc b/engine/test/components/test_assistant_tool_function.cc new file mode 100644 index 000000000..6f59df693 --- /dev/null +++ b/engine/test/components/test_assistant_tool_function.cc @@ -0,0 +1,240 @@ +#include +#include "common/assistant_function_tool.h" +#include + +namespace OpenAi { +namespace { + +class AssistantFunctionTest : public ::testing::Test { +protected: + void SetUp() override { + // Common test setup + basic_description = "Test function description"; + basic_name = "test_function"; + basic_params = Json::Value(Json::objectValue); + basic_params["type"] = "object"; + basic_params["properties"] = Json::Value(Json::objectValue); + } + + std::string basic_description; + std::string basic_name; + Json::Value basic_params; +}; + +TEST_F(AssistantFunctionTest, BasicConstructionWithoutStrict) { + AssistantFunction function(basic_description, basic_name, basic_params, std::nullopt); + + EXPECT_EQ(function.description, basic_description); + EXPECT_EQ(function.name, basic_name); + EXPECT_EQ(function.parameters, basic_params); + EXPECT_FALSE(function.strict.has_value()); +} + +TEST_F(AssistantFunctionTest, BasicConstructionWithStrict) { + AssistantFunction function(basic_description, basic_name, basic_params, true); + + EXPECT_EQ(function.description, basic_description); + EXPECT_EQ(function.name, basic_name); + EXPECT_EQ(function.parameters, basic_params); + ASSERT_TRUE(function.strict.has_value()); + EXPECT_TRUE(*function.strict); +} + +TEST_F(AssistantFunctionTest, MoveConstructor) { + AssistantFunction original(basic_description, basic_name, basic_params, true); + + AssistantFunction moved(std::move(original)); + + EXPECT_EQ(moved.description, basic_description); + EXPECT_EQ(moved.name, basic_name); + EXPECT_EQ(moved.parameters, basic_params); + ASSERT_TRUE(moved.strict.has_value()); + EXPECT_TRUE(*moved.strict); +} + +TEST_F(AssistantFunctionTest, MoveAssignment) { + AssistantFunction original(basic_description, basic_name, basic_params, true); + + AssistantFunction target("other", "other_name", Json::Value(Json::objectValue), false); + target = std::move(original); + + EXPECT_EQ(target.description, basic_description); + EXPECT_EQ(target.name, basic_name); + EXPECT_EQ(target.parameters, basic_params); + ASSERT_TRUE(target.strict.has_value()); + EXPECT_TRUE(*target.strict); +} + +TEST_F(AssistantFunctionTest, FromValidJson) { + Json::Value json; + json["description"] = basic_description; + json["name"] = basic_name; + json["strict"] = true; + json["parameters"] = basic_params; + + auto result = AssistantFunction::FromJson(json); + ASSERT_TRUE(result.has_value()); + + const auto& function = result.value(); + EXPECT_EQ(function.description, basic_description); + EXPECT_EQ(function.name, basic_name); + EXPECT_EQ(function.parameters, basic_params); + ASSERT_TRUE(function.strict.has_value()); + EXPECT_TRUE(*function.strict); +} + +TEST_F(AssistantFunctionTest, FromJsonValidationEmptyJson) { + Json::Value json; + auto result = AssistantFunction::FromJson(json); + ASSERT_TRUE(result.has_error()); + EXPECT_EQ(result.error(), "Function json can't be empty"); +} + +TEST_F(AssistantFunctionTest, FromJsonValidationEmptyName) { + Json::Value json; + json["description"] = basic_description; + json["parameters"] = basic_params; + + auto result = AssistantFunction::FromJson(json); + ASSERT_TRUE(result.has_error()); + EXPECT_EQ(result.error(), "Function name can't be empty"); + + // Test with empty name value + json["name"] = ""; + result = AssistantFunction::FromJson(json); + ASSERT_TRUE(result.has_error()); + EXPECT_EQ(result.error(), "Function name can't be empty"); +} + +TEST_F(AssistantFunctionTest, FromJsonValidationMissingDescription) { + Json::Value json; + json["name"] = basic_name; + json["parameters"] = basic_params; + + auto result = AssistantFunction::FromJson(json); + ASSERT_TRUE(result.has_error()); + EXPECT_EQ(result.error(), "Function description is mandatory"); +} + +TEST_F(AssistantFunctionTest, FromJsonValidationMissingParameters) { + Json::Value json; + json["name"] = basic_name; + json["description"] = basic_description; + + auto result = AssistantFunction::FromJson(json); + ASSERT_TRUE(result.has_error()); + EXPECT_EQ(result.error(), "Function parameters are mandatory"); +} + +TEST_F(AssistantFunctionTest, ToJsonWithStrict) { + AssistantFunction function(basic_description, basic_name, basic_params, true); + + auto result = function.ToJson(); + ASSERT_TRUE(result.has_value()); + + const auto& json = result.value(); + EXPECT_EQ(json["description"].asString(), basic_description); + EXPECT_EQ(json["name"].asString(), basic_name); + EXPECT_EQ(json["parameters"], basic_params); + EXPECT_TRUE(json["strict"].asBool()); +} + +TEST_F(AssistantFunctionTest, ToJsonWithoutStrict) { + AssistantFunction function(basic_description, basic_name, basic_params, std::nullopt); + + auto result = function.ToJson(); + ASSERT_TRUE(result.has_value()); + + const auto& json = result.value(); + EXPECT_EQ(json["description"].asString(), basic_description); + EXPECT_EQ(json["name"].asString(), basic_name); + EXPECT_EQ(json["parameters"], basic_params); + EXPECT_FALSE(json.isMember("strict")); +} + +// AssistantFunctionTool Tests +class AssistantFunctionToolTest : public ::testing::Test { +protected: + void SetUp() override { + description = "Test tool description"; + name = "test_tool"; + params = Json::Value(Json::objectValue); + params["type"] = "object"; + } + + std::string description; + std::string name; + Json::Value params; +}; + +TEST_F(AssistantFunctionToolTest, BasicConstruction) { + AssistantFunction function(description, name, params, true); + AssistantFunctionTool tool(function); + + EXPECT_EQ(tool.type, "function"); + EXPECT_EQ(tool.function.description, description); + EXPECT_EQ(tool.function.name, name); + EXPECT_EQ(tool.function.parameters, params); + ASSERT_TRUE(tool.function.strict.has_value()); + EXPECT_TRUE(*tool.function.strict); +} + +TEST_F(AssistantFunctionToolTest, MoveConstructor) { + AssistantFunction function(description, name, params, true); + AssistantFunctionTool original(function); + + AssistantFunctionTool moved(std::move(original)); + + EXPECT_EQ(moved.type, "function"); + EXPECT_EQ(moved.function.description, description); + EXPECT_EQ(moved.function.name, name); + EXPECT_EQ(moved.function.parameters, params); +} + +TEST_F(AssistantFunctionToolTest, FromValidJson) { + Json::Value function_json; + function_json["description"] = description; + function_json["name"] = name; + function_json["strict"] = true; + function_json["parameters"] = params; + + Json::Value json; + json["type"] = "function"; + json["function"] = function_json; + + auto result = AssistantFunctionTool::FromJson(json); + ASSERT_TRUE(result.has_value()); + + const auto& tool = result.value(); + EXPECT_EQ(tool.type, "function"); + EXPECT_EQ(tool.function.description, description); + EXPECT_EQ(tool.function.name, name); + EXPECT_EQ(tool.function.parameters, params); + ASSERT_TRUE(tool.function.strict.has_value()); + EXPECT_TRUE(*tool.function.strict); +} + +TEST_F(AssistantFunctionToolTest, FromInvalidJson) { + Json::Value json; + auto result = AssistantFunctionTool::FromJson(json); + EXPECT_TRUE(result.has_error()); + EXPECT_EQ(result.error(), "Failed to parse function: Function json can't be empty"); +} + +TEST_F(AssistantFunctionToolTest, ToJson) { + AssistantFunction function(description, name, params, true); + AssistantFunctionTool tool(function); + + auto result = tool.ToJson(); + ASSERT_TRUE(result.has_value()); + + const auto& json = result.value(); + EXPECT_EQ(json["type"].asString(), "function"); + EXPECT_EQ(json["function"]["description"].asString(), description); + EXPECT_EQ(json["function"]["name"].asString(), name); + EXPECT_EQ(json["function"]["parameters"], params); + EXPECT_TRUE(json["function"]["strict"].asBool()); +} + +} // namespace +} // namespace OpenAi diff --git a/engine/test/components/test_tool_resources.cc b/engine/test/components/test_tool_resources.cc new file mode 100644 index 000000000..2b78e6494 --- /dev/null +++ b/engine/test/components/test_tool_resources.cc @@ -0,0 +1,212 @@ +#include +#include +#include "common/tool_resources.h" + +namespace OpenAi { +namespace { + +// Mock class for testing abstract ToolResources +class MockToolResources : public ToolResources { + public: + cpp::result ToJson() override { + Json::Value json; + json["mock"] = "value"; + return json; + } +}; + +class ToolResourcesTest : public ::testing::Test {}; + +TEST_F(ToolResourcesTest, MoveConstructor) { + MockToolResources original; + MockToolResources moved(std::move(original)); + + auto json_result = moved.ToJson(); + ASSERT_TRUE(json_result.has_value()); + EXPECT_EQ(json_result.value()["mock"].asString(), "value"); +} + +TEST_F(ToolResourcesTest, MoveAssignment) { + MockToolResources original; + MockToolResources target; + target = std::move(original); + + auto json_result = target.ToJson(); + ASSERT_TRUE(json_result.has_value()); + EXPECT_EQ(json_result.value()["mock"].asString(), "value"); +} + +class CodeInterpreterTest : public ::testing::Test { + protected: + void SetUp() override { sample_file_ids = {"file1", "file2", "file3"}; } + + std::vector sample_file_ids; +}; + +TEST_F(CodeInterpreterTest, DefaultConstruction) { + CodeInterpreter interpreter; + EXPECT_TRUE(interpreter.file_ids.empty()); +} + +TEST_F(CodeInterpreterTest, MoveConstructor) { + CodeInterpreter original; + original.file_ids = sample_file_ids; + + CodeInterpreter moved(std::move(original)); + EXPECT_EQ(moved.file_ids, sample_file_ids); + EXPECT_TRUE(original.file_ids.empty()); // NOLINT: Checking moved-from state +} + +TEST_F(CodeInterpreterTest, MoveAssignment) { + CodeInterpreter original; + original.file_ids = sample_file_ids; + + CodeInterpreter target; + target = std::move(original); + EXPECT_EQ(target.file_ids, sample_file_ids); + EXPECT_TRUE(original.file_ids.empty()); // NOLINT: Checking moved-from state +} + +TEST_F(CodeInterpreterTest, FromJsonWithFileIds) { + Json::Value json; + Json::Value file_ids(Json::arrayValue); + for (const auto& id : sample_file_ids) { + file_ids.append(id); + } + json["file_ids"] = file_ids; + + auto result = CodeInterpreter::FromJson(json); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(result.value().file_ids, sample_file_ids); +} + +TEST_F(CodeInterpreterTest, FromJsonWithoutFileIds) { + Json::Value json; // Empty JSON + auto result = CodeInterpreter::FromJson(json); + ASSERT_TRUE(result.has_value()); + EXPECT_TRUE(result.value().file_ids.empty()); +} + +TEST_F(CodeInterpreterTest, ToJson) { + CodeInterpreter interpreter; + interpreter.file_ids = sample_file_ids; + + auto result = interpreter.ToJson(); + ASSERT_TRUE(result.has_value()); + + const auto& json = result.value(); + ASSERT_TRUE(json.isMember("file_ids")); + ASSERT_TRUE(json["file_ids"].isArray()); + ASSERT_EQ(json["file_ids"].size(), sample_file_ids.size()); + + for (Json::ArrayIndex i = 0; i < json["file_ids"].size(); ++i) { + EXPECT_EQ(json["file_ids"][i].asString(), sample_file_ids[i]); + } +} + +TEST_F(CodeInterpreterTest, ToJsonEmptyFileIds) { + CodeInterpreter interpreter; + + auto result = interpreter.ToJson(); + ASSERT_TRUE(result.has_value()); + + const auto& json = result.value(); + ASSERT_TRUE(json.isMember("file_ids")); + ASSERT_TRUE(json["file_ids"].isArray()); + EXPECT_EQ(json["file_ids"].size(), 0); +} + +class FileSearchTest : public ::testing::Test { + protected: + void SetUp() override { + sample_vector_store_ids = {"store1", "store2", "store3"}; + } + + std::vector sample_vector_store_ids; +}; + +TEST_F(FileSearchTest, DefaultConstruction) { + FileSearch search; + EXPECT_TRUE(search.vector_store_ids.empty()); +} + +TEST_F(FileSearchTest, MoveConstructor) { + FileSearch original; + original.vector_store_ids = sample_vector_store_ids; + + FileSearch moved(std::move(original)); + EXPECT_EQ(moved.vector_store_ids, sample_vector_store_ids); + EXPECT_TRUE( + original.vector_store_ids.empty()); // NOLINT: Checking moved-from state +} + +TEST_F(FileSearchTest, MoveAssignment) { + FileSearch original; + original.vector_store_ids = sample_vector_store_ids; + + FileSearch target; + target = std::move(original); + EXPECT_EQ(target.vector_store_ids, sample_vector_store_ids); + EXPECT_TRUE( + original.vector_store_ids.empty()); // NOLINT: Checking moved-from state +} + +TEST_F(FileSearchTest, FromJsonWithVectorStoreIds) { + Json::Value json; + Json::Value vector_store_ids(Json::arrayValue); + for (const auto& id : sample_vector_store_ids) { + vector_store_ids.append(id); + } + json["vector_store_ids"] = vector_store_ids; + + auto result = FileSearch::FromJson(json); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(result.value().vector_store_ids, sample_vector_store_ids); +} + +TEST_F(FileSearchTest, FromJsonWithoutVectorStoreIds) { + Json::Value json; // Empty JSON + auto result = FileSearch::FromJson(json); + ASSERT_TRUE(result.has_value()); + EXPECT_TRUE(result.value().vector_store_ids.empty()); +} + +TEST_F(FileSearchTest, ToJson) { + FileSearch search; + search.vector_store_ids = sample_vector_store_ids; + + auto result = search.ToJson(); + ASSERT_TRUE(result.has_value()); + + const auto& json = result.value(); + ASSERT_TRUE(json.isMember("vector_store_ids")); + ASSERT_TRUE(json["vector_store_ids"].isArray()); + ASSERT_EQ(json["vector_store_ids"].size(), sample_vector_store_ids.size()); + + for (Json::ArrayIndex i = 0; i < json["vector_store_ids"].size(); ++i) { + EXPECT_EQ(json["vector_store_ids"][i].asString(), + sample_vector_store_ids[i]); + } +} + +TEST_F(FileSearchTest, ToJsonEmptyVectorStoreIds) { + FileSearch search; + + auto result = search.ToJson(); + ASSERT_TRUE(result.has_value()); + + const auto& json = result.value(); + ASSERT_TRUE(json.isMember("vector_store_ids")); + ASSERT_TRUE(json["vector_store_ids"].isArray()); + EXPECT_EQ(json["vector_store_ids"].size(), 0); +} + +TEST_F(FileSearchTest, SelfAssignment) { + FileSearch search; + search.vector_store_ids = sample_vector_store_ids; + + search = std::move(search); // Self-assignment with move + EXPECT_EQ(search.vector_store_ids, sample_vector_store_ids); +} +} // namespace +} // namespace OpenAi