From f4284f432af03d6c222950aa811788d617efe5d2 Mon Sep 17 00:00:00 2001 From: Grant Linville Date: Thu, 11 Apr 2024 17:04:46 -0400 Subject: [PATCH] feat: credentials framework (#212) Signed-off-by: Grant Linville --- docs/docs/03-tools/04-credentials.md | 99 ++++++++++++++++ docs/docs/07-gpt-file-reference.md | 23 ++-- go.mod | 4 + go.sum | 8 ++ pkg/builtin/builtin.go | 57 ++++++++++ pkg/cli/credential.go | 78 +++++++++++++ pkg/cli/credential_delete.go | 37 ++++++ pkg/cli/gptscript.go | 44 ++++---- pkg/config/cliconfig.go | 161 +++++++++++++++++++++++++++ pkg/credentials/credential.go | 58 ++++++++++ pkg/credentials/helper.go | 75 +++++++++++++ pkg/credentials/store.go | 120 ++++++++++++++++++++ pkg/engine/cmd.go | 1 + pkg/engine/engine.go | 37 ++++-- pkg/gptscript/gptscript.go | 15 +-- pkg/loader/loader.go | 3 +- pkg/monitor/display.go | 39 ++++++- pkg/mvl/log.go | 4 + pkg/openai/client.go | 1 + pkg/parser/parser.go | 2 + pkg/runner/monitor.go | 4 + pkg/runner/runner.go | 94 +++++++++++++++- pkg/server/server.go | 24 +++- pkg/tests/tester/runner.go | 2 +- pkg/types/tool.go | 4 + 25 files changed, 927 insertions(+), 67 deletions(-) create mode 100644 docs/docs/03-tools/04-credentials.md create mode 100644 pkg/cli/credential.go create mode 100644 pkg/cli/credential_delete.go create mode 100644 pkg/config/cliconfig.go create mode 100644 pkg/credentials/credential.go create mode 100644 pkg/credentials/helper.go create mode 100644 pkg/credentials/store.go diff --git a/docs/docs/03-tools/04-credentials.md b/docs/docs/03-tools/04-credentials.md new file mode 100644 index 00000000..617587aa --- /dev/null +++ b/docs/docs/03-tools/04-credentials.md @@ -0,0 +1,99 @@ +# Credentials + +GPTScript supports credential provider tools. These tools can be used to fetch credentials from a secure location (or +directly from user input) and conveniently set them in the environment before running a script. + +## Writing a Credential Provider Tool + +A credential provider tool looks just like any other GPTScript, with the following caveats: +- It cannot call the LLM and must run a command. +- It must print contents to stdout in the format `{"env":{"ENV_VAR_1":"value1","ENV_VAR_2":"value2"}}`. +- Any args defined on the tool will be ignored. + +Here is a simple example of a credential provider tool that uses the builtin `sys.prompt` to ask the user for some input: + +```yaml +# my-credential-tool.gpt +name: my-credential-tool + +#!/usr/bin/env bash + +output=$(gptscript -q --cache=false sys.prompt '{"message":"Please enter your fake credential.","fields":"credential","sensitive":"true"}') +credential=$(echo $output | jq -r '.credential') +echo "{\"env\":{\"MY_ENV_VAR\":\"$credential\"}}" +``` + +## Using a Credential Provider Tool + +Continuing with the above example, this is how you can use it in a script: + +```yaml +credentials: my-credential-tool.gpt + +#!/usr/bin/env bash + +echo "The value of MY_ENV_VAR is $MY_ENV_VAR" +``` + +When you run the script, GPTScript will call the credential provider tool first, set the environment variables from its +output, and then run the script body. The credential provider tool is called by GPTScript itself. GPTScript does not ask the +LLM about it or even tell the LLM about the tool. + +If GPTScript has called the credential provider tool in the same context (more on that later), then it will use the stored +credential instead of fetching it again. + +You can also specify multiple credential tools for the same script: + +```yaml +credentials: credential-tool-1.gpt, credential-tool-2.gpt + +(tool stuff here) +``` + +## Storing Credentials + +By default, credentials are automatically stored in a config file at `$XDG_CONFIG_HOME/gptscript/config.json`. +This config file also has another parameter, `credsStore`, which indicates where the credentials are being stored. + +- `file` (default): The credentials are stored directly in the config file. +- `osxkeychain`: The credentials are stored in the macOS Keychain. + +In order to use `osxkeychain` as the credsStore, you must have the `gptscript-credential-osxkeychain` executable +available in your PATH. There will probably be better packaging for this in the future, but for now, you can build it +from the [repo](https://github.com/gptscript-ai/gptscript-credential-helpers). + +There will likely be support added for other credential stores in the future. + +:::note +Credentials received from credential provider tools that are not on GitHub (such as a local file) will not be stored +in the credentials store. +::: + +## Credential Contexts + +Each stored credential is uniquely identified by the name of its provider tool and the name of its context. A credential +context is basically a namespace for credentials. If you have multiple credentials from the same provider tool, you can +switch between them by defining them in different credential contexts. The default context is called `default`, and this +is used if none is specified. + +You can set the credential context to use with the `--credential-context` flag when running GPTScript. For +example: + +```bash +gptscript --credential-context my-azure-workspace my-azure-script.gpt +``` + +Any credentials fetched for that script will be stored in the `my-azure-workspace` context. If you were to call it again +with a different context, you would be able to give it a different set of credentials. + +## Listing and Deleting Stored Credentials + +The `gptscript credential` command can be used to list and delete stored credentials. Running the command with no +`--credential-context` set will use the `default` credential context. You can also specify that it should list +credentials in all contexts with `--all-contexts`. + +You can delete a credential by running the following command: + +```bash +gptscript credential delete --credential-context +``` diff --git a/docs/docs/07-gpt-file-reference.md b/docs/docs/07-gpt-file-reference.md index fd57bfe5..ba40bcf8 100644 --- a/docs/docs/07-gpt-file-reference.md +++ b/docs/docs/07-gpt-file-reference.md @@ -43,17 +43,18 @@ Tool instructions go here. Tool parameters are key-value pairs defined at the beginning of a tool block, before any instructional text. They are specified in the format `key: value`. The parser recognizes the following keys (case-insensitive and spaces are ignored): -| Key | Description | -|------------------|-----------------------------------------------------------------------------------------------------------------------------------------| -| `Name` | The name of the tool. | -| `Model Name` | The OpenAI model to use, by default it uses "gpt-4-turbo-preview" | -| `Description` | The description of the tool. It is important that this properly describes the tool's purpose as the description is used by the LLM. | -| `Internal Prompt`| Setting this to `false` will disable the built-in system prompt for this tool. | -| `Tools` | A comma-separated list of tools that are available to be called by this tool. | -| `Args` | Arguments for the tool. Each argument is defined in the format `arg-name: description`. | -| `Max Tokens` | Set to a number if you wish to limit the maximum number of tokens that can be generated by the LLM. | -| `JSON Response` | Setting to `true` will cause the LLM to respond in a JSON format. If you set true you must also include instructions in the tool. | -| `Temperature` | A floating-point number representing the temperature parameter. By default, the temperature is 0. Set to a higher number for more creativity. | +| Key | Description | +|-------------------|-----------------------------------------------------------------------------------------------------------------------------------------------| +| `Name` | The name of the tool. | +| `Model Name` | The OpenAI model to use, by default it uses "gpt-4-turbo-preview" | +| `Description` | The description of the tool. It is important that this properly describes the tool's purpose as the description is used by the LLM. | +| `Internal Prompt` | Setting this to `false` will disable the built-in system prompt for this tool. | +| `Tools` | A comma-separated list of tools that are available to be called by this tool. | +| `Credentials` | A comma-separated list of credential tools to run before the main tool. | +| `Args` | Arguments for the tool. Each argument is defined in the format `arg-name: description`. | +| `Max Tokens` | Set to a number if you wish to limit the maximum number of tokens that can be generated by the LLM. | +| `JSON Response` | Setting to `true` will cause the LLM to respond in a JSON format. If you set true you must also include instructions in the tool. | +| `Temperature` | A floating-point number representing the temperature parameter. By default, the temperature is 0. Set to a higher number for more creativity. | diff --git a/go.mod b/go.mod index 6e4f3df6..afb085dd 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,8 @@ require ( github.com/acorn-io/broadcaster v0.0.0-20240105011354-bfadd4a7b45d github.com/acorn-io/cmd v0.0.0-20240404013709-34f690bde37b github.com/adrg/xdg v0.4.0 + github.com/docker/cli v26.0.0+incompatible + github.com/docker/docker-credential-helpers v0.8.1 github.com/fatih/color v1.16.0 github.com/getkin/kin-openapi v0.123.0 github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 @@ -63,6 +65,7 @@ require ( github.com/olekukonko/tablewriter v0.0.6-0.20230925090304-df64c4bbad77 // indirect github.com/perimeterx/marshmallow v1.1.5 // indirect github.com/pierrec/lz4/v4 v4.1.15 // indirect + github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rivo/uniseg v0.1.0 // indirect github.com/spf13/pflag v1.0.5 // indirect @@ -77,5 +80,6 @@ require ( golang.org/x/sys v0.16.0 // indirect golang.org/x/text v0.14.0 // indirect golang.org/x/tools v0.17.0 // indirect + gotest.tools/v3 v3.5.1 // indirect mvdan.cc/gofumpt v0.6.0 // indirect ) diff --git a/go.sum b/go.sum index 95a7dfc5..5bc013f3 100644 --- a/go.sum +++ b/go.sum @@ -51,6 +51,10 @@ github.com/creack/pty v1.1.17/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/docker/cli v26.0.0+incompatible h1:90BKrx1a1HKYpSnnBFR6AgDq/FqkHxwlUyzJVPxD30I= +github.com/docker/cli v26.0.0+incompatible/go.mod h1:JLrzqnKDaYBop7H2jaqPtU4hHvMKP+vjCwu2uszcLI8= +github.com/docker/docker-credential-helpers v0.8.1 h1:j/eKUktUltBtMzKqmfLB0PAgqYyMHOp5vfsD1807oKo= +github.com/docker/docker-credential-helpers v0.8.1/go.mod h1:P3ci7E3lwkZg6XiHdRKft1KckHiO9a2rNtyFbZ/ry9M= github.com/dsnet/compress v0.0.1 h1:PlZu0n3Tuv04TzpfPbrnI0HW/YwodEXDS+oPKahKF0Q= github.com/dsnet/compress v0.0.1/go.mod h1:Aw8dCMJ7RioblQeTqt88akK31OvO8Dhf5JflhBbQEHo= github.com/dsnet/golib v0.0.0-20171103203638-1ea166775780/go.mod h1:Lj+Z9rebOhdfkVLjJ8T6VcRQv3SXugXy999NBtR9aFY= @@ -191,6 +195,8 @@ github.com/perimeterx/marshmallow v1.1.5/go.mod h1:dsXbUu8CRzfYP5a87xpp0xq9S3u0V github.com/pierrec/lz4/v4 v4.1.15 h1:MO0/ucJhngq7299dKLwIMtgTfbkoSPF6AoMYDd8Q4q0= github.com/pierrec/lz4/v4 v4.1.15/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= @@ -432,6 +438,8 @@ gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C gopkg.in/yaml.v3 v3.0.0/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gotest.tools/v3 v3.5.1 h1:EENdUnS3pdur5nybKYIh2Vfgc8IUNBjxDPSjtiJcOzU= +gotest.tools/v3 v3.5.1/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/pkg/builtin/builtin.go b/pkg/builtin/builtin.go index e696e4c1..1ca10067 100644 --- a/pkg/builtin/builtin.go +++ b/pkg/builtin/builtin.go @@ -17,9 +17,11 @@ import ( "strings" "time" + "github.com/AlecAivazis/survey/v2" "github.com/BurntSushi/locker" "github.com/google/shlex" "github.com/gptscript-ai/gptscript/pkg/confirm" + "github.com/gptscript-ai/gptscript/pkg/runner" "github.com/gptscript-ai/gptscript/pkg/types" "github.com/jaytaylor/html2text" ) @@ -149,6 +151,17 @@ var tools = map[string]types.Tool{ }, BuiltinFunc: SysStat, }, + "sys.prompt": { + Parameters: types.Parameters{ + Description: "Prompts the user for input", + Arguments: types.ObjectSchema( + "message", "The message to display to the user", + "fields", "A comma-separated list of fields to prompt for", + "sensitive", "(true or false) Whether the input should be hidden", + ), + }, + BuiltinFunc: SysPrompt, + }, } func SysProgram() *types.Program { @@ -633,3 +646,47 @@ func SysDownload(ctx context.Context, env []string, input string) (_ string, err return params.Location, nil } + +func SysPrompt(ctx context.Context, _ []string, input string) (_ string, err error) { + monitor := ctx.Value(runner.MonitorKey{}) + if monitor == nil { + return "", errors.New("no monitor in context") + } + + unpause := monitor.(runner.Monitor).Pause() + defer unpause() + + var params struct { + Message string `json:"message,omitempty"` + Fields string `json:"fields,omitempty"` + Sensitive string `json:"sensitive,omitempty"` + } + if err := json.Unmarshal([]byte(input), ¶ms); err != nil { + return "", err + } + + if params.Message != "" { + _, _ = fmt.Fprintln(os.Stderr, params.Message) + } + + results := map[string]string{} + for _, f := range strings.Split(params.Fields, ",") { + var value string + if params.Sensitive == "true" { + err = survey.AskOne(&survey.Password{Message: f}, &value, survey.WithStdio(os.Stdin, os.Stderr, os.Stderr)) + } else { + err = survey.AskOne(&survey.Input{Message: f}, &value, survey.WithStdio(os.Stdin, os.Stderr, os.Stderr)) + } + if err != nil { + return "", err + } + results[f] = value + } + + resultsStr, err := json.Marshal(results) + if err != nil { + return "", err + } + + return string(resultsStr), nil +} diff --git a/pkg/cli/credential.go b/pkg/cli/credential.go new file mode 100644 index 00000000..9c999032 --- /dev/null +++ b/pkg/cli/credential.go @@ -0,0 +1,78 @@ +package cli + +import ( + "fmt" + "os" + "sort" + "text/tabwriter" + + cmd2 "github.com/acorn-io/cmd" + "github.com/gptscript-ai/gptscript/pkg/config" + "github.com/gptscript-ai/gptscript/pkg/credentials" + "github.com/gptscript-ai/gptscript/pkg/version" + "github.com/spf13/cobra" +) + +type Credential struct { + root *GPTScript + AllContexts bool `usage:"List credentials for all contexts" local:"true"` +} + +func (c *Credential) Customize(cmd *cobra.Command) { + cmd.Use = "credential" + cmd.Version = version.Get().String() + cmd.Aliases = []string{"cred", "creds", "credentials"} + cmd.Short = "List stored credentials" + cmd.Args = cobra.NoArgs + cmd.AddCommand(cmd2.Command(&Delete{root: c.root})) +} + +func (c *Credential) Run(_ *cobra.Command, _ []string) error { + cfg, err := config.ReadCLIConfig(c.root.ConfigFile) + if err != nil { + return fmt.Errorf("failed to read CLI config: %w", err) + } + + ctx := c.root.CredentialContext + if c.AllContexts { + ctx = "*" + } + + store, err := credentials.NewStore(cfg, ctx) + if err != nil { + return fmt.Errorf("failed to get credentials store: %w", err) + } + + creds, err := store.List() + if err != nil { + return fmt.Errorf("failed to list credentials: %w", err) + } + + if c.AllContexts { + // Sort credentials by context + sort.Slice(creds, func(i, j int) bool { + if creds[i].Context == creds[j].Context { + return creds[i].ToolName < creds[j].ToolName + } + return creds[i].Context < creds[j].Context + }) + + w := tabwriter.NewWriter(os.Stdout, 10, 1, 3, ' ', 0) + defer w.Flush() + _, _ = w.Write([]byte("CONTEXT\tTOOL\n")) + for _, cred := range creds { + _, _ = fmt.Fprintf(w, "%s\t%s\n", cred.Context, cred.ToolName) + } + } else { + // Sort credentials by tool name + sort.Slice(creds, func(i, j int) bool { + return creds[i].ToolName < creds[j].ToolName + }) + + for _, cred := range creds { + fmt.Println(cred.ToolName) + } + } + + return nil +} diff --git a/pkg/cli/credential_delete.go b/pkg/cli/credential_delete.go new file mode 100644 index 00000000..5df56509 --- /dev/null +++ b/pkg/cli/credential_delete.go @@ -0,0 +1,37 @@ +package cli + +import ( + "fmt" + + "github.com/gptscript-ai/gptscript/pkg/config" + "github.com/gptscript-ai/gptscript/pkg/credentials" + "github.com/spf13/cobra" +) + +type Delete struct { + root *GPTScript +} + +func (c *Delete) Customize(cmd *cobra.Command) { + cmd.Use = "delete " + cmd.SilenceUsage = true + cmd.Short = "Delete a stored credential" + cmd.Args = cobra.ExactArgs(1) +} + +func (c *Delete) Run(_ *cobra.Command, args []string) error { + cfg, err := config.ReadCLIConfig(c.root.ConfigFile) + if err != nil { + return fmt.Errorf("failed to read CLI config: %w", err) + } + + store, err := credentials.NewStore(cfg, c.root.CredentialContext) + if err != nil { + return fmt.Errorf("failed to get credentials store: %w", err) + } + + if err = store.Remove(args[0]); err != nil { + return fmt.Errorf("failed to remove credential: %w", err) + } + return nil +} diff --git a/pkg/cli/gptscript.go b/pkg/cli/gptscript.go index dc91adf8..637ecddf 100644 --- a/pkg/cli/gptscript.go +++ b/pkg/cli/gptscript.go @@ -38,28 +38,29 @@ type GPTScript struct { CacheOptions OpenAIOptions DisplayOptions - Color *bool `usage:"Use color in output (default true)" default:"true"` - Confirm bool `usage:"Prompt before running potentially dangerous commands"` - Debug bool `usage:"Enable debug logging"` - Quiet *bool `usage:"No output logging (set --quiet=false to force on even when there is no TTY)" short:"q"` - Output string `usage:"Save output to a file, or - for stdout" short:"o"` - Input string `usage:"Read input from a file (\"-\" for stdin)" short:"f"` - SubTool string `usage:"Use tool of this name, not the first tool in file" local:"true"` - Assemble bool `usage:"Assemble tool to a single artifact, saved to --output" hidden:"true" local:"true"` - ListModels bool `usage:"List the models available and exit" local:"true"` - ListTools bool `usage:"List built-in tools and exit" local:"true"` - Server bool `usage:"Start server" local:"true"` - ListenAddress string `usage:"Server listen address" default:"127.0.0.1:9090" local:"true"` - Chdir string `usage:"Change current working directory" short:"C"` - Daemon bool `usage:"Run tool as a daemon" local:"true" hidden:"true"` - Ports string `usage:"The port range to use for ephemeral daemon ports (ex: 11000-12000)" hidden:"true"` + Color *bool `usage:"Use color in output (default true)" default:"true"` + Confirm bool `usage:"Prompt before running potentially dangerous commands"` + Debug bool `usage:"Enable debug logging"` + Quiet *bool `usage:"No output logging (set --quiet=false to force on even when there is no TTY)" short:"q"` + Output string `usage:"Save output to a file, or - for stdout" short:"o"` + Input string `usage:"Read input from a file (\"-\" for stdin)" short:"f"` + SubTool string `usage:"Use tool of this name, not the first tool in file" local:"true"` + Assemble bool `usage:"Assemble tool to a single artifact, saved to --output" hidden:"true" local:"true"` + ListModels bool `usage:"List the models available and exit" local:"true"` + ListTools bool `usage:"List built-in tools and exit" local:"true"` + Server bool `usage:"Start server" local:"true"` + ListenAddress string `usage:"Server listen address" default:"127.0.0.1:9090" local:"true"` + Chdir string `usage:"Change current working directory" short:"C"` + Daemon bool `usage:"Run tool as a daemon" local:"true" hidden:"true"` + Ports string `usage:"The port range to use for ephemeral daemon ports (ex: 11000-12000)" hidden:"true"` + CredentialContext string `usage:"Context name in which to store credentials" default:"default"` } func New() *cobra.Command { root := &GPTScript{} return cmd.Command(root, &Eval{ gptscript: root, - }) + }, &Credential{root: root}) } func (r *GPTScript) NewRunContext(cmd *cobra.Command) context.Context { @@ -72,11 +73,12 @@ func (r *GPTScript) NewRunContext(cmd *cobra.Command) context.Context { func (r *GPTScript) NewGPTScriptOpts() (gptscript.Options, error) { opts := gptscript.Options{ - Cache: cache.Options(r.CacheOptions), - OpenAI: openai.Options(r.OpenAIOptions), - Monitor: monitor.Options(r.DisplayOptions), - Quiet: r.Quiet, - Env: os.Environ(), + Cache: cache.Options(r.CacheOptions), + OpenAI: openai.Options(r.OpenAIOptions), + Monitor: monitor.Options(r.DisplayOptions), + Quiet: r.Quiet, + Env: os.Environ(), + CredentialContext: r.CredentialContext, } if r.Ports != "" { diff --git a/pkg/config/cliconfig.go b/pkg/config/cliconfig.go new file mode 100644 index 00000000..54ab0c3e --- /dev/null +++ b/pkg/config/cliconfig.go @@ -0,0 +1,161 @@ +package config + +import ( + "encoding/base64" + "encoding/json" + "fmt" + "os" + "os/exec" + "runtime" + "strings" + "sync" + + "github.com/adrg/xdg" + "github.com/docker/cli/cli/config/types" +) + +const GPTScriptHelperPrefix = "gptscript-credential-" + +type AuthConfig types.AuthConfig + +func (a AuthConfig) MarshalJSON() ([]byte, error) { + cp := a + if cp.Username != "" || cp.Password != "" { + cp.Auth = base64.StdEncoding.EncodeToString([]byte(cp.Username + ":" + cp.Password)) + cp.Username = "" + cp.Password = "" + } + cp.ServerAddress = "" + return json.Marshal((types.AuthConfig)(cp)) +} + +func (a *AuthConfig) UnmarshalJSON(data []byte) error { + if err := json.Unmarshal(data, (*types.AuthConfig)(a)); err != nil { + return err + } + if a.Auth != "" { + data, err := base64.StdEncoding.DecodeString(a.Auth) + if err != nil { + return err + } + a.Username, a.Password, _ = strings.Cut(string(data), ":") + a.Auth = "" + } + return nil +} + +type CLIConfig struct { + Auths map[string]AuthConfig `json:"auths,omitempty"` + CredentialsStore string `json:"credsStore,omitempty"` + GPTScriptConfigFile string `json:"gptscriptConfig,omitempty"` + + auths map[string]types.AuthConfig + authsLock *sync.Mutex +} + +func (c *CLIConfig) Sanitize() *CLIConfig { + if c == nil { + return nil + } + cp := *c + cp.Auths = map[string]AuthConfig{} + for k := range c.Auths { + cp.Auths[k] = AuthConfig{ + Auth: "", + } + } + return &cp +} + +func (c *CLIConfig) Save() error { + if c.authsLock != nil { + c.authsLock.Lock() + defer c.authsLock.Unlock() + } + + if c.auths != nil { + c.Auths = map[string]AuthConfig{} + for k, v := range c.auths { + c.Auths[k] = (AuthConfig)(v) + } + c.auths = nil + } + data, err := json.Marshal(c) + if err != nil { + return err + } + return os.WriteFile(c.GPTScriptConfigFile, data, 0655) +} + +func (c *CLIConfig) GetAuthConfigs() map[string]types.AuthConfig { + if c.authsLock != nil { + c.authsLock.Lock() + defer c.authsLock.Unlock() + } + + if c.auths == nil { + c.auths = map[string]types.AuthConfig{} + for k, v := range c.Auths { + authConfig := (types.AuthConfig)(v) + c.auths[k] = authConfig + } + } + return c.auths +} + +func (c *CLIConfig) GetFilename() string { + return c.GPTScriptConfigFile +} + +func ReadCLIConfig(gptscriptConfigFile string) (*CLIConfig, error) { + if gptscriptConfigFile == "" { + // If gptscriptConfigFile isn't provided, check the environment variable + if gptscriptConfigFile = os.Getenv("GPTSCRIPT_CONFIG_FILE"); gptscriptConfigFile == "" { + // If an environment variable isn't provided, check the default location + var err error + if gptscriptConfigFile, err = xdg.ConfigFile("gptscript/config.json"); err != nil { + return nil, fmt.Errorf("failed to read user config from standard location: %w", err) + } + } + } + + data, err := readFile(gptscriptConfigFile) + if err != nil { + return nil, err + } + result := &CLIConfig{ + authsLock: &sync.Mutex{}, + GPTScriptConfigFile: gptscriptConfigFile, + } + if err := json.Unmarshal(data, result); err != nil { + return nil, err + } + + if result.CredentialsStore == "" { + result.setDefaultCredentialsStore() + } + + return result, nil +} + +func (c *CLIConfig) setDefaultCredentialsStore() { + if runtime.GOOS == "darwin" { + // Check for the existence of the helper program + fullPath, err := exec.LookPath(GPTScriptHelperPrefix + "osxkeychain") + if err == nil && fullPath != "" { + c.CredentialsStore = "osxkeychain" + } + } + c.CredentialsStore = "file" +} + +func readFile(path string) ([]byte, error) { + data, err := os.ReadFile(path) + if os.IsNotExist(err) { + return []byte("{}"), nil + } else if err != nil { + return nil, fmt.Errorf("failed to read user config %s: %w", path, err) + } + + return data, nil +} diff --git a/pkg/credentials/credential.go b/pkg/credentials/credential.go new file mode 100644 index 00000000..6344b8c5 --- /dev/null +++ b/pkg/credentials/credential.go @@ -0,0 +1,58 @@ +package credentials + +import ( + "encoding/json" + "fmt" + "strings" + + "github.com/docker/cli/cli/config/types" +) + +type Credential struct { + Context string `json:"context"` + ToolName string `json:"toolName"` + Env map[string]string `json:"env"` +} + +func (c Credential) toDockerAuthConfig() (types.AuthConfig, error) { + env, err := json.Marshal(c.Env) + if err != nil { + return types.AuthConfig{}, err + } + + return types.AuthConfig{ + Username: "gptscript", // Username is required, but not used + Password: string(env), + ServerAddress: toolNameWithCtx(c.ToolName, c.Context), + }, nil +} + +func credentialFromDockerAuthConfig(authCfg types.AuthConfig) (Credential, error) { + var env map[string]string + if err := json.Unmarshal([]byte(authCfg.Password), &env); err != nil { + return Credential{}, err + } + + tool, ctx, err := toolNameAndCtxFromAddress(strings.TrimPrefix(authCfg.ServerAddress, "https://")) + if err != nil { + return Credential{}, err + } + + return Credential{ + Context: ctx, + ToolName: tool, + Env: env, + }, nil +} + +func toolNameWithCtx(toolName, credCtx string) string { + return toolName + "///" + credCtx +} + +func toolNameAndCtxFromAddress(address string) (string, string, error) { + parts := strings.Split(address, "///") + if len(parts) != 2 { + return "", "", fmt.Errorf("error parsing tool name and context %q. Tool names cannot contain '///'", address) + } + return parts[0], parts[1], nil +} diff --git a/pkg/credentials/helper.go b/pkg/credentials/helper.go new file mode 100644 index 00000000..49dd1900 --- /dev/null +++ b/pkg/credentials/helper.go @@ -0,0 +1,75 @@ +package credentials + +import ( + "errors" + + "github.com/docker/cli/cli/config/credentials" + "github.com/docker/cli/cli/config/types" + "github.com/docker/docker-credential-helpers/client" + credentials2 "github.com/docker/docker-credential-helpers/credentials" + "github.com/gptscript-ai/gptscript/pkg/config" +) + +func NewHelper(c *config.CLIConfig, helper string) (credentials.Store, error) { + return &HelperStore{ + file: credentials.NewFileStore(c), + program: client.NewShellProgramFunc(helper), + }, nil +} + +type HelperStore struct { + file credentials.Store + program client.ProgramFunc +} + +func (h *HelperStore) Erase(serverAddress string) error { + var errs []error + if err := client.Erase(h.program, serverAddress); err != nil { + errs = append(errs, err) + } + if err := h.file.Erase(serverAddress); err != nil { + errs = append(errs, err) + } + return errors.Join(errs...) +} + +func (h *HelperStore) Get(serverAddress string) (types.AuthConfig, error) { + creds, err := client.Get(h.program, serverAddress) + if credentials2.IsErrCredentialsNotFound(err) { + return h.file.Get(serverAddress) + } else if err != nil { + return types.AuthConfig{}, err + } + return types.AuthConfig{ + Username: creds.Username, + Password: creds.Secret, + ServerAddress: serverAddress, + }, nil +} + +func (h *HelperStore) GetAll() (map[string]types.AuthConfig, error) { + result := map[string]types.AuthConfig{} + + serverAddresses, err := client.List(h.program) + if err != nil { + return nil, err + } + + for serverAddress := range serverAddresses { + ac, err := h.Get(serverAddress) + if err != nil { + return nil, err + } + result[serverAddress] = ac + } + + return result, nil +} + +func (h *HelperStore) Store(authConfig types.AuthConfig) error { + return client.Store(h.program, &credentials2.Credentials{ + ServerURL: authConfig.ServerAddress, + Username: authConfig.Username, + Secret: authConfig.Password, + }) +} diff --git a/pkg/credentials/store.go b/pkg/credentials/store.go new file mode 100644 index 00000000..86b6fe50 --- /dev/null +++ b/pkg/credentials/store.go @@ -0,0 +1,120 @@ +package credentials + +import ( + "fmt" + "regexp" + + "github.com/docker/cli/cli/config/credentials" + "github.com/gptscript-ai/gptscript/pkg/config" +) + +type Store struct { + credCtx string + cfg *config.CLIConfig +} + +func NewStore(cfg *config.CLIConfig, credCtx string) (*Store, error) { + if err := validateCredentialCtx(credCtx); err != nil { + return nil, err + } + return &Store{ + credCtx: credCtx, + cfg: cfg, + }, nil +} + +func (s *Store) Get(toolName string) (*Credential, bool, error) { + store, err := s.getStore() + if err != nil { + return nil, false, err + } + auth, err := store.Get(toolNameWithCtx(toolName, s.credCtx)) + if err != nil { + return nil, false, err + } else if auth.Password == "" { + return nil, false, nil + } + + cred, err := credentialFromDockerAuthConfig(auth) + if err != nil { + return nil, false, err + } + return &cred, true, nil +} + +func (s *Store) Add(cred Credential) error { + cred.Context = s.credCtx + store, err := s.getStore() + if err != nil { + return err + } + auth, err := cred.toDockerAuthConfig() + if err != nil { + return err + } + return store.Store(auth) +} + +func (s *Store) Remove(toolName string) error { + store, err := s.getStore() + if err != nil { + return err + } + return store.Erase(toolNameWithCtx(toolName, s.credCtx)) +} + +func (s *Store) List() ([]Credential, error) { + store, err := s.getStore() + if err != nil { + return nil, err + } + list, err := store.GetAll() + if err != nil { + return nil, err + } + + var creds []Credential + for serverAddress, authCfg := range list { + if authCfg.ServerAddress == "" { + authCfg.ServerAddress = serverAddress // Not sure why we have to do this, but we do. + } + + c, err := credentialFromDockerAuthConfig(authCfg) + if err != nil { + return nil, err + } + if s.credCtx == "*" || c.Context == s.credCtx { + creds = append(creds, c) + } + } + + return creds, nil +} + +func (s *Store) getStore() (credentials.Store, error) { + return s.getStoreByHelper(config.GPTScriptHelperPrefix + s.cfg.CredentialsStore) +} + +func (s *Store) getStoreByHelper(helper string) (credentials.Store, error) { + if helper == "" || helper == config.GPTScriptHelperPrefix+"file" { + return credentials.NewFileStore(s.cfg), nil + } + return NewHelper(s.cfg, helper) +} + +func validateCredentialCtx(ctx string) error { + if ctx == "" { + return fmt.Errorf("credential context cannot be empty") + } + + if ctx == "*" { // this represents "all contexts" and is allowed + return nil + } + + // check alphanumeric + r := regexp.MustCompile("^[a-zA-Z0-9]+$") + if !r.MatchString(ctx) { + return fmt.Errorf("credential context must be alphanumeric") + } + return nil +} diff --git a/pkg/engine/cmd.go b/pkg/engine/cmd.go index 7e318eb9..da7b3bca 100644 --- a/pkg/engine/cmd.go +++ b/pkg/engine/cmd.go @@ -60,6 +60,7 @@ func (e *Engine) runCommand(ctx context.Context, tool types.Tool, input string) output := &bytes.Buffer{} all := &bytes.Buffer{} + cmd.Stdin = os.Stdin cmd.Stderr = io.MultiWriter(all, os.Stderr) cmd.Stdout = io.MultiWriter(all, output) diff --git a/pkg/engine/engine.go b/pkg/engine/engine.go index fda1f06a..09c037a1 100644 --- a/pkg/engine/engine.go +++ b/pkg/engine/engine.go @@ -54,12 +54,15 @@ type CallResult struct { } type Context struct { - ID string - Ctx context.Context - Parent *Context - Program *types.Program - Tool types.Tool - InputContext []InputContext + ID string + Ctx context.Context + Parent *Context + Program *types.Program + Tool types.Tool + InputContext []InputContext + CredentialContext string + // IsCredential indicates that the current call is for a credential tool + IsCredential bool } type InputContext struct { @@ -103,17 +106,23 @@ func NewContext(ctx context.Context, prg *types.Program) Context { return callCtx } -func (c *Context) SubCall(ctx context.Context, toolID, callID string) (Context, error) { +func (c *Context) SubCall(ctx context.Context, toolID, callID string, isCredentialTool bool) (Context, error) { tool, ok := c.Program.ToolSet[toolID] if !ok { return Context{}, fmt.Errorf("failed to file tool for id [%s]", toolID) } + + if callID == "" { + callID = fmt.Sprint(atomic.AddInt32(&execID, 1)) + } + return Context{ - ID: callID, - Ctx: ctx, - Parent: c, - Program: c.Program, - Tool: tool, + ID: callID, + Ctx: ctx, + Parent: c, + Program: c.Program, + Tool: tool, + IsCredential: isCredentialTool, // disallow calls to the LLM if this is a credential tool }, nil } @@ -148,6 +157,10 @@ func (e *Engine) Start(ctx Context, input string) (*Return, error) { }, nil } + if ctx.IsCredential { + return nil, fmt.Errorf("credential tools cannot make calls to the LLM") + } + completion := types.CompletionRequest{ Model: tool.Parameters.ModelName, MaxTokens: tool.Parameters.MaxTokens, diff --git a/pkg/gptscript/gptscript.go b/pkg/gptscript/gptscript.go index 41a0a593..08c85b4d 100644 --- a/pkg/gptscript/gptscript.go +++ b/pkg/gptscript/gptscript.go @@ -22,12 +22,13 @@ type GPTScript struct { } type Options struct { - Cache cache.Options - OpenAI openai.Options - Monitor monitor.Options - Runner runner.Options - Quiet *bool `usage:"No output logging (set --quiet=false to force on even when there is no TTY)" short:"q"` - Env []string `usage:"-"` + Cache cache.Options + OpenAI openai.Options + Monitor monitor.Options + Runner runner.Options + CredentialContext string `usage:"Context name in which to store credentials" default:"default"` + Quiet *bool `usage:"No output logging (set --quiet=false to force on even when there is no TTY)" short:"q"` + Env []string `usage:"-"` } func complete(opts *Options) (result *Options) { @@ -76,7 +77,7 @@ func New(opts *Options) (*GPTScript, error) { opts.Runner.RuntimeManager = runtimes.Default(cacheClient.CacheDir()) } - runner, err := runner.New(registry, opts.Runner) + runner, err := runner.New(registry, opts.CredentialContext, opts.Runner) if err != nil { return nil, err } diff --git a/pkg/loader/loader.go b/pkg/loader/loader.go index eb06a88e..3ed7725b 100644 --- a/pkg/loader/loader.go +++ b/pkg/loader/loader.go @@ -199,7 +199,8 @@ func link(ctx context.Context, prg *types.Program, base *source, tool types.Tool for _, targetToolName := range slices.Concat(tool.Parameters.Tools, tool.Parameters.Export, tool.Parameters.ExportContext, - tool.Parameters.Context) { + tool.Parameters.Context, + tool.Parameters.Credentials) { localTool, ok := localTools[targetToolName] if ok { var linkedTool types.Tool diff --git a/pkg/monitor/display.go b/pkg/monitor/display.go index 5aaffbb0..55cd8bcb 100644 --- a/pkg/monitor/display.go +++ b/pkg/monitor/display.go @@ -12,6 +12,7 @@ import ( "sync/atomic" "time" + "github.com/fatih/color" "github.com/gptscript-ai/gptscript/pkg/runner" "github.com/gptscript-ai/gptscript/pkg/types" ) @@ -215,6 +216,16 @@ func (d *display) Event(event runner.Event) { call: ¤tCall, prg: d.dump.Program, calls: d.dump.Calls, + credential: event.CallContext.IsCredential, + } + + if event.CallContext.Parent != nil { + for name, id := range event.CallContext.Parent.Tool.ToolMapping { + if id == event.CallContext.Tool.ID { + callName.userSpecifiedToolName = name + break + } + } } switch event.Type { @@ -266,6 +277,9 @@ func (d *display) Event(event runner.Event) { } func (d *display) Stop(output string, err error) { + d.callLock.Lock() + defer d.callLock.Unlock() + log.Fields("runID", d.dump.ID, "output", output, "err", err).Debugf("Run stopped") d.dump.Output = output d.dump.Err = err @@ -306,6 +320,13 @@ func (d *display) Dump(out io.Writer) error { return enc.Encode(d.dump) } +func (d *display) Pause() func() { + d.callLock.Lock() + return func() { + d.callLock.Unlock() + } +} + func toJSON(obj any) jsonDump { return jsonDump{obj: obj} } @@ -327,10 +348,12 @@ func (j jsonDump) String() string { } type callName struct { - prettyIDMap map[string]string - call *call - prg *types.Program - calls []call + prettyIDMap map[string]string + call *call + prg *types.Program + calls []call + credential bool + userSpecifiedToolName string } func (c callName) String() string { @@ -339,6 +362,14 @@ func (c callName) String() string { currentCall = c.call ) + if c.credential { + // We want to print the credential tool in the same format that the user referenced it, if possible. + if c.userSpecifiedToolName != "" { + return "credential: " + color.YellowString(c.userSpecifiedToolName) + } + return "credential: " + color.YellowString(currentCall.ToolID) + } + for { tool := c.prg.ToolSet[currentCall.ToolID] name := tool.Parameters.Name diff --git a/pkg/mvl/log.go b/pkg/mvl/log.go index 7bee5a93..dd40b27e 100644 --- a/pkg/mvl/log.go +++ b/pkg/mvl/log.go @@ -114,6 +114,10 @@ func (l *Logger) Tracef(msg string, args ...any) { l.log.WithFields(l.fields).Tracef(msg, args...) } +func (l *Logger) Warnf(msg string, args ...any) { + l.log.WithFields(l.fields).Warnf(msg, args...) +} + func (l *Logger) IsDebug() bool { return l.log.IsLevelEnabled(logrus.DebugLevel) } diff --git a/pkg/openai/client.go b/pkg/openai/client.go index ee86bd23..8d47a98b 100644 --- a/pkg/openai/client.go +++ b/pkg/openai/client.go @@ -48,6 +48,7 @@ type Options struct { APIType openai.APIType `usage:"OpenAI API Type (valid: OPEN_AI, AZURE, AZURE_AD)" name:"openai-api-type" env:"OPENAI_API_TYPE"` OrgID string `usage:"OpenAI organization ID" name:"openai-org-id" env:"OPENAI_ORG_ID"` DefaultModel string `usage:"Default LLM model to use" default:"gpt-4-turbo-preview"` + ConfigFile string `usage:"Path to GPTScript config file" name:"config"` SetSeed bool `usage:"-"` CacheKey string `usage:"-"` Cache *cache.Client diff --git a/pkg/parser/parser.go b/pkg/parser/parser.go index 9557b340..afe5ce78 100644 --- a/pkg/parser/parser.go +++ b/pkg/parser/parser.go @@ -122,6 +122,8 @@ func isParam(line string, tool *types.Tool) (_ bool, err error) { if err != nil { return false, err } + case "credentials", "creds", "credential", "cred": + tool.Parameters.Credentials = append(tool.Parameters.Credentials, csv(strings.ToLower(value))...) default: return false, nil } diff --git a/pkg/runner/monitor.go b/pkg/runner/monitor.go index 99aafb1f..87543eda 100644 --- a/pkg/runner/monitor.go +++ b/pkg/runner/monitor.go @@ -20,3 +20,7 @@ func (n noopMonitor) Event(Event) { } func (n noopMonitor) Stop(string, error) {} + +func (n noopMonitor) Pause() func() { + return func() {} +} diff --git a/pkg/runner/runner.go b/pkg/runner/runner.go index 9e6edbd9..82f1f617 100644 --- a/pkg/runner/runner.go +++ b/pkg/runner/runner.go @@ -9,6 +9,8 @@ import ( "sync" "time" + "github.com/gptscript-ai/gptscript/pkg/config" + "github.com/gptscript-ai/gptscript/pkg/credentials" "github.com/gptscript-ai/gptscript/pkg/engine" "github.com/gptscript-ai/gptscript/pkg/types" "golang.org/x/sync/errgroup" @@ -20,9 +22,12 @@ type MonitorFactory interface { type Monitor interface { Event(event Event) + Pause() func() Stop(output string, err error) } +type MonitorKey struct{} + type Options struct { MonitorFactory MonitorFactory `usage:"-"` RuntimeManager engine.RuntimeManager `usage:"-"` @@ -54,15 +59,17 @@ type Runner struct { factory MonitorFactory runtimeManager engine.RuntimeManager ports engine.Ports + credCtx string } -func New(client engine.Model, opts ...Options) (*Runner, error) { +func New(client engine.Model, credCtx string, opts ...Options) (*Runner, error) { opt := complete(opts...) runner := &Runner{ c: client, factory: opt.MonitorFactory, runtimeManager: opt.RuntimeManager, + credCtx: credCtx, } if opt.StartPort != 0 { @@ -123,7 +130,7 @@ func (r *Runner) getContext(callCtx engine.Context, monitor Monitor, env []strin } for _, toolID := range toolIDs { - content, err := r.subCall(callCtx.Ctx, callCtx, monitor, env, toolID, "", "") + content, err := r.subCall(callCtx.Ctx, callCtx, monitor, env, toolID, "", "", false) if err != nil { return nil, err } @@ -139,6 +146,14 @@ func (r *Runner) call(callCtx engine.Context, monitor Monitor, env []string, inp progress, progressClose := streamProgress(&callCtx, monitor) defer progressClose() + if len(callCtx.Tool.Credentials) > 0 { + var err error + env, err = r.handleCredentials(callCtx, monitor, env) + if err != nil { + return "", err + } + } + var err error callCtx.InputContext, err = r.getContext(callCtx, monitor, env) if err != nil { @@ -160,6 +175,11 @@ func (r *Runner) call(callCtx engine.Context, monitor Monitor, env []string, inp Content: input, }) + // The sys.prompt tool is a special case where we need to pass the monitor to the builtin function. + if callCtx.Tool.BuiltinFunc != nil && callCtx.Tool.ID == "sys.prompt" { + callCtx.Ctx = context.WithValue(callCtx.Ctx, MonitorKey{}, monitor) + } + result, err := e.Start(callCtx, input) if err != nil { return "", err @@ -246,8 +266,8 @@ func streamProgress(callCtx *engine.Context, monitor Monitor) (chan<- types.Comp } } -func (r *Runner) subCall(ctx context.Context, parentContext engine.Context, monitor Monitor, env []string, toolID, input, callID string) (string, error) { - callCtx, err := parentContext.SubCall(ctx, toolID, callID) +func (r *Runner) subCall(ctx context.Context, parentContext engine.Context, monitor Monitor, env []string, toolID, input, callID string, isCredentialTool bool) (string, error) { + callCtx, err := parentContext.SubCall(ctx, toolID, callID, isCredentialTool) if err != nil { return "", err } @@ -263,7 +283,7 @@ func (r *Runner) subCalls(callCtx engine.Context, monitor Monitor, env []string, eg, subCtx := errgroup.WithContext(callCtx.Ctx) for id, call := range lastReturn.Calls { eg.Go(func() error { - result, err := r.subCall(subCtx, callCtx, monitor, env, call.ToolID, call.Input, id) + result, err := r.subCall(subCtx, callCtx, monitor, env, call.ToolID, call.Input, id, false) if err != nil { return err } @@ -307,3 +327,67 @@ func recordStateMessage(state *engine.State) error { filename := filepath.Join(tmpdir, fmt.Sprintf("gptscript-state-%v-%v", hostname, time.Now().UnixMilli())) return os.WriteFile(filename, data, 0444) } + +func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env []string) ([]string, error) { + c, err := config.ReadCLIConfig("") + if err != nil { + return nil, fmt.Errorf("failed to read CLI config: %w", err) + } + + store, err := credentials.NewStore(c, r.credCtx) + if err != nil { + return nil, fmt.Errorf("failed to create credentials store: %w", err) + } + + for _, credToolName := range callCtx.Tool.Credentials { + cred, exists, err := store.Get(credToolName) + if err != nil { + return nil, fmt.Errorf("failed to get credentials for tool %s: %w", credToolName, err) + } + + // If the credential doesn't already exist in the store, run the credential tool in order to get the value, + // and save it in the store. + if !exists { + credToolID, ok := callCtx.Tool.ToolMapping[credToolName] + if !ok { + return nil, fmt.Errorf("failed to find ID for tool %s", credToolName) + } + + subCtx, err := callCtx.SubCall(callCtx.Ctx, credToolID, "", true) // leaving callID as "" will cause it to be set by the engine + if err != nil { + return nil, fmt.Errorf("failed to create subcall context for tool %s: %w", credToolName, err) + } + res, err := r.call(subCtx, monitor, env, "") + if err != nil { + return nil, fmt.Errorf("failed to run credential tool %s: %w", credToolName, err) + } + + var envMap struct { + Env map[string]string `json:"env"` + } + if err := json.Unmarshal([]byte(res), &envMap); err != nil { + return nil, fmt.Errorf("failed to unmarshal credential tool %s response: %w", credToolName, err) + } + + cred = &credentials.Credential{ + ToolName: credToolName, + Env: envMap.Env, + } + + // Only store the credential if the tool is on GitHub. + if callCtx.Program.ToolSet[credToolID].Source.Repo != nil { + if err := store.Add(*cred); err != nil { + return nil, fmt.Errorf("failed to add credential for tool %s: %w", credToolName, err) + } + } else { + log.Warnf("Not saving credential for local tool %s - credentials will only be saved for tools from GitHub.", credToolName) + } + } + + for k, v := range cred.Env { + env = append(env, fmt.Sprintf("%s=%s", k, v)) + } + } + + return env, nil +} diff --git a/pkg/server/server.go b/pkg/server/server.go index 93cfc339..9e86adbf 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -11,6 +11,7 @@ import ( "os" "path/filepath" "strings" + "sync" "sync/atomic" "time" @@ -319,14 +320,17 @@ func (s SessionFactory) Start(ctx context.Context, prg *types.Program, env []str } type Session struct { - id string - prj *types.Program - env []string - input string - events *broadcaster.Broadcaster[Event] + id string + prj *types.Program + env []string + input string + events *broadcaster.Broadcaster[Event] + runLock sync.Mutex } func (s *Session) Event(event runner.Event) { + s.runLock.Lock() + defer s.runLock.Unlock() s.events.C <- Event{ Event: event, RunID: s.id, @@ -347,5 +351,15 @@ func (s *Session) Stop(output string, err error) { if err != nil { e.Err = err.Error() } + + s.runLock.Lock() + defer s.runLock.Unlock() s.events.C <- e } + +func (s *Session) Pause() func() { + s.runLock.Lock() + return func() { + s.runLock.Unlock() + } +} diff --git a/pkg/tests/tester/runner.go b/pkg/tests/tester/runner.go index 06f170d8..45905664 100644 --- a/pkg/tests/tester/runner.go +++ b/pkg/tests/tester/runner.go @@ -119,7 +119,7 @@ func NewRunner(t *testing.T) *Runner { t: t, } - run, err := runner.New(c) + run, err := runner.New(c, "default") require.NoError(t, err) return &Runner{ diff --git a/pkg/types/tool.go b/pkg/types/tool.go index f7469be9..d1553827 100644 --- a/pkg/types/tool.go +++ b/pkg/types/tool.go @@ -121,6 +121,7 @@ type Parameters struct { Context []string `json:"context,omitempty"` ExportContext []string `json:"exportContext,omitempty"` Export []string `json:"export,omitempty"` + Credentials []string `json:"credentials,omitempty"` Blocking bool `json:"-"` } @@ -203,6 +204,9 @@ func (t Tool) String() string { _, _ = fmt.Fprintln(buf) _, _ = fmt.Fprintln(buf, t.Instructions) } + if len(t.Parameters.Credentials) > 0 { + _, _ = fmt.Fprintf(buf, "Credentials: %s\n", strings.Join(t.Parameters.Credentials, ", ")) + } return buf.String() }