Skip to content

Commit

Permalink
enhance: share credential (#634)
Browse files Browse the repository at this point in the history
Signed-off-by: Grant Linville <[email protected]>
  • Loading branch information
g-linville authored Jul 17, 2024
1 parent a5a0538 commit 3055632
Show file tree
Hide file tree
Showing 6 changed files with 228 additions and 19 deletions.
16 changes: 16 additions & 0 deletions integration/cred_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,19 @@ func TestGPTScriptCredential(t *testing.T) {
require.NoError(t, err)
require.Contains(t, out, "CREDENTIAL")
}

// TestCredentialScopes makes sure that environment variables set by credential tools and shared credential tools
// are only available to the correct tools. See scripts/credscopes.gpt for more details.
func TestCredentialScopes(t *testing.T) {
out, err := RunScript("scripts/credscopes.gpt", "--sub-tool", "oneOne")
require.NoError(t, err)
require.Contains(t, out, "good")

out, err = RunScript("scripts/credscopes.gpt", "--sub-tool", "twoOne")
require.NoError(t, err)
require.Contains(t, out, "good")

out, err = RunScript("scripts/credscopes.gpt", "--sub-tool", "twoTwo")
require.NoError(t, err)
require.Contains(t, out, "good")
}
4 changes: 4 additions & 0 deletions integration/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,7 @@ func GPTScriptExec(args ...string) (string, error) {
out, err := cmd.CombinedOutput()
return string(out), err
}

func RunScript(script string, options ...string) (string, error) {
return GPTScriptExec(append(options, "--quiet", script)...)
}
160 changes: 160 additions & 0 deletions integration/scripts/credscopes.gpt
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# This script sets up a chain of tools in a tree structure.
# The root is oneOne, with children twoOne and twoTwo, with children threeOne, threeTwo, and threeThree, with only
# threeTwo shared between them.
# Each tool should only have access to any credentials it defines and any credentials exported/shared by its
# immediate children (but not grandchildren).
# This script checks to make sure that this is working properly.
name: oneOne
tools: twoOne, twoTwo
cred: getcred with oneOne as var and 11 as val

#!python3

import os

oneOne = os.getenv('oneOne')
twoOne = os.getenv('twoOne')
twoTwo = os.getenv('twoTwo')
threeOne = os.getenv('threeOne')
threeTwo = os.getenv('threeTwo')
threeThree = os.getenv('threeThree')

if oneOne != '11':
print('error: oneOne is not 11')
exit(1)

if twoOne != '21':
print('error: twoOne is not 21')
exit(1)

if twoTwo != '22':
print('error: twoTwo is not 22')
exit(1)

if threeOne is not None:
print('error: threeOne is not None')
exit(1)

if threeTwo is not None:
print('error: threeTwo is not None')
exit(1)

if threeThree is not None:
print('error: threeThree is not None')
exit(1)

print('good')

---
name: twoOne
tools: threeOne, threeTwo
sharecred: getcred with twoOne as var and 21 as val

#!python3

import os

oneOne = os.getenv('oneOne')
twoOne = os.getenv('twoOne')
twoTwo = os.getenv('twoTwo')
threeOne = os.getenv('threeOne')
threeTwo = os.getenv('threeTwo')
threeThree = os.getenv('threeThree')

if oneOne is not None:
print('error: oneOne is not None')
exit(1)

if twoOne is not None:
print('error: twoOne is not None')
exit(1)

if twoTwo is not None:
print('error: twoTwo is not None')
exit(1)

if threeOne != '31':
print('error: threeOne is not 31')
exit(1)

if threeTwo != '32':
print('error: threeTwo is not 32')
exit(1)

if threeThree is not None:
print('error: threeThree is not None')
exit(1)

print('good')

---
name: twoTwo
tools: threeTwo, threeThree
sharecred: getcred with twoTwo as var and 22 as val

#!python3

import os

oneOne = os.getenv('oneOne')
twoOne = os.getenv('twoOne')
twoTwo = os.getenv('twoTwo')
threeOne = os.getenv('threeOne')
threeTwo = os.getenv('threeTwo')
threeThree = os.getenv('threeThree')

if oneOne is not None:
print('error: oneOne is not None')
exit(1)

if twoOne is not None:
print('error: twoOne is not None')
exit(1)

if twoTwo is not None:
print('error: twoTwo is not None')
exit(1)

if threeOne is not None:
print('error: threeOne is not None')
exit(1)

if threeTwo != '32':
print('error: threeTwo is not 32')
exit(1)

if threeThree != '33':
print('error: threeThree is not 33')
exit(1)

print('good')

---
name: threeOne
sharecred: getcred with threeOne as var and 31 as val

---
name: threeTwo
sharecred: getcred with threeTwo as var and 32 as val

---
name: threeThree
sharecred: getcred with threeThree as var and 33 as val

---
name: getcred

#!python3

import os
import json

var = os.getenv('var')
val = os.getenv('val')

output = {
"env": {
var: val
}
}
print(json.dumps(output))
2 changes: 2 additions & 0 deletions pkg/parser/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ func isParam(line string, tool *types.Tool) (_ bool, err error) {
}
case "credentials", "creds", "credential", "cred":
tool.Parameters.Credentials = append(tool.Parameters.Credentials, value)
case "sharecredentials", "sharecreds", "sharecredential", "sharecred":
tool.Parameters.ExportCredentials = append(tool.Parameters.ExportCredentials, value)
default:
return false, nil
}
Expand Down
41 changes: 22 additions & 19 deletions pkg/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -419,9 +419,13 @@ func (r *Runner) start(callCtx engine.Context, state *State, monitor Monitor, en
return nil, err
}

if len(callCtx.Tool.Credentials) > 0 {
credTools, err := callCtx.Tool.GetCredentialTools(*callCtx.Program, callCtx.AgentGroup)
if err != nil {
return nil, err
}
if len(credTools) > 0 {
var err error
env, err = r.handleCredentials(callCtx, monitor, env)
env, err = r.handleCredentials(callCtx, monitor, env, credTools)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -552,9 +556,13 @@ func (r *Runner) resume(callCtx engine.Context, monitor Monitor, env []string, s
progress, progressClose := streamProgress(&callCtx, monitor)
defer progressClose()

if len(callCtx.Tool.Credentials) > 0 {
credTools, err := callCtx.Tool.GetCredentialTools(*callCtx.Program, callCtx.AgentGroup)
if err != nil {
return nil, err
}
if len(credTools) > 0 {
var err error
env, err = r.handleCredentials(callCtx, monitor, env)
env, err = r.handleCredentials(callCtx, monitor, env, credTools)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -828,7 +836,7 @@ func getEventContent(content string, callCtx engine.Context) string {
return content
}

func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env []string) ([]string, error) {
func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env []string, credToolRefs []types.ToolReference) ([]string, error) {
// Since credential tools (usually) prompt the user, we want to only run one at a time.
r.credMutex.Lock()
defer r.credMutex.Unlock()
Expand All @@ -845,10 +853,10 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env
}
}

for _, credToolName := range callCtx.Tool.Credentials {
toolName, credentialAlias, args, err := types.ParseCredentialArgs(credToolName, callCtx.Input)
for _, ref := range credToolRefs {
toolName, credentialAlias, args, err := types.ParseCredentialArgs(ref.Reference, callCtx.Input)
if err != nil {
return nil, fmt.Errorf("failed to parse credential tool %q: %w", credToolName, err)
return nil, fmt.Errorf("failed to parse credential tool %q: %w", ref.Reference, err)
}

credName := toolName
Expand Down Expand Up @@ -895,11 +903,6 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env
// If the credential doesn't already exist in the store, run the credential tool in order to get the value,
// and save it in the store.
if !exists || c.IsExpired() {
credToolRefs, ok := callCtx.Tool.ToolMapping[credToolName]
if !ok || len(credToolRefs) != 1 {
return nil, fmt.Errorf("failed to find ID for tool %s", credToolName)
}

// If the existing credential is expired, we need to provide it to the cred tool through the environment.
if exists && c.IsExpired() {
credJSON, err := json.Marshal(c)
Expand All @@ -914,22 +917,22 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env
if args != nil {
inputBytes, err := json.Marshal(args)
if err != nil {
return nil, fmt.Errorf("failed to marshal args for tool %s: %w", credToolName, err)
return nil, fmt.Errorf("failed to marshal args for tool %s: %w", ref.Reference, err)
}
input = string(inputBytes)
}

res, err := r.subCall(callCtx.Ctx, callCtx, monitor, env, credToolRefs[0].ToolID, input, "", engine.CredentialToolCategory)
res, err := r.subCall(callCtx.Ctx, callCtx, monitor, env, ref.ToolID, input, "", engine.CredentialToolCategory)
if err != nil {
return nil, fmt.Errorf("failed to run credential tool %s: %w", credToolName, err)
return nil, fmt.Errorf("failed to run credential tool %s: %w", ref.Reference, err)
}

if res.Result == nil {
return nil, fmt.Errorf("invalid state: credential tool [%s] can not result in a continuation", credToolName)
return nil, fmt.Errorf("invalid state: credential tool [%s] can not result in a continuation", ref.Reference)
}

if err := json.Unmarshal([]byte(*res.Result), &c); err != nil {
return nil, fmt.Errorf("failed to unmarshal credential tool %s response: %w", credToolName, err)
return nil, fmt.Errorf("failed to unmarshal credential tool %s response: %w", ref.Reference, err)
}
c.ToolName = credName
c.Type = credentials.CredentialTypeTool
Expand All @@ -943,7 +946,7 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env
}

// Only store the credential if the tool is on GitHub or has an alias, and the credential is non-empty.
if (isGitHubTool(toolName) && callCtx.Program.ToolSet[credToolRefs[0].ToolID].Source.Repo != nil) || credentialAlias != "" {
if (isGitHubTool(toolName) && callCtx.Program.ToolSet[ref.ToolID].Source.Repo != nil) || credentialAlias != "" {
if isEmpty {
log.Warnf("Not saving empty credential for tool %s", toolName)
} else if err := r.credStore.Add(callCtx.Ctx, *c); err != nil {
Expand Down
24 changes: 24 additions & 0 deletions pkg/types/tool.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ type Parameters struct {
Export []string `json:"export,omitempty"`
Agents []string `json:"agents,omitempty"`
Credentials []string `json:"credentials,omitempty"`
ExportCredentials []string `json:"exportCredentials,omitempty"`
InputFilters []string `json:"inputFilters,omitempty"`
ExportInputFilters []string `json:"exportInputFilters,omitempty"`
OutputFilters []string `json:"outputFilters,omitempty"`
Expand All @@ -154,6 +155,7 @@ func (p Parameters) ToolRefNames() []string {
p.ExportContext,
p.Context,
p.Credentials,
p.ExportCredentials,
p.InputFilters,
p.ExportInputFilters,
p.OutputFilters,
Expand Down Expand Up @@ -466,6 +468,11 @@ func (t ToolDef) String() string {
_, _ = fmt.Fprintf(buf, "Credential: %s\n", cred)
}
}
if len(t.Parameters.ExportCredentials) > 0 {
for _, exportCred := range t.Parameters.ExportCredentials {
_, _ = fmt.Fprintf(buf, "Share Credential: %s\n", exportCred)
}
}
if t.Parameters.Chat {
_, _ = fmt.Fprintf(buf, "Chat: true\n")
}
Expand Down Expand Up @@ -675,6 +682,23 @@ func (t Tool) getCompletionToolRefs(prg Program, agentGroup []ToolReference) ([]
return result.List()
}

func (t Tool) GetCredentialTools(prg Program, agentGroup []ToolReference) ([]ToolReference, error) {
result := toolRefSet{}

result.AddAll(t.GetToolRefsFromNames(t.Credentials))

toolRefs, err := t.getCompletionToolRefs(prg, agentGroup)
if err != nil {
return nil, err
}
for _, toolRef := range toolRefs {
referencedTool := prg.ToolSet[toolRef.ToolID]
result.AddAll(referencedTool.GetToolRefsFromNames(referencedTool.ExportCredentials))
}

return result.List()
}

func toolRefsToCompletionTools(completionTools []ToolReference, prg Program) (result []CompletionTool) {
toolNames := map[string]struct{}{}

Expand Down

0 comments on commit 3055632

Please sign in to comment.