Skip to content

Commit

Permalink
bug: respect run level env in openai prompt
Browse files Browse the repository at this point in the history
  • Loading branch information
ibuildthecloud committed Jun 14, 2024
1 parent 2aafa62 commit 1c6faca
Show file tree
Hide file tree
Showing 10 changed files with 64 additions and 39 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ require (
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510
github.com/google/uuid v1.6.0
github.com/gptscript-ai/chat-completion-client v0.0.0-20240531200700-af8e7ecf0379
github.com/gptscript-ai/tui v0.0.0-20240614023948-004dc1597dd7
github.com/gptscript-ai/tui v0.0.0-20240614062633-985091711b0a
github.com/hexops/autogold/v2 v2.2.1
github.com/hexops/valast v1.4.4
github.com/jaytaylor/html2text v0.0.0-20230321000545-74c2419ad056
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,8 @@ github.com/gptscript-ai/chat-completion-client v0.0.0-20240531200700-af8e7ecf037
github.com/gptscript-ai/chat-completion-client v0.0.0-20240531200700-af8e7ecf0379/go.mod h1:7P/o6/IWa1KqsntVf68hSnLKuu3+xuqm6lYhch1w4jo=
github.com/gptscript-ai/go-gptscript v0.0.0-20240613214812-8111c2b02d71 h1:WehkkausLuXI91ePpIVrzZ6eBmfFIU/HfNsSA1CHiwo=
github.com/gptscript-ai/go-gptscript v0.0.0-20240613214812-8111c2b02d71/go.mod h1:Dh6vYRAiVcyC3ElZIGzTvNF1FxtYwA07BHfSiFKQY7s=
github.com/gptscript-ai/tui v0.0.0-20240614023948-004dc1597dd7 h1:t+IuS+4JLUnwLHv+bgJQ2jHVT9ii0SLR3D7eNTZ47fg=
github.com/gptscript-ai/tui v0.0.0-20240614023948-004dc1597dd7/go.mod h1:ZlyM+BRiD6mV04w+Xw2mXP1VKGEUbn8BvwrosWlplUo=
github.com/gptscript-ai/tui v0.0.0-20240614062633-985091711b0a h1:LFsEDiIAx0Rq0V6aOMlRjXMMIfkA3uEhqqqjoggLlDQ=
github.com/gptscript-ai/tui v0.0.0-20240614062633-985091711b0a/go.mod h1:ZlyM+BRiD6mV04w+Xw2mXP1VKGEUbn8BvwrosWlplUo=
github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA=
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo=
Expand Down
10 changes: 5 additions & 5 deletions pkg/cli/gptscript.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,11 @@ func (r *GPTScript) NewGPTScriptOpts() (gptscript.Options, error) {
CredentialOverride: r.CredentialOverride,
Sequential: r.ForceSequential,
},
Quiet: r.Quiet,
Env: os.Environ(),
CredentialContext: r.CredentialContext,
Workspace: r.Workspace,
Quiet: r.Quiet,
Env: os.Environ(),
CredentialContext: r.CredentialContext,
Workspace: r.Workspace,
DisablePromptServer: r.UI,
}

if r.Confirm {
Expand Down Expand Up @@ -452,7 +453,6 @@ func (r *GPTScript) Run(cmd *cobra.Command, args []string) (retErr error) {
Workspace: r.Workspace,
SaveChatStateFile: r.SaveChatStateFile,
ChatState: chatState,
ExtraEnv: gptScript.ExtraEnv,
})
}
return chat.Start(cmd.Context(), chatState, gptScript, func() (types.Program, error) {
Expand Down
11 changes: 11 additions & 0 deletions pkg/context/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,14 @@ func GetLogger(ctx context.Context) mvl.Logger {

return l
}

type envKey struct{}

func WithEnv(ctx context.Context, env []string) context.Context {
return context.WithValue(ctx, envKey{}, env)
}

func GetEnv(ctx context.Context) []string {
l, _ := ctx.Value(envKey{}).([]string)
return l
}
3 changes: 2 additions & 1 deletion pkg/engine/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"sync"

"github.com/gptscript-ai/gptscript/pkg/config"
gcontext "github.com/gptscript-ai/gptscript/pkg/context"
"github.com/gptscript-ai/gptscript/pkg/counter"
"github.com/gptscript-ai/gptscript/pkg/system"
"github.com/gptscript-ai/gptscript/pkg/types"
Expand Down Expand Up @@ -328,7 +329,7 @@ func (e *Engine) complete(ctx context.Context, state *State) (*Return, error) {
}
}()

resp, err := e.Model.Call(ctx, state.Completion, progress)
resp, err := e.Model.Call(gcontext.WithEnv(ctx, e.Env), state.Completion, progress)
if err != nil {
return nil, err
}
Expand Down
36 changes: 22 additions & 14 deletions pkg/gptscript/gptscript.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,15 @@ type GPTScript struct {
}

type Options struct {
Cache cache.Options
OpenAI openai.Options
Monitor monitor.Options
Runner runner.Options
CredentialContext string
Quiet *bool
Workspace string
Env []string
Cache cache.Options
OpenAI openai.Options
Monitor monitor.Options
Runner runner.Options
CredentialContext string
Quiet *bool
Workspace string
DisablePromptServer bool
Env []string
}

func complete(opts ...Options) Options {
Expand All @@ -60,6 +61,7 @@ func complete(opts ...Options) Options {
result.Quiet = types.FirstSet(opt.Quiet, result.Quiet)
result.Workspace = types.FirstSet(opt.Workspace, result.Workspace)
result.Env = append(result.Env, opt.Env...)
result.DisablePromptServer = types.FirstSet(opt.DisablePromptServer, result.DisablePromptServer)
}

if result.Quiet == nil {
Expand Down Expand Up @@ -123,15 +125,21 @@ func New(o ...Options) (*GPTScript, error) {
return nil, err
}

ctx, closeServer := context.WithCancel(context2.AddPauseFuncToCtx(context.Background(), opts.Runner.MonitorFactory.Pause))
extraEnv, err := prompt.NewServer(ctx, opts.Env)
if err != nil {
closeServer()
return nil, err
var (
extraEnv []string
closeServer = func() {}
)
if !opts.DisablePromptServer {
var ctx context.Context
ctx, closeServer = context.WithCancel(context2.AddPauseFuncToCtx(context.Background(), opts.Runner.MonitorFactory.Pause))
extraEnv, err = prompt.NewServer(ctx, opts.Env)
if err != nil {
closeServer()
return nil, err
}
}

fullEnv := append(opts.Env, extraEnv...)
oaiClient.SetEnvs(fullEnv)

remoteClient := remote.New(runner, fullEnv, cacheClient, credStore)
if err := registry.AddClient(remoteClient); err != nil {
Expand Down
8 changes: 2 additions & 6 deletions pkg/openai/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (

openai "github.com/gptscript-ai/chat-completion-client"
"github.com/gptscript-ai/gptscript/pkg/cache"
gcontext "github.com/gptscript-ai/gptscript/pkg/context"
"github.com/gptscript-ai/gptscript/pkg/counter"
"github.com/gptscript-ai/gptscript/pkg/credentials"
"github.com/gptscript-ai/gptscript/pkg/hash"
Expand Down Expand Up @@ -43,7 +44,6 @@ type Client struct {
invalidAuth bool
cacheKeyBase string
setSeed bool
envs []string
credStore credentials.CredentialStore
}

Expand Down Expand Up @@ -136,10 +136,6 @@ func (c *Client) ValidAuth() error {
return nil
}

func (c *Client) SetEnvs(env []string) {
c.envs = env
}

func (c *Client) Supports(ctx context.Context, modelName string) (bool, error) {
models, err := c.ListModels(ctx)
if err != nil {
Expand Down Expand Up @@ -546,7 +542,7 @@ func (c *Client) call(ctx context.Context, request openai.ChatCompletionRequest,
}

func (c *Client) RetrieveAPIKey(ctx context.Context) error {
k, err := prompt.GetModelProviderCredential(ctx, c.credStore, BuiltinCredName, "OPENAI_API_KEY", "Please provide your OpenAI API key:", c.envs)
k, err := prompt.GetModelProviderCredential(ctx, c.credStore, BuiltinCredName, "OPENAI_API_KEY", "Please provide your OpenAI API key:", gcontext.GetEnv(ctx))
if err != nil {
return err
}
Expand Down
16 changes: 11 additions & 5 deletions pkg/prompt/prompt.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func sysPromptHTTP(ctx context.Context, envs []string, url string, prompt types.
defer resp.Body.Close()

if resp.StatusCode != 200 {
return "", fmt.Errorf("invalid status code [%d], expected 200", resp.StatusCode)
return "", fmt.Errorf("getting prompt response invalid status code [%d], expected 200", resp.StatusCode)
}

data, err = io.ReadAll(resp.Body)
Expand Down Expand Up @@ -75,17 +75,23 @@ func SysPrompt(ctx context.Context, envs []string, input string) (_ string, err
func sysPrompt(ctx context.Context, req types.Prompt) (_ string, err error) {
defer context2.GetPauseFuncFromCtx(ctx)()()

if req.Message != "" {
if req.Message != "" && len(req.Fields) != 1 {
_, _ = fmt.Fprintln(os.Stderr, req.Message)
}

results := map[string]string{}
for _, f := range req.Fields {
var value string
var (
value string
msg = f
)
if len(req.Fields) == 1 && req.Message != "" {
msg = req.Message
}
if req.Sensitive {
err = survey.AskOne(&survey.Password{Message: f}, &value, survey.WithStdio(os.Stdin, os.Stderr, os.Stderr))
err = survey.AskOne(&survey.Password{Message: msg}, &value, survey.WithStdio(os.Stdin, os.Stderr, os.Stderr))
} else {
err = survey.AskOne(&survey.Input{Message: f}, &value, survey.WithStdio(os.Stdin, os.Stderr, os.Stderr))
err = survey.AskOne(&survey.Input{Message: msg}, &value, survey.WithStdio(os.Stdin, os.Stderr, os.Stderr))
}
if err != nil {
return "", err
Expand Down
10 changes: 6 additions & 4 deletions pkg/prompt/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@ import (

func NewServer(ctx context.Context, envs []string) ([]string, error) {
for _, env := range envs {
v, ok := strings.CutPrefix(env, types.PromptTokenEnvVar+"=")
if ok && v != "" {
return nil, nil
for _, k := range []string{types.PromptURLEnvVar, types.PromptTokenEnvVar} {
v, ok := strings.CutPrefix(env, k+"=")
if ok && v != "" {
return nil, nil
}
}
}

Expand All @@ -34,7 +36,7 @@ func NewServer(ctx context.Context, envs []string) ([]string, error) {
Handler: http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
if r.Header.Get("Authorization") != "Bearer "+token {
rw.WriteHeader(http.StatusUnauthorized)
_, _ = rw.Write([]byte("Unauthorized"))
_, _ = rw.Write([]byte("Unauthorized (invalid token)"))
return
}

Expand Down
3 changes: 2 additions & 1 deletion pkg/remote/remote.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"sync"

"github.com/gptscript-ai/gptscript/pkg/cache"
gcontext "github.com/gptscript-ai/gptscript/pkg/context"
"github.com/gptscript-ai/gptscript/pkg/credentials"
"github.com/gptscript-ai/gptscript/pkg/engine"
env2 "github.com/gptscript-ai/gptscript/pkg/env"
Expand Down Expand Up @@ -176,5 +177,5 @@ func (c *Client) load(ctx context.Context, toolName string) (*openai.Client, err
}

func (c *Client) retrieveAPIKey(ctx context.Context, env, url string) (string, error) {
return prompt.GetModelProviderCredential(ctx, c.credStore, url, env, fmt.Sprintf("Please provide your API key for %s", url), c.envs)
return prompt.GetModelProviderCredential(ctx, c.credStore, url, env, fmt.Sprintf("Please provide your API key for %s", url), append(gcontext.GetEnv(ctx), c.envs...))
}

0 comments on commit 1c6faca

Please sign in to comment.