Skip to content

Commit

Permalink
Merge pull request #523 from ibuildthecloud/main
Browse files Browse the repository at this point in the history
feat: add input filters
  • Loading branch information
ibuildthecloud authored Jun 21, 2024
2 parents 7f17683 + 7f5bfb6 commit d7d9ca6
Show file tree
Hide file tree
Showing 18 changed files with 428 additions and 62 deletions.
3 changes: 2 additions & 1 deletion pkg/engine/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,8 @@ func appendInputAsEnv(env []string, input string) []string {
dec := json.NewDecoder(bytes.NewReader([]byte(input)))
dec.UseNumber()

env = appendEnv(env, "GPTSCRIPT_INPUT", input)

if err := json.Unmarshal([]byte(input), &data); err != nil {
// ignore invalid JSON
return env
Expand All @@ -206,7 +208,6 @@ func appendInputAsEnv(env []string, input string) []string {
}
}

env = appendEnv(env, "GPTSCRIPT_INPUT", input)
return env
}

Expand Down
3 changes: 2 additions & 1 deletion pkg/engine/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ const (
ProviderToolCategory ToolCategory = "provider"
CredentialToolCategory ToolCategory = "credential"
ContextToolCategory ToolCategory = "context"
InputToolCategory ToolCategory = "input"
NoCategory ToolCategory = ""
)

Expand Down Expand Up @@ -180,7 +181,7 @@ func NewContext(ctx context.Context, prg *types.Program, input string) Context {
return callCtx
}

func (c *Context) SubCall(ctx context.Context, input, toolID, callID string, toolCategory ToolCategory) (Context, error) {
func (c *Context) SubCallContext(ctx context.Context, input, toolID, callID string, toolCategory ToolCategory) (Context, error) {
tool, ok := c.Program.ToolSet[toolID]
if !ok {
return Context{}, fmt.Errorf("failed to file tool for id [%s]", toolID)
Expand Down
9 changes: 7 additions & 2 deletions pkg/openai/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package openai

import (
"context"
"fmt"
"io"
"log/slog"
"os"
Expand All @@ -16,6 +15,7 @@ import (
"github.com/gptscript-ai/gptscript/pkg/counter"
"github.com/gptscript-ai/gptscript/pkg/credentials"
"github.com/gptscript-ai/gptscript/pkg/hash"
"github.com/gptscript-ai/gptscript/pkg/mvl"
"github.com/gptscript-ai/gptscript/pkg/prompt"
"github.com/gptscript-ai/gptscript/pkg/system"
"github.com/gptscript-ai/gptscript/pkg/types"
Expand All @@ -29,6 +29,7 @@ const (
var (
key = os.Getenv("OPENAI_API_KEY")
url = os.Getenv("OPENAI_BASE_URL")
log = mvl.Package()
)

type InvalidAuthError struct{}
Expand Down Expand Up @@ -305,7 +306,11 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques
}

if len(msgs) == 0 {
return nil, fmt.Errorf("invalid request, no messages to send to LLM")
log.Errorf("invalid request, no messages to send to LLM")
return &types.CompletionMessage{
Role: types.CompletionMessageRoleTypeAssistant,
Content: types.Text(""),
}, nil
}

request := openai.ChatCompletionRequest{
Expand Down
11 changes: 9 additions & 2 deletions pkg/parser/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,10 @@ func isParam(line string, tool *types.Tool) (_ bool, err error) {
tool.Parameters.Export = append(tool.Parameters.Export, csv(value)...)
case "tool", "tools":
tool.Parameters.Tools = append(tool.Parameters.Tools, csv(value)...)
case "inputfilter", "inputfilters":
tool.Parameters.InputFilters = append(tool.Parameters.InputFilters, csv(value)...)
case "shareinputfilter", "shareinputfilters":
tool.Parameters.ExportInputFilters = append(tool.Parameters.ExportInputFilters, csv(value)...)
case "agent", "agents":
tool.Parameters.Agents = append(tool.Parameters.Agents, csv(value)...)
case "globaltool", "globaltools":
Expand Down Expand Up @@ -183,10 +187,13 @@ type context struct {

func (c *context) finish(tools *[]Node) {
c.tool.Instructions = strings.TrimSpace(strings.Join(c.instructions, ""))
if c.tool.Instructions != "" || c.tool.Parameters.Name != "" ||
len(c.tool.Export) > 0 || len(c.tool.Tools) > 0 ||
if c.tool.Instructions != "" ||
c.tool.Parameters.Name != "" ||
len(c.tool.Export) > 0 ||
len(c.tool.Tools) > 0 ||
c.tool.GlobalModelName != "" ||
len(c.tool.GlobalTools) > 0 ||
len(c.tool.ExportInputFilters) > 0 ||
c.tool.Chat {
*tools = append(*tools, Node{
ToolNode: &ToolNode{
Expand Down
24 changes: 24 additions & 0 deletions pkg/parser/parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,3 +191,27 @@ name: bad
},
}}).Equal(t, out)
}

func TestParseInput(t *testing.T) {
input := `
input filters: input
share input filters: shared
`
out, err := Parse(strings.NewReader(input))
require.NoError(t, err)
autogold.Expect(Document{Nodes: []Node{
{ToolNode: &ToolNode{
Tool: types.Tool{
ToolDef: types.ToolDef{
Parameters: types.Parameters{
InputFilters: []string{
"input",
},
ExportInputFilters: []string{"shared"},
},
},
Source: types.ToolSource{LineNo: 1},
},
}},
}}).Equal(t, out)
}
27 changes: 27 additions & 0 deletions pkg/runner/input.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package runner

import (
"fmt"

"github.com/gptscript-ai/gptscript/pkg/engine"
)

func (r *Runner) handleInput(callCtx engine.Context, monitor Monitor, env []string, input string) (string, error) {
inputToolRefs, err := callCtx.Tool.GetInputFilterTools(*callCtx.Program)
if err != nil {
return "", err
}

for _, inputToolRef := range inputToolRefs {
res, err := r.subCall(callCtx.Ctx, callCtx, monitor, env, inputToolRef.ToolID, input, "", engine.InputToolCategory)
if err != nil {
return "", err
}
if res.Result == nil {
return "", fmt.Errorf("invalid state: input tool [%s] can not result in a chat continuation", inputToolRef.Reference)
}
input = *res.Result
}

return input, nil
}
34 changes: 23 additions & 11 deletions pkg/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,11 @@ func (r *Runner) start(callCtx engine.Context, state *State, monitor Monitor, en
Content: input,
})

input, err := r.handleInput(callCtx, monitor, env, input)
if err != nil {
return nil, err
}

if len(callCtx.Tool.Credentials) > 0 {
var err error
env, err = r.handleCredentials(callCtx, monitor, env)
Expand All @@ -417,7 +422,6 @@ func (r *Runner) start(callCtx engine.Context, state *State, monitor Monitor, en
}

var (
err error
newState *State
)
callCtx.InputContext, newState, err = r.getContext(callCtx, state, monitor, env, input)
Expand Down Expand Up @@ -446,7 +450,10 @@ func (r *Runner) start(callCtx engine.Context, state *State, monitor Monitor, en
}

if !authResp.Accept {
msg := fmt.Sprintf("[AUTHORIZATION ERROR]: %s", authResp.Message)
msg := authResp.Message
if msg == "" {
msg = "Tool call request has been denied"
}
return &State{
Continuation: &engine.Return{
Result: &msg,
Expand Down Expand Up @@ -631,8 +638,12 @@ func (r *Runner) resume(callCtx engine.Context, monitor Monitor, env []string, s
}

if state.ResumeInput != nil {
input, err := r.handleInput(callCtx, monitor, env, *state.ResumeInput)
if err != nil {
return state, err
}
engineResults = append(engineResults, engine.CallResult{
User: *state.ResumeInput,
User: input,
})
}

Expand Down Expand Up @@ -689,16 +700,22 @@ func streamProgress(callCtx *engine.Context, monitor Monitor) (chan<- types.Comp
}

func (r *Runner) subCall(ctx context.Context, parentContext engine.Context, monitor Monitor, env []string, toolID, input, callID string, toolCategory engine.ToolCategory) (*State, error) {
callCtx, err := parentContext.SubCall(ctx, input, toolID, callID, toolCategory)
callCtx, err := parentContext.SubCallContext(ctx, input, toolID, callID, toolCategory)
if err != nil {
return nil, err
}

if toolCategory == engine.ContextToolCategory && callCtx.Tool.IsNoop() {
return &State{
Result: new(string),
}, nil
}

return r.call(callCtx, monitor, env, input)
}

func (r *Runner) subCallResume(ctx context.Context, parentContext engine.Context, monitor Monitor, env []string, toolID, callID string, state *State, toolCategory engine.ToolCategory) (*State, error) {
callCtx, err := parentContext.SubCall(ctx, "", toolID, callID, toolCategory)
callCtx, err := parentContext.SubCallContext(ctx, "", toolID, callID, toolCategory)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -882,12 +899,7 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env
input = string(inputBytes)
}

subCtx, err := callCtx.SubCall(callCtx.Ctx, input, credToolRefs[0].ToolID, "", engine.CredentialToolCategory) // leaving callID as "" will cause it to be set by the engine
if err != nil {
return nil, fmt.Errorf("failed to create subcall context for tool %s: %w", credToolName, err)
}

res, err := r.call(subCtx, monitor, env, input)
res, err := r.subCall(callCtx.Ctx, callCtx, monitor, env, credToolRefs[0].ToolID, input, "", engine.CredentialToolCategory)
if err != nil {
return nil, fmt.Errorf("failed to run credential tool %s: %w", credToolName, err)
}
Expand Down
25 changes: 25 additions & 0 deletions pkg/tests/runner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -822,3 +822,28 @@ func TestAgents(t *testing.T) {
autogold.Expect("TEST RESULT CALL: 4").Equal(t, resp.Content)
autogold.ExpectFile(t, toJSONString(t, resp), autogold.Name(t.Name()+"/step1"))
}

func TestInput(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip()
}

r := tester.NewRunner(t)

prg, err := r.Load("")
require.NoError(t, err)

resp, err := r.Chat(context.Background(), nil, prg, nil, "You're stupid")
require.NoError(t, err)
r.AssertResponded(t)
assert.False(t, resp.Done)
autogold.Expect("TEST RESULT CALL: 1").Equal(t, resp.Content)
autogold.ExpectFile(t, toJSONString(t, resp), autogold.Name(t.Name()+"/step1"))

resp, err = r.Chat(context.Background(), resp.State, prg, nil, "You're ugly")
require.NoError(t, err)
r.AssertResponded(t)
assert.False(t, resp.Done)
autogold.Expect("TEST RESULT CALL: 2").Equal(t, resp.Content)
autogold.ExpectFile(t, toJSONString(t, resp), autogold.Name(t.Name()+"/step2"))
}
9 changes: 9 additions & 0 deletions pkg/tests/testdata/TestInput/call1-resp.golden
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
`{
"role": "assistant",
"content": [
{
"text": "TEST RESULT CALL: 1"
}
],
"usage": {}
}`
24 changes: 24 additions & 0 deletions pkg/tests/testdata/TestInput/call1.golden
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
`{
"model": "gpt-4o",
"internalSystemPrompt": false,
"messages": [
{
"role": "system",
"content": [
{
"text": "\nTool body"
}
],
"usage": {}
},
{
"role": "user",
"content": [
{
"text": "No, You're stupid!\n ha ha ha\n"
}
],
"usage": {}
}
]
}`
9 changes: 9 additions & 0 deletions pkg/tests/testdata/TestInput/call2-resp.golden
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
`{
"role": "assistant",
"content": [
{
"text": "TEST RESULT CALL: 2"
}
],
"usage": {}
}`
42 changes: 42 additions & 0 deletions pkg/tests/testdata/TestInput/call2.golden
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
`{
"model": "gpt-4o",
"internalSystemPrompt": false,
"messages": [
{
"role": "system",
"content": [
{
"text": "\nTool body"
}
],
"usage": {}
},
{
"role": "user",
"content": [
{
"text": "No, You're stupid!\n ha ha ha\n"
}
],
"usage": {}
},
{
"role": "assistant",
"content": [
{
"text": "TEST RESULT CALL: 1"
}
],
"usage": {}
},
{
"role": "user",
"content": [
{
"text": "No, You're ugly!\n ha ha ha\n"
}
],
"usage": {}
}
]
}`
Loading

0 comments on commit d7d9ca6

Please sign in to comment.