From 19a51890efe613918d573d91ee268797c2c7ad2f Mon Sep 17 00:00:00 2001 From: Darren Shepherd Date: Fri, 19 Jul 2024 10:46:26 -0700 Subject: [PATCH] chore: return missing tool call to LLM, don't fail --- pkg/engine/engine.go | 24 +++++--- pkg/runner/runner.go | 12 ++++ pkg/tests/runner_test.go | 15 +++++ .../TestMissingTool/call1-resp.golden | 14 +++++ .../testdata/TestMissingTool/call1.golden | 32 ++++++++++ .../TestMissingTool/call2-resp.golden | 9 +++ .../testdata/TestMissingTool/call2.golden | 61 +++++++++++++++++++ pkg/tests/testdata/TestMissingTool/test.gpt | 10 +++ pkg/tests/tester/runner.go | 15 ++++- 9 files changed, 181 insertions(+), 11 deletions(-) create mode 100644 pkg/tests/testdata/TestMissingTool/call1-resp.golden create mode 100644 pkg/tests/testdata/TestMissingTool/call1.golden create mode 100644 pkg/tests/testdata/TestMissingTool/call2-resp.golden create mode 100644 pkg/tests/testdata/TestMissingTool/call2.golden create mode 100644 pkg/tests/testdata/TestMissingTool/test.gpt diff --git a/pkg/engine/engine.go b/pkg/engine/engine.go index 250e9578..0ea72ff6 100644 --- a/pkg/engine/engine.go +++ b/pkg/engine/engine.go @@ -45,8 +45,9 @@ type Return struct { } type Call struct { - ToolID string `json:"toolID,omitempty"` - Input string `json:"input,omitempty"` + Missing bool `json:"missing,omitempty"` + ToolID string `json:"toolID,omitempty"` + Input string `json:"input,omitempty"` } type CallResult struct { @@ -216,10 +217,7 @@ func NewContext(ctx context.Context, prg *types.Program, input string) (Context, } func (c *Context) SubCallContext(ctx context.Context, input, toolID, callID string, toolCategory ToolCategory) (Context, error) { - tool, ok := c.Program.ToolSet[toolID] - if !ok { - return Context{}, fmt.Errorf("failed to file tool for id [%s]", toolID) - } + tool := c.Program.ToolSet[toolID] if callID == "" { callID = counter.Next() @@ -387,19 +385,25 @@ func (e *Engine) complete(ctx context.Context, state *State) (*Return, error) { state.Pending = map[string]types.CompletionToolCall{} for _, content := range resp.Content { if content.ToolCall != nil { - var toolID string + var ( + toolID string + missing bool + ) for _, tool := range state.Completion.Tools { if strings.EqualFold(tool.Function.Name, content.ToolCall.Function.Name) { toolID = tool.Function.ToolID } } if toolID == "" { - return nil, fmt.Errorf("failed to find tool id for tool %s in tool_call result", content.ToolCall.Function.Name) + log.Debugf("failed to find tool id for tool %s in tool_call result", content.ToolCall.Function.Name) + toolID = content.ToolCall.Function.Name + missing = true } state.Pending[content.ToolCall.ID] = *content.ToolCall ret.Calls[content.ToolCall.ID] = Call{ - ToolID: toolID, - Input: content.ToolCall.Function.Arguments, + ToolID: toolID, + Missing: missing, + Input: content.ToolCall.Function.Arguments, } } else { cp := content.Text diff --git a/pkg/runner/runner.go b/pkg/runner/runner.go index 36bac826..9e8695a7 100644 --- a/pkg/runner/runner.go +++ b/pkg/runner/runner.go @@ -802,6 +802,18 @@ func (r *Runner) subCalls(callCtx engine.Context, monitor Monitor, env []string, for _, id := range ids { call := state.Continuation.Calls[id] + if call.Missing { + resultLock.Lock() + callResults = append(callResults, SubCallResult{ + ToolID: call.ToolID, + CallID: id, + State: &State{ + Result: &[]string{fmt.Sprintf("ERROR: can not call unknown tool named [%s]", call.ToolID)}[0], + }, + }) + resultLock.Unlock() + continue + } d.Run(func(ctx context.Context) error { result, err := r.subCall(ctx, callCtx, monitor, env, call.ToolID, call.Input, id, toolCategory) if err != nil { diff --git a/pkg/tests/runner_test.go b/pkg/tests/runner_test.go index 0b75c8a2..a38de6a2 100644 --- a/pkg/tests/runner_test.go +++ b/pkg/tests/runner_test.go @@ -948,3 +948,18 @@ func TestSysContext(t *testing.T) { require.Len(t, context.Call.AgentGroup, 1) assert.Equal(t, context.Call.AgentGroup[0].Named, "iAmSuperman") } + +func TestMissingTool(t *testing.T) { + r := tester.NewRunner(t) + + r.RespondWith(tester.Result{ + Func: types.CompletionFunctionCall{ + Name: "not bob", + }, + }) + + resp, err := r.Run("", "Input 1") + require.NoError(t, err) + r.AssertResponded(t) + autogold.Expect("TEST RESULT CALL: 2").Equal(t, resp) +} diff --git a/pkg/tests/testdata/TestMissingTool/call1-resp.golden b/pkg/tests/testdata/TestMissingTool/call1-resp.golden new file mode 100644 index 00000000..c9799ee8 --- /dev/null +++ b/pkg/tests/testdata/TestMissingTool/call1-resp.golden @@ -0,0 +1,14 @@ +`{ + "role": "assistant", + "content": [ + { + "toolCall": { + "id": "call_1", + "function": { + "name": "not bob" + } + } + } + ], + "usage": {} +}` diff --git a/pkg/tests/testdata/TestMissingTool/call1.golden b/pkg/tests/testdata/TestMissingTool/call1.golden new file mode 100644 index 00000000..f1bcc4f0 --- /dev/null +++ b/pkg/tests/testdata/TestMissingTool/call1.golden @@ -0,0 +1,32 @@ +`{ + "model": "gpt-4o", + "tools": [ + { + "function": { + "toolID": "testdata/TestMissingTool/test.gpt:Bob", + "name": "Bob", + "parameters": null + } + } + ], + "messages": [ + { + "role": "system", + "content": [ + { + "text": "Call tool Bob" + } + ], + "usage": {} + }, + { + "role": "user", + "content": [ + { + "text": "Input 1" + } + ], + "usage": {} + } + ] +}` diff --git a/pkg/tests/testdata/TestMissingTool/call2-resp.golden b/pkg/tests/testdata/TestMissingTool/call2-resp.golden new file mode 100644 index 00000000..997ca1b9 --- /dev/null +++ b/pkg/tests/testdata/TestMissingTool/call2-resp.golden @@ -0,0 +1,9 @@ +`{ + "role": "assistant", + "content": [ + { + "text": "TEST RESULT CALL: 2" + } + ], + "usage": {} +}` diff --git a/pkg/tests/testdata/TestMissingTool/call2.golden b/pkg/tests/testdata/TestMissingTool/call2.golden new file mode 100644 index 00000000..2fe99e81 --- /dev/null +++ b/pkg/tests/testdata/TestMissingTool/call2.golden @@ -0,0 +1,61 @@ +`{ + "model": "gpt-4o", + "tools": [ + { + "function": { + "toolID": "testdata/TestMissingTool/test.gpt:Bob", + "name": "Bob", + "parameters": null + } + } + ], + "messages": [ + { + "role": "system", + "content": [ + { + "text": "Call tool Bob" + } + ], + "usage": {} + }, + { + "role": "user", + "content": [ + { + "text": "Input 1" + } + ], + "usage": {} + }, + { + "role": "assistant", + "content": [ + { + "toolCall": { + "id": "call_1", + "function": { + "name": "not bob" + } + } + } + ], + "usage": {} + }, + { + "role": "tool", + "content": [ + { + "text": "ERROR: can not call unknown tool named [not bob]" + } + ], + "toolCall": { + "id": "call_1", + "function": { + "name": "not bob" + } + }, + "usage": {} + } + ] +}` diff --git a/pkg/tests/testdata/TestMissingTool/test.gpt b/pkg/tests/testdata/TestMissingTool/test.gpt new file mode 100644 index 00000000..2613ffd2 --- /dev/null +++ b/pkg/tests/testdata/TestMissingTool/test.gpt @@ -0,0 +1,10 @@ +tools: Bob + +Call tool Bob + +--- +name: Bob + +#!sys.echo + +You called? \ No newline at end of file diff --git a/pkg/tests/tester/runner.go b/pkg/tests/tester/runner.go index fe21ba92..775f0248 100644 --- a/pkg/tests/tester/runner.go +++ b/pkg/tests/tester/runner.go @@ -104,7 +104,20 @@ func (c *Client) Call(_ context.Context, messageRequest types.CompletionRequest, } if result.Func.Name != "" { - c.t.Fatalf("failed to find tool %s", result.Func.Name) + return &types.CompletionMessage{ + Role: types.CompletionMessageRoleTypeAssistant, + Content: []types.ContentPart{ + { + ToolCall: &types.CompletionToolCall{ + ID: fmt.Sprintf("call_%d", c.id), + Function: types.CompletionFunctionCall{ + Name: result.Func.Name, + Arguments: result.Func.Arguments, + }, + }, + }, + }, + }, nil } return &types.CompletionMessage{