Skip to content

Commit

Permalink
Change adding user message to a cast
Browse files Browse the repository at this point in the history
Making this a cast so that the caller doesn't need to synchronously wait
for OpenAI to actually receive the request, which can take some time.
This should make callers more resilient at the expense of needing to
watch for the UserMessageAdded event.
  • Loading branch information
jwilger committed Feb 18, 2024
1 parent 468b65c commit ccab0a3
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 99 deletions.
59 changes: 26 additions & 33 deletions lib/gpt_agent.ex
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ defmodule GptAgent do
defp noreply(%__MODULE__{} = state), do: {:noreply, state, state.timeout_ms}
defp noreply(%__MODULE__{} = state, next), do: {:noreply, state, next}
defp reply(%__MODULE__{} = state, reply), do: {:reply, reply, state, state.timeout_ms}
defp reply(%__MODULE__{} = state, reply, next), do: {:reply, reply, state, next}
defp stop(%__MODULE__{} = state), do: {:stop, :normal, state}

defp log(message, level \\ :debug) when is_binary(message),
Expand Down Expand Up @@ -215,46 +214,21 @@ defmodule GptAgent do
{:noreply, %{state | assistant_id: assistant_id}}
end

@impl true
def handle_cast({:set_last_message_id, last_message_id}, %__MODULE__{} = state) do
log("Setting last message ID to #{last_message_id}")
{:noreply, %{state | last_message_id: last_message_id}}
end

@impl true
def handle_call(:run_in_progress?, _caller, %__MODULE__{} = state) do
reply(state, state.running?)
end

def handle_call(:shutdown, _caller, %__MODULE__{} = state) do
log("Shutting down")
Registry.unregister(GptAgent.Registry, state.thread_id)
stop(state)
end

@impl true
def handle_call(:thread_id, _caller, %__MODULE__{} = state) do
log("Returning thread ID #{inspect(state.thread_id)}")
reply(state, {:ok, state.thread_id})
end

@impl true
def handle_call(:assistant_id, _caller, %__MODULE__{} = state) do
log("Returning default assistant ID #{inspect(state.assistant_id)}")
reply(state, {:ok, state.assistant_id})
end

@impl true
def handle_call({:add_user_message, message}, _caller, %__MODULE__{running?: true} = state) do
def handle_cast({:add_user_message, message}, %__MODULE__{running?: true} = state) do
log(
"Attempting to add user message, but run in progress, cannot add user message: #{inspect(message)}"
)

reply(state, {:error, :run_in_progress})
GenServer.cast(self(), {:add_user_message, message})
noreply(state)
end

@impl true
def handle_call({:add_user_message, %UserMessage{} = message}, _caller, %__MODULE__{} = state) do
def handle_cast({:add_user_message, %UserMessage{} = message}, %__MODULE__{} = state) do
log("Adding user message #{inspect(message)}")

{:ok, %{body: %{"id" => id}}} =
Expand All @@ -271,10 +245,30 @@ defmodule GptAgent do
content: message
)
)
|> reply(:ok, {:continue, :run})
|> noreply({:continue, :run})
end

@impl true
def handle_call(:run_in_progress?, _caller, %__MODULE__{} = state) do
reply(state, state.running?)
end

def handle_call(:shutdown, _caller, %__MODULE__{} = state) do
log("Shutting down")
Registry.unregister(GptAgent.Registry, state.thread_id)
stop(state)
end

def handle_call(:thread_id, _caller, %__MODULE__{} = state) do
log("Returning thread ID #{inspect(state.thread_id)}")
reply(state, {:ok, state.thread_id})
end

def handle_call(:assistant_id, _caller, %__MODULE__{} = state) do
log("Returning default assistant ID #{inspect(state.assistant_id)}")
reply(state, {:ok, state.assistant_id})
end

def handle_call(
{:submit_tool_output, tool_call_id, tool_output},
_caller,
Expand All @@ -287,7 +281,6 @@ defmodule GptAgent do
reply(state, {:error, :run_not_in_progress})
end

@impl true
def handle_call(
{:submit_tool_output, tool_call_id, tool_output},
_caller,
Expand Down Expand Up @@ -593,7 +586,7 @@ defmodule GptAgent do
@impl true
def add_user_message(pid, message) do
if Process.alive?(pid) do
GenServer.call(pid, {:add_user_message, %UserMessage{content: message}})
GenServer.cast(pid, {:add_user_message, %UserMessage{content: message}})
else
handle_dead_process(pid)
end
Expand Down
73 changes: 7 additions & 66 deletions test/gpt_agent_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ defmodule GptAgentTest do
{:ok, pid} =
GptAgent.connect(thread_id: thread_id, last_message_id: nil, assistant_id: assistant_id)

assert {:error, :run_in_progress} = GptAgent.add_user_message(pid, Faker.Lorem.sentence())
assert GptAgent.run_in_progress?(pid)

GptAgent.shutdown(pid)

Expand Down Expand Up @@ -424,7 +424,7 @@ defmodule GptAgentTest do
{:ok, pid} =
GptAgent.connect(thread_id: thread_id, last_message_id: nil, assistant_id: assistant_id)

assert :ok = GptAgent.add_user_message(pid, Faker.Lorem.sentence())
refute GptAgent.run_in_progress?(pid)
end

test "sets the last_message_id based on the passed value", %{
Expand Down Expand Up @@ -991,70 +991,7 @@ defmodule GptAgentTest do
5_000
end

test "returns {:error, :pending_tool_calls} if the agent is waiting on tool calls to be submitted",
%{
bypass: bypass,
assistant_id: assistant_id,
thread_id: thread_id,
run_id: run_id
} do
{:ok, pid} =
GptAgent.connect(thread_id: thread_id, last_message_id: nil, assistant_id: assistant_id)

Bypass.stub(bypass, "GET", "/v1/threads/#{thread_id}/runs/#{run_id}", fn conn ->
conn
|> Plug.Conn.put_resp_content_type("application/json")
|> Plug.Conn.resp(
200,
Jason.encode!(%{
"id" => run_id,
"object" => "thread.run",
"created_at" => 1_699_075_072,
"assistant_id" => assistant_id,
"thread_id" => thread_id,
"status" => "requires_action",
"required_action" => %{
"type" => "submit_tool_outputs",
"submit_tool_outputs" => %{
"tool_calls" => [
%{
"id" => UUID.uuid4(),
"type" => "function",
"function" => %{"name" => "tool_1", "arguments" => ~s({"foo":"bar","baz":1})}
},
%{
"id" => UUID.uuid4(),
"type" => "function",
"function" => %{
"name" => "tool_2",
"arguments" => ~s({"ham":"spam","wham":2})
}
}
]
}
},
"started_at" => 1_699_075_072,
"expires_at" => nil,
"cancelled_at" => nil,
"failed_at" => nil,
"completed_at" => 1_699_075_073,
"last_error" => nil,
"model" => "gpt-4-1106-preview",
"instructions" => nil,
"tools" => [],
"file_ids" => [],
"metadata" => %{}
})
)
end)

:ok = GptAgent.add_user_message(pid, Faker.Lorem.sentence())

assert {:error, :run_in_progress} =
GptAgent.add_user_message(pid, Faker.Lorem.sentence())
end

test "allow adding additional messages if the run is complete", %{
test "allow adding additional messages if the run is not complete", %{
assistant_id: assistant_id,
thread_id: thread_id,
run_id: run_id
Expand All @@ -1063,6 +1000,10 @@ defmodule GptAgentTest do
GptAgent.connect(thread_id: thread_id, last_message_id: nil, assistant_id: assistant_id)

:ok = GptAgent.add_user_message(pid, Faker.Lorem.sentence())
assert_receive {^pid, %UserMessageAdded{}}, 5_000

:ok = GptAgent.add_user_message(pid, Faker.Lorem.sentence())
assert_receive {^pid, %UserMessageAdded{}}, 5_000

assert_receive {^pid,
%RunCompleted{
Expand Down

0 comments on commit ccab0a3

Please sign in to comment.