From e9c2bf91a55dfb3f4ad272689aba84cdf3e3f6e8 Mon Sep 17 00:00:00 2001 From: Darren Shepherd Date: Thu, 20 Jun 2024 10:18:35 -0700 Subject: [PATCH] chore: add progress output for builtins specifically sys.exec --- pkg/builtin/builtin.go | 73 ++++++++++++++++++++++++++----------- pkg/builtin/builtin_test.go | 6 ++- pkg/engine/cmd.go | 26 ++++++++++++- pkg/prompt/credential.go | 3 +- pkg/prompt/prompt.go | 2 +- pkg/types/tool.go | 2 +- 6 files changed, 84 insertions(+), 28 deletions(-) diff --git a/pkg/builtin/builtin.go b/pkg/builtin/builtin.go index 989f523c..291384c3 100644 --- a/pkg/builtin/builtin.go +++ b/pkg/builtin/builtin.go @@ -1,6 +1,7 @@ package builtin import ( + "bytes" "context" "encoding/json" "errors" @@ -264,7 +265,7 @@ func Builtin(name string) (types.Tool, bool) { return SetDefaults(t), ok } -func SysFind(_ context.Context, _ []string, input string) (string, error) { +func SysFind(_ context.Context, _ []string, input string, _ chan<- string) (string, error) { var result []string var params struct { Pattern string `json:"pattern,omitempty"` @@ -305,7 +306,7 @@ func SysFind(_ context.Context, _ []string, input string) (string, error) { return strings.Join(result, "\n"), nil } -func SysExec(_ context.Context, env []string, input string) (string, error) { +func SysExec(_ context.Context, env []string, input string, progress chan<- string) (string, error) { var params struct { Command string `json:"command,omitempty"` Directory string `json:"directory,omitempty"` @@ -328,13 +329,30 @@ func SysExec(_ context.Context, env []string, input string) (string, error) { cmd = exec.Command("/bin/sh", "-c", params.Command) } + var ( + out bytes.Buffer + pw = progressWriter{ + out: progress, + } + combined = io.MultiWriter(&out, &pw) + ) cmd.Env = env cmd.Dir = params.Directory - out, err := cmd.CombinedOutput() - if err != nil { - return fmt.Sprintf("ERROR: %s\nOUTPUT:\n%s", err, out), nil + cmd.Stdout = combined + cmd.Stderr = combined + if err := cmd.Run(); err != nil { + return fmt.Sprintf("ERROR: %s\nOUTPUT:\n%s", err, &out), nil } - return string(out), nil + return out.String(), nil +} + +type progressWriter struct { + out chan<- string +} + +func (pw *progressWriter) Write(p []byte) (n int, err error) { + pw.out <- string(p) + return len(p), nil } func getWorkspaceDir(envs []string) (string, error) { @@ -347,7 +365,7 @@ func getWorkspaceDir(envs []string) (string, error) { return "", fmt.Errorf("no workspace directory found in env") } -func SysLs(_ context.Context, _ []string, input string) (string, error) { +func SysLs(_ context.Context, _ []string, input string, _ chan<- string) (string, error) { var params struct { Dir string `json:"dir,omitempty"` } @@ -383,7 +401,7 @@ func SysLs(_ context.Context, _ []string, input string) (string, error) { return strings.Join(result, "\n"), nil } -func SysRead(_ context.Context, _ []string, input string) (string, error) { +func SysRead(_ context.Context, _ []string, input string, _ chan<- string) (string, error) { var params struct { Filename string `json:"filename,omitempty"` } @@ -411,7 +429,7 @@ func SysRead(_ context.Context, _ []string, input string) (string, error) { return string(data), nil } -func SysWrite(_ context.Context, _ []string, input string) (string, error) { +func SysWrite(_ context.Context, _ []string, input string, _ chan<- string) (string, error) { var params struct { Filename string `json:"filename,omitempty"` Content string `json:"content,omitempty"` @@ -443,7 +461,7 @@ func SysWrite(_ context.Context, _ []string, input string) (string, error) { return fmt.Sprintf("Wrote (%d) bytes to file %s", len(data), file), nil } -func SysAppend(_ context.Context, _ []string, input string) (string, error) { +func SysAppend(_ context.Context, _ []string, input string, _ chan<- string) (string, error) { var params struct { Filename string `json:"filename,omitempty"` Content string `json:"content,omitempty"` @@ -489,7 +507,7 @@ func fixQueries(u string) string { return url.String() } -func SysHTTPGet(_ context.Context, _ []string, input string) (_ string, err error) { +func SysHTTPGet(_ context.Context, _ []string, input string, _ chan<- string) (_ string, err error) { var params struct { URL string `json:"url,omitempty"` } @@ -523,8 +541,8 @@ func SysHTTPGet(_ context.Context, _ []string, input string) (_ string, err erro return string(data), nil } -func SysHTTPHtml2Text(ctx context.Context, env []string, input string) (string, error) { - content, err := SysHTTPGet(ctx, env, input) +func SysHTTPHtml2Text(ctx context.Context, env []string, input string, progress chan<- string) (string, error) { + content, err := SysHTTPGet(ctx, env, input, progress) if err != nil { return "", err } @@ -533,7 +551,7 @@ func SysHTTPHtml2Text(ctx context.Context, env []string, input string) (string, }) } -func SysHTTPPost(ctx context.Context, _ []string, input string) (_ string, err error) { +func SysHTTPPost(ctx context.Context, _ []string, input string, _ chan<- string) (_ string, err error) { var params struct { URL string `json:"url,omitempty"` Content string `json:"content,omitempty"` @@ -569,7 +587,18 @@ func SysHTTPPost(ctx context.Context, _ []string, input string) (_ string, err e return fmt.Sprintf("Wrote %d to %s", len([]byte(params.Content)), params.URL), nil } -func SysGetenv(_ context.Context, env []string, input string) (string, error) { +func DiscardProgress() (progress chan<- string, closeFunc func()) { + ch := make(chan string) + go func() { + for range ch { + } + }() + return ch, func() { + close(ch) + } +} + +func SysGetenv(_ context.Context, env []string, input string, _ chan<- string) (string, error) { var params struct { Name string `json:"name,omitempty"` } @@ -597,7 +626,7 @@ func invalidArgument(input string, err error) string { return fmt.Sprintf("Failed to parse arguments %s: %v", input, err) } -func SysChatHistory(ctx context.Context, _ []string, _ string) (string, error) { +func SysChatHistory(ctx context.Context, _ []string, _ string, _ chan<- string) (string, error) { engineContext, _ := engine.FromContext(ctx) data, err := json.Marshal(engine.ChatHistory{ @@ -627,7 +656,7 @@ func writeHistory(ctx *engine.Context) (result []engine.ChatHistoryCall) { return } -func SysChatFinish(_ context.Context, _ []string, input string) (string, error) { +func SysChatFinish(_ context.Context, _ []string, input string, _ chan<- string) (string, error) { var params struct { Message string `json:"return,omitempty"` } @@ -641,7 +670,7 @@ func SysChatFinish(_ context.Context, _ []string, input string) (string, error) } } -func SysAbort(_ context.Context, _ []string, input string) (string, error) { +func SysAbort(_ context.Context, _ []string, input string, _ chan<- string) (string, error) { var params struct { Message string `json:"message,omitempty"` } @@ -651,7 +680,7 @@ func SysAbort(_ context.Context, _ []string, input string) (string, error) { return "", fmt.Errorf("ABORT: %s", params.Message) } -func SysRemove(_ context.Context, _ []string, input string) (string, error) { +func SysRemove(_ context.Context, _ []string, input string, _ chan<- string) (string, error) { var params struct { Location string `json:"location,omitempty"` } @@ -670,7 +699,7 @@ func SysRemove(_ context.Context, _ []string, input string) (string, error) { return fmt.Sprintf("Removed file: %s", params.Location), nil } -func SysStat(_ context.Context, _ []string, input string) (string, error) { +func SysStat(_ context.Context, _ []string, input string, _ chan<- string) (string, error) { var params struct { Filepath string `json:"filepath,omitempty"` } @@ -690,7 +719,7 @@ func SysStat(_ context.Context, _ []string, input string) (string, error) { return fmt.Sprintf("%s %s mode: %s, size: %d bytes, modtime: %s", title, params.Filepath, stat.Mode().String(), stat.Size(), stat.ModTime().String()), nil } -func SysDownload(_ context.Context, env []string, input string) (_ string, err error) { +func SysDownload(_ context.Context, env []string, input string, _ chan<- string) (_ string, err error) { var params struct { URL string `json:"url,omitempty"` Location string `json:"location,omitempty"` @@ -763,6 +792,6 @@ func SysDownload(_ context.Context, env []string, input string) (_ string, err e return fmt.Sprintf("Downloaded %s to %s", params.URL, params.Location), nil } -func SysTimeNow(context.Context, []string, string) (string, error) { +func SysTimeNow(context.Context, []string, string, chan<- string) (string, error) { return time.Now().Format(time.RFC3339), nil } diff --git a/pkg/builtin/builtin_test.go b/pkg/builtin/builtin_test.go index 313b9718..c12a68f6 100644 --- a/pkg/builtin/builtin_test.go +++ b/pkg/builtin/builtin_test.go @@ -10,15 +10,17 @@ import ( ) func TestSysGetenv(t *testing.T) { + p, c := DiscardProgress() + defer c() v, err := SysGetenv(context.Background(), []string{ "MAGIC=VALUE", - }, `{"name":"MAGIC"}`) + }, `{"name":"MAGIC"}`, nil) require.NoError(t, err) autogold.Expect("VALUE").Equal(t, v) v, err = SysGetenv(context.Background(), []string{ "MAGIC=VALUE", - }, `{"name":"MAGIC2"}`) + }, `{"name":"MAGIC2"}`, p) require.NoError(t, err) autogold.Expect("MAGIC2 is not set or has no value").Equal(t, v) } diff --git a/pkg/engine/cmd.go b/pkg/engine/cmd.go index 4a697c69..3707205e 100644 --- a/pkg/engine/cmd.go +++ b/pkg/engine/cmd.go @@ -12,6 +12,7 @@ import ( "runtime" "sort" "strings" + "sync" "github.com/google/shlex" "github.com/gptscript-ai/gptscript/pkg/counter" @@ -64,7 +65,30 @@ func (e *Engine) runCommand(ctx Context, tool types.Tool, input string, toolCate "input": input, }, } - return tool.BuiltinFunc(ctx.WrappedContext(), e.Env, input) + + var ( + progress = make(chan string) + wg sync.WaitGroup + ) + wg.Add(1) + defer wg.Wait() + defer close(progress) + go func() { + defer wg.Done() + buf := strings.Builder{} + for line := range progress { + buf.WriteString(line) + e.Progress <- types.CompletionStatus{ + CompletionID: id, + PartialResponse: &types.CompletionMessage{ + Role: types.CompletionMessageRoleTypeAssistant, + Content: types.Text(buf.String()), + }, + } + } + }() + + return tool.BuiltinFunc(ctx.WrappedContext(), e.Env, input, progress) } var instructions []string diff --git a/pkg/prompt/credential.go b/pkg/prompt/credential.go index a8bf6f76..9202ed49 100644 --- a/pkg/prompt/credential.go +++ b/pkg/prompt/credential.go @@ -18,7 +18,8 @@ func GetModelProviderCredential(ctx context.Context, credStore credentials.Crede if exists { k = cred.Env[env] } else { - result, err := SysPrompt(ctx, envs, fmt.Sprintf(`{"message":"%s","fields":"key","sensitive":"true"}`, message)) + // we know progress isn't used so pass as nil + result, err := SysPrompt(ctx, envs, fmt.Sprintf(`{"message":"%s","fields":"key","sensitive":"true"}`, message), nil) if err != nil { return "", err } diff --git a/pkg/prompt/prompt.go b/pkg/prompt/prompt.go index 6cf8febd..4a9550a3 100644 --- a/pkg/prompt/prompt.go +++ b/pkg/prompt/prompt.go @@ -48,7 +48,7 @@ func sysPromptHTTP(ctx context.Context, envs []string, url string, prompt types. return string(data), err } -func SysPrompt(ctx context.Context, envs []string, input string) (_ string, err error) { +func SysPrompt(ctx context.Context, envs []string, input string, _ chan<- string) (_ string, err error) { var params struct { Message string `json:"message,omitempty"` Fields string `json:"fields,omitempty"` diff --git a/pkg/types/tool.go b/pkg/types/tool.go index 6c016d82..9468a04a 100644 --- a/pkg/types/tool.go +++ b/pkg/types/tool.go @@ -117,7 +117,7 @@ func (p Program) SetBlocking() Program { return p } -type BuiltinFunc func(ctx context.Context, env []string, input string) (string, error) +type BuiltinFunc func(ctx context.Context, env []string, input string, progress chan<- string) (string, error) type Parameters struct { Name string `json:"name,omitempty"`