Skip to content

Commit

Permalink
Submit tool output as a cast vs. call
Browse files Browse the repository at this point in the history
Submitting the tool output as a cast rather than a call provides better
performance and prevents the caller from potentially timing out when it
tries to submit multiple tool calls at once.
  • Loading branch information
jwilger committed Feb 18, 2024
1 parent ccab0a3 commit 27f1728
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 106 deletions.
56 changes: 27 additions & 29 deletions lib/gpt_agent.ex
Original file line number Diff line number Diff line change
Expand Up @@ -248,50 +248,27 @@ defmodule GptAgent do
|> noreply({:continue, :run})
end

@impl true
def handle_call(:run_in_progress?, _caller, %__MODULE__{} = state) do
reply(state, state.running?)
end

def handle_call(:shutdown, _caller, %__MODULE__{} = state) do
log("Shutting down")
Registry.unregister(GptAgent.Registry, state.thread_id)
stop(state)
end

def handle_call(:thread_id, _caller, %__MODULE__{} = state) do
log("Returning thread ID #{inspect(state.thread_id)}")
reply(state, {:ok, state.thread_id})
end

def handle_call(:assistant_id, _caller, %__MODULE__{} = state) do
log("Returning default assistant ID #{inspect(state.assistant_id)}")
reply(state, {:ok, state.assistant_id})
end

def handle_call(
def handle_cast(
{:submit_tool_output, tool_call_id, tool_output},
_caller,
%__MODULE__{running?: false} = state
) do
log(
"Attempting to submit tool output, but no run in progress, cannot submit tool output for call #{inspect(tool_call_id)}: #{inspect(tool_output)}"
)

reply(state, {:error, :run_not_in_progress})
noreply(state)
end

def handle_call(
def handle_cast(
{:submit_tool_output, tool_call_id, tool_output},
_caller,
%__MODULE__{} = state
) do
log("Submitting tool output #{inspect(tool_output)}")

case Enum.find_index(state.tool_calls, fn %ToolCallRequested{id: id} -> id == tool_call_id end) do
nil ->
log("Tool call ID #{inspect(tool_call_id)} not found")
reply(state, {:error, :invalid_tool_call_id})
noreply(state)

index ->
log("Tool call ID #{inspect(tool_call_id)} found at index #{inspect(index)}")
Expand All @@ -313,7 +290,7 @@ defmodule GptAgent do
|> Map.put(:tool_calls, tool_calls)
|> Map.put(:tool_outputs, tool_outputs)
|> possibly_send_outputs_to_openai()
|> reply(:ok)
|> noreply()
end
end

Expand All @@ -335,6 +312,27 @@ defmodule GptAgent do

defp possibly_send_outputs_to_openai(%__MODULE__{} = state), do: state

@impl true
def handle_call(:run_in_progress?, _caller, %__MODULE__{} = state) do
reply(state, state.running?)
end

def handle_call(:shutdown, _caller, %__MODULE__{} = state) do
log("Shutting down")
Registry.unregister(GptAgent.Registry, state.thread_id)
stop(state)
end

def handle_call(:thread_id, _caller, %__MODULE__{} = state) do
log("Returning thread ID #{inspect(state.thread_id)}")
reply(state, {:ok, state.thread_id})
end

def handle_call(:assistant_id, _caller, %__MODULE__{} = state) do
log("Returning default assistant ID #{inspect(state.assistant_id)}")
reply(state, {:ok, state.assistant_id})
end

@impl true
def handle_info(:timeout, %__MODULE__{} = state) do
log("Timeout Received")
Expand Down Expand Up @@ -595,7 +593,7 @@ defmodule GptAgent do
@impl true
def submit_tool_output(pid, tool_call_id, tool_output) do
if Process.alive?(pid) do
GenServer.call(pid, {:submit_tool_output, tool_call_id, tool_output})
GenServer.cast(pid, {:submit_tool_output, tool_call_id, tool_output})
else
handle_dead_process(pid)
end
Expand Down
2 changes: 1 addition & 1 deletion mix.exs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ defmodule GptAgent.MixProject do
def project do
[
app: :gpt_agent,
version: "8.1.1",
version: "9.0.0",
elixir: "~> 1.16",
start_permanent: Mix.env() == :prod,
aliases: aliases(),
Expand Down
111 changes: 35 additions & 76 deletions test/gpt_agent_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ defmodule GptAgentTest do
AssistantMessageAdded,
RunCompleted,
RunStarted,
ToolCallOutputRecorded,
ToolCallRequested,
UserMessageAdded
}
Expand Down Expand Up @@ -1033,82 +1034,6 @@ defmodule GptAgentTest do
GptAgent.submit_tool_output(pid, UUID.uuid4(), %{})
end

test "returns {:error, :not_running} if there is no run in progress", %{
assistant_id: assistant_id,
thread_id: thread_id
} do
{:ok, pid} =
GptAgent.connect(thread_id: thread_id, last_message_id: nil, assistant_id: assistant_id)

assert {:error, :run_not_in_progress} =
GptAgent.submit_tool_output(pid, UUID.uuid4(), %{})
end

test "returns {:error, :invalid_tool_call_id} if the tool call ID is not one of the outstanding tool calls",
%{
bypass: bypass,
assistant_id: assistant_id,
thread_id: thread_id,
run_id: run_id
} do
{:ok, pid} =
GptAgent.connect(thread_id: thread_id, last_message_id: nil, assistant_id: assistant_id)

Bypass.stub(bypass, "GET", "/v1/threads/#{thread_id}/runs/#{run_id}", fn conn ->
conn
|> Plug.Conn.put_resp_content_type("application/json")
|> Plug.Conn.resp(
200,
Jason.encode!(%{
"id" => run_id,
"object" => "thread.run",
"created_at" => 1_699_075_072,
"assistant_id" => assistant_id,
"thread_id" => thread_id,
"status" => "requires_action",
"required_action" => %{
"type" => "submit_tool_outputs",
"submit_tool_outputs" => %{
"tool_calls" => [
%{
"id" => UUID.uuid4(),
"type" => "function",
"function" => %{"name" => "tool_1", "arguments" => ~s({"foo":"bar","baz":1})}
},
%{
"id" => UUID.uuid4(),
"type" => "function",
"function" => %{
"name" => "tool_2",
"arguments" => ~s({"ham":"spam","wham":2})
}
}
]
}
},
"started_at" => 1_699_075_072,
"expires_at" => nil,
"cancelled_at" => nil,
"failed_at" => nil,
"completed_at" => 1_699_075_073,
"last_error" => nil,
"model" => "gpt-4-1106-preview",
"instructions" => nil,
"tools" => [],
"file_ids" => [],
"metadata" => %{}
})
)
end)

:ok = GptAgent.add_user_message(pid, Faker.Lorem.sentence())

assert_receive {^pid, %ToolCallRequested{}}, 5_000

assert {:error, :invalid_tool_call_id} =
GptAgent.submit_tool_output(pid, UUID.uuid4(), %{})
end

test "if there are other tool calls still outstanding, do not submit the tool calls to openai yet",
%{
bypass: bypass,
Expand Down Expand Up @@ -1352,8 +1277,42 @@ defmodule GptAgentTest do
end
)

Bypass.stub(bypass, "GET", "/v1/threads/#{thread_id}/runs/#{run_id}", fn conn ->
conn
|> Plug.Conn.put_resp_content_type("application/json")
|> Plug.Conn.resp(
200,
Jason.encode!(%{
"id" => run_id,
"object" => "thread.run",
"created_at" => 1_699_075_072,
"assistant_id" => assistant_id,
"thread_id" => thread_id,
"status" => "completed",
"started_at" => 1_699_075_072,
"expires_at" => nil,
"cancelled_at" => nil,
"failed_at" => nil,
"completed_at" => 1_699_075_073,
"last_error" => nil,
"model" => "gpt-4-1106-preview",
"instructions" => nil,
"tools" => [],
"file_ids" => [],
"metadata" => %{}
})
)
end)

:ok = GptAgent.submit_tool_output(pid, tool_2_id, %{some: "result"})

assert_receive {^pid, %ToolCallOutputRecorded{}}, 5_000

:ok = GptAgent.submit_tool_output(pid, tool_1_id, %{another: "answer"})

assert_receive {^pid, %ToolCallOutputRecorded{}}, 5_000

assert_receive {^pid, %RunCompleted{}}, 5_000
end
end

Expand Down

0 comments on commit 27f1728

Please sign in to comment.