diff --git a/lib/gpt_agent.ex b/lib/gpt_agent.ex index 8bc3a8c..f899cc5 100644 --- a/lib/gpt_agent.ex +++ b/lib/gpt_agent.ex @@ -28,11 +28,11 @@ defmodule GptAgent do typedstruct do field :assistant_id, Types.assistant_id(), enforce: true field :thread_id, Types.thread_id(), enforce: true + field :last_message_id, Types.message_id() | nil, enforce: true field :running?, boolean(), default: false field :run_id, Types.run_id() | nil field :tool_calls, [ToolCallRequested.t()], default: [] field :tool_outputs, [ToolCallOutputRecorded.t()], default: [] - field :last_message_id, Types.message_id() | nil field :timeout_ms, non_neg_integer(), default: @timeout_ms end @@ -201,6 +201,12 @@ defmodule GptAgent do {:noreply, %{state | assistant_id: assistant_id}} end + @impl true + def handle_cast({:set_last_message_id, last_message_id}, %__MODULE__{} = state) do + log("Setting last message ID to #{last_message_id}") + {:noreply, %{state | last_message_id: last_message_id}} + end + @impl true def handle_call(:shutdown, _caller, %__MODULE__{} = state) do log("Shutting down") @@ -400,6 +406,8 @@ defmodule GptAgent do defp log(message, level \\ :debug) when is_binary(message), do: Logger.log(level, "[GptAgent (#{inspect(self())})] " <> message) + defp ok(data), do: {:ok, data} + @impl true def create_thread do log("Creating thread") @@ -419,7 +427,7 @@ defmodule GptAgent do @impl true def connect(opts) when is_list(opts) do - opts = validate_and_convert_opts(opts) + {:ok, opts} = validate_and_convert_opts(opts) opts |> connect_to_new_or_existing_agent() @@ -431,24 +439,61 @@ defmodule GptAgent do case Registry.lookup(GptAgent.Registry, opts.thread_id) do [{pid, :gpt_agent}] -> - handle_existing_agent(pid) + handle_existing_agent(pid, opts.last_message_id) [] -> - handle_no_existing_agent(opts.thread_id, opts.assistant_id, opts.timeout_ms) + handle_no_existing_agent( + opts.thread_id, + opts.last_message_id, + opts.assistant_id, + opts.timeout_ms + ) end end defp validate_and_convert_opts(opts) do Keyword.validate!(opts, [ :thread_id, + :last_message_id, + :assistant_id, subscribe: true, - assistant_id: nil, - last_message_id: nil, timeout_ms: nil ]) |> Enum.into(%{}) + |> ok() + |> validate_thread_id() + |> validate_last_message_id() + |> validate_assistant_id() end + defp validate_thread_id({:ok, %{thread_id: _thread_id} = opts}) do + ok(opts) + end + + defp validate_thread_id({:ok, _opts}) do + {:error, :missing_thread_id} + end + + defp validate_last_message_id({:ok, %{last_message_id: _last_message_id} = opts}) do + ok(opts) + end + + defp validate_last_message_id({:ok, _opts}) do + {:error, :missing_last_message_id} + end + + defp validate_last_message_id({:error, _} = error), do: error + + defp validate_assistant_id({:ok, %{assistant_id: _assistant_id} = opts}) do + ok(opts) + end + + defp validate_assistant_id({:ok, _opts}) do + {:error, :missing_assistant_id} + end + + defp validate_assistant_id({:error, _} = error), do: error + defp maybe_subscribe({:ok, _pid} = result, opts) do if opts.subscribe do Phoenix.PubSub.subscribe(GptAgent.PubSub, "gpt_agent:#{opts.thread_id}") @@ -459,17 +504,20 @@ defmodule GptAgent do defp maybe_subscribe(result, _opts), do: result - defp handle_existing_agent(pid) do + defp handle_existing_agent(pid, last_message_id) do log("Found existing GPT Agent with PID #{inspect(pid)}") + log("Updating last message ID to #{inspect(last_message_id)}") + GenServer.cast(pid, {:set_last_message_id, last_message_id}) {:ok, pid} end - defp handle_no_existing_agent(thread_id, assistant_id, timeout_ms) do + defp handle_no_existing_agent(thread_id, last_message_id, assistant_id, timeout_ms) do log("No existing GPT Agent found, starting new one") state = GptAgent.new!( thread_id: thread_id, + last_message_id: last_message_id, assistant_id: assistant_id, timeout_ms: timeout_ms || default_timeout_ms() ) diff --git a/mix.exs b/mix.exs index 140925b..33bcf1c 100644 --- a/mix.exs +++ b/mix.exs @@ -4,7 +4,7 @@ defmodule GptAgent.MixProject do def project do [ app: :gpt_agent, - version: "5.1.1", + version: "6.0.0", elixir: "~> 1.16", start_permanent: Mix.env() == :prod, aliases: aliases(), diff --git a/test/gpt_agent_test.exs b/test/gpt_agent_test.exs index d69a96f..77bdf22 100644 --- a/test/gpt_agent_test.exs +++ b/test/gpt_agent_test.exs @@ -198,14 +198,24 @@ defmodule GptAgentTest do thread_id: thread_id, assistant_id: assistant_id } do - assert {:ok, pid} = GptAgent.connect(thread_id: thread_id, assistant_id: assistant_id) + assert {:ok, pid} = + GptAgent.connect( + thread_id: thread_id, + last_message_id: nil, + assistant_id: assistant_id + ) + assert Process.alive?(pid) end test "does not start a new GptAgent process for the given thread ID if one is already running", %{thread_id: thread_id, assistant_id: assistant_id} do - {:ok, pid1} = GptAgent.connect(thread_id: thread_id, assistant_id: assistant_id) - {:ok, pid2} = GptAgent.connect(thread_id: thread_id, assistant_id: assistant_id) + {:ok, pid1} = + GptAgent.connect(thread_id: thread_id, last_message_id: nil, assistant_id: assistant_id) + + {:ok, pid2} = + GptAgent.connect(thread_id: thread_id, last_message_id: nil, assistant_id: assistant_id) + assert pid1 == pid2 GptAgent.shutdown(pid1) end @@ -215,7 +225,12 @@ defmodule GptAgentTest do assistant_id: assistant_id } do {:ok, pid} = - GptAgent.connect(thread_id: thread_id, assistant_id: assistant_id, timeout_ms: 10) + GptAgent.connect( + thread_id: thread_id, + last_message_id: nil, + assistant_id: assistant_id, + timeout_ms: 10 + ) assert Process.alive?(pid) refute_eventually(Process.alive?(pid), 20) @@ -232,14 +247,20 @@ defmodule GptAgentTest do end) assert {:error, :invalid_thread_id} = - GptAgent.connect(thread_id: thread_id, assistant_id: Faker.Lorem.word()) + GptAgent.connect( + thread_id: thread_id, + last_message_id: nil, + assistant_id: Faker.Lorem.word() + ) end test "subscribes to updates for the thread", %{ thread_id: thread_id, assistant_id: assistant_id } do - {:ok, pid} = GptAgent.connect(thread_id: thread_id, assistant_id: assistant_id) + {:ok, pid} = + GptAgent.connect(thread_id: thread_id, last_message_id: nil, assistant_id: assistant_id) + :ok = GptAgent.add_user_message(pid, Faker.Lorem.sentence()) assert_receive {^pid, %UserMessageAdded{}} end @@ -249,26 +270,25 @@ defmodule GptAgentTest do assistant_id: assistant_id } do {:ok, pid} = - GptAgent.connect(thread_id: thread_id, assistant_id: assistant_id, subscribe: false) + GptAgent.connect( + thread_id: thread_id, + last_message_id: nil, + assistant_id: assistant_id, + subscribe: false + ) :ok = GptAgent.add_user_message(pid, Faker.Lorem.sentence()) refute_receive {^pid, %UserMessageAdded{}} end - test "starts a GptAgent process for the given thread ID if no such process is running and sets the default assistant ID", - %{ - thread_id: thread_id, - assistant_id: assistant_id - } do - assert {:ok, pid} = GptAgent.connect(thread_id: thread_id, assistant_id: assistant_id) - assert Process.alive?(pid) - assert %GptAgent{thread_id: ^thread_id, assistant_id: ^assistant_id} = :sys.get_state(pid) - end - test "does not update the assistant id on an agent if the agent is already running", %{thread_id: thread_id, assistant_id: assistant_id} do - {:ok, pid1} = GptAgent.connect(thread_id: thread_id, assistant_id: assistant_id) - {:ok, pid2} = GptAgent.connect(thread_id: thread_id, assistant_id: UUID.uuid4()) + {:ok, pid1} = + GptAgent.connect(thread_id: thread_id, last_message_id: nil, assistant_id: assistant_id) + + {:ok, pid2} = + GptAgent.connect(thread_id: thread_id, last_message_id: nil, assistant_id: UUID.uuid4()) + assert pid1 == pid2 assert %GptAgent{assistant_id: ^assistant_id} = :sys.get_state(pid1) end @@ -319,9 +339,46 @@ defmodule GptAgentTest do end ) - {:ok, pid} = GptAgent.connect(thread_id: thread_id, assistant_id: assistant_id) + {:ok, pid} = + GptAgent.connect(thread_id: thread_id, last_message_id: nil, assistant_id: assistant_id) + assert {:error, :run_in_progress} = GptAgent.add_user_message(pid, Faker.Lorem.sentence()) end + + test "sets the last_message_id based on the passed value", %{ + thread_id: thread_id, + assistant_id: assistant_id + } do + {:ok, pid} = + GptAgent.connect( + thread_id: thread_id, + last_message_id: "msg_abc123", + assistant_id: assistant_id + ) + + assert %GptAgent{last_message_id: "msg_abc123"} = :sys.get_state(pid) + end + + test "updates the last_message_id based on the passed value", %{ + thread_id: thread_id, + assistant_id: assistant_id + } do + {:ok, _pid} = + GptAgent.connect( + thread_id: thread_id, + last_message_id: "msg_abc123", + assistant_id: assistant_id + ) + + {:ok, pid} = + GptAgent.connect( + thread_id: thread_id, + last_message_id: "msg_abc456", + assistant_id: assistant_id + ) + + assert %GptAgent{last_message_id: "msg_abc456"} = :sys.get_state(pid) + end end describe "shutdown/1" do @@ -329,7 +386,9 @@ defmodule GptAgentTest do thread_id: thread_id, assistant_id: assistant_id } do - {:ok, pid} = GptAgent.connect(thread_id: thread_id, assistant_id: assistant_id) + {:ok, pid} = + GptAgent.connect(thread_id: thread_id, last_message_id: nil, assistant_id: assistant_id) + assert Process.alive?(pid) assert :ok = GptAgent.shutdown(pid) @@ -344,7 +403,8 @@ defmodule GptAgentTest do thread_id: thread_id, assistant_id: assistant_id } do - {:ok, pid} = GptAgent.connect(thread_id: thread_id, assistant_id: assistant_id) + {:ok, pid} = + GptAgent.connect(thread_id: thread_id, last_message_id: nil, assistant_id: assistant_id) user_message_id = UUID.uuid4() message_content = Faker.Lorem.paragraph() @@ -395,7 +455,8 @@ defmodule GptAgentTest do thread_id: thread_id, assistant_id: assistant_id } do - {:ok, pid} = GptAgent.connect(thread_id: thread_id, assistant_id: assistant_id) + {:ok, pid} = + GptAgent.connect(thread_id: thread_id, last_message_id: nil, assistant_id: assistant_id) Bypass.expect_once(bypass, "POST", "/v1/threads/#{thread_id}/runs", fn conn -> {:ok, body, conn} = Plug.Conn.read_body(conn) @@ -435,7 +496,8 @@ defmodule GptAgentTest do thread_id: thread_id, run_id: run_id } do - {:ok, pid} = GptAgent.connect(thread_id: thread_id, assistant_id: assistant_id) + {:ok, pid} = + GptAgent.connect(thread_id: thread_id, last_message_id: nil, assistant_id: assistant_id) Bypass.expect_once(bypass, "GET", "/v1/threads/#{thread_id}/runs/#{run_id}", fn conn -> conn @@ -482,7 +544,8 @@ defmodule GptAgentTest do thread_id: thread_id, run_id: run_id } do - {:ok, pid} = GptAgent.connect(thread_id: thread_id, assistant_id: assistant_id) + {:ok, pid} = + GptAgent.connect(thread_id: thread_id, last_message_id: nil, assistant_id: assistant_id) message_id = UUID.uuid4() @@ -569,7 +632,8 @@ defmodule GptAgentTest do thread_id: thread_id, run_id: run_id } do - {:ok, pid} = GptAgent.connect(thread_id: thread_id, assistant_id: assistant_id) + {:ok, pid} = + GptAgent.connect(thread_id: thread_id, last_message_id: nil, assistant_id: assistant_id) message_id_1 = UUID.uuid4() message_content_1 = Faker.Lorem.paragraph() @@ -728,7 +792,8 @@ defmodule GptAgentTest do thread_id: thread_id, run_id: run_id } do - {:ok, pid} = GptAgent.connect(thread_id: thread_id, assistant_id: assistant_id) + {:ok, pid} = + GptAgent.connect(thread_id: thread_id, last_message_id: nil, assistant_id: assistant_id) tool_1_id = UUID.uuid4() tool_2_id = UUID.uuid4() @@ -810,7 +875,8 @@ defmodule GptAgentTest do thread_id: thread_id, run_id: run_id } do - {:ok, pid} = GptAgent.connect(thread_id: thread_id, assistant_id: assistant_id) + {:ok, pid} = + GptAgent.connect(thread_id: thread_id, last_message_id: nil, assistant_id: assistant_id) Bypass.stub(bypass, "GET", "/v1/threads/#{thread_id}/runs/#{run_id}", fn conn -> conn @@ -870,7 +936,8 @@ defmodule GptAgentTest do thread_id: thread_id, run_id: run_id } do - {:ok, pid} = GptAgent.connect(thread_id: thread_id, assistant_id: assistant_id) + {:ok, pid} = + GptAgent.connect(thread_id: thread_id, last_message_id: nil, assistant_id: assistant_id) :ok = GptAgent.add_user_message(pid, Faker.Lorem.sentence()) @@ -892,7 +959,8 @@ defmodule GptAgentTest do assistant_id: assistant_id, thread_id: thread_id } do - {:ok, pid} = GptAgent.connect(thread_id: thread_id, assistant_id: assistant_id) + {:ok, pid} = + GptAgent.connect(thread_id: thread_id, last_message_id: nil, assistant_id: assistant_id) assert {:error, :run_not_in_progress} = GptAgent.submit_tool_output(pid, UUID.uuid4(), %{}) @@ -905,7 +973,8 @@ defmodule GptAgentTest do thread_id: thread_id, run_id: run_id } do - {:ok, pid} = GptAgent.connect(thread_id: thread_id, assistant_id: assistant_id) + {:ok, pid} = + GptAgent.connect(thread_id: thread_id, last_message_id: nil, assistant_id: assistant_id) Bypass.stub(bypass, "GET", "/v1/threads/#{thread_id}/runs/#{run_id}", fn conn -> conn @@ -969,7 +1038,8 @@ defmodule GptAgentTest do thread_id: thread_id, run_id: run_id } do - {:ok, pid} = GptAgent.connect(thread_id: thread_id, assistant_id: assistant_id) + {:ok, pid} = + GptAgent.connect(thread_id: thread_id, last_message_id: nil, assistant_id: assistant_id) tool_1_id = UUID.uuid4() tool_2_id = UUID.uuid4() @@ -1062,7 +1132,8 @@ defmodule GptAgentTest do thread_id: thread_id, run_id: run_id } do - {:ok, pid} = GptAgent.connect(thread_id: thread_id, assistant_id: assistant_id) + {:ok, pid} = + GptAgent.connect(thread_id: thread_id, last_message_id: nil, assistant_id: assistant_id) tool_1_id = UUID.uuid4() tool_2_id = UUID.uuid4()