diff --git a/lib/gpt_agent.ex b/lib/gpt_agent.ex index d15c1e4..521f40b 100644 --- a/lib/gpt_agent.ex +++ b/lib/gpt_agent.ex @@ -57,7 +57,6 @@ defmodule GptAgent do defp noreply(%__MODULE__{} = state), do: {:noreply, state, state.timeout_ms} defp noreply(%__MODULE__{} = state, next), do: {:noreply, state, next} defp reply(%__MODULE__{} = state, reply), do: {:reply, reply, state, state.timeout_ms} - defp reply(%__MODULE__{} = state, reply, next), do: {:reply, reply, state, next} defp stop(%__MODULE__{} = state), do: {:stop, :normal, state} defp log(message, level \\ :debug) when is_binary(message), @@ -215,46 +214,21 @@ defmodule GptAgent do {:noreply, %{state | assistant_id: assistant_id}} end - @impl true def handle_cast({:set_last_message_id, last_message_id}, %__MODULE__{} = state) do log("Setting last message ID to #{last_message_id}") {:noreply, %{state | last_message_id: last_message_id}} end - @impl true - def handle_call(:run_in_progress?, _caller, %__MODULE__{} = state) do - reply(state, state.running?) - end - - def handle_call(:shutdown, _caller, %__MODULE__{} = state) do - log("Shutting down") - Registry.unregister(GptAgent.Registry, state.thread_id) - stop(state) - end - - @impl true - def handle_call(:thread_id, _caller, %__MODULE__{} = state) do - log("Returning thread ID #{inspect(state.thread_id)}") - reply(state, {:ok, state.thread_id}) - end - - @impl true - def handle_call(:assistant_id, _caller, %__MODULE__{} = state) do - log("Returning default assistant ID #{inspect(state.assistant_id)}") - reply(state, {:ok, state.assistant_id}) - end - - @impl true - def handle_call({:add_user_message, message}, _caller, %__MODULE__{running?: true} = state) do + def handle_cast({:add_user_message, message}, %__MODULE__{running?: true} = state) do log( "Attempting to add user message, but run in progress, cannot add user message: #{inspect(message)}" ) - reply(state, {:error, :run_in_progress}) + GenServer.cast(self(), {:add_user_message, message}) + noreply(state) end - @impl true - def handle_call({:add_user_message, %UserMessage{} = message}, _caller, %__MODULE__{} = state) do + def handle_cast({:add_user_message, %UserMessage{} = message}, %__MODULE__{} = state) do log("Adding user message #{inspect(message)}") {:ok, %{body: %{"id" => id}}} = @@ -271,10 +245,30 @@ defmodule GptAgent do content: message ) ) - |> reply(:ok, {:continue, :run}) + |> noreply({:continue, :run}) end @impl true + def handle_call(:run_in_progress?, _caller, %__MODULE__{} = state) do + reply(state, state.running?) + end + + def handle_call(:shutdown, _caller, %__MODULE__{} = state) do + log("Shutting down") + Registry.unregister(GptAgent.Registry, state.thread_id) + stop(state) + end + + def handle_call(:thread_id, _caller, %__MODULE__{} = state) do + log("Returning thread ID #{inspect(state.thread_id)}") + reply(state, {:ok, state.thread_id}) + end + + def handle_call(:assistant_id, _caller, %__MODULE__{} = state) do + log("Returning default assistant ID #{inspect(state.assistant_id)}") + reply(state, {:ok, state.assistant_id}) + end + def handle_call( {:submit_tool_output, tool_call_id, tool_output}, _caller, @@ -287,7 +281,6 @@ defmodule GptAgent do reply(state, {:error, :run_not_in_progress}) end - @impl true def handle_call( {:submit_tool_output, tool_call_id, tool_output}, _caller, @@ -593,7 +586,7 @@ defmodule GptAgent do @impl true def add_user_message(pid, message) do if Process.alive?(pid) do - GenServer.call(pid, {:add_user_message, %UserMessage{content: message}}) + GenServer.cast(pid, {:add_user_message, %UserMessage{content: message}}) else handle_dead_process(pid) end diff --git a/test/gpt_agent_test.exs b/test/gpt_agent_test.exs index 6e0e837..2d83c91 100644 --- a/test/gpt_agent_test.exs +++ b/test/gpt_agent_test.exs @@ -376,7 +376,7 @@ defmodule GptAgentTest do {:ok, pid} = GptAgent.connect(thread_id: thread_id, last_message_id: nil, assistant_id: assistant_id) - assert {:error, :run_in_progress} = GptAgent.add_user_message(pid, Faker.Lorem.sentence()) + assert GptAgent.run_in_progress?(pid) GptAgent.shutdown(pid) @@ -424,7 +424,7 @@ defmodule GptAgentTest do {:ok, pid} = GptAgent.connect(thread_id: thread_id, last_message_id: nil, assistant_id: assistant_id) - assert :ok = GptAgent.add_user_message(pid, Faker.Lorem.sentence()) + refute GptAgent.run_in_progress?(pid) end test "sets the last_message_id based on the passed value", %{ @@ -991,70 +991,7 @@ defmodule GptAgentTest do 5_000 end - test "returns {:error, :pending_tool_calls} if the agent is waiting on tool calls to be submitted", - %{ - bypass: bypass, - assistant_id: assistant_id, - thread_id: thread_id, - run_id: run_id - } do - {:ok, pid} = - GptAgent.connect(thread_id: thread_id, last_message_id: nil, assistant_id: assistant_id) - - Bypass.stub(bypass, "GET", "/v1/threads/#{thread_id}/runs/#{run_id}", fn conn -> - conn - |> Plug.Conn.put_resp_content_type("application/json") - |> Plug.Conn.resp( - 200, - Jason.encode!(%{ - "id" => run_id, - "object" => "thread.run", - "created_at" => 1_699_075_072, - "assistant_id" => assistant_id, - "thread_id" => thread_id, - "status" => "requires_action", - "required_action" => %{ - "type" => "submit_tool_outputs", - "submit_tool_outputs" => %{ - "tool_calls" => [ - %{ - "id" => UUID.uuid4(), - "type" => "function", - "function" => %{"name" => "tool_1", "arguments" => ~s({"foo":"bar","baz":1})} - }, - %{ - "id" => UUID.uuid4(), - "type" => "function", - "function" => %{ - "name" => "tool_2", - "arguments" => ~s({"ham":"spam","wham":2}) - } - } - ] - } - }, - "started_at" => 1_699_075_072, - "expires_at" => nil, - "cancelled_at" => nil, - "failed_at" => nil, - "completed_at" => 1_699_075_073, - "last_error" => nil, - "model" => "gpt-4-1106-preview", - "instructions" => nil, - "tools" => [], - "file_ids" => [], - "metadata" => %{} - }) - ) - end) - - :ok = GptAgent.add_user_message(pid, Faker.Lorem.sentence()) - - assert {:error, :run_in_progress} = - GptAgent.add_user_message(pid, Faker.Lorem.sentence()) - end - - test "allow adding additional messages if the run is complete", %{ + test "allow adding additional messages if the run is not complete", %{ assistant_id: assistant_id, thread_id: thread_id, run_id: run_id @@ -1063,6 +1000,10 @@ defmodule GptAgentTest do GptAgent.connect(thread_id: thread_id, last_message_id: nil, assistant_id: assistant_id) :ok = GptAgent.add_user_message(pid, Faker.Lorem.sentence()) + assert_receive {^pid, %UserMessageAdded{}}, 5_000 + + :ok = GptAgent.add_user_message(pid, Faker.Lorem.sentence()) + assert_receive {^pid, %UserMessageAdded{}}, 5_000 assert_receive {^pid, %RunCompleted{