Skip to content

Commit

Permalink
enhance: update credentials framework for OAuth support (#305)
Browse files Browse the repository at this point in the history
Signed-off-by: Grant Linville <[email protected]>
  • Loading branch information
g-linville authored Jun 24, 2024
1 parent abcd863 commit 60da900
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 64 deletions.
80 changes: 50 additions & 30 deletions pkg/cli/credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"sort"
"strings"
"text/tabwriter"
"time"

cmd2 "github.com/acorn-io/cmd"
"github.com/gptscript-ai/gptscript/pkg/cache"
Expand All @@ -14,6 +15,11 @@ import (
"github.com/spf13/cobra"
)

const (
expiresNever = "never"
expiresExpired = "expired"
)

type Credential struct {
root *GPTScript
AllContexts bool `usage:"List credentials for all contexts" local:"true"`
Expand Down Expand Up @@ -46,6 +52,7 @@ func (c *Credential) Run(_ *cobra.Command, _ []string) error {
}
opts.Cache = cache.Complete(opts.Cache)

// Initialize the credential store and get all the credentials.
store, err := credentials.NewStore(cfg, ctx, opts.Cache.CacheDir)
if err != nil {
return fmt.Errorf("failed to get credentials store: %w", err)
Expand All @@ -56,6 +63,10 @@ func (c *Credential) Run(_ *cobra.Command, _ []string) error {
return fmt.Errorf("failed to list credentials: %w", err)
}

w := tabwriter.NewWriter(os.Stdout, 10, 1, 3, ' ', 0)
defer w.Flush()

// Sort credentials and print column names, depending on the options.
if c.AllContexts {
// Sort credentials by context
sort.Slice(creds, func(i, j int) bool {
Expand All @@ -65,25 +76,10 @@ func (c *Credential) Run(_ *cobra.Command, _ []string) error {
return creds[i].Context < creds[j].Context
})

w := tabwriter.NewWriter(os.Stdout, 10, 1, 3, ' ', 0)
defer w.Flush()

if c.ShowEnvVars {
_, _ = w.Write([]byte("CONTEXT\tCREDENTIAL\tENVIRONMENT VARIABLES\n"))

for _, cred := range creds {
envVars := make([]string, 0, len(cred.Env))
for envVar := range cred.Env {
envVars = append(envVars, envVar)
}
sort.Strings(envVars)
_, _ = fmt.Fprintf(w, "%s\t%s\t%s\n", cred.Context, cred.ToolName, strings.Join(envVars, ", "))
}
_, _ = w.Write([]byte("CONTEXT\tCREDENTIAL\tEXPIRES IN\tENV\n"))
} else {
_, _ = w.Write([]byte("CONTEXT\tCREDENTIAL\n"))
for _, cred := range creds {
_, _ = fmt.Fprintf(w, "%s\t%s\n", cred.Context, cred.ToolName)
}
_, _ = w.Write([]byte("CONTEXT\tCREDENTIAL\tEXPIRES IN\n"))
}
} else {
// Sort credentials by tool name
Expand All @@ -92,24 +88,48 @@ func (c *Credential) Run(_ *cobra.Command, _ []string) error {
})

if c.ShowEnvVars {
w := tabwriter.NewWriter(os.Stdout, 10, 1, 3, ' ', 0)
defer w.Flush()
_, _ = w.Write([]byte("CREDENTIAL\tENVIRONMENT VARIABLES\n"))

for _, cred := range creds {
envVars := make([]string, 0, len(cred.Env))
for envVar := range cred.Env {
envVars = append(envVars, envVar)
}
sort.Strings(envVars)
_, _ = fmt.Fprintf(w, "%s\t%s\n", cred.ToolName, strings.Join(envVars, ", "))
_, _ = w.Write([]byte("CREDENTIAL\tEXPIRES IN\tENV\n"))
} else {
_, _ = w.Write([]byte("CREDENTIAL\tEXPIRES IN\n"))
}
}

for _, cred := range creds {
expires := expiresNever
if cred.ExpiresAt != nil {
expires = expiresExpired
if !cred.IsExpired() {
expires = time.Until(*cred.ExpiresAt).Truncate(time.Second).String()
}
}

var fields []any
if c.AllContexts {
fields = []any{cred.Context, cred.ToolName, expires}
} else {
for _, cred := range creds {
fmt.Println(cred.ToolName)
fields = []any{cred.ToolName, expires}
}

if c.ShowEnvVars {
envVars := make([]string, 0, len(cred.Env))
for envVar := range cred.Env {
envVars = append(envVars, envVar)
}
sort.Strings(envVars)
fields = append(fields, strings.Join(envVars, ", "))
}

printFields(w, fields)
}

return nil
}

func printFields(w *tabwriter.Writer, fields []any) {
if len(fields) == 0 {
return
}

fmtStr := strings.Repeat("%s\t", len(fields)-1) + "%s\n"
_, _ = fmt.Fprintf(w, fmtStr, fields...)
}
47 changes: 32 additions & 15 deletions pkg/credentials/credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,43 +4,58 @@ import (
"encoding/json"
"fmt"
"strings"
"time"

"github.com/docker/cli/cli/config/types"
)

const ctxSeparator = "///"

type CredentialType string

const (
ctxSeparator = "///"
CredentialTypeTool CredentialType = "tool"
CredentialTypeModelProvider CredentialType = "modelProvider"
ExistingCredential = "GPTSCRIPT_EXISTING_CREDENTIAL"
)

type Credential struct {
Context string `json:"context"`
ToolName string `json:"toolName"`
Type CredentialType `json:"type"`
Env map[string]string `json:"env"`
Context string `json:"context"`
ToolName string `json:"toolName"`
Type CredentialType `json:"type"`
Env map[string]string `json:"env"`
ExpiresAt *time.Time `json:"expiresAt"`
RefreshToken string `json:"refreshToken"`
}

func (c Credential) IsExpired() bool {
if c.ExpiresAt == nil {
return false
}
return time.Now().After(*c.ExpiresAt)
}

func (c Credential) toDockerAuthConfig() (types.AuthConfig, error) {
env, err := json.Marshal(c.Env)
cred, err := json.Marshal(c)
if err != nil {
return types.AuthConfig{}, err
}

return types.AuthConfig{
Username: string(c.Type),
Password: string(env),
Password: string(cred),
ServerAddress: toolNameWithCtx(c.ToolName, c.Context),
}, nil
}

func credentialFromDockerAuthConfig(authCfg types.AuthConfig) (Credential, error) {
var env map[string]string
if err := json.Unmarshal([]byte(authCfg.Password), &env); err != nil {
return Credential{}, err
var cred Credential
if err := json.Unmarshal([]byte(authCfg.Password), &cred); err != nil || len(cred.Env) == 0 {
// Legacy: try unmarshalling into just an env map
var env map[string]string
if err := json.Unmarshal([]byte(authCfg.Password), &env); err != nil {
return Credential{}, err
}
cred.Env = env
}

// We used to hardcode the username as "gptscript" before CredentialType was introduced, so
Expand All @@ -62,10 +77,12 @@ func credentialFromDockerAuthConfig(authCfg types.AuthConfig) (Credential, error
}

return Credential{
Context: ctx,
ToolName: tool,
Type: CredentialType(credType),
Env: env,
Context: ctx,
ToolName: tool,
Type: CredentialType(credType),
Env: cred.Env,
ExpiresAt: cred.ExpiresAt,
RefreshToken: cred.RefreshToken,
}, nil
}

Expand Down
45 changes: 26 additions & 19 deletions pkg/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ var (
EventTypeRunFinish EventType = "runFinish"
)

func getContextInput(prg *types.Program, ref types.ToolReference, input string) (string, error) {
func getToolRefInput(prg *types.Program, ref types.ToolReference, input string) (string, error) {
if ref.Arg == "" {
return "", nil
}
Expand Down Expand Up @@ -355,7 +355,7 @@ func (r *Runner) getContext(callCtx engine.Context, state *State, monitor Monito
continue
}

contextInput, err := getContextInput(callCtx.Program, toolRef, input)
contextInput, err := getToolRefInput(callCtx.Program, toolRef, input)
if err != nil {
return nil, nil, err
}
Expand Down Expand Up @@ -867,7 +867,7 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env
}

var (
cred *credentials.Credential
c *credentials.Credential
exists bool
)

Expand All @@ -879,25 +879,39 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env
// Only try to look up the cred if the tool is on GitHub or has an alias.
// If it is a GitHub tool and has an alias, the alias overrides the tool name, so we use it as the credential name.
if isGitHubTool(toolName) && credentialAlias == "" {
cred, exists, err = r.credStore.Get(toolName)
c, exists, err = r.credStore.Get(toolName)
if err != nil {
return nil, fmt.Errorf("failed to get credentials for tool %s: %w", toolName, err)
}
} else if credentialAlias != "" {
cred, exists, err = r.credStore.Get(credentialAlias)
c, exists, err = r.credStore.Get(credentialAlias)
if err != nil {
return nil, fmt.Errorf("failed to get credentials for tool %s: %w", credentialAlias, err)
}
}

if c == nil {
c = &credentials.Credential{}
}

// If the credential doesn't already exist in the store, run the credential tool in order to get the value,
// and save it in the store.
if !exists {
if !exists || c.IsExpired() {
credToolRefs, ok := callCtx.Tool.ToolMapping[credToolName]
if !ok || len(credToolRefs) != 1 {
return nil, fmt.Errorf("failed to find ID for tool %s", credToolName)
}

// If the existing credential is expired, we need to provide it to the cred tool through the environment.
if exists && c.IsExpired() {
credJSON, err := json.Marshal(c)
if err != nil {
return nil, fmt.Errorf("failed to marshal credential: %w", err)
}
env = append(env, fmt.Sprintf("%s=%s", credentials.ExistingCredential, string(credJSON)))
}

// Get the input for the credential tool, if there is any.
var input string
if args != nil {
inputBytes, err := json.Marshal(args)
Expand All @@ -916,21 +930,14 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env
return nil, fmt.Errorf("invalid state: credential tool [%s] can not result in a continuation", credToolName)
}

var envMap struct {
Env map[string]string `json:"env"`
}
if err := json.Unmarshal([]byte(*res.Result), &envMap); err != nil {
if err := json.Unmarshal([]byte(*res.Result), &c); err != nil {
return nil, fmt.Errorf("failed to unmarshal credential tool %s response: %w", credToolName, err)
}

cred = &credentials.Credential{
Type: credentials.CredentialTypeTool,
Env: envMap.Env,
ToolName: credName,
}
c.ToolName = credName
c.Type = credentials.CredentialTypeTool

isEmpty := true
for _, v := range cred.Env {
for _, v := range c.Env {
if v != "" {
isEmpty = false
break
Expand All @@ -941,15 +948,15 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env
if (isGitHubTool(toolName) && callCtx.Program.ToolSet[credToolRefs[0].ToolID].Source.Repo != nil) || credentialAlias != "" {
if isEmpty {
log.Warnf("Not saving empty credential for tool %s", toolName)
} else if err := r.credStore.Add(*cred); err != nil {
} else if err := r.credStore.Add(*c); err != nil {
return nil, fmt.Errorf("failed to add credential for tool %s: %w", toolName, err)
}
} else {
log.Warnf("Not saving credential for tool %s - credentials will only be saved for tools from GitHub, or tools that use aliases.", toolName)
}
}

for k, v := range cred.Env {
for k, v := range c.Env {
env = append(env, fmt.Sprintf("%s=%s", k, v))
}
}
Expand Down

0 comments on commit 60da900

Please sign in to comment.