diff --git a/config/config.exs b/config/config.exs index f850e6f..98bf6f4 100644 --- a/config/config.exs +++ b/config/config.exs @@ -4,6 +4,9 @@ case config_env() do :dev -> config :mix_test_interactive, clear: true + :test -> + config :bypass, enable_debug_log: true + _ -> nil end diff --git a/config/runtime.exs b/config/runtime.exs index e8b2372..4d7f399 100644 --- a/config/runtime.exs +++ b/config/runtime.exs @@ -4,3 +4,5 @@ config :open_ai_client, OpenAiClient, base_url: System.get_env("OPENAI_BASE_URL") || "https://api.openai.com", openai_api_key: System.get_env("OPENAI_API_KEY") || raise("OPENAI_API_KEY is not set"), openai_organization_id: System.get_env("OPENAI_ORGANIZATION_ID") + +config :gpt_agent, :heartbeat_interval_ms, if(config_env() == :test, do: 1, else: 1000) diff --git a/lib/gpt_agent.ex b/lib/gpt_agent.ex index 982eade..f0b0a13 100644 --- a/lib/gpt_agent.ex +++ b/lib/gpt_agent.ex @@ -6,7 +6,7 @@ defmodule GptAgent do use GenServer use TypedStruct - alias GptAgent.Events.{RunStarted, ThreadCreated, UserMessageAdded} + alias GptAgent.Events.{RunCompleted, RunStarted, ThreadCreated, UserMessageAdded} alias GptAgent.Values.NonblankString typedstruct do @@ -14,6 +14,7 @@ defmodule GptAgent do field :callback_handler, pid(), enforce: true field :assistant_id, binary(), enforce: true field :thread_id, binary() | nil + field :running?, boolean(), default: false end defp continue(state, continue_arg), do: {:ok, state, {:continue, continue_arg}} @@ -61,7 +62,10 @@ defmodule GptAgent do } ) + Process.send_after(self(), {:check_run_status, id}, heartbeat_interval_ms()) + state + |> Map.put(:running?, true) |> send_callback(%RunStarted{ id: id, thread_id: state.thread_id, @@ -70,6 +74,8 @@ defmodule GptAgent do |> noreply() end + defp heartbeat_interval_ms, do: Application.get_env(:gpt_agent, :heartbeat_interval_ms, 1000) + def handle_cast({:add_user_message, message}, state) do {:ok, message} = NonblankString.new(message) @@ -85,6 +91,24 @@ defmodule GptAgent do |> noreply({:continue, :run}) end + def handle_info({:check_run_status, id}, state) do + {:ok, %{body: %{"status" => status}}} = + OpenAiClient.get("/v1/threads/#{state.thread_id}/runs/#{id}", []) + + if status == "completed" do + state + |> send_callback(%RunCompleted{ + id: id, + thread_id: state.thread_id, + assistant_id: state.assistant_id + }) + |> noreply() + else + Process.send_after(self(), {:check_run_status, id}, heartbeat_interval_ms()) + noreply(state) + end + end + @doc """ Starts the GPT Agent """ diff --git a/lib/gpt_agent/events/run_completed.ex b/lib/gpt_agent/events/run_completed.ex new file mode 100644 index 0000000..d028b69 --- /dev/null +++ b/lib/gpt_agent/events/run_completed.ex @@ -0,0 +1,12 @@ +defmodule GptAgent.Events.RunCompleted do + @moduledoc """ + An OpenAI Assistants run was completed + """ + use TypedStruct + + typedstruct do + field :id, binary(), enforce: true + field :thread_id, binary(), enforce: true + field :assistant_id, binary(), enforce: true + end +end diff --git a/test/gpt_agent_test.exs b/test/gpt_agent_test.exs index 81145c6..f1b7685 100644 --- a/test/gpt_agent_test.exs +++ b/test/gpt_agent_test.exs @@ -4,7 +4,7 @@ defmodule GptAgentTest do use ExUnit.Case doctest GptAgent - alias GptAgent.Events.{RunStarted, ThreadCreated, UserMessageAdded} + alias GptAgent.Events.{RunCompleted, RunStarted, ThreadCreated, UserMessageAdded} alias GptAgent.Values.NonblankString setup _context do @@ -16,7 +16,9 @@ defmodule GptAgentTest do openai_organization_id: "test" ) + assistant_id = UUID.uuid4() thread_id = UUID.uuid4() + run_id = UUID.uuid4() Bypass.stub(bypass, "POST", "/v1/threads", fn conn -> conn @@ -60,7 +62,50 @@ defmodule GptAgentTest do ) end) - {:ok, bypass: bypass, thread_id: thread_id} + Bypass.stub(bypass, "POST", "/v1/threads/#{thread_id}/runs", fn conn -> + conn + |> Plug.Conn.put_resp_content_type("application/json") + |> Plug.Conn.resp( + 201, + Jason.encode!(%{ + "id" => run_id, + "object" => "thread.run", + "created_at" => "1699012949", + "thread_id" => thread_id, + "assistant_id" => assistant_id, + "metadata" => %{} + }) + ) + end) + + Bypass.stub(bypass, "GET", "/v1/threads/#{thread_id}/runs/#{run_id}", fn conn -> + conn + |> Plug.Conn.put_resp_content_type("application/json") + |> Plug.Conn.resp( + 200, + Jason.encode!(%{ + "id" => run_id, + "object" => "thread.run", + "created_at" => 1_699_075_072, + "assistant_id" => assistant_id, + "thread_id" => thread_id, + "status" => "completed", + "started_at" => 1_699_075_072, + "expires_at" => nil, + "cancelled_at" => nil, + "failed_at" => nil, + "completed_at" => 1_699_075_073, + "last_error" => nil, + "model" => "gpt-4-1106-preview", + "instructions" => nil, + "tools" => [], + "file_ids" => [], + "metadata" => %{} + }) + ) + end) + + {:ok, bypass: bypass, assistant_id: assistant_id, thread_id: thread_id, run_id: run_id} end describe "start_link/2" do @@ -190,7 +235,7 @@ defmodule GptAgentTest do "object" => "thread.run", "created_at" => "1699012949", "thread_id" => thread_id, - "assistant_id" => UUID.uuid4(), + "assistant_id" => assistant_id, "metadata" => %{} }) ) @@ -208,5 +253,51 @@ defmodule GptAgentTest do assert is_binary(run_id) end + + test "when the run is finished, sends the RunFinished event to the callback handler", %{ + bypass: bypass, + assistant_id: assistant_id, + thread_id: thread_id, + run_id: run_id + } do + {:ok, pid} = GptAgent.start_link(self(), assistant_id, thread_id) + + Bypass.expect_once(bypass, "GET", "/v1/threads/#{thread_id}/runs/#{run_id}", fn conn -> + conn + |> Plug.Conn.put_resp_content_type("application/json") + |> Plug.Conn.resp( + 200, + Jason.encode!(%{ + "id" => run_id, + "object" => "thread.run", + "created_at" => 1_699_075_072, + "assistant_id" => assistant_id, + "thread_id" => thread_id, + "status" => "completed", + "started_at" => 1_699_075_072, + "expires_at" => nil, + "cancelled_at" => nil, + "failed_at" => nil, + "completed_at" => 1_699_075_073, + "last_error" => nil, + "model" => "gpt-4-1106-preview", + "instructions" => nil, + "tools" => [], + "file_ids" => [], + "metadata" => %{} + }) + ) + end) + + :ok = GptAgent.add_user_message(pid, "Hello") + + assert_receive {GptAgent, ^pid, + %RunCompleted{ + id: ^run_id, + thread_id: ^thread_id, + assistant_id: ^assistant_id + }}, + 5_000 + end end end