diff --git a/engine/common/assistant.h b/engine/common/assistant.h index e49147e9e..b9592e3e9 100644 --- a/engine/common/assistant.h +++ b/engine/common/assistant.h @@ -1,9 +1,13 @@ #pragma once #include +#include "common/assistant_code_interpreter_tool.h" +#include "common/assistant_file_search_tool.h" +#include "common/assistant_function_tool.h" #include "common/assistant_tool.h" -#include "common/thread_tool_resources.h" +#include "common/tool_resources.h" #include "common/variant_map.h" +#include "utils/logging_utils.h" #include "utils/result.hpp" namespace OpenAi { @@ -75,7 +79,49 @@ struct JanAssistant : JsonSerializable { } }; -struct Assistant { +struct Assistant : JsonSerializable { + Assistant() = default; + + ~Assistant() = default; + + Assistant(const Assistant&) = delete; + + Assistant& operator=(const Assistant&) = delete; + + Assistant(Assistant&& other) noexcept + : id{std::move(other.id)}, + object{std::move(other.object)}, + created_at{other.created_at}, + name{std::move(other.name)}, + description{std::move(other.description)}, + model(std::move(other.model)), + instructions(std::move(other.instructions)), + tools(std::move(other.tools)), + tool_resources(std::move(other.tool_resources)), + metadata(std::move(other.metadata)), + temperature{std::move(other.temperature)}, + top_p{std::move(other.top_p)}, + response_format{std::move(other.response_format)} {} + + Assistant& operator=(Assistant&& other) noexcept { + if (this != &other) { + id = std::move(other.id); + object = std::move(other.object); + created_at = other.created_at; + name = std::move(other.name); + description = std::move(other.description); + model = std::move(other.model); + instructions = std::move(other.instructions); + tools = std::move(other.tools); + tool_resources = std::move(other.tool_resources); + metadata = std::move(other.metadata); + temperature = std::move(other.temperature); + top_p = std::move(other.top_p); + response_format = std::move(other.response_format); + } + return *this; + } + /** * The identifier, which can be referenced in API endpoints. */ @@ -126,8 +172,7 @@ struct Assistant { * requires a list of file IDs, while the file_search tool requires a list * of vector store IDs. */ - std::optional> - tool_resources; + std::optional> tool_resources; /** * Set of 16 key-value pairs that can be attached to an object. This can be @@ -153,5 +198,192 @@ struct Assistant { * We generally recommend altering this or temperature but not both. */ std::optional top_p; + + std::variant response_format; + + cpp::result ToJson() override { + try { + Json::Value root; + + root["id"] = std::move(id); + root["object"] = "assistant"; + root["created_at"] = created_at; + if (name.has_value()) { + root["name"] = name.value(); + } + if (description.has_value()) { + root["description"] = description.value(); + } + root["model"] = model; + if (instructions.has_value()) { + root["instructions"] = instructions.value(); + } + + Json::Value tools_jarr{Json::arrayValue}; + for (auto& tool_ptr : tools) { + if (auto it = tool_ptr->ToJson(); it.has_value()) { + tools_jarr.append(it.value()); + } else { + CTL_WRN("Failed to convert content to json: " + it.error()); + } + } + if (tool_resources.has_value()) { + Json::Value tool_resources_json; + + if (auto* code_interpreter = + std::get_if(&tool_resources.value())) { + if (auto result = code_interpreter->ToJson(); result.has_value()) { + tool_resources_json["code_interpreter"] = result.value(); + } else { + CTL_WRN("Failed to convert code_interpreter to json: " + + result.error()); + } + } else if (auto* file_search = + std::get_if(&tool_resources.value())) { + if (auto result = file_search->ToJson(); result.has_value()) { + tool_resources_json["file_search"] = result.value(); + } else { + CTL_WRN("Failed to convert file_search to json: " + result.error()); + } + } + + if (!tool_resources_json.empty()) { + root["tool_resources"] = tool_resources_json; + } + } + Json::Value metadata_json{Json::objectValue}; + for (const auto& [key, value] : metadata) { + if (std::holds_alternative(value)) { + metadata_json[key] = std::get(value); + } else if (std::holds_alternative(value)) { + metadata_json[key] = std::get(value); + } else if (std::holds_alternative(value)) { + metadata_json[key] = std::get(value); + } else { + metadata_json[key] = std::get(value); + } + } + root["metadata"] = metadata_json; + + if (temperature.has_value()) { + root["temperature"] = temperature.value(); + } + if (top_p.has_value()) { + root["top_p"] = top_p.value(); + } + return root; + } catch (const std::exception& e) { + return cpp::fail("ToJson failed: " + std::string(e.what())); + } + } + + static cpp::result FromJson(Json::Value&& json) { + try { + Assistant assistant; + + // Parse required fields + if (!json.isMember("id") || !json["id"].isString()) { + return cpp::fail("Missing or invalid 'id' field"); + } + assistant.id = json["id"].asString(); + + if (!json.isMember("object") || !json["object"].isString() || + json["object"].asString() != "assistant") { + return cpp::fail("Missing or invalid 'object' field"); + } + + if (!json.isMember("created_at") || !json["created_at"].isUInt64()) { + return cpp::fail("Missing or invalid 'created_at' field"); + } + assistant.created_at = json["created_at"].asUInt64(); + + if (!json.isMember("model") || !json["model"].isString()) { + return cpp::fail("Missing or invalid 'model' field"); + } + assistant.model = json["model"].asString(); + + // Parse optional fields + if (json.isMember("name") && json["name"].isString()) { + assistant.name = json["name"].asString(); + } + + if (json.isMember("description") && json["description"].isString()) { + assistant.description = json["description"].asString(); + } + + if (json.isMember("instructions") && json["instructions"].isString()) { + assistant.instructions = json["instructions"].asString(); + } + + // Parse tools array + if (json.isMember("tools") && json["tools"].isArray()) { + auto tools_array = json["tools"]; + for (const auto& tool : tools_array) { + if (!tool.isMember("type") || !tool["type"].isString()) { + CTL_WRN("Tool missing type field or invalid type"); + continue; + } + + std::string tool_type = tool["type"].asString(); + if (tool_type == "file_search") { + auto result = AssistantFileSearchTool::FromJson(tool); + if (result.has_value()) { + assistant.tools.push_back( + std::make_unique( + std::move(result.value()))); + } else { + CTL_WRN("Failed to parse file_search tool: " + result.error()); + } + } else if (tool_type == "code_interpreter") { + auto result = AssistantCodeInterpreterTool::FromJson(tool); + if (result.has_value()) { + assistant.tools.push_back( + std::make_unique( + std::move(result.value()))); + } else { + CTL_WRN("Failed to parse code_interpreter tool: " + + result.error()); + } + } else if (tool_type == "function") { + auto result = AssistantFunctionTool::FromJson(tool); + if (result.has_value()) { + assistant.tools.push_back(std::make_unique( + std::move(result.value()))); + } else { + CTL_WRN("Failed to parse function tool: " + result.error()); + } + } else { + CTL_WRN("Unknown tool type: " + tool_type); + } + } + } + + // Parse tool_resources + if (json.isMember("tool_resources") && + json["tool_resources"].isObject()) {} + + // Parse metadata + if (json.isMember("metadata") && json["metadata"].isObject()) { + auto res = Cortex::ConvertJsonValueToMap(json["metadata"]); + if (res.has_value()) { + assistant.metadata = res.value(); + } else { + CTL_WRN("Failed to convert metadata to map: " + res.error()); + } + } + + if (json.isMember("temperature") && json["temperature"].isDouble()) { + assistant.temperature = json["temperature"].asFloat(); + } + + if (json.isMember("top_p") && json["top_p"].isDouble()) { + assistant.top_p = json["top_p"].asFloat(); + } + + return assistant; + } catch (const std::exception& e) { + return cpp::fail("FromJson failed: " + std::string(e.what())); + } + } }; } // namespace OpenAi diff --git a/engine/common/assistant_code_interpreter_tool.h b/engine/common/assistant_code_interpreter_tool.h new file mode 100644 index 000000000..98e2d9fd1 --- /dev/null +++ b/engine/common/assistant_code_interpreter_tool.h @@ -0,0 +1,34 @@ +#include "common/assistant_tool.h" + +namespace OpenAi { +struct AssistantCodeInterpreterTool : public AssistantTool { + AssistantCodeInterpreterTool() : AssistantTool("code_interpreter") {} + + AssistantCodeInterpreterTool(AssistantCodeInterpreterTool&) = delete; + + AssistantCodeInterpreterTool& operator=(AssistantCodeInterpreterTool&) = + delete; + + AssistantCodeInterpreterTool(AssistantCodeInterpreterTool&&) = default; + + AssistantCodeInterpreterTool& operator=(AssistantCodeInterpreterTool&&) = + default; + + ~AssistantCodeInterpreterTool() = default; + + static cpp::result FromJson( + const Json::Value& json) { + if (json.empty()) { + return cpp::fail("Empty JSON"); + } + AssistantCodeInterpreterTool tool; + return std::move(tool); + } + + cpp::result ToJson() override { + Json::Value json; + json["type"] = type; + return json; + } +}; +} // namespace OpenAi diff --git a/engine/common/assistant_file_search_tool.h b/engine/common/assistant_file_search_tool.h new file mode 100644 index 000000000..8b1081355 --- /dev/null +++ b/engine/common/assistant_file_search_tool.h @@ -0,0 +1,148 @@ +#include "common/assistant_tool.h" +#include "common/json_serializable.h" + +namespace OpenAi { +struct FileSearchRankingOption : public JsonSerializable { + /** + * The ranker to use for the file search. If not specified will use the auto ranker. + */ + std::string ranker; + + /** + * The score threshold for the file search. All values must be a + * floating point number between 0 and 1. + */ + float score_threshold; + + FileSearchRankingOption(float score_threshold, + const std::string& ranker = "auto") + : ranker{ranker}, score_threshold{score_threshold} {} + + FileSearchRankingOption(const FileSearchRankingOption&) = delete; + + FileSearchRankingOption& operator=(const FileSearchRankingOption&) = delete; + + FileSearchRankingOption(FileSearchRankingOption&&) = default; + + FileSearchRankingOption& operator=(FileSearchRankingOption&&) = default; + + ~FileSearchRankingOption() = default; + + static cpp::result FromJson( + const Json::Value& json) { + try { + FileSearchRankingOption option{json["score_threshold"].asFloat(), + std::move(json["ranker"].asString())}; + return option; + } catch (const std::exception& e) { + return cpp::fail(std::string("FromJson failed: ") + e.what()); + } + } + + cpp::result ToJson() override { + Json::Value json; + json["ranker"] = ranker; + json["score_threshold"] = score_threshold; + return json; + } +}; + +/** + * Overrides for the file search tool. + */ +struct AssistantFileSearch : public JsonSerializable { + /** + * The maximum number of results the file search tool should output. + * The default is 20 for gpt-4* models and 5 for gpt-3.5-turbo. + * This number should be between 1 and 50 inclusive. + * + * Note that the file search tool may output fewer than max_num_results results. + * See the file search tool documentation for more information. + */ + int max_num_results; + + /** + * The ranking options for the file search. If not specified, + * the file search tool will use the auto ranker and a score_threshold of 0. + * + * See the file search tool documentation for more information. + */ + FileSearchRankingOption ranking_options; + + AssistantFileSearch(int max_num_results, + FileSearchRankingOption&& ranking_options) + : max_num_results{max_num_results}, + ranking_options{std::move(ranking_options)} {} + + AssistantFileSearch(const AssistantFileSearch&) = delete; + + AssistantFileSearch& operator=(const AssistantFileSearch&) = delete; + + AssistantFileSearch(AssistantFileSearch&&) = default; + + AssistantFileSearch& operator=(AssistantFileSearch&&) = default; + + ~AssistantFileSearch() = default; + + static cpp::result FromJson( + const Json::Value& json) { + try { + AssistantFileSearch search{ + json["max_num_results"].asInt(), + FileSearchRankingOption::FromJson(json["ranking_options"]).value()}; + return search; + } catch (const std::exception& e) { + return cpp::fail(std::string("FromJson failed: ") + e.what()); + } + } + + cpp::result ToJson() override { + Json::Value root; + root["max_num_results"] = max_num_results; + root["ranking_options"] = ranking_options.ToJson().value(); + return root; + } +}; + +struct AssistantFileSearchTool : public AssistantTool { + AssistantFileSearch file_search; + + AssistantFileSearchTool(AssistantFileSearch& file_search) + : AssistantTool("file_search"), file_search{std::move(file_search)} {} + + AssistantFileSearchTool(const AssistantFileSearchTool&) = delete; + + AssistantFileSearchTool& operator=(const AssistantFileSearchTool&) = delete; + + AssistantFileSearchTool(AssistantFileSearchTool&&) = default; + + AssistantFileSearchTool& operator=(AssistantFileSearchTool&&) = default; + + ~AssistantFileSearchTool() = default; + + static cpp::result FromJson( + const Json::Value& json) { + try { + AssistantFileSearch search{json["file_search"]["max_num_results"].asInt(), + FileSearchRankingOption::FromJson( + json["file_search"]["ranking_options"]) + .value()}; + AssistantFileSearchTool tool{search}; + return tool; + } catch (const std::exception& e) { + return cpp::fail(std::string("FromJson failed: ") + e.what()); + } + } + + cpp::result ToJson() override { + try { + Json::Value root; + root["type"] = type; + root["file_search"] = file_search.ToJson().value(); + return root; + } catch (const std::exception& e) { + return cpp::fail(std::string("ToJson failed: ") + e.what()); + } + } +}; +}; // namespace OpenAi diff --git a/engine/common/assistant_function_tool.h b/engine/common/assistant_function_tool.h new file mode 100644 index 000000000..4e36fdbd3 --- /dev/null +++ b/engine/common/assistant_function_tool.h @@ -0,0 +1,106 @@ +#include "common/assistant_tool.h" +#include "common/json_serializable.h" + +namespace OpenAi { +struct AssistantFunction : public JsonSerializable { + AssistantFunction(const std::string& description, const std::string& name, + const std::optional& strict) + : description{description}, name{name}, strict{strict} {} + + AssistantFunction(AssistantFunction&) = delete; + + AssistantFunction& operator=(AssistantFunction&) = delete; + + AssistantFunction(AssistantFunction&&) = default; + + AssistantFunction& operator=(AssistantFunction&&) = default; + + ~AssistantFunction() = default; + + /** + * A description of what the function does, used by the model to choose + * when and how to call the function. + */ + std::string description; + + /** + * The name of the function to be called. Must be a-z, A-Z, 0-9, or contain + * underscores and dashes, with a maximum length of 64. + */ + std::string name; + + /** + * The parameters the functions accepts, described as a JSON Schema object. + * See the guide for examples, and the JSON Schema reference for documentation + * about the format. + * + * Omitting parameters defines a function with an empty parameter list. + */ + Json::Value parameters; + + /** + * Whether to enable strict schema adherence when generating the function call. + * If set to true, the model will follow the exact schema defined in the parameters + * field. Only a subset of JSON Schema is supported when strict is true. + * + * Learn more about Structured Outputs in the function calling guide. + */ + std::optional strict; + + static cpp::result FromJson( + const Json::Value& json) { + std::optional is_strict = std::nullopt; + if (json.isMember("strict")) { + is_strict = json["strict"].asBool(); + } + AssistantFunction function{json["description"].asString(), + json["name"].asString(), is_strict}; + function.parameters = json["parameters"]; + return function; + } + + cpp::result ToJson() override { + Json::Value json; + json["description"] = description; + json["name"] = name; + if (strict.has_value()) { + json["strict"] = *strict; + } + json["parameters"] = parameters; + return json; + } +}; + +struct AssistantFunctionTool : public AssistantTool { + AssistantFunctionTool(AssistantFunction& function) + : AssistantTool("function"), function{std::move(function)} {} + + AssistantFunctionTool(AssistantFunctionTool&) = delete; + + AssistantFunctionTool& operator=(AssistantFunctionTool&) = delete; + + AssistantFunctionTool(AssistantFunctionTool&&) = default; + + AssistantFunctionTool& operator=(AssistantFunctionTool&&) = default; + + ~AssistantFunctionTool() = default; + + AssistantFunction function; + + static cpp::result FromJson( + const Json::Value& json) { + auto function_res = AssistantFunction::FromJson(json["function"]); + if (function_res.has_error()) { + return cpp::fail("Failed to parse function: " + function_res.error()); + } + return AssistantFunctionTool{function_res.value()}; + } + + cpp::result ToJson() override { + Json::Value root; + root["type"] = type; + root["function"] = function.ToJson().value(); + return root; + } +}; +}; // namespace OpenAi diff --git a/engine/common/assistant_tool.h b/engine/common/assistant_tool.h index 622721708..2f951eab6 100644 --- a/engine/common/assistant_tool.h +++ b/engine/common/assistant_tool.h @@ -1,91 +1,14 @@ #pragma once -#include #include +#include "common/json_serializable.h" namespace OpenAi { -struct AssistantTool { +struct AssistantTool : public JsonSerializable { std::string type; AssistantTool(const std::string& type) : type{type} {} virtual ~AssistantTool() = default; }; - -struct AssistantCodeInterpreterTool : public AssistantTool { - AssistantCodeInterpreterTool() : AssistantTool{"code_interpreter"} {} - - ~AssistantCodeInterpreterTool() = default; -}; - -struct AssistantFileSearchTool : public AssistantTool { - AssistantFileSearchTool() : AssistantTool("file_search") {} - - ~AssistantFileSearchTool() = default; - - /** - * The ranking options for the file search. If not specified, - * the file search tool will use the auto ranker and a score_threshold of 0. - * - * See the file search tool documentation for more information. - */ - struct RankingOption { - /** - * The ranker to use for the file search. If not specified will use the auto ranker. - */ - std::string ranker; - - /** - * The score threshold for the file search. All values must be a - * floating point number between 0 and 1. - */ - float score_threshold; - }; - - /** - * Overrides for the file search tool. - */ - struct FileSearch { - /** - * The maximum number of results the file search tool should output. - * The default is 20 for gpt-4* models and 5 for gpt-3.5-turbo. - * This number should be between 1 and 50 inclusive. - * - * Note that the file search tool may output fewer than max_num_results results. - * See the file search tool documentation for more information. - */ - int max_num_result; - }; -}; - -struct AssistantFunctionTool : public AssistantTool { - AssistantFunctionTool() : AssistantTool("function") {} - - ~AssistantFunctionTool() = default; - - struct Function { - /** - * A description of what the function does, used by the model to choose - * when and how to call the function. - */ - std::string description; - - /** - * The name of the function to be called. Must be a-z, A-Z, 0-9, or contain - * underscores and dashes, with a maximum length of 64. - */ - std::string name; - - // TODO: namh handle parameters - - /** - * Whether to enable strict schema adherence when generating the function call. - * If set to true, the model will follow the exact schema defined in the parameters - * field. Only a subset of JSON Schema is supported when strict is true. - * - * Learn more about Structured Outputs in the function calling guide. - */ - std::optional strict; - }; -}; } // namespace OpenAi diff --git a/engine/common/dto/assistant_create_dto.h b/engine/common/dto/assistant_create_dto.h new file mode 100644 index 000000000..b27e020b6 --- /dev/null +++ b/engine/common/dto/assistant_create_dto.h @@ -0,0 +1,123 @@ +#pragma once + +#include +#include +#include "common/dto/base_dto.h" +#include "common/variant_map.h" +#include "utils/logging_utils.h" + +namespace dto { +struct CreateAssistantDto : public BaseDto { + CreateAssistantDto() = default; + + ~CreateAssistantDto() = default; + + CreateAssistantDto(const CreateAssistantDto&) = delete; + + CreateAssistantDto& operator=(const CreateAssistantDto&) = delete; + + CreateAssistantDto(CreateAssistantDto&& other) noexcept + : model{std::move(other.model)}, + name{std::move(other.name)}, + description{std::move(other.description)}, + instructions{std::move(other.instructions)}, + metadata{std::move(other.metadata)}, + temperature{std::move(other.temperature)}, + top_p{std::move(other.top_p)}, + response_format{std::move(other.response_format)} {} + + CreateAssistantDto& operator=(CreateAssistantDto&& other) noexcept { + if (this != &other) { + model = std::move(other.model); + name = std::move(other.name); + description = std::move(other.description); + instructions = std::move(other.instructions); + metadata = std::move(other.metadata); + temperature = std::move(other.temperature); + top_p = std::move(other.top_p); + response_format = std::move(other.response_format); + } + return *this; + } + + std::string model; + + std::optional name; + + std::optional description; + + std::optional instructions; + + // namH: implement tools + + // namh: implement tool_resources + + std::optional metadata; + + std::optional temperature; + + std::optional top_p; + + std::optional> response_format; + + cpp::result Validate() const override { + if (model.empty()) { + return cpp::fail("Model is mandatory"); + } + + if (response_format.has_value()) { + const auto& variant_value = response_format.value(); + if (std::holds_alternative(variant_value)) { + if (std::get(variant_value) != "auto") { + return cpp::fail("Invalid response_format"); + } + } + } + + return {}; + } + + static CreateAssistantDto FromJson(Json::Value&& root) { + if (root.empty()) { + throw std::runtime_error("Json passed in FromJson can't be empty"); + } + CreateAssistantDto dto; + dto.model = std::move(root["model"].asString()); + if (root.isMember("name")) { + dto.name = std::move(root["name"].asString()); + } + if (root.isMember("description")) { + dto.description = std::move(root["description"].asString()); + } + if (root.isMember("instructions")) { + dto.instructions = std::move(root["instructions"].asString()); + } + if (root["metadata"].isObject() && !root["metadata"].empty()) { + auto res = Cortex::ConvertJsonValueToMap(root["metadata"]); + if (res.has_error()) { + CTL_WRN("Failed to convert metadata to map: " + res.error()); + } else { + dto.metadata = std::move(res.value()); + } + } + if (root.isMember("temperature")) { + dto.temperature = root["temperature"].asFloat(); + } + if (root.isMember("top_p")) { + dto.top_p = root["top_p"].asFloat(); + } + if (root.isMember("response_format")) { + const auto& response_format = root["response_format"]; + if (response_format.isString()) { + dto.response_format = response_format.asString(); + } else if (response_format.isObject()) { + dto.response_format = response_format; + } else { + throw std::runtime_error( + "response_format must be either a string or an object"); + } + } + return dto; + } +}; +} // namespace dto diff --git a/engine/common/dto/assistant_update_dto.h b/engine/common/dto/assistant_update_dto.h new file mode 100644 index 000000000..620cc35a5 --- /dev/null +++ b/engine/common/dto/assistant_update_dto.h @@ -0,0 +1,83 @@ +#pragma once + +#include "common/dto/base_dto.h" +#include "common/variant_map.h" +#include "utils/logging_utils.h" + +namespace dto { +struct UpdateAssistantDto : public BaseDto { + std::optional model; + + std::optional name; + + std::optional description; + + std::optional instructions; + + // namH: implement tools + + // namh: implement tool_resources + + std::optional metadata; + + std::optional temperature; + + std::optional top_p; + + std::optional> response_format; + + cpp::result Validate() const override { + if (!model.has_value() && !name.has_value() && !description.has_value() && + !instructions.has_value() && !metadata.has_value() && + !temperature.has_value() && !top_p.has_value() && + !response_format.has_value()) { + return cpp::fail("At least one field must be provided"); + } + + return {}; + } + + static UpdateAssistantDto FromJson(Json::Value&& root) { + if (root.empty()) { + throw std::runtime_error("Json passed in FromJson can't be empty"); + } + UpdateAssistantDto dto; + dto.model = std::move(root["model"].asString()); + if (root.isMember("name")) { + dto.name = std::move(root["name"].asString()); + } + if (root.isMember("description")) { + dto.description = std::move(root["description"].asString()); + } + if (root.isMember("instruction")) { + dto.instructions = std::move(root["instruction"].asString()); + } + if (root["metadata"].isObject() && !root["metadata"].empty()) { + auto res = Cortex::ConvertJsonValueToMap(root["metadata"]); + if (res.has_error()) { + CTL_WRN("Failed to convert metadata to map: " + res.error()); + } else { + dto.metadata = std::move(res.value()); + } + } + if (root.isMember("temperature")) { + dto.temperature = root["temperature"].asFloat(); + } + if (root.isMember("top_p")) { + dto.top_p = root["top_p"].asFloat(); + } + if (root.isMember("response_format")) { + const auto& response_format = root["response_format"]; + if (response_format.isString()) { + dto.response_format = response_format.asString(); + } else if (response_format.isObject()) { + dto.response_format = response_format; + } else { + throw std::runtime_error( + "response_format must be either a string or an object"); + } + } + return dto; + }; +}; +} // namespace dto diff --git a/engine/common/dto/base_dto.h b/engine/common/dto/base_dto.h new file mode 100644 index 000000000..ed7460aa3 --- /dev/null +++ b/engine/common/dto/base_dto.h @@ -0,0 +1,16 @@ +#pragma once + +#include +#include "utils/result.hpp" + +namespace dto { +template +struct BaseDto { + virtual ~BaseDto() = default; + + /** + * Validate itself. + */ + virtual cpp::result Validate() const = 0; +}; +} // namespace dto diff --git a/engine/common/message_attachment.h b/engine/common/message_attachment.h index 767ec9bea..6a0fb02e9 100644 --- a/engine/common/message_attachment.h +++ b/engine/common/message_attachment.h @@ -4,22 +4,27 @@ #include "common/json_serializable.h" namespace OpenAi { - // The tools to add this file to. struct Tool { std::string type; Tool(const std::string& type) : type{type} {} + + virtual ~Tool() = default; }; // The type of tool being defined: code_interpreter -struct CodeInterpreter : Tool { - CodeInterpreter() : Tool{"code_interpreter"} {} +struct MessageCodeInterpreter : Tool { + MessageCodeInterpreter() : Tool{"code_interpreter"} {} + + ~MessageCodeInterpreter() = default; }; // The type of tool being defined: file_search -struct FileSearch : Tool { - FileSearch() : Tool{"file_search"} {} +struct MessageFileSearch : Tool { + MessageFileSearch() : Tool{"file_search"} {} + + ~MessageFileSearch() = default; }; // A list of files attached to the message, and the tools they were added to. diff --git a/engine/common/repository/assistant_repository.h b/engine/common/repository/assistant_repository.h new file mode 100644 index 000000000..d0ff1908d --- /dev/null +++ b/engine/common/repository/assistant_repository.h @@ -0,0 +1,25 @@ +#pragma once + +#include "common/assistant.h" +#include "utils/result.hpp" + +class AssistantRepository { + public: + virtual cpp::result, std::string> + ListAssistants(uint8_t limit, const std::string& order, + const std::string& after, const std::string& before) const = 0; + + virtual cpp::result CreateAssistant( + OpenAi::Assistant& assistant) = 0; + + virtual cpp::result RetrieveAssistant( + const std::string assistant_id) const = 0; + + virtual cpp::result ModifyAssistant( + OpenAi::Assistant& assistant) = 0; + + virtual cpp::result DeleteAssistant( + const std::string& assitant_id) = 0; + + virtual ~AssistantRepository() = default; +}; diff --git a/engine/common/thread.h b/engine/common/thread.h index 2bd5d866b..dc57ba32d 100644 --- a/engine/common/thread.h +++ b/engine/common/thread.h @@ -4,7 +4,7 @@ #include #include #include "common/assistant.h" -#include "common/thread_tool_resources.h" +#include "common/tool_resources.h" #include "common/variant_map.h" #include "json_serializable.h" #include "utils/logging_utils.h" @@ -36,7 +36,7 @@ struct Thread : JsonSerializable { * of tool. For example, the code_interpreter tool requires a list of * file IDs, while the file_search tool requires a list of vector store IDs. */ - std::unique_ptr tool_resources; + std::unique_ptr tool_resources; /** * Set of 16 key-value pairs that can be attached to an object. @@ -65,7 +65,7 @@ struct Thread : JsonSerializable { const auto& tool_json = json["tool_resources"]; if (tool_json.isMember("code_interpreter")) { - auto code_interpreter = std::make_unique(); + auto code_interpreter = std::make_unique(); const auto& file_ids = tool_json["code_interpreter"]["file_ids"]; if (file_ids.isArray()) { for (const auto& file_id : file_ids) { @@ -74,7 +74,7 @@ struct Thread : JsonSerializable { } thread.tool_resources = std::move(code_interpreter); } else if (tool_json.isMember("file_search")) { - auto file_search = std::make_unique(); + auto file_search = std::make_unique(); const auto& store_ids = tool_json["file_search"]["vector_store_ids"]; if (store_ids.isArray()) { for (const auto& store_id : store_ids) { @@ -148,10 +148,10 @@ struct Thread : JsonSerializable { Json::Value tool_json; if (auto code_interpreter = - dynamic_cast(tool_resources.get())) { + dynamic_cast(tool_resources.get())) { tool_json["code_interpreter"] = tool_result.value(); } else if (auto file_search = - dynamic_cast(tool_resources.get())) { + dynamic_cast(tool_resources.get())) { tool_json["file_search"] = tool_result.value(); } json["tool_resources"] = tool_json; diff --git a/engine/common/thread_tool_resources.h b/engine/common/thread_tool_resources.h deleted file mode 100644 index 3c22a4480..000000000 --- a/engine/common/thread_tool_resources.h +++ /dev/null @@ -1,50 +0,0 @@ -#pragma once - -#include -#include -#include "common/json_serializable.h" - -namespace OpenAi { - -struct ThreadToolResources : JsonSerializable { - ~ThreadToolResources() = default; - - virtual cpp::result ToJson() override = 0; -}; - -struct ThreadCodeInterpreter : ThreadToolResources { - std::vector file_ids; - - cpp::result ToJson() override { - try { - Json::Value json; - Json::Value file_ids_json{Json::arrayValue}; - for (auto& file_id : file_ids) { - file_ids_json.append(file_id); - } - json["file_ids"] = file_ids_json; - return json; - } catch (const std::exception& e) { - return cpp::fail(std::string("ToJson failed: ") + e.what()); - } - } -}; - -struct ThreadFileSearch : ThreadToolResources { - std::vector vector_store_ids; - - cpp::result ToJson() override { - try { - Json::Value json; - Json::Value vector_store_ids_json{Json::arrayValue}; - for (auto& vector_store_id : vector_store_ids) { - vector_store_ids_json.append(vector_store_id); - } - json["vector_store_ids"] = vector_store_ids_json; - return json; - } catch (const std::exception& e) { - return cpp::fail(std::string("ToJson failed: ") + e.what()); - } - } -}; -} // namespace OpenAi diff --git a/engine/common/tool_resources.h b/engine/common/tool_resources.h new file mode 100644 index 000000000..c026ef60d --- /dev/null +++ b/engine/common/tool_resources.h @@ -0,0 +1,100 @@ +#pragma once + +#include +#include +#include "common/json_serializable.h" + +namespace OpenAi { + +struct ToolResources : JsonSerializable { + virtual ~ToolResources() = default; + + virtual cpp::result ToJson() override = 0; +}; + +struct CodeInterpreter : ToolResources { + CodeInterpreter() = default; + + ~CodeInterpreter() override = default; + + CodeInterpreter(const CodeInterpreter&) = delete; + + CodeInterpreter& operator=(const CodeInterpreter&) = delete; + + CodeInterpreter(CodeInterpreter&& other) noexcept + : file_ids(std::move(other.file_ids)) {} + + CodeInterpreter& operator=(CodeInterpreter&& other) noexcept { + if (this != &other) { + file_ids = std::move(other.file_ids); + } + return *this; + } + + std::vector file_ids; + + static cpp::result FromJson( + const Json::Value& json) { + CodeInterpreter code_interpreter; + if (json.isMember("file_ids")) { + for (const auto& file_id : json["file_ids"]) { + code_interpreter.file_ids.push_back(file_id.asString()); + } + } + return code_interpreter; + } + + cpp::result ToJson() override { + Json::Value json; + Json::Value file_ids_json{Json::arrayValue}; + for (auto& file_id : file_ids) { + file_ids_json.append(file_id); + } + json["file_ids"] = file_ids_json; + return json; + } +}; + +struct FileSearch : ToolResources { + FileSearch() = default; + + ~FileSearch() override = default; + + FileSearch(const FileSearch&) = delete; + + FileSearch& operator=(const FileSearch&) = delete; + + FileSearch(FileSearch&& other) noexcept + : vector_store_ids{std::move(other.vector_store_ids)} {} + + FileSearch& operator=(FileSearch&& other) noexcept { + if (this != &other) { + vector_store_ids = std::move(other.vector_store_ids); + } + return *this; + } + + std::vector vector_store_ids; + + static cpp::result FromJson( + const Json::Value& json) { + FileSearch file_search; + if (json.isMember("vector_store_ids")) { + for (const auto& vector_store_id : json["vector_store_ids"]) { + file_search.vector_store_ids.push_back(vector_store_id.asString()); + } + } + return file_search; + } + + cpp::result ToJson() override { + Json::Value json; + Json::Value vector_store_ids_json{Json::arrayValue}; + for (auto& vector_store_id : vector_store_ids) { + vector_store_ids_json.append(vector_store_id); + } + json["vector_store_ids"] = vector_store_ids_json; + return json; + } +}; +} // namespace OpenAi diff --git a/engine/controllers/assistants.cc b/engine/controllers/assistants.cc index 405d7ed3c..530e180a5 100644 --- a/engine/controllers/assistants.cc +++ b/engine/controllers/assistants.cc @@ -1,4 +1,6 @@ #include "assistants.h" +#include "common/api-dto/delete_success_response.h" +#include "common/dto/assistant_create_dto.h" #include "utils/cortex_utils.h" #include "utils/logging_utils.h" @@ -6,7 +8,12 @@ void Assistants::RetrieveAssistant( const HttpRequestPtr& req, std::function&& callback, const std::string& assistant_id) const { - CTL_INF("RetrieveAssistant: " + assistant_id); + const auto& headers = req->headers(); + auto it = headers.find(kOpenAiAssistantKeyV2); + if (it != headers.end() && it->second == kOpenAiAssistantValueV2) { + return RetrieveAssistantV2(req, std::move(callback), assistant_id); + } + auto res = assistant_service_->RetrieveAssistant(assistant_id); if (res.has_error()) { Json::Value ret; @@ -33,6 +40,78 @@ void Assistants::RetrieveAssistant( } } +void Assistants::RetrieveAssistantV2( + const HttpRequestPtr& req, + std::function&& callback, + const std::string& assistant_id) const { + auto res = assistant_service_->RetrieveAssistantV2(assistant_id); + + if (res.has_error()) { + Json::Value ret; + ret["message"] = res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + } else { + auto to_json_res = res->ToJson(); + if (to_json_res.has_error()) { + CTL_ERR("Failed to convert assistant to json: " + to_json_res.error()); + Json::Value ret; + ret["message"] = to_json_res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + } else { + // TODO: namh need to use the text response because it contains model config + auto resp = + cortex_utils::CreateCortexHttpJsonResponse(res->ToJson().value()); + resp->setStatusCode(k200OK); + callback(resp); + } + } +} + +void Assistants::CreateAssistantV2( + const HttpRequestPtr& req, + std::function&& callback) { + auto json_body = req->getJsonObject(); + if (json_body == nullptr) { + Json::Value ret; + ret["message"] = "Request body can't be empty"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + auto dto = dto::CreateAssistantDto::FromJson(std::move(*json_body)); + CTL_INF("CreateAssistantV2: " << dto.model); + auto validate_res = dto.Validate(); + if (validate_res.has_error()) { + Json::Value ret; + ret["message"] = validate_res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + auto res = assistant_service_->CreateAssistantV2(dto); + if (res.has_error()) { + Json::Value ret; + ret["message"] = res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + auto to_json_res = res->ToJson(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(to_json_res.value()); + resp->setStatusCode(k200OK); + callback(resp); +} + void Assistants::CreateAssistant( const HttpRequestPtr& req, std::function&& callback, @@ -88,10 +167,55 @@ void Assistants::CreateAssistant( callback(resp); } +void Assistants::ModifyAssistantV2( + const HttpRequestPtr& req, + std::function&& callback, + const std::string& assistant_id) { + auto json_body = req->getJsonObject(); + if (json_body == nullptr) { + Json::Value ret; + ret["message"] = "Request body can't be empty"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + auto dto = dto::UpdateAssistantDto::FromJson(std::move(*json_body)); + auto validate_res = dto.Validate(); + if (validate_res.has_error()) { + Json::Value ret; + ret["message"] = validate_res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + auto res = assistant_service_->ModifyAssistantV2(assistant_id, dto); + if (res.has_error()) { + Json::Value ret; + ret["message"] = res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + auto resp = cortex_utils::CreateCortexHttpJsonResponse(res->ToJson().value()); + resp->setStatusCode(k200OK); + callback(resp); +} + void Assistants::ModifyAssistant( const HttpRequestPtr& req, std::function&& callback, const std::string& assistant_id) { + const auto& headers = req->headers(); + auto it = headers.find(kOpenAiAssistantKeyV2); + if (it != headers.end() && it->second == kOpenAiAssistantValueV2) { + return ModifyAssistantV2(req, std::move(callback), assistant_id); + } auto json_body = req->getJsonObject(); if (json_body == nullptr) { Json::Value ret; @@ -142,3 +266,62 @@ void Assistants::ModifyAssistant( resp->setStatusCode(k200OK); callback(resp); } + +void Assistants::ListAssistants( + const HttpRequestPtr& req, + std::function&& callback, + std::optional limit, std::optional order, + std::optional after, std::optional before) const { + + auto res = assistant_service_->ListAssistants( + std::stoi(limit.value_or("20")), order.value_or("desc"), + after.value_or(""), before.value_or("")); + if (res.has_error()) { + Json::Value root; + root["message"] = res.error(); + auto response = cortex_utils::CreateCortexHttpJsonResponse(root); + response->setStatusCode(k400BadRequest); + callback(response); + return; + } + + Json::Value assistant_list(Json::arrayValue); + for (auto& msg : res.value()) { + if (auto it = msg.ToJson(); it.has_value()) { + assistant_list.append(it.value()); + } else { + CTL_WRN("Failed to convert message to json: " + it.error()); + } + } + + Json::Value root; + root["object"] = "list"; + root["data"] = assistant_list; + auto response = cortex_utils::CreateCortexHttpJsonResponse(root); + response->setStatusCode(k200OK); + callback(response); +} + +void Assistants::DeleteAssistant( + const HttpRequestPtr& req, + std::function&& callback, + const std::string& assistant_id) { + auto res = assistant_service_->DeleteAssistantV2(assistant_id); + if (res.has_error()) { + Json::Value ret; + ret["message"] = res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + api_response::DeleteSuccessResponse response; + response.id = assistant_id; + response.object = "assistant.deleted"; + response.deleted = true; + auto resp = + cortex_utils::CreateCortexHttpJsonResponse(response.ToJson().value()); + resp->setStatusCode(k200OK); + callback(resp); +} diff --git a/engine/controllers/assistants.h b/engine/controllers/assistants.h index 94ddd14b1..30111bb01 100644 --- a/engine/controllers/assistants.h +++ b/engine/controllers/assistants.h @@ -7,33 +7,72 @@ using namespace drogon; class Assistants : public drogon::HttpController { + constexpr static auto kOpenAiAssistantKeyV2 = "openai-beta"; + constexpr static auto kOpenAiAssistantValueV2 = "assistants=v2"; + public: METHOD_LIST_BEGIN + ADD_METHOD_TO( + Assistants::ListAssistants, + "/v1/" + "assistants?limit={limit}&order={order}&after={after}&before={before}", + Get); + + ADD_METHOD_TO(Assistants::DeleteAssistant, "/v1/assistants/{assistant_id}", + Options, Delete); + ADD_METHOD_TO(Assistants::RetrieveAssistant, "/v1/assistants/{assistant_id}", Get); ADD_METHOD_TO(Assistants::CreateAssistant, "/v1/assistants/{assistant_id}", Options, Post); + ADD_METHOD_TO(Assistants::CreateAssistantV2, "/v1/assistants", Options, Post); + ADD_METHOD_TO(Assistants::ModifyAssistant, "/v1/assistants/{assistant_id}", Options, Patch); + METHOD_LIST_END explicit Assistants(std::shared_ptr assistant_srv) : assistant_service_{assistant_srv} {}; + void ListAssistants(const HttpRequestPtr& req, + std::function&& callback, + std::optional limit, + std::optional order, + std::optional after, + std::optional before) const; + void RetrieveAssistant(const HttpRequestPtr& req, std::function&& callback, const std::string& assistant_id) const; + void RetrieveAssistantV2( + const HttpRequestPtr& req, + std::function&& callback, + const std::string& assistant_id) const; + + void DeleteAssistant(const HttpRequestPtr& req, + std::function&& callback, + const std::string& assistant_id); + void CreateAssistant(const HttpRequestPtr& req, std::function&& callback, const std::string& assistant_id); + void CreateAssistantV2( + const HttpRequestPtr& req, + std::function&& callback); + void ModifyAssistant(const HttpRequestPtr& req, std::function&& callback, const std::string& assistant_id); + void ModifyAssistantV2(const HttpRequestPtr& req, + std::function&& callback, + const std::string& assistant_id); + private: std::shared_ptr assistant_service_; }; diff --git a/engine/main.cc b/engine/main.cc index ddf1eefd8..938392bf0 100644 --- a/engine/main.cc +++ b/engine/main.cc @@ -15,6 +15,7 @@ #include "controllers/threads.h" #include "database/database.h" #include "migrations/migration_manager.h" +#include "repositories/assistant_fs_repository.h" #include "repositories/file_fs_repository.h" #include "repositories/message_fs_repository.h" #include "repositories/thread_fs_repository.h" @@ -142,9 +143,12 @@ void RunServer(std::optional host, std::optional port, auto file_repo = std::make_shared(data_folder_path); auto msg_repo = std::make_shared(data_folder_path); auto thread_repo = std::make_shared(data_folder_path); + auto assistant_repo = + std::make_shared(data_folder_path); auto file_srv = std::make_shared(file_repo); - auto assistant_srv = std::make_shared(thread_repo); + auto assistant_srv = + std::make_shared(thread_repo, assistant_repo); auto thread_srv = std::make_shared(thread_repo); auto message_srv = std::make_shared(msg_repo); diff --git a/engine/repositories/assistant_fs_repository.cc b/engine/repositories/assistant_fs_repository.cc new file mode 100644 index 000000000..f5103f5c0 --- /dev/null +++ b/engine/repositories/assistant_fs_repository.cc @@ -0,0 +1,206 @@ +#include "assistant_fs_repository.h" +#include +#include +#include +#include "utils/result.hpp" + +cpp::result, std::string> +AssistantFsRepository::ListAssistants(uint8_t limit, const std::string& order, + const std::string& after, + const std::string& before) const { + std::vector assistants; + try { + auto assistant_container_path = + data_folder_path_ / kAssistantContainerFolderName; + std::vector all_assistants; + + for (const auto& entry : + std::filesystem::directory_iterator(assistant_container_path)) { + if (!entry.is_directory()) { + continue; + } + + auto assistant_file = entry.path() / kAssistantFileName; + if (!std::filesystem::exists(assistant_file)) { + continue; + } + + auto current_assistant_id = entry.path().filename().string(); + + if (!after.empty() && current_assistant_id <= after) { + continue; + } + + if (!before.empty() && current_assistant_id >= before) { + continue; + } + + std::shared_lock assistant_lock(GrabAssistantMutex(current_assistant_id)); + auto assistant_res = LoadAssistant(current_assistant_id); + if (assistant_res.has_value()) { + all_assistants.push_back(std::move(assistant_res.value())); + } + assistant_lock.unlock(); + } + + // sorting + if (order == "desc") { + std::sort(all_assistants.begin(), all_assistants.end(), + [](const OpenAi::Assistant& assistant1, + const OpenAi::Assistant& assistant2) { + return assistant1.created_at > assistant2.created_at; + }); + } else { + std::sort(all_assistants.begin(), all_assistants.end(), + [](const OpenAi::Assistant& assistant1, + const OpenAi::Assistant& assistant2) { + return assistant1.created_at < assistant2.created_at; + }); + } + + size_t assistant_count = + std::min(static_cast(limit), all_assistants.size()); + for (size_t i = 0; i < assistant_count; i++) { + assistants.push_back(std::move(all_assistants[i])); + } + + return assistants; + } catch (const std::exception& e) { + return cpp::fail("Failed to list assistants: " + std::string(e.what())); + } +} + +cpp::result +AssistantFsRepository::RetrieveAssistant(const std::string assistant_id) const { + std::shared_lock lock(GrabAssistantMutex(assistant_id)); + return LoadAssistant(assistant_id); +} + +cpp::result AssistantFsRepository::ModifyAssistant( + OpenAi::Assistant& assistant) { + std::unique_lock lock(GrabAssistantMutex(assistant.id)); + auto path = GetAssistantPath(assistant.id); + + if (!std::filesystem::exists(path)) { + return cpp::fail("Assistant doesn't exist: " + assistant.id); + } + + return SaveAssistant(assistant); +} + +cpp::result AssistantFsRepository::DeleteAssistant( + const std::string& assitant_id) { + { + std::unique_lock assistant_lock(GrabAssistantMutex(assitant_id)); + auto path = GetAssistantPath(assitant_id); + if (!std::filesystem::exists(path)) { + return cpp::fail("Assistant doesn't exist: " + assitant_id); + } + try { + std::filesystem::remove_all(path); + } catch (const std::exception& e) { + return cpp::fail(""); + } + } + + std::unique_lock map_lock(map_mutex_); + assistant_mutexes_.erase(assitant_id); + return {}; +} + +cpp::result +AssistantFsRepository::CreateAssistant(OpenAi::Assistant& assistant) { + std::unique_lock lock(GrabAssistantMutex(assistant.id)); + auto path = GetAssistantPath(assistant.id); + + if (std::filesystem::exists(path)) { + return cpp::fail("Assistant already exists: " + assistant.id); + } + + std::filesystem::create_directories(path); + auto assistant_file_path = path / kAssistantFileName; + std::ofstream assistant_file(assistant_file_path); + assistant_file.close(); + + auto save_result = SaveAssistant(assistant); + if (save_result.has_error()) { + return cpp::fail("Failed to save assistant: " + save_result.error()); + } + + return std::move(assistant); +} + +cpp::result AssistantFsRepository::SaveAssistant( + OpenAi::Assistant& assistant) { + auto path = GetAssistantPath(assistant.id) / kAssistantFileName; + if (!std::filesystem::exists(path)) { + std::filesystem::create_directories(path); + } + + std::ofstream file(path); + if (!file) { + return cpp::fail("Failed to open file: " + path.string()); + } + try { + file << assistant.ToJson()->toStyledString(); + file.flush(); + file.close(); + return {}; + } catch (const std::exception& e) { + file.close(); + return cpp::fail("Failed to save assistant: " + std::string(e.what())); + } +} + +std::filesystem::path AssistantFsRepository::GetAssistantPath( + const std::string& assistant_id) const { + auto container_folder_path = + data_folder_path_ / kAssistantContainerFolderName; + if (!std::filesystem::exists(container_folder_path)) { + std::filesystem::create_directories(container_folder_path); + } + + return data_folder_path_ / kAssistantContainerFolderName / assistant_id; +} + +std::shared_mutex& AssistantFsRepository::GrabAssistantMutex( + const std::string& assistant_id) const { + std::shared_lock map_lock(map_mutex_); + auto it = assistant_mutexes_.find(assistant_id); + if (it != assistant_mutexes_.end()) { + return *it->second; + } + + map_lock.unlock(); + std::unique_lock map_write_lock(map_mutex_); + return *assistant_mutexes_ + .try_emplace(assistant_id, std::make_unique()) + .first->second; +} + +cpp::result +AssistantFsRepository::LoadAssistant(const std::string& assistant_id) const { + auto path = GetAssistantPath(assistant_id) / kAssistantFileName; + if (!std::filesystem::exists(path)) { + return cpp::fail("Path does not exist: " + path.string()); + } + + try { + std::ifstream file(path); + if (!file.is_open()) { + return cpp::fail("Failed to open file: " + path.string()); + } + + Json::Value root; + Json::CharReaderBuilder builder; + JSONCPP_STRING errs; + + if (!parseFromStream(builder, file, &root, &errs)) { + return cpp::fail("Failed to parse JSON: " + errs); + } + + return OpenAi::Assistant::FromJson(std::move(root)); + } catch (const std::exception& e) { + return cpp::fail("Failed to load assistant: " + std::string(e.what())); + } +} diff --git a/engine/repositories/assistant_fs_repository.h b/engine/repositories/assistant_fs_repository.h new file mode 100644 index 000000000..f310bd54e --- /dev/null +++ b/engine/repositories/assistant_fs_repository.h @@ -0,0 +1,59 @@ +#pragma once + +#include +#include + +#include "common/repository/assistant_repository.h" + +class AssistantFsRepository : public AssistantRepository { + public: + constexpr static auto kAssistantContainerFolderName = "assistants"; + constexpr static auto kAssistantFileName = "assistant.json"; + + cpp::result, std::string> ListAssistants( + uint8_t limit, const std::string& order, const std::string& after, + const std::string& before) const override; + + cpp::result CreateAssistant( + OpenAi::Assistant& assistant) override; + + cpp::result RetrieveAssistant( + const std::string assistant_id) const override; + + cpp::result ModifyAssistant( + OpenAi::Assistant& assistant) override; + + cpp::result DeleteAssistant( + const std::string& assitant_id) override; + + explicit AssistantFsRepository(const std::filesystem::path& data_folder_path) + : data_folder_path_{data_folder_path} { + CTL_INF("Constructing AssistantFsRepository.."); + auto path = data_folder_path_ / kAssistantContainerFolderName; + + if (!std::filesystem::exists(path)) { + std::filesystem::create_directories(path); + } + } + + ~AssistantFsRepository() = default; + + private: + std::filesystem::path GetAssistantPath(const std::string& assistant_id) const; + + std::shared_mutex& GrabAssistantMutex(const std::string& assistant_id) const; + + cpp::result SaveAssistant(OpenAi::Assistant& assistant); + + cpp::result LoadAssistant( + const std::string& assistant_id) const; + + /** + * The path to the data folder. + */ + std::filesystem::path data_folder_path_; + + mutable std::shared_mutex map_mutex_; + mutable std::unordered_map> + assistant_mutexes_; +}; diff --git a/engine/repositories/file_fs_repository.h b/engine/repositories/file_fs_repository.h index 974e81fa4..77af60dfc 100644 --- a/engine/repositories/file_fs_repository.h +++ b/engine/repositories/file_fs_repository.h @@ -28,7 +28,7 @@ class FileFsRepository : public FileRepository { cpp::result DeleteFileLocal( const std::string& file_id) override; - explicit FileFsRepository(std::filesystem::path data_folder_path) + explicit FileFsRepository(const std::filesystem::path& data_folder_path) : data_folder_path_{data_folder_path} { CTL_INF("Constructing FileFsRepository.."); auto file_container_path = data_folder_path_ / kFileContainerFolderName; diff --git a/engine/repositories/message_fs_repository.h b/engine/repositories/message_fs_repository.h index 2146778bf..0ca6e89b3 100644 --- a/engine/repositories/message_fs_repository.h +++ b/engine/repositories/message_fs_repository.h @@ -32,7 +32,7 @@ class MessageFsRepository : public MessageRepository { const std::string& thread_id, std::optional> messages) override; - explicit MessageFsRepository(std::filesystem::path data_folder_path) + explicit MessageFsRepository(const std::filesystem::path& data_folder_path) : data_folder_path_{data_folder_path} { CTL_INF("Constructing MessageFsRepository.."); auto thread_container_path = data_folder_path_ / kThreadContainerFolderName; diff --git a/engine/services/assistant_service.cc b/engine/services/assistant_service.cc index e769bf23f..4ff8cb5df 100644 --- a/engine/services/assistant_service.cc +++ b/engine/services/assistant_service.cc @@ -1,5 +1,6 @@ #include "assistant_service.h" #include "utils/logging_utils.h" +#include "utils/ulid_generator.h" cpp::result AssistantService::CreateAssistant(const std::string& thread_id, @@ -26,3 +27,119 @@ AssistantService::ModifyAssistant(const std::string& thread_id, CTL_INF("RetrieveAssistant: " + thread_id); return thread_repository_->ModifyAssistant(thread_id, assistant); } + +cpp::result, std::string> +AssistantService::ListAssistants(uint8_t limit, const std::string& order, + const std::string& after, + const std::string& before) const { + CTL_INF("List assistants invoked"); + return assistant_repository_->ListAssistants(limit, order, after, before); +} + +cpp::result AssistantService::CreateAssistantV2( + const dto::CreateAssistantDto& create_dto) { + + OpenAi::Assistant assistant; + assistant.id = "asst_" + ulid::GenerateUlid(); + assistant.model = create_dto.model; + if (create_dto.name) { + assistant.name = *create_dto.name; + } + if (create_dto.description) { + assistant.description = *create_dto.description; + } + if (create_dto.instructions) { + assistant.instructions = *create_dto.instructions; + } + if (create_dto.metadata) { + assistant.metadata = *create_dto.metadata; + } + if (create_dto.temperature) { + assistant.temperature = *create_dto.temperature; + } + if (create_dto.top_p) { + assistant.top_p = *create_dto.top_p; + } + if (create_dto.response_format) { + assistant.response_format = *create_dto.response_format; + } + auto seconds_since_epoch = + std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count(); + assistant.created_at = seconds_since_epoch; + return assistant_repository_->CreateAssistant(assistant); +} + +cpp::result +AssistantService::RetrieveAssistantV2(const std::string& assistant_id) const { + if (assistant_id.empty()) { + return cpp::failure("Assistant ID cannot be empty"); + } + + return assistant_repository_->RetrieveAssistant(assistant_id); +} + +cpp::result AssistantService::ModifyAssistantV2( + const std::string& assistant_id, + const dto::UpdateAssistantDto& update_dto) { + if (assistant_id.empty()) { + return cpp::fail("Assistant ID cannot be empty"); + } + + if (!update_dto.Validate()) { + return cpp::fail("Invalid update assistant dto"); + } + + // First retrieve the existing assistant + auto existing_assistant = + assistant_repository_->RetrieveAssistant(assistant_id); + if (existing_assistant.has_error()) { + return cpp::fail(existing_assistant.error()); + } + + OpenAi::Assistant updated_assistant; + updated_assistant.id = assistant_id; + + // Update fields if they are present in the DTO + if (update_dto.model) { + updated_assistant.model = *update_dto.model; + } + if (update_dto.name) { + updated_assistant.name = *update_dto.name; + } + if (update_dto.description) { + updated_assistant.description = *update_dto.description; + } + if (update_dto.instructions) { + updated_assistant.instructions = *update_dto.instructions; + } + if (update_dto.metadata) { + updated_assistant.metadata = *update_dto.metadata; + } + if (update_dto.temperature) { + updated_assistant.temperature = *update_dto.temperature; + } + if (update_dto.top_p) { + updated_assistant.top_p = *update_dto.top_p; + } + if (update_dto.response_format) { + updated_assistant.response_format = *update_dto.response_format; + } + + auto res = assistant_repository_->ModifyAssistant(updated_assistant); + if (res.has_error()) { + return cpp::fail(res.error()); + } + + return updated_assistant; +} + +cpp::result AssistantService::DeleteAssistantV2( + const std::string& assistant_id) { + if (assistant_id.empty()) { + return cpp::fail("Assistant ID cannot be empty"); + } + + return assistant_repository_->DeleteAssistant(assistant_id); +} diff --git a/engine/services/assistant_service.h b/engine/services/assistant_service.h index e7f7414d1..ad31104ff 100644 --- a/engine/services/assistant_service.h +++ b/engine/services/assistant_service.h @@ -1,15 +1,14 @@ #pragma once #include "common/assistant.h" +#include "common/dto/assistant_create_dto.h" +#include "common/dto/assistant_update_dto.h" +#include "common/repository/assistant_repository.h" #include "repositories/thread_fs_repository.h" #include "utils/result.hpp" class AssistantService { public: - explicit AssistantService( - std::shared_ptr thread_repository) - : thread_repository_{thread_repository} {} - cpp::result CreateAssistant( const std::string& thread_id, const OpenAi::JanAssistant& assistant); @@ -19,6 +18,31 @@ class AssistantService { cpp::result ModifyAssistant( const std::string& thread_id, const OpenAi::JanAssistant& assistant); + // V2 + cpp::result CreateAssistantV2( + const dto::CreateAssistantDto& create_dto); + + cpp::result, std::string> ListAssistants( + uint8_t limit, const std::string& order, const std::string& after, + const std::string& before) const; + + cpp::result RetrieveAssistantV2( + const std::string& assistant_id) const; + + cpp::result ModifyAssistantV2( + const std::string& assistant_id, + const dto::UpdateAssistantDto& update_dto); + + cpp::result DeleteAssistantV2( + const std::string& assistant_id); + + explicit AssistantService( + std::shared_ptr thread_repository, + std::shared_ptr assistant_repository) + : thread_repository_{thread_repository}, + assistant_repository_{assistant_repository} {} + private: std::shared_ptr thread_repository_; + std::shared_ptr assistant_repository_; }; diff --git a/engine/services/thread_service.cc b/engine/services/thread_service.cc index 25784c2ee..827c4ea83 100644 --- a/engine/services/thread_service.cc +++ b/engine/services/thread_service.cc @@ -3,7 +3,7 @@ #include "utils/ulid/ulid.hh" cpp::result ThreadService::CreateThread( - std::unique_ptr tool_resources, + std::unique_ptr tool_resources, std::optional metadata) { LOG_TRACE << "CreateThread"; @@ -48,7 +48,7 @@ cpp::result ThreadService::RetrieveThread( cpp::result ThreadService::ModifyThread( const std::string& thread_id, - std::unique_ptr tool_resources, + std::unique_ptr tool_resources, std::optional metadata) { LOG_TRACE << "ModifyThread " << thread_id; auto retrieve_res = RetrieveThread(thread_id); diff --git a/engine/services/thread_service.h b/engine/services/thread_service.h index 966b0ab01..7011f46f3 100644 --- a/engine/services/thread_service.h +++ b/engine/services/thread_service.h @@ -2,7 +2,6 @@ #include #include "common/repository/thread_repository.h" -#include "common/thread_tool_resources.h" #include "common/variant_map.h" #include "utils/result.hpp" @@ -12,7 +11,7 @@ class ThreadService { : thread_repository_{thread_repository} {} cpp::result CreateThread( - std::unique_ptr tool_resources, + std::unique_ptr tool_resources, std::optional metadata); cpp::result, std::string> ListThreads( @@ -24,7 +23,7 @@ class ThreadService { cpp::result ModifyThread( const std::string& thread_id, - std::unique_ptr tool_resources, + std::unique_ptr tool_resources, std::optional metadata); cpp::result DeleteThread(