Skip to content

Commit

Permalink
feat: add openai assistant (#1826)
Browse files Browse the repository at this point in the history
  • Loading branch information
namchuai authored Dec 27, 2024
1 parent 3456c7b commit f94527f
Show file tree
Hide file tree
Showing 30 changed files with 3,289 additions and 199 deletions.
542 changes: 497 additions & 45 deletions docs/static/openapi/cortex.json

Large diffs are not rendered by default.

271 changes: 267 additions & 4 deletions engine/common/assistant.h
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
#pragma once

#include <string>
#include "common/assistant_code_interpreter_tool.h"
#include "common/assistant_file_search_tool.h"
#include "common/assistant_function_tool.h"
#include "common/assistant_tool.h"
#include "common/thread_tool_resources.h"
#include "common/tool_resources.h"
#include "common/variant_map.h"
#include "utils/logging_utils.h"
#include "utils/result.hpp"

namespace OpenAi {
Expand Down Expand Up @@ -75,7 +79,49 @@ struct JanAssistant : JsonSerializable {
}
};

struct Assistant {
struct Assistant : JsonSerializable {
Assistant() = default;

~Assistant() = default;

Assistant(const Assistant&) = delete;

Assistant& operator=(const Assistant&) = delete;

Assistant(Assistant&& other) noexcept
: id{std::move(other.id)},
object{std::move(other.object)},
created_at{other.created_at},
name{std::move(other.name)},
description{std::move(other.description)},
model(std::move(other.model)),
instructions(std::move(other.instructions)),
tools(std::move(other.tools)),
tool_resources(std::move(other.tool_resources)),
metadata(std::move(other.metadata)),
temperature{std::move(other.temperature)},
top_p{std::move(other.top_p)},
response_format{std::move(other.response_format)} {}

Assistant& operator=(Assistant&& other) noexcept {
if (this != &other) {
id = std::move(other.id);
object = std::move(other.object);
created_at = other.created_at;
name = std::move(other.name);
description = std::move(other.description);
model = std::move(other.model);
instructions = std::move(other.instructions);
tools = std::move(other.tools);
tool_resources = std::move(other.tool_resources);
metadata = std::move(other.metadata);
temperature = std::move(other.temperature);
top_p = std::move(other.top_p);
response_format = std::move(other.response_format);
}
return *this;
}

/**
* The identifier, which can be referenced in API endpoints.
*/
Expand Down Expand Up @@ -126,8 +172,7 @@ struct Assistant {
* requires a list of file IDs, while the file_search tool requires a list
* of vector store IDs.
*/
std::optional<std::variant<ThreadCodeInterpreter, ThreadFileSearch>>
tool_resources;
std::unique_ptr<OpenAi::ToolResources> tool_resources;

/**
* Set of 16 key-value pairs that can be attached to an object. This can be
Expand All @@ -153,5 +198,223 @@ struct Assistant {
* We generally recommend altering this or temperature but not both.
*/
std::optional<float> top_p;

std::variant<std::string, Json::Value> response_format;

cpp::result<Json::Value, std::string> ToJson() override {
try {
Json::Value root;

root["id"] = std::move(id);
root["object"] = "assistant";
root["created_at"] = created_at;
if (name.has_value()) {
root["name"] = name.value();
}
if (description.has_value()) {
root["description"] = description.value();
}
root["model"] = model;
if (instructions.has_value()) {
root["instructions"] = instructions.value();
}

Json::Value tools_jarr{Json::arrayValue};
for (auto& tool_ptr : tools) {
if (auto it = tool_ptr->ToJson(); it.has_value()) {
tools_jarr.append(it.value());
} else {
CTL_WRN("Failed to convert content to json: " + it.error());
}
}
root["tools"] = tools_jarr;
if (tool_resources) {
Json::Value tool_resources_json{Json::objectValue};

if (auto* code_interpreter =
dynamic_cast<OpenAi::CodeInterpreter*>(tool_resources.get())) {
auto result = code_interpreter->ToJson();
if (result.has_value()) {
tool_resources_json["code_interpreter"] = result.value();
} else {
CTL_WRN("Failed to convert code_interpreter to json: " +
result.error());
}
} else if (auto* file_search = dynamic_cast<OpenAi::FileSearch*>(
tool_resources.get())) {
auto result = file_search->ToJson();
if (result.has_value()) {
tool_resources_json["file_search"] = result.value();
} else {
CTL_WRN("Failed to convert file_search to json: " + result.error());
}
}

// Only add tool_resources to root if we successfully serialized some resources
if (!tool_resources_json.empty()) {
root["tool_resources"] = tool_resources_json;
}
}
Json::Value metadata_json{Json::objectValue};
for (const auto& [key, value] : metadata) {
if (std::holds_alternative<bool>(value)) {
metadata_json[key] = std::get<bool>(value);
} else if (std::holds_alternative<uint64_t>(value)) {
metadata_json[key] = std::get<uint64_t>(value);
} else if (std::holds_alternative<double>(value)) {
metadata_json[key] = std::get<double>(value);
} else {
metadata_json[key] = std::get<std::string>(value);
}
}
root["metadata"] = metadata_json;

if (temperature.has_value()) {
root["temperature"] = temperature.value();
}
if (top_p.has_value()) {
root["top_p"] = top_p.value();
}
return root;
} catch (const std::exception& e) {
return cpp::fail("ToJson failed: " + std::string(e.what()));
}
}

static cpp::result<Assistant, std::string> FromJson(Json::Value&& json) {
try {
Assistant assistant;

// Parse required fields
if (!json.isMember("id") || !json["id"].isString()) {
return cpp::fail("Missing or invalid 'id' field");
}
assistant.id = json["id"].asString();

if (!json.isMember("object") || !json["object"].isString() ||
json["object"].asString() != "assistant") {
return cpp::fail("Missing or invalid 'object' field");
}

if (!json.isMember("created_at") || !json["created_at"].isUInt64()) {
return cpp::fail("Missing or invalid 'created_at' field");
}
assistant.created_at = json["created_at"].asUInt64();

if (!json.isMember("model") || !json["model"].isString()) {
return cpp::fail("Missing or invalid 'model' field");
}
assistant.model = json["model"].asString();

// Parse optional fields
if (json.isMember("name") && json["name"].isString()) {
assistant.name = json["name"].asString();
}

if (json.isMember("description") && json["description"].isString()) {
assistant.description = json["description"].asString();
}

if (json.isMember("instructions") && json["instructions"].isString()) {
assistant.instructions = json["instructions"].asString();
}

// Parse tools array
if (json.isMember("tools") && json["tools"].isArray()) {
auto tools_array = json["tools"];
for (const auto& tool : tools_array) {
if (!tool.isMember("type") || !tool["type"].isString()) {
CTL_WRN("Tool missing type field or invalid type");
continue;
}

std::string tool_type = tool["type"].asString();
if (tool_type == "file_search") {
auto result = AssistantFileSearchTool::FromJson(tool);
if (result.has_value()) {
assistant.tools.push_back(
std::make_unique<AssistantFileSearchTool>(
std::move(result.value())));
} else {
CTL_WRN("Failed to parse file_search tool: " + result.error());
}
} else if (tool_type == "code_interpreter") {
auto result = AssistantCodeInterpreterTool::FromJson();
if (result.has_value()) {
assistant.tools.push_back(
std::make_unique<AssistantCodeInterpreterTool>(
std::move(result.value())));
} else {
CTL_WRN("Failed to parse code_interpreter tool: " +
result.error());
}
} else if (tool_type == "function") {
auto result = AssistantFunctionTool::FromJson(tool);
if (result.has_value()) {
assistant.tools.push_back(std::make_unique<AssistantFunctionTool>(
std::move(result.value())));
} else {
CTL_WRN("Failed to parse function tool: " + result.error());
}
} else {
CTL_WRN("Unknown tool type: " + tool_type);
}
}
}

if (json.isMember("tool_resources") &&
json["tool_resources"].isObject()) {
const auto& tool_resources_json = json["tool_resources"];

// Parse code interpreter resources
if (tool_resources_json.isMember("code_interpreter")) {
auto result = OpenAi::CodeInterpreter::FromJson(
tool_resources_json["code_interpreter"]);
if (result.has_value()) {
assistant.tool_resources =
std::make_unique<OpenAi::CodeInterpreter>(
std::move(result.value()));
} else {
CTL_WRN("Failed to parse code_interpreter resources: " +
result.error());
}
}

// Parse file search resources
if (tool_resources_json.isMember("file_search")) {
auto result =
OpenAi::FileSearch::FromJson(tool_resources_json["file_search"]);
if (result.has_value()) {
assistant.tool_resources =
std::make_unique<OpenAi::FileSearch>(std::move(result.value()));
} else {
CTL_WRN("Failed to parse file_search resources: " + result.error());
}
}
}

// Parse metadata
if (json.isMember("metadata") && json["metadata"].isObject()) {
auto res = Cortex::ConvertJsonValueToMap(json["metadata"]);
if (res.has_value()) {
assistant.metadata = res.value();
} else {
CTL_WRN("Failed to convert metadata to map: " + res.error());
}
}

if (json.isMember("temperature") && json["temperature"].isDouble()) {
assistant.temperature = json["temperature"].asFloat();
}

if (json.isMember("top_p") && json["top_p"].isDouble()) {
assistant.top_p = json["top_p"].asFloat();
}

return assistant;
} catch (const std::exception& e) {
return cpp::fail("FromJson failed: " + std::string(e.what()));
}
}
};
} // namespace OpenAi
32 changes: 32 additions & 0 deletions engine/common/assistant_code_interpreter_tool.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#pragma once

#include "common/assistant_tool.h"

namespace OpenAi {
struct AssistantCodeInterpreterTool : public AssistantTool {
AssistantCodeInterpreterTool() : AssistantTool("code_interpreter") {}

AssistantCodeInterpreterTool(const AssistantCodeInterpreterTool&) = delete;

AssistantCodeInterpreterTool& operator=(const AssistantCodeInterpreterTool&) =
delete;

AssistantCodeInterpreterTool(AssistantCodeInterpreterTool&&) = default;

AssistantCodeInterpreterTool& operator=(AssistantCodeInterpreterTool&&) =
default;

~AssistantCodeInterpreterTool() = default;

static cpp::result<AssistantCodeInterpreterTool, std::string> FromJson() {
AssistantCodeInterpreterTool tool;
return std::move(tool);
}

cpp::result<Json::Value, std::string> ToJson() override {
Json::Value json;
json["type"] = type;
return json;
}
};
} // namespace OpenAi
Loading

0 comments on commit f94527f

Please sign in to comment.