diff --git a/pkg/cli/credential.go b/pkg/cli/credential.go index 5f617941..6bfb1ed6 100644 --- a/pkg/cli/credential.go +++ b/pkg/cli/credential.go @@ -6,6 +6,7 @@ import ( "sort" "strings" "text/tabwriter" + "time" cmd2 "github.com/acorn-io/cmd" "github.com/gptscript-ai/gptscript/pkg/cache" @@ -14,6 +15,11 @@ import ( "github.com/spf13/cobra" ) +const ( + expiresNever = "never" + expiresExpired = "expired" +) + type Credential struct { root *GPTScript AllContexts bool `usage:"List credentials for all contexts" local:"true"` @@ -46,6 +52,7 @@ func (c *Credential) Run(_ *cobra.Command, _ []string) error { } opts.Cache = cache.Complete(opts.Cache) + // Initialize the credential store and get all the credentials. store, err := credentials.NewStore(cfg, ctx, opts.Cache.CacheDir) if err != nil { return fmt.Errorf("failed to get credentials store: %w", err) @@ -56,6 +63,10 @@ func (c *Credential) Run(_ *cobra.Command, _ []string) error { return fmt.Errorf("failed to list credentials: %w", err) } + w := tabwriter.NewWriter(os.Stdout, 10, 1, 3, ' ', 0) + defer w.Flush() + + // Sort credentials and print column names, depending on the options. if c.AllContexts { // Sort credentials by context sort.Slice(creds, func(i, j int) bool { @@ -65,25 +76,10 @@ func (c *Credential) Run(_ *cobra.Command, _ []string) error { return creds[i].Context < creds[j].Context }) - w := tabwriter.NewWriter(os.Stdout, 10, 1, 3, ' ', 0) - defer w.Flush() - if c.ShowEnvVars { - _, _ = w.Write([]byte("CONTEXT\tCREDENTIAL\tENVIRONMENT VARIABLES\n")) - - for _, cred := range creds { - envVars := make([]string, 0, len(cred.Env)) - for envVar := range cred.Env { - envVars = append(envVars, envVar) - } - sort.Strings(envVars) - _, _ = fmt.Fprintf(w, "%s\t%s\t%s\n", cred.Context, cred.ToolName, strings.Join(envVars, ", ")) - } + _, _ = w.Write([]byte("CONTEXT\tCREDENTIAL\tEXPIRES IN\tENV\n")) } else { - _, _ = w.Write([]byte("CONTEXT\tCREDENTIAL\n")) - for _, cred := range creds { - _, _ = fmt.Fprintf(w, "%s\t%s\n", cred.Context, cred.ToolName) - } + _, _ = w.Write([]byte("CONTEXT\tCREDENTIAL\tEXPIRES IN\n")) } } else { // Sort credentials by tool name @@ -92,24 +88,48 @@ func (c *Credential) Run(_ *cobra.Command, _ []string) error { }) if c.ShowEnvVars { - w := tabwriter.NewWriter(os.Stdout, 10, 1, 3, ' ', 0) - defer w.Flush() - _, _ = w.Write([]byte("CREDENTIAL\tENVIRONMENT VARIABLES\n")) - - for _, cred := range creds { - envVars := make([]string, 0, len(cred.Env)) - for envVar := range cred.Env { - envVars = append(envVars, envVar) - } - sort.Strings(envVars) - _, _ = fmt.Fprintf(w, "%s\t%s\n", cred.ToolName, strings.Join(envVars, ", ")) + _, _ = w.Write([]byte("CREDENTIAL\tEXPIRES IN\tENV\n")) + } else { + _, _ = w.Write([]byte("CREDENTIAL\tEXPIRES IN\n")) + } + } + + for _, cred := range creds { + expires := expiresNever + if cred.ExpiresAt != nil { + expires = expiresExpired + if !cred.IsExpired() { + expires = time.Until(*cred.ExpiresAt).Truncate(time.Second).String() } + } + + var fields []any + if c.AllContexts { + fields = []any{cred.Context, cred.ToolName, expires} } else { - for _, cred := range creds { - fmt.Println(cred.ToolName) + fields = []any{cred.ToolName, expires} + } + + if c.ShowEnvVars { + envVars := make([]string, 0, len(cred.Env)) + for envVar := range cred.Env { + envVars = append(envVars, envVar) } + sort.Strings(envVars) + fields = append(fields, strings.Join(envVars, ", ")) } + + printFields(w, fields) } return nil } + +func printFields(w *tabwriter.Writer, fields []any) { + if len(fields) == 0 { + return + } + + fmtStr := strings.Repeat("%s\t", len(fields)-1) + "%s\n" + _, _ = fmt.Fprintf(w, fmtStr, fields...) +} diff --git a/pkg/credentials/credential.go b/pkg/credentials/credential.go index 46f705dc..fc247b38 100644 --- a/pkg/credentials/credential.go +++ b/pkg/credentials/credential.go @@ -4,43 +4,58 @@ import ( "encoding/json" "fmt" "strings" + "time" "github.com/docker/cli/cli/config/types" ) -const ctxSeparator = "///" - type CredentialType string const ( + ctxSeparator = "///" CredentialTypeTool CredentialType = "tool" CredentialTypeModelProvider CredentialType = "modelProvider" + ExistingCredential = "GPTSCRIPT_EXISTING_CREDENTIAL" ) type Credential struct { - Context string `json:"context"` - ToolName string `json:"toolName"` - Type CredentialType `json:"type"` - Env map[string]string `json:"env"` + Context string `json:"context"` + ToolName string `json:"toolName"` + Type CredentialType `json:"type"` + Env map[string]string `json:"env"` + ExpiresAt *time.Time `json:"expiresAt"` + RefreshToken string `json:"refreshToken"` +} + +func (c Credential) IsExpired() bool { + if c.ExpiresAt == nil { + return false + } + return time.Now().After(*c.ExpiresAt) } func (c Credential) toDockerAuthConfig() (types.AuthConfig, error) { - env, err := json.Marshal(c.Env) + cred, err := json.Marshal(c) if err != nil { return types.AuthConfig{}, err } return types.AuthConfig{ Username: string(c.Type), - Password: string(env), + Password: string(cred), ServerAddress: toolNameWithCtx(c.ToolName, c.Context), }, nil } func credentialFromDockerAuthConfig(authCfg types.AuthConfig) (Credential, error) { - var env map[string]string - if err := json.Unmarshal([]byte(authCfg.Password), &env); err != nil { - return Credential{}, err + var cred Credential + if err := json.Unmarshal([]byte(authCfg.Password), &cred); err != nil || len(cred.Env) == 0 { + // Legacy: try unmarshalling into just an env map + var env map[string]string + if err := json.Unmarshal([]byte(authCfg.Password), &env); err != nil { + return Credential{}, err + } + cred.Env = env } // We used to hardcode the username as "gptscript" before CredentialType was introduced, so @@ -62,10 +77,12 @@ func credentialFromDockerAuthConfig(authCfg types.AuthConfig) (Credential, error } return Credential{ - Context: ctx, - ToolName: tool, - Type: CredentialType(credType), - Env: env, + Context: ctx, + ToolName: tool, + Type: CredentialType(credType), + Env: cred.Env, + ExpiresAt: cred.ExpiresAt, + RefreshToken: cred.RefreshToken, }, nil } diff --git a/pkg/runner/runner.go b/pkg/runner/runner.go index fb2cba0d..e4794535 100644 --- a/pkg/runner/runner.go +++ b/pkg/runner/runner.go @@ -250,7 +250,7 @@ var ( EventTypeRunFinish EventType = "runFinish" ) -func getContextInput(prg *types.Program, ref types.ToolReference, input string) (string, error) { +func getToolRefInput(prg *types.Program, ref types.ToolReference, input string) (string, error) { if ref.Arg == "" { return "", nil } @@ -355,7 +355,7 @@ func (r *Runner) getContext(callCtx engine.Context, state *State, monitor Monito continue } - contextInput, err := getContextInput(callCtx.Program, toolRef, input) + contextInput, err := getToolRefInput(callCtx.Program, toolRef, input) if err != nil { return nil, nil, err } @@ -867,7 +867,7 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env } var ( - cred *credentials.Credential + c *credentials.Credential exists bool ) @@ -879,25 +879,39 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env // Only try to look up the cred if the tool is on GitHub or has an alias. // If it is a GitHub tool and has an alias, the alias overrides the tool name, so we use it as the credential name. if isGitHubTool(toolName) && credentialAlias == "" { - cred, exists, err = r.credStore.Get(toolName) + c, exists, err = r.credStore.Get(toolName) if err != nil { return nil, fmt.Errorf("failed to get credentials for tool %s: %w", toolName, err) } } else if credentialAlias != "" { - cred, exists, err = r.credStore.Get(credentialAlias) + c, exists, err = r.credStore.Get(credentialAlias) if err != nil { return nil, fmt.Errorf("failed to get credentials for tool %s: %w", credentialAlias, err) } } + if c == nil { + c = &credentials.Credential{} + } + // If the credential doesn't already exist in the store, run the credential tool in order to get the value, // and save it in the store. - if !exists { + if !exists || c.IsExpired() { credToolRefs, ok := callCtx.Tool.ToolMapping[credToolName] if !ok || len(credToolRefs) != 1 { return nil, fmt.Errorf("failed to find ID for tool %s", credToolName) } + // If the existing credential is expired, we need to provide it to the cred tool through the environment. + if exists && c.IsExpired() { + credJSON, err := json.Marshal(c) + if err != nil { + return nil, fmt.Errorf("failed to marshal credential: %w", err) + } + env = append(env, fmt.Sprintf("%s=%s", credentials.ExistingCredential, string(credJSON))) + } + + // Get the input for the credential tool, if there is any. var input string if args != nil { inputBytes, err := json.Marshal(args) @@ -916,21 +930,14 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env return nil, fmt.Errorf("invalid state: credential tool [%s] can not result in a continuation", credToolName) } - var envMap struct { - Env map[string]string `json:"env"` - } - if err := json.Unmarshal([]byte(*res.Result), &envMap); err != nil { + if err := json.Unmarshal([]byte(*res.Result), &c); err != nil { return nil, fmt.Errorf("failed to unmarshal credential tool %s response: %w", credToolName, err) } - - cred = &credentials.Credential{ - Type: credentials.CredentialTypeTool, - Env: envMap.Env, - ToolName: credName, - } + c.ToolName = credName + c.Type = credentials.CredentialTypeTool isEmpty := true - for _, v := range cred.Env { + for _, v := range c.Env { if v != "" { isEmpty = false break @@ -941,7 +948,7 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env if (isGitHubTool(toolName) && callCtx.Program.ToolSet[credToolRefs[0].ToolID].Source.Repo != nil) || credentialAlias != "" { if isEmpty { log.Warnf("Not saving empty credential for tool %s", toolName) - } else if err := r.credStore.Add(*cred); err != nil { + } else if err := r.credStore.Add(*c); err != nil { return nil, fmt.Errorf("failed to add credential for tool %s: %w", toolName, err) } } else { @@ -949,7 +956,7 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env } } - for k, v := range cred.Env { + for k, v := range c.Env { env = append(env, fmt.Sprintf("%s=%s", k, v)) } }