Skip to content

Commit

Permalink
Merge pull request #72 from strongdm/improve-batch-error-handling
Browse files Browse the repository at this point in the history
Improve batch error handling
  • Loading branch information
philhassey authored Dec 3, 2024
2 parents 224fafe + 1753724 commit 2401ad8
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 21 deletions.
3 changes: 2 additions & 1 deletion corpus_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,9 +208,10 @@ func TestCorpus(t *testing.T) {
"resource": []cedar.Value{resource},
"context": []cedar.Value{context},
},
}, func(r batch.Result) {
}, func(r batch.Result) error {
res = r
total++
return nil
})
testutil.OK(t, err)
testutil.Equals(t, total, 1)
Expand Down
9 changes: 5 additions & 4 deletions x/exp/batch/batch.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ package batch

import (
"context"
"errors"
"fmt"
"maps"
"slices"
Expand Down Expand Up @@ -55,7 +56,7 @@ type Result struct {

// Callback is a function that is called for each single batch authorization with
// a Result.
type Callback func(Result)
type Callback func(Result) error

type idEvaler struct {
Policy *ast.Policy
Expand Down Expand Up @@ -107,6 +108,7 @@ var errInvalidPart = fmt.Errorf("invalid part")
// - It will error in case any of PARC are an incorrect type at authorization.
// - It will error in case there are unbound variables.
// - It will error in case there are unused variables.
// - It will error in case of a callback error.
//
// The result passed to the callback must be used / cloned immediately and not modified.
func Authorize(ctx context.Context, ps *cedar.PolicySet, entities types.EntityMap, request Request, cb Callback) error {
Expand Down Expand Up @@ -174,7 +176,7 @@ func Authorize(ctx context.Context, ps *cedar.PolicySet, entities types.EntityMa
fixIgnores(be)
}

return doBatch(ctx, be)
return errors.Join(doBatch(ctx, be), ctx.Err())
}

func doPartial(be *batchEvaler) {
Expand Down Expand Up @@ -284,8 +286,7 @@ func diagnosticAuthzWithCallback(be *batchEvaler) error {
res.Values = be.Values
batchCompile(be)
res.Decision, res.Diagnostic = isAuthorized(be.evalers, be.env)
be.callback(res)
return nil
return be.callback(res)
}

func isAuthorized(ps map[types.PolicyID]*idEvaler, env eval.Env) (types.Decision, types.Diagnostic) {
Expand Down
114 changes: 98 additions & 16 deletions x/exp/batch/batch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package batch

import (
"context"
"fmt"
"maps"
"reflect"
"slices"
Expand Down Expand Up @@ -235,10 +236,11 @@ func TestBatch(t *testing.T) {
ps := cedar.NewPolicySet()
ps.Add("0", cedar.NewPolicyFromAST((*publicast.Policy)(tt.policy)))

err := Authorize(context.Background(), ps, tt.entities, tt.request, func(br Result) {
err := Authorize(context.Background(), ps, tt.entities, tt.request, func(br Result) error {
// Need to clone this because it could be mutated in successive authorizations
br.Values = maps.Clone(br.Values)
res = append(res, br)
return nil
})
testutil.OK(t, err)
testutil.Equals(t, len(res), len(tt.results))
Expand All @@ -259,7 +261,7 @@ func TestBatchErrors(t *testing.T) {
t.Parallel()
err := Authorize(context.Background(), cedar.NewPolicySet(), types.EntityMap{}, Request{
Principal: Variable("bananas"),
}, func(_ Result) {},
}, func(_ Result) error { return nil },
)
testutil.ErrorIs(t, err, errUnboundVariable)
})
Expand All @@ -268,7 +270,7 @@ func TestBatchErrors(t *testing.T) {
Variables: Variables{
"bananas": []types.Value{types.String("test")},
},
}, func(_ Result) {},
}, func(_ Result) error { return nil },
)
testutil.ErrorIs(t, err, errUnusedVariable)
})
Expand All @@ -280,7 +282,7 @@ func TestBatchErrors(t *testing.T) {
Variables: Variables{
"bananas": nil,
},
}, func(_ Result) { total++ },
}, func(_ Result) error { total++; return nil },
)
testutil.OK(t, err)
testutil.Equals(t, total, 0)
Expand All @@ -293,7 +295,7 @@ func TestBatchErrors(t *testing.T) {
Action: types.NewEntityUID("Action", "action"),
Resource: types.NewEntityUID("Resource", "resource"),
Context: types.Record{},
}, func(_ Result) {},
}, func(_ Result) error { return nil },
)
testutil.ErrorIs(t, err, errMissingPart)
})
Expand All @@ -303,7 +305,7 @@ func TestBatchErrors(t *testing.T) {
Action: nil,
Resource: types.NewEntityUID("Resource", "resource"),
Context: types.Record{},
}, func(_ Result) {},
}, func(_ Result) error { return nil },
)
testutil.ErrorIs(t, err, errMissingPart)
})
Expand All @@ -313,7 +315,7 @@ func TestBatchErrors(t *testing.T) {
Action: types.NewEntityUID("Action", "action"),
Resource: nil,
Context: types.Record{},
}, func(_ Result) {},
}, func(_ Result) error { return nil },
)
testutil.ErrorIs(t, err, errMissingPart)
})
Expand All @@ -323,7 +325,7 @@ func TestBatchErrors(t *testing.T) {
Action: types.NewEntityUID("Action", "action"),
Resource: types.NewEntityUID("Resource", "resource"),
Context: nil,
}, func(_ Result) {},
}, func(_ Result) error { return nil },
)
testutil.ErrorIs(t, err, errMissingPart)
})
Expand All @@ -336,12 +338,12 @@ func TestBatchErrors(t *testing.T) {
Action: types.NewEntityUID("Action", "action"),
Resource: types.NewEntityUID("Resource", "resource"),
Context: types.Record{},
}, func(_ Result) {},
}, func(_ Result) error { return nil },
)
testutil.ErrorIs(t, err, context.Canceled)
})

t.Run("lateContextCancelled", func(t *testing.T) {
t.Run("firstContextCancelled", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
var total int
err := Authorize(ctx, cedar.NewPolicySet(), types.EntityMap{}, Request{
Expand All @@ -356,22 +358,101 @@ func TestBatchErrors(t *testing.T) {
types.NewEntityUID("Resource", "3"),
},
},
}, func(_ Result) {
}, func(_ Result) error {
total++
cancel()
return nil
},
)
testutil.ErrorIs(t, err, context.Canceled)
testutil.Equals(t, total, 1)
})

t.Run("lastContextCancelled", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
var total int
err := Authorize(ctx, cedar.NewPolicySet(), types.EntityMap{}, Request{
Principal: types.NewEntityUID("Principal", "principal"),
Action: types.NewEntityUID("Action", "action"),
Resource: Variable("resource"),
Context: types.Record{},
Variables: Variables{
"resource": []types.Value{
types.NewEntityUID("Resource", "1"),
types.NewEntityUID("Resource", "2"),
types.NewEntityUID("Resource", "3"),
},
},
}, func(_ Result) error {
total++
if total == 3 {
cancel()
}
return nil
},
)
testutil.ErrorIs(t, err, context.Canceled)
testutil.Equals(t, total, 3)
})

t.Run("callbackErrored", func(t *testing.T) {
var total int
errWant := fmt.Errorf("errWant")
err := Authorize(context.Background(), cedar.NewPolicySet(), types.EntityMap{}, Request{
Principal: types.NewEntityUID("Principal", "principal"),
Action: types.NewEntityUID("Action", "action"),
Resource: Variable("resource"),
Context: types.Record{},
Variables: Variables{
"resource": []types.Value{
types.NewEntityUID("Resource", "1"),
types.NewEntityUID("Resource", "2"),
types.NewEntityUID("Resource", "3"),
},
},
}, func(_ Result) error {
total++
if total == 2 {
return errWant
}
return nil
},
)
testutil.ErrorIs(t, err, errWant)
testutil.Equals(t, total, 2)
})

t.Run("contextAndCallbackErrored", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
errWant := fmt.Errorf("errWant")
err := Authorize(ctx, cedar.NewPolicySet(), types.EntityMap{}, Request{
Principal: types.NewEntityUID("Principal", "principal"),
Action: types.NewEntityUID("Action", "action"),
Resource: Variable("resource"),
Context: types.Record{},
Variables: Variables{
"resource": []types.Value{
types.NewEntityUID("Resource", "1"),
types.NewEntityUID("Resource", "2"),
types.NewEntityUID("Resource", "3"),
},
},
}, func(_ Result) error {
cancel()
return errWant
},
)
testutil.ErrorIs(t, err, context.Canceled)
testutil.ErrorIs(t, err, errWant)
})

t.Run("invalidPrincipal", func(t *testing.T) {
err := Authorize(context.Background(), cedar.NewPolicySet(), types.EntityMap{}, Request{
Principal: types.String("invalid"),
Action: types.NewEntityUID("Action", "action"),
Resource: types.NewEntityUID("Resource", "resource"),
Context: types.Record{},
}, func(_ Result) {},
}, func(_ Result) error { return nil },
)
testutil.ErrorIs(t, err, errInvalidPart)
})
Expand All @@ -381,7 +462,7 @@ func TestBatchErrors(t *testing.T) {
Action: types.String("invalid"),
Resource: types.NewEntityUID("Resource", "resource"),
Context: types.Record{},
}, func(_ Result) {},
}, func(_ Result) error { return nil },
)
testutil.ErrorIs(t, err, errInvalidPart)
})
Expand All @@ -391,7 +472,7 @@ func TestBatchErrors(t *testing.T) {
Action: types.NewEntityUID("Action", "action"),
Resource: types.String("invalid"),
Context: types.Record{},
}, func(_ Result) {},
}, func(_ Result) error { return nil },
)
testutil.ErrorIs(t, err, errInvalidPart)
})
Expand All @@ -401,7 +482,7 @@ func TestBatchErrors(t *testing.T) {
Action: types.NewEntityUID("Action", "action"),
Resource: types.NewEntityUID("Resource", "resource"),
Context: types.String("invalid"),
}, func(_ Result) {},
}, func(_ Result) error { return nil },
)
testutil.ErrorIs(t, err, errInvalidPart)
})
Expand Down Expand Up @@ -596,14 +677,15 @@ func TestIgnoreReasons(t *testing.T) {

var reasons []types.PolicyID
var total int
err := Authorize(context.Background(), ps, types.EntityMap{}, tt.Request, func(r Result) {
err := Authorize(context.Background(), ps, types.EntityMap{}, tt.Request, func(r Result) error {
total++
testutil.Equals(t, r.Decision, tt.Decision)
for _, v := range r.Diagnostic.Reasons {
if !slices.Contains(reasons, v.PolicyID) {
reasons = append(reasons, v.PolicyID)
}
}
return nil
})
testutil.OK(t, err)
testutil.Equals(t, total, tt.Total)
Expand Down

0 comments on commit 2401ad8

Please sign in to comment.