diff --git a/corpus_test.go b/corpus_test.go index 7067097..978b639 100644 --- a/corpus_test.go +++ b/corpus_test.go @@ -208,9 +208,10 @@ func TestCorpus(t *testing.T) { "resource": []cedar.Value{resource}, "context": []cedar.Value{context}, }, - }, func(r batch.Result) { + }, func(r batch.Result) error { res = r total++ + return nil }) testutil.OK(t, err) testutil.Equals(t, total, 1) diff --git a/x/exp/batch/batch.go b/x/exp/batch/batch.go index 9ae0216..023e725 100644 --- a/x/exp/batch/batch.go +++ b/x/exp/batch/batch.go @@ -9,6 +9,7 @@ package batch import ( "context" + "errors" "fmt" "maps" "slices" @@ -55,7 +56,7 @@ type Result struct { // Callback is a function that is called for each single batch authorization with // a Result. -type Callback func(Result) +type Callback func(Result) error type idEvaler struct { Policy *ast.Policy @@ -107,6 +108,7 @@ var errInvalidPart = fmt.Errorf("invalid part") // - It will error in case any of PARC are an incorrect type at authorization. // - It will error in case there are unbound variables. // - It will error in case there are unused variables. +// - It will error in case of a callback error. // // The result passed to the callback must be used / cloned immediately and not modified. func Authorize(ctx context.Context, ps *cedar.PolicySet, entities types.EntityMap, request Request, cb Callback) error { @@ -174,7 +176,7 @@ func Authorize(ctx context.Context, ps *cedar.PolicySet, entities types.EntityMa fixIgnores(be) } - return doBatch(ctx, be) + return errors.Join(doBatch(ctx, be), ctx.Err()) } func doPartial(be *batchEvaler) { @@ -284,8 +286,7 @@ func diagnosticAuthzWithCallback(be *batchEvaler) error { res.Values = be.Values batchCompile(be) res.Decision, res.Diagnostic = isAuthorized(be.evalers, be.env) - be.callback(res) - return nil + return be.callback(res) } func isAuthorized(ps map[types.PolicyID]*idEvaler, env eval.Env) (types.Decision, types.Diagnostic) { diff --git a/x/exp/batch/batch_test.go b/x/exp/batch/batch_test.go index 47adbb2..fd352cf 100644 --- a/x/exp/batch/batch_test.go +++ b/x/exp/batch/batch_test.go @@ -2,6 +2,7 @@ package batch import ( "context" + "fmt" "maps" "reflect" "slices" @@ -235,10 +236,11 @@ func TestBatch(t *testing.T) { ps := cedar.NewPolicySet() ps.Add("0", cedar.NewPolicyFromAST((*publicast.Policy)(tt.policy))) - err := Authorize(context.Background(), ps, tt.entities, tt.request, func(br Result) { + err := Authorize(context.Background(), ps, tt.entities, tt.request, func(br Result) error { // Need to clone this because it could be mutated in successive authorizations br.Values = maps.Clone(br.Values) res = append(res, br) + return nil }) testutil.OK(t, err) testutil.Equals(t, len(res), len(tt.results)) @@ -259,7 +261,7 @@ func TestBatchErrors(t *testing.T) { t.Parallel() err := Authorize(context.Background(), cedar.NewPolicySet(), types.EntityMap{}, Request{ Principal: Variable("bananas"), - }, func(_ Result) {}, + }, func(_ Result) error { return nil }, ) testutil.ErrorIs(t, err, errUnboundVariable) }) @@ -268,7 +270,7 @@ func TestBatchErrors(t *testing.T) { Variables: Variables{ "bananas": []types.Value{types.String("test")}, }, - }, func(_ Result) {}, + }, func(_ Result) error { return nil }, ) testutil.ErrorIs(t, err, errUnusedVariable) }) @@ -280,7 +282,7 @@ func TestBatchErrors(t *testing.T) { Variables: Variables{ "bananas": nil, }, - }, func(_ Result) { total++ }, + }, func(_ Result) error { total++; return nil }, ) testutil.OK(t, err) testutil.Equals(t, total, 0) @@ -293,7 +295,7 @@ func TestBatchErrors(t *testing.T) { Action: types.NewEntityUID("Action", "action"), Resource: types.NewEntityUID("Resource", "resource"), Context: types.Record{}, - }, func(_ Result) {}, + }, func(_ Result) error { return nil }, ) testutil.ErrorIs(t, err, errMissingPart) }) @@ -303,7 +305,7 @@ func TestBatchErrors(t *testing.T) { Action: nil, Resource: types.NewEntityUID("Resource", "resource"), Context: types.Record{}, - }, func(_ Result) {}, + }, func(_ Result) error { return nil }, ) testutil.ErrorIs(t, err, errMissingPart) }) @@ -313,7 +315,7 @@ func TestBatchErrors(t *testing.T) { Action: types.NewEntityUID("Action", "action"), Resource: nil, Context: types.Record{}, - }, func(_ Result) {}, + }, func(_ Result) error { return nil }, ) testutil.ErrorIs(t, err, errMissingPart) }) @@ -323,7 +325,7 @@ func TestBatchErrors(t *testing.T) { Action: types.NewEntityUID("Action", "action"), Resource: types.NewEntityUID("Resource", "resource"), Context: nil, - }, func(_ Result) {}, + }, func(_ Result) error { return nil }, ) testutil.ErrorIs(t, err, errMissingPart) }) @@ -336,12 +338,12 @@ func TestBatchErrors(t *testing.T) { Action: types.NewEntityUID("Action", "action"), Resource: types.NewEntityUID("Resource", "resource"), Context: types.Record{}, - }, func(_ Result) {}, + }, func(_ Result) error { return nil }, ) testutil.ErrorIs(t, err, context.Canceled) }) - t.Run("lateContextCancelled", func(t *testing.T) { + t.Run("firstContextCancelled", func(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) var total int err := Authorize(ctx, cedar.NewPolicySet(), types.EntityMap{}, Request{ @@ -356,22 +358,101 @@ func TestBatchErrors(t *testing.T) { types.NewEntityUID("Resource", "3"), }, }, - }, func(_ Result) { + }, func(_ Result) error { total++ cancel() + return nil }, ) testutil.ErrorIs(t, err, context.Canceled) testutil.Equals(t, total, 1) }) + t.Run("lastContextCancelled", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + var total int + err := Authorize(ctx, cedar.NewPolicySet(), types.EntityMap{}, Request{ + Principal: types.NewEntityUID("Principal", "principal"), + Action: types.NewEntityUID("Action", "action"), + Resource: Variable("resource"), + Context: types.Record{}, + Variables: Variables{ + "resource": []types.Value{ + types.NewEntityUID("Resource", "1"), + types.NewEntityUID("Resource", "2"), + types.NewEntityUID("Resource", "3"), + }, + }, + }, func(_ Result) error { + total++ + if total == 3 { + cancel() + } + return nil + }, + ) + testutil.ErrorIs(t, err, context.Canceled) + testutil.Equals(t, total, 3) + }) + + t.Run("callbackErrored", func(t *testing.T) { + var total int + errWant := fmt.Errorf("errWant") + err := Authorize(context.Background(), cedar.NewPolicySet(), types.EntityMap{}, Request{ + Principal: types.NewEntityUID("Principal", "principal"), + Action: types.NewEntityUID("Action", "action"), + Resource: Variable("resource"), + Context: types.Record{}, + Variables: Variables{ + "resource": []types.Value{ + types.NewEntityUID("Resource", "1"), + types.NewEntityUID("Resource", "2"), + types.NewEntityUID("Resource", "3"), + }, + }, + }, func(_ Result) error { + total++ + if total == 2 { + return errWant + } + return nil + }, + ) + testutil.ErrorIs(t, err, errWant) + testutil.Equals(t, total, 2) + }) + + t.Run("contextAndCallbackErrored", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + errWant := fmt.Errorf("errWant") + err := Authorize(ctx, cedar.NewPolicySet(), types.EntityMap{}, Request{ + Principal: types.NewEntityUID("Principal", "principal"), + Action: types.NewEntityUID("Action", "action"), + Resource: Variable("resource"), + Context: types.Record{}, + Variables: Variables{ + "resource": []types.Value{ + types.NewEntityUID("Resource", "1"), + types.NewEntityUID("Resource", "2"), + types.NewEntityUID("Resource", "3"), + }, + }, + }, func(_ Result) error { + cancel() + return errWant + }, + ) + testutil.ErrorIs(t, err, context.Canceled) + testutil.ErrorIs(t, err, errWant) + }) + t.Run("invalidPrincipal", func(t *testing.T) { err := Authorize(context.Background(), cedar.NewPolicySet(), types.EntityMap{}, Request{ Principal: types.String("invalid"), Action: types.NewEntityUID("Action", "action"), Resource: types.NewEntityUID("Resource", "resource"), Context: types.Record{}, - }, func(_ Result) {}, + }, func(_ Result) error { return nil }, ) testutil.ErrorIs(t, err, errInvalidPart) }) @@ -381,7 +462,7 @@ func TestBatchErrors(t *testing.T) { Action: types.String("invalid"), Resource: types.NewEntityUID("Resource", "resource"), Context: types.Record{}, - }, func(_ Result) {}, + }, func(_ Result) error { return nil }, ) testutil.ErrorIs(t, err, errInvalidPart) }) @@ -391,7 +472,7 @@ func TestBatchErrors(t *testing.T) { Action: types.NewEntityUID("Action", "action"), Resource: types.String("invalid"), Context: types.Record{}, - }, func(_ Result) {}, + }, func(_ Result) error { return nil }, ) testutil.ErrorIs(t, err, errInvalidPart) }) @@ -401,7 +482,7 @@ func TestBatchErrors(t *testing.T) { Action: types.NewEntityUID("Action", "action"), Resource: types.NewEntityUID("Resource", "resource"), Context: types.String("invalid"), - }, func(_ Result) {}, + }, func(_ Result) error { return nil }, ) testutil.ErrorIs(t, err, errInvalidPart) }) @@ -596,7 +677,7 @@ func TestIgnoreReasons(t *testing.T) { var reasons []types.PolicyID var total int - err := Authorize(context.Background(), ps, types.EntityMap{}, tt.Request, func(r Result) { + err := Authorize(context.Background(), ps, types.EntityMap{}, tt.Request, func(r Result) error { total++ testutil.Equals(t, r.Decision, tt.Decision) for _, v := range r.Diagnostic.Reasons { @@ -604,6 +685,7 @@ func TestIgnoreReasons(t *testing.T) { reasons = append(reasons, v.PolicyID) } } + return nil }) testutil.OK(t, err) testutil.Equals(t, total, tt.Total)