From 8a95a232d5d206913a5898982f9f44179bd3f49f Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Tue, 1 Oct 2024 10:28:50 -0700 Subject: [PATCH] internal/eval: remove the inCache We don't currently have a benchmark that's telling us that we need this cache. In fact, several of our benchmarks show that for large, shallow entity graphs, the inCache actually slows down authorizations and batch evaluations because of all the extra allocation that goes on to build the map. For deep entity graphs, such a cache might be useful, but I think instead that we'll try to put such a cache in the entity graph itself by keeping track of the transitive closure of every entity's parents. That way, the cache will be effective across multiple authorizations. Signed-off-by: Patrick Jakubowski --- authorize.go | 6 +- internal/eval/compile.go | 2 +- internal/eval/compile_test.go | 8 +- internal/eval/convert_test.go | 4 +- internal/eval/evalers.go | 161 +++++++++++++------------------- internal/eval/evalers_test.go | 167 +++++++++++++++------------------- internal/eval/fold.go | 2 +- internal/eval/partial.go | 28 +++--- internal/eval/partial_test.go | 86 ++++++++--------- x/exp/batch/batch.go | 12 +-- 10 files changed, 210 insertions(+), 266 deletions(-) diff --git a/authorize.go b/authorize.go index 741b38fe..cdda72ca 100644 --- a/authorize.go +++ b/authorize.go @@ -19,13 +19,13 @@ const ( // IsAuthorized uses the combination of the PolicySet and Entities to determine // if the given Request to determine Decision and Diagnostic. func (p PolicySet) IsAuthorized(entityMap Entities, req Request) (Decision, Diagnostic) { - c := eval.InitEnv(&eval.Env{ + env := eval.Env{ Entities: entityMap, Principal: req.Principal, Action: req.Action, Resource: req.Resource, Context: req.Context, - }) + } var diag Diagnostic var forbids []DiagnosticReason var permits []DiagnosticReason @@ -35,7 +35,7 @@ func (p PolicySet) IsAuthorized(entityMap Entities, req Request) (Decision, Diag // - For permit, all permits must be run to collect annotations // - For forbid, forbids must be run to collect annotations for id, po := range p.policies { - result, err := po.eval.Eval(c) + result, err := po.eval.Eval(env) if err != nil { diag.Errors = append(diag.Errors, DiagnosticError{PolicyID: id, Position: po.Position(), Message: err.Error()}) continue diff --git a/internal/eval/compile.go b/internal/eval/compile.go index 35da2dfd..0b217d58 100644 --- a/internal/eval/compile.go +++ b/internal/eval/compile.go @@ -11,7 +11,7 @@ type BoolEvaler struct { eval Evaler } -func (e *BoolEvaler) Eval(env *Env) (types.Boolean, error) { +func (e *BoolEvaler) Eval(env Env) (types.Boolean, error) { v, err := e.eval.Eval(env) if err != nil { return false, err diff --git a/internal/eval/compile_test.go b/internal/eval/compile_test.go index 042423b0..7cea952f 100644 --- a/internal/eval/compile_test.go +++ b/internal/eval/compile_test.go @@ -12,7 +12,7 @@ import ( func TestCompile(t *testing.T) { t.Parallel() e := Compile(ast.Permit()) - res, err := e.Eval(nil) + res, err := e.Eval(Env{}) testutil.OK(t, err) testutil.Equals(t, res, types.True) } @@ -22,7 +22,7 @@ func TestBoolEvaler(t *testing.T) { t.Run("Happy", func(t *testing.T) { t.Parallel() b := BoolEvaler{eval: newLiteralEval(types.True)} - v, err := b.Eval(nil) + v, err := b.Eval(Env{}) testutil.OK(t, err) testutil.Equals(t, v, true) }) @@ -31,7 +31,7 @@ func TestBoolEvaler(t *testing.T) { t.Parallel() errWant := fmt.Errorf("error") b := BoolEvaler{eval: newErrorEval(errWant)} - v, err := b.Eval(nil) + v, err := b.Eval(Env{}) testutil.ErrorIs(t, err, errWant) testutil.Equals(t, v, false) }) @@ -39,7 +39,7 @@ func TestBoolEvaler(t *testing.T) { t.Run("NonBool", func(t *testing.T) { t.Parallel() b := BoolEvaler{eval: newLiteralEval(types.String("bad"))} - v, err := b.Eval(nil) + v, err := b.Eval(Env{}) testutil.ErrorIs(t, err, ErrType) testutil.Equals(t, v, false) }) diff --git a/internal/eval/convert_test.go b/internal/eval/convert_test.go index 85389f23..1febc614 100644 --- a/internal/eval/convert_test.go +++ b/internal/eval/convert_test.go @@ -350,12 +350,12 @@ func TestToEval(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() e := toEval(tt.in.AsIsNode()) - out, err := e.Eval(InitEnv(&Env{ + out, err := e.Eval(Env{ Principal: types.NewEntityUID("Actor", "principal"), Action: types.NewEntityUID("Action", "test"), Resource: types.NewEntityUID("Resource", "database"), Context: types.Record{}, - })) + }) tt.err(t, err) testutil.Equals(t, out, tt.out) }) diff --git a/internal/eval/evalers.go b/internal/eval/evalers.go index bb864aac..5238e32b 100644 --- a/internal/eval/evalers.go +++ b/internal/eval/evalers.go @@ -25,34 +25,13 @@ type Env struct { Entities types.Entities Principal, Action, Resource types.Value Context types.Value - - inCache map[inKey]bool -} - -type inKey struct { - a, b types.EntityUID -} - -func NewEnv() *Env { - return InitEnv(&Env{}) -} - -func InitEnv(in *Env) *Env { - // add caches if applicable - in.inCache = map[inKey]bool{} - return in -} - -func InitEnvWithCacheFrom(in *Env, parent *Env) *Env { - in.inCache = parent.inCache - return in } type Evaler interface { - Eval(*Env) (types.Value, error) + Eval(Env) (types.Value, error) } -func evalBool(n Evaler, env *Env) (types.Boolean, error) { +func evalBool(n Evaler, env Env) (types.Boolean, error) { v, err := n.Eval(env) if err != nil { return false, err @@ -64,7 +43,7 @@ func evalBool(n Evaler, env *Env) (types.Boolean, error) { return b, nil } -func evalLong(n Evaler, env *Env) (types.Long, error) { +func evalLong(n Evaler, env Env) (types.Long, error) { v, err := n.Eval(env) if err != nil { return 0, err @@ -76,7 +55,7 @@ func evalLong(n Evaler, env *Env) (types.Long, error) { return l, nil } -func evalComparableValue(n Evaler, env *Env) (ComparableValue, error) { +func evalComparableValue(n Evaler, env Env) (ComparableValue, error) { v, err := n.Eval(env) if err != nil { return nil, err @@ -88,7 +67,7 @@ func evalComparableValue(n Evaler, env *Env) (ComparableValue, error) { return l, nil } -func evalString(n Evaler, env *Env) (types.String, error) { +func evalString(n Evaler, env Env) (types.String, error) { v, err := n.Eval(env) if err != nil { return "", err @@ -100,7 +79,7 @@ func evalString(n Evaler, env *Env) (types.String, error) { return s, nil } -func evalSet(n Evaler, env *Env) (types.Set, error) { +func evalSet(n Evaler, env Env) (types.Set, error) { v, err := n.Eval(env) if err != nil { return types.Set{}, err @@ -112,7 +91,7 @@ func evalSet(n Evaler, env *Env) (types.Set, error) { return s, nil } -func evalEntity(n Evaler, env *Env) (types.EntityUID, error) { +func evalEntity(n Evaler, env Env) (types.EntityUID, error) { v, err := n.Eval(env) if err != nil { return types.EntityUID{}, err @@ -124,7 +103,7 @@ func evalEntity(n Evaler, env *Env) (types.EntityUID, error) { return e, nil } -func evalDatetime(n Evaler, env *Env) (types.Datetime, error) { +func evalDatetime(n Evaler, env Env) (types.Datetime, error) { v, err := n.Eval(env) if err != nil { return types.Datetime{}, err @@ -136,7 +115,7 @@ func evalDatetime(n Evaler, env *Env) (types.Datetime, error) { return d, nil } -func evalDecimal(n Evaler, env *Env) (types.Decimal, error) { +func evalDecimal(n Evaler, env Env) (types.Decimal, error) { v, err := n.Eval(env) if err != nil { return types.Decimal{}, err @@ -148,7 +127,7 @@ func evalDecimal(n Evaler, env *Env) (types.Decimal, error) { return d, nil } -func evalDuration(n Evaler, env *Env) (types.Duration, error) { +func evalDuration(n Evaler, env Env) (types.Duration, error) { v, err := n.Eval(env) if err != nil { return types.Duration{}, err @@ -160,7 +139,7 @@ func evalDuration(n Evaler, env *Env) (types.Duration, error) { return d, nil } -func evalIP(n Evaler, env *Env) (types.IPAddr, error) { +func evalIP(n Evaler, env Env) (types.IPAddr, error) { v, err := n.Eval(env) if err != nil { @@ -184,7 +163,7 @@ func newErrorEval(err error) *errorEval { } } -func (n *errorEval) Eval(_ *Env) (types.Value, error) { +func (n *errorEval) Eval(Env) (types.Value, error) { return zeroValue(), n.err } @@ -197,7 +176,7 @@ func newLiteralEval(value types.Value) *literalEval { return &literalEval{value: value} } -func (n *literalEval) Eval(_ *Env) (types.Value, error) { +func (n *literalEval) Eval(Env) (types.Value, error) { return n.value, nil } @@ -214,7 +193,7 @@ func newOrEval(lhs Evaler, rhs Evaler) Evaler { } } -func (n *orEval) Eval(env *Env) (types.Value, error) { +func (n *orEval) Eval(env Env) (types.Value, error) { v, err := n.lhs.Eval(env) if err != nil { return zeroValue(), err @@ -250,7 +229,7 @@ func newAndEval(lhs Evaler, rhs Evaler) Evaler { } } -func (n *andEval) Eval(env *Env) (types.Value, error) { +func (n *andEval) Eval(env Env) (types.Value, error) { v, err := n.lhs.Eval(env) if err != nil { return zeroValue(), err @@ -284,7 +263,7 @@ func newNotEval(inner Evaler) Evaler { } } -func (n *notEval) Eval(env *Env) (types.Value, error) { +func (n *notEval) Eval(env Env) (types.Value, error) { v, err := n.inner.Eval(env) if err != nil { return zeroValue(), err @@ -353,7 +332,7 @@ func newAddEval(lhs Evaler, rhs Evaler) Evaler { } } -func (n *addEval) Eval(env *Env) (types.Value, error) { +func (n *addEval) Eval(env Env) (types.Value, error) { lhs, err := evalLong(n.lhs, env) if err != nil { return zeroValue(), err @@ -382,7 +361,7 @@ func newSubtractEval(lhs Evaler, rhs Evaler) Evaler { } } -func (n *subtractEval) Eval(env *Env) (types.Value, error) { +func (n *subtractEval) Eval(env Env) (types.Value, error) { lhs, err := evalLong(n.lhs, env) if err != nil { return zeroValue(), err @@ -411,7 +390,7 @@ func newMultiplyEval(lhs Evaler, rhs Evaler) Evaler { } } -func (n *multiplyEval) Eval(env *Env) (types.Value, error) { +func (n *multiplyEval) Eval(env Env) (types.Value, error) { lhs, err := evalLong(n.lhs, env) if err != nil { return zeroValue(), err @@ -438,7 +417,7 @@ func newNegateEval(inner Evaler) Evaler { } } -func (n *negateEval) Eval(env *Env) (types.Value, error) { +func (n *negateEval) Eval(env Env) (types.Value, error) { inner, err := evalLong(n.inner, env) if err != nil { return zeroValue(), err @@ -463,7 +442,7 @@ func newLongLessThanEval(lhs Evaler, rhs Evaler) Evaler { } } -func (n *longLessThanEval) Eval(env *Env) (types.Value, error) { +func (n *longLessThanEval) Eval(env Env) (types.Value, error) { lhs, err := evalLong(n.lhs, env) if err != nil { return zeroValue(), err @@ -488,7 +467,7 @@ func newLongLessThanOrEqualEval(lhs Evaler, rhs Evaler) Evaler { } } -func (n *longLessThanOrEqualEval) Eval(env *Env) (types.Value, error) { +func (n *longLessThanOrEqualEval) Eval(env Env) (types.Value, error) { lhs, err := evalLong(n.lhs, env) if err != nil { return zeroValue(), err @@ -513,7 +492,7 @@ func newLongGreaterThanEval(lhs Evaler, rhs Evaler) Evaler { } } -func (n *longGreaterThanEval) Eval(env *Env) (types.Value, error) { +func (n *longGreaterThanEval) Eval(env Env) (types.Value, error) { lhs, err := evalLong(n.lhs, env) if err != nil { return zeroValue(), err @@ -538,7 +517,7 @@ func newLongGreaterThanOrEqualEval(lhs Evaler, rhs Evaler) Evaler { } } -func (n *longGreaterThanOrEqualEval) Eval(env *Env) (types.Value, error) { +func (n *longGreaterThanOrEqualEval) Eval(env Env) (types.Value, error) { lhs, err := evalLong(n.lhs, env) if err != nil { return zeroValue(), err @@ -563,7 +542,7 @@ func newDecimalLessThanEval(lhs Evaler, rhs Evaler) *decimalLessThanEval { } } -func (n *decimalLessThanEval) Eval(env *Env) (types.Value, error) { +func (n *decimalLessThanEval) Eval(env Env) (types.Value, error) { lhs, err := evalDecimal(n.lhs, env) if err != nil { return zeroValue(), err @@ -588,7 +567,7 @@ func newDecimalLessThanOrEqualEval(lhs Evaler, rhs Evaler) *decimalLessThanOrEqu } } -func (n *decimalLessThanOrEqualEval) Eval(env *Env) (types.Value, error) { +func (n *decimalLessThanOrEqualEval) Eval(env Env) (types.Value, error) { lhs, err := evalDecimal(n.lhs, env) if err != nil { return zeroValue(), err @@ -613,7 +592,7 @@ func newDecimalGreaterThanEval(lhs Evaler, rhs Evaler) *decimalGreaterThanEval { } } -func (n *decimalGreaterThanEval) Eval(env *Env) (types.Value, error) { +func (n *decimalGreaterThanEval) Eval(env Env) (types.Value, error) { lhs, err := evalDecimal(n.lhs, env) if err != nil { return zeroValue(), err @@ -638,7 +617,7 @@ func newDecimalGreaterThanOrEqualEval(lhs Evaler, rhs Evaler) *decimalGreaterTha } } -func (n *decimalGreaterThanOrEqualEval) Eval(env *Env) (types.Value, error) { +func (n *decimalGreaterThanOrEqualEval) Eval(env Env) (types.Value, error) { lhs, err := evalDecimal(n.lhs, env) if err != nil { return zeroValue(), err @@ -665,7 +644,7 @@ func newIfThenElseEval(if_, then, else_ Evaler) *ifThenElseEval { } } -func (n *ifThenElseEval) Eval(env *Env) (types.Value, error) { +func (n *ifThenElseEval) Eval(env Env) (types.Value, error) { cond, err := evalBool(n.if_, env) if err != nil { return zeroValue(), err @@ -688,7 +667,7 @@ func newEqualEval(lhs, rhs Evaler) Evaler { } } -func (n *equalEval) Eval(env *Env) (types.Value, error) { +func (n *equalEval) Eval(env Env) (types.Value, error) { lv, err := n.lhs.Eval(env) if err != nil { return zeroValue(), err @@ -712,7 +691,7 @@ func newNotEqualEval(lhs, rhs Evaler) Evaler { } } -func (n *notEqualEval) Eval(env *Env) (types.Value, error) { +func (n *notEqualEval) Eval(env Env) (types.Value, error) { lv, err := n.lhs.Eval(env) if err != nil { return zeroValue(), err @@ -733,7 +712,7 @@ func newSetLiteralEval(elements []Evaler) *setLiteralEval { return &setLiteralEval{elements: elements} } -func (n *setLiteralEval) Eval(env *Env) (types.Value, error) { +func (n *setLiteralEval) Eval(env Env) (types.Value, error) { vals := make([]types.Value, len(n.elements)) for i, e := range n.elements { v, err := e.Eval(env) @@ -757,7 +736,7 @@ func newContainsEval(lhs, rhs Evaler) Evaler { } } -func (n *containsEval) Eval(env *Env) (types.Value, error) { +func (n *containsEval) Eval(env Env) (types.Value, error) { lhs, err := evalSet(n.lhs, env) if err != nil { return zeroValue(), err @@ -781,7 +760,7 @@ func newContainsAllEval(lhs, rhs Evaler) Evaler { } } -func (n *containsAllEval) Eval(env *Env) (types.Value, error) { +func (n *containsAllEval) Eval(env Env) (types.Value, error) { lhs, err := evalSet(n.lhs, env) if err != nil { return zeroValue(), err @@ -813,7 +792,7 @@ func newContainsAnyEval(lhs, rhs Evaler) Evaler { } } -func (n *containsAnyEval) Eval(env *Env) (types.Value, error) { +func (n *containsAnyEval) Eval(env Env) (types.Value, error) { lhs, err := evalSet(n.lhs, env) if err != nil { return zeroValue(), err @@ -842,7 +821,7 @@ func newRecordLiteralEval(elements map[types.String]Evaler) *recordLiteralEval { return &recordLiteralEval{elements: elements} } -func (n *recordLiteralEval) Eval(env *Env) (types.Value, error) { +func (n *recordLiteralEval) Eval(env Env) (types.Value, error) { vals := types.RecordMap{} for k, en := range n.elements { v, err := en.Eval(env) @@ -864,7 +843,7 @@ func newAttributeAccessEval(record Evaler, attribute types.String) *attributeAcc return &attributeAccessEval{object: record, attribute: attribute} } -func (n *attributeAccessEval) Eval(env *Env) (types.Value, error) { +func (n *attributeAccessEval) Eval(env Env) (types.Value, error) { v, err := n.object.Eval(env) if err != nil { return zeroValue(), err @@ -905,7 +884,7 @@ func newHasEval(record Evaler, attribute types.String) *hasEval { return &hasEval{object: record, attribute: attribute} } -func (n *hasEval) Eval(env *Env) (types.Value, error) { +func (n *hasEval) Eval(env Env) (types.Value, error) { v, err := n.object.Eval(env) if err != nil { return zeroValue(), err @@ -935,7 +914,7 @@ func newLikeEval(lhs Evaler, pattern types.Pattern) *likeEval { return &likeEval{lhs: lhs, pattern: pattern} } -func (l *likeEval) Eval(env *Env) (types.Value, error) { +func (l *likeEval) Eval(env Env) (types.Value, error) { v, err := evalString(l.lhs, env) if err != nil { return zeroValue(), err @@ -952,7 +931,7 @@ func newVariableEval(variableName types.String) *variableEval { return &variableEval{variableName: variableName} } -func (n *variableEval) Eval(env *Env) (types.Value, error) { +func (n *variableEval) Eval(env Env) (types.Value, error) { switch n.variableName { case consts.Principal: return env.Principal, nil @@ -974,17 +953,7 @@ func newInEval(lhs, rhs Evaler) Evaler { return &inEval{lhs: lhs, rhs: rhs} } -func entityInOne(env *Env, entity types.EntityUID, parent types.EntityUID) bool { - key := inKey{a: entity, b: parent} - if cached, ok := env.inCache[key]; ok { - return cached - } - result := entityInOneWork(env, entity, parent) - env.inCache[key] = result - return result -} - -func entityInOneWork(env *Env, entity types.EntityUID, parent types.EntityUID) bool { +func entityInOne(env Env, entity types.EntityUID, parent types.EntityUID) bool { if entity == parent { return true } @@ -1013,7 +982,7 @@ func entityInOneWork(env *Env, entity types.EntityUID, parent types.EntityUID) b } } -func entityInSet(env *Env, entity types.EntityUID, parents mapset.Container[types.EntityUID]) bool { +func entityInSet(env Env, entity types.EntityUID, parents mapset.Container[types.EntityUID]) bool { if parents.Contains(entity) { return true } @@ -1042,7 +1011,7 @@ func entityInSet(env *Env, entity types.EntityUID, parents mapset.Container[type } } -func (n *inEval) Eval(env *Env) (types.Value, error) { +func (n *inEval) Eval(env Env) (types.Value, error) { lhs, err := evalEntity(n.lhs, env) if err != nil { return zeroValue(), err @@ -1056,7 +1025,7 @@ func (n *inEval) Eval(env *Env) (types.Value, error) { return doInEval(env, lhs, rhs) } -func doInEval(env *Env, lhs types.EntityUID, rhs types.Value) (types.Value, error) { +func doInEval(env Env, lhs types.EntityUID, rhs types.Value) (types.Value, error) { switch rhsv := rhs.(type) { case types.EntityUID: return types.Boolean(entityInOne(env, lhs, rhsv)), nil @@ -1090,7 +1059,7 @@ func newIsEval(lhs Evaler, rhs types.EntityType) *isEval { return &isEval{lhs: lhs, rhs: rhs} } -func (n *isEval) Eval(env *Env) (types.Value, error) { +func (n *isEval) Eval(env Env) (types.Value, error) { lhs, err := evalEntity(n.lhs, env) if err != nil { return zeroValue(), err @@ -1108,7 +1077,7 @@ func newIsInEval(lhs Evaler, is types.EntityType, rhs Evaler) Evaler { return &isInEval{lhs: lhs, is: is, rhs: rhs} } -func (n *isInEval) Eval(env *Env) (types.Value, error) { +func (n *isInEval) Eval(env Env) (types.Value, error) { lhs, err := evalEntity(n.lhs, env) if err != nil { return zeroValue(), err @@ -1132,7 +1101,7 @@ func newDecimalLiteralEval(literal Evaler) *decimalLiteralEval { return &decimalLiteralEval{literal: literal} } -func (n *decimalLiteralEval) Eval(env *Env) (types.Value, error) { +func (n *decimalLiteralEval) Eval(env Env) (types.Value, error) { literal, err := evalString(n.literal, env) if err != nil { return zeroValue(), err @@ -1155,7 +1124,7 @@ func newDatetimeLiteralEval(literal Evaler) *datetimeLiteralEval { return &datetimeLiteralEval{literal: literal} } -func (n *datetimeLiteralEval) Eval(env *Env) (types.Value, error) { +func (n *datetimeLiteralEval) Eval(env Env) (types.Value, error) { literal, err := evalString(n.literal, env) if err != nil { return zeroValue(), err @@ -1177,7 +1146,7 @@ func newDurationLiteralEval(literal Evaler) *durationLiteralEval { return &durationLiteralEval{literal: literal} } -func (n *durationLiteralEval) Eval(env *Env) (types.Value, error) { +func (n *durationLiteralEval) Eval(env Env) (types.Value, error) { literal, err := evalString(n.literal, env) if err != nil { return zeroValue(), err @@ -1199,7 +1168,7 @@ func newIPLiteralEval(literal Evaler) *ipLiteralEval { return &ipLiteralEval{literal: literal} } -func (n *ipLiteralEval) Eval(env *Env) (types.Value, error) { +func (n *ipLiteralEval) Eval(env Env) (types.Value, error) { literal, err := evalString(n.literal, env) if err != nil { return zeroValue(), err @@ -1230,7 +1199,7 @@ func newIPTestEval(object Evaler, test ipTestType) *ipTestEval { return &ipTestEval{object: object, test: test} } -func (n *ipTestEval) Eval(env *Env) (types.Value, error) { +func (n *ipTestEval) Eval(env Env) (types.Value, error) { i, err := evalIP(n.object, env) if err != nil { return zeroValue(), err @@ -1248,7 +1217,7 @@ func newIPIsInRangeEval(lhs, rhs Evaler) *ipIsInRangeEval { return &ipIsInRangeEval{lhs: lhs, rhs: rhs} } -func (n *ipIsInRangeEval) Eval(env *Env) (types.Value, error) { +func (n *ipIsInRangeEval) Eval(env Env) (types.Value, error) { lhs, err := evalIP(n.lhs, env) if err != nil { return zeroValue(), err @@ -1339,7 +1308,7 @@ func newComparableValueLessThanEval(lhs Evaler, rhs Evaler) *comparableValueLess } } -func (n *comparableValueLessThanEval) Eval(env *Env) (types.Value, error) { +func (n *comparableValueLessThanEval) Eval(env Env) (types.Value, error) { lhs, err := evalComparableValue(n.lhs, env) if err != nil { return zeroValue(), err @@ -1370,7 +1339,7 @@ func newComparableValueGreaterThanEval(lhs Evaler, rhs Evaler) *comparableValueG } } -func (n *comparableValueGreaterThanEval) Eval(env *Env) (types.Value, error) { +func (n *comparableValueGreaterThanEval) Eval(env Env) (types.Value, error) { lhs, err := evalComparableValue(n.lhs, env) if err != nil { return zeroValue(), err @@ -1399,7 +1368,7 @@ func newComparableValueLessThanOrEqualEval(lhs Evaler, rhs Evaler) *comparableVa } } -func (n *comparableValueLessThanOrEqualEval) Eval(env *Env) (types.Value, error) { +func (n *comparableValueLessThanOrEqualEval) Eval(env Env) (types.Value, error) { lhs, err := evalComparableValue(n.lhs, env) if err != nil { return zeroValue(), err @@ -1428,7 +1397,7 @@ func newComparableValueGreaterThanOrEqualEval(lhs Evaler, rhs Evaler) *comparabl } } -func (n *comparableValueGreaterThanOrEqualEval) Eval(env *Env) (types.Value, error) { +func (n *comparableValueGreaterThanOrEqualEval) Eval(env Env) (types.Value, error) { lhs, err := evalComparableValue(n.lhs, env) if err != nil { return zeroValue(), err @@ -1452,7 +1421,7 @@ func newToDateEval(lhs Evaler) *toDateEval { return &toDateEval{lhs: lhs} } -func (n *toDateEval) Eval(env *Env) (types.Value, error) { +func (n *toDateEval) Eval(env Env) (types.Value, error) { lhs, err := evalDatetime(n.lhs, env) if err != nil { return zeroValue(), err @@ -1468,7 +1437,7 @@ func newToTimeEval(lhs Evaler) *toTimeEval { return &toTimeEval{lhs: lhs} } -func (n *toTimeEval) Eval(env *Env) (types.Value, error) { +func (n *toTimeEval) Eval(env Env) (types.Value, error) { lhs, err := evalDatetime(n.lhs, env) if err != nil { return zeroValue(), err @@ -1484,7 +1453,7 @@ func newToMillisecondsEval(lhs Evaler) *toMillisecondsEval { return &toMillisecondsEval{lhs: lhs} } -func (n *toMillisecondsEval) Eval(env *Env) (types.Value, error) { +func (n *toMillisecondsEval) Eval(env Env) (types.Value, error) { lhs, err := evalDuration(n.lhs, env) if err != nil { return zeroValue(), err @@ -1500,7 +1469,7 @@ func newToSecondsEval(lhs Evaler) *toSecondsEval { return &toSecondsEval{lhs: lhs} } -func (n *toSecondsEval) Eval(env *Env) (types.Value, error) { +func (n *toSecondsEval) Eval(env Env) (types.Value, error) { lhs, err := evalDuration(n.lhs, env) if err != nil { return zeroValue(), err @@ -1516,7 +1485,7 @@ func newToMinutesEval(lhs Evaler) *toMinutesEval { return &toMinutesEval{lhs: lhs} } -func (n *toMinutesEval) Eval(env *Env) (types.Value, error) { +func (n *toMinutesEval) Eval(env Env) (types.Value, error) { lhs, err := evalDuration(n.lhs, env) if err != nil { return zeroValue(), err @@ -1532,7 +1501,7 @@ func newToHoursEval(lhs Evaler) *toHoursEval { return &toHoursEval{lhs: lhs} } -func (n *toHoursEval) Eval(env *Env) (types.Value, error) { +func (n *toHoursEval) Eval(env Env) (types.Value, error) { lhs, err := evalDuration(n.lhs, env) if err != nil { return zeroValue(), err @@ -1548,7 +1517,7 @@ func newToDaysEval(lhs Evaler) *toDaysEval { return &toDaysEval{lhs: lhs} } -func (n *toDaysEval) Eval(env *Env) (types.Value, error) { +func (n *toDaysEval) Eval(env Env) (types.Value, error) { lhs, err := evalDuration(n.lhs, env) if err != nil { return zeroValue(), err @@ -1565,7 +1534,7 @@ func newOffsetEval(lhs Evaler, rhs Evaler) *offsetEval { return &offsetEval{lhs: lhs, rhs: rhs} } -func (n *offsetEval) Eval(env *Env) (types.Value, error) { +func (n *offsetEval) Eval(env Env) (types.Value, error) { lhs, err := evalDatetime(n.lhs, env) if err != nil { return zeroValue(), err @@ -1586,7 +1555,7 @@ func newDurationSinceEval(lhs Evaler, rhs Evaler) *durationSinceEval { return &durationSinceEval{lhs: lhs, rhs: rhs} } -func (n *durationSinceEval) Eval(env *Env) (types.Value, error) { +func (n *durationSinceEval) Eval(env Env) (types.Value, error) { lhs, err := evalDatetime(n.lhs, env) if err != nil { return zeroValue(), err diff --git a/internal/eval/evalers_test.go b/internal/eval/evalers_test.go index 6262940b..ed5b3bfc 100644 --- a/internal/eval/evalers_test.go +++ b/internal/eval/evalers_test.go @@ -61,7 +61,7 @@ func TestOrNode(t *testing.T) { t.Run(fmt.Sprintf("%v%v", tt.lhs, tt.rhs), func(t *testing.T) { t.Parallel() n := newOrEval(newLiteralEval(types.Boolean(tt.lhs)), newLiteralEval(types.Boolean(tt.rhs))) - v, err := n.Eval(NewEnv()) + v, err := n.Eval(Env{}) testutil.OK(t, err) AssertBoolValue(t, v, tt.result) }) @@ -72,7 +72,7 @@ func TestOrNode(t *testing.T) { t.Parallel() n := newOrEval( newLiteralEval(types.True), newLiteralEval(types.Long(1))) - v, err := n.Eval(NewEnv()) + v, err := n.Eval(Env{}) testutil.OK(t, err) AssertBoolValue(t, v, true) }) @@ -93,7 +93,7 @@ func TestOrNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newOrEval(tt.lhs, tt.rhs) - _, err := n.Eval(NewEnv()) + _, err := n.Eval(Env{}) testutil.ErrorIs(t, err, tt.err) }) } @@ -116,7 +116,7 @@ func TestAndNode(t *testing.T) { t.Run(fmt.Sprintf("%v%v", tt.lhs, tt.rhs), func(t *testing.T) { t.Parallel() n := newAndEval(newLiteralEval(types.Boolean(tt.lhs)), newLiteralEval(types.Boolean(tt.rhs))) - v, err := n.Eval(NewEnv()) + v, err := n.Eval(Env{}) testutil.OK(t, err) AssertBoolValue(t, v, tt.result) }) @@ -127,7 +127,7 @@ func TestAndNode(t *testing.T) { t.Parallel() n := newAndEval( newLiteralEval(types.False), newLiteralEval(types.Long(1))) - v, err := n.Eval(NewEnv()) + v, err := n.Eval(Env{}) testutil.OK(t, err) AssertBoolValue(t, v, false) }) @@ -148,7 +148,7 @@ func TestAndNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newAndEval(tt.lhs, tt.rhs) - _, err := n.Eval(NewEnv()) + _, err := n.Eval(Env{}) testutil.ErrorIs(t, err, tt.err) }) } @@ -169,7 +169,7 @@ func TestNotNode(t *testing.T) { t.Run(fmt.Sprintf("%v", tt.arg), func(t *testing.T) { t.Parallel() n := newNotEval(newLiteralEval(types.Boolean(tt.arg))) - v, err := n.Eval(NewEnv()) + v, err := n.Eval(Env{}) testutil.OK(t, err) AssertBoolValue(t, v, tt.result) }) @@ -190,7 +190,7 @@ func TestNotNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newNotEval(tt.arg) - _, err := n.Eval(NewEnv()) + _, err := n.Eval(Env{}) testutil.ErrorIs(t, err, tt.err) }) } @@ -370,7 +370,7 @@ func TestAddNode(t *testing.T) { t.Run("Basic", func(t *testing.T) { t.Parallel() n := newAddEval(newLiteralEval(types.Long(1)), newLiteralEval(types.Long(2))) - v, err := n.Eval(NewEnv()) + v, err := n.Eval(Env{}) testutil.OK(t, err) AssertLongValue(t, v, 3) }) @@ -398,7 +398,7 @@ func TestAddNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newAddEval(tt.lhs, tt.rhs) - _, err := n.Eval(NewEnv()) + _, err := n.Eval(Env{}) testutil.ErrorIs(t, err, tt.err) }) } @@ -409,7 +409,7 @@ func TestSubtractNode(t *testing.T) { t.Run("Basic", func(t *testing.T) { t.Parallel() n := newSubtractEval(newLiteralEval(types.Long(1)), newLiteralEval(types.Long(2))) - v, err := n.Eval(NewEnv()) + v, err := n.Eval(Env{}) testutil.OK(t, err) AssertLongValue(t, v, -1) }) @@ -437,7 +437,7 @@ func TestSubtractNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newSubtractEval(tt.lhs, tt.rhs) - _, err := n.Eval(NewEnv()) + _, err := n.Eval(Env{}) testutil.ErrorIs(t, err, tt.err) }) } @@ -448,7 +448,7 @@ func TestMultiplyNode(t *testing.T) { t.Run("Basic", func(t *testing.T) { t.Parallel() n := newMultiplyEval(newLiteralEval(types.Long(-3)), newLiteralEval(types.Long(2))) - v, err := n.Eval(NewEnv()) + v, err := n.Eval(Env{}) testutil.OK(t, err) AssertLongValue(t, v, -6) }) @@ -476,7 +476,7 @@ func TestMultiplyNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newMultiplyEval(tt.lhs, tt.rhs) - _, err := n.Eval(NewEnv()) + _, err := n.Eval(Env{}) testutil.ErrorIs(t, err, tt.err) }) } @@ -487,7 +487,7 @@ func TestNegateNode(t *testing.T) { t.Run("Basic", func(t *testing.T) { t.Parallel() n := newNegateEval(newLiteralEval(types.Long(-3))) - v, err := n.Eval(NewEnv()) + v, err := n.Eval(Env{}) testutil.OK(t, err) AssertLongValue(t, v, 3) }) @@ -506,7 +506,7 @@ func TestNegateNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newNegateEval(tt.arg) - _, err := n.Eval(NewEnv()) + _, err := n.Eval(Env{}) testutil.ErrorIs(t, err, tt.err) }) } @@ -535,7 +535,7 @@ func TestLongLessThanNode(t *testing.T) { t.Parallel() n := newLongLessThanEval( newLiteralEval(types.Long(tt.lhs)), newLiteralEval(types.Long(tt.rhs))) - v, err := n.Eval(NewEnv()) + v, err := n.Eval(Env{}) testutil.OK(t, err) AssertBoolValue(t, v, tt.result) }) @@ -557,7 +557,7 @@ func TestLongLessThanNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newLongLessThanEval(tt.lhs, tt.rhs) - _, err := n.Eval(NewEnv()) + _, err := n.Eval(Env{}) testutil.ErrorIs(t, err, tt.err) }) } @@ -587,7 +587,7 @@ func TestLongLessThanOrEqualNode(t *testing.T) { t.Parallel() n := newLongLessThanOrEqualEval( newLiteralEval(types.Long(tt.lhs)), newLiteralEval(types.Long(tt.rhs))) - v, err := n.Eval(NewEnv()) + v, err := n.Eval(Env{}) testutil.OK(t, err) AssertBoolValue(t, v, tt.result) }) @@ -609,7 +609,7 @@ func TestLongLessThanOrEqualNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newLongLessThanOrEqualEval(tt.lhs, tt.rhs) - _, err := n.Eval(NewEnv()) + _, err := n.Eval(Env{}) testutil.ErrorIs(t, err, tt.err) }) } @@ -639,7 +639,7 @@ func TestLongGreaterThanNode(t *testing.T) { t.Parallel() n := newLongGreaterThanEval( newLiteralEval(types.Long(tt.lhs)), newLiteralEval(types.Long(tt.rhs))) - v, err := n.Eval(NewEnv()) + v, err := n.Eval(Env{}) testutil.OK(t, err) AssertBoolValue(t, v, tt.result) }) @@ -661,7 +661,7 @@ func TestLongGreaterThanNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newLongGreaterThanEval(tt.lhs, tt.rhs) - _, err := n.Eval(NewEnv()) + _, err := n.Eval(Env{}) testutil.ErrorIs(t, err, tt.err) }) } @@ -691,7 +691,7 @@ func TestLongGreaterThanOrEqualNode(t *testing.T) { t.Parallel() n := newLongGreaterThanOrEqualEval( newLiteralEval(types.Long(tt.lhs)), newLiteralEval(types.Long(tt.rhs))) - v, err := n.Eval(NewEnv()) + v, err := n.Eval(Env{}) testutil.OK(t, err) AssertBoolValue(t, v, tt.result) }) @@ -713,7 +713,7 @@ func TestLongGreaterThanOrEqualNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newLongGreaterThanOrEqualEval(tt.lhs, tt.rhs) - _, err := n.Eval(NewEnv()) + _, err := n.Eval(Env{}) testutil.ErrorIs(t, err, tt.err) }) } @@ -748,7 +748,7 @@ func TestDecimalLessThanNode(t *testing.T) { testutil.OK(t, err) rhsv := rhsd n := newDecimalLessThanEval(newLiteralEval(lhsv), newLiteralEval(rhsv)) - v, err := n.Eval(NewEnv()) + v, err := n.Eval(Env{}) testutil.OK(t, err) AssertBoolValue(t, v, tt.result) }) @@ -770,7 +770,7 @@ func TestDecimalLessThanNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newDecimalLessThanEval(tt.lhs, tt.rhs) - _, err := n.Eval(NewEnv()) + _, err := n.Eval(Env{}) testutil.ErrorIs(t, err, tt.err) }) } @@ -805,7 +805,7 @@ func TestDecimalLessThanOrEqualNode(t *testing.T) { testutil.OK(t, err) rhsv := rhsd n := newDecimalLessThanOrEqualEval(newLiteralEval(lhsv), newLiteralEval(rhsv)) - v, err := n.Eval(NewEnv()) + v, err := n.Eval(Env{}) testutil.OK(t, err) AssertBoolValue(t, v, tt.result) }) @@ -827,7 +827,7 @@ func TestDecimalLessThanOrEqualNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newDecimalLessThanOrEqualEval(tt.lhs, tt.rhs) - _, err := n.Eval(NewEnv()) + _, err := n.Eval(Env{}) testutil.ErrorIs(t, err, tt.err) }) } @@ -862,7 +862,7 @@ func TestDecimalGreaterThanNode(t *testing.T) { testutil.OK(t, err) rhsv := rhsd n := newDecimalGreaterThanEval(newLiteralEval(lhsv), newLiteralEval(rhsv)) - v, err := n.Eval(NewEnv()) + v, err := n.Eval(Env{}) testutil.OK(t, err) AssertBoolValue(t, v, tt.result) }) @@ -884,7 +884,7 @@ func TestDecimalGreaterThanNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newDecimalGreaterThanEval(tt.lhs, tt.rhs) - _, err := n.Eval(NewEnv()) + _, err := n.Eval(Env{}) testutil.ErrorIs(t, err, tt.err) }) } @@ -919,7 +919,7 @@ func TestDecimalGreaterThanOrEqualNode(t *testing.T) { testutil.OK(t, err) rhsv := rhsd n := newDecimalGreaterThanOrEqualEval(newLiteralEval(lhsv), newLiteralEval(rhsv)) - v, err := n.Eval(NewEnv()) + v, err := n.Eval(Env{}) testutil.OK(t, err) AssertBoolValue(t, v, tt.result) }) @@ -941,7 +941,7 @@ func TestDecimalGreaterThanOrEqualNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newDecimalGreaterThanOrEqualEval(tt.lhs, tt.rhs) - _, err := n.Eval(NewEnv()) + _, err := n.Eval(Env{}) testutil.ErrorIs(t, err, tt.err) }) } @@ -1182,7 +1182,7 @@ func TestComparableValueComparisonNodes(t *testing.T) { t.Run(fmt.Sprintf("%v_%s_%v", tt.lhs, tc.name, tt.rhs), func(t *testing.T) { t.Parallel() n := tc.evaler(toEvaler(tt.lhs), toEvaler(tt.rhs)) - v, err := n.Eval(&Env{}) + v, err := n.Eval(Env{}) if tt.wantErr == nil { testutil.OK(t, err) AssertBoolValue(t, v, tt.result) @@ -1218,7 +1218,7 @@ func TestIfThenElseNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newIfThenElseEval(tt.if_, tt.then, tt.else_) - v, err := n.Eval(NewEnv()) + v, err := n.Eval(Env{}) testutil.ErrorIs(t, err, tt.err) testutil.Equals(t, v, tt.result) }) @@ -1244,7 +1244,7 @@ func TestEqualNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newEqualEval(tt.lhs, tt.rhs) - v, err := n.Eval(NewEnv()) + v, err := n.Eval(Env{}) testutil.ErrorIs(t, err, tt.err) AssertValue(t, v, tt.result) }) @@ -1270,7 +1270,7 @@ func TestNotEqualNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newNotEqualEval(tt.lhs, tt.rhs) - v, err := n.Eval(NewEnv()) + v, err := n.Eval(Env{}) testutil.ErrorIs(t, err, tt.err) AssertValue(t, v, tt.result) }) @@ -1311,7 +1311,7 @@ func TestSetLiteralNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newSetLiteralEval(tt.elems) - v, err := n.Eval(NewEnv()) + v, err := n.Eval(Env{}) testutil.ErrorIs(t, err, tt.err) AssertValue(t, v, tt.result) }) @@ -1335,7 +1335,7 @@ func TestContainsNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newContainsEval(tt.lhs, tt.rhs) - v, err := n.Eval(NewEnv()) + v, err := n.Eval(Env{}) testutil.ErrorIs(t, err, tt.err) AssertZeroValue(t, v) }) @@ -1364,7 +1364,7 @@ func TestContainsNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newContainsEval(tt.lhs, tt.rhs) - v, err := n.Eval(NewEnv()) + v, err := n.Eval(Env{}) testutil.OK(t, err) AssertBoolValue(t, v, tt.result) }) @@ -1390,7 +1390,7 @@ func TestContainsAllNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newContainsAllEval(tt.lhs, tt.rhs) - v, err := n.Eval(NewEnv()) + v, err := n.Eval(Env{}) testutil.ErrorIs(t, err, tt.err) AssertZeroValue(t, v) }) @@ -1418,7 +1418,7 @@ func TestContainsAllNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newContainsAllEval(tt.lhs, tt.rhs) - v, err := n.Eval(NewEnv()) + v, err := n.Eval(Env{}) testutil.OK(t, err) AssertBoolValue(t, v, tt.result) }) @@ -1444,7 +1444,7 @@ func TestContainsAnyNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newContainsAnyEval(tt.lhs, tt.rhs) - v, err := n.Eval(NewEnv()) + v, err := n.Eval(Env{}) testutil.ErrorIs(t, err, tt.err) AssertZeroValue(t, v) }) @@ -1475,7 +1475,7 @@ func TestContainsAnyNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newContainsAnyEval(tt.lhs, tt.rhs) - v, err := n.Eval(NewEnv()) + v, err := n.Eval(Env{}) testutil.OK(t, err) AssertBoolValue(t, v, tt.result) }) @@ -1498,7 +1498,7 @@ func TestContainsAnyNode(t *testing.T) { n := newContainsAnyEval(newLiteralEval(types.NewSet(set1)), newLiteralEval(types.NewSet(set2))) // This call would take several minutes if the evaluation of ContainsAny was quadratic - val, err := n.Eval(NewEnv()) + val, err := n.Eval(Env{}) testutil.OK(t, err) testutil.Equals(t, val.(types.Boolean), types.False) @@ -1529,7 +1529,7 @@ func TestRecordLiteralNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newRecordLiteralEval(tt.elems) - v, err := n.Eval(NewEnv()) + v, err := n.Eval(Env{}) testutil.ErrorIs(t, err, tt.err) AssertValue(t, v, tt.result) }) @@ -1587,11 +1587,11 @@ func TestAttributeAccessNode(t *testing.T) { UID: types.NewEntityUID("knownType", "knownID"), Attributes: types.NewRecord(types.RecordMap{"knownAttr": types.Long(42)}), } - v, err := n.Eval(InitEnv(&Env{ + v, err := n.Eval(Env{ Entities: types.Entities{ entity.UID: entity, }, - })) + }) testutil.ErrorIs(t, err, tt.err) AssertValue(t, v, tt.result) }) @@ -1644,11 +1644,11 @@ func TestHasNode(t *testing.T) { UID: types.NewEntityUID("knownType", "knownID"), Attributes: types.NewRecord(types.RecordMap{"knownAttr": types.Long(42)}), } - v, err := n.Eval(InitEnv(&Env{ + v, err := n.Eval(Env{ Entities: types.Entities{ entity.UID: entity, }, - })) + }) testutil.ErrorIs(t, err, tt.err) AssertValue(t, v, tt.result) }) @@ -1702,7 +1702,7 @@ func TestLikeNode(t *testing.T) { pat, err := parser.ParsePattern(tt.pattern[1 : len(tt.pattern)-1]) testutil.OK(t, err) n := newLikeEval(tt.str, pat) - v, err := n.Eval(NewEnv()) + v, err := n.Eval(Env{}) testutil.ErrorIs(t, err, tt.err) AssertValue(t, v, tt.result) }) @@ -1739,7 +1739,7 @@ func TestVariableNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newVariableEval(tt.variable) - v, err := n.Eval(InitEnv(&tt.env)) + v, err := n.Eval(tt.env) testutil.OK(t, err) AssertValue(t, v, tt.result) }) @@ -1849,7 +1849,7 @@ func TestEntityIn(t *testing.T) { Parents: types.NewEntityUIDSet(ps...), } } - res := entityInSet(&Env{Entities: entityMap}, strEnt(tt.lhs), types.NewEntityUIDSet(rhs...)) + res := entityInSet(Env{Entities: entityMap}, strEnt(tt.lhs), types.NewEntityUIDSet(rhs...)) testutil.Equals(t, res, tt.result) }) } @@ -1877,7 +1877,7 @@ func TestEntityIn(t *testing.T) { } res := entityInSet( - &Env{Entities: entityMap}, + Env{Entities: entityMap}, types.NewEntityUID("0", "1"), types.NewEntityUIDSet(types.NewEntityUID("0", "3")), ) @@ -1903,7 +1903,7 @@ func TestIsNode(t *testing.T) { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() - got, err := newIsEval(tt.lhs, tt.rhs).Eval(NewEnv()) + got, err := newIsEval(tt.lhs, tt.rhs).Eval(Env{}) testutil.ErrorIs(t, err, tt.err) AssertValue(t, got, tt.result) }) @@ -2019,7 +2019,7 @@ func TestInNode(t *testing.T) { Parents: types.NewEntityUIDSet(ps...), } } - ec := InitEnv(&Env{Entities: entityMap}) + ec := Env{Entities: entityMap} v, err := n.Eval(ec) testutil.ErrorIs(t, err, tt.err) AssertValue(t, v, tt.result) @@ -2159,7 +2159,7 @@ func TestIsInNode(t *testing.T) { Parents: types.NewEntityUIDSet(ps...), } } - ec := InitEnv(&Env{Entities: entityMap}) + ec := Env{Entities: entityMap} v, err := n.Eval(ec) testutil.ErrorIs(t, err, tt.err) AssertValue(t, v, tt.result) @@ -2185,7 +2185,7 @@ func TestDecimalLiteralNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newDecimalLiteralEval(tt.arg) - v, err := n.Eval(NewEnv()) + v, err := n.Eval(Env{}) testutil.ErrorIs(t, err, tt.err) AssertValue(t, v, tt.result) }) @@ -2212,7 +2212,7 @@ func TestIPLiteralNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newIPLiteralEval(tt.arg) - v, err := n.Eval(NewEnv()) + v, err := n.Eval(Env{}) testutil.ErrorIs(t, err, tt.err) AssertValue(t, v, tt.result) }) @@ -2250,7 +2250,7 @@ func TestIPTestNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newIPTestEval(tt.lhs, tt.rhs) - v, err := n.Eval(NewEnv()) + v, err := n.Eval(Env{}) testutil.ErrorIs(t, err, tt.err) AssertValue(t, v, tt.result) }) @@ -2288,7 +2288,7 @@ func TestIPIsInRangeNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newIPIsInRangeEval(tt.lhs, tt.rhs) - v, err := n.Eval(NewEnv()) + v, err := n.Eval(Env{}) testutil.ErrorIs(t, err, tt.err) AssertValue(t, v, tt.result) }) @@ -2325,31 +2325,6 @@ func TestCedarString(t *testing.T) { } } -func TestCache(t *testing.T) { - t.Parallel() - env := NewEnv() - e1Eval := newLiteralEval(types.NewEntityUID("T", "1")) - e2Eval := newLiteralEval(types.NewEntityUID("T", "2")) - var res types.Value - var err error - res, err = newInEval(e1Eval, e1Eval).Eval(env) - testutil.OK(t, err) - testutil.Equals(t, res, types.Value(types.True)) - - res, err = newInEval(e1Eval, e2Eval).Eval(env) - testutil.OK(t, err) - testutil.Equals(t, res, types.Value(types.False)) - - env = InitEnvWithCacheFrom(&Env{}, env) - res, err = newInEval(e1Eval, e1Eval).Eval(env) - testutil.OK(t, err) - testutil.Equals(t, res, types.Value(types.True)) - - res, err = newInEval(e1Eval, e2Eval).Eval(env) - testutil.OK(t, err) - testutil.Equals(t, res, types.Value(types.False)) -} - func TestDatetimeLiteralNode(t *testing.T) { t.Parallel() tests := []struct { @@ -2368,7 +2343,7 @@ func TestDatetimeLiteralNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newDatetimeLiteralEval(tt.arg) - v, err := n.Eval(NewEnv()) + v, err := n.Eval(Env{}) testutil.ErrorIs(t, err, tt.err) AssertValue(t, v, tt.result) }) @@ -2395,7 +2370,7 @@ func TestDatetimeToDate(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newToDateEval(tt.arg) - v, err := n.Eval(&Env{}) + v, err := n.Eval(Env{}) testutil.ErrorIs(t, err, tt.err) AssertValue(t, v, tt.result) }) @@ -2428,7 +2403,7 @@ func TestDatetimeDurationSince(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newDurationSinceEval(tt.lhs, tt.rhs) - v, err := n.Eval(&Env{}) + v, err := n.Eval(Env{}) testutil.ErrorIs(t, err, tt.err) AssertValue(t, v, tt.result) }) @@ -2461,7 +2436,7 @@ func TestDatetimeOffset(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newOffsetEval(tt.lhs, tt.rhs) - v, err := n.Eval(&Env{}) + v, err := n.Eval(Env{}) testutil.ErrorIs(t, err, tt.err) AssertValue(t, v, tt.result) }) @@ -2488,7 +2463,7 @@ func TestDatetimeToTime(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newToTimeEval(tt.arg) - v, err := n.Eval(&Env{}) + v, err := n.Eval(Env{}) testutil.ErrorIs(t, err, tt.err) AssertValue(t, v, tt.result) }) @@ -2513,7 +2488,7 @@ func TestDurationLiteralNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newDurationLiteralEval(tt.arg) - v, err := n.Eval(NewEnv()) + v, err := n.Eval(Env{}) testutil.ErrorIs(t, err, tt.err) AssertValue(t, v, tt.result) }) @@ -2540,7 +2515,7 @@ func TestDurationToMilliseconds(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newToMillisecondsEval(tt.arg) - v, err := n.Eval(&Env{}) + v, err := n.Eval(Env{}) testutil.ErrorIs(t, err, tt.err) AssertValue(t, v, tt.result) }) @@ -2567,7 +2542,7 @@ func TestDurationToSeconds(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newToSecondsEval(tt.arg) - v, err := n.Eval(&Env{}) + v, err := n.Eval(Env{}) testutil.ErrorIs(t, err, tt.err) AssertValue(t, v, tt.result) }) @@ -2594,7 +2569,7 @@ func TestDurationToMinutes(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newToMinutesEval(tt.arg) - v, err := n.Eval(&Env{}) + v, err := n.Eval(Env{}) testutil.ErrorIs(t, err, tt.err) AssertValue(t, v, tt.result) }) @@ -2621,7 +2596,7 @@ func TestDurationToHours(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newToHoursEval(tt.arg) - v, err := n.Eval(&Env{}) + v, err := n.Eval(Env{}) testutil.ErrorIs(t, err, tt.err) AssertValue(t, v, tt.result) }) @@ -2648,7 +2623,7 @@ func TestDurationToDays(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newToDaysEval(tt.arg) - v, err := n.Eval(&Env{}) + v, err := n.Eval(Env{}) testutil.ErrorIs(t, err, tt.err) AssertValue(t, v, tt.result) }) diff --git a/internal/eval/fold.go b/internal/eval/fold.go index 20f440fe..17ef965c 100644 --- a/internal/eval/fold.go +++ b/internal/eval/fold.go @@ -55,7 +55,7 @@ func tryFold(nodes []ast.IsNode, } if allFolded { eval := mkEval(values) - v, err := eval.Eval(nil) + v, err := eval.Eval(Env{}) if err == nil { return ast.NodeValue{Value: v} } diff --git a/internal/eval/partial.go b/internal/eval/partial.go index 5b86a21f..4fc87b12 100644 --- a/internal/eval/partial.go +++ b/internal/eval/partial.go @@ -45,7 +45,7 @@ func IsIgnore(v types.Value) bool { // PartialPolicy returns a partially evaluated version of the policy and a boolean indicating if the policy should be kept. // (Policies that are determined to evaluate to false are not kept.) -func PartialPolicy(env *Env, p *ast.Policy) (policy *ast.Policy, keep bool) { +func PartialPolicy(env Env, p *ast.Policy) (policy *ast.Policy, keep bool) { p2 := *p if p2.Principal, keep = partialPrincipalScope(env, env.Principal, p2.Principal); !keep { return nil, false @@ -87,7 +87,7 @@ func PartialPolicy(env *Env, p *ast.Policy) (policy *ast.Policy, keep bool) { return &p2, true } -func partialPrincipalScope(env *Env, ent types.Value, scope ast.IsPrincipalScopeNode) (ast.IsPrincipalScopeNode, bool) { +func partialPrincipalScope(env Env, ent types.Value, scope ast.IsPrincipalScopeNode) (ast.IsPrincipalScopeNode, bool) { evaled, result := partialScopeEval(env, ent, scope) switch { case evaled && !result: @@ -99,7 +99,7 @@ func partialPrincipalScope(env *Env, ent types.Value, scope ast.IsPrincipalScope } } -func partialActionScope(env *Env, ent types.Value, scope ast.IsActionScopeNode) (ast.IsActionScopeNode, bool) { +func partialActionScope(env Env, ent types.Value, scope ast.IsActionScopeNode) (ast.IsActionScopeNode, bool) { evaled, result := partialScopeEval(env, ent, scope) switch { case evaled && !result: @@ -111,7 +111,7 @@ func partialActionScope(env *Env, ent types.Value, scope ast.IsActionScopeNode) } } -func partialResourceScope(env *Env, ent types.Value, scope ast.IsResourceScopeNode) (ast.IsResourceScopeNode, bool) { +func partialResourceScope(env Env, ent types.Value, scope ast.IsResourceScopeNode) (ast.IsResourceScopeNode, bool) { evaled, result := partialScopeEval(env, ent, scope) switch { case evaled && !result: @@ -123,7 +123,7 @@ func partialResourceScope(env *Env, ent types.Value, scope ast.IsResourceScopeNo } } -func partialScopeEval(env *Env, ent types.Value, in ast.IsScopeNode) (evaled bool, result bool) { +func partialScopeEval(env Env, ent types.Value, in ast.IsScopeNode) (evaled bool, result bool) { if IsVariable(ent) { return false, false } else if IsIgnore(ent) { @@ -156,7 +156,7 @@ var errVariable = fmt.Errorf("variable") var errIgnore = fmt.Errorf("ignore") // NOTE: nodes is modified in place, so be sure to send unique copy in -func tryPartial(env *Env, nodes []ast.IsNode, +func tryPartial(env Env, nodes []ast.IsNode, mkEval func(values []types.Value) Evaler, mkNode func(nodes []ast.IsNode) ast.IsNode, ) (ast.IsNode, error) { @@ -196,13 +196,13 @@ func tryPartial(env *Env, nodes []ast.IsNode, return mkNode(nodes), nil } -func tryPartialBinary(env *Env, v ast.BinaryNode, mkEval func(a, b Evaler) Evaler, wrap func(b ast.BinaryNode) ast.IsNode) (ast.IsNode, error) { +func tryPartialBinary(env Env, v ast.BinaryNode, mkEval func(a, b Evaler) Evaler, wrap func(b ast.BinaryNode) ast.IsNode) (ast.IsNode, error) { return tryPartial(env, []ast.IsNode{v.Left, v.Right}, func(values []types.Value) Evaler { return mkEval(newLiteralEval(values[0]), newLiteralEval(values[1])) }, func(nodes []ast.IsNode) ast.IsNode { return wrap(ast.BinaryNode{Left: nodes[0], Right: nodes[1]}) }, ) } -func tryPartialUnary(env *Env, v ast.UnaryNode, mkEval func(a Evaler) Evaler, wrap func(b ast.UnaryNode) ast.IsNode) (ast.IsNode, error) { +func tryPartialUnary(env Env, v ast.UnaryNode, mkEval func(a Evaler) Evaler, wrap func(b ast.UnaryNode) ast.IsNode) (ast.IsNode, error) { return tryPartial(env, []ast.IsNode{v.Arg}, func(values []types.Value) Evaler { return mkEval(newLiteralEval(values[0])) }, func(nodes []ast.IsNode) ast.IsNode { return wrap(ast.UnaryNode{Arg: nodes[0]}) }, @@ -210,7 +210,7 @@ func tryPartialUnary(env *Env, v ast.UnaryNode, mkEval func(a Evaler) Evaler, wr } // partial takes in an ast.Node and finds does as much as is possible given the context -func partial(env *Env, n ast.IsNode) (ast.IsNode, error) { +func partial(env Env, n ast.IsNode) (ast.IsNode, error) { switch v := n.(type) { case ast.NodeTypeAccess: return tryPartial(env, @@ -400,7 +400,7 @@ func isFalse(in ast.IsNode) bool { return v == types.Boolean(false) } -func partialIfThenElse(env *Env, v ast.NodeTypeIfThenElse) (ast.IsNode, error) { +func partialIfThenElse(env Env, v ast.NodeTypeIfThenElse) (ast.IsNode, error) { if_, ifErr := partial(env, v.If) switch { case errors.Is(ifErr, errVariable): @@ -428,7 +428,7 @@ func partialIfThenElse(env *Env, v ast.NodeTypeIfThenElse) (ast.IsNode, error) { return ast.NodeTypeIfThenElse{If: if_, Then: then, Else: else_}, nil } -func partialAnd(env *Env, v ast.NodeTypeAnd) (ast.IsNode, error) { +func partialAnd(env Env, v ast.NodeTypeAnd) (ast.IsNode, error) { left, leftErr := partial(env, v.Left) switch { case errors.Is(leftErr, errVariable): @@ -454,7 +454,7 @@ func partialAnd(env *Env, v ast.NodeTypeAnd) (ast.IsNode, error) { return ast.NodeTypeAnd{BinaryNode: ast.BinaryNode{Left: left, Right: right}}, nil } -func partialOr(env *Env, v ast.NodeTypeOr) (ast.IsNode, error) { +func partialOr(env Env, v ast.NodeTypeOr) (ast.IsNode, error) { left, leftErr := partial(env, v.Left) switch { case errors.Is(leftErr, errVariable): @@ -496,7 +496,7 @@ func newPartialHasEval(record Evaler, attribute types.String) *partialHasEval { return &partialHasEval{object: record, attribute: attribute} } -func (n *partialHasEval) Eval(env *Env) (types.Value, error) { +func (n *partialHasEval) Eval(env Env) (types.Value, error) { v, err := n.object.Eval(env) if err != nil { return zeroValue(), err @@ -530,7 +530,7 @@ func newPartialErrorEval(err Evaler) *partialErrorEval { } } -func (n *partialErrorEval) Eval(env *Env) (types.Value, error) { +func (n *partialErrorEval) Eval(env Env) (types.Value, error) { v, err := evalString(n.arg, env) if err != nil { return nil, err diff --git a/internal/eval/partial_test.go b/internal/eval/partial_test.go index a8b7dbd5..8a3a82b4 100644 --- a/internal/eval/partial_test.go +++ b/internal/eval/partial_test.go @@ -142,7 +142,7 @@ func TestPartialScopeEval(t *testing.T) { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() - evaled, result := partialScopeEval(InitEnv(&tt.env), tt.ent, tt.in) + evaled, result := partialScopeEval(tt.env, tt.ent, tt.in) testutil.Equals(t, evaled, tt.evaled) testutil.Equals(t, result, tt.result) }) @@ -153,7 +153,7 @@ func TestPartialScopeEval(t *testing.T) { func TestPartialScopeEvalPanic(t *testing.T) { t.Parallel() testutil.Panic(t, func() { - partialScopeEval(NewEnv(), types.NewEntityUID("T", "1"), nil) + partialScopeEval(Env{}, types.NewEntityUID("T", "1"), nil) }) } @@ -162,19 +162,19 @@ func TestPartialPolicy(t *testing.T) { tests := []struct { name string in *ast.Policy - env *Env + env Env out *ast.Policy keep bool }{ {"smokeTest", ast.Permit(), - &Env{}, + Env{}, ast.Permit(), true, }, {"principalEqual", ast.Permit().PrincipalEq(types.NewEntityUID("Account", "42")), - &Env{ + Env{ Principal: types.NewEntityUID("Account", "42"), }, ast.Permit(), @@ -182,7 +182,7 @@ func TestPartialPolicy(t *testing.T) { }, {"principalNotEqual", ast.Permit().PrincipalEq(types.NewEntityUID("Account", "42")), - &Env{ + Env{ Principal: types.NewEntityUID("Account", "Other"), }, nil, @@ -190,7 +190,7 @@ func TestPartialPolicy(t *testing.T) { }, {"actionEqual", ast.Permit().ActionEq(types.NewEntityUID("Action", "42")), - &Env{ + Env{ Action: types.NewEntityUID("Action", "42"), }, ast.Permit(), @@ -198,7 +198,7 @@ func TestPartialPolicy(t *testing.T) { }, {"actionNotEqual", ast.Permit().ActionEq(types.NewEntityUID("Action", "42")), - &Env{ + Env{ Action: types.NewEntityUID("Action", "Other"), }, nil, @@ -206,7 +206,7 @@ func TestPartialPolicy(t *testing.T) { }, {"resourceEqual", ast.Permit().ResourceEq(types.NewEntityUID("Resource", "42")), - &Env{ + Env{ Resource: types.NewEntityUID("Resource", "42"), }, ast.Permit(), @@ -214,7 +214,7 @@ func TestPartialPolicy(t *testing.T) { }, {"resourceNotEqual", ast.Permit().ResourceEq(types.NewEntityUID("Resource", "42")), - &Env{ + Env{ Resource: types.NewEntityUID("Resource", "Other"), }, nil, @@ -222,37 +222,37 @@ func TestPartialPolicy(t *testing.T) { }, {"conditionOmitTrue", ast.Permit().When(ast.True()), - &Env{}, + Env{}, ast.Permit(), true, }, {"conditionDropFalse", ast.Permit().When(ast.False()), - &Env{}, + Env{}, nil, false, }, {"conditionDropError", ast.Permit().When(ast.Long(42).GreaterThan(ast.String("bananas"))), - &Env{}, + Env{}, ast.Permit().When(ast.NewNode(extError(errors.New("type error: expected long, got string")))), true, }, {"conditionDropTypeError", ast.Permit().When(ast.Long(42)), - &Env{}, + Env{}, ast.Permit().When(ast.NewNode(extError(errors.New("type error: condition expected bool")))), true, }, {"conditionKeepUnfolded", ast.Permit().When(ast.Context().GreaterThan(ast.Long(42))), - &Env{Context: Variable("context")}, + Env{Context: Variable("context")}, ast.Permit().When(ast.Context().GreaterThan(ast.Long(42))), true, }, {"conditionOmitTrueFolded", ast.Permit().When(ast.Context().GreaterThan(ast.Long(42))), - &Env{ + Env{ Context: types.Long(43), }, ast.Permit(), @@ -260,7 +260,7 @@ func TestPartialPolicy(t *testing.T) { }, {"conditionDropFalseFolded", ast.Permit().When(ast.Context().GreaterThan(ast.Long(42))), - &Env{ + Env{ Context: types.Long(41), }, nil, @@ -268,7 +268,7 @@ func TestPartialPolicy(t *testing.T) { }, {"conditionDropErrorFolded", ast.Permit().When(ast.Context().GreaterThan(ast.Long(42))), - &Env{ + Env{ Context: types.String("bananas"), }, ast.Permit().When(ast.NewNode(extError(errors.New("type error: expected long, got string")))), @@ -276,7 +276,7 @@ func TestPartialPolicy(t *testing.T) { }, {"contextVariableAccess", ast.Permit().When(ast.Context().Access("key").Equal(ast.Long(42))), - &Env{ + Env{ Context: types.NewRecord(types.RecordMap{ "key": Variable("var"), }), @@ -287,7 +287,7 @@ func TestPartialPolicy(t *testing.T) { {"ignorePermitContext", ast.Permit().When(ast.Context().Equal(ast.Long(42))), - &Env{ + Env{ Context: Ignore(), }, ast.Permit(), @@ -295,7 +295,7 @@ func TestPartialPolicy(t *testing.T) { }, {"ignoreForbidContext", ast.Forbid().When(ast.Context().Equal(ast.Long(42))), - &Env{ + Env{ Context: Ignore(), }, nil, @@ -303,7 +303,7 @@ func TestPartialPolicy(t *testing.T) { }, {"ignorePermitScope", ast.Permit().PrincipalEq(types.NewEntityUID("T", "42")), - &Env{ + Env{ Principal: Ignore(), }, ast.Permit(), @@ -311,7 +311,7 @@ func TestPartialPolicy(t *testing.T) { }, {"ignoreForbidScope", ast.Forbid().PrincipalEq(types.NewEntityUID("T", "42")), - &Env{ + Env{ Principal: Ignore(), }, ast.Forbid(), @@ -319,7 +319,7 @@ func TestPartialPolicy(t *testing.T) { }, {"ignoreAnd", ast.Permit().When(ast.Context().Access("variable").And(ast.Context().Access("ignore").Equal(ast.Long(42)))), - &Env{ + Env{ Context: types.NewRecord(types.RecordMap{ "ignore": Ignore(), "variable": Variable("variable"), @@ -330,7 +330,7 @@ func TestPartialPolicy(t *testing.T) { }, {"ignoreOr", ast.Permit().When(ast.Context().Access("variable").Or(ast.Context().Access("ignore").Equal(ast.Long(42)))), - &Env{ + Env{ Context: types.NewRecord(types.RecordMap{ "ignore": Ignore(), "variable": Variable("variable"), @@ -341,7 +341,7 @@ func TestPartialPolicy(t *testing.T) { }, {"ignoreIfThen", ast.Permit().When(ast.IfThenElse(ast.Context().Access("variable"), ast.Context().Access("ignore").Equal(ast.Long(42)), ast.True())), - &Env{ + Env{ Context: types.NewRecord(types.RecordMap{ "ignore": Ignore(), "variable": Variable("variable"), @@ -352,7 +352,7 @@ func TestPartialPolicy(t *testing.T) { }, {"ignoreIfElse", ast.Permit().When(ast.IfThenElse(ast.Context().Access("variable"), ast.True(), ast.Context().Access("ignore").Equal(ast.Long(42)))), - &Env{ + Env{ Context: types.NewRecord(types.RecordMap{ "ignore": Ignore(), "variable": Variable("variable"), @@ -363,7 +363,7 @@ func TestPartialPolicy(t *testing.T) { }, {"ignoreHas", ast.Permit().When(ast.Context().Has("ignore")), - &Env{ + Env{ Context: types.NewRecord(types.RecordMap{ "ignore": Ignore(), "variable": Variable("variable"), @@ -374,7 +374,7 @@ func TestPartialPolicy(t *testing.T) { }, {"ignoreHasNot", ast.Permit().When(ast.Not(ast.Context().Has("ignore"))), - &Env{ + Env{ Context: types.NewRecord(types.RecordMap{ "ignore": Ignore(), "variable": Variable("variable"), @@ -385,7 +385,7 @@ func TestPartialPolicy(t *testing.T) { }, {"errorShortCircuit", ast.Permit().When(ast.True()).When(ast.String("test").LessThan(ast.Long(42))).When(ast.Context().Access("variable")), - &Env{ + Env{ Context: types.NewRecord(types.RecordMap{ "variable": Variable("variable"), }), @@ -395,7 +395,7 @@ func TestPartialPolicy(t *testing.T) { }, {"errorShortCircuitKept", ast.Permit().When(ast.Context().Access("variable")).When(ast.String("test").LessThan(ast.Long(42))).When(ast.Context().Access("variable")), - &Env{ + Env{ Context: types.NewRecord(types.RecordMap{ "variable": Variable("variable"), }), @@ -405,7 +405,7 @@ func TestPartialPolicy(t *testing.T) { }, {"errorConditionShortCircuit", ast.Permit().When(ast.True()).When(ast.String("test")).When(ast.Context().Access("variable")), - &Env{ + Env{ Context: types.NewRecord(types.RecordMap{ "variable": Variable("variable"), }), @@ -415,7 +415,7 @@ func TestPartialPolicy(t *testing.T) { }, {"errorConditionShortCircuitKept", ast.Permit().When(ast.Context().Access("variable")).When(ast.String("test")).When(ast.Context().Access("variable")), - &Env{ + Env{ Context: types.NewRecord(types.RecordMap{ "variable": Variable("variable"), }), @@ -425,7 +425,7 @@ func TestPartialPolicy(t *testing.T) { }, {"errorConditionShortCircuitKeptDeeper", ast.Permit().When(ast.Context().Access("variable")).When(ast.String("test")).When(ast.Context().Access("variable")), - &Env{ + Env{ Context: types.NewRecord(types.RecordMap{ "variable": Variable("variable"), }), @@ -435,7 +435,7 @@ func TestPartialPolicy(t *testing.T) { }, {"keepDeepVariables", ast.Permit().When(ast.True().Equal(ast.False().Equal(ast.Context()))), - &Env{ + Env{ Context: Variable("context"), }, ast.Permit().When(ast.True().Equal(ast.False().Equal(ast.Context()))), @@ -446,7 +446,7 @@ func TestPartialPolicy(t *testing.T) { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() - out, keep := PartialPolicy(InitEnv(tt.env), tt.in) + out, keep := PartialPolicy(tt.env, tt.in) if keep { testutil.Equals(t, out, tt.out) // gotP := (*parser.Policy)(out) @@ -499,7 +499,7 @@ func TestPartialIfThenElse(t *testing.T) { t.Run(tt.name, func(t *testing.T) { n, ok := tt.in.AsIsNode().(ast.NodeTypeIfThenElse) testutil.Equals(t, ok, true) - out, err := partialIfThenElse(&Env{ + out, err := partialIfThenElse(Env{ Context: Variable("context"), }, n) tt.errTest(t, err) @@ -563,7 +563,7 @@ func TestPartialAnd(t *testing.T) { t.Run(tt.name, func(t *testing.T) { n, ok := tt.in.AsIsNode().(ast.NodeTypeAnd) testutil.Equals(t, ok, true) - out, err := partialAnd(&Env{ + out, err := partialAnd(Env{ Context: Variable("context"), }, n) tt.errTest(t, err) @@ -627,7 +627,7 @@ func TestPartialOr(t *testing.T) { t.Run(tt.name, func(t *testing.T) { n, ok := tt.in.AsIsNode().(ast.NodeTypeOr) testutil.Equals(t, ok, true) - out, err := partialOr(&Env{ + out, err := partialOr(Env{ Context: Variable("context"), }, n) tt.errTest(t, err) @@ -1196,7 +1196,7 @@ func TestPartialBasic(t *testing.T) { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() - out, err := partial(InitEnv(&Env{ + out, err := partial((Env{ Principal: Variable("principal"), Action: Variable("action"), Resource: Variable("resource"), @@ -1211,7 +1211,7 @@ func TestPartialBasic(t *testing.T) { func TestPartialPanic(t *testing.T) { t.Parallel() testutil.Panic(t, func() { - partial(NewEnv(), nil) + partial(Env{}, nil) }) } @@ -1241,7 +1241,7 @@ func TestPartialErrorEval(t *testing.T) { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() - out, err := tt.in.Eval(InitEnv(&tt.env)) + out, err := tt.in.Eval(tt.env) testutil.Equals(t, out, tt.out) tt.err(t, err) }) @@ -1288,7 +1288,7 @@ func TestPartialHasEval(t *testing.T) { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() - out, err := tt.in.Eval(InitEnv(&tt.env)) + out, err := tt.in.Eval(tt.env) testutil.Equals(t, out, tt.out) tt.err(t, err) }) diff --git a/x/exp/batch/batch.go b/x/exp/batch/batch.go index 11884b66..cf645863 100644 --- a/x/exp/batch/batch.go +++ b/x/exp/batch/batch.go @@ -62,7 +62,7 @@ type batchEvaler struct { policies map[types.PolicyID]*ast.Policy compiled bool evalers map[types.PolicyID]*idEvaler - env *eval.Env + env eval.Env callback Callback } @@ -146,13 +146,13 @@ func Authorize(ctx context.Context, ps *cedar.PolicySet, entityMap types.Entitie case request.Context == nil: return fmt.Errorf("%w: context", errMissingPart) } - be.env = eval.InitEnv(&eval.Env{ + be.env = eval.Env{ Entities: entityMap, Principal: request.Principal, Action: request.Action, Resource: request.Resource, Context: request.Context, - }) + } be.Values = Values{} for k, v := range request.Variables { be.Variables = append(be.Variables, variableItem{Key: k, Values: v}) @@ -225,7 +225,7 @@ func doBatch(ctx context.Context, be *batchEvaler) error { } // then loop the current variable - loopEnv := *be.env + loopEnv := be.env u := be.Variables[0] dummyVal := types.True _, chPrincipal := cloneSub(be.env.Principal, u.Key, dummyVal) @@ -235,7 +235,7 @@ func doBatch(ctx context.Context, be *batchEvaler) error { be.Variables = be.Variables[1:] be.Values = maps.Clone(be.Values) for _, v := range u.Values { - *be.env = loopEnv + be.env = loopEnv be.Values[u.Key] = v if chPrincipal { be.env.Principal, _ = cloneSub(loopEnv.Principal, u.Key, v) @@ -281,7 +281,7 @@ func diagnosticAuthzWithCallback(be *batchEvaler) error { return nil } -func isAuthorized(ps map[types.PolicyID]*idEvaler, env *eval.Env) (types.Decision, types.Diagnostic) { +func isAuthorized(ps map[types.PolicyID]*idEvaler, env eval.Env) (types.Decision, types.Diagnostic) { var diag types.Diagnostic var forbids []types.DiagnosticReason var permits []types.DiagnosticReason