diff --git a/pkg/builtin/builtin.go b/pkg/builtin/builtin.go index 67e59a60..0018a45c 100644 --- a/pkg/builtin/builtin.go +++ b/pkg/builtin/builtin.go @@ -26,6 +26,31 @@ import ( ) var tools = map[string]types.Tool{ + "sys.workspace.ls": { + Parameters: types.Parameters{ + Description: "Lists the contents of a directory relative to the current workspace", + Arguments: types.ObjectSchema( + "dir", "The directory to list"), + }, + BuiltinFunc: SysWorkspaceLs, + }, + "sys.workspace.write": { + Parameters: types.Parameters{ + Description: "Write the contents to a file relative to the current workspace", + Arguments: types.ObjectSchema( + "filename", "The name of the file to write to", + "content", "The content to write"), + }, + BuiltinFunc: SysWorkspaceWrite, + }, + "sys.workspace.read": { + Parameters: types.Parameters{ + Description: "Reads the contents of a file relative to the current workspace", + Arguments: types.ObjectSchema( + "filename", "The name of the file to read"), + }, + BuiltinFunc: SysWorkspaceRead, + }, "sys.ls": { Parameters: types.Parameters{ Description: "Lists the contents of a directory", @@ -297,7 +322,29 @@ func SysExec(ctx context.Context, env []string, input string) (string, error) { return string(out), err } +func getWorkspaceDir(envs []string) (string, error) { + for _, env := range envs { + dir, ok := strings.CutPrefix(env, "GPTSCRIPT_WORKSPACE_DIR=") + if ok && dir != "" { + return dir, nil + } + } + return "", fmt.Errorf("no workspace directory found in env") +} + +func SysWorkspaceLs(_ context.Context, env []string, input string) (string, error) { + dir, err := getWorkspaceDir(env) + if err != nil { + return "", err + } + return sysLs(dir, input) +} + func SysLs(_ context.Context, _ []string, input string) (string, error) { + return sysLs("", input) +} + +func sysLs(base, input string) (string, error) { var params struct { Dir string `json:"dir,omitempty"` } @@ -305,11 +352,16 @@ func SysLs(_ context.Context, _ []string, input string) (string, error) { return "", err } - if params.Dir == "" { - params.Dir = "." + dir := params.Dir + if dir == "" { + dir = "." + } + + if base != "" { + dir = filepath.Join(base, dir) } - entries, err := os.ReadDir(params.Dir) + entries, err := os.ReadDir(dir) if errors.Is(err, fs.ErrNotExist) { return fmt.Sprintf("directory does not exist: %s", params.Dir), nil } else if err != nil { @@ -328,7 +380,20 @@ func SysLs(_ context.Context, _ []string, input string) (string, error) { return strings.Join(result, "\n"), nil } +func SysWorkspaceRead(ctx context.Context, env []string, input string) (string, error) { + dir, err := getWorkspaceDir(env) + if err != nil { + return "", err + } + + return sysRead(ctx, dir, env, input) +} + func SysRead(ctx context.Context, env []string, input string) (string, error) { + return sysRead(ctx, "", env, input) +} + +func sysRead(ctx context.Context, base string, env []string, input string) (string, error) { var params struct { Filename string `json:"filename,omitempty"` } @@ -336,12 +401,17 @@ func SysRead(ctx context.Context, env []string, input string) (string, error) { return "", err } + file := params.Filename + if base != "" { + file = filepath.Join(base, file) + } + // Lock the file to prevent concurrent writes from other tool calls. - locker.RLock(params.Filename) - defer locker.RUnlock(params.Filename) + locker.RLock(file) + defer locker.RUnlock(file) - log.Debugf("Reading file %s", params.Filename) - data, err := os.ReadFile(params.Filename) + log.Debugf("Reading file %s", file) + data, err := os.ReadFile(file) if errors.Is(err, fs.ErrNotExist) { return fmt.Sprintf("The file %s does not exist", params.Filename), nil } else if err != nil { @@ -354,7 +424,19 @@ func SysRead(ctx context.Context, env []string, input string) (string, error) { return string(data), nil } +func SysWorkspaceWrite(ctx context.Context, env []string, input string) (string, error) { + dir, err := getWorkspaceDir(env) + if err != nil { + return "", err + } + return sysWrite(ctx, dir, env, input) +} + func SysWrite(ctx context.Context, env []string, input string) (string, error) { + return sysWrite(ctx, "", env, input) +} + +func sysWrite(ctx context.Context, base string, env []string, input string) (string, error) { var params struct { Filename string `json:"filename,omitempty"` Content string `json:"content,omitempty"` @@ -363,11 +445,16 @@ func SysWrite(ctx context.Context, env []string, input string) (string, error) { return "", err } + file := params.Filename + if base != "" { + file = filepath.Join(base, file) + } + // Lock the file to prevent concurrent writes from other tool calls. - locker.Lock(params.Filename) - defer locker.Unlock(params.Filename) + locker.Lock(file) + defer locker.Unlock(file) - dir := filepath.Dir(params.Filename) + dir := filepath.Dir(file) if _, err := os.Stat(dir); errors.Is(err, fs.ErrNotExist) { log.Debugf("Creating dir %s", dir) if err := os.MkdirAll(dir, 0755); err != nil { @@ -375,16 +462,16 @@ func SysWrite(ctx context.Context, env []string, input string) (string, error) { } } - if _, err := os.Stat(params.Filename); err == nil { + if _, err := os.Stat(file); err == nil { if err := confirm.Promptf(ctx, "Overwrite: %s", params.Filename); err != nil { return "", err } } data := []byte(params.Content) - log.Debugf("Wrote %d bytes to file %s", len(data), params.Filename) + log.Debugf("Wrote %d bytes to file %s", len(data), file) - return "", os.WriteFile(params.Filename, data, 0644) + return "", os.WriteFile(file, data, 0644) } func SysAppend(ctx context.Context, env []string, input string) (string, error) { diff --git a/pkg/chat/chat.go b/pkg/chat/chat.go index e1008374..23c9d132 100644 --- a/pkg/chat/chat.go +++ b/pkg/chat/chat.go @@ -35,7 +35,7 @@ func Start(ctx context.Context, prevState runner.ChatState, chatter Chatter, prg prompter Prompter ) - prompter, err := newReadlinePrompter() + prompter, err := newReadlinePrompter(prg) if err != nil { return err } diff --git a/pkg/chat/readline.go b/pkg/chat/readline.go index 029ddb90..bf0e779f 100644 --- a/pkg/chat/readline.go +++ b/pkg/chat/readline.go @@ -9,6 +9,7 @@ import ( "github.com/adrg/xdg" "github.com/chzyer/readline" "github.com/fatih/color" + "github.com/gptscript-ai/gptscript/pkg/hash" "github.com/gptscript-ai/gptscript/pkg/mvl" ) @@ -18,8 +19,13 @@ type readlinePrompter struct { readliner *readline.Instance } -func newReadlinePrompter() (*readlinePrompter, error) { - historyFile, err := xdg.CacheFile("gptscript/chat.history") +func newReadlinePrompter(prg GetProgram) (*readlinePrompter, error) { + targetProgram, err := prg() + if err != nil { + return nil, err + } + + historyFile, err := xdg.CacheFile(fmt.Sprintf("gptscript/chat-%s.history", hash.ID(targetProgram.EntryToolID))) if err != nil { historyFile = "" } diff --git a/pkg/cli/gptscript.go b/pkg/cli/gptscript.go index 095131d7..caebc68c 100644 --- a/pkg/cli/gptscript.go +++ b/pkg/cli/gptscript.go @@ -61,6 +61,7 @@ type GPTScript struct { CredentialOverride string `usage:"Credentials to override (ex: --credential-override github.com/example/cred-tool:API_TOKEN=1234)"` ChatState string `usage:"The chat state to continue, or null to start a new chat and return the state"` ForceChat bool `usage:"Force an interactive chat session if even the top level tool is not a chat tool"` + Workspace string `usage:"Directory to use for the workspace, if specified it will not be deleted on exit"` readData []byte } @@ -123,6 +124,7 @@ func (r *GPTScript) NewGPTScriptOpts() (gptscript.Options, error) { Quiet: r.Quiet, Env: os.Environ(), CredentialContext: r.CredentialContext, + Workspace: r.Workspace, } if r.Ports != "" { diff --git a/pkg/gptscript/gptscript.go b/pkg/gptscript/gptscript.go index ff9c70f2..4f638ea9 100644 --- a/pkg/gptscript/gptscript.go +++ b/pkg/gptscript/gptscript.go @@ -2,13 +2,16 @@ package gptscript import ( "context" + "fmt" "os" "github.com/gptscript-ai/gptscript/pkg/builtin" "github.com/gptscript-ai/gptscript/pkg/cache" "github.com/gptscript-ai/gptscript/pkg/engine" + "github.com/gptscript-ai/gptscript/pkg/hash" "github.com/gptscript-ai/gptscript/pkg/llm" "github.com/gptscript-ai/gptscript/pkg/monitor" + "github.com/gptscript-ai/gptscript/pkg/mvl" "github.com/gptscript-ai/gptscript/pkg/openai" "github.com/gptscript-ai/gptscript/pkg/remote" "github.com/gptscript-ai/gptscript/pkg/repos/runtimes" @@ -16,9 +19,13 @@ import ( "github.com/gptscript-ai/gptscript/pkg/types" ) +var log = mvl.Package() + type GPTScript struct { - Registry *llm.Registry - Runner *runner.Runner + Registry *llm.Registry + Runner *runner.Runner + WorkspacePath string + DeleteWorkspaceOnClose bool } type Options struct { @@ -26,9 +33,10 @@ type Options struct { OpenAI openai.Options Monitor monitor.Options Runner runner.Options - CredentialContext string `usage:"Context name in which to store credentials" default:"default"` - Quiet *bool `usage:"No output logging (set --quiet=false to force on even when there is no TTY)" short:"q"` - Env []string `usage:"-"` + CredentialContext string + Quiet *bool + Workspace string + Env []string } func complete(opts *Options) (result *Options) { @@ -89,21 +97,55 @@ func New(opts *Options) (*GPTScript, error) { } return &GPTScript{ - Registry: registry, - Runner: runner, + Registry: registry, + Runner: runner, + WorkspacePath: opts.Workspace, + DeleteWorkspaceOnClose: opts.Workspace == "", }, nil } -func (g *GPTScript) Chat(ctx context.Context, prevState runner.ChatState, prg types.Program, env []string, input string) (runner.ChatResponse, error) { - return g.Runner.Chat(ctx, prevState, prg, env, input) +func (g *GPTScript) getEnv(env []string) ([]string, error) { + if g.WorkspacePath == "" { + var err error + g.WorkspacePath, err = os.MkdirTemp("", "gptscript-workspace-*") + if err != nil { + return nil, err + } + } + if err := os.MkdirAll(g.WorkspacePath, 0700); err != nil { + return nil, err + } + return append([]string{ + fmt.Sprintf("GPTSCRIPT_WORKSPACE_DIR=%s", g.WorkspacePath), + fmt.Sprintf("GPTSCRIPT_WORKSPACE_ID=%s", hash.ID(g.WorkspacePath)), + }, env...), nil +} + +func (g *GPTScript) Chat(ctx context.Context, prevState runner.ChatState, prg types.Program, envs []string, input string) (runner.ChatResponse, error) { + envs, err := g.getEnv(envs) + if err != nil { + return runner.ChatResponse{}, err + } + + return g.Runner.Chat(ctx, prevState, prg, envs, input) } func (g *GPTScript) Run(ctx context.Context, prg types.Program, envs []string, input string) (string, error) { + envs, err := g.getEnv(envs) + if err != nil { + return "", err + } + return g.Runner.Run(ctx, prg, envs, input) } func (g *GPTScript) Close() { g.Runner.Close() + if g.DeleteWorkspaceOnClose && g.WorkspacePath != "" { + if err := os.RemoveAll(g.WorkspacePath); err != nil { + log.Errorf("failed to delete workspace %s: %s", g.WorkspacePath, err) + } + } } func (g *GPTScript) GetModel() engine.Model { diff --git a/pkg/parser/parser.go b/pkg/parser/parser.go index 25ebab1b..33f42c80 100644 --- a/pkg/parser/parser.go +++ b/pkg/parser/parser.go @@ -79,7 +79,7 @@ func isParam(line string, tool *types.Tool) (_ bool, err error) { value = strings.TrimSpace(value) switch normalize(key) { case "name": - tool.Parameters.Name = strings.ToLower(value) + tool.Parameters.Name = value case "modelprovider": tool.Parameters.ModelProvider = true case "model", "modelname": diff --git a/pkg/runner/runner.go b/pkg/runner/runner.go index 0da12ebf..4a4f0251 100644 --- a/pkg/runner/runner.go +++ b/pkg/runner/runner.go @@ -419,28 +419,6 @@ func (r *Runner) resume(callCtx engine.Context, monitor Monitor, env []string, s } } - var ( - err error - contentInput string - ) - - if state.Continuation != nil && state.Continuation.State != nil { - contentInput = state.Continuation.State.Input - } - - callCtx.InputContext, err = r.getContext(callCtx, monitor, env, contentInput) - if err != nil { - return nil, err - } - - e := engine.Engine{ - Model: r.c, - RuntimeManager: r.runtimeManager, - Progress: progress, - Env: env, - Ports: &r.ports, - } - for { if state.Continuation.Result != nil && len(state.Continuation.Calls) == 0 && state.SubCallID == "" && state.ResumeInput == nil { progressClose() @@ -512,6 +490,27 @@ func (r *Runner) resume(callCtx engine.Context, monitor Monitor, env []string, s ToolResults: len(callResults), }) + e := engine.Engine{ + Model: r.c, + RuntimeManager: r.runtimeManager, + Progress: progress, + Env: env, + Ports: &r.ports, + } + + var ( + contentInput string + ) + + if state.Continuation != nil && state.Continuation.State != nil { + contentInput = state.Continuation.State.Input + } + + callCtx.InputContext, err = r.getContext(callCtx, monitor, env, contentInput) + if err != nil { + return nil, err + } + nextContinuation, err := e.Continue(callCtx, state.Continuation.State, engineResults...) if err != nil { return nil, err