Skip to content

Commit

Permalink
x/exp/batch: use a MapSet[types.String] instead of map[types.String]{…
Browse files Browse the repository at this point in the history
…} for detecting unbound or unused variables

Signed-off-by: Patrick Jakubowski <[email protected]>
  • Loading branch information
patjakdev committed Sep 23, 2024
1 parent 14b20af commit 4cf812d
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 19 deletions.
27 changes: 17 additions & 10 deletions x/exp/batch/batch.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/cedar-policy/cedar-go/internal/ast"
"github.com/cedar-policy/cedar-go/internal/consts"
"github.com/cedar-policy/cedar-go/internal/eval"
"github.com/cedar-policy/cedar-go/internal/sets"
"github.com/cedar-policy/cedar-go/types"
)

Expand Down Expand Up @@ -103,18 +104,24 @@ var errInvalidPart = fmt.Errorf("invalid part")
// The result passed to the callback must be used / cloned immediately and not modified.
func Authorize(ctx context.Context, ps *cedar.PolicySet, entityMap types.Entities, request Request, cb Callback) error {
be := &batchEvaler{}
found := map[types.String]struct{}{}
findVariables(found, request.Principal)
findVariables(found, request.Action)
findVariables(found, request.Resource)
findVariables(found, request.Context)
for key := range found {
var found sets.MapSet[types.String]
findVariables(&found, request.Principal)
findVariables(&found, request.Action)
findVariables(&found, request.Resource)
findVariables(&found, request.Context)
var err error
found.Iterate(func(key types.String) bool {
if _, ok := request.Variables[key]; !ok {
return fmt.Errorf("%w: %v", errUnboundVariable, key)
err = fmt.Errorf("%w: %v", errUnboundVariable, key)
return false
}
return true
})
if err != nil {
return err
}
for k := range request.Variables {
if _, ok := found[k]; !ok {
if !found.Contains(k) {
return fmt.Errorf("%w: %v", errUnusedVariable, k)
}
}
Expand Down Expand Up @@ -375,11 +382,11 @@ func cloneSub(r types.Value, k types.String, v types.Value) (types.Value, bool)
return r, false
}

func findVariables(found map[types.String]struct{}, r types.Value) {
func findVariables(found *sets.MapSet[types.String], r types.Value) {
switch t := r.(type) {
case types.EntityUID:
if key, ok := eval.ToVariable(t); ok {
found[key] = struct{}{}
found.Add(key)
}
case types.Record:
t.Iterate(func(_ types.String, vv types.Value) bool {
Expand Down
19 changes: 10 additions & 9 deletions x/exp/batch/batch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/cedar-policy/cedar-go"
publicast "github.com/cedar-policy/cedar-go/ast"
"github.com/cedar-policy/cedar-go/internal/ast"
"github.com/cedar-policy/cedar-go/internal/sets"
"github.com/cedar-policy/cedar-go/internal/testutil"
"github.com/cedar-policy/cedar-go/types"
)
Expand Down Expand Up @@ -663,22 +664,22 @@ func TestFindVariables(t *testing.T) {
tests := []struct {
name string
in types.Value
out map[types.String]struct{}
out []types.String
}{
{"record", types.NewRecord(types.RecordMap{"key": Variable("bananas")}), map[types.String]struct{}{"bananas": {}}},
{"set", types.NewSet([]types.Value{Variable("bananas")}), map[types.String]struct{}{"bananas": {}}},
{"dupes", types.NewSet([]types.Value{Variable("bananas"), Variable("bananas")}), map[types.String]struct{}{"bananas": {}}},
{"none", types.String("test"), map[types.String]struct{}{}},
{"multi", types.NewSet([]types.Value{Variable("bananas"), Variable("test")}), map[types.String]struct{}{"bananas": {}, "test": {}}},
{"record", types.NewRecord(types.RecordMap{"key": Variable("bananas")}), []types.String{"bananas"}},
{"set", types.NewSet([]types.Value{Variable("bananas")}), []types.String{"bananas"}},
{"dupes", types.NewSet([]types.Value{Variable("bananas"), Variable("bananas")}), []types.String{"bananas"}},
{"none", types.String("test"), nil},
{"multi", types.NewSet([]types.Value{Variable("bananas"), Variable("test")}), []types.String{"bananas", "test"}},
}

for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
out := map[types.String]struct{}{}
findVariables(out, tt.in)
testutil.Equals(t, out, tt.out)
var out sets.MapSet[types.String]
findVariables(&out, tt.in)
testutil.Equals(t, out, sets.NewMapSetFromSlice(tt.out))
})
}

Expand Down

0 comments on commit 4cf812d

Please sign in to comment.