From 3d30cdd339654165b1cb6f8fadc47cc2202b4941 Mon Sep 17 00:00:00 2001 From: John Wilger Date: Fri, 15 Dec 2023 20:44:52 -0800 Subject: [PATCH] Add user's message to the assistant thread --- .doctor.exs | 22 +---- .formatter.exs | 2 +- config/runtime.exs | 2 +- lib/gpt_agent.ex | 24 ++++- lib/gpt_agent/events/user_message_added.ex | 13 +++ lib/gpt_agent/value.ex | 98 +++++++++++++++++++ lib/gpt_agent/values/nonblank_string.ex | 27 +++++ mix.exs | 1 + .../gpt_agent/values/nonblank_string_test.exs | 25 +++++ test/gpt_agent_test.exs | 52 +++++++++- test/support/test_case.ex | 72 +++++++++++++- 11 files changed, 311 insertions(+), 27 deletions(-) create mode 100644 lib/gpt_agent/events/user_message_added.ex create mode 100644 lib/gpt_agent/value.ex create mode 100644 lib/gpt_agent/values/nonblank_string.ex create mode 100644 test/gpt_agent/values/nonblank_string_test.exs diff --git a/.doctor.exs b/.doctor.exs index f7ac36d..4007d27 100644 --- a/.doctor.exs +++ b/.doctor.exs @@ -1,25 +1,5 @@ %Doctor.Config{ - ignore_modules: [ - TaleForge.Accounts.User, - TaleForge.Accounts.UserToken, - TaleForge.Repo, - TaleForgeWeb, - TaleForgeWeb.Endpoint, - TaleForgeWeb.ErrorHTML, - TaleForgeWeb.ErrorJSON, - TaleForgeWeb.PageController, - TaleForgeWeb.PageHTML, - TaleForgeWeb.Router, - TaleForgeWeb.Telemetry, - TaleForgeWeb.UserConfirmationInstructionsLive, - TaleForgeWeb.UserConfirmationLive, - TaleForgeWeb.UserForgotPasswordLive, - TaleForgeWeb.UserLoginLive, - TaleForgeWeb.UserRegistrationLive, - TaleForgeWeb.UserResetPasswordLive, - TaleForgeWeb.UserSessionController, - TaleForgeWeb.UserSettingsLive - ], + ignore_modules: [GptAgent.Value], ignore_paths: [], min_module_doc_coverage: 40, min_module_spec_coverage: 0, diff --git a/.formatter.exs b/.formatter.exs index a9a7f0f..75b30fe 100644 --- a/.formatter.exs +++ b/.formatter.exs @@ -1,5 +1,5 @@ # Used by "mix format" [ inputs: ["{mix,.formatter}.exs", "{config,lib,test}/**/*.{ex,exs}"], - locals_without_parens: [field: 2, field: 3] + locals_without_parens: [field: 2, field: 3, type: 1, validate_with: 1] ] diff --git a/config/runtime.exs b/config/runtime.exs index 03dce7f..e8b2372 100644 --- a/config/runtime.exs +++ b/config/runtime.exs @@ -1,6 +1,6 @@ import Config config :open_ai_client, OpenAiClient, - base_url: System.get_env("OPENAI_BASE_URL") || "https://api.openai.com/v1", + base_url: System.get_env("OPENAI_BASE_URL") || "https://api.openai.com", openai_api_key: System.get_env("OPENAI_API_KEY") || raise("OPENAI_API_KEY is not set"), openai_organization_id: System.get_env("OPENAI_ORGANIZATION_ID") diff --git a/lib/gpt_agent.ex b/lib/gpt_agent.ex index b7edfbf..7b36b1e 100644 --- a/lib/gpt_agent.ex +++ b/lib/gpt_agent.ex @@ -6,7 +6,8 @@ defmodule GptAgent do use GenServer use TypedStruct - alias GptAgent.Events.ThreadCreated + alias GptAgent.Events.{ThreadCreated, UserMessageAdded} + alias GptAgent.Values.NonblankString typedstruct do field :pid, pid(), enforce: true @@ -34,7 +35,7 @@ defmodule GptAgent do end def handle_continue(:create_thread, %__MODULE__{thread_id: nil} = state) do - {:ok, %{body: %{"id" => thread_id}}} = OpenAiClient.post("/v1/threads", json: %{}) + {:ok, %{body: %{"id" => thread_id}}} = OpenAiClient.post("/v1/threads", json: "") state |> Map.put(:thread_id, thread_id) @@ -49,6 +50,21 @@ defmodule GptAgent do |> noreply() end + def handle_cast({:add_user_message, message}, state) do + {:ok, message} = NonblankString.new(message) + + {:ok, %{body: %{"id" => id}}} = + OpenAiClient.post("/v1/threads/#{state.thread_id}/messages", json: message) + + state + |> send_callback(%UserMessageAdded{ + id: id, + thread_id: state.thread_id, + content: message.value + }) + |> noreply() + end + @doc """ Starts the GPT Agent """ @@ -56,4 +72,8 @@ defmodule GptAgent do def start_link(callback_handler, thread_id \\ nil) when is_pid(callback_handler) do GenServer.start_link(__MODULE__, callback_handler: callback_handler, thread_id: thread_id) end + + def add_user_message(pid, message) do + GenServer.cast(pid, {:add_user_message, message}) + end end diff --git a/lib/gpt_agent/events/user_message_added.ex b/lib/gpt_agent/events/user_message_added.ex new file mode 100644 index 0000000..19ce26f --- /dev/null +++ b/lib/gpt_agent/events/user_message_added.ex @@ -0,0 +1,13 @@ +defmodule GptAgent.Events.UserMessageAdded do + @moduledoc """ + An OpenAI Assistants user message was added to a thread + """ + + use TypedStruct + + typedstruct do + field :id, binary(), enforce: true + field :thread_id, binary(), enforce: true + field :content, String.t(), enforce: true + end +end diff --git a/lib/gpt_agent/value.ex b/lib/gpt_agent/value.ex new file mode 100644 index 0000000..f25aa0c --- /dev/null +++ b/lib/gpt_agent/value.ex @@ -0,0 +1,98 @@ +defmodule GptAgent.Value do + @moduledoc """ + A module that provides a macro to define a new type and its validation. + """ + + @doc """ + Macro to use GptAgent.Value in the current module. + + ## Examples + + defmodule MyModule do + use GptAgent.Value + end + """ + defmacro __using__(_opts) do + quote do + require GptAgent.Value + import GptAgent.Value + end + end + + @doc """ + Macro to define a new type. + + ## Examples + + defmodule MyType do + use GptAgent.Value + type String.t() + end + """ + @spec type(atom()) :: Macro.t() + defmacro type(type) do + quote do + use TypedStruct + + typedstruct do + field :value, unquote(type), enforce: true + end + + @doc """ + Creates a new instance of the type. + + The function validates the value and returns a new struct if the value is valid. + If the value is invalid, it returns an error tuple with the error message. + + ## Params + + - `value`: The value to be validated and set in the struct. + + ## Examples + + iex> MyType.new("valid value") + %MyType{value: "valid value"} + + iex> MyType.new("invalid value") + {:error, "error message"} + + """ + @spec new(any()) :: {:ok, t()} | {:error, String.t()} + def new(value) do + case validate(value) do + :ok -> + {:ok, %__MODULE__{value: value}} + + {:error, error} -> + {:error, error} + end + end + end + end + + @doc """ + Macro to define a validation function for the type. + + ## Examples + + defmodule MyType do + use GptAgent.Value + type String.t() + validate_with fn + value when is_binary(value) -> + case String.trim(value) do + "" -> {:error, "must be a nonblank string"} + _ -> :ok + end + end + end + """ + @spec validate_with(Macro.t()) :: Macro.t() + defmacro validate_with(validate_with) do + quote do + defp validate(value) do + unquote(validate_with).(value) + end + end + end +end diff --git a/lib/gpt_agent/values/nonblank_string.ex b/lib/gpt_agent/values/nonblank_string.ex new file mode 100644 index 0000000..d016371 --- /dev/null +++ b/lib/gpt_agent/values/nonblank_string.ex @@ -0,0 +1,27 @@ +defmodule GptAgent.Values.NonblankString do + @moduledoc """ + A message that a user has sent to the GPT agent + """ + + use GptAgent.Value + + type String.t() + + validate_with fn + value when is_binary(value) -> + case String.trim(value) do + "" -> {:error, "must be a nonblank string"} + _ -> :ok + end + + _ -> + {:error, "must be a nonblank string"} + end + + defimpl Jason.Encoder do + def encode(%{value: value}, opts) do + %{role: :user, content: value} + |> Jason.Encoder.encode(opts) + end + end +end diff --git a/mix.exs b/mix.exs index 5879c0f..cda5033 100644 --- a/mix.exs +++ b/mix.exs @@ -53,6 +53,7 @@ defmodule GptAgent.MixProject do {:mix_test_interactive, "~> 1.2", only: :dev, runtime: false}, {:mox, "~> 1.1", only: :test}, {:open_ai_client, "~> 1.2"}, + {:stream_data, "~> 0.6"}, {:sobelow, "~> 0.12", only: [:dev, :test], runtime: false}, {:typed_struct, "~> 0.3"}, {:uuid, "~> 1.1"} diff --git a/test/gpt_agent/values/nonblank_string_test.exs b/test/gpt_agent/values/nonblank_string_test.exs new file mode 100644 index 0000000..419011a --- /dev/null +++ b/test/gpt_agent/values/nonblank_string_test.exs @@ -0,0 +1,25 @@ +defmodule GptAgent.Values.NonblankStringTest do + @moduledoc false + + use GptAgent.TestCase, async: true + + alias GptAgent.Values.NonblankString + + test "a non-blank string is valid" do + check all(value <- nonblank_string()) do + assert {:ok, %NonblankString{value: ^value}} = NonblankString.new(value) + end + end + + test "blank strings are invalid" do + assert {:error, "must be a nonblank string"} = NonblankString.new("") + assert {:error, "must be a nonblank string"} = NonblankString.new(" ") + assert {:error, "must be a nonblank string"} = NonblankString.new(" \n\t ") + end + + test "all other values are invalid" do + check all(value <- term_other_than_nonblank_string()) do + assert {:error, "must be a nonblank string"} = NonblankString.new(value) + end + end +end diff --git a/test/gpt_agent_test.exs b/test/gpt_agent_test.exs index c511077..a7251f0 100644 --- a/test/gpt_agent_test.exs +++ b/test/gpt_agent_test.exs @@ -4,7 +4,7 @@ defmodule GptAgentTest do use ExUnit.Case doctest GptAgent - alias GptAgent.Events.ThreadCreated + alias GptAgent.Events.{ThreadCreated, UserMessageAdded} setup _context do bypass = Bypass.open() @@ -89,4 +89,54 @@ defmodule GptAgentTest do refute_receive {GptAgent, ^pid, %ThreadCreated{}}, 100 end end + + describe "add_user_message/2" do + test "adds the user message to the agent's thread via the OpenAI API", %{ + bypass: bypass, + thread_id: thread_id + } do + {:ok, pid} = GptAgent.start_link(self(), thread_id) + + user_message_id = Faker.Lorem.word() + message_content = Faker.Lorem.paragraph() + + Bypass.expect_once(bypass, "POST", "/v1/threads/#{thread_id}/messages", fn conn -> + conn + |> Plug.Conn.put_resp_content_type("application/json") + |> Plug.Conn.resp( + 201, + Jason.encode!(%{ + "id" => user_message_id, + "object" => "thread.message", + "created_at" => "1699012949", + "thread_id" => thread_id, + "role" => "user", + "content" => [ + %{ + "type" => "text", + "text" => %{ + "value" => message_content, + "annotations" => [] + } + } + ], + "file_ids" => [], + "assistant_id" => nil, + "run_id" => nil, + "metadata" => %{} + }) + ) + end) + + :ok = GptAgent.add_user_message(pid, message_content) + + assert_receive {GptAgent, ^pid, + %UserMessageAdded{ + id: ^user_message_id, + thread_id: ^thread_id, + content: ^message_content + }}, + 5_000 + end + end end diff --git a/test/support/test_case.ex b/test/support/test_case.ex index 5e25625..873b63e 100644 --- a/test/support/test_case.ex +++ b/test/support/test_case.ex @@ -1,13 +1,83 @@ defmodule GptAgent.TestCase do @moduledoc """ - This module provides test case template for GptAgent tests. + Base test case template for the entire GptAgent application + + This case template includes setup and helper functions that are applicable to + all tests of the application. """ use ExUnit.CaseTemplate + use ExUnitProperties using do quote do + use ExUnitProperties import GptAgent.TestCase end end + + @doc """ + `StreamData` generator for UUIDs + """ + def uuid do + StreamData.frequency([ + {3, + StreamData.map(integer(), fn _n -> + UUID.uuid4() + end)}, + {1, + StreamData.map(integer(), fn _n -> + UUID.uuid4() |> UUID.string_to_binary!() + end)} + ]) + end + + def term_other_than_uuid do + StreamData.filter(term(), fn x -> !is_uuid?(x) end) + end + + def is_uuid?(value) do + case UUID.info(value) do + {:ok, _} -> true + _ -> false + end + end + + def url do + StreamData.map(integer(), fn _n -> + Faker.Internet.url() + end) + end + + def term_other_than_url do + StreamData.filter(term(), fn x -> + !is_binary(x) || !is_url(x) + end) + end + + defp is_url(data) do + case URI.new(data) do + {:ok, _uri} -> true + _ -> false + end + rescue + # why? Because the URI library will throw function clause errors on certain + # binaries instead of just returning a gosh-darned {:error, _} tuple + _ -> false + end + + def nonblank_string do + StreamData.string(:printable) + |> StreamData.filter(&(String.trim(&1) != "")) + end + + def term_other_than_nonblank_string do + StreamData.filter(term(), fn x -> + if is_binary(x) do + x |> String.trim() |> String.length() == 0 + else + true + end + end) + end end