Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
namchuai committed Dec 25, 2024
1 parent 6976e8d commit c67dc4e
Show file tree
Hide file tree
Showing 20 changed files with 511 additions and 164 deletions.
120 changes: 63 additions & 57 deletions engine/common/assistant.h
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
#pragma once

#include <string>
#include "common/assistant_code_interpreter_tool.h"
#include "common/assistant_file_search_tool.h"
#include "common/assistant_function_tool.h"
#include "common/assistant_tool.h"
#include "common/message_attachment.h"
#include "common/thread_tool_resources.h"
#include "common/tool_resources.h"
#include "common/variant_map.h"
#include "utils/logging_utils.h"
#include "utils/result.hpp"
Expand Down Expand Up @@ -87,18 +89,19 @@ struct Assistant : JsonSerializable {
Assistant& operator=(const Assistant&) = delete;

Assistant(Assistant&& other) noexcept
: id(std::move(other.id)),
object(std::move(other.object)),
created_at(other.created_at),
name(std::move(other.name)),
description(std::move(other.description)),
: id{std::move(other.id)},
object{std::move(other.object)},
created_at{other.created_at},
name{std::move(other.name)},
description{std::move(other.description)},
model(std::move(other.model)),
instructions(std::move(other.instructions)),
tools(std::move(other.tools)),
tool_resources(std::move(other.tool_resources)),
metadata(std::move(other.metadata)),
temperature(std::move(other.temperature)),
top_p(std::move(other.top_p)) {}
temperature{std::move(other.temperature)},
top_p{std::move(other.top_p)},
response_format{std::move(other.response_format)} {}

Assistant& operator=(Assistant&& other) noexcept {
if (this != &other) {
Expand All @@ -114,6 +117,7 @@ struct Assistant : JsonSerializable {
metadata = std::move(other.metadata);
temperature = std::move(other.temperature);
top_p = std::move(other.top_p);
response_format = std::move(other.response_format);
}
return *this;
}
Expand Down Expand Up @@ -168,8 +172,7 @@ struct Assistant : JsonSerializable {
* requires a list of file IDs, while the file_search tool requires a list
* of vector store IDs.
*/
std::optional<std::variant<ThreadCodeInterpreter, ThreadFileSearch>>
tool_resources;
std::optional<std::variant<CodeInterpreter, FileSearch>> tool_resources;

/**
* Set of 16 key-value pairs that can be attached to an object. This can be
Expand All @@ -196,6 +199,8 @@ struct Assistant : JsonSerializable {
*/
std::optional<float> top_p;

std::variant<std::string, Json::Value> response_format;

cpp::result<Json::Value, std::string> ToJson() override {
try {
Json::Value root;
Expand Down Expand Up @@ -226,15 +231,15 @@ struct Assistant : JsonSerializable {
Json::Value tool_resources_json;

if (auto* code_interpreter =
std::get_if<ThreadCodeInterpreter>(&tool_resources.value())) {
std::get_if<CodeInterpreter>(&tool_resources.value())) {
if (auto result = code_interpreter->ToJson(); result.has_value()) {
tool_resources_json["code_interpreter"] = result.value();
} else {
CTL_WRN("Failed to convert code_interpreter to json: " +
result.error());
}
} else if (auto* file_search =
std::get_if<ThreadFileSearch>(&tool_resources.value())) {
std::get_if<FileSearch>(&tool_resources.value())) {
if (auto result = file_search->ToJson(); result.has_value()) {
tool_resources_json["file_search"] = result.value();
} else {
Expand Down Expand Up @@ -312,60 +317,61 @@ struct Assistant : JsonSerializable {

// Parse tools array
if (json.isMember("tools") && json["tools"].isArray()) {
// TODO: namh implement
// for (const auto& tool_json : json["tools"]) {
// auto tool = AssistantTool::FromJson(tool_json);
// if (!tool.has_value()) {
// return cpp::fail("Failed to parse tool: " + tool.error());
// }
// assistant.tools.push_back(std::move(tool.value()));
// }
auto tools_array = json["tools"];
for (const auto& tool : tools_array) {
if (!tool.isMember("type") || !tool["type"].isString()) {
CTL_WRN("Tool missing type field or invalid type");
continue;
}

std::string tool_type = tool["type"].asString();
if (tool_type == "file_search") {
auto result = AssistantFileSearchTool::FromJson(tool);
if (result.has_value()) {
assistant.tools.push_back(
std::make_unique<AssistantFileSearchTool>(
std::move(result.value())));
} else {
CTL_WRN("Failed to parse file_search tool: " + result.error());
}
} else if (tool_type == "code_interpreter") {
auto result = AssistantCodeInterpreterTool::FromJson(tool);
if (result.has_value()) {
assistant.tools.push_back(
std::make_unique<AssistantCodeInterpreterTool>(
std::move(result.value())));
} else {
CTL_WRN("Failed to parse code_interpreter tool: " +
result.error());
}
} else if (tool_type == "function") {
auto result = AssistantFunctionTool::FromJson(tool);
if (result.has_value()) {
assistant.tools.push_back(std::make_unique<AssistantFunctionTool>(
std::move(result.value())));
} else {
CTL_WRN("Failed to parse function tool: " + result.error());
}
} else {
CTL_WRN("Unknown tool type: " + tool_type);
}
}
}

// Parse tool_resources
if (json.isMember("tool_resources") &&
json["tool_resources"].isObject()) {
// const auto& resources_json = json["tool_resources"];
//
// if (resources_json.isMember("code_interpreter")) {
// auto code_interpreter = ThreadCodeInterpreter::FromJson(
// resources_json["code_interpreter"]);
// if (!code_interpreter.has_value()) {
// return cpp::fail("Failed to parse code_interpreter: " +
// code_interpreter.error());
// }
// assistant.tool_resources = std::move(code_interpreter.value());
// } else if (resources_json.isMember("file_search")) {
// auto file_search =
// ThreadFileSearch::FromJson(resources_json["file_search"]);
// if (!file_search.has_value()) {
// return cpp::fail("Failed to parse file_search: " +
// file_search.error());
// }
// assistant.tool_resources = std::move(file_search.value());
// }
}
json["tool_resources"].isObject()) {}

// Parse metadata
if (json.isMember("metadata") && json["metadata"].isObject()) {
const auto& metadata_json = json["metadata"];
for (const auto& key : metadata_json.getMemberNames()) {
const auto& value = metadata_json[key];
if (value.isBool()) {
assistant.metadata[key] = value.asBool();
} else if (value.isUInt64()) {
assistant.metadata[key] = value.asUInt64();
} else if (value.isDouble()) {
assistant.metadata[key] = value.asDouble();
} else if (value.isString()) {
assistant.metadata[key] = value.asString();
} else {
return cpp::fail("Invalid metadata value type for key: " + key);
}
auto res = Cortex::ConvertJsonValueToMap(json["metadata"]);
if (res.has_value()) {
assistant.metadata = res.value();
} else {
CTL_WRN("Failed to convert metadata to map: " + res.error());
}
}

// Parse optional numerical fields
if (json.isMember("temperature") && json["temperature"].isDouble()) {
assistant.temperature = json["temperature"].asFloat();
}
Expand Down
39 changes: 28 additions & 11 deletions engine/common/dto/assistant_create_dto.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ struct CreateAssistantDto : public BaseDto<CreateAssistantDto> {
: model{std::move(other.model)},
name{std::move(other.name)},
description{std::move(other.description)},
instruction{std::move(other.instruction)},
instructions{std::move(other.instructions)},
metadata{std::move(other.metadata)},
temperature{std::move(other.temperature)},
top_p{std::move(other.top_p)},
Expand All @@ -31,7 +31,7 @@ struct CreateAssistantDto : public BaseDto<CreateAssistantDto> {
model = std::move(other.model);
name = std::move(other.name);
description = std::move(other.description);
instruction = std::move(other.instruction);
instructions = std::move(other.instructions);
metadata = std::move(other.metadata);
temperature = std::move(other.temperature);
top_p = std::move(other.top_p);
Expand All @@ -46,7 +46,7 @@ struct CreateAssistantDto : public BaseDto<CreateAssistantDto> {

std::optional<std::string> description;

std::optional<std::string> instruction;
std::optional<std::string> instructions;

// namH: implement tools

Expand All @@ -58,17 +58,26 @@ struct CreateAssistantDto : public BaseDto<CreateAssistantDto> {

std::optional<float> top_p;

std::optional<std::string> response_format;
std::optional<std::variant<std::string, Json::Value>> response_format;

bool Validate() const override {
cpp::result<void, std::string> Validate() const override {
if (model.empty()) {
return false;
return cpp::fail("Model is mandatory");
}

return true;
if (response_format.has_value()) {
const auto& variant_value = response_format.value();
if (std::holds_alternative<std::string>(variant_value)) {
if (std::get<std::string>(variant_value) != "auto") {
return cpp::fail("Invalid response_format");
}
}
}

return {};
}

CreateAssistantDto FromJson(Json::Value&& root) override {
static CreateAssistantDto FromJson(Json::Value&& root) {
if (root.empty()) {
throw std::runtime_error("Json passed in FromJson can't be empty");
}
Expand All @@ -80,8 +89,8 @@ struct CreateAssistantDto : public BaseDto<CreateAssistantDto> {
if (root.isMember("description")) {
dto.description = std::move(root["description"].asString());
}
if (root.isMember("instruction")) {
dto.instruction = std::move(root["instruction"].asString());
if (root.isMember("instructions")) {
dto.instructions = std::move(root["instructions"].asString());
}
if (root["metadata"].isObject() && !root["metadata"].empty()) {
auto res = Cortex::ConvertJsonValueToMap(root["metadata"]);
Expand All @@ -98,7 +107,15 @@ struct CreateAssistantDto : public BaseDto<CreateAssistantDto> {
dto.top_p = root["top_p"].asFloat();
}
if (root.isMember("response_format")) {
dto.response_format = std::move(root["response_format"].asString());
const auto& response_format = root["response_format"];
if (response_format.isString()) {
dto.response_format = response_format.asString();
} else if (response_format.isObject()) {
dto.response_format = response_format;
} else {
throw std::runtime_error(
"response_format must be either a string or an object");
}
}
return dto;
}
Expand Down
26 changes: 17 additions & 9 deletions engine/common/dto/assistant_update_dto.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ struct UpdateAssistantDto : public BaseDto<UpdateAssistantDto> {

std::optional<std::string> description;

std::optional<std::string> instruction;
std::optional<std::string> instructions;

// namH: implement tools

Expand All @@ -24,20 +24,20 @@ struct UpdateAssistantDto : public BaseDto<UpdateAssistantDto> {

std::optional<float> top_p;

std::optional<std::string> response_format;
std::optional<std::variant<std::string, Json::Value>> response_format;

bool Validate() const override {
cpp::result<void, std::string> Validate() const override {
if (!model.has_value() && !name.has_value() && !description.has_value() &&
!instruction.has_value() && !metadata.has_value() &&
!instructions.has_value() && !metadata.has_value() &&
!temperature.has_value() && !top_p.has_value() &&
!response_format.has_value()) {
return false;
return cpp::fail("At least one field must be provided");
}

return true;
return {};
}

UpdateAssistantDto FromJson(Json::Value&& root) override {
static UpdateAssistantDto FromJson(Json::Value&& root) {
if (root.empty()) {
throw std::runtime_error("Json passed in FromJson can't be empty");
}
Expand All @@ -50,7 +50,7 @@ struct UpdateAssistantDto : public BaseDto<UpdateAssistantDto> {
dto.description = std::move(root["description"].asString());
}
if (root.isMember("instruction")) {
dto.instruction = std::move(root["instruction"].asString());
dto.instructions = std::move(root["instruction"].asString());
}
if (root["metadata"].isObject() && !root["metadata"].empty()) {
auto res = Cortex::ConvertJsonValueToMap(root["metadata"]);
Expand All @@ -67,7 +67,15 @@ struct UpdateAssistantDto : public BaseDto<UpdateAssistantDto> {
dto.top_p = root["top_p"].asFloat();
}
if (root.isMember("response_format")) {
dto.response_format = std::move(root["response_format"].asString());
const auto& response_format = root["response_format"];
if (response_format.isString()) {
dto.response_format = response_format.asString();
} else if (response_format.isObject()) {
dto.response_format = response_format;
} else {
throw std::runtime_error(
"response_format must be either a string or an object");
}
}
return dto;
};
Expand Down
5 changes: 2 additions & 3 deletions engine/common/dto/base_dto.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include <json/value.h>
#include "utils/result.hpp"

namespace dto {
template <typename T>
Expand All @@ -10,8 +11,6 @@ struct BaseDto {
/**
* Validate itself.
*/
virtual bool Validate() const = 0;

virtual T FromJson(Json::Value&& root) = 0;
virtual cpp::result<void, std::string> Validate() const = 0;
};
} // namespace dto
13 changes: 8 additions & 5 deletions engine/common/message_attachment.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
#include "common/json_serializable.h"

namespace OpenAi {

// The tools to add this file to.
struct Tool {
std::string type;
Expand All @@ -15,13 +14,17 @@ struct Tool {
};

// The type of tool being defined: code_interpreter
struct CodeInterpreter : Tool {
CodeInterpreter() : Tool{"code_interpreter"} {}
struct MessageCodeInterpreter : Tool {
MessageCodeInterpreter() : Tool{"code_interpreter"} {}

~MessageCodeInterpreter() = default;
};

// The type of tool being defined: file_search
struct FileSearch : Tool {
FileSearch() : Tool{"file_search"} {}
struct MessageFileSearch : Tool {
MessageFileSearch() : Tool{"file_search"} {}

~MessageFileSearch() = default;
};

// A list of files attached to the message, and the tools they were added to.
Expand Down
2 changes: 1 addition & 1 deletion engine/common/repository/assistant_repository.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class AssistantRepository {
virtual cpp::result<OpenAi::Assistant, std::string> CreateAssistant(
OpenAi::Assistant& assistant) = 0;

virtual cpp::result<OpenAi::Assistant, std::string> RetrieveAssisant(
virtual cpp::result<OpenAi::Assistant, std::string> RetrieveAssistant(
const std::string assistant_id) const = 0;

virtual cpp::result<void, std::string> ModifyAssistant(
Expand Down
Loading

0 comments on commit c67dc4e

Please sign in to comment.