diff --git a/integration/cred_test.go b/integration/cred_test.go index b92ccd55..67298ef8 100644 --- a/integration/cred_test.go +++ b/integration/cred_test.go @@ -11,3 +11,19 @@ func TestGPTScriptCredential(t *testing.T) { require.NoError(t, err) require.Contains(t, out, "CREDENTIAL") } + +// TestCredentialScopes makes sure that environment variables set by credential tools and shared credential tools +// are only available to the correct tools. See scripts/credscopes.gpt for more details. +func TestCredentialScopes(t *testing.T) { + out, err := RunScript("scripts/credscopes.gpt", "--sub-tool", "oneOne") + require.NoError(t, err) + require.Contains(t, out, "good") + + out, err = RunScript("scripts/credscopes.gpt", "--sub-tool", "twoOne") + require.NoError(t, err) + require.Contains(t, out, "good") + + out, err = RunScript("scripts/credscopes.gpt", "--sub-tool", "twoTwo") + require.NoError(t, err) + require.Contains(t, out, "good") +} diff --git a/integration/helpers.go b/integration/helpers.go index 8af581c3..33304676 100644 --- a/integration/helpers.go +++ b/integration/helpers.go @@ -14,3 +14,7 @@ func GPTScriptExec(args ...string) (string, error) { out, err := cmd.CombinedOutput() return string(out), err } + +func RunScript(script string, options ...string) (string, error) { + return GPTScriptExec(append(options, "--quiet", script)...) +} diff --git a/integration/scripts/credscopes.gpt b/integration/scripts/credscopes.gpt new file mode 100644 index 00000000..7319f163 --- /dev/null +++ b/integration/scripts/credscopes.gpt @@ -0,0 +1,160 @@ +# This script sets up a chain of tools in a tree structure. +# The root is oneOne, with children twoOne and twoTwo, with children threeOne, threeTwo, and threeThree, with only +# threeTwo shared between them. +# Each tool should only have access to any credentials it defines and any credentials exported/shared by its +# immediate children (but not grandchildren). +# This script checks to make sure that this is working properly. +name: oneOne +tools: twoOne, twoTwo +cred: getcred with oneOne as var and 11 as val + +#!python3 + +import os + +oneOne = os.getenv('oneOne') +twoOne = os.getenv('twoOne') +twoTwo = os.getenv('twoTwo') +threeOne = os.getenv('threeOne') +threeTwo = os.getenv('threeTwo') +threeThree = os.getenv('threeThree') + +if oneOne != '11': + print('error: oneOne is not 11') + exit(1) + +if twoOne != '21': + print('error: twoOne is not 21') + exit(1) + +if twoTwo != '22': + print('error: twoTwo is not 22') + exit(1) + +if threeOne is not None: + print('error: threeOne is not None') + exit(1) + +if threeTwo is not None: + print('error: threeTwo is not None') + exit(1) + +if threeThree is not None: + print('error: threeThree is not None') + exit(1) + +print('good') + +--- +name: twoOne +tools: threeOne, threeTwo +sharecred: getcred with twoOne as var and 21 as val + +#!python3 + +import os + +oneOne = os.getenv('oneOne') +twoOne = os.getenv('twoOne') +twoTwo = os.getenv('twoTwo') +threeOne = os.getenv('threeOne') +threeTwo = os.getenv('threeTwo') +threeThree = os.getenv('threeThree') + +if oneOne is not None: + print('error: oneOne is not None') + exit(1) + +if twoOne is not None: + print('error: twoOne is not None') + exit(1) + +if twoTwo is not None: + print('error: twoTwo is not None') + exit(1) + +if threeOne != '31': + print('error: threeOne is not 31') + exit(1) + +if threeTwo != '32': + print('error: threeTwo is not 32') + exit(1) + +if threeThree is not None: + print('error: threeThree is not None') + exit(1) + +print('good') + +--- +name: twoTwo +tools: threeTwo, threeThree +sharecred: getcred with twoTwo as var and 22 as val + +#!python3 + +import os + +oneOne = os.getenv('oneOne') +twoOne = os.getenv('twoOne') +twoTwo = os.getenv('twoTwo') +threeOne = os.getenv('threeOne') +threeTwo = os.getenv('threeTwo') +threeThree = os.getenv('threeThree') + +if oneOne is not None: + print('error: oneOne is not None') + exit(1) + +if twoOne is not None: + print('error: twoOne is not None') + exit(1) + +if twoTwo is not None: + print('error: twoTwo is not None') + exit(1) + +if threeOne is not None: + print('error: threeOne is not None') + exit(1) + +if threeTwo != '32': + print('error: threeTwo is not 32') + exit(1) + +if threeThree != '33': + print('error: threeThree is not 33') + exit(1) + +print('good') + +--- +name: threeOne +sharecred: getcred with threeOne as var and 31 as val + +--- +name: threeTwo +sharecred: getcred with threeTwo as var and 32 as val + +--- +name: threeThree +sharecred: getcred with threeThree as var and 33 as val + +--- +name: getcred + +#!python3 + +import os +import json + +var = os.getenv('var') +val = os.getenv('val') + +output = { + "env": { + var: val + } +} +print(json.dumps(output)) diff --git a/pkg/parser/parser.go b/pkg/parser/parser.go index 22cd5e9e..f7c750c1 100644 --- a/pkg/parser/parser.go +++ b/pkg/parser/parser.go @@ -148,6 +148,8 @@ func isParam(line string, tool *types.Tool) (_ bool, err error) { } case "credentials", "creds", "credential", "cred": tool.Parameters.Credentials = append(tool.Parameters.Credentials, value) + case "sharecredentials", "sharecreds", "sharecredential", "sharecred": + tool.Parameters.ExportCredentials = append(tool.Parameters.ExportCredentials, value) default: return false, nil } diff --git a/pkg/runner/runner.go b/pkg/runner/runner.go index cc5a3927..36bac826 100644 --- a/pkg/runner/runner.go +++ b/pkg/runner/runner.go @@ -419,9 +419,13 @@ func (r *Runner) start(callCtx engine.Context, state *State, monitor Monitor, en return nil, err } - if len(callCtx.Tool.Credentials) > 0 { + credTools, err := callCtx.Tool.GetCredentialTools(*callCtx.Program, callCtx.AgentGroup) + if err != nil { + return nil, err + } + if len(credTools) > 0 { var err error - env, err = r.handleCredentials(callCtx, monitor, env) + env, err = r.handleCredentials(callCtx, monitor, env, credTools) if err != nil { return nil, err } @@ -552,9 +556,13 @@ func (r *Runner) resume(callCtx engine.Context, monitor Monitor, env []string, s progress, progressClose := streamProgress(&callCtx, monitor) defer progressClose() - if len(callCtx.Tool.Credentials) > 0 { + credTools, err := callCtx.Tool.GetCredentialTools(*callCtx.Program, callCtx.AgentGroup) + if err != nil { + return nil, err + } + if len(credTools) > 0 { var err error - env, err = r.handleCredentials(callCtx, monitor, env) + env, err = r.handleCredentials(callCtx, monitor, env, credTools) if err != nil { return nil, err } @@ -828,7 +836,7 @@ func getEventContent(content string, callCtx engine.Context) string { return content } -func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env []string) ([]string, error) { +func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env []string, credToolRefs []types.ToolReference) ([]string, error) { // Since credential tools (usually) prompt the user, we want to only run one at a time. r.credMutex.Lock() defer r.credMutex.Unlock() @@ -845,10 +853,10 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env } } - for _, credToolName := range callCtx.Tool.Credentials { - toolName, credentialAlias, args, err := types.ParseCredentialArgs(credToolName, callCtx.Input) + for _, ref := range credToolRefs { + toolName, credentialAlias, args, err := types.ParseCredentialArgs(ref.Reference, callCtx.Input) if err != nil { - return nil, fmt.Errorf("failed to parse credential tool %q: %w", credToolName, err) + return nil, fmt.Errorf("failed to parse credential tool %q: %w", ref.Reference, err) } credName := toolName @@ -895,11 +903,6 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env // If the credential doesn't already exist in the store, run the credential tool in order to get the value, // and save it in the store. if !exists || c.IsExpired() { - credToolRefs, ok := callCtx.Tool.ToolMapping[credToolName] - if !ok || len(credToolRefs) != 1 { - return nil, fmt.Errorf("failed to find ID for tool %s", credToolName) - } - // If the existing credential is expired, we need to provide it to the cred tool through the environment. if exists && c.IsExpired() { credJSON, err := json.Marshal(c) @@ -914,22 +917,22 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env if args != nil { inputBytes, err := json.Marshal(args) if err != nil { - return nil, fmt.Errorf("failed to marshal args for tool %s: %w", credToolName, err) + return nil, fmt.Errorf("failed to marshal args for tool %s: %w", ref.Reference, err) } input = string(inputBytes) } - res, err := r.subCall(callCtx.Ctx, callCtx, monitor, env, credToolRefs[0].ToolID, input, "", engine.CredentialToolCategory) + res, err := r.subCall(callCtx.Ctx, callCtx, monitor, env, ref.ToolID, input, "", engine.CredentialToolCategory) if err != nil { - return nil, fmt.Errorf("failed to run credential tool %s: %w", credToolName, err) + return nil, fmt.Errorf("failed to run credential tool %s: %w", ref.Reference, err) } if res.Result == nil { - return nil, fmt.Errorf("invalid state: credential tool [%s] can not result in a continuation", credToolName) + return nil, fmt.Errorf("invalid state: credential tool [%s] can not result in a continuation", ref.Reference) } if err := json.Unmarshal([]byte(*res.Result), &c); err != nil { - return nil, fmt.Errorf("failed to unmarshal credential tool %s response: %w", credToolName, err) + return nil, fmt.Errorf("failed to unmarshal credential tool %s response: %w", ref.Reference, err) } c.ToolName = credName c.Type = credentials.CredentialTypeTool @@ -943,7 +946,7 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env } // Only store the credential if the tool is on GitHub or has an alias, and the credential is non-empty. - if (isGitHubTool(toolName) && callCtx.Program.ToolSet[credToolRefs[0].ToolID].Source.Repo != nil) || credentialAlias != "" { + if (isGitHubTool(toolName) && callCtx.Program.ToolSet[ref.ToolID].Source.Repo != nil) || credentialAlias != "" { if isEmpty { log.Warnf("Not saving empty credential for tool %s", toolName) } else if err := r.credStore.Add(callCtx.Ctx, *c); err != nil { diff --git a/pkg/types/tool.go b/pkg/types/tool.go index e4b3424a..ad483984 100644 --- a/pkg/types/tool.go +++ b/pkg/types/tool.go @@ -139,6 +139,7 @@ type Parameters struct { Export []string `json:"export,omitempty"` Agents []string `json:"agents,omitempty"` Credentials []string `json:"credentials,omitempty"` + ExportCredentials []string `json:"exportCredentials,omitempty"` InputFilters []string `json:"inputFilters,omitempty"` ExportInputFilters []string `json:"exportInputFilters,omitempty"` OutputFilters []string `json:"outputFilters,omitempty"` @@ -154,6 +155,7 @@ func (p Parameters) ToolRefNames() []string { p.ExportContext, p.Context, p.Credentials, + p.ExportCredentials, p.InputFilters, p.ExportInputFilters, p.OutputFilters, @@ -466,6 +468,11 @@ func (t ToolDef) String() string { _, _ = fmt.Fprintf(buf, "Credential: %s\n", cred) } } + if len(t.Parameters.ExportCredentials) > 0 { + for _, exportCred := range t.Parameters.ExportCredentials { + _, _ = fmt.Fprintf(buf, "Share Credential: %s\n", exportCred) + } + } if t.Parameters.Chat { _, _ = fmt.Fprintf(buf, "Chat: true\n") } @@ -675,6 +682,23 @@ func (t Tool) getCompletionToolRefs(prg Program, agentGroup []ToolReference) ([] return result.List() } +func (t Tool) GetCredentialTools(prg Program, agentGroup []ToolReference) ([]ToolReference, error) { + result := toolRefSet{} + + result.AddAll(t.GetToolRefsFromNames(t.Credentials)) + + toolRefs, err := t.getCompletionToolRefs(prg, agentGroup) + if err != nil { + return nil, err + } + for _, toolRef := range toolRefs { + referencedTool := prg.ToolSet[toolRef.ToolID] + result.AddAll(referencedTool.GetToolRefsFromNames(referencedTool.ExportCredentials)) + } + + return result.List() +} + func toolRefsToCompletionTools(completionTools []ToolReference, prg Program) (result []CompletionTool) { toolNames := map[string]struct{}{}