Skip to content

Commit

Permalink
Merge pull request #292 from ibuildthecloud/main
Browse files Browse the repository at this point in the history
feat: add workspace functions
  • Loading branch information
ibuildthecloud authored Apr 26, 2024
2 parents 0a7d53a + 7132a36 commit cb54f3b
Show file tree
Hide file tree
Showing 7 changed files with 184 additions and 48 deletions.
113 changes: 100 additions & 13 deletions pkg/builtin/builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,31 @@ import (
)

var tools = map[string]types.Tool{
"sys.workspace.ls": {
Parameters: types.Parameters{
Description: "Lists the contents of a directory relative to the current workspace",
Arguments: types.ObjectSchema(
"dir", "The directory to list"),
},
BuiltinFunc: SysWorkspaceLs,
},
"sys.workspace.write": {
Parameters: types.Parameters{
Description: "Write the contents to a file relative to the current workspace",
Arguments: types.ObjectSchema(
"filename", "The name of the file to write to",
"content", "The content to write"),
},
BuiltinFunc: SysWorkspaceWrite,
},
"sys.workspace.read": {
Parameters: types.Parameters{
Description: "Reads the contents of a file relative to the current workspace",
Arguments: types.ObjectSchema(
"filename", "The name of the file to read"),
},
BuiltinFunc: SysWorkspaceRead,
},
"sys.ls": {
Parameters: types.Parameters{
Description: "Lists the contents of a directory",
Expand Down Expand Up @@ -297,19 +322,46 @@ func SysExec(ctx context.Context, env []string, input string) (string, error) {
return string(out), err
}

func getWorkspaceDir(envs []string) (string, error) {
for _, env := range envs {
dir, ok := strings.CutPrefix(env, "GPTSCRIPT_WORKSPACE_DIR=")
if ok && dir != "" {
return dir, nil
}
}
return "", fmt.Errorf("no workspace directory found in env")
}

func SysWorkspaceLs(_ context.Context, env []string, input string) (string, error) {
dir, err := getWorkspaceDir(env)
if err != nil {
return "", err
}
return sysLs(dir, input)
}

func SysLs(_ context.Context, _ []string, input string) (string, error) {
return sysLs("", input)
}

func sysLs(base, input string) (string, error) {
var params struct {
Dir string `json:"dir,omitempty"`
}
if err := json.Unmarshal([]byte(input), &params); err != nil {
return "", err
}

if params.Dir == "" {
params.Dir = "."
dir := params.Dir
if dir == "" {
dir = "."
}

if base != "" {
dir = filepath.Join(base, dir)
}

entries, err := os.ReadDir(params.Dir)
entries, err := os.ReadDir(dir)
if errors.Is(err, fs.ErrNotExist) {
return fmt.Sprintf("directory does not exist: %s", params.Dir), nil
} else if err != nil {
Expand All @@ -328,20 +380,38 @@ func SysLs(_ context.Context, _ []string, input string) (string, error) {
return strings.Join(result, "\n"), nil
}

func SysWorkspaceRead(ctx context.Context, env []string, input string) (string, error) {
dir, err := getWorkspaceDir(env)
if err != nil {
return "", err
}

return sysRead(ctx, dir, env, input)
}

func SysRead(ctx context.Context, env []string, input string) (string, error) {
return sysRead(ctx, "", env, input)
}

func sysRead(ctx context.Context, base string, env []string, input string) (string, error) {
var params struct {
Filename string `json:"filename,omitempty"`
}
if err := json.Unmarshal([]byte(input), &params); err != nil {
return "", err
}

file := params.Filename
if base != "" {
file = filepath.Join(base, file)
}

// Lock the file to prevent concurrent writes from other tool calls.
locker.RLock(params.Filename)
defer locker.RUnlock(params.Filename)
locker.RLock(file)
defer locker.RUnlock(file)

log.Debugf("Reading file %s", params.Filename)
data, err := os.ReadFile(params.Filename)
log.Debugf("Reading file %s", file)
data, err := os.ReadFile(file)
if errors.Is(err, fs.ErrNotExist) {
return fmt.Sprintf("The file %s does not exist", params.Filename), nil
} else if err != nil {
Expand All @@ -354,7 +424,19 @@ func SysRead(ctx context.Context, env []string, input string) (string, error) {
return string(data), nil
}

func SysWorkspaceWrite(ctx context.Context, env []string, input string) (string, error) {
dir, err := getWorkspaceDir(env)
if err != nil {
return "", err
}
return sysWrite(ctx, dir, env, input)
}

func SysWrite(ctx context.Context, env []string, input string) (string, error) {
return sysWrite(ctx, "", env, input)
}

func sysWrite(ctx context.Context, base string, env []string, input string) (string, error) {
var params struct {
Filename string `json:"filename,omitempty"`
Content string `json:"content,omitempty"`
Expand All @@ -363,28 +445,33 @@ func SysWrite(ctx context.Context, env []string, input string) (string, error) {
return "", err
}

file := params.Filename
if base != "" {
file = filepath.Join(base, file)
}

// Lock the file to prevent concurrent writes from other tool calls.
locker.Lock(params.Filename)
defer locker.Unlock(params.Filename)
locker.Lock(file)
defer locker.Unlock(file)

dir := filepath.Dir(params.Filename)
dir := filepath.Dir(file)
if _, err := os.Stat(dir); errors.Is(err, fs.ErrNotExist) {
log.Debugf("Creating dir %s", dir)
if err := os.MkdirAll(dir, 0755); err != nil {
return "", fmt.Errorf("creating dir %s: %w", dir, err)
}
}

if _, err := os.Stat(params.Filename); err == nil {
if _, err := os.Stat(file); err == nil {
if err := confirm.Promptf(ctx, "Overwrite: %s", params.Filename); err != nil {
return "", err
}
}

data := []byte(params.Content)
log.Debugf("Wrote %d bytes to file %s", len(data), params.Filename)
log.Debugf("Wrote %d bytes to file %s", len(data), file)

return "", os.WriteFile(params.Filename, data, 0644)
return "", os.WriteFile(file, data, 0644)
}

func SysAppend(ctx context.Context, env []string, input string) (string, error) {
Expand Down
2 changes: 1 addition & 1 deletion pkg/chat/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func Start(ctx context.Context, prevState runner.ChatState, chatter Chatter, prg
prompter Prompter
)

prompter, err := newReadlinePrompter()
prompter, err := newReadlinePrompter(prg)
if err != nil {
return err
}
Expand Down
10 changes: 8 additions & 2 deletions pkg/chat/readline.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/adrg/xdg"
"github.com/chzyer/readline"
"github.com/fatih/color"
"github.com/gptscript-ai/gptscript/pkg/hash"
"github.com/gptscript-ai/gptscript/pkg/mvl"
)

Expand All @@ -18,8 +19,13 @@ type readlinePrompter struct {
readliner *readline.Instance
}

func newReadlinePrompter() (*readlinePrompter, error) {
historyFile, err := xdg.CacheFile("gptscript/chat.history")
func newReadlinePrompter(prg GetProgram) (*readlinePrompter, error) {
targetProgram, err := prg()
if err != nil {
return nil, err
}

historyFile, err := xdg.CacheFile(fmt.Sprintf("gptscript/chat-%s.history", hash.ID(targetProgram.EntryToolID)))
if err != nil {
historyFile = ""
}
Expand Down
2 changes: 2 additions & 0 deletions pkg/cli/gptscript.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ type GPTScript struct {
CredentialOverride string `usage:"Credentials to override (ex: --credential-override github.com/example/cred-tool:API_TOKEN=1234)"`
ChatState string `usage:"The chat state to continue, or null to start a new chat and return the state"`
ForceChat bool `usage:"Force an interactive chat session if even the top level tool is not a chat tool"`
Workspace string `usage:"Directory to use for the workspace, if specified it will not be deleted on exit"`

readData []byte
}
Expand Down Expand Up @@ -123,6 +124,7 @@ func (r *GPTScript) NewGPTScriptOpts() (gptscript.Options, error) {
Quiet: r.Quiet,
Env: os.Environ(),
CredentialContext: r.CredentialContext,
Workspace: r.Workspace,
}

if r.Ports != "" {
Expand Down
60 changes: 51 additions & 9 deletions pkg/gptscript/gptscript.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,33 +2,41 @@ package gptscript

import (
"context"
"fmt"
"os"

"github.com/gptscript-ai/gptscript/pkg/builtin"
"github.com/gptscript-ai/gptscript/pkg/cache"
"github.com/gptscript-ai/gptscript/pkg/engine"
"github.com/gptscript-ai/gptscript/pkg/hash"
"github.com/gptscript-ai/gptscript/pkg/llm"
"github.com/gptscript-ai/gptscript/pkg/monitor"
"github.com/gptscript-ai/gptscript/pkg/mvl"
"github.com/gptscript-ai/gptscript/pkg/openai"
"github.com/gptscript-ai/gptscript/pkg/remote"
"github.com/gptscript-ai/gptscript/pkg/repos/runtimes"
"github.com/gptscript-ai/gptscript/pkg/runner"
"github.com/gptscript-ai/gptscript/pkg/types"
)

var log = mvl.Package()

type GPTScript struct {
Registry *llm.Registry
Runner *runner.Runner
Registry *llm.Registry
Runner *runner.Runner
WorkspacePath string
DeleteWorkspaceOnClose bool
}

type Options struct {
Cache cache.Options
OpenAI openai.Options
Monitor monitor.Options
Runner runner.Options
CredentialContext string `usage:"Context name in which to store credentials" default:"default"`
Quiet *bool `usage:"No output logging (set --quiet=false to force on even when there is no TTY)" short:"q"`
Env []string `usage:"-"`
CredentialContext string
Quiet *bool
Workspace string
Env []string
}

func complete(opts *Options) (result *Options) {
Expand Down Expand Up @@ -89,21 +97,55 @@ func New(opts *Options) (*GPTScript, error) {
}

return &GPTScript{
Registry: registry,
Runner: runner,
Registry: registry,
Runner: runner,
WorkspacePath: opts.Workspace,
DeleteWorkspaceOnClose: opts.Workspace == "",
}, nil
}

func (g *GPTScript) Chat(ctx context.Context, prevState runner.ChatState, prg types.Program, env []string, input string) (runner.ChatResponse, error) {
return g.Runner.Chat(ctx, prevState, prg, env, input)
func (g *GPTScript) getEnv(env []string) ([]string, error) {
if g.WorkspacePath == "" {
var err error
g.WorkspacePath, err = os.MkdirTemp("", "gptscript-workspace-*")
if err != nil {
return nil, err
}
}
if err := os.MkdirAll(g.WorkspacePath, 0700); err != nil {
return nil, err
}
return append([]string{
fmt.Sprintf("GPTSCRIPT_WORKSPACE_DIR=%s", g.WorkspacePath),
fmt.Sprintf("GPTSCRIPT_WORKSPACE_ID=%s", hash.ID(g.WorkspacePath)),
}, env...), nil
}

func (g *GPTScript) Chat(ctx context.Context, prevState runner.ChatState, prg types.Program, envs []string, input string) (runner.ChatResponse, error) {
envs, err := g.getEnv(envs)
if err != nil {
return runner.ChatResponse{}, err
}

return g.Runner.Chat(ctx, prevState, prg, envs, input)
}

func (g *GPTScript) Run(ctx context.Context, prg types.Program, envs []string, input string) (string, error) {
envs, err := g.getEnv(envs)
if err != nil {
return "", err
}

return g.Runner.Run(ctx, prg, envs, input)
}

func (g *GPTScript) Close() {
g.Runner.Close()
if g.DeleteWorkspaceOnClose && g.WorkspacePath != "" {
if err := os.RemoveAll(g.WorkspacePath); err != nil {
log.Errorf("failed to delete workspace %s: %s", g.WorkspacePath, err)
}
}
}

func (g *GPTScript) GetModel() engine.Model {
Expand Down
2 changes: 1 addition & 1 deletion pkg/parser/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ func isParam(line string, tool *types.Tool) (_ bool, err error) {
value = strings.TrimSpace(value)
switch normalize(key) {
case "name":
tool.Parameters.Name = strings.ToLower(value)
tool.Parameters.Name = value
case "modelprovider":
tool.Parameters.ModelProvider = true
case "model", "modelname":
Expand Down
Loading

0 comments on commit cb54f3b

Please sign in to comment.