Skip to content

Commit

Permalink
Add user's message to the assistant thread
Browse files Browse the repository at this point in the history
  • Loading branch information
jwilger committed Dec 16, 2023
1 parent 7f0dfa7 commit 3d30cdd
Show file tree
Hide file tree
Showing 11 changed files with 311 additions and 27 deletions.
22 changes: 1 addition & 21 deletions .doctor.exs
Original file line number Diff line number Diff line change
@@ -1,25 +1,5 @@
%Doctor.Config{
ignore_modules: [
TaleForge.Accounts.User,
TaleForge.Accounts.UserToken,
TaleForge.Repo,
TaleForgeWeb,
TaleForgeWeb.Endpoint,
TaleForgeWeb.ErrorHTML,
TaleForgeWeb.ErrorJSON,
TaleForgeWeb.PageController,
TaleForgeWeb.PageHTML,
TaleForgeWeb.Router,
TaleForgeWeb.Telemetry,
TaleForgeWeb.UserConfirmationInstructionsLive,
TaleForgeWeb.UserConfirmationLive,
TaleForgeWeb.UserForgotPasswordLive,
TaleForgeWeb.UserLoginLive,
TaleForgeWeb.UserRegistrationLive,
TaleForgeWeb.UserResetPasswordLive,
TaleForgeWeb.UserSessionController,
TaleForgeWeb.UserSettingsLive
],
ignore_modules: [GptAgent.Value],
ignore_paths: [],
min_module_doc_coverage: 40,
min_module_spec_coverage: 0,
Expand Down
2 changes: 1 addition & 1 deletion .formatter.exs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Used by "mix format"
[
inputs: ["{mix,.formatter}.exs", "{config,lib,test}/**/*.{ex,exs}"],
locals_without_parens: [field: 2, field: 3]
locals_without_parens: [field: 2, field: 3, type: 1, validate_with: 1]
]
2 changes: 1 addition & 1 deletion config/runtime.exs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import Config

config :open_ai_client, OpenAiClient,
base_url: System.get_env("OPENAI_BASE_URL") || "https://api.openai.com/v1",
base_url: System.get_env("OPENAI_BASE_URL") || "https://api.openai.com",
openai_api_key: System.get_env("OPENAI_API_KEY") || raise("OPENAI_API_KEY is not set"),
openai_organization_id: System.get_env("OPENAI_ORGANIZATION_ID")
24 changes: 22 additions & 2 deletions lib/gpt_agent.ex
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ defmodule GptAgent do
use GenServer
use TypedStruct

alias GptAgent.Events.ThreadCreated
alias GptAgent.Events.{ThreadCreated, UserMessageAdded}
alias GptAgent.Values.NonblankString

typedstruct do
field :pid, pid(), enforce: true
Expand Down Expand Up @@ -34,7 +35,7 @@ defmodule GptAgent do
end

def handle_continue(:create_thread, %__MODULE__{thread_id: nil} = state) do
{:ok, %{body: %{"id" => thread_id}}} = OpenAiClient.post("/v1/threads", json: %{})
{:ok, %{body: %{"id" => thread_id}}} = OpenAiClient.post("/v1/threads", json: "")

state
|> Map.put(:thread_id, thread_id)
Expand All @@ -49,11 +50,30 @@ defmodule GptAgent do
|> noreply()
end

def handle_cast({:add_user_message, message}, state) do
{:ok, message} = NonblankString.new(message)

{:ok, %{body: %{"id" => id}}} =
OpenAiClient.post("/v1/threads/#{state.thread_id}/messages", json: message)

state
|> send_callback(%UserMessageAdded{
id: id,
thread_id: state.thread_id,
content: message.value
})
|> noreply()
end

@doc """
Starts the GPT Agent
"""
@spec start_link(pid(), binary() | nil) :: {:ok, pid()} | {:error, reason :: term()}
def start_link(callback_handler, thread_id \\ nil) when is_pid(callback_handler) do
GenServer.start_link(__MODULE__, callback_handler: callback_handler, thread_id: thread_id)
end

def add_user_message(pid, message) do
GenServer.cast(pid, {:add_user_message, message})
end
end
13 changes: 13 additions & 0 deletions lib/gpt_agent/events/user_message_added.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
defmodule GptAgent.Events.UserMessageAdded do
@moduledoc """
An OpenAI Assistants user message was added to a thread
"""

use TypedStruct

typedstruct do
field :id, binary(), enforce: true
field :thread_id, binary(), enforce: true
field :content, String.t(), enforce: true
end
end
98 changes: 98 additions & 0 deletions lib/gpt_agent/value.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
defmodule GptAgent.Value do
@moduledoc """
A module that provides a macro to define a new type and its validation.
"""

@doc """
Macro to use GptAgent.Value in the current module.
## Examples
defmodule MyModule do
use GptAgent.Value
end
"""
defmacro __using__(_opts) do
quote do
require GptAgent.Value
import GptAgent.Value
end
end

@doc """
Macro to define a new type.
## Examples
defmodule MyType do
use GptAgent.Value
type String.t()
end
"""
@spec type(atom()) :: Macro.t()
defmacro type(type) do
quote do
use TypedStruct

typedstruct do
field :value, unquote(type), enforce: true
end

@doc """
Creates a new instance of the type.
The function validates the value and returns a new struct if the value is valid.
If the value is invalid, it returns an error tuple with the error message.
## Params
- `value`: The value to be validated and set in the struct.
## Examples
iex> MyType.new("valid value")
%MyType{value: "valid value"}
iex> MyType.new("invalid value")
{:error, "error message"}
"""
@spec new(any()) :: {:ok, t()} | {:error, String.t()}
def new(value) do
case validate(value) do
:ok ->
{:ok, %__MODULE__{value: value}}

{:error, error} ->
{:error, error}
end
end
end
end

@doc """
Macro to define a validation function for the type.
## Examples
defmodule MyType do
use GptAgent.Value
type String.t()
validate_with fn
value when is_binary(value) ->
case String.trim(value) do
"" -> {:error, "must be a nonblank string"}
_ -> :ok
end
end
end
"""
@spec validate_with(Macro.t()) :: Macro.t()
defmacro validate_with(validate_with) do
quote do
defp validate(value) do
unquote(validate_with).(value)
end
end
end
end
27 changes: 27 additions & 0 deletions lib/gpt_agent/values/nonblank_string.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
defmodule GptAgent.Values.NonblankString do
@moduledoc """
A message that a user has sent to the GPT agent
"""

use GptAgent.Value

type String.t()

validate_with fn
value when is_binary(value) ->
case String.trim(value) do
"" -> {:error, "must be a nonblank string"}
_ -> :ok
end

_ ->
{:error, "must be a nonblank string"}
end

defimpl Jason.Encoder do
def encode(%{value: value}, opts) do
%{role: :user, content: value}
|> Jason.Encoder.encode(opts)
end
end
end
1 change: 1 addition & 0 deletions mix.exs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ defmodule GptAgent.MixProject do
{:mix_test_interactive, "~> 1.2", only: :dev, runtime: false},
{:mox, "~> 1.1", only: :test},
{:open_ai_client, "~> 1.2"},
{:stream_data, "~> 0.6"},
{:sobelow, "~> 0.12", only: [:dev, :test], runtime: false},
{:typed_struct, "~> 0.3"},
{:uuid, "~> 1.1"}
Expand Down
25 changes: 25 additions & 0 deletions test/gpt_agent/values/nonblank_string_test.exs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
defmodule GptAgent.Values.NonblankStringTest do
@moduledoc false

use GptAgent.TestCase, async: true

alias GptAgent.Values.NonblankString

test "a non-blank string is valid" do
check all(value <- nonblank_string()) do
assert {:ok, %NonblankString{value: ^value}} = NonblankString.new(value)
end
end

test "blank strings are invalid" do
assert {:error, "must be a nonblank string"} = NonblankString.new("")
assert {:error, "must be a nonblank string"} = NonblankString.new(" ")
assert {:error, "must be a nonblank string"} = NonblankString.new(" \n\t ")
end

test "all other values are invalid" do
check all(value <- term_other_than_nonblank_string()) do
assert {:error, "must be a nonblank string"} = NonblankString.new(value)
end
end
end
52 changes: 51 additions & 1 deletion test/gpt_agent_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ defmodule GptAgentTest do
use ExUnit.Case
doctest GptAgent

alias GptAgent.Events.ThreadCreated
alias GptAgent.Events.{ThreadCreated, UserMessageAdded}

setup _context do
bypass = Bypass.open()
Expand Down Expand Up @@ -89,4 +89,54 @@ defmodule GptAgentTest do
refute_receive {GptAgent, ^pid, %ThreadCreated{}}, 100
end
end

describe "add_user_message/2" do
test "adds the user message to the agent's thread via the OpenAI API", %{
bypass: bypass,
thread_id: thread_id
} do
{:ok, pid} = GptAgent.start_link(self(), thread_id)

user_message_id = Faker.Lorem.word()
message_content = Faker.Lorem.paragraph()

Bypass.expect_once(bypass, "POST", "/v1/threads/#{thread_id}/messages", fn conn ->
conn
|> Plug.Conn.put_resp_content_type("application/json")
|> Plug.Conn.resp(
201,
Jason.encode!(%{
"id" => user_message_id,
"object" => "thread.message",
"created_at" => "1699012949",
"thread_id" => thread_id,
"role" => "user",
"content" => [
%{
"type" => "text",
"text" => %{
"value" => message_content,
"annotations" => []
}
}
],
"file_ids" => [],
"assistant_id" => nil,
"run_id" => nil,
"metadata" => %{}
})
)
end)

:ok = GptAgent.add_user_message(pid, message_content)

assert_receive {GptAgent, ^pid,
%UserMessageAdded{
id: ^user_message_id,
thread_id: ^thread_id,
content: ^message_content
}},
5_000
end
end
end
Loading

0 comments on commit 3d30cdd

Please sign in to comment.