diff --git a/go.mod b/go.mod index 04b8f162..f3b7eb22 100644 --- a/go.mod +++ b/go.mod @@ -16,7 +16,7 @@ require ( github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 github.com/google/uuid v1.6.0 github.com/gptscript-ai/chat-completion-client v0.0.0-20240531200700-af8e7ecf0379 - github.com/gptscript-ai/tui v0.0.0-20240614023948-004dc1597dd7 + github.com/gptscript-ai/tui v0.0.0-20240614062633-985091711b0a github.com/hexops/autogold/v2 v2.2.1 github.com/hexops/valast v1.4.4 github.com/jaytaylor/html2text v0.0.0-20230321000545-74c2419ad056 diff --git a/go.sum b/go.sum index fa9bd54b..bbaa5899 100644 --- a/go.sum +++ b/go.sum @@ -173,8 +173,8 @@ github.com/gptscript-ai/chat-completion-client v0.0.0-20240531200700-af8e7ecf037 github.com/gptscript-ai/chat-completion-client v0.0.0-20240531200700-af8e7ecf0379/go.mod h1:7P/o6/IWa1KqsntVf68hSnLKuu3+xuqm6lYhch1w4jo= github.com/gptscript-ai/go-gptscript v0.0.0-20240613214812-8111c2b02d71 h1:WehkkausLuXI91ePpIVrzZ6eBmfFIU/HfNsSA1CHiwo= github.com/gptscript-ai/go-gptscript v0.0.0-20240613214812-8111c2b02d71/go.mod h1:Dh6vYRAiVcyC3ElZIGzTvNF1FxtYwA07BHfSiFKQY7s= -github.com/gptscript-ai/tui v0.0.0-20240614023948-004dc1597dd7 h1:t+IuS+4JLUnwLHv+bgJQ2jHVT9ii0SLR3D7eNTZ47fg= -github.com/gptscript-ai/tui v0.0.0-20240614023948-004dc1597dd7/go.mod h1:ZlyM+BRiD6mV04w+Xw2mXP1VKGEUbn8BvwrosWlplUo= +github.com/gptscript-ai/tui v0.0.0-20240614062633-985091711b0a h1:LFsEDiIAx0Rq0V6aOMlRjXMMIfkA3uEhqqqjoggLlDQ= +github.com/gptscript-ai/tui v0.0.0-20240614062633-985091711b0a/go.mod h1:ZlyM+BRiD6mV04w+Xw2mXP1VKGEUbn8BvwrosWlplUo= github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= diff --git a/pkg/cli/gptscript.go b/pkg/cli/gptscript.go index 2db70c46..e7f9e8ab 100644 --- a/pkg/cli/gptscript.go +++ b/pkg/cli/gptscript.go @@ -132,10 +132,11 @@ func (r *GPTScript) NewGPTScriptOpts() (gptscript.Options, error) { CredentialOverride: r.CredentialOverride, Sequential: r.ForceSequential, }, - Quiet: r.Quiet, - Env: os.Environ(), - CredentialContext: r.CredentialContext, - Workspace: r.Workspace, + Quiet: r.Quiet, + Env: os.Environ(), + CredentialContext: r.CredentialContext, + Workspace: r.Workspace, + DisablePromptServer: r.UI, } if r.Confirm { @@ -452,7 +453,6 @@ func (r *GPTScript) Run(cmd *cobra.Command, args []string) (retErr error) { Workspace: r.Workspace, SaveChatStateFile: r.SaveChatStateFile, ChatState: chatState, - ExtraEnv: gptScript.ExtraEnv, }) } return chat.Start(cmd.Context(), chatState, gptScript, func() (types.Program, error) { diff --git a/pkg/context/context.go b/pkg/context/context.go index 31474f6c..0169d0e0 100644 --- a/pkg/context/context.go +++ b/pkg/context/context.go @@ -46,3 +46,14 @@ func GetLogger(ctx context.Context) mvl.Logger { return l } + +type envKey struct{} + +func WithEnv(ctx context.Context, env []string) context.Context { + return context.WithValue(ctx, envKey{}, env) +} + +func GetEnv(ctx context.Context) []string { + l, _ := ctx.Value(envKey{}).([]string) + return l +} diff --git a/pkg/engine/engine.go b/pkg/engine/engine.go index a2804fa7..c94b236a 100644 --- a/pkg/engine/engine.go +++ b/pkg/engine/engine.go @@ -8,6 +8,7 @@ import ( "sync" "github.com/gptscript-ai/gptscript/pkg/config" + gcontext "github.com/gptscript-ai/gptscript/pkg/context" "github.com/gptscript-ai/gptscript/pkg/counter" "github.com/gptscript-ai/gptscript/pkg/system" "github.com/gptscript-ai/gptscript/pkg/types" @@ -328,7 +329,7 @@ func (e *Engine) complete(ctx context.Context, state *State) (*Return, error) { } }() - resp, err := e.Model.Call(ctx, state.Completion, progress) + resp, err := e.Model.Call(gcontext.WithEnv(ctx, e.Env), state.Completion, progress) if err != nil { return nil, err } diff --git a/pkg/gptscript/gptscript.go b/pkg/gptscript/gptscript.go index a177d85c..a083afa1 100644 --- a/pkg/gptscript/gptscript.go +++ b/pkg/gptscript/gptscript.go @@ -38,14 +38,15 @@ type GPTScript struct { } type Options struct { - Cache cache.Options - OpenAI openai.Options - Monitor monitor.Options - Runner runner.Options - CredentialContext string - Quiet *bool - Workspace string - Env []string + Cache cache.Options + OpenAI openai.Options + Monitor monitor.Options + Runner runner.Options + CredentialContext string + Quiet *bool + Workspace string + DisablePromptServer bool + Env []string } func complete(opts ...Options) Options { @@ -60,6 +61,7 @@ func complete(opts ...Options) Options { result.Quiet = types.FirstSet(opt.Quiet, result.Quiet) result.Workspace = types.FirstSet(opt.Workspace, result.Workspace) result.Env = append(result.Env, opt.Env...) + result.DisablePromptServer = types.FirstSet(opt.DisablePromptServer, result.DisablePromptServer) } if result.Quiet == nil { @@ -123,15 +125,21 @@ func New(o ...Options) (*GPTScript, error) { return nil, err } - ctx, closeServer := context.WithCancel(context2.AddPauseFuncToCtx(context.Background(), opts.Runner.MonitorFactory.Pause)) - extraEnv, err := prompt.NewServer(ctx, opts.Env) - if err != nil { - closeServer() - return nil, err + var ( + extraEnv []string + closeServer = func() {} + ) + if !opts.DisablePromptServer { + var ctx context.Context + ctx, closeServer = context.WithCancel(context2.AddPauseFuncToCtx(context.Background(), opts.Runner.MonitorFactory.Pause)) + extraEnv, err = prompt.NewServer(ctx, opts.Env) + if err != nil { + closeServer() + return nil, err + } } fullEnv := append(opts.Env, extraEnv...) - oaiClient.SetEnvs(fullEnv) remoteClient := remote.New(runner, fullEnv, cacheClient, credStore) if err := registry.AddClient(remoteClient); err != nil { diff --git a/pkg/openai/client.go b/pkg/openai/client.go index 31c2cf6c..27a6317c 100644 --- a/pkg/openai/client.go +++ b/pkg/openai/client.go @@ -12,6 +12,7 @@ import ( openai "github.com/gptscript-ai/chat-completion-client" "github.com/gptscript-ai/gptscript/pkg/cache" + gcontext "github.com/gptscript-ai/gptscript/pkg/context" "github.com/gptscript-ai/gptscript/pkg/counter" "github.com/gptscript-ai/gptscript/pkg/credentials" "github.com/gptscript-ai/gptscript/pkg/hash" @@ -43,7 +44,6 @@ type Client struct { invalidAuth bool cacheKeyBase string setSeed bool - envs []string credStore credentials.CredentialStore } @@ -136,10 +136,6 @@ func (c *Client) ValidAuth() error { return nil } -func (c *Client) SetEnvs(env []string) { - c.envs = env -} - func (c *Client) Supports(ctx context.Context, modelName string) (bool, error) { models, err := c.ListModels(ctx) if err != nil { @@ -546,7 +542,7 @@ func (c *Client) call(ctx context.Context, request openai.ChatCompletionRequest, } func (c *Client) RetrieveAPIKey(ctx context.Context) error { - k, err := prompt.GetModelProviderCredential(ctx, c.credStore, BuiltinCredName, "OPENAI_API_KEY", "Please provide your OpenAI API key:", c.envs) + k, err := prompt.GetModelProviderCredential(ctx, c.credStore, BuiltinCredName, "OPENAI_API_KEY", "Please provide your OpenAI API key:", gcontext.GetEnv(ctx)) if err != nil { return err } diff --git a/pkg/prompt/prompt.go b/pkg/prompt/prompt.go index 047a6abc..6cf8febd 100644 --- a/pkg/prompt/prompt.go +++ b/pkg/prompt/prompt.go @@ -41,7 +41,7 @@ func sysPromptHTTP(ctx context.Context, envs []string, url string, prompt types. defer resp.Body.Close() if resp.StatusCode != 200 { - return "", fmt.Errorf("invalid status code [%d], expected 200", resp.StatusCode) + return "", fmt.Errorf("getting prompt response invalid status code [%d], expected 200", resp.StatusCode) } data, err = io.ReadAll(resp.Body) @@ -75,17 +75,23 @@ func SysPrompt(ctx context.Context, envs []string, input string) (_ string, err func sysPrompt(ctx context.Context, req types.Prompt) (_ string, err error) { defer context2.GetPauseFuncFromCtx(ctx)()() - if req.Message != "" { + if req.Message != "" && len(req.Fields) != 1 { _, _ = fmt.Fprintln(os.Stderr, req.Message) } results := map[string]string{} for _, f := range req.Fields { - var value string + var ( + value string + msg = f + ) + if len(req.Fields) == 1 && req.Message != "" { + msg = req.Message + } if req.Sensitive { - err = survey.AskOne(&survey.Password{Message: f}, &value, survey.WithStdio(os.Stdin, os.Stderr, os.Stderr)) + err = survey.AskOne(&survey.Password{Message: msg}, &value, survey.WithStdio(os.Stdin, os.Stderr, os.Stderr)) } else { - err = survey.AskOne(&survey.Input{Message: f}, &value, survey.WithStdio(os.Stdin, os.Stderr, os.Stderr)) + err = survey.AskOne(&survey.Input{Message: msg}, &value, survey.WithStdio(os.Stdin, os.Stderr, os.Stderr)) } if err != nil { return "", err diff --git a/pkg/prompt/server.go b/pkg/prompt/server.go index e84ab4ec..2f7748b3 100644 --- a/pkg/prompt/server.go +++ b/pkg/prompt/server.go @@ -15,9 +15,11 @@ import ( func NewServer(ctx context.Context, envs []string) ([]string, error) { for _, env := range envs { - v, ok := strings.CutPrefix(env, types.PromptTokenEnvVar+"=") - if ok && v != "" { - return nil, nil + for _, k := range []string{types.PromptURLEnvVar, types.PromptTokenEnvVar} { + v, ok := strings.CutPrefix(env, k+"=") + if ok && v != "" { + return nil, nil + } } } @@ -34,7 +36,7 @@ func NewServer(ctx context.Context, envs []string) ([]string, error) { Handler: http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { if r.Header.Get("Authorization") != "Bearer "+token { rw.WriteHeader(http.StatusUnauthorized) - _, _ = rw.Write([]byte("Unauthorized")) + _, _ = rw.Write([]byte("Unauthorized (invalid token)")) return } diff --git a/pkg/remote/remote.go b/pkg/remote/remote.go index f707c8d3..3837879a 100644 --- a/pkg/remote/remote.go +++ b/pkg/remote/remote.go @@ -10,6 +10,7 @@ import ( "sync" "github.com/gptscript-ai/gptscript/pkg/cache" + gcontext "github.com/gptscript-ai/gptscript/pkg/context" "github.com/gptscript-ai/gptscript/pkg/credentials" "github.com/gptscript-ai/gptscript/pkg/engine" env2 "github.com/gptscript-ai/gptscript/pkg/env" @@ -176,5 +177,5 @@ func (c *Client) load(ctx context.Context, toolName string) (*openai.Client, err } func (c *Client) retrieveAPIKey(ctx context.Context, env, url string) (string, error) { - return prompt.GetModelProviderCredential(ctx, c.credStore, url, env, fmt.Sprintf("Please provide your API key for %s", url), c.envs) + return prompt.GetModelProviderCredential(ctx, c.credStore, url, env, fmt.Sprintf("Please provide your API key for %s", url), append(gcontext.GetEnv(ctx), c.envs...)) }