From b7d2dc65d9d164833b3345427572c2c96d1ddb43 Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Mon, 29 Jul 2024 08:33:49 -0700 Subject: [PATCH 001/216] cedar-go/x/exp/ast: Add initial constructor functions for stable AST Signed-off-by: philhassey --- x/exp/ast/annotation.go | 10 +++ x/exp/ast/ast_test.go | 40 ++++++++++++ x/exp/ast/node.go | 48 ++++++++++++++ x/exp/ast/operator.go | 140 ++++++++++++++++++++++++++++++++++++++++ x/exp/ast/policy.go | 35 ++++++++++ x/exp/ast/scope.go | 55 ++++++++++++++++ x/exp/ast/value.go | 51 +++++++++++++++ x/exp/ast/variable.go | 33 ++++++++++ x/exp/types/types.go | 46 +++++++++++++ 9 files changed, 458 insertions(+) create mode 100644 x/exp/ast/annotation.go create mode 100644 x/exp/ast/ast_test.go create mode 100644 x/exp/ast/node.go create mode 100644 x/exp/ast/operator.go create mode 100644 x/exp/ast/policy.go create mode 100644 x/exp/ast/scope.go create mode 100644 x/exp/ast/value.go create mode 100644 x/exp/ast/variable.go create mode 100644 x/exp/types/types.go diff --git a/x/exp/ast/annotation.go b/x/exp/ast/annotation.go new file mode 100644 index 00000000..55a67031 --- /dev/null +++ b/x/exp/ast/annotation.go @@ -0,0 +1,10 @@ +package ast + +func (p *Policy) Annotate(name string, value string) *Policy { + p.annotations = append(p.annotations, newAnnotationNode(name, value)) + return p +} + +func newAnnotationNode(name, value string) Node { + return newValueNode(nodeTypeAnnotation, []string{name, value}) +} diff --git a/x/exp/ast/ast_test.go b/x/exp/ast/ast_test.go new file mode 100644 index 00000000..1f6225b7 --- /dev/null +++ b/x/exp/ast/ast_test.go @@ -0,0 +1,40 @@ +package ast_test + +import ( + "testing" + + "github.com/cedar-policy/cedar-go/x/exp/ast" + "github.com/cedar-policy/cedar-go/x/exp/types" +) + +// These tests mostly verify that policy ASTs compile +func TestAst(t *testing.T) { + t.Parallel() + + johnny := types.EntityUID{"user", "johnny"} + sow := types.EntityUID{"Action", "sow"} + cast := types.EntityUID{"Action", "cast"} + + _ = ast.Permit(). + Annotate("example", "one"). + PrincipalEq(johnny). + ActionIn(sow, cast). + When(ast.True()). + Unless(ast.False()) + + _ = ast.Forbid(). + Annotate("example", "two"). + PrincipalEq(johnny). + ResourceIn(types.EntityUID{"Classification", "Poisonous"}) + + private := types.String("private") + + _ = ast.Forbid().Annotate("example", "three"). + When( + // TODO: It's a little annoying that we have to wrap private in ast.String here. + ast.Resource().Access("tags").Contains(ast.String(private)), + ). + Unless( + ast.Resource().In(ast.Principal().Access("allowed_resources")), + ) +} diff --git a/x/exp/ast/node.go b/x/exp/ast/node.go new file mode 100644 index 00000000..5a9732ea --- /dev/null +++ b/x/exp/ast/node.go @@ -0,0 +1,48 @@ +package ast + +type opType uint8 + +const ( + nodeTypeAccess opType = iota + nodeTypeAdd + nodeTypeAnd + nodeTypeAnnotation + nodeTypeBoolean + nodeTypeContains + nodeTypeContainsAll + nodeTypeContainsAny + nodeTypeEntity + nodeTypeEntityType + nodeTypeEquals + nodeTypeGreater + nodeTypeGreaterEqual + nodeTypeHas + nodeTypeIf + nodeTypeIn + nodeTypeIpAddr + nodeTypeIs + nodeTypeIsInRange + nodeTypeIsIpv4 + nodeTypeIsIpv6 + nodeTypeIsLoopback + nodeTypeIsMulticast + nodeTypeLess + nodeTypeLessEqual + nodeTypeLong + nodeTypeMap + nodeTypeMult + nodeTypeNot + nodeTypeNotEquals + nodeTypeOr + nodeTypeSet + nodeTypeSub + nodeTypeString + nodeTypeVariable +) + +type Node struct { + op opType + // TODO: Should we just have `value any`? + args []Node + value any +} diff --git a/x/exp/ast/operator.go b/x/exp/ast/operator.go new file mode 100644 index 00000000..8a49b59d --- /dev/null +++ b/x/exp/ast/operator.go @@ -0,0 +1,140 @@ +package ast + +import "github.com/cedar-policy/cedar-go/x/exp/types" + +// ____ _ +// / ___|___ _ __ ___ _ __ __ _ _ __(_)___ ___ _ __ +// | | / _ \| '_ ` _ \| '_ \ / _` | '__| / __|/ _ \| '_ \ +// | |__| (_) | | | | | | |_) | (_| | | | \__ \ (_) | | | | +// \____\___/|_| |_| |_| .__/ \__,_|_| |_|___/\___/|_| |_| +// |_| + +func (lhs Node) Equals(rhs Node) Node { + return newOpNode(nodeTypeEquals, lhs, rhs) +} + +func (lhs Node) NotEquals(rhs Node) Node { + return newOpNode(nodeTypeNotEquals, lhs, rhs) +} + +func (lhs Node) LessThan(rhs Node) Node { + return newOpNode(nodeTypeLess, lhs, rhs) +} + +func (lhs Node) LessThanOrEqual(rhs Node) Node { + return newOpNode(nodeTypeLessEqual, lhs, rhs) +} + +func (lhs Node) GreaterThan(rhs Node) Node { + return newOpNode(nodeTypeGreater, lhs, rhs) +} + +func (lhs Node) GreaterThanOrEqual(rhs Node) Node { + return newOpNode(nodeTypeGreaterEqual, lhs, rhs) +} + +// _ _ _ +// | | ___ __ _(_) ___ __ _| | +// | | / _ \ / _` | |/ __/ _` | | +// | |__| (_) | (_| | | (_| (_| | | +// |_____\___/ \__, |_|\___\__,_|_| +// |___/ + +func (lhs Node) And(rhs Node) Node { + return newOpNode(nodeTypeAnd, lhs, rhs) +} + +func (lhs Node) Or(rhs Node) Node { + return newOpNode(nodeTypeOr, lhs, rhs) +} + +func Not(rhs Node) Node { + return newOpNode(nodeTypeNot, rhs) +} + +func If(condition Node, ifTrue Node, ifFalse Node) Node { + return newOpNode(nodeTypeIf, condition, ifTrue, ifFalse) +} + +// _ _ _ _ _ _ +// / \ _ __(_) |_| |__ _ __ ___ ___| |_(_) ___ +// / _ \ | '__| | __| '_ \| '_ ` _ \ / _ \ __| |/ __| +// / ___ \| | | | |_| | | | | | | | | __/ |_| | (__ +// /_/ \_\_| |_|\__|_| |_|_| |_| |_|\___|\__|_|\___| + +func (lhs Node) Plus(rhs Node) Node { + return newOpNode(nodeTypeAdd, lhs, rhs) +} + +func (lhs Node) Minus(rhs Node) Node { + return newOpNode(nodeTypeSub, lhs, rhs) +} + +func (lhs Node) Times(rhs Node) Node { + return newOpNode(nodeTypeMult, lhs, rhs) +} + +// _ _ _ _ +// | | | (_) ___ _ __ __ _ _ __ ___| |__ _ _ +// | |_| | |/ _ \ '__/ _` | '__/ __| '_ \| | | | +// | _ | | __/ | | (_| | | | (__| | | | |_| | +// |_| |_|_|\___|_| \__,_|_| \___|_| |_|\__, | +// |___/ + +func (lhs Node) In(rhs Node) Node { + return newOpNode(nodeTypeIn, lhs, rhs) +} + +func (lhs Node) Has(rhs Node) Node { + return newOpNode(nodeTypeHas, lhs, rhs) +} + +func (lhs Node) Is(rhs Node) Node { + return newOpNode(nodeTypeIs, lhs, rhs) +} + +func (lhs Node) Contains(rhs Node) Node { + return newOpNode(nodeTypeContains, lhs, rhs) +} + +func (lhs Node) ContainsAll(rhs Node) Node { + return newOpNode(nodeTypeContainsAll, lhs, rhs) +} + +func (lhs Node) ContainsAny(rhs Node) Node { + return newOpNode(nodeTypeContainsAny, lhs, rhs) +} + +func (lhs Node) Access(rhs string) Node { + return newOpNode(nodeTypeAccess, lhs, String(types.String(rhs))) +} + +// ___ ____ _ _ _ +// |_ _| _ \ / \ __| | __| |_ __ ___ ___ ___ +// | || |_) / _ \ / _` |/ _` | '__/ _ \/ __/ __| +// | || __/ ___ \ (_| | (_| | | | __/\__ \__ \ +// |___|_| /_/ \_\__,_|\__,_|_| \___||___/___/ + +func (lhs Node) IsIpv4() Node { + return newOpNode(nodeTypeIsIpv4, lhs) +} + +func (lhs Node) IsIpv6() Node { + return newOpNode(nodeTypeIsIpv6, lhs) +} + +func (lhs Node) IsMulticast() Node { + return newOpNode(nodeTypeIsMulticast, lhs) +} + +func (lhs Node) IsLoopback() Node { + return newOpNode(nodeTypeIsLoopback, lhs) +} + +func (lhs Node) IsInRange(rhs Node) Node { + return newOpNode(nodeTypeIsInRange, lhs, rhs) +} + +func newOpNode(op opType, args ...Node) Node { + return Node{op: op, args: args} +} diff --git a/x/exp/ast/policy.go b/x/exp/ast/policy.go new file mode 100644 index 00000000..0e8dc49b --- /dev/null +++ b/x/exp/ast/policy.go @@ -0,0 +1,35 @@ +package ast + +type Policy struct { + effect effect + annotations []Node + principal Node + action Node + resource Node + conditions []Node +} + +func Permit() *Policy { + return &Policy{effect: effectPermit} +} + +func Forbid() *Policy { + return &Policy{effect: effectForbid} +} + +func (p *Policy) When(node Node) *Policy { + p.conditions = append(p.conditions, node) + return p +} + +func (p *Policy) Unless(node Node) *Policy { + p.conditions = append(p.conditions, Not(node)) + return p +} + +type effect bool + +const ( + effectPermit effect = true + effectForbid effect = false +) diff --git a/x/exp/ast/scope.go b/x/exp/ast/scope.go new file mode 100644 index 00000000..9a754e2c --- /dev/null +++ b/x/exp/ast/scope.go @@ -0,0 +1,55 @@ +package ast + +import "github.com/cedar-policy/cedar-go/x/exp/types" + +func (p *Policy) PrincipalEq(entity types.EntityUID) *Policy { + p.principal = Principal().Equals(Entity(entity)) + return p +} + +func (p *Policy) PrincipalIn(entities ...types.EntityUID) *Policy { + var entityValues []types.Value + for _, e := range entities { + entities = append(entities, e) + } + p.principal = Principal().In(Set(entityValues)) + return p +} + +func (p *Policy) PrincipalIs(entityType types.EntityType) *Policy { + p.principal = Principal().Is(EntityType(entityType)) + return p +} + +func (p *Policy) ActionEq(entity types.EntityUID) *Policy { + p.action = Action().Equals(Entity(entity)) + return p +} + +func (p *Policy) ActionIn(entities ...types.EntityUID) *Policy { + var entityValues []types.Value + for _, e := range entities { + entities = append(entities, e) + } + p.action = Action().In(Set(entityValues)) + return p +} + +func (p *Policy) ResourceEq(entity types.EntityUID) *Policy { + p.principal = Resource().Equals(Entity(entity)) + return p +} + +func (p *Policy) ResourceIn(entities ...types.EntityUID) *Policy { + var entityValues []types.Value + for _, e := range entities { + entities = append(entities, e) + } + p.principal = Resource().In(Set(entityValues)) + return p +} + +func (p *Policy) ResourceIs(entityType types.EntityType) *Policy { + p.principal = Resource().Is(EntityType(entityType)) + return p +} diff --git a/x/exp/ast/value.go b/x/exp/ast/value.go new file mode 100644 index 00000000..6c5451cf --- /dev/null +++ b/x/exp/ast/value.go @@ -0,0 +1,51 @@ +package ast + +import "github.com/cedar-policy/cedar-go/x/exp/types" + +func Boolean(b types.Boolean) Node { + return newValueNode(nodeTypeBoolean, b) +} + +func True() Node { + return Boolean(true) +} + +func False() Node { + return Boolean(false) +} + +func String(s types.String) Node { + return newValueNode(nodeTypeString, s) +} + +func Long(l types.Long) Node { + return newValueNode(nodeTypeLong, l) +} + +func Set(s types.Set) Node { + return newValueNode(nodeTypeSet, s) +} + +func Record(r types.Record) Node { + return newValueNode(nodeTypeMap, r) +} + +func EntityType(e types.EntityType) Node { + return newValueNode(nodeTypeEntityType, e) +} + +func Entity(e types.EntityUID) Node { + return newValueNode(nodeTypeEntity, e) +} + +func Decimal(d types.Decimal) Node { + return newValueNode(nodeTypeEntity, d) +} + +func IpAddr(i types.IpAddr) Node { + return newValueNode(nodeTypeIpAddr, i) +} + +func newValueNode(op opType, v any) Node { + return Node{op: op, value: v} +} diff --git a/x/exp/ast/variable.go b/x/exp/ast/variable.go new file mode 100644 index 00000000..5ef83687 --- /dev/null +++ b/x/exp/ast/variable.go @@ -0,0 +1,33 @@ +package ast + +func Principal() Node { + return newPrincipalNode() +} + +func Action() Node { + return newPrincipalNode() +} + +func Resource() Node { + return newResourceNode() +} + +func Context() Node { + return newContextNode() +} + +func newPrincipalNode() Node { + return newValueNode(nodeTypeVariable, "principal") +} + +func newActionNode() Node { + return newValueNode(nodeTypeVariable, "action") +} + +func newResourceNode() Node { + return newValueNode(nodeTypeVariable, "resource") +} + +func newContextNode() Node { + return newValueNode(nodeTypeVariable, "context") +} diff --git a/x/exp/types/types.go b/x/exp/types/types.go new file mode 100644 index 00000000..7463c0fe --- /dev/null +++ b/x/exp/types/types.go @@ -0,0 +1,46 @@ +package types + +import "net" + +type Value interface { + isValue() +} + +type Boolean bool + +func (Boolean) isValue() {} + +type String string + +func (String) isValue() {} + +type Long int64 + +func (Long) isValue() {} + +type Set []Value + +func (Set) isValue() {} + +type Record map[string]Value + +func (Record) isValue() {} + +type EntityType string + +type EntityUID struct { + Type string + ID string +} + +func (EntityUID) isValue() {} + +type Decimal []float64 + +func (Decimal) isValue() {} + +type IpAddr net.IPAddr + +func (IpAddr) isValue() {} + +// TODO: Variable? What if you want a Set{1, 2, context.bananaCount}? From 2e14d50dce2a1fb6d3f9f793c68665139899a5bd Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Mon, 29 Jul 2024 13:44:18 -0700 Subject: [PATCH 002/216] cedar-go/x/exp/ast: Add functions that allow the construction of Access, Set, and Record nodes given other AST nodes Also, beef up the documentation on the tests a bit to show how Cedar text is translated to the Golang representation. Signed-off-by: philhassey --- x/exp/ast/ast_test.go | 53 ++++++++++++++++++++++++------ x/exp/ast/node.go | 2 +- x/exp/ast/operator.go | 21 ++++++++++-- x/exp/ast/value.go | 75 +++++++++++++++++++++++++++++++++++++++++-- x/exp/types/types.go | 2 -- 5 files changed, 136 insertions(+), 17 deletions(-) diff --git a/x/exp/ast/ast_test.go b/x/exp/ast/ast_test.go index 1f6225b7..6c4aaa4f 100644 --- a/x/exp/ast/ast_test.go +++ b/x/exp/ast/ast_test.go @@ -11,10 +11,18 @@ import ( func TestAst(t *testing.T) { t.Parallel() - johnny := types.EntityUID{"user", "johnny"} + johnny := types.EntityUID{"User", "johnny"} sow := types.EntityUID{"Action", "sow"} cast := types.EntityUID{"Action", "cast"} + // @example("one") + // permit ( + // principal == User::"johnny" + // action in [Action::"sow", Action::"cast"] + // resource + // ) + // when { true } + // unless { false }; _ = ast.Permit(). Annotate("example", "one"). PrincipalEq(johnny). @@ -22,19 +30,46 @@ func TestAst(t *testing.T) { When(ast.True()). Unless(ast.False()) - _ = ast.Forbid(). - Annotate("example", "two"). - PrincipalEq(johnny). - ResourceIn(types.EntityUID{"Classification", "Poisonous"}) - + // @example("two") + // forbid (principal, action, resource) + // when { resource.tags.contains("private") } + // unless { resource in principal.allowed_resources }; private := types.String("private") - - _ = ast.Forbid().Annotate("example", "three"). + _ = ast.Forbid().Annotate("example", "two"). When( - // TODO: It's a little annoying that we have to wrap private in ast.String here. ast.Resource().Access("tags").Contains(ast.String(private)), ). Unless( ast.Resource().In(ast.Principal().Access("allowed_resources")), ) + + // forbid (principal, action, resource) + // when { resource[context.resourceField] == "specialValue" } + // when { {x: "value"}.x == "value" } + // when { {x: 1 + context.fooCount}.x == 3 } + // when { [1, 2 + 3, context.fooCount].contains(1) }; + simpleRecord := types.Record{ + "x": types.String("value"), + } + _ = ast.Forbid(). + When( + ast.Resource().AccessNode( + ast.Context().Access("resourceField"), + ).Equals(ast.String("specialValue")), + ). + When( + ast.Record(simpleRecord).Access("x").Equals(ast.String("value")), + ). + When( + ast.RecordNodes(map[string]ast.Node{ + "x": ast.Long(1).Plus(ast.Context().Access("fooCount")), + }).Access("x").Equals(ast.Long(3)), + ). + When( + ast.SetNodes([]ast.Node{ + ast.Long(1), + ast.Long(2).Plus(ast.Long(3)), + ast.Context().Access("fooCount"), + }).Contains(ast.Long(1)), + ) } diff --git a/x/exp/ast/node.go b/x/exp/ast/node.go index 5a9732ea..38f2b7e4 100644 --- a/x/exp/ast/node.go +++ b/x/exp/ast/node.go @@ -29,11 +29,11 @@ const ( nodeTypeLess nodeTypeLessEqual nodeTypeLong - nodeTypeMap nodeTypeMult nodeTypeNot nodeTypeNotEquals nodeTypeOr + nodeTypeRecord nodeTypeSet nodeTypeSub nodeTypeString diff --git a/x/exp/ast/operator.go b/x/exp/ast/operator.go index 8a49b59d..c3b89b21 100644 --- a/x/exp/ast/operator.go +++ b/x/exp/ast/operator.go @@ -105,8 +105,25 @@ func (lhs Node) ContainsAny(rhs Node) Node { return newOpNode(nodeTypeContainsAny, lhs, rhs) } -func (lhs Node) Access(rhs string) Node { - return newOpNode(nodeTypeAccess, lhs, String(types.String(rhs))) +// Access is a convenience function that wraps a simple string +// in an ast.String() and passes it along to AccessNode. +func (lhs Node) Access(attr string) Node { + return lhs.AccessNode(String(types.String(attr))) +} + +// AccessNode is a version of the access operator which allows +// more complex access of attributes, such as might be expressed +// by this Cedar text: +// +// resource[context.resourceAttribute] == "foo" +// +// In Golang, this could be expressed as: +// +// ast.Resource().AccessNode( +// ast.Context().Access("resourceAttribute") +// ).Equals(ast.String("foo")) +func (lhs Node) AccessNode(rhs Node) Node { + return newOpNode(nodeTypeAccess, lhs, rhs) } // ___ ____ _ _ _ diff --git a/x/exp/ast/value.go b/x/exp/ast/value.go index 6c5451cf..d5d46f26 100644 --- a/x/exp/ast/value.go +++ b/x/exp/ast/value.go @@ -1,6 +1,10 @@ package ast -import "github.com/cedar-policy/cedar-go/x/exp/types" +import ( + "fmt" + + "github.com/cedar-policy/cedar-go/x/exp/types" +) func Boolean(b types.Boolean) Node { return newValueNode(nodeTypeBoolean, b) @@ -22,12 +26,54 @@ func Long(l types.Long) Node { return newValueNode(nodeTypeLong, l) } +// Set is a convenience function that wraps concrete instances of a Cedar Set type +// types in AST value nodes and passes them along to SetNodes. func Set(s types.Set) Node { - return newValueNode(nodeTypeSet, s) + var nodes []Node + for _, v := range s { + nodes = append(nodes, valueToNode(v)) + } + return SetNodes(nodes) +} + +// SetNodes allows for a complex set definition with values potentially +// being Cedar expressions of their own. For example, this Cedar text: +// +// [1, 2 + 3, context.fooCount] +// +// could be expressed in Golang as: +// +// ast.SetNodes([]ast.Node{ +// ast.Long(1), +// ast.Long(2).Plus(ast.Long(3)), +// ast.Context().Access("fooCount"), +// }) +func SetNodes(nodes []Node) Node { + return newValueNode(nodeTypeSet, nodes) } +// Record is a convenience function that wraps concrete instances of a Cedar Record type +// types in AST value nodes and passes them along to RecordNodes. func Record(r types.Record) Node { - return newValueNode(nodeTypeMap, r) + recordNodes := map[string]Node{} + for k, v := range r { + recordNodes[k] = valueToNode(v) + } + return RecordNodes(recordNodes) +} + +// RecordNodes allows for a complex record definition with values potentially +// being Cedar expressions of their own. For example, this Cedar text: +// +// {"x": 1 + context.fooCount} +// +// could be expressed in Golang as: +// +// ast.RecordNodes([]ast.RecordNode{ +// {Key: "x", Value: ast.Long(1).Plus(ast.Context().Access("resourceField"))}, +// }) +func RecordNodes(nodes map[string]Node) Node { + return newValueNode(nodeTypeRecord, nodes) } func EntityType(e types.EntityType) Node { @@ -49,3 +95,26 @@ func IpAddr(i types.IpAddr) Node { func newValueNode(op opType, v any) Node { return Node{op: op, value: v} } + +func valueToNode(v types.Value) Node { + switch x := v.(type) { + case types.Boolean: + return Boolean(x) + case types.String: + return String(x) + case types.Long: + return Long(x) + case types.Set: + return Set(x) + case types.Record: + return Record(x) + case types.EntityUID: + return Entity(x) + case types.Decimal: + return Decimal(x) + case types.IpAddr: + return IpAddr(x) + default: + panic(fmt.Sprintf("unexpected value type: %T(%v)", v, v)) + } +} diff --git a/x/exp/types/types.go b/x/exp/types/types.go index 7463c0fe..bbf6c3aa 100644 --- a/x/exp/types/types.go +++ b/x/exp/types/types.go @@ -42,5 +42,3 @@ func (Decimal) isValue() {} type IpAddr net.IPAddr func (IpAddr) isValue() {} - -// TODO: Variable? What if you want a Set{1, 2, context.bananaCount}? From ef411f589a9c814f84c0660ce303da65e9e601b2 Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Mon, 29 Jul 2024 18:06:46 -0700 Subject: [PATCH 003/216] cedar-go/types: move Cedar data types into a separate library This will allow the new AST and parser to depend on the types without introducing a dependecy cycle by depending on cedar-go. Signed-off-by: philhassey --- cedar.go | 19 +- cedar_test.go | 686 +++++++++++--------- corpus_test.go | 28 +- eval.go | 466 ++++++-------- eval_test.go | 902 ++++++++++++++------------- match_test.go | 5 +- testutil/testutil.go | 44 ++ testutil_test.go | 83 --- toeval.go | 17 +- toeval_test.go | 20 +- json.go => types/json.go | 2 +- json_test.go => types/json_test.go | 162 +---- types/testutil.go | 36 ++ value.go => types/value.go | 201 ++++-- value_test.go => types/value_test.go | 420 ++++++------- x/exp/ast/ast_test.go | 2 +- x/exp/ast/node.go | 6 +- x/exp/ast/operator.go | 10 +- x/exp/ast/scope.go | 6 +- x/exp/ast/value.go | 14 +- x/exp/ast/variable.go | 2 +- x/exp/types/types.go | 44 -- 22 files changed, 1578 insertions(+), 1597 deletions(-) create mode 100644 testutil/testutil.go delete mode 100644 testutil_test.go rename json.go => types/json.go (99%) rename json_test.go => types/json_test.go (70%) create mode 100644 types/testutil.go rename value.go => types/value.go (79%) rename value_test.go => types/value_test.go (62%) delete mode 100644 x/exp/types/types.go diff --git a/cedar.go b/cedar.go index 4c72106e..6340a9df 100644 --- a/cedar.go +++ b/cedar.go @@ -6,6 +6,7 @@ import ( "fmt" "strings" + "github.com/cedar-policy/cedar-go/types" "github.com/cedar-policy/cedar-go/x/exp/parser" "golang.org/x/exp/maps" "golang.org/x/exp/slices" @@ -94,13 +95,13 @@ func NewPolicySet(fileName string, document []byte) (PolicySet, error) { // An Entities is a collection of all the Entities that are needed to evaluate // authorization requests. The key is an EntityUID which uniquely identifies // the Entity (it must be the same as the UID within the Entity itself.) -type Entities map[EntityUID]Entity +type Entities map[types.EntityUID]Entity // An Entity defines the parents and attributes for an EntityUID. type Entity struct { - UID EntityUID `json:"uid"` - Parents []EntityUID `json:"parents,omitempty"` - Attributes Record `json:"attrs"` + UID types.EntityUID `json:"uid"` + Parents []types.EntityUID `json:"parents,omitempty"` + Attributes types.Record `json:"attrs"` } func (e Entities) MarshalJSON() ([]byte, error) { @@ -188,10 +189,10 @@ type Reason struct { // A Request is the Principal, Action, Resource, and Context portion of an // authorization request. type Request struct { - Principal EntityUID `json:"principal"` - Action EntityUID `json:"action"` - Resource EntityUID `json:"resource"` - Context Record `json:"context"` + Principal types.EntityUID `json:"principal"` + Action types.EntityUID `json:"action"` + Resource types.EntityUID `json:"resource"` + Context types.Record `json:"context"` } // IsAuthorized uses the combination of the PolicySet and Entities to determine @@ -220,7 +221,7 @@ func (p PolicySet) IsAuthorized(entities Entities, req Request) (Decision, Diagn diag.Errors = append(diag.Errors, Error{Policy: n, Position: po.Position, Message: err.Error()}) continue } - vb, err := valueToBool(v) + vb, err := types.ValueToBool(v) if err != nil { // should never happen, maybe remove this case diag.Errors = append(diag.Errors, Error{Policy: n, Position: po.Position, Message: err.Error()}) diff --git a/cedar_test.go b/cedar_test.go index 1f34b863..774e4233 100644 --- a/cedar_test.go +++ b/cedar_test.go @@ -1,27 +1,31 @@ package cedar import ( + "encoding/json" "net/netip" "testing" + + "github.com/cedar-policy/cedar-go/testutil" + "github.com/cedar-policy/cedar-go/types" ) func TestEntityIsZero(t *testing.T) { t.Parallel() tests := []struct { name string - uid EntityUID + uid types.EntityUID want bool }{ - {"empty", EntityUID{}, true}, - {"empty-type", NewEntityUID("one", ""), false}, - {"empty-id", NewEntityUID("", "one"), false}, - {"not-empty", NewEntityUID("one", "two"), false}, + {"empty", types.EntityUID{}, true}, + {"empty-type", types.NewEntityUID("one", ""), false}, + {"empty-id", types.NewEntityUID("", "one"), false}, + {"not-empty", types.NewEntityUID("one", "two"), false}, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() - testutilEquals(t, tt.uid.IsZero(), tt.want) + testutil.Equals(t, tt.uid.IsZero(), tt.want) }) } } @@ -31,18 +35,18 @@ func TestNewPolicySet(t *testing.T) { t.Run("err-in-tokenize", func(t *testing.T) { t.Parallel() _, err := NewPolicySet("policy.cedar", []byte(`"`)) - testutilError(t, err) + testutil.Error(t, err) }) t.Run("err-in-parse", func(t *testing.T) { t.Parallel() _, err := NewPolicySet("policy.cedar", []byte(`err`)) - testutilError(t, err) + testutil.Error(t, err) }) t.Run("annotations", func(t *testing.T) { t.Parallel() ps, err := NewPolicySet("policy.cedar", []byte(`@key("value") permit (principal, action, resource);`)) - testutilOK(t, err) - testutilEquals(t, ps[0].Annotations, Annotations{"key": "value"}) + testutil.OK(t, err) + testutil.Equals(t, ps[0].Annotations, Annotations{"key": "value"}) }) } @@ -52,8 +56,8 @@ func TestIsAuthorized(t *testing.T) { Name string Policy string Entities Entities - Principal, Action, Resource EntityUID - Context Record + Principal, Action, Resource types.EntityUID + Context types.Record Want Decision DiagErr int }{ @@ -61,10 +65,10 @@ func TestIsAuthorized(t *testing.T) { Name: "simple-permit", Policy: `permit(principal,action,resource);`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: true, DiagErr: 0, }, @@ -72,10 +76,10 @@ func TestIsAuthorized(t *testing.T) { Name: "simple-forbid", Policy: `forbid(principal,action,resource);`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: false, DiagErr: 0, }, @@ -83,10 +87,10 @@ func TestIsAuthorized(t *testing.T) { Name: "no-permit", Policy: `permit(principal,action,resource in asdf::"1234");`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: false, DiagErr: 0, }, @@ -94,10 +98,10 @@ func TestIsAuthorized(t *testing.T) { Name: "error-in-policy", Policy: `permit(principal,action,resource) when { resource in "foo" };`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: false, DiagErr: 1, }, @@ -107,10 +111,10 @@ func TestIsAuthorized(t *testing.T) { permit(principal,action,resource); `, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: true, DiagErr: 1, }, @@ -118,10 +122,10 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-requires-context-success", Policy: `permit(principal,action,resource) when { context.x == 42 };`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{"x": Long(42)}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{"x": types.Long(42)}, Want: true, DiagErr: 0, }, @@ -129,10 +133,10 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-requires-context-fail", Policy: `permit(principal,action,resource) when { context.x == 42 };`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{"x": Long(43)}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{"x": types.Long(43)}, Want: false, DiagErr: 0, }, @@ -141,14 +145,14 @@ func TestIsAuthorized(t *testing.T) { Policy: `permit(principal,action,resource) when { principal.x == 42 };`, Entities: entitiesFromSlice([]Entity{ { - UID: EntityUID{"coder", "cuzco"}, - Attributes: Record{"x": Long(42)}, + UID: types.EntityUID{"coder", "cuzco"}, + Attributes: types.Record{"x": types.Long(42)}, }, }), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: true, DiagErr: 0, }, @@ -157,14 +161,14 @@ func TestIsAuthorized(t *testing.T) { Policy: `permit(principal,action,resource) when { principal.x == 42 };`, Entities: entitiesFromSlice([]Entity{ { - UID: EntityUID{"coder", "cuzco"}, - Attributes: Record{"x": Long(43)}, + UID: types.EntityUID{"coder", "cuzco"}, + Attributes: types.Record{"x": types.Long(43)}, }, }), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: false, DiagErr: 0, }, @@ -173,14 +177,14 @@ func TestIsAuthorized(t *testing.T) { Policy: `permit(principal,action,resource) when { principal in parent::"bob" };`, Entities: entitiesFromSlice([]Entity{ { - UID: EntityUID{"coder", "cuzco"}, - Parents: []EntityUID{{"parent", "bob"}}, + UID: types.EntityUID{"coder", "cuzco"}, + Parents: []types.EntityUID{{"parent", "bob"}}, }, }), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: true, DiagErr: 0, }, @@ -188,10 +192,10 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-principal-equals", Policy: `permit(principal == coder::"cuzco",action,resource);`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: true, DiagErr: 0, }, @@ -200,14 +204,14 @@ func TestIsAuthorized(t *testing.T) { Policy: `permit(principal in team::"osiris",action,resource);`, Entities: entitiesFromSlice([]Entity{ { - UID: EntityUID{"coder", "cuzco"}, - Parents: []EntityUID{{"team", "osiris"}}, + UID: types.EntityUID{"coder", "cuzco"}, + Parents: []types.EntityUID{{"team", "osiris"}}, }, }), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: true, DiagErr: 0, }, @@ -215,10 +219,10 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-action-equals", Policy: `permit(principal,action == table::"drop",resource);`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: true, DiagErr: 0, }, @@ -227,14 +231,14 @@ func TestIsAuthorized(t *testing.T) { Policy: `permit(principal,action in scary::"stuff",resource);`, Entities: entitiesFromSlice([]Entity{ { - UID: EntityUID{"table", "drop"}, - Parents: []EntityUID{{"scary", "stuff"}}, + UID: types.EntityUID{"table", "drop"}, + Parents: []types.EntityUID{{"scary", "stuff"}}, }, }), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: true, DiagErr: 0, }, @@ -243,14 +247,14 @@ func TestIsAuthorized(t *testing.T) { Policy: `permit(principal,action in [scary::"stuff"],resource);`, Entities: entitiesFromSlice([]Entity{ { - UID: EntityUID{"table", "drop"}, - Parents: []EntityUID{{"scary", "stuff"}}, + UID: types.EntityUID{"table", "drop"}, + Parents: []types.EntityUID{{"scary", "stuff"}}, }, }), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: true, DiagErr: 0, }, @@ -258,10 +262,10 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-resource-equals", Policy: `permit(principal,action,resource == table::"whatever");`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: true, DiagErr: 0, }, @@ -269,10 +273,10 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-unless", Policy: `permit(principal,action,resource) unless { false };`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: true, DiagErr: 0, }, @@ -280,10 +284,10 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-if", Policy: `permit(principal,action,resource) when { (if true then true else true) };`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: true, DiagErr: 0, }, @@ -291,10 +295,10 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-or", Policy: `permit(principal,action,resource) when { (true || false) };`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: true, DiagErr: 0, }, @@ -302,10 +306,10 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-and", Policy: `permit(principal,action,resource) when { (true && true) };`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: true, DiagErr: 0, }, @@ -313,10 +317,10 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-relations", Policy: `permit(principal,action,resource) when { (1<2) && (1<=1) && (2>1) && (1>=1) && (1!=2) && (1==1)};`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: true, DiagErr: 0, }, @@ -324,10 +328,10 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-relations-in", Policy: `permit(principal,action,resource) when { principal in principal };`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: true, DiagErr: 0, }, @@ -336,14 +340,14 @@ func TestIsAuthorized(t *testing.T) { Policy: `permit(principal,action,resource) when { principal has name };`, Entities: entitiesFromSlice([]Entity{ { - UID: EntityUID{"coder", "cuzco"}, - Attributes: Record{"name": String("bob")}, + UID: types.EntityUID{"coder", "cuzco"}, + Attributes: types.Record{"name": types.String("bob")}, }, }), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: true, DiagErr: 0, }, @@ -351,10 +355,10 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-add-sub", Policy: `permit(principal,action,resource) when { 40+3-1==42 };`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: true, DiagErr: 0, }, @@ -362,10 +366,10 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-mul", Policy: `permit(principal,action,resource) when { 6*7==42 };`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: true, DiagErr: 0, }, @@ -373,10 +377,10 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-negate", Policy: `permit(principal,action,resource) when { -42==-42 };`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: true, DiagErr: 0, }, @@ -384,10 +388,10 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-not", Policy: `permit(principal,action,resource) when { !(1+1==42) };`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: true, DiagErr: 0, }, @@ -395,10 +399,10 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-set", Policy: `permit(principal,action,resource) when { [1,2,3].contains(2) };`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: true, DiagErr: 0, }, @@ -406,10 +410,10 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-record", Policy: `permit(principal,action,resource) when { {name:"bob"} has name };`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: true, DiagErr: 0, }, @@ -417,10 +421,10 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-action", Policy: `permit(principal,action,resource) when { action in action };`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: true, DiagErr: 0, }, @@ -428,10 +432,10 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-set-contains-ok", Policy: `permit(principal,action,resource) when { [1,2,3].contains(2) };`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: true, DiagErr: 0, }, @@ -439,10 +443,10 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-set-contains-error", Policy: `permit(principal,action,resource) when { [1,2,3].contains(2,3) };`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: false, DiagErr: 1, }, @@ -450,10 +454,10 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-set-containsAll-ok", Policy: `permit(principal,action,resource) when { [1,2,3].containsAll([2,3]) };`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: true, DiagErr: 0, }, @@ -461,10 +465,10 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-set-containsAll-error", Policy: `permit(principal,action,resource) when { [1,2,3].containsAll(2,3) };`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: false, DiagErr: 1, }, @@ -472,10 +476,10 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-set-containsAny-ok", Policy: `permit(principal,action,resource) when { [1,2,3].containsAny([2,5]) };`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: true, DiagErr: 0, }, @@ -483,10 +487,10 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-set-containsAny-error", Policy: `permit(principal,action,resource) when { [1,2,3].containsAny(2,5) };`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: false, DiagErr: 1, }, @@ -494,10 +498,10 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-record-attr", Policy: `permit(principal,action,resource) when { {name:"bob"}["name"] == "bob" };`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: true, DiagErr: 0, }, @@ -505,10 +509,10 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-unknown-method", Policy: `permit(principal,action,resource) when { [1,2,3].shuffle() };`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: false, DiagErr: 1, }, @@ -516,10 +520,10 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-like", Policy: `permit(principal,action,resource) when { "bananas" like "*nan*" };`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: true, DiagErr: 0, }, @@ -527,10 +531,10 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-unknown-ext-fun", Policy: `permit(principal,action,resource) when { fooBar("10") };`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: false, DiagErr: 1, }, @@ -542,10 +546,10 @@ func TestIsAuthorized(t *testing.T) { decimal("10.0").greaterThan(decimal("9.0")) && decimal("10.0").greaterThanOrEqual(decimal("9.0")) };`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: true, DiagErr: 0, }, @@ -553,10 +557,10 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-decimal-fun-wrong-arity", Policy: `permit(principal,action,resource) when { decimal(1, 2) };`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: false, DiagErr: 1, }, @@ -569,10 +573,10 @@ func TestIsAuthorized(t *testing.T) { ip("224.1.2.3").isMulticast() && ip("127.0.0.1").isInRange(ip("127.0.0.0/16"))};`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: true, DiagErr: 0, }, @@ -580,10 +584,10 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-ip-fun-wrong-arity", Policy: `permit(principal,action,resource) when { ip() };`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: false, DiagErr: 1, }, @@ -591,10 +595,10 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-isIpv4-wrong-arity", Policy: `permit(principal,action,resource) when { ip("1.2.3.4").isIpv4(true) };`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: false, DiagErr: 1, }, @@ -602,10 +606,10 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-isIpv6-wrong-arity", Policy: `permit(principal,action,resource) when { ip("1.2.3.4").isIpv6(true) };`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: false, DiagErr: 1, }, @@ -613,10 +617,10 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-isLoopback-wrong-arity", Policy: `permit(principal,action,resource) when { ip("1.2.3.4").isLoopback(true) };`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: false, DiagErr: 1, }, @@ -624,10 +628,10 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-isMulticast-wrong-arity", Policy: `permit(principal,action,resource) when { ip("1.2.3.4").isMulticast(true) };`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: false, DiagErr: 1, }, @@ -635,10 +639,10 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-isInRange-wrong-arity", Policy: `permit(principal,action,resource) when { ip("1.2.3.4").isInRange() };`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"coder", "cuzco"}, - Action: EntityUID{"table", "drop"}, - Resource: EntityUID{"table", "whatever"}, - Context: Record{}, + Principal: types.EntityUID{"coder", "cuzco"}, + Action: types.EntityUID{"table", "drop"}, + Resource: types.EntityUID{"table", "whatever"}, + Context: types.Record{}, Want: false, DiagErr: 1, }, @@ -646,7 +650,7 @@ func TestIsAuthorized(t *testing.T) { Name: "negative-unary-op", Policy: `permit(principal,action,resource) when { -context.value > 0 };`, Entities: entitiesFromSlice(nil), - Context: Record{"value": Long(-42)}, + Context: types.Record{"value": types.Long(-42)}, Want: true, DiagErr: 0, }, @@ -654,10 +658,10 @@ func TestIsAuthorized(t *testing.T) { Name: "principal-is", Policy: `permit(principal is Actor,action,resource);`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"Actor", "cuzco"}, - Action: EntityUID{"Action", "drop"}, - Resource: EntityUID{"Resource", "table"}, - Context: Record{}, + Principal: types.EntityUID{"Actor", "cuzco"}, + Action: types.EntityUID{"Action", "drop"}, + Resource: types.EntityUID{"Resource", "table"}, + Context: types.Record{}, Want: true, DiagErr: 0, }, @@ -665,10 +669,10 @@ func TestIsAuthorized(t *testing.T) { Name: "principal-is-in", Policy: `permit(principal is Actor in Actor::"cuzco",action,resource);`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"Actor", "cuzco"}, - Action: EntityUID{"Action", "drop"}, - Resource: EntityUID{"Resource", "table"}, - Context: Record{}, + Principal: types.EntityUID{"Actor", "cuzco"}, + Action: types.EntityUID{"Action", "drop"}, + Resource: types.EntityUID{"Resource", "table"}, + Context: types.Record{}, Want: true, DiagErr: 0, }, @@ -676,10 +680,10 @@ func TestIsAuthorized(t *testing.T) { Name: "resource-is", Policy: `permit(principal,action,resource is Resource);`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"Actor", "cuzco"}, - Action: EntityUID{"Action", "drop"}, - Resource: EntityUID{"Resource", "table"}, - Context: Record{}, + Principal: types.EntityUID{"Actor", "cuzco"}, + Action: types.EntityUID{"Action", "drop"}, + Resource: types.EntityUID{"Resource", "table"}, + Context: types.Record{}, Want: true, DiagErr: 0, }, @@ -687,10 +691,10 @@ func TestIsAuthorized(t *testing.T) { Name: "resource-is-in", Policy: `permit(principal,action,resource is Resource in Resource::"table");`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"Actor", "cuzco"}, - Action: EntityUID{"Action", "drop"}, - Resource: EntityUID{"Resource", "table"}, - Context: Record{}, + Principal: types.EntityUID{"Actor", "cuzco"}, + Action: types.EntityUID{"Action", "drop"}, + Resource: types.EntityUID{"Resource", "table"}, + Context: types.Record{}, Want: true, DiagErr: 0, }, @@ -698,10 +702,10 @@ func TestIsAuthorized(t *testing.T) { Name: "when-is", Policy: `permit(principal,action,resource) when { resource is Resource };`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"Actor", "cuzco"}, - Action: EntityUID{"Action", "drop"}, - Resource: EntityUID{"Resource", "table"}, - Context: Record{}, + Principal: types.EntityUID{"Actor", "cuzco"}, + Action: types.EntityUID{"Action", "drop"}, + Resource: types.EntityUID{"Resource", "table"}, + Context: types.Record{}, Want: true, DiagErr: 0, }, @@ -709,10 +713,10 @@ func TestIsAuthorized(t *testing.T) { Name: "when-is-in", Policy: `permit(principal,action,resource) when { resource is Resource in Resource::"table" };`, Entities: entitiesFromSlice(nil), - Principal: EntityUID{"Actor", "cuzco"}, - Action: EntityUID{"Action", "drop"}, - Resource: EntityUID{"Resource", "table"}, - Context: Record{}, + Principal: types.EntityUID{"Actor", "cuzco"}, + Action: types.EntityUID{"Action", "drop"}, + Resource: types.EntityUID{"Resource", "table"}, + Context: types.Record{}, Want: true, DiagErr: 0, }, @@ -721,14 +725,14 @@ func TestIsAuthorized(t *testing.T) { Policy: `permit(principal,action,resource) when { resource is Resource in Parent::"id" };`, Entities: entitiesFromSlice([]Entity{ { - UID: EntityUID{"Resource", "table"}, - Parents: []EntityUID{{"Parent", "id"}}, + UID: types.EntityUID{"Resource", "table"}, + Parents: []types.EntityUID{{"Parent", "id"}}, }, }), - Principal: EntityUID{"Actor", "cuzco"}, - Action: EntityUID{"Action", "drop"}, - Resource: EntityUID{"Resource", "table"}, - Context: Record{}, + Principal: types.EntityUID{"Actor", "cuzco"}, + Action: types.EntityUID{"Action", "drop"}, + Resource: types.EntityUID{"Resource", "table"}, + Context: types.Record{}, Want: true, DiagErr: 0, }, @@ -738,15 +742,15 @@ func TestIsAuthorized(t *testing.T) { t.Run(tt.Name, func(t *testing.T) { t.Parallel() ps, err := NewPolicySet("policy.cedar", []byte(tt.Policy)) - testutilOK(t, err) + testutil.OK(t, err) ok, diag := ps.IsAuthorized(tt.Entities, Request{ Principal: tt.Principal, Action: tt.Action, Resource: tt.Resource, Context: tt.Context, }) - testutilEquals(t, ok, tt.Want) - testutilEquals(t, len(diag.Errors), tt.DiagErr) + testutil.Equals(t, ok, tt.Want) + testutil.Equals(t, len(diag.Errors), tt.DiagErr) }) } } @@ -757,41 +761,41 @@ func TestEntities(t *testing.T) { t.Parallel() s := []Entity{ { - UID: EntityUID{Type: "A", ID: "A"}, + UID: types.EntityUID{Type: "A", ID: "A"}, }, { - UID: EntityUID{Type: "A", ID: "B"}, + UID: types.EntityUID{Type: "A", ID: "B"}, }, { - UID: EntityUID{Type: "B", ID: "A"}, + UID: types.EntityUID{Type: "B", ID: "A"}, }, { - UID: EntityUID{Type: "B", ID: "B"}, + UID: types.EntityUID{Type: "B", ID: "B"}, }, } entities := entitiesFromSlice(s) s2 := entities.toSlice() - testutilEquals(t, s2, s) + testutil.Equals(t, s2, s) }) t.Run("Clone", func(t *testing.T) { t.Parallel() s := []Entity{ { - UID: EntityUID{Type: "A", ID: "A"}, + UID: types.EntityUID{Type: "A", ID: "A"}, }, { - UID: EntityUID{Type: "A", ID: "B"}, + UID: types.EntityUID{Type: "A", ID: "B"}, }, { - UID: EntityUID{Type: "B", ID: "A"}, + UID: types.EntityUID{Type: "B", ID: "A"}, }, { - UID: EntityUID{Type: "B", ID: "B"}, + UID: types.EntityUID{Type: "B", ID: "B"}, }, } entities := entitiesFromSlice(s) clone := entities.Clone() - testutilEquals(t, clone, entities) + testutil.Equals(t, clone, entities) }) } @@ -800,37 +804,37 @@ func TestValueFrom(t *testing.T) { t.Parallel() tests := []struct { name string - in Value + in types.Value outJSON string }{ { name: "string", - in: String("hello"), + in: types.String("hello"), outJSON: `"hello"`, }, { name: "bool", - in: Boolean(true), + in: types.Boolean(true), outJSON: `true`, }, { name: "int64", - in: Long(42), + in: types.Long(42), outJSON: `42`, }, { name: "int64", - in: EntityUID{Type: "T", ID: "0"}, + in: types.EntityUID{Type: "T", ID: "0"}, outJSON: `{"__entity":{"type":"T","id":"0"}}`, }, { name: "record", - in: Record{"K": Boolean(true)}, + in: types.Record{"K": types.Boolean(true)}, outJSON: `{"K":true}`, }, { name: "netipPrefix", - in: IPAddr(netip.MustParsePrefix("192.168.0.42/32")), + in: types.IPAddr(netip.MustParsePrefix("192.168.0.42/32")), outJSON: `{"__extn":{"fn":"ip","arg":"192.168.0.42"}}`, }, } @@ -840,8 +844,8 @@ func TestValueFrom(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() out, err := tt.in.ExplicitMarshalJSON() - testutilOK(t, err) - testutilEquals(t, string(out), tt.outJSON) + testutil.OK(t, err) + testutil.Equals(t, string(out), tt.outJSON) }) } } @@ -849,7 +853,7 @@ func TestValueFrom(t *testing.T) { func TestError(t *testing.T) { t.Parallel() e := Error{Policy: 42, Message: "bad error"} - testutilEquals(t, e.String(), "while evaluating policy `policy42`: bad error") + testutil.Equals(t, e.String(), "while evaluating policy `policy42`: bad error") } func TestInvalidPolicy(t *testing.T) { @@ -858,12 +862,12 @@ func TestInvalidPolicy(t *testing.T) { ps := PolicySet{ { Effect: Forbid, - eval: newLiteralEval(Long(42)), + eval: newLiteralEval(types.Long(42)), }, } ok, diag := ps.IsAuthorized(Entities{}, Request{}) - testutilEquals(t, ok, Deny) - testutilEquals(t, diag, Diagnostic{ + testutil.Equals(t, ok, Deny) + testutil.Equals(t, diag, Diagnostic{ Errors: []Error{ { Policy: 0, @@ -892,7 +896,7 @@ func TestCorpusRelated(t *testing.T) { ) when { (true && (((!870985681610) == principal) == principal)) && principal };`, - Request{Principal: NewEntityUID("a", "\u0000\u0000"), Action: NewEntityUID("Action", "action"), Resource: NewEntityUID("a", "\u0000\u0000")}, + Request{Principal: types.NewEntityUID("a", "\u0000\u0000"), Action: types.NewEntityUID("Action", "action"), Resource: types.NewEntityUID("a", "\u0000\u0000")}, Deny, nil, []int{0}, @@ -907,7 +911,7 @@ func TestCorpusRelated(t *testing.T) { ) when { (((!870985681610) == principal) == principal) };`, - Request{Principal: NewEntityUID("a", "\u0000\u0000"), Action: NewEntityUID("Action", "action"), Resource: NewEntityUID("a", "\u0000\u0000")}, + Request{Principal: types.NewEntityUID("a", "\u0000\u0000"), Action: types.NewEntityUID("Action", "action"), Resource: types.NewEntityUID("a", "\u0000\u0000")}, Deny, nil, []int{0}, @@ -921,7 +925,7 @@ func TestCorpusRelated(t *testing.T) { ) when { ((!870985681610) == principal) };`, - Request{Principal: NewEntityUID("a", "\u0000\u0000"), Action: NewEntityUID("Action", "action"), Resource: NewEntityUID("a", "\u0000\u0000")}, + Request{Principal: types.NewEntityUID("a", "\u0000\u0000"), Action: types.NewEntityUID("Action", "action"), Resource: types.NewEntityUID("a", "\u0000\u0000")}, Deny, nil, []int{0}, @@ -936,7 +940,7 @@ func TestCorpusRelated(t *testing.T) { ) when { (!870985681610) };`, - Request{Principal: NewEntityUID("a", "\u0000\u0000"), Action: NewEntityUID("Action", "action"), Resource: NewEntityUID("a", "\u0000\u0000")}, + Request{Principal: types.NewEntityUID("a", "\u0000\u0000"), Action: types.NewEntityUID("Action", "action"), Resource: types.NewEntityUID("a", "\u0000\u0000")}, Deny, nil, []int{0}, @@ -980,7 +984,7 @@ func TestCorpusRelated(t *testing.T) { ) when { true && ((if (principal in action) then (ip("")) else (if true then (ip("6b6b:f00::32ff:ffff:6368/00")) else (ip("7265:6c69:706d:6f43:5f74:6f70:7374:6f68")))).isMulticast()) };`, - Request{Principal: NewEntityUID("a", "\u0000\b\u0011\u0000R"), Action: NewEntityUID("Action", "action"), Resource: NewEntityUID("a", "\u0000\b\u0011\u0000R")}, + Request{Principal: types.NewEntityUID("a", "\u0000\b\u0011\u0000R"), Action: types.NewEntityUID("Action", "action"), Resource: types.NewEntityUID("a", "\u0000\b\u0011\u0000R")}, Deny, nil, []int{0}, @@ -1008,7 +1012,7 @@ func TestCorpusRelated(t *testing.T) { ) when { true && (([ip("c5c5:c5c5:c5c5:c5c5:c5c5:c5c5:c5c5:c5c5/68")].containsAll([ip("c5c5:c5c5:c5c5:c5c5:c5c5:5cc5:c5c5:c5c5/68")])) || ((ip("")) == (ip("")))) };`, - request: Request{Principal: NewEntityUID("a", "\u0000\u0000(W\u0000\u0000\u0000"), Action: NewEntityUID("Action", "action"), Resource: NewEntityUID("a", "")}, + request: Request{Principal: types.NewEntityUID("a", "\u0000\u0000(W\u0000\u0000\u0000"), Action: types.NewEntityUID("Action", "action"), Resource: types.NewEntityUID("a", "")}, decision: Deny, reasons: nil, errors: []int{0}, @@ -1019,19 +1023,123 @@ func TestCorpusRelated(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() policy, err := NewPolicySet("", []byte(tt.policy)) - testutilOK(t, err) + testutil.OK(t, err) ok, diag := policy.IsAuthorized(Entities{}, tt.request) - testutilEquals(t, ok, tt.decision) + testutil.Equals(t, ok, tt.decision) var reasons []int for _, n := range diag.Reasons { reasons = append(reasons, n.Policy) } - testutilEquals(t, reasons, tt.reasons) + testutil.Equals(t, reasons, tt.reasons) var errors []int for _, n := range diag.Errors { errors = append(errors, n.Policy) } - testutilEquals(t, errors, tt.errors) + testutil.Equals(t, errors, tt.errors) }) } } + +func TestEntitiesJSON(t *testing.T) { + t.Parallel() + t.Run("Marshal", func(t *testing.T) { + t.Parallel() + e := Entities{} + ent := Entity{ + UID: types.NewEntityUID("Type", "id"), + Parents: []types.EntityUID{}, + Attributes: types.Record{"key": types.Long(42)}, + } + e[ent.UID] = ent + b, err := e.MarshalJSON() + testutil.OK(t, err) + testutil.Equals(t, string(b), `[{"uid":{"type":"Type","id":"id"},"attrs":{"key":42}}]`) + }) + + t.Run("Unmarshal", func(t *testing.T) { + t.Parallel() + b := []byte(`[{"uid":{"type":"Type","id":"id"},"parents":[],"attrs":{"key":42}}]`) + var e Entities + err := json.Unmarshal(b, &e) + testutil.OK(t, err) + want := Entities{} + ent := Entity{ + UID: types.NewEntityUID("Type", "id"), + Parents: []types.EntityUID{}, + Attributes: types.Record{"key": types.Long(42)}, + } + want[ent.UID] = ent + testutil.Equals(t, e, want) + }) + + t.Run("UnmarshalErr", func(t *testing.T) { + t.Parallel() + var e Entities + err := e.UnmarshalJSON([]byte(`!@#$`)) + testutil.Error(t, err) + }) +} + +func TestJSONEffect(t *testing.T) { + t.Parallel() + t.Run("MarshalPermit", func(t *testing.T) { + t.Parallel() + e := Permit + b, err := e.MarshalJSON() + testutil.OK(t, err) + testutil.Equals(t, string(b), `"permit"`) + }) + t.Run("MarshalForbid", func(t *testing.T) { + t.Parallel() + e := Forbid + b, err := e.MarshalJSON() + testutil.OK(t, err) + testutil.Equals(t, string(b), `"forbid"`) + }) + t.Run("UnmarshalPermit", func(t *testing.T) { + t.Parallel() + var e Effect + err := json.Unmarshal([]byte(`"permit"`), &e) + testutil.OK(t, err) + testutil.Equals(t, e, Permit) + }) + t.Run("UnmarshalForbid", func(t *testing.T) { + t.Parallel() + var e Effect + err := json.Unmarshal([]byte(`"forbid"`), &e) + testutil.OK(t, err) + testutil.Equals(t, e, Forbid) + }) +} + +func TestJSONDecision(t *testing.T) { + t.Parallel() + t.Run("MarshalAllow", func(t *testing.T) { + t.Parallel() + d := Allow + b, err := d.MarshalJSON() + testutil.OK(t, err) + testutil.Equals(t, string(b), `"allow"`) + }) + t.Run("MarshalDeny", func(t *testing.T) { + t.Parallel() + d := Deny + b, err := d.MarshalJSON() + testutil.OK(t, err) + testutil.Equals(t, string(b), `"deny"`) + }) + t.Run("UnmarshalAllow", func(t *testing.T) { + t.Parallel() + var d Decision + err := json.Unmarshal([]byte(`"allow"`), &d) + testutil.OK(t, err) + testutil.Equals(t, d, Allow) + }) + t.Run("UnmarshalDeny", func(t *testing.T) { + t.Parallel() + var d Decision + err := json.Unmarshal([]byte(`"deny"`), &d) + testutil.OK(t, err) + testutil.Equals(t, d, Deny) + }) +} diff --git a/corpus_test.go b/corpus_test.go index e27353d9..305c54f9 100644 --- a/corpus_test.go +++ b/corpus_test.go @@ -11,18 +11,20 @@ import ( "slices" "strings" "testing" + + "github.com/cedar-policy/cedar-go/types" ) // jsonEntity is not part of entityValue as I can find // no evidence this is part of the JSON spec. It also // requires creating a parser, so it's quite expensive. -type jsonEntity EntityUID +type jsonEntity types.EntityUID func (e *jsonEntity) UnmarshalJSON(b []byte) error { if string(b) == "null" { return nil } - var input EntityUID + var input types.EntityUID if err := json.Unmarshal(b, &input); err != nil { return err } @@ -36,14 +38,14 @@ type corpusTest struct { ShouldValidate bool `json:"shouldValidate"` Entities string `json:"entities"` Requests []struct { - Desc string `json:"description"` - Principal jsonEntity `json:"principal"` - Action jsonEntity `json:"action"` - Resource jsonEntity `json:"resource"` - Context Record `json:"context"` - Decision string `json:"decision"` - Reasons []string `json:"reason"` - Errors []string `json:"errors"` + Desc string `json:"description"` + Principal jsonEntity `json:"principal"` + Action jsonEntity `json:"action"` + Resource jsonEntity `json:"resource"` + Context types.Record `json:"context"` + Decision string `json:"decision"` + Reasons []string `json:"reason"` + Errors []string `json:"errors"` } `json:"requests"` } @@ -157,9 +159,9 @@ func TestCorpus(t *testing.T) { ok, diag := policySet.IsAuthorized( entities, Request{ - Principal: EntityUID(request.Principal), - Action: EntityUID(request.Action), - Resource: EntityUID(request.Resource), + Principal: types.EntityUID(request.Principal), + Action: types.EntityUID(request.Action), + Resource: types.EntityUID(request.Resource), Context: request.Context, }) diff --git a/eval.go b/eval.go index 30a739a7..b38159c5 100644 --- a/eval.go +++ b/eval.go @@ -3,194 +3,120 @@ package cedar import ( "fmt" + "github.com/cedar-policy/cedar-go/types" "github.com/cedar-policy/cedar-go/x/exp/parser" ) var errOverflow = fmt.Errorf("integer overflow") -var errType = fmt.Errorf("type error") var errUnknownMethod = fmt.Errorf("unknown method") var errUnknownExtensionFunction = fmt.Errorf("function does not exist") var errArity = fmt.Errorf("wrong number of arguments provided to extension function") var errAttributeAccess = fmt.Errorf("does not have the attribute") -var errDecimal = fmt.Errorf("error parsing decimal value") -var errIP = fmt.Errorf("error parsing ip value") var errEntityNotExist = fmt.Errorf("does not exist") var errUnspecifiedEntity = fmt.Errorf("unspecified entity") type evalContext struct { Entities Entities - Principal, Action, Resource Value - Context Value + Principal, Action, Resource types.Value + Context types.Value } type evaler interface { - Eval(*evalContext) (Value, error) + Eval(*evalContext) (types.Value, error) } -func valueToBool(v Value) (Boolean, error) { - bv, ok := v.(Boolean) - if !ok { - return false, fmt.Errorf("%w: expected bool, got %v", errType, v.typeName()) - } - return bv, nil -} - -func evalBool(n evaler, ctx *evalContext) (Boolean, error) { +func evalBool(n evaler, ctx *evalContext) (types.Boolean, error) { v, err := n.Eval(ctx) if err != nil { return false, err } - b, err := valueToBool(v) + b, err := types.ValueToBool(v) if err != nil { return false, err } return b, nil } -func valueToLong(v Value) (Long, error) { - lv, ok := v.(Long) - if !ok { - return 0, fmt.Errorf("%w: expected long, got %v", errType, v.typeName()) - } - return lv, nil -} - -func evalLong(n evaler, ctx *evalContext) (Long, error) { +func evalLong(n evaler, ctx *evalContext) (types.Long, error) { v, err := n.Eval(ctx) if err != nil { return 0, err } - l, err := valueToLong(v) + l, err := types.ValueToLong(v) if err != nil { return 0, err } return l, nil } -func valueToString(v Value) (String, error) { - sv, ok := v.(String) - if !ok { - return "", fmt.Errorf("%w: expected string, got %v", errType, v.typeName()) - } - return sv, nil -} - -func evalString(n evaler, ctx *evalContext) (String, error) { +func evalString(n evaler, ctx *evalContext) (types.String, error) { v, err := n.Eval(ctx) if err != nil { return "", err } - s, err := valueToString(v) + s, err := types.ValueToString(v) if err != nil { return "", err } return s, nil } -func valueToSet(v Value) (Set, error) { - sv, ok := v.(Set) - if !ok { - return nil, fmt.Errorf("%w: expected set, got %v", errType, v.typeName()) - } - return sv, nil -} - -func evalSet(n evaler, ctx *evalContext) (Set, error) { +func evalSet(n evaler, ctx *evalContext) (types.Set, error) { v, err := n.Eval(ctx) if err != nil { return nil, err } - s, err := valueToSet(v) + s, err := types.ValueToSet(v) if err != nil { return nil, err } return s, nil } -func valueToRecord(v Value) (Record, error) { - rv, ok := v.(Record) - if !ok { - return nil, fmt.Errorf("%w: expected record got %v", errType, v.typeName()) - } - return rv, nil -} - -func valueToEntity(v Value) (EntityUID, error) { - ev, ok := v.(EntityUID) - if !ok { - return EntityUID{}, fmt.Errorf("%w: expected (entity of type `any_entity_type`), got %v", errType, v.typeName()) - } - return ev, nil -} - -func valueToPath(v Value) (path, error) { - ev, ok := v.(path) - if !ok { - return "", fmt.Errorf("%w: expected (path of type `any_entity_type`), got %v", errType, v.typeName()) - } - return ev, nil -} - -func evalEntity(n evaler, ctx *evalContext) (EntityUID, error) { +func evalEntity(n evaler, ctx *evalContext) (types.EntityUID, error) { v, err := n.Eval(ctx) if err != nil { - return EntityUID{}, err + return types.EntityUID{}, err } - e, err := valueToEntity(v) + e, err := types.ValueToEntity(v) if err != nil { - return EntityUID{}, err + return types.EntityUID{}, err } return e, nil } -func evalPath(n evaler, ctx *evalContext) (path, error) { +func evalPath(n evaler, ctx *evalContext) (types.Path, error) { v, err := n.Eval(ctx) if err != nil { return "", err } - e, err := valueToPath(v) + e, err := types.ValueToPath(v) if err != nil { return "", err } return e, nil } -func valueToDecimal(v Value) (Decimal, error) { - d, ok := v.(Decimal) - if !ok { - return 0, fmt.Errorf("%w: expected decimal, got %v", errType, v.typeName()) - } - return d, nil -} - -func evalDecimal(n evaler, ctx *evalContext) (Decimal, error) { +func evalDecimal(n evaler, ctx *evalContext) (types.Decimal, error) { v, err := n.Eval(ctx) if err != nil { - return Decimal(0), err + return types.Decimal(0), err } - d, err := valueToDecimal(v) + d, err := types.ValueToDecimal(v) if err != nil { - return Decimal(0), err + return types.Decimal(0), err } return d, nil } -func valueToIP(v Value) (IPAddr, error) { - i, ok := v.(IPAddr) - if !ok { - return IPAddr{}, fmt.Errorf("%w: expected ipaddr, got %v", errType, v.typeName()) - } - return i, nil -} - -func evalIP(n evaler, ctx *evalContext) (IPAddr, error) { +func evalIP(n evaler, ctx *evalContext) (types.IPAddr, error) { v, err := n.Eval(ctx) if err != nil { - return IPAddr{}, err + return types.IPAddr{}, err } - i, err := valueToIP(v) + i, err := types.ValueToIP(v) if err != nil { - return IPAddr{}, err + return types.IPAddr{}, err } return i, nil } @@ -206,20 +132,20 @@ func newErrorEval(err error) *errorEval { } } -func (n *errorEval) Eval(_ *evalContext) (Value, error) { - return zeroValue(), n.err +func (n *errorEval) Eval(_ *evalContext) (types.Value, error) { + return types.ZeroValue(), n.err } // literalEval type literalEval struct { - value Value + value types.Value } -func newLiteralEval(value Value) *literalEval { +func newLiteralEval(value types.Value) *literalEval { return &literalEval{value: value} } -func (n *literalEval) Eval(_ *evalContext) (Value, error) { +func (n *literalEval) Eval(_ *evalContext) (types.Value, error) { return n.value, nil } @@ -236,25 +162,25 @@ func newOrNode(lhs evaler, rhs evaler) *orEval { } } -func (n *orEval) Eval(ctx *evalContext) (Value, error) { +func (n *orEval) Eval(ctx *evalContext) (types.Value, error) { v, err := n.lhs.Eval(ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } - b, err := valueToBool(v) + b, err := types.ValueToBool(v) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } if b { return v, nil } v, err = n.rhs.Eval(ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } - _, err = valueToBool(v) + _, err = types.ValueToBool(v) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } return v, nil } @@ -272,25 +198,25 @@ func newAndEval(lhs evaler, rhs evaler) *andEval { } } -func (n *andEval) Eval(ctx *evalContext) (Value, error) { +func (n *andEval) Eval(ctx *evalContext) (types.Value, error) { v, err := n.lhs.Eval(ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } - b, err := valueToBool(v) + b, err := types.ValueToBool(v) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } if !b { return v, nil } v, err = n.rhs.Eval(ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } - _, err = valueToBool(v) + _, err = types.ValueToBool(v) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } return v, nil } @@ -306,14 +232,14 @@ func newNotEval(inner evaler) *notEval { } } -func (n *notEval) Eval(ctx *evalContext) (Value, error) { +func (n *notEval) Eval(ctx *evalContext) (types.Value, error) { v, err := n.inner.Eval(ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } - b, err := valueToBool(v) + b, err := types.ValueToBool(v) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } return !b, nil } @@ -323,7 +249,7 @@ func (n *notEval) Eval(ctx *evalContext) (Value, error) { // behavior (https://go.dev/ref/spec#Integer_overflow), so we can go ahead and // do the operations and then check for overflow ex post facto. -func checkedAddI64(lhs, rhs Long) (Long, bool) { +func checkedAddI64(lhs, rhs types.Long) (types.Long, bool) { result := lhs + rhs if (result > lhs) != (rhs > 0) { return result, false @@ -331,7 +257,7 @@ func checkedAddI64(lhs, rhs Long) (Long, bool) { return result, true } -func checkedSubI64(lhs, rhs Long) (Long, bool) { +func checkedSubI64(lhs, rhs types.Long) (types.Long, bool) { result := lhs - rhs if (result > lhs) != (rhs < 0) { return result, false @@ -339,7 +265,7 @@ func checkedSubI64(lhs, rhs Long) (Long, bool) { return result, true } -func checkedMulI64(lhs, rhs Long) (Long, bool) { +func checkedMulI64(lhs, rhs types.Long) (types.Long, bool) { if lhs == 0 || rhs == 0 { return 0, true } @@ -355,7 +281,7 @@ func checkedMulI64(lhs, rhs Long) (Long, bool) { return result, true } -func checkedNegI64(a Long) (Long, bool) { +func checkedNegI64(a types.Long) (types.Long, bool) { if a == -9_223_372_036_854_775_808 { return 0, false } @@ -375,18 +301,18 @@ func newAddEval(lhs evaler, rhs evaler) *addEval { } } -func (n *addEval) Eval(ctx *evalContext) (Value, error) { +func (n *addEval) Eval(ctx *evalContext) (types.Value, error) { lhs, err := evalLong(n.lhs, ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } rhs, err := evalLong(n.rhs, ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } res, ok := checkedAddI64(lhs, rhs) if !ok { - return zeroValue(), fmt.Errorf("%w while attempting to add `%d` with `%d`", errOverflow, lhs, rhs) + return types.ZeroValue(), fmt.Errorf("%w while attempting to add `%d` with `%d`", errOverflow, lhs, rhs) } return res, nil } @@ -404,18 +330,18 @@ func newSubtractEval(lhs evaler, rhs evaler) *subtractEval { } } -func (n *subtractEval) Eval(ctx *evalContext) (Value, error) { +func (n *subtractEval) Eval(ctx *evalContext) (types.Value, error) { lhs, err := evalLong(n.lhs, ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } rhs, err := evalLong(n.rhs, ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } res, ok := checkedSubI64(lhs, rhs) if !ok { - return zeroValue(), fmt.Errorf("%w while attempting to subtract `%d` from `%d`", errOverflow, rhs, lhs) + return types.ZeroValue(), fmt.Errorf("%w while attempting to subtract `%d` from `%d`", errOverflow, rhs, lhs) } return res, nil } @@ -433,18 +359,18 @@ func newMultiplyEval(lhs evaler, rhs evaler) *multiplyEval { } } -func (n *multiplyEval) Eval(ctx *evalContext) (Value, error) { +func (n *multiplyEval) Eval(ctx *evalContext) (types.Value, error) { lhs, err := evalLong(n.lhs, ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } rhs, err := evalLong(n.rhs, ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } res, ok := checkedMulI64(lhs, rhs) if !ok { - return zeroValue(), fmt.Errorf("%w while attempting to multiply `%d` by `%d`", errOverflow, lhs, rhs) + return types.ZeroValue(), fmt.Errorf("%w while attempting to multiply `%d` by `%d`", errOverflow, lhs, rhs) } return res, nil } @@ -460,14 +386,14 @@ func newNegateEval(inner evaler) *negateEval { } } -func (n *negateEval) Eval(ctx *evalContext) (Value, error) { +func (n *negateEval) Eval(ctx *evalContext) (types.Value, error) { inner, err := evalLong(n.inner, ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } res, ok := checkedNegI64(inner) if !ok { - return zeroValue(), fmt.Errorf("%w while attempting to negate `%d`", errOverflow, inner) + return types.ZeroValue(), fmt.Errorf("%w while attempting to negate `%d`", errOverflow, inner) } return res, nil } @@ -485,16 +411,16 @@ func newLongLessThanEval(lhs evaler, rhs evaler) *longLessThanEval { } } -func (n *longLessThanEval) Eval(ctx *evalContext) (Value, error) { +func (n *longLessThanEval) Eval(ctx *evalContext) (types.Value, error) { lhs, err := evalLong(n.lhs, ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } rhs, err := evalLong(n.rhs, ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } - return Boolean(lhs < rhs), nil + return types.Boolean(lhs < rhs), nil } // longLessThanOrEqualEval @@ -510,16 +436,16 @@ func newLongLessThanOrEqualEval(lhs evaler, rhs evaler) *longLessThanOrEqualEval } } -func (n *longLessThanOrEqualEval) Eval(ctx *evalContext) (Value, error) { +func (n *longLessThanOrEqualEval) Eval(ctx *evalContext) (types.Value, error) { lhs, err := evalLong(n.lhs, ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } rhs, err := evalLong(n.rhs, ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } - return Boolean(lhs <= rhs), nil + return types.Boolean(lhs <= rhs), nil } // longGreaterThanEval @@ -535,16 +461,16 @@ func newLongGreaterThanEval(lhs evaler, rhs evaler) *longGreaterThanEval { } } -func (n *longGreaterThanEval) Eval(ctx *evalContext) (Value, error) { +func (n *longGreaterThanEval) Eval(ctx *evalContext) (types.Value, error) { lhs, err := evalLong(n.lhs, ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } rhs, err := evalLong(n.rhs, ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } - return Boolean(lhs > rhs), nil + return types.Boolean(lhs > rhs), nil } // longGreaterThanOrEqualEval @@ -560,16 +486,16 @@ func newLongGreaterThanOrEqualEval(lhs evaler, rhs evaler) *longGreaterThanOrEqu } } -func (n *longGreaterThanOrEqualEval) Eval(ctx *evalContext) (Value, error) { +func (n *longGreaterThanOrEqualEval) Eval(ctx *evalContext) (types.Value, error) { lhs, err := evalLong(n.lhs, ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } rhs, err := evalLong(n.rhs, ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } - return Boolean(lhs >= rhs), nil + return types.Boolean(lhs >= rhs), nil } // decimalLessThanEval @@ -585,16 +511,16 @@ func newDecimalLessThanEval(lhs evaler, rhs evaler) *decimalLessThanEval { } } -func (n *decimalLessThanEval) Eval(ctx *evalContext) (Value, error) { +func (n *decimalLessThanEval) Eval(ctx *evalContext) (types.Value, error) { lhs, err := evalDecimal(n.lhs, ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } rhs, err := evalDecimal(n.rhs, ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } - return Boolean(lhs < rhs), nil + return types.Boolean(lhs < rhs), nil } // decimalLessThanOrEqualEval @@ -610,16 +536,16 @@ func newDecimalLessThanOrEqualEval(lhs evaler, rhs evaler) *decimalLessThanOrEqu } } -func (n *decimalLessThanOrEqualEval) Eval(ctx *evalContext) (Value, error) { +func (n *decimalLessThanOrEqualEval) Eval(ctx *evalContext) (types.Value, error) { lhs, err := evalDecimal(n.lhs, ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } rhs, err := evalDecimal(n.rhs, ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } - return Boolean(lhs <= rhs), nil + return types.Boolean(lhs <= rhs), nil } // decimalGreaterThanEval @@ -635,16 +561,16 @@ func newDecimalGreaterThanEval(lhs evaler, rhs evaler) *decimalGreaterThanEval { } } -func (n *decimalGreaterThanEval) Eval(ctx *evalContext) (Value, error) { +func (n *decimalGreaterThanEval) Eval(ctx *evalContext) (types.Value, error) { lhs, err := evalDecimal(n.lhs, ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } rhs, err := evalDecimal(n.rhs, ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } - return Boolean(lhs > rhs), nil + return types.Boolean(lhs > rhs), nil } // decimalGreaterThanOrEqualEval @@ -660,16 +586,16 @@ func newDecimalGreaterThanOrEqualEval(lhs evaler, rhs evaler) *decimalGreaterTha } } -func (n *decimalGreaterThanOrEqualEval) Eval(ctx *evalContext) (Value, error) { +func (n *decimalGreaterThanOrEqualEval) Eval(ctx *evalContext) (types.Value, error) { lhs, err := evalDecimal(n.lhs, ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } rhs, err := evalDecimal(n.rhs, ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } - return Boolean(lhs >= rhs), nil + return types.Boolean(lhs >= rhs), nil } // ifThenElseEval @@ -687,10 +613,10 @@ func newIfThenElseEval(if_, then, else_ evaler) *ifThenElseEval { } } -func (n *ifThenElseEval) Eval(ctx *evalContext) (Value, error) { +func (n *ifThenElseEval) Eval(ctx *evalContext) (types.Value, error) { cond, err := evalBool(n.if_, ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } if cond { return n.then.Eval(ctx) @@ -710,16 +636,16 @@ func newEqualEval(lhs, rhs evaler) *equalEval { } } -func (n *equalEval) Eval(ctx *evalContext) (Value, error) { +func (n *equalEval) Eval(ctx *evalContext) (types.Value, error) { lv, err := n.lhs.Eval(ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } rv, err := n.rhs.Eval(ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } - return Boolean(lv.equal(rv)), nil + return types.Boolean(lv.Equal(rv)), nil } // notEqualEval @@ -734,16 +660,16 @@ func newNotEqualEval(lhs, rhs evaler) *notEqualEval { } } -func (n *notEqualEval) Eval(ctx *evalContext) (Value, error) { +func (n *notEqualEval) Eval(ctx *evalContext) (types.Value, error) { lv, err := n.lhs.Eval(ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } rv, err := n.rhs.Eval(ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } - return Boolean(!lv.equal(rv)), nil + return types.Boolean(!lv.Equal(rv)), nil } // setLiteralEval @@ -755,12 +681,12 @@ func newSetLiteralEval(elements []evaler) *setLiteralEval { return &setLiteralEval{elements: elements} } -func (n *setLiteralEval) Eval(ctx *evalContext) (Value, error) { - var vals Set +func (n *setLiteralEval) Eval(ctx *evalContext) (types.Value, error) { + var vals types.Set for _, e := range n.elements { v, err := e.Eval(ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } vals = append(vals, v) } @@ -779,16 +705,16 @@ func newContainsEval(lhs, rhs evaler) *containsEval { } } -func (n *containsEval) Eval(ctx *evalContext) (Value, error) { +func (n *containsEval) Eval(ctx *evalContext) (types.Value, error) { lhs, err := evalSet(n.lhs, ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } rhs, err := n.rhs.Eval(ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } - return Boolean(lhs.contains(rhs)), nil + return types.Boolean(lhs.Contains(rhs)), nil } // containsAllEval @@ -803,23 +729,23 @@ func newContainsAllEval(lhs, rhs evaler) *containsAllEval { } } -func (n *containsAllEval) Eval(ctx *evalContext) (Value, error) { +func (n *containsAllEval) Eval(ctx *evalContext) (types.Value, error) { lhs, err := evalSet(n.lhs, ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } rhs, err := evalSet(n.rhs, ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } result := true for _, e := range rhs { - if !lhs.contains(e) { + if !lhs.Contains(e) { result = false break } } - return Boolean(result), nil + return types.Boolean(result), nil } // containsAnyEval @@ -834,23 +760,23 @@ func newContainsAnyEval(lhs, rhs evaler) *containsAnyEval { } } -func (n *containsAnyEval) Eval(ctx *evalContext) (Value, error) { +func (n *containsAnyEval) Eval(ctx *evalContext) (types.Value, error) { lhs, err := evalSet(n.lhs, ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } rhs, err := evalSet(n.rhs, ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } result := false for _, e := range rhs { - if lhs.contains(e) { + if lhs.Contains(e) { result = true break } } - return Boolean(result), nil + return types.Boolean(result), nil } // recordLiteralEval @@ -862,12 +788,12 @@ func newRecordLiteralEval(elements map[string]evaler) *recordLiteralEval { return &recordLiteralEval{elements: elements} } -func (n *recordLiteralEval) Eval(ctx *evalContext) (Value, error) { - vals := Record{} +func (n *recordLiteralEval) Eval(ctx *evalContext) (types.Value, error) { + vals := types.Record{} for k, en := range n.elements { v, err := en.Eval(ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } vals[k] = v } @@ -884,34 +810,34 @@ func newAttributeAccessEval(record evaler, attribute string) *attributeAccessEva return &attributeAccessEval{object: record, attribute: attribute} } -func (n *attributeAccessEval) Eval(ctx *evalContext) (Value, error) { +func (n *attributeAccessEval) Eval(ctx *evalContext) (types.Value, error) { v, err := n.object.Eval(ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } - var record Record + var record types.Record key := "record" switch vv := v.(type) { - case EntityUID: + case types.EntityUID: key = "`" + vv.String() + "`" - var unspecified EntityUID + var unspecified types.EntityUID if vv == unspecified { - return zeroValue(), fmt.Errorf("cannot access attribute `%s` of %w", n.attribute, errUnspecifiedEntity) + return types.ZeroValue(), fmt.Errorf("cannot access attribute `%s` of %w", n.attribute, errUnspecifiedEntity) } rec, ok := ctx.Entities[vv] if !ok { - return zeroValue(), fmt.Errorf("entity `%v` %w", vv.String(), errEntityNotExist) + return types.ZeroValue(), fmt.Errorf("entity `%v` %w", vv.String(), errEntityNotExist) } else { record = rec.Attributes } - case Record: + case types.Record: record = vv default: - return zeroValue(), fmt.Errorf("%w: expected one of [record, (entity of type `any_entity_type`)], got %v", errType, v.typeName()) + return types.ZeroValue(), fmt.Errorf("%w: expected one of [record, (entity of type `any_entity_type`)], got %v", types.ErrType, v.TypeName()) } val, ok := record[n.attribute] if !ok { - return zeroValue(), fmt.Errorf("%s %w `%s`", key, errAttributeAccess, n.attribute) + return types.ZeroValue(), fmt.Errorf("%s %w `%s`", key, errAttributeAccess, n.attribute) } return val, nil } @@ -926,27 +852,27 @@ func newHasEval(record evaler, attribute string) *hasEval { return &hasEval{object: record, attribute: attribute} } -func (n *hasEval) Eval(ctx *evalContext) (Value, error) { +func (n *hasEval) Eval(ctx *evalContext) (types.Value, error) { v, err := n.object.Eval(ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } - var record Record + var record types.Record switch vv := v.(type) { - case EntityUID: + case types.EntityUID: rec, ok := ctx.Entities[vv] if !ok { - record = Record{} + record = types.Record{} } else { record = rec.Attributes } - case Record: + case types.Record: record = vv default: - return zeroValue(), fmt.Errorf("%w: expected one of [record, (entity of type `any_entity_type`)], got %v", errType, v.typeName()) + return types.ZeroValue(), fmt.Errorf("%w: expected one of [record, (entity of type `any_entity_type`)], got %v", types.ErrType, v.TypeName()) } _, ok := record[n.attribute] - return Boolean(ok), nil + return types.Boolean(ok), nil } // likeEval @@ -959,20 +885,20 @@ func newLikeEval(lhs evaler, pattern parser.Pattern) *likeEval { return &likeEval{lhs: lhs, pattern: pattern} } -func (l *likeEval) Eval(ctx *evalContext) (Value, error) { +func (l *likeEval) Eval(ctx *evalContext) (types.Value, error) { v, err := evalString(l.lhs, ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } - return Boolean(match(l.pattern, string(v))), nil + return types.Boolean(match(l.pattern, string(v))), nil } -type variableName func(ctx *evalContext) Value +type variableName func(ctx *evalContext) types.Value -func variableNamePrincipal(ctx *evalContext) Value { return ctx.Principal } -func variableNameAction(ctx *evalContext) Value { return ctx.Action } -func variableNameResource(ctx *evalContext) Value { return ctx.Resource } -func variableNameContext(ctx *evalContext) Value { return ctx.Context } +func variableNamePrincipal(ctx *evalContext) types.Value { return ctx.Principal } +func variableNameAction(ctx *evalContext) types.Value { return ctx.Action } +func variableNameResource(ctx *evalContext) types.Value { return ctx.Resource } +func variableNameContext(ctx *evalContext) types.Value { return ctx.Context } // variableEval type variableEval struct { @@ -983,7 +909,7 @@ func newVariableEval(variableName variableName) *variableEval { return &variableEval{variableName: variableName} } -func (n *variableEval) Eval(ctx *evalContext) (Value, error) { +func (n *variableEval) Eval(ctx *evalContext) (types.Value, error) { return n.variableName(ctx), nil } @@ -996,11 +922,11 @@ func newInEval(lhs, rhs evaler) *inEval { return &inEval{lhs: lhs, rhs: rhs} } -func entityIn(entity EntityUID, query map[EntityUID]struct{}, entities Entities) bool { - checked := map[EntityUID]struct{}{} - toCheck := []EntityUID{entity} +func entityIn(entity types.EntityUID, query map[types.EntityUID]struct{}, entities Entities) bool { + checked := map[types.EntityUID]struct{}{} + toCheck := []types.EntityUID{entity} for len(toCheck) > 0 { - var candidate EntityUID + var candidate types.EntityUID candidate, toCheck = toCheck[len(toCheck)-1], toCheck[:len(toCheck)-1] if _, ok := checked[candidate]; ok { continue @@ -1014,34 +940,34 @@ func entityIn(entity EntityUID, query map[EntityUID]struct{}, entities Entities) return false } -func (n *inEval) Eval(ctx *evalContext) (Value, error) { +func (n *inEval) Eval(ctx *evalContext) (types.Value, error) { lhs, err := evalEntity(n.lhs, ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } rhs, err := n.rhs.Eval(ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } - query := map[EntityUID]struct{}{} + query := map[types.EntityUID]struct{}{} switch rhsv := rhs.(type) { - case EntityUID: + case types.EntityUID: query[rhsv] = struct{}{} - case Set: + case types.Set: for _, rhv := range rhsv { - e, err := valueToEntity(rhv) + e, err := types.ValueToEntity(rhv) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } query[e] = struct{}{} } default: - return zeroValue(), fmt.Errorf( - "%w: expected one of [set, (entity of type `any_entity_type`)], got %v", errType, rhs.typeName()) + return types.ZeroValue(), fmt.Errorf( + "%w: expected one of [set, (entity of type `any_entity_type`)], got %v", types.ErrType, rhs.TypeName()) } - return Boolean(entityIn(lhs, query, ctx.Entities)), nil + return types.Boolean(entityIn(lhs, query, ctx.Entities)), nil } // isEval @@ -1053,18 +979,18 @@ func newIsEval(lhs, rhs evaler) *isEval { return &isEval{lhs: lhs, rhs: rhs} } -func (n *isEval) Eval(ctx *evalContext) (Value, error) { +func (n *isEval) Eval(ctx *evalContext) (types.Value, error) { lhs, err := evalEntity(n.lhs, ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } rhs, err := evalPath(n.rhs, ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } - return Boolean(path(lhs.Type) == rhs), nil + return types.Boolean(types.Path(lhs.Type) == rhs), nil } // decimalLiteralEval @@ -1076,15 +1002,15 @@ func newDecimalLiteralEval(literal evaler) *decimalLiteralEval { return &decimalLiteralEval{literal: literal} } -func (n *decimalLiteralEval) Eval(ctx *evalContext) (Value, error) { +func (n *decimalLiteralEval) Eval(ctx *evalContext) (types.Value, error) { literal, err := evalString(n.literal, ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } - d, err := ParseDecimal(string(literal)) + d, err := types.ParseDecimal(string(literal)) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } return d, nil @@ -1098,26 +1024,26 @@ func newIPLiteralEval(literal evaler) *ipLiteralEval { return &ipLiteralEval{literal: literal} } -func (n *ipLiteralEval) Eval(ctx *evalContext) (Value, error) { +func (n *ipLiteralEval) Eval(ctx *evalContext) (types.Value, error) { literal, err := evalString(n.literal, ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } - i, err := ParseIPAddr(string(literal)) + i, err := types.ParseIPAddr(string(literal)) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } return i, nil } -type ipTestType func(v IPAddr) bool +type ipTestType func(v types.IPAddr) bool -func ipTestIPv4(v IPAddr) bool { return v.isIPv4() } -func ipTestIPv6(v IPAddr) bool { return v.isIPv6() } -func ipTestLoopback(v IPAddr) bool { return v.isLoopback() } -func ipTestMulticast(v IPAddr) bool { return v.isMulticast() } +func ipTestIPv4(v types.IPAddr) bool { return v.IsIPv4() } +func ipTestIPv6(v types.IPAddr) bool { return v.IsIPv6() } +func ipTestLoopback(v types.IPAddr) bool { return v.IsLoopback() } +func ipTestMulticast(v types.IPAddr) bool { return v.IsMulticast() } // ipTestEval type ipTestEval struct { @@ -1129,12 +1055,12 @@ func newIPTestEval(object evaler, test ipTestType) *ipTestEval { return &ipTestEval{object: object, test: test} } -func (n *ipTestEval) Eval(ctx *evalContext) (Value, error) { +func (n *ipTestEval) Eval(ctx *evalContext) (types.Value, error) { i, err := evalIP(n.object, ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } - return Boolean(n.test(i)), nil + return types.Boolean(n.test(i)), nil } // ipIsInRangeEval @@ -1147,14 +1073,14 @@ func newIPIsInRangeEval(lhs, rhs evaler) *ipIsInRangeEval { return &ipIsInRangeEval{lhs: lhs, rhs: rhs} } -func (n *ipIsInRangeEval) Eval(ctx *evalContext) (Value, error) { +func (n *ipIsInRangeEval) Eval(ctx *evalContext) (types.Value, error) { lhs, err := evalIP(n.lhs, ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } rhs, err := evalIP(n.rhs, ctx) if err != nil { - return zeroValue(), err + return types.ZeroValue(), err } - return Boolean(rhs.contains(lhs)), nil + return types.Boolean(rhs.Contains(lhs)), nil } diff --git a/eval_test.go b/eval_test.go index 6bceca3a..bd5027d3 100644 --- a/eval_test.go +++ b/eval_test.go @@ -6,15 +6,17 @@ import ( "strings" "testing" + "github.com/cedar-policy/cedar-go/testutil" + "github.com/cedar-policy/cedar-go/types" "github.com/cedar-policy/cedar-go/x/exp/parser" ) var errTest = fmt.Errorf("test error") // not a real parser -func strEnt(v string) EntityUID { +func strEnt(v string) types.EntityUID { p := strings.Split(v, "::\"") - return EntityUID{Type: p[0], ID: p[1][:len(p[1])-1]} + return types.EntityUID{Type: p[0], ID: p[1][:len(p[1])-1]} } func TestOrNode(t *testing.T) { @@ -32,10 +34,10 @@ func TestOrNode(t *testing.T) { tt := tt t.Run(fmt.Sprintf("%v%v", tt.lhs, tt.rhs), func(t *testing.T) { t.Parallel() - n := newOrNode(newLiteralEval(Boolean(tt.lhs)), newLiteralEval(Boolean(tt.rhs))) + n := newOrNode(newLiteralEval(types.Boolean(tt.lhs)), newLiteralEval(types.Boolean(tt.rhs))) v, err := n.Eval(&evalContext{}) - testutilOK(t, err) - assertBoolValue(t, v, tt.result) + testutil.OK(t, err) + types.AssertBoolValue(t, v, tt.result) }) } } @@ -43,10 +45,10 @@ func TestOrNode(t *testing.T) { t.Run("TrueXShortCircuit", func(t *testing.T) { t.Parallel() n := newOrNode( - newLiteralEval(Boolean(true)), newLiteralEval(Long(1))) + newLiteralEval(types.Boolean(true)), newLiteralEval(types.Long(1))) v, err := n.Eval(&evalContext{}) - testutilOK(t, err) - assertBoolValue(t, v, true) + testutil.OK(t, err) + types.AssertBoolValue(t, v, true) }) { @@ -55,10 +57,10 @@ func TestOrNode(t *testing.T) { lhs, rhs evaler err error }{ - {"LhsError", newErrorEval(errTest), newLiteralEval(Boolean(true)), errTest}, - {"LhsTypeError", newLiteralEval(Long(1)), newLiteralEval(Boolean(true)), errType}, - {"RhsError", newLiteralEval(Boolean(false)), newErrorEval(errTest), errTest}, - {"RhsTypeError", newLiteralEval(Boolean(false)), newLiteralEval(Long(1)), errType}, + {"LhsError", newErrorEval(errTest), newLiteralEval(types.Boolean(true)), errTest}, + {"LhsTypeError", newLiteralEval(types.Long(1)), newLiteralEval(types.Boolean(true)), types.ErrType}, + {"RhsError", newLiteralEval(types.Boolean(false)), newErrorEval(errTest), errTest}, + {"RhsTypeError", newLiteralEval(types.Boolean(false)), newLiteralEval(types.Long(1)), types.ErrType}, } for _, tt := range tests { tt := tt @@ -66,7 +68,7 @@ func TestOrNode(t *testing.T) { t.Parallel() n := newOrNode(tt.lhs, tt.rhs) _, err := n.Eval(&evalContext{}) - assertError(t, err, tt.err) + testutil.AssertError(t, err, tt.err) }) } } @@ -87,10 +89,10 @@ func TestAndNode(t *testing.T) { tt := tt t.Run(fmt.Sprintf("%v%v", tt.lhs, tt.rhs), func(t *testing.T) { t.Parallel() - n := newAndEval(newLiteralEval(Boolean(tt.lhs)), newLiteralEval(Boolean(tt.rhs))) + n := newAndEval(newLiteralEval(types.Boolean(tt.lhs)), newLiteralEval(types.Boolean(tt.rhs))) v, err := n.Eval(&evalContext{}) - testutilOK(t, err) - assertBoolValue(t, v, tt.result) + testutil.OK(t, err) + types.AssertBoolValue(t, v, tt.result) }) } } @@ -98,10 +100,10 @@ func TestAndNode(t *testing.T) { t.Run("FalseXShortCircuit", func(t *testing.T) { t.Parallel() n := newAndEval( - newLiteralEval(Boolean(false)), newLiteralEval(Long(1))) + newLiteralEval(types.Boolean(false)), newLiteralEval(types.Long(1))) v, err := n.Eval(&evalContext{}) - testutilOK(t, err) - assertBoolValue(t, v, false) + testutil.OK(t, err) + types.AssertBoolValue(t, v, false) }) { @@ -110,10 +112,10 @@ func TestAndNode(t *testing.T) { lhs, rhs evaler err error }{ - {"LhsError", newErrorEval(errTest), newLiteralEval(Boolean(true)), errTest}, - {"LhsTypeError", newLiteralEval(Long(1)), newLiteralEval(Boolean(true)), errType}, - {"RhsError", newLiteralEval(Boolean(true)), newErrorEval(errTest), errTest}, - {"RhsTypeError", newLiteralEval(Boolean(true)), newLiteralEval(Long(1)), errType}, + {"LhsError", newErrorEval(errTest), newLiteralEval(types.Boolean(true)), errTest}, + {"LhsTypeError", newLiteralEval(types.Long(1)), newLiteralEval(types.Boolean(true)), types.ErrType}, + {"RhsError", newLiteralEval(types.Boolean(true)), newErrorEval(errTest), errTest}, + {"RhsTypeError", newLiteralEval(types.Boolean(true)), newLiteralEval(types.Long(1)), types.ErrType}, } for _, tt := range tests { tt := tt @@ -121,7 +123,7 @@ func TestAndNode(t *testing.T) { t.Parallel() n := newAndEval(tt.lhs, tt.rhs) _, err := n.Eval(&evalContext{}) - assertError(t, err, tt.err) + testutil.AssertError(t, err, tt.err) }) } } @@ -140,10 +142,10 @@ func TestNotNode(t *testing.T) { tt := tt t.Run(fmt.Sprintf("%v", tt.arg), func(t *testing.T) { t.Parallel() - n := newNotEval(newLiteralEval(Boolean(tt.arg))) + n := newNotEval(newLiteralEval(types.Boolean(tt.arg))) v, err := n.Eval(&evalContext{}) - testutilOK(t, err) - assertBoolValue(t, v, tt.result) + testutil.OK(t, err) + types.AssertBoolValue(t, v, tt.result) }) } } @@ -155,7 +157,7 @@ func TestNotNode(t *testing.T) { err error }{ {"Error", newErrorEval(errTest), errTest}, - {"TypeError", newLiteralEval(Long(1)), errType}, + {"TypeError", newLiteralEval(types.Long(1)), types.ErrType}, } for _, tt := range tests { tt := tt @@ -163,7 +165,7 @@ func TestNotNode(t *testing.T) { t.Parallel() n := newNotEval(tt.arg) _, err := n.Eval(&evalContext{}) - assertError(t, err, tt.err) + testutil.AssertError(t, err, tt.err) }) } } @@ -172,7 +174,7 @@ func TestNotNode(t *testing.T) { func TestCheckedAddI64(t *testing.T) { t.Parallel() tests := []struct { - lhs, rhs, result Long + lhs, rhs, result types.Long ok bool }{ {1, 1, 2, true}, @@ -198,8 +200,8 @@ func TestCheckedAddI64(t *testing.T) { t.Run(fmt.Sprintf("%v+%v=%v(%v)", tt.lhs, tt.rhs, tt.result, tt.ok), func(t *testing.T) { t.Parallel() result, ok := checkedAddI64(tt.lhs, tt.rhs) - testutilEquals(t, ok, tt.ok) - testutilEquals(t, result, tt.result) + testutil.Equals(t, ok, tt.ok) + testutil.Equals(t, result, tt.result) }) } } @@ -207,7 +209,7 @@ func TestCheckedAddI64(t *testing.T) { func TestCheckedSubI64(t *testing.T) { t.Parallel() tests := []struct { - lhs, rhs, result Long + lhs, rhs, result types.Long ok bool }{ {1, 1, 0, true}, @@ -233,8 +235,8 @@ func TestCheckedSubI64(t *testing.T) { t.Run(fmt.Sprintf("%v-%v=%v(%v)", tt.lhs, tt.rhs, tt.result, tt.ok), func(t *testing.T) { t.Parallel() result, ok := checkedSubI64(tt.lhs, tt.rhs) - testutilEquals(t, ok, tt.ok) - testutilEquals(t, result, tt.result) + testutil.Equals(t, ok, tt.ok) + testutil.Equals(t, result, tt.result) }) } } @@ -242,7 +244,7 @@ func TestCheckedSubI64(t *testing.T) { func TestCheckedMulI64(t *testing.T) { t.Parallel() tests := []struct { - lhs, rhs, result Long + lhs, rhs, result types.Long ok bool }{ {2, 3, 6, true}, @@ -307,8 +309,8 @@ func TestCheckedMulI64(t *testing.T) { t.Run(fmt.Sprintf("%v*%v=%v(%v)", tt.lhs, tt.rhs, tt.result, tt.ok), func(t *testing.T) { t.Parallel() result, ok := checkedMulI64(tt.lhs, tt.rhs) - testutilEquals(t, ok, tt.ok) - testutilEquals(t, result, tt.result) + testutil.Equals(t, ok, tt.ok) + testutil.Equals(t, result, tt.result) }) } } @@ -316,7 +318,7 @@ func TestCheckedMulI64(t *testing.T) { func TestCheckedNegI64(t *testing.T) { t.Parallel() tests := []struct { - arg, result Long + arg, result types.Long ok bool }{ {2, -2, true}, @@ -331,8 +333,8 @@ func TestCheckedNegI64(t *testing.T) { t.Run(fmt.Sprintf("-%v*=%v(%v)", tt.arg, tt.result, tt.ok), func(t *testing.T) { t.Parallel() result, ok := checkedNegI64(tt.arg) - testutilEquals(t, ok, tt.ok) - testutilEquals(t, result, tt.result) + testutil.Equals(t, ok, tt.ok) + testutil.Equals(t, result, tt.result) }) } } @@ -341,10 +343,10 @@ func TestAddNode(t *testing.T) { t.Parallel() t.Run("Basic", func(t *testing.T) { t.Parallel() - n := newAddEval(newLiteralEval(Long(1)), newLiteralEval(Long(2))) + n := newAddEval(newLiteralEval(types.Long(1)), newLiteralEval(types.Long(2))) v, err := n.Eval(&evalContext{}) - testutilOK(t, err) - assertLongValue(t, v, 3) + testutil.OK(t, err) + types.AssertLongValue(t, v, 3) }) tests := []struct { @@ -352,17 +354,17 @@ func TestAddNode(t *testing.T) { lhs, rhs evaler err error }{ - {"LhsError", newErrorEval(errTest), newLiteralEval(Long(0)), errTest}, - {"LhsTypeError", newLiteralEval(Boolean(true)), newLiteralEval(Long(0)), errType}, - {"RhsError", newLiteralEval(Long(0)), newErrorEval(errTest), errTest}, - {"RhsTypeError", newLiteralEval(Long(0)), newLiteralEval(Boolean(true)), errType}, + {"LhsError", newErrorEval(errTest), newLiteralEval(types.Long(0)), errTest}, + {"LhsTypeError", newLiteralEval(types.Boolean(true)), newLiteralEval(types.Long(0)), types.ErrType}, + {"RhsError", newLiteralEval(types.Long(0)), newErrorEval(errTest), errTest}, + {"RhsTypeError", newLiteralEval(types.Long(0)), newLiteralEval(types.Boolean(true)), types.ErrType}, {"PositiveOverflow", - newLiteralEval(Long(9_223_372_036_854_775_807)), - newLiteralEval(Long(1)), + newLiteralEval(types.Long(9_223_372_036_854_775_807)), + newLiteralEval(types.Long(1)), errOverflow}, {"NegativeOverflow", - newLiteralEval(Long(-9_223_372_036_854_775_808)), - newLiteralEval(Long(-1)), + newLiteralEval(types.Long(-9_223_372_036_854_775_808)), + newLiteralEval(types.Long(-1)), errOverflow}, } for _, tt := range tests { @@ -371,7 +373,7 @@ func TestAddNode(t *testing.T) { t.Parallel() n := newAddEval(tt.lhs, tt.rhs) _, err := n.Eval(&evalContext{}) - assertError(t, err, tt.err) + testutil.AssertError(t, err, tt.err) }) } } @@ -380,10 +382,10 @@ func TestSubtractNode(t *testing.T) { t.Parallel() t.Run("Basic", func(t *testing.T) { t.Parallel() - n := newSubtractEval(newLiteralEval(Long(1)), newLiteralEval(Long(2))) + n := newSubtractEval(newLiteralEval(types.Long(1)), newLiteralEval(types.Long(2))) v, err := n.Eval(&evalContext{}) - testutilOK(t, err) - assertLongValue(t, v, -1) + testutil.OK(t, err) + types.AssertLongValue(t, v, -1) }) tests := []struct { @@ -391,17 +393,17 @@ func TestSubtractNode(t *testing.T) { lhs, rhs evaler err error }{ - {"LhsError", newErrorEval(errTest), newLiteralEval(Long(0)), errTest}, - {"LhsTypeError", newLiteralEval(Boolean(true)), newLiteralEval(Long(0)), errType}, - {"RhsError", newLiteralEval(Long(0)), newErrorEval(errTest), errTest}, - {"RhsTypeError", newLiteralEval(Long(0)), newLiteralEval(Boolean(true)), errType}, + {"LhsError", newErrorEval(errTest), newLiteralEval(types.Long(0)), errTest}, + {"LhsTypeError", newLiteralEval(types.Boolean(true)), newLiteralEval(types.Long(0)), types.ErrType}, + {"RhsError", newLiteralEval(types.Long(0)), newErrorEval(errTest), errTest}, + {"RhsTypeError", newLiteralEval(types.Long(0)), newLiteralEval(types.Boolean(true)), types.ErrType}, {"PositiveOverflow", - newLiteralEval(Long(9_223_372_036_854_775_807)), - newLiteralEval(Long(-1)), + newLiteralEval(types.Long(9_223_372_036_854_775_807)), + newLiteralEval(types.Long(-1)), errOverflow}, {"NegativeOverflow", - newLiteralEval(Long(-9_223_372_036_854_775_808)), - newLiteralEval(Long(1)), + newLiteralEval(types.Long(-9_223_372_036_854_775_808)), + newLiteralEval(types.Long(1)), errOverflow}, } for _, tt := range tests { @@ -410,7 +412,7 @@ func TestSubtractNode(t *testing.T) { t.Parallel() n := newSubtractEval(tt.lhs, tt.rhs) _, err := n.Eval(&evalContext{}) - assertError(t, err, tt.err) + testutil.AssertError(t, err, tt.err) }) } } @@ -419,10 +421,10 @@ func TestMultiplyNode(t *testing.T) { t.Parallel() t.Run("Basic", func(t *testing.T) { t.Parallel() - n := newMultiplyEval(newLiteralEval(Long(-3)), newLiteralEval(Long(2))) + n := newMultiplyEval(newLiteralEval(types.Long(-3)), newLiteralEval(types.Long(2))) v, err := n.Eval(&evalContext{}) - testutilOK(t, err) - assertLongValue(t, v, -6) + testutil.OK(t, err) + types.AssertLongValue(t, v, -6) }) tests := []struct { @@ -430,17 +432,17 @@ func TestMultiplyNode(t *testing.T) { lhs, rhs evaler err error }{ - {"LhsError", newErrorEval(errTest), newLiteralEval(Long(0)), errTest}, - {"LhsTypeError", newLiteralEval(Boolean(true)), newLiteralEval(Long(0)), errType}, - {"RhsError", newLiteralEval(Long(0)), newErrorEval(errTest), errTest}, - {"RhsTypeError", newLiteralEval(Long(0)), newLiteralEval(Boolean(true)), errType}, + {"LhsError", newErrorEval(errTest), newLiteralEval(types.Long(0)), errTest}, + {"LhsTypeError", newLiteralEval(types.Boolean(true)), newLiteralEval(types.Long(0)), types.ErrType}, + {"RhsError", newLiteralEval(types.Long(0)), newErrorEval(errTest), errTest}, + {"RhsTypeError", newLiteralEval(types.Long(0)), newLiteralEval(types.Boolean(true)), types.ErrType}, {"PositiveOverflow", - newLiteralEval(Long(9_223_372_036_854_775_807)), - newLiteralEval(Long(2)), + newLiteralEval(types.Long(9_223_372_036_854_775_807)), + newLiteralEval(types.Long(2)), errOverflow}, {"NegativeOverflow", - newLiteralEval(Long(-9_223_372_036_854_775_808)), - newLiteralEval(Long(2)), + newLiteralEval(types.Long(-9_223_372_036_854_775_808)), + newLiteralEval(types.Long(2)), errOverflow}, } for _, tt := range tests { @@ -449,7 +451,7 @@ func TestMultiplyNode(t *testing.T) { t.Parallel() n := newMultiplyEval(tt.lhs, tt.rhs) _, err := n.Eval(&evalContext{}) - assertError(t, err, tt.err) + testutil.AssertError(t, err, tt.err) }) } } @@ -458,10 +460,10 @@ func TestNegateNode(t *testing.T) { t.Parallel() t.Run("Basic", func(t *testing.T) { t.Parallel() - n := newNegateEval(newLiteralEval(Long(-3))) + n := newNegateEval(newLiteralEval(types.Long(-3))) v, err := n.Eval(&evalContext{}) - testutilOK(t, err) - assertLongValue(t, v, 3) + testutil.OK(t, err) + types.AssertLongValue(t, v, 3) }) tests := []struct { @@ -470,8 +472,8 @@ func TestNegateNode(t *testing.T) { err error }{ {"Error", newErrorEval(errTest), errTest}, - {"TypeError", newLiteralEval(Boolean(true)), errType}, - {"Overflow", newLiteralEval(Long(-9_223_372_036_854_775_808)), errOverflow}, + {"TypeError", newLiteralEval(types.Boolean(true)), types.ErrType}, + {"Overflow", newLiteralEval(types.Long(-9_223_372_036_854_775_808)), errOverflow}, } for _, tt := range tests { tt := tt @@ -479,7 +481,7 @@ func TestNegateNode(t *testing.T) { t.Parallel() n := newNegateEval(tt.arg) _, err := n.Eval(&evalContext{}) - assertError(t, err, tt.err) + testutil.AssertError(t, err, tt.err) }) } } @@ -506,10 +508,10 @@ func TestLongLessThanNode(t *testing.T) { t.Run(fmt.Sprintf("%v<%v", tt.lhs, tt.rhs), func(t *testing.T) { t.Parallel() n := newLongLessThanEval( - newLiteralEval(Long(tt.lhs)), newLiteralEval(Long(tt.rhs))) + newLiteralEval(types.Long(tt.lhs)), newLiteralEval(types.Long(tt.rhs))) v, err := n.Eval(&evalContext{}) - testutilOK(t, err) - assertBoolValue(t, v, tt.result) + testutil.OK(t, err) + types.AssertBoolValue(t, v, tt.result) }) } } @@ -519,10 +521,10 @@ func TestLongLessThanNode(t *testing.T) { lhs, rhs evaler err error }{ - {"LhsError", newErrorEval(errTest), newLiteralEval(Long(0)), errTest}, - {"LhsTypeError", newLiteralEval(Boolean(true)), newLiteralEval(Long(0)), errType}, - {"RhsError", newLiteralEval(Long(0)), newErrorEval(errTest), errTest}, - {"RhsTypeError", newLiteralEval(Long(0)), newLiteralEval(Boolean(true)), errType}, + {"LhsError", newErrorEval(errTest), newLiteralEval(types.Long(0)), errTest}, + {"LhsTypeError", newLiteralEval(types.Boolean(true)), newLiteralEval(types.Long(0)), types.ErrType}, + {"RhsError", newLiteralEval(types.Long(0)), newErrorEval(errTest), errTest}, + {"RhsTypeError", newLiteralEval(types.Long(0)), newLiteralEval(types.Boolean(true)), types.ErrType}, } for _, tt := range tests { tt := tt @@ -530,7 +532,7 @@ func TestLongLessThanNode(t *testing.T) { t.Parallel() n := newLongLessThanEval(tt.lhs, tt.rhs) _, err := n.Eval(&evalContext{}) - assertError(t, err, tt.err) + testutil.AssertError(t, err, tt.err) }) } } @@ -558,10 +560,10 @@ func TestLongLessThanOrEqualNode(t *testing.T) { t.Run(fmt.Sprintf("%v<=%v", tt.lhs, tt.rhs), func(t *testing.T) { t.Parallel() n := newLongLessThanOrEqualEval( - newLiteralEval(Long(tt.lhs)), newLiteralEval(Long(tt.rhs))) + newLiteralEval(types.Long(tt.lhs)), newLiteralEval(types.Long(tt.rhs))) v, err := n.Eval(&evalContext{}) - testutilOK(t, err) - assertBoolValue(t, v, tt.result) + testutil.OK(t, err) + types.AssertBoolValue(t, v, tt.result) }) } } @@ -571,10 +573,10 @@ func TestLongLessThanOrEqualNode(t *testing.T) { lhs, rhs evaler err error }{ - {"LhsError", newErrorEval(errTest), newLiteralEval(Long(0)), errTest}, - {"LhsTypeError", newLiteralEval(Boolean(true)), newLiteralEval(Long(0)), errType}, - {"RhsError", newLiteralEval(Long(0)), newErrorEval(errTest), errTest}, - {"RhsTypeError", newLiteralEval(Long(0)), newLiteralEval(Boolean(true)), errType}, + {"LhsError", newErrorEval(errTest), newLiteralEval(types.Long(0)), errTest}, + {"LhsTypeError", newLiteralEval(types.Boolean(true)), newLiteralEval(types.Long(0)), types.ErrType}, + {"RhsError", newLiteralEval(types.Long(0)), newErrorEval(errTest), errTest}, + {"RhsTypeError", newLiteralEval(types.Long(0)), newLiteralEval(types.Boolean(true)), types.ErrType}, } for _, tt := range tests { tt := tt @@ -582,7 +584,7 @@ func TestLongLessThanOrEqualNode(t *testing.T) { t.Parallel() n := newLongLessThanOrEqualEval(tt.lhs, tt.rhs) _, err := n.Eval(&evalContext{}) - assertError(t, err, tt.err) + testutil.AssertError(t, err, tt.err) }) } } @@ -610,10 +612,10 @@ func TestLongGreaterThanNode(t *testing.T) { t.Run(fmt.Sprintf("%v>%v", tt.lhs, tt.rhs), func(t *testing.T) { t.Parallel() n := newLongGreaterThanEval( - newLiteralEval(Long(tt.lhs)), newLiteralEval(Long(tt.rhs))) + newLiteralEval(types.Long(tt.lhs)), newLiteralEval(types.Long(tt.rhs))) v, err := n.Eval(&evalContext{}) - testutilOK(t, err) - assertBoolValue(t, v, tt.result) + testutil.OK(t, err) + types.AssertBoolValue(t, v, tt.result) }) } } @@ -623,10 +625,10 @@ func TestLongGreaterThanNode(t *testing.T) { lhs, rhs evaler err error }{ - {"LhsError", newErrorEval(errTest), newLiteralEval(Long(0)), errTest}, - {"LhsTypeError", newLiteralEval(Boolean(true)), newLiteralEval(Long(0)), errType}, - {"RhsError", newLiteralEval(Long(0)), newErrorEval(errTest), errTest}, - {"RhsTypeError", newLiteralEval(Long(0)), newLiteralEval(Boolean(true)), errType}, + {"LhsError", newErrorEval(errTest), newLiteralEval(types.Long(0)), errTest}, + {"LhsTypeError", newLiteralEval(types.Boolean(true)), newLiteralEval(types.Long(0)), types.ErrType}, + {"RhsError", newLiteralEval(types.Long(0)), newErrorEval(errTest), errTest}, + {"RhsTypeError", newLiteralEval(types.Long(0)), newLiteralEval(types.Boolean(true)), types.ErrType}, } for _, tt := range tests { tt := tt @@ -634,7 +636,7 @@ func TestLongGreaterThanNode(t *testing.T) { t.Parallel() n := newLongGreaterThanEval(tt.lhs, tt.rhs) _, err := n.Eval(&evalContext{}) - assertError(t, err, tt.err) + testutil.AssertError(t, err, tt.err) }) } } @@ -662,10 +664,10 @@ func TestLongGreaterThanOrEqualNode(t *testing.T) { t.Run(fmt.Sprintf("%v>=%v", tt.lhs, tt.rhs), func(t *testing.T) { t.Parallel() n := newLongGreaterThanOrEqualEval( - newLiteralEval(Long(tt.lhs)), newLiteralEval(Long(tt.rhs))) + newLiteralEval(types.Long(tt.lhs)), newLiteralEval(types.Long(tt.rhs))) v, err := n.Eval(&evalContext{}) - testutilOK(t, err) - assertBoolValue(t, v, tt.result) + testutil.OK(t, err) + types.AssertBoolValue(t, v, tt.result) }) } } @@ -675,10 +677,10 @@ func TestLongGreaterThanOrEqualNode(t *testing.T) { lhs, rhs evaler err error }{ - {"LhsError", newErrorEval(errTest), newLiteralEval(Long(0)), errTest}, - {"LhsTypeError", newLiteralEval(Boolean(true)), newLiteralEval(Long(0)), errType}, - {"RhsError", newLiteralEval(Long(0)), newErrorEval(errTest), errTest}, - {"RhsTypeError", newLiteralEval(Long(0)), newLiteralEval(Boolean(true)), errType}, + {"LhsError", newErrorEval(errTest), newLiteralEval(types.Long(0)), errTest}, + {"LhsTypeError", newLiteralEval(types.Boolean(true)), newLiteralEval(types.Long(0)), types.ErrType}, + {"RhsError", newLiteralEval(types.Long(0)), newErrorEval(errTest), errTest}, + {"RhsTypeError", newLiteralEval(types.Long(0)), newLiteralEval(types.Boolean(true)), types.ErrType}, } for _, tt := range tests { tt := tt @@ -686,7 +688,7 @@ func TestLongGreaterThanOrEqualNode(t *testing.T) { t.Parallel() n := newLongGreaterThanOrEqualEval(tt.lhs, tt.rhs) _, err := n.Eval(&evalContext{}) - assertError(t, err, tt.err) + testutil.AssertError(t, err, tt.err) }) } } @@ -713,16 +715,16 @@ func TestDecimalLessThanNode(t *testing.T) { tt := tt t.Run(fmt.Sprintf("%s<%s", tt.lhs, tt.rhs), func(t *testing.T) { t.Parallel() - lhsd, err := ParseDecimal(tt.lhs) - testutilOK(t, err) + lhsd, err := types.ParseDecimal(tt.lhs) + testutil.OK(t, err) lhsv := lhsd - rhsd, err := ParseDecimal(tt.rhs) - testutilOK(t, err) + rhsd, err := types.ParseDecimal(tt.rhs) + testutil.OK(t, err) rhsv := rhsd n := newDecimalLessThanEval(newLiteralEval(lhsv), newLiteralEval(rhsv)) v, err := n.Eval(&evalContext{}) - testutilOK(t, err) - assertBoolValue(t, v, tt.result) + testutil.OK(t, err) + types.AssertBoolValue(t, v, tt.result) }) } } @@ -732,10 +734,10 @@ func TestDecimalLessThanNode(t *testing.T) { lhs, rhs evaler err error }{ - {"LhsError", newErrorEval(errTest), newLiteralEval(Decimal(0)), errTest}, - {"LhsTypeError", newLiteralEval(Boolean(true)), newLiteralEval(Decimal(0)), errType}, - {"RhsError", newLiteralEval(Decimal(0)), newErrorEval(errTest), errTest}, - {"RhsTypeError", newLiteralEval(Decimal(0)), newLiteralEval(Boolean(true)), errType}, + {"LhsError", newErrorEval(errTest), newLiteralEval(types.Decimal(0)), errTest}, + {"LhsTypeError", newLiteralEval(types.Boolean(true)), newLiteralEval(types.Decimal(0)), types.ErrType}, + {"RhsError", newLiteralEval(types.Decimal(0)), newErrorEval(errTest), errTest}, + {"RhsTypeError", newLiteralEval(types.Decimal(0)), newLiteralEval(types.Boolean(true)), types.ErrType}, } for _, tt := range tests { tt := tt @@ -743,7 +745,7 @@ func TestDecimalLessThanNode(t *testing.T) { t.Parallel() n := newDecimalLessThanEval(tt.lhs, tt.rhs) _, err := n.Eval(&evalContext{}) - assertError(t, err, tt.err) + testutil.AssertError(t, err, tt.err) }) } } @@ -770,16 +772,16 @@ func TestDecimalLessThanOrEqualNode(t *testing.T) { tt := tt t.Run(fmt.Sprintf("%s<=%s", tt.lhs, tt.rhs), func(t *testing.T) { t.Parallel() - lhsd, err := ParseDecimal(tt.lhs) - testutilOK(t, err) + lhsd, err := types.ParseDecimal(tt.lhs) + testutil.OK(t, err) lhsv := lhsd - rhsd, err := ParseDecimal(tt.rhs) - testutilOK(t, err) + rhsd, err := types.ParseDecimal(tt.rhs) + testutil.OK(t, err) rhsv := rhsd n := newDecimalLessThanOrEqualEval(newLiteralEval(lhsv), newLiteralEval(rhsv)) v, err := n.Eval(&evalContext{}) - testutilOK(t, err) - assertBoolValue(t, v, tt.result) + testutil.OK(t, err) + types.AssertBoolValue(t, v, tt.result) }) } } @@ -789,10 +791,10 @@ func TestDecimalLessThanOrEqualNode(t *testing.T) { lhs, rhs evaler err error }{ - {"LhsError", newErrorEval(errTest), newLiteralEval(Decimal(0)), errTest}, - {"LhsTypeError", newLiteralEval(Boolean(true)), newLiteralEval(Decimal(0)), errType}, - {"RhsError", newLiteralEval(Decimal(0)), newErrorEval(errTest), errTest}, - {"RhsTypeError", newLiteralEval(Decimal(0)), newLiteralEval(Boolean(true)), errType}, + {"LhsError", newErrorEval(errTest), newLiteralEval(types.Decimal(0)), errTest}, + {"LhsTypeError", newLiteralEval(types.Boolean(true)), newLiteralEval(types.Decimal(0)), types.ErrType}, + {"RhsError", newLiteralEval(types.Decimal(0)), newErrorEval(errTest), errTest}, + {"RhsTypeError", newLiteralEval(types.Decimal(0)), newLiteralEval(types.Boolean(true)), types.ErrType}, } for _, tt := range tests { tt := tt @@ -800,7 +802,7 @@ func TestDecimalLessThanOrEqualNode(t *testing.T) { t.Parallel() n := newDecimalLessThanOrEqualEval(tt.lhs, tt.rhs) _, err := n.Eval(&evalContext{}) - assertError(t, err, tt.err) + testutil.AssertError(t, err, tt.err) }) } } @@ -827,16 +829,16 @@ func TestDecimalGreaterThanNode(t *testing.T) { tt := tt t.Run(fmt.Sprintf("%s>%s", tt.lhs, tt.rhs), func(t *testing.T) { t.Parallel() - lhsd, err := ParseDecimal(tt.lhs) - testutilOK(t, err) + lhsd, err := types.ParseDecimal(tt.lhs) + testutil.OK(t, err) lhsv := lhsd - rhsd, err := ParseDecimal(tt.rhs) - testutilOK(t, err) + rhsd, err := types.ParseDecimal(tt.rhs) + testutil.OK(t, err) rhsv := rhsd n := newDecimalGreaterThanEval(newLiteralEval(lhsv), newLiteralEval(rhsv)) v, err := n.Eval(&evalContext{}) - testutilOK(t, err) - assertBoolValue(t, v, tt.result) + testutil.OK(t, err) + types.AssertBoolValue(t, v, tt.result) }) } } @@ -846,10 +848,10 @@ func TestDecimalGreaterThanNode(t *testing.T) { lhs, rhs evaler err error }{ - {"LhsError", newErrorEval(errTest), newLiteralEval(Decimal(0)), errTest}, - {"LhsTypeError", newLiteralEval(Boolean(true)), newLiteralEval(Decimal(0)), errType}, - {"RhsError", newLiteralEval(Decimal(0)), newErrorEval(errTest), errTest}, - {"RhsTypeError", newLiteralEval(Decimal(0)), newLiteralEval(Boolean(true)), errType}, + {"LhsError", newErrorEval(errTest), newLiteralEval(types.Decimal(0)), errTest}, + {"LhsTypeError", newLiteralEval(types.Boolean(true)), newLiteralEval(types.Decimal(0)), types.ErrType}, + {"RhsError", newLiteralEval(types.Decimal(0)), newErrorEval(errTest), errTest}, + {"RhsTypeError", newLiteralEval(types.Decimal(0)), newLiteralEval(types.Boolean(true)), types.ErrType}, } for _, tt := range tests { tt := tt @@ -857,7 +859,7 @@ func TestDecimalGreaterThanNode(t *testing.T) { t.Parallel() n := newDecimalGreaterThanEval(tt.lhs, tt.rhs) _, err := n.Eval(&evalContext{}) - assertError(t, err, tt.err) + testutil.AssertError(t, err, tt.err) }) } } @@ -884,16 +886,16 @@ func TestDecimalGreaterThanOrEqualNode(t *testing.T) { tt := tt t.Run(fmt.Sprintf("%s>=%s", tt.lhs, tt.rhs), func(t *testing.T) { t.Parallel() - lhsd, err := ParseDecimal(tt.lhs) - testutilOK(t, err) + lhsd, err := types.ParseDecimal(tt.lhs) + testutil.OK(t, err) lhsv := lhsd - rhsd, err := ParseDecimal(tt.rhs) - testutilOK(t, err) + rhsd, err := types.ParseDecimal(tt.rhs) + testutil.OK(t, err) rhsv := rhsd n := newDecimalGreaterThanOrEqualEval(newLiteralEval(lhsv), newLiteralEval(rhsv)) v, err := n.Eval(&evalContext{}) - testutilOK(t, err) - assertBoolValue(t, v, tt.result) + testutil.OK(t, err) + types.AssertBoolValue(t, v, tt.result) }) } } @@ -903,10 +905,10 @@ func TestDecimalGreaterThanOrEqualNode(t *testing.T) { lhs, rhs evaler err error }{ - {"LhsError", newErrorEval(errTest), newLiteralEval(Decimal(0)), errTest}, - {"LhsTypeError", newLiteralEval(Boolean(true)), newLiteralEval(Decimal(0)), errType}, - {"RhsError", newLiteralEval(Decimal(0)), newErrorEval(errTest), errTest}, - {"RhsTypeError", newLiteralEval(Decimal(0)), newLiteralEval(Boolean(true)), errType}, + {"LhsError", newErrorEval(errTest), newLiteralEval(types.Decimal(0)), errTest}, + {"LhsTypeError", newLiteralEval(types.Boolean(true)), newLiteralEval(types.Decimal(0)), types.ErrType}, + {"RhsError", newLiteralEval(types.Decimal(0)), newErrorEval(errTest), errTest}, + {"RhsTypeError", newLiteralEval(types.Decimal(0)), newLiteralEval(types.Boolean(true)), types.ErrType}, } for _, tt := range tests { tt := tt @@ -914,7 +916,7 @@ func TestDecimalGreaterThanOrEqualNode(t *testing.T) { t.Parallel() n := newDecimalGreaterThanOrEqualEval(tt.lhs, tt.rhs) _, err := n.Eval(&evalContext{}) - assertError(t, err, tt.err) + testutil.AssertError(t, err, tt.err) }) } } @@ -925,19 +927,19 @@ func TestIfThenElseNode(t *testing.T) { tests := []struct { name string if_, then, else_ evaler - result Value + result types.Value err error }{ - {"Then", newLiteralEval(Boolean(true)), newLiteralEval(Long(42)), - newLiteralEval(Long(-1)), Long(42), + {"Then", newLiteralEval(types.Boolean(true)), newLiteralEval(types.Long(42)), + newLiteralEval(types.Long(-1)), types.Long(42), nil}, - {"Else", newLiteralEval(Boolean(false)), newLiteralEval(Long(-1)), - newLiteralEval(Long(42)), Long(42), + {"Else", newLiteralEval(types.Boolean(false)), newLiteralEval(types.Long(-1)), + newLiteralEval(types.Long(42)), types.Long(42), nil}, - {"Err", newErrorEval(errTest), newLiteralEval(zeroValue()), newLiteralEval(zeroValue()), zeroValue(), + {"Err", newErrorEval(errTest), newLiteralEval(types.ZeroValue()), newLiteralEval(types.ZeroValue()), types.ZeroValue(), errTest}, - {"ErrType", newLiteralEval(Long(123)), newLiteralEval(zeroValue()), newLiteralEval(zeroValue()), zeroValue(), - errType}, + {"ErrType", newLiteralEval(types.Long(123)), newLiteralEval(types.ZeroValue()), newLiteralEval(types.ZeroValue()), types.ZeroValue(), + types.ErrType}, } for _, tt := range tests { tt := tt @@ -945,8 +947,8 @@ func TestIfThenElseNode(t *testing.T) { t.Parallel() n := newIfThenElseEval(tt.if_, tt.then, tt.else_) v, err := n.Eval(&evalContext{}) - assertError(t, err, tt.err) - testutilEquals(t, v, tt.result) + testutil.AssertError(t, err, tt.err) + testutil.Equals(t, v, tt.result) }) } } @@ -956,14 +958,14 @@ func TestEqualNode(t *testing.T) { tests := []struct { name string lhs, rhs evaler - result Value + result types.Value err error }{ - {"equals", newLiteralEval(Long(42)), newLiteralEval(Long(42)), Boolean(true), nil}, - {"notEquals", newLiteralEval(Long(42)), newLiteralEval(Long(1234)), Boolean(false), nil}, - {"leftErr", newErrorEval(errTest), newLiteralEval(zeroValue()), zeroValue(), errTest}, - {"rightErr", newLiteralEval(zeroValue()), newErrorEval(errTest), zeroValue(), errTest}, - {"typesNotEqual", newLiteralEval(Long(1)), newLiteralEval(Boolean(true)), Boolean(false), nil}, + {"equals", newLiteralEval(types.Long(42)), newLiteralEval(types.Long(42)), types.Boolean(true), nil}, + {"notEquals", newLiteralEval(types.Long(42)), newLiteralEval(types.Long(1234)), types.Boolean(false), nil}, + {"leftErr", newErrorEval(errTest), newLiteralEval(types.ZeroValue()), types.ZeroValue(), errTest}, + {"rightErr", newLiteralEval(types.ZeroValue()), newErrorEval(errTest), types.ZeroValue(), errTest}, + {"typesNotEqual", newLiteralEval(types.Long(1)), newLiteralEval(types.Boolean(true)), types.Boolean(false), nil}, } for _, tt := range tests { tt := tt @@ -971,8 +973,8 @@ func TestEqualNode(t *testing.T) { t.Parallel() n := newEqualEval(tt.lhs, tt.rhs) v, err := n.Eval(&evalContext{}) - assertError(t, err, tt.err) - assertValue(t, v, tt.result) + testutil.AssertError(t, err, tt.err) + types.AssertValue(t, v, tt.result) }) } } @@ -982,14 +984,14 @@ func TestNotEqualNode(t *testing.T) { tests := []struct { name string lhs, rhs evaler - result Value + result types.Value err error }{ - {"equals", newLiteralEval(Long(42)), newLiteralEval(Long(42)), Boolean(false), nil}, - {"notEquals", newLiteralEval(Long(42)), newLiteralEval(Long(1234)), Boolean(true), nil}, - {"leftErr", newErrorEval(errTest), newLiteralEval(zeroValue()), zeroValue(), errTest}, - {"rightErr", newLiteralEval(zeroValue()), newErrorEval(errTest), zeroValue(), errTest}, - {"typesNotEqual", newLiteralEval(Long(1)), newLiteralEval(Boolean(true)), Boolean(true), nil}, + {"equals", newLiteralEval(types.Long(42)), newLiteralEval(types.Long(42)), types.Boolean(false), nil}, + {"notEquals", newLiteralEval(types.Long(42)), newLiteralEval(types.Long(1234)), types.Boolean(true), nil}, + {"leftErr", newErrorEval(errTest), newLiteralEval(types.ZeroValue()), types.ZeroValue(), errTest}, + {"rightErr", newLiteralEval(types.ZeroValue()), newErrorEval(errTest), types.ZeroValue(), errTest}, + {"typesNotEqual", newLiteralEval(types.Long(1)), newLiteralEval(types.Boolean(true)), types.Boolean(true), nil}, } for _, tt := range tests { tt := tt @@ -997,8 +999,8 @@ func TestNotEqualNode(t *testing.T) { t.Parallel() n := newNotEqualEval(tt.lhs, tt.rhs) v, err := n.Eval(&evalContext{}) - assertError(t, err, tt.err) - assertValue(t, v, tt.result) + testutil.AssertError(t, err, tt.err) + types.AssertValue(t, v, tt.result) }) } } @@ -1008,27 +1010,27 @@ func TestSetLiteralNode(t *testing.T) { tests := []struct { name string elems []evaler - result Value + result types.Value err error }{ - {"empty", []evaler{}, Set{}, nil}, - {"errorNode", []evaler{newErrorEval(errTest)}, zeroValue(), errTest}, + {"empty", []evaler{}, types.Set{}, nil}, + {"errorNode", []evaler{newErrorEval(errTest)}, types.ZeroValue(), errTest}, {"nested", []evaler{ - newLiteralEval(Boolean(true)), - newLiteralEval(Set{ - Boolean(false), - Long(1), + newLiteralEval(types.Boolean(true)), + newLiteralEval(types.Set{ + types.Boolean(false), + types.Long(1), }), - newLiteralEval(Long(10)), + newLiteralEval(types.Long(10)), }, - Set{ - Boolean(true), - Set{ - Boolean(false), - Long(1), + types.Set{ + types.Boolean(true), + types.Set{ + types.Boolean(false), + types.Long(1), }, - Long(10), + types.Long(10), }, nil}, } @@ -1038,8 +1040,8 @@ func TestSetLiteralNode(t *testing.T) { t.Parallel() n := newSetLiteralEval(tt.elems) v, err := n.Eval(&evalContext{}) - assertError(t, err, tt.err) - assertValue(t, v, tt.result) + testutil.AssertError(t, err, tt.err) + types.AssertValue(t, v, tt.result) }) } } @@ -1052,9 +1054,9 @@ func TestContainsNode(t *testing.T) { lhs, rhs evaler err error }{ - {"LhsError", newErrorEval(errTest), newLiteralEval(Long(0)), errTest}, - {"LhsTypeError", newLiteralEval(Boolean(true)), newLiteralEval(Long(0)), errType}, - {"RhsError", newLiteralEval(Set{}), newErrorEval(errTest), errTest}, + {"LhsError", newErrorEval(errTest), newLiteralEval(types.Long(0)), errTest}, + {"LhsTypeError", newLiteralEval(types.Boolean(true)), newLiteralEval(types.Long(0)), types.ErrType}, + {"RhsError", newLiteralEval(types.Set{}), newErrorEval(errTest), errTest}, } for _, tt := range tests { tt := tt @@ -1062,28 +1064,28 @@ func TestContainsNode(t *testing.T) { t.Parallel() n := newContainsEval(tt.lhs, tt.rhs) v, err := n.Eval(&evalContext{}) - assertError(t, err, tt.err) - assertZeroValue(t, v) + testutil.AssertError(t, err, tt.err) + types.AssertZeroValue(t, v) }) } } { - empty := Set{} - trueAndOne := Set{Boolean(true), Long(1)} - nested := Set{trueAndOne, Boolean(false), Long(2)} + empty := types.Set{} + trueAndOne := types.Set{types.Boolean(true), types.Long(1)} + nested := types.Set{trueAndOne, types.Boolean(false), types.Long(2)} tests := []struct { name string lhs, rhs evaler result bool }{ - {"empty", newLiteralEval(empty), newLiteralEval(Boolean(true)), false}, - {"trueAndOneContainsTrue", newLiteralEval(trueAndOne), newLiteralEval(Boolean(true)), true}, - {"trueAndOneContainsOne", newLiteralEval(trueAndOne), newLiteralEval(Long(1)), true}, - {"trueAndOneDoesNotContainTwo", newLiteralEval(trueAndOne), newLiteralEval(Long(2)), false}, - {"nestedContainsFalse", newLiteralEval(nested), newLiteralEval(Boolean(false)), true}, + {"empty", newLiteralEval(empty), newLiteralEval(types.Boolean(true)), false}, + {"trueAndOneContainsTrue", newLiteralEval(trueAndOne), newLiteralEval(types.Boolean(true)), true}, + {"trueAndOneContainsOne", newLiteralEval(trueAndOne), newLiteralEval(types.Long(1)), true}, + {"trueAndOneDoesNotContainTwo", newLiteralEval(trueAndOne), newLiteralEval(types.Long(2)), false}, + {"nestedContainsFalse", newLiteralEval(nested), newLiteralEval(types.Boolean(false)), true}, {"nestedContainsSet", newLiteralEval(nested), newLiteralEval(trueAndOne), true}, - {"nestedDoesNotContainTrue", newLiteralEval(nested), newLiteralEval(Boolean(true)), false}, + {"nestedDoesNotContainTrue", newLiteralEval(nested), newLiteralEval(types.Boolean(true)), false}, } for _, tt := range tests { tt := tt @@ -1091,8 +1093,8 @@ func TestContainsNode(t *testing.T) { t.Parallel() n := newContainsEval(tt.lhs, tt.rhs) v, err := n.Eval(&evalContext{}) - testutilOK(t, err) - assertBoolValue(t, v, tt.result) + testutil.OK(t, err) + types.AssertBoolValue(t, v, tt.result) }) } } @@ -1106,10 +1108,10 @@ func TestContainsAllNode(t *testing.T) { lhs, rhs evaler err error }{ - {"LhsError", newErrorEval(errTest), newLiteralEval(Set{}), errTest}, - {"LhsTypeError", newLiteralEval(Boolean(true)), newLiteralEval(Set{}), errType}, - {"RhsError", newLiteralEval(Set{}), newErrorEval(errTest), errTest}, - {"RhsTypeError", newLiteralEval(Set{}), newLiteralEval(Long(0)), errType}, + {"LhsError", newErrorEval(errTest), newLiteralEval(types.Set{}), errTest}, + {"LhsTypeError", newLiteralEval(types.Boolean(true)), newLiteralEval(types.Set{}), types.ErrType}, + {"RhsError", newLiteralEval(types.Set{}), newErrorEval(errTest), errTest}, + {"RhsTypeError", newLiteralEval(types.Set{}), newLiteralEval(types.Long(0)), types.ErrType}, } for _, tt := range tests { tt := tt @@ -1117,16 +1119,16 @@ func TestContainsAllNode(t *testing.T) { t.Parallel() n := newContainsAllEval(tt.lhs, tt.rhs) v, err := n.Eval(&evalContext{}) - assertError(t, err, tt.err) - assertZeroValue(t, v) + testutil.AssertError(t, err, tt.err) + types.AssertZeroValue(t, v) }) } } { - empty := Set{} - trueOnly := Set{Boolean(true)} - trueAndOne := Set{Boolean(true), Long(1)} - nested := Set{trueAndOne, Boolean(false), Long(2)} + empty := types.Set{} + trueOnly := types.Set{types.Boolean(true)} + trueAndOne := types.Set{types.Boolean(true), types.Long(1)} + nested := types.Set{trueAndOne, types.Boolean(false), types.Long(2)} tests := []struct { name string @@ -1145,8 +1147,8 @@ func TestContainsAllNode(t *testing.T) { t.Parallel() n := newContainsAllEval(tt.lhs, tt.rhs) v, err := n.Eval(&evalContext{}) - testutilOK(t, err) - assertBoolValue(t, v, tt.result) + testutil.OK(t, err) + types.AssertBoolValue(t, v, tt.result) }) } } @@ -1160,10 +1162,10 @@ func TestContainsAnyNode(t *testing.T) { lhs, rhs evaler err error }{ - {"LhsError", newErrorEval(errTest), newLiteralEval(Set{}), errTest}, - {"LhsTypeError", newLiteralEval(Boolean(true)), newLiteralEval(Set{}), errType}, - {"RhsError", newLiteralEval(Set{}), newErrorEval(errTest), errTest}, - {"RhsTypeError", newLiteralEval(Set{}), newLiteralEval(Long(0)), errType}, + {"LhsError", newErrorEval(errTest), newLiteralEval(types.Set{}), errTest}, + {"LhsTypeError", newLiteralEval(types.Boolean(true)), newLiteralEval(types.Set{}), types.ErrType}, + {"RhsError", newLiteralEval(types.Set{}), newErrorEval(errTest), errTest}, + {"RhsTypeError", newLiteralEval(types.Set{}), newLiteralEval(types.Long(0)), types.ErrType}, } for _, tt := range tests { tt := tt @@ -1171,17 +1173,17 @@ func TestContainsAnyNode(t *testing.T) { t.Parallel() n := newContainsAnyEval(tt.lhs, tt.rhs) v, err := n.Eval(&evalContext{}) - assertError(t, err, tt.err) - assertZeroValue(t, v) + testutil.AssertError(t, err, tt.err) + types.AssertZeroValue(t, v) }) } } { - empty := Set{} - trueOnly := Set{Boolean(true)} - trueAndOne := Set{Boolean(true), Long(1)} - trueAndTwo := Set{Boolean(true), Long(2)} - nested := Set{trueAndOne, Boolean(false), Long(2)} + empty := types.Set{} + trueOnly := types.Set{types.Boolean(true)} + trueAndOne := types.Set{types.Boolean(true), types.Long(1)} + trueAndTwo := types.Set{types.Boolean(true), types.Long(2)} + nested := types.Set{trueAndOne, types.Boolean(false), types.Long(2)} tests := []struct { name string @@ -1202,8 +1204,8 @@ func TestContainsAnyNode(t *testing.T) { t.Parallel() n := newContainsAnyEval(tt.lhs, tt.rhs) v, err := n.Eval(&evalContext{}) - testutilOK(t, err) - assertBoolValue(t, v, tt.result) + testutil.OK(t, err) + types.AssertBoolValue(t, v, tt.result) }) } } @@ -1214,18 +1216,18 @@ func TestRecordLiteralNode(t *testing.T) { tests := []struct { name string elems map[string]evaler - result Value + result types.Value err error }{ - {"empty", map[string]evaler{}, Record{}, nil}, - {"errorNode", map[string]evaler{"foo": newErrorEval(errTest)}, zeroValue(), errTest}, + {"empty", map[string]evaler{}, types.Record{}, nil}, + {"errorNode", map[string]evaler{"foo": newErrorEval(errTest)}, types.ZeroValue(), errTest}, {"ok", map[string]evaler{ - "foo": newLiteralEval(Boolean(true)), - "bar": newLiteralEval(String("baz")), - }, Record{ - "foo": Boolean(true), - "bar": String("baz"), + "foo": newLiteralEval(types.Boolean(true)), + "bar": newLiteralEval(types.String("baz")), + }, types.Record{ + "foo": types.Boolean(true), + "bar": types.String("baz"), }, nil}, } for _, tt := range tests { @@ -1234,8 +1236,8 @@ func TestRecordLiteralNode(t *testing.T) { t.Parallel() n := newRecordLiteralEval(tt.elems) v, err := n.Eval(&evalContext{}) - assertError(t, err, tt.err) - assertValue(t, v, tt.result) + testutil.AssertError(t, err, tt.err) + types.AssertValue(t, v, tt.result) }) } } @@ -1246,35 +1248,35 @@ func TestAttributeAccessNode(t *testing.T) { name string object evaler attribute string - result Value + result types.Value err error }{ - {"RecordError", newErrorEval(errTest), "foo", zeroValue(), errTest}, - {"RecordTypeError", newLiteralEval(Boolean(true)), "foo", zeroValue(), errType}, + {"RecordError", newErrorEval(errTest), "foo", types.ZeroValue(), errTest}, + {"RecordTypeError", newLiteralEval(types.Boolean(true)), "foo", types.ZeroValue(), types.ErrType}, {"UnknownAttribute", - newLiteralEval(Record{}), + newLiteralEval(types.Record{}), "foo", - zeroValue(), + types.ZeroValue(), errAttributeAccess}, {"KnownAttribute", - newLiteralEval(Record{"foo": Long(42)}), + newLiteralEval(types.Record{"foo": types.Long(42)}), "foo", - Long(42), + types.Long(42), nil}, {"KnownAttributeOnEntity", - newLiteralEval(EntityUID{"knownType", "knownID"}), + newLiteralEval(types.EntityUID{"knownType", "knownID"}), "knownAttr", - Long(42), + types.Long(42), nil}, {"UnknownEntity", - newLiteralEval(EntityUID{"unknownType", "unknownID"}), + newLiteralEval(types.EntityUID{"unknownType", "unknownID"}), "unknownAttr", - zeroValue(), + types.ZeroValue(), errEntityNotExist}, {"UnspecifiedEntity", - newLiteralEval(EntityUID{"", ""}), + newLiteralEval(types.EntityUID{"", ""}), "knownAttr", - zeroValue(), + types.ZeroValue(), errUnspecifiedEntity}, } for _, tt := range tests { @@ -1285,13 +1287,13 @@ func TestAttributeAccessNode(t *testing.T) { v, err := n.Eval(&evalContext{ Entities: entitiesFromSlice([]Entity{ { - UID: NewEntityUID("knownType", "knownID"), - Attributes: Record{"knownAttr": Long(42)}, + UID: types.NewEntityUID("knownType", "knownID"), + Attributes: types.Record{"knownAttr": types.Long(42)}, }, }), }) - assertError(t, err, tt.err) - assertValue(t, v, tt.result) + testutil.AssertError(t, err, tt.err) + types.AssertValue(t, v, tt.result) }) } } @@ -1302,35 +1304,35 @@ func TestHasNode(t *testing.T) { name string record evaler attribute string - result Value + result types.Value err error }{ - {"RecordError", newErrorEval(errTest), "foo", zeroValue(), errTest}, - {"RecordTypeError", newLiteralEval(Boolean(true)), "foo", zeroValue(), errType}, + {"RecordError", newErrorEval(errTest), "foo", types.ZeroValue(), errTest}, + {"RecordTypeError", newLiteralEval(types.Boolean(true)), "foo", types.ZeroValue(), types.ErrType}, {"UnknownAttribute", - newLiteralEval(Record{}), + newLiteralEval(types.Record{}), "foo", - Boolean(false), + types.Boolean(false), nil}, {"KnownAttribute", - newLiteralEval(Record{"foo": Long(42)}), + newLiteralEval(types.Record{"foo": types.Long(42)}), "foo", - Boolean(true), + types.Boolean(true), nil}, {"KnownAttributeOnEntity", - newLiteralEval(EntityUID{"knownType", "knownID"}), + newLiteralEval(types.EntityUID{"knownType", "knownID"}), "knownAttr", - Boolean(true), + types.Boolean(true), nil}, {"UnknownAttributeOnEntity", - newLiteralEval(EntityUID{"knownType", "knownID"}), + newLiteralEval(types.EntityUID{"knownType", "knownID"}), "unknownAttr", - Boolean(false), + types.Boolean(false), nil}, {"UnknownEntity", - newLiteralEval(EntityUID{"unknownType", "unknownID"}), + newLiteralEval(types.EntityUID{"unknownType", "unknownID"}), "unknownAttr", - Boolean(false), + types.Boolean(false), nil}, } for _, tt := range tests { @@ -1341,13 +1343,13 @@ func TestHasNode(t *testing.T) { v, err := n.Eval(&evalContext{ Entities: entitiesFromSlice([]Entity{ { - UID: NewEntityUID("knownType", "knownID"), - Attributes: Record{"knownAttr": Long(42)}, + UID: types.NewEntityUID("knownType", "knownID"), + Attributes: types.Record{"knownAttr": types.Long(42)}, }, }), }) - assertError(t, err, tt.err) - assertValue(t, v, tt.result) + testutil.AssertError(t, err, tt.err) + types.AssertValue(t, v, tt.result) }) } } @@ -1358,50 +1360,50 @@ func TestLikeNode(t *testing.T) { name string str evaler pattern string - result Value + result types.Value err error }{ - {"leftError", newErrorEval(errTest), `"foo"`, zeroValue(), errTest}, - {"leftTypeError", newLiteralEval(Boolean(true)), `"foo"`, zeroValue(), errType}, - {"noMatch", newLiteralEval(String("test")), `"zebra"`, Boolean(false), nil}, - {"match", newLiteralEval(String("test")), `"*es*"`, Boolean(true), nil}, - - {"case-1", newLiteralEval(String("eggs")), `"ham*"`, Boolean(false), nil}, - {"case-2", newLiteralEval(String("eggs")), `"*ham"`, Boolean(false), nil}, - {"case-3", newLiteralEval(String("eggs")), `"*ham*"`, Boolean(false), nil}, - {"case-4", newLiteralEval(String("ham and eggs")), `"ham*"`, Boolean(true), nil}, - {"case-5", newLiteralEval(String("ham and eggs")), `"*ham"`, Boolean(false), nil}, - {"case-6", newLiteralEval(String("ham and eggs")), `"*ham*"`, Boolean(true), nil}, - {"case-7", newLiteralEval(String("ham and eggs")), `"*h*a*m*"`, Boolean(true), nil}, - {"case-8", newLiteralEval(String("eggs and ham")), `"ham*"`, Boolean(false), nil}, - {"case-9", newLiteralEval(String("eggs and ham")), `"*ham"`, Boolean(true), nil}, - {"case-10", newLiteralEval(String("eggs, ham, and spinach")), `"ham*"`, Boolean(false), nil}, - {"case-11", newLiteralEval(String("eggs, ham, and spinach")), `"*ham"`, Boolean(false), nil}, - {"case-12", newLiteralEval(String("eggs, ham, and spinach")), `"*ham*"`, Boolean(true), nil}, - {"case-13", newLiteralEval(String("Gotham")), `"ham*"`, Boolean(false), nil}, - {"case-14", newLiteralEval(String("Gotham")), `"*ham"`, Boolean(true), nil}, - {"case-15", newLiteralEval(String("ham")), `"ham"`, Boolean(true), nil}, - {"case-16", newLiteralEval(String("ham")), `"ham*"`, Boolean(true), nil}, - {"case-17", newLiteralEval(String("ham")), `"*ham"`, Boolean(true), nil}, - {"case-18", newLiteralEval(String("ham")), `"*h*a*m*"`, Boolean(true), nil}, - {"case-19", newLiteralEval(String("ham and ham")), `"ham*"`, Boolean(true), nil}, - {"case-20", newLiteralEval(String("ham and ham")), `"*ham"`, Boolean(true), nil}, - {"case-21", newLiteralEval(String("ham")), `"*ham and eggs*"`, Boolean(false), nil}, - {"case-22", newLiteralEval(String("\\afterslash")), `"\\*"`, Boolean(true), nil}, - {"case-23", newLiteralEval(String("string\\with\\backslashes")), `"string\\with\\backslashes"`, Boolean(true), nil}, - {"case-24", newLiteralEval(String("string\\with\\backslashes")), `"string*with*backslashes"`, Boolean(true), nil}, - {"case-25", newLiteralEval(String("string*with*stars")), `"string\*with\*stars"`, Boolean(true), nil}, + {"leftError", newErrorEval(errTest), `"foo"`, types.ZeroValue(), errTest}, + {"leftTypeError", newLiteralEval(types.Boolean(true)), `"foo"`, types.ZeroValue(), types.ErrType}, + {"noMatch", newLiteralEval(types.String("test")), `"zebra"`, types.Boolean(false), nil}, + {"match", newLiteralEval(types.String("test")), `"*es*"`, types.Boolean(true), nil}, + + {"case-1", newLiteralEval(types.String("eggs")), `"ham*"`, types.Boolean(false), nil}, + {"case-2", newLiteralEval(types.String("eggs")), `"*ham"`, types.Boolean(false), nil}, + {"case-3", newLiteralEval(types.String("eggs")), `"*ham*"`, types.Boolean(false), nil}, + {"case-4", newLiteralEval(types.String("ham and eggs")), `"ham*"`, types.Boolean(true), nil}, + {"case-5", newLiteralEval(types.String("ham and eggs")), `"*ham"`, types.Boolean(false), nil}, + {"case-6", newLiteralEval(types.String("ham and eggs")), `"*ham*"`, types.Boolean(true), nil}, + {"case-7", newLiteralEval(types.String("ham and eggs")), `"*h*a*m*"`, types.Boolean(true), nil}, + {"case-8", newLiteralEval(types.String("eggs and ham")), `"ham*"`, types.Boolean(false), nil}, + {"case-9", newLiteralEval(types.String("eggs and ham")), `"*ham"`, types.Boolean(true), nil}, + {"case-10", newLiteralEval(types.String("eggs, ham, and spinach")), `"ham*"`, types.Boolean(false), nil}, + {"case-11", newLiteralEval(types.String("eggs, ham, and spinach")), `"*ham"`, types.Boolean(false), nil}, + {"case-12", newLiteralEval(types.String("eggs, ham, and spinach")), `"*ham*"`, types.Boolean(true), nil}, + {"case-13", newLiteralEval(types.String("Gotham")), `"ham*"`, types.Boolean(false), nil}, + {"case-14", newLiteralEval(types.String("Gotham")), `"*ham"`, types.Boolean(true), nil}, + {"case-15", newLiteralEval(types.String("ham")), `"ham"`, types.Boolean(true), nil}, + {"case-16", newLiteralEval(types.String("ham")), `"ham*"`, types.Boolean(true), nil}, + {"case-17", newLiteralEval(types.String("ham")), `"*ham"`, types.Boolean(true), nil}, + {"case-18", newLiteralEval(types.String("ham")), `"*h*a*m*"`, types.Boolean(true), nil}, + {"case-19", newLiteralEval(types.String("ham and ham")), `"ham*"`, types.Boolean(true), nil}, + {"case-20", newLiteralEval(types.String("ham and ham")), `"*ham"`, types.Boolean(true), nil}, + {"case-21", newLiteralEval(types.String("ham")), `"*ham and eggs*"`, types.Boolean(false), nil}, + {"case-22", newLiteralEval(types.String("\\afterslash")), `"\\*"`, types.Boolean(true), nil}, + {"case-23", newLiteralEval(types.String("string\\with\\backslashes")), `"string\\with\\backslashes"`, types.Boolean(true), nil}, + {"case-24", newLiteralEval(types.String("string\\with\\backslashes")), `"string*with*backslashes"`, types.Boolean(true), nil}, + {"case-25", newLiteralEval(types.String("string*with*stars")), `"string\*with\*stars"`, types.Boolean(true), nil}, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() pat, err := parser.NewPattern(tt.pattern) - testutilOK(t, err) + testutil.OK(t, err) n := newLikeEval(tt.str, pat) v, err := n.Eval(&evalContext{}) - assertError(t, err, tt.err) - assertValue(t, v, tt.result) + testutil.AssertError(t, err, tt.err) + types.AssertValue(t, v, tt.result) }) } } @@ -1412,24 +1414,24 @@ func TestVariableNode(t *testing.T) { name string context evalContext variable variableName - result Value + result types.Value }{ {"principal", - evalContext{Principal: String("foo")}, + evalContext{Principal: types.String("foo")}, variableNamePrincipal, - String("foo")}, + types.String("foo")}, {"action", - evalContext{Action: String("bar")}, + evalContext{Action: types.String("bar")}, variableNameAction, - String("bar")}, + types.String("bar")}, {"resource", - evalContext{Resource: String("baz")}, + evalContext{Resource: types.String("baz")}, variableNameResource, - String("baz")}, + types.String("baz")}, {"context", - evalContext{Context: String("frob")}, + evalContext{Context: types.String("frob")}, variableNameContext, - String("frob")}, + types.String("frob")}, } for _, tt := range tests { tt := tt @@ -1437,8 +1439,8 @@ func TestVariableNode(t *testing.T) { t.Parallel() n := newVariableEval(tt.variable) v, err := n.Eval(&tt.context) - testutilOK(t, err) - assertValue(t, v, tt.result) + testutil.OK(t, err) + types.AssertValue(t, v, tt.result) }) } } @@ -1530,13 +1532,13 @@ func TestEntityIn(t *testing.T) { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() - rhs := map[EntityUID]struct{}{} + rhs := map[types.EntityUID]struct{}{} for _, v := range tt.rhs { rhs[strEnt(v)] = struct{}{} } entities := Entities{} for k, p := range tt.parents { - var ps []EntityUID + var ps []types.EntityUID for _, pp := range p { ps = append(ps, strEnt(pp)) } @@ -1547,7 +1549,7 @@ func TestEntityIn(t *testing.T) { } } res := entityIn(strEnt(tt.lhs), rhs, entities) - testutilEquals(t, res, tt.result) + testutil.Equals(t, res, tt.result) }) } t.Run("exponentialWithoutCaching", func(t *testing.T) { @@ -1557,24 +1559,24 @@ func TestEntityIn(t *testing.T) { entities := Entities{} for i := 0; i < 100; i++ { - p := []EntityUID{ - NewEntityUID(fmt.Sprint(i+1), "1"), - NewEntityUID(fmt.Sprint(i+1), "2"), + p := []types.EntityUID{ + types.NewEntityUID(fmt.Sprint(i+1), "1"), + types.NewEntityUID(fmt.Sprint(i+1), "2"), } - uid1 := NewEntityUID(fmt.Sprint(i), "1") + uid1 := types.NewEntityUID(fmt.Sprint(i), "1") entities[uid1] = Entity{ UID: uid1, Parents: p, } - uid2 := NewEntityUID(fmt.Sprint(i), "2") + uid2 := types.NewEntityUID(fmt.Sprint(i), "2") entities[uid2] = Entity{ UID: uid2, Parents: p, } } - res := entityIn(NewEntityUID("0", "1"), map[EntityUID]struct{}{NewEntityUID("0", "3"): {}}, entities) - testutilEquals(t, res, false) + res := entityIn(types.NewEntityUID("0", "1"), map[types.EntityUID]struct{}{types.NewEntityUID("0", "3"): {}}, entities) + testutil.Equals(t, res, false) }) } @@ -1583,23 +1585,23 @@ func TestIsNode(t *testing.T) { tests := []struct { name string lhs, rhs evaler - result Value + result types.Value err error }{ - {"happyEq", newLiteralEval(NewEntityUID("X", "z")), newLiteralEval(path("X")), Boolean(true), nil}, - {"happyNeq", newLiteralEval(NewEntityUID("X", "z")), newLiteralEval(path("Y")), Boolean(false), nil}, - {"badLhs", newLiteralEval(Long(42)), newLiteralEval(path("X")), zeroValue(), errType}, - {"badRhs", newLiteralEval(NewEntityUID("X", "z")), newLiteralEval(Long(42)), zeroValue(), errType}, - {"errLhs", newErrorEval(errTest), newLiteralEval(path("X")), zeroValue(), errTest}, - {"errRhs", newLiteralEval(NewEntityUID("X", "z")), newErrorEval(errTest), zeroValue(), errTest}, + {"happyEq", newLiteralEval(types.NewEntityUID("X", "z")), newLiteralEval(types.Path("X")), types.Boolean(true), nil}, + {"happyNeq", newLiteralEval(types.NewEntityUID("X", "z")), newLiteralEval(types.Path("Y")), types.Boolean(false), nil}, + {"badLhs", newLiteralEval(types.Long(42)), newLiteralEval(types.Path("X")), types.ZeroValue(), types.ErrType}, + {"badRhs", newLiteralEval(types.NewEntityUID("X", "z")), newLiteralEval(types.Long(42)), types.ZeroValue(), types.ErrType}, + {"errLhs", newErrorEval(errTest), newLiteralEval(types.Path("X")), types.ZeroValue(), errTest}, + {"errRhs", newLiteralEval(types.NewEntityUID("X", "z")), newErrorEval(errTest), types.ZeroValue(), errTest}, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() got, err := newIsEval(tt.lhs, tt.rhs).Eval(&evalContext{}) - assertError(t, err, tt.err) - assertValue(t, got, tt.result) + testutil.AssertError(t, err, tt.err) + types.AssertValue(t, got, tt.result) }) } } @@ -1610,89 +1612,89 @@ func TestInNode(t *testing.T) { name string lhs, rhs evaler parents map[string][]string - result Value + result types.Value err error }{ { "LhsError", newErrorEval(errTest), - newLiteralEval(Set{}), + newLiteralEval(types.Set{}), map[string][]string{}, - zeroValue(), + types.ZeroValue(), errTest, }, { "LhsTypeError", - newLiteralEval(String("foo")), - newLiteralEval(Set{}), + newLiteralEval(types.String("foo")), + newLiteralEval(types.Set{}), map[string][]string{}, - zeroValue(), - errType, + types.ZeroValue(), + types.ErrType, }, { "RhsError", - newLiteralEval(EntityUID{"human", "joe"}), + newLiteralEval(types.EntityUID{"human", "joe"}), newErrorEval(errTest), map[string][]string{}, - zeroValue(), + types.ZeroValue(), errTest, }, { "RhsTypeError1", - newLiteralEval(EntityUID{"human", "joe"}), - newLiteralEval(String("foo")), + newLiteralEval(types.EntityUID{"human", "joe"}), + newLiteralEval(types.String("foo")), map[string][]string{}, - zeroValue(), - errType, + types.ZeroValue(), + types.ErrType, }, { "RhsTypeError2", - newLiteralEval(EntityUID{"human", "joe"}), - newLiteralEval(Set{ - String("foo"), + newLiteralEval(types.EntityUID{"human", "joe"}), + newLiteralEval(types.Set{ + types.String("foo"), }), map[string][]string{}, - zeroValue(), - errType, + types.ZeroValue(), + types.ErrType, }, { "Reflexive1", - newLiteralEval(EntityUID{"human", "joe"}), - newLiteralEval(EntityUID{"human", "joe"}), + newLiteralEval(types.EntityUID{"human", "joe"}), + newLiteralEval(types.EntityUID{"human", "joe"}), map[string][]string{}, - Boolean(true), + types.Boolean(true), nil, }, { "Reflexive2", - newLiteralEval(EntityUID{"human", "joe"}), - newLiteralEval(Set{ - EntityUID{"human", "joe"}, + newLiteralEval(types.EntityUID{"human", "joe"}), + newLiteralEval(types.Set{ + types.EntityUID{"human", "joe"}, }), map[string][]string{}, - Boolean(true), + types.Boolean(true), nil, }, { "BasicTrue", - newLiteralEval(EntityUID{"human", "joe"}), - newLiteralEval(EntityUID{"kingdom", "animal"}), + newLiteralEval(types.EntityUID{"human", "joe"}), + newLiteralEval(types.EntityUID{"kingdom", "animal"}), map[string][]string{ `human::"joe"`: {`species::"human"`}, `species::"human"`: {`kingdom::"animal"`}, }, - Boolean(true), + types.Boolean(true), nil, }, { "BasicFalse", - newLiteralEval(EntityUID{"human", "joe"}), - newLiteralEval(EntityUID{"kingdom", "plant"}), + newLiteralEval(types.EntityUID{"human", "joe"}), + newLiteralEval(types.EntityUID{"kingdom", "plant"}), map[string][]string{ `human::"joe"`: {`species::"human"`}, `species::"human"`: {`kingdom::"animal"`}, }, - Boolean(false), + types.Boolean(false), nil, }, } @@ -1703,7 +1705,7 @@ func TestInNode(t *testing.T) { n := newInEval(tt.lhs, tt.rhs) entities := Entities{} for k, p := range tt.parents { - var ps []EntityUID + var ps []types.EntityUID for _, pp := range p { ps = append(ps, strEnt(pp)) } @@ -1715,8 +1717,8 @@ func TestInNode(t *testing.T) { } evalContext := evalContext{Entities: entities} v, err := n.Eval(&evalContext) - assertError(t, err, tt.err) - assertValue(t, v, tt.result) + testutil.AssertError(t, err, tt.err) + types.AssertValue(t, v, tt.result) }) } } @@ -1726,13 +1728,13 @@ func TestDecimalLiteralNode(t *testing.T) { tests := []struct { name string arg evaler - result Value + result types.Value err error }{ - {"Error", newErrorEval(errTest), zeroValue(), errTest}, - {"TypeError", newLiteralEval(Long(1)), zeroValue(), errType}, - {"DecimalError", newLiteralEval(String("frob")), zeroValue(), errDecimal}, - {"Success", newLiteralEval(String("1.0")), Decimal(10000), nil}, + {"Error", newErrorEval(errTest), types.ZeroValue(), errTest}, + {"TypeError", newLiteralEval(types.Long(1)), types.ZeroValue(), types.ErrType}, + {"DecimalError", newLiteralEval(types.String("frob")), types.ZeroValue(), types.ErrDecimal}, + {"Success", newLiteralEval(types.String("1.0")), types.Decimal(10000), nil}, } for _, tt := range tests { tt := tt @@ -1740,26 +1742,26 @@ func TestDecimalLiteralNode(t *testing.T) { t.Parallel() n := newDecimalLiteralEval(tt.arg) v, err := n.Eval(&evalContext{}) - assertError(t, err, tt.err) - assertValue(t, v, tt.result) + testutil.AssertError(t, err, tt.err) + types.AssertValue(t, v, tt.result) }) } } func TestIPLiteralNode(t *testing.T) { t.Parallel() - ipv6Loopback, err := ParseIPAddr("::1") - testutilOK(t, err) + ipv6Loopback, err := types.ParseIPAddr("::1") + testutil.OK(t, err) tests := []struct { name string arg evaler - result Value + result types.Value err error }{ - {"Error", newErrorEval(errTest), zeroValue(), errTest}, - {"TypeError", newLiteralEval(Long(1)), zeroValue(), errType}, - {"IPError", newLiteralEval(String("not-an-IP-address")), zeroValue(), errIP}, - {"Success", newLiteralEval(String("::1/128")), ipv6Loopback, nil}, + {"Error", newErrorEval(errTest), types.ZeroValue(), errTest}, + {"TypeError", newLiteralEval(types.Long(1)), types.ZeroValue(), types.ErrType}, + {"IPError", newLiteralEval(types.String("not-an-IP-address")), types.ZeroValue(), types.ErrIP}, + {"Success", newLiteralEval(types.String("::1/128")), ipv6Loopback, nil}, } for _, tt := range tests { tt := tt @@ -1767,37 +1769,37 @@ func TestIPLiteralNode(t *testing.T) { t.Parallel() n := newIPLiteralEval(tt.arg) v, err := n.Eval(&evalContext{}) - assertError(t, err, tt.err) - assertValue(t, v, tt.result) + testutil.AssertError(t, err, tt.err) + types.AssertValue(t, v, tt.result) }) } } func TestIPTestNode(t *testing.T) { t.Parallel() - ipv4Loopback, err := ParseIPAddr("127.0.0.1") - testutilOK(t, err) - ipv6Loopback, err := ParseIPAddr("::1") - testutilOK(t, err) - ipv4Multicast, err := ParseIPAddr("224.0.0.1") - testutilOK(t, err) + ipv4Loopback, err := types.ParseIPAddr("127.0.0.1") + testutil.OK(t, err) + ipv6Loopback, err := types.ParseIPAddr("::1") + testutil.OK(t, err) + ipv4Multicast, err := types.ParseIPAddr("224.0.0.1") + testutil.OK(t, err) tests := []struct { name string lhs evaler rhs ipTestType - result Value + result types.Value err error }{ - {"Error", newErrorEval(errTest), ipTestIPv4, zeroValue(), errTest}, - {"TypeError", newLiteralEval(Long(1)), ipTestIPv4, zeroValue(), errType}, - {"IPv4True", newLiteralEval(ipv4Loopback), ipTestIPv4, Boolean(true), nil}, - {"IPv4False", newLiteralEval(ipv6Loopback), ipTestIPv4, Boolean(false), nil}, - {"IPv6True", newLiteralEval(ipv6Loopback), ipTestIPv6, Boolean(true), nil}, - {"IPv6False", newLiteralEval(ipv4Loopback), ipTestIPv6, Boolean(false), nil}, - {"LoopbackTrue", newLiteralEval(ipv6Loopback), ipTestLoopback, Boolean(true), nil}, - {"LoopbackFalse", newLiteralEval(ipv4Multicast), ipTestLoopback, Boolean(false), nil}, - {"MulticastTrue", newLiteralEval(ipv4Multicast), ipTestMulticast, Boolean(true), nil}, - {"MulticastFalse", newLiteralEval(ipv6Loopback), ipTestMulticast, Boolean(false), nil}, + {"Error", newErrorEval(errTest), ipTestIPv4, types.ZeroValue(), errTest}, + {"TypeError", newLiteralEval(types.Long(1)), ipTestIPv4, types.ZeroValue(), types.ErrType}, + {"IPv4True", newLiteralEval(ipv4Loopback), ipTestIPv4, types.Boolean(true), nil}, + {"IPv4False", newLiteralEval(ipv6Loopback), ipTestIPv4, types.Boolean(false), nil}, + {"IPv6True", newLiteralEval(ipv6Loopback), ipTestIPv6, types.Boolean(true), nil}, + {"IPv6False", newLiteralEval(ipv4Loopback), ipTestIPv6, types.Boolean(false), nil}, + {"LoopbackTrue", newLiteralEval(ipv6Loopback), ipTestLoopback, types.Boolean(true), nil}, + {"LoopbackFalse", newLiteralEval(ipv4Multicast), ipTestLoopback, types.Boolean(false), nil}, + {"MulticastTrue", newLiteralEval(ipv4Multicast), ipTestMulticast, types.Boolean(true), nil}, + {"MulticastFalse", newLiteralEval(ipv6Loopback), ipTestMulticast, types.Boolean(false), nil}, } for _, tt := range tests { tt := tt @@ -1805,37 +1807,37 @@ func TestIPTestNode(t *testing.T) { t.Parallel() n := newIPTestEval(tt.lhs, tt.rhs) v, err := n.Eval(&evalContext{}) - assertError(t, err, tt.err) - assertValue(t, v, tt.result) + testutil.AssertError(t, err, tt.err) + types.AssertValue(t, v, tt.result) }) } } func TestIPIsInRangeNode(t *testing.T) { t.Parallel() - ipv4A, err := ParseIPAddr("1.2.3.4") - testutilOK(t, err) - ipv4B, err := ParseIPAddr("1.2.3.0/24") - testutilOK(t, err) - ipv4C, err := ParseIPAddr("1.2.4.0/24") - testutilOK(t, err) + ipv4A, err := types.ParseIPAddr("1.2.3.4") + testutil.OK(t, err) + ipv4B, err := types.ParseIPAddr("1.2.3.0/24") + testutil.OK(t, err) + ipv4C, err := types.ParseIPAddr("1.2.4.0/24") + testutil.OK(t, err) tests := []struct { name string lhs, rhs evaler - result Value + result types.Value err error }{ - {"LhsError", newErrorEval(errTest), newLiteralEval(ipv4A), zeroValue(), errTest}, - {"LhsTypeError", newLiteralEval(Long(1)), newLiteralEval(ipv4A), zeroValue(), errType}, - {"RhsError", newLiteralEval(ipv4A), newErrorEval(errTest), zeroValue(), errTest}, - {"RhsTypeError", newLiteralEval(ipv4A), newLiteralEval(Long(1)), zeroValue(), errType}, - {"AA", newLiteralEval(ipv4A), newLiteralEval(ipv4A), Boolean(true), nil}, - {"AB", newLiteralEval(ipv4A), newLiteralEval(ipv4B), Boolean(true), nil}, - {"BA", newLiteralEval(ipv4B), newLiteralEval(ipv4A), Boolean(false), nil}, - {"AC", newLiteralEval(ipv4A), newLiteralEval(ipv4C), Boolean(false), nil}, - {"CA", newLiteralEval(ipv4C), newLiteralEval(ipv4A), Boolean(false), nil}, - {"BC", newLiteralEval(ipv4B), newLiteralEval(ipv4C), Boolean(false), nil}, - {"CB", newLiteralEval(ipv4C), newLiteralEval(ipv4B), Boolean(false), nil}, + {"LhsError", newErrorEval(errTest), newLiteralEval(ipv4A), types.ZeroValue(), errTest}, + {"LhsTypeError", newLiteralEval(types.Long(1)), newLiteralEval(ipv4A), types.ZeroValue(), types.ErrType}, + {"RhsError", newLiteralEval(ipv4A), newErrorEval(errTest), types.ZeroValue(), errTest}, + {"RhsTypeError", newLiteralEval(ipv4A), newLiteralEval(types.Long(1)), types.ZeroValue(), types.ErrType}, + {"AA", newLiteralEval(ipv4A), newLiteralEval(ipv4A), types.Boolean(true), nil}, + {"AB", newLiteralEval(ipv4A), newLiteralEval(ipv4B), types.Boolean(true), nil}, + {"BA", newLiteralEval(ipv4B), newLiteralEval(ipv4A), types.Boolean(false), nil}, + {"AC", newLiteralEval(ipv4A), newLiteralEval(ipv4C), types.Boolean(false), nil}, + {"CA", newLiteralEval(ipv4C), newLiteralEval(ipv4A), types.Boolean(false), nil}, + {"BC", newLiteralEval(ipv4B), newLiteralEval(ipv4C), types.Boolean(false), nil}, + {"CB", newLiteralEval(ipv4C), newLiteralEval(ipv4B), types.Boolean(false), nil}, } for _, tt := range tests { tt := tt @@ -1843,8 +1845,8 @@ func TestIPIsInRangeNode(t *testing.T) { t.Parallel() n := newIPIsInRangeEval(tt.lhs, tt.rhs) v, err := n.Eval(&evalContext{}) - assertError(t, err, tt.err) - assertValue(t, v, tt.result) + testutil.AssertError(t, err, tt.err) + types.AssertValue(t, v, tt.result) }) } } @@ -1853,27 +1855,27 @@ func TestCedarString(t *testing.T) { t.Parallel() tests := []struct { name string - in Value + in types.Value wantString string wantCedar string }{ - {"string", String("hello"), `hello`, `"hello"`}, - {"number", Long(42), `42`, `42`}, - {"bool", Boolean(true), `true`, `true`}, - {"record", Record{"a": Long(42), "b": Long(43)}, `{"a":42,"b":43}`, `{"a":42,"b":43}`}, - {"set", Set{Long(42), Long(43)}, `[42,43]`, `[42,43]`}, - {"singleIP", IPAddr(netip.MustParsePrefix("192.168.0.42/32")), `192.168.0.42`, `ip("192.168.0.42")`}, - {"ipPrefix", IPAddr(netip.MustParsePrefix("192.168.0.42/24")), `192.168.0.42/24`, `ip("192.168.0.42/24")`}, - {"decimal", Decimal(12345678), `1234.5678`, `decimal("1234.5678")`}, + {"string", types.String("hello"), `hello`, `"hello"`}, + {"number", types.Long(42), `42`, `42`}, + {"bool", types.Boolean(true), `true`, `true`}, + {"record", types.Record{"a": types.Long(42), "b": types.Long(43)}, `{"a":42,"b":43}`, `{"a":42,"b":43}`}, + {"set", types.Set{types.Long(42), types.Long(43)}, `[42,43]`, `[42,43]`}, + {"singleIP", types.IPAddr(netip.MustParsePrefix("192.168.0.42/32")), `192.168.0.42`, `ip("192.168.0.42")`}, + {"ipPrefix", types.IPAddr(netip.MustParsePrefix("192.168.0.42/24")), `192.168.0.42/24`, `ip("192.168.0.42/24")`}, + {"decimal", types.Decimal(12345678), `1234.5678`, `decimal("1234.5678")`}, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() gotString := tt.in.String() - testutilEquals(t, gotString, tt.wantString) + testutil.Equals(t, gotString, tt.wantString) gotCedar := tt.in.Cedar() - testutilEquals(t, gotCedar, tt.wantCedar) + testutil.Equals(t, gotCedar, tt.wantCedar) }) } } diff --git a/match_test.go b/match_test.go index 39b710a2..783ed174 100644 --- a/match_test.go +++ b/match_test.go @@ -3,6 +3,7 @@ package cedar import ( "testing" + "github.com/cedar-policy/cedar-go/testutil" "github.com/cedar-policy/cedar-go/x/exp/parser" ) @@ -38,9 +39,9 @@ func TestMatch(t *testing.T) { t.Run(tt.pattern+":"+tt.target, func(t *testing.T) { t.Parallel() pat, err := parser.NewPattern(tt.pattern) - testutilOK(t, err) + testutil.OK(t, err) got := match(pat, tt.target) - testutilEquals(t, got, tt.want) + testutil.Equals(t, got, tt.want) }) } } diff --git a/testutil/testutil.go b/testutil/testutil.go new file mode 100644 index 00000000..6d897b67 --- /dev/null +++ b/testutil/testutil.go @@ -0,0 +1,44 @@ +package testutil + +import ( + "errors" + "reflect" + "testing" +) + +func Equals[T any](t testing.TB, a, b T) { + t.Helper() + if reflect.DeepEqual(a, b) { + return + } + t.Fatalf("got %+v want %+v", a, b) +} + +func FatalIf(t testing.TB, c bool, f string, args ...any) { + t.Helper() + if !c { + return + } + t.Fatalf(f, args...) +} + +func OK(t testing.TB, err error) { + t.Helper() + if err == nil { + return + } + t.Fatalf("got %v want nil", err) +} + +func Error(t testing.TB, err error) { + t.Helper() + if err != nil { + return + } + t.Fatalf("got nil want error") +} + +func AssertError(t *testing.T, got, want error) { + t.Helper() + FatalIf(t, !errors.Is(got, want), "err got %v want %v", got, want) +} diff --git a/testutil_test.go b/testutil_test.go deleted file mode 100644 index 2de50a82..00000000 --- a/testutil_test.go +++ /dev/null @@ -1,83 +0,0 @@ -package cedar - -import ( - "errors" - "fmt" - "reflect" - "testing" -) - -func testutilEquals[T any](t testing.TB, a, b T) { - t.Helper() - if reflect.DeepEqual(a, b) { - return - } - t.Fatalf("got %+v want %+v", a, b) -} - -func testutilFatalIf(t testing.TB, c bool, f string, args ...any) { - t.Helper() - if !c { - return - } - t.Fatalf(f, args...) -} - -func testutilOK(t testing.TB, err error) { - t.Helper() - if err == nil { - return - } - t.Fatalf("got %v want nil", err) -} - -func testutilError(t testing.TB, err error) { - t.Helper() - if err != nil { - return - } - t.Fatalf("got nil want error") -} - -func assertError(t *testing.T, got, want error) { - t.Helper() - testutilFatalIf(t, !errors.Is(got, want), "err got %v want %v", got, want) -} - -func assertValue(t *testing.T, got, want Value) { - t.Helper() - testutilFatalIf( - t, - !((got == zeroValue() && want == zeroValue()) || - (got != zeroValue() && want != zeroValue() && got.equal(want))), - "got %v want %v", got, want) -} - -func assertBoolValue(t *testing.T, got Value, want bool) { - t.Helper() - testutilEquals[Value](t, got, Boolean(want)) -} - -func assertLongValue(t *testing.T, got Value, want int64) { - t.Helper() - testutilEquals[Value](t, got, Long(want)) -} - -func assertZeroValue(t *testing.T, got Value) { - t.Helper() - testutilEquals(t, got, zeroValue()) -} - -func assertValueString(t *testing.T, v Value, want string) { - t.Helper() - testutilEquals(t, v.String(), want) -} - -func safeDoErr(f func() error) (err error) { - defer func() { - if r := recover(); r != nil { - err = fmt.Errorf("%v", r) - } - }() - return f() -} diff --git a/toeval.go b/toeval.go index 1d2f1977..9988178e 100644 --- a/toeval.go +++ b/toeval.go @@ -4,6 +4,7 @@ import ( "fmt" "strings" + "github.com/cedar-policy/cedar-go/types" "github.com/cedar-policy/cedar-go/x/exp/parser" ) @@ -21,7 +22,7 @@ func toEval(n any) evaler { var res evaler switch v.Type { case parser.MatchAny: - res = newLiteralEval(Boolean(true)) + res = newLiteralEval(types.Boolean(true)) case parser.MatchEquals: res = newEqualEval(newVariableEval(variableNamePrincipal), toEval(v.Entity)) case parser.MatchIn: @@ -38,7 +39,7 @@ func toEval(n any) evaler { var res evaler switch v.Type { case parser.MatchAny: - res = newLiteralEval(Boolean(true)) + res = newLiteralEval(types.Boolean(true)) case parser.MatchEquals: res = newEqualEval(newVariableEval(variableNameAction), toEval(v.Entities[0])) case parser.MatchIn: @@ -56,7 +57,7 @@ func toEval(n any) evaler { var res evaler switch v.Type { case parser.MatchAny: - res = newLiteralEval(Boolean(true)) + res = newLiteralEval(types.Boolean(true)) case parser.MatchEquals: res = newEqualEval(newVariableEval(variableNameResource), toEval(v.Entity)) case parser.MatchIn: @@ -70,9 +71,9 @@ func toEval(n any) evaler { } return res case parser.Entity: - return newLiteralEval(entityValueFromSlice(v.Path)) + return newLiteralEval(types.EntityValueFromSlice(v.Path)) case parser.Path: - return newLiteralEval(pathFromSlice(v.Path)) + return newLiteralEval(types.PathFromSlice(v.Path)) case parser.Condition: var res evaler switch v.Type { @@ -210,11 +211,11 @@ func toEval(n any) evaler { case parser.Literal: switch v.Type { case parser.LiteralBool: - return newLiteralEval(Boolean(v.Bool)) + return newLiteralEval(types.Boolean(v.Bool)) case parser.LiteralInt: - return newLiteralEval(Long(v.Long)) + return newLiteralEval(types.Long(v.Long)) case parser.LiteralString: - return newLiteralEval(String(v.Str)) + return newLiteralEval(types.String(v.Str)) default: panic("missing LiteralType case") } diff --git a/toeval_test.go b/toeval_test.go index 5a64a9d3..037e58ef 100644 --- a/toeval_test.go +++ b/toeval_test.go @@ -1,12 +1,24 @@ package cedar import ( + "fmt" "strings" "testing" + "github.com/cedar-policy/cedar-go/testutil" + "github.com/cedar-policy/cedar-go/types" "github.com/cedar-policy/cedar-go/x/exp/parser" ) +func safeDoErr(f func() error) (err error) { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("%v", r) + } + }() + return f() +} + func TestToEval(t *testing.T) { t.Parallel() tests := []struct { @@ -18,7 +30,7 @@ func TestToEval(t *testing.T) { {"happy", parser.Entity{ Path: []string{"Action", "test"}, }, - newLiteralEval(entityValueFromSlice([]string{"Action", "test"})), ""}, + newLiteralEval(types.EntityValueFromSlice([]string{"Action", "test"})), ""}, {"missingRelOp", parser.Relation{ Add: parser.Add{ Mults: []parser.Mult{ @@ -148,10 +160,10 @@ func TestToEval(t *testing.T) { out = toEval(tt.in) return nil }) - testutilEquals(t, out, tt.out) - testutilEquals(t, err != nil, tt.panic != "") + testutil.Equals(t, out, tt.out) + testutil.Equals(t, err != nil, tt.panic != "") if tt.panic != "" { - testutilFatalIf(t, !strings.Contains(err.Error(), tt.panic), "panic got %v want %v", err.Error(), tt.panic) + testutil.FatalIf(t, !strings.Contains(err.Error(), tt.panic), "panic got %v want %v", err.Error(), tt.panic) } }) } diff --git a/json.go b/types/json.go similarity index 99% rename from json.go rename to types/json.go index 0ae74298..1df5e0f8 100644 --- a/json.go +++ b/types/json.go @@ -1,4 +1,4 @@ -package cedar +package types import ( "bytes" diff --git a/json_test.go b/types/json_test.go similarity index 70% rename from json_test.go rename to types/json_test.go index 4b62a2ad..c4e2d834 100644 --- a/json_test.go +++ b/types/json_test.go @@ -1,9 +1,11 @@ -package cedar +package types import ( "encoding/json" "fmt" "testing" + + "github.com/cedar-policy/cedar-go/testutil" ) func mustDecimalValue(v string) Decimal { @@ -28,15 +30,15 @@ func TestJSON_Value(t *testing.T) { {"explicitEntity", `{ "__entity": { "type": "User", "id": "alice" } }`, EntityUID{Type: "User", ID: "alice"}, nil}, {"impliedLongEntity", `{ "type": "User::External", "id": "alice" }`, EntityUID{Type: "User::External", ID: "alice"}, nil}, {"explicitLongEntity", `{ "__entity": { "type": "User::External", "id": "alice" } }`, EntityUID{Type: "User::External", ID: "alice"}, nil}, - {"invalidJSON", `!@#$`, zeroValue(), errJSONDecode}, - {"numericOverflow", "12341234123412341234", zeroValue(), errJSONLongOutOfRange}, - {"unsupportedNull", "null", zeroValue(), errJSONUnsupportedType}, + {"invalidJSON", `!@#$`, ZeroValue(), errJSONDecode}, + {"numericOverflow", "12341234123412341234", ZeroValue(), errJSONLongOutOfRange}, + {"unsupportedNull", "null", ZeroValue(), errJSONUnsupportedType}, {"explicitIP", `{ "__extn": { "fn": "ip", "arg": "222.222.222.7" } }`, mustIPValue("222.222.222.7"), nil}, {"explicitSubnet", `{ "__extn": { "fn": "ip", "arg": "192.168.0.0/16" } }`, mustIPValue("192.168.0.0/16"), nil}, {"explicitDecimal", `{ "__extn": { "fn": "decimal", "arg": "33.57" } }`, mustDecimalValue("33.57"), nil}, - {"invalidExtension", `{ "__extn": { "fn": "asdf", "arg": "blah" } }`, zeroValue(), errJSONInvalidExtn}, - {"badIP", `{ "__extn": { "fn": "ip", "arg": "bad" } }`, zeroValue(), errIP}, - {"badDecimal", `{ "__extn": { "fn": "decimal", "arg": "bad" } }`, zeroValue(), errDecimal}, + {"invalidExtension", `{ "__extn": { "fn": "asdf", "arg": "blah" } }`, ZeroValue(), errJSONInvalidExtn}, + {"badIP", `{ "__extn": { "fn": "ip", "arg": "bad" } }`, ZeroValue(), ErrIP}, + {"badDecimal", `{ "__extn": { "fn": "decimal", "arg": "bad" } }`, ZeroValue(), ErrDecimal}, {"set", `[42]`, Set{Long(42)}, nil}, {"record", `{"a":"b"}`, Record{"a": String("b")}, nil}, {"bool", `false`, Boolean(false), nil}, @@ -48,8 +50,8 @@ func TestJSON_Value(t *testing.T) { var got Value ptr := &got err := unmarshalJSON([]byte(tt.in), ptr) - assertError(t, err, tt.err) - assertValue(t, got, tt.want) + testutil.AssertError(t, err, tt.err) + AssertValue(t, got, tt.want) if tt.err != nil { return } @@ -57,12 +59,12 @@ func TestJSON_Value(t *testing.T) { // Now assert that when we Marshal/Unmarshal that value, we still // have what we started with gotJSON, err := (*ptr).ExplicitMarshalJSON() - testutilOK(t, err) + testutil.OK(t, err) var gotRetry Value ptr = &gotRetry err = unmarshalJSON(gotJSON, ptr) - testutilOK(t, err) - testutilEquals(t, gotRetry, tt.want) + testutil.OK(t, err) + testutil.Equals(t, gotRetry, tt.want) }) } } @@ -129,7 +131,7 @@ func TestTypedJSONUnmarshal(t *testing.T) { }, in: `{ "__extn": { "fn": "ip", "arg": "bad" } }`, wantValue: IPAddr{}, - wantErr: errIP, + wantErr: ErrIP, }, { name: "ip/badJSON", @@ -207,7 +209,7 @@ func TestTypedJSONUnmarshal(t *testing.T) { }, in: `{ "__extn": { "fn": "decimal", "arg": "bad" } }`, wantValue: Decimal(0), - wantErr: errDecimal, + wantErr: ErrDecimal, }, { name: "decimal/badJSON", @@ -248,8 +250,8 @@ func TestTypedJSONUnmarshal(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() gotValue, gotErr := tt.f([]byte(tt.in)) - testutilEquals(t, gotValue, tt.wantValue) - assertError(t, gotErr, tt.wantErr) + testutil.Equals(t, gotValue, tt.wantValue) + testutil.AssertError(t, gotErr, tt.wantErr) }) } } @@ -286,11 +288,11 @@ func TestJSONMarshal(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() outExplicit, err := tt.in.ExplicitMarshalJSON() - testutilOK(t, err) - testutilEquals(t, string(outExplicit), tt.outExplicit) + testutil.OK(t, err) + testutil.Equals(t, string(outExplicit), tt.outExplicit) outImplicit, err := json.Marshal(tt.in) - testutilOK(t, err) - testutilEquals(t, string(outImplicit), tt.outImplicit) + testutil.OK(t, err) + testutil.Equals(t, string(outImplicit), tt.outImplicit) }) } } @@ -299,9 +301,9 @@ type jsonErr struct{} func (j *jsonErr) String() string { return "" } func (j *jsonErr) Cedar() string { return "" } -func (j *jsonErr) equal(Value) bool { return false } +func (j *jsonErr) Equal(Value) bool { return false } func (j *jsonErr) ExplicitMarshalJSON() ([]byte, error) { return nil, fmt.Errorf("jsonErr") } -func (j *jsonErr) typeName() string { return "jsonErr" } +func (j *jsonErr) TypeName() string { return "jsonErr" } func (j *jsonErr) deepClone() Value { return nil } func TestJSONSet(t *testing.T) { @@ -310,13 +312,13 @@ func TestJSONSet(t *testing.T) { t.Parallel() var s Set err := json.Unmarshal([]byte(`[{"__extn":{"fn":"err"}}]`), &s) - testutilError(t, err) + testutil.Error(t, err) }) t.Run("MarshalErr", func(t *testing.T) { t.Parallel() s := Set{&jsonErr{}} _, err := json.Marshal(s) - testutilError(t, err) + testutil.Error(t, err) }) } @@ -326,7 +328,7 @@ func TestJSONRecord(t *testing.T) { t.Parallel() var r Record err := json.Unmarshal([]byte(`{"key":{"__extn":{"fn":"err"}}}`), &r) - testutilError(t, err) + testutil.Error(t, err) }) t.Run("MarshalKeyErrImpossible", func(t *testing.T) { t.Parallel() @@ -335,117 +337,13 @@ func TestJSONRecord(t *testing.T) { r[string(k)] = Boolean(false) v, err := json.Marshal(r) // this demonstrates that invalid keys will still result in json - testutilEquals(t, string(v), `{"\ufffd\u0001":false}`) - testutilOK(t, err) + testutil.Equals(t, string(v), `{"\ufffd\u0001":false}`) + testutil.OK(t, err) }) t.Run("MarshalValueErr", func(t *testing.T) { t.Parallel() r := Record{"key": &jsonErr{}} _, err := json.Marshal(r) - testutilError(t, err) - }) -} - -func TestEntitiesJSON(t *testing.T) { - t.Parallel() - t.Run("Marshal", func(t *testing.T) { - t.Parallel() - e := Entities{} - ent := Entity{ - UID: NewEntityUID("Type", "id"), - Parents: []EntityUID{}, - Attributes: Record{"key": Long(42)}, - } - e[ent.UID] = ent - b, err := e.MarshalJSON() - testutilOK(t, err) - testutilEquals(t, string(b), `[{"uid":{"type":"Type","id":"id"},"attrs":{"key":42}}]`) - }) - - t.Run("Unmarshal", func(t *testing.T) { - t.Parallel() - b := []byte(`[{"uid":{"type":"Type","id":"id"},"parents":[],"attrs":{"key":42}}]`) - var e Entities - err := json.Unmarshal(b, &e) - testutilOK(t, err) - want := Entities{} - ent := Entity{ - UID: NewEntityUID("Type", "id"), - Parents: []EntityUID{}, - Attributes: Record{"key": Long(42)}, - } - want[ent.UID] = ent - testutilEquals(t, e, want) - }) - - t.Run("UnmarshalErr", func(t *testing.T) { - t.Parallel() - var e Entities - err := e.UnmarshalJSON([]byte(`!@#$`)) - testutilError(t, err) - }) -} - -func TestJSONEffect(t *testing.T) { - t.Parallel() - t.Run("MarshalPermit", func(t *testing.T) { - t.Parallel() - e := Permit - b, err := e.MarshalJSON() - testutilOK(t, err) - testutilEquals(t, string(b), `"permit"`) - }) - t.Run("MarshalForbid", func(t *testing.T) { - t.Parallel() - e := Forbid - b, err := e.MarshalJSON() - testutilOK(t, err) - testutilEquals(t, string(b), `"forbid"`) - }) - t.Run("UnmarshalPermit", func(t *testing.T) { - t.Parallel() - var e Effect - err := json.Unmarshal([]byte(`"permit"`), &e) - testutilOK(t, err) - testutilEquals(t, e, Permit) - }) - t.Run("UnmarshalForbid", func(t *testing.T) { - t.Parallel() - var e Effect - err := json.Unmarshal([]byte(`"forbid"`), &e) - testutilOK(t, err) - testutilEquals(t, e, Forbid) - }) -} - -func TestJSONDecision(t *testing.T) { - t.Parallel() - t.Run("MarshalAllow", func(t *testing.T) { - t.Parallel() - d := Allow - b, err := d.MarshalJSON() - testutilOK(t, err) - testutilEquals(t, string(b), `"allow"`) - }) - t.Run("MarshalDeny", func(t *testing.T) { - t.Parallel() - d := Deny - b, err := d.MarshalJSON() - testutilOK(t, err) - testutilEquals(t, string(b), `"deny"`) - }) - t.Run("UnmarshalAllow", func(t *testing.T) { - t.Parallel() - var d Decision - err := json.Unmarshal([]byte(`"allow"`), &d) - testutilOK(t, err) - testutilEquals(t, d, Allow) - }) - t.Run("UnmarshalDeny", func(t *testing.T) { - t.Parallel() - var d Decision - err := json.Unmarshal([]byte(`"deny"`), &d) - testutilOK(t, err) - testutilEquals(t, d, Deny) + testutil.Error(t, err) }) } diff --git a/types/testutil.go b/types/testutil.go new file mode 100644 index 00000000..787f96b9 --- /dev/null +++ b/types/testutil.go @@ -0,0 +1,36 @@ +package types + +import ( + "testing" + + "github.com/cedar-policy/cedar-go/testutil" +) + +func AssertValue(t *testing.T, got, want Value) { + t.Helper() + testutil.FatalIf( + t, + !((got == ZeroValue() && want == ZeroValue()) || + (got != ZeroValue() && want != ZeroValue() && got.Equal(want))), + "got %v want %v", got, want) +} + +func AssertBoolValue(t *testing.T, got Value, want bool) { + t.Helper() + testutil.Equals[Value](t, got, Boolean(want)) +} + +func AssertLongValue(t *testing.T, got Value, want int64) { + t.Helper() + testutil.Equals[Value](t, got, Long(want)) +} + +func AssertZeroValue(t *testing.T, got Value) { + t.Helper() + testutil.Equals(t, got, ZeroValue()) +} + +func AssertValueString(t *testing.T, v Value, want string) { + t.Helper() + testutil.Equals(t, v.String(), want) +} diff --git a/value.go b/types/value.go similarity index 79% rename from value.go rename to types/value.go index 07cef673..27304ddf 100644 --- a/value.go +++ b/types/value.go @@ -1,4 +1,4 @@ -package cedar +package types import ( "bytes" @@ -10,11 +10,14 @@ import ( "strings" "unicode" - "github.com/cedar-policy/cedar-go/x/exp/parser" "golang.org/x/exp/maps" "golang.org/x/exp/slices" ) +var ErrDecimal = fmt.Errorf("error parsing decimal value") +var ErrIP = fmt.Errorf("error parsing ip value") +var ErrType = fmt.Errorf("type error") + type Value interface { // String produces a string representation of the Value. String() string @@ -24,23 +27,23 @@ type Value interface { // applicable) JSON form, which is necessary for marshalling values within // Sets or Records where the type is not defined. ExplicitMarshalJSON() ([]byte, error) - equal(Value) bool - typeName() string + Equal(Value) bool + TypeName() string deepClone() Value } -func zeroValue() Value { +func ZeroValue() Value { return nil } // A Boolean is a value that is either true or false. type Boolean bool -func (a Boolean) equal(bi Value) bool { +func (a Boolean) Equal(bi Value) bool { b, ok := bi.(Boolean) return ok && a == b } -func (v Boolean) typeName() string { return "bool" } +func (v Boolean) TypeName() string { return "bool" } // String produces a string representation of the Boolean, e.g. `true`. func (v Boolean) String() string { return v.Cedar() } @@ -54,17 +57,25 @@ func (v Boolean) Cedar() string { func (v Boolean) ExplicitMarshalJSON() ([]byte, error) { return json.Marshal(v) } func (v Boolean) deepClone() Value { return v } +func ValueToBool(v Value) (Boolean, error) { + bv, ok := v.(Boolean) + if !ok { + return false, fmt.Errorf("%w: expected bool, got %v", ErrType, v.TypeName()) + } + return bv, nil +} + // A Long is a whole number without decimals that can range from -9223372036854775808 to 9223372036854775807. type Long int64 -func (a Long) equal(bi Value) bool { +func (a Long) Equal(bi Value) bool { b, ok := bi.(Long) return ok && a == b } // ExplicitMarshalJSON marshals the Long into JSON. func (v Long) ExplicitMarshalJSON() ([]byte, error) { return json.Marshal(v) } -func (v Long) typeName() string { return "long" } +func (v Long) TypeName() string { return "long" } // String produces a string representation of the Long, e.g. `42`. func (v Long) String() string { return v.Cedar() } @@ -75,17 +86,25 @@ func (v Long) Cedar() string { } func (v Long) deepClone() Value { return v } +func ValueToLong(v Value) (Long, error) { + lv, ok := v.(Long) + if !ok { + return 0, fmt.Errorf("%w: expected long, got %v", ErrType, v.TypeName()) + } + return lv, nil +} + // A String is a sequence of characters consisting of letters, numbers, or symbols. type String string -func (a String) equal(bi Value) bool { +func (a String) Equal(bi Value) bool { b, ok := bi.(String) return ok && a == b } // ExplicitMarshalJSON marshals the String into JSON. func (v String) ExplicitMarshalJSON() ([]byte, error) { return json.Marshal(v) } -func (v String) typeName() string { return "string" } +func (v String) TypeName() string { return "string" } // String produces an unquoted string representation of the String, e.g. `hello`. func (v String) String() string { @@ -94,37 +113,45 @@ func (v String) String() string { // Cedar produces a valid Cedar language representation of the String, e.g. `"hello"`. func (v String) Cedar() string { - return parser.FakeRustQuote(string(v)) + return strconv.Quote(string(v)) } func (v String) deepClone() Value { return v } +func ValueToString(v Value) (String, error) { + sv, ok := v.(String) + if !ok { + return "", fmt.Errorf("%w: expected string, got %v", ErrType, v.TypeName()) + } + return sv, nil +} + // A Set is a collection of elements that can be of the same or different types. type Set []Value -func (s Set) contains(v Value) bool { +func (s Set) Contains(v Value) bool { for _, e := range s { - if e.equal(v) { + if e.Equal(v) { return true } } return false } -// Equals returns true if the sets are equal. -func (s Set) Equals(b Set) bool { return s.equal(b) } +// Equals returns true if the sets are Equal. +func (s Set) Equals(b Set) bool { return s.Equal(b) } -func (as Set) equal(bi Value) bool { +func (as Set) Equal(bi Value) bool { bs, ok := bi.(Set) if !ok { return false } for _, a := range as { - if !bs.contains(a) { + if !bs.Contains(a) { return false } } for _, b := range bs { - if !as.contains(b) { + if !as.Contains(b) { return false } } @@ -170,7 +197,7 @@ func (v Set) MarshalJSON() ([]byte, error) { // explicit JSON form for all the values in the Set. func (v Set) ExplicitMarshalJSON() ([]byte, error) { return v.MarshalJSON() } -func (v Set) typeName() string { return "set" } +func (v Set) TypeName() string { return "set" } // String produces a string representation of the Set, e.g. `[1,2,3]`. func (v Set) String() string { return v.Cedar() } @@ -202,21 +229,29 @@ func (v Set) DeepClone() Set { return res } +func ValueToSet(v Value) (Set, error) { + sv, ok := v.(Set) + if !ok { + return nil, fmt.Errorf("%w: expected set, got %v", ErrType, v.TypeName()) + } + return sv, nil +} + // A Record is a collection of attributes. Each attribute consists of a name and // an associated value. Names are simple strings. Values can be of any type. type Record map[string]Value -// Equals returns true if the records are equal. -func (r Record) Equals(b Record) bool { return r.equal(b) } +// Equals returns true if the records are Equal. +func (r Record) Equals(b Record) bool { return r.Equal(b) } -func (a Record) equal(bi Value) bool { +func (a Record) Equal(bi Value) bool { b, ok := bi.(Record) if !ok || len(a) != len(b) { return false } for k, av := range a { bv, ok := b[k] - if !ok || !av.equal(bv) { + if !ok || !av.Equal(bv) { return false } } @@ -264,7 +299,7 @@ func (v Record) MarshalJSON() ([]byte, error) { // ExplicitMarshalJSON marshals the Record into JSON, the marshaller uses the // explicit JSON form for all the values in the Record. func (v Record) ExplicitMarshalJSON() ([]byte, error) { return v.MarshalJSON() } -func (r Record) typeName() string { return "record" } +func (r Record) TypeName() string { return "record" } // String produces a string representation of the Record, e.g. `{"a":1,"b":2,"c":3}`. func (r Record) String() string { return r.Cedar() } @@ -282,7 +317,7 @@ func (r Record) Cedar() string { sb.WriteString(",") } first = false - sb.WriteString(parser.FakeRustQuote(k)) + sb.WriteString(strconv.Quote(k)) sb.WriteString(":") sb.WriteString(v.Cedar()) } @@ -303,6 +338,14 @@ func (v Record) DeepClone() Record { return res } +func ValueToRecord(v Value) (Record, error) { + rv, ok := v.(Record) + if !ok { + return nil, fmt.Errorf("%w: expected record got %v", ErrType, v.TypeName()) + } + return rv, nil +} + // An EntityUID is the identifier for a principal, action, or resource. type EntityUID struct { Type string @@ -321,18 +364,18 @@ func (a EntityUID) IsZero() bool { return a.Type == "" && a.ID == "" } -func (a EntityUID) equal(bi Value) bool { +func (a EntityUID) Equal(bi Value) bool { b, ok := bi.(EntityUID) return ok && a == b } -func (v EntityUID) typeName() string { return fmt.Sprintf("(entity of type `%s`)", v.Type) } +func (v EntityUID) TypeName() string { return fmt.Sprintf("(entity of type `%s`)", v.Type) } // String produces a string representation of the EntityUID, e.g. `Type::"id"`. func (v EntityUID) String() string { return v.Cedar() } // Cedar produces a valid Cedar language representation of the EntityUID, e.g. `Type::"id"`. func (v EntityUID) Cedar() string { - return v.Type + "::" + parser.FakeRustQuote(v.ID) + return v.Type + "::" + strconv.Quote(v.ID) } func (v *EntityUID) UnmarshalJSON(b []byte) error { @@ -372,29 +415,45 @@ func (v EntityUID) ExplicitMarshalJSON() ([]byte, error) { } func (v EntityUID) deepClone() Value { return v } -func entityValueFromSlice(v []string) EntityUID { +func ValueToEntity(v Value) (EntityUID, error) { + ev, ok := v.(EntityUID) + if !ok { + return EntityUID{}, fmt.Errorf("%w: expected (entity of type `any_entity_type`), got %v", ErrType, v.TypeName()) + } + return ev, nil +} + +func EntityValueFromSlice(v []string) EntityUID { return EntityUID{ Type: strings.Join(v[:len(v)-1], "::"), ID: v[len(v)-1], } } -// path is the type portion of an EntityUID -type path string +// Path is the type portion of an EntityUID +type Path string -func (a path) equal(bi Value) bool { - b, ok := bi.(path) +func (a Path) Equal(bi Value) bool { + b, ok := bi.(Path) return ok && a == b } -func (v path) typeName() string { return fmt.Sprintf("(path of type `%s`)", v) } +func (v Path) TypeName() string { return fmt.Sprintf("(Path of type `%s`)", v) } + +func (v Path) String() string { return string(v) } +func (v Path) Cedar() string { return string(v) } +func (v Path) ExplicitMarshalJSON() ([]byte, error) { return json.Marshal(string(v)) } +func (v Path) deepClone() Value { return v } -func (v path) String() string { return string(v) } -func (v path) Cedar() string { return string(v) } -func (v path) ExplicitMarshalJSON() ([]byte, error) { return json.Marshal(string(v)) } -func (v path) deepClone() Value { return v } +func ValueToPath(v Value) (Path, error) { + ev, ok := v.(Path) + if !ok { + return "", fmt.Errorf("%w: expected (Path of type `any_entity_type`), got %v", ErrType, v.TypeName()) + } + return ev, nil +} -func pathFromSlice(v []string) path { - return path(strings.Join(v, "::")) +func PathFromSlice(v []string) Path { + return Path(strings.Join(v, "::")) } // A Decimal is a value with both a whole number part and a decimal part of no @@ -409,7 +468,7 @@ const DecimalPrecision = 10000 func ParseDecimal(s string) (Decimal, error) { // Check for empty string. if len(s) == 0 { - return Decimal(0), fmt.Errorf("%w: string too short", errDecimal) + return Decimal(0), fmt.Errorf("%w: string too short", ErrDecimal) } i := 0 @@ -419,14 +478,14 @@ func ParseDecimal(s string) (Decimal, error) { negative = true i++ if i == len(s) { - return Decimal(0), fmt.Errorf("%w: string too short", errDecimal) + return Decimal(0), fmt.Errorf("%w: string too short", ErrDecimal) } } // Parse the required first digit. c := rune(s[i]) if !unicode.IsDigit(c) { - return Decimal(0), fmt.Errorf("%w: unexpected character %s", errDecimal, strconv.QuoteRune(c)) + return Decimal(0), fmt.Errorf("%w: unexpected character %s", ErrDecimal, strconv.QuoteRune(c)) } integer := int64(c - '0') i++ @@ -434,18 +493,18 @@ func ParseDecimal(s string) (Decimal, error) { // Parse any other digits, ending with i pointing to '.'. for ; ; i++ { if i == len(s) { - return Decimal(0), fmt.Errorf("%w: string missing decimal point", errDecimal) + return Decimal(0), fmt.Errorf("%w: string missing decimal point", ErrDecimal) } c = rune(s[i]) if c == '.' { break } if !unicode.IsDigit(c) { - return Decimal(0), fmt.Errorf("%w: unexpected character %s", errDecimal, strconv.QuoteRune(c)) + return Decimal(0), fmt.Errorf("%w: unexpected character %s", ErrDecimal, strconv.QuoteRune(c)) } integer = 10*integer + int64(c-'0') if integer > 922337203685477 { - return Decimal(0), fmt.Errorf("%w: overflow", errDecimal) + return Decimal(0), fmt.Errorf("%w: overflow", ErrDecimal) } } @@ -458,7 +517,7 @@ func ParseDecimal(s string) (Decimal, error) { for ; i < len(s); i++ { c = rune(s[i]) if !unicode.IsDigit(c) { - return Decimal(0), fmt.Errorf("%w: unexpected character %s", errDecimal, strconv.QuoteRune(c)) + return Decimal(0), fmt.Errorf("%w: unexpected character %s", ErrDecimal, strconv.QuoteRune(c)) } fraction = 10*fraction + int64(c-'0') fractionDigits++ @@ -467,7 +526,7 @@ func ParseDecimal(s string) (Decimal, error) { // Adjust the fraction part based on how many digits we parsed. switch fractionDigits { case 0: - return Decimal(0), fmt.Errorf("%w: missing digits after decimal point", errDecimal) + return Decimal(0), fmt.Errorf("%w: missing digits after decimal point", ErrDecimal) case 1: fraction *= 1000 case 2: @@ -476,12 +535,12 @@ func ParseDecimal(s string) (Decimal, error) { fraction *= 10 case 4: default: - return Decimal(0), fmt.Errorf("%w: too many digits after decimal point", errDecimal) + return Decimal(0), fmt.Errorf("%w: too many digits after decimal point", ErrDecimal) } // Check for overflow before we put the number together. if integer >= 922337203685477 && (fraction > 5808 || (!negative && fraction == 5808)) { - return Decimal(0), fmt.Errorf("%w: overflow", errDecimal) + return Decimal(0), fmt.Errorf("%w: overflow", ErrDecimal) } // Put the number together. @@ -496,12 +555,12 @@ func ParseDecimal(s string) (Decimal, error) { } } -func (a Decimal) equal(bi Value) bool { +func (a Decimal) Equal(bi Value) bool { b, ok := bi.(Decimal) return ok && a == b } -func (v Decimal) typeName() string { return "decimal" } +func (v Decimal) TypeName() string { return "decimal" } // Cedar produces a valid Cedar language representation of the Decimal, e.g. `decimal("12.34")`. func (v Decimal) Cedar() string { return `decimal("` + v.String() + `")` } @@ -573,6 +632,14 @@ func (v Decimal) ExplicitMarshalJSON() ([]byte, error) { } func (v Decimal) deepClone() Value { return v } +func ValueToDecimal(v Value) (Decimal, error) { + d, ok := v.(Decimal) + if !ok { + return 0, fmt.Errorf("%w: expected decimal, got %v", ErrType, v.TypeName()) + } + return d, nil +} + // An IPAddr is value that represents an IP address. It can be either IPv4 or IPv6. // The value can represent an individual address or a range of addresses. type IPAddr netip.Prefix @@ -581,22 +648,22 @@ type IPAddr netip.Prefix func ParseIPAddr(s string) (IPAddr, error) { // We disallow IPv4-mapped IPv6 addresses in dotted notation because Cedar does. if strings.Count(s, ":") >= 2 && strings.Count(s, ".") >= 2 { - return IPAddr{}, fmt.Errorf("%w: cannot parse IPv4 addresses embedded in IPv6 addresses", errIP) + return IPAddr{}, fmt.Errorf("%w: cannot parse IPv4 addresses embedded in IPv6 addresses", ErrIP) } else if net, err := netip.ParsePrefix(s); err == nil { return IPAddr(net), nil } else if addr, err := netip.ParseAddr(s); err == nil { return IPAddr(netip.PrefixFrom(addr, addr.BitLen())), nil } else { - return IPAddr{}, fmt.Errorf("%w: error parsing IP address %s", errIP, s) + return IPAddr{}, fmt.Errorf("%w: error parsing IP address %s", ErrIP, s) } } -func (a IPAddr) equal(bi Value) bool { +func (a IPAddr) Equal(bi Value) bool { b, ok := bi.(IPAddr) return ok && a == b } -func (v IPAddr) typeName() string { return "IP" } +func (v IPAddr) TypeName() string { return "IP" } // Cedar produces a valid Cedar language representation of the IPAddr, e.g. `ip("127.0.0.1")`. func (v IPAddr) Cedar() string { return `ip("` + v.String() + `")` } @@ -613,15 +680,15 @@ func (v IPAddr) Prefix() netip.Prefix { return netip.Prefix(v) } -func (v IPAddr) isIPv4() bool { +func (v IPAddr) IsIPv4() bool { return v.Addr().Is4() } -func (v IPAddr) isIPv6() bool { +func (v IPAddr) IsIPv6() bool { return v.Addr().Is6() } -func (v IPAddr) isLoopback() bool { +func (v IPAddr) IsLoopback() bool { // This comment is in the Cedar Rust implementation: // // Loopback addresses are "127.0.0.0/8" for IpV4 and "::1" for IpV6 @@ -640,7 +707,7 @@ func (v IPAddr) Addr() netip.Addr { return netip.Prefix(v).Addr() } -func (v IPAddr) isMulticast() bool { +func (v IPAddr) IsMulticast() bool { // This comment is in the Cedar Rust implementation: // // Multicast addresses are "224.0.0.0/4" for IpV4 and "ff00::/8" for @@ -654,7 +721,7 @@ func (v IPAddr) isMulticast() bool { // range `ip2/prefix2`, then `ip1` is in `ip2/prefix2` and `prefix1 >= // prefix2` var min_prefix_len int - if v.isIPv4() { + if v.IsIPv4() { min_prefix_len = 4 } else { min_prefix_len = 8 @@ -662,7 +729,7 @@ func (v IPAddr) isMulticast() bool { return v.Addr().IsMulticast() && v.Prefix().Bits() >= min_prefix_len } -func (c IPAddr) contains(o IPAddr) bool { +func (c IPAddr) Contains(o IPAddr) bool { return c.Prefix().Contains(o.Addr()) && c.Prefix().Bits() <= o.Prefix().Bits() } @@ -721,3 +788,11 @@ func (v IPAddr) ExplicitMarshalJSON() ([]byte, error) { // in this case, netip.Prefix does contain a pointer, but // the interface given is immutable, so it is safe to return func (v IPAddr) deepClone() Value { return v } + +func ValueToIP(v Value) (IPAddr, error) { + i, ok := v.(IPAddr) + if !ok { + return IPAddr{}, fmt.Errorf("%w: expected ipaddr, got %v", ErrType, v.TypeName()) + } + return i, nil +} diff --git a/value_test.go b/types/value_test.go similarity index 62% rename from value_test.go rename to types/value_test.go index 08eadf2b..9381e78e 100644 --- a/value_test.go +++ b/types/value_test.go @@ -1,48 +1,50 @@ -package cedar +package types import ( "fmt" "testing" + + "github.com/cedar-policy/cedar-go/testutil" ) func TestBool(t *testing.T) { t.Parallel() t.Run("roundTrip", func(t *testing.T) { t.Parallel() - v, err := valueToBool(Boolean(true)) - testutilOK(t, err) - testutilEquals(t, v, true) + v, err := ValueToBool(Boolean(true)) + testutil.OK(t, err) + testutil.Equals(t, v, true) }) t.Run("toBoolOnNonBool", func(t *testing.T) { t.Parallel() - v, err := valueToBool(Long(0)) - assertError(t, err, errType) - testutilEquals(t, v, false) + v, err := ValueToBool(Long(0)) + testutil.AssertError(t, err, ErrType) + testutil.Equals(t, v, false) }) - t.Run("equal", func(t *testing.T) { + t.Run("Equal", func(t *testing.T) { t.Parallel() t1 := Boolean(true) t2 := Boolean(true) f := Boolean(false) zero := Long(0) - testutilFatalIf(t, !t1.equal(t1), "%v not equal to %v", t1, t1) - testutilFatalIf(t, !t1.equal(t2), "%v not equal to %v", t1, t2) - testutilFatalIf(t, t1.equal(f), "%v equal to %v", t1, f) - testutilFatalIf(t, f.equal(t1), "%v equal to %v", f, t1) - testutilFatalIf(t, f.equal(zero), "%v equal to %v", f, zero) + testutil.FatalIf(t, !t1.Equal(t1), "%v not Equal to %v", t1, t1) + testutil.FatalIf(t, !t1.Equal(t2), "%v not Equal to %v", t1, t2) + testutil.FatalIf(t, t1.Equal(f), "%v Equal to %v", t1, f) + testutil.FatalIf(t, f.Equal(t1), "%v Equal to %v", f, t1) + testutil.FatalIf(t, f.Equal(zero), "%v Equal to %v", f, zero) }) t.Run("string", func(t *testing.T) { t.Parallel() - assertValueString(t, Boolean(true), "true") + AssertValueString(t, Boolean(true), "true") }) - t.Run("typeName", func(t *testing.T) { + t.Run("TypeName", func(t *testing.T) { t.Parallel() - tn := Boolean(true).typeName() - testutilEquals(t, tn, "bool") + tn := Boolean(true).TypeName() + testutil.Equals(t, tn, "bool") }) } @@ -50,40 +52,40 @@ func TestLong(t *testing.T) { t.Parallel() t.Run("roundTrip", func(t *testing.T) { t.Parallel() - v, err := valueToLong(Long(42)) - testutilOK(t, err) - testutilEquals(t, v, 42) + v, err := ValueToLong(Long(42)) + testutil.OK(t, err) + testutil.Equals(t, v, 42) }) t.Run("toLongOnNonLong", func(t *testing.T) { t.Parallel() - v, err := valueToLong(Boolean(true)) - assertError(t, err, errType) - testutilEquals(t, v, 0) + v, err := ValueToLong(Boolean(true)) + testutil.AssertError(t, err, ErrType) + testutil.Equals(t, v, 0) }) - t.Run("equal", func(t *testing.T) { + t.Run("Equal", func(t *testing.T) { t.Parallel() one := Long(1) one2 := Long(1) zero := Long(0) f := Boolean(false) - testutilFatalIf(t, !one.equal(one), "%v not equal to %v", one, one) - testutilFatalIf(t, !one.equal(one2), "%v not equal to %v", one, one2) - testutilFatalIf(t, one.equal(zero), "%v equal to %v", one, zero) - testutilFatalIf(t, zero.equal(one), "%v equal to %v", zero, one) - testutilFatalIf(t, zero.equal(f), "%v equal to %v", zero, f) + testutil.FatalIf(t, !one.Equal(one), "%v not Equal to %v", one, one) + testutil.FatalIf(t, !one.Equal(one2), "%v not Equal to %v", one, one2) + testutil.FatalIf(t, one.Equal(zero), "%v Equal to %v", one, zero) + testutil.FatalIf(t, zero.Equal(one), "%v Equal to %v", zero, one) + testutil.FatalIf(t, zero.Equal(f), "%v Equal to %v", zero, f) }) t.Run("string", func(t *testing.T) { t.Parallel() - assertValueString(t, Long(1), "1") + AssertValueString(t, Long(1), "1") }) - t.Run("typeName", func(t *testing.T) { + t.Run("TypeName", func(t *testing.T) { t.Parallel() - tn := Long(1).typeName() - testutilEquals(t, tn, "long") + tn := Long(1).TypeName() + testutil.Equals(t, tn, "long") }) } @@ -91,38 +93,38 @@ func TestString(t *testing.T) { t.Parallel() t.Run("roundTrip", func(t *testing.T) { t.Parallel() - v, err := valueToString(String("hello")) - testutilOK(t, err) - testutilEquals(t, v, "hello") + v, err := ValueToString(String("hello")) + testutil.OK(t, err) + testutil.Equals(t, v, "hello") }) t.Run("toStringOnNonString", func(t *testing.T) { t.Parallel() - v, err := valueToString(Boolean(true)) - assertError(t, err, errType) - testutilEquals(t, v, "") + v, err := ValueToString(Boolean(true)) + testutil.AssertError(t, err, ErrType) + testutil.Equals(t, v, "") }) - t.Run("equal", func(t *testing.T) { + t.Run("Equal", func(t *testing.T) { t.Parallel() hello := String("hello") hello2 := String("hello") goodbye := String("goodbye") - testutilFatalIf(t, !hello.equal(hello), "%v not equal to %v", hello, hello) - testutilFatalIf(t, !hello.equal(hello2), "%v not equal to %v", hello, hello2) - testutilFatalIf(t, hello.equal(goodbye), "%v equal to %v", hello, goodbye) + testutil.FatalIf(t, !hello.Equal(hello), "%v not Equal to %v", hello, hello) + testutil.FatalIf(t, !hello.Equal(hello2), "%v not Equal to %v", hello, hello2) + testutil.FatalIf(t, hello.Equal(goodbye), "%v Equal to %v", hello, goodbye) }) t.Run("string", func(t *testing.T) { t.Parallel() - assertValueString(t, String("hello"), `hello`) - assertValueString(t, String("hello\ngoodbye"), "hello\ngoodbye") + AssertValueString(t, String("hello"), `hello`) + AssertValueString(t, String("hello\ngoodbye"), "hello\ngoodbye") }) - t.Run("typeName", func(t *testing.T) { + t.Run("TypeName", func(t *testing.T) { t.Parallel() - tn := String("hello").typeName() - testutilEquals(t, tn, "string") + tn := String("hello").TypeName() + testutil.Equals(t, tn, "string") }) } @@ -131,20 +133,20 @@ func TestSet(t *testing.T) { t.Run("roundTrip", func(t *testing.T) { t.Parallel() v := Set{Boolean(true), Long(1)} - slice, err := valueToSet(v) - testutilOK(t, err) + slice, err := ValueToSet(v) + testutil.OK(t, err) v2 := slice - testutilFatalIf(t, !v.equal(v2), "got %v want %v", v, v2) + testutil.FatalIf(t, !v.Equal(v2), "got %v want %v", v, v2) }) t.Run("ToSetOnNonSet", func(t *testing.T) { t.Parallel() - v, err := valueToSet(Boolean(true)) - assertError(t, err, errType) - testutilEquals(t, v, nil) + v, err := ValueToSet(Boolean(true)) + testutil.AssertError(t, err, ErrType) + testutil.Equals(t, v, nil) }) - t.Run("equal", func(t *testing.T) { + t.Run("Equal", func(t *testing.T) { t.Parallel() empty := Set{} empty2 := Set{} @@ -162,34 +164,34 @@ func TestSet(t *testing.T) { Long(3), Long(2), Long(2), Long(1), } - testutilFatalIf(t, !empty.Equals(empty), "%v not equal to %v", empty, empty) - testutilFatalIf(t, !empty.Equals(empty2), "%v not equal to %v", empty, empty2) - testutilFatalIf(t, !oneTrue.Equals(oneTrue), "%v not equal to %v", oneTrue, oneTrue) - testutilFatalIf(t, !oneTrue.Equals(oneTrue2), "%v not equal to %v", oneTrue, oneTrue2) - testutilFatalIf(t, !nestedOnce.Equals(nestedOnce), "%v not equal to %v", nestedOnce, nestedOnce) - testutilFatalIf(t, !nestedOnce.Equals(nestedOnce2), "%v not equal to %v", nestedOnce, nestedOnce2) - testutilFatalIf(t, !nestedTwice.Equals(nestedTwice), "%v not equal to %v", nestedTwice, nestedTwice) - testutilFatalIf(t, !nestedTwice.Equals(nestedTwice2), "%v not equal to %v", nestedTwice, nestedTwice2) - testutilFatalIf(t, !oneTwoThree.Equals(threeTwoTwoOne), "%v not equal to %v", oneTwoThree, threeTwoTwoOne) + testutil.FatalIf(t, !empty.Equals(empty), "%v not Equal to %v", empty, empty) + testutil.FatalIf(t, !empty.Equals(empty2), "%v not Equal to %v", empty, empty2) + testutil.FatalIf(t, !oneTrue.Equals(oneTrue), "%v not Equal to %v", oneTrue, oneTrue) + testutil.FatalIf(t, !oneTrue.Equals(oneTrue2), "%v not Equal to %v", oneTrue, oneTrue2) + testutil.FatalIf(t, !nestedOnce.Equals(nestedOnce), "%v not Equal to %v", nestedOnce, nestedOnce) + testutil.FatalIf(t, !nestedOnce.Equals(nestedOnce2), "%v not Equal to %v", nestedOnce, nestedOnce2) + testutil.FatalIf(t, !nestedTwice.Equals(nestedTwice), "%v not Equal to %v", nestedTwice, nestedTwice) + testutil.FatalIf(t, !nestedTwice.Equals(nestedTwice2), "%v not Equal to %v", nestedTwice, nestedTwice2) + testutil.FatalIf(t, !oneTwoThree.Equals(threeTwoTwoOne), "%v not Equal to %v", oneTwoThree, threeTwoTwoOne) - testutilFatalIf(t, empty.Equals(oneFalse), "%v equal to %v", empty, oneFalse) - testutilFatalIf(t, oneTrue.Equals(oneFalse), "%v equal to %v", oneTrue, oneFalse) - testutilFatalIf(t, nestedOnce.Equals(nestedTwice), "%v equal to %v", nestedOnce, nestedTwice) + testutil.FatalIf(t, empty.Equals(oneFalse), "%v Equal to %v", empty, oneFalse) + testutil.FatalIf(t, oneTrue.Equals(oneFalse), "%v Equal to %v", oneTrue, oneFalse) + testutil.FatalIf(t, nestedOnce.Equals(nestedTwice), "%v Equal to %v", nestedOnce, nestedTwice) }) t.Run("string", func(t *testing.T) { t.Parallel() - assertValueString(t, Set{}, "[]") - assertValueString( + AssertValueString(t, Set{}, "[]") + AssertValueString( t, Set{Boolean(true), Long(1)}, "[true,1]") }) - t.Run("typeName", func(t *testing.T) { + t.Run("TypeName", func(t *testing.T) { t.Parallel() - tn := Set{}.typeName() - testutilEquals(t, tn, "set") + tn := Set{}.TypeName() + testutil.Equals(t, tn, "set") }) } @@ -201,20 +203,20 @@ func TestRecord(t *testing.T) { "foo": Boolean(true), "bar": Long(1), } - map_, err := valueToRecord(v) - testutilOK(t, err) + map_, err := ValueToRecord(v) + testutil.OK(t, err) v2 := map_ - testutilFatalIf(t, !v.equal(v2), "got %v want %v", v, v2) + testutil.FatalIf(t, !v.Equal(v2), "got %v want %v", v, v2) }) t.Run("toRecordOnNonRecord", func(t *testing.T) { t.Parallel() - v, err := valueToRecord(String("hello")) - assertError(t, err, errType) - testutilEquals(t, v, nil) + v, err := ValueToRecord(String("hello")) + testutil.AssertError(t, err, ErrType) + testutil.Equals(t, v, nil) }) - t.Run("equal", func(t *testing.T) { + t.Run("Equal", func(t *testing.T) { t.Parallel() empty := Record{} empty2 := Record{} @@ -245,28 +247,28 @@ func TestRecord(t *testing.T) { "nest": twoElems, } - testutilFatalIf(t, !empty.Equals(empty), "%v not equal to %v", empty, empty) - testutilFatalIf(t, !empty.Equals(empty2), "%v not equal to %v", empty, empty2) + testutil.FatalIf(t, !empty.Equals(empty), "%v not Equal to %v", empty, empty) + testutil.FatalIf(t, !empty.Equals(empty2), "%v not Equal to %v", empty, empty2) - testutilFatalIf(t, !twoElems.Equals(twoElems), "%v not equal to %v", twoElems, twoElems) - testutilFatalIf(t, !twoElems.Equals(twoElems2), "%v not equal to %v", twoElems, twoElems2) + testutil.FatalIf(t, !twoElems.Equals(twoElems), "%v not Equal to %v", twoElems, twoElems) + testutil.FatalIf(t, !twoElems.Equals(twoElems2), "%v not Equal to %v", twoElems, twoElems2) - testutilFatalIf(t, !nested.Equals(nested), "%v not equal to %v", nested, nested) - testutilFatalIf(t, !nested.Equals(nested2), "%v not equal to %v", nested, nested2) + testutil.FatalIf(t, !nested.Equals(nested), "%v not Equal to %v", nested, nested) + testutil.FatalIf(t, !nested.Equals(nested2), "%v not Equal to %v", nested, nested2) - testutilFatalIf(t, nested.Equals(twoElems), "%v equal to %v", nested, twoElems) - testutilFatalIf(t, twoElems.Equals(differentValues), "%v equal to %v", twoElems, differentValues) - testutilFatalIf(t, twoElems.Equals(differentKeys), "%v equal to %v", twoElems, differentKeys) + testutil.FatalIf(t, nested.Equals(twoElems), "%v Equal to %v", nested, twoElems) + testutil.FatalIf(t, twoElems.Equals(differentValues), "%v Equal to %v", twoElems, differentValues) + testutil.FatalIf(t, twoElems.Equals(differentKeys), "%v Equal to %v", twoElems, differentKeys) }) t.Run("string", func(t *testing.T) { t.Parallel() - assertValueString(t, Record{}, "{}") - assertValueString( + AssertValueString(t, Record{}, "{}") + AssertValueString( t, Record{"foo": Boolean(true)}, `{"foo":true}`) - assertValueString( + AssertValueString( t, Record{ "foo": Boolean(true), @@ -275,10 +277,10 @@ func TestRecord(t *testing.T) { `{"bar":"blah","foo":true}`) }) - t.Run("typeName", func(t *testing.T) { + t.Run("TypeName", func(t *testing.T) { t.Parallel() - tn := Record{}.typeName() - testutilEquals(t, tn, "record") + tn := Record{}.TypeName() + testutil.Equals(t, tn, "record") }) } @@ -287,37 +289,37 @@ func TestEntity(t *testing.T) { t.Run("roundTrip", func(t *testing.T) { t.Parallel() want := EntityUID{Type: "User", ID: "bananas"} - v, err := valueToEntity(want) - testutilOK(t, err) - testutilEquals(t, v, want) + v, err := ValueToEntity(want) + testutil.OK(t, err) + testutil.Equals(t, v, want) }) t.Run("ToEntityOnNonEntity", func(t *testing.T) { t.Parallel() - v, err := valueToEntity(String("hello")) - assertError(t, err, errType) - testutilEquals(t, v, EntityUID{}) + v, err := ValueToEntity(String("hello")) + testutil.AssertError(t, err, ErrType) + testutil.Equals(t, v, EntityUID{}) }) - t.Run("equal", func(t *testing.T) { + t.Run("Equal", func(t *testing.T) { t.Parallel() twoElems := EntityUID{"type", "id"} twoElems2 := EntityUID{"type", "id"} differentValues := EntityUID{"asdf", "vfds"} - testutilFatalIf(t, !twoElems.equal(twoElems), "%v not equal to %v", twoElems, twoElems) - testutilFatalIf(t, !twoElems.equal(twoElems2), "%v not equal to %v", twoElems, twoElems2) - testutilFatalIf(t, twoElems.equal(differentValues), "%v equal to %v", twoElems, differentValues) + testutil.FatalIf(t, !twoElems.Equal(twoElems), "%v not Equal to %v", twoElems, twoElems) + testutil.FatalIf(t, !twoElems.Equal(twoElems2), "%v not Equal to %v", twoElems, twoElems2) + testutil.FatalIf(t, twoElems.Equal(differentValues), "%v Equal to %v", twoElems, differentValues) }) t.Run("string", func(t *testing.T) { t.Parallel() - assertValueString(t, EntityUID{Type: "type", ID: "id"}, `type::"id"`) - assertValueString(t, EntityUID{Type: "namespace::type", ID: "id"}, `namespace::type::"id"`) + AssertValueString(t, EntityUID{Type: "type", ID: "id"}, `type::"id"`) + AssertValueString(t, EntityUID{Type: "namespace::type", ID: "id"}, `namespace::type::"id"`) }) - t.Run("typeName", func(t *testing.T) { + t.Run("TypeName", func(t *testing.T) { t.Parallel() - tn := EntityUID{"T", "id"}.typeName() - testutilEquals(t, tn, "(entity of type `T`)") + tn := EntityUID{"T", "id"}.TypeName() + testutil.Equals(t, tn, "(entity of type `T`)") }) } @@ -378,8 +380,8 @@ func TestDecimal(t *testing.T) { t.Run(fmt.Sprintf("%s->%s", tt.in, tt.out), func(t *testing.T) { t.Parallel() d, err := ParseDecimal(tt.in) - testutilOK(t, err) - testutilEquals(t, d.String(), tt.out) + testutil.OK(t, err) + testutil.Equals(t, d.String(), tt.out) }) } } @@ -414,8 +416,8 @@ func TestDecimal(t *testing.T) { t.Run(fmt.Sprintf("%s->%s", tt.in, tt.errStr), func(t *testing.T) { t.Parallel() _, err := ParseDecimal(tt.in) - assertError(t, err, errDecimal) - testutilEquals(t, err.Error(), tt.errStr) + testutil.AssertError(t, err, ErrDecimal) + testutil.Equals(t, err.Error(), tt.errStr) }) } } @@ -423,36 +425,36 @@ func TestDecimal(t *testing.T) { t.Run("roundTrip", func(t *testing.T) { t.Parallel() dv, err := ParseDecimal("1.20") - testutilOK(t, err) - v, err := valueToDecimal(dv) - testutilOK(t, err) - testutilFatalIf(t, !v.equal(dv), "got %v want %v", v, dv) + testutil.OK(t, err) + v, err := ValueToDecimal(dv) + testutil.OK(t, err) + testutil.FatalIf(t, !v.Equal(dv), "got %v want %v", v, dv) }) t.Run("toDecimalOnNonDecimal", func(t *testing.T) { t.Parallel() - v, err := valueToDecimal(Boolean(true)) - assertError(t, err, errType) - testutilEquals(t, v, 0) + v, err := ValueToDecimal(Boolean(true)) + testutil.AssertError(t, err, ErrType) + testutil.Equals(t, v, 0) }) - t.Run("equal", func(t *testing.T) { + t.Run("Equal", func(t *testing.T) { t.Parallel() one := Decimal(10000) one2 := Decimal(10000) zero := Decimal(0) f := Boolean(false) - testutilFatalIf(t, !one.equal(one), "%v not equal to %v", one, one) - testutilFatalIf(t, !one.equal(one2), "%v not equal to %v", one, one2) - testutilFatalIf(t, one.equal(zero), "%v equal to %v", one, zero) - testutilFatalIf(t, zero.equal(one), "%v equal to %v", zero, one) - testutilFatalIf(t, zero.equal(f), "%v equal to %v", zero, f) + testutil.FatalIf(t, !one.Equal(one), "%v not Equal to %v", one, one) + testutil.FatalIf(t, !one.Equal(one2), "%v not Equal to %v", one, one2) + testutil.FatalIf(t, one.Equal(zero), "%v Equal to %v", one, zero) + testutil.FatalIf(t, zero.Equal(one), "%v Equal to %v", zero, one) + testutil.FatalIf(t, zero.Equal(f), "%v Equal to %v", zero, f) }) - t.Run("typeName", func(t *testing.T) { + t.Run("TypeName", func(t *testing.T) { t.Parallel() - tn := Decimal(0).typeName() - testutilEquals(t, tn, "decimal") + tn := Decimal(0).TypeName() + testutil.Equals(t, tn, "decimal") }) } @@ -500,10 +502,10 @@ func TestIP(t *testing.T) { t.Parallel() i, err := ParseIPAddr(tt.in) if tt.parses { - testutilOK(t, err) - testutilEquals(t, i.String(), tt.out) + testutil.OK(t, err) + testutil.Equals(t, i.String(), tt.out) } else { - testutilError(t, err) + testutil.Error(t, err) } }) } @@ -511,9 +513,9 @@ func TestIP(t *testing.T) { t.Run("toIPOnNonIP", func(t *testing.T) { t.Parallel() - v, err := valueToIP(Boolean(true)) - assertError(t, err, errType) - testutilEquals(t, v, IPAddr{}) + v, err := ValueToIP(Boolean(true)) + testutil.AssertError(t, err, ErrType) + testutil.Equals(t, v, IPAddr{}) }) t.Run("Equal", func(t *testing.T) { @@ -547,25 +549,25 @@ func TestIP(t *testing.T) { } for _, tt := range tests { tt := tt - t.Run(fmt.Sprintf("ip(%v).equal(ip(%v))", tt.lhs, tt.rhs), func(t *testing.T) { + t.Run(fmt.Sprintf("ip(%v).Equal(ip(%v))", tt.lhs, tt.rhs), func(t *testing.T) { t.Parallel() lhs, err := ParseIPAddr(tt.lhs) - testutilOK(t, err) + testutil.OK(t, err) rhs, err := ParseIPAddr(tt.rhs) - testutilOK(t, err) - equal := lhs.equal(rhs) + testutil.OK(t, err) + equal := lhs.Equal(rhs) if equal != tt.equal { - t.Fatalf("expected ip(%v).equal(ip(%v)) to be %v instead of %v", tt.lhs, tt.rhs, tt.equal, equal) + t.Fatalf("expected ip(%v).Equal(ip(%v)) to be %v instead of %v", tt.lhs, tt.rhs, tt.equal, equal) } if equal { - testutilFatalIf( + testutil.FatalIf( t, - !lhs.contains(rhs), - "ip(%v) and ip(%v) compare equal but !ip(%v).contains(ip(%v))", tt.lhs, tt.rhs, tt.lhs, tt.rhs) - testutilFatalIf( + !lhs.Contains(rhs), + "ip(%v) and ip(%v) compare Equal but !ip(%v).contains(ip(%v))", tt.lhs, tt.rhs, tt.lhs, tt.rhs) + testutil.FatalIf( t, - !rhs.contains(lhs), - "ip(%v) and ip(%v) compare equal but !ip(%v).contains(ip(%v))", tt.rhs, tt.lhs, tt.rhs, tt.lhs) + !rhs.Contains(lhs), + "ip(%v) and ip(%v) compare Equal but !ip(%v).contains(ip(%v))", tt.rhs, tt.lhs, tt.rhs, tt.lhs) } }) } @@ -598,12 +600,12 @@ func TestIP(t *testing.T) { t.Run(fmt.Sprintf("ip(%v).isIPv{4,6}()", tt.val), func(t *testing.T) { t.Parallel() val, err := ParseIPAddr(tt.val) - testutilOK(t, err) - isIPv4 := val.isIPv4() + testutil.OK(t, err) + isIPv4 := val.IsIPv4() if isIPv4 != tt.isIPv4 { t.Fatalf("expected ip(%v).isIPv4() to be %v instead of %v", tt.val, tt.isIPv4, isIPv4) } - isIPv6 := val.isIPv6() + isIPv6 := val.IsIPv6() if isIPv6 != tt.isIPv6 { t.Fatalf("expected ip(%v).isIPv6() to be %v instead of %v", tt.val, tt.isIPv6, isIPv6) } @@ -647,8 +649,8 @@ func TestIP(t *testing.T) { t.Run(fmt.Sprintf("ip(%v).isLoopback()", tt.val), func(t *testing.T) { t.Parallel() val, err := ParseIPAddr(tt.val) - testutilOK(t, err) - isLoopback := val.isLoopback() + testutil.OK(t, err) + isLoopback := val.IsLoopback() if isLoopback != tt.isLoopback { t.Fatalf("expected ip(%v).isLoopback() to be %v instead of %v", tt.val, tt.isLoopback, isLoopback) } @@ -681,8 +683,8 @@ func TestIP(t *testing.T) { t.Run(fmt.Sprintf("ip(%v).isMulticast()", tt.val), func(t *testing.T) { t.Parallel() val, err := ParseIPAddr(tt.val) - testutilOK(t, err) - isMulticast := val.isMulticast() + testutil.OK(t, err) + isMulticast := val.IsMulticast() if isMulticast != tt.isMulticast { t.Fatalf("expected ip(%v).isMulticast() to be %v instead of %v", tt.val, tt.isMulticast, isMulticast) } @@ -714,10 +716,10 @@ func TestIP(t *testing.T) { t.Run(fmt.Sprintf("ip(%v).contains(ip(%v))", tt.lhs, tt.rhs), func(t *testing.T) { t.Parallel() lhs, err := ParseIPAddr(tt.lhs) - testutilOK(t, err) + testutil.OK(t, err) rhs, err := ParseIPAddr(tt.rhs) - testutilOK(t, err) - contains := lhs.contains(rhs) + testutil.OK(t, err) + contains := lhs.Contains(rhs) if contains != tt.contains { t.Fatalf("expected ip(%v).contains(ip(%v)) to be %v instead of %v", tt.lhs, tt.rhs, tt.contains, contains) } @@ -725,10 +727,10 @@ func TestIP(t *testing.T) { } }) - t.Run("typeName", func(t *testing.T) { + t.Run("TypeName", func(t *testing.T) { t.Parallel() - tn := IPAddr{}.typeName() - testutilEquals(t, tn, "IP") + tn := IPAddr{}.TypeName() + testutil.Equals(t, tn, "IP") }) } @@ -738,140 +740,140 @@ func TestDeepClone(t *testing.T) { t.Parallel() a := Boolean(true) b := a.deepClone() - testutilEquals(t, Value(a), b) + testutil.Equals(t, Value(a), b) a = Boolean(false) - testutilEquals(t, a, Boolean(false)) - testutilEquals(t, b, Value(Boolean(true))) + testutil.Equals(t, a, Boolean(false)) + testutil.Equals(t, b, Value(Boolean(true))) }) t.Run("Long", func(t *testing.T) { t.Parallel() a := Long(42) b := a.deepClone() - testutilEquals(t, Value(a), b) + testutil.Equals(t, Value(a), b) a = Long(43) - testutilEquals(t, a, Long(43)) - testutilEquals(t, b, Value(Long(42))) + testutil.Equals(t, a, Long(43)) + testutil.Equals(t, b, Value(Long(42))) }) t.Run("String", func(t *testing.T) { t.Parallel() a := String("cedar") b := a.deepClone() - testutilEquals(t, Value(a), b) + testutil.Equals(t, Value(a), b) a = String("policy") - testutilEquals(t, a, String("policy")) - testutilEquals(t, b, Value(String("cedar"))) + testutil.Equals(t, a, String("policy")) + testutil.Equals(t, b, Value(String("cedar"))) }) t.Run("EntityUID", func(t *testing.T) { t.Parallel() a := NewEntityUID("Action", "test") b := a.deepClone() - testutilEquals(t, Value(a), b) + testutil.Equals(t, Value(a), b) a.ID = "bananas" - testutilEquals(t, a, NewEntityUID("Action", "bananas")) - testutilEquals(t, b, Value(NewEntityUID("Action", "test"))) + testutil.Equals(t, a, NewEntityUID("Action", "bananas")) + testutil.Equals(t, b, Value(NewEntityUID("Action", "test"))) }) t.Run("Set", func(t *testing.T) { t.Parallel() a := Set{Long(42)} b := a.deepClone() - testutilEquals(t, Value(a), b) + testutil.Equals(t, Value(a), b) a[0] = String("bananas") - testutilEquals(t, a, Set{String("bananas")}) - testutilEquals(t, b, Value(Set{Long(42)})) + testutil.Equals(t, a, Set{String("bananas")}) + testutil.Equals(t, b, Value(Set{Long(42)})) }) t.Run("NilSet", func(t *testing.T) { t.Parallel() var a Set b := a.deepClone() - testutilEquals(t, Value(a), b) + testutil.Equals(t, Value(a), b) }) t.Run("Record", func(t *testing.T) { t.Parallel() a := Record{"key": Long(42)} b := a.deepClone() - testutilEquals(t, Value(a), b) + testutil.Equals(t, Value(a), b) a["key"] = String("bananas") - testutilEquals(t, a, Record{"key": String("bananas")}) - testutilEquals(t, b, Value(Record{"key": Long(42)})) + testutil.Equals(t, a, Record{"key": String("bananas")}) + testutil.Equals(t, b, Value(Record{"key": Long(42)})) }) t.Run("NilRecord", func(t *testing.T) { t.Parallel() var a Record b := a.deepClone() - testutilEquals(t, Value(a), b) + testutil.Equals(t, Value(a), b) }) t.Run("Decimal", func(t *testing.T) { t.Parallel() a := Decimal(42) b := a.deepClone() - testutilEquals(t, Value(a), b) + testutil.Equals(t, Value(a), b) a = Decimal(43) - testutilEquals(t, a, Decimal(43)) - testutilEquals(t, b, Value(Decimal(42))) + testutil.Equals(t, a, Decimal(43)) + testutil.Equals(t, b, Value(Decimal(42))) }) t.Run("IPAddr", func(t *testing.T) { t.Parallel() a := mustIPValue("127.0.0.42") b := a.deepClone() - testutilEquals(t, a.Cedar(), b.Cedar()) + testutil.Equals(t, a.Cedar(), b.Cedar()) a = mustIPValue("127.0.0.43") - testutilEquals(t, a.Cedar(), mustIPValue("127.0.0.43").Cedar()) - testutilEquals(t, b.Cedar(), mustIPValue("127.0.0.42").Cedar()) + testutil.Equals(t, a.Cedar(), mustIPValue("127.0.0.43").Cedar()) + testutil.Equals(t, b.Cedar(), mustIPValue("127.0.0.42").Cedar()) }) } func TestPath(t *testing.T) { t.Parallel() - t.Run("equal", func(t *testing.T) { + t.Run("Equal", func(t *testing.T) { t.Parallel() - a := path("X") - b := path("X") - c := path("Y") - testutilEquals(t, a.equal(b), true) - testutilEquals(t, b.equal(a), true) - testutilEquals(t, a.equal(c), false) - testutilEquals(t, c.equal(a), false) + a := Path("X") + b := Path("X") + c := Path("Y") + testutil.Equals(t, a.Equal(b), true) + testutil.Equals(t, b.Equal(a), true) + testutil.Equals(t, a.Equal(c), false) + testutil.Equals(t, c.Equal(a), false) }) - t.Run("typeName", func(t *testing.T) { + t.Run("TypeName", func(t *testing.T) { t.Parallel() - a := path("X") - testutilEquals(t, a.typeName(), "(path of type `X`)") + a := Path("X") + testutil.Equals(t, a.TypeName(), "(Path of type `X`)") }) t.Run("String", func(t *testing.T) { t.Parallel() - a := path("X") - testutilEquals(t, a.String(), "X") + a := Path("X") + testutil.Equals(t, a.String(), "X") }) t.Run("Cedar", func(t *testing.T) { t.Parallel() - a := path("X") - testutilEquals(t, a.Cedar(), "X") + a := Path("X") + testutil.Equals(t, a.Cedar(), "X") }) t.Run("ExplicitMarshalJSON", func(t *testing.T) { t.Parallel() - a := path("X") + a := Path("X") v, err := a.ExplicitMarshalJSON() - testutilOK(t, err) - testutilEquals(t, string(v), `"X"`) + testutil.OK(t, err) + testutil.Equals(t, string(v), `"X"`) }) t.Run("deepClone", func(t *testing.T) { t.Parallel() - a := path("X") + a := Path("X") b := a.deepClone() - c, ok := b.(path) - testutilEquals(t, ok, true) - testutilEquals(t, c, a) + c, ok := b.(Path) + testutil.Equals(t, ok, true) + testutil.Equals(t, c, a) }) t.Run("pathFromSlice", func(t *testing.T) { t.Parallel() - a := pathFromSlice([]string{"X", "Y"}) - testutilEquals(t, a, path("X::Y")) + a := PathFromSlice([]string{"X", "Y"}) + testutil.Equals(t, a, Path("X::Y")) }) } diff --git a/x/exp/ast/ast_test.go b/x/exp/ast/ast_test.go index 6c4aaa4f..9e513e62 100644 --- a/x/exp/ast/ast_test.go +++ b/x/exp/ast/ast_test.go @@ -3,8 +3,8 @@ package ast_test import ( "testing" + "github.com/cedar-policy/cedar-go/types" "github.com/cedar-policy/cedar-go/x/exp/ast" - "github.com/cedar-policy/cedar-go/x/exp/types" ) // These tests mostly verify that policy ASTs compile diff --git a/x/exp/ast/node.go b/x/exp/ast/node.go index 38f2b7e4..7e261cea 100644 --- a/x/exp/ast/node.go +++ b/x/exp/ast/node.go @@ -1,9 +1,9 @@ package ast -type opType uint8 +type nodeType uint8 const ( - nodeTypeAccess opType = iota + nodeTypeAccess nodeType = iota nodeTypeAdd nodeTypeAnd nodeTypeAnnotation @@ -41,7 +41,7 @@ const ( ) type Node struct { - op opType + nodeType nodeType // TODO: Should we just have `value any`? args []Node value any diff --git a/x/exp/ast/operator.go b/x/exp/ast/operator.go index c3b89b21..d8b0c177 100644 --- a/x/exp/ast/operator.go +++ b/x/exp/ast/operator.go @@ -1,9 +1,9 @@ package ast -import "github.com/cedar-policy/cedar-go/x/exp/types" +import "github.com/cedar-policy/cedar-go/types" -// ____ _ -// / ___|___ _ __ ___ _ __ __ _ _ __(_)___ ___ _ __ +// ____ _ +// / ___|___ _ __ ___ _ __ __ _ _ __(_)___ ___ _ __ // | | / _ \| '_ ` _ \| '_ \ / _` | '__| / __|/ _ \| '_ \ // | |__| (_) | | | | | | |_) | (_| | | | \__ \ (_) | | | | // \____\___/|_| |_| |_| .__/ \__,_|_| |_|___/\___/|_| |_| @@ -152,6 +152,6 @@ func (lhs Node) IsInRange(rhs Node) Node { return newOpNode(nodeTypeIsInRange, lhs, rhs) } -func newOpNode(op opType, args ...Node) Node { - return Node{op: op, args: args} +func newOpNode(op nodeType, args ...Node) Node { + return Node{nodeType: op, args: args} } diff --git a/x/exp/ast/scope.go b/x/exp/ast/scope.go index 9a754e2c..05914236 100644 --- a/x/exp/ast/scope.go +++ b/x/exp/ast/scope.go @@ -1,6 +1,6 @@ package ast -import "github.com/cedar-policy/cedar-go/x/exp/types" +import "github.com/cedar-policy/cedar-go/types" func (p *Policy) PrincipalEq(entity types.EntityUID) *Policy { p.principal = Principal().Equals(Entity(entity)) @@ -16,7 +16,7 @@ func (p *Policy) PrincipalIn(entities ...types.EntityUID) *Policy { return p } -func (p *Policy) PrincipalIs(entityType types.EntityType) *Policy { +func (p *Policy) PrincipalIs(entityType string) *Policy { p.principal = Principal().Is(EntityType(entityType)) return p } @@ -49,7 +49,7 @@ func (p *Policy) ResourceIn(entities ...types.EntityUID) *Policy { return p } -func (p *Policy) ResourceIs(entityType types.EntityType) *Policy { +func (p *Policy) ResourceIs(entityType string) *Policy { p.principal = Resource().Is(EntityType(entityType)) return p } diff --git a/x/exp/ast/value.go b/x/exp/ast/value.go index d5d46f26..3d1f0f66 100644 --- a/x/exp/ast/value.go +++ b/x/exp/ast/value.go @@ -3,7 +3,7 @@ package ast import ( "fmt" - "github.com/cedar-policy/cedar-go/x/exp/types" + "github.com/cedar-policy/cedar-go/types" ) func Boolean(b types.Boolean) Node { @@ -76,7 +76,7 @@ func RecordNodes(nodes map[string]Node) Node { return newValueNode(nodeTypeRecord, nodes) } -func EntityType(e types.EntityType) Node { +func EntityType(e string) Node { return newValueNode(nodeTypeEntityType, e) } @@ -88,12 +88,12 @@ func Decimal(d types.Decimal) Node { return newValueNode(nodeTypeEntity, d) } -func IpAddr(i types.IpAddr) Node { +func IPAddr(i types.IPAddr) Node { return newValueNode(nodeTypeIpAddr, i) } -func newValueNode(op opType, v any) Node { - return Node{op: op, value: v} +func newValueNode(nodeType nodeType, v any) Node { + return Node{nodeType: nodeType, value: v} } func valueToNode(v types.Value) Node { @@ -112,8 +112,8 @@ func valueToNode(v types.Value) Node { return Entity(x) case types.Decimal: return Decimal(x) - case types.IpAddr: - return IpAddr(x) + case types.IPAddr: + return IPAddr(x) default: panic(fmt.Sprintf("unexpected value type: %T(%v)", v, v)) } diff --git a/x/exp/ast/variable.go b/x/exp/ast/variable.go index 5ef83687..8a7cb662 100644 --- a/x/exp/ast/variable.go +++ b/x/exp/ast/variable.go @@ -5,7 +5,7 @@ func Principal() Node { } func Action() Node { - return newPrincipalNode() + return newActionNode() } func Resource() Node { diff --git a/x/exp/types/types.go b/x/exp/types/types.go deleted file mode 100644 index bbf6c3aa..00000000 --- a/x/exp/types/types.go +++ /dev/null @@ -1,44 +0,0 @@ -package types - -import "net" - -type Value interface { - isValue() -} - -type Boolean bool - -func (Boolean) isValue() {} - -type String string - -func (String) isValue() {} - -type Long int64 - -func (Long) isValue() {} - -type Set []Value - -func (Set) isValue() {} - -type Record map[string]Value - -func (Record) isValue() {} - -type EntityType string - -type EntityUID struct { - Type string - ID string -} - -func (EntityUID) isValue() {} - -type Decimal []float64 - -func (Decimal) isValue() {} - -type IpAddr net.IPAddr - -func (IpAddr) isValue() {} From 1e62dcf81c0c7cd2a5dc6fed5e39efbdf8e6d0c5 Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Tue, 30 Jul 2024 10:17:06 -0700 Subject: [PATCH 004/216] cedar-go: make it lint Signed-off-by: philhassey --- cedar_test.go | 362 +++++++++++++++++++++--------------------- eval_test.go | 34 ++-- x/exp/ast/ast_test.go | 6 +- x/exp/ast/scope.go | 6 +- 4 files changed, 204 insertions(+), 204 deletions(-) diff --git a/cedar_test.go b/cedar_test.go index 774e4233..0a0a61ee 100644 --- a/cedar_test.go +++ b/cedar_test.go @@ -65,9 +65,9 @@ func TestIsAuthorized(t *testing.T) { Name: "simple-permit", Policy: `permit(principal,action,resource);`, Entities: entitiesFromSlice(nil), - Principal: types.EntityUID{"coder", "cuzco"}, - Action: types.EntityUID{"table", "drop"}, - Resource: types.EntityUID{"table", "whatever"}, + Principal: types.NewEntityUID("coder", "cuzco"), + Action: types.NewEntityUID("table", "drop"), + Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: true, DiagErr: 0, @@ -76,9 +76,9 @@ func TestIsAuthorized(t *testing.T) { Name: "simple-forbid", Policy: `forbid(principal,action,resource);`, Entities: entitiesFromSlice(nil), - Principal: types.EntityUID{"coder", "cuzco"}, - Action: types.EntityUID{"table", "drop"}, - Resource: types.EntityUID{"table", "whatever"}, + Principal: types.NewEntityUID("coder", "cuzco"), + Action: types.NewEntityUID("table", "drop"), + Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: false, DiagErr: 0, @@ -87,9 +87,9 @@ func TestIsAuthorized(t *testing.T) { Name: "no-permit", Policy: `permit(principal,action,resource in asdf::"1234");`, Entities: entitiesFromSlice(nil), - Principal: types.EntityUID{"coder", "cuzco"}, - Action: types.EntityUID{"table", "drop"}, - Resource: types.EntityUID{"table", "whatever"}, + Principal: types.NewEntityUID("coder", "cuzco"), + Action: types.NewEntityUID("table", "drop"), + Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: false, DiagErr: 0, @@ -98,9 +98,9 @@ func TestIsAuthorized(t *testing.T) { Name: "error-in-policy", Policy: `permit(principal,action,resource) when { resource in "foo" };`, Entities: entitiesFromSlice(nil), - Principal: types.EntityUID{"coder", "cuzco"}, - Action: types.EntityUID{"table", "drop"}, - Resource: types.EntityUID{"table", "whatever"}, + Principal: types.NewEntityUID("coder", "cuzco"), + Action: types.NewEntityUID("table", "drop"), + Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: false, DiagErr: 1, @@ -111,9 +111,9 @@ func TestIsAuthorized(t *testing.T) { permit(principal,action,resource); `, Entities: entitiesFromSlice(nil), - Principal: types.EntityUID{"coder", "cuzco"}, - Action: types.EntityUID{"table", "drop"}, - Resource: types.EntityUID{"table", "whatever"}, + Principal: types.NewEntityUID("coder", "cuzco"), + Action: types.NewEntityUID("table", "drop"), + Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: true, DiagErr: 1, @@ -122,9 +122,9 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-requires-context-success", Policy: `permit(principal,action,resource) when { context.x == 42 };`, Entities: entitiesFromSlice(nil), - Principal: types.EntityUID{"coder", "cuzco"}, - Action: types.EntityUID{"table", "drop"}, - Resource: types.EntityUID{"table", "whatever"}, + Principal: types.NewEntityUID("coder", "cuzco"), + Action: types.NewEntityUID("table", "drop"), + Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{"x": types.Long(42)}, Want: true, DiagErr: 0, @@ -133,9 +133,9 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-requires-context-fail", Policy: `permit(principal,action,resource) when { context.x == 42 };`, Entities: entitiesFromSlice(nil), - Principal: types.EntityUID{"coder", "cuzco"}, - Action: types.EntityUID{"table", "drop"}, - Resource: types.EntityUID{"table", "whatever"}, + Principal: types.NewEntityUID("coder", "cuzco"), + Action: types.NewEntityUID("table", "drop"), + Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{"x": types.Long(43)}, Want: false, DiagErr: 0, @@ -145,13 +145,13 @@ func TestIsAuthorized(t *testing.T) { Policy: `permit(principal,action,resource) when { principal.x == 42 };`, Entities: entitiesFromSlice([]Entity{ { - UID: types.EntityUID{"coder", "cuzco"}, + UID: types.NewEntityUID("coder", "cuzco"), Attributes: types.Record{"x": types.Long(42)}, }, }), - Principal: types.EntityUID{"coder", "cuzco"}, - Action: types.EntityUID{"table", "drop"}, - Resource: types.EntityUID{"table", "whatever"}, + Principal: types.NewEntityUID("coder", "cuzco"), + Action: types.NewEntityUID("table", "drop"), + Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: true, DiagErr: 0, @@ -161,13 +161,13 @@ func TestIsAuthorized(t *testing.T) { Policy: `permit(principal,action,resource) when { principal.x == 42 };`, Entities: entitiesFromSlice([]Entity{ { - UID: types.EntityUID{"coder", "cuzco"}, + UID: types.NewEntityUID("coder", "cuzco"), Attributes: types.Record{"x": types.Long(43)}, }, }), - Principal: types.EntityUID{"coder", "cuzco"}, - Action: types.EntityUID{"table", "drop"}, - Resource: types.EntityUID{"table", "whatever"}, + Principal: types.NewEntityUID("coder", "cuzco"), + Action: types.NewEntityUID("table", "drop"), + Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: false, DiagErr: 0, @@ -177,13 +177,13 @@ func TestIsAuthorized(t *testing.T) { Policy: `permit(principal,action,resource) when { principal in parent::"bob" };`, Entities: entitiesFromSlice([]Entity{ { - UID: types.EntityUID{"coder", "cuzco"}, - Parents: []types.EntityUID{{"parent", "bob"}}, + UID: types.NewEntityUID("coder", "cuzco"), + Parents: []types.EntityUID{types.NewEntityUID("parent", "bob")}, }, }), - Principal: types.EntityUID{"coder", "cuzco"}, - Action: types.EntityUID{"table", "drop"}, - Resource: types.EntityUID{"table", "whatever"}, + Principal: types.NewEntityUID("coder", "cuzco"), + Action: types.NewEntityUID("table", "drop"), + Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: true, DiagErr: 0, @@ -192,9 +192,9 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-principal-equals", Policy: `permit(principal == coder::"cuzco",action,resource);`, Entities: entitiesFromSlice(nil), - Principal: types.EntityUID{"coder", "cuzco"}, - Action: types.EntityUID{"table", "drop"}, - Resource: types.EntityUID{"table", "whatever"}, + Principal: types.NewEntityUID("coder", "cuzco"), + Action: types.NewEntityUID("table", "drop"), + Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: true, DiagErr: 0, @@ -204,13 +204,13 @@ func TestIsAuthorized(t *testing.T) { Policy: `permit(principal in team::"osiris",action,resource);`, Entities: entitiesFromSlice([]Entity{ { - UID: types.EntityUID{"coder", "cuzco"}, - Parents: []types.EntityUID{{"team", "osiris"}}, + UID: types.NewEntityUID("coder", "cuzco"), + Parents: []types.EntityUID{types.NewEntityUID("team", "osiris")}, }, }), - Principal: types.EntityUID{"coder", "cuzco"}, - Action: types.EntityUID{"table", "drop"}, - Resource: types.EntityUID{"table", "whatever"}, + Principal: types.NewEntityUID("coder", "cuzco"), + Action: types.NewEntityUID("table", "drop"), + Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: true, DiagErr: 0, @@ -219,9 +219,9 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-action-equals", Policy: `permit(principal,action == table::"drop",resource);`, Entities: entitiesFromSlice(nil), - Principal: types.EntityUID{"coder", "cuzco"}, - Action: types.EntityUID{"table", "drop"}, - Resource: types.EntityUID{"table", "whatever"}, + Principal: types.NewEntityUID("coder", "cuzco"), + Action: types.NewEntityUID("table", "drop"), + Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: true, DiagErr: 0, @@ -231,13 +231,13 @@ func TestIsAuthorized(t *testing.T) { Policy: `permit(principal,action in scary::"stuff",resource);`, Entities: entitiesFromSlice([]Entity{ { - UID: types.EntityUID{"table", "drop"}, - Parents: []types.EntityUID{{"scary", "stuff"}}, + UID: types.NewEntityUID("table", "drop"), + Parents: []types.EntityUID{types.NewEntityUID("scary", "stuff")}, }, }), - Principal: types.EntityUID{"coder", "cuzco"}, - Action: types.EntityUID{"table", "drop"}, - Resource: types.EntityUID{"table", "whatever"}, + Principal: types.NewEntityUID("coder", "cuzco"), + Action: types.NewEntityUID("table", "drop"), + Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: true, DiagErr: 0, @@ -247,13 +247,13 @@ func TestIsAuthorized(t *testing.T) { Policy: `permit(principal,action in [scary::"stuff"],resource);`, Entities: entitiesFromSlice([]Entity{ { - UID: types.EntityUID{"table", "drop"}, - Parents: []types.EntityUID{{"scary", "stuff"}}, + UID: types.NewEntityUID("table", "drop"), + Parents: []types.EntityUID{types.NewEntityUID("scary", "stuff")}, }, }), - Principal: types.EntityUID{"coder", "cuzco"}, - Action: types.EntityUID{"table", "drop"}, - Resource: types.EntityUID{"table", "whatever"}, + Principal: types.NewEntityUID("coder", "cuzco"), + Action: types.NewEntityUID("table", "drop"), + Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: true, DiagErr: 0, @@ -262,9 +262,9 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-resource-equals", Policy: `permit(principal,action,resource == table::"whatever");`, Entities: entitiesFromSlice(nil), - Principal: types.EntityUID{"coder", "cuzco"}, - Action: types.EntityUID{"table", "drop"}, - Resource: types.EntityUID{"table", "whatever"}, + Principal: types.NewEntityUID("coder", "cuzco"), + Action: types.NewEntityUID("table", "drop"), + Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: true, DiagErr: 0, @@ -273,9 +273,9 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-unless", Policy: `permit(principal,action,resource) unless { false };`, Entities: entitiesFromSlice(nil), - Principal: types.EntityUID{"coder", "cuzco"}, - Action: types.EntityUID{"table", "drop"}, - Resource: types.EntityUID{"table", "whatever"}, + Principal: types.NewEntityUID("coder", "cuzco"), + Action: types.NewEntityUID("table", "drop"), + Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: true, DiagErr: 0, @@ -284,9 +284,9 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-if", Policy: `permit(principal,action,resource) when { (if true then true else true) };`, Entities: entitiesFromSlice(nil), - Principal: types.EntityUID{"coder", "cuzco"}, - Action: types.EntityUID{"table", "drop"}, - Resource: types.EntityUID{"table", "whatever"}, + Principal: types.NewEntityUID("coder", "cuzco"), + Action: types.NewEntityUID("table", "drop"), + Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: true, DiagErr: 0, @@ -295,9 +295,9 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-or", Policy: `permit(principal,action,resource) when { (true || false) };`, Entities: entitiesFromSlice(nil), - Principal: types.EntityUID{"coder", "cuzco"}, - Action: types.EntityUID{"table", "drop"}, - Resource: types.EntityUID{"table", "whatever"}, + Principal: types.NewEntityUID("coder", "cuzco"), + Action: types.NewEntityUID("table", "drop"), + Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: true, DiagErr: 0, @@ -306,9 +306,9 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-and", Policy: `permit(principal,action,resource) when { (true && true) };`, Entities: entitiesFromSlice(nil), - Principal: types.EntityUID{"coder", "cuzco"}, - Action: types.EntityUID{"table", "drop"}, - Resource: types.EntityUID{"table", "whatever"}, + Principal: types.NewEntityUID("coder", "cuzco"), + Action: types.NewEntityUID("table", "drop"), + Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: true, DiagErr: 0, @@ -317,9 +317,9 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-relations", Policy: `permit(principal,action,resource) when { (1<2) && (1<=1) && (2>1) && (1>=1) && (1!=2) && (1==1)};`, Entities: entitiesFromSlice(nil), - Principal: types.EntityUID{"coder", "cuzco"}, - Action: types.EntityUID{"table", "drop"}, - Resource: types.EntityUID{"table", "whatever"}, + Principal: types.NewEntityUID("coder", "cuzco"), + Action: types.NewEntityUID("table", "drop"), + Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: true, DiagErr: 0, @@ -328,9 +328,9 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-relations-in", Policy: `permit(principal,action,resource) when { principal in principal };`, Entities: entitiesFromSlice(nil), - Principal: types.EntityUID{"coder", "cuzco"}, - Action: types.EntityUID{"table", "drop"}, - Resource: types.EntityUID{"table", "whatever"}, + Principal: types.NewEntityUID("coder", "cuzco"), + Action: types.NewEntityUID("table", "drop"), + Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: true, DiagErr: 0, @@ -340,13 +340,13 @@ func TestIsAuthorized(t *testing.T) { Policy: `permit(principal,action,resource) when { principal has name };`, Entities: entitiesFromSlice([]Entity{ { - UID: types.EntityUID{"coder", "cuzco"}, + UID: types.NewEntityUID("coder", "cuzco"), Attributes: types.Record{"name": types.String("bob")}, }, }), - Principal: types.EntityUID{"coder", "cuzco"}, - Action: types.EntityUID{"table", "drop"}, - Resource: types.EntityUID{"table", "whatever"}, + Principal: types.NewEntityUID("coder", "cuzco"), + Action: types.NewEntityUID("table", "drop"), + Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: true, DiagErr: 0, @@ -355,9 +355,9 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-add-sub", Policy: `permit(principal,action,resource) when { 40+3-1==42 };`, Entities: entitiesFromSlice(nil), - Principal: types.EntityUID{"coder", "cuzco"}, - Action: types.EntityUID{"table", "drop"}, - Resource: types.EntityUID{"table", "whatever"}, + Principal: types.NewEntityUID("coder", "cuzco"), + Action: types.NewEntityUID("table", "drop"), + Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: true, DiagErr: 0, @@ -366,9 +366,9 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-mul", Policy: `permit(principal,action,resource) when { 6*7==42 };`, Entities: entitiesFromSlice(nil), - Principal: types.EntityUID{"coder", "cuzco"}, - Action: types.EntityUID{"table", "drop"}, - Resource: types.EntityUID{"table", "whatever"}, + Principal: types.NewEntityUID("coder", "cuzco"), + Action: types.NewEntityUID("table", "drop"), + Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: true, DiagErr: 0, @@ -377,9 +377,9 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-negate", Policy: `permit(principal,action,resource) when { -42==-42 };`, Entities: entitiesFromSlice(nil), - Principal: types.EntityUID{"coder", "cuzco"}, - Action: types.EntityUID{"table", "drop"}, - Resource: types.EntityUID{"table", "whatever"}, + Principal: types.NewEntityUID("coder", "cuzco"), + Action: types.NewEntityUID("table", "drop"), + Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: true, DiagErr: 0, @@ -388,9 +388,9 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-not", Policy: `permit(principal,action,resource) when { !(1+1==42) };`, Entities: entitiesFromSlice(nil), - Principal: types.EntityUID{"coder", "cuzco"}, - Action: types.EntityUID{"table", "drop"}, - Resource: types.EntityUID{"table", "whatever"}, + Principal: types.NewEntityUID("coder", "cuzco"), + Action: types.NewEntityUID("table", "drop"), + Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: true, DiagErr: 0, @@ -399,9 +399,9 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-set", Policy: `permit(principal,action,resource) when { [1,2,3].contains(2) };`, Entities: entitiesFromSlice(nil), - Principal: types.EntityUID{"coder", "cuzco"}, - Action: types.EntityUID{"table", "drop"}, - Resource: types.EntityUID{"table", "whatever"}, + Principal: types.NewEntityUID("coder", "cuzco"), + Action: types.NewEntityUID("table", "drop"), + Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: true, DiagErr: 0, @@ -410,9 +410,9 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-record", Policy: `permit(principal,action,resource) when { {name:"bob"} has name };`, Entities: entitiesFromSlice(nil), - Principal: types.EntityUID{"coder", "cuzco"}, - Action: types.EntityUID{"table", "drop"}, - Resource: types.EntityUID{"table", "whatever"}, + Principal: types.NewEntityUID("coder", "cuzco"), + Action: types.NewEntityUID("table", "drop"), + Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: true, DiagErr: 0, @@ -421,9 +421,9 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-action", Policy: `permit(principal,action,resource) when { action in action };`, Entities: entitiesFromSlice(nil), - Principal: types.EntityUID{"coder", "cuzco"}, - Action: types.EntityUID{"table", "drop"}, - Resource: types.EntityUID{"table", "whatever"}, + Principal: types.NewEntityUID("coder", "cuzco"), + Action: types.NewEntityUID("table", "drop"), + Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: true, DiagErr: 0, @@ -432,9 +432,9 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-set-contains-ok", Policy: `permit(principal,action,resource) when { [1,2,3].contains(2) };`, Entities: entitiesFromSlice(nil), - Principal: types.EntityUID{"coder", "cuzco"}, - Action: types.EntityUID{"table", "drop"}, - Resource: types.EntityUID{"table", "whatever"}, + Principal: types.NewEntityUID("coder", "cuzco"), + Action: types.NewEntityUID("table", "drop"), + Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: true, DiagErr: 0, @@ -443,9 +443,9 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-set-contains-error", Policy: `permit(principal,action,resource) when { [1,2,3].contains(2,3) };`, Entities: entitiesFromSlice(nil), - Principal: types.EntityUID{"coder", "cuzco"}, - Action: types.EntityUID{"table", "drop"}, - Resource: types.EntityUID{"table", "whatever"}, + Principal: types.NewEntityUID("coder", "cuzco"), + Action: types.NewEntityUID("table", "drop"), + Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: false, DiagErr: 1, @@ -454,9 +454,9 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-set-containsAll-ok", Policy: `permit(principal,action,resource) when { [1,2,3].containsAll([2,3]) };`, Entities: entitiesFromSlice(nil), - Principal: types.EntityUID{"coder", "cuzco"}, - Action: types.EntityUID{"table", "drop"}, - Resource: types.EntityUID{"table", "whatever"}, + Principal: types.NewEntityUID("coder", "cuzco"), + Action: types.NewEntityUID("table", "drop"), + Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: true, DiagErr: 0, @@ -465,9 +465,9 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-set-containsAll-error", Policy: `permit(principal,action,resource) when { [1,2,3].containsAll(2,3) };`, Entities: entitiesFromSlice(nil), - Principal: types.EntityUID{"coder", "cuzco"}, - Action: types.EntityUID{"table", "drop"}, - Resource: types.EntityUID{"table", "whatever"}, + Principal: types.NewEntityUID("coder", "cuzco"), + Action: types.NewEntityUID("table", "drop"), + Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: false, DiagErr: 1, @@ -476,9 +476,9 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-set-containsAny-ok", Policy: `permit(principal,action,resource) when { [1,2,3].containsAny([2,5]) };`, Entities: entitiesFromSlice(nil), - Principal: types.EntityUID{"coder", "cuzco"}, - Action: types.EntityUID{"table", "drop"}, - Resource: types.EntityUID{"table", "whatever"}, + Principal: types.NewEntityUID("coder", "cuzco"), + Action: types.NewEntityUID("table", "drop"), + Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: true, DiagErr: 0, @@ -487,9 +487,9 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-set-containsAny-error", Policy: `permit(principal,action,resource) when { [1,2,3].containsAny(2,5) };`, Entities: entitiesFromSlice(nil), - Principal: types.EntityUID{"coder", "cuzco"}, - Action: types.EntityUID{"table", "drop"}, - Resource: types.EntityUID{"table", "whatever"}, + Principal: types.NewEntityUID("coder", "cuzco"), + Action: types.NewEntityUID("table", "drop"), + Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: false, DiagErr: 1, @@ -498,9 +498,9 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-record-attr", Policy: `permit(principal,action,resource) when { {name:"bob"}["name"] == "bob" };`, Entities: entitiesFromSlice(nil), - Principal: types.EntityUID{"coder", "cuzco"}, - Action: types.EntityUID{"table", "drop"}, - Resource: types.EntityUID{"table", "whatever"}, + Principal: types.NewEntityUID("coder", "cuzco"), + Action: types.NewEntityUID("table", "drop"), + Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: true, DiagErr: 0, @@ -509,9 +509,9 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-unknown-method", Policy: `permit(principal,action,resource) when { [1,2,3].shuffle() };`, Entities: entitiesFromSlice(nil), - Principal: types.EntityUID{"coder", "cuzco"}, - Action: types.EntityUID{"table", "drop"}, - Resource: types.EntityUID{"table", "whatever"}, + Principal: types.NewEntityUID("coder", "cuzco"), + Action: types.NewEntityUID("table", "drop"), + Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: false, DiagErr: 1, @@ -520,9 +520,9 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-like", Policy: `permit(principal,action,resource) when { "bananas" like "*nan*" };`, Entities: entitiesFromSlice(nil), - Principal: types.EntityUID{"coder", "cuzco"}, - Action: types.EntityUID{"table", "drop"}, - Resource: types.EntityUID{"table", "whatever"}, + Principal: types.NewEntityUID("coder", "cuzco"), + Action: types.NewEntityUID("table", "drop"), + Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: true, DiagErr: 0, @@ -531,9 +531,9 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-unknown-ext-fun", Policy: `permit(principal,action,resource) when { fooBar("10") };`, Entities: entitiesFromSlice(nil), - Principal: types.EntityUID{"coder", "cuzco"}, - Action: types.EntityUID{"table", "drop"}, - Resource: types.EntityUID{"table", "whatever"}, + Principal: types.NewEntityUID("coder", "cuzco"), + Action: types.NewEntityUID("table", "drop"), + Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: false, DiagErr: 1, @@ -546,9 +546,9 @@ func TestIsAuthorized(t *testing.T) { decimal("10.0").greaterThan(decimal("9.0")) && decimal("10.0").greaterThanOrEqual(decimal("9.0")) };`, Entities: entitiesFromSlice(nil), - Principal: types.EntityUID{"coder", "cuzco"}, - Action: types.EntityUID{"table", "drop"}, - Resource: types.EntityUID{"table", "whatever"}, + Principal: types.NewEntityUID("coder", "cuzco"), + Action: types.NewEntityUID("table", "drop"), + Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: true, DiagErr: 0, @@ -557,9 +557,9 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-decimal-fun-wrong-arity", Policy: `permit(principal,action,resource) when { decimal(1, 2) };`, Entities: entitiesFromSlice(nil), - Principal: types.EntityUID{"coder", "cuzco"}, - Action: types.EntityUID{"table", "drop"}, - Resource: types.EntityUID{"table", "whatever"}, + Principal: types.NewEntityUID("coder", "cuzco"), + Action: types.NewEntityUID("table", "drop"), + Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: false, DiagErr: 1, @@ -573,9 +573,9 @@ func TestIsAuthorized(t *testing.T) { ip("224.1.2.3").isMulticast() && ip("127.0.0.1").isInRange(ip("127.0.0.0/16"))};`, Entities: entitiesFromSlice(nil), - Principal: types.EntityUID{"coder", "cuzco"}, - Action: types.EntityUID{"table", "drop"}, - Resource: types.EntityUID{"table", "whatever"}, + Principal: types.NewEntityUID("coder", "cuzco"), + Action: types.NewEntityUID("table", "drop"), + Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: true, DiagErr: 0, @@ -584,9 +584,9 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-ip-fun-wrong-arity", Policy: `permit(principal,action,resource) when { ip() };`, Entities: entitiesFromSlice(nil), - Principal: types.EntityUID{"coder", "cuzco"}, - Action: types.EntityUID{"table", "drop"}, - Resource: types.EntityUID{"table", "whatever"}, + Principal: types.NewEntityUID("coder", "cuzco"), + Action: types.NewEntityUID("table", "drop"), + Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: false, DiagErr: 1, @@ -595,9 +595,9 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-isIpv4-wrong-arity", Policy: `permit(principal,action,resource) when { ip("1.2.3.4").isIpv4(true) };`, Entities: entitiesFromSlice(nil), - Principal: types.EntityUID{"coder", "cuzco"}, - Action: types.EntityUID{"table", "drop"}, - Resource: types.EntityUID{"table", "whatever"}, + Principal: types.NewEntityUID("coder", "cuzco"), + Action: types.NewEntityUID("table", "drop"), + Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: false, DiagErr: 1, @@ -606,9 +606,9 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-isIpv6-wrong-arity", Policy: `permit(principal,action,resource) when { ip("1.2.3.4").isIpv6(true) };`, Entities: entitiesFromSlice(nil), - Principal: types.EntityUID{"coder", "cuzco"}, - Action: types.EntityUID{"table", "drop"}, - Resource: types.EntityUID{"table", "whatever"}, + Principal: types.NewEntityUID("coder", "cuzco"), + Action: types.NewEntityUID("table", "drop"), + Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: false, DiagErr: 1, @@ -617,9 +617,9 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-isLoopback-wrong-arity", Policy: `permit(principal,action,resource) when { ip("1.2.3.4").isLoopback(true) };`, Entities: entitiesFromSlice(nil), - Principal: types.EntityUID{"coder", "cuzco"}, - Action: types.EntityUID{"table", "drop"}, - Resource: types.EntityUID{"table", "whatever"}, + Principal: types.NewEntityUID("coder", "cuzco"), + Action: types.NewEntityUID("table", "drop"), + Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: false, DiagErr: 1, @@ -628,9 +628,9 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-isMulticast-wrong-arity", Policy: `permit(principal,action,resource) when { ip("1.2.3.4").isMulticast(true) };`, Entities: entitiesFromSlice(nil), - Principal: types.EntityUID{"coder", "cuzco"}, - Action: types.EntityUID{"table", "drop"}, - Resource: types.EntityUID{"table", "whatever"}, + Principal: types.NewEntityUID("coder", "cuzco"), + Action: types.NewEntityUID("table", "drop"), + Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: false, DiagErr: 1, @@ -639,9 +639,9 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-isInRange-wrong-arity", Policy: `permit(principal,action,resource) when { ip("1.2.3.4").isInRange() };`, Entities: entitiesFromSlice(nil), - Principal: types.EntityUID{"coder", "cuzco"}, - Action: types.EntityUID{"table", "drop"}, - Resource: types.EntityUID{"table", "whatever"}, + Principal: types.NewEntityUID("coder", "cuzco"), + Action: types.NewEntityUID("table", "drop"), + Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: false, DiagErr: 1, @@ -658,9 +658,9 @@ func TestIsAuthorized(t *testing.T) { Name: "principal-is", Policy: `permit(principal is Actor,action,resource);`, Entities: entitiesFromSlice(nil), - Principal: types.EntityUID{"Actor", "cuzco"}, - Action: types.EntityUID{"Action", "drop"}, - Resource: types.EntityUID{"Resource", "table"}, + Principal: types.NewEntityUID("Actor", "cuzco"), + Action: types.NewEntityUID("Action", "drop"), + Resource: types.NewEntityUID("Resource", "table"), Context: types.Record{}, Want: true, DiagErr: 0, @@ -669,9 +669,9 @@ func TestIsAuthorized(t *testing.T) { Name: "principal-is-in", Policy: `permit(principal is Actor in Actor::"cuzco",action,resource);`, Entities: entitiesFromSlice(nil), - Principal: types.EntityUID{"Actor", "cuzco"}, - Action: types.EntityUID{"Action", "drop"}, - Resource: types.EntityUID{"Resource", "table"}, + Principal: types.NewEntityUID("Actor", "cuzco"), + Action: types.NewEntityUID("Action", "drop"), + Resource: types.NewEntityUID("Resource", "table"), Context: types.Record{}, Want: true, DiagErr: 0, @@ -680,9 +680,9 @@ func TestIsAuthorized(t *testing.T) { Name: "resource-is", Policy: `permit(principal,action,resource is Resource);`, Entities: entitiesFromSlice(nil), - Principal: types.EntityUID{"Actor", "cuzco"}, - Action: types.EntityUID{"Action", "drop"}, - Resource: types.EntityUID{"Resource", "table"}, + Principal: types.NewEntityUID("Actor", "cuzco"), + Action: types.NewEntityUID("Action", "drop"), + Resource: types.NewEntityUID("Resource", "table"), Context: types.Record{}, Want: true, DiagErr: 0, @@ -691,9 +691,9 @@ func TestIsAuthorized(t *testing.T) { Name: "resource-is-in", Policy: `permit(principal,action,resource is Resource in Resource::"table");`, Entities: entitiesFromSlice(nil), - Principal: types.EntityUID{"Actor", "cuzco"}, - Action: types.EntityUID{"Action", "drop"}, - Resource: types.EntityUID{"Resource", "table"}, + Principal: types.NewEntityUID("Actor", "cuzco"), + Action: types.NewEntityUID("Action", "drop"), + Resource: types.NewEntityUID("Resource", "table"), Context: types.Record{}, Want: true, DiagErr: 0, @@ -702,9 +702,9 @@ func TestIsAuthorized(t *testing.T) { Name: "when-is", Policy: `permit(principal,action,resource) when { resource is Resource };`, Entities: entitiesFromSlice(nil), - Principal: types.EntityUID{"Actor", "cuzco"}, - Action: types.EntityUID{"Action", "drop"}, - Resource: types.EntityUID{"Resource", "table"}, + Principal: types.NewEntityUID("Actor", "cuzco"), + Action: types.NewEntityUID("Action", "drop"), + Resource: types.NewEntityUID("Resource", "table"), Context: types.Record{}, Want: true, DiagErr: 0, @@ -713,9 +713,9 @@ func TestIsAuthorized(t *testing.T) { Name: "when-is-in", Policy: `permit(principal,action,resource) when { resource is Resource in Resource::"table" };`, Entities: entitiesFromSlice(nil), - Principal: types.EntityUID{"Actor", "cuzco"}, - Action: types.EntityUID{"Action", "drop"}, - Resource: types.EntityUID{"Resource", "table"}, + Principal: types.NewEntityUID("Actor", "cuzco"), + Action: types.NewEntityUID("Action", "drop"), + Resource: types.NewEntityUID("Resource", "table"), Context: types.Record{}, Want: true, DiagErr: 0, @@ -725,13 +725,13 @@ func TestIsAuthorized(t *testing.T) { Policy: `permit(principal,action,resource) when { resource is Resource in Parent::"id" };`, Entities: entitiesFromSlice([]Entity{ { - UID: types.EntityUID{"Resource", "table"}, - Parents: []types.EntityUID{{"Parent", "id"}}, + UID: types.NewEntityUID("Resource", "table"), + Parents: []types.EntityUID{types.NewEntityUID("Parent", "id")}, }, }), - Principal: types.EntityUID{"Actor", "cuzco"}, - Action: types.EntityUID{"Action", "drop"}, - Resource: types.EntityUID{"Resource", "table"}, + Principal: types.NewEntityUID("Actor", "cuzco"), + Action: types.NewEntityUID("Action", "drop"), + Resource: types.NewEntityUID("Resource", "table"), Context: types.Record{}, Want: true, DiagErr: 0, diff --git a/eval_test.go b/eval_test.go index bd5027d3..3ecdd2b8 100644 --- a/eval_test.go +++ b/eval_test.go @@ -1264,17 +1264,17 @@ func TestAttributeAccessNode(t *testing.T) { types.Long(42), nil}, {"KnownAttributeOnEntity", - newLiteralEval(types.EntityUID{"knownType", "knownID"}), + newLiteralEval(types.NewEntityUID("knownType", "knownID")), "knownAttr", types.Long(42), nil}, {"UnknownEntity", - newLiteralEval(types.EntityUID{"unknownType", "unknownID"}), + newLiteralEval(types.NewEntityUID("unknownType", "unknownID")), "unknownAttr", types.ZeroValue(), errEntityNotExist}, {"UnspecifiedEntity", - newLiteralEval(types.EntityUID{"", ""}), + newLiteralEval(types.NewEntityUID("", "")), "knownAttr", types.ZeroValue(), errUnspecifiedEntity}, @@ -1320,17 +1320,17 @@ func TestHasNode(t *testing.T) { types.Boolean(true), nil}, {"KnownAttributeOnEntity", - newLiteralEval(types.EntityUID{"knownType", "knownID"}), + newLiteralEval(types.NewEntityUID("knownType", "knownID")), "knownAttr", types.Boolean(true), nil}, {"UnknownAttributeOnEntity", - newLiteralEval(types.EntityUID{"knownType", "knownID"}), + newLiteralEval(types.NewEntityUID("knownType", "knownID")), "unknownAttr", types.Boolean(false), nil}, {"UnknownEntity", - newLiteralEval(types.EntityUID{"unknownType", "unknownID"}), + newLiteralEval(types.NewEntityUID("unknownType", "unknownID")), "unknownAttr", types.Boolean(false), nil}, @@ -1633,7 +1633,7 @@ func TestInNode(t *testing.T) { }, { "RhsError", - newLiteralEval(types.EntityUID{"human", "joe"}), + newLiteralEval(types.NewEntityUID("human", "joe")), newErrorEval(errTest), map[string][]string{}, types.ZeroValue(), @@ -1641,7 +1641,7 @@ func TestInNode(t *testing.T) { }, { "RhsTypeError1", - newLiteralEval(types.EntityUID{"human", "joe"}), + newLiteralEval(types.NewEntityUID("human", "joe")), newLiteralEval(types.String("foo")), map[string][]string{}, types.ZeroValue(), @@ -1649,7 +1649,7 @@ func TestInNode(t *testing.T) { }, { "RhsTypeError2", - newLiteralEval(types.EntityUID{"human", "joe"}), + newLiteralEval(types.NewEntityUID("human", "joe")), newLiteralEval(types.Set{ types.String("foo"), }), @@ -1659,17 +1659,17 @@ func TestInNode(t *testing.T) { }, { "Reflexive1", - newLiteralEval(types.EntityUID{"human", "joe"}), - newLiteralEval(types.EntityUID{"human", "joe"}), + newLiteralEval(types.NewEntityUID("human", "joe")), + newLiteralEval(types.NewEntityUID("human", "joe")), map[string][]string{}, types.Boolean(true), nil, }, { "Reflexive2", - newLiteralEval(types.EntityUID{"human", "joe"}), + newLiteralEval(types.NewEntityUID("human", "joe")), newLiteralEval(types.Set{ - types.EntityUID{"human", "joe"}, + types.NewEntityUID("human", "joe"), }), map[string][]string{}, types.Boolean(true), @@ -1677,8 +1677,8 @@ func TestInNode(t *testing.T) { }, { "BasicTrue", - newLiteralEval(types.EntityUID{"human", "joe"}), - newLiteralEval(types.EntityUID{"kingdom", "animal"}), + newLiteralEval(types.NewEntityUID("human", "joe")), + newLiteralEval(types.NewEntityUID("kingdom", "animal")), map[string][]string{ `human::"joe"`: {`species::"human"`}, `species::"human"`: {`kingdom::"animal"`}, @@ -1688,8 +1688,8 @@ func TestInNode(t *testing.T) { }, { "BasicFalse", - newLiteralEval(types.EntityUID{"human", "joe"}), - newLiteralEval(types.EntityUID{"kingdom", "plant"}), + newLiteralEval(types.NewEntityUID("human", "joe")), + newLiteralEval(types.NewEntityUID("kingdom", "plant")), map[string][]string{ `human::"joe"`: {`species::"human"`}, `species::"human"`: {`kingdom::"animal"`}, diff --git a/x/exp/ast/ast_test.go b/x/exp/ast/ast_test.go index 9e513e62..58687919 100644 --- a/x/exp/ast/ast_test.go +++ b/x/exp/ast/ast_test.go @@ -11,9 +11,9 @@ import ( func TestAst(t *testing.T) { t.Parallel() - johnny := types.EntityUID{"User", "johnny"} - sow := types.EntityUID{"Action", "sow"} - cast := types.EntityUID{"Action", "cast"} + johnny := types.NewEntityUID("User", "johnny") + sow := types.NewEntityUID("Action", "sow") + cast := types.NewEntityUID("Action", "cast") // @example("one") // permit ( diff --git a/x/exp/ast/scope.go b/x/exp/ast/scope.go index 05914236..469eda60 100644 --- a/x/exp/ast/scope.go +++ b/x/exp/ast/scope.go @@ -36,7 +36,7 @@ func (p *Policy) ActionIn(entities ...types.EntityUID) *Policy { } func (p *Policy) ResourceEq(entity types.EntityUID) *Policy { - p.principal = Resource().Equals(Entity(entity)) + p.resource = Resource().Equals(Entity(entity)) return p } @@ -45,11 +45,11 @@ func (p *Policy) ResourceIn(entities ...types.EntityUID) *Policy { for _, e := range entities { entities = append(entities, e) } - p.principal = Resource().In(Set(entityValues)) + p.resource = Resource().In(Set(entityValues)) return p } func (p *Policy) ResourceIs(entityType string) *Policy { - p.principal = Resource().Is(EntityType(entityType)) + p.resource = Resource().Is(EntityType(entityType)) return p } From 2b0778bf64632b57d1d30ae5339ff12ca686c076 Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Tue, 30 Jul 2024 11:47:01 -0700 Subject: [PATCH 005/216] cedar-go/x/exp/ast: make Set use the args slices to store its elements rather than the value Signed-off-by: philhassey --- x/exp/ast/value.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x/exp/ast/value.go b/x/exp/ast/value.go index 3d1f0f66..5dcd2aea 100644 --- a/x/exp/ast/value.go +++ b/x/exp/ast/value.go @@ -49,7 +49,7 @@ func Set(s types.Set) Node { // ast.Context().Access("fooCount"), // }) func SetNodes(nodes []Node) Node { - return newValueNode(nodeTypeSet, nodes) + return Node{nodeType: nodeTypeSet, args: nodes} } // Record is a convenience function that wraps concrete instances of a Cedar Record type From 2940ece21826aef6daa564bac72f0d9a82ae1232 Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Tue, 30 Jul 2024 11:54:06 -0700 Subject: [PATCH 006/216] cedar-go/x/exp/ast: make Record AST node use args to store the entries in the record rather than value Signed-off-by: philhassey --- x/exp/ast/ast_test.go | 2 +- x/exp/ast/node.go | 1 + x/exp/ast/value.go | 18 ++++++++++++++---- 3 files changed, 16 insertions(+), 5 deletions(-) diff --git a/x/exp/ast/ast_test.go b/x/exp/ast/ast_test.go index 58687919..c8a381ab 100644 --- a/x/exp/ast/ast_test.go +++ b/x/exp/ast/ast_test.go @@ -61,7 +61,7 @@ func TestAst(t *testing.T) { ast.Record(simpleRecord).Access("x").Equals(ast.String("value")), ). When( - ast.RecordNodes(map[string]ast.Node{ + ast.RecordNodes(map[types.String]ast.Node{ "x": ast.Long(1).Plus(ast.Context().Access("fooCount")), }).Access("x").Equals(ast.Long(3)), ). diff --git a/x/exp/ast/node.go b/x/exp/ast/node.go index 7e261cea..aa58e987 100644 --- a/x/exp/ast/node.go +++ b/x/exp/ast/node.go @@ -34,6 +34,7 @@ const ( nodeTypeNotEquals nodeTypeOr nodeTypeRecord + nodeTypeRecordEntry nodeTypeSet nodeTypeSub nodeTypeString diff --git a/x/exp/ast/value.go b/x/exp/ast/value.go index 5dcd2aea..1f60a213 100644 --- a/x/exp/ast/value.go +++ b/x/exp/ast/value.go @@ -55,9 +55,9 @@ func SetNodes(nodes []Node) Node { // Record is a convenience function that wraps concrete instances of a Cedar Record type // types in AST value nodes and passes them along to RecordNodes. func Record(r types.Record) Node { - recordNodes := map[string]Node{} + recordNodes := map[types.String]Node{} for k, v := range r { - recordNodes[k] = valueToNode(v) + recordNodes[types.String(k)] = valueToNode(v) } return RecordNodes(recordNodes) } @@ -72,8 +72,18 @@ func Record(r types.Record) Node { // ast.RecordNodes([]ast.RecordNode{ // {Key: "x", Value: ast.Long(1).Plus(ast.Context().Access("resourceField"))}, // }) -func RecordNodes(nodes map[string]Node) Node { - return newValueNode(nodeTypeRecord, nodes) +func RecordNodes(entries map[types.String]Node) Node { + var nodes []Node + for k, v := range entries { + nodes = append( + nodes, + Node{ + nodeType: nodeTypeRecordEntry, + args: []Node{String(k), v}, + }, + ) + } + return Node{nodeType: nodeTypeRecord, args: nodes} } func EntityType(e string) Node { From aa2d1838a5f1ad8704970559cbf0e762924bbe1f Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Tue, 30 Jul 2024 11:56:43 -0700 Subject: [PATCH 007/216] cedar-go/x/exp/ast: make Annotation AST node use args for its name/value rather than value Signed-off-by: philhassey --- x/exp/ast/annotation.go | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/x/exp/ast/annotation.go b/x/exp/ast/annotation.go index 55a67031..5bcd931e 100644 --- a/x/exp/ast/annotation.go +++ b/x/exp/ast/annotation.go @@ -1,10 +1,12 @@ package ast -func (p *Policy) Annotate(name string, value string) *Policy { +import "github.com/cedar-policy/cedar-go/types" + +func (p *Policy) Annotate(name, value types.String) *Policy { p.annotations = append(p.annotations, newAnnotationNode(name, value)) return p } -func newAnnotationNode(name, value string) Node { - return newValueNode(nodeTypeAnnotation, []string{name, value}) +func newAnnotationNode(name, value types.String) Node { + return Node{nodeType: nodeTypeAnnotation, args: []Node{String(name), String(value)}} } From 003fa9e04e88470511cc54dd733881a758c98dd0 Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Tue, 30 Jul 2024 12:00:33 -0700 Subject: [PATCH 008/216] cedar-go/x/exp/ast: give Node.value a type of types.Value and fix up users Signed-off-by: philhassey --- x/exp/ast/node.go | 7 ++++--- x/exp/ast/scope.go | 4 ++-- x/exp/ast/value.go | 4 ++-- x/exp/ast/variable.go | 10 ++++++---- 4 files changed, 14 insertions(+), 11 deletions(-) diff --git a/x/exp/ast/node.go b/x/exp/ast/node.go index aa58e987..4e107fcf 100644 --- a/x/exp/ast/node.go +++ b/x/exp/ast/node.go @@ -1,5 +1,7 @@ package ast +import "github.com/cedar-policy/cedar-go/types" + type nodeType uint8 const ( @@ -43,7 +45,6 @@ const ( type Node struct { nodeType nodeType - // TODO: Should we just have `value any`? - args []Node - value any + args []Node // For inner nodes like operators, records, etc + value types.Value // For leaf nodes like String, Long, EntityUID } diff --git a/x/exp/ast/scope.go b/x/exp/ast/scope.go index 469eda60..0d528f1f 100644 --- a/x/exp/ast/scope.go +++ b/x/exp/ast/scope.go @@ -16,7 +16,7 @@ func (p *Policy) PrincipalIn(entities ...types.EntityUID) *Policy { return p } -func (p *Policy) PrincipalIs(entityType string) *Policy { +func (p *Policy) PrincipalIs(entityType types.String) *Policy { p.principal = Principal().Is(EntityType(entityType)) return p } @@ -49,7 +49,7 @@ func (p *Policy) ResourceIn(entities ...types.EntityUID) *Policy { return p } -func (p *Policy) ResourceIs(entityType string) *Policy { +func (p *Policy) ResourceIs(entityType types.String) *Policy { p.resource = Resource().Is(EntityType(entityType)) return p } diff --git a/x/exp/ast/value.go b/x/exp/ast/value.go index 1f60a213..ed3d778b 100644 --- a/x/exp/ast/value.go +++ b/x/exp/ast/value.go @@ -86,7 +86,7 @@ func RecordNodes(entries map[types.String]Node) Node { return Node{nodeType: nodeTypeRecord, args: nodes} } -func EntityType(e string) Node { +func EntityType(e types.String) Node { return newValueNode(nodeTypeEntityType, e) } @@ -102,7 +102,7 @@ func IPAddr(i types.IPAddr) Node { return newValueNode(nodeTypeIpAddr, i) } -func newValueNode(nodeType nodeType, v any) Node { +func newValueNode(nodeType nodeType, v types.Value) Node { return Node{nodeType: nodeType, value: v} } diff --git a/x/exp/ast/variable.go b/x/exp/ast/variable.go index 8a7cb662..9bf750a7 100644 --- a/x/exp/ast/variable.go +++ b/x/exp/ast/variable.go @@ -1,5 +1,7 @@ package ast +import "github.com/cedar-policy/cedar-go/types" + func Principal() Node { return newPrincipalNode() } @@ -17,17 +19,17 @@ func Context() Node { } func newPrincipalNode() Node { - return newValueNode(nodeTypeVariable, "principal") + return newValueNode(nodeTypeVariable, types.String("principal")) } func newActionNode() Node { - return newValueNode(nodeTypeVariable, "action") + return newValueNode(nodeTypeVariable, types.String("action")) } func newResourceNode() Node { - return newValueNode(nodeTypeVariable, "resource") + return newValueNode(nodeTypeVariable, types.String("resource")) } func newContextNode() Node { - return newValueNode(nodeTypeVariable, "context") + return newValueNode(nodeTypeVariable, types.String("context")) } From 99e8336dfa4de04e81e4b5155642c5000243b502 Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Tue, 30 Jul 2024 13:17:49 -0700 Subject: [PATCH 009/216] cedar-go/x/exp/parser: make tests depend on common testutil package Signed-off-by: philhassey --- x/exp/parser/parse_test.go | 36 ++++++------ x/exp/parser/testutil_test.go | 30 ---------- x/exp/parser/tokenize_test.go | 104 +++++++++++++++++----------------- 3 files changed, 72 insertions(+), 98 deletions(-) delete mode 100644 x/exp/parser/testutil_test.go diff --git a/x/exp/parser/parse_test.go b/x/exp/parser/parse_test.go index 2b8f0a78..7a6c0bb4 100644 --- a/x/exp/parser/parse_test.go +++ b/x/exp/parser/parse_test.go @@ -2,6 +2,8 @@ package parser import ( "testing" + + "github.com/cedar-policy/cedar-go/testutil" ) func TestParse(t *testing.T) { @@ -291,16 +293,16 @@ func TestParse(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() tokens, err := Tokenize([]byte(tt.in)) - testutilOK(t, err) + testutil.OK(t, err) got, err := Parse(tokens) - testutilEquals(t, err != nil, tt.err) + testutil.Equals(t, err != nil, tt.err) if err != nil { - testutilEquals(t, got, nil) + testutil.Equals(t, got, nil) return } gotTokens, err := Tokenize([]byte(got.String())) - testutilOK(t, err) + testutil.OK(t, err) var tokenStrs []string for _, t := range tokens { @@ -312,7 +314,7 @@ func TestParse(t *testing.T) { gotTokenStrs = append(gotTokenStrs, t.toString()) } - testutilEquals(t, gotTokenStrs, tokenStrs) + testutil.Equals(t, gotTokenStrs, tokenStrs) }) } } @@ -408,10 +410,10 @@ func TestParseTypes(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() tokens, err := Tokenize([]byte(tt.in)) - testutilOK(t, err) + testutil.OK(t, err) got, err := Parse(tokens) - testutilOK(t, err) - testutilEquals(t, got, tt.out) + testutil.OK(t, err) + testutil.Equals(t, got, tt.out) }) } } @@ -424,16 +426,16 @@ func TestParseEntity(t *testing.T) { out Entity err func(testing.TB, error) }{ - {"happy", `Action::"test"`, Entity{Path: []string{"Action", "test"}}, testutilOK}, + {"happy", `Action::"test"`, Entity{Path: []string{"Action", "test"}}, testutil.OK}, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() toks, err := Tokenize([]byte(tt.in)) - testutilOK(t, err) + testutil.OK(t, err) out, err := ParseEntity(toks) - testutilEquals(t, out, tt.out) + testutil.Equals(t, out, tt.out) tt.err(t, err) }) } @@ -453,11 +455,11 @@ permit( principal, action, resource ); @test("1234") permit (principal, action, resource ); ` toks, err := Tokenize([]byte(in)) - testutilOK(t, err) + testutil.OK(t, err) out, err := Parse(toks) - testutilOK(t, err) - testutilEquals(t, len(out), 3) - testutilEquals(t, out[0].Position, Position{Offset: 17, Line: 2, Column: 1}) - testutilEquals(t, out[1].Position, Position{Offset: 86, Line: 7, Column: 3}) - testutilEquals(t, out[2].Position, Position{Offset: 148, Line: 10, Column: 2}) + testutil.OK(t, err) + testutil.Equals(t, len(out), 3) + testutil.Equals(t, out[0].Position, Position{Offset: 17, Line: 2, Column: 1}) + testutil.Equals(t, out[1].Position, Position{Offset: 86, Line: 7, Column: 3}) + testutil.Equals(t, out[2].Position, Position{Offset: 148, Line: 10, Column: 2}) } diff --git a/x/exp/parser/testutil_test.go b/x/exp/parser/testutil_test.go deleted file mode 100644 index ce1858bc..00000000 --- a/x/exp/parser/testutil_test.go +++ /dev/null @@ -1,30 +0,0 @@ -package parser - -import ( - "reflect" - "testing" -) - -func testutilEquals[T any](t testing.TB, a, b T) { - t.Helper() - if reflect.DeepEqual(a, b) { - return - } - t.Fatalf("got %+v want %+v", a, b) -} - -func testutilOK(t testing.TB, err error) { - t.Helper() - if err == nil { - return - } - t.Fatalf("got %v want nil", err) -} - -func testutilError(t testing.TB, err error) { - t.Helper() - if err != nil { - return - } - t.Fatalf("got nil want error") -} diff --git a/x/exp/parser/tokenize_test.go b/x/exp/parser/tokenize_test.go index 7281f1ba..42d9911a 100644 --- a/x/exp/parser/tokenize_test.go +++ b/x/exp/parser/tokenize_test.go @@ -6,6 +6,8 @@ import ( "strings" "testing" "unicode/utf8" + + "github.com/cedar-policy/cedar-go/testutil" ) func TestTokenize(t *testing.T) { @@ -89,8 +91,8 @@ multiline comment {Type: TokenEOF, Text: "", Pos: Position{Offset: 271, Line: 16, Column: 7}}, } got, err := Tokenize([]byte(input)) - testutilOK(t, err) - testutilEquals(t, got, want) + testutil.OK(t, err) + testutil.Equals(t, got, want) } func TestTokenizeErrors(t *testing.T) { @@ -125,9 +127,9 @@ func TestTokenizeErrors(t *testing.T) { t.Parallel() got, gotErr := Tokenize([]byte(tt.input)) wantErrStr := fmt.Sprintf("%v: %s", tt.wantErrPos, tt.wantErrStr) - testutilError(t, gotErr) - testutilEquals(t, gotErr.Error(), wantErrStr) - testutilEquals(t, got, nil) + testutil.Error(t, gotErr) + testutil.Equals(t, gotErr.Error(), wantErrStr) + testutil.Equals(t, got, nil) }) } } @@ -149,16 +151,16 @@ func TestIntTokenValues(t *testing.T) { t.Run(tt.input, func(t *testing.T) { t.Parallel() got, err := Tokenize([]byte(tt.input)) - testutilOK(t, err) - testutilEquals(t, len(got), 2) - testutilEquals(t, got[0].Type, TokenInt) + testutil.OK(t, err) + testutil.Equals(t, len(got), 2) + testutil.Equals(t, got[0].Type, TokenInt) gotInt, err := got[0].intValue() if err != nil { - testutilEquals(t, tt.wantOk, false) - testutilEquals(t, err.Error(), tt.wantErr) + testutil.Equals(t, tt.wantOk, false) + testutil.Equals(t, err.Error(), tt.wantErr) } else { - testutilEquals(t, tt.wantOk, true) - testutilEquals(t, gotInt, tt.want) + testutil.Equals(t, tt.wantOk, true) + testutil.Equals(t, gotInt, tt.want) } }) } @@ -201,16 +203,16 @@ func TestStringTokenValues(t *testing.T) { t.Run(tt.input, func(t *testing.T) { t.Parallel() got, err := Tokenize([]byte(tt.input)) - testutilOK(t, err) - testutilEquals(t, len(got), 2) - testutilEquals(t, got[0].Type, TokenString) + testutil.OK(t, err) + testutil.Equals(t, len(got), 2) + testutil.Equals(t, got[0].Type, TokenString) gotStr, err := got[0].stringValue() if err != nil { - testutilEquals(t, tt.wantOk, false) - testutilEquals(t, err.Error(), tt.wantErr) + testutil.Equals(t, tt.wantOk, false) + testutil.Equals(t, err.Error(), tt.wantErr) } else { - testutilEquals(t, tt.wantOk, true) - testutilEquals(t, gotStr, tt.want) + testutil.Equals(t, tt.wantOk, true) + testutil.Equals(t, gotStr, tt.want) } }) } @@ -225,17 +227,17 @@ func TestParseUnicodeEscape(t *testing.T) { outN int err func(t testing.TB, err error) }{ - {"happy", []byte{'{', '4', '2', '}'}, 0x42, 4, testutilOK}, - {"badRune", []byte{'{', 0x80, 0x81}, 0, 1, testutilError}, - {"notHex", []byte{'{', 'g'}, 0, 2, testutilError}, + {"happy", []byte{'{', '4', '2', '}'}, 0x42, 4, testutil.OK}, + {"badRune", []byte{'{', 0x80, 0x81}, 0, 1, testutil.Error}, + {"notHex", []byte{'{', 'g'}, 0, 2, testutil.Error}, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() out, n, err := parseUnicodeEscape(tt.in, 0) - testutilEquals(t, out, tt.out) - testutilEquals(t, n, tt.outN) + testutil.Equals(t, out, tt.out) + testutil.Equals(t, n, tt.outN) tt.err(t, err) }) } @@ -249,14 +251,14 @@ func TestUnquote(t *testing.T) { out string err func(t testing.TB, err error) }{ - {"happy", `"test"`, `test`, testutilOK}, + {"happy", `"test"`, `test`, testutil.OK}, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() out, err := Unquote(tt.in) - testutilEquals(t, out, tt.out) + testutil.Equals(t, out, tt.out) tt.err(t, err) }) } @@ -319,13 +321,13 @@ func TestRustUnquote(t *testing.T) { t.Parallel() got, rem, err := rustUnquote([]byte(tt.input), false) if err != nil { - testutilEquals(t, tt.wantOk, false) - testutilEquals(t, err.Error(), tt.wantErr) - testutilEquals(t, got, tt.want) + testutil.Equals(t, tt.wantOk, false) + testutil.Equals(t, err.Error(), tt.wantErr) + testutil.Equals(t, got, tt.want) } else { - testutilEquals(t, tt.wantOk, true) - testutilEquals(t, got, tt.want) - testutilEquals(t, rem, []byte("")) + testutil.Equals(t, tt.wantOk, true) + testutil.Equals(t, got, tt.want) + testutil.Equals(t, rem, []byte("")) } }) } @@ -390,13 +392,13 @@ func TestRustUnquote(t *testing.T) { t.Parallel() got, rem, err := rustUnquote([]byte(tt.input), true) if err != nil { - testutilEquals(t, tt.wantOk, false) - testutilEquals(t, err.Error(), tt.wantErr) - testutilEquals(t, got, tt.want) + testutil.Equals(t, tt.wantOk, false) + testutil.Equals(t, err.Error(), tt.wantErr) + testutil.Equals(t, got, tt.want) } else { - testutilEquals(t, tt.wantOk, true) - testutilEquals(t, got, tt.want) - testutilEquals(t, string(rem), tt.wantRem) + testutil.Equals(t, tt.wantOk, true) + testutil.Equals(t, got, tt.want) + testutil.Equals(t, string(rem), tt.wantRem) } }) } @@ -406,7 +408,7 @@ func TestRustUnquote(t *testing.T) { func TestFakeRustQuote(t *testing.T) { t.Parallel() out := FakeRustQuote("hello") - testutilEquals(t, out, `"hello"`) + testutil.Equals(t, out, `"hello"`) } func TestPatternFromStringLiteral(t *testing.T) { @@ -456,12 +458,12 @@ func TestPatternFromStringLiteral(t *testing.T) { t.Parallel() got, err := NewPattern(tt.input) if err != nil { - testutilEquals(t, tt.wantOk, false) - testutilEquals(t, err.Error(), tt.wantErr) + testutil.Equals(t, tt.wantOk, false) + testutil.Equals(t, err.Error(), tt.wantErr) } else { - testutilEquals(t, tt.wantOk, true) - testutilEquals(t, got.Comps, tt.want) - testutilEquals(t, got.String(), tt.input) + testutil.Equals(t, tt.wantOk, true) + testutil.Equals(t, got.Comps, tt.want) + testutil.Equals(t, got.String(), tt.input) } }) } @@ -480,7 +482,7 @@ func TestScanner(t *testing.T) { var s scanner s.Init(r) out := s.next() - testutilEquals(t, out, specialRuneEOF) + testutil.Equals(t, out, specialRuneEOF) }) t.Run("MidEmojiEOF", func(t *testing.T) { @@ -500,9 +502,9 @@ func TestScanner(t *testing.T) { } s.Init(r) out := s.next() - testutilEquals(t, out, utf8.RuneError) + testutil.Equals(t, out, utf8.RuneError) out = s.next() - testutilEquals(t, out, specialRuneEOF) + testutil.Equals(t, out, specialRuneEOF) }) t.Run("NotAsciiEmoji", func(t *testing.T) { @@ -510,7 +512,7 @@ func TestScanner(t *testing.T) { var s scanner s.Init(strings.NewReader(`🐐`)) out := s.next() - testutilEquals(t, out, '🐐') + testutil.Equals(t, out, '🐐') }) t.Run("InvalidUTF8", func(t *testing.T) { @@ -518,7 +520,7 @@ func TestScanner(t *testing.T) { var s scanner s.Init(strings.NewReader(string([]byte{0x80, 0x81}))) out := s.next() - testutilEquals(t, out, utf8.RuneError) + testutil.Equals(t, out, utf8.RuneError) }) t.Run("tokenTextNone", func(t *testing.T) { @@ -526,7 +528,7 @@ func TestScanner(t *testing.T) { var s scanner s.Init(strings.NewReader("")) out := s.tokenText() - testutilEquals(t, out, "") + testutil.Equals(t, out, "") }) } @@ -546,7 +548,7 @@ func TestDigitVal(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() out := digitVal(tt.in) - testutilEquals(t, out, tt.out) + testutil.Equals(t, out, tt.out) }) } } From 3f19bdb462413f6e93d9909336fd48a36d1eb94a Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Tue, 30 Jul 2024 13:19:38 -0700 Subject: [PATCH 010/216] cedar-go/types: rename testutil.go to testutil_test.go to prevent it being compiled into the package Signed-off-by: philhassey --- types/{testutil.go => testutil_test.go} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename types/{testutil.go => testutil_test.go} (100%) diff --git a/types/testutil.go b/types/testutil_test.go similarity index 100% rename from types/testutil.go rename to types/testutil_test.go From 0f8098821ecf43487229144e68e0d5dc9b7bd1c2 Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Tue, 30 Jul 2024 13:33:45 -0700 Subject: [PATCH 011/216] cedar-go/types: put testutil back into the types package because cedar-go depends on it Signed-off-by: philhassey --- types/{testutil_test.go => testutil.go} | 0 x/exp/parser2/fuzz_test.go | 103 ++ x/exp/parser2/parse.go | 1462 +++++++++++++++++++++++ x/exp/parser2/parse_test.go | 465 +++++++ x/exp/parser2/tokenize.go | 705 +++++++++++ x/exp/parser2/tokenize_mocks_test.go | 74 ++ x/exp/parser2/tokenize_test.go | 554 +++++++++ 7 files changed, 3363 insertions(+) rename types/{testutil_test.go => testutil.go} (100%) create mode 100644 x/exp/parser2/fuzz_test.go create mode 100644 x/exp/parser2/parse.go create mode 100644 x/exp/parser2/parse_test.go create mode 100644 x/exp/parser2/tokenize.go create mode 100644 x/exp/parser2/tokenize_mocks_test.go create mode 100644 x/exp/parser2/tokenize_test.go diff --git a/types/testutil_test.go b/types/testutil.go similarity index 100% rename from types/testutil_test.go rename to types/testutil.go diff --git a/x/exp/parser2/fuzz_test.go b/x/exp/parser2/fuzz_test.go new file mode 100644 index 00000000..c6f89606 --- /dev/null +++ b/x/exp/parser2/fuzz_test.go @@ -0,0 +1,103 @@ +package parser + +import ( + "testing" +) + +// https://go.dev/doc/tutorial/fuzz +// mkdir testdata +// go test -fuzz=FuzzTokenize -fuzztime 60s +// go test -fuzz=FuzzParse -fuzztime 60s + +func FuzzTokenize(f *testing.F) { + tests := []string{ + `These are some identifiers`, + `0 1 1234`, + `-1 9223372036854775807 -9223372036854775808`, + `"" "string" "\"\'\n\r\t\\\0" "\x123" "\u{0}\u{10fFfF}"`, + `"*" "\*" "*\**"`, + `@.,;(){}[]+-*`, + `:::`, + `!!=<<=>>=`, + `||&&`, + `// single line comment`, + `/*`, + `multiline comment`, + `// embedded comment does nothing`, + `*/`, + `'/%|&=`, + } + for _, tt := range tests { + f.Add(tt) + } + f.Fuzz(func(t *testing.T, orig string) { + toks, err := Tokenize([]byte(orig)) + if err != nil { + if toks != nil { + t.Errorf("toks != nil on err") + } + } + }) +} + +func FuzzParse(f *testing.F) { + tests := []string{ + `permit(principal,action,resource);`, + `forbid(principal,action,resource);`, + `permit(principal,action,resource in asdf::"1234");`, + `permit(principal,action,resource) when { resource in "foo" };`, + `permit(principal,action,resource) when { context.x == 42 };`, + `permit(principal,action,resource) when { context.x == 42 };`, + `permit(principal,action,resource) when { principal.x == 42 };`, + `permit(principal,action,resource) when { principal.x == 42 };`, + `permit(principal,action,resource) when { principal in parent::"bob" };`, + `permit(principal == coder::"cuzco",action,resource);`, + `permit(principal in team::"osiris",action,resource);`, + `permit(principal,action == table::"drop",resource);`, + `permit(principal,action in scary::"stuff",resource);`, + `permit(principal,action in [scary::"stuff"],resource);`, + `permit(principal,action,resource == table::"whatever");`, + `permit(principal,action,resource) unless { false };`, + `permit(principal,action,resource) when { (if true then true else true) };`, + `permit(principal,action,resource) when { (true || false) };`, + `permit(principal,action,resource) when { (true && true) };`, + `permit(principal,action,resource) when { (1<2) && (1<=1) && (2>1) && (1>=1) && (1!=2) && (1==1)};`, + `permit(principal,action,resource) when { principal in principal };`, + `permit(principal,action,resource) when { principal has name };`, + `permit(principal,action,resource) when { 40+3-1==42 };`, + `permit(principal,action,resource) when { 6*7==42 };`, + `permit(principal,action,resource) when { -42==-42 };`, + `permit(principal,action,resource) when { !(1+1==42) };`, + `permit(principal,action,resource) when { [1,2,3].contains(2) };`, + `permit(principal,action,resource) when { {name:"bob"} has name };`, + `permit(principal,action,resource) when { action in action };`, + `permit(principal,action,resource) when { [1,2,3].contains(2) };`, + `permit(principal,action,resource) when { [1,2,3].contains(2,3) };`, + `permit(principal,action,resource) when { [1,2,3].containsAll([2,3]) };`, + `permit(principal,action,resource) when { [1,2,3].containsAll(2,3) };`, + `permit(principal,action,resource) when { [1,2,3].containsAny([2,5]) };`, + `permit(principal,action,resource) when { [1,2,3].containsAny(2,5) };`, + `permit(principal,action,resource) when { {name:"bob"}["name"] == "bob" };`, + `permit(principal,action,resource) when { [1,2,3].shuffle() };`, + `permit(principal,action,resource) when { "bananas" like "*nan*" };`, + `permit(principal,action,resource) when { fooBar("10") };`, + `permit(principal,action,resource) when { decimal(1, 2) };`, + `permit(principal,action,resource) when { ip() };`, + `permit(principal,action,resource) when { ip("1.2.3.4").isIpv4(true) };`, + `permit(principal,action,resource) when { ip("1.2.3.4").isIpv6(true) };`, + `permit(principal,action,resource) when { ip("1.2.3.4").isLoopback(true) };`, + `permit(principal,action,resource) when { ip("1.2.3.4").isMulticast(true) };`, + `permit(principal,action,resource) when { ip("1.2.3.4").isInRange() };`, + } + for _, tt := range tests { + f.Add(tt) + } + f.Fuzz(func(_ *testing.T, orig string) { + toks, err := Tokenize([]byte(orig)) + if err != nil { + return + } + // intentionally ignore parse errors + _, _ = Parse(toks) + }) +} diff --git a/x/exp/parser2/parse.go b/x/exp/parser2/parse.go new file mode 100644 index 00000000..a6ca0690 --- /dev/null +++ b/x/exp/parser2/parse.go @@ -0,0 +1,1462 @@ +package parser + +import ( + "fmt" + "strconv" + "strings" +) + +func Parse(tokens []Token) (Policies, error) { + p := &parser{Tokens: tokens} + return p.Policies() +} + +func ParseEntity(tokens []Token) (Entity, error) { + p := &parser{Tokens: tokens} + return p.Entity() +} + +type parser struct { + Tokens []Token + Pos int +} + +func (p *parser) advance() Token { + t := p.peek() + if p.Pos < len(p.Tokens)-1 { + p.Pos++ + } + return t +} + +func (p *parser) peek() Token { + return p.Tokens[p.Pos] +} + +func (p *parser) exact(tok string) error { + t := p.advance() + if t.Text != tok { + return p.errorf("exact got %v want %v", t.Text, tok) + } + return nil +} + +func (p *parser) errorf(s string, args ...interface{}) error { + var t Token + if p.Pos < len(p.Tokens) { + t = p.Tokens[p.Pos] + } + err := fmt.Errorf(s, args...) + return fmt.Errorf("parse error at %v %q: %w", t.Pos, t.Text, err) +} + +// Policies := {Policy} + +type Policies []Policy + +func (c Policies) String() string { + var sb strings.Builder + for i, p := range c { + if i > 0 { + sb.WriteRune('\n') + } + sb.WriteString(p.String()) + } + return sb.String() +} + +func (p *parser) Policies() (Policies, error) { + var res Policies + for !p.peek().isEOF() { + policy, err := p.policy() + if err != nil { + return nil, err + } + res = append(res, policy) + } + return res, nil +} + +// Policy := {Annotation} Effect '(' Scope ')' {Conditions} ';' +// Scope := Principal ',' Action ',' Resource + +type Policy struct { + Position Position + Annotations []Annotation + Effect Effect + Principal Principal + Action Action + Resource Resource + Conditions []Condition +} + +func (p Policy) String() string { + var sb strings.Builder + for i, a := range p.Annotations { + if i > 0 { + sb.WriteRune('\n') + } + sb.WriteString(a.String()) + } + sb.WriteString(fmt.Sprintf("%s(\n%s,\n%s,\n%s\n)", + p.Effect, p.Principal, p.Action, p.Resource, + )) + for _, c := range p.Conditions { + sb.WriteRune('\n') + sb.WriteString(c.String()) + } + sb.WriteString(";") + return sb.String() +} + +func (p *parser) policy() (Policy, error) { + var res Policy + res.Position = p.peek().Pos + var err error + if res.Annotations, err = p.annotations(); err != nil { + return res, err + } + if res.Effect, err = p.effect(); err != nil { + return res, err + } + if err := p.exact("("); err != nil { + return res, err + } + if res.Principal, err = p.principal(); err != nil { + return res, err + } + if err := p.exact(","); err != nil { + return res, err + } + if res.Action, err = p.action(); err != nil { + return res, err + } + if err := p.exact(","); err != nil { + return res, err + } + if res.Resource, err = p.resource(); err != nil { + return res, err + } + if err := p.exact(")"); err != nil { + return res, err + } + if res.Conditions, err = p.conditions(); err != nil { + return res, err + } + if err := p.exact(";"); err != nil { + return res, err + } + return res, nil +} + +// Annotation := '@'IDENT'('STR')' + +type Annotation struct { + Key string + Value string +} + +func (a Annotation) String() string { + return fmt.Sprintf("@%s(%q)", a.Key, a.Value) +} + +func (p *parser) annotation() (Annotation, error) { + var res Annotation + var err error + t := p.advance() + if !t.isIdent() { + return res, p.errorf("expected ident") + } + res.Key = t.Text + if err := p.exact("("); err != nil { + return res, err + } + t = p.advance() + if !t.isString() { + return res, p.errorf("expected string") + } + if res.Value, err = t.stringValue(); err != nil { + return res, err + } + if err := p.exact(")"); err != nil { + return res, err + } + return res, nil +} + +func (p *parser) annotations() ([]Annotation, error) { + var res []Annotation + for p.peek().Text == "@" { + p.advance() + a, err := p.annotation() + if err != nil { + return res, err + } + for _, aa := range res { + if aa.Key == a.Key { + return res, p.errorf("duplicate annotation") + } + } + res = append(res, a) + } + return res, nil +} + +// Effect := 'permit' | 'forbid' + +type Effect string + +const ( + EffectPermit = Effect("permit") + EffectForbid = Effect("forbid") +) + +func (p *parser) effect() (Effect, error) { + next := p.advance() + res := Effect(next.Text) + switch res { + case EffectForbid: + case EffectPermit: + default: + return res, p.errorf("unexpected effect: %v", res) + } + return res, nil +} + +// MatchType + +type MatchType int + +const ( + MatchAny = MatchType(iota) + MatchEquals + MatchIn + MatchInList + MatchIs + MatchIsIn +) + +// Principal := 'principal' [('in' | '==') Entity] + +type Principal struct { + Type MatchType + Path Path + Entity Entity +} + +func (p Principal) String() string { + var res string + switch p.Type { + case MatchAny: + res = "principal" + case MatchEquals: + res = fmt.Sprintf("principal == %s", p.Entity) + case MatchIs: + res = fmt.Sprintf("principal is %s", p.Path) + case MatchIsIn: + res = fmt.Sprintf("principal is %s in %s", p.Path, p.Entity) + case MatchIn: + res = fmt.Sprintf("principal in %s", p.Entity) + } + return res +} + +func (p *parser) principal() (Principal, error) { + var res Principal + if err := p.exact("principal"); err != nil { + return res, err + } + switch p.peek().Text { + case "==": + p.advance() + var err error + res.Type = MatchEquals + res.Entity, err = p.Entity() + return res, err + case "is": + p.advance() + var err error + res.Type = MatchIs + res.Path, err = p.Path() + if err == nil && p.peek().Text == "in" { + p.advance() + res.Type = MatchIsIn + res.Entity, err = p.Entity() + return res, err + } + return res, err + case "in": + p.advance() + var err error + res.Type = MatchIn + res.Entity, err = p.Entity() + return res, err + default: + return Principal{ + Type: MatchAny, + }, nil + } +} + +// Action := 'action' [( '==' Entity | 'in' ('[' EntList ']' | Entity) )] + +type Action struct { + Type MatchType + Entities []Entity +} + +func (a Action) String() string { + var sb strings.Builder + switch a.Type { + case MatchAny: + sb.WriteString("action") + case MatchEquals: + sb.WriteString(fmt.Sprintf("action == %s", a.Entities[0])) + case MatchIn: + sb.WriteString(fmt.Sprintf("action in %s", a.Entities[0])) + case MatchInList: + sb.WriteString("action in [") + for i, e := range a.Entities { + if i > 0 { + sb.WriteString(", ") + } + sb.WriteString(e.String()) + } + sb.WriteRune(']') + } + return sb.String() +} + +func (p *parser) action() (Action, error) { + var res Action + var err error + if err := p.exact("action"); err != nil { + return res, err + } + switch p.peek().Text { + case "==": + p.advance() + res.Type = MatchEquals + e, err := p.Entity() + if err != nil { + return res, err + } + res.Entities = append(res.Entities, e) + return res, nil + case "in": + p.advance() + if p.peek().Text == "[" { + res.Type = MatchInList + p.advance() + res.Entities, err = p.entlist() + if err != nil { + return res, err + } + p.advance() // entlist guarantees "]" + return res, nil + } else { + res.Type = MatchIn + e, err := p.Entity() + if err != nil { + return res, err + } + res.Entities = append(res.Entities, e) + return res, nil + } + default: + return Action{ + Type: MatchAny, + }, nil + } +} + +// Resource := 'resource' [('in' | '==') Entity)] + +type Resource struct { + Type MatchType + Path Path + Entity Entity +} + +func (r Resource) String() string { + var res string + switch r.Type { + case MatchAny: + res = "resource" + case MatchEquals: + res = fmt.Sprintf("resource == %s", r.Entity) + case MatchIs: + res = fmt.Sprintf("resource is %s", r.Path) + case MatchIsIn: + res = fmt.Sprintf("resource is %s in %s", r.Path, r.Entity) + case MatchIn: + res = fmt.Sprintf("resource in %s", r.Entity) + } + return res +} + +func (p *parser) resource() (Resource, error) { + var res Resource + if err := p.exact("resource"); err != nil { + return res, err + } + switch p.peek().Text { + case "==": + p.advance() + var err error + res.Type = MatchEquals + res.Entity, err = p.Entity() + return res, err + case "is": + p.advance() + var err error + res.Type = MatchIs + res.Path, err = p.Path() + if err == nil && p.peek().Text == "in" { + p.advance() + res.Type = MatchIsIn + res.Entity, err = p.Entity() + return res, err + } + return res, err + case "in": + p.advance() + var err error + res.Type = MatchIn + res.Entity, err = p.Entity() + return res, err + default: + return Resource{ + Type: MatchAny, + }, nil + } +} + +// Entity := Path '::' STR + +type Entity struct { + Path []string +} + +func (e Entity) String() string { + return fmt.Sprintf( + "%s::%q", + strings.Join(e.Path[0:len(e.Path)-1], "::"), + e.Path[len(e.Path)-1], + ) +} + +func (p *parser) Entity() (Entity, error) { + var res Entity + t := p.advance() + if !t.isIdent() { + return res, p.errorf("expected ident") + } + res.Path = append(res.Path, t.Text) + for { + if err := p.exact("::"); err != nil { + return res, err + } + t := p.advance() + switch { + case t.isIdent(): + res.Path = append(res.Path, t.Text) + case t.isString(): + component, err := t.stringValue() + if err != nil { + return res, err + } + res.Path = append(res.Path, component) + return res, nil + default: + return res, p.errorf("unexpected token") + } + } +} + +// Path ::= IDENT {'::' IDENT} + +type Path struct { + Path []string +} + +func (e Path) String() string { + return strings.Join(e.Path, "::") +} + +func (p *parser) Path() (Path, error) { + var res Path + t := p.advance() + if !t.isIdent() { + return res, p.errorf("expected ident") + } + res.Path = append(res.Path, t.Text) + for { + if p.peek().Text != "::" { + return res, nil + } + p.advance() + t := p.advance() + switch { + case t.isIdent(): + res.Path = append(res.Path, t.Text) + default: + return res, p.errorf("unexpected token") + } + } +} + +// EntList := Entity {',' Entity} + +func (p *parser) entlist() ([]Entity, error) { + var res []Entity + for p.peek().Text != "]" { + if len(res) > 0 { + if err := p.exact(","); err != nil { + return res, err + } + } + e, err := p.Entity() + if err != nil { + return res, err + } + res = append(res, e) + } + return res, nil +} + +// Condition := ('when' | 'unless') '{' Expr '}' + +type ConditionType string + +const ( + ConditionWhen ConditionType = "when" + ConditionUnless ConditionType = "unless" +) + +type Condition struct { + Type ConditionType + Expression Expression +} + +func (c Condition) String() string { + var res string + switch c.Type { + case ConditionWhen: + res = fmt.Sprintf("when {\n%s\n}", c.Expression) + case ConditionUnless: + res = fmt.Sprintf("unless {\n%s\n}", c.Expression) + } + return res +} + +func (p *parser) condition() (Condition, error) { + var res Condition + var err error + res.Type = ConditionType(p.advance().Text) + if err := p.exact("{"); err != nil { + return res, err + } + if res.Expression, err = p.expression(); err != nil { + return res, err + } + if err := p.exact("}"); err != nil { + return res, err + } + return res, nil +} + +func (p *parser) conditions() ([]Condition, error) { + var res []Condition + for { + switch p.peek().Text { + case "when", "unless": + c, err := p.condition() + if err != nil { + return res, err + } + res = append(res, c) + default: + return res, nil + } + } +} + +// Expr := Or | If + +type ExpressionType int + +const ( + ExpressionOr ExpressionType = iota + ExpressionIf +) + +type Expression struct { + Type ExpressionType + Or Or + If *If +} + +func (e Expression) String() string { + var res string + switch e.Type { + case ExpressionOr: + res = e.Or.String() + case ExpressionIf: + res = e.If.String() + } + return res +} + +func (p *parser) expression() (Expression, error) { + var res Expression + var err error + if p.peek().Text == "if" { + p.advance() + res.Type = ExpressionIf + i, err := p.ifExpr() + if err != nil { + return res, err + } + res.If = &i + return res, nil + } else { + res.Type = ExpressionOr + if res.Or, err = p.or(); err != nil { + return res, err + } + return res, nil + } +} + +// If := 'if' Expr 'then' Expr 'else' Expr + +type If struct { + If Expression + Then Expression + Else Expression +} + +func (i If) String() string { + return fmt.Sprintf("if %s then %s else %s", i.If, i.Then, i.Else) +} + +func (p *parser) ifExpr() (If, error) { + var res If + var err error + if res.If, err = p.expression(); err != nil { + return res, err + } + if err = p.exact("then"); err != nil { + return res, err + } + if res.Then, err = p.expression(); err != nil { + return res, err + } + if err = p.exact("else"); err != nil { + return res, err + } + if res.Else, err = p.expression(); err != nil { + return res, err + } + return res, err +} + +// Or := And {'||' And} + +type Or struct { + Ands []And +} + +func (o Or) String() string { + var sb strings.Builder + for i, and := range o.Ands { + if i > 0 { + sb.WriteString(" || ") + } + sb.WriteString(and.String()) + } + return sb.String() +} + +func (p *parser) or() (Or, error) { + var res Or + for { + a, err := p.and() + if err != nil { + return res, err + } + res.Ands = append(res.Ands, a) + if p.peek().Text != "||" { + return res, nil + } + p.advance() + } +} + +// And := Relation {'&&' Relation} + +type And struct { + Relations []Relation +} + +func (a And) String() string { + var sb strings.Builder + for i, rel := range a.Relations { + if i > 0 { + sb.WriteString(" && ") + } + sb.WriteString(rel.String()) + } + return sb.String() +} + +func (p *parser) and() (And, error) { + var res And + for { + r, err := p.relation() + if err != nil { + return res, err + } + res.Relations = append(res.Relations, r) + if p.peek().Text != "&&" { + return res, nil + } + p.advance() + } +} + +// Relation := Add [RELOP Add] | Add 'has' (IDENT | STR) | Add 'like' PAT + +type RelationType string + +const ( + RelationNone RelationType = "none" + RelationRelOp RelationType = "relop" + RelationHasIdent RelationType = "hasident" + RelationHasLiteral RelationType = "hasliteral" + RelationLike RelationType = "like" + RelationIs RelationType = "is" + RelationIsIn RelationType = "isIn" +) + +type Relation struct { + Add Add + Type RelationType + RelOp RelOp + RelOpRhs Add + Str string + Pat Pattern + Path Path + Entity Add +} + +func (r Relation) String() string { + var sb strings.Builder + sb.WriteString(r.Add.String()) + switch r.Type { + case RelationNone: + case RelationRelOp: + sb.WriteString(" ") + sb.WriteString(string(r.RelOp)) + sb.WriteString(" ") + sb.WriteString(r.RelOpRhs.String()) + case RelationHasIdent: + sb.WriteString(" has ") + sb.WriteString(r.Str) + case RelationHasLiteral: + sb.WriteString(" has ") + sb.WriteString(strconv.Quote(r.Str)) + case RelationLike: + sb.WriteString(" like ") + sb.WriteString(r.Pat.String()) + case RelationIs: + sb.WriteString(" is ") + sb.WriteString(r.Path.String()) + case RelationIsIn: + sb.WriteString(" is ") + sb.WriteString(r.Path.String()) + sb.WriteString(" in ") + sb.WriteString(r.Entity.String()) + } + return sb.String() +} + +func (p *parser) relation() (Relation, error) { + var res Relation + var err error + if res.Add, err = p.add(); err != nil { + return res, err + } + + t := p.peek() + switch t.Text { + case "<", "<=", ">=", ">", "!=", "==", "in": + p.advance() + res.Type = RelationRelOp + res.RelOp = RelOp(t.Text) + if res.RelOpRhs, err = p.add(); err != nil { + return res, err + } + case "has": + p.advance() + t := p.advance() + switch { + case t.isIdent(): + res.Type = RelationHasIdent + res.Str = t.Text + case t.isString(): + res.Type = RelationHasLiteral + if res.Str, err = t.stringValue(); err != nil { + return res, err + } + default: + return res, p.errorf("unexpected token") + } + case "like": + p.advance() + res.Type = RelationLike + t := p.advance() + if !t.isString() { + return res, p.errorf("unexpected token") + } + if res.Pat, err = t.patternValue(); err != nil { + return res, err + } + case "is": + p.advance() + var err error + res.Type = RelationIs + res.Path, err = p.Path() + if err == nil && p.peek().Text == "in" { + p.advance() + res.Type = RelationIsIn + res.Entity, err = p.add() + return res, err + } + return res, err + default: + res.Type = RelationNone + } + return res, nil +} + +// RELOP := '<' | '<=' | '>=' | '>' | '!=' | '==' | 'in' + +type RelOp string + +const ( + RelOpLt RelOp = "<" + RelOpLe RelOp = "<=" + RelOpGe RelOp = ">=" + RelOpGt RelOp = ">" + RelOpNe RelOp = "!=" + RelOpEq RelOp = "==" + RelOpIn RelOp = "in" +) + +// Add := Mult {ADDOP Mult} + +type Add struct { + Mults []Mult + AddOps []AddOp +} + +func (a Add) String() string { + var sb strings.Builder + sb.WriteString(a.Mults[0].String()) + for i, op := range a.AddOps { + sb.WriteString(fmt.Sprintf(" %s %s", op, a.Mults[i+1].String())) + } + return sb.String() +} + +func (p *parser) add() (Add, error) { + var res Add + var err error + mult, err := p.mult() + if err != nil { + return res, err + } + res.Mults = append(res.Mults, mult) + for { + op := AddOp(p.peek().Text) + switch op { + case AddOpAdd, AddOpSub: + default: + return res, nil + } + p.advance() + mult, err := p.mult() + if err != nil { + return res, err + } + res.AddOps = append(res.AddOps, op) + res.Mults = append(res.Mults, mult) + } +} + +// ADDOP := '+' | '-' + +type AddOp string + +const ( + AddOpAdd AddOp = "+" + AddOpSub AddOp = "-" +) + +// Mult := Unary { '*' Unary} + +type Mult struct { + Unaries []Unary +} + +func (m Mult) String() string { + var sb strings.Builder + for i, u := range m.Unaries { + if i > 0 { + sb.WriteString(" * ") + } + sb.WriteString(u.String()) + } + return sb.String() +} + +func (p *parser) mult() (Mult, error) { + var res Mult + for { + u, err := p.unary() + if err != nil { + return res, err + } + res.Unaries = append(res.Unaries, u) + if p.peek().Text != "*" { + return res, nil + } + p.advance() + } +} + +// Unary := [UNARYOP]x4 Member + +type Unary struct { + Ops []UnaryOp + Member Member +} + +func (u Unary) String() string { + var sb strings.Builder + for _, o := range u.Ops { + sb.WriteString(string(o)) + } + sb.WriteString(u.Member.String()) + return sb.String() +} + +func (p *parser) unary() (Unary, error) { + var res Unary + for { + o := UnaryOp(p.peek().Text) + switch o { + case UnaryOpNot: + p.advance() + res.Ops = append(res.Ops, o) + case UnaryOpMinus: + p.advance() + if p.peek().isInt() { + t := p.advance() + i, err := strconv.ParseInt("-"+t.Text, 10, 64) + if err != nil { + return res, err + } + res.Member = Member{ + Primary: Primary{ + Type: PrimaryLiteral, + Literal: Literal{ + Type: LiteralInt, + Long: i, + }, + }, + } + return res, nil + } + res.Ops = append(res.Ops, o) + default: + var err error + res.Member, err = p.member() + if err != nil { + return res, err + } + return res, nil + } + } +} + +// UNARYOP := '!' | '-' + +type UnaryOp string + +const ( + UnaryOpNot UnaryOp = "!" + UnaryOpMinus UnaryOp = "-" +) + +// Member := Primary {Access} + +type Member struct { + Primary Primary + Accesses []Access +} + +func (m Member) String() string { + var sb strings.Builder + sb.WriteString(m.Primary.String()) + for _, a := range m.Accesses { + sb.WriteString(a.String()) + } + return sb.String() +} + +func (p *parser) member() (Member, error) { + var res Member + var err error + if res.Primary, err = p.primary(); err != nil { + return res, err + } + for { + a, ok, err := p.access() + if !ok { + return res, err + } else { + res.Accesses = append(res.Accesses, a) + } + } +} + +// Primary := LITERAL +// | VAR +// | Entity +// | ExtFun '(' [ExprList] ')' +// | '(' Expr ')' +// | '[' [ExprList] ']' +// | '{' [RecInits] '}' + +type PrimaryType int + +const ( + PrimaryLiteral PrimaryType = iota + PrimaryVar + PrimaryEntity + PrimaryExtFun + PrimaryExpr + PrimaryExprList + PrimaryRecInits +) + +type Primary struct { + Type PrimaryType + Literal Literal + Var Var + Entity Entity + ExtFun ExtFun + Expression Expression + Expressions []Expression + RecInits []RecInit +} + +func (p Primary) String() string { + var res string + switch p.Type { + case PrimaryLiteral: + res = p.Literal.String() + case PrimaryVar: + res = p.Var.String() + case PrimaryEntity: + res = p.Entity.String() + case PrimaryExtFun: + res = p.ExtFun.String() + case PrimaryExpr: + res = fmt.Sprintf("(%s)", p.Expression) + case PrimaryExprList: + var sb strings.Builder + sb.WriteRune('[') + for i, e := range p.Expressions { + if i > 0 { + sb.WriteString(", ") + } + sb.WriteString(e.String()) + } + sb.WriteRune(']') + res = sb.String() + case PrimaryRecInits: + var sb strings.Builder + sb.WriteRune('{') + for i, r := range p.RecInits { + if i > 0 { + sb.WriteString(", ") + } + sb.WriteString(r.String()) + } + sb.WriteRune('}') + res = sb.String() + } + return res +} + +func (p *parser) primary() (Primary, error) { + var res Primary + var err error + t := p.advance() + switch { + case t.isInt(): + i, err := t.intValue() + if err != nil { + return res, err + } + res.Type = PrimaryLiteral + res.Literal = Literal{ + Type: LiteralInt, + Long: i, + } + case t.isString(): + res.Type = PrimaryLiteral + res.Literal.Type = LiteralString + if res.Literal.Str, err = t.stringValue(); err != nil { + return res, err + } + case t.Text == "true", t.Text == "false": + res.Type = PrimaryLiteral + res.Literal = Literal{ + Type: LiteralBool, + Bool: t.Text == "true", + } + case t.Text == string(VarPrincipal), + t.Text == string(VarAction), + t.Text == string(VarResource), + t.Text == string(VarContext): + res.Type = PrimaryVar + res.Var = Var{ + Type: VarType(t.Text), + } + case t.isIdent(): + e, f, err := p.entityOrExtFun(t.Text) + switch { + case e != nil: + res.Type = PrimaryEntity + res.Entity = *e + case f != nil: + res.Type = PrimaryExtFun + res.ExtFun = *f + default: + return res, err + } + case t.Text == "(": + res.Type = PrimaryExpr + if res.Expression, err = p.expression(); err != nil { + return res, err + } + if err := p.exact(")"); err != nil { + return res, err + } + case t.Text == "[": + res.Type = PrimaryExprList + if res.Expressions, err = p.expressions("]"); err != nil { + return res, err + } + p.advance() // expressions guarantees "]" + return res, err + case t.Text == "{": + res.Type = PrimaryRecInits + if res.RecInits, err = p.recInits(); err != nil { + return res, err + } + return res, err + default: + return res, p.errorf("invalid primary") + } + return res, nil +} + +func (p *parser) entityOrExtFun(first string) (*Entity, *ExtFun, error) { + path := []string{first} + for { + if p.peek().Text != "::" { + f, err := p.extFun(path) + if err != nil { + return nil, nil, err + } + return nil, &f, err + } + p.advance() + t := p.advance() + switch { + case t.isIdent(): + path = append(path, t.Text) + case t.isString(): + component, err := t.stringValue() + if err != nil { + return nil, nil, err + } + path = append(path, component) + return &Entity{Path: path}, nil, nil + default: + return nil, nil, p.errorf("unexpected token") + } + } +} + +func (p *parser) expressions(endOfListMarker string) ([]Expression, error) { + var res []Expression + for p.peek().Text != endOfListMarker { + if len(res) > 0 { + if err := p.exact(","); err != nil { + return res, err + } + } + e, err := p.expression() + if err != nil { + return res, err + } + res = append(res, e) + } + return res, nil +} + +func (p *parser) recInits() ([]RecInit, error) { + var res []RecInit + for { + t := p.peek() + if t.Text == "}" { + p.advance() + return res, nil + } + if len(res) > 0 { + if err := p.exact(","); err != nil { + return res, err + } + } + e, err := p.recInit() + if err != nil { + return res, err + } + res = append(res, e) + } +} + +// LITERAL := BOOL | INT | STR + +type LiteralType int + +const ( + LiteralBool LiteralType = iota + LiteralInt + LiteralString +) + +type Literal struct { + Type LiteralType + Bool bool + Long int64 + Str string +} + +func (l Literal) String() string { + var res string + switch l.Type { + case LiteralBool: + res = strconv.FormatBool(l.Bool) + case LiteralInt: + res = strconv.FormatInt(l.Long, 10) + case LiteralString: + res = strconv.Quote(l.Str) + } + return res +} + +// VAR := 'principal' | 'action' | 'resource' | 'context' + +type VarType string + +const ( + VarPrincipal VarType = "principal" + VarAction VarType = "action" + VarResource VarType = "resource" + VarContext VarType = "context" +) + +type Var struct { + Type VarType +} + +func (v Var) String() string { + return string(v.Type) +} + +// ExtFun := [Path '::'] IDENT + +type ExtFun struct { + Path []string + Expressions []Expression +} + +func (f ExtFun) String() string { + var sb strings.Builder + sb.WriteString(strings.Join(f.Path, "::")) + sb.WriteRune('(') + for i, e := range f.Expressions { + if i > 0 { + sb.WriteString(", ") + } + sb.WriteString(e.String()) + } + sb.WriteRune(')') + return sb.String() +} + +func (p *parser) extFun(path []string) (ExtFun, error) { + res := ExtFun{Path: path} + if err := p.exact("("); err != nil { + return res, err + } + var err error + if res.Expressions, err = p.expressions(")"); err != nil { + return res, err + } + p.advance() // expressions guarantees ")" + return res, err +} + +// Access := '.' IDENT ['(' [ExprList] ')'] | '[' STR ']' + +type AccessType int + +const ( + AccessField AccessType = iota + AccessCall + AccessIndex +) + +type Access struct { + Type AccessType + Name string + Expressions []Expression +} + +func (a Access) String() string { + var sb strings.Builder + switch a.Type { + case AccessField: + sb.WriteRune('.') + sb.WriteString(a.Name) + case AccessCall: + sb.WriteRune('.') + sb.WriteString(a.Name) + sb.WriteRune('(') + for i, e := range a.Expressions { + if i > 0 { + sb.WriteString(", ") + } + sb.WriteString(e.String()) + } + sb.WriteRune(')') + case AccessIndex: + sb.WriteRune('[') + sb.WriteString(strconv.Quote(a.Name)) + sb.WriteRune(']') + } + return sb.String() +} + +func (p *parser) access() (Access, bool, error) { + var res Access + var err error + t := p.peek() + switch t.Text { + case ".": + p.advance() + t := p.advance() + if !t.isIdent() { + return res, false, p.errorf("unexpected token") + } + res.Name = t.Text + if p.peek().Text == "(" { + p.advance() + res.Type = AccessCall + if res.Expressions, err = p.expressions(")"); err != nil { + return res, false, err + } + p.advance() // expressions guarantees ")" + } else { + res.Type = AccessField + } + case "[": + p.advance() + res.Type = AccessIndex + t := p.advance() + if !t.isString() { + return res, false, p.errorf("unexpected token") + } + if res.Name, err = t.stringValue(); err != nil { + return res, false, err + } + if err := p.exact("]"); err != nil { + return res, false, err + } + default: + return res, false, nil + } + return res, true, nil +} + +// RecInits := (IDENT | STR) ':' Expr {',' (IDENT | STR) ':' Expr} + +type RecKeyType int + +const ( + RecKeyIdent RecKeyType = iota + RecKeyString +) + +type RecInit struct { + KeyType RecKeyType + Key string + Value Expression +} + +func (r RecInit) String() string { + var sb strings.Builder + switch r.KeyType { + case RecKeyIdent: + sb.WriteString(r.Key) + case RecKeyString: + sb.WriteString(strconv.Quote(r.Key)) + } + sb.WriteString(": ") + sb.WriteString(r.Value.String()) + return sb.String() +} + +func (p *parser) recInit() (RecInit, error) { + var res RecInit + var err error + t := p.advance() + switch { + case t.isIdent(): + res.KeyType = RecKeyIdent + res.Key = t.Text + case t.isString(): + res.KeyType = RecKeyString + if res.Key, err = t.stringValue(); err != nil { + return res, err + } + default: + return res, p.errorf("unexpected token") + } + if err := p.exact(":"); err != nil { + return res, err + } + if res.Value, err = p.expression(); err != nil { + return res, err + } + return res, nil +} diff --git a/x/exp/parser2/parse_test.go b/x/exp/parser2/parse_test.go new file mode 100644 index 00000000..7a6c0bb4 --- /dev/null +++ b/x/exp/parser2/parse_test.go @@ -0,0 +1,465 @@ +package parser + +import ( + "testing" + + "github.com/cedar-policy/cedar-go/testutil" +) + +func TestParse(t *testing.T) { + t.Parallel() + tests := []struct { + name string + in string + err bool + }{ + // Success cases + // Test cases from https://github.com/cedar-policy/cedar/blob/main/cedar-policy-core/src/parser/testfiles/policies.cedar + {"empty", ``, false}, + {"ex1", `//@test_annotation("This is the annotation") + permit( + principal == User::"alice", + action == PhotoOp::"view", + resource == Photo::"VacationPhoto94.jpg" + );`, false}, + {"ex2", `permit( + principal in Team::"admins", + action in [PhotoOp::"view", PhotoOp::"edit", PhotoOp::"delete"], + resource in Album::"jane_vacation" + );`, false}, + {"ex3", `permit( + principal == User::"alice", + action in PhotoflashRole::"admin", + resource in Album::"jane_vacation" + );`, false}, + {"simplest", `permit( + principal, + action, + resource + );`, false}, + {"in", `permit( + principal in Team::"eng", + action in PhotoflashRole::"admin", + resource in Album::"jane_vacation" + ); + + permit( + principal in Team::"eng", + action in [PhotoflashRole::"admin"], + resource in Album::"jane_vacation" + ); + + permit( + principal in Team::"eng", + action in [PhotoflashRole::"admin", PhotoflashRole::"operator"], + resource in Album::"jane_vacation" + ); + `, false}, + {"multipleIdentEntities", `permit( + principal == Org::Team::User::"alice", + action, + resource + );`, false}, + {"multiplePolicies", `permit( + principal, + action, + resource + ); + + forbid( + principal in Team::"admins", + action in [PhotoOp::"view", PhotoOp::"edit", PhotoOp::"delete"], + resource in Album::"jane_vacation" + ); + `, false}, + {"annotations", `@first_annotation("This is the annotation") + @second_annotation("This is another annotation") + permit( + principal, + action, + resource + );`, false}, + + // Additional success cases + {"primaryInt", `permit(principal, action, resource) when { 1234 };`, false}, + {"primaryString", `permit(principal, action, resource) when { "test string" };`, false}, + {"primaryBool", `permit(principal, action, resource) when { true } unless { false };`, false}, + {"primaryVar", `permit(principal, action, resource) + when { principal } + unless { action } + when { resource } + unless { context }; + `, false}, + {"primaryEntity", `permit(principal, action, resource) + when { Org::User::"alice" }; + `, false}, + {"primaryExtFun", `permit(principal, action, resource) + when { foo() } + unless { foo::bar::as() } + when { foo("hello") } + unless { foo::bar(true, 42, "forty two") }; + `, false}, + {"ifElseThen", `permit(principal, action, resource) + when { if false then principal else principal };`, false}, + {"access", `permit(principal, action, resource) + when { resource.foo } + unless { resource.foo.bar } + when { principal.foo() } + unless { principal.bar(false) } + when { action.foo["bar"].baz() } + unless { principal.bar(false, 123, "foo") } + when { principal["foo"] };`, false}, + {"unary", `permit(principal, action, resource) + when { !resource.foo } + unless { -resource.bar } + when { !!resource.foo } + unless { --resource.bar } + when { !-!-resource.bar };`, false}, + {"mult", `permit(principal, action, resource) + when { resource.foo * 42 } + unless { 42 * resource.bar } + when { 42 * resource.bar * 43 } + when { resource.foo * principal.bar };`, false}, + {"add", `permit(principal, action, resource) + when { resource.foo + 42 } + unless { 42 - resource.bar } + when { 42 + resource.bar - 43 } + when { resource.foo + principal.bar };`, false}, + {"relations", `permit(principal, action, resource) + when { foo() } + unless { foo() < 3 } + unless { foo() <= 3 } + unless { foo() > 3 } + unless { foo() >= 3 } + unless { foo() != 3 } + unless { foo() == 3 } + unless { foo() in Domain::"value" } + unless { foo() has blah } + when { foo() has "bar" } + when { foo() like "h*ll*" };`, false}, + {"foo-like-foo", `permit(principal, action, resource) + when { "f*o" like "f\*o" };`, false}, + {"ands", `permit(principal, action, resource) + when { foo() && bar() && 3};`, false}, + {"ors_and_ands", `permit(principal, action, resource) + when { foo() && bar() || baz() || 1 < 2 && 2 < 3};`, false}, + {"primaryExpression", `permit(principal, action, resource) + when { (true) } + unless { ((if (foo() <= 234) then principal else principal) like "") };`, false}, + {"primaryExprList", `permit(principal, action, resource) + when { [] } + unless { [true] } + when { [123, (principal has "name" && principal.name == "alice")]};`, false}, + {"primaryRecInits", `permit(principal, action, resource) + when { {} } + unless { {"key": "some value"} } + when { {"key": "some value", id: "another value"} };`, false}, + {"most-positive-long", + `permit(principal,action,resource) when { 9223372036854775807 == -(-9223372036854775807) };`, + false}, + {"principal-is", `permit (principal is X, action, resource);`, false}, + {"principal-is-long", `permit (principal is X::Y, action, resource);`, false}, + {"principal-is-in", `permit (principal is X in X::"z", action, resource);`, false}, + {"resource-is", `permit (principal, action, resource is X);`, false}, + {"resource-is-long", `permit (principal, action, resource is X::Y);`, false}, + {"resource-is-in", `permit (principal, action, resource is X in X::"z");`, false}, + {"when-is", `permit (principal, action, resource) when { principal is X };`, false}, + {"when-is-long", `permit (principal, action, resource) when { principal is X::Y };`, false}, + {"when-is-in", `permit (principal, action, resource) when { principal is X in X::"z" };`, false}, + + {"most-negative-long", `permit(principal,action,resource) when { -9223372036854775808 == -9223372036854775808 };`, false}, + {"most-negative-long2", `permit(principal,action,resource) when { -9223372036854775808 < -9223372036854775807 };`, false}, + + // Error cases + {"missingEffect", `@id("test")`, true}, + {"invalidEffect", `invalidEffect(principal, action, resource);`, true}, + {"missingSemicolon", `permit(principal, action, resource)`, true}, + {"missingScope", `permit;`, true}, + {"missingPrincipal", `permit(resource, action);`, true}, + {"missingResourceAndAction", `permit(principal);`, true}, + {"missingResource", `permit(principal, action);`, true}, + {"eofInScope", `permit(principal`, true}, + {"missingAction", `permit(principal, resource);`, true}, + {"invalidResource", `permit(principal, action, other);`, true}, + {"missingScopeEndParen", `permit(principal, action, resource;`, true}, + {"missingEntity", `permit(principal ==`, true}, + {"invalidEntity", `permit(principal == "alice", action, resource);`, true}, + {"invalidEntity2", `permit(principal == User::, action, resource);`, true}, + {"invalidEntity3", `permit(principal == User::123, action, resource);`, true}, + {"invalidEntity3", `permit(principal == User::`, true}, + {"invalidEntities", `permit(principal, action in [invalidEntity], resource);`, true}, + {"invalidEntities2", `permit(principal, action in [User::"alice", invalidEntity], resource);`, true}, + {"invalidEntities3", `permit(principal, action in [User::"alice";], resource);`, true}, + {"invalidEntities4", `permit(principal, action in [User::"alice"`, true}, + {"invalidAnnotation1", `@`, true}, + {"invalidAnnotation2", `@"annotate"`, true}, + {"invalidAnnotation3", `@annotate(`, true}, + {"invalidAnnotation4", `@annotate[""]`, true}, + {"invalidAnnotation5", `@annotate("test"]`, true}, + {"invalidAnnotation6", `@annotate(test)`, true}, + {"invalidCondition1", `permit(principal, action, resource) when`, true}, + {"invalidCondition2", `permit(principal, action, resource) when {`, true}, + {"invalidCondition3", `permit(principal, action, resource) when {}`, true}, + {"invalidCondition4", `permit(principal, action, resource) when { true`, true}, + {"invalidPrimaryInteger", `permit(principal, action, resource) + when { 0xabcd };`, true}, + {"invalidPrimary", `permit(principal, action, resource) + when { ( };`, true}, + {"invalidExtFun1", `permit(principal, action, resource) + when { abcd`, true}, + {"invalidExtFun2", `permit(principal, action, resource) + when { abcd(`, true}, + {"invalidExtFun3", `permit(principal, action, resource) + when { abcd::`, true}, + {"invalidExtFun4", `permit(principal, action, resource) + when { abcd::123`, true}, + {"invalidExtFun5", `permit(principal, action, resource) + when { abcd(123`, true}, + {"invalidIfElseThen1", `permit(principal, action, resource) + when { if }`, true}, + {"invalidIfElseThen2", `permit(principal, action, resource) + when { if true }`, true}, + {"invalidIfElseThen3", `permit(principal, action, resource) + when { if true then }`, true}, + {"invalidIfElseThen4", `permit(principal, action, resource) + when { if true then principal }`, true}, + {"invalidIfElseThen5", `permit(principal, action, resource) + when { if true then principal else }`, true}, + {"invalidAccess1", `permit(principal, action, resource) + when { resource.`, true}, + {"invalidAccess2", `permit(principal, action, resource) + when { resource.bar.123 };`, true}, + {"invalidAccess3", `permit(principal, action, resource) + when { resource.bar(`, true}, + {"invalidAccess4", `permit(principal, action, resource) + when { resource.bar(]`, true}, + {"invalidAccess5", `permit(principal, action, resource) + when { resource.bar(,)`, true}, + {"invalidAccess6", `permit(principal, action, resource) + when { resource.bar[`, true}, + {"invalidAccess7", `permit(principal, action, resource) + when { resource.bar[baz]`, true}, + {"invalidAccess8", `permit(principal, action, resource) + when { resource.bar["baz")`, true}, + {"invalidUnaryOp", `permit(principal, action, resource) + when { +resource.bar };`, true}, + {"invalidAdd", `permit(principal, action, resource) + when { resource.foo +`, true}, + {"invalidRelation", `permit(principal, action, resource) + when { resource.name in`, true}, + {"invalidHas1", `permit(principal, action, resource) + when { resource.name has`, true}, + {"invalidHas2", `permit(principal, action, resource) + when { resource.name has 123`, true}, + {"invalidLike1", `permit(principal, action, resource) + when { resource.name like`, true}, + {"invalidLike2", `permit(principal, action, resource) + when { resource.name like foo`, true}, + {"invalidPrimaryExpr", `permit(principal, action, resource) + when { (true`, true}, + {"invalidPrimaryExprList", `permit(principal, action, resource) + when { [`, true}, + {"invalidActionEqRhs", `permit(principal, action == Foo, resource);`, true}, + {"invalidActionInRhs", `permit(principal, action in Foo, resource);`, true}, + {"invalidPrimaryRecInits1", `permit(principal, action, resource) + when { {`, true}, + {"invalidPrimaryRecInits2", `permit(principal, action, resource) + when { {123: "value"} };`, true}, + {"invalidPrimaryRecInits3", `permit(principal, action, resource) + when { {"key" "value"} };`, true}, + {"invalidPrimaryRecInits4", `permit(principal, action, resource) + when { {"key":`, true}, + {"invalidPrimaryRecInits5", `permit(principal, action, resource) + when { {"key1": "value1" "key2": "value2" };`, true}, + + {"invalidStringAnnotation", `@bananas("\*") permit (principal, action, resource);`, true}, + {"invalidStringEntityID", `permit(principal == User::"\*", action, resource);`, true}, + {"invalidStringHas", `permit(principal, action, resource) when { context has "\*" };`, true}, + {"invalidNumericLike", `permit(principal, action, resource) when { context.key like 42 };`, true}, + {"invalidPatternLike", `permit(principal, action, resource) when { context.key like "\u{DFFF}" };`, true}, + {"invalidStringPrimary", `permit(principal, action, resource) when { context.key == "\*" };`, true}, + {"invalidExtFun", `permit(principal, action, resource) when { principal == User::"\*" };`, true}, + {"invalidAccess", `permit(principal, action, resource) when { context["\*"] == 42 };`, true}, + {"invalidRecordKey", `permit(principal, action, resource) when { { "\*":42 } };`, true}, + {"invalidIs", `permit (principal is 1, action, resource);`, true}, + {"invalidIsLong", `permit (principal is X::1, action, resource);`, true}, + {"duplicateAnnotations", `@key("value") @key("value") permit (principal, action, resource);`, true}, + + {"very-negative-long-bad", `permit(principal,action,resource) when { -9223372036823454775808 < -9224323372036854775807 };`, true}, + {"very-positive-long-bad", `permit(principal,action,resource) when { 9223372036823454775808 < 9224323372036854775807 };`, true}, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + tokens, err := Tokenize([]byte(tt.in)) + testutil.OK(t, err) + got, err := Parse(tokens) + testutil.Equals(t, err != nil, tt.err) + if err != nil { + testutil.Equals(t, got, nil) + return + } + + gotTokens, err := Tokenize([]byte(got.String())) + testutil.OK(t, err) + + var tokenStrs []string + for _, t := range tokens { + tokenStrs = append(tokenStrs, t.toString()) + } + + var gotTokenStrs []string + for _, t := range gotTokens { + gotTokenStrs = append(gotTokenStrs, t.toString()) + } + + testutil.Equals(t, gotTokenStrs, tokenStrs) + }) + } +} + +func TestParseTypes(t *testing.T) { + t.Parallel() + tests := []struct { + name string + in string + out Policies + }{ + { + "first", + "permit(principal, action, resource) when { 3 * 2 > 5 };", + Policies{ + Policy{ + Position: Position{Offset: 0, Line: 1, Column: 1}, + Annotations: []Annotation(nil), + Effect: "permit", + Conditions: []Condition{ + { + Type: "when", + Expression: Expression{ + Type: ExpressionOr, + Or: Or{ + Ands: []And{ + { + Relations: []Relation{ + { + Add: Add{ + Mults: []Mult{ + { + Unaries: []Unary{ + { + Ops: []UnaryOp(nil), + Member: Member{ + Primary: Primary{ + Type: PrimaryLiteral, + Literal: Literal{Type: LiteralInt, Long: 3}, + }, + Accesses: []Access(nil), + }, + }, + { + Ops: []UnaryOp(nil), + Member: Member{ + Primary: Primary{ + Type: PrimaryLiteral, + Literal: Literal{Type: LiteralInt, Long: 2}, + }, + Accesses: []Access(nil), + }, + }, + }, + }, + }, + }, + Type: "relop", + RelOp: ">", + RelOpRhs: Add{ + Mults: []Mult{ + { + Unaries: []Unary{ + { + Ops: []UnaryOp(nil), + Member: Member{ + Primary: Primary{ + Type: PrimaryLiteral, + Literal: Literal{Type: LiteralInt, Long: 5}, + }, + Accesses: []Access(nil), + }, + }, + }, + }, + }, + }, + Str: "", + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + tokens, err := Tokenize([]byte(tt.in)) + testutil.OK(t, err) + got, err := Parse(tokens) + testutil.OK(t, err) + testutil.Equals(t, got, tt.out) + }) + } +} + +func TestParseEntity(t *testing.T) { + t.Parallel() + tests := []struct { + name string + in string + out Entity + err func(testing.TB, error) + }{ + {"happy", `Action::"test"`, Entity{Path: []string{"Action", "test"}}, testutil.OK}, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + toks, err := Tokenize([]byte(tt.in)) + testutil.OK(t, err) + out, err := ParseEntity(toks) + testutil.Equals(t, out, tt.out) + tt.err(t, err) + }) + } +} + +func TestPolicyPositions(t *testing.T) { + t.Parallel() + in := `// idk a comment +@blah("asdf") +permit( principal, action, resource ); + + +// later on + permit (principal, action, resource) ; + +// annotation indent + @test("1234") permit (principal, action, resource ); +` + toks, err := Tokenize([]byte(in)) + testutil.OK(t, err) + out, err := Parse(toks) + testutil.OK(t, err) + testutil.Equals(t, len(out), 3) + testutil.Equals(t, out[0].Position, Position{Offset: 17, Line: 2, Column: 1}) + testutil.Equals(t, out[1].Position, Position{Offset: 86, Line: 7, Column: 3}) + testutil.Equals(t, out[2].Position, Position{Offset: 148, Line: 10, Column: 2}) +} diff --git a/x/exp/parser2/tokenize.go b/x/exp/parser2/tokenize.go new file mode 100644 index 00000000..e2e41d65 --- /dev/null +++ b/x/exp/parser2/tokenize.go @@ -0,0 +1,705 @@ +package parser + +import ( + "bytes" + "fmt" + "io" + "strconv" + "strings" + "unicode" + "unicode/utf8" +) + +//go:generate moq -pkg parser -fmt goimports -out tokenize_mocks_test.go . reader + +// This type alias is for test purposes only. +type reader = io.Reader + +type TokenType int + +const ( + TokenEOF = TokenType(iota) + TokenIdent + TokenInt + TokenString + TokenOperator + TokenUnknown +) + +type Token struct { + Type TokenType + Pos Position + Text string +} + +func (t Token) isEOF() bool { + return t.Type == TokenEOF +} + +func (t Token) isIdent() bool { + return t.Type == TokenIdent +} + +func (t Token) isInt() bool { + return t.Type == TokenInt +} + +func (t Token) isString() bool { + return t.Type == TokenString +} + +func (t Token) toString() string { + return t.Text +} + +func (t Token) stringValue() (string, error) { + s := t.Text + s = strings.TrimPrefix(s, "\"") + s = strings.TrimSuffix(s, "\"") + b := []byte(s) + res, _, err := rustUnquote(b, false) + return res, err +} + +func (t Token) patternValue() (Pattern, error) { + return NewPattern(t.Text) +} + +func nextRune(b []byte, i int) (rune, int, error) { + ch, size := utf8.DecodeRune(b[i:]) + if ch == utf8.RuneError { + return ch, i, fmt.Errorf("bad unicode rune") + } + return ch, i + size, nil +} + +func parseHexEscape(b []byte, i int) (rune, int, error) { + var ch rune + var err error + ch, i, err = nextRune(b, i) + if err != nil { + return 0, i, err + } + if !isHexadecimal(ch) { + return 0, i, fmt.Errorf("bad hex escape sequence") + } + res := digitVal(ch) + ch, i, err = nextRune(b, i) + if err != nil { + return 0, i, err + } + if !isHexadecimal(ch) { + return 0, i, fmt.Errorf("bad hex escape sequence") + } + res = 16*res + digitVal(ch) + if res > 127 { + return 0, i, fmt.Errorf("bad hex escape sequence") + } + return rune(res), i, nil +} + +func parseUnicodeEscape(b []byte, i int) (rune, int, error) { + var ch rune + var err error + + ch, i, err = nextRune(b, i) + if err != nil { + return 0, i, err + } + if ch != '{' { + return 0, i, fmt.Errorf("bad unicode escape sequence") + } + + digits := 0 + res := 0 + for { + ch, i, err = nextRune(b, i) + if err != nil { + return 0, i, err + } + if ch == '}' { + break + } + if !isHexadecimal(ch) { + return 0, i, fmt.Errorf("bad unicode escape sequence") + } + res = 16*res + digitVal(ch) + digits++ + } + + if digits == 0 || digits > 6 || !utf8.ValidRune(rune(res)) { + return 0, i, fmt.Errorf("bad unicode escape sequence") + } + + return rune(res), i, nil +} + +func Unquote(s string) (string, error) { + s = strings.TrimPrefix(s, "\"") + s = strings.TrimSuffix(s, "\"") + res, _, err := rustUnquote([]byte(s), false) + return res, err +} + +func rustUnquote(b []byte, star bool) (string, []byte, error) { + var sb strings.Builder + var ch rune + var err error + i := 0 + for i < len(b) { + ch, i, err = nextRune(b, i) + if err != nil { + return "", nil, err + } + if star && ch == '*' { + i-- + return sb.String(), b[i:], nil + } + if ch != '\\' { + sb.WriteRune(ch) + continue + } + ch, i, err = nextRune(b, i) + if err != nil { + return "", nil, err + } + switch ch { + case 'n': + sb.WriteRune('\n') + case 'r': + sb.WriteRune('\r') + case 't': + sb.WriteRune('\t') + case '\\': + sb.WriteRune('\\') + case '0': + sb.WriteRune('\x00') + case '\'': + sb.WriteRune('\'') + case '"': + sb.WriteRune('"') + case 'x': + ch, i, err = parseHexEscape(b, i) + if err != nil { + return "", nil, err + } + sb.WriteRune(ch) + case 'u': + ch, i, err = parseUnicodeEscape(b, i) + if err != nil { + return "", nil, err + } + sb.WriteRune(ch) + case '*': + if !star { + return "", nil, fmt.Errorf("bad char escape") + } + sb.WriteRune('*') + default: + return "", nil, fmt.Errorf("bad char escape") + } + } + return sb.String(), b[i:], nil +} + +type PatternComponent struct { + Star bool + Chunk string +} + +type Pattern struct { + Comps []PatternComponent + Raw string +} + +func (p Pattern) String() string { + return p.Raw +} + +func NewPattern(literal string) (Pattern, error) { + rawPat := literal + + literal = strings.TrimPrefix(literal, "\"") + literal = strings.TrimSuffix(literal, "\"") + + b := []byte(literal) + + var comps []PatternComponent + for len(b) > 0 { + var comp PatternComponent + var err error + for len(b) > 0 && b[0] == '*' { + b = b[1:] + comp.Star = true + } + comp.Chunk, b, err = rustUnquote(b, true) + if err != nil { + return Pattern{}, err + } + comps = append(comps, comp) + } + return Pattern{ + Comps: comps, + Raw: rawPat, + }, nil +} + +func isHexadecimal(ch rune) bool { + return isDecimal(ch) || ('a' <= lower(ch) && lower(ch) <= 'f') +} + +// TODO: make FakeRustQuote actually accurate in all cases +func FakeRustQuote(s string) string { + return strconv.Quote(s) +} + +func (t Token) intValue() (int64, error) { + return strconv.ParseInt(t.Text, 10, 64) +} + +func Tokenize(src []byte) ([]Token, error) { + var res []Token + var s scanner + s.Init(bytes.NewBuffer(src)) + for tok := s.nextToken(); s.err == nil && tok.Type != TokenEOF; tok = s.nextToken() { + res = append(res, tok) + } + if s.err != nil { + return nil, s.err + } + res = append(res, Token{Type: TokenEOF, Pos: s.position}) + return res, nil +} + +// Position is a value that represents a source position. +// A position is valid if Line > 0. +type Position struct { + Offset int // byte offset, starting at 0 + Line int // line number, starting at 1 + Column int // column number, starting at 1 (character count per line) +} + +func (pos Position) String() string { + return fmt.Sprintf(":%d:%d", pos.Line, pos.Column) +} + +const ( + specialRuneEOF = rune(-(iota + 1)) + specialRuneBOF +) + +const bufLen = 1024 // at least utf8.UTFMax + +// A scanner implements reading of Unicode characters and tokens from an io.Reader. +type scanner struct { + // Input + src io.Reader + + // Source buffer + srcBuf [bufLen + 1]byte // +1 for sentinel for common case of s.next() + srcPos int // reading position (srcBuf index) + srcEnd int // source end (srcBuf index) + + // Source position + srcBufOffset int // byte offset of srcBuf[0] in source + line int // line count + column int // character count + lastLineLen int // length of last line in characters (for correct column reporting) + lastCharLen int // length of last character in bytes + + // Token text buffer + // Typically, token text is stored completely in srcBuf, but in general + // the token text's head may be buffered in tokBuf while the token text's + // tail is stored in srcBuf. + tokBuf bytes.Buffer // token text head that is not in srcBuf anymore + tokPos int // token text tail position (srcBuf index); valid if >= 0 + tokEnd int // token text tail end (srcBuf index) + + // One character look-ahead + ch rune // character before current srcPos + + // Last error encountered by nextToken. + err error + + // Start position of most recently scanned token; set by nextToken. + // Calling Init or Next invalidates the position (Line == 0). + // If an error is reported (via Error) and position is invalid, + // the scanner is not inside a token. Call Pos to obtain an error + // position in that case, or to obtain the position immediately + // after the most recently scanned token. + position Position +} + +// Init initializes a Scanner with a new source and returns s. +func (s *scanner) Init(src io.Reader) *scanner { + s.src = src + + // initialize source buffer + // (the first call to next() will fill it by calling src.Read) + s.srcBuf[0] = utf8.RuneSelf // sentinel + s.srcPos = 0 + s.srcEnd = 0 + + // initialize source position + s.srcBufOffset = 0 + s.line = 1 + s.column = 0 + s.lastLineLen = 0 + s.lastCharLen = 0 + + // initialize token text buffer + // (required for first call to next()). + s.tokPos = -1 + + // initialize one character look-ahead + s.ch = specialRuneBOF // no char read yet, not EOF + + // initialize public fields + s.position.Line = 0 // invalidate token position + + return s +} + +// next reads and returns the next Unicode character. It is designed such +// that only a minimal amount of work needs to be done in the common ASCII +// case (one test to check for both ASCII and end-of-buffer, and one test +// to check for newlines). +func (s *scanner) next() rune { + ch, width := rune(s.srcBuf[s.srcPos]), 1 + + if ch >= utf8.RuneSelf { + // uncommon case: not ASCII or not enough bytes + for s.srcPos+utf8.UTFMax > s.srcEnd && !utf8.FullRune(s.srcBuf[s.srcPos:s.srcEnd]) { + // not enough bytes: read some more, but first + // save away token text if any + if s.tokPos >= 0 { + s.tokBuf.Write(s.srcBuf[s.tokPos:s.srcPos]) + s.tokPos = 0 + // s.tokEnd is set by nextToken() + } + // move unread bytes to beginning of buffer + copy(s.srcBuf[0:], s.srcBuf[s.srcPos:s.srcEnd]) + s.srcBufOffset += s.srcPos + // read more bytes + // (an io.Reader must return io.EOF when it reaches + // the end of what it is reading - simply returning + // n == 0 will make this loop retry forever; but the + // error is in the reader implementation in that case) + i := s.srcEnd - s.srcPos + n, err := s.src.Read(s.srcBuf[i:bufLen]) + s.srcPos = 0 + s.srcEnd = i + n + s.srcBuf[s.srcEnd] = utf8.RuneSelf // sentinel + if err != nil { + if err != io.EOF { + s.error(err.Error()) + } + if s.srcEnd == 0 { + if s.lastCharLen > 0 { + // previous character was not EOF + s.column++ + } + s.lastCharLen = 0 + return specialRuneEOF + } + // If err == EOF, we won't be getting more + // bytes; break to avoid infinite loop. If + // err is something else, we don't know if + // we can get more bytes; thus also break. + break + } + } + // at least one byte + ch = rune(s.srcBuf[s.srcPos]) + if ch >= utf8.RuneSelf { + // uncommon case: not ASCII + ch, width = utf8.DecodeRune(s.srcBuf[s.srcPos:s.srcEnd]) + if ch == utf8.RuneError && width == 1 { + // advance for correct error position + s.srcPos += width + s.lastCharLen = width + s.column++ + s.error("invalid UTF-8 encoding") + return ch + } + } + } + + // advance + s.srcPos += width + s.lastCharLen = width + s.column++ + + // special situations + switch ch { + case 0: + // for compatibility with other tools + s.error("invalid character NUL") + case '\n': + s.line++ + s.lastLineLen = s.column + s.column = 0 + } + + return ch +} + +func (s *scanner) error(msg string) { + s.tokEnd = s.srcPos - s.lastCharLen // make sure token text is terminated + s.err = fmt.Errorf("%v: %v", s.position, msg) +} + +func isIdentRune(ch rune, first bool) bool { + return ch == '_' || unicode.IsLetter(ch) || unicode.IsDigit(ch) && !first +} + +func (s *scanner) scanIdentifier() rune { + // we know the zeroth rune is OK; start scanning at the next one + ch := s.next() + for isIdentRune(ch, false) { + ch = s.next() + } + return ch +} + +func lower(ch rune) rune { return ('a' - 'A') | ch } // returns lower-case ch iff ch is ASCII letter +func isDecimal(ch rune) bool { return '0' <= ch && ch <= '9' } + +func (s *scanner) scanInteger(ch rune) rune { + for isDecimal(ch) { + ch = s.next() + } + return ch +} + +func digitVal(ch rune) int { + switch { + case '0' <= ch && ch <= '9': + return int(ch - '0') + case 'a' <= lower(ch) && lower(ch) <= 'f': + return int(lower(ch) - 'a' + 10) + } + return 16 // larger than any legal digit val +} + +func (s *scanner) scanHexDigits(ch rune, min, max int) rune { + n := 0 + for n < max && isHexadecimal(ch) { + ch = s.next() + n++ + } + if n < min || n > max { + s.error("invalid char escape") + } + return ch +} + +func (s *scanner) scanEscape() rune { + ch := s.next() // read character after '/' + switch ch { + case 'n', 'r', 't', '\\', '0', '\'', '"', '*': + // nothing to do + ch = s.next() + case 'x': + ch = s.scanHexDigits(s.next(), 2, 2) + case 'u': + ch = s.next() + if ch != '{' { + s.error("invalid char escape") + return ch + } + ch = s.scanHexDigits(s.next(), 1, 6) + if ch != '}' { + s.error("invalid char escape") + return ch + } + ch = s.next() + default: + s.error("invalid char escape") + } + return ch +} + +func (s *scanner) scanString() (n int) { + ch := s.next() // read character after quote + for ch != '"' { + if ch == '\n' || ch < 0 { + s.error("literal not terminated") + return + } + if ch == '\\' { + ch = s.scanEscape() + } else { + ch = s.next() + } + n++ + } + return +} + +func (s *scanner) scanComment(ch rune) rune { + // ch == '/' || ch == '*' + if ch == '/' { + // line comment + ch = s.next() // read character after "//" + for ch != '\n' && ch >= 0 { + ch = s.next() + } + return ch + } + + // general comment + ch = s.next() // read character after "/*" + for { + if ch < 0 { + s.error("comment not terminated") + break + } + ch0 := ch + ch = s.next() + if ch0 == '*' && ch == '/' { + ch = s.next() + break + } + } + return ch +} + +func (s *scanner) scanOperator(ch0, ch rune) (TokenType, rune) { + switch ch0 { + case '@', '.', ',', ';', '(', ')', '{', '}', '[', ']', '+', '-', '*': + case ':': + if ch == ':' { + ch = s.next() + } + case '!', '<', '>': + if ch == '=' { + ch = s.next() + } + case '=': + if ch != '=' { + return TokenUnknown, ch + } + ch = s.next() + case '|': + if ch != '|' { + return TokenUnknown, ch + } + ch = s.next() + case '&': + if ch != '&' { + return TokenUnknown, ch + } + ch = s.next() + default: + return TokenUnknown, ch + } + return TokenOperator, ch +} + +func isWhitespace(c rune) bool { + switch c { + case '\t', '\n', '\r', ' ': + return true + default: + return false + } +} + +// nextToken reads the next token or Unicode character from source and returns +// it. It returns specialRuneEOF at the end of the source. It reports scanner +// errors (read and token errors) by calling s.Error, if not nil; otherwise it +// prints an error message to os.Stderr. +func (s *scanner) nextToken() Token { + if s.ch == specialRuneBOF { + s.ch = s.next() + } + + ch := s.ch + + // reset token text position + s.tokPos = -1 + s.position.Line = 0 + +redo: + // skip white space + for isWhitespace(ch) { + ch = s.next() + } + + // start collecting token text + s.tokBuf.Reset() + s.tokPos = s.srcPos - s.lastCharLen + + // set token position + s.position.Offset = s.srcBufOffset + s.tokPos + if s.column > 0 { + // common case: last character was not a '\n' + s.position.Line = s.line + s.position.Column = s.column + } else { + // last character was a '\n' + // (we cannot be at the beginning of the source + // since we have called next() at least once) + s.position.Line = s.line - 1 + s.position.Column = s.lastLineLen + } + + // determine token value + var tt TokenType + switch { + case ch == specialRuneEOF: + tt = TokenEOF + case isIdentRune(ch, true): + ch = s.scanIdentifier() + tt = TokenIdent + case isDecimal(ch): + ch = s.scanInteger(ch) + tt = TokenInt + case ch == '"': + s.scanString() + ch = s.next() + tt = TokenString + case ch == '/': + ch0 := ch + ch = s.next() + if ch == '/' || ch == '*' { + s.tokPos = -1 // don't collect token text + ch = s.scanComment(ch) + goto redo + } + tt, ch = s.scanOperator(ch0, ch) + default: + tt, ch = s.scanOperator(ch, s.next()) + } + + // end of token text + s.tokEnd = s.srcPos - s.lastCharLen + s.ch = ch + + return Token{ + Type: tt, + Pos: s.position, + Text: s.tokenText(), + } +} + +// tokenText returns the string corresponding to the most recently scanned token. +// Valid after calling nextToken and in calls of Scanner.Error. +func (s *scanner) tokenText() string { + if s.tokPos < 0 { + // no token text + return "" + } + + if s.tokBuf.Len() == 0 { + // common case: the entire token text is still in srcBuf + return string(s.srcBuf[s.tokPos:s.tokEnd]) + } + + // part of the token text was saved in tokBuf: save the rest in + // tokBuf as well and return its content + s.tokBuf.Write(s.srcBuf[s.tokPos:s.tokEnd]) + s.tokPos = s.tokEnd // ensure idempotency of TokenText() call + return s.tokBuf.String() +} diff --git a/x/exp/parser2/tokenize_mocks_test.go b/x/exp/parser2/tokenize_mocks_test.go new file mode 100644 index 00000000..ff5a98fc --- /dev/null +++ b/x/exp/parser2/tokenize_mocks_test.go @@ -0,0 +1,74 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package parser + +import ( + "sync" +) + +// Ensure, that readerMock does implement reader. +// If this is not the case, regenerate this file with moq. +var _ reader = &readerMock{} + +// readerMock is a mock implementation of reader. +// +// func TestSomethingThatUsesreader(t *testing.T) { +// +// // make and configure a mocked reader +// mockedreader := &readerMock{ +// ReadFunc: func(p []byte) (int, error) { +// panic("mock out the Read method") +// }, +// } +// +// // use mockedreader in code that requires reader +// // and then make assertions. +// +// } +type readerMock struct { + // ReadFunc mocks the Read method. + ReadFunc func(p []byte) (int, error) + + // calls tracks calls to the methods. + calls struct { + // Read holds details about calls to the Read method. + Read []struct { + // P is the p argument value. + P []byte + } + } + lockRead sync.RWMutex +} + +// Read calls ReadFunc. +func (mock *readerMock) Read(p []byte) (int, error) { + if mock.ReadFunc == nil { + panic("readerMock.ReadFunc: method is nil but reader.Read was just called") + } + callInfo := struct { + P []byte + }{ + P: p, + } + mock.lockRead.Lock() + mock.calls.Read = append(mock.calls.Read, callInfo) + mock.lockRead.Unlock() + return mock.ReadFunc(p) +} + +// ReadCalls gets all the calls that were made to Read. +// Check the length with: +// +// len(mockedreader.ReadCalls()) +func (mock *readerMock) ReadCalls() []struct { + P []byte +} { + var calls []struct { + P []byte + } + mock.lockRead.RLock() + calls = mock.calls.Read + mock.lockRead.RUnlock() + return calls +} diff --git a/x/exp/parser2/tokenize_test.go b/x/exp/parser2/tokenize_test.go new file mode 100644 index 00000000..42d9911a --- /dev/null +++ b/x/exp/parser2/tokenize_test.go @@ -0,0 +1,554 @@ +package parser + +import ( + "fmt" + "io" + "strings" + "testing" + "unicode/utf8" + + "github.com/cedar-policy/cedar-go/testutil" +) + +func TestTokenize(t *testing.T) { + t.Parallel() + input := ` +These are some identifiers +0 1 1234 +-1 9223372036854775807 -9223372036854775808 +"" "string" "\"\'\n\r\t\\\0" "\x123" "\u{0}\u{10fFfF}" +"*" "\*" "*\**" +@.,;(){}[]+-* +::: +!!=<<=>>= +||&& +// single line comment +/* +multiline comment +// embedded comment does nothing +*/ +'/%|&=` + want := []Token{ + {Type: TokenIdent, Text: "These", Pos: Position{Offset: 1, Line: 2, Column: 1}}, + {Type: TokenIdent, Text: "are", Pos: Position{Offset: 7, Line: 2, Column: 7}}, + {Type: TokenIdent, Text: "some", Pos: Position{Offset: 11, Line: 2, Column: 11}}, + {Type: TokenIdent, Text: "identifiers", Pos: Position{Offset: 16, Line: 2, Column: 16}}, + + {Type: TokenInt, Text: "0", Pos: Position{Offset: 28, Line: 3, Column: 1}}, + {Type: TokenInt, Text: "1", Pos: Position{Offset: 30, Line: 3, Column: 3}}, + {Type: TokenInt, Text: "1234", Pos: Position{Offset: 32, Line: 3, Column: 5}}, + + {Type: TokenOperator, Text: "-", Pos: Position{Offset: 37, Line: 4, Column: 1}}, + {Type: TokenInt, Text: "1", Pos: Position{Offset: 38, Line: 4, Column: 2}}, + {Type: TokenInt, Text: "9223372036854775807", Pos: Position{Offset: 40, Line: 4, Column: 4}}, + {Type: TokenOperator, Text: "-", Pos: Position{Offset: 60, Line: 4, Column: 24}}, + {Type: TokenInt, Text: "9223372036854775808", Pos: Position{Offset: 61, Line: 4, Column: 25}}, + + {Type: TokenString, Text: `""`, Pos: Position{Offset: 81, Line: 5, Column: 1}}, + {Type: TokenString, Text: `"string"`, Pos: Position{Offset: 84, Line: 5, Column: 4}}, + {Type: TokenString, Text: `"\"\'\n\r\t\\\0"`, Pos: Position{Offset: 93, Line: 5, Column: 13}}, + {Type: TokenString, Text: `"\x123"`, Pos: Position{Offset: 110, Line: 5, Column: 30}}, + {Type: TokenString, Text: `"\u{0}\u{10fFfF}"`, Pos: Position{Offset: 118, Line: 5, Column: 38}}, + + {Type: TokenString, Text: `"*"`, Pos: Position{Offset: 136, Line: 6, Column: 1}}, + {Type: TokenString, Text: `"\*"`, Pos: Position{Offset: 140, Line: 6, Column: 5}}, + {Type: TokenString, Text: `"*\**"`, Pos: Position{Offset: 145, Line: 6, Column: 10}}, + + {Type: TokenOperator, Text: "@", Pos: Position{Offset: 152, Line: 7, Column: 1}}, + {Type: TokenOperator, Text: ".", Pos: Position{Offset: 153, Line: 7, Column: 2}}, + {Type: TokenOperator, Text: ",", Pos: Position{Offset: 154, Line: 7, Column: 3}}, + {Type: TokenOperator, Text: ";", Pos: Position{Offset: 155, Line: 7, Column: 4}}, + {Type: TokenOperator, Text: "(", Pos: Position{Offset: 156, Line: 7, Column: 5}}, + {Type: TokenOperator, Text: ")", Pos: Position{Offset: 157, Line: 7, Column: 6}}, + {Type: TokenOperator, Text: "{", Pos: Position{Offset: 158, Line: 7, Column: 7}}, + {Type: TokenOperator, Text: "}", Pos: Position{Offset: 159, Line: 7, Column: 8}}, + {Type: TokenOperator, Text: "[", Pos: Position{Offset: 160, Line: 7, Column: 9}}, + {Type: TokenOperator, Text: "]", Pos: Position{Offset: 161, Line: 7, Column: 10}}, + {Type: TokenOperator, Text: "+", Pos: Position{Offset: 162, Line: 7, Column: 11}}, + {Type: TokenOperator, Text: "-", Pos: Position{Offset: 163, Line: 7, Column: 12}}, + {Type: TokenOperator, Text: "*", Pos: Position{Offset: 164, Line: 7, Column: 13}}, + + {Type: TokenOperator, Text: "::", Pos: Position{Offset: 166, Line: 8, Column: 1}}, + {Type: TokenOperator, Text: ":", Pos: Position{Offset: 168, Line: 8, Column: 3}}, + + {Type: TokenOperator, Text: "!", Pos: Position{Offset: 170, Line: 9, Column: 1}}, + {Type: TokenOperator, Text: "!=", Pos: Position{Offset: 171, Line: 9, Column: 2}}, + {Type: TokenOperator, Text: "<", Pos: Position{Offset: 173, Line: 9, Column: 4}}, + {Type: TokenOperator, Text: "<=", Pos: Position{Offset: 174, Line: 9, Column: 5}}, + {Type: TokenOperator, Text: ">", Pos: Position{Offset: 176, Line: 9, Column: 7}}, + {Type: TokenOperator, Text: ">=", Pos: Position{Offset: 177, Line: 9, Column: 8}}, + + {Type: TokenOperator, Text: "||", Pos: Position{Offset: 180, Line: 10, Column: 1}}, + {Type: TokenOperator, Text: "&&", Pos: Position{Offset: 182, Line: 10, Column: 3}}, + + {Type: TokenUnknown, Text: "'", Pos: Position{Offset: 265, Line: 16, Column: 1}}, + {Type: TokenUnknown, Text: "/", Pos: Position{Offset: 266, Line: 16, Column: 2}}, + {Type: TokenUnknown, Text: "%", Pos: Position{Offset: 267, Line: 16, Column: 3}}, + {Type: TokenUnknown, Text: "|", Pos: Position{Offset: 268, Line: 16, Column: 4}}, + {Type: TokenUnknown, Text: "&", Pos: Position{Offset: 269, Line: 16, Column: 5}}, + {Type: TokenUnknown, Text: "=", Pos: Position{Offset: 270, Line: 16, Column: 6}}, + + {Type: TokenEOF, Text: "", Pos: Position{Offset: 271, Line: 16, Column: 7}}, + } + got, err := Tokenize([]byte(input)) + testutil.OK(t, err) + testutil.Equals(t, got, want) +} + +func TestTokenizeErrors(t *testing.T) { + t.Parallel() + tests := []struct { + input string + wantErrStr string + wantErrPos Position + }{ + {"okay\x00not okay", "invalid character NUL", Position{Line: 1, Column: 1}}, + {`okay /* + stuff + `, "comment not terminated", Position{Line: 1, Column: 6}}, + {`okay " + " foo bar`, "literal not terminated", Position{Line: 1, Column: 6}}, + {`"okay" "\a"`, "invalid char escape", Position{Line: 1, Column: 8}}, + {`"okay" "\b"`, "invalid char escape", Position{Line: 1, Column: 8}}, + {`"okay" "\f"`, "invalid char escape", Position{Line: 1, Column: 8}}, + {`"okay" "\v"`, "invalid char escape", Position{Line: 1, Column: 8}}, + {`"okay" "\1"`, "invalid char escape", Position{Line: 1, Column: 8}}, + {`"okay" "\x"`, "invalid char escape", Position{Line: 1, Column: 8}}, + {`"okay" "\x1"`, "invalid char escape", Position{Line: 1, Column: 8}}, + {`"okay" "\ubadf"`, "invalid char escape", Position{Line: 1, Column: 8}}, + {`"okay" "\U0000badf"`, "invalid char escape", Position{Line: 1, Column: 8}}, + {`"okay" "\u{}"`, "invalid char escape", Position{Line: 1, Column: 8}}, + {`"okay" "\u{0000000}"`, "invalid char escape", Position{Line: 1, Column: 8}}, + {`"okay" "\u{z"`, "invalid char escape", Position{Line: 1, Column: 8}}, + } + for i, tt := range tests { + tt := tt + t.Run(fmt.Sprintf("%02d", i), func(t *testing.T) { + t.Parallel() + got, gotErr := Tokenize([]byte(tt.input)) + wantErrStr := fmt.Sprintf("%v: %s", tt.wantErrPos, tt.wantErrStr) + testutil.Error(t, gotErr) + testutil.Equals(t, gotErr.Error(), wantErrStr) + testutil.Equals(t, got, nil) + }) + } +} + +func TestIntTokenValues(t *testing.T) { + t.Parallel() + tests := []struct { + input string + wantOk bool + want int64 + wantErr string + }{ + {"0", true, 0, ""}, + {"9223372036854775807", true, 9223372036854775807, ""}, + {"9223372036854775808", false, 0, `strconv.ParseInt: parsing "9223372036854775808": value out of range`}, + } + for _, tt := range tests { + tt := tt + t.Run(tt.input, func(t *testing.T) { + t.Parallel() + got, err := Tokenize([]byte(tt.input)) + testutil.OK(t, err) + testutil.Equals(t, len(got), 2) + testutil.Equals(t, got[0].Type, TokenInt) + gotInt, err := got[0].intValue() + if err != nil { + testutil.Equals(t, tt.wantOk, false) + testutil.Equals(t, err.Error(), tt.wantErr) + } else { + testutil.Equals(t, tt.wantOk, true) + testutil.Equals(t, gotInt, tt.want) + } + }) + } +} + +func TestStringTokenValues(t *testing.T) { + t.Parallel() + tests := []struct { + input string + wantOk bool + want string + wantErr string + }{ + {`""`, true, "", ""}, + {`"hello"`, true, "hello", ""}, + {`"a\n\r\t\\\0b"`, true, "a\n\r\t\\\x00b", ""}, + {`"a\"b"`, true, "a\"b", ""}, + {`"a\'b"`, true, "a'b", ""}, + + {`"a\x00b"`, true, "a\x00b", ""}, + {`"a\x7fb"`, true, "a\x7fb", ""}, + {`"a\x80b"`, false, "", "bad hex escape sequence"}, + + {`"a\u{A}b"`, true, "a\u000ab", ""}, + {`"a\u{aB}b"`, true, "a\u00abb", ""}, + {`"a\u{AbC}b"`, true, "a\u0abcb", ""}, + {`"a\u{aBcD}b"`, true, "a\uabcdb", ""}, + {`"a\u{AbCdE}b"`, true, "a\U000abcdeb", ""}, + {`"a\u{10cDeF}b"`, true, "a\U0010cdefb", ""}, + {`"a\u{ffffff}b"`, false, "", "bad unicode escape sequence"}, + {`"a\u{d7ff}b"`, true, "a\ud7ffb", ""}, + {`"a\u{d800}b"`, false, "", "bad unicode escape sequence"}, + {`"a\u{dfff}b"`, false, "", "bad unicode escape sequence"}, + {`"a\u{e000}b"`, true, "a\ue000b", ""}, + {`"a\u{10ffff}b"`, true, "a\U0010ffffb", ""}, + {`"a\u{110000}b"`, false, "", "bad unicode escape sequence"}, + } + for _, tt := range tests { + tt := tt + t.Run(tt.input, func(t *testing.T) { + t.Parallel() + got, err := Tokenize([]byte(tt.input)) + testutil.OK(t, err) + testutil.Equals(t, len(got), 2) + testutil.Equals(t, got[0].Type, TokenString) + gotStr, err := got[0].stringValue() + if err != nil { + testutil.Equals(t, tt.wantOk, false) + testutil.Equals(t, err.Error(), tt.wantErr) + } else { + testutil.Equals(t, tt.wantOk, true) + testutil.Equals(t, gotStr, tt.want) + } + }) + } +} + +func TestParseUnicodeEscape(t *testing.T) { + t.Parallel() + tests := []struct { + name string + in []byte + out rune + outN int + err func(t testing.TB, err error) + }{ + {"happy", []byte{'{', '4', '2', '}'}, 0x42, 4, testutil.OK}, + {"badRune", []byte{'{', 0x80, 0x81}, 0, 1, testutil.Error}, + {"notHex", []byte{'{', 'g'}, 0, 2, testutil.Error}, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + out, n, err := parseUnicodeEscape(tt.in, 0) + testutil.Equals(t, out, tt.out) + testutil.Equals(t, n, tt.outN) + tt.err(t, err) + }) + } +} + +func TestUnquote(t *testing.T) { + t.Parallel() + tests := []struct { + name string + in string + out string + err func(t testing.TB, err error) + }{ + {"happy", `"test"`, `test`, testutil.OK}, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + out, err := Unquote(tt.in) + testutil.Equals(t, out, tt.out) + tt.err(t, err) + }) + } +} + +func TestRustUnquote(t *testing.T) { + t.Parallel() + // star == false + { + tests := []struct { + input string + wantOk bool + want string + wantErr string + }{ + {``, true, "", ""}, + {`hello`, true, "hello", ""}, + {`a\n\r\t\\\0b`, true, "a\n\r\t\\\x00b", ""}, + {`a\"b`, true, "a\"b", ""}, + {`a\'b`, true, "a'b", ""}, + + {`a\x00b`, true, "a\x00b", ""}, + {`a\x7fb`, true, "a\x7fb", ""}, + {`a\x80b`, false, "", "bad hex escape sequence"}, + + {string([]byte{0x80, 0x81}), false, "", "bad unicode rune"}, + {`a\u`, false, "", "bad unicode rune"}, + {`a\uz`, false, "", "bad unicode escape sequence"}, + {`a\u{}b`, false, "", "bad unicode escape sequence"}, + {`a\u{A}b`, true, "a\u000ab", ""}, + {`a\u{aB}b`, true, "a\u00abb", ""}, + {`a\u{AbC}b`, true, "a\u0abcb", ""}, + {`a\u{aBcD}b`, true, "a\uabcdb", ""}, + {`a\u{AbCdE}b`, true, "a\U000abcdeb", ""}, + {`a\u{10cDeF}b`, true, "a\U0010cdefb", ""}, + {`a\u{ffffff}b`, false, "", "bad unicode escape sequence"}, + {`a\u{0000000}b`, false, "", "bad unicode escape sequence"}, + {`a\u{d7ff}b`, true, "a\ud7ffb", ""}, + {`a\u{d800}b`, false, "", "bad unicode escape sequence"}, + {`a\u{dfff}b`, false, "", "bad unicode escape sequence"}, + {`a\u{e000}b`, true, "a\ue000b", ""}, + {`a\u{10ffff}b`, true, "a\U0010ffffb", ""}, + {`a\u{110000}b`, false, "", "bad unicode escape sequence"}, + + {`\`, false, "", "bad unicode rune"}, + {`\a`, false, "", "bad char escape"}, + {`\*`, false, "", "bad char escape"}, + {`\x`, false, "", "bad unicode rune"}, + {`\xz`, false, "", "bad hex escape sequence"}, + {`\xa`, false, "", "bad unicode rune"}, + {`\xaz`, false, "", "bad hex escape sequence"}, + {`\{`, false, "", "bad char escape"}, + {`\{z`, false, "", "bad char escape"}, + {`\{0`, false, "", "bad char escape"}, + {`\{0z`, false, "", "bad char escape"}, + } + for _, tt := range tests { + tt := tt + t.Run(tt.input, func(t *testing.T) { + t.Parallel() + got, rem, err := rustUnquote([]byte(tt.input), false) + if err != nil { + testutil.Equals(t, tt.wantOk, false) + testutil.Equals(t, err.Error(), tt.wantErr) + testutil.Equals(t, got, tt.want) + } else { + testutil.Equals(t, tt.wantOk, true) + testutil.Equals(t, got, tt.want) + testutil.Equals(t, rem, []byte("")) + } + }) + } + } + + // star == true + { + tests := []struct { + input string + wantOk bool + want string + wantRem string + wantErr string + }{ + {``, true, "", "", ""}, + {`hello`, true, "hello", "", ""}, + {`a\n\r\t\\\0b`, true, "a\n\r\t\\\x00b", "", ""}, + {`a\"b`, true, "a\"b", "", ""}, + {`a\'b`, true, "a'b", "", ""}, + + {`a\x00b`, true, "a\x00b", "", ""}, + {`a\x7fb`, true, "a\x7fb", "", ""}, + {`a\x80b`, false, "", "", "bad hex escape sequence"}, + + {`a\u`, false, "", "", "bad unicode rune"}, + {`a\uz`, false, "", "", "bad unicode escape sequence"}, + {`a\u{}b`, false, "", "", "bad unicode escape sequence"}, + {`a\u{A}b`, true, "a\u000ab", "", ""}, + {`a\u{aB}b`, true, "a\u00abb", "", ""}, + {`a\u{AbC}b`, true, "a\u0abcb", "", ""}, + {`a\u{aBcD}b`, true, "a\uabcdb", "", ""}, + {`a\u{AbCdE}b`, true, "a\U000abcdeb", "", ""}, + {`a\u{10cDeF}b`, true, "a\U0010cdefb", "", ""}, + {`a\u{ffffff}b`, false, "", "", "bad unicode escape sequence"}, + {`a\u{0000000}b`, false, "", "", "bad unicode escape sequence"}, + {`a\u{d7ff}b`, true, "a\ud7ffb", "", ""}, + {`a\u{d800}b`, false, "", "", "bad unicode escape sequence"}, + {`a\u{dfff}b`, false, "", "", "bad unicode escape sequence"}, + {`a\u{e000}b`, true, "a\ue000b", "", ""}, + {`a\u{10ffff}b`, true, "a\U0010ffffb", "", ""}, + {`a\u{110000}b`, false, "", "", "bad unicode escape sequence"}, + + {`*`, true, "", "*", ""}, + {`*hello*how*are*you`, true, "", "*hello*how*are*you", ""}, + {`hello*how*are*you`, true, "hello", "*how*are*you", ""}, + {`\**`, true, "*", "*", ""}, + + {`\`, false, "", "", "bad unicode rune"}, + {`\a`, false, "", "", "bad char escape"}, + {`\x`, false, "", "", "bad unicode rune"}, + {`\xz`, false, "", "", "bad hex escape sequence"}, + {`\xa`, false, "", "", "bad unicode rune"}, + {`\xaz`, false, "", "", "bad hex escape sequence"}, + {`\{`, false, "", "", "bad char escape"}, + {`\{z`, false, "", "", "bad char escape"}, + {`\{0`, false, "", "", "bad char escape"}, + {`\{0z`, false, "", "", "bad char escape"}, + } + for _, tt := range tests { + tt := tt + t.Run(tt.input, func(t *testing.T) { + t.Parallel() + got, rem, err := rustUnquote([]byte(tt.input), true) + if err != nil { + testutil.Equals(t, tt.wantOk, false) + testutil.Equals(t, err.Error(), tt.wantErr) + testutil.Equals(t, got, tt.want) + } else { + testutil.Equals(t, tt.wantOk, true) + testutil.Equals(t, got, tt.want) + testutil.Equals(t, string(rem), tt.wantRem) + } + }) + } + } +} + +func TestFakeRustQuote(t *testing.T) { + t.Parallel() + out := FakeRustQuote("hello") + testutil.Equals(t, out, `"hello"`) +} + +func TestPatternFromStringLiteral(t *testing.T) { + t.Parallel() + tests := []struct { + input string + wantOk bool + want []PatternComponent + wantErr string + }{ + {`""`, true, nil, ""}, + {`"a"`, true, []PatternComponent{{false, "a"}}, ""}, + {`"*"`, true, []PatternComponent{{true, ""}}, ""}, + {`"*a"`, true, []PatternComponent{{true, "a"}}, ""}, + {`"a*"`, true, []PatternComponent{{false, "a"}, {true, ""}}, ""}, + {`"**"`, true, []PatternComponent{{true, ""}}, ""}, + {`"**a"`, true, []PatternComponent{{true, "a"}}, ""}, + {`"a**"`, true, []PatternComponent{{false, "a"}, {true, ""}}, ""}, + {`"*a*"`, true, []PatternComponent{{true, "a"}, {true, ""}}, ""}, + {`"**a**"`, true, []PatternComponent{{true, "a"}, {true, ""}}, ""}, + {`"abra*ca"`, true, []PatternComponent{ + {false, "abra"}, {true, "ca"}, + }, ""}, + {`"abra**ca"`, true, []PatternComponent{ + {false, "abra"}, {true, "ca"}, + }, ""}, + {`"*abra*ca"`, true, []PatternComponent{ + {true, "abra"}, {true, "ca"}, + }, ""}, + {`"abra*ca*"`, true, []PatternComponent{ + {false, "abra"}, {true, "ca"}, {true, ""}, + }, ""}, + {`"*abra*ca*"`, true, []PatternComponent{ + {true, "abra"}, {true, "ca"}, {true, ""}, + }, ""}, + {`"*abra*ca*dabra"`, true, []PatternComponent{ + {true, "abra"}, {true, "ca"}, {true, "dabra"}, + }, ""}, + {`"*abra*c\**da\*ra"`, true, []PatternComponent{ + {true, "abra"}, {true, "c*"}, {true, "da*ra"}, + }, ""}, + {`"\u"`, false, nil, "bad unicode rune"}, + } + for _, tt := range tests { + tt := tt + t.Run(tt.input, func(t *testing.T) { + t.Parallel() + got, err := NewPattern(tt.input) + if err != nil { + testutil.Equals(t, tt.wantOk, false) + testutil.Equals(t, err.Error(), tt.wantErr) + } else { + testutil.Equals(t, tt.wantOk, true) + testutil.Equals(t, got.Comps, tt.want) + testutil.Equals(t, got.String(), tt.input) + } + }) + } +} + +func TestScanner(t *testing.T) { + t.Parallel() + t.Run("SrcError", func(t *testing.T) { + t.Parallel() + wantErr := fmt.Errorf("wantErr") + r := &readerMock{ + ReadFunc: func(_ []byte) (int, error) { + return 0, wantErr + }, + } + var s scanner + s.Init(r) + out := s.next() + testutil.Equals(t, out, specialRuneEOF) + }) + + t.Run("MidEmojiEOF", func(t *testing.T) { + t.Parallel() + var s scanner + var eof bool + str := []byte(string(`🐐`)) + r := &readerMock{ + ReadFunc: func(p []byte) (int, error) { + if eof { + return 0, io.EOF + } + p[0] = str[0] + eof = true + return 1, nil + }, + } + s.Init(r) + out := s.next() + testutil.Equals(t, out, utf8.RuneError) + out = s.next() + testutil.Equals(t, out, specialRuneEOF) + }) + + t.Run("NotAsciiEmoji", func(t *testing.T) { + t.Parallel() + var s scanner + s.Init(strings.NewReader(`🐐`)) + out := s.next() + testutil.Equals(t, out, '🐐') + }) + + t.Run("InvalidUTF8", func(t *testing.T) { + t.Parallel() + var s scanner + s.Init(strings.NewReader(string([]byte{0x80, 0x81}))) + out := s.next() + testutil.Equals(t, out, utf8.RuneError) + }) + + t.Run("tokenTextNone", func(t *testing.T) { + t.Parallel() + var s scanner + s.Init(strings.NewReader("")) + out := s.tokenText() + testutil.Equals(t, out, "") + }) +} + +func TestDigitVal(t *testing.T) { + t.Parallel() + tests := []struct { + name string + in rune + out int + }{ + {"happy", '0', 0}, + {"hex", 'f', 15}, + {"sad", 'g', 16}, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + out := digitVal(tt.in) + testutil.Equals(t, out, tt.out) + }) + } +} From 97a7c0a7121d01e37c1e49a40136a47efdd0bac0 Mon Sep 17 00:00:00 2001 From: philhassey Date: Tue, 30 Jul 2024 15:27:55 -0600 Subject: [PATCH 012/216] x/exp/ast: make interface for PrincipalIn, ResourceIn more accurate, fix ActionIn Addresses IDX-48 Signed-off-by: philhassey --- x/exp/ast/scope.go | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/x/exp/ast/scope.go b/x/exp/ast/scope.go index 0d528f1f..27d0b94e 100644 --- a/x/exp/ast/scope.go +++ b/x/exp/ast/scope.go @@ -7,12 +7,8 @@ func (p *Policy) PrincipalEq(entity types.EntityUID) *Policy { return p } -func (p *Policy) PrincipalIn(entities ...types.EntityUID) *Policy { - var entityValues []types.Value - for _, e := range entities { - entities = append(entities, e) - } - p.principal = Principal().In(Set(entityValues)) +func (p *Policy) PrincipalIn(entity types.EntityUID) *Policy { + p.principal = Principal().In(Entity(entity)) return p } @@ -27,9 +23,13 @@ func (p *Policy) ActionEq(entity types.EntityUID) *Policy { } func (p *Policy) ActionIn(entities ...types.EntityUID) *Policy { + if len(entities) == 1 { + p.action = Action().In(Entity(entities[0])) + return p + } var entityValues []types.Value for _, e := range entities { - entities = append(entities, e) + entityValues = append(entityValues, e) } p.action = Action().In(Set(entityValues)) return p @@ -40,12 +40,8 @@ func (p *Policy) ResourceEq(entity types.EntityUID) *Policy { return p } -func (p *Policy) ResourceIn(entities ...types.EntityUID) *Policy { - var entityValues []types.Value - for _, e := range entities { - entities = append(entities, e) - } - p.resource = Resource().In(Set(entityValues)) +func (p *Policy) ResourceIn(entity types.EntityUID) *Policy { + p.resource = Resource().In(Entity(entity)) return p } From bf1fce64669a3542951aaed1c4534ab2499757f2 Mon Sep 17 00:00:00 2001 From: philhassey Date: Tue, 30 Jul 2024 15:45:11 -0600 Subject: [PATCH 013/216] x/exp/ast: initial sketch of JSON unmarshaller Addresses IDX-48 Signed-off-by: philhassey --- x/exp/ast/json.go | 160 +++++++++++++++++++++++++++++++++++++++++ x/exp/ast/json_test.go | 95 ++++++++++++++++++++++++ 2 files changed, 255 insertions(+) create mode 100644 x/exp/ast/json.go create mode 100644 x/exp/ast/json_test.go diff --git a/x/exp/ast/json.go b/x/exp/ast/json.go new file mode 100644 index 00000000..a32d6f00 --- /dev/null +++ b/x/exp/ast/json.go @@ -0,0 +1,160 @@ +package ast + +import ( + "encoding/json" + "fmt" + + "github.com/cedar-policy/cedar-go/types" +) + +type policyJSON struct { + Effect string `json:"effect"` + Annotations map[string]string `json:"annotations"` + Principal scopeJSON `json:"principal"` + Action scopeJSON `json:"action"` + Resource scopeJSON `json:"resource"` + Conditions []conditionJSON `json:"conditions"` +} + +type scopeJSON struct { + Op string `json:"op"` + Entity types.EntityUID `json:"entity"` + Entities []types.EntityUID `json:"entities"` + EntityType string `json:"entity_type"` + In *struct { + Entity types.EntityUID `json:"entity"` + } `json:"in"` +} + +func (s *scopeJSON) ToNode(n Node) (Node, error) { + switch s.Op { + case "All": + return True(), nil + case "==": + return n.Equals(Entity(s.Entity)), nil + case "in": + var zero types.EntityUID + if s.Entity != zero { + return n.In(Entity(s.Entity)), nil // TODO: review shape, maybe .In vs .InNode + } + var set types.Set + for _, e := range s.Entities { + set = append(set, e) + } + return n.In(Set(set)), nil // TODO: maybe there is an In and an InSet Node? + case "is": + isNode := n.Is(String(types.String(s.EntityType))) // TODO: hmmm, I'm not sure can this be Stronger-typed? + if s.In == nil { + return isNode, nil + } + return isNode.And(n.In(Entity(s.In.Entity))), nil + } + return Node{}, fmt.Errorf("unknown op: %v", s.Op) +} + +type conditionJSON struct { + Kind string `json:"kind"` + Body nodeJSON `json:"body"` +} + +func (c *conditionJSON) ToNode() (Node, error) { + n, err := c.Body.ToNode() + if err != nil { + return Node{}, fmt.Errorf("error in body: %w", err) + } + if c.Kind == "unless" { + return Not(n), nil + } + return n, nil +} + +type binaryJSON struct { + Left nodeJSON `json:"left"` + Right nodeJSON `json:"right"` +} + +type accessJSON struct { + Left nodeJSON `json:"left"` + Attr string `json:"attr"` +} + +type nodeJSON struct { + Equals *binaryJSON `json:"=="` + Access *accessJSON `json:"."` + Var *string `json:"Var"` + Value *string `json:"Value"` // could be any +} + +func (j nodeJSON) ToNode() (Node, error) { + switch { + case j.Equals != nil: + left, err := j.Equals.Left.ToNode() + if err != nil { + return Node{}, fmt.Errorf("error in left of equals: %w", err) + } + right, err := j.Equals.Right.ToNode() + if err != nil { + return Node{}, fmt.Errorf("error in right of equals: %w", err) + } + return left.Equals(right), nil + case j.Access != nil: + left, err := j.Access.Left.ToNode() + if err != nil { + return Node{}, fmt.Errorf("error in left of access: %w", err) + } + return left.Access(j.Access.Attr), nil + case j.Var != nil: + switch *j.Var { + case "principal": + return Principal(), nil + case "action": + return Action(), nil + case "resource": + return Resource(), nil + case "context": + return Context(), nil + } + return Node{}, fmt.Errorf("unknown var: %v", j.Var) + case j.Value != nil: + return String(types.String(*j.Value)), nil + } + + return Node{}, fmt.Errorf("unknown node") +} + +func (p *Policy) UnmarshalJSON(b []byte) error { + var j policyJSON + if err := json.Unmarshal(b, &j); err != nil { + return err + } + if j.Effect == "permit" { + // TODO: use builder, overwrite *p + p.effect = effectPermit + } + for k, v := range j.Annotations { + p.Annotate(types.String(k), types.String(v)) + } + var err error + p.principal, err = j.Principal.ToNode(Principal()) + if err != nil { + return fmt.Errorf("error in principal: %w", err) + } + p.action, err = j.Action.ToNode(Action()) + if err != nil { + return fmt.Errorf("error in action: %w", err) + } + p.resource, err = j.Resource.ToNode(Resource()) + if err != nil { + return fmt.Errorf("error in resource: %w", err) + } + for _, c := range j.Conditions { + n, err := c.ToNode() + if err != nil { + return fmt.Errorf("error in conditions: %w", err) + } + p.conditions = append(p.conditions, n) + // TODO: use builder? + } + + return nil +} diff --git a/x/exp/ast/json_test.go b/x/exp/ast/json_test.go new file mode 100644 index 00000000..fb7cf53e --- /dev/null +++ b/x/exp/ast/json_test.go @@ -0,0 +1,95 @@ +package ast_test + +import ( + "encoding/json" + "reflect" + "testing" + + "github.com/cedar-policy/cedar-go/types" + "github.com/cedar-policy/cedar-go/x/exp/ast" +) + +func TestUnmarshalJSON(t *testing.T) { + t.Parallel() + tests := []struct { + name string + input string + want *ast.Policy + wantErr bool + }{ + /* + @key("value") + permit ( + principal == User::"12UA45", + action == Action::"view", + resource in Folder::"abc" + ) when { + context.tls_version == "1.3" + }; + */ + {"exampleFromDocs", `{ + "annotations": { + "key": "value" + }, + "effect": "permit", + "principal": { + "op": "==", + "entity": { "type": "User", "id": "12UA45" } + }, + "action": { + "op": "==", + "entity": { "type": "Action", "id": "view" } + }, + "resource": { + "op": "in", + "entity": { "type": "Folder", "id": "abc" } + }, + "conditions": [ + { + "kind": "when", + "body": { + "==": { + "left": { + ".": { + "left": { + "Var": "context" + }, + "attr": "tls_version" + } + }, + "right": { + "Value": "1.3" + } + } + } + } + ] +}`, + ast.Permit(). + Annotate("key", "value"). + PrincipalEq(types.NewEntityUID("User", "12UA45")). + ActionEq(types.NewEntityUID("Action", "view")). + ResourceIn(types.NewEntityUID("Folder", "abc")). + When( + ast.Context().Access("tls_version").Equals(ast.String("1.3")), + ), + false, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + var p ast.Policy + err := json.Unmarshal([]byte(tt.input), &p) + if (err != nil) != tt.wantErr { + t.Errorf("error got: %v want: %v", err, tt.wantErr) + } + if !reflect.DeepEqual(&p, tt.want) { + t.Errorf("policy mismatch: got: %+v want: %+v", p, *tt.want) + } + }) + } + +} From 0bd8e68d79e195419575e69590a555fb4f251bc4 Mon Sep 17 00:00:00 2001 From: philhassey Date: Tue, 30 Jul 2024 15:47:52 -0600 Subject: [PATCH 014/216] x/exp/ast: only use sugar for JSON unmarshaller Addresses IDX-48 Signed-off-by: philhassey --- x/exp/ast/json.go | 33 ++++++++++++++++----------------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/x/exp/ast/json.go b/x/exp/ast/json.go index a32d6f00..7e515ccb 100644 --- a/x/exp/ast/json.go +++ b/x/exp/ast/json.go @@ -57,17 +57,6 @@ type conditionJSON struct { Body nodeJSON `json:"body"` } -func (c *conditionJSON) ToNode() (Node, error) { - n, err := c.Body.ToNode() - if err != nil { - return Node{}, fmt.Errorf("error in body: %w", err) - } - if c.Kind == "unless" { - return Not(n), nil - } - return n, nil -} - type binaryJSON struct { Left nodeJSON `json:"left"` Right nodeJSON `json:"right"` @@ -127,9 +116,13 @@ func (p *Policy) UnmarshalJSON(b []byte) error { if err := json.Unmarshal(b, &j); err != nil { return err } - if j.Effect == "permit" { - // TODO: use builder, overwrite *p - p.effect = effectPermit + switch j.Effect { + case "permit": + *p = *Permit() + case "forbid": + *p = *Forbid() + default: + return fmt.Errorf("unknown effect: %v", j.Effect) } for k, v := range j.Annotations { p.Annotate(types.String(k), types.String(v)) @@ -148,12 +141,18 @@ func (p *Policy) UnmarshalJSON(b []byte) error { return fmt.Errorf("error in resource: %w", err) } for _, c := range j.Conditions { - n, err := c.ToNode() + n, err := c.Body.ToNode() if err != nil { return fmt.Errorf("error in conditions: %w", err) } - p.conditions = append(p.conditions, n) - // TODO: use builder? + switch c.Kind { + case "when": + p.When(n) + case "unless": + p.Unless(n) + default: + return fmt.Errorf("unknown condition kind: %v", c.Kind) + } } return nil From cba26ae953d0e7095c1c8325e5edb6e3cbb27cd5 Mon Sep 17 00:00:00 2001 From: philhassey Date: Tue, 30 Jul 2024 16:01:26 -0600 Subject: [PATCH 015/216] x/exp/ast: add all binary ops to JSON Addresses IDX-48 Signed-off-by: philhassey --- x/exp/ast/json.go | 67 ++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 57 insertions(+), 10 deletions(-) diff --git a/x/exp/ast/json.go b/x/exp/ast/json.go index 7e515ccb..f2ae2800 100644 --- a/x/exp/ast/json.go +++ b/x/exp/ast/json.go @@ -62,13 +62,40 @@ type binaryJSON struct { Right nodeJSON `json:"right"` } +func (j binaryJSON) ToNode(f func(a, b Node) Node) (Node, error) { + left, err := j.Left.ToNode() + if err != nil { + return Node{}, fmt.Errorf("error in left: %w", err) + } + right, err := j.Right.ToNode() + if err != nil { + return Node{}, fmt.Errorf("error in right: %w", err) + } + return f(left, right), nil +} + type accessJSON struct { Left nodeJSON `json:"left"` Attr string `json:"attr"` } type nodeJSON struct { - Equals *binaryJSON `json:"=="` + Equals *binaryJSON `json:"=="` + NotEquals *binaryJSON `json:"!="` + In *binaryJSON `json:"in"` + LessThan *binaryJSON `json:"<"` + LessThanOrEqual *binaryJSON `json:"<="` + GreaterThan *binaryJSON `json:">"` + GreaterThanOrEqual *binaryJSON `json:">="` + And *binaryJSON `json:"&&"` + Or *binaryJSON `json:"||"` + Plus *binaryJSON `json:"+"` + Minus *binaryJSON `json:"-"` + Times *binaryJSON `json:"*"` + Contains *binaryJSON `json:"contains"` + ContainsAll *binaryJSON `json:"containsAll"` + ContainsAny *binaryJSON `json:"containsAny"` + Access *accessJSON `json:"."` Var *string `json:"Var"` Value *string `json:"Value"` // could be any @@ -77,15 +104,35 @@ type nodeJSON struct { func (j nodeJSON) ToNode() (Node, error) { switch { case j.Equals != nil: - left, err := j.Equals.Left.ToNode() - if err != nil { - return Node{}, fmt.Errorf("error in left of equals: %w", err) - } - right, err := j.Equals.Right.ToNode() - if err != nil { - return Node{}, fmt.Errorf("error in right of equals: %w", err) - } - return left.Equals(right), nil + return j.Equals.ToNode(Node.Equals) + case j.NotEquals != nil: + return j.NotEquals.ToNode(Node.NotEquals) + case j.In != nil: + return j.In.ToNode(Node.In) + case j.LessThan != nil: + return j.LessThan.ToNode(Node.LessThan) + case j.LessThanOrEqual != nil: + return j.LessThanOrEqual.ToNode(Node.LessThanOrEqual) + case j.GreaterThan != nil: + return j.GreaterThan.ToNode(Node.GreaterThan) + case j.GreaterThanOrEqual != nil: + return j.GreaterThanOrEqual.ToNode(Node.GreaterThanOrEqual) + case j.And != nil: + return j.And.ToNode(Node.And) + case j.Or != nil: + return j.Or.ToNode(Node.Or) + case j.Plus != nil: + return j.Plus.ToNode(Node.Plus) + case j.Minus != nil: + return j.Minus.ToNode(Node.Minus) + case j.Times != nil: + return j.Times.ToNode(Node.Times) + case j.Contains != nil: + return j.Contains.ToNode(Node.Contains) + case j.ContainsAll != nil: + return j.ContainsAll.ToNode(Node.ContainsAll) + case j.ContainsAny != nil: + return j.ContainsAny.ToNode(Node.ContainsAny) case j.Access != nil: left, err := j.Access.Left.ToNode() if err != nil { From a83f23e984f0a3d07fc25c04fb1eb282f6fe22dc Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Tue, 30 Jul 2024 15:08:35 -0700 Subject: [PATCH 016/216] cedar-go/x/exp/ast: add a construct for creating annotations prior to the policy effect This allows Cedar AST to look more similar to textual Cedar and may be nice to have when we write the text parser. Signed-off-by: philhassey --- x/exp/ast/annotation.go | 33 +- x/exp/ast/ast_test.go | 7 +- x/exp/parser2/fuzz_test.go | 103 -- x/exp/parser2/parse.go | 1462 -------------------------- x/exp/parser2/parse_test.go | 465 -------- x/exp/parser2/tokenize.go | 705 ------------- x/exp/parser2/tokenize_mocks_test.go | 74 -- x/exp/parser2/tokenize_test.go | 554 ---------- 8 files changed, 36 insertions(+), 3367 deletions(-) delete mode 100644 x/exp/parser2/fuzz_test.go delete mode 100644 x/exp/parser2/parse.go delete mode 100644 x/exp/parser2/parse_test.go delete mode 100644 x/exp/parser2/tokenize.go delete mode 100644 x/exp/parser2/tokenize_mocks_test.go delete mode 100644 x/exp/parser2/tokenize_test.go diff --git a/x/exp/ast/annotation.go b/x/exp/ast/annotation.go index 5bcd931e..4ffac6f3 100644 --- a/x/exp/ast/annotation.go +++ b/x/exp/ast/annotation.go @@ -2,9 +2,40 @@ package ast import "github.com/cedar-policy/cedar-go/types" +type Annotations struct { + nodes []Node +} + +// Annotation allows AST constructors to make policy in a similar shape to textual Cedar with +// annotations appearing before the actual policy scope: +// +// ast := Annotation("foo", "bar"). +// Annotation("baz", "quux"). +// Permit(). +// PrincipalEq(superUser) +func Annotation(name, value types.String) *Annotations { + return &Annotations{nodes: []Node{newAnnotationNode(name, value)}} +} + +func (a *Annotations) Annotation(name, value types.String) *Annotations { + a.nodes = append(a.nodes, newAnnotationNode(name, value)) + return a +} + +func (a *Annotations) Permit() *Policy { + p := Permit() + p.annotations = a.nodes + return p +} + +func (a *Annotations) Forbid() *Policy { + p := Forbid() + p.annotations = a.nodes + return p +} + func (p *Policy) Annotate(name, value types.String) *Policy { p.annotations = append(p.annotations, newAnnotationNode(name, value)) - return p } func newAnnotationNode(name, value types.String) Node { diff --git a/x/exp/ast/ast_test.go b/x/exp/ast/ast_test.go index c8a381ab..d9f9faa5 100644 --- a/x/exp/ast/ast_test.go +++ b/x/exp/ast/ast_test.go @@ -23,8 +23,8 @@ func TestAst(t *testing.T) { // ) // when { true } // unless { false }; - _ = ast.Permit(). - Annotate("example", "one"). + _ = ast.Annotation("example", "one"). + Permit(). PrincipalEq(johnny). ActionIn(sow, cast). When(ast.True()). @@ -35,7 +35,8 @@ func TestAst(t *testing.T) { // when { resource.tags.contains("private") } // unless { resource in principal.allowed_resources }; private := types.String("private") - _ = ast.Forbid().Annotate("example", "two"). + _ = ast.Annotation("example", "two"). + Forbid(). When( ast.Resource().Access("tags").Contains(ast.String(private)), ). diff --git a/x/exp/parser2/fuzz_test.go b/x/exp/parser2/fuzz_test.go deleted file mode 100644 index c6f89606..00000000 --- a/x/exp/parser2/fuzz_test.go +++ /dev/null @@ -1,103 +0,0 @@ -package parser - -import ( - "testing" -) - -// https://go.dev/doc/tutorial/fuzz -// mkdir testdata -// go test -fuzz=FuzzTokenize -fuzztime 60s -// go test -fuzz=FuzzParse -fuzztime 60s - -func FuzzTokenize(f *testing.F) { - tests := []string{ - `These are some identifiers`, - `0 1 1234`, - `-1 9223372036854775807 -9223372036854775808`, - `"" "string" "\"\'\n\r\t\\\0" "\x123" "\u{0}\u{10fFfF}"`, - `"*" "\*" "*\**"`, - `@.,;(){}[]+-*`, - `:::`, - `!!=<<=>>=`, - `||&&`, - `// single line comment`, - `/*`, - `multiline comment`, - `// embedded comment does nothing`, - `*/`, - `'/%|&=`, - } - for _, tt := range tests { - f.Add(tt) - } - f.Fuzz(func(t *testing.T, orig string) { - toks, err := Tokenize([]byte(orig)) - if err != nil { - if toks != nil { - t.Errorf("toks != nil on err") - } - } - }) -} - -func FuzzParse(f *testing.F) { - tests := []string{ - `permit(principal,action,resource);`, - `forbid(principal,action,resource);`, - `permit(principal,action,resource in asdf::"1234");`, - `permit(principal,action,resource) when { resource in "foo" };`, - `permit(principal,action,resource) when { context.x == 42 };`, - `permit(principal,action,resource) when { context.x == 42 };`, - `permit(principal,action,resource) when { principal.x == 42 };`, - `permit(principal,action,resource) when { principal.x == 42 };`, - `permit(principal,action,resource) when { principal in parent::"bob" };`, - `permit(principal == coder::"cuzco",action,resource);`, - `permit(principal in team::"osiris",action,resource);`, - `permit(principal,action == table::"drop",resource);`, - `permit(principal,action in scary::"stuff",resource);`, - `permit(principal,action in [scary::"stuff"],resource);`, - `permit(principal,action,resource == table::"whatever");`, - `permit(principal,action,resource) unless { false };`, - `permit(principal,action,resource) when { (if true then true else true) };`, - `permit(principal,action,resource) when { (true || false) };`, - `permit(principal,action,resource) when { (true && true) };`, - `permit(principal,action,resource) when { (1<2) && (1<=1) && (2>1) && (1>=1) && (1!=2) && (1==1)};`, - `permit(principal,action,resource) when { principal in principal };`, - `permit(principal,action,resource) when { principal has name };`, - `permit(principal,action,resource) when { 40+3-1==42 };`, - `permit(principal,action,resource) when { 6*7==42 };`, - `permit(principal,action,resource) when { -42==-42 };`, - `permit(principal,action,resource) when { !(1+1==42) };`, - `permit(principal,action,resource) when { [1,2,3].contains(2) };`, - `permit(principal,action,resource) when { {name:"bob"} has name };`, - `permit(principal,action,resource) when { action in action };`, - `permit(principal,action,resource) when { [1,2,3].contains(2) };`, - `permit(principal,action,resource) when { [1,2,3].contains(2,3) };`, - `permit(principal,action,resource) when { [1,2,3].containsAll([2,3]) };`, - `permit(principal,action,resource) when { [1,2,3].containsAll(2,3) };`, - `permit(principal,action,resource) when { [1,2,3].containsAny([2,5]) };`, - `permit(principal,action,resource) when { [1,2,3].containsAny(2,5) };`, - `permit(principal,action,resource) when { {name:"bob"}["name"] == "bob" };`, - `permit(principal,action,resource) when { [1,2,3].shuffle() };`, - `permit(principal,action,resource) when { "bananas" like "*nan*" };`, - `permit(principal,action,resource) when { fooBar("10") };`, - `permit(principal,action,resource) when { decimal(1, 2) };`, - `permit(principal,action,resource) when { ip() };`, - `permit(principal,action,resource) when { ip("1.2.3.4").isIpv4(true) };`, - `permit(principal,action,resource) when { ip("1.2.3.4").isIpv6(true) };`, - `permit(principal,action,resource) when { ip("1.2.3.4").isLoopback(true) };`, - `permit(principal,action,resource) when { ip("1.2.3.4").isMulticast(true) };`, - `permit(principal,action,resource) when { ip("1.2.3.4").isInRange() };`, - } - for _, tt := range tests { - f.Add(tt) - } - f.Fuzz(func(_ *testing.T, orig string) { - toks, err := Tokenize([]byte(orig)) - if err != nil { - return - } - // intentionally ignore parse errors - _, _ = Parse(toks) - }) -} diff --git a/x/exp/parser2/parse.go b/x/exp/parser2/parse.go deleted file mode 100644 index a6ca0690..00000000 --- a/x/exp/parser2/parse.go +++ /dev/null @@ -1,1462 +0,0 @@ -package parser - -import ( - "fmt" - "strconv" - "strings" -) - -func Parse(tokens []Token) (Policies, error) { - p := &parser{Tokens: tokens} - return p.Policies() -} - -func ParseEntity(tokens []Token) (Entity, error) { - p := &parser{Tokens: tokens} - return p.Entity() -} - -type parser struct { - Tokens []Token - Pos int -} - -func (p *parser) advance() Token { - t := p.peek() - if p.Pos < len(p.Tokens)-1 { - p.Pos++ - } - return t -} - -func (p *parser) peek() Token { - return p.Tokens[p.Pos] -} - -func (p *parser) exact(tok string) error { - t := p.advance() - if t.Text != tok { - return p.errorf("exact got %v want %v", t.Text, tok) - } - return nil -} - -func (p *parser) errorf(s string, args ...interface{}) error { - var t Token - if p.Pos < len(p.Tokens) { - t = p.Tokens[p.Pos] - } - err := fmt.Errorf(s, args...) - return fmt.Errorf("parse error at %v %q: %w", t.Pos, t.Text, err) -} - -// Policies := {Policy} - -type Policies []Policy - -func (c Policies) String() string { - var sb strings.Builder - for i, p := range c { - if i > 0 { - sb.WriteRune('\n') - } - sb.WriteString(p.String()) - } - return sb.String() -} - -func (p *parser) Policies() (Policies, error) { - var res Policies - for !p.peek().isEOF() { - policy, err := p.policy() - if err != nil { - return nil, err - } - res = append(res, policy) - } - return res, nil -} - -// Policy := {Annotation} Effect '(' Scope ')' {Conditions} ';' -// Scope := Principal ',' Action ',' Resource - -type Policy struct { - Position Position - Annotations []Annotation - Effect Effect - Principal Principal - Action Action - Resource Resource - Conditions []Condition -} - -func (p Policy) String() string { - var sb strings.Builder - for i, a := range p.Annotations { - if i > 0 { - sb.WriteRune('\n') - } - sb.WriteString(a.String()) - } - sb.WriteString(fmt.Sprintf("%s(\n%s,\n%s,\n%s\n)", - p.Effect, p.Principal, p.Action, p.Resource, - )) - for _, c := range p.Conditions { - sb.WriteRune('\n') - sb.WriteString(c.String()) - } - sb.WriteString(";") - return sb.String() -} - -func (p *parser) policy() (Policy, error) { - var res Policy - res.Position = p.peek().Pos - var err error - if res.Annotations, err = p.annotations(); err != nil { - return res, err - } - if res.Effect, err = p.effect(); err != nil { - return res, err - } - if err := p.exact("("); err != nil { - return res, err - } - if res.Principal, err = p.principal(); err != nil { - return res, err - } - if err := p.exact(","); err != nil { - return res, err - } - if res.Action, err = p.action(); err != nil { - return res, err - } - if err := p.exact(","); err != nil { - return res, err - } - if res.Resource, err = p.resource(); err != nil { - return res, err - } - if err := p.exact(")"); err != nil { - return res, err - } - if res.Conditions, err = p.conditions(); err != nil { - return res, err - } - if err := p.exact(";"); err != nil { - return res, err - } - return res, nil -} - -// Annotation := '@'IDENT'('STR')' - -type Annotation struct { - Key string - Value string -} - -func (a Annotation) String() string { - return fmt.Sprintf("@%s(%q)", a.Key, a.Value) -} - -func (p *parser) annotation() (Annotation, error) { - var res Annotation - var err error - t := p.advance() - if !t.isIdent() { - return res, p.errorf("expected ident") - } - res.Key = t.Text - if err := p.exact("("); err != nil { - return res, err - } - t = p.advance() - if !t.isString() { - return res, p.errorf("expected string") - } - if res.Value, err = t.stringValue(); err != nil { - return res, err - } - if err := p.exact(")"); err != nil { - return res, err - } - return res, nil -} - -func (p *parser) annotations() ([]Annotation, error) { - var res []Annotation - for p.peek().Text == "@" { - p.advance() - a, err := p.annotation() - if err != nil { - return res, err - } - for _, aa := range res { - if aa.Key == a.Key { - return res, p.errorf("duplicate annotation") - } - } - res = append(res, a) - } - return res, nil -} - -// Effect := 'permit' | 'forbid' - -type Effect string - -const ( - EffectPermit = Effect("permit") - EffectForbid = Effect("forbid") -) - -func (p *parser) effect() (Effect, error) { - next := p.advance() - res := Effect(next.Text) - switch res { - case EffectForbid: - case EffectPermit: - default: - return res, p.errorf("unexpected effect: %v", res) - } - return res, nil -} - -// MatchType - -type MatchType int - -const ( - MatchAny = MatchType(iota) - MatchEquals - MatchIn - MatchInList - MatchIs - MatchIsIn -) - -// Principal := 'principal' [('in' | '==') Entity] - -type Principal struct { - Type MatchType - Path Path - Entity Entity -} - -func (p Principal) String() string { - var res string - switch p.Type { - case MatchAny: - res = "principal" - case MatchEquals: - res = fmt.Sprintf("principal == %s", p.Entity) - case MatchIs: - res = fmt.Sprintf("principal is %s", p.Path) - case MatchIsIn: - res = fmt.Sprintf("principal is %s in %s", p.Path, p.Entity) - case MatchIn: - res = fmt.Sprintf("principal in %s", p.Entity) - } - return res -} - -func (p *parser) principal() (Principal, error) { - var res Principal - if err := p.exact("principal"); err != nil { - return res, err - } - switch p.peek().Text { - case "==": - p.advance() - var err error - res.Type = MatchEquals - res.Entity, err = p.Entity() - return res, err - case "is": - p.advance() - var err error - res.Type = MatchIs - res.Path, err = p.Path() - if err == nil && p.peek().Text == "in" { - p.advance() - res.Type = MatchIsIn - res.Entity, err = p.Entity() - return res, err - } - return res, err - case "in": - p.advance() - var err error - res.Type = MatchIn - res.Entity, err = p.Entity() - return res, err - default: - return Principal{ - Type: MatchAny, - }, nil - } -} - -// Action := 'action' [( '==' Entity | 'in' ('[' EntList ']' | Entity) )] - -type Action struct { - Type MatchType - Entities []Entity -} - -func (a Action) String() string { - var sb strings.Builder - switch a.Type { - case MatchAny: - sb.WriteString("action") - case MatchEquals: - sb.WriteString(fmt.Sprintf("action == %s", a.Entities[0])) - case MatchIn: - sb.WriteString(fmt.Sprintf("action in %s", a.Entities[0])) - case MatchInList: - sb.WriteString("action in [") - for i, e := range a.Entities { - if i > 0 { - sb.WriteString(", ") - } - sb.WriteString(e.String()) - } - sb.WriteRune(']') - } - return sb.String() -} - -func (p *parser) action() (Action, error) { - var res Action - var err error - if err := p.exact("action"); err != nil { - return res, err - } - switch p.peek().Text { - case "==": - p.advance() - res.Type = MatchEquals - e, err := p.Entity() - if err != nil { - return res, err - } - res.Entities = append(res.Entities, e) - return res, nil - case "in": - p.advance() - if p.peek().Text == "[" { - res.Type = MatchInList - p.advance() - res.Entities, err = p.entlist() - if err != nil { - return res, err - } - p.advance() // entlist guarantees "]" - return res, nil - } else { - res.Type = MatchIn - e, err := p.Entity() - if err != nil { - return res, err - } - res.Entities = append(res.Entities, e) - return res, nil - } - default: - return Action{ - Type: MatchAny, - }, nil - } -} - -// Resource := 'resource' [('in' | '==') Entity)] - -type Resource struct { - Type MatchType - Path Path - Entity Entity -} - -func (r Resource) String() string { - var res string - switch r.Type { - case MatchAny: - res = "resource" - case MatchEquals: - res = fmt.Sprintf("resource == %s", r.Entity) - case MatchIs: - res = fmt.Sprintf("resource is %s", r.Path) - case MatchIsIn: - res = fmt.Sprintf("resource is %s in %s", r.Path, r.Entity) - case MatchIn: - res = fmt.Sprintf("resource in %s", r.Entity) - } - return res -} - -func (p *parser) resource() (Resource, error) { - var res Resource - if err := p.exact("resource"); err != nil { - return res, err - } - switch p.peek().Text { - case "==": - p.advance() - var err error - res.Type = MatchEquals - res.Entity, err = p.Entity() - return res, err - case "is": - p.advance() - var err error - res.Type = MatchIs - res.Path, err = p.Path() - if err == nil && p.peek().Text == "in" { - p.advance() - res.Type = MatchIsIn - res.Entity, err = p.Entity() - return res, err - } - return res, err - case "in": - p.advance() - var err error - res.Type = MatchIn - res.Entity, err = p.Entity() - return res, err - default: - return Resource{ - Type: MatchAny, - }, nil - } -} - -// Entity := Path '::' STR - -type Entity struct { - Path []string -} - -func (e Entity) String() string { - return fmt.Sprintf( - "%s::%q", - strings.Join(e.Path[0:len(e.Path)-1], "::"), - e.Path[len(e.Path)-1], - ) -} - -func (p *parser) Entity() (Entity, error) { - var res Entity - t := p.advance() - if !t.isIdent() { - return res, p.errorf("expected ident") - } - res.Path = append(res.Path, t.Text) - for { - if err := p.exact("::"); err != nil { - return res, err - } - t := p.advance() - switch { - case t.isIdent(): - res.Path = append(res.Path, t.Text) - case t.isString(): - component, err := t.stringValue() - if err != nil { - return res, err - } - res.Path = append(res.Path, component) - return res, nil - default: - return res, p.errorf("unexpected token") - } - } -} - -// Path ::= IDENT {'::' IDENT} - -type Path struct { - Path []string -} - -func (e Path) String() string { - return strings.Join(e.Path, "::") -} - -func (p *parser) Path() (Path, error) { - var res Path - t := p.advance() - if !t.isIdent() { - return res, p.errorf("expected ident") - } - res.Path = append(res.Path, t.Text) - for { - if p.peek().Text != "::" { - return res, nil - } - p.advance() - t := p.advance() - switch { - case t.isIdent(): - res.Path = append(res.Path, t.Text) - default: - return res, p.errorf("unexpected token") - } - } -} - -// EntList := Entity {',' Entity} - -func (p *parser) entlist() ([]Entity, error) { - var res []Entity - for p.peek().Text != "]" { - if len(res) > 0 { - if err := p.exact(","); err != nil { - return res, err - } - } - e, err := p.Entity() - if err != nil { - return res, err - } - res = append(res, e) - } - return res, nil -} - -// Condition := ('when' | 'unless') '{' Expr '}' - -type ConditionType string - -const ( - ConditionWhen ConditionType = "when" - ConditionUnless ConditionType = "unless" -) - -type Condition struct { - Type ConditionType - Expression Expression -} - -func (c Condition) String() string { - var res string - switch c.Type { - case ConditionWhen: - res = fmt.Sprintf("when {\n%s\n}", c.Expression) - case ConditionUnless: - res = fmt.Sprintf("unless {\n%s\n}", c.Expression) - } - return res -} - -func (p *parser) condition() (Condition, error) { - var res Condition - var err error - res.Type = ConditionType(p.advance().Text) - if err := p.exact("{"); err != nil { - return res, err - } - if res.Expression, err = p.expression(); err != nil { - return res, err - } - if err := p.exact("}"); err != nil { - return res, err - } - return res, nil -} - -func (p *parser) conditions() ([]Condition, error) { - var res []Condition - for { - switch p.peek().Text { - case "when", "unless": - c, err := p.condition() - if err != nil { - return res, err - } - res = append(res, c) - default: - return res, nil - } - } -} - -// Expr := Or | If - -type ExpressionType int - -const ( - ExpressionOr ExpressionType = iota - ExpressionIf -) - -type Expression struct { - Type ExpressionType - Or Or - If *If -} - -func (e Expression) String() string { - var res string - switch e.Type { - case ExpressionOr: - res = e.Or.String() - case ExpressionIf: - res = e.If.String() - } - return res -} - -func (p *parser) expression() (Expression, error) { - var res Expression - var err error - if p.peek().Text == "if" { - p.advance() - res.Type = ExpressionIf - i, err := p.ifExpr() - if err != nil { - return res, err - } - res.If = &i - return res, nil - } else { - res.Type = ExpressionOr - if res.Or, err = p.or(); err != nil { - return res, err - } - return res, nil - } -} - -// If := 'if' Expr 'then' Expr 'else' Expr - -type If struct { - If Expression - Then Expression - Else Expression -} - -func (i If) String() string { - return fmt.Sprintf("if %s then %s else %s", i.If, i.Then, i.Else) -} - -func (p *parser) ifExpr() (If, error) { - var res If - var err error - if res.If, err = p.expression(); err != nil { - return res, err - } - if err = p.exact("then"); err != nil { - return res, err - } - if res.Then, err = p.expression(); err != nil { - return res, err - } - if err = p.exact("else"); err != nil { - return res, err - } - if res.Else, err = p.expression(); err != nil { - return res, err - } - return res, err -} - -// Or := And {'||' And} - -type Or struct { - Ands []And -} - -func (o Or) String() string { - var sb strings.Builder - for i, and := range o.Ands { - if i > 0 { - sb.WriteString(" || ") - } - sb.WriteString(and.String()) - } - return sb.String() -} - -func (p *parser) or() (Or, error) { - var res Or - for { - a, err := p.and() - if err != nil { - return res, err - } - res.Ands = append(res.Ands, a) - if p.peek().Text != "||" { - return res, nil - } - p.advance() - } -} - -// And := Relation {'&&' Relation} - -type And struct { - Relations []Relation -} - -func (a And) String() string { - var sb strings.Builder - for i, rel := range a.Relations { - if i > 0 { - sb.WriteString(" && ") - } - sb.WriteString(rel.String()) - } - return sb.String() -} - -func (p *parser) and() (And, error) { - var res And - for { - r, err := p.relation() - if err != nil { - return res, err - } - res.Relations = append(res.Relations, r) - if p.peek().Text != "&&" { - return res, nil - } - p.advance() - } -} - -// Relation := Add [RELOP Add] | Add 'has' (IDENT | STR) | Add 'like' PAT - -type RelationType string - -const ( - RelationNone RelationType = "none" - RelationRelOp RelationType = "relop" - RelationHasIdent RelationType = "hasident" - RelationHasLiteral RelationType = "hasliteral" - RelationLike RelationType = "like" - RelationIs RelationType = "is" - RelationIsIn RelationType = "isIn" -) - -type Relation struct { - Add Add - Type RelationType - RelOp RelOp - RelOpRhs Add - Str string - Pat Pattern - Path Path - Entity Add -} - -func (r Relation) String() string { - var sb strings.Builder - sb.WriteString(r.Add.String()) - switch r.Type { - case RelationNone: - case RelationRelOp: - sb.WriteString(" ") - sb.WriteString(string(r.RelOp)) - sb.WriteString(" ") - sb.WriteString(r.RelOpRhs.String()) - case RelationHasIdent: - sb.WriteString(" has ") - sb.WriteString(r.Str) - case RelationHasLiteral: - sb.WriteString(" has ") - sb.WriteString(strconv.Quote(r.Str)) - case RelationLike: - sb.WriteString(" like ") - sb.WriteString(r.Pat.String()) - case RelationIs: - sb.WriteString(" is ") - sb.WriteString(r.Path.String()) - case RelationIsIn: - sb.WriteString(" is ") - sb.WriteString(r.Path.String()) - sb.WriteString(" in ") - sb.WriteString(r.Entity.String()) - } - return sb.String() -} - -func (p *parser) relation() (Relation, error) { - var res Relation - var err error - if res.Add, err = p.add(); err != nil { - return res, err - } - - t := p.peek() - switch t.Text { - case "<", "<=", ">=", ">", "!=", "==", "in": - p.advance() - res.Type = RelationRelOp - res.RelOp = RelOp(t.Text) - if res.RelOpRhs, err = p.add(); err != nil { - return res, err - } - case "has": - p.advance() - t := p.advance() - switch { - case t.isIdent(): - res.Type = RelationHasIdent - res.Str = t.Text - case t.isString(): - res.Type = RelationHasLiteral - if res.Str, err = t.stringValue(); err != nil { - return res, err - } - default: - return res, p.errorf("unexpected token") - } - case "like": - p.advance() - res.Type = RelationLike - t := p.advance() - if !t.isString() { - return res, p.errorf("unexpected token") - } - if res.Pat, err = t.patternValue(); err != nil { - return res, err - } - case "is": - p.advance() - var err error - res.Type = RelationIs - res.Path, err = p.Path() - if err == nil && p.peek().Text == "in" { - p.advance() - res.Type = RelationIsIn - res.Entity, err = p.add() - return res, err - } - return res, err - default: - res.Type = RelationNone - } - return res, nil -} - -// RELOP := '<' | '<=' | '>=' | '>' | '!=' | '==' | 'in' - -type RelOp string - -const ( - RelOpLt RelOp = "<" - RelOpLe RelOp = "<=" - RelOpGe RelOp = ">=" - RelOpGt RelOp = ">" - RelOpNe RelOp = "!=" - RelOpEq RelOp = "==" - RelOpIn RelOp = "in" -) - -// Add := Mult {ADDOP Mult} - -type Add struct { - Mults []Mult - AddOps []AddOp -} - -func (a Add) String() string { - var sb strings.Builder - sb.WriteString(a.Mults[0].String()) - for i, op := range a.AddOps { - sb.WriteString(fmt.Sprintf(" %s %s", op, a.Mults[i+1].String())) - } - return sb.String() -} - -func (p *parser) add() (Add, error) { - var res Add - var err error - mult, err := p.mult() - if err != nil { - return res, err - } - res.Mults = append(res.Mults, mult) - for { - op := AddOp(p.peek().Text) - switch op { - case AddOpAdd, AddOpSub: - default: - return res, nil - } - p.advance() - mult, err := p.mult() - if err != nil { - return res, err - } - res.AddOps = append(res.AddOps, op) - res.Mults = append(res.Mults, mult) - } -} - -// ADDOP := '+' | '-' - -type AddOp string - -const ( - AddOpAdd AddOp = "+" - AddOpSub AddOp = "-" -) - -// Mult := Unary { '*' Unary} - -type Mult struct { - Unaries []Unary -} - -func (m Mult) String() string { - var sb strings.Builder - for i, u := range m.Unaries { - if i > 0 { - sb.WriteString(" * ") - } - sb.WriteString(u.String()) - } - return sb.String() -} - -func (p *parser) mult() (Mult, error) { - var res Mult - for { - u, err := p.unary() - if err != nil { - return res, err - } - res.Unaries = append(res.Unaries, u) - if p.peek().Text != "*" { - return res, nil - } - p.advance() - } -} - -// Unary := [UNARYOP]x4 Member - -type Unary struct { - Ops []UnaryOp - Member Member -} - -func (u Unary) String() string { - var sb strings.Builder - for _, o := range u.Ops { - sb.WriteString(string(o)) - } - sb.WriteString(u.Member.String()) - return sb.String() -} - -func (p *parser) unary() (Unary, error) { - var res Unary - for { - o := UnaryOp(p.peek().Text) - switch o { - case UnaryOpNot: - p.advance() - res.Ops = append(res.Ops, o) - case UnaryOpMinus: - p.advance() - if p.peek().isInt() { - t := p.advance() - i, err := strconv.ParseInt("-"+t.Text, 10, 64) - if err != nil { - return res, err - } - res.Member = Member{ - Primary: Primary{ - Type: PrimaryLiteral, - Literal: Literal{ - Type: LiteralInt, - Long: i, - }, - }, - } - return res, nil - } - res.Ops = append(res.Ops, o) - default: - var err error - res.Member, err = p.member() - if err != nil { - return res, err - } - return res, nil - } - } -} - -// UNARYOP := '!' | '-' - -type UnaryOp string - -const ( - UnaryOpNot UnaryOp = "!" - UnaryOpMinus UnaryOp = "-" -) - -// Member := Primary {Access} - -type Member struct { - Primary Primary - Accesses []Access -} - -func (m Member) String() string { - var sb strings.Builder - sb.WriteString(m.Primary.String()) - for _, a := range m.Accesses { - sb.WriteString(a.String()) - } - return sb.String() -} - -func (p *parser) member() (Member, error) { - var res Member - var err error - if res.Primary, err = p.primary(); err != nil { - return res, err - } - for { - a, ok, err := p.access() - if !ok { - return res, err - } else { - res.Accesses = append(res.Accesses, a) - } - } -} - -// Primary := LITERAL -// | VAR -// | Entity -// | ExtFun '(' [ExprList] ')' -// | '(' Expr ')' -// | '[' [ExprList] ']' -// | '{' [RecInits] '}' - -type PrimaryType int - -const ( - PrimaryLiteral PrimaryType = iota - PrimaryVar - PrimaryEntity - PrimaryExtFun - PrimaryExpr - PrimaryExprList - PrimaryRecInits -) - -type Primary struct { - Type PrimaryType - Literal Literal - Var Var - Entity Entity - ExtFun ExtFun - Expression Expression - Expressions []Expression - RecInits []RecInit -} - -func (p Primary) String() string { - var res string - switch p.Type { - case PrimaryLiteral: - res = p.Literal.String() - case PrimaryVar: - res = p.Var.String() - case PrimaryEntity: - res = p.Entity.String() - case PrimaryExtFun: - res = p.ExtFun.String() - case PrimaryExpr: - res = fmt.Sprintf("(%s)", p.Expression) - case PrimaryExprList: - var sb strings.Builder - sb.WriteRune('[') - for i, e := range p.Expressions { - if i > 0 { - sb.WriteString(", ") - } - sb.WriteString(e.String()) - } - sb.WriteRune(']') - res = sb.String() - case PrimaryRecInits: - var sb strings.Builder - sb.WriteRune('{') - for i, r := range p.RecInits { - if i > 0 { - sb.WriteString(", ") - } - sb.WriteString(r.String()) - } - sb.WriteRune('}') - res = sb.String() - } - return res -} - -func (p *parser) primary() (Primary, error) { - var res Primary - var err error - t := p.advance() - switch { - case t.isInt(): - i, err := t.intValue() - if err != nil { - return res, err - } - res.Type = PrimaryLiteral - res.Literal = Literal{ - Type: LiteralInt, - Long: i, - } - case t.isString(): - res.Type = PrimaryLiteral - res.Literal.Type = LiteralString - if res.Literal.Str, err = t.stringValue(); err != nil { - return res, err - } - case t.Text == "true", t.Text == "false": - res.Type = PrimaryLiteral - res.Literal = Literal{ - Type: LiteralBool, - Bool: t.Text == "true", - } - case t.Text == string(VarPrincipal), - t.Text == string(VarAction), - t.Text == string(VarResource), - t.Text == string(VarContext): - res.Type = PrimaryVar - res.Var = Var{ - Type: VarType(t.Text), - } - case t.isIdent(): - e, f, err := p.entityOrExtFun(t.Text) - switch { - case e != nil: - res.Type = PrimaryEntity - res.Entity = *e - case f != nil: - res.Type = PrimaryExtFun - res.ExtFun = *f - default: - return res, err - } - case t.Text == "(": - res.Type = PrimaryExpr - if res.Expression, err = p.expression(); err != nil { - return res, err - } - if err := p.exact(")"); err != nil { - return res, err - } - case t.Text == "[": - res.Type = PrimaryExprList - if res.Expressions, err = p.expressions("]"); err != nil { - return res, err - } - p.advance() // expressions guarantees "]" - return res, err - case t.Text == "{": - res.Type = PrimaryRecInits - if res.RecInits, err = p.recInits(); err != nil { - return res, err - } - return res, err - default: - return res, p.errorf("invalid primary") - } - return res, nil -} - -func (p *parser) entityOrExtFun(first string) (*Entity, *ExtFun, error) { - path := []string{first} - for { - if p.peek().Text != "::" { - f, err := p.extFun(path) - if err != nil { - return nil, nil, err - } - return nil, &f, err - } - p.advance() - t := p.advance() - switch { - case t.isIdent(): - path = append(path, t.Text) - case t.isString(): - component, err := t.stringValue() - if err != nil { - return nil, nil, err - } - path = append(path, component) - return &Entity{Path: path}, nil, nil - default: - return nil, nil, p.errorf("unexpected token") - } - } -} - -func (p *parser) expressions(endOfListMarker string) ([]Expression, error) { - var res []Expression - for p.peek().Text != endOfListMarker { - if len(res) > 0 { - if err := p.exact(","); err != nil { - return res, err - } - } - e, err := p.expression() - if err != nil { - return res, err - } - res = append(res, e) - } - return res, nil -} - -func (p *parser) recInits() ([]RecInit, error) { - var res []RecInit - for { - t := p.peek() - if t.Text == "}" { - p.advance() - return res, nil - } - if len(res) > 0 { - if err := p.exact(","); err != nil { - return res, err - } - } - e, err := p.recInit() - if err != nil { - return res, err - } - res = append(res, e) - } -} - -// LITERAL := BOOL | INT | STR - -type LiteralType int - -const ( - LiteralBool LiteralType = iota - LiteralInt - LiteralString -) - -type Literal struct { - Type LiteralType - Bool bool - Long int64 - Str string -} - -func (l Literal) String() string { - var res string - switch l.Type { - case LiteralBool: - res = strconv.FormatBool(l.Bool) - case LiteralInt: - res = strconv.FormatInt(l.Long, 10) - case LiteralString: - res = strconv.Quote(l.Str) - } - return res -} - -// VAR := 'principal' | 'action' | 'resource' | 'context' - -type VarType string - -const ( - VarPrincipal VarType = "principal" - VarAction VarType = "action" - VarResource VarType = "resource" - VarContext VarType = "context" -) - -type Var struct { - Type VarType -} - -func (v Var) String() string { - return string(v.Type) -} - -// ExtFun := [Path '::'] IDENT - -type ExtFun struct { - Path []string - Expressions []Expression -} - -func (f ExtFun) String() string { - var sb strings.Builder - sb.WriteString(strings.Join(f.Path, "::")) - sb.WriteRune('(') - for i, e := range f.Expressions { - if i > 0 { - sb.WriteString(", ") - } - sb.WriteString(e.String()) - } - sb.WriteRune(')') - return sb.String() -} - -func (p *parser) extFun(path []string) (ExtFun, error) { - res := ExtFun{Path: path} - if err := p.exact("("); err != nil { - return res, err - } - var err error - if res.Expressions, err = p.expressions(")"); err != nil { - return res, err - } - p.advance() // expressions guarantees ")" - return res, err -} - -// Access := '.' IDENT ['(' [ExprList] ')'] | '[' STR ']' - -type AccessType int - -const ( - AccessField AccessType = iota - AccessCall - AccessIndex -) - -type Access struct { - Type AccessType - Name string - Expressions []Expression -} - -func (a Access) String() string { - var sb strings.Builder - switch a.Type { - case AccessField: - sb.WriteRune('.') - sb.WriteString(a.Name) - case AccessCall: - sb.WriteRune('.') - sb.WriteString(a.Name) - sb.WriteRune('(') - for i, e := range a.Expressions { - if i > 0 { - sb.WriteString(", ") - } - sb.WriteString(e.String()) - } - sb.WriteRune(')') - case AccessIndex: - sb.WriteRune('[') - sb.WriteString(strconv.Quote(a.Name)) - sb.WriteRune(']') - } - return sb.String() -} - -func (p *parser) access() (Access, bool, error) { - var res Access - var err error - t := p.peek() - switch t.Text { - case ".": - p.advance() - t := p.advance() - if !t.isIdent() { - return res, false, p.errorf("unexpected token") - } - res.Name = t.Text - if p.peek().Text == "(" { - p.advance() - res.Type = AccessCall - if res.Expressions, err = p.expressions(")"); err != nil { - return res, false, err - } - p.advance() // expressions guarantees ")" - } else { - res.Type = AccessField - } - case "[": - p.advance() - res.Type = AccessIndex - t := p.advance() - if !t.isString() { - return res, false, p.errorf("unexpected token") - } - if res.Name, err = t.stringValue(); err != nil { - return res, false, err - } - if err := p.exact("]"); err != nil { - return res, false, err - } - default: - return res, false, nil - } - return res, true, nil -} - -// RecInits := (IDENT | STR) ':' Expr {',' (IDENT | STR) ':' Expr} - -type RecKeyType int - -const ( - RecKeyIdent RecKeyType = iota - RecKeyString -) - -type RecInit struct { - KeyType RecKeyType - Key string - Value Expression -} - -func (r RecInit) String() string { - var sb strings.Builder - switch r.KeyType { - case RecKeyIdent: - sb.WriteString(r.Key) - case RecKeyString: - sb.WriteString(strconv.Quote(r.Key)) - } - sb.WriteString(": ") - sb.WriteString(r.Value.String()) - return sb.String() -} - -func (p *parser) recInit() (RecInit, error) { - var res RecInit - var err error - t := p.advance() - switch { - case t.isIdent(): - res.KeyType = RecKeyIdent - res.Key = t.Text - case t.isString(): - res.KeyType = RecKeyString - if res.Key, err = t.stringValue(); err != nil { - return res, err - } - default: - return res, p.errorf("unexpected token") - } - if err := p.exact(":"); err != nil { - return res, err - } - if res.Value, err = p.expression(); err != nil { - return res, err - } - return res, nil -} diff --git a/x/exp/parser2/parse_test.go b/x/exp/parser2/parse_test.go deleted file mode 100644 index 7a6c0bb4..00000000 --- a/x/exp/parser2/parse_test.go +++ /dev/null @@ -1,465 +0,0 @@ -package parser - -import ( - "testing" - - "github.com/cedar-policy/cedar-go/testutil" -) - -func TestParse(t *testing.T) { - t.Parallel() - tests := []struct { - name string - in string - err bool - }{ - // Success cases - // Test cases from https://github.com/cedar-policy/cedar/blob/main/cedar-policy-core/src/parser/testfiles/policies.cedar - {"empty", ``, false}, - {"ex1", `//@test_annotation("This is the annotation") - permit( - principal == User::"alice", - action == PhotoOp::"view", - resource == Photo::"VacationPhoto94.jpg" - );`, false}, - {"ex2", `permit( - principal in Team::"admins", - action in [PhotoOp::"view", PhotoOp::"edit", PhotoOp::"delete"], - resource in Album::"jane_vacation" - );`, false}, - {"ex3", `permit( - principal == User::"alice", - action in PhotoflashRole::"admin", - resource in Album::"jane_vacation" - );`, false}, - {"simplest", `permit( - principal, - action, - resource - );`, false}, - {"in", `permit( - principal in Team::"eng", - action in PhotoflashRole::"admin", - resource in Album::"jane_vacation" - ); - - permit( - principal in Team::"eng", - action in [PhotoflashRole::"admin"], - resource in Album::"jane_vacation" - ); - - permit( - principal in Team::"eng", - action in [PhotoflashRole::"admin", PhotoflashRole::"operator"], - resource in Album::"jane_vacation" - ); - `, false}, - {"multipleIdentEntities", `permit( - principal == Org::Team::User::"alice", - action, - resource - );`, false}, - {"multiplePolicies", `permit( - principal, - action, - resource - ); - - forbid( - principal in Team::"admins", - action in [PhotoOp::"view", PhotoOp::"edit", PhotoOp::"delete"], - resource in Album::"jane_vacation" - ); - `, false}, - {"annotations", `@first_annotation("This is the annotation") - @second_annotation("This is another annotation") - permit( - principal, - action, - resource - );`, false}, - - // Additional success cases - {"primaryInt", `permit(principal, action, resource) when { 1234 };`, false}, - {"primaryString", `permit(principal, action, resource) when { "test string" };`, false}, - {"primaryBool", `permit(principal, action, resource) when { true } unless { false };`, false}, - {"primaryVar", `permit(principal, action, resource) - when { principal } - unless { action } - when { resource } - unless { context }; - `, false}, - {"primaryEntity", `permit(principal, action, resource) - when { Org::User::"alice" }; - `, false}, - {"primaryExtFun", `permit(principal, action, resource) - when { foo() } - unless { foo::bar::as() } - when { foo("hello") } - unless { foo::bar(true, 42, "forty two") }; - `, false}, - {"ifElseThen", `permit(principal, action, resource) - when { if false then principal else principal };`, false}, - {"access", `permit(principal, action, resource) - when { resource.foo } - unless { resource.foo.bar } - when { principal.foo() } - unless { principal.bar(false) } - when { action.foo["bar"].baz() } - unless { principal.bar(false, 123, "foo") } - when { principal["foo"] };`, false}, - {"unary", `permit(principal, action, resource) - when { !resource.foo } - unless { -resource.bar } - when { !!resource.foo } - unless { --resource.bar } - when { !-!-resource.bar };`, false}, - {"mult", `permit(principal, action, resource) - when { resource.foo * 42 } - unless { 42 * resource.bar } - when { 42 * resource.bar * 43 } - when { resource.foo * principal.bar };`, false}, - {"add", `permit(principal, action, resource) - when { resource.foo + 42 } - unless { 42 - resource.bar } - when { 42 + resource.bar - 43 } - when { resource.foo + principal.bar };`, false}, - {"relations", `permit(principal, action, resource) - when { foo() } - unless { foo() < 3 } - unless { foo() <= 3 } - unless { foo() > 3 } - unless { foo() >= 3 } - unless { foo() != 3 } - unless { foo() == 3 } - unless { foo() in Domain::"value" } - unless { foo() has blah } - when { foo() has "bar" } - when { foo() like "h*ll*" };`, false}, - {"foo-like-foo", `permit(principal, action, resource) - when { "f*o" like "f\*o" };`, false}, - {"ands", `permit(principal, action, resource) - when { foo() && bar() && 3};`, false}, - {"ors_and_ands", `permit(principal, action, resource) - when { foo() && bar() || baz() || 1 < 2 && 2 < 3};`, false}, - {"primaryExpression", `permit(principal, action, resource) - when { (true) } - unless { ((if (foo() <= 234) then principal else principal) like "") };`, false}, - {"primaryExprList", `permit(principal, action, resource) - when { [] } - unless { [true] } - when { [123, (principal has "name" && principal.name == "alice")]};`, false}, - {"primaryRecInits", `permit(principal, action, resource) - when { {} } - unless { {"key": "some value"} } - when { {"key": "some value", id: "another value"} };`, false}, - {"most-positive-long", - `permit(principal,action,resource) when { 9223372036854775807 == -(-9223372036854775807) };`, - false}, - {"principal-is", `permit (principal is X, action, resource);`, false}, - {"principal-is-long", `permit (principal is X::Y, action, resource);`, false}, - {"principal-is-in", `permit (principal is X in X::"z", action, resource);`, false}, - {"resource-is", `permit (principal, action, resource is X);`, false}, - {"resource-is-long", `permit (principal, action, resource is X::Y);`, false}, - {"resource-is-in", `permit (principal, action, resource is X in X::"z");`, false}, - {"when-is", `permit (principal, action, resource) when { principal is X };`, false}, - {"when-is-long", `permit (principal, action, resource) when { principal is X::Y };`, false}, - {"when-is-in", `permit (principal, action, resource) when { principal is X in X::"z" };`, false}, - - {"most-negative-long", `permit(principal,action,resource) when { -9223372036854775808 == -9223372036854775808 };`, false}, - {"most-negative-long2", `permit(principal,action,resource) when { -9223372036854775808 < -9223372036854775807 };`, false}, - - // Error cases - {"missingEffect", `@id("test")`, true}, - {"invalidEffect", `invalidEffect(principal, action, resource);`, true}, - {"missingSemicolon", `permit(principal, action, resource)`, true}, - {"missingScope", `permit;`, true}, - {"missingPrincipal", `permit(resource, action);`, true}, - {"missingResourceAndAction", `permit(principal);`, true}, - {"missingResource", `permit(principal, action);`, true}, - {"eofInScope", `permit(principal`, true}, - {"missingAction", `permit(principal, resource);`, true}, - {"invalidResource", `permit(principal, action, other);`, true}, - {"missingScopeEndParen", `permit(principal, action, resource;`, true}, - {"missingEntity", `permit(principal ==`, true}, - {"invalidEntity", `permit(principal == "alice", action, resource);`, true}, - {"invalidEntity2", `permit(principal == User::, action, resource);`, true}, - {"invalidEntity3", `permit(principal == User::123, action, resource);`, true}, - {"invalidEntity3", `permit(principal == User::`, true}, - {"invalidEntities", `permit(principal, action in [invalidEntity], resource);`, true}, - {"invalidEntities2", `permit(principal, action in [User::"alice", invalidEntity], resource);`, true}, - {"invalidEntities3", `permit(principal, action in [User::"alice";], resource);`, true}, - {"invalidEntities4", `permit(principal, action in [User::"alice"`, true}, - {"invalidAnnotation1", `@`, true}, - {"invalidAnnotation2", `@"annotate"`, true}, - {"invalidAnnotation3", `@annotate(`, true}, - {"invalidAnnotation4", `@annotate[""]`, true}, - {"invalidAnnotation5", `@annotate("test"]`, true}, - {"invalidAnnotation6", `@annotate(test)`, true}, - {"invalidCondition1", `permit(principal, action, resource) when`, true}, - {"invalidCondition2", `permit(principal, action, resource) when {`, true}, - {"invalidCondition3", `permit(principal, action, resource) when {}`, true}, - {"invalidCondition4", `permit(principal, action, resource) when { true`, true}, - {"invalidPrimaryInteger", `permit(principal, action, resource) - when { 0xabcd };`, true}, - {"invalidPrimary", `permit(principal, action, resource) - when { ( };`, true}, - {"invalidExtFun1", `permit(principal, action, resource) - when { abcd`, true}, - {"invalidExtFun2", `permit(principal, action, resource) - when { abcd(`, true}, - {"invalidExtFun3", `permit(principal, action, resource) - when { abcd::`, true}, - {"invalidExtFun4", `permit(principal, action, resource) - when { abcd::123`, true}, - {"invalidExtFun5", `permit(principal, action, resource) - when { abcd(123`, true}, - {"invalidIfElseThen1", `permit(principal, action, resource) - when { if }`, true}, - {"invalidIfElseThen2", `permit(principal, action, resource) - when { if true }`, true}, - {"invalidIfElseThen3", `permit(principal, action, resource) - when { if true then }`, true}, - {"invalidIfElseThen4", `permit(principal, action, resource) - when { if true then principal }`, true}, - {"invalidIfElseThen5", `permit(principal, action, resource) - when { if true then principal else }`, true}, - {"invalidAccess1", `permit(principal, action, resource) - when { resource.`, true}, - {"invalidAccess2", `permit(principal, action, resource) - when { resource.bar.123 };`, true}, - {"invalidAccess3", `permit(principal, action, resource) - when { resource.bar(`, true}, - {"invalidAccess4", `permit(principal, action, resource) - when { resource.bar(]`, true}, - {"invalidAccess5", `permit(principal, action, resource) - when { resource.bar(,)`, true}, - {"invalidAccess6", `permit(principal, action, resource) - when { resource.bar[`, true}, - {"invalidAccess7", `permit(principal, action, resource) - when { resource.bar[baz]`, true}, - {"invalidAccess8", `permit(principal, action, resource) - when { resource.bar["baz")`, true}, - {"invalidUnaryOp", `permit(principal, action, resource) - when { +resource.bar };`, true}, - {"invalidAdd", `permit(principal, action, resource) - when { resource.foo +`, true}, - {"invalidRelation", `permit(principal, action, resource) - when { resource.name in`, true}, - {"invalidHas1", `permit(principal, action, resource) - when { resource.name has`, true}, - {"invalidHas2", `permit(principal, action, resource) - when { resource.name has 123`, true}, - {"invalidLike1", `permit(principal, action, resource) - when { resource.name like`, true}, - {"invalidLike2", `permit(principal, action, resource) - when { resource.name like foo`, true}, - {"invalidPrimaryExpr", `permit(principal, action, resource) - when { (true`, true}, - {"invalidPrimaryExprList", `permit(principal, action, resource) - when { [`, true}, - {"invalidActionEqRhs", `permit(principal, action == Foo, resource);`, true}, - {"invalidActionInRhs", `permit(principal, action in Foo, resource);`, true}, - {"invalidPrimaryRecInits1", `permit(principal, action, resource) - when { {`, true}, - {"invalidPrimaryRecInits2", `permit(principal, action, resource) - when { {123: "value"} };`, true}, - {"invalidPrimaryRecInits3", `permit(principal, action, resource) - when { {"key" "value"} };`, true}, - {"invalidPrimaryRecInits4", `permit(principal, action, resource) - when { {"key":`, true}, - {"invalidPrimaryRecInits5", `permit(principal, action, resource) - when { {"key1": "value1" "key2": "value2" };`, true}, - - {"invalidStringAnnotation", `@bananas("\*") permit (principal, action, resource);`, true}, - {"invalidStringEntityID", `permit(principal == User::"\*", action, resource);`, true}, - {"invalidStringHas", `permit(principal, action, resource) when { context has "\*" };`, true}, - {"invalidNumericLike", `permit(principal, action, resource) when { context.key like 42 };`, true}, - {"invalidPatternLike", `permit(principal, action, resource) when { context.key like "\u{DFFF}" };`, true}, - {"invalidStringPrimary", `permit(principal, action, resource) when { context.key == "\*" };`, true}, - {"invalidExtFun", `permit(principal, action, resource) when { principal == User::"\*" };`, true}, - {"invalidAccess", `permit(principal, action, resource) when { context["\*"] == 42 };`, true}, - {"invalidRecordKey", `permit(principal, action, resource) when { { "\*":42 } };`, true}, - {"invalidIs", `permit (principal is 1, action, resource);`, true}, - {"invalidIsLong", `permit (principal is X::1, action, resource);`, true}, - {"duplicateAnnotations", `@key("value") @key("value") permit (principal, action, resource);`, true}, - - {"very-negative-long-bad", `permit(principal,action,resource) when { -9223372036823454775808 < -9224323372036854775807 };`, true}, - {"very-positive-long-bad", `permit(principal,action,resource) when { 9223372036823454775808 < 9224323372036854775807 };`, true}, - } - for _, tt := range tests { - tt := tt - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - tokens, err := Tokenize([]byte(tt.in)) - testutil.OK(t, err) - got, err := Parse(tokens) - testutil.Equals(t, err != nil, tt.err) - if err != nil { - testutil.Equals(t, got, nil) - return - } - - gotTokens, err := Tokenize([]byte(got.String())) - testutil.OK(t, err) - - var tokenStrs []string - for _, t := range tokens { - tokenStrs = append(tokenStrs, t.toString()) - } - - var gotTokenStrs []string - for _, t := range gotTokens { - gotTokenStrs = append(gotTokenStrs, t.toString()) - } - - testutil.Equals(t, gotTokenStrs, tokenStrs) - }) - } -} - -func TestParseTypes(t *testing.T) { - t.Parallel() - tests := []struct { - name string - in string - out Policies - }{ - { - "first", - "permit(principal, action, resource) when { 3 * 2 > 5 };", - Policies{ - Policy{ - Position: Position{Offset: 0, Line: 1, Column: 1}, - Annotations: []Annotation(nil), - Effect: "permit", - Conditions: []Condition{ - { - Type: "when", - Expression: Expression{ - Type: ExpressionOr, - Or: Or{ - Ands: []And{ - { - Relations: []Relation{ - { - Add: Add{ - Mults: []Mult{ - { - Unaries: []Unary{ - { - Ops: []UnaryOp(nil), - Member: Member{ - Primary: Primary{ - Type: PrimaryLiteral, - Literal: Literal{Type: LiteralInt, Long: 3}, - }, - Accesses: []Access(nil), - }, - }, - { - Ops: []UnaryOp(nil), - Member: Member{ - Primary: Primary{ - Type: PrimaryLiteral, - Literal: Literal{Type: LiteralInt, Long: 2}, - }, - Accesses: []Access(nil), - }, - }, - }, - }, - }, - }, - Type: "relop", - RelOp: ">", - RelOpRhs: Add{ - Mults: []Mult{ - { - Unaries: []Unary{ - { - Ops: []UnaryOp(nil), - Member: Member{ - Primary: Primary{ - Type: PrimaryLiteral, - Literal: Literal{Type: LiteralInt, Long: 5}, - }, - Accesses: []Access(nil), - }, - }, - }, - }, - }, - }, - Str: "", - }, - }, - }, - }, - }, - }, - }, - }, - }, - }, - }, - } - for _, tt := range tests { - tt := tt - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - tokens, err := Tokenize([]byte(tt.in)) - testutil.OK(t, err) - got, err := Parse(tokens) - testutil.OK(t, err) - testutil.Equals(t, got, tt.out) - }) - } -} - -func TestParseEntity(t *testing.T) { - t.Parallel() - tests := []struct { - name string - in string - out Entity - err func(testing.TB, error) - }{ - {"happy", `Action::"test"`, Entity{Path: []string{"Action", "test"}}, testutil.OK}, - } - for _, tt := range tests { - tt := tt - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - toks, err := Tokenize([]byte(tt.in)) - testutil.OK(t, err) - out, err := ParseEntity(toks) - testutil.Equals(t, out, tt.out) - tt.err(t, err) - }) - } -} - -func TestPolicyPositions(t *testing.T) { - t.Parallel() - in := `// idk a comment -@blah("asdf") -permit( principal, action, resource ); - - -// later on - permit (principal, action, resource) ; - -// annotation indent - @test("1234") permit (principal, action, resource ); -` - toks, err := Tokenize([]byte(in)) - testutil.OK(t, err) - out, err := Parse(toks) - testutil.OK(t, err) - testutil.Equals(t, len(out), 3) - testutil.Equals(t, out[0].Position, Position{Offset: 17, Line: 2, Column: 1}) - testutil.Equals(t, out[1].Position, Position{Offset: 86, Line: 7, Column: 3}) - testutil.Equals(t, out[2].Position, Position{Offset: 148, Line: 10, Column: 2}) -} diff --git a/x/exp/parser2/tokenize.go b/x/exp/parser2/tokenize.go deleted file mode 100644 index e2e41d65..00000000 --- a/x/exp/parser2/tokenize.go +++ /dev/null @@ -1,705 +0,0 @@ -package parser - -import ( - "bytes" - "fmt" - "io" - "strconv" - "strings" - "unicode" - "unicode/utf8" -) - -//go:generate moq -pkg parser -fmt goimports -out tokenize_mocks_test.go . reader - -// This type alias is for test purposes only. -type reader = io.Reader - -type TokenType int - -const ( - TokenEOF = TokenType(iota) - TokenIdent - TokenInt - TokenString - TokenOperator - TokenUnknown -) - -type Token struct { - Type TokenType - Pos Position - Text string -} - -func (t Token) isEOF() bool { - return t.Type == TokenEOF -} - -func (t Token) isIdent() bool { - return t.Type == TokenIdent -} - -func (t Token) isInt() bool { - return t.Type == TokenInt -} - -func (t Token) isString() bool { - return t.Type == TokenString -} - -func (t Token) toString() string { - return t.Text -} - -func (t Token) stringValue() (string, error) { - s := t.Text - s = strings.TrimPrefix(s, "\"") - s = strings.TrimSuffix(s, "\"") - b := []byte(s) - res, _, err := rustUnquote(b, false) - return res, err -} - -func (t Token) patternValue() (Pattern, error) { - return NewPattern(t.Text) -} - -func nextRune(b []byte, i int) (rune, int, error) { - ch, size := utf8.DecodeRune(b[i:]) - if ch == utf8.RuneError { - return ch, i, fmt.Errorf("bad unicode rune") - } - return ch, i + size, nil -} - -func parseHexEscape(b []byte, i int) (rune, int, error) { - var ch rune - var err error - ch, i, err = nextRune(b, i) - if err != nil { - return 0, i, err - } - if !isHexadecimal(ch) { - return 0, i, fmt.Errorf("bad hex escape sequence") - } - res := digitVal(ch) - ch, i, err = nextRune(b, i) - if err != nil { - return 0, i, err - } - if !isHexadecimal(ch) { - return 0, i, fmt.Errorf("bad hex escape sequence") - } - res = 16*res + digitVal(ch) - if res > 127 { - return 0, i, fmt.Errorf("bad hex escape sequence") - } - return rune(res), i, nil -} - -func parseUnicodeEscape(b []byte, i int) (rune, int, error) { - var ch rune - var err error - - ch, i, err = nextRune(b, i) - if err != nil { - return 0, i, err - } - if ch != '{' { - return 0, i, fmt.Errorf("bad unicode escape sequence") - } - - digits := 0 - res := 0 - for { - ch, i, err = nextRune(b, i) - if err != nil { - return 0, i, err - } - if ch == '}' { - break - } - if !isHexadecimal(ch) { - return 0, i, fmt.Errorf("bad unicode escape sequence") - } - res = 16*res + digitVal(ch) - digits++ - } - - if digits == 0 || digits > 6 || !utf8.ValidRune(rune(res)) { - return 0, i, fmt.Errorf("bad unicode escape sequence") - } - - return rune(res), i, nil -} - -func Unquote(s string) (string, error) { - s = strings.TrimPrefix(s, "\"") - s = strings.TrimSuffix(s, "\"") - res, _, err := rustUnquote([]byte(s), false) - return res, err -} - -func rustUnquote(b []byte, star bool) (string, []byte, error) { - var sb strings.Builder - var ch rune - var err error - i := 0 - for i < len(b) { - ch, i, err = nextRune(b, i) - if err != nil { - return "", nil, err - } - if star && ch == '*' { - i-- - return sb.String(), b[i:], nil - } - if ch != '\\' { - sb.WriteRune(ch) - continue - } - ch, i, err = nextRune(b, i) - if err != nil { - return "", nil, err - } - switch ch { - case 'n': - sb.WriteRune('\n') - case 'r': - sb.WriteRune('\r') - case 't': - sb.WriteRune('\t') - case '\\': - sb.WriteRune('\\') - case '0': - sb.WriteRune('\x00') - case '\'': - sb.WriteRune('\'') - case '"': - sb.WriteRune('"') - case 'x': - ch, i, err = parseHexEscape(b, i) - if err != nil { - return "", nil, err - } - sb.WriteRune(ch) - case 'u': - ch, i, err = parseUnicodeEscape(b, i) - if err != nil { - return "", nil, err - } - sb.WriteRune(ch) - case '*': - if !star { - return "", nil, fmt.Errorf("bad char escape") - } - sb.WriteRune('*') - default: - return "", nil, fmt.Errorf("bad char escape") - } - } - return sb.String(), b[i:], nil -} - -type PatternComponent struct { - Star bool - Chunk string -} - -type Pattern struct { - Comps []PatternComponent - Raw string -} - -func (p Pattern) String() string { - return p.Raw -} - -func NewPattern(literal string) (Pattern, error) { - rawPat := literal - - literal = strings.TrimPrefix(literal, "\"") - literal = strings.TrimSuffix(literal, "\"") - - b := []byte(literal) - - var comps []PatternComponent - for len(b) > 0 { - var comp PatternComponent - var err error - for len(b) > 0 && b[0] == '*' { - b = b[1:] - comp.Star = true - } - comp.Chunk, b, err = rustUnquote(b, true) - if err != nil { - return Pattern{}, err - } - comps = append(comps, comp) - } - return Pattern{ - Comps: comps, - Raw: rawPat, - }, nil -} - -func isHexadecimal(ch rune) bool { - return isDecimal(ch) || ('a' <= lower(ch) && lower(ch) <= 'f') -} - -// TODO: make FakeRustQuote actually accurate in all cases -func FakeRustQuote(s string) string { - return strconv.Quote(s) -} - -func (t Token) intValue() (int64, error) { - return strconv.ParseInt(t.Text, 10, 64) -} - -func Tokenize(src []byte) ([]Token, error) { - var res []Token - var s scanner - s.Init(bytes.NewBuffer(src)) - for tok := s.nextToken(); s.err == nil && tok.Type != TokenEOF; tok = s.nextToken() { - res = append(res, tok) - } - if s.err != nil { - return nil, s.err - } - res = append(res, Token{Type: TokenEOF, Pos: s.position}) - return res, nil -} - -// Position is a value that represents a source position. -// A position is valid if Line > 0. -type Position struct { - Offset int // byte offset, starting at 0 - Line int // line number, starting at 1 - Column int // column number, starting at 1 (character count per line) -} - -func (pos Position) String() string { - return fmt.Sprintf(":%d:%d", pos.Line, pos.Column) -} - -const ( - specialRuneEOF = rune(-(iota + 1)) - specialRuneBOF -) - -const bufLen = 1024 // at least utf8.UTFMax - -// A scanner implements reading of Unicode characters and tokens from an io.Reader. -type scanner struct { - // Input - src io.Reader - - // Source buffer - srcBuf [bufLen + 1]byte // +1 for sentinel for common case of s.next() - srcPos int // reading position (srcBuf index) - srcEnd int // source end (srcBuf index) - - // Source position - srcBufOffset int // byte offset of srcBuf[0] in source - line int // line count - column int // character count - lastLineLen int // length of last line in characters (for correct column reporting) - lastCharLen int // length of last character in bytes - - // Token text buffer - // Typically, token text is stored completely in srcBuf, but in general - // the token text's head may be buffered in tokBuf while the token text's - // tail is stored in srcBuf. - tokBuf bytes.Buffer // token text head that is not in srcBuf anymore - tokPos int // token text tail position (srcBuf index); valid if >= 0 - tokEnd int // token text tail end (srcBuf index) - - // One character look-ahead - ch rune // character before current srcPos - - // Last error encountered by nextToken. - err error - - // Start position of most recently scanned token; set by nextToken. - // Calling Init or Next invalidates the position (Line == 0). - // If an error is reported (via Error) and position is invalid, - // the scanner is not inside a token. Call Pos to obtain an error - // position in that case, or to obtain the position immediately - // after the most recently scanned token. - position Position -} - -// Init initializes a Scanner with a new source and returns s. -func (s *scanner) Init(src io.Reader) *scanner { - s.src = src - - // initialize source buffer - // (the first call to next() will fill it by calling src.Read) - s.srcBuf[0] = utf8.RuneSelf // sentinel - s.srcPos = 0 - s.srcEnd = 0 - - // initialize source position - s.srcBufOffset = 0 - s.line = 1 - s.column = 0 - s.lastLineLen = 0 - s.lastCharLen = 0 - - // initialize token text buffer - // (required for first call to next()). - s.tokPos = -1 - - // initialize one character look-ahead - s.ch = specialRuneBOF // no char read yet, not EOF - - // initialize public fields - s.position.Line = 0 // invalidate token position - - return s -} - -// next reads and returns the next Unicode character. It is designed such -// that only a minimal amount of work needs to be done in the common ASCII -// case (one test to check for both ASCII and end-of-buffer, and one test -// to check for newlines). -func (s *scanner) next() rune { - ch, width := rune(s.srcBuf[s.srcPos]), 1 - - if ch >= utf8.RuneSelf { - // uncommon case: not ASCII or not enough bytes - for s.srcPos+utf8.UTFMax > s.srcEnd && !utf8.FullRune(s.srcBuf[s.srcPos:s.srcEnd]) { - // not enough bytes: read some more, but first - // save away token text if any - if s.tokPos >= 0 { - s.tokBuf.Write(s.srcBuf[s.tokPos:s.srcPos]) - s.tokPos = 0 - // s.tokEnd is set by nextToken() - } - // move unread bytes to beginning of buffer - copy(s.srcBuf[0:], s.srcBuf[s.srcPos:s.srcEnd]) - s.srcBufOffset += s.srcPos - // read more bytes - // (an io.Reader must return io.EOF when it reaches - // the end of what it is reading - simply returning - // n == 0 will make this loop retry forever; but the - // error is in the reader implementation in that case) - i := s.srcEnd - s.srcPos - n, err := s.src.Read(s.srcBuf[i:bufLen]) - s.srcPos = 0 - s.srcEnd = i + n - s.srcBuf[s.srcEnd] = utf8.RuneSelf // sentinel - if err != nil { - if err != io.EOF { - s.error(err.Error()) - } - if s.srcEnd == 0 { - if s.lastCharLen > 0 { - // previous character was not EOF - s.column++ - } - s.lastCharLen = 0 - return specialRuneEOF - } - // If err == EOF, we won't be getting more - // bytes; break to avoid infinite loop. If - // err is something else, we don't know if - // we can get more bytes; thus also break. - break - } - } - // at least one byte - ch = rune(s.srcBuf[s.srcPos]) - if ch >= utf8.RuneSelf { - // uncommon case: not ASCII - ch, width = utf8.DecodeRune(s.srcBuf[s.srcPos:s.srcEnd]) - if ch == utf8.RuneError && width == 1 { - // advance for correct error position - s.srcPos += width - s.lastCharLen = width - s.column++ - s.error("invalid UTF-8 encoding") - return ch - } - } - } - - // advance - s.srcPos += width - s.lastCharLen = width - s.column++ - - // special situations - switch ch { - case 0: - // for compatibility with other tools - s.error("invalid character NUL") - case '\n': - s.line++ - s.lastLineLen = s.column - s.column = 0 - } - - return ch -} - -func (s *scanner) error(msg string) { - s.tokEnd = s.srcPos - s.lastCharLen // make sure token text is terminated - s.err = fmt.Errorf("%v: %v", s.position, msg) -} - -func isIdentRune(ch rune, first bool) bool { - return ch == '_' || unicode.IsLetter(ch) || unicode.IsDigit(ch) && !first -} - -func (s *scanner) scanIdentifier() rune { - // we know the zeroth rune is OK; start scanning at the next one - ch := s.next() - for isIdentRune(ch, false) { - ch = s.next() - } - return ch -} - -func lower(ch rune) rune { return ('a' - 'A') | ch } // returns lower-case ch iff ch is ASCII letter -func isDecimal(ch rune) bool { return '0' <= ch && ch <= '9' } - -func (s *scanner) scanInteger(ch rune) rune { - for isDecimal(ch) { - ch = s.next() - } - return ch -} - -func digitVal(ch rune) int { - switch { - case '0' <= ch && ch <= '9': - return int(ch - '0') - case 'a' <= lower(ch) && lower(ch) <= 'f': - return int(lower(ch) - 'a' + 10) - } - return 16 // larger than any legal digit val -} - -func (s *scanner) scanHexDigits(ch rune, min, max int) rune { - n := 0 - for n < max && isHexadecimal(ch) { - ch = s.next() - n++ - } - if n < min || n > max { - s.error("invalid char escape") - } - return ch -} - -func (s *scanner) scanEscape() rune { - ch := s.next() // read character after '/' - switch ch { - case 'n', 'r', 't', '\\', '0', '\'', '"', '*': - // nothing to do - ch = s.next() - case 'x': - ch = s.scanHexDigits(s.next(), 2, 2) - case 'u': - ch = s.next() - if ch != '{' { - s.error("invalid char escape") - return ch - } - ch = s.scanHexDigits(s.next(), 1, 6) - if ch != '}' { - s.error("invalid char escape") - return ch - } - ch = s.next() - default: - s.error("invalid char escape") - } - return ch -} - -func (s *scanner) scanString() (n int) { - ch := s.next() // read character after quote - for ch != '"' { - if ch == '\n' || ch < 0 { - s.error("literal not terminated") - return - } - if ch == '\\' { - ch = s.scanEscape() - } else { - ch = s.next() - } - n++ - } - return -} - -func (s *scanner) scanComment(ch rune) rune { - // ch == '/' || ch == '*' - if ch == '/' { - // line comment - ch = s.next() // read character after "//" - for ch != '\n' && ch >= 0 { - ch = s.next() - } - return ch - } - - // general comment - ch = s.next() // read character after "/*" - for { - if ch < 0 { - s.error("comment not terminated") - break - } - ch0 := ch - ch = s.next() - if ch0 == '*' && ch == '/' { - ch = s.next() - break - } - } - return ch -} - -func (s *scanner) scanOperator(ch0, ch rune) (TokenType, rune) { - switch ch0 { - case '@', '.', ',', ';', '(', ')', '{', '}', '[', ']', '+', '-', '*': - case ':': - if ch == ':' { - ch = s.next() - } - case '!', '<', '>': - if ch == '=' { - ch = s.next() - } - case '=': - if ch != '=' { - return TokenUnknown, ch - } - ch = s.next() - case '|': - if ch != '|' { - return TokenUnknown, ch - } - ch = s.next() - case '&': - if ch != '&' { - return TokenUnknown, ch - } - ch = s.next() - default: - return TokenUnknown, ch - } - return TokenOperator, ch -} - -func isWhitespace(c rune) bool { - switch c { - case '\t', '\n', '\r', ' ': - return true - default: - return false - } -} - -// nextToken reads the next token or Unicode character from source and returns -// it. It returns specialRuneEOF at the end of the source. It reports scanner -// errors (read and token errors) by calling s.Error, if not nil; otherwise it -// prints an error message to os.Stderr. -func (s *scanner) nextToken() Token { - if s.ch == specialRuneBOF { - s.ch = s.next() - } - - ch := s.ch - - // reset token text position - s.tokPos = -1 - s.position.Line = 0 - -redo: - // skip white space - for isWhitespace(ch) { - ch = s.next() - } - - // start collecting token text - s.tokBuf.Reset() - s.tokPos = s.srcPos - s.lastCharLen - - // set token position - s.position.Offset = s.srcBufOffset + s.tokPos - if s.column > 0 { - // common case: last character was not a '\n' - s.position.Line = s.line - s.position.Column = s.column - } else { - // last character was a '\n' - // (we cannot be at the beginning of the source - // since we have called next() at least once) - s.position.Line = s.line - 1 - s.position.Column = s.lastLineLen - } - - // determine token value - var tt TokenType - switch { - case ch == specialRuneEOF: - tt = TokenEOF - case isIdentRune(ch, true): - ch = s.scanIdentifier() - tt = TokenIdent - case isDecimal(ch): - ch = s.scanInteger(ch) - tt = TokenInt - case ch == '"': - s.scanString() - ch = s.next() - tt = TokenString - case ch == '/': - ch0 := ch - ch = s.next() - if ch == '/' || ch == '*' { - s.tokPos = -1 // don't collect token text - ch = s.scanComment(ch) - goto redo - } - tt, ch = s.scanOperator(ch0, ch) - default: - tt, ch = s.scanOperator(ch, s.next()) - } - - // end of token text - s.tokEnd = s.srcPos - s.lastCharLen - s.ch = ch - - return Token{ - Type: tt, - Pos: s.position, - Text: s.tokenText(), - } -} - -// tokenText returns the string corresponding to the most recently scanned token. -// Valid after calling nextToken and in calls of Scanner.Error. -func (s *scanner) tokenText() string { - if s.tokPos < 0 { - // no token text - return "" - } - - if s.tokBuf.Len() == 0 { - // common case: the entire token text is still in srcBuf - return string(s.srcBuf[s.tokPos:s.tokEnd]) - } - - // part of the token text was saved in tokBuf: save the rest in - // tokBuf as well and return its content - s.tokBuf.Write(s.srcBuf[s.tokPos:s.tokEnd]) - s.tokPos = s.tokEnd // ensure idempotency of TokenText() call - return s.tokBuf.String() -} diff --git a/x/exp/parser2/tokenize_mocks_test.go b/x/exp/parser2/tokenize_mocks_test.go deleted file mode 100644 index ff5a98fc..00000000 --- a/x/exp/parser2/tokenize_mocks_test.go +++ /dev/null @@ -1,74 +0,0 @@ -// Code generated by moq; DO NOT EDIT. -// github.com/matryer/moq - -package parser - -import ( - "sync" -) - -// Ensure, that readerMock does implement reader. -// If this is not the case, regenerate this file with moq. -var _ reader = &readerMock{} - -// readerMock is a mock implementation of reader. -// -// func TestSomethingThatUsesreader(t *testing.T) { -// -// // make and configure a mocked reader -// mockedreader := &readerMock{ -// ReadFunc: func(p []byte) (int, error) { -// panic("mock out the Read method") -// }, -// } -// -// // use mockedreader in code that requires reader -// // and then make assertions. -// -// } -type readerMock struct { - // ReadFunc mocks the Read method. - ReadFunc func(p []byte) (int, error) - - // calls tracks calls to the methods. - calls struct { - // Read holds details about calls to the Read method. - Read []struct { - // P is the p argument value. - P []byte - } - } - lockRead sync.RWMutex -} - -// Read calls ReadFunc. -func (mock *readerMock) Read(p []byte) (int, error) { - if mock.ReadFunc == nil { - panic("readerMock.ReadFunc: method is nil but reader.Read was just called") - } - callInfo := struct { - P []byte - }{ - P: p, - } - mock.lockRead.Lock() - mock.calls.Read = append(mock.calls.Read, callInfo) - mock.lockRead.Unlock() - return mock.ReadFunc(p) -} - -// ReadCalls gets all the calls that were made to Read. -// Check the length with: -// -// len(mockedreader.ReadCalls()) -func (mock *readerMock) ReadCalls() []struct { - P []byte -} { - var calls []struct { - P []byte - } - mock.lockRead.RLock() - calls = mock.calls.Read - mock.lockRead.RUnlock() - return calls -} diff --git a/x/exp/parser2/tokenize_test.go b/x/exp/parser2/tokenize_test.go deleted file mode 100644 index 42d9911a..00000000 --- a/x/exp/parser2/tokenize_test.go +++ /dev/null @@ -1,554 +0,0 @@ -package parser - -import ( - "fmt" - "io" - "strings" - "testing" - "unicode/utf8" - - "github.com/cedar-policy/cedar-go/testutil" -) - -func TestTokenize(t *testing.T) { - t.Parallel() - input := ` -These are some identifiers -0 1 1234 --1 9223372036854775807 -9223372036854775808 -"" "string" "\"\'\n\r\t\\\0" "\x123" "\u{0}\u{10fFfF}" -"*" "\*" "*\**" -@.,;(){}[]+-* -::: -!!=<<=>>= -||&& -// single line comment -/* -multiline comment -// embedded comment does nothing -*/ -'/%|&=` - want := []Token{ - {Type: TokenIdent, Text: "These", Pos: Position{Offset: 1, Line: 2, Column: 1}}, - {Type: TokenIdent, Text: "are", Pos: Position{Offset: 7, Line: 2, Column: 7}}, - {Type: TokenIdent, Text: "some", Pos: Position{Offset: 11, Line: 2, Column: 11}}, - {Type: TokenIdent, Text: "identifiers", Pos: Position{Offset: 16, Line: 2, Column: 16}}, - - {Type: TokenInt, Text: "0", Pos: Position{Offset: 28, Line: 3, Column: 1}}, - {Type: TokenInt, Text: "1", Pos: Position{Offset: 30, Line: 3, Column: 3}}, - {Type: TokenInt, Text: "1234", Pos: Position{Offset: 32, Line: 3, Column: 5}}, - - {Type: TokenOperator, Text: "-", Pos: Position{Offset: 37, Line: 4, Column: 1}}, - {Type: TokenInt, Text: "1", Pos: Position{Offset: 38, Line: 4, Column: 2}}, - {Type: TokenInt, Text: "9223372036854775807", Pos: Position{Offset: 40, Line: 4, Column: 4}}, - {Type: TokenOperator, Text: "-", Pos: Position{Offset: 60, Line: 4, Column: 24}}, - {Type: TokenInt, Text: "9223372036854775808", Pos: Position{Offset: 61, Line: 4, Column: 25}}, - - {Type: TokenString, Text: `""`, Pos: Position{Offset: 81, Line: 5, Column: 1}}, - {Type: TokenString, Text: `"string"`, Pos: Position{Offset: 84, Line: 5, Column: 4}}, - {Type: TokenString, Text: `"\"\'\n\r\t\\\0"`, Pos: Position{Offset: 93, Line: 5, Column: 13}}, - {Type: TokenString, Text: `"\x123"`, Pos: Position{Offset: 110, Line: 5, Column: 30}}, - {Type: TokenString, Text: `"\u{0}\u{10fFfF}"`, Pos: Position{Offset: 118, Line: 5, Column: 38}}, - - {Type: TokenString, Text: `"*"`, Pos: Position{Offset: 136, Line: 6, Column: 1}}, - {Type: TokenString, Text: `"\*"`, Pos: Position{Offset: 140, Line: 6, Column: 5}}, - {Type: TokenString, Text: `"*\**"`, Pos: Position{Offset: 145, Line: 6, Column: 10}}, - - {Type: TokenOperator, Text: "@", Pos: Position{Offset: 152, Line: 7, Column: 1}}, - {Type: TokenOperator, Text: ".", Pos: Position{Offset: 153, Line: 7, Column: 2}}, - {Type: TokenOperator, Text: ",", Pos: Position{Offset: 154, Line: 7, Column: 3}}, - {Type: TokenOperator, Text: ";", Pos: Position{Offset: 155, Line: 7, Column: 4}}, - {Type: TokenOperator, Text: "(", Pos: Position{Offset: 156, Line: 7, Column: 5}}, - {Type: TokenOperator, Text: ")", Pos: Position{Offset: 157, Line: 7, Column: 6}}, - {Type: TokenOperator, Text: "{", Pos: Position{Offset: 158, Line: 7, Column: 7}}, - {Type: TokenOperator, Text: "}", Pos: Position{Offset: 159, Line: 7, Column: 8}}, - {Type: TokenOperator, Text: "[", Pos: Position{Offset: 160, Line: 7, Column: 9}}, - {Type: TokenOperator, Text: "]", Pos: Position{Offset: 161, Line: 7, Column: 10}}, - {Type: TokenOperator, Text: "+", Pos: Position{Offset: 162, Line: 7, Column: 11}}, - {Type: TokenOperator, Text: "-", Pos: Position{Offset: 163, Line: 7, Column: 12}}, - {Type: TokenOperator, Text: "*", Pos: Position{Offset: 164, Line: 7, Column: 13}}, - - {Type: TokenOperator, Text: "::", Pos: Position{Offset: 166, Line: 8, Column: 1}}, - {Type: TokenOperator, Text: ":", Pos: Position{Offset: 168, Line: 8, Column: 3}}, - - {Type: TokenOperator, Text: "!", Pos: Position{Offset: 170, Line: 9, Column: 1}}, - {Type: TokenOperator, Text: "!=", Pos: Position{Offset: 171, Line: 9, Column: 2}}, - {Type: TokenOperator, Text: "<", Pos: Position{Offset: 173, Line: 9, Column: 4}}, - {Type: TokenOperator, Text: "<=", Pos: Position{Offset: 174, Line: 9, Column: 5}}, - {Type: TokenOperator, Text: ">", Pos: Position{Offset: 176, Line: 9, Column: 7}}, - {Type: TokenOperator, Text: ">=", Pos: Position{Offset: 177, Line: 9, Column: 8}}, - - {Type: TokenOperator, Text: "||", Pos: Position{Offset: 180, Line: 10, Column: 1}}, - {Type: TokenOperator, Text: "&&", Pos: Position{Offset: 182, Line: 10, Column: 3}}, - - {Type: TokenUnknown, Text: "'", Pos: Position{Offset: 265, Line: 16, Column: 1}}, - {Type: TokenUnknown, Text: "/", Pos: Position{Offset: 266, Line: 16, Column: 2}}, - {Type: TokenUnknown, Text: "%", Pos: Position{Offset: 267, Line: 16, Column: 3}}, - {Type: TokenUnknown, Text: "|", Pos: Position{Offset: 268, Line: 16, Column: 4}}, - {Type: TokenUnknown, Text: "&", Pos: Position{Offset: 269, Line: 16, Column: 5}}, - {Type: TokenUnknown, Text: "=", Pos: Position{Offset: 270, Line: 16, Column: 6}}, - - {Type: TokenEOF, Text: "", Pos: Position{Offset: 271, Line: 16, Column: 7}}, - } - got, err := Tokenize([]byte(input)) - testutil.OK(t, err) - testutil.Equals(t, got, want) -} - -func TestTokenizeErrors(t *testing.T) { - t.Parallel() - tests := []struct { - input string - wantErrStr string - wantErrPos Position - }{ - {"okay\x00not okay", "invalid character NUL", Position{Line: 1, Column: 1}}, - {`okay /* - stuff - `, "comment not terminated", Position{Line: 1, Column: 6}}, - {`okay " - " foo bar`, "literal not terminated", Position{Line: 1, Column: 6}}, - {`"okay" "\a"`, "invalid char escape", Position{Line: 1, Column: 8}}, - {`"okay" "\b"`, "invalid char escape", Position{Line: 1, Column: 8}}, - {`"okay" "\f"`, "invalid char escape", Position{Line: 1, Column: 8}}, - {`"okay" "\v"`, "invalid char escape", Position{Line: 1, Column: 8}}, - {`"okay" "\1"`, "invalid char escape", Position{Line: 1, Column: 8}}, - {`"okay" "\x"`, "invalid char escape", Position{Line: 1, Column: 8}}, - {`"okay" "\x1"`, "invalid char escape", Position{Line: 1, Column: 8}}, - {`"okay" "\ubadf"`, "invalid char escape", Position{Line: 1, Column: 8}}, - {`"okay" "\U0000badf"`, "invalid char escape", Position{Line: 1, Column: 8}}, - {`"okay" "\u{}"`, "invalid char escape", Position{Line: 1, Column: 8}}, - {`"okay" "\u{0000000}"`, "invalid char escape", Position{Line: 1, Column: 8}}, - {`"okay" "\u{z"`, "invalid char escape", Position{Line: 1, Column: 8}}, - } - for i, tt := range tests { - tt := tt - t.Run(fmt.Sprintf("%02d", i), func(t *testing.T) { - t.Parallel() - got, gotErr := Tokenize([]byte(tt.input)) - wantErrStr := fmt.Sprintf("%v: %s", tt.wantErrPos, tt.wantErrStr) - testutil.Error(t, gotErr) - testutil.Equals(t, gotErr.Error(), wantErrStr) - testutil.Equals(t, got, nil) - }) - } -} - -func TestIntTokenValues(t *testing.T) { - t.Parallel() - tests := []struct { - input string - wantOk bool - want int64 - wantErr string - }{ - {"0", true, 0, ""}, - {"9223372036854775807", true, 9223372036854775807, ""}, - {"9223372036854775808", false, 0, `strconv.ParseInt: parsing "9223372036854775808": value out of range`}, - } - for _, tt := range tests { - tt := tt - t.Run(tt.input, func(t *testing.T) { - t.Parallel() - got, err := Tokenize([]byte(tt.input)) - testutil.OK(t, err) - testutil.Equals(t, len(got), 2) - testutil.Equals(t, got[0].Type, TokenInt) - gotInt, err := got[0].intValue() - if err != nil { - testutil.Equals(t, tt.wantOk, false) - testutil.Equals(t, err.Error(), tt.wantErr) - } else { - testutil.Equals(t, tt.wantOk, true) - testutil.Equals(t, gotInt, tt.want) - } - }) - } -} - -func TestStringTokenValues(t *testing.T) { - t.Parallel() - tests := []struct { - input string - wantOk bool - want string - wantErr string - }{ - {`""`, true, "", ""}, - {`"hello"`, true, "hello", ""}, - {`"a\n\r\t\\\0b"`, true, "a\n\r\t\\\x00b", ""}, - {`"a\"b"`, true, "a\"b", ""}, - {`"a\'b"`, true, "a'b", ""}, - - {`"a\x00b"`, true, "a\x00b", ""}, - {`"a\x7fb"`, true, "a\x7fb", ""}, - {`"a\x80b"`, false, "", "bad hex escape sequence"}, - - {`"a\u{A}b"`, true, "a\u000ab", ""}, - {`"a\u{aB}b"`, true, "a\u00abb", ""}, - {`"a\u{AbC}b"`, true, "a\u0abcb", ""}, - {`"a\u{aBcD}b"`, true, "a\uabcdb", ""}, - {`"a\u{AbCdE}b"`, true, "a\U000abcdeb", ""}, - {`"a\u{10cDeF}b"`, true, "a\U0010cdefb", ""}, - {`"a\u{ffffff}b"`, false, "", "bad unicode escape sequence"}, - {`"a\u{d7ff}b"`, true, "a\ud7ffb", ""}, - {`"a\u{d800}b"`, false, "", "bad unicode escape sequence"}, - {`"a\u{dfff}b"`, false, "", "bad unicode escape sequence"}, - {`"a\u{e000}b"`, true, "a\ue000b", ""}, - {`"a\u{10ffff}b"`, true, "a\U0010ffffb", ""}, - {`"a\u{110000}b"`, false, "", "bad unicode escape sequence"}, - } - for _, tt := range tests { - tt := tt - t.Run(tt.input, func(t *testing.T) { - t.Parallel() - got, err := Tokenize([]byte(tt.input)) - testutil.OK(t, err) - testutil.Equals(t, len(got), 2) - testutil.Equals(t, got[0].Type, TokenString) - gotStr, err := got[0].stringValue() - if err != nil { - testutil.Equals(t, tt.wantOk, false) - testutil.Equals(t, err.Error(), tt.wantErr) - } else { - testutil.Equals(t, tt.wantOk, true) - testutil.Equals(t, gotStr, tt.want) - } - }) - } -} - -func TestParseUnicodeEscape(t *testing.T) { - t.Parallel() - tests := []struct { - name string - in []byte - out rune - outN int - err func(t testing.TB, err error) - }{ - {"happy", []byte{'{', '4', '2', '}'}, 0x42, 4, testutil.OK}, - {"badRune", []byte{'{', 0x80, 0x81}, 0, 1, testutil.Error}, - {"notHex", []byte{'{', 'g'}, 0, 2, testutil.Error}, - } - for _, tt := range tests { - tt := tt - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - out, n, err := parseUnicodeEscape(tt.in, 0) - testutil.Equals(t, out, tt.out) - testutil.Equals(t, n, tt.outN) - tt.err(t, err) - }) - } -} - -func TestUnquote(t *testing.T) { - t.Parallel() - tests := []struct { - name string - in string - out string - err func(t testing.TB, err error) - }{ - {"happy", `"test"`, `test`, testutil.OK}, - } - for _, tt := range tests { - tt := tt - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - out, err := Unquote(tt.in) - testutil.Equals(t, out, tt.out) - tt.err(t, err) - }) - } -} - -func TestRustUnquote(t *testing.T) { - t.Parallel() - // star == false - { - tests := []struct { - input string - wantOk bool - want string - wantErr string - }{ - {``, true, "", ""}, - {`hello`, true, "hello", ""}, - {`a\n\r\t\\\0b`, true, "a\n\r\t\\\x00b", ""}, - {`a\"b`, true, "a\"b", ""}, - {`a\'b`, true, "a'b", ""}, - - {`a\x00b`, true, "a\x00b", ""}, - {`a\x7fb`, true, "a\x7fb", ""}, - {`a\x80b`, false, "", "bad hex escape sequence"}, - - {string([]byte{0x80, 0x81}), false, "", "bad unicode rune"}, - {`a\u`, false, "", "bad unicode rune"}, - {`a\uz`, false, "", "bad unicode escape sequence"}, - {`a\u{}b`, false, "", "bad unicode escape sequence"}, - {`a\u{A}b`, true, "a\u000ab", ""}, - {`a\u{aB}b`, true, "a\u00abb", ""}, - {`a\u{AbC}b`, true, "a\u0abcb", ""}, - {`a\u{aBcD}b`, true, "a\uabcdb", ""}, - {`a\u{AbCdE}b`, true, "a\U000abcdeb", ""}, - {`a\u{10cDeF}b`, true, "a\U0010cdefb", ""}, - {`a\u{ffffff}b`, false, "", "bad unicode escape sequence"}, - {`a\u{0000000}b`, false, "", "bad unicode escape sequence"}, - {`a\u{d7ff}b`, true, "a\ud7ffb", ""}, - {`a\u{d800}b`, false, "", "bad unicode escape sequence"}, - {`a\u{dfff}b`, false, "", "bad unicode escape sequence"}, - {`a\u{e000}b`, true, "a\ue000b", ""}, - {`a\u{10ffff}b`, true, "a\U0010ffffb", ""}, - {`a\u{110000}b`, false, "", "bad unicode escape sequence"}, - - {`\`, false, "", "bad unicode rune"}, - {`\a`, false, "", "bad char escape"}, - {`\*`, false, "", "bad char escape"}, - {`\x`, false, "", "bad unicode rune"}, - {`\xz`, false, "", "bad hex escape sequence"}, - {`\xa`, false, "", "bad unicode rune"}, - {`\xaz`, false, "", "bad hex escape sequence"}, - {`\{`, false, "", "bad char escape"}, - {`\{z`, false, "", "bad char escape"}, - {`\{0`, false, "", "bad char escape"}, - {`\{0z`, false, "", "bad char escape"}, - } - for _, tt := range tests { - tt := tt - t.Run(tt.input, func(t *testing.T) { - t.Parallel() - got, rem, err := rustUnquote([]byte(tt.input), false) - if err != nil { - testutil.Equals(t, tt.wantOk, false) - testutil.Equals(t, err.Error(), tt.wantErr) - testutil.Equals(t, got, tt.want) - } else { - testutil.Equals(t, tt.wantOk, true) - testutil.Equals(t, got, tt.want) - testutil.Equals(t, rem, []byte("")) - } - }) - } - } - - // star == true - { - tests := []struct { - input string - wantOk bool - want string - wantRem string - wantErr string - }{ - {``, true, "", "", ""}, - {`hello`, true, "hello", "", ""}, - {`a\n\r\t\\\0b`, true, "a\n\r\t\\\x00b", "", ""}, - {`a\"b`, true, "a\"b", "", ""}, - {`a\'b`, true, "a'b", "", ""}, - - {`a\x00b`, true, "a\x00b", "", ""}, - {`a\x7fb`, true, "a\x7fb", "", ""}, - {`a\x80b`, false, "", "", "bad hex escape sequence"}, - - {`a\u`, false, "", "", "bad unicode rune"}, - {`a\uz`, false, "", "", "bad unicode escape sequence"}, - {`a\u{}b`, false, "", "", "bad unicode escape sequence"}, - {`a\u{A}b`, true, "a\u000ab", "", ""}, - {`a\u{aB}b`, true, "a\u00abb", "", ""}, - {`a\u{AbC}b`, true, "a\u0abcb", "", ""}, - {`a\u{aBcD}b`, true, "a\uabcdb", "", ""}, - {`a\u{AbCdE}b`, true, "a\U000abcdeb", "", ""}, - {`a\u{10cDeF}b`, true, "a\U0010cdefb", "", ""}, - {`a\u{ffffff}b`, false, "", "", "bad unicode escape sequence"}, - {`a\u{0000000}b`, false, "", "", "bad unicode escape sequence"}, - {`a\u{d7ff}b`, true, "a\ud7ffb", "", ""}, - {`a\u{d800}b`, false, "", "", "bad unicode escape sequence"}, - {`a\u{dfff}b`, false, "", "", "bad unicode escape sequence"}, - {`a\u{e000}b`, true, "a\ue000b", "", ""}, - {`a\u{10ffff}b`, true, "a\U0010ffffb", "", ""}, - {`a\u{110000}b`, false, "", "", "bad unicode escape sequence"}, - - {`*`, true, "", "*", ""}, - {`*hello*how*are*you`, true, "", "*hello*how*are*you", ""}, - {`hello*how*are*you`, true, "hello", "*how*are*you", ""}, - {`\**`, true, "*", "*", ""}, - - {`\`, false, "", "", "bad unicode rune"}, - {`\a`, false, "", "", "bad char escape"}, - {`\x`, false, "", "", "bad unicode rune"}, - {`\xz`, false, "", "", "bad hex escape sequence"}, - {`\xa`, false, "", "", "bad unicode rune"}, - {`\xaz`, false, "", "", "bad hex escape sequence"}, - {`\{`, false, "", "", "bad char escape"}, - {`\{z`, false, "", "", "bad char escape"}, - {`\{0`, false, "", "", "bad char escape"}, - {`\{0z`, false, "", "", "bad char escape"}, - } - for _, tt := range tests { - tt := tt - t.Run(tt.input, func(t *testing.T) { - t.Parallel() - got, rem, err := rustUnquote([]byte(tt.input), true) - if err != nil { - testutil.Equals(t, tt.wantOk, false) - testutil.Equals(t, err.Error(), tt.wantErr) - testutil.Equals(t, got, tt.want) - } else { - testutil.Equals(t, tt.wantOk, true) - testutil.Equals(t, got, tt.want) - testutil.Equals(t, string(rem), tt.wantRem) - } - }) - } - } -} - -func TestFakeRustQuote(t *testing.T) { - t.Parallel() - out := FakeRustQuote("hello") - testutil.Equals(t, out, `"hello"`) -} - -func TestPatternFromStringLiteral(t *testing.T) { - t.Parallel() - tests := []struct { - input string - wantOk bool - want []PatternComponent - wantErr string - }{ - {`""`, true, nil, ""}, - {`"a"`, true, []PatternComponent{{false, "a"}}, ""}, - {`"*"`, true, []PatternComponent{{true, ""}}, ""}, - {`"*a"`, true, []PatternComponent{{true, "a"}}, ""}, - {`"a*"`, true, []PatternComponent{{false, "a"}, {true, ""}}, ""}, - {`"**"`, true, []PatternComponent{{true, ""}}, ""}, - {`"**a"`, true, []PatternComponent{{true, "a"}}, ""}, - {`"a**"`, true, []PatternComponent{{false, "a"}, {true, ""}}, ""}, - {`"*a*"`, true, []PatternComponent{{true, "a"}, {true, ""}}, ""}, - {`"**a**"`, true, []PatternComponent{{true, "a"}, {true, ""}}, ""}, - {`"abra*ca"`, true, []PatternComponent{ - {false, "abra"}, {true, "ca"}, - }, ""}, - {`"abra**ca"`, true, []PatternComponent{ - {false, "abra"}, {true, "ca"}, - }, ""}, - {`"*abra*ca"`, true, []PatternComponent{ - {true, "abra"}, {true, "ca"}, - }, ""}, - {`"abra*ca*"`, true, []PatternComponent{ - {false, "abra"}, {true, "ca"}, {true, ""}, - }, ""}, - {`"*abra*ca*"`, true, []PatternComponent{ - {true, "abra"}, {true, "ca"}, {true, ""}, - }, ""}, - {`"*abra*ca*dabra"`, true, []PatternComponent{ - {true, "abra"}, {true, "ca"}, {true, "dabra"}, - }, ""}, - {`"*abra*c\**da\*ra"`, true, []PatternComponent{ - {true, "abra"}, {true, "c*"}, {true, "da*ra"}, - }, ""}, - {`"\u"`, false, nil, "bad unicode rune"}, - } - for _, tt := range tests { - tt := tt - t.Run(tt.input, func(t *testing.T) { - t.Parallel() - got, err := NewPattern(tt.input) - if err != nil { - testutil.Equals(t, tt.wantOk, false) - testutil.Equals(t, err.Error(), tt.wantErr) - } else { - testutil.Equals(t, tt.wantOk, true) - testutil.Equals(t, got.Comps, tt.want) - testutil.Equals(t, got.String(), tt.input) - } - }) - } -} - -func TestScanner(t *testing.T) { - t.Parallel() - t.Run("SrcError", func(t *testing.T) { - t.Parallel() - wantErr := fmt.Errorf("wantErr") - r := &readerMock{ - ReadFunc: func(_ []byte) (int, error) { - return 0, wantErr - }, - } - var s scanner - s.Init(r) - out := s.next() - testutil.Equals(t, out, specialRuneEOF) - }) - - t.Run("MidEmojiEOF", func(t *testing.T) { - t.Parallel() - var s scanner - var eof bool - str := []byte(string(`🐐`)) - r := &readerMock{ - ReadFunc: func(p []byte) (int, error) { - if eof { - return 0, io.EOF - } - p[0] = str[0] - eof = true - return 1, nil - }, - } - s.Init(r) - out := s.next() - testutil.Equals(t, out, utf8.RuneError) - out = s.next() - testutil.Equals(t, out, specialRuneEOF) - }) - - t.Run("NotAsciiEmoji", func(t *testing.T) { - t.Parallel() - var s scanner - s.Init(strings.NewReader(`🐐`)) - out := s.next() - testutil.Equals(t, out, '🐐') - }) - - t.Run("InvalidUTF8", func(t *testing.T) { - t.Parallel() - var s scanner - s.Init(strings.NewReader(string([]byte{0x80, 0x81}))) - out := s.next() - testutil.Equals(t, out, utf8.RuneError) - }) - - t.Run("tokenTextNone", func(t *testing.T) { - t.Parallel() - var s scanner - s.Init(strings.NewReader("")) - out := s.tokenText() - testutil.Equals(t, out, "") - }) -} - -func TestDigitVal(t *testing.T) { - t.Parallel() - tests := []struct { - name string - in rune - out int - }{ - {"happy", '0', 0}, - {"hex", 'f', 15}, - {"sad", 'g', 16}, - } - for _, tt := range tests { - tt := tt - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - out := digitVal(tt.in) - testutil.Equals(t, out, tt.out) - }) - } -} From 2230a93d19ff269f52c9c7d65603ed8a137635db Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Tue, 30 Jul 2024 15:13:21 -0700 Subject: [PATCH 017/216] cedar-go/x/exp/ast: fix build break Signed-off-by: philhassey --- x/exp/ast/annotation.go | 1 + 1 file changed, 1 insertion(+) diff --git a/x/exp/ast/annotation.go b/x/exp/ast/annotation.go index 4ffac6f3..940a9f1b 100644 --- a/x/exp/ast/annotation.go +++ b/x/exp/ast/annotation.go @@ -36,6 +36,7 @@ func (a *Annotations) Forbid() *Policy { func (p *Policy) Annotate(name, value types.String) *Policy { p.annotations = append(p.annotations, newAnnotationNode(name, value)) + return p } func newAnnotationNode(name, value types.String) Node { From 9280536ebbacbc341a28c1889c5521089ea416b3 Mon Sep 17 00:00:00 2001 From: philhassey Date: Tue, 30 Jul 2024 16:13:24 -0600 Subject: [PATCH 018/216] x/exp/ast: add not, negate to JSON unmarshaller Addresses IDX-48 Signed-off-by: philhassey --- x/exp/ast/json.go | 87 +++++++++++++++++++++++++++++++++++-------- x/exp/ast/node.go | 1 + x/exp/ast/operator.go | 4 ++ 3 files changed, 76 insertions(+), 16 deletions(-) diff --git a/x/exp/ast/json.go b/x/exp/ast/json.go index f2ae2800..cf1bd237 100644 --- a/x/exp/ast/json.go +++ b/x/exp/ast/json.go @@ -74,12 +74,39 @@ func (j binaryJSON) ToNode(f func(a, b Node) Node) (Node, error) { return f(left, right), nil } +type unaryJSON struct { + Arg nodeJSON `json:"arg"` +} + +func (j unaryJSON) ToNode(f func(a Node) Node) (Node, error) { + arg, err := j.Arg.ToNode() + if err != nil { + return Node{}, fmt.Errorf("error in arg: %w", err) + } + return f(arg), nil +} + type accessJSON struct { Left nodeJSON `json:"left"` Attr string `json:"attr"` } type nodeJSON struct { + + // Value + Value *string `json:"Value"` // could be any + + // Var + Var *string `json:"Var"` + + // Slot + // Unknown + + // ! or neg operators + Not *unaryJSON `json:"!"` + Negate *unaryJSON `json:"neg"` + + // Binary operators: ==, !=, in, <, <=, >, >=, &&, ||, +, -, *, contains, containsAll, containsAny Equals *binaryJSON `json:"=="` NotEquals *binaryJSON `json:"!="` In *binaryJSON `json:"in"` @@ -96,13 +123,47 @@ type nodeJSON struct { ContainsAll *binaryJSON `json:"containsAll"` ContainsAny *binaryJSON `json:"containsAny"` + // ., has Access *accessJSON `json:"."` - Var *string `json:"Var"` - Value *string `json:"Value"` // could be any + + // like + // if-then-else + // Set + // Record + // Any other key + } func (j nodeJSON) ToNode() (Node, error) { switch { + // Value + case j.Value != nil: + return String(types.String(*j.Value)), nil + + // Var + case j.Var != nil: + switch *j.Var { + case "principal": + return Principal(), nil + case "action": + return Action(), nil + case "resource": + return Resource(), nil + case "context": + return Context(), nil + } + return Node{}, fmt.Errorf("unknown var: %v", j.Var) + + // Slot + // Unknown + + // ! or neg operators + case j.Not != nil: + return j.Not.ToNode(Not) + case j.Negate != nil: + return j.Negate.ToNode(Negate) + + // Binary operators: ==, !=, in, <, <=, >, >=, &&, ||, +, -, *, contains, containsAll, containsAny case j.Equals != nil: return j.Equals.ToNode(Node.Equals) case j.NotEquals != nil: @@ -133,28 +194,22 @@ func (j nodeJSON) ToNode() (Node, error) { return j.ContainsAll.ToNode(Node.ContainsAll) case j.ContainsAny != nil: return j.ContainsAny.ToNode(Node.ContainsAny) + + // ., has case j.Access != nil: left, err := j.Access.Left.ToNode() if err != nil { return Node{}, fmt.Errorf("error in left of access: %w", err) } return left.Access(j.Access.Attr), nil - case j.Var != nil: - switch *j.Var { - case "principal": - return Principal(), nil - case "action": - return Action(), nil - case "resource": - return Resource(), nil - case "context": - return Context(), nil - } - return Node{}, fmt.Errorf("unknown var: %v", j.Var) - case j.Value != nil: - return String(types.String(*j.Value)), nil } + // like + // if-then-else + // Set + // Record + // Any other key + return Node{}, fmt.Errorf("unknown node") } diff --git a/x/exp/ast/node.go b/x/exp/ast/node.go index 4e107fcf..e0b1d841 100644 --- a/x/exp/ast/node.go +++ b/x/exp/ast/node.go @@ -33,6 +33,7 @@ const ( nodeTypeLong nodeTypeMult nodeTypeNot + nodeTypeNegate nodeTypeNotEquals nodeTypeOr nodeTypeRecord diff --git a/x/exp/ast/operator.go b/x/exp/ast/operator.go index d8b0c177..0de88bf6 100644 --- a/x/exp/ast/operator.go +++ b/x/exp/ast/operator.go @@ -52,6 +52,10 @@ func Not(rhs Node) Node { return newOpNode(nodeTypeNot, rhs) } +func Negate(rhs Node) Node { + return newOpNode(nodeTypeNegate, rhs) +} + func If(condition Node, ifTrue Node, ifFalse Node) Node { return newOpNode(nodeTypeIf, condition, ifTrue, ifFalse) } From 9ad353098ea54d653d352e7caab3ed9918b34787 Mon Sep 17 00:00:00 2001 From: philhassey Date: Tue, 30 Jul 2024 16:25:12 -0600 Subject: [PATCH 019/216] x/exp/ast: add has support to JSON Addresses IDX-48 Signed-off-by: philhassey --- x/exp/ast/json.go | 21 ++++++++++++++------- x/exp/ast/operator.go | 8 ++++---- 2 files changed, 18 insertions(+), 11 deletions(-) diff --git a/x/exp/ast/json.go b/x/exp/ast/json.go index cf1bd237..7e68d7c9 100644 --- a/x/exp/ast/json.go +++ b/x/exp/ast/json.go @@ -86,11 +86,19 @@ func (j unaryJSON) ToNode(f func(a Node) Node) (Node, error) { return f(arg), nil } -type accessJSON struct { +type attrJSON struct { Left nodeJSON `json:"left"` Attr string `json:"attr"` } +func (j attrJSON) ToNode(f func(a Node, k string) Node) (Node, error) { + left, err := j.Left.ToNode() + if err != nil { + return Node{}, fmt.Errorf("error in left: %w", err) + } + return f(left, j.Attr), nil +} + type nodeJSON struct { // Value @@ -124,7 +132,8 @@ type nodeJSON struct { ContainsAny *binaryJSON `json:"containsAny"` // ., has - Access *accessJSON `json:"."` + Access *attrJSON `json:"."` + Has *attrJSON `json:"has"` // like // if-then-else @@ -197,11 +206,9 @@ func (j nodeJSON) ToNode() (Node, error) { // ., has case j.Access != nil: - left, err := j.Access.Left.ToNode() - if err != nil { - return Node{}, fmt.Errorf("error in left of access: %w", err) - } - return left.Access(j.Access.Attr), nil + return j.Access.ToNode(Node.Access) + case j.Has != nil: + return j.Has.ToNode(Node.Has) } // like diff --git a/x/exp/ast/operator.go b/x/exp/ast/operator.go index 0de88bf6..b4019124 100644 --- a/x/exp/ast/operator.go +++ b/x/exp/ast/operator.go @@ -89,10 +89,6 @@ func (lhs Node) In(rhs Node) Node { return newOpNode(nodeTypeIn, lhs, rhs) } -func (lhs Node) Has(rhs Node) Node { - return newOpNode(nodeTypeHas, lhs, rhs) -} - func (lhs Node) Is(rhs Node) Node { return newOpNode(nodeTypeIs, lhs, rhs) } @@ -130,6 +126,10 @@ func (lhs Node) AccessNode(rhs Node) Node { return newOpNode(nodeTypeAccess, lhs, rhs) } +func (lhs Node) Has(attr string) Node { + return newOpNode(nodeTypeHas, lhs, String(types.String(attr))) +} + // ___ ____ _ _ _ // |_ _| _ \ / \ __| | __| |_ __ ___ ___ ___ // | || |_) / _ \ / _` |/ _` | '__/ _ \/ __/ __| From 1f0fb26ebd35f68434713c6865040a66d9dacd93 Mon Sep 17 00:00:00 2001 From: philhassey Date: Tue, 30 Jul 2024 16:28:13 -0600 Subject: [PATCH 020/216] x/exp/ast: add like to JSON Addresses IDX-48 Signed-off-by: philhassey --- x/exp/ast/json.go | 23 ++++++++++++++--------- x/exp/ast/node.go | 1 + x/exp/ast/operator.go | 4 ++++ 3 files changed, 19 insertions(+), 9 deletions(-) diff --git a/x/exp/ast/json.go b/x/exp/ast/json.go index 7e68d7c9..1994b33a 100644 --- a/x/exp/ast/json.go +++ b/x/exp/ast/json.go @@ -86,12 +86,12 @@ func (j unaryJSON) ToNode(f func(a Node) Node) (Node, error) { return f(arg), nil } -type attrJSON struct { +type strJSON struct { Left nodeJSON `json:"left"` Attr string `json:"attr"` } -func (j attrJSON) ToNode(f func(a Node, k string) Node) (Node, error) { +func (j strJSON) ToNode(f func(a Node, k string) Node) (Node, error) { left, err := j.Left.ToNode() if err != nil { return Node{}, fmt.Errorf("error in left: %w", err) @@ -132,10 +132,12 @@ type nodeJSON struct { ContainsAny *binaryJSON `json:"containsAny"` // ., has - Access *attrJSON `json:"."` - Has *attrJSON `json:"has"` + Access *strJSON `json:"."` + Has *strJSON `json:"has"` // like + Like *strJSON `json:"like"` + // if-then-else // Set // Record @@ -209,13 +211,16 @@ func (j nodeJSON) ToNode() (Node, error) { return j.Access.ToNode(Node.Access) case j.Has != nil: return j.Has.ToNode(Node.Has) - } // like - // if-then-else - // Set - // Record - // Any other key + case j.Like != nil: + return j.Like.ToNode(Node.Like) + + // if-then-else + // Set + // Record + // Any other key + } return Node{}, fmt.Errorf("unknown node") } diff --git a/x/exp/ast/node.go b/x/exp/ast/node.go index e0b1d841..fce418a3 100644 --- a/x/exp/ast/node.go +++ b/x/exp/ast/node.go @@ -18,6 +18,7 @@ const ( nodeTypeEquals nodeTypeGreater nodeTypeGreaterEqual + nodeTypeLike nodeTypeHas nodeTypeIf nodeTypeIn diff --git a/x/exp/ast/operator.go b/x/exp/ast/operator.go index b4019124..6e54efde 100644 --- a/x/exp/ast/operator.go +++ b/x/exp/ast/operator.go @@ -33,6 +33,10 @@ func (lhs Node) GreaterThanOrEqual(rhs Node) Node { return newOpNode(nodeTypeGreaterEqual, lhs, rhs) } +func (lhs Node) Like(patt string) Node { + return newOpNode(nodeTypeLike, lhs, String(types.String(patt))) +} + // _ _ _ // | | ___ __ _(_) ___ __ _| | // | | / _ \ / _` | |/ __/ _` | | From 34186915756641c0cd573d6e94c9265043ff984b Mon Sep 17 00:00:00 2001 From: philhassey Date: Tue, 30 Jul 2024 16:32:17 -0600 Subject: [PATCH 021/216] x/exp/ast: add if-then-else node to JSON Addresses IDX-48 Signed-off-by: philhassey --- x/exp/ast/json.go | 28 +++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/x/exp/ast/json.go b/x/exp/ast/json.go index 1994b33a..b5a95668 100644 --- a/x/exp/ast/json.go +++ b/x/exp/ast/json.go @@ -99,6 +99,28 @@ func (j strJSON) ToNode(f func(a Node, k string) Node) (Node, error) { return f(left, j.Attr), nil } +type ifThenElseJSON struct { + If nodeJSON `json:"if"` + Then nodeJSON `json:"then"` + Else nodeJSON `json:"else"` +} + +func (j ifThenElseJSON) ToNode() (Node, error) { + if_, err := j.If.ToNode() + if err != nil { + return Node{}, fmt.Errorf("error in if: %w", err) + } + then, err := j.Then.ToNode() + if err != nil { + return Node{}, fmt.Errorf("error in then: %w", err) + } + else_, err := j.Else.ToNode() + if err != nil { + return Node{}, fmt.Errorf("error in else: %w", err) + } + return If(if_, then, else_), nil +} + type nodeJSON struct { // Value @@ -139,6 +161,7 @@ type nodeJSON struct { Like *strJSON `json:"like"` // if-then-else + IfThenElse *ifThenElseJSON `json:"if-then-else"` // Set // Record // Any other key @@ -216,7 +239,10 @@ func (j nodeJSON) ToNode() (Node, error) { case j.Like != nil: return j.Like.ToNode(Node.Like) - // if-then-else + // if-then-else + case j.IfThenElse != nil: + return j.IfThenElse.ToNode() + // Set // Record // Any other key From 59a1a9b26dec9edc47216580bd3f88c2c58a9673 Mon Sep 17 00:00:00 2001 From: philhassey Date: Tue, 30 Jul 2024 16:40:07 -0600 Subject: [PATCH 022/216] x/exp/ast: add set, record JSON Addresses IDX-48 Signed-off-by: philhassey --- x/exp/ast/json.go | 43 +++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 41 insertions(+), 2 deletions(-) diff --git a/x/exp/ast/json.go b/x/exp/ast/json.go index b5a95668..4c2b4256 100644 --- a/x/exp/ast/json.go +++ b/x/exp/ast/json.go @@ -121,6 +121,34 @@ func (j ifThenElseJSON) ToNode() (Node, error) { return If(if_, then, else_), nil } +type jsonSet []nodeJSON + +func (j jsonSet) ToNode() (Node, error) { + var nodes []Node + for _, jj := range j { + n, err := jj.ToNode() + if err != nil { + return Node{}, fmt.Errorf("error in set: %w", err) + } + nodes = append(nodes, n) + } + return SetNodes(nodes), nil +} + +type jsonRecord map[string]nodeJSON + +func (j jsonRecord) ToNode() (Node, error) { + nodes := map[types.String]Node{} + for k, v := range j { + n, err := v.ToNode() + if err != nil { + return Node{}, fmt.Errorf("error in record: %w", err) + } + nodes[types.String(k)] = n + } + return RecordNodes(nodes), nil +} + type nodeJSON struct { // Value @@ -162,8 +190,13 @@ type nodeJSON struct { // if-then-else IfThenElse *ifThenElseJSON `json:"if-then-else"` + // Set + Set jsonSet `json:"Set"` + // Record + Record jsonRecord `json:"Record"` + // Any other key } @@ -243,8 +276,14 @@ func (j nodeJSON) ToNode() (Node, error) { case j.IfThenElse != nil: return j.IfThenElse.ToNode() - // Set - // Record + // Set + case j.Set != nil: + return j.Set.ToNode() + + // Record + case j.Record != nil: + return j.Record.ToNode() + // Any other key } From 891ac2057e3f5518c5748d518b5623582ffe5ce3 Mon Sep 17 00:00:00 2001 From: philhassey Date: Tue, 30 Jul 2024 17:12:46 -0600 Subject: [PATCH 023/216] x/exp/ast: add extension support to JSON Addresses IDX-48 Signed-off-by: philhassey --- x/exp/ast/json.go | 120 ++++++++++++++++++++++++++++++++++++++---- x/exp/ast/node.go | 5 ++ x/exp/ast/operator.go | 16 ++++++ x/exp/ast/value.go | 2 +- 4 files changed, 132 insertions(+), 11 deletions(-) diff --git a/x/exp/ast/json.go b/x/exp/ast/json.go index 4c2b4256..e579cf5c 100644 --- a/x/exp/ast/json.go +++ b/x/exp/ast/json.go @@ -121,9 +121,9 @@ func (j ifThenElseJSON) ToNode() (Node, error) { return If(if_, then, else_), nil } -type jsonSet []nodeJSON +type arrayJSON []nodeJSON -func (j jsonSet) ToNode() (Node, error) { +func (j arrayJSON) ToNode() (Node, error) { var nodes []Node for _, jj := range j { n, err := jj.ToNode() @@ -135,9 +135,73 @@ func (j jsonSet) ToNode() (Node, error) { return SetNodes(nodes), nil } -type jsonRecord map[string]nodeJSON +func (j arrayJSON) ToExt1(f func(Node) Node) (Node, error) { + if len(j) != 1 { + return Node{}, fmt.Errorf("unexpected number of arguments for extension: %v", len(j)) + } + arg, err := j[0].ToNode() + if err != nil { + return Node{}, fmt.Errorf("error in extension: %w", err) + } + return f(arg), nil +} + +func (j arrayJSON) ToDecimalNode() (Node, error) { + if len(j) != 1 { + return Node{}, fmt.Errorf("unexpected number of arguments for extension: %v", len(j)) + } + arg, err := j[0].ToNode() + if err != nil { + return Node{}, fmt.Errorf("error in extension: %w", err) + } + s, ok := arg.value.(types.String) + if !ok { + return Node{}, fmt.Errorf("unexpected type for decimal") + } + v, err := types.ParseDecimal(string(s)) + if err != nil { + return Node{}, fmt.Errorf("error parsing decimal: %w", err) + } + return Decimal(v), nil +} + +func (j arrayJSON) ToIPAddrNode() (Node, error) { + if len(j) != 1 { + return Node{}, fmt.Errorf("unexpected number of arguments for extension: %v", len(j)) + } + arg, err := j[0].ToNode() + if err != nil { + return Node{}, fmt.Errorf("error in extension: %w", err) + } + s, ok := arg.value.(types.String) + if !ok { + return Node{}, fmt.Errorf("unexpected type for ipaddr") + } + v, err := types.ParseIPAddr(string(s)) + if err != nil { + return Node{}, fmt.Errorf("error parsing ipaddr: %w", err) + } + return IPAddr(v), nil +} + +func (j arrayJSON) ToExt2(f func(Node, Node) Node) (Node, error) { + if len(j) != 2 { + return Node{}, fmt.Errorf("unexpected number of arguments for extension: %v", len(j)) + } + left, err := j[0].ToNode() + if err != nil { + return Node{}, fmt.Errorf("error in argument 0: %w", err) + } + right, err := j[1].ToNode() + if err != nil { + return Node{}, fmt.Errorf("error in argument 1: %w", err) + } + return f(left, right), nil +} + +type recordJSON map[string]nodeJSON -func (j jsonRecord) ToNode() (Node, error) { +func (j recordJSON) ToNode() (Node, error) { nodes := map[types.String]Node{} for k, v := range j { n, err := v.ToNode() @@ -192,13 +256,25 @@ type nodeJSON struct { IfThenElse *ifThenElseJSON `json:"if-then-else"` // Set - Set jsonSet `json:"Set"` + Set arrayJSON `json:"Set"` // Record - Record jsonRecord `json:"Record"` - - // Any other key - + Record recordJSON `json:"Record"` + + // Any other function: decimal, ip + Decimal arrayJSON `json:"decimal"` + IP arrayJSON `json:"ip"` + + // Any other method: lessThan, lessThanOrEqual, greaterThan, greaterThanOrEqual, isIpv4, isIpv6, isLoopback, isMulticast, isInRange + LessThanExt arrayJSON `json:"lessThan"` + LessThanOrEqualExt arrayJSON `json:"lessThanOrEqual"` + GreaterThanExt arrayJSON `json:"greaterThan"` + GreaterThanOrEqualExt arrayJSON `json:"greaterThanOrEqual"` + IsIpv4Ext arrayJSON `json:"isIpv4"` + IsIpv6Ext arrayJSON `json:"isIpv6"` + IsLoopbackExt arrayJSON `json:"isLoopback"` + IsMulticastExt arrayJSON `json:"isMulticast"` + IsInRangeExt arrayJSON `json:"isInRange"` } func (j nodeJSON) ToNode() (Node, error) { @@ -284,7 +360,31 @@ func (j nodeJSON) ToNode() (Node, error) { case j.Record != nil: return j.Record.ToNode() - // Any other key + // Any other function: decimal, ip + case j.Decimal != nil: + return j.Decimal.ToDecimalNode() + case j.IP != nil: + return j.IP.ToIPAddrNode() + + // Any other method: lessThan, lessThanOrEqual, greaterThan, greaterThanOrEqual, isIpv4, isIpv6, isLoopback, isMulticast, isInRange + case j.LessThanExt != nil: + return j.LessThanExt.ToExt2(Node.LessThanExt) + case j.LessThanOrEqualExt != nil: + return j.LessThanOrEqualExt.ToExt2(Node.LessThanOrEqualExt) + case j.GreaterThanExt != nil: + return j.GreaterThanExt.ToExt2(Node.GreaterThanExt) + case j.GreaterThanOrEqualExt != nil: + return j.GreaterThanOrEqualExt.ToExt2(Node.GreaterThanOrEqualExt) + case j.IsIpv4Ext != nil: + return j.IsIpv4Ext.ToExt1(Node.IsIpv4) + case j.IsIpv6Ext != nil: + return j.IsIpv6Ext.ToExt1(Node.IsIpv6) + case j.IsLoopbackExt != nil: + return j.IsLoopbackExt.ToExt1(Node.IsLoopback) + case j.IsMulticastExt != nil: + return j.IsMulticastExt.ToExt1(Node.IsMulticast) + case j.IsInRangeExt != nil: + return j.IsInRangeExt.ToExt2(Node.IsInRange) } return Node{}, fmt.Errorf("unknown node") diff --git a/x/exp/ast/node.go b/x/exp/ast/node.go index fce418a3..e680b117 100644 --- a/x/exp/ast/node.go +++ b/x/exp/ast/node.go @@ -23,6 +23,7 @@ const ( nodeTypeIf nodeTypeIn nodeTypeIpAddr + nodeTypeDecimal nodeTypeIs nodeTypeIsInRange nodeTypeIsIpv4 @@ -43,6 +44,10 @@ const ( nodeTypeSub nodeTypeString nodeTypeVariable + nodeTypeLessExt + nodeTypeLessEqualExt + nodeTypeGreaterExt + nodeTypeGreaterEqualExt ) type Node struct { diff --git a/x/exp/ast/operator.go b/x/exp/ast/operator.go index 6e54efde..898fe2f4 100644 --- a/x/exp/ast/operator.go +++ b/x/exp/ast/operator.go @@ -33,6 +33,22 @@ func (lhs Node) GreaterThanOrEqual(rhs Node) Node { return newOpNode(nodeTypeGreaterEqual, lhs, rhs) } +func (lhs Node) LessThanExt(rhs Node) Node { + return newOpNode(nodeTypeLessExt, lhs, rhs) +} + +func (lhs Node) LessThanOrEqualExt(rhs Node) Node { + return newOpNode(nodeTypeLessEqualExt, lhs, rhs) +} + +func (lhs Node) GreaterThanExt(rhs Node) Node { + return newOpNode(nodeTypeGreaterExt, lhs, rhs) +} + +func (lhs Node) GreaterThanOrEqualExt(rhs Node) Node { + return newOpNode(nodeTypeGreaterEqualExt, lhs, rhs) +} + func (lhs Node) Like(patt string) Node { return newOpNode(nodeTypeLike, lhs, String(types.String(patt))) } diff --git a/x/exp/ast/value.go b/x/exp/ast/value.go index ed3d778b..0b046ac2 100644 --- a/x/exp/ast/value.go +++ b/x/exp/ast/value.go @@ -95,7 +95,7 @@ func Entity(e types.EntityUID) Node { } func Decimal(d types.Decimal) Node { - return newValueNode(nodeTypeEntity, d) + return newValueNode(nodeTypeDecimal, d) } func IPAddr(i types.IPAddr) Node { From 8b667fba2a3a4eab61a29436f76da259cbaedead Mon Sep 17 00:00:00 2001 From: philhassey Date: Tue, 30 Jul 2024 17:27:10 -0600 Subject: [PATCH 024/216] x/exp/ast: add json parsing of value types Addresses IDX-48 Signed-off-by: philhassey --- x/exp/ast/json.go | 44 ++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 42 insertions(+), 2 deletions(-) diff --git a/x/exp/ast/json.go b/x/exp/ast/json.go index e579cf5c..e3e6788d 100644 --- a/x/exp/ast/json.go +++ b/x/exp/ast/json.go @@ -1,6 +1,7 @@ package ast import ( + "bytes" "encoding/json" "fmt" @@ -216,7 +217,7 @@ func (j recordJSON) ToNode() (Node, error) { type nodeJSON struct { // Value - Value *string `json:"Value"` // could be any + Value *json.RawMessage `json:"Value"` // could be any // Var Var *string `json:"Var"` @@ -277,11 +278,50 @@ type nodeJSON struct { IsInRangeExt arrayJSON `json:"isInRange"` } +var ( // TODO: de-dupe from types? + errJSONDecode = fmt.Errorf("error decoding json") + errJSONLongOutOfRange = fmt.Errorf("long out of range") + errJSONUnsupportedType = fmt.Errorf("unsupported type") +) + +func parseRawMessage(j *json.RawMessage) (Node, error) { + // TODO: de-dupe from types? though it's not 100% compat, because of extensions :( + // TODO: make this faster if it matters + { + var res types.EntityUID + ptr := &res + if err := ptr.UnmarshalJSON(*j); err == nil { + return Entity(res), nil + } + } + + var res interface{} + dec := json.NewDecoder(bytes.NewBuffer(*j)) + dec.UseNumber() + if err := dec.Decode(&res); err != nil { + return Node{}, fmt.Errorf("%w: %w", errJSONDecode, err) + } + switch vv := res.(type) { + case string: + return String(types.String(vv)), nil + case bool: + return Boolean(types.Boolean(vv)), nil + case json.Number: + l, err := vv.Int64() + if err != nil { + return Node{}, fmt.Errorf("%w: %w", errJSONLongOutOfRange, err) + } + return Long(types.Long(l)), nil + } + return Node{}, errJSONUnsupportedType + +} + func (j nodeJSON) ToNode() (Node, error) { switch { // Value case j.Value != nil: - return String(types.String(*j.Value)), nil + return parseRawMessage(j.Value) // Var case j.Var != nil: From 67855de1d6b4231fc762f77f780a50e163908097 Mon Sep 17 00:00:00 2001 From: philhassey Date: Tue, 30 Jul 2024 17:30:19 -0600 Subject: [PATCH 025/216] x/exp/ast: add TODO note Addresses IDX-48 Signed-off-by: philhassey --- x/exp/ast/value.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x/exp/ast/value.go b/x/exp/ast/value.go index 0b046ac2..46bf7134 100644 --- a/x/exp/ast/value.go +++ b/x/exp/ast/value.go @@ -59,7 +59,7 @@ func Record(r types.Record) Node { for k, v := range r { recordNodes[types.String(k)] = valueToNode(v) } - return RecordNodes(recordNodes) + return RecordNodes(recordNodes) // TODO: maybe inline this to avoid the double conversion } // RecordNodes allows for a complex record definition with values potentially From 5013489cde94d8c52e523eb4a3c7441e003af143 Mon Sep 17 00:00:00 2001 From: philhassey Date: Wed, 31 Jul 2024 12:23:40 -0600 Subject: [PATCH 026/216] x/exp/ast: improve clarity of node types Addresses IDX-55 Signed-off-by: philhassey --- x/exp/ast/json.go | 39 +++++++++++++++++++++------------------ x/exp/ast/json_test.go | 19 ++++++++++++++++--- x/exp/ast/node.go | 3 +++ x/exp/ast/operator.go | 8 ++++++-- x/exp/ast/policy.go | 4 ++-- x/exp/ast/scope.go | 14 ++++++++++++-- 6 files changed, 60 insertions(+), 27 deletions(-) diff --git a/x/exp/ast/json.go b/x/exp/ast/json.go index e3e6788d..eb19147e 100644 --- a/x/exp/ast/json.go +++ b/x/exp/ast/json.go @@ -9,46 +9,49 @@ import ( ) type policyJSON struct { + Annotations map[string]string `json:"annotations,omitempty"` Effect string `json:"effect"` - Annotations map[string]string `json:"annotations"` Principal scopeJSON `json:"principal"` Action scopeJSON `json:"action"` Resource scopeJSON `json:"resource"` - Conditions []conditionJSON `json:"conditions"` + Conditions []conditionJSON `json:"conditions,omitempty"` +} + +type inJSON struct { + Entity types.EntityUID `json:"entity"` } type scopeJSON struct { Op string `json:"op"` - Entity types.EntityUID `json:"entity"` - Entities []types.EntityUID `json:"entities"` - EntityType string `json:"entity_type"` - In *struct { - Entity types.EntityUID `json:"entity"` - } `json:"in"` + Entity *types.EntityUID `json:"entity,omitempty"` + Entities []types.EntityUID `json:"entities,omitempty"` + EntityType string `json:"entity_type,omitempty"` + In *inJSON `json:"in,omitempty"` } -func (s *scopeJSON) ToNode(n Node) (Node, error) { +func (s *scopeJSON) ToNode(variable Node) (Node, error) { switch s.Op { case "All": return True(), nil case "==": - return n.Equals(Entity(s.Entity)), nil + if s.Entity == nil { + return Node{}, fmt.Errorf("missing entity") + } + return variable.Equals(Entity(*s.Entity)), nil case "in": - var zero types.EntityUID - if s.Entity != zero { - return n.In(Entity(s.Entity)), nil // TODO: review shape, maybe .In vs .InNode + if s.Entity != nil { + return variable.In(Entity(*s.Entity)), nil // TODO: review shape, maybe .In vs .InNode } var set types.Set for _, e := range s.Entities { set = append(set, e) } - return n.In(Set(set)), nil // TODO: maybe there is an In and an InSet Node? + return variable.In(Set(set)), nil // TODO: maybe there is an In and an InSet Node? case "is": - isNode := n.Is(String(types.String(s.EntityType))) // TODO: hmmm, I'm not sure can this be Stronger-typed? if s.In == nil { - return isNode, nil + return variable.Is(types.String(s.EntityType)), nil // TODO: hmmm, I'm not sure can this be Stronger-typed? } - return isNode.And(n.In(Entity(s.In.Entity))), nil + return variable.IsIn(types.String(s.EntityType), Entity(s.In.Entity)), nil } return Node{}, fmt.Errorf("unknown op: %v", s.Op) } @@ -433,7 +436,7 @@ func (j nodeJSON) ToNode() (Node, error) { func (p *Policy) UnmarshalJSON(b []byte) error { var j policyJSON if err := json.Unmarshal(b, &j); err != nil { - return err + return fmt.Errorf("error unmarshalling json: %w", err) } switch j.Effect { case "permit": diff --git a/x/exp/ast/json_test.go b/x/exp/ast/json_test.go index fb7cf53e..bf2b75d2 100644 --- a/x/exp/ast/json_test.go +++ b/x/exp/ast/json_test.go @@ -34,15 +34,24 @@ func TestUnmarshalJSON(t *testing.T) { "effect": "permit", "principal": { "op": "==", - "entity": { "type": "User", "id": "12UA45" } + "entity": { + "type": "User", + "id": "12UA45" + } }, "action": { "op": "==", - "entity": { "type": "Action", "id": "view" } + "entity": { + "type": "Action", + "id": "view" + } }, "resource": { "op": "in", - "entity": { "type": "Folder", "id": "abc" } + "entity": { + "type": "Folder", + "id": "abc" + } }, "conditions": [ { @@ -89,6 +98,10 @@ func TestUnmarshalJSON(t *testing.T) { if !reflect.DeepEqual(&p, tt.want) { t.Errorf("policy mismatch: got: %+v want: %+v", p, *tt.want) } + + // b, err := json.MarshalIndent(&p, "", " ") + // testutil.OK(t, err) + // testutil.Equals(t, string(b), tt.input) }) } diff --git a/x/exp/ast/node.go b/x/exp/ast/node.go index e680b117..22bea5d7 100644 --- a/x/exp/ast/node.go +++ b/x/exp/ast/node.go @@ -48,6 +48,9 @@ const ( nodeTypeLessEqualExt nodeTypeGreaterExt nodeTypeGreaterEqualExt + nodeTypeWhen + nodeTypeUnless + nodeTypeIsIn ) type Node struct { diff --git a/x/exp/ast/operator.go b/x/exp/ast/operator.go index 898fe2f4..b20be4cb 100644 --- a/x/exp/ast/operator.go +++ b/x/exp/ast/operator.go @@ -109,8 +109,12 @@ func (lhs Node) In(rhs Node) Node { return newOpNode(nodeTypeIn, lhs, rhs) } -func (lhs Node) Is(rhs Node) Node { - return newOpNode(nodeTypeIs, lhs, rhs) +func (lhs Node) Is(entityType types.String) Node { + return newOpNode(nodeTypeIs, lhs, String(entityType)) +} + +func (lhs Node) IsIn(entityType types.String, rhs Node) Node { + return newOpNode(nodeTypeIsIn, lhs, String(entityType), rhs) } func (lhs Node) Contains(rhs Node) Node { diff --git a/x/exp/ast/policy.go b/x/exp/ast/policy.go index 0e8dc49b..dab6048d 100644 --- a/x/exp/ast/policy.go +++ b/x/exp/ast/policy.go @@ -18,12 +18,12 @@ func Forbid() *Policy { } func (p *Policy) When(node Node) *Policy { - p.conditions = append(p.conditions, node) + p.conditions = append(p.conditions, Node{nodeType: nodeTypeUnless, args: []Node{node}}) return p } func (p *Policy) Unless(node Node) *Policy { - p.conditions = append(p.conditions, Not(node)) + p.conditions = append(p.conditions, Node{nodeType: nodeTypeWhen, args: []Node{node}}) return p } diff --git a/x/exp/ast/scope.go b/x/exp/ast/scope.go index 27d0b94e..eb559481 100644 --- a/x/exp/ast/scope.go +++ b/x/exp/ast/scope.go @@ -13,7 +13,12 @@ func (p *Policy) PrincipalIn(entity types.EntityUID) *Policy { } func (p *Policy) PrincipalIs(entityType types.String) *Policy { - p.principal = Principal().Is(EntityType(entityType)) + p.principal = Principal().Is(entityType) + return p +} + +func (p *Policy) PrincipalIsIn(entityType types.String, entity types.EntityUID) *Policy { + p.principal = Principal().IsIn(entityType, Entity(entity)) return p } @@ -46,6 +51,11 @@ func (p *Policy) ResourceIn(entity types.EntityUID) *Policy { } func (p *Policy) ResourceIs(entityType types.String) *Policy { - p.resource = Resource().Is(EntityType(entityType)) + p.resource = Resource().Is(entityType) + return p +} + +func (p *Policy) ResourceIsIn(entityType types.String, entity types.EntityUID) *Policy { + p.resource = Resource().IsIn(entityType, Entity(entity)) return p } From fd86bd40c330fec84c9c8cc5bbccf1966b1a9d79 Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Wed, 31 Jul 2024 11:38:08 -0700 Subject: [PATCH 027/216] cedar-go/x/exp/ast: Add Cedar text parsing for the scope portion of the policy Signed-off-by: philhassey --- x/exp/ast/ast_test.go | 2 +- x/exp/ast/parser.go | 329 +++++++++++++++ x/exp/ast/parser_test.go | 137 ++++++ x/exp/ast/tokenize.go | 705 +++++++++++++++++++++++++++++++ x/exp/ast/tokenize_mocks_test.go | 74 ++++ x/exp/ast/tokenize_test.go | 554 ++++++++++++++++++++++++ 6 files changed, 1800 insertions(+), 1 deletion(-) create mode 100644 x/exp/ast/parser.go create mode 100644 x/exp/ast/parser_test.go create mode 100644 x/exp/ast/tokenize.go create mode 100644 x/exp/ast/tokenize_mocks_test.go create mode 100644 x/exp/ast/tokenize_test.go diff --git a/x/exp/ast/ast_test.go b/x/exp/ast/ast_test.go index d9f9faa5..9aadbf1a 100644 --- a/x/exp/ast/ast_test.go +++ b/x/exp/ast/ast_test.go @@ -25,7 +25,7 @@ func TestAst(t *testing.T) { // unless { false }; _ = ast.Annotation("example", "one"). Permit(). - PrincipalEq(johnny). + PrincipalIsIn("User", johnny). ActionIn(sow, cast). When(ast.True()). Unless(ast.False()) diff --git a/x/exp/ast/parser.go b/x/exp/ast/parser.go new file mode 100644 index 00000000..0996b798 --- /dev/null +++ b/x/exp/ast/parser.go @@ -0,0 +1,329 @@ +package ast + +import ( + "fmt" + + "github.com/cedar-policy/cedar-go/types" +) + +func PolicyFromCedar(p *parser) (*Policy, error) { + annotations, err := p.annotations() + if err != nil { + return nil, err + } + + policy, err := p.effect(&annotations) + if err != nil { + return nil, err + } + + if err = p.exact("("); err != nil { + return nil, err + } + if err = p.principal(policy); err != nil { + return nil, err + } + if err = p.exact(","); err != nil { + return nil, err + } + if err = p.action(policy); err != nil { + return nil, err + } + if err = p.exact(","); err != nil { + return nil, err + } + if err = p.resource(policy); err != nil { + return nil, err + } + if err = p.exact(")"); err != nil { + return nil, err + } + // if res.Conditions, err = p.conditions(); err != nil { + // return res, err + // } + if err = p.exact(";"); err != nil { + return nil, err + } + + return policy, nil +} + +type parser struct { + tokens []Token + pos int +} + +func newParser(tokens []Token) parser { + return parser{tokens: tokens, pos: 0} +} + +func (p *parser) advance() Token { + t := p.peek() + if p.pos < len(p.tokens)-1 { + p.pos++ + } + return t +} + +func (p *parser) peek() Token { + return p.tokens[p.pos] +} + +func (p *parser) exact(tok string) error { + t := p.advance() + if t.Text != tok { + return p.errorf("exact got %v want %v", t.Text, tok) + } + return nil +} + +func (p *parser) errorf(s string, args ...interface{}) error { + var t Token + if p.pos < len(p.tokens) { + t = p.tokens[p.pos] + } + err := fmt.Errorf(s, args...) + return fmt.Errorf("parse error at %v %q: %w", t.Pos, t.Text, err) +} + +func (p *parser) annotations() (Annotations, error) { + var res Annotations + for p.peek().Text == "@" { + p.advance() + err := p.annotation(&res) + if err != nil { + return res, err + } + } + return res, nil + +} + +func (p *parser) annotation(a *Annotations) error { + var err error + t := p.advance() + if !t.isIdent() { + return p.errorf("expected ident") + } + name := types.String(t.Text) + if err = p.exact("("); err != nil { + return err + } + t = p.advance() + if !t.isString() { + return p.errorf("expected string") + } + value, err := t.stringValue() + if err != nil { + return err + } + if err = p.exact(")"); err != nil { + return err + } + + a.Annotation(name, types.String(value)) + return nil +} + +func (p *parser) effect(a *Annotations) (*Policy, error) { + next := p.advance() + if next.Text == "permit" { + return a.Permit(), nil + } else if next.Text == "forbid" { + return a.Forbid(), nil + } + + return nil, p.errorf("unexpected effect: %v", next.Text) +} + +func (p *parser) principal(policy *Policy) error { + if err := p.exact("principal"); err != nil { + return err + } + switch p.peek().Text { + case "==": + p.advance() + entity, err := p.entity() + if err != nil { + return err + } + policy.PrincipalEq(entity) + return nil + case "is": + p.advance() + path, err := p.path() + if err != nil { + return err + } + if p.peek().Text == "in" { + p.advance() + entity, err := p.entity() + if err != nil { + return err + } + policy.PrincipalIsIn(path, entity) + return nil + } + + policy.PrincipalIs(path) + return nil + case "in": + p.advance() + entity, err := p.entity() + if err != nil { + return err + } + policy.PrincipalIn(entity) + return nil + } + + return nil +} + +func (p *parser) entity() (types.EntityUID, error) { + var res types.EntityUID + var err error + t := p.advance() + if !t.isIdent() { + return res, p.errorf("expected ident") + } + res.Type = t.Text + for { + if err := p.exact("::"); err != nil { + return res, err + } + t := p.advance() + switch { + case t.isIdent(): + res.Type = fmt.Sprintf("%v::%v", res.Type, t.Text) + case t.isString(): + res.ID, err = t.stringValue() + if err != nil { + return res, err + } + return res, nil + default: + return res, p.errorf("unexpected token") + } + } +} + +func (p *parser) path() (types.String, error) { + var res types.String + t := p.advance() + if !t.isIdent() { + return res, p.errorf("expected ident") + } + res = types.String(t.Text) + for { + if p.peek().Text != "::" { + return res, nil + } + p.advance() + t := p.advance() + switch { + case t.isIdent(): + res = types.String(fmt.Sprintf("%v::%v", res, t.Text)) + default: + return res, p.errorf("unexpected token") + } + } +} + +func (p *parser) action(policy *Policy) error { + if err := p.exact("action"); err != nil { + return err + } + switch p.peek().Text { + case "==": + p.advance() + entity, err := p.entity() + if err != nil { + return err + } + policy.ActionEq(entity) + return nil + case "in": + p.advance() + if p.peek().Text == "[" { + p.advance() + entities, err := p.entlist() + if err != nil { + return err + } + policy.ActionIn(entities...) + p.advance() // entlist guarantees "]" + return nil + } else { + entity, err := p.entity() + if err != nil { + return err + } + policy.ActionIn(entity) + return nil + } + } + + return nil +} + +func (p *parser) entlist() ([]types.EntityUID, error) { + var res []types.EntityUID + for p.peek().Text != "]" { + if len(res) > 0 { + if err := p.exact(","); err != nil { + return nil, err + } + } + e, err := p.entity() + if err != nil { + return nil, err + } + res = append(res, e) + } + return res, nil +} + +func (p *parser) resource(policy *Policy) error { + if err := p.exact("resource"); err != nil { + return err + } + switch p.peek().Text { + case "==": + p.advance() + entity, err := p.entity() + if err != nil { + return err + } + policy.ResourceEq(entity) + return nil + case "is": + p.advance() + path, err := p.path() + if err != nil { + return err + } + if p.peek().Text == "in" { + p.advance() + entity, err := p.entity() + if err != nil { + return err + } + policy.ResourceIsIn(path, entity) + return nil + } + + policy.ResourceIs(path) + return nil + case "in": + p.advance() + entity, err := p.entity() + if err != nil { + return err + } + policy.ResourceIn(entity) + return nil + } + + return nil +} diff --git a/x/exp/ast/parser_test.go b/x/exp/ast/parser_test.go new file mode 100644 index 00000000..8dceef58 --- /dev/null +++ b/x/exp/ast/parser_test.go @@ -0,0 +1,137 @@ +package ast + +import ( + "testing" + + "github.com/cedar-policy/cedar-go/testutil" + "github.com/cedar-policy/cedar-go/types" +) + +var johnny = types.EntityUID{ + Type: "User", + ID: "johnny", +} +var folkHeroes = types.EntityUID{ + Type: "Group", + ID: "folkHeroes", +} +var sow = types.EntityUID{ + Type: "Action", + ID: "sow", +} +var farming = types.EntityUID{ + Type: "ActionType", + ID: "farming", +} +var forestry = types.EntityUID{ + Type: "ActionType", + ID: "forestry", +} +var apple = types.EntityUID{ + Type: "Crop", + ID: "apple", +} +var malus = types.EntityUID{ + Type: "Genus", + ID: "malus", +} + +var parseTests = []struct { + Text string + ExpectedPolicy *Policy +}{ + { + `permit ( + principal, + action, + resource + );`, + Permit(), + }, + { + `forbid ( + principal, + action, + resource + );`, + Forbid(), + }, + { + `@foo("bar") + permit ( + principal, + action, + resource + );`, + Annotation("foo", "bar").Permit(), + }, + { + `@foo("bar") + @baz("quux") + permit ( + principal, + action, + resource + );`, + Annotation("foo", "bar").Annotation("baz", "quux").Permit(), + }, + { + `permit ( + principal == User::"johnny", + action == Action::"sow", + resource == Crop::"apple" + );`, + Permit().PrincipalEq(johnny).ActionEq(sow).ResourceEq(apple), + }, + { + `permit ( + principal is User, + action, + resource is Crop + );`, + Permit().PrincipalIs("User").ResourceIs("Crop"), + }, + { + `permit ( + principal is User in Group::"folkHeroes", + action, + resource is Crop in Genus::"malus" + );`, + Permit().PrincipalIsIn("User", folkHeroes).ResourceIsIn("Crop", malus), + }, + { + `permit ( + principal in Group::"folkHeroes", + action in ActionType::"farming", + resource in Genus::"malus" + );`, + Permit().PrincipalIn(folkHeroes).ActionIn(farming).ResourceIn(malus), + }, + { + `permit ( + principal, + action in [ActionType::"farming", ActionType::"forestry"], + resource + );`, + Permit().ActionIn(farming, forestry), + }, +} + +func TestParse(t *testing.T) { + t.Parallel() + for _, tt := range parseTests { + t.Run(tt.Text, func(t *testing.T) { + t.Parallel() + + tokens, err := Tokenize([]byte(tt.Text)) + testutil.OK(t, err) + + parser := newParser(tokens) + + policy, err := PolicyFromCedar(&parser) + testutil.OK(t, err) + + testutil.Equals(t, policy, tt.ExpectedPolicy) + }) + } +} diff --git a/x/exp/ast/tokenize.go b/x/exp/ast/tokenize.go new file mode 100644 index 00000000..720d7bf2 --- /dev/null +++ b/x/exp/ast/tokenize.go @@ -0,0 +1,705 @@ +package ast + +import ( + "bytes" + "fmt" + "io" + "strconv" + "strings" + "unicode" + "unicode/utf8" +) + +//go:generate moq -pkg parser -fmt goimports -out tokenize_mocks_test.go . reader + +// This type alias is for test purposes only. +type reader = io.Reader + +type TokenType int + +const ( + TokenEOF = TokenType(iota) + TokenIdent + TokenInt + TokenString + TokenOperator + TokenUnknown +) + +type Token struct { + Type TokenType + Pos Position + Text string +} + +func (t Token) isEOF() bool { + return t.Type == TokenEOF +} + +func (t Token) isIdent() bool { + return t.Type == TokenIdent +} + +func (t Token) isInt() bool { + return t.Type == TokenInt +} + +func (t Token) isString() bool { + return t.Type == TokenString +} + +func (t Token) toString() string { + return t.Text +} + +func (t Token) stringValue() (string, error) { + s := t.Text + s = strings.TrimPrefix(s, "\"") + s = strings.TrimSuffix(s, "\"") + b := []byte(s) + res, _, err := rustUnquote(b, false) + return res, err +} + +func (t Token) patternValue() (Pattern, error) { + return NewPattern(t.Text) +} + +func nextRune(b []byte, i int) (rune, int, error) { + ch, size := utf8.DecodeRune(b[i:]) + if ch == utf8.RuneError { + return ch, i, fmt.Errorf("bad unicode rune") + } + return ch, i + size, nil +} + +func parseHexEscape(b []byte, i int) (rune, int, error) { + var ch rune + var err error + ch, i, err = nextRune(b, i) + if err != nil { + return 0, i, err + } + if !isHexadecimal(ch) { + return 0, i, fmt.Errorf("bad hex escape sequence") + } + res := digitVal(ch) + ch, i, err = nextRune(b, i) + if err != nil { + return 0, i, err + } + if !isHexadecimal(ch) { + return 0, i, fmt.Errorf("bad hex escape sequence") + } + res = 16*res + digitVal(ch) + if res > 127 { + return 0, i, fmt.Errorf("bad hex escape sequence") + } + return rune(res), i, nil +} + +func parseUnicodeEscape(b []byte, i int) (rune, int, error) { + var ch rune + var err error + + ch, i, err = nextRune(b, i) + if err != nil { + return 0, i, err + } + if ch != '{' { + return 0, i, fmt.Errorf("bad unicode escape sequence") + } + + digits := 0 + res := 0 + for { + ch, i, err = nextRune(b, i) + if err != nil { + return 0, i, err + } + if ch == '}' { + break + } + if !isHexadecimal(ch) { + return 0, i, fmt.Errorf("bad unicode escape sequence") + } + res = 16*res + digitVal(ch) + digits++ + } + + if digits == 0 || digits > 6 || !utf8.ValidRune(rune(res)) { + return 0, i, fmt.Errorf("bad unicode escape sequence") + } + + return rune(res), i, nil +} + +func Unquote(s string) (string, error) { + s = strings.TrimPrefix(s, "\"") + s = strings.TrimSuffix(s, "\"") + res, _, err := rustUnquote([]byte(s), false) + return res, err +} + +func rustUnquote(b []byte, star bool) (string, []byte, error) { + var sb strings.Builder + var ch rune + var err error + i := 0 + for i < len(b) { + ch, i, err = nextRune(b, i) + if err != nil { + return "", nil, err + } + if star && ch == '*' { + i-- + return sb.String(), b[i:], nil + } + if ch != '\\' { + sb.WriteRune(ch) + continue + } + ch, i, err = nextRune(b, i) + if err != nil { + return "", nil, err + } + switch ch { + case 'n': + sb.WriteRune('\n') + case 'r': + sb.WriteRune('\r') + case 't': + sb.WriteRune('\t') + case '\\': + sb.WriteRune('\\') + case '0': + sb.WriteRune('\x00') + case '\'': + sb.WriteRune('\'') + case '"': + sb.WriteRune('"') + case 'x': + ch, i, err = parseHexEscape(b, i) + if err != nil { + return "", nil, err + } + sb.WriteRune(ch) + case 'u': + ch, i, err = parseUnicodeEscape(b, i) + if err != nil { + return "", nil, err + } + sb.WriteRune(ch) + case '*': + if !star { + return "", nil, fmt.Errorf("bad char escape") + } + sb.WriteRune('*') + default: + return "", nil, fmt.Errorf("bad char escape") + } + } + return sb.String(), b[i:], nil +} + +type PatternComponent struct { + Star bool + Chunk string +} + +type Pattern struct { + Comps []PatternComponent + Raw string +} + +func (p Pattern) String() string { + return p.Raw +} + +func NewPattern(literal string) (Pattern, error) { + rawPat := literal + + literal = strings.TrimPrefix(literal, "\"") + literal = strings.TrimSuffix(literal, "\"") + + b := []byte(literal) + + var comps []PatternComponent + for len(b) > 0 { + var comp PatternComponent + var err error + for len(b) > 0 && b[0] == '*' { + b = b[1:] + comp.Star = true + } + comp.Chunk, b, err = rustUnquote(b, true) + if err != nil { + return Pattern{}, err + } + comps = append(comps, comp) + } + return Pattern{ + Comps: comps, + Raw: rawPat, + }, nil +} + +func isHexadecimal(ch rune) bool { + return isDecimal(ch) || ('a' <= lower(ch) && lower(ch) <= 'f') +} + +// TODO: make FakeRustQuote actually accurate in all cases +func FakeRustQuote(s string) string { + return strconv.Quote(s) +} + +func (t Token) intValue() (int64, error) { + return strconv.ParseInt(t.Text, 10, 64) +} + +func Tokenize(src []byte) ([]Token, error) { + var res []Token + var s scanner + s.Init(bytes.NewBuffer(src)) + for tok := s.nextToken(); s.err == nil && tok.Type != TokenEOF; tok = s.nextToken() { + res = append(res, tok) + } + if s.err != nil { + return nil, s.err + } + res = append(res, Token{Type: TokenEOF, Pos: s.position}) + return res, nil +} + +// Position is a value that represents a source position. +// A position is valid if Line > 0. +type Position struct { + Offset int // byte offset, starting at 0 + Line int // line number, starting at 1 + Column int // column number, starting at 1 (character count per line) +} + +func (pos Position) String() string { + return fmt.Sprintf(":%d:%d", pos.Line, pos.Column) +} + +const ( + specialRuneEOF = rune(-(iota + 1)) + specialRuneBOF +) + +const bufLen = 1024 // at least utf8.UTFMax + +// A scanner implements reading of Unicode characters and tokens from an io.Reader. +type scanner struct { + // Input + src io.Reader + + // Source buffer + srcBuf [bufLen + 1]byte // +1 for sentinel for common case of s.next() + srcPos int // reading position (srcBuf index) + srcEnd int // source end (srcBuf index) + + // Source position + srcBufOffset int // byte offset of srcBuf[0] in source + line int // line count + column int // character count + lastLineLen int // length of last line in characters (for correct column reporting) + lastCharLen int // length of last character in bytes + + // Token text buffer + // Typically, token text is stored completely in srcBuf, but in general + // the token text's head may be buffered in tokBuf while the token text's + // tail is stored in srcBuf. + tokBuf bytes.Buffer // token text head that is not in srcBuf anymore + tokPos int // token text tail position (srcBuf index); valid if >= 0 + tokEnd int // token text tail end (srcBuf index) + + // One character look-ahead + ch rune // character before current srcPos + + // Last error encountered by nextToken. + err error + + // Start position of most recently scanned token; set by nextToken. + // Calling Init or Next invalidates the position (Line == 0). + // If an error is reported (via Error) and position is invalid, + // the scanner is not inside a token. Call Pos to obtain an error + // position in that case, or to obtain the position immediately + // after the most recently scanned token. + position Position +} + +// Init initializes a Scanner with a new source and returns s. +func (s *scanner) Init(src io.Reader) *scanner { + s.src = src + + // initialize source buffer + // (the first call to next() will fill it by calling src.Read) + s.srcBuf[0] = utf8.RuneSelf // sentinel + s.srcPos = 0 + s.srcEnd = 0 + + // initialize source position + s.srcBufOffset = 0 + s.line = 1 + s.column = 0 + s.lastLineLen = 0 + s.lastCharLen = 0 + + // initialize token text buffer + // (required for first call to next()). + s.tokPos = -1 + + // initialize one character look-ahead + s.ch = specialRuneBOF // no char read yet, not EOF + + // initialize public fields + s.position.Line = 0 // invalidate token position + + return s +} + +// next reads and returns the next Unicode character. It is designed such +// that only a minimal amount of work needs to be done in the common ASCII +// case (one test to check for both ASCII and end-of-buffer, and one test +// to check for newlines). +func (s *scanner) next() rune { + ch, width := rune(s.srcBuf[s.srcPos]), 1 + + if ch >= utf8.RuneSelf { + // uncommon case: not ASCII or not enough bytes + for s.srcPos+utf8.UTFMax > s.srcEnd && !utf8.FullRune(s.srcBuf[s.srcPos:s.srcEnd]) { + // not enough bytes: read some more, but first + // save away token text if any + if s.tokPos >= 0 { + s.tokBuf.Write(s.srcBuf[s.tokPos:s.srcPos]) + s.tokPos = 0 + // s.tokEnd is set by nextToken() + } + // move unread bytes to beginning of buffer + copy(s.srcBuf[0:], s.srcBuf[s.srcPos:s.srcEnd]) + s.srcBufOffset += s.srcPos + // read more bytes + // (an io.Reader must return io.EOF when it reaches + // the end of what it is reading - simply returning + // n == 0 will make this loop retry forever; but the + // error is in the reader implementation in that case) + i := s.srcEnd - s.srcPos + n, err := s.src.Read(s.srcBuf[i:bufLen]) + s.srcPos = 0 + s.srcEnd = i + n + s.srcBuf[s.srcEnd] = utf8.RuneSelf // sentinel + if err != nil { + if err != io.EOF { + s.error(err.Error()) + } + if s.srcEnd == 0 { + if s.lastCharLen > 0 { + // previous character was not EOF + s.column++ + } + s.lastCharLen = 0 + return specialRuneEOF + } + // If err == EOF, we won't be getting more + // bytes; break to avoid infinite loop. If + // err is something else, we don't know if + // we can get more bytes; thus also break. + break + } + } + // at least one byte + ch = rune(s.srcBuf[s.srcPos]) + if ch >= utf8.RuneSelf { + // uncommon case: not ASCII + ch, width = utf8.DecodeRune(s.srcBuf[s.srcPos:s.srcEnd]) + if ch == utf8.RuneError && width == 1 { + // advance for correct error position + s.srcPos += width + s.lastCharLen = width + s.column++ + s.error("invalid UTF-8 encoding") + return ch + } + } + } + + // advance + s.srcPos += width + s.lastCharLen = width + s.column++ + + // special situations + switch ch { + case 0: + // for compatibility with other tools + s.error("invalid character NUL") + case '\n': + s.line++ + s.lastLineLen = s.column + s.column = 0 + } + + return ch +} + +func (s *scanner) error(msg string) { + s.tokEnd = s.srcPos - s.lastCharLen // make sure token text is terminated + s.err = fmt.Errorf("%v: %v", s.position, msg) +} + +func isIdentRune(ch rune, first bool) bool { + return ch == '_' || unicode.IsLetter(ch) || unicode.IsDigit(ch) && !first +} + +func (s *scanner) scanIdentifier() rune { + // we know the zeroth rune is OK; start scanning at the next one + ch := s.next() + for isIdentRune(ch, false) { + ch = s.next() + } + return ch +} + +func lower(ch rune) rune { return ('a' - 'A') | ch } // returns lower-case ch iff ch is ASCII letter +func isDecimal(ch rune) bool { return '0' <= ch && ch <= '9' } + +func (s *scanner) scanInteger(ch rune) rune { + for isDecimal(ch) { + ch = s.next() + } + return ch +} + +func digitVal(ch rune) int { + switch { + case '0' <= ch && ch <= '9': + return int(ch - '0') + case 'a' <= lower(ch) && lower(ch) <= 'f': + return int(lower(ch) - 'a' + 10) + } + return 16 // larger than any legal digit val +} + +func (s *scanner) scanHexDigits(ch rune, min, max int) rune { + n := 0 + for n < max && isHexadecimal(ch) { + ch = s.next() + n++ + } + if n < min || n > max { + s.error("invalid char escape") + } + return ch +} + +func (s *scanner) scanEscape() rune { + ch := s.next() // read character after '/' + switch ch { + case 'n', 'r', 't', '\\', '0', '\'', '"', '*': + // nothing to do + ch = s.next() + case 'x': + ch = s.scanHexDigits(s.next(), 2, 2) + case 'u': + ch = s.next() + if ch != '{' { + s.error("invalid char escape") + return ch + } + ch = s.scanHexDigits(s.next(), 1, 6) + if ch != '}' { + s.error("invalid char escape") + return ch + } + ch = s.next() + default: + s.error("invalid char escape") + } + return ch +} + +func (s *scanner) scanString() (n int) { + ch := s.next() // read character after quote + for ch != '"' { + if ch == '\n' || ch < 0 { + s.error("literal not terminated") + return + } + if ch == '\\' { + ch = s.scanEscape() + } else { + ch = s.next() + } + n++ + } + return +} + +func (s *scanner) scanComment(ch rune) rune { + // ch == '/' || ch == '*' + if ch == '/' { + // line comment + ch = s.next() // read character after "//" + for ch != '\n' && ch >= 0 { + ch = s.next() + } + return ch + } + + // general comment + ch = s.next() // read character after "/*" + for { + if ch < 0 { + s.error("comment not terminated") + break + } + ch0 := ch + ch = s.next() + if ch0 == '*' && ch == '/' { + ch = s.next() + break + } + } + return ch +} + +func (s *scanner) scanOperator(ch0, ch rune) (TokenType, rune) { + switch ch0 { + case '@', '.', ',', ';', '(', ')', '{', '}', '[', ']', '+', '-', '*': + case ':': + if ch == ':' { + ch = s.next() + } + case '!', '<', '>': + if ch == '=' { + ch = s.next() + } + case '=': + if ch != '=' { + return TokenUnknown, ch + } + ch = s.next() + case '|': + if ch != '|' { + return TokenUnknown, ch + } + ch = s.next() + case '&': + if ch != '&' { + return TokenUnknown, ch + } + ch = s.next() + default: + return TokenUnknown, ch + } + return TokenOperator, ch +} + +func isWhitespace(c rune) bool { + switch c { + case '\t', '\n', '\r', ' ': + return true + default: + return false + } +} + +// nextToken reads the next token or Unicode character from source and returns +// it. It returns specialRuneEOF at the end of the source. It reports scanner +// errors (read and token errors) by calling s.Error, if not nil; otherwise it +// prints an error message to os.Stderr. +func (s *scanner) nextToken() Token { + if s.ch == specialRuneBOF { + s.ch = s.next() + } + + ch := s.ch + + // reset token text position + s.tokPos = -1 + s.position.Line = 0 + +redo: + // skip white space + for isWhitespace(ch) { + ch = s.next() + } + + // start collecting token text + s.tokBuf.Reset() + s.tokPos = s.srcPos - s.lastCharLen + + // set token position + s.position.Offset = s.srcBufOffset + s.tokPos + if s.column > 0 { + // common case: last character was not a '\n' + s.position.Line = s.line + s.position.Column = s.column + } else { + // last character was a '\n' + // (we cannot be at the beginning of the source + // since we have called next() at least once) + s.position.Line = s.line - 1 + s.position.Column = s.lastLineLen + } + + // determine token value + var tt TokenType + switch { + case ch == specialRuneEOF: + tt = TokenEOF + case isIdentRune(ch, true): + ch = s.scanIdentifier() + tt = TokenIdent + case isDecimal(ch): + ch = s.scanInteger(ch) + tt = TokenInt + case ch == '"': + s.scanString() + ch = s.next() + tt = TokenString + case ch == '/': + ch0 := ch + ch = s.next() + if ch == '/' || ch == '*' { + s.tokPos = -1 // don't collect token text + ch = s.scanComment(ch) + goto redo + } + tt, ch = s.scanOperator(ch0, ch) + default: + tt, ch = s.scanOperator(ch, s.next()) + } + + // end of token text + s.tokEnd = s.srcPos - s.lastCharLen + s.ch = ch + + return Token{ + Type: tt, + Pos: s.position, + Text: s.tokenText(), + } +} + +// tokenText returns the string corresponding to the most recently scanned token. +// Valid after calling nextToken and in calls of Scanner.Error. +func (s *scanner) tokenText() string { + if s.tokPos < 0 { + // no token text + return "" + } + + if s.tokBuf.Len() == 0 { + // common case: the entire token text is still in srcBuf + return string(s.srcBuf[s.tokPos:s.tokEnd]) + } + + // part of the token text was saved in tokBuf: save the rest in + // tokBuf as well and return its content + s.tokBuf.Write(s.srcBuf[s.tokPos:s.tokEnd]) + s.tokPos = s.tokEnd // ensure idempotency of TokenText() call + return s.tokBuf.String() +} diff --git a/x/exp/ast/tokenize_mocks_test.go b/x/exp/ast/tokenize_mocks_test.go new file mode 100644 index 00000000..21d98b9e --- /dev/null +++ b/x/exp/ast/tokenize_mocks_test.go @@ -0,0 +1,74 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package ast + +import ( + "sync" +) + +// Ensure, that readerMock does implement reader. +// If this is not the case, regenerate this file with moq. +var _ reader = &readerMock{} + +// readerMock is a mock implementation of reader. +// +// func TestSomethingThatUsesreader(t *testing.T) { +// +// // make and configure a mocked reader +// mockedreader := &readerMock{ +// ReadFunc: func(p []byte) (int, error) { +// panic("mock out the Read method") +// }, +// } +// +// // use mockedreader in code that requires reader +// // and then make assertions. +// +// } +type readerMock struct { + // ReadFunc mocks the Read method. + ReadFunc func(p []byte) (int, error) + + // calls tracks calls to the methods. + calls struct { + // Read holds details about calls to the Read method. + Read []struct { + // P is the p argument value. + P []byte + } + } + lockRead sync.RWMutex +} + +// Read calls ReadFunc. +func (mock *readerMock) Read(p []byte) (int, error) { + if mock.ReadFunc == nil { + panic("readerMock.ReadFunc: method is nil but reader.Read was just called") + } + callInfo := struct { + P []byte + }{ + P: p, + } + mock.lockRead.Lock() + mock.calls.Read = append(mock.calls.Read, callInfo) + mock.lockRead.Unlock() + return mock.ReadFunc(p) +} + +// ReadCalls gets all the calls that were made to Read. +// Check the length with: +// +// len(mockedreader.ReadCalls()) +func (mock *readerMock) ReadCalls() []struct { + P []byte +} { + var calls []struct { + P []byte + } + mock.lockRead.RLock() + calls = mock.calls.Read + mock.lockRead.RUnlock() + return calls +} diff --git a/x/exp/ast/tokenize_test.go b/x/exp/ast/tokenize_test.go new file mode 100644 index 00000000..cb9292f3 --- /dev/null +++ b/x/exp/ast/tokenize_test.go @@ -0,0 +1,554 @@ +package ast + +import ( + "fmt" + "io" + "strings" + "testing" + "unicode/utf8" + + "github.com/cedar-policy/cedar-go/testutil" +) + +func TestTokenize(t *testing.T) { + t.Parallel() + input := ` +These are some identifiers +0 1 1234 +-1 9223372036854775807 -9223372036854775808 +"" "string" "\"\'\n\r\t\\\0" "\x123" "\u{0}\u{10fFfF}" +"*" "\*" "*\**" +@.,;(){}[]+-* +::: +!!=<<=>>= +||&& +// single line comment +/* +multiline comment +// embedded comment does nothing +*/ +'/%|&=` + want := []Token{ + {Type: TokenIdent, Text: "These", Pos: Position{Offset: 1, Line: 2, Column: 1}}, + {Type: TokenIdent, Text: "are", Pos: Position{Offset: 7, Line: 2, Column: 7}}, + {Type: TokenIdent, Text: "some", Pos: Position{Offset: 11, Line: 2, Column: 11}}, + {Type: TokenIdent, Text: "identifiers", Pos: Position{Offset: 16, Line: 2, Column: 16}}, + + {Type: TokenInt, Text: "0", Pos: Position{Offset: 28, Line: 3, Column: 1}}, + {Type: TokenInt, Text: "1", Pos: Position{Offset: 30, Line: 3, Column: 3}}, + {Type: TokenInt, Text: "1234", Pos: Position{Offset: 32, Line: 3, Column: 5}}, + + {Type: TokenOperator, Text: "-", Pos: Position{Offset: 37, Line: 4, Column: 1}}, + {Type: TokenInt, Text: "1", Pos: Position{Offset: 38, Line: 4, Column: 2}}, + {Type: TokenInt, Text: "9223372036854775807", Pos: Position{Offset: 40, Line: 4, Column: 4}}, + {Type: TokenOperator, Text: "-", Pos: Position{Offset: 60, Line: 4, Column: 24}}, + {Type: TokenInt, Text: "9223372036854775808", Pos: Position{Offset: 61, Line: 4, Column: 25}}, + + {Type: TokenString, Text: `""`, Pos: Position{Offset: 81, Line: 5, Column: 1}}, + {Type: TokenString, Text: `"string"`, Pos: Position{Offset: 84, Line: 5, Column: 4}}, + {Type: TokenString, Text: `"\"\'\n\r\t\\\0"`, Pos: Position{Offset: 93, Line: 5, Column: 13}}, + {Type: TokenString, Text: `"\x123"`, Pos: Position{Offset: 110, Line: 5, Column: 30}}, + {Type: TokenString, Text: `"\u{0}\u{10fFfF}"`, Pos: Position{Offset: 118, Line: 5, Column: 38}}, + + {Type: TokenString, Text: `"*"`, Pos: Position{Offset: 136, Line: 6, Column: 1}}, + {Type: TokenString, Text: `"\*"`, Pos: Position{Offset: 140, Line: 6, Column: 5}}, + {Type: TokenString, Text: `"*\**"`, Pos: Position{Offset: 145, Line: 6, Column: 10}}, + + {Type: TokenOperator, Text: "@", Pos: Position{Offset: 152, Line: 7, Column: 1}}, + {Type: TokenOperator, Text: ".", Pos: Position{Offset: 153, Line: 7, Column: 2}}, + {Type: TokenOperator, Text: ",", Pos: Position{Offset: 154, Line: 7, Column: 3}}, + {Type: TokenOperator, Text: ";", Pos: Position{Offset: 155, Line: 7, Column: 4}}, + {Type: TokenOperator, Text: "(", Pos: Position{Offset: 156, Line: 7, Column: 5}}, + {Type: TokenOperator, Text: ")", Pos: Position{Offset: 157, Line: 7, Column: 6}}, + {Type: TokenOperator, Text: "{", Pos: Position{Offset: 158, Line: 7, Column: 7}}, + {Type: TokenOperator, Text: "}", Pos: Position{Offset: 159, Line: 7, Column: 8}}, + {Type: TokenOperator, Text: "[", Pos: Position{Offset: 160, Line: 7, Column: 9}}, + {Type: TokenOperator, Text: "]", Pos: Position{Offset: 161, Line: 7, Column: 10}}, + {Type: TokenOperator, Text: "+", Pos: Position{Offset: 162, Line: 7, Column: 11}}, + {Type: TokenOperator, Text: "-", Pos: Position{Offset: 163, Line: 7, Column: 12}}, + {Type: TokenOperator, Text: "*", Pos: Position{Offset: 164, Line: 7, Column: 13}}, + + {Type: TokenOperator, Text: "::", Pos: Position{Offset: 166, Line: 8, Column: 1}}, + {Type: TokenOperator, Text: ":", Pos: Position{Offset: 168, Line: 8, Column: 3}}, + + {Type: TokenOperator, Text: "!", Pos: Position{Offset: 170, Line: 9, Column: 1}}, + {Type: TokenOperator, Text: "!=", Pos: Position{Offset: 171, Line: 9, Column: 2}}, + {Type: TokenOperator, Text: "<", Pos: Position{Offset: 173, Line: 9, Column: 4}}, + {Type: TokenOperator, Text: "<=", Pos: Position{Offset: 174, Line: 9, Column: 5}}, + {Type: TokenOperator, Text: ">", Pos: Position{Offset: 176, Line: 9, Column: 7}}, + {Type: TokenOperator, Text: ">=", Pos: Position{Offset: 177, Line: 9, Column: 8}}, + + {Type: TokenOperator, Text: "||", Pos: Position{Offset: 180, Line: 10, Column: 1}}, + {Type: TokenOperator, Text: "&&", Pos: Position{Offset: 182, Line: 10, Column: 3}}, + + {Type: TokenUnknown, Text: "'", Pos: Position{Offset: 265, Line: 16, Column: 1}}, + {Type: TokenUnknown, Text: "/", Pos: Position{Offset: 266, Line: 16, Column: 2}}, + {Type: TokenUnknown, Text: "%", Pos: Position{Offset: 267, Line: 16, Column: 3}}, + {Type: TokenUnknown, Text: "|", Pos: Position{Offset: 268, Line: 16, Column: 4}}, + {Type: TokenUnknown, Text: "&", Pos: Position{Offset: 269, Line: 16, Column: 5}}, + {Type: TokenUnknown, Text: "=", Pos: Position{Offset: 270, Line: 16, Column: 6}}, + + {Type: TokenEOF, Text: "", Pos: Position{Offset: 271, Line: 16, Column: 7}}, + } + got, err := Tokenize([]byte(input)) + testutil.OK(t, err) + testutil.Equals(t, got, want) +} + +func TestTokenizeErrors(t *testing.T) { + t.Parallel() + tests := []struct { + input string + wantErrStr string + wantErrPos Position + }{ + {"okay\x00not okay", "invalid character NUL", Position{Line: 1, Column: 1}}, + {`okay /* + stuff + `, "comment not terminated", Position{Line: 1, Column: 6}}, + {`okay " + " foo bar`, "literal not terminated", Position{Line: 1, Column: 6}}, + {`"okay" "\a"`, "invalid char escape", Position{Line: 1, Column: 8}}, + {`"okay" "\b"`, "invalid char escape", Position{Line: 1, Column: 8}}, + {`"okay" "\f"`, "invalid char escape", Position{Line: 1, Column: 8}}, + {`"okay" "\v"`, "invalid char escape", Position{Line: 1, Column: 8}}, + {`"okay" "\1"`, "invalid char escape", Position{Line: 1, Column: 8}}, + {`"okay" "\x"`, "invalid char escape", Position{Line: 1, Column: 8}}, + {`"okay" "\x1"`, "invalid char escape", Position{Line: 1, Column: 8}}, + {`"okay" "\ubadf"`, "invalid char escape", Position{Line: 1, Column: 8}}, + {`"okay" "\U0000badf"`, "invalid char escape", Position{Line: 1, Column: 8}}, + {`"okay" "\u{}"`, "invalid char escape", Position{Line: 1, Column: 8}}, + {`"okay" "\u{0000000}"`, "invalid char escape", Position{Line: 1, Column: 8}}, + {`"okay" "\u{z"`, "invalid char escape", Position{Line: 1, Column: 8}}, + } + for i, tt := range tests { + tt := tt + t.Run(fmt.Sprintf("%02d", i), func(t *testing.T) { + t.Parallel() + got, gotErr := Tokenize([]byte(tt.input)) + wantErrStr := fmt.Sprintf("%v: %s", tt.wantErrPos, tt.wantErrStr) + testutil.Error(t, gotErr) + testutil.Equals(t, gotErr.Error(), wantErrStr) + testutil.Equals(t, got, nil) + }) + } +} + +func TestIntTokenValues(t *testing.T) { + t.Parallel() + tests := []struct { + input string + wantOk bool + want int64 + wantErr string + }{ + {"0", true, 0, ""}, + {"9223372036854775807", true, 9223372036854775807, ""}, + {"9223372036854775808", false, 0, `strconv.ParseInt: parsing "9223372036854775808": value out of range`}, + } + for _, tt := range tests { + tt := tt + t.Run(tt.input, func(t *testing.T) { + t.Parallel() + got, err := Tokenize([]byte(tt.input)) + testutil.OK(t, err) + testutil.Equals(t, len(got), 2) + testutil.Equals(t, got[0].Type, TokenInt) + gotInt, err := got[0].intValue() + if err != nil { + testutil.Equals(t, tt.wantOk, false) + testutil.Equals(t, err.Error(), tt.wantErr) + } else { + testutil.Equals(t, tt.wantOk, true) + testutil.Equals(t, gotInt, tt.want) + } + }) + } +} + +func TestStringTokenValues(t *testing.T) { + t.Parallel() + tests := []struct { + input string + wantOk bool + want string + wantErr string + }{ + {`""`, true, "", ""}, + {`"hello"`, true, "hello", ""}, + {`"a\n\r\t\\\0b"`, true, "a\n\r\t\\\x00b", ""}, + {`"a\"b"`, true, "a\"b", ""}, + {`"a\'b"`, true, "a'b", ""}, + + {`"a\x00b"`, true, "a\x00b", ""}, + {`"a\x7fb"`, true, "a\x7fb", ""}, + {`"a\x80b"`, false, "", "bad hex escape sequence"}, + + {`"a\u{A}b"`, true, "a\u000ab", ""}, + {`"a\u{aB}b"`, true, "a\u00abb", ""}, + {`"a\u{AbC}b"`, true, "a\u0abcb", ""}, + {`"a\u{aBcD}b"`, true, "a\uabcdb", ""}, + {`"a\u{AbCdE}b"`, true, "a\U000abcdeb", ""}, + {`"a\u{10cDeF}b"`, true, "a\U0010cdefb", ""}, + {`"a\u{ffffff}b"`, false, "", "bad unicode escape sequence"}, + {`"a\u{d7ff}b"`, true, "a\ud7ffb", ""}, + {`"a\u{d800}b"`, false, "", "bad unicode escape sequence"}, + {`"a\u{dfff}b"`, false, "", "bad unicode escape sequence"}, + {`"a\u{e000}b"`, true, "a\ue000b", ""}, + {`"a\u{10ffff}b"`, true, "a\U0010ffffb", ""}, + {`"a\u{110000}b"`, false, "", "bad unicode escape sequence"}, + } + for _, tt := range tests { + tt := tt + t.Run(tt.input, func(t *testing.T) { + t.Parallel() + got, err := Tokenize([]byte(tt.input)) + testutil.OK(t, err) + testutil.Equals(t, len(got), 2) + testutil.Equals(t, got[0].Type, TokenString) + gotStr, err := got[0].stringValue() + if err != nil { + testutil.Equals(t, tt.wantOk, false) + testutil.Equals(t, err.Error(), tt.wantErr) + } else { + testutil.Equals(t, tt.wantOk, true) + testutil.Equals(t, gotStr, tt.want) + } + }) + } +} + +func TestParseUnicodeEscape(t *testing.T) { + t.Parallel() + tests := []struct { + name string + in []byte + out rune + outN int + err func(t testing.TB, err error) + }{ + {"happy", []byte{'{', '4', '2', '}'}, 0x42, 4, testutil.OK}, + {"badRune", []byte{'{', 0x80, 0x81}, 0, 1, testutil.Error}, + {"notHex", []byte{'{', 'g'}, 0, 2, testutil.Error}, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + out, n, err := parseUnicodeEscape(tt.in, 0) + testutil.Equals(t, out, tt.out) + testutil.Equals(t, n, tt.outN) + tt.err(t, err) + }) + } +} + +func TestUnquote(t *testing.T) { + t.Parallel() + tests := []struct { + name string + in string + out string + err func(t testing.TB, err error) + }{ + {"happy", `"test"`, `test`, testutil.OK}, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + out, err := Unquote(tt.in) + testutil.Equals(t, out, tt.out) + tt.err(t, err) + }) + } +} + +func TestRustUnquote(t *testing.T) { + t.Parallel() + // star == false + { + tests := []struct { + input string + wantOk bool + want string + wantErr string + }{ + {``, true, "", ""}, + {`hello`, true, "hello", ""}, + {`a\n\r\t\\\0b`, true, "a\n\r\t\\\x00b", ""}, + {`a\"b`, true, "a\"b", ""}, + {`a\'b`, true, "a'b", ""}, + + {`a\x00b`, true, "a\x00b", ""}, + {`a\x7fb`, true, "a\x7fb", ""}, + {`a\x80b`, false, "", "bad hex escape sequence"}, + + {string([]byte{0x80, 0x81}), false, "", "bad unicode rune"}, + {`a\u`, false, "", "bad unicode rune"}, + {`a\uz`, false, "", "bad unicode escape sequence"}, + {`a\u{}b`, false, "", "bad unicode escape sequence"}, + {`a\u{A}b`, true, "a\u000ab", ""}, + {`a\u{aB}b`, true, "a\u00abb", ""}, + {`a\u{AbC}b`, true, "a\u0abcb", ""}, + {`a\u{aBcD}b`, true, "a\uabcdb", ""}, + {`a\u{AbCdE}b`, true, "a\U000abcdeb", ""}, + {`a\u{10cDeF}b`, true, "a\U0010cdefb", ""}, + {`a\u{ffffff}b`, false, "", "bad unicode escape sequence"}, + {`a\u{0000000}b`, false, "", "bad unicode escape sequence"}, + {`a\u{d7ff}b`, true, "a\ud7ffb", ""}, + {`a\u{d800}b`, false, "", "bad unicode escape sequence"}, + {`a\u{dfff}b`, false, "", "bad unicode escape sequence"}, + {`a\u{e000}b`, true, "a\ue000b", ""}, + {`a\u{10ffff}b`, true, "a\U0010ffffb", ""}, + {`a\u{110000}b`, false, "", "bad unicode escape sequence"}, + + {`\`, false, "", "bad unicode rune"}, + {`\a`, false, "", "bad char escape"}, + {`\*`, false, "", "bad char escape"}, + {`\x`, false, "", "bad unicode rune"}, + {`\xz`, false, "", "bad hex escape sequence"}, + {`\xa`, false, "", "bad unicode rune"}, + {`\xaz`, false, "", "bad hex escape sequence"}, + {`\{`, false, "", "bad char escape"}, + {`\{z`, false, "", "bad char escape"}, + {`\{0`, false, "", "bad char escape"}, + {`\{0z`, false, "", "bad char escape"}, + } + for _, tt := range tests { + tt := tt + t.Run(tt.input, func(t *testing.T) { + t.Parallel() + got, rem, err := rustUnquote([]byte(tt.input), false) + if err != nil { + testutil.Equals(t, tt.wantOk, false) + testutil.Equals(t, err.Error(), tt.wantErr) + testutil.Equals(t, got, tt.want) + } else { + testutil.Equals(t, tt.wantOk, true) + testutil.Equals(t, got, tt.want) + testutil.Equals(t, rem, []byte("")) + } + }) + } + } + + // star == true + { + tests := []struct { + input string + wantOk bool + want string + wantRem string + wantErr string + }{ + {``, true, "", "", ""}, + {`hello`, true, "hello", "", ""}, + {`a\n\r\t\\\0b`, true, "a\n\r\t\\\x00b", "", ""}, + {`a\"b`, true, "a\"b", "", ""}, + {`a\'b`, true, "a'b", "", ""}, + + {`a\x00b`, true, "a\x00b", "", ""}, + {`a\x7fb`, true, "a\x7fb", "", ""}, + {`a\x80b`, false, "", "", "bad hex escape sequence"}, + + {`a\u`, false, "", "", "bad unicode rune"}, + {`a\uz`, false, "", "", "bad unicode escape sequence"}, + {`a\u{}b`, false, "", "", "bad unicode escape sequence"}, + {`a\u{A}b`, true, "a\u000ab", "", ""}, + {`a\u{aB}b`, true, "a\u00abb", "", ""}, + {`a\u{AbC}b`, true, "a\u0abcb", "", ""}, + {`a\u{aBcD}b`, true, "a\uabcdb", "", ""}, + {`a\u{AbCdE}b`, true, "a\U000abcdeb", "", ""}, + {`a\u{10cDeF}b`, true, "a\U0010cdefb", "", ""}, + {`a\u{ffffff}b`, false, "", "", "bad unicode escape sequence"}, + {`a\u{0000000}b`, false, "", "", "bad unicode escape sequence"}, + {`a\u{d7ff}b`, true, "a\ud7ffb", "", ""}, + {`a\u{d800}b`, false, "", "", "bad unicode escape sequence"}, + {`a\u{dfff}b`, false, "", "", "bad unicode escape sequence"}, + {`a\u{e000}b`, true, "a\ue000b", "", ""}, + {`a\u{10ffff}b`, true, "a\U0010ffffb", "", ""}, + {`a\u{110000}b`, false, "", "", "bad unicode escape sequence"}, + + {`*`, true, "", "*", ""}, + {`*hello*how*are*you`, true, "", "*hello*how*are*you", ""}, + {`hello*how*are*you`, true, "hello", "*how*are*you", ""}, + {`\**`, true, "*", "*", ""}, + + {`\`, false, "", "", "bad unicode rune"}, + {`\a`, false, "", "", "bad char escape"}, + {`\x`, false, "", "", "bad unicode rune"}, + {`\xz`, false, "", "", "bad hex escape sequence"}, + {`\xa`, false, "", "", "bad unicode rune"}, + {`\xaz`, false, "", "", "bad hex escape sequence"}, + {`\{`, false, "", "", "bad char escape"}, + {`\{z`, false, "", "", "bad char escape"}, + {`\{0`, false, "", "", "bad char escape"}, + {`\{0z`, false, "", "", "bad char escape"}, + } + for _, tt := range tests { + tt := tt + t.Run(tt.input, func(t *testing.T) { + t.Parallel() + got, rem, err := rustUnquote([]byte(tt.input), true) + if err != nil { + testutil.Equals(t, tt.wantOk, false) + testutil.Equals(t, err.Error(), tt.wantErr) + testutil.Equals(t, got, tt.want) + } else { + testutil.Equals(t, tt.wantOk, true) + testutil.Equals(t, got, tt.want) + testutil.Equals(t, string(rem), tt.wantRem) + } + }) + } + } +} + +func TestFakeRustQuote(t *testing.T) { + t.Parallel() + out := FakeRustQuote("hello") + testutil.Equals(t, out, `"hello"`) +} + +func TestPatternFromStringLiteral(t *testing.T) { + t.Parallel() + tests := []struct { + input string + wantOk bool + want []PatternComponent + wantErr string + }{ + {`""`, true, nil, ""}, + {`"a"`, true, []PatternComponent{{false, "a"}}, ""}, + {`"*"`, true, []PatternComponent{{true, ""}}, ""}, + {`"*a"`, true, []PatternComponent{{true, "a"}}, ""}, + {`"a*"`, true, []PatternComponent{{false, "a"}, {true, ""}}, ""}, + {`"**"`, true, []PatternComponent{{true, ""}}, ""}, + {`"**a"`, true, []PatternComponent{{true, "a"}}, ""}, + {`"a**"`, true, []PatternComponent{{false, "a"}, {true, ""}}, ""}, + {`"*a*"`, true, []PatternComponent{{true, "a"}, {true, ""}}, ""}, + {`"**a**"`, true, []PatternComponent{{true, "a"}, {true, ""}}, ""}, + {`"abra*ca"`, true, []PatternComponent{ + {false, "abra"}, {true, "ca"}, + }, ""}, + {`"abra**ca"`, true, []PatternComponent{ + {false, "abra"}, {true, "ca"}, + }, ""}, + {`"*abra*ca"`, true, []PatternComponent{ + {true, "abra"}, {true, "ca"}, + }, ""}, + {`"abra*ca*"`, true, []PatternComponent{ + {false, "abra"}, {true, "ca"}, {true, ""}, + }, ""}, + {`"*abra*ca*"`, true, []PatternComponent{ + {true, "abra"}, {true, "ca"}, {true, ""}, + }, ""}, + {`"*abra*ca*dabra"`, true, []PatternComponent{ + {true, "abra"}, {true, "ca"}, {true, "dabra"}, + }, ""}, + {`"*abra*c\**da\*ra"`, true, []PatternComponent{ + {true, "abra"}, {true, "c*"}, {true, "da*ra"}, + }, ""}, + {`"\u"`, false, nil, "bad unicode rune"}, + } + for _, tt := range tests { + tt := tt + t.Run(tt.input, func(t *testing.T) { + t.Parallel() + got, err := NewPattern(tt.input) + if err != nil { + testutil.Equals(t, tt.wantOk, false) + testutil.Equals(t, err.Error(), tt.wantErr) + } else { + testutil.Equals(t, tt.wantOk, true) + testutil.Equals(t, got.Comps, tt.want) + testutil.Equals(t, got.String(), tt.input) + } + }) + } +} + +func TestScanner(t *testing.T) { + t.Parallel() + t.Run("SrcError", func(t *testing.T) { + t.Parallel() + wantErr := fmt.Errorf("wantErr") + r := &readerMock{ + ReadFunc: func(_ []byte) (int, error) { + return 0, wantErr + }, + } + var s scanner + s.Init(r) + out := s.next() + testutil.Equals(t, out, specialRuneEOF) + }) + + t.Run("MidEmojiEOF", func(t *testing.T) { + t.Parallel() + var s scanner + var eof bool + str := []byte(string(`🐐`)) + r := &readerMock{ + ReadFunc: func(p []byte) (int, error) { + if eof { + return 0, io.EOF + } + p[0] = str[0] + eof = true + return 1, nil + }, + } + s.Init(r) + out := s.next() + testutil.Equals(t, out, utf8.RuneError) + out = s.next() + testutil.Equals(t, out, specialRuneEOF) + }) + + t.Run("NotAsciiEmoji", func(t *testing.T) { + t.Parallel() + var s scanner + s.Init(strings.NewReader(`🐐`)) + out := s.next() + testutil.Equals(t, out, '🐐') + }) + + t.Run("InvalidUTF8", func(t *testing.T) { + t.Parallel() + var s scanner + s.Init(strings.NewReader(string([]byte{0x80, 0x81}))) + out := s.next() + testutil.Equals(t, out, utf8.RuneError) + }) + + t.Run("tokenTextNone", func(t *testing.T) { + t.Parallel() + var s scanner + s.Init(strings.NewReader("")) + out := s.tokenText() + testutil.Equals(t, out, "") + }) +} + +func TestDigitVal(t *testing.T) { + t.Parallel() + tests := []struct { + name string + in rune + out int + }{ + {"happy", '0', 0}, + {"hex", 'f', 15}, + {"sad", 'g', 16}, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + out := digitVal(tt.in) + testutil.Equals(t, out, tt.out) + }) + } +} From 0224d8ead26cb023c73ed1242698d7d27496d9ef Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Wed, 31 Jul 2024 11:42:00 -0700 Subject: [PATCH 028/216] cedar-go/x/exp/ast: make Cedar policy parsing private for now Signed-off-by: philhassey --- x/exp/ast/parser.go | 2 +- x/exp/ast/parser_test.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/x/exp/ast/parser.go b/x/exp/ast/parser.go index 0996b798..d82a27c8 100644 --- a/x/exp/ast/parser.go +++ b/x/exp/ast/parser.go @@ -6,7 +6,7 @@ import ( "github.com/cedar-policy/cedar-go/types" ) -func PolicyFromCedar(p *parser) (*Policy, error) { +func policyFromCedar(p *parser) (*Policy, error) { annotations, err := p.annotations() if err != nil { return nil, err diff --git a/x/exp/ast/parser_test.go b/x/exp/ast/parser_test.go index 8dceef58..672eb0f6 100644 --- a/x/exp/ast/parser_test.go +++ b/x/exp/ast/parser_test.go @@ -128,7 +128,7 @@ func TestParse(t *testing.T) { parser := newParser(tokens) - policy, err := PolicyFromCedar(&parser) + policy, err := policyFromCedar(&parser) testutil.OK(t, err) testutil.Equals(t, policy, tt.ExpectedPolicy) From 7520cca07292bb37bfbc020ddc390e16ecf21855 Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Wed, 31 Jul 2024 11:50:10 -0700 Subject: [PATCH 029/216] cedar-go/x/exp/ast: scope tests and give them names This allows GoLand to see the individual tests and gives them more legible names Signed-off-by: philhassey --- x/exp/ast/parser_test.go | 96 ++++++++++++++++++++++------------------ 1 file changed, 53 insertions(+), 43 deletions(-) diff --git a/x/exp/ast/parser_test.go b/x/exp/ast/parser_test.go index 672eb0f6..f8b6d9e8 100644 --- a/x/exp/ast/parser_test.go +++ b/x/exp/ast/parser_test.go @@ -36,89 +36,99 @@ var malus = types.EntityUID{ ID: "malus", } -var parseTests = []struct { - Text string - ExpectedPolicy *Policy -}{ - { - `permit ( +func TestParse(t *testing.T) { + t.Parallel() + parseTests := []struct { + Name string + Text string + ExpectedPolicy *Policy + }{ + { + "permit any scope", + `permit ( principal, action, resource );`, - Permit(), - }, - { - `forbid ( + Permit(), + }, + { + "forbid any scope", + `forbid ( principal, action, resource );`, - Forbid(), - }, - { - `@foo("bar") + Forbid(), + }, + { + "one annotation", + `@foo("bar") permit ( principal, action, resource );`, - Annotation("foo", "bar").Permit(), - }, - { - `@foo("bar") + Annotation("foo", "bar").Permit(), + }, + { + "two annotations", + `@foo("bar") @baz("quux") permit ( principal, action, resource );`, - Annotation("foo", "bar").Annotation("baz", "quux").Permit(), - }, - { - `permit ( + Annotation("foo", "bar").Annotation("baz", "quux").Permit(), + }, + { + "scope eq", + `permit ( principal == User::"johnny", action == Action::"sow", resource == Crop::"apple" );`, - Permit().PrincipalEq(johnny).ActionEq(sow).ResourceEq(apple), - }, - { - `permit ( + Permit().PrincipalEq(johnny).ActionEq(sow).ResourceEq(apple), + }, + { + "scope is", + `permit ( principal is User, action, resource is Crop );`, - Permit().PrincipalIs("User").ResourceIs("Crop"), - }, - { - `permit ( + Permit().PrincipalIs("User").ResourceIs("Crop"), + }, + { + "scope is in", + `permit ( principal is User in Group::"folkHeroes", action, resource is Crop in Genus::"malus" );`, - Permit().PrincipalIsIn("User", folkHeroes).ResourceIsIn("Crop", malus), - }, - { - `permit ( + Permit().PrincipalIsIn("User", folkHeroes).ResourceIsIn("Crop", malus), + }, + { + "scope in", + `permit ( principal in Group::"folkHeroes", action in ActionType::"farming", resource in Genus::"malus" );`, - Permit().PrincipalIn(folkHeroes).ActionIn(farming).ResourceIn(malus), - }, - { - `permit ( + Permit().PrincipalIn(folkHeroes).ActionIn(farming).ResourceIn(malus), + }, + { + "scope action in entities", + `permit ( principal, action in [ActionType::"farming", ActionType::"forestry"], resource );`, - Permit().ActionIn(farming, forestry), - }, -} + Permit().ActionIn(farming, forestry), + }, + } -func TestParse(t *testing.T) { - t.Parallel() for _, tt := range parseTests { t.Run(tt.Text, func(t *testing.T) { t.Parallel() From b7f565122291e9d53b58751742d5281c90cc7de5 Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Wed, 31 Jul 2024 12:35:15 -0700 Subject: [PATCH 030/216] cedar-go/x/exp/ast: fix linting issues Signed-off-by: philhassey --- x/exp/ast/tokenize.go | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/x/exp/ast/tokenize.go b/x/exp/ast/tokenize.go index 720d7bf2..a3564985 100644 --- a/x/exp/ast/tokenize.go +++ b/x/exp/ast/tokenize.go @@ -32,26 +32,10 @@ type Token struct { Text string } -func (t Token) isEOF() bool { - return t.Type == TokenEOF -} - func (t Token) isIdent() bool { return t.Type == TokenIdent } -func (t Token) isInt() bool { - return t.Type == TokenInt -} - -func (t Token) isString() bool { - return t.Type == TokenString -} - -func (t Token) toString() string { - return t.Text -} - func (t Token) stringValue() (string, error) { s := t.Text s = strings.TrimPrefix(s, "\"") @@ -61,10 +45,6 @@ func (t Token) stringValue() (string, error) { return res, err } -func (t Token) patternValue() (Pattern, error) { - return NewPattern(t.Text) -} - func nextRune(b []byte, i int) (rune, int, error) { ch, size := utf8.DecodeRune(b[i:]) if ch == utf8.RuneError { From 51a39e413d087f9e54046de233e5db6c7a89a9bf Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Wed, 31 Jul 2024 12:37:23 -0700 Subject: [PATCH 031/216] cedar-go/x/exp/ast: fix broken build Signed-off-by: philhassey --- x/exp/ast/tokenize.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/x/exp/ast/tokenize.go b/x/exp/ast/tokenize.go index a3564985..5ecea604 100644 --- a/x/exp/ast/tokenize.go +++ b/x/exp/ast/tokenize.go @@ -36,6 +36,10 @@ func (t Token) isIdent() bool { return t.Type == TokenIdent } +func (t Token) isString() bool { + return t.Type == TokenString +} + func (t Token) stringValue() (string, error) { s := t.Text s = strings.TrimPrefix(s, "\"") From 82f145cfd595834c909160782a8e5221ce1a78d8 Mon Sep 17 00:00:00 2001 From: philhassey Date: Wed, 31 Jul 2024 14:56:13 -0600 Subject: [PATCH 032/216] x/exp/ast: add sketch of json marshaller using node helper facades Addresses IDX-55 Signed-off-by: philhassey --- x/exp/ast/annotation.go | 10 + x/exp/ast/ast_test.go | 2 +- x/exp/ast/json.go | 360 ----------------------------------- x/exp/ast/json_marshal.go | 79 ++++++++ x/exp/ast/json_unmarshal.go | 362 ++++++++++++++++++++++++++++++++++++ x/exp/ast/parser.go | 2 +- x/exp/ast/parser_test.go | 2 +- x/exp/ast/scope.go | 101 ++++++++-- 8 files changed, 536 insertions(+), 382 deletions(-) create mode 100644 x/exp/ast/json_marshal.go create mode 100644 x/exp/ast/json_unmarshal.go diff --git a/x/exp/ast/annotation.go b/x/exp/ast/annotation.go index 940a9f1b..ad7a2618 100644 --- a/x/exp/ast/annotation.go +++ b/x/exp/ast/annotation.go @@ -22,6 +22,16 @@ func (a *Annotations) Annotation(name, value types.String) *Annotations { return a } +type annotationNode Node + +func (n annotationNode) Key() types.String { + return n.args[0].value.(types.String) +} + +func (n annotationNode) Value() types.String { + return n.args[1].value.(types.String) +} + func (a *Annotations) Permit() *Policy { p := Permit() p.annotations = a.nodes diff --git a/x/exp/ast/ast_test.go b/x/exp/ast/ast_test.go index 9aadbf1a..d07d5745 100644 --- a/x/exp/ast/ast_test.go +++ b/x/exp/ast/ast_test.go @@ -26,7 +26,7 @@ func TestAst(t *testing.T) { _ = ast.Annotation("example", "one"). Permit(). PrincipalIsIn("User", johnny). - ActionIn(sow, cast). + ActionInSet(sow, cast). When(ast.True()). Unless(ast.False()) diff --git a/x/exp/ast/json.go b/x/exp/ast/json.go index eb19147e..46863ede 100644 --- a/x/exp/ast/json.go +++ b/x/exp/ast/json.go @@ -1,9 +1,7 @@ package ast import ( - "bytes" "encoding/json" - "fmt" "github.com/cedar-policy/cedar-go/types" ) @@ -29,33 +27,6 @@ type scopeJSON struct { In *inJSON `json:"in,omitempty"` } -func (s *scopeJSON) ToNode(variable Node) (Node, error) { - switch s.Op { - case "All": - return True(), nil - case "==": - if s.Entity == nil { - return Node{}, fmt.Errorf("missing entity") - } - return variable.Equals(Entity(*s.Entity)), nil - case "in": - if s.Entity != nil { - return variable.In(Entity(*s.Entity)), nil // TODO: review shape, maybe .In vs .InNode - } - var set types.Set - for _, e := range s.Entities { - set = append(set, e) - } - return variable.In(Set(set)), nil // TODO: maybe there is an In and an InSet Node? - case "is": - if s.In == nil { - return variable.Is(types.String(s.EntityType)), nil // TODO: hmmm, I'm not sure can this be Stronger-typed? - } - return variable.IsIn(types.String(s.EntityType), Entity(s.In.Entity)), nil - } - return Node{}, fmt.Errorf("unknown op: %v", s.Op) -} - type conditionJSON struct { Kind string `json:"kind"` Body nodeJSON `json:"body"` @@ -66,157 +37,25 @@ type binaryJSON struct { Right nodeJSON `json:"right"` } -func (j binaryJSON) ToNode(f func(a, b Node) Node) (Node, error) { - left, err := j.Left.ToNode() - if err != nil { - return Node{}, fmt.Errorf("error in left: %w", err) - } - right, err := j.Right.ToNode() - if err != nil { - return Node{}, fmt.Errorf("error in right: %w", err) - } - return f(left, right), nil -} - type unaryJSON struct { Arg nodeJSON `json:"arg"` } -func (j unaryJSON) ToNode(f func(a Node) Node) (Node, error) { - arg, err := j.Arg.ToNode() - if err != nil { - return Node{}, fmt.Errorf("error in arg: %w", err) - } - return f(arg), nil -} - type strJSON struct { Left nodeJSON `json:"left"` Attr string `json:"attr"` } -func (j strJSON) ToNode(f func(a Node, k string) Node) (Node, error) { - left, err := j.Left.ToNode() - if err != nil { - return Node{}, fmt.Errorf("error in left: %w", err) - } - return f(left, j.Attr), nil -} - type ifThenElseJSON struct { If nodeJSON `json:"if"` Then nodeJSON `json:"then"` Else nodeJSON `json:"else"` } -func (j ifThenElseJSON) ToNode() (Node, error) { - if_, err := j.If.ToNode() - if err != nil { - return Node{}, fmt.Errorf("error in if: %w", err) - } - then, err := j.Then.ToNode() - if err != nil { - return Node{}, fmt.Errorf("error in then: %w", err) - } - else_, err := j.Else.ToNode() - if err != nil { - return Node{}, fmt.Errorf("error in else: %w", err) - } - return If(if_, then, else_), nil -} - type arrayJSON []nodeJSON -func (j arrayJSON) ToNode() (Node, error) { - var nodes []Node - for _, jj := range j { - n, err := jj.ToNode() - if err != nil { - return Node{}, fmt.Errorf("error in set: %w", err) - } - nodes = append(nodes, n) - } - return SetNodes(nodes), nil -} - -func (j arrayJSON) ToExt1(f func(Node) Node) (Node, error) { - if len(j) != 1 { - return Node{}, fmt.Errorf("unexpected number of arguments for extension: %v", len(j)) - } - arg, err := j[0].ToNode() - if err != nil { - return Node{}, fmt.Errorf("error in extension: %w", err) - } - return f(arg), nil -} - -func (j arrayJSON) ToDecimalNode() (Node, error) { - if len(j) != 1 { - return Node{}, fmt.Errorf("unexpected number of arguments for extension: %v", len(j)) - } - arg, err := j[0].ToNode() - if err != nil { - return Node{}, fmt.Errorf("error in extension: %w", err) - } - s, ok := arg.value.(types.String) - if !ok { - return Node{}, fmt.Errorf("unexpected type for decimal") - } - v, err := types.ParseDecimal(string(s)) - if err != nil { - return Node{}, fmt.Errorf("error parsing decimal: %w", err) - } - return Decimal(v), nil -} - -func (j arrayJSON) ToIPAddrNode() (Node, error) { - if len(j) != 1 { - return Node{}, fmt.Errorf("unexpected number of arguments for extension: %v", len(j)) - } - arg, err := j[0].ToNode() - if err != nil { - return Node{}, fmt.Errorf("error in extension: %w", err) - } - s, ok := arg.value.(types.String) - if !ok { - return Node{}, fmt.Errorf("unexpected type for ipaddr") - } - v, err := types.ParseIPAddr(string(s)) - if err != nil { - return Node{}, fmt.Errorf("error parsing ipaddr: %w", err) - } - return IPAddr(v), nil -} - -func (j arrayJSON) ToExt2(f func(Node, Node) Node) (Node, error) { - if len(j) != 2 { - return Node{}, fmt.Errorf("unexpected number of arguments for extension: %v", len(j)) - } - left, err := j[0].ToNode() - if err != nil { - return Node{}, fmt.Errorf("error in argument 0: %w", err) - } - right, err := j[1].ToNode() - if err != nil { - return Node{}, fmt.Errorf("error in argument 1: %w", err) - } - return f(left, right), nil -} - type recordJSON map[string]nodeJSON -func (j recordJSON) ToNode() (Node, error) { - nodes := map[types.String]Node{} - for k, v := range j { - n, err := v.ToNode() - if err != nil { - return Node{}, fmt.Errorf("error in record: %w", err) - } - nodes[types.String(k)] = n - } - return RecordNodes(nodes), nil -} - type nodeJSON struct { // Value @@ -280,202 +119,3 @@ type nodeJSON struct { IsMulticastExt arrayJSON `json:"isMulticast"` IsInRangeExt arrayJSON `json:"isInRange"` } - -var ( // TODO: de-dupe from types? - errJSONDecode = fmt.Errorf("error decoding json") - errJSONLongOutOfRange = fmt.Errorf("long out of range") - errJSONUnsupportedType = fmt.Errorf("unsupported type") -) - -func parseRawMessage(j *json.RawMessage) (Node, error) { - // TODO: de-dupe from types? though it's not 100% compat, because of extensions :( - // TODO: make this faster if it matters - { - var res types.EntityUID - ptr := &res - if err := ptr.UnmarshalJSON(*j); err == nil { - return Entity(res), nil - } - } - - var res interface{} - dec := json.NewDecoder(bytes.NewBuffer(*j)) - dec.UseNumber() - if err := dec.Decode(&res); err != nil { - return Node{}, fmt.Errorf("%w: %w", errJSONDecode, err) - } - switch vv := res.(type) { - case string: - return String(types.String(vv)), nil - case bool: - return Boolean(types.Boolean(vv)), nil - case json.Number: - l, err := vv.Int64() - if err != nil { - return Node{}, fmt.Errorf("%w: %w", errJSONLongOutOfRange, err) - } - return Long(types.Long(l)), nil - } - return Node{}, errJSONUnsupportedType - -} - -func (j nodeJSON) ToNode() (Node, error) { - switch { - // Value - case j.Value != nil: - return parseRawMessage(j.Value) - - // Var - case j.Var != nil: - switch *j.Var { - case "principal": - return Principal(), nil - case "action": - return Action(), nil - case "resource": - return Resource(), nil - case "context": - return Context(), nil - } - return Node{}, fmt.Errorf("unknown var: %v", j.Var) - - // Slot - // Unknown - - // ! or neg operators - case j.Not != nil: - return j.Not.ToNode(Not) - case j.Negate != nil: - return j.Negate.ToNode(Negate) - - // Binary operators: ==, !=, in, <, <=, >, >=, &&, ||, +, -, *, contains, containsAll, containsAny - case j.Equals != nil: - return j.Equals.ToNode(Node.Equals) - case j.NotEquals != nil: - return j.NotEquals.ToNode(Node.NotEquals) - case j.In != nil: - return j.In.ToNode(Node.In) - case j.LessThan != nil: - return j.LessThan.ToNode(Node.LessThan) - case j.LessThanOrEqual != nil: - return j.LessThanOrEqual.ToNode(Node.LessThanOrEqual) - case j.GreaterThan != nil: - return j.GreaterThan.ToNode(Node.GreaterThan) - case j.GreaterThanOrEqual != nil: - return j.GreaterThanOrEqual.ToNode(Node.GreaterThanOrEqual) - case j.And != nil: - return j.And.ToNode(Node.And) - case j.Or != nil: - return j.Or.ToNode(Node.Or) - case j.Plus != nil: - return j.Plus.ToNode(Node.Plus) - case j.Minus != nil: - return j.Minus.ToNode(Node.Minus) - case j.Times != nil: - return j.Times.ToNode(Node.Times) - case j.Contains != nil: - return j.Contains.ToNode(Node.Contains) - case j.ContainsAll != nil: - return j.ContainsAll.ToNode(Node.ContainsAll) - case j.ContainsAny != nil: - return j.ContainsAny.ToNode(Node.ContainsAny) - - // ., has - case j.Access != nil: - return j.Access.ToNode(Node.Access) - case j.Has != nil: - return j.Has.ToNode(Node.Has) - - // like - case j.Like != nil: - return j.Like.ToNode(Node.Like) - - // if-then-else - case j.IfThenElse != nil: - return j.IfThenElse.ToNode() - - // Set - case j.Set != nil: - return j.Set.ToNode() - - // Record - case j.Record != nil: - return j.Record.ToNode() - - // Any other function: decimal, ip - case j.Decimal != nil: - return j.Decimal.ToDecimalNode() - case j.IP != nil: - return j.IP.ToIPAddrNode() - - // Any other method: lessThan, lessThanOrEqual, greaterThan, greaterThanOrEqual, isIpv4, isIpv6, isLoopback, isMulticast, isInRange - case j.LessThanExt != nil: - return j.LessThanExt.ToExt2(Node.LessThanExt) - case j.LessThanOrEqualExt != nil: - return j.LessThanOrEqualExt.ToExt2(Node.LessThanOrEqualExt) - case j.GreaterThanExt != nil: - return j.GreaterThanExt.ToExt2(Node.GreaterThanExt) - case j.GreaterThanOrEqualExt != nil: - return j.GreaterThanOrEqualExt.ToExt2(Node.GreaterThanOrEqualExt) - case j.IsIpv4Ext != nil: - return j.IsIpv4Ext.ToExt1(Node.IsIpv4) - case j.IsIpv6Ext != nil: - return j.IsIpv6Ext.ToExt1(Node.IsIpv6) - case j.IsLoopbackExt != nil: - return j.IsLoopbackExt.ToExt1(Node.IsLoopback) - case j.IsMulticastExt != nil: - return j.IsMulticastExt.ToExt1(Node.IsMulticast) - case j.IsInRangeExt != nil: - return j.IsInRangeExt.ToExt2(Node.IsInRange) - } - - return Node{}, fmt.Errorf("unknown node") -} - -func (p *Policy) UnmarshalJSON(b []byte) error { - var j policyJSON - if err := json.Unmarshal(b, &j); err != nil { - return fmt.Errorf("error unmarshalling json: %w", err) - } - switch j.Effect { - case "permit": - *p = *Permit() - case "forbid": - *p = *Forbid() - default: - return fmt.Errorf("unknown effect: %v", j.Effect) - } - for k, v := range j.Annotations { - p.Annotate(types.String(k), types.String(v)) - } - var err error - p.principal, err = j.Principal.ToNode(Principal()) - if err != nil { - return fmt.Errorf("error in principal: %w", err) - } - p.action, err = j.Action.ToNode(Action()) - if err != nil { - return fmt.Errorf("error in action: %w", err) - } - p.resource, err = j.Resource.ToNode(Resource()) - if err != nil { - return fmt.Errorf("error in resource: %w", err) - } - for _, c := range j.Conditions { - n, err := c.Body.ToNode() - if err != nil { - return fmt.Errorf("error in conditions: %w", err) - } - switch c.Kind { - case "when": - p.When(n) - case "unless": - p.Unless(n) - default: - return fmt.Errorf("unknown condition kind: %v", c.Kind) - } - } - - return nil -} diff --git a/x/exp/ast/json_marshal.go b/x/exp/ast/json_marshal.go new file mode 100644 index 00000000..7843ad3c --- /dev/null +++ b/x/exp/ast/json_marshal.go @@ -0,0 +1,79 @@ +package ast + +import ( + "encoding/json" + "fmt" +) + +func (s *scopeJSON) FromNode(src Node) error { + switch src.nodeType { + case nodeTypeBoolean: + s.Op = "All" + return nil + case nodeTypeEquals: + n := scopeEqNode(src) + s.Op = "==" + e := n.Entity() + s.Entity = &e + return nil + case nodeTypeIn: + n := scopeInNode(src) + s.Op = "in" + if n.IsSet() { + s.Entities = n.Set() + } else { + e := n.Entity() + s.Entity = &e + } + return nil + case nodeTypeIs: + n := scopeIsNode(src) + s.Op = "is" + s.EntityType = string(n.EntityType()) + return nil + case nodeTypeIsIn: // is in + n := scopeIsInNode(src) + s.Op = "is" + s.EntityType = string(n.EntityType()) + s.In = &inJSON{ + Entity: n.Entity(), + } + return nil + } + return fmt.Errorf("unexpected scope node: %v", src.nodeType) +} +func (j nodeJSON) FromNode(src Node) error { + // TODO: all this + return nil +} +func (p *Policy) MarshalJSON() ([]byte, error) { + var j policyJSON + j.Effect = "forbid" + if p.effect { + j.Effect = "permit" + } + if len(p.annotations) > 0 { + j.Annotations = map[string]string{} + } + for _, a := range p.annotations { + n := annotationNode(a) + j.Annotations[string(n.Key())] = string(n.Value()) + } + if err := j.Principal.FromNode(p.principal); err != nil { + return nil, fmt.Errorf("error in principal: %w", err) + } + if err := j.Action.FromNode(p.action); err != nil { + return nil, fmt.Errorf("error in action: %w", err) + } + if err := j.Resource.FromNode(p.resource); err != nil { + return nil, fmt.Errorf("error in resource: %w", err) + } + for _, c := range p.conditions { + var cond conditionJSON + if err := cond.Body.FromNode(c); err != nil { + return nil, fmt.Errorf("error in condition: %w", err) + } + j.Conditions = append(j.Conditions, cond) + } + return json.Marshal(j) +} diff --git a/x/exp/ast/json_unmarshal.go b/x/exp/ast/json_unmarshal.go new file mode 100644 index 00000000..e313a92a --- /dev/null +++ b/x/exp/ast/json_unmarshal.go @@ -0,0 +1,362 @@ +package ast + +import ( + "bytes" + "encoding/json" + "fmt" + + "github.com/cedar-policy/cedar-go/types" +) + +func (s *scopeJSON) ToNode(variable Node) (Node, error) { + switch s.Op { + case "All": + return True(), nil + case "==": + if s.Entity == nil { + return Node{}, fmt.Errorf("missing entity") + } + return variable.Equals(Entity(*s.Entity)), nil + case "in": + if s.Entity != nil { + return variable.In(Entity(*s.Entity)), nil // TODO: review shape, maybe .In vs .InNode + } + var set types.Set + for _, e := range s.Entities { + set = append(set, e) + } + return variable.In(Set(set)), nil // TODO: maybe there is an In and an InSet Node? + case "is": + if s.In == nil { + return variable.Is(types.String(s.EntityType)), nil // TODO: hmmm, I'm not sure can this be Stronger-typed? + } + return variable.IsIn(types.String(s.EntityType), Entity(s.In.Entity)), nil + } + return Node{}, fmt.Errorf("unknown op: %v", s.Op) +} + +func (j binaryJSON) ToNode(f func(a, b Node) Node) (Node, error) { + left, err := j.Left.ToNode() + if err != nil { + return Node{}, fmt.Errorf("error in left: %w", err) + } + right, err := j.Right.ToNode() + if err != nil { + return Node{}, fmt.Errorf("error in right: %w", err) + } + return f(left, right), nil +} +func (j unaryJSON) ToNode(f func(a Node) Node) (Node, error) { + arg, err := j.Arg.ToNode() + if err != nil { + return Node{}, fmt.Errorf("error in arg: %w", err) + } + return f(arg), nil +} +func (j strJSON) ToNode(f func(a Node, k string) Node) (Node, error) { + left, err := j.Left.ToNode() + if err != nil { + return Node{}, fmt.Errorf("error in left: %w", err) + } + return f(left, j.Attr), nil +} +func (j ifThenElseJSON) ToNode() (Node, error) { + if_, err := j.If.ToNode() + if err != nil { + return Node{}, fmt.Errorf("error in if: %w", err) + } + then, err := j.Then.ToNode() + if err != nil { + return Node{}, fmt.Errorf("error in then: %w", err) + } + else_, err := j.Else.ToNode() + if err != nil { + return Node{}, fmt.Errorf("error in else: %w", err) + } + return If(if_, then, else_), nil +} +func (j arrayJSON) ToNode() (Node, error) { + var nodes []Node + for _, jj := range j { + n, err := jj.ToNode() + if err != nil { + return Node{}, fmt.Errorf("error in set: %w", err) + } + nodes = append(nodes, n) + } + return SetNodes(nodes), nil +} + +func (j arrayJSON) ToExt1(f func(Node) Node) (Node, error) { + if len(j) != 1 { + return Node{}, fmt.Errorf("unexpected number of arguments for extension: %v", len(j)) + } + arg, err := j[0].ToNode() + if err != nil { + return Node{}, fmt.Errorf("error in extension: %w", err) + } + return f(arg), nil +} + +func (j arrayJSON) ToDecimalNode() (Node, error) { + if len(j) != 1 { + return Node{}, fmt.Errorf("unexpected number of arguments for extension: %v", len(j)) + } + arg, err := j[0].ToNode() + if err != nil { + return Node{}, fmt.Errorf("error in extension: %w", err) + } + s, ok := arg.value.(types.String) + if !ok { + return Node{}, fmt.Errorf("unexpected type for decimal") + } + v, err := types.ParseDecimal(string(s)) + if err != nil { + return Node{}, fmt.Errorf("error parsing decimal: %w", err) + } + return Decimal(v), nil +} + +func (j arrayJSON) ToIPAddrNode() (Node, error) { + if len(j) != 1 { + return Node{}, fmt.Errorf("unexpected number of arguments for extension: %v", len(j)) + } + arg, err := j[0].ToNode() + if err != nil { + return Node{}, fmt.Errorf("error in extension: %w", err) + } + s, ok := arg.value.(types.String) + if !ok { + return Node{}, fmt.Errorf("unexpected type for ipaddr") + } + v, err := types.ParseIPAddr(string(s)) + if err != nil { + return Node{}, fmt.Errorf("error parsing ipaddr: %w", err) + } + return IPAddr(v), nil +} + +func (j arrayJSON) ToExt2(f func(Node, Node) Node) (Node, error) { + if len(j) != 2 { + return Node{}, fmt.Errorf("unexpected number of arguments for extension: %v", len(j)) + } + left, err := j[0].ToNode() + if err != nil { + return Node{}, fmt.Errorf("error in argument 0: %w", err) + } + right, err := j[1].ToNode() + if err != nil { + return Node{}, fmt.Errorf("error in argument 1: %w", err) + } + return f(left, right), nil +} +func (j recordJSON) ToNode() (Node, error) { + nodes := map[types.String]Node{} + for k, v := range j { + n, err := v.ToNode() + if err != nil { + return Node{}, fmt.Errorf("error in record: %w", err) + } + nodes[types.String(k)] = n + } + return RecordNodes(nodes), nil +} + +var ( // TODO: de-dupe from types? + errJSONDecode = fmt.Errorf("error decoding json") + errJSONLongOutOfRange = fmt.Errorf("long out of range") + errJSONUnsupportedType = fmt.Errorf("unsupported type") +) + +func parseRawMessage(j *json.RawMessage) (Node, error) { + // TODO: de-dupe from types? though it's not 100% compat, because of extensions :( + // TODO: make this faster if it matters + { + var res types.EntityUID + ptr := &res + if err := ptr.UnmarshalJSON(*j); err == nil { + return Entity(res), nil + } + } + + var res interface{} + dec := json.NewDecoder(bytes.NewBuffer(*j)) + dec.UseNumber() + if err := dec.Decode(&res); err != nil { + return Node{}, fmt.Errorf("%w: %w", errJSONDecode, err) + } + switch vv := res.(type) { + case string: + return String(types.String(vv)), nil + case bool: + return Boolean(types.Boolean(vv)), nil + case json.Number: + l, err := vv.Int64() + if err != nil { + return Node{}, fmt.Errorf("%w: %w", errJSONLongOutOfRange, err) + } + return Long(types.Long(l)), nil + } + return Node{}, errJSONUnsupportedType + +} + +func (j nodeJSON) ToNode() (Node, error) { + switch { + // Value + case j.Value != nil: + return parseRawMessage(j.Value) + + // Var + case j.Var != nil: + switch *j.Var { + case "principal": + return Principal(), nil + case "action": + return Action(), nil + case "resource": + return Resource(), nil + case "context": + return Context(), nil + } + return Node{}, fmt.Errorf("unknown var: %v", j.Var) + + // Slot + // Unknown + + // ! or neg operators + case j.Not != nil: + return j.Not.ToNode(Not) + case j.Negate != nil: + return j.Negate.ToNode(Negate) + + // Binary operators: ==, !=, in, <, <=, >, >=, &&, ||, +, -, *, contains, containsAll, containsAny + case j.Equals != nil: + return j.Equals.ToNode(Node.Equals) + case j.NotEquals != nil: + return j.NotEquals.ToNode(Node.NotEquals) + case j.In != nil: + return j.In.ToNode(Node.In) + case j.LessThan != nil: + return j.LessThan.ToNode(Node.LessThan) + case j.LessThanOrEqual != nil: + return j.LessThanOrEqual.ToNode(Node.LessThanOrEqual) + case j.GreaterThan != nil: + return j.GreaterThan.ToNode(Node.GreaterThan) + case j.GreaterThanOrEqual != nil: + return j.GreaterThanOrEqual.ToNode(Node.GreaterThanOrEqual) + case j.And != nil: + return j.And.ToNode(Node.And) + case j.Or != nil: + return j.Or.ToNode(Node.Or) + case j.Plus != nil: + return j.Plus.ToNode(Node.Plus) + case j.Minus != nil: + return j.Minus.ToNode(Node.Minus) + case j.Times != nil: + return j.Times.ToNode(Node.Times) + case j.Contains != nil: + return j.Contains.ToNode(Node.Contains) + case j.ContainsAll != nil: + return j.ContainsAll.ToNode(Node.ContainsAll) + case j.ContainsAny != nil: + return j.ContainsAny.ToNode(Node.ContainsAny) + + // ., has + case j.Access != nil: + return j.Access.ToNode(Node.Access) + case j.Has != nil: + return j.Has.ToNode(Node.Has) + + // like + case j.Like != nil: + return j.Like.ToNode(Node.Like) + + // if-then-else + case j.IfThenElse != nil: + return j.IfThenElse.ToNode() + + // Set + case j.Set != nil: + return j.Set.ToNode() + + // Record + case j.Record != nil: + return j.Record.ToNode() + + // Any other function: decimal, ip + case j.Decimal != nil: + return j.Decimal.ToDecimalNode() + case j.IP != nil: + return j.IP.ToIPAddrNode() + + // Any other method: lessThan, lessThanOrEqual, greaterThan, greaterThanOrEqual, isIpv4, isIpv6, isLoopback, isMulticast, isInRange + case j.LessThanExt != nil: + return j.LessThanExt.ToExt2(Node.LessThanExt) + case j.LessThanOrEqualExt != nil: + return j.LessThanOrEqualExt.ToExt2(Node.LessThanOrEqualExt) + case j.GreaterThanExt != nil: + return j.GreaterThanExt.ToExt2(Node.GreaterThanExt) + case j.GreaterThanOrEqualExt != nil: + return j.GreaterThanOrEqualExt.ToExt2(Node.GreaterThanOrEqualExt) + case j.IsIpv4Ext != nil: + return j.IsIpv4Ext.ToExt1(Node.IsIpv4) + case j.IsIpv6Ext != nil: + return j.IsIpv6Ext.ToExt1(Node.IsIpv6) + case j.IsLoopbackExt != nil: + return j.IsLoopbackExt.ToExt1(Node.IsLoopback) + case j.IsMulticastExt != nil: + return j.IsMulticastExt.ToExt1(Node.IsMulticast) + case j.IsInRangeExt != nil: + return j.IsInRangeExt.ToExt2(Node.IsInRange) + } + + return Node{}, fmt.Errorf("unknown node") +} + +func (p *Policy) UnmarshalJSON(b []byte) error { + var j policyJSON + if err := json.Unmarshal(b, &j); err != nil { + return fmt.Errorf("error unmarshalling json: %w", err) + } + switch j.Effect { + case "permit": + *p = *Permit() + case "forbid": + *p = *Forbid() + default: + return fmt.Errorf("unknown effect: %v", j.Effect) + } + for k, v := range j.Annotations { + p.Annotate(types.String(k), types.String(v)) + } + var err error + p.principal, err = j.Principal.ToNode(Principal()) + if err != nil { + return fmt.Errorf("error in principal: %w", err) + } + p.action, err = j.Action.ToNode(Action()) + if err != nil { + return fmt.Errorf("error in action: %w", err) + } + p.resource, err = j.Resource.ToNode(Resource()) + if err != nil { + return fmt.Errorf("error in resource: %w", err) + } + for _, c := range j.Conditions { + n, err := c.Body.ToNode() + if err != nil { + return fmt.Errorf("error in conditions: %w", err) + } + switch c.Kind { + case "when": + p.When(n) + case "unless": + p.Unless(n) + default: + return fmt.Errorf("unknown condition kind: %v", c.Kind) + } + } + + return nil +} diff --git a/x/exp/ast/parser.go b/x/exp/ast/parser.go index d82a27c8..56fdd76f 100644 --- a/x/exp/ast/parser.go +++ b/x/exp/ast/parser.go @@ -251,7 +251,7 @@ func (p *parser) action(policy *Policy) error { if err != nil { return err } - policy.ActionIn(entities...) + policy.ActionInSet(entities...) p.advance() // entlist guarantees "]" return nil } else { diff --git a/x/exp/ast/parser_test.go b/x/exp/ast/parser_test.go index f8b6d9e8..8df84e3c 100644 --- a/x/exp/ast/parser_test.go +++ b/x/exp/ast/parser_test.go @@ -125,7 +125,7 @@ func TestParse(t *testing.T) { action in [ActionType::"farming", ActionType::"forestry"], resource );`, - Permit().ActionIn(farming, forestry), + Permit().ActionInSet(farming, forestry), }, } diff --git a/x/exp/ast/scope.go b/x/exp/ast/scope.go index eb559481..e628c73f 100644 --- a/x/exp/ast/scope.go +++ b/x/exp/ast/scope.go @@ -2,60 +2,123 @@ package ast import "github.com/cedar-policy/cedar-go/types" +type scope Node + +func (s scope) Eq(entity types.EntityUID) Node { + return Node(s).Equals(Entity(entity)) +} + +type scopeEqNode Node + +func (n scopeEqNode) Entity() types.EntityUID { + return n.args[1].value.(types.EntityUID) +} + +func (s scope) In(entity types.EntityUID) Node { + return Node(s).In(Entity(entity)) +} + +func (s scope) InSet(entities []types.EntityUID) Node { + var entityValues []types.Value + for _, e := range entities { + entityValues = append(entityValues, e) + } + return Node(s).In(Set(entityValues)) +} + +type scopeInNode Node + +func (n scopeInNode) IsSet() bool { + return Node(n).args[1].nodeType == nodeTypeSet +} + +func (n scopeInNode) Entity() types.EntityUID { + return n.args[1].value.(types.EntityUID) +} + +func (n scopeInNode) Set() []types.EntityUID { + var res []types.EntityUID + for _, a := range n.args[1].args { + res = append(res, a.value.(types.EntityUID)) + } + return res +} + +func (s scope) Is(entityType types.String) Node { + return Node(s).Is(entityType) +} + +type scopeIsNode Node + +func (n scopeIsNode) EntityType() types.String { + return n.args[1].value.(types.String) +} + +func (s scope) IsIn(entityType types.String, entity types.EntityUID) Node { + return Node(s).IsIn(entityType, Entity(entity)) +} + +type scopeIsInNode Node + +func (n scopeIsInNode) EntityType() types.String { + return n.args[1].value.(types.String) +} + +func (n scopeIsInNode) Entity() types.EntityUID { + return n.args[2].value.(types.EntityUID) +} + func (p *Policy) PrincipalEq(entity types.EntityUID) *Policy { - p.principal = Principal().Equals(Entity(entity)) + p.principal = scope(Principal()).Eq(entity) return p } func (p *Policy) PrincipalIn(entity types.EntityUID) *Policy { - p.principal = Principal().In(Entity(entity)) + p.principal = scope(Principal()).In(entity) return p } func (p *Policy) PrincipalIs(entityType types.String) *Policy { - p.principal = Principal().Is(entityType) + p.principal = scope(Principal()).Is(entityType) return p } func (p *Policy) PrincipalIsIn(entityType types.String, entity types.EntityUID) *Policy { - p.principal = Principal().IsIn(entityType, Entity(entity)) + p.principal = scope(Principal()).IsIn(entityType, entity) return p } func (p *Policy) ActionEq(entity types.EntityUID) *Policy { - p.action = Action().Equals(Entity(entity)) + p.action = scope(Action()).Eq(entity) return p } -func (p *Policy) ActionIn(entities ...types.EntityUID) *Policy { - if len(entities) == 1 { - p.action = Action().In(Entity(entities[0])) - return p - } - var entityValues []types.Value - for _, e := range entities { - entityValues = append(entityValues, e) - } - p.action = Action().In(Set(entityValues)) +func (p *Policy) ActionIn(entity types.EntityUID) *Policy { + p.action = scope(Action()).In(entity) + return p +} + +func (p *Policy) ActionInSet(entities ...types.EntityUID) *Policy { + p.action = scope(Action()).InSet(entities) return p } func (p *Policy) ResourceEq(entity types.EntityUID) *Policy { - p.resource = Resource().Equals(Entity(entity)) + p.resource = scope(Resource()).Eq(entity) return p } func (p *Policy) ResourceIn(entity types.EntityUID) *Policy { - p.resource = Resource().In(Entity(entity)) + p.resource = scope(Resource()).In(entity) return p } func (p *Policy) ResourceIs(entityType types.String) *Policy { - p.resource = Resource().Is(entityType) + p.resource = scope(Resource()).Is(entityType) return p } func (p *Policy) ResourceIsIn(entityType types.String, entity types.EntityUID) *Policy { - p.resource = Resource().IsIn(entityType, Entity(entity)) + p.resource = scope(Resource()).IsIn(entityType, entity) return p } From 318fa73100a3ea81b537c824025c886b286df234 Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Wed, 31 Jul 2024 14:36:44 -0700 Subject: [PATCH 033/216] cedar-go/x/exp/ast: default the scope nodes to a new "all" node type Signed-off-by: philhassey --- x/exp/ast/annotation.go | 8 ++------ x/exp/ast/node.go | 1 + x/exp/ast/policy.go | 14 ++++++++++++-- 3 files changed, 15 insertions(+), 8 deletions(-) diff --git a/x/exp/ast/annotation.go b/x/exp/ast/annotation.go index ad7a2618..d2a3d59e 100644 --- a/x/exp/ast/annotation.go +++ b/x/exp/ast/annotation.go @@ -33,15 +33,11 @@ func (n annotationNode) Value() types.String { } func (a *Annotations) Permit() *Policy { - p := Permit() - p.annotations = a.nodes - return p + return newPolicy(effectPermit, a.nodes) } func (a *Annotations) Forbid() *Policy { - p := Forbid() - p.annotations = a.nodes - return p + return newPolicy(effectForbid, a.nodes) } func (p *Policy) Annotate(name, value types.String) *Policy { diff --git a/x/exp/ast/node.go b/x/exp/ast/node.go index 22bea5d7..71088898 100644 --- a/x/exp/ast/node.go +++ b/x/exp/ast/node.go @@ -7,6 +7,7 @@ type nodeType uint8 const ( nodeTypeAccess nodeType = iota nodeTypeAdd + nodeTypeAll nodeTypeAnd nodeTypeAnnotation nodeTypeBoolean diff --git a/x/exp/ast/policy.go b/x/exp/ast/policy.go index dab6048d..ab1645f9 100644 --- a/x/exp/ast/policy.go +++ b/x/exp/ast/policy.go @@ -9,12 +9,22 @@ type Policy struct { conditions []Node } +func newPolicy(effect effect, annotations []Node) *Policy { + return &Policy{ + effect: effect, + annotations: annotations, + principal: Node{nodeType: nodeTypeAll}, + action: Node{nodeType: nodeTypeAll}, + resource: Node{nodeType: nodeTypeAll}, + } +} + func Permit() *Policy { - return &Policy{effect: effectPermit} + return newPolicy(effectPermit, nil) } func Forbid() *Policy { - return &Policy{effect: effectForbid} + return newPolicy(effectForbid, nil) } func (p *Policy) When(node Node) *Policy { From ac059115e42e8168a1b1011bb526f84868bc3cc2 Mon Sep 17 00:00:00 2001 From: philhassey Date: Wed, 31 Jul 2024 16:47:54 -0600 Subject: [PATCH 034/216] x/exp/ast: add JSON marshaller Addresses IDX-55 Signed-off-by: philhassey --- types/json.go | 2 +- types/json_test.go | 4 +- types/value.go | 2 +- x/exp/ast/json.go | 76 ++++---- x/exp/ast/json_marshal.go | 337 +++++++++++++++++++++++++++++++++++- x/exp/ast/json_test.go | 11 +- x/exp/ast/json_unmarshal.go | 48 +---- x/exp/ast/node.go | 30 +++- x/exp/ast/operator.go | 68 ++++---- x/exp/ast/policy.go | 4 +- x/exp/ast/value.go | 5 +- x/exp/ast/variable.go | 6 + 12 files changed, 458 insertions(+), 135 deletions(-) diff --git a/types/json.go b/types/json.go index 1df5e0f8..8dd29cd7 100644 --- a/types/json.go +++ b/types/json.go @@ -40,7 +40,7 @@ type explicitValue struct { Value Value } -func unmarshalJSON(b []byte, v *Value) error { +func UnmarshalJSON(b []byte, v *Value) error { // TODO: make this faster if it matters { var res EntityUID diff --git a/types/json_test.go b/types/json_test.go index c4e2d834..50d18cb0 100644 --- a/types/json_test.go +++ b/types/json_test.go @@ -49,7 +49,7 @@ func TestJSON_Value(t *testing.T) { t.Parallel() var got Value ptr := &got - err := unmarshalJSON([]byte(tt.in), ptr) + err := UnmarshalJSON([]byte(tt.in), ptr) testutil.AssertError(t, err, tt.err) AssertValue(t, got, tt.want) if tt.err != nil { @@ -62,7 +62,7 @@ func TestJSON_Value(t *testing.T) { testutil.OK(t, err) var gotRetry Value ptr = &gotRetry - err = unmarshalJSON(gotJSON, ptr) + err = UnmarshalJSON(gotJSON, ptr) testutil.OK(t, err) testutil.Equals(t, gotRetry, tt.want) }) diff --git a/types/value.go b/types/value.go index 27304ddf..433021ee 100644 --- a/types/value.go +++ b/types/value.go @@ -159,7 +159,7 @@ func (as Set) Equal(bi Value) bool { } func (v *explicitValue) UnmarshalJSON(b []byte) error { - return unmarshalJSON(b, &v.Value) + return UnmarshalJSON(b, &v.Value) } func (v *Set) UnmarshalJSON(b []byte) error { diff --git a/x/exp/ast/json.go b/x/exp/ast/json.go index 46863ede..465d70dc 100644 --- a/x/exp/ast/json.go +++ b/x/exp/ast/json.go @@ -57,65 +57,67 @@ type arrayJSON []nodeJSON type recordJSON map[string]nodeJSON type nodeJSON struct { - // Value - Value *json.RawMessage `json:"Value"` // could be any + Value *json.RawMessage `json:"Value,omitempty"` // could be any // Var - Var *string `json:"Var"` + Var *string `json:"Var,omitempty"` // Slot // Unknown // ! or neg operators - Not *unaryJSON `json:"!"` - Negate *unaryJSON `json:"neg"` + Not *unaryJSON `json:"!,omitempty"` + Negate *unaryJSON `json:"neg,omitempty"` // Binary operators: ==, !=, in, <, <=, >, >=, &&, ||, +, -, *, contains, containsAll, containsAny - Equals *binaryJSON `json:"=="` - NotEquals *binaryJSON `json:"!="` - In *binaryJSON `json:"in"` - LessThan *binaryJSON `json:"<"` - LessThanOrEqual *binaryJSON `json:"<="` - GreaterThan *binaryJSON `json:">"` - GreaterThanOrEqual *binaryJSON `json:">="` - And *binaryJSON `json:"&&"` - Or *binaryJSON `json:"||"` - Plus *binaryJSON `json:"+"` - Minus *binaryJSON `json:"-"` - Times *binaryJSON `json:"*"` - Contains *binaryJSON `json:"contains"` - ContainsAll *binaryJSON `json:"containsAll"` - ContainsAny *binaryJSON `json:"containsAny"` + Equals *binaryJSON `json:"==,omitempty"` + NotEquals *binaryJSON `json:"!=,omitempty"` + In *binaryJSON `json:"in,omitempty"` + LessThan *binaryJSON `json:"<,omitempty"` + LessThanOrEqual *binaryJSON `json:"<=,omitempty"` + GreaterThan *binaryJSON `json:">,omitempty"` + GreaterThanOrEqual *binaryJSON `json:">=,omitempty"` + And *binaryJSON `json:"&&,omitempty"` + Or *binaryJSON `json:"||,omitempty"` + Plus *binaryJSON `json:"+,omitempty"` + Minus *binaryJSON `json:"-,omitempty"` + Times *binaryJSON `json:"*,omitempty"` + Contains *binaryJSON `json:"contains,omitempty"` + ContainsAll *binaryJSON `json:"containsAll,omitempty"` + ContainsAny *binaryJSON `json:"containsAny,omitempty"` // ., has - Access *strJSON `json:"."` - Has *strJSON `json:"has"` + Access *strJSON `json:".,omitempty"` + Has *strJSON `json:"has,omitempty"` + + // is + // TODO: https://docs.cedarpolicy.com/policies/json-format.html#JsonExpr-is // like - Like *strJSON `json:"like"` + Like *strJSON `json:"like,omitempty"` // if-then-else - IfThenElse *ifThenElseJSON `json:"if-then-else"` + IfThenElse *ifThenElseJSON `json:"if-then-else,omitempty"` // Set - Set arrayJSON `json:"Set"` + Set arrayJSON `json:"Set,omitempty"` // Record - Record recordJSON `json:"Record"` + Record recordJSON `json:"Record,omitempty"` // Any other function: decimal, ip - Decimal arrayJSON `json:"decimal"` - IP arrayJSON `json:"ip"` + Decimal arrayJSON `json:"decimal,omitempty"` + IP arrayJSON `json:"ip,omitempty"` // Any other method: lessThan, lessThanOrEqual, greaterThan, greaterThanOrEqual, isIpv4, isIpv6, isLoopback, isMulticast, isInRange - LessThanExt arrayJSON `json:"lessThan"` - LessThanOrEqualExt arrayJSON `json:"lessThanOrEqual"` - GreaterThanExt arrayJSON `json:"greaterThan"` - GreaterThanOrEqualExt arrayJSON `json:"greaterThanOrEqual"` - IsIpv4Ext arrayJSON `json:"isIpv4"` - IsIpv6Ext arrayJSON `json:"isIpv6"` - IsLoopbackExt arrayJSON `json:"isLoopback"` - IsMulticastExt arrayJSON `json:"isMulticast"` - IsInRangeExt arrayJSON `json:"isInRange"` + LessThanExt arrayJSON `json:"lessThan,omitempty"` + LessThanOrEqualExt arrayJSON `json:"lessThanOrEqual,omitempty"` + GreaterThanExt arrayJSON `json:"greaterThan,omitempty"` + GreaterThanOrEqualExt arrayJSON `json:"greaterThanOrEqual,omitempty"` + IsIpv4Ext arrayJSON `json:"isIpv4,omitempty"` + IsIpv6Ext arrayJSON `json:"isIpv6,omitempty"` + IsLoopbackExt arrayJSON `json:"isLoopback,omitempty"` + IsMulticastExt arrayJSON `json:"isMulticast,omitempty"` + IsInRangeExt arrayJSON `json:"isInRange,omitempty"` } diff --git a/x/exp/ast/json_marshal.go b/x/exp/ast/json_marshal.go index 7843ad3c..78f6bfe5 100644 --- a/x/exp/ast/json_marshal.go +++ b/x/exp/ast/json_marshal.go @@ -7,7 +7,7 @@ import ( func (s *scopeJSON) FromNode(src Node) error { switch src.nodeType { - case nodeTypeBoolean: + case nodeTypeNone: s.Op = "All" return nil case nodeTypeEquals: @@ -42,9 +42,331 @@ func (s *scopeJSON) FromNode(src Node) error { } return fmt.Errorf("unexpected scope node: %v", src.nodeType) } -func (j nodeJSON) FromNode(src Node) error { - // TODO: all this - return nil +func (j *nodeJSON) FromNode(src Node) error { + switch src.nodeType { + // Value + // Value *json.RawMessage `json:"Value"` // could be any + case nodeTypeBoolean, nodeTypeLong, nodeTypeString, nodeTypeEntity: + b, err := src.value.ExplicitMarshalJSON() + j.Value = (*json.RawMessage)(&b) + return err + + // Var + // Var *string `json:"Var"` + case nodeTypeVariable: + n := variableNode(src) + val := string(n.String()) + j.Var = &val + return nil + + // ! or neg operators + // Not *unaryJSON `json:"!"` + // Negate *unaryJSON `json:"neg"` + case nodeTypeNot: + n := unaryNode(src) + j.Not = &unaryJSON{} + j.Not.Arg.FromNode(n.Arg()) + return nil + case nodeTypeNegate: + n := unaryNode(src) + j.Negate = &unaryJSON{} + j.Negate.Arg.FromNode(n.Arg()) + return nil + + // Binary operators: ==, !=, in, <, <=, >, >=, &&, ||, +, -, *, contains, containsAll, containsAny + case nodeTypeAdd: + n := binaryNode(src) + j.Plus = &binaryJSON{} + j.Plus.Left.FromNode(n.Left()) + j.Plus.Right.FromNode(n.Right()) + return nil + case nodeTypeAnd: + n := binaryNode(src) + j.And = &binaryJSON{} + j.And.Left.FromNode(n.Left()) + j.And.Right.FromNode(n.Right()) + return nil + case nodeTypeContains: + n := binaryNode(src) + j.Contains = &binaryJSON{} + j.Contains.Left.FromNode(n.Left()) + j.Contains.Right.FromNode(n.Right()) + return nil + case nodeTypeContainsAll: + n := binaryNode(src) + j.ContainsAll = &binaryJSON{} + j.ContainsAll.Left.FromNode(n.Left()) + j.ContainsAll.Right.FromNode(n.Right()) + return nil + case nodeTypeContainsAny: + n := binaryNode(src) + j.ContainsAny = &binaryJSON{} + j.ContainsAny.Left.FromNode(n.Left()) + j.ContainsAny.Right.FromNode(n.Right()) + return nil + case nodeTypeEquals: + n := binaryNode(src) + j.Equals = &binaryJSON{} + j.Equals.Left.FromNode(n.Left()) + j.Equals.Right.FromNode(n.Right()) + return nil + case nodeTypeGreater: + n := binaryNode(src) + j.GreaterThan = &binaryJSON{} + j.GreaterThan.Left.FromNode(n.Left()) + j.GreaterThan.Right.FromNode(n.Right()) + return nil + case nodeTypeGreaterEqual: + n := binaryNode(src) + j.GreaterThanOrEqual = &binaryJSON{} + j.GreaterThanOrEqual.Left.FromNode(n.Left()) + j.GreaterThanOrEqual.Right.FromNode(n.Right()) + return nil + case nodeTypeIn: + n := binaryNode(src) + j.In = &binaryJSON{} + j.In.Left.FromNode(n.Left()) + j.In.Right.FromNode(n.Right()) + return nil + case nodeTypeLess: + n := binaryNode(src) + j.LessThan = &binaryJSON{} + j.LessThan.Left.FromNode(n.Left()) + j.LessThan.Right.FromNode(n.Right()) + return nil + case nodeTypeLessEqual: + n := binaryNode(src) + j.LessThanOrEqual = &binaryJSON{} + j.LessThanOrEqual.Left.FromNode(n.Left()) + j.LessThanOrEqual.Right.FromNode(n.Right()) + return nil + case nodeTypeMult: + n := binaryNode(src) + j.Times = &binaryJSON{} + j.Times.Left.FromNode(n.Left()) + j.Times.Right.FromNode(n.Right()) + return nil + case nodeTypeNotEquals: + n := binaryNode(src) + j.NotEquals = &binaryJSON{} + j.NotEquals.Left.FromNode(n.Left()) + j.NotEquals.Right.FromNode(n.Right()) + return nil + case nodeTypeOr: + n := binaryNode(src) + j.Or = &binaryJSON{} + j.Or.Left.FromNode(n.Left()) + j.Or.Right.FromNode(n.Right()) + return nil + case nodeTypeSub: + n := binaryNode(src) + j.Minus = &binaryJSON{} + j.Minus.Left.FromNode(n.Left()) + j.Minus.Right.FromNode(n.Right()) + return nil + + // ., has + // Access *strJSON `json:"."` + // Has *strJSON `json:"has"` + case nodeTypeAccess: + n := binaryNode(src) + j.Access = &strJSON{} + j.Access.Left.FromNode(n.Left()) + j.Access.Attr = n.Right().value.String() // TODO: make this nicer + return nil + case nodeTypeHas: + n := binaryNode(src) + j.Has = &strJSON{} + j.Has.Left.FromNode(n.Left()) + j.Has.Attr = n.Right().value.String() // TODO: make this nicer + return nil + // is + case nodeTypeIs: + case nodeTypeIsIn: // TODO + + // like + // Like *strJSON `json:"like"` + case nodeTypeLike: + n := binaryNode(src) + j.Like = &strJSON{} + j.Like.Left.FromNode(n.Left()) + j.Like.Attr = n.Right().value.String() // TODO: make this nicer + + // if-then-else + // IfThenElse *ifThenElseJSON `json:"if-then-else"` + case nodeTypeIf: + n := trinaryNode(src) + j.IfThenElse = &ifThenElseJSON{} + j.IfThenElse.If.FromNode(n.A()) + j.IfThenElse.Then.FromNode(n.B()) + j.IfThenElse.Else.FromNode(n.C()) + + // Set + // Set arrayJSON `json:"Set"` + case nodeTypeSet: + j.Set = arrayJSON{} + for _, v := range src.args { + var nn nodeJSON + if err := nn.FromNode(v); err != nil { + return err + } + j.Set = append(j.Set, nn) + } + return nil + + // Record + // Record recordJSON `json:"Record"` + case nodeTypeRecord: + j.Record = recordJSON{} + for _, kv := range src.args { + n := binaryNode(kv) + // TODO: make this nicer + var nn nodeJSON + if err := nn.FromNode(n.Right()); err != nil { + return err + } + j.Record[n.Left().value.String()] = nn + } + return nil + + // Any other function: decimal, ip + // Decimal arrayJSON `json:"decimal"` + // IP arrayJSON `json:"ip"` + case nodeTypeDecimal: + j.Decimal = arrayJSON{} + str := src.value.String() // TODO: make this nicer + b := []byte(str) + j.Decimal = append(j.Decimal, nodeJSON{ + Value: (*json.RawMessage)(&b), + }, + ) + return nil + + case nodeTypeIpAddr: + j.IP = arrayJSON{} + str := src.value.String() // TODO: make this nicer + b := []byte(str) + j.IP = append(j.IP, nodeJSON{ + Value: (*json.RawMessage)(&b), + }, + ) + return nil + + // Any other method: lessThan, lessThanOrEqual, greaterThan, greaterThanOrEqual, isIpv4, isIpv6, isLoopback, isMulticast, isInRange + // LessThanExt arrayJSON `json:"lessThan"` + // LessThanOrEqualExt arrayJSON `json:"lessThanOrEqual"` + // GreaterThanExt arrayJSON `json:"greaterThan"` + // GreaterThanOrEqualExt arrayJSON `json:"greaterThanOrEqual"` + // IsIpv4Ext arrayJSON `json:"isIpv4"` + // IsIpv6Ext arrayJSON `json:"isIpv6"` + // IsLoopbackExt arrayJSON `json:"isLoopback"` + // IsMulticastExt arrayJSON `json:"isMulticast"` + // IsInRangeExt arrayJSON `json:"isInRange"` + case nodeTypeLessExt: + n := binaryNode(src) + j.LessThanExt = arrayJSON{} + var left, right nodeJSON + if err := left.FromNode(n.Left()); err != nil { + return err + } + if err := right.FromNode(n.Right()); err != nil { + return err + } + j.LessThanExt = append(j.LessThanExt, left, right) + return nil + case nodeTypeLessEqualExt: + n := binaryNode(src) + j.LessThanOrEqualExt = arrayJSON{} + var left, right nodeJSON + if err := left.FromNode(n.Left()); err != nil { + return err + } + if err := right.FromNode(n.Right()); err != nil { + return err + } + j.LessThanOrEqualExt = append(j.LessThanOrEqualExt, left, right) + return nil + case nodeTypeGreaterExt: + n := binaryNode(src) + j.GreaterThanExt = arrayJSON{} + var left, right nodeJSON + if err := left.FromNode(n.Left()); err != nil { + return err + } + if err := right.FromNode(n.Right()); err != nil { + return err + } + j.GreaterThanExt = append(j.GreaterThanExt, left, right) + return nil + case nodeTypeGreaterEqualExt: + n := binaryNode(src) + j.GreaterThanOrEqualExt = arrayJSON{} + var left, right nodeJSON + if err := left.FromNode(n.Left()); err != nil { + return err + } + if err := right.FromNode(n.Right()); err != nil { + return err + } + j.GreaterThanOrEqualExt = append(j.GreaterThanOrEqualExt, left, right) + return nil + case nodeTypeIsInRange: + n := binaryNode(src) + j.IsInRangeExt = arrayJSON{} + var left, right nodeJSON + if err := left.FromNode(n.Left()); err != nil { + return err + } + if err := right.FromNode(n.Right()); err != nil { + return err + } + j.IsInRangeExt = append(j.IsInRangeExt, left, right) + return nil + + case nodeTypeIsIpv4: + n := unaryNode(src) + j.IsIpv4Ext = arrayJSON{} + var arg nodeJSON + if err := arg.FromNode(n.Arg()); err != nil { + return err + } + j.IsIpv4Ext = append(j.IsIpv4Ext, arg) + return nil + case nodeTypeIsIpv6: + n := unaryNode(src) + j.IsIpv6Ext = arrayJSON{} + var arg nodeJSON + if err := arg.FromNode(n.Arg()); err != nil { + return err + } + j.IsIpv6Ext = append(j.IsIpv6Ext, arg) + return nil + case nodeTypeIsLoopback: + n := unaryNode(src) + j.IsLoopbackExt = arrayJSON{} + var arg nodeJSON + if err := arg.FromNode(n.Arg()); err != nil { + return err + } + j.IsLoopbackExt = append(j.IsLoopbackExt, arg) + return nil + case nodeTypeIsMulticast: + n := unaryNode(src) + j.IsMulticastExt = arrayJSON{} + var arg nodeJSON + if err := arg.FromNode(n.Arg()); err != nil { + return err + } + j.IsMulticastExt = append(j.IsMulticastExt, arg) + return nil + + } + // case nodeTypeRecordEntry: + // case nodeTypeEntityType: + // case nodeTypeAnnotation: + // case nodeTypeWhen: + // case nodeTypeUnless: + return fmt.Errorf("unknown node type: %v", src.nodeType) } func (p *Policy) MarshalJSON() ([]byte, error) { var j policyJSON @@ -69,8 +391,13 @@ func (p *Policy) MarshalJSON() ([]byte, error) { return nil, fmt.Errorf("error in resource: %w", err) } for _, c := range p.conditions { + n := unaryNode(c) var cond conditionJSON - if err := cond.Body.FromNode(c); err != nil { + cond.Kind = "when" + if c.nodeType == nodeTypeUnless { + cond.Kind = "unless" + } + if err := cond.Body.FromNode(n.Arg()); err != nil { return nil, fmt.Errorf("error in condition: %w", err) } j.Conditions = append(j.Conditions, cond) diff --git a/x/exp/ast/json_test.go b/x/exp/ast/json_test.go index bf2b75d2..710f5ef1 100644 --- a/x/exp/ast/json_test.go +++ b/x/exp/ast/json_test.go @@ -3,8 +3,10 @@ package ast_test import ( "encoding/json" "reflect" + "strings" "testing" + "github.com/cedar-policy/cedar-go/testutil" "github.com/cedar-policy/cedar-go/types" "github.com/cedar-policy/cedar-go/x/exp/ast" ) @@ -91,7 +93,8 @@ func TestUnmarshalJSON(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() var p ast.Policy - err := json.Unmarshal([]byte(tt.input), &p) + fixedInput := strings.ReplaceAll(tt.input, "\t", " ") + err := json.Unmarshal([]byte(fixedInput), &p) if (err != nil) != tt.wantErr { t.Errorf("error got: %v want: %v", err, tt.wantErr) } @@ -99,9 +102,9 @@ func TestUnmarshalJSON(t *testing.T) { t.Errorf("policy mismatch: got: %+v want: %+v", p, *tt.want) } - // b, err := json.MarshalIndent(&p, "", " ") - // testutil.OK(t, err) - // testutil.Equals(t, string(b), tt.input) + b, err := json.MarshalIndent(&p, "", " ") + testutil.OK(t, err) + testutil.Equals(t, string(b), fixedInput) }) } diff --git a/x/exp/ast/json_unmarshal.go b/x/exp/ast/json_unmarshal.go index e313a92a..705becdf 100644 --- a/x/exp/ast/json_unmarshal.go +++ b/x/exp/ast/json_unmarshal.go @@ -1,7 +1,6 @@ package ast import ( - "bytes" "encoding/json" "fmt" @@ -11,7 +10,7 @@ import ( func (s *scopeJSON) ToNode(variable Node) (Node, error) { switch s.Op { case "All": - return True(), nil + return Node{}, nil case "==": if s.Entity == nil { return Node{}, fmt.Errorf("missing entity") @@ -162,50 +161,15 @@ func (j recordJSON) ToNode() (Node, error) { return RecordNodes(nodes), nil } -var ( // TODO: de-dupe from types? - errJSONDecode = fmt.Errorf("error decoding json") - errJSONLongOutOfRange = fmt.Errorf("long out of range") - errJSONUnsupportedType = fmt.Errorf("unsupported type") -) - -func parseRawMessage(j *json.RawMessage) (Node, error) { - // TODO: de-dupe from types? though it's not 100% compat, because of extensions :( - // TODO: make this faster if it matters - { - var res types.EntityUID - ptr := &res - if err := ptr.UnmarshalJSON(*j); err == nil { - return Entity(res), nil - } - } - - var res interface{} - dec := json.NewDecoder(bytes.NewBuffer(*j)) - dec.UseNumber() - if err := dec.Decode(&res); err != nil { - return Node{}, fmt.Errorf("%w: %w", errJSONDecode, err) - } - switch vv := res.(type) { - case string: - return String(types.String(vv)), nil - case bool: - return Boolean(types.Boolean(vv)), nil - case json.Number: - l, err := vv.Int64() - if err != nil { - return Node{}, fmt.Errorf("%w: %w", errJSONLongOutOfRange, err) - } - return Long(types.Long(l)), nil - } - return Node{}, errJSONUnsupportedType - -} - func (j nodeJSON) ToNode() (Node, error) { switch { // Value case j.Value != nil: - return parseRawMessage(j.Value) + var v types.Value + if err := types.UnmarshalJSON(*j.Value, &v); err != nil { + return Node{}, fmt.Errorf("error unmarshalling value: %w", err) + } + return valueToNode(v), nil // Var case j.Var != nil: diff --git a/x/exp/ast/node.go b/x/exp/ast/node.go index 71088898..25ed44c3 100644 --- a/x/exp/ast/node.go +++ b/x/exp/ast/node.go @@ -5,7 +5,8 @@ import "github.com/cedar-policy/cedar-go/types" type nodeType uint8 const ( - nodeTypeAccess nodeType = iota + nodeTypeNone nodeType = iota + nodeTypeAccess nodeTypeAdd nodeTypeAll nodeTypeAnd @@ -59,3 +60,30 @@ type Node struct { args []Node // For inner nodes like operators, records, etc value types.Value // For leaf nodes like String, Long, EntityUID } + +func newUnaryNode(op nodeType, arg Node) Node { + return Node{nodeType: op, args: []Node{arg}} +} + +type unaryNode Node + +func (n unaryNode) Arg() Node { return n.args[0] } + +func newBinaryNode(op nodeType, arg1, arg2 Node) Node { + return Node{nodeType: op, args: []Node{arg1, arg2}} +} + +type binaryNode Node + +func (n binaryNode) Left() Node { return n.args[0] } +func (n binaryNode) Right() Node { return n.args[1] } + +func newTrinaryNode(op nodeType, arg1, arg2, arg3 Node) Node { + return Node{nodeType: op, args: []Node{arg1, arg2, arg3}} +} + +type trinaryNode Node + +func (n trinaryNode) A() Node { return n.args[0] } +func (n trinaryNode) B() Node { return n.args[1] } +func (n trinaryNode) C() Node { return n.args[2] } diff --git a/x/exp/ast/operator.go b/x/exp/ast/operator.go index b20be4cb..d3f96990 100644 --- a/x/exp/ast/operator.go +++ b/x/exp/ast/operator.go @@ -10,47 +10,47 @@ import "github.com/cedar-policy/cedar-go/types" // |_| func (lhs Node) Equals(rhs Node) Node { - return newOpNode(nodeTypeEquals, lhs, rhs) + return newBinaryNode(nodeTypeEquals, lhs, rhs) } func (lhs Node) NotEquals(rhs Node) Node { - return newOpNode(nodeTypeNotEquals, lhs, rhs) + return newBinaryNode(nodeTypeNotEquals, lhs, rhs) } func (lhs Node) LessThan(rhs Node) Node { - return newOpNode(nodeTypeLess, lhs, rhs) + return newBinaryNode(nodeTypeLess, lhs, rhs) } func (lhs Node) LessThanOrEqual(rhs Node) Node { - return newOpNode(nodeTypeLessEqual, lhs, rhs) + return newBinaryNode(nodeTypeLessEqual, lhs, rhs) } func (lhs Node) GreaterThan(rhs Node) Node { - return newOpNode(nodeTypeGreater, lhs, rhs) + return newBinaryNode(nodeTypeGreater, lhs, rhs) } func (lhs Node) GreaterThanOrEqual(rhs Node) Node { - return newOpNode(nodeTypeGreaterEqual, lhs, rhs) + return newBinaryNode(nodeTypeGreaterEqual, lhs, rhs) } func (lhs Node) LessThanExt(rhs Node) Node { - return newOpNode(nodeTypeLessExt, lhs, rhs) + return newBinaryNode(nodeTypeLessExt, lhs, rhs) } func (lhs Node) LessThanOrEqualExt(rhs Node) Node { - return newOpNode(nodeTypeLessEqualExt, lhs, rhs) + return newBinaryNode(nodeTypeLessEqualExt, lhs, rhs) } func (lhs Node) GreaterThanExt(rhs Node) Node { - return newOpNode(nodeTypeGreaterExt, lhs, rhs) + return newBinaryNode(nodeTypeGreaterExt, lhs, rhs) } func (lhs Node) GreaterThanOrEqualExt(rhs Node) Node { - return newOpNode(nodeTypeGreaterEqualExt, lhs, rhs) + return newBinaryNode(nodeTypeGreaterEqualExt, lhs, rhs) } func (lhs Node) Like(patt string) Node { - return newOpNode(nodeTypeLike, lhs, String(types.String(patt))) + return newBinaryNode(nodeTypeLike, lhs, String(types.String(patt))) } // _ _ _ @@ -61,23 +61,23 @@ func (lhs Node) Like(patt string) Node { // |___/ func (lhs Node) And(rhs Node) Node { - return newOpNode(nodeTypeAnd, lhs, rhs) + return newBinaryNode(nodeTypeAnd, lhs, rhs) } func (lhs Node) Or(rhs Node) Node { - return newOpNode(nodeTypeOr, lhs, rhs) + return newBinaryNode(nodeTypeOr, lhs, rhs) } func Not(rhs Node) Node { - return newOpNode(nodeTypeNot, rhs) + return newUnaryNode(nodeTypeNot, rhs) } func Negate(rhs Node) Node { - return newOpNode(nodeTypeNegate, rhs) + return newUnaryNode(nodeTypeNegate, rhs) } func If(condition Node, ifTrue Node, ifFalse Node) Node { - return newOpNode(nodeTypeIf, condition, ifTrue, ifFalse) + return newTrinaryNode(nodeTypeIf, condition, ifTrue, ifFalse) } // _ _ _ _ _ _ @@ -87,15 +87,15 @@ func If(condition Node, ifTrue Node, ifFalse Node) Node { // /_/ \_\_| |_|\__|_| |_|_| |_| |_|\___|\__|_|\___| func (lhs Node) Plus(rhs Node) Node { - return newOpNode(nodeTypeAdd, lhs, rhs) + return newBinaryNode(nodeTypeAdd, lhs, rhs) } func (lhs Node) Minus(rhs Node) Node { - return newOpNode(nodeTypeSub, lhs, rhs) + return newBinaryNode(nodeTypeSub, lhs, rhs) } func (lhs Node) Times(rhs Node) Node { - return newOpNode(nodeTypeMult, lhs, rhs) + return newBinaryNode(nodeTypeMult, lhs, rhs) } // _ _ _ _ @@ -106,27 +106,27 @@ func (lhs Node) Times(rhs Node) Node { // |___/ func (lhs Node) In(rhs Node) Node { - return newOpNode(nodeTypeIn, lhs, rhs) + return newBinaryNode(nodeTypeIn, lhs, rhs) } func (lhs Node) Is(entityType types.String) Node { - return newOpNode(nodeTypeIs, lhs, String(entityType)) + return newBinaryNode(nodeTypeIs, lhs, String(entityType)) } func (lhs Node) IsIn(entityType types.String, rhs Node) Node { - return newOpNode(nodeTypeIsIn, lhs, String(entityType), rhs) + return newTrinaryNode(nodeTypeIsIn, lhs, String(entityType), rhs) } func (lhs Node) Contains(rhs Node) Node { - return newOpNode(nodeTypeContains, lhs, rhs) + return newBinaryNode(nodeTypeContains, lhs, rhs) } func (lhs Node) ContainsAll(rhs Node) Node { - return newOpNode(nodeTypeContainsAll, lhs, rhs) + return newBinaryNode(nodeTypeContainsAll, lhs, rhs) } func (lhs Node) ContainsAny(rhs Node) Node { - return newOpNode(nodeTypeContainsAny, lhs, rhs) + return newBinaryNode(nodeTypeContainsAny, lhs, rhs) } // Access is a convenience function that wraps a simple string @@ -147,11 +147,11 @@ func (lhs Node) Access(attr string) Node { // ast.Context().Access("resourceAttribute") // ).Equals(ast.String("foo")) func (lhs Node) AccessNode(rhs Node) Node { - return newOpNode(nodeTypeAccess, lhs, rhs) + return newBinaryNode(nodeTypeAccess, lhs, rhs) } func (lhs Node) Has(attr string) Node { - return newOpNode(nodeTypeHas, lhs, String(types.String(attr))) + return newBinaryNode(nodeTypeHas, lhs, String(types.String(attr))) } // ___ ____ _ _ _ @@ -161,25 +161,21 @@ func (lhs Node) Has(attr string) Node { // |___|_| /_/ \_\__,_|\__,_|_| \___||___/___/ func (lhs Node) IsIpv4() Node { - return newOpNode(nodeTypeIsIpv4, lhs) + return newUnaryNode(nodeTypeIsIpv4, lhs) } func (lhs Node) IsIpv6() Node { - return newOpNode(nodeTypeIsIpv6, lhs) + return newUnaryNode(nodeTypeIsIpv6, lhs) } func (lhs Node) IsMulticast() Node { - return newOpNode(nodeTypeIsMulticast, lhs) + return newUnaryNode(nodeTypeIsMulticast, lhs) } func (lhs Node) IsLoopback() Node { - return newOpNode(nodeTypeIsLoopback, lhs) + return newUnaryNode(nodeTypeIsLoopback, lhs) } func (lhs Node) IsInRange(rhs Node) Node { - return newOpNode(nodeTypeIsInRange, lhs, rhs) -} - -func newOpNode(op nodeType, args ...Node) Node { - return Node{nodeType: op, args: args} + return newBinaryNode(nodeTypeIsInRange, lhs, rhs) } diff --git a/x/exp/ast/policy.go b/x/exp/ast/policy.go index ab1645f9..34b408df 100644 --- a/x/exp/ast/policy.go +++ b/x/exp/ast/policy.go @@ -28,12 +28,12 @@ func Forbid() *Policy { } func (p *Policy) When(node Node) *Policy { - p.conditions = append(p.conditions, Node{nodeType: nodeTypeUnless, args: []Node{node}}) + p.conditions = append(p.conditions, Node{nodeType: nodeTypeWhen, args: []Node{node}}) return p } func (p *Policy) Unless(node Node) *Policy { - p.conditions = append(p.conditions, Node{nodeType: nodeTypeWhen, args: []Node{node}}) + p.conditions = append(p.conditions, Node{nodeType: nodeTypeUnless, args: []Node{node}}) return p } diff --git a/x/exp/ast/value.go b/x/exp/ast/value.go index 46bf7134..9adca6cb 100644 --- a/x/exp/ast/value.go +++ b/x/exp/ast/value.go @@ -77,10 +77,7 @@ func RecordNodes(entries map[types.String]Node) Node { for k, v := range entries { nodes = append( nodes, - Node{ - nodeType: nodeTypeRecordEntry, - args: []Node{String(k), v}, - }, + newBinaryNode(nodeTypeRecordEntry, String(k), v), ) } return Node{nodeType: nodeTypeRecord, args: nodes} diff --git a/x/exp/ast/variable.go b/x/exp/ast/variable.go index 9bf750a7..a45467b1 100644 --- a/x/exp/ast/variable.go +++ b/x/exp/ast/variable.go @@ -33,3 +33,9 @@ func newResourceNode() Node { func newContextNode() Node { return newValueNode(nodeTypeVariable, types.String("context")) } + +type variableNode Node + +func (v variableNode) String() types.String { + return v.value.(types.String) +} From 854aa97932b6509abfd31c94e973093735b4b5c5 Mon Sep 17 00:00:00 2001 From: philhassey Date: Wed, 31 Jul 2024 17:09:41 -0600 Subject: [PATCH 035/216] x/exp/ast: add is / is in handling for JSON marshaller Addresses IDX-55 Signed-off-by: philhassey --- x/exp/ast/json.go | 8 +++++++- x/exp/ast/json_marshal.go | 20 ++++++++++++++++++-- x/exp/ast/json_unmarshal.go | 14 ++++++++++++++ 3 files changed, 39 insertions(+), 3 deletions(-) diff --git a/x/exp/ast/json.go b/x/exp/ast/json.go index 465d70dc..18f851e6 100644 --- a/x/exp/ast/json.go +++ b/x/exp/ast/json.go @@ -46,6 +46,12 @@ type strJSON struct { Attr string `json:"attr"` } +type isJSON struct { + Left nodeJSON `json:"left"` + EntityType string `json:"entity_type"` + In *inJSON `json:"in,omitempty"` +} + type ifThenElseJSON struct { If nodeJSON `json:"if"` Then nodeJSON `json:"then"` @@ -92,7 +98,7 @@ type nodeJSON struct { Has *strJSON `json:"has,omitempty"` // is - // TODO: https://docs.cedarpolicy.com/policies/json-format.html#JsonExpr-is + Is *isJSON `json:"is,omitempty"` // like Like *strJSON `json:"like,omitempty"` diff --git a/x/exp/ast/json_marshal.go b/x/exp/ast/json_marshal.go index 78f6bfe5..f01c804e 100644 --- a/x/exp/ast/json_marshal.go +++ b/x/exp/ast/json_marshal.go @@ -3,6 +3,8 @@ package ast import ( "encoding/json" "fmt" + + "github.com/cedar-policy/cedar-go/types" ) func (s *scopeJSON) FromNode(src Node) error { @@ -161,7 +163,7 @@ func (j *nodeJSON) FromNode(src Node) error { case nodeTypeSub: n := binaryNode(src) j.Minus = &binaryJSON{} - j.Minus.Left.FromNode(n.Left()) + j.Minus.Left.FromNode(n.Left()) // TODO: in all these cases, check for an error, handle it ... j.Minus.Right.FromNode(n.Right()) return nil @@ -182,7 +184,21 @@ func (j *nodeJSON) FromNode(src Node) error { return nil // is case nodeTypeIs: - case nodeTypeIsIn: // TODO + n := binaryNode(src) + j.Is = &isJSON{ + EntityType: string(n.Right().value.(types.String)), // TODO: make this nicer + } + j.Is.Left.FromNode(n.Left()) + return nil + case nodeTypeIsIn: + n := trinaryNode(src) + j.Is = &isJSON{ + EntityType: string(n.B().value.(types.String)), // TODO: make this nicer + In: &inJSON{}, + } + j.Is.Left.FromNode(n.A()) + j.Is.In.Entity = n.C().value.(types.EntityUID) + return nil // like // Like *strJSON `json:"like"` diff --git a/x/exp/ast/json_unmarshal.go b/x/exp/ast/json_unmarshal.go index 705becdf..b0b81f25 100644 --- a/x/exp/ast/json_unmarshal.go +++ b/x/exp/ast/json_unmarshal.go @@ -59,6 +59,16 @@ func (j strJSON) ToNode(f func(a Node, k string) Node) (Node, error) { } return f(left, j.Attr), nil } +func (j isJSON) ToNode() (Node, error) { + left, err := j.Left.ToNode() + if err != nil { + return Node{}, fmt.Errorf("error in left: %w", err) + } + if j.In != nil { + return left.IsIn(types.String(j.EntityType), Entity(j.In.Entity)), nil + } + return left.Is(types.String(j.EntityType)), nil +} func (j ifThenElseJSON) ToNode() (Node, error) { if_, err := j.If.ToNode() if err != nil { @@ -232,6 +242,10 @@ func (j nodeJSON) ToNode() (Node, error) { case j.Has != nil: return j.Has.ToNode(Node.Has) + // is + case j.Is != nil: + return j.Is.ToNode() + // like case j.Like != nil: return j.Like.ToNode(Node.Like) From 3508e40653694f74a6be92b05fbff885b2d1953f Mon Sep 17 00:00:00 2001 From: philhassey Date: Wed, 31 Jul 2024 17:42:53 -0600 Subject: [PATCH 036/216] x/exp/ast: improve JSON marshaller error handling and code dryness Addresses IDX-55 Signed-off-by: philhassey --- x/exp/ast/json_marshal.go | 403 +++++++++++++++----------------------- 1 file changed, 160 insertions(+), 243 deletions(-) diff --git a/x/exp/ast/json_marshal.go b/x/exp/ast/json_marshal.go index f01c804e..8228f983 100644 --- a/x/exp/ast/json_marshal.go +++ b/x/exp/ast/json_marshal.go @@ -44,6 +44,130 @@ func (s *scopeJSON) FromNode(src Node) error { } return fmt.Errorf("unexpected scope node: %v", src.nodeType) } + +func unaryToJSON(dest **unaryJSON, src Node) error { + n := unaryNode(src) + res := &unaryJSON{} + if err := res.Arg.FromNode(n.Arg()); err != nil { + return fmt.Errorf("error in arg: %w", err) + } + *dest = res + return nil +} + +func binaryToJSON(dest **binaryJSON, src Node) error { + n := binaryNode(src) + res := &binaryJSON{} + if err := res.Left.FromNode(n.Left()); err != nil { + return fmt.Errorf("error in left: %w", err) + } + if err := res.Right.FromNode(n.Right()); err != nil { + return fmt.Errorf("error in right: %w", err) + } + *dest = res + return nil +} + +func arrayToJSON(dest *arrayJSON, src Node) error { + res := arrayJSON{} + for _, n := range src.args { + var nn nodeJSON + if err := nn.FromNode(n); err != nil { + return fmt.Errorf("error in array: %w", err) + } + res = append(res, nn) + } + *dest = res + return nil +} + +func extToJSON(dest *arrayJSON, src Node) error { + res := arrayJSON{} + if src.value == nil { + return fmt.Errorf("missing value") + } + str := src.value.String() // TODO: is this the correct string? + b := []byte(str) + res = append(res, nodeJSON{ + Value: (*json.RawMessage)(&b), + }) + *dest = res + return nil +} + +func strToJSON(dest **strJSON, src Node) error { + n := binaryNode(src) + res := &strJSON{} + if err := res.Left.FromNode(n.Left()); err != nil { + return fmt.Errorf("error in left: %w", err) + } + str, ok := n.Right().value.(types.String) + if !ok { + return fmt.Errorf("right not string") + } + res.Attr = string(str) + *dest = res + return nil +} + +func recordToJSON(dest *recordJSON, src Node) error { + res := recordJSON{} + for _, kv := range src.args { + n := binaryNode(kv) + var nn nodeJSON + if err := nn.FromNode(n.Right()); err != nil { + return err + } + str, ok := n.Left().value.(types.String) + if !ok { + return fmt.Errorf("left not string") + } + res[string(str)] = nn + } + *dest = res + return nil +} + +func ifToJSON(dest **ifThenElseJSON, src Node) error { + n := trinaryNode(src) + res := &ifThenElseJSON{} + if err := res.If.FromNode(n.A()); err != nil { + return fmt.Errorf("error in if: %w", err) + } + if err := res.Then.FromNode(n.B()); err != nil { + return fmt.Errorf("error in then: %w", err) + } + if err := res.Else.FromNode(n.C()); err != nil { + return fmt.Errorf("error in else: %w", err) + } + *dest = res + return nil +} + +func isToJSON(dest **isJSON, src Node) error { + n := binaryNode(src) + res := &isJSON{} + if err := res.Left.FromNode(n.Left()); err != nil { + return fmt.Errorf("error in left: %w", err) + } + str, ok := n.Right().value.(types.String) + if !ok { + return fmt.Errorf("right not a string") + } + res.EntityType = string(str) + if len(src.args) == 3 { + ent, ok := src.args[2].value.(types.EntityUID) + if !ok { + return fmt.Errorf("in not an entity") + } + res.In = &inJSON{ + Entity: ent, + } + } + *dest = res + return nil +} + func (j *nodeJSON) FromNode(src Node) error { switch src.nodeType { // Value @@ -65,208 +189,81 @@ func (j *nodeJSON) FromNode(src Node) error { // Not *unaryJSON `json:"!"` // Negate *unaryJSON `json:"neg"` case nodeTypeNot: - n := unaryNode(src) - j.Not = &unaryJSON{} - j.Not.Arg.FromNode(n.Arg()) - return nil + return unaryToJSON(&j.Not, src) case nodeTypeNegate: - n := unaryNode(src) - j.Negate = &unaryJSON{} - j.Negate.Arg.FromNode(n.Arg()) - return nil + return unaryToJSON(&j.Not, src) // Binary operators: ==, !=, in, <, <=, >, >=, &&, ||, +, -, *, contains, containsAll, containsAny case nodeTypeAdd: - n := binaryNode(src) - j.Plus = &binaryJSON{} - j.Plus.Left.FromNode(n.Left()) - j.Plus.Right.FromNode(n.Right()) - return nil + return binaryToJSON(&j.Plus, src) case nodeTypeAnd: - n := binaryNode(src) - j.And = &binaryJSON{} - j.And.Left.FromNode(n.Left()) - j.And.Right.FromNode(n.Right()) - return nil + return binaryToJSON(&j.And, src) case nodeTypeContains: - n := binaryNode(src) - j.Contains = &binaryJSON{} - j.Contains.Left.FromNode(n.Left()) - j.Contains.Right.FromNode(n.Right()) - return nil + return binaryToJSON(&j.Contains, src) case nodeTypeContainsAll: - n := binaryNode(src) - j.ContainsAll = &binaryJSON{} - j.ContainsAll.Left.FromNode(n.Left()) - j.ContainsAll.Right.FromNode(n.Right()) - return nil + return binaryToJSON(&j.ContainsAll, src) case nodeTypeContainsAny: - n := binaryNode(src) - j.ContainsAny = &binaryJSON{} - j.ContainsAny.Left.FromNode(n.Left()) - j.ContainsAny.Right.FromNode(n.Right()) - return nil + return binaryToJSON(&j.ContainsAny, src) case nodeTypeEquals: - n := binaryNode(src) - j.Equals = &binaryJSON{} - j.Equals.Left.FromNode(n.Left()) - j.Equals.Right.FromNode(n.Right()) - return nil + return binaryToJSON(&j.Equals, src) case nodeTypeGreater: - n := binaryNode(src) - j.GreaterThan = &binaryJSON{} - j.GreaterThan.Left.FromNode(n.Left()) - j.GreaterThan.Right.FromNode(n.Right()) - return nil + return binaryToJSON(&j.GreaterThan, src) case nodeTypeGreaterEqual: - n := binaryNode(src) - j.GreaterThanOrEqual = &binaryJSON{} - j.GreaterThanOrEqual.Left.FromNode(n.Left()) - j.GreaterThanOrEqual.Right.FromNode(n.Right()) - return nil + return binaryToJSON(&j.GreaterThanOrEqual, src) case nodeTypeIn: - n := binaryNode(src) - j.In = &binaryJSON{} - j.In.Left.FromNode(n.Left()) - j.In.Right.FromNode(n.Right()) - return nil + return binaryToJSON(&j.In, src) case nodeTypeLess: - n := binaryNode(src) - j.LessThan = &binaryJSON{} - j.LessThan.Left.FromNode(n.Left()) - j.LessThan.Right.FromNode(n.Right()) - return nil + return binaryToJSON(&j.LessThan, src) case nodeTypeLessEqual: - n := binaryNode(src) - j.LessThanOrEqual = &binaryJSON{} - j.LessThanOrEqual.Left.FromNode(n.Left()) - j.LessThanOrEqual.Right.FromNode(n.Right()) - return nil + return binaryToJSON(&j.LessThanOrEqual, src) case nodeTypeMult: - n := binaryNode(src) - j.Times = &binaryJSON{} - j.Times.Left.FromNode(n.Left()) - j.Times.Right.FromNode(n.Right()) - return nil + return binaryToJSON(&j.Times, src) case nodeTypeNotEquals: - n := binaryNode(src) - j.NotEquals = &binaryJSON{} - j.NotEquals.Left.FromNode(n.Left()) - j.NotEquals.Right.FromNode(n.Right()) - return nil + return binaryToJSON(&j.NotEquals, src) case nodeTypeOr: - n := binaryNode(src) - j.Or = &binaryJSON{} - j.Or.Left.FromNode(n.Left()) - j.Or.Right.FromNode(n.Right()) - return nil + return binaryToJSON(&j.Or, src) case nodeTypeSub: - n := binaryNode(src) - j.Minus = &binaryJSON{} - j.Minus.Left.FromNode(n.Left()) // TODO: in all these cases, check for an error, handle it ... - j.Minus.Right.FromNode(n.Right()) - return nil + return binaryToJSON(&j.Minus, src) // ., has // Access *strJSON `json:"."` // Has *strJSON `json:"has"` case nodeTypeAccess: - n := binaryNode(src) - j.Access = &strJSON{} - j.Access.Left.FromNode(n.Left()) - j.Access.Attr = n.Right().value.String() // TODO: make this nicer - return nil + return strToJSON(&j.Access, src) case nodeTypeHas: - n := binaryNode(src) - j.Has = &strJSON{} - j.Has.Left.FromNode(n.Left()) - j.Has.Attr = n.Right().value.String() // TODO: make this nicer - return nil + return strToJSON(&j.Access, src) // is - case nodeTypeIs: - n := binaryNode(src) - j.Is = &isJSON{ - EntityType: string(n.Right().value.(types.String)), // TODO: make this nicer - } - j.Is.Left.FromNode(n.Left()) - return nil - case nodeTypeIsIn: - n := trinaryNode(src) - j.Is = &isJSON{ - EntityType: string(n.B().value.(types.String)), // TODO: make this nicer - In: &inJSON{}, - } - j.Is.Left.FromNode(n.A()) - j.Is.In.Entity = n.C().value.(types.EntityUID) - return nil + case nodeTypeIs, nodeTypeIsIn: + return isToJSON(&j.Is, src) // like // Like *strJSON `json:"like"` case nodeTypeLike: - n := binaryNode(src) - j.Like = &strJSON{} - j.Like.Left.FromNode(n.Left()) - j.Like.Attr = n.Right().value.String() // TODO: make this nicer + return strToJSON(&j.Access, src) // if-then-else // IfThenElse *ifThenElseJSON `json:"if-then-else"` case nodeTypeIf: - n := trinaryNode(src) - j.IfThenElse = &ifThenElseJSON{} - j.IfThenElse.If.FromNode(n.A()) - j.IfThenElse.Then.FromNode(n.B()) - j.IfThenElse.Else.FromNode(n.C()) + return ifToJSON(&j.IfThenElse, src) // Set // Set arrayJSON `json:"Set"` case nodeTypeSet: - j.Set = arrayJSON{} - for _, v := range src.args { - var nn nodeJSON - if err := nn.FromNode(v); err != nil { - return err - } - j.Set = append(j.Set, nn) - } - return nil + return arrayToJSON(&j.Set, src) // Record // Record recordJSON `json:"Record"` case nodeTypeRecord: - j.Record = recordJSON{} - for _, kv := range src.args { - n := binaryNode(kv) - // TODO: make this nicer - var nn nodeJSON - if err := nn.FromNode(n.Right()); err != nil { - return err - } - j.Record[n.Left().value.String()] = nn - } - return nil + return recordToJSON(&j.Record, src) // Any other function: decimal, ip // Decimal arrayJSON `json:"decimal"` // IP arrayJSON `json:"ip"` case nodeTypeDecimal: - j.Decimal = arrayJSON{} - str := src.value.String() // TODO: make this nicer - b := []byte(str) - j.Decimal = append(j.Decimal, nodeJSON{ - Value: (*json.RawMessage)(&b), - }, - ) - return nil + return extToJSON(&j.Decimal, src) case nodeTypeIpAddr: - j.IP = arrayJSON{} - str := src.value.String() // TODO: make this nicer - b := []byte(str) - j.IP = append(j.IP, nodeJSON{ - Value: (*json.RawMessage)(&b), - }, - ) - return nil + return extToJSON(&j.IP, src) // Any other method: lessThan, lessThanOrEqual, greaterThan, greaterThanOrEqual, isIpv4, isIpv6, isLoopback, isMulticast, isInRange // LessThanExt arrayJSON `json:"lessThan"` @@ -279,103 +276,23 @@ func (j *nodeJSON) FromNode(src Node) error { // IsMulticastExt arrayJSON `json:"isMulticast"` // IsInRangeExt arrayJSON `json:"isInRange"` case nodeTypeLessExt: - n := binaryNode(src) - j.LessThanExt = arrayJSON{} - var left, right nodeJSON - if err := left.FromNode(n.Left()); err != nil { - return err - } - if err := right.FromNode(n.Right()); err != nil { - return err - } - j.LessThanExt = append(j.LessThanExt, left, right) - return nil + return arrayToJSON(&j.LessThanExt, src) case nodeTypeLessEqualExt: - n := binaryNode(src) - j.LessThanOrEqualExt = arrayJSON{} - var left, right nodeJSON - if err := left.FromNode(n.Left()); err != nil { - return err - } - if err := right.FromNode(n.Right()); err != nil { - return err - } - j.LessThanOrEqualExt = append(j.LessThanOrEqualExt, left, right) - return nil + return arrayToJSON(&j.LessThanOrEqualExt, src) case nodeTypeGreaterExt: - n := binaryNode(src) - j.GreaterThanExt = arrayJSON{} - var left, right nodeJSON - if err := left.FromNode(n.Left()); err != nil { - return err - } - if err := right.FromNode(n.Right()); err != nil { - return err - } - j.GreaterThanExt = append(j.GreaterThanExt, left, right) - return nil + return arrayToJSON(&j.GreaterThanExt, src) case nodeTypeGreaterEqualExt: - n := binaryNode(src) - j.GreaterThanOrEqualExt = arrayJSON{} - var left, right nodeJSON - if err := left.FromNode(n.Left()); err != nil { - return err - } - if err := right.FromNode(n.Right()); err != nil { - return err - } - j.GreaterThanOrEqualExt = append(j.GreaterThanOrEqualExt, left, right) - return nil + return arrayToJSON(&j.GreaterThanOrEqualExt, src) case nodeTypeIsInRange: - n := binaryNode(src) - j.IsInRangeExt = arrayJSON{} - var left, right nodeJSON - if err := left.FromNode(n.Left()); err != nil { - return err - } - if err := right.FromNode(n.Right()); err != nil { - return err - } - j.IsInRangeExt = append(j.IsInRangeExt, left, right) - return nil - + return arrayToJSON(&j.IsInRangeExt, src) case nodeTypeIsIpv4: - n := unaryNode(src) - j.IsIpv4Ext = arrayJSON{} - var arg nodeJSON - if err := arg.FromNode(n.Arg()); err != nil { - return err - } - j.IsIpv4Ext = append(j.IsIpv4Ext, arg) - return nil + return arrayToJSON(&j.IsIpv4Ext, src) case nodeTypeIsIpv6: - n := unaryNode(src) - j.IsIpv6Ext = arrayJSON{} - var arg nodeJSON - if err := arg.FromNode(n.Arg()); err != nil { - return err - } - j.IsIpv6Ext = append(j.IsIpv6Ext, arg) - return nil + return arrayToJSON(&j.IsIpv6Ext, src) case nodeTypeIsLoopback: - n := unaryNode(src) - j.IsLoopbackExt = arrayJSON{} - var arg nodeJSON - if err := arg.FromNode(n.Arg()); err != nil { - return err - } - j.IsLoopbackExt = append(j.IsLoopbackExt, arg) - return nil + return arrayToJSON(&j.IsLoopbackExt, src) case nodeTypeIsMulticast: - n := unaryNode(src) - j.IsMulticastExt = arrayJSON{} - var arg nodeJSON - if err := arg.FromNode(n.Arg()); err != nil { - return err - } - j.IsMulticastExt = append(j.IsMulticastExt, arg) - return nil - + return arrayToJSON(&j.IsMulticastExt, src) } // case nodeTypeRecordEntry: // case nodeTypeEntityType: From 876e6c261a64c59fb8bbcb8b73ee4dd35e519cd0 Mon Sep 17 00:00:00 2001 From: philhassey Date: Thu, 1 Aug 2024 10:50:17 -0600 Subject: [PATCH 037/216] x/exp/ast: improve adoption of All node type Addresses IDX-55 Signed-off-by: philhassey --- x/exp/ast/json_marshal.go | 2 +- x/exp/ast/json_unmarshal.go | 24 ++++++++++-------------- x/exp/ast/policy.go | 6 +++--- x/exp/ast/scope.go | 4 ++++ 4 files changed, 18 insertions(+), 18 deletions(-) diff --git a/x/exp/ast/json_marshal.go b/x/exp/ast/json_marshal.go index 8228f983..17874108 100644 --- a/x/exp/ast/json_marshal.go +++ b/x/exp/ast/json_marshal.go @@ -9,7 +9,7 @@ import ( func (s *scopeJSON) FromNode(src Node) error { switch src.nodeType { - case nodeTypeNone: + case nodeTypeAll: s.Op = "All" return nil case nodeTypeEquals: diff --git a/x/exp/ast/json_unmarshal.go b/x/exp/ast/json_unmarshal.go index b0b81f25..0879f6cb 100644 --- a/x/exp/ast/json_unmarshal.go +++ b/x/exp/ast/json_unmarshal.go @@ -7,29 +7,25 @@ import ( "github.com/cedar-policy/cedar-go/types" ) -func (s *scopeJSON) ToNode(variable Node) (Node, error) { +func (s *scopeJSON) ToNode(variable scope) (Node, error) { switch s.Op { case "All": - return Node{}, nil + return variable.All(), nil case "==": if s.Entity == nil { return Node{}, fmt.Errorf("missing entity") } - return variable.Equals(Entity(*s.Entity)), nil + return variable.Eq(*s.Entity), nil case "in": if s.Entity != nil { - return variable.In(Entity(*s.Entity)), nil // TODO: review shape, maybe .In vs .InNode + return variable.In(*s.Entity), nil } - var set types.Set - for _, e := range s.Entities { - set = append(set, e) - } - return variable.In(Set(set)), nil // TODO: maybe there is an In and an InSet Node? + return variable.InSet(s.Entities), nil case "is": if s.In == nil { - return variable.Is(types.String(s.EntityType)), nil // TODO: hmmm, I'm not sure can this be Stronger-typed? + return variable.Is(types.String(s.EntityType)), nil } - return variable.IsIn(types.String(s.EntityType), Entity(s.In.Entity)), nil + return variable.IsIn(types.String(s.EntityType), s.In.Entity), nil } return Node{}, fmt.Errorf("unknown op: %v", s.Op) } @@ -309,15 +305,15 @@ func (p *Policy) UnmarshalJSON(b []byte) error { p.Annotate(types.String(k), types.String(v)) } var err error - p.principal, err = j.Principal.ToNode(Principal()) + p.principal, err = j.Principal.ToNode(scope(Principal())) if err != nil { return fmt.Errorf("error in principal: %w", err) } - p.action, err = j.Action.ToNode(Action()) + p.action, err = j.Action.ToNode(scope(Action())) if err != nil { return fmt.Errorf("error in action: %w", err) } - p.resource, err = j.Resource.ToNode(Resource()) + p.resource, err = j.Resource.ToNode(scope(Resource())) if err != nil { return fmt.Errorf("error in resource: %w", err) } diff --git a/x/exp/ast/policy.go b/x/exp/ast/policy.go index 34b408df..250a625c 100644 --- a/x/exp/ast/policy.go +++ b/x/exp/ast/policy.go @@ -13,9 +13,9 @@ func newPolicy(effect effect, annotations []Node) *Policy { return &Policy{ effect: effect, annotations: annotations, - principal: Node{nodeType: nodeTypeAll}, - action: Node{nodeType: nodeTypeAll}, - resource: Node{nodeType: nodeTypeAll}, + principal: scope(Principal()).All(), + action: scope(Action()).All(), + resource: scope(Resource()).All(), } } diff --git a/x/exp/ast/scope.go b/x/exp/ast/scope.go index e628c73f..a4fd5dbf 100644 --- a/x/exp/ast/scope.go +++ b/x/exp/ast/scope.go @@ -4,6 +4,10 @@ import "github.com/cedar-policy/cedar-go/types" type scope Node +func (s scope) All() Node { + return Node{nodeType: nodeTypeAll, args: []Node{Node(s)}} +} + func (s scope) Eq(entity types.EntityUID) Node { return Node(s).Equals(Entity(entity)) } From 9f164a7f2d77e8978815788f3928cc248b8a15de Mon Sep 17 00:00:00 2001 From: philhassey Date: Thu, 1 Aug 2024 14:20:20 -0600 Subject: [PATCH 038/216] x/exp/ast: add happy path coverage for JSON marshal/unmarshal Addresses IDX-55 Signed-off-by: philhassey --- x/exp/ast/json.go | 7 +- x/exp/ast/json_marshal.go | 25 ++- x/exp/ast/json_test.go | 389 ++++++++++++++++++++++++++++++++++-- x/exp/ast/json_unmarshal.go | 7 + 4 files changed, 408 insertions(+), 20 deletions(-) diff --git a/x/exp/ast/json.go b/x/exp/ast/json.go index 18f851e6..322ee448 100644 --- a/x/exp/ast/json.go +++ b/x/exp/ast/json.go @@ -46,6 +46,11 @@ type strJSON struct { Attr string `json:"attr"` } +type patternJSON struct { + Left nodeJSON `json:"left"` + Pattern string `json:"pattern"` +} + type isJSON struct { Left nodeJSON `json:"left"` EntityType string `json:"entity_type"` @@ -101,7 +106,7 @@ type nodeJSON struct { Is *isJSON `json:"is,omitempty"` // like - Like *strJSON `json:"like,omitempty"` + Like *patternJSON `json:"like,omitempty"` // if-then-else IfThenElse *ifThenElseJSON `json:"if-then-else,omitempty"` diff --git a/x/exp/ast/json_marshal.go b/x/exp/ast/json_marshal.go index 17874108..33d9a422 100644 --- a/x/exp/ast/json_marshal.go +++ b/x/exp/ast/json_marshal.go @@ -86,8 +86,8 @@ func extToJSON(dest *arrayJSON, src Node) error { if src.value == nil { return fmt.Errorf("missing value") } - str := src.value.String() // TODO: is this the correct string? - b := []byte(str) + str := src.value.String() // TODO: is this the correct string? + b, _ := json.Marshal(string(str)) // error impossible res = append(res, nodeJSON{ Value: (*json.RawMessage)(&b), }) @@ -110,6 +110,21 @@ func strToJSON(dest **strJSON, src Node) error { return nil } +func patternToJSON(dest **patternJSON, src Node) error { + n := binaryNode(src) + res := &patternJSON{} + if err := res.Left.FromNode(n.Left()); err != nil { + return fmt.Errorf("error in left: %w", err) + } + str, ok := n.Right().value.(types.String) + if !ok { + return fmt.Errorf("right not string") + } + res.Pattern = string(str) + *dest = res + return nil +} + func recordToJSON(dest *recordJSON, src Node) error { res := recordJSON{} for _, kv := range src.args { @@ -191,7 +206,7 @@ func (j *nodeJSON) FromNode(src Node) error { case nodeTypeNot: return unaryToJSON(&j.Not, src) case nodeTypeNegate: - return unaryToJSON(&j.Not, src) + return unaryToJSON(&j.Negate, src) // Binary operators: ==, !=, in, <, <=, >, >=, &&, ||, +, -, *, contains, containsAll, containsAny case nodeTypeAdd: @@ -231,7 +246,7 @@ func (j *nodeJSON) FromNode(src Node) error { case nodeTypeAccess: return strToJSON(&j.Access, src) case nodeTypeHas: - return strToJSON(&j.Access, src) + return strToJSON(&j.Has, src) // is case nodeTypeIs, nodeTypeIsIn: return isToJSON(&j.Is, src) @@ -239,7 +254,7 @@ func (j *nodeJSON) FromNode(src Node) error { // like // Like *strJSON `json:"like"` case nodeTypeLike: - return strToJSON(&j.Access, src) + return patternToJSON(&j.Like, src) // if-then-else // IfThenElse *ifThenElseJSON `json:"if-then-else"` diff --git a/x/exp/ast/json_test.go b/x/exp/ast/json_test.go index 710f5ef1..107bb9c8 100644 --- a/x/exp/ast/json_test.go +++ b/x/exp/ast/json_test.go @@ -2,8 +2,6 @@ package ast_test import ( "encoding/json" - "reflect" - "strings" "testing" "github.com/cedar-policy/cedar-go/testutil" @@ -17,7 +15,7 @@ func TestUnmarshalJSON(t *testing.T) { name string input string want *ast.Policy - wantErr bool + errFunc func(testing.TB, error) }{ /* @key("value") @@ -84,7 +82,354 @@ func TestUnmarshalJSON(t *testing.T) { When( ast.Context().Access("tls_version").Equals(ast.String("1.3")), ), - false, + testutil.OK, + }, + + { + "permit", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}}`, + ast.Permit(), + testutil.OK, + }, + { + "forbid", + `{"effect":"forbid","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}}`, + ast.Forbid(), + testutil.OK, + }, + { + "annotations", + `{"annotations":{"key":"value"},"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}}`, + ast.Permit().Annotate("key", "value"), + testutil.OK, + }, + { + "principalEq", + `{"effect":"permit","principal":{"op":"==","entity":{"type":"T","id":"42"}},"action":{"op":"All"},"resource":{"op":"All"}}`, + ast.Permit().PrincipalEq(types.NewEntityUID("T", "42")), + testutil.OK, + }, + { + "principalIn", + `{"effect":"permit","principal":{"op":"in","entity":{"type":"T","id":"42"}},"action":{"op":"All"},"resource":{"op":"All"}}`, + ast.Permit().PrincipalIn(types.NewEntityUID("T", "42")), + testutil.OK, + }, + { + "principalIs", + `{"effect":"permit","principal":{"op":"is","entity_type":"T"},"action":{"op":"All"},"resource":{"op":"All"}}`, + ast.Permit().PrincipalIs(types.String("T")), + testutil.OK, + }, + { + "principalIsIn", + `{"effect":"permit","principal":{"op":"is","entity_type":"T","in":{"entity":{"type":"P","id":"42"}}},"action":{"op":"All"},"resource":{"op":"All"}}`, + ast.Permit().PrincipalIsIn(types.String("T"), types.NewEntityUID("P", "42")), + testutil.OK, + }, + { + "actionEq", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"==","entity":{"type":"T","id":"42"}},"resource":{"op":"All"}}`, + ast.Permit().ActionEq(types.NewEntityUID("T", "42")), + testutil.OK, + }, + { + "actionIn", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"in","entity":{"type":"T","id":"42"}},"resource":{"op":"All"}}`, + ast.Permit().ActionIn(types.NewEntityUID("T", "42")), + testutil.OK, + }, + { + "actionInSet", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"in","entities":[{"type":"T","id":"42"},{"type":"T","id":"43"}]},"resource":{"op":"All"}}`, + ast.Permit().ActionInSet(types.NewEntityUID("T", "42"), types.NewEntityUID("T", "43")), + testutil.OK, + }, + { + "resourceEq", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"==","entity":{"type":"T","id":"42"}}}`, + ast.Permit().ResourceEq(types.NewEntityUID("T", "42")), + testutil.OK, + }, + { + "resourceIn", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"in","entity":{"type":"T","id":"42"}}}`, + ast.Permit().ResourceIn(types.NewEntityUID("T", "42")), + testutil.OK, + }, + { + "resourceIs", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"is","entity_type":"T"}}`, + ast.Permit().ResourceIs(types.String("T")), + testutil.OK, + }, + { + "resourceIsIn", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"is","entity_type":"T","in":{"entity":{"type":"P","id":"42"}}}}`, + ast.Permit().ResourceIsIn(types.String("T"), types.NewEntityUID("P", "42")), + testutil.OK, + }, + { + "when", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, + "conditions":[{"kind":"when","body":{"Value":true}}]}`, + ast.Permit().When(ast.True()), + testutil.OK, + }, + { + "unless", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, + "conditions":[{"kind":"unless","body":{"Value":false}}]}`, + ast.Permit().Unless(ast.False()), + testutil.OK, + }, + { + "long", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, + "conditions":[{"kind":"when","body":{"Value":42}}]}`, + ast.Permit().When(ast.Long(42)), + testutil.OK, + }, + { + "string", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, + "conditions":[{"kind":"when","body":{"Value":"bananas"}}]}`, + ast.Permit().When(ast.String("bananas")), + testutil.OK, + }, + { + "entity", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, + "conditions":[{"kind":"when","body":{"Value":{"__entity":{"type":"T","id":"42"}}}}]}`, + ast.Permit().When(ast.Entity(types.NewEntityUID("T", "42"))), + testutil.OK, + }, + { + "set", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, + "conditions":[{"kind":"when","body":{"Set":[{"Value":42},{"Value":"bananas"}]}}]}`, + ast.Permit().When(ast.Set(types.Set{types.Long(42), types.String("bananas")})), + testutil.OK, + }, + { + "record", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, + "conditions":[{"kind":"when","body":{"Record":{"key":{"Value":42}}}}]}`, + ast.Permit().When(ast.Record(types.Record{"key": types.Long(42)})), + testutil.OK, + }, + { + "decimal", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, + "conditions":[{"kind":"when","body":{"decimal":[{"Value":"42.24"}]}}]}`, + ast.Permit().When(ast.Decimal(mustParseDecimal("42.24"))), + testutil.OK, + }, + { + "ip", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, + "conditions":[{"kind":"when","body":{"ip":[{"Value":"10.0.0.42/8"}]}}]}`, + ast.Permit().When(ast.IPAddr(mustParseIPAddr("10.0.0.42/8"))), + testutil.OK, + }, + { + "principalVar", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, + "conditions":[{"kind":"when","body":{"Var":"principal"}}]}`, + ast.Permit().When(ast.Principal()), + testutil.OK, + }, + { + "actionVar", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, + "conditions":[{"kind":"when","body":{"Var":"action"}}]}`, + ast.Permit().When(ast.Action()), + testutil.OK, + }, + { + "resourceVar", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, + "conditions":[{"kind":"when","body":{"Var":"resource"}}]}`, + ast.Permit().When(ast.Resource()), + testutil.OK, + }, + { + "contextVar", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, + "conditions":[{"kind":"when","body":{"Var":"context"}}]}`, + ast.Permit().When(ast.Context()), + testutil.OK, + }, + { + "not", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, + "conditions":[{"kind":"when","body":{"!":{"arg":{"Value":true}}}}]}`, + ast.Permit().When(ast.Not(ast.True())), + testutil.OK, + }, + { + "negate", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, + "conditions":[{"kind":"when","body":{"neg":{"arg":{"Value":42}}}}]}`, + ast.Permit().When(ast.Negate(ast.Long(42))), + testutil.OK, + }, + { + "equals", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, + "conditions":[{"kind":"when","body":{"==":{"left":{"Value":42},"right":{"Value":24}}}}]}`, + ast.Permit().When(ast.Long(42).Equals(ast.Long(24))), + testutil.OK, + }, + { + "notEquals", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, + "conditions":[{"kind":"when","body":{"!=":{"left":{"Value":42},"right":{"Value":24}}}}]}`, + ast.Permit().When(ast.Long(42).NotEquals(ast.Long(24))), + testutil.OK, + }, + { + "in", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, + "conditions":[{"kind":"when","body":{"in":{"left":{"Value":42},"right":{"Value":24}}}}]}`, + ast.Permit().When(ast.Long(42).In(ast.Long(24))), + testutil.OK, + }, + { + "lessThan", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, + "conditions":[{"kind":"when","body":{"<":{"left":{"Value":42},"right":{"Value":24}}}}]}`, + ast.Permit().When(ast.Long(42).LessThan(ast.Long(24))), + testutil.OK, + }, + { + "lessThanEquals", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, + "conditions":[{"kind":"when","body":{"<=":{"left":{"Value":42},"right":{"Value":24}}}}]}`, + ast.Permit().When(ast.Long(42).LessThanOrEqual(ast.Long(24))), + testutil.OK, + }, + { + "greaterThan", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, + "conditions":[{"kind":"when","body":{">":{"left":{"Value":42},"right":{"Value":24}}}}]}`, + ast.Permit().When(ast.Long(42).GreaterThan(ast.Long(24))), + testutil.OK, + }, + { + "greaterThanEquals", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, + "conditions":[{"kind":"when","body":{">=":{"left":{"Value":42},"right":{"Value":24}}}}]}`, + ast.Permit().When(ast.Long(42).GreaterThanOrEqual(ast.Long(24))), + testutil.OK, + }, + { + "and", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, + "conditions":[{"kind":"when","body":{"&&":{"left":{"Value":42},"right":{"Value":24}}}}]}`, + ast.Permit().When(ast.Long(42).And(ast.Long(24))), + testutil.OK, + }, + { + "or", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, + "conditions":[{"kind":"when","body":{"||":{"left":{"Value":42},"right":{"Value":24}}}}]}`, + ast.Permit().When(ast.Long(42).Or(ast.Long(24))), + testutil.OK, + }, + { + "plus", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, + "conditions":[{"kind":"when","body":{"+":{"left":{"Value":42},"right":{"Value":24}}}}]}`, + ast.Permit().When(ast.Long(42).Plus(ast.Long(24))), + testutil.OK, + }, + { + "minus", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, + "conditions":[{"kind":"when","body":{"-":{"left":{"Value":42},"right":{"Value":24}}}}]}`, + ast.Permit().When(ast.Long(42).Minus(ast.Long(24))), + testutil.OK, + }, + { + "times", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, + "conditions":[{"kind":"when","body":{"*":{"left":{"Value":42},"right":{"Value":24}}}}]}`, + ast.Permit().When(ast.Long(42).Times(ast.Long(24))), + testutil.OK, + }, + { + "contains", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, + "conditions":[{"kind":"when","body":{"contains":{"left":{"Value":42},"right":{"Value":24}}}}]}`, + ast.Permit().When(ast.Long(42).Contains(ast.Long(24))), + testutil.OK, + }, + { + "containsAll", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, + "conditions":[{"kind":"when","body":{"containsAll":{"left":{"Value":42},"right":{"Value":24}}}}]}`, + ast.Permit().When(ast.Long(42).ContainsAll(ast.Long(24))), + testutil.OK, + }, + { + "containsAny", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, + "conditions":[{"kind":"when","body":{"containsAny":{"left":{"Value":42},"right":{"Value":24}}}}]}`, + ast.Permit().When(ast.Long(42).ContainsAny(ast.Long(24))), + testutil.OK, + }, + { + "access", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, + "conditions":[{"kind":"when","body":{".":{"left":{"Var":"context"},"attr":"key"}}}]}`, + ast.Permit().When(ast.Context().Access("key")), + testutil.OK, + }, + { + "has", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, + "conditions":[{"kind":"when","body":{"has":{"left":{"Var":"context"},"attr":"key"}}}]}`, + ast.Permit().When(ast.Context().Has("key")), + testutil.OK, + }, + { + "is", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, + "conditions":[{"kind":"when","body":{"is":{"left":{"Var":"resource"},"entity_type":"T"}}}]}`, + ast.Permit().When(ast.Resource().Is("T")), + testutil.OK, + }, + { + "isIn", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, + "conditions":[{"kind":"when","body":{"is":{"left":{"Var":"resource"},"entity_type":"T","in":{"entity":{"type":"P","id":"42"}}}}}]}`, + ast.Permit().When(ast.Resource().IsIn("T", ast.Entity(types.NewEntityUID("P", "42")))), + testutil.OK, + }, + { + "like", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, + "conditions":[{"kind":"when","body":{"like":{"left":{"Value":"text"},"pattern":"*"}}}]}`, + ast.Permit().When(ast.String("text").Like("*")), + testutil.OK, + }, + { + "ifThenElse", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, + "conditions":[{"kind":"when","body":{"if-then-else":{"if":{"Value":true},"then":{"Value":42},"else":{"Value":24}}}}]}`, + ast.Permit().When(ast.If(ast.True(), ast.Long(42), ast.Long(24))), + testutil.OK, + }, + { + "isInRange", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, + "conditions":[{"kind":"when","body":{"isInRange":[ + {"ip":[{"Value":"10.0.0.43"}]}, + {"ip":[{"Value":"10.0.0.42/8"}]} + ]}}]}`, + ast.Permit().When(ast.IPAddr(mustParseIPAddr("10.0.0.43")).IsInRange(ast.IPAddr(mustParseIPAddr("10.0.0.42/8")))), + testutil.OK, }, } @@ -93,19 +438,35 @@ func TestUnmarshalJSON(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() var p ast.Policy - fixedInput := strings.ReplaceAll(tt.input, "\t", " ") - err := json.Unmarshal([]byte(fixedInput), &p) - if (err != nil) != tt.wantErr { - t.Errorf("error got: %v want: %v", err, tt.wantErr) - } - if !reflect.DeepEqual(&p, tt.want) { - t.Errorf("policy mismatch: got: %+v want: %+v", p, *tt.want) + err := json.Unmarshal([]byte(tt.input), &p) + tt.errFunc(t, err) + if err != nil { + return } - - b, err := json.MarshalIndent(&p, "", " ") + testutil.Equals(t, p, *tt.want) + b, err := json.Marshal(&p) testutil.OK(t, err) - testutil.Equals(t, string(b), fixedInput) + normInput := testNormalizeJSON(t, tt.input) + normOutput := testNormalizeJSON(t, string(b)) + testutil.Equals(t, normOutput, normInput) }) } +} +func testNormalizeJSON(t testing.TB, in string) string { + var x any + err := json.Unmarshal([]byte(in), &x) + testutil.OK(t, err) + out, err := json.MarshalIndent(x, "", " ") + testutil.OK(t, err) + return string(out) +} + +func mustParseDecimal(v string) types.Decimal { + res, _ := types.ParseDecimal(v) + return res +} +func mustParseIPAddr(v string) types.IPAddr { + res, _ := types.ParseIPAddr(v) + return res } diff --git a/x/exp/ast/json_unmarshal.go b/x/exp/ast/json_unmarshal.go index 0879f6cb..f390cd7c 100644 --- a/x/exp/ast/json_unmarshal.go +++ b/x/exp/ast/json_unmarshal.go @@ -55,6 +55,13 @@ func (j strJSON) ToNode(f func(a Node, k string) Node) (Node, error) { } return f(left, j.Attr), nil } +func (j patternJSON) ToNode(f func(a Node, k string) Node) (Node, error) { + left, err := j.Left.ToNode() + if err != nil { + return Node{}, fmt.Errorf("error in left: %w", err) + } + return f(left, j.Pattern), nil +} func (j isJSON) ToNode() (Node, error) { left, err := j.Left.ToNode() if err != nil { From 9461a08828b338022ee40dddc56a3201a527e5ba Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Thu, 1 Aug 2024 14:10:16 -0700 Subject: [PATCH 039/216] cedar-go/x/exp/ast: fold all extension methods into a single node type Signed-off-by: philhassey --- x/exp/ast/json.go | 12 ++--- x/exp/ast/json_marshal.go | 61 ++++++++++++++------------ x/exp/ast/json_unmarshal.go | 87 +++++++++++++++++++------------------ x/exp/ast/node.go | 47 ++++++++++++-------- x/exp/ast/operator.go | 18 ++++---- 5 files changed, 120 insertions(+), 105 deletions(-) diff --git a/x/exp/ast/json.go b/x/exp/ast/json.go index 322ee448..76786190 100644 --- a/x/exp/ast/json.go +++ b/x/exp/ast/json.go @@ -67,6 +67,8 @@ type arrayJSON []nodeJSON type recordJSON map[string]nodeJSON +type extMethodCallJSON map[string]arrayJSON + type nodeJSON struct { // Value Value *json.RawMessage `json:"Value,omitempty"` // could be any @@ -122,13 +124,5 @@ type nodeJSON struct { IP arrayJSON `json:"ip,omitempty"` // Any other method: lessThan, lessThanOrEqual, greaterThan, greaterThanOrEqual, isIpv4, isIpv6, isLoopback, isMulticast, isInRange - LessThanExt arrayJSON `json:"lessThan,omitempty"` - LessThanOrEqualExt arrayJSON `json:"lessThanOrEqual,omitempty"` - GreaterThanExt arrayJSON `json:"greaterThan,omitempty"` - GreaterThanOrEqualExt arrayJSON `json:"greaterThanOrEqual,omitempty"` - IsIpv4Ext arrayJSON `json:"isIpv4,omitempty"` - IsIpv6Ext arrayJSON `json:"isIpv6,omitempty"` - IsLoopbackExt arrayJSON `json:"isLoopback,omitempty"` - IsMulticastExt arrayJSON `json:"isMulticast,omitempty"` - IsInRangeExt arrayJSON `json:"isInRange,omitempty"` + ExtensionMethod extMethodCallJSON `json:"-"` } diff --git a/x/exp/ast/json_marshal.go b/x/exp/ast/json_marshal.go index 33d9a422..87439bb6 100644 --- a/x/exp/ast/json_marshal.go +++ b/x/exp/ast/json_marshal.go @@ -95,6 +95,26 @@ func extToJSON(dest *arrayJSON, src Node) error { return nil } +func extMethodToJSON(dest extMethodCallJSON, src Node) error { + n := extMethodCallNode(src) + objectNode := &nodeJSON{} + err := objectNode.FromNode(n.Object()) + if err != nil { + return err + } + jsonArgs := arrayJSON{*objectNode} + for _, n := range n.Args() { + argNode := &nodeJSON{} + err := argNode.FromNode(n) + if err != nil { + return err + } + jsonArgs = append(jsonArgs, *argNode) + } + dest[n.Name()] = jsonArgs + return nil +} + func strToJSON(dest **strJSON, src Node) error { n := binaryNode(src) res := &strJSON{} @@ -281,33 +301,10 @@ func (j *nodeJSON) FromNode(src Node) error { return extToJSON(&j.IP, src) // Any other method: lessThan, lessThanOrEqual, greaterThan, greaterThanOrEqual, isIpv4, isIpv6, isLoopback, isMulticast, isInRange - // LessThanExt arrayJSON `json:"lessThan"` - // LessThanOrEqualExt arrayJSON `json:"lessThanOrEqual"` - // GreaterThanExt arrayJSON `json:"greaterThan"` - // GreaterThanOrEqualExt arrayJSON `json:"greaterThanOrEqual"` - // IsIpv4Ext arrayJSON `json:"isIpv4"` - // IsIpv6Ext arrayJSON `json:"isIpv6"` - // IsLoopbackExt arrayJSON `json:"isLoopback"` - // IsMulticastExt arrayJSON `json:"isMulticast"` - // IsInRangeExt arrayJSON `json:"isInRange"` - case nodeTypeLessExt: - return arrayToJSON(&j.LessThanExt, src) - case nodeTypeLessEqualExt: - return arrayToJSON(&j.LessThanOrEqualExt, src) - case nodeTypeGreaterExt: - return arrayToJSON(&j.GreaterThanExt, src) - case nodeTypeGreaterEqualExt: - return arrayToJSON(&j.GreaterThanOrEqualExt, src) - case nodeTypeIsInRange: - return arrayToJSON(&j.IsInRangeExt, src) - case nodeTypeIsIpv4: - return arrayToJSON(&j.IsIpv4Ext, src) - case nodeTypeIsIpv6: - return arrayToJSON(&j.IsIpv6Ext, src) - case nodeTypeIsLoopback: - return arrayToJSON(&j.IsLoopbackExt, src) - case nodeTypeIsMulticast: - return arrayToJSON(&j.IsMulticastExt, src) + // ExtensionMethod map[string]arrayJSON `json:"-"` + case nodeTypeExtMethodCall: + j.ExtensionMethod = extMethodCallJSON{} + return extMethodToJSON(j.ExtensionMethod, src) } // case nodeTypeRecordEntry: // case nodeTypeEntityType: @@ -316,6 +313,16 @@ func (j *nodeJSON) FromNode(src Node) error { // case nodeTypeUnless: return fmt.Errorf("unknown node type: %v", src.nodeType) } + +func (j *nodeJSON) MarshalJSON() ([]byte, error) { + if len(j.ExtensionMethod) > 0 { + return json.Marshal(j.ExtensionMethod) + } + + type nodeJSONAlias nodeJSON + return json.Marshal((*nodeJSONAlias)(j)) +} + func (p *Policy) MarshalJSON() ([]byte, error) { var j policyJSON j.Effect = "forbid" diff --git a/x/exp/ast/json_unmarshal.go b/x/exp/ast/json_unmarshal.go index f390cd7c..ce4ef2ed 100644 --- a/x/exp/ast/json_unmarshal.go +++ b/x/exp/ast/json_unmarshal.go @@ -1,8 +1,10 @@ package ast import ( + "bytes" "encoding/json" "fmt" + "strings" "github.com/cedar-policy/cedar-go/types" ) @@ -99,17 +101,6 @@ func (j arrayJSON) ToNode() (Node, error) { return SetNodes(nodes), nil } -func (j arrayJSON) ToExt1(f func(Node) Node) (Node, error) { - if len(j) != 1 { - return Node{}, fmt.Errorf("unexpected number of arguments for extension: %v", len(j)) - } - arg, err := j[0].ToNode() - if err != nil { - return Node{}, fmt.Errorf("error in extension: %w", err) - } - return f(arg), nil -} - func (j arrayJSON) ToDecimalNode() (Node, error) { if len(j) != 1 { return Node{}, fmt.Errorf("unexpected number of arguments for extension: %v", len(j)) @@ -148,20 +139,6 @@ func (j arrayJSON) ToIPAddrNode() (Node, error) { return IPAddr(v), nil } -func (j arrayJSON) ToExt2(f func(Node, Node) Node) (Node, error) { - if len(j) != 2 { - return Node{}, fmt.Errorf("unexpected number of arguments for extension: %v", len(j)) - } - left, err := j[0].ToNode() - if err != nil { - return Node{}, fmt.Errorf("error in argument 0: %w", err) - } - right, err := j[1].ToNode() - if err != nil { - return Node{}, fmt.Errorf("error in argument 1: %w", err) - } - return f(left, right), nil -} func (j recordJSON) ToNode() (Node, error) { nodes := map[types.String]Node{} for k, v := range j { @@ -174,6 +151,27 @@ func (j recordJSON) ToNode() (Node, error) { return RecordNodes(nodes), nil } +func (e extMethodCallJSON) ToNode() (Node, error) { + if len(e) != 1 { + return Node{}, fmt.Errorf("unexpected number of extension methods in node: %v", len(e)) + } + for k, v := range e { + if len(v) == 0 { + return Node{}, fmt.Errorf("extension method '%v' must have at least one argument", k) + } + var argNodes []Node + for _, n := range v { + node, err := n.ToNode() + if err != nil { + return Node{}, fmt.Errorf("error in extension method argument: %w", err) + } + argNodes = append(argNodes, node) + } + return newExtMethodCallNode(argNodes[0], k, argNodes[1:]...), nil + } + panic("unreachable code") +} + func (j nodeJSON) ToNode() (Node, error) { switch { // Value @@ -272,29 +270,32 @@ func (j nodeJSON) ToNode() (Node, error) { return j.IP.ToIPAddrNode() // Any other method: lessThan, lessThanOrEqual, greaterThan, greaterThanOrEqual, isIpv4, isIpv6, isLoopback, isMulticast, isInRange - case j.LessThanExt != nil: - return j.LessThanExt.ToExt2(Node.LessThanExt) - case j.LessThanOrEqualExt != nil: - return j.LessThanOrEqualExt.ToExt2(Node.LessThanOrEqualExt) - case j.GreaterThanExt != nil: - return j.GreaterThanExt.ToExt2(Node.GreaterThanExt) - case j.GreaterThanOrEqualExt != nil: - return j.GreaterThanOrEqualExt.ToExt2(Node.GreaterThanOrEqualExt) - case j.IsIpv4Ext != nil: - return j.IsIpv4Ext.ToExt1(Node.IsIpv4) - case j.IsIpv6Ext != nil: - return j.IsIpv6Ext.ToExt1(Node.IsIpv6) - case j.IsLoopbackExt != nil: - return j.IsLoopbackExt.ToExt1(Node.IsLoopback) - case j.IsMulticastExt != nil: - return j.IsMulticastExt.ToExt1(Node.IsMulticast) - case j.IsInRangeExt != nil: - return j.IsInRangeExt.ToExt2(Node.IsInRange) + case j.ExtensionMethod != nil: + return j.ExtensionMethod.ToNode() } return Node{}, fmt.Errorf("unknown node") } +func (n *nodeJSON) UnmarshalJSON(b []byte) error { + decoder := json.NewDecoder(bytes.NewReader(b)) + decoder.DisallowUnknownFields() + + type nodeJSONAlias nodeJSON + if err := decoder.Decode((*nodeJSONAlias)(n)); err == nil { + return nil + } else if !strings.HasPrefix(err.Error(), "json: unknown field") { + return err + } + + // If an unknown field was parsed, the spec tells us to treat it as an extension method: + // > Any other key + // > This key is treated as the name of an extension function or method. The value must + // > be a JSON array of values, each of which is itself an JsonExpr object. Note that for + // > method calls, the method receiver is the first argument. + return json.Unmarshal(b, &n.ExtensionMethod) +} + func (p *Policy) UnmarshalJSON(b []byte) error { var j policyJSON if err := json.Unmarshal(b, &j); err != nil { diff --git a/x/exp/ast/node.go b/x/exp/ast/node.go index 25ed44c3..6d1cb70f 100644 --- a/x/exp/ast/node.go +++ b/x/exp/ast/node.go @@ -5,8 +5,7 @@ import "github.com/cedar-policy/cedar-go/types" type nodeType uint8 const ( - nodeTypeNone nodeType = iota - nodeTypeAccess + nodeTypeAccess = iota nodeTypeAdd nodeTypeAll nodeTypeAnd @@ -15,44 +14,36 @@ const ( nodeTypeContains nodeTypeContainsAll nodeTypeContainsAny + nodeTypeDecimal nodeTypeEntity nodeTypeEntityType nodeTypeEquals nodeTypeGreater nodeTypeGreaterEqual - nodeTypeLike nodeTypeHas nodeTypeIf nodeTypeIn nodeTypeIpAddr - nodeTypeDecimal nodeTypeIs - nodeTypeIsInRange - nodeTypeIsIpv4 - nodeTypeIsIpv6 - nodeTypeIsLoopback - nodeTypeIsMulticast + nodeTypeIsIn nodeTypeLess nodeTypeLessEqual + nodeTypeLike nodeTypeLong + nodeTypeExtMethodCall nodeTypeMult - nodeTypeNot nodeTypeNegate + nodeTypeNot nodeTypeNotEquals nodeTypeOr nodeTypeRecord nodeTypeRecordEntry nodeTypeSet - nodeTypeSub nodeTypeString + nodeTypeSub + nodeTypeUnless nodeTypeVariable - nodeTypeLessExt - nodeTypeLessEqualExt - nodeTypeGreaterExt - nodeTypeGreaterEqualExt nodeTypeWhen - nodeTypeUnless - nodeTypeIsIn ) type Node struct { @@ -87,3 +78,25 @@ type trinaryNode Node func (n trinaryNode) A() Node { return n.args[0] } func (n trinaryNode) B() Node { return n.args[1] } func (n trinaryNode) C() Node { return n.args[2] } + +func newExtMethodCallNode(object Node, methodName string, args ...Node) Node { + nodes := []Node{object, String(types.String(methodName))} + return Node{ + nodeType: nodeTypeExtMethodCall, + args: append(nodes, args...), + } +} + +type extMethodCallNode Node + +func (n extMethodCallNode) Object() Node { + return n.args[0] +} + +func (n extMethodCallNode) Name() string { + return string(n.args[1].value.(types.String)) +} + +func (n extMethodCallNode) Args() []Node { + return n.args[2:] +} diff --git a/x/exp/ast/operator.go b/x/exp/ast/operator.go index d3f96990..63bea91a 100644 --- a/x/exp/ast/operator.go +++ b/x/exp/ast/operator.go @@ -34,19 +34,19 @@ func (lhs Node) GreaterThanOrEqual(rhs Node) Node { } func (lhs Node) LessThanExt(rhs Node) Node { - return newBinaryNode(nodeTypeLessExt, lhs, rhs) + return newExtMethodCallNode(lhs, "lessThan", rhs) } func (lhs Node) LessThanOrEqualExt(rhs Node) Node { - return newBinaryNode(nodeTypeLessEqualExt, lhs, rhs) + return newExtMethodCallNode(lhs, "lessThanOrEqual", rhs) } func (lhs Node) GreaterThanExt(rhs Node) Node { - return newBinaryNode(nodeTypeGreaterExt, lhs, rhs) + return newExtMethodCallNode(lhs, "greaterThan", rhs) } func (lhs Node) GreaterThanOrEqualExt(rhs Node) Node { - return newBinaryNode(nodeTypeGreaterEqualExt, lhs, rhs) + return newExtMethodCallNode(lhs, "greaterThanOrEqual", rhs) } func (lhs Node) Like(patt string) Node { @@ -161,21 +161,21 @@ func (lhs Node) Has(attr string) Node { // |___|_| /_/ \_\__,_|\__,_|_| \___||___/___/ func (lhs Node) IsIpv4() Node { - return newUnaryNode(nodeTypeIsIpv4, lhs) + return newExtMethodCallNode(lhs, "isIpv4") } func (lhs Node) IsIpv6() Node { - return newUnaryNode(nodeTypeIsIpv6, lhs) + return newExtMethodCallNode(lhs, "isIpv6") } func (lhs Node) IsMulticast() Node { - return newUnaryNode(nodeTypeIsMulticast, lhs) + return newExtMethodCallNode(lhs, "isMulticast") } func (lhs Node) IsLoopback() Node { - return newUnaryNode(nodeTypeIsLoopback, lhs) + return newExtMethodCallNode(lhs, "isLoopback") } func (lhs Node) IsInRange(rhs Node) Node { - return newBinaryNode(nodeTypeIsInRange, lhs, rhs) + return newExtMethodCallNode(lhs, "isInRange", rhs) } From 4a45c0eea6a1044436b8d50c1a3e9c4e4868aa8b Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Thu, 1 Aug 2024 17:45:14 -0700 Subject: [PATCH 040/216] cedar-go/x/exp/ast: mega-patch containing most of a Cedar text to AST parser Missing functionality: * Parsing support for patterns. * Parsing sugar for negative numbers. Right now, they're expressed as Negate(Long(x)) rather than Long(-x). * Parsing error tests * Parsing fuzz tests There are tests for the first two cases that are commented out. Signed-off-by: philhassey --- x/exp/ast/ast_test.go | 4 +- x/exp/ast/json_unmarshal.go | 2 +- x/exp/ast/operator.go | 8 +- x/exp/ast/parser.go | 497 +++++++++++++++++++++++++++++++++++- x/exp/ast/parser_test.go | 165 +++++++++++- x/exp/ast/tokenize.go | 4 + x/exp/ast/value.go | 16 +- 7 files changed, 675 insertions(+), 21 deletions(-) diff --git a/x/exp/ast/ast_test.go b/x/exp/ast/ast_test.go index d07d5745..cac3b458 100644 --- a/x/exp/ast/ast_test.go +++ b/x/exp/ast/ast_test.go @@ -67,10 +67,10 @@ func TestAst(t *testing.T) { }).Access("x").Equals(ast.Long(3)), ). When( - ast.SetNodes([]ast.Node{ + ast.SetNodes( ast.Long(1), ast.Long(2).Plus(ast.Long(3)), ast.Context().Access("fooCount"), - }).Contains(ast.Long(1)), + ).Contains(ast.Long(1)), ) } diff --git a/x/exp/ast/json_unmarshal.go b/x/exp/ast/json_unmarshal.go index ce4ef2ed..683d12e6 100644 --- a/x/exp/ast/json_unmarshal.go +++ b/x/exp/ast/json_unmarshal.go @@ -98,7 +98,7 @@ func (j arrayJSON) ToNode() (Node, error) { } nodes = append(nodes, n) } - return SetNodes(nodes), nil + return SetNodes(nodes...), nil } func (j arrayJSON) ToDecimalNode() (Node, error) { diff --git a/x/exp/ast/operator.go b/x/exp/ast/operator.go index 63bea91a..8f31f562 100644 --- a/x/exp/ast/operator.go +++ b/x/exp/ast/operator.go @@ -72,10 +72,6 @@ func Not(rhs Node) Node { return newUnaryNode(nodeTypeNot, rhs) } -func Negate(rhs Node) Node { - return newUnaryNode(nodeTypeNegate, rhs) -} - func If(condition Node, ifTrue Node, ifFalse Node) Node { return newTrinaryNode(nodeTypeIf, condition, ifTrue, ifFalse) } @@ -98,6 +94,10 @@ func (lhs Node) Times(rhs Node) Node { return newBinaryNode(nodeTypeMult, lhs, rhs) } +func Negate(rhs Node) Node { + return newUnaryNode(nodeTypeNegate, rhs) +} + // _ _ _ _ // | | | (_) ___ _ __ __ _ _ __ ___| |__ _ _ // | |_| | |/ _ \ '__/ _` | '__/ __| '_ \| | | | diff --git a/x/exp/ast/parser.go b/x/exp/ast/parser.go index 56fdd76f..59719007 100644 --- a/x/exp/ast/parser.go +++ b/x/exp/ast/parser.go @@ -2,6 +2,8 @@ package ast import ( "fmt" + "net/netip" + "strconv" "github.com/cedar-policy/cedar-go/types" ) @@ -38,9 +40,9 @@ func policyFromCedar(p *parser) (*Policy, error) { if err = p.exact(")"); err != nil { return nil, err } - // if res.Conditions, err = p.conditions(); err != nil { - // return res, err - // } + if err = p.conditions(policy); err != nil { + return nil, err + } if err = p.exact(";"); err != nil { return nil, err } @@ -182,12 +184,17 @@ func (p *parser) principal(policy *Policy) error { func (p *parser) entity() (types.EntityUID, error) { var res types.EntityUID - var err error t := p.advance() if !t.isIdent() { return res, p.errorf("expected ident") } - res.Type = t.Text + return p.entityFirstPathPreread(t.Text) +} + +func (p *parser) entityFirstPathPreread(firstPath string) (types.EntityUID, error) { + var res types.EntityUID + var err error + res.Type = firstPath for { if err := p.exact("::"); err != nil { return res, err @@ -327,3 +334,483 @@ func (p *parser) resource(policy *Policy) error { return nil } + +func (p *parser) conditions(policy *Policy) error { + for { + switch p.peek().Text { + case "when": + p.advance() + expr, err := p.condition() + if err != nil { + return err + } + policy.When(expr) + case "unless": + p.advance() + expr, err := p.condition() + if err != nil { + return err + } + policy.Unless(expr) + default: + return nil + } + } +} + +func (p *parser) condition() (Node, error) { + var res Node + var err error + if err := p.exact("{"); err != nil { + return res, err + } + if res, err = p.expression(); err != nil { + return res, err + } + if err := p.exact("}"); err != nil { + return res, err + } + return res, nil +} + +func (p *parser) expression() (Node, error) { + t := p.peek() + if t.Text == "if" { + p.advance() + + condition, err := p.expression() + if err != nil { + return Node{}, err + } + + if err = p.exact("then"); err != nil { + return Node{}, err + } + ifTrue, err := p.expression() + if err != nil { + return Node{}, err + } + + if err = p.exact("else"); err != nil { + return Node{}, err + } + ifFalse, err := p.expression() + if err != nil { + return Node{}, err + } + + return If(condition, ifTrue, ifFalse), nil + } + + return p.or() +} + +func (p *parser) or() (Node, error) { + lhs, err := p.and() + if err != nil { + return Node{}, err + } + + t := p.peek() + if t.Text != "||" { + return lhs, nil + } + + p.advance() + rhs, err := p.and() + if err != nil { + return Node{}, err + } + return lhs.Or(rhs), nil +} + +func (p *parser) and() (Node, error) { + lhs, err := p.relation() + if err != nil { + return Node{}, err + } + + t := p.peek() + if t.Text != "&&" { + return lhs, nil + } + + p.advance() + rhs, err := p.relation() + if err != nil { + return Node{}, err + } + return lhs.And(rhs), nil +} + +func (p *parser) relation() (Node, error) { + lhs, err := p.add() + if err != nil { + return Node{}, err + } + + t := p.peek() + operators := map[string]func(Node) Node{ + "<": lhs.LessThan, + "<=": lhs.LessThanOrEqual, + ">": lhs.GreaterThan, + ">=": lhs.GreaterThanOrEqual, + "!=": lhs.NotEquals, + "==": lhs.Equals, + "in": lhs.In, + } + if f, ok := operators[t.Text]; ok { + p.advance() + rhs, err := p.add() + if err != nil { + return Node{}, err + } + return f(rhs), nil + } + + if t.Text == "has" { + p.advance() + t = p.advance() + if t.isIdent() { + return lhs.Has(t.Text), nil + } else if t.isString() { + str, err := t.stringValue() + if err != nil { + return Node{}, err + } + return lhs.Has(str), nil + } + return Node{}, p.errorf("expected ident or string") + } else if t.Text == "like" { + // TODO: Deal with pattern matching + return Node{}, p.errorf("unimplemented") + } else if t.Text == "is" { + p.advance() + entityType, err := p.path() + if err != nil { + return Node{}, err + } + if p.peek().Text == "in" { + p.advance() + inEntity, err := p.add() + if err != nil { + return Node{}, err + } + return lhs.IsIn(entityType, inEntity), nil + } + return lhs.Is(entityType), nil + } + + return lhs, err +} + +func (p *parser) add() (Node, error) { + lhs, err := p.mult() + if err != nil { + return Node{}, err + } + + t := p.peek().Text + operators := map[string]func(Node) Node{ + "+": lhs.Plus, + "-": lhs.Minus, + } + if f, ok := operators[t]; ok { + p.advance() + rhs, err := p.mult() + if err != nil { + return Node{}, err + } + return f(rhs), nil + } + + return lhs, nil +} + +func (p *parser) mult() (Node, error) { + lhs, err := p.unary() + if err != nil { + return Node{}, err + } + + if p.peek().Text != "*" { + return lhs, nil + } + + p.advance() + rhs, err := p.unary() + if err != nil { + return Node{}, err + } + return lhs.Times(rhs), nil +} + +func (p *parser) unary() (Node, error) { + var res Node + var ops [](func(Node) Node) + for { + op := p.peek().Text + switch op { + case "!": + p.advance() + ops = append(ops, Not) + case "-": + p.advance() + ops = append(ops, Negate) + default: + var err error + res, err = p.member() + if err != nil { + return res, err + } + + // TODO: add support for parsing -1 into a negative Long rather than a Negate(Long) + for i := len(ops) - 1; i >= 0; i-- { + res = ops[i](res) + } + return res, nil + } + } +} + +func (p *parser) member() (Node, error) { + res, err := p.primary() + if err != nil { + return res, err + } + for { + var ok bool + res, ok, err = p.access(res) + if !ok { + return res, err + } + } +} + +func (p *parser) primary() (Node, error) { + var res Node + t := p.advance() + switch { + case t.isInt(): + i, err := t.intValue() + if err != nil { + return res, err + } + res = Long(types.Long(i)) + case t.isString(): + str, err := t.stringValue() + if err != nil { + return res, err + } + res = String(types.String(str)) + case t.Text == "true": + res = True() + case t.Text == "false": + res = False() + case t.Text == "principal": + res = Principal() + case t.Text == "action": + res = Action() + case t.Text == "resource": + res = Resource() + case t.Text == "context": + res = Context() + case t.isIdent(): + return p.entityOrExtFun(t.Text) + case t.Text == "(": + expr, err := p.expression() + if err != nil { + return res, err + } + if err := p.exact(")"); err != nil { + return res, err + } + res = expr + case t.Text == "[": + set, err := p.expressions("]") + if err != nil { + return res, err + } + p.advance() // expressions guarantees "]" + res = SetNodes(set...) + case t.Text == "{": + record, err := p.record() + if err != nil { + return res, err + } + res = record + default: + return res, p.errorf("invalid primary") + } + return res, nil +} + +func (p *parser) entityOrExtFun(ident string) (Node, error) { + // Technically, according to the grammar, both entities and extension functions + // can have path prefixes and so parsing here is not trivial. In practice, there + // are only two extension functions: `ip()` and `decimal()`, neither of which + // have a path prefix. We'll just handle those two cases specially and treat + // everything else as an entity. + var res Node + switch ident { + case "ip", "decimal": + if err := p.exact("("); err != nil { + return res, err + } + t := p.advance() + if !t.isString() { + return res, p.errorf("expected string") + } + str, err := t.stringValue() + if err != nil { + return res, err + } + if err := p.exact(")"); err != nil { + return res, err + } + + if ident == "ip" { + ipaddr, err := netip.ParsePrefix(str) + if err != nil { + return res, err + } + res = IPAddr(types.IPAddr(ipaddr)) + } else { + dec, err := strconv.ParseFloat(str, 64) + if err != nil { + return res, err + } + res = Decimal(types.Decimal(dec)) + } + default: + entity, err := p.entityFirstPathPreread(ident) + if err != nil { + return res, err + } + res = Entity(entity) + } + + return res, nil +} + +func (p *parser) expressions(endOfListMarker string) ([]Node, error) { + var res []Node + for p.peek().Text != endOfListMarker { + if len(res) > 0 { + if err := p.exact(","); err != nil { + return res, err + } + } + e, err := p.expression() + if err != nil { + return res, err + } + res = append(res, e) + } + return res, nil +} + +func (p *parser) record() (Node, error) { + var res Node + entries := map[types.String]Node{} + for { + t := p.peek() + if t.Text == "}" { + p.advance() + return RecordNodes(entries), nil + } + if len(entries) > 0 { + if err := p.exact(","); err != nil { + return res, err + } + } + k, v, err := p.recordEntry() + if err != nil { + return res, err + } + + if _, ok := entries[k]; ok { + return res, p.errorf("duplicate key: %v", k) + } + entries[k] = v + } +} + +func (p *parser) recordEntry() (types.String, Node, error) { + var key types.String + var value Node + var err error + t := p.advance() + switch { + case t.isIdent(): + key = types.String(t.Text) + case t.isString(): + str, err := t.stringValue() + if err != nil { + return key, value, err + } + key = types.String(str) + default: + return key, value, p.errorf("unexpected token") + } + if err := p.exact(":"); err != nil { + return key, value, err + } + value, err = p.expression() + if err != nil { + return key, value, err + } + return key, value, nil +} + +func (p *parser) access(lhs Node) (Node, bool, error) { + t := p.peek() + switch t.Text { + case ".": + p.advance() + t := p.advance() + if !t.isIdent() { + return Node{}, false, p.errorf("unexpected token") + } + if p.peek().Text == "(" { + methodName := t.Text + p.advance() + exprs, err := p.expressions(")") + if err != nil { + return Node{}, false, err + } + p.advance() // expressions guarantees ")" + + knownMethods := map[string]func(Node) Node{ + "contains": lhs.Contains, + "containsAll": lhs.ContainsAll, + "containsAny": lhs.ContainsAny, + } + if f, ok := knownMethods[methodName]; ok { + if len(exprs) != 1 { + return Node{}, false, p.errorf("%v expects one argument", methodName) + } + return f(exprs[0]), true, nil + } + return newExtMethodCallNode(lhs, methodName, exprs...), true, nil + } else { + return lhs.Access(t.Text), true, nil + } + case "[": + p.advance() + t := p.advance() + if !t.isString() { + return Node{}, false, p.errorf("unexpected token") + } + name, err := t.stringValue() + if err != nil { + return Node{}, false, err + } + if err := p.exact("]"); err != nil { + return Node{}, false, err + } + return lhs.Access(name), true, nil + default: + return lhs, false, nil + } +} diff --git a/x/exp/ast/parser_test.go b/x/exp/ast/parser_test.go index 8df84e3c..11a92293 100644 --- a/x/exp/ast/parser_test.go +++ b/x/exp/ast/parser_test.go @@ -127,10 +127,173 @@ func TestParse(t *testing.T) { );`, Permit().ActionInSet(farming, forestry), }, + { + "trivial conditions", + `permit (principal, action, resource) + when { true } + unless { false };`, + Permit().When(Boolean(true)).Unless(Boolean(false)), + }, + { + "not operator", + `permit (principal, action, resource) + when { !true };`, + Permit().When(Not(Boolean(true))), + }, + //{ + // "negate operator", + // `permit (principal, action, resource) + // when { -1 };`, + // Permit().When(Long(-1)), + //}, + { + "variable member", + `permit (principal, action, resource) + when { context.boolValue };`, + Permit().When(Context().Access("boolValue")), + }, + { + "contains method call", + `permit (principal, action, resource) + when { context.strings.contains("foo") };`, + Permit().When(Context().Access("strings").Contains(String("foo"))), + }, + { + "containsAll method call", + `permit (principal, action, resource) + when { context.strings.containsAll(["foo"]) };`, + Permit().When(Context().Access("strings").ContainsAll(SetNodes(String("foo")))), + }, + { + "containsAny method call", + `permit (principal, action, resource) + when { context.strings.containsAny(["foo"]) };`, + Permit().When(Context().Access("strings").ContainsAny(SetNodes(String("foo")))), + }, + { + "extension method call", + `permit (principal, action, resource) + when { context.sourceIP.isIpv4() };`, + Permit().When(Context().Access("sourceIP").IsIpv4()), + }, + { + "multiplication", + `permit (principal, action, resource) + when { 42 * 2 };`, + Permit().When(Long(42).Times(Long(2))), + }, + { + "addition", + `permit (principal, action, resource) + when { 42 + 2 };`, + Permit().When(Long(42).Plus(Long(2))), + }, + { + "subtraction", + `permit (principal, action, resource) + when { 42 - 2 };`, + Permit().When(Long(42).Minus(Long(2))), + }, + { + "less than", + `permit (principal, action, resource) + when { 2 < 42 };`, + Permit().When(Long(2).LessThan(Long(42))), + }, + { + "less than or equal", + `permit (principal, action, resource) + when { 2 <= 42 };`, + Permit().When(Long(2).LessThanOrEqual(Long(42))), + }, + { + "greater than", + `permit (principal, action, resource) + when { 2 > 42 };`, + Permit().When(Long(2).GreaterThan(Long(42))), + }, + { + "greater than or equal", + `permit (principal, action, resource) + when { 2 >= 42 };`, + Permit().When(Long(2).GreaterThanOrEqual(Long(42))), + }, + { + "equal", + `permit (principal, action, resource) + when { 2 == 42 };`, + Permit().When(Long(2).Equals(Long(42))), + }, + { + "not equal", + `permit (principal, action, resource) + when { 2 != 42 };`, + Permit().When(Long(2).NotEquals(Long(42))), + }, + { + "in", + `permit (principal, action, resource) + when { principal in Group::"folkHeroes" };`, + Permit().When(Principal().In(Entity(folkHeroes))), + }, + { + "has ident", + `permit (principal, action, resource) + when { principal has firstName };`, + Permit().When(Principal().Has("firstName")), + }, + { + "has string", + `permit (principal, action, resource) + when { principal has "firstName" };`, + Permit().When(Principal().Has("firstName")), + }, + //{ + // "like no wildcards", + // `permit (principal, action, resource) + // when { principal.firstName like "johnny" };`, + // Permit().When(Principal().Has("firstName")), + //}, + { + "is", + `permit (principal, action, resource) + when { principal is User };`, + Permit().When(Principal().Is("User")), + }, + { + "is in", + `permit (principal, action, resource) + when { principal is User in Group::"folkHeroes" };`, + Permit().When(Principal().IsIn("User", Entity(folkHeroes))), + }, + { + "is in", + `permit (principal, action, resource) + when { principal is User in Group::"folkHeroes" };`, + Permit().When(Principal().IsIn("User", Entity(folkHeroes))), + }, + { + "and", + `permit (principal, action, resource) + when { true && false };`, + Permit().When(True().And(False())), + }, + { + "or", + `permit (principal, action, resource) + when { true || false };`, + Permit().When(True().Or(False())), + }, + { + "if then else", + `permit (principal, action, resource) + when { if true then true else false };`, + Permit().When(If(True(), True(), False())), + }, } for _, tt := range parseTests { - t.Run(tt.Text, func(t *testing.T) { + t.Run(tt.Name, func(t *testing.T) { t.Parallel() tokens, err := Tokenize([]byte(tt.Text)) diff --git a/x/exp/ast/tokenize.go b/x/exp/ast/tokenize.go index 5ecea604..f4c07355 100644 --- a/x/exp/ast/tokenize.go +++ b/x/exp/ast/tokenize.go @@ -36,6 +36,10 @@ func (t Token) isIdent() bool { return t.Type == TokenIdent } +func (t Token) isInt() bool { + return t.Type == TokenInt +} + func (t Token) isString() bool { return t.Type == TokenString } diff --git a/x/exp/ast/value.go b/x/exp/ast/value.go index 9adca6cb..bf3050f5 100644 --- a/x/exp/ast/value.go +++ b/x/exp/ast/value.go @@ -33,7 +33,7 @@ func Set(s types.Set) Node { for _, v := range s { nodes = append(nodes, valueToNode(v)) } - return SetNodes(nodes) + return SetNodes(nodes...) } // SetNodes allows for a complex set definition with values potentially @@ -43,12 +43,12 @@ func Set(s types.Set) Node { // // could be expressed in Golang as: // -// ast.SetNodes([]ast.Node{ +// ast.SetNodes( // ast.Long(1), // ast.Long(2).Plus(ast.Long(3)), // ast.Context().Access("fooCount"), -// }) -func SetNodes(nodes []Node) Node { +// ) +func SetNodes(nodes ...Node) Node { return Node{nodeType: nodeTypeSet, args: nodes} } @@ -59,7 +59,7 @@ func Record(r types.Record) Node { for k, v := range r { recordNodes[types.String(k)] = valueToNode(v) } - return RecordNodes(recordNodes) // TODO: maybe inline this to avoid the double conversion + return RecordNodes(recordNodes) } // RecordNodes allows for a complex record definition with values potentially @@ -69,9 +69,9 @@ func Record(r types.Record) Node { // // could be expressed in Golang as: // -// ast.RecordNodes([]ast.RecordNode{ -// {Key: "x", Value: ast.Long(1).Plus(ast.Context().Access("resourceField"))}, -// }) +// ast.RecordNodes(map[types.String]Node{ +// "x": ast.Long(1).Plus(ast.Context().Access("fooCount"))}, +// }) func RecordNodes(entries map[types.String]Node) Node { var nodes []Node for k, v := range entries { From 38b07de00633fa38eba4fbdebb62576c73da82bf Mon Sep 17 00:00:00 2001 From: philhassey Date: Thu, 1 Aug 2024 19:00:29 -0600 Subject: [PATCH 041/216] x/exp/ast: use strongly typed AST Addresses IDX-55 Signed-off-by: philhassey --- x/exp/ast/annotation.go | 22 +--- x/exp/ast/ast_test.go | 6 - x/exp/ast/json.go | 10 +- x/exp/ast/json_marshal.go | 247 ++++++++++++++++-------------------- x/exp/ast/json_test.go | 2 +- x/exp/ast/json_unmarshal.go | 29 +++-- x/exp/ast/node.go | 185 ++++++++++++++++----------- x/exp/ast/operator.go | 61 ++++----- x/exp/ast/parser.go | 2 +- x/exp/ast/policy.go | 22 ++-- x/exp/ast/scope.go | 131 +++++++++---------- x/exp/ast/value.go | 35 +++-- x/exp/ast/variable.go | 24 +++- 13 files changed, 383 insertions(+), 393 deletions(-) diff --git a/x/exp/ast/annotation.go b/x/exp/ast/annotation.go index d2a3d59e..c79bf94a 100644 --- a/x/exp/ast/annotation.go +++ b/x/exp/ast/annotation.go @@ -3,7 +3,7 @@ package ast import "github.com/cedar-policy/cedar-go/types" type Annotations struct { - nodes []Node + nodes []nodeTypeAnnotation } // Annotation allows AST constructors to make policy in a similar shape to textual Cedar with @@ -14,24 +14,14 @@ type Annotations struct { // Permit(). // PrincipalEq(superUser) func Annotation(name, value types.String) *Annotations { - return &Annotations{nodes: []Node{newAnnotationNode(name, value)}} + return &Annotations{nodes: []nodeTypeAnnotation{newAnnotation(name, value)}} } func (a *Annotations) Annotation(name, value types.String) *Annotations { - a.nodes = append(a.nodes, newAnnotationNode(name, value)) + a.nodes = append(a.nodes, newAnnotation(name, value)) return a } -type annotationNode Node - -func (n annotationNode) Key() types.String { - return n.args[0].value.(types.String) -} - -func (n annotationNode) Value() types.String { - return n.args[1].value.(types.String) -} - func (a *Annotations) Permit() *Policy { return newPolicy(effectPermit, a.nodes) } @@ -41,10 +31,10 @@ func (a *Annotations) Forbid() *Policy { } func (p *Policy) Annotate(name, value types.String) *Policy { - p.annotations = append(p.annotations, newAnnotationNode(name, value)) + p.annotations = append(p.annotations, nodeTypeAnnotation{Key: name, Value: value}) return p } -func newAnnotationNode(name, value types.String) Node { - return Node{nodeType: nodeTypeAnnotation, args: []Node{String(name), String(value)}} +func newAnnotation(name, value types.String) nodeTypeAnnotation { + return nodeTypeAnnotation{Key: name, Value: value} } diff --git a/x/exp/ast/ast_test.go b/x/exp/ast/ast_test.go index cac3b458..b75b1a5d 100644 --- a/x/exp/ast/ast_test.go +++ b/x/exp/ast/ast_test.go @@ -45,7 +45,6 @@ func TestAst(t *testing.T) { ) // forbid (principal, action, resource) - // when { resource[context.resourceField] == "specialValue" } // when { {x: "value"}.x == "value" } // when { {x: 1 + context.fooCount}.x == 3 } // when { [1, 2 + 3, context.fooCount].contains(1) }; @@ -53,11 +52,6 @@ func TestAst(t *testing.T) { "x": types.String("value"), } _ = ast.Forbid(). - When( - ast.Resource().AccessNode( - ast.Context().Access("resourceField"), - ).Equals(ast.String("specialValue")), - ). When( ast.Record(simpleRecord).Access("x").Equals(ast.String("value")), ). diff --git a/x/exp/ast/json.go b/x/exp/ast/json.go index 76786190..0000474e 100644 --- a/x/exp/ast/json.go +++ b/x/exp/ast/json.go @@ -15,7 +15,7 @@ type policyJSON struct { Conditions []conditionJSON `json:"conditions,omitempty"` } -type inJSON struct { +type scopeInJSON struct { Entity types.EntityUID `json:"entity"` } @@ -24,7 +24,7 @@ type scopeJSON struct { Entity *types.EntityUID `json:"entity,omitempty"` Entities []types.EntityUID `json:"entities,omitempty"` EntityType string `json:"entity_type,omitempty"` - In *inJSON `json:"in,omitempty"` + In *scopeInJSON `json:"in,omitempty"` } type conditionJSON struct { @@ -52,9 +52,9 @@ type patternJSON struct { } type isJSON struct { - Left nodeJSON `json:"left"` - EntityType string `json:"entity_type"` - In *inJSON `json:"in,omitempty"` + Left nodeJSON `json:"left"` + EntityType string `json:"entity_type"` + In *nodeJSON `json:"in,omitempty"` } type ifThenElseJSON struct { diff --git a/x/exp/ast/json_marshal.go b/x/exp/ast/json_marshal.go index 87439bb6..86fd01a5 100644 --- a/x/exp/ast/json_marshal.go +++ b/x/exp/ast/json_marshal.go @@ -7,70 +7,66 @@ import ( "github.com/cedar-policy/cedar-go/types" ) -func (s *scopeJSON) FromNode(src Node) error { - switch src.nodeType { - case nodeTypeAll: +func (s *scopeJSON) FromNode(src isScopeNode) error { + switch t := src.(type) { + case scopeTypeAll: s.Op = "All" return nil - case nodeTypeEquals: - n := scopeEqNode(src) + case scopeTypeEq: s.Op = "==" - e := n.Entity() + e := t.Entity s.Entity = &e return nil - case nodeTypeIn: - n := scopeInNode(src) + case scopeTypeIn: s.Op = "in" - if n.IsSet() { - s.Entities = n.Set() - } else { - e := n.Entity() - s.Entity = &e - } + e := t.Entity + s.Entity = &e return nil - case nodeTypeIs: - n := scopeIsNode(src) + case scopeTypeInSet: + s.Op = "in" + s.Entities = t.Entities + return nil + case scopeTypeIs: s.Op = "is" - s.EntityType = string(n.EntityType()) + s.EntityType = string(t.Type) return nil - case nodeTypeIsIn: // is in - n := scopeIsInNode(src) + case scopeTypeIsIn: s.Op = "is" - s.EntityType = string(n.EntityType()) - s.In = &inJSON{ - Entity: n.Entity(), + s.EntityType = string(t.Type) + s.In = &scopeInJSON{ + Entity: t.Entity, } return nil } - return fmt.Errorf("unexpected scope node: %v", src.nodeType) + return fmt.Errorf("unexpected scope node: %T", src) } -func unaryToJSON(dest **unaryJSON, src Node) error { +func unaryToJSON(dest **unaryJSON, src unaryNode) error { n := unaryNode(src) res := &unaryJSON{} - if err := res.Arg.FromNode(n.Arg()); err != nil { + if err := res.Arg.FromNode(n.Arg); err != nil { return fmt.Errorf("error in arg: %w", err) } *dest = res return nil } -func binaryToJSON(dest **binaryJSON, src Node) error { +func binaryToJSON(dest **binaryJSON, src binaryNode) error { n := binaryNode(src) res := &binaryJSON{} - if err := res.Left.FromNode(n.Left()); err != nil { + if err := res.Left.FromNode(n.Left); err != nil { return fmt.Errorf("error in left: %w", err) } - if err := res.Right.FromNode(n.Right()); err != nil { + if err := res.Right.FromNode(n.Right); err != nil { return fmt.Errorf("error in right: %w", err) } *dest = res return nil } -func arrayToJSON(dest *arrayJSON, src Node) error { +func arrayToJSON(dest *arrayJSON, args []node) error { res := arrayJSON{} - for _, n := range src.args { + for _, n := range args { var nn nodeJSON if err := nn.FromNode(n); err != nil { return fmt.Errorf("error in array: %w", err) @@ -81,12 +77,9 @@ func arrayToJSON(dest *arrayJSON, src Node) error { return nil } -func extToJSON(dest *arrayJSON, src Node) error { +func extToJSON(dest *arrayJSON, src types.Value) error { res := arrayJSON{} - if src.value == nil { - return fmt.Errorf("missing value") - } - str := src.value.String() // TODO: is this the correct string? + str := src.String() // TODO: is this the correct string? b, _ := json.Marshal(string(str)) // error impossible res = append(res, nodeJSON{ Value: (*json.RawMessage)(&b), @@ -95,15 +88,14 @@ func extToJSON(dest *arrayJSON, src Node) error { return nil } -func extMethodToJSON(dest extMethodCallJSON, src Node) error { - n := extMethodCallNode(src) +func extMethodToJSON(dest extMethodCallJSON, src nodeTypeExtMethodCall) error { objectNode := &nodeJSON{} - err := objectNode.FromNode(n.Object()) + err := objectNode.FromNode(src.Left) if err != nil { return err } jsonArgs := arrayJSON{*objectNode} - for _, n := range n.Args() { + for _, n := range src.Args { argNode := &nodeJSON{} err := argNode.FromNode(n) if err != nil { @@ -111,112 +103,104 @@ func extMethodToJSON(dest extMethodCallJSON, src Node) error { } jsonArgs = append(jsonArgs, *argNode) } - dest[n.Name()] = jsonArgs + dest[string(src.Method)] = jsonArgs return nil } -func strToJSON(dest **strJSON, src Node) error { - n := binaryNode(src) +func strToJSON(dest **strJSON, src strOpNode) error { res := &strJSON{} - if err := res.Left.FromNode(n.Left()); err != nil { + if err := res.Left.FromNode(src.Arg); err != nil { return fmt.Errorf("error in left: %w", err) } - str, ok := n.Right().value.(types.String) - if !ok { - return fmt.Errorf("right not string") - } - res.Attr = string(str) + res.Attr = string(src.Value) *dest = res return nil } -func patternToJSON(dest **patternJSON, src Node) error { - n := binaryNode(src) +func patternToJSON(dest **patternJSON, src strOpNode) error { res := &patternJSON{} - if err := res.Left.FromNode(n.Left()); err != nil { + if err := res.Left.FromNode(src.Arg); err != nil { return fmt.Errorf("error in left: %w", err) } - str, ok := n.Right().value.(types.String) - if !ok { - return fmt.Errorf("right not string") - } - res.Pattern = string(str) + res.Pattern = string(src.Value) *dest = res return nil } -func recordToJSON(dest *recordJSON, src Node) error { +func recordToJSON(dest *recordJSON, src nodeTypeRecord) error { res := recordJSON{} - for _, kv := range src.args { - n := binaryNode(kv) + for _, kv := range src.Elements { var nn nodeJSON - if err := nn.FromNode(n.Right()); err != nil { + if err := nn.FromNode(kv.Value); err != nil { return err } - str, ok := n.Left().value.(types.String) - if !ok { - return fmt.Errorf("left not string") - } - res[string(str)] = nn + res[string(kv.Key)] = nn } *dest = res return nil } -func ifToJSON(dest **ifThenElseJSON, src Node) error { - n := trinaryNode(src) +func ifToJSON(dest **ifThenElseJSON, src nodeTypeIf) error { res := &ifThenElseJSON{} - if err := res.If.FromNode(n.A()); err != nil { + if err := res.If.FromNode(src.If); err != nil { return fmt.Errorf("error in if: %w", err) } - if err := res.Then.FromNode(n.B()); err != nil { + if err := res.Then.FromNode(src.Then); err != nil { return fmt.Errorf("error in then: %w", err) } - if err := res.Else.FromNode(n.C()); err != nil { + if err := res.Else.FromNode(src.Else); err != nil { return fmt.Errorf("error in else: %w", err) } *dest = res return nil } -func isToJSON(dest **isJSON, src Node) error { - n := binaryNode(src) +func isToJSON(dest **isJSON, src nodeTypeIs) error { res := &isJSON{} - if err := res.Left.FromNode(n.Left()); err != nil { + if err := res.Left.FromNode(src.Left); err != nil { return fmt.Errorf("error in left: %w", err) } - str, ok := n.Right().value.(types.String) - if !ok { - return fmt.Errorf("right not a string") + res.EntityType = string(src.EntityType) + *dest = res + return nil +} + +func isInToJSON(dest **isJSON, src nodeTypeIsIn) error { + res := &isJSON{} + if err := res.Left.FromNode(src.Left); err != nil { + return fmt.Errorf("error in left: %w", err) } - res.EntityType = string(str) - if len(src.args) == 3 { - ent, ok := src.args[2].value.(types.EntityUID) - if !ok { - return fmt.Errorf("in not an entity") - } - res.In = &inJSON{ - Entity: ent, - } + res.EntityType = string(src.EntityType) + res.In = &nodeJSON{} + if err := res.In.FromNode(src.Entity); err != nil { + return fmt.Errorf("error in entity: %w", err) } *dest = res return nil } -func (j *nodeJSON) FromNode(src Node) error { - switch src.nodeType { +func (j *nodeJSON) FromNode(src node) error { + switch t := src.(type) { // Value // Value *json.RawMessage `json:"Value"` // could be any - case nodeTypeBoolean, nodeTypeLong, nodeTypeString, nodeTypeEntity: - b, err := src.value.ExplicitMarshalJSON() + case nodeValue: + // Any other function: decimal, ip + // Decimal arrayJSON `json:"decimal"` + // IP arrayJSON `json:"ip"` + switch tt := t.Value.(type) { + case types.Decimal: + return extToJSON(&j.Decimal, tt) + case types.IPAddr: + return extToJSON(&j.IP, tt) + } + b, err := t.Value.ExplicitMarshalJSON() j.Value = (*json.RawMessage)(&b) return err // Var // Var *string `json:"Var"` case nodeTypeVariable: - n := variableNode(src) - val := string(n.String()) + val := string(t.Name) j.Var = &val return nil @@ -224,94 +208,87 @@ func (j *nodeJSON) FromNode(src Node) error { // Not *unaryJSON `json:"!"` // Negate *unaryJSON `json:"neg"` case nodeTypeNot: - return unaryToJSON(&j.Not, src) + return unaryToJSON(&j.Not, t.unaryNode) case nodeTypeNegate: - return unaryToJSON(&j.Negate, src) + return unaryToJSON(&j.Negate, t.unaryNode) // Binary operators: ==, !=, in, <, <=, >, >=, &&, ||, +, -, *, contains, containsAll, containsAny case nodeTypeAdd: - return binaryToJSON(&j.Plus, src) + return binaryToJSON(&j.Plus, t.binaryNode) case nodeTypeAnd: - return binaryToJSON(&j.And, src) + return binaryToJSON(&j.And, t.binaryNode) case nodeTypeContains: - return binaryToJSON(&j.Contains, src) + return binaryToJSON(&j.Contains, t.binaryNode) case nodeTypeContainsAll: - return binaryToJSON(&j.ContainsAll, src) + return binaryToJSON(&j.ContainsAll, t.binaryNode) case nodeTypeContainsAny: - return binaryToJSON(&j.ContainsAny, src) + return binaryToJSON(&j.ContainsAny, t.binaryNode) case nodeTypeEquals: - return binaryToJSON(&j.Equals, src) - case nodeTypeGreater: - return binaryToJSON(&j.GreaterThan, src) - case nodeTypeGreaterEqual: - return binaryToJSON(&j.GreaterThanOrEqual, src) + return binaryToJSON(&j.Equals, t.binaryNode) + case nodeTypeGreaterThan: + return binaryToJSON(&j.GreaterThan, t.binaryNode) + case nodeTypeGreaterThanOrEqual: + return binaryToJSON(&j.GreaterThanOrEqual, t.binaryNode) case nodeTypeIn: - return binaryToJSON(&j.In, src) - case nodeTypeLess: - return binaryToJSON(&j.LessThan, src) - case nodeTypeLessEqual: - return binaryToJSON(&j.LessThanOrEqual, src) + return binaryToJSON(&j.In, t.binaryNode) + case nodeTypeLessThan: + return binaryToJSON(&j.LessThan, t.binaryNode) + case nodeTypeLessThanOrEqual: + return binaryToJSON(&j.LessThanOrEqual, t.binaryNode) case nodeTypeMult: - return binaryToJSON(&j.Times, src) + return binaryToJSON(&j.Times, t.binaryNode) case nodeTypeNotEquals: - return binaryToJSON(&j.NotEquals, src) + return binaryToJSON(&j.NotEquals, t.binaryNode) case nodeTypeOr: - return binaryToJSON(&j.Or, src) + return binaryToJSON(&j.Or, t.binaryNode) case nodeTypeSub: - return binaryToJSON(&j.Minus, src) + return binaryToJSON(&j.Minus, t.binaryNode) // ., has // Access *strJSON `json:"."` // Has *strJSON `json:"has"` case nodeTypeAccess: - return strToJSON(&j.Access, src) + return strToJSON(&j.Access, t.strOpNode) case nodeTypeHas: - return strToJSON(&j.Has, src) + return strToJSON(&j.Has, t.strOpNode) // is - case nodeTypeIs, nodeTypeIsIn: - return isToJSON(&j.Is, src) + case nodeTypeIs: + return isToJSON(&j.Is, t) + case nodeTypeIsIn: + return isInToJSON(&j.Is, t) // like // Like *strJSON `json:"like"` case nodeTypeLike: - return patternToJSON(&j.Like, src) + return patternToJSON(&j.Like, t.strOpNode) // if-then-else // IfThenElse *ifThenElseJSON `json:"if-then-else"` case nodeTypeIf: - return ifToJSON(&j.IfThenElse, src) + return ifToJSON(&j.IfThenElse, t) // Set // Set arrayJSON `json:"Set"` case nodeTypeSet: - return arrayToJSON(&j.Set, src) + return arrayToJSON(&j.Set, t.Elements) // Record // Record recordJSON `json:"Record"` case nodeTypeRecord: - return recordToJSON(&j.Record, src) - - // Any other function: decimal, ip - // Decimal arrayJSON `json:"decimal"` - // IP arrayJSON `json:"ip"` - case nodeTypeDecimal: - return extToJSON(&j.Decimal, src) - - case nodeTypeIpAddr: - return extToJSON(&j.IP, src) + return recordToJSON(&j.Record, t) // Any other method: lessThan, lessThanOrEqual, greaterThan, greaterThanOrEqual, isIpv4, isIpv6, isLoopback, isMulticast, isInRange // ExtensionMethod map[string]arrayJSON `json:"-"` case nodeTypeExtMethodCall: j.ExtensionMethod = extMethodCallJSON{} - return extMethodToJSON(j.ExtensionMethod, src) + return extMethodToJSON(j.ExtensionMethod, t) } // case nodeTypeRecordEntry: // case nodeTypeEntityType: // case nodeTypeAnnotation: // case nodeTypeWhen: // case nodeTypeUnless: - return fmt.Errorf("unknown node type: %v", src.nodeType) + return fmt.Errorf("unknown node type: %T", src) } func (j *nodeJSON) MarshalJSON() ([]byte, error) { @@ -333,8 +310,7 @@ func (p *Policy) MarshalJSON() ([]byte, error) { j.Annotations = map[string]string{} } for _, a := range p.annotations { - n := annotationNode(a) - j.Annotations[string(n.Key())] = string(n.Value()) + j.Annotations[string(a.Key)] = string(a.Value) } if err := j.Principal.FromNode(p.principal); err != nil { return nil, fmt.Errorf("error in principal: %w", err) @@ -346,13 +322,12 @@ func (p *Policy) MarshalJSON() ([]byte, error) { return nil, fmt.Errorf("error in resource: %w", err) } for _, c := range p.conditions { - n := unaryNode(c) var cond conditionJSON cond.Kind = "when" - if c.nodeType == nodeTypeUnless { + if c.Condition == conditionUnless { cond.Kind = "unless" } - if err := cond.Body.FromNode(n.Arg()); err != nil { + if err := cond.Body.FromNode(c.Body); err != nil { return nil, fmt.Errorf("error in condition: %w", err) } j.Conditions = append(j.Conditions, cond) diff --git a/x/exp/ast/json_test.go b/x/exp/ast/json_test.go index 107bb9c8..efe03090 100644 --- a/x/exp/ast/json_test.go +++ b/x/exp/ast/json_test.go @@ -403,7 +403,7 @@ func TestUnmarshalJSON(t *testing.T) { { "isIn", `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, - "conditions":[{"kind":"when","body":{"is":{"left":{"Var":"resource"},"entity_type":"T","in":{"entity":{"type":"P","id":"42"}}}}}]}`, + "conditions":[{"kind":"when","body":{"is":{"left":{"Var":"resource"},"entity_type":"T","in":{"Value":{"__entity":{"type":"P","id":"42"}}}}}}]}`, ast.Permit().When(ast.Resource().IsIn("T", ast.Entity(types.NewEntityUID("P", "42")))), testutil.OK, }, diff --git a/x/exp/ast/json_unmarshal.go b/x/exp/ast/json_unmarshal.go index 683d12e6..13f13bad 100644 --- a/x/exp/ast/json_unmarshal.go +++ b/x/exp/ast/json_unmarshal.go @@ -9,13 +9,14 @@ import ( "github.com/cedar-policy/cedar-go/types" ) -func (s *scopeJSON) ToNode(variable scope) (Node, error) { +func (s *scopeJSON) ToNode(variable scope) (isScopeNode, error) { + // TODO: should we be careful to be more strict about what is allowed here? switch s.Op { case "All": return variable.All(), nil case "==": if s.Entity == nil { - return Node{}, fmt.Errorf("missing entity") + return nil, fmt.Errorf("missing entity") } return variable.Eq(*s.Entity), nil case "in": @@ -29,7 +30,7 @@ func (s *scopeJSON) ToNode(variable scope) (Node, error) { } return variable.IsIn(types.String(s.EntityType), s.In.Entity), nil } - return Node{}, fmt.Errorf("unknown op: %v", s.Op) + return nil, fmt.Errorf("unknown op: %v", s.Op) } func (j binaryJSON) ToNode(f func(a, b Node) Node) (Node, error) { @@ -70,7 +71,11 @@ func (j isJSON) ToNode() (Node, error) { return Node{}, fmt.Errorf("error in left: %w", err) } if j.In != nil { - return left.IsIn(types.String(j.EntityType), Entity(j.In.Entity)), nil + right, err := j.In.ToNode() + if err != nil { + return Node{}, fmt.Errorf("error in entity: %w", err) + } + return left.IsIn(types.String(j.EntityType), right), nil } return left.Is(types.String(j.EntityType)), nil } @@ -109,11 +114,11 @@ func (j arrayJSON) ToDecimalNode() (Node, error) { if err != nil { return Node{}, fmt.Errorf("error in extension: %w", err) } - s, ok := arg.value.(types.String) + s, ok := arg.v.(nodeValue) if !ok { return Node{}, fmt.Errorf("unexpected type for decimal") } - v, err := types.ParseDecimal(string(s)) + v, err := types.ParseDecimal(s.Value.String()) // TODO: this maybe isn't correct if err != nil { return Node{}, fmt.Errorf("error parsing decimal: %w", err) } @@ -128,11 +133,11 @@ func (j arrayJSON) ToIPAddrNode() (Node, error) { if err != nil { return Node{}, fmt.Errorf("error in extension: %w", err) } - s, ok := arg.value.(types.String) + s, ok := arg.v.(nodeValue) if !ok { return Node{}, fmt.Errorf("unexpected type for ipaddr") } - v, err := types.ParseIPAddr(string(s)) + v, err := types.ParseIPAddr(s.Value.String()) if err != nil { return Node{}, fmt.Errorf("error parsing ipaddr: %w", err) } @@ -167,7 +172,7 @@ func (e extMethodCallJSON) ToNode() (Node, error) { } argNodes = append(argNodes, node) } - return newExtMethodCallNode(argNodes[0], k, argNodes[1:]...), nil + return newExtMethodCallNode(argNodes[0], types.String(k), argNodes[1:]...), nil } panic("unreachable code") } @@ -313,15 +318,15 @@ func (p *Policy) UnmarshalJSON(b []byte) error { p.Annotate(types.String(k), types.String(v)) } var err error - p.principal, err = j.Principal.ToNode(scope(Principal())) + p.principal, err = j.Principal.ToNode(scope(rawPrincipalNode())) if err != nil { return fmt.Errorf("error in principal: %w", err) } - p.action, err = j.Action.ToNode(scope(Action())) + p.action, err = j.Action.ToNode(scope(rawActionNode())) if err != nil { return fmt.Errorf("error in action: %w", err) } - p.resource, err = j.Resource.ToNode(scope(Resource())) + p.resource, err = j.Resource.ToNode(scope(rawResourceNode())) if err != nil { return fmt.Errorf("error in resource: %w", err) } diff --git a/x/exp/ast/node.go b/x/exp/ast/node.go index 6d1cb70f..a7411c69 100644 --- a/x/exp/ast/node.go +++ b/x/exp/ast/node.go @@ -2,101 +2,140 @@ package ast import "github.com/cedar-policy/cedar-go/types" -type nodeType uint8 +type strOpNode struct { + node + Arg node + Value types.String +} -const ( - nodeTypeAccess = iota - nodeTypeAdd - nodeTypeAll - nodeTypeAnd - nodeTypeAnnotation - nodeTypeBoolean - nodeTypeContains - nodeTypeContainsAll - nodeTypeContainsAny - nodeTypeDecimal - nodeTypeEntity - nodeTypeEntityType - nodeTypeEquals - nodeTypeGreater - nodeTypeGreaterEqual - nodeTypeHas - nodeTypeIf - nodeTypeIn - nodeTypeIpAddr +type nodeTypeAccess struct{ strOpNode } +type nodeTypeHas struct{ strOpNode } +type nodeTypeLike struct{ strOpNode } + +type nodeTypeAnnotation struct { + node + Key types.String // TODO: review type + Value types.String +} + +type nodeTypeIf struct { + node + If, Then, Else node +} + +type nodeTypeIs struct { + node + Left node + EntityType types.String // TODO: review type +} + +type nodeTypeIsIn struct { nodeTypeIs - nodeTypeIsIn - nodeTypeLess - nodeTypeLessEqual - nodeTypeLike - nodeTypeLong - nodeTypeExtMethodCall - nodeTypeMult - nodeTypeNegate - nodeTypeNot - nodeTypeNotEquals - nodeTypeOr - nodeTypeRecord - nodeTypeRecordEntry - nodeTypeSet - nodeTypeString - nodeTypeSub - nodeTypeUnless - nodeTypeVariable - nodeTypeWhen -) + Entity node +} -type Node struct { - nodeType nodeType - args []Node // For inner nodes like operators, records, etc - value types.Value // For leaf nodes like String, Long, EntityUID +type nodeTypeScopeIsIn struct { + nodeTypeIs + Entity types.EntityUID } -func newUnaryNode(op nodeType, arg Node) Node { - return Node{nodeType: op, args: []Node{arg}} +type nodeTypeExtMethodCall struct { + node + Left node + Method types.String // TODO: review type + Args []node } -type unaryNode Node +func stripNodes(args []Node) []node { + res := make([]node, len(args)) + for i, v := range args { + res[i] = v.v + } + return res +} -func (n unaryNode) Arg() Node { return n.args[0] } +func newExtMethodCallNode(left Node, method types.String, args ...Node) Node { + return newNode(nodeTypeExtMethodCall{ + Left: left.v, + Method: method, + Args: stripNodes(args), + }) +} -func newBinaryNode(op nodeType, arg1, arg2 Node) Node { - return Node{nodeType: op, args: []Node{arg1, arg2}} +type nodeValue struct { + node + Value types.Value } -type binaryNode Node +type recordElement struct { + Key types.String + Value node +} +type nodeTypeRecord struct { + node + Elements []recordElement +} -func (n binaryNode) Left() Node { return n.args[0] } -func (n binaryNode) Right() Node { return n.args[1] } +type nodeTypeSet struct { + node + Elements []node +} -func newTrinaryNode(op nodeType, arg1, arg2, arg3 Node) Node { - return Node{nodeType: op, args: []Node{arg1, arg2, arg3}} +type unaryNode struct { + node + Arg node } -type trinaryNode Node +type nodeTypeNegate struct{ unaryNode } +type nodeTypeNot struct{ unaryNode } -func (n trinaryNode) A() Node { return n.args[0] } -func (n trinaryNode) B() Node { return n.args[1] } -func (n trinaryNode) C() Node { return n.args[2] } +type condition bool -func newExtMethodCallNode(object Node, methodName string, args ...Node) Node { - nodes := []Node{object, String(types.String(methodName))} - return Node{ - nodeType: nodeTypeExtMethodCall, - args: append(nodes, args...), - } +const ( + conditionWhen = true + conditionUnless = false +) + +type nodeTypeCondition struct { + node + Condition condition + Body node +} + +type nodeTypeVariable struct { + node + Name types.String // TODO: Review type } -type extMethodCallNode Node +type binaryNode struct { + node + Left, Right node +} -func (n extMethodCallNode) Object() Node { - return n.args[0] +type nodeTypeIn struct{ binaryNode } +type nodeTypeAnd struct{ binaryNode } +type nodeTypeEquals struct{ binaryNode } +type nodeTypeGreaterThan struct{ binaryNode } +type nodeTypeGreaterThanOrEqual struct{ binaryNode } +type nodeTypeLessThan struct{ binaryNode } +type nodeTypeLessThanOrEqual struct{ binaryNode } +type nodeTypeSub struct{ binaryNode } +type nodeTypeAdd struct{ binaryNode } +type nodeTypeContains struct{ binaryNode } +type nodeTypeContainsAll struct{ binaryNode } +type nodeTypeContainsAny struct{ binaryNode } +type nodeTypeMult struct{ binaryNode } +type nodeTypeNotEquals struct{ binaryNode } +type nodeTypeOr struct{ binaryNode } + +type node interface { + isNode() } -func (n extMethodCallNode) Name() string { - return string(n.args[1].value.(types.String)) +type Node struct { + v node // NOTE: not an embed because a `Node` is not a `node` } -func (n extMethodCallNode) Args() []Node { - return n.args[2:] +func newNode(v node) Node { + return Node{v: v} } diff --git a/x/exp/ast/operator.go b/x/exp/ast/operator.go index 8f31f562..e27bf01d 100644 --- a/x/exp/ast/operator.go +++ b/x/exp/ast/operator.go @@ -10,27 +10,27 @@ import "github.com/cedar-policy/cedar-go/types" // |_| func (lhs Node) Equals(rhs Node) Node { - return newBinaryNode(nodeTypeEquals, lhs, rhs) + return newNode(nodeTypeEquals{binaryNode: binaryNode{Left: lhs.v, Right: rhs.v}}) } func (lhs Node) NotEquals(rhs Node) Node { - return newBinaryNode(nodeTypeNotEquals, lhs, rhs) + return newNode(nodeTypeNotEquals{binaryNode: binaryNode{Left: lhs.v, Right: rhs.v}}) } func (lhs Node) LessThan(rhs Node) Node { - return newBinaryNode(nodeTypeLess, lhs, rhs) + return newNode(nodeTypeLessThan{binaryNode: binaryNode{Left: lhs.v, Right: rhs.v}}) } func (lhs Node) LessThanOrEqual(rhs Node) Node { - return newBinaryNode(nodeTypeLessEqual, lhs, rhs) + return newNode(nodeTypeLessThanOrEqual{binaryNode: binaryNode{Left: lhs.v, Right: rhs.v}}) } func (lhs Node) GreaterThan(rhs Node) Node { - return newBinaryNode(nodeTypeGreater, lhs, rhs) + return newNode(nodeTypeGreaterThan{binaryNode: binaryNode{Left: lhs.v, Right: rhs.v}}) } func (lhs Node) GreaterThanOrEqual(rhs Node) Node { - return newBinaryNode(nodeTypeGreaterEqual, lhs, rhs) + return newNode(nodeTypeGreaterThanOrEqual{binaryNode: binaryNode{Left: lhs.v, Right: rhs.v}}) } func (lhs Node) LessThanExt(rhs Node) Node { @@ -50,7 +50,7 @@ func (lhs Node) GreaterThanOrEqualExt(rhs Node) Node { } func (lhs Node) Like(patt string) Node { - return newBinaryNode(nodeTypeLike, lhs, String(types.String(patt))) + return newNode(nodeTypeLike{strOpNode: strOpNode{Arg: lhs.v, Value: types.String(patt)}}) } // _ _ _ @@ -61,19 +61,19 @@ func (lhs Node) Like(patt string) Node { // |___/ func (lhs Node) And(rhs Node) Node { - return newBinaryNode(nodeTypeAnd, lhs, rhs) + return newNode(nodeTypeAnd{binaryNode: binaryNode{Left: lhs.v, Right: rhs.v}}) } func (lhs Node) Or(rhs Node) Node { - return newBinaryNode(nodeTypeOr, lhs, rhs) + return newNode(nodeTypeOr{binaryNode: binaryNode{Left: lhs.v, Right: rhs.v}}) } func Not(rhs Node) Node { - return newUnaryNode(nodeTypeNot, rhs) + return newNode(nodeTypeNot{unaryNode: unaryNode{Arg: rhs.v}}) } func If(condition Node, ifTrue Node, ifFalse Node) Node { - return newTrinaryNode(nodeTypeIf, condition, ifTrue, ifFalse) + return newNode(nodeTypeIf{If: condition.v, Then: ifTrue.v, Else: ifFalse.v}) } // _ _ _ _ _ _ @@ -83,19 +83,19 @@ func If(condition Node, ifTrue Node, ifFalse Node) Node { // /_/ \_\_| |_|\__|_| |_|_| |_| |_|\___|\__|_|\___| func (lhs Node) Plus(rhs Node) Node { - return newBinaryNode(nodeTypeAdd, lhs, rhs) + return newNode(nodeTypeAdd{binaryNode: binaryNode{Left: lhs.v, Right: rhs.v}}) } func (lhs Node) Minus(rhs Node) Node { - return newBinaryNode(nodeTypeSub, lhs, rhs) + return newNode(nodeTypeSub{binaryNode: binaryNode{Left: lhs.v, Right: rhs.v}}) } func (lhs Node) Times(rhs Node) Node { - return newBinaryNode(nodeTypeMult, lhs, rhs) + return newNode(nodeTypeMult{binaryNode: binaryNode{Left: lhs.v, Right: rhs.v}}) } func Negate(rhs Node) Node { - return newUnaryNode(nodeTypeNegate, rhs) + return newNode(nodeTypeNegate{unaryNode: unaryNode{Arg: rhs.v}}) } // _ _ _ _ @@ -106,52 +106,37 @@ func Negate(rhs Node) Node { // |___/ func (lhs Node) In(rhs Node) Node { - return newBinaryNode(nodeTypeIn, lhs, rhs) + return newNode(nodeTypeIn{binaryNode: binaryNode{Left: lhs.v, Right: rhs.v}}) } func (lhs Node) Is(entityType types.String) Node { - return newBinaryNode(nodeTypeIs, lhs, String(entityType)) + return newNode(nodeTypeIs{Left: lhs.v, EntityType: entityType}) } func (lhs Node) IsIn(entityType types.String, rhs Node) Node { - return newTrinaryNode(nodeTypeIsIn, lhs, String(entityType), rhs) + return newNode(nodeTypeIsIn{nodeTypeIs: nodeTypeIs{Left: lhs.v, EntityType: entityType}, Entity: rhs.v}) } func (lhs Node) Contains(rhs Node) Node { - return newBinaryNode(nodeTypeContains, lhs, rhs) + return newNode(nodeTypeContains{binaryNode: binaryNode{Left: lhs.v, Right: rhs.v}}) } func (lhs Node) ContainsAll(rhs Node) Node { - return newBinaryNode(nodeTypeContainsAll, lhs, rhs) + return newNode(nodeTypeContainsAll{binaryNode: binaryNode{Left: lhs.v, Right: rhs.v}}) } func (lhs Node) ContainsAny(rhs Node) Node { - return newBinaryNode(nodeTypeContainsAny, lhs, rhs) + return newNode(nodeTypeContainsAny{binaryNode: binaryNode{Left: lhs.v, Right: rhs.v}}) } // Access is a convenience function that wraps a simple string // in an ast.String() and passes it along to AccessNode. func (lhs Node) Access(attr string) Node { - return lhs.AccessNode(String(types.String(attr))) -} - -// AccessNode is a version of the access operator which allows -// more complex access of attributes, such as might be expressed -// by this Cedar text: -// -// resource[context.resourceAttribute] == "foo" -// -// In Golang, this could be expressed as: -// -// ast.Resource().AccessNode( -// ast.Context().Access("resourceAttribute") -// ).Equals(ast.String("foo")) -func (lhs Node) AccessNode(rhs Node) Node { - return newBinaryNode(nodeTypeAccess, lhs, rhs) + return newNode(nodeTypeAccess{strOpNode: strOpNode{Arg: lhs.v, Value: types.String(attr)}}) } func (lhs Node) Has(attr string) Node { - return newBinaryNode(nodeTypeHas, lhs, String(types.String(attr))) + return newNode(nodeTypeHas{strOpNode: strOpNode{Arg: lhs.v, Value: types.String(attr)}}) } // ___ ____ _ _ _ diff --git a/x/exp/ast/parser.go b/x/exp/ast/parser.go index 59719007..08d46ac3 100644 --- a/x/exp/ast/parser.go +++ b/x/exp/ast/parser.go @@ -792,7 +792,7 @@ func (p *parser) access(lhs Node) (Node, bool, error) { } return f(exprs[0]), true, nil } - return newExtMethodCallNode(lhs, methodName, exprs...), true, nil + return newExtMethodCallNode(lhs, types.String(methodName), exprs...), true, nil } else { return lhs.Access(t.Text), true, nil } diff --git a/x/exp/ast/policy.go b/x/exp/ast/policy.go index 250a625c..29305c6b 100644 --- a/x/exp/ast/policy.go +++ b/x/exp/ast/policy.go @@ -2,20 +2,20 @@ package ast type Policy struct { effect effect - annotations []Node - principal Node - action Node - resource Node - conditions []Node + annotations []nodeTypeAnnotation + principal isScopeNode + action isScopeNode + resource isScopeNode + conditions []nodeTypeCondition } -func newPolicy(effect effect, annotations []Node) *Policy { +func newPolicy(effect effect, annotations []nodeTypeAnnotation) *Policy { return &Policy{ effect: effect, annotations: annotations, - principal: scope(Principal()).All(), - action: scope(Action()).All(), - resource: scope(Resource()).All(), + principal: scope(rawPrincipalNode()).All(), + action: scope(rawActionNode()).All(), + resource: scope(rawResourceNode()).All(), } } @@ -28,12 +28,12 @@ func Forbid() *Policy { } func (p *Policy) When(node Node) *Policy { - p.conditions = append(p.conditions, Node{nodeType: nodeTypeWhen, args: []Node{node}}) + p.conditions = append(p.conditions, nodeTypeCondition{Condition: conditionWhen, Body: node.v}) return p } func (p *Policy) Unless(node Node) *Policy { - p.conditions = append(p.conditions, Node{nodeType: nodeTypeUnless, args: []Node{node}}) + p.conditions = append(p.conditions, nodeTypeCondition{Condition: conditionUnless, Body: node.v}) return p } diff --git a/x/exp/ast/scope.go b/x/exp/ast/scope.go index a4fd5dbf..9d13bc37 100644 --- a/x/exp/ast/scope.go +++ b/x/exp/ast/scope.go @@ -2,127 +2,122 @@ package ast import "github.com/cedar-policy/cedar-go/types" -type scope Node +type scope nodeTypeVariable -func (s scope) All() Node { - return Node{nodeType: nodeTypeAll, args: []Node{Node(s)}} +func (s scope) All() isScopeNode { + return scopeTypeAll{scopeNode: scopeNode{Variable: nodeTypeVariable(s)}} } -func (s scope) Eq(entity types.EntityUID) Node { - return Node(s).Equals(Entity(entity)) +func (s scope) Eq(entity types.EntityUID) isScopeNode { + return scopeTypeEq{scopeNode: scopeNode{Variable: nodeTypeVariable(s)}, Entity: entity} } -type scopeEqNode Node - -func (n scopeEqNode) Entity() types.EntityUID { - return n.args[1].value.(types.EntityUID) -} - -func (s scope) In(entity types.EntityUID) Node { - return Node(s).In(Entity(entity)) -} - -func (s scope) InSet(entities []types.EntityUID) Node { - var entityValues []types.Value - for _, e := range entities { - entityValues = append(entityValues, e) - } - return Node(s).In(Set(entityValues)) -} - -type scopeInNode Node - -func (n scopeInNode) IsSet() bool { - return Node(n).args[1].nodeType == nodeTypeSet -} - -func (n scopeInNode) Entity() types.EntityUID { - return n.args[1].value.(types.EntityUID) -} - -func (n scopeInNode) Set() []types.EntityUID { - var res []types.EntityUID - for _, a := range n.args[1].args { - res = append(res, a.value.(types.EntityUID)) - } - return res -} - -func (s scope) Is(entityType types.String) Node { - return Node(s).Is(entityType) -} - -type scopeIsNode Node - -func (n scopeIsNode) EntityType() types.String { - return n.args[1].value.(types.String) +func (s scope) In(entity types.EntityUID) isScopeNode { + return scopeTypeIn{scopeNode: scopeNode{Variable: nodeTypeVariable(s)}, Entity: entity} } -func (s scope) IsIn(entityType types.String, entity types.EntityUID) Node { - return Node(s).IsIn(entityType, Entity(entity)) +func (s scope) InSet(entities []types.EntityUID) isScopeNode { + return scopeTypeInSet{scopeNode: scopeNode{Variable: nodeTypeVariable(s)}, Entities: entities} } -type scopeIsInNode Node - -func (n scopeIsInNode) EntityType() types.String { - return n.args[1].value.(types.String) +func (s scope) Is(entityType types.String) isScopeNode { + return scopeTypeIs{scopeNode: scopeNode{Variable: nodeTypeVariable(s)}, Type: entityType} } -func (n scopeIsInNode) Entity() types.EntityUID { - return n.args[2].value.(types.EntityUID) +func (s scope) IsIn(entityType types.String, entity types.EntityUID) isScopeNode { + return scopeTypeIsIn{scopeNode: scopeNode{Variable: nodeTypeVariable(s)}, Type: entityType, Entity: entity} } func (p *Policy) PrincipalEq(entity types.EntityUID) *Policy { - p.principal = scope(Principal()).Eq(entity) + p.principal = scope(rawPrincipalNode()).Eq(entity) return p } func (p *Policy) PrincipalIn(entity types.EntityUID) *Policy { - p.principal = scope(Principal()).In(entity) + p.principal = scope(rawPrincipalNode()).In(entity) return p } func (p *Policy) PrincipalIs(entityType types.String) *Policy { - p.principal = scope(Principal()).Is(entityType) + p.principal = scope(rawPrincipalNode()).Is(entityType) return p } func (p *Policy) PrincipalIsIn(entityType types.String, entity types.EntityUID) *Policy { - p.principal = scope(Principal()).IsIn(entityType, entity) + p.principal = scope(rawPrincipalNode()).IsIn(entityType, entity) return p } func (p *Policy) ActionEq(entity types.EntityUID) *Policy { - p.action = scope(Action()).Eq(entity) + p.action = scope(rawActionNode()).Eq(entity) return p } func (p *Policy) ActionIn(entity types.EntityUID) *Policy { - p.action = scope(Action()).In(entity) + p.action = scope(rawActionNode()).In(entity) return p } func (p *Policy) ActionInSet(entities ...types.EntityUID) *Policy { - p.action = scope(Action()).InSet(entities) + p.action = scope(rawActionNode()).InSet(entities) return p } func (p *Policy) ResourceEq(entity types.EntityUID) *Policy { - p.resource = scope(Resource()).Eq(entity) + p.resource = scope(rawResourceNode()).Eq(entity) return p } func (p *Policy) ResourceIn(entity types.EntityUID) *Policy { - p.resource = scope(Resource()).In(entity) + p.resource = scope(rawResourceNode()).In(entity) return p } func (p *Policy) ResourceIs(entityType types.String) *Policy { - p.resource = scope(Resource()).Is(entityType) + p.resource = scope(rawResourceNode()).Is(entityType) return p } func (p *Policy) ResourceIsIn(entityType types.String, entity types.EntityUID) *Policy { - p.resource = scope(Resource()).IsIn(entityType, entity) + p.resource = scope(rawResourceNode()).IsIn(entityType, entity) return p } + +type isScopeNode interface { + isScope() +} + +type scopeNode struct { + isScopeNode + Variable nodeTypeVariable +} + +type scopeTypeAll struct { + scopeNode +} + +type scopeTypeEq struct { + scopeNode + Entity types.EntityUID +} + +type scopeTypeIn struct { + scopeNode + Entity types.EntityUID +} + +type scopeTypeInSet struct { + scopeNode + Entities []types.EntityUID +} + +type scopeTypeIs struct { + scopeNode + Type types.String +} + +type scopeTypeIsIn struct { + scopeNode + Type types.String + Entity types.EntityUID +} diff --git a/x/exp/ast/value.go b/x/exp/ast/value.go index bf3050f5..4e52ca54 100644 --- a/x/exp/ast/value.go +++ b/x/exp/ast/value.go @@ -7,7 +7,7 @@ import ( ) func Boolean(b types.Boolean) Node { - return newValueNode(nodeTypeBoolean, b) + return newValueNode(b) } func True() Node { @@ -19,21 +19,21 @@ func False() Node { } func String(s types.String) Node { - return newValueNode(nodeTypeString, s) + return newValueNode(s) } func Long(l types.Long) Node { - return newValueNode(nodeTypeLong, l) + return newValueNode(l) } // Set is a convenience function that wraps concrete instances of a Cedar Set type // types in AST value nodes and passes them along to SetNodes. func Set(s types.Set) Node { - var nodes []Node + var nodes []node for _, v := range s { - nodes = append(nodes, valueToNode(v)) + nodes = append(nodes, valueToNode(v).v) } - return SetNodes(nodes...) + return newNode(nodeTypeSet{Elements: nodes}) } // SetNodes allows for a complex set definition with values potentially @@ -49,7 +49,7 @@ func Set(s types.Set) Node { // ast.Context().Access("fooCount"), // ) func SetNodes(nodes ...Node) Node { - return Node{nodeType: nodeTypeSet, args: nodes} + return newNode(nodeTypeSet{Elements: stripNodes(nodes)}) } // Record is a convenience function that wraps concrete instances of a Cedar Record type @@ -73,34 +73,31 @@ func Record(r types.Record) Node { // "x": ast.Long(1).Plus(ast.Context().Access("fooCount"))}, // }) func RecordNodes(entries map[types.String]Node) Node { - var nodes []Node + var res nodeTypeRecord for k, v := range entries { - nodes = append( - nodes, - newBinaryNode(nodeTypeRecordEntry, String(k), v), - ) + res.Elements = append(res.Elements, recordElement{Key: k, Value: v.v}) } - return Node{nodeType: nodeTypeRecord, args: nodes} + return newNode(res) } func EntityType(e types.String) Node { - return newValueNode(nodeTypeEntityType, e) + return newValueNode(e) } func Entity(e types.EntityUID) Node { - return newValueNode(nodeTypeEntity, e) + return newValueNode(e) } func Decimal(d types.Decimal) Node { - return newValueNode(nodeTypeDecimal, d) + return newValueNode(d) } func IPAddr(i types.IPAddr) Node { - return newValueNode(nodeTypeIpAddr, i) + return newValueNode(i) } -func newValueNode(nodeType nodeType, v types.Value) Node { - return Node{nodeType: nodeType, value: v} +func newValueNode(v types.Value) Node { + return newNode(nodeValue{Value: v}) } func valueToNode(v types.Value) Node { diff --git a/x/exp/ast/variable.go b/x/exp/ast/variable.go index a45467b1..e14cc783 100644 --- a/x/exp/ast/variable.go +++ b/x/exp/ast/variable.go @@ -19,23 +19,33 @@ func Context() Node { } func newPrincipalNode() Node { - return newValueNode(nodeTypeVariable, types.String("principal")) + return newNode(rawPrincipalNode()) } func newActionNode() Node { - return newValueNode(nodeTypeVariable, types.String("action")) + return newNode(rawActionNode()) } func newResourceNode() Node { - return newValueNode(nodeTypeVariable, types.String("resource")) + return newNode(rawResourceNode()) } func newContextNode() Node { - return newValueNode(nodeTypeVariable, types.String("context")) + return newNode(rawContextNode()) } -type variableNode Node +func rawPrincipalNode() nodeTypeVariable { + return nodeTypeVariable{Name: types.String("principal")} +} + +func rawActionNode() nodeTypeVariable { + return nodeTypeVariable{Name: types.String("action")} +} + +func rawResourceNode() nodeTypeVariable { + return nodeTypeVariable{Name: types.String("resource")} +} -func (v variableNode) String() types.String { - return v.value.(types.String) +func rawContextNode() nodeTypeVariable { + return nodeTypeVariable{Name: types.String("context")} } From 6f8768d9fa128a8b8d7a7c0ddf4e78b63496dd8f Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Fri, 2 Aug 2024 10:15:17 -0700 Subject: [PATCH 042/216] cedar-go/x/exp/ast: add a few tests of operator precedence Signed-off-by: philhassey --- x/exp/ast/parser_test.go | 48 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/x/exp/ast/parser_test.go b/x/exp/ast/parser_test.go index 11a92293..d035b6d1 100644 --- a/x/exp/ast/parser_test.go +++ b/x/exp/ast/parser_test.go @@ -290,6 +290,54 @@ func TestParse(t *testing.T) { when { if true then true else false };`, Permit().When(If(True(), True(), False())), }, + { + "and over or precedence", + `permit (principal, action, resource) + when { true && false || true && true };`, + Permit().When(True().And(False()).Or(True().And(True()))), + }, + { + "rel over and precedence", + `permit (principal, action, resource) + when { 1 < 2 && true };`, + Permit().When(Long(1).LessThan(Long(2)).And(True())), + }, + { + "add over rel precedence", + `permit (principal, action, resource) + when { 1 + 1 < 3 };`, + Permit().When(Long(1).Plus(Long(1)).LessThan(Long(3))), + }, + { + "mult over add precedence", + `permit (principal, action, resource) + when { 2 * 3 + 4 == 10 };`, + Permit().When(Long(2).Times(Long(3)).Plus(Long(4)).Equals(Long(10))), + }, + { + "unary over mult precedence", + `permit (principal, action, resource) + when { -2 * 3 == -6 };`, + Permit().When(Negate(Long(2)).Times(Long(3)).Equals(Negate(Long(6)))), + }, + { + "member over unary precedence", + `permit (principal, action, resource) + when { -context.num };`, + Permit().When(Negate(Context().Access("num"))), + }, + { + "member over unary precedence", + `permit (principal, action, resource) + when { -context.num };`, + Permit().When(Negate(Context().Access("num"))), + }, + { + "parens over unary precedence", + `permit (principal, action, resource) + when { -(2 + 3) == -5 };`, + Permit().When(Negate(Long(2).Plus(Long(3))).Equals(Negate(Long(5)))), + }, } for _, tt := range parseTests { From 2ea7c8e3f7e976a1daa0752a82c23bf6e9eeb37d Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Fri, 2 Aug 2024 10:36:42 -0700 Subject: [PATCH 043/216] cedar-go/x/exp/ast: fix lint violations Signed-off-by: philhassey --- x/exp/ast/node.go | 5 ----- x/exp/ast/parser_test.go | 24 ++++++++++++------------ 2 files changed, 12 insertions(+), 17 deletions(-) diff --git a/x/exp/ast/node.go b/x/exp/ast/node.go index a7411c69..6c8a0ac2 100644 --- a/x/exp/ast/node.go +++ b/x/exp/ast/node.go @@ -34,11 +34,6 @@ type nodeTypeIsIn struct { Entity node } -type nodeTypeScopeIsIn struct { - nodeTypeIs - Entity types.EntityUID -} - type nodeTypeExtMethodCall struct { node Left node diff --git a/x/exp/ast/parser_test.go b/x/exp/ast/parser_test.go index d035b6d1..8e23e014 100644 --- a/x/exp/ast/parser_test.go +++ b/x/exp/ast/parser_test.go @@ -140,12 +140,12 @@ func TestParse(t *testing.T) { when { !true };`, Permit().When(Not(Boolean(true))), }, - //{ - // "negate operator", - // `permit (principal, action, resource) - // when { -1 };`, - // Permit().When(Long(-1)), - //}, + // { + // "negate operator", + // `permit (principal, action, resource) + // when { -1 };`, + // Permit().When(Long(-1)), + // }, { "variable member", `permit (principal, action, resource) @@ -248,12 +248,12 @@ func TestParse(t *testing.T) { when { principal has "firstName" };`, Permit().When(Principal().Has("firstName")), }, - //{ - // "like no wildcards", - // `permit (principal, action, resource) - // when { principal.firstName like "johnny" };`, - // Permit().When(Principal().Has("firstName")), - //}, + // { + // "like no wildcards", + // `permit (principal, action, resource) + // when { principal.firstName like "johnny" };`, + // Permit().When(Principal().Has("firstName")), + // }, { "is", `permit (principal, action, resource) From e85e9fc938ecda5b36e4495661ae56542132c830 Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Fri, 2 Aug 2024 10:43:09 -0700 Subject: [PATCH 044/216] cedar-go/x/exp/ast: split Pattern out of tokenize.go and into its own file Signed-off-by: philhassey --- x/exp/ast/pattern.go | 45 ++++++++++++++++++++++++++ x/exp/ast/patttern_test.go | 65 ++++++++++++++++++++++++++++++++++++++ x/exp/ast/tokenize.go | 42 ------------------------ x/exp/ast/tokenize_test.go | 58 ---------------------------------- 4 files changed, 110 insertions(+), 100 deletions(-) create mode 100644 x/exp/ast/pattern.go create mode 100644 x/exp/ast/patttern_test.go diff --git a/x/exp/ast/pattern.go b/x/exp/ast/pattern.go new file mode 100644 index 00000000..9be16a54 --- /dev/null +++ b/x/exp/ast/pattern.go @@ -0,0 +1,45 @@ +package ast + +import "strings" + +type PatternComponent struct { + Star bool + Chunk string +} + +type Pattern struct { + Comps []PatternComponent + Raw string +} + +func (p Pattern) String() string { + return p.Raw +} + +func NewPattern(literal string) (Pattern, error) { + rawPat := literal + + literal = strings.TrimPrefix(literal, "\"") + literal = strings.TrimSuffix(literal, "\"") + + b := []byte(literal) + + var comps []PatternComponent + for len(b) > 0 { + var comp PatternComponent + var err error + for len(b) > 0 && b[0] == '*' { + b = b[1:] + comp.Star = true + } + comp.Chunk, b, err = rustUnquote(b, true) + if err != nil { + return Pattern{}, err + } + comps = append(comps, comp) + } + return Pattern{ + Comps: comps, + Raw: rawPat, + }, nil +} diff --git a/x/exp/ast/patttern_test.go b/x/exp/ast/patttern_test.go new file mode 100644 index 00000000..3fb03c99 --- /dev/null +++ b/x/exp/ast/patttern_test.go @@ -0,0 +1,65 @@ +package ast + +import ( + "testing" + + "github.com/cedar-policy/cedar-go/testutil" +) + +func TestPatternFromStringLiteral(t *testing.T) { + t.Parallel() + tests := []struct { + input string + wantOk bool + want []PatternComponent + wantErr string + }{ + {`""`, true, nil, ""}, + {`"a"`, true, []PatternComponent{{false, "a"}}, ""}, + {`"*"`, true, []PatternComponent{{true, ""}}, ""}, + {`"*a"`, true, []PatternComponent{{true, "a"}}, ""}, + {`"a*"`, true, []PatternComponent{{false, "a"}, {true, ""}}, ""}, + {`"**"`, true, []PatternComponent{{true, ""}}, ""}, + {`"**a"`, true, []PatternComponent{{true, "a"}}, ""}, + {`"a**"`, true, []PatternComponent{{false, "a"}, {true, ""}}, ""}, + {`"*a*"`, true, []PatternComponent{{true, "a"}, {true, ""}}, ""}, + {`"**a**"`, true, []PatternComponent{{true, "a"}, {true, ""}}, ""}, + {`"abra*ca"`, true, []PatternComponent{ + {false, "abra"}, {true, "ca"}, + }, ""}, + {`"abra**ca"`, true, []PatternComponent{ + {false, "abra"}, {true, "ca"}, + }, ""}, + {`"*abra*ca"`, true, []PatternComponent{ + {true, "abra"}, {true, "ca"}, + }, ""}, + {`"abra*ca*"`, true, []PatternComponent{ + {false, "abra"}, {true, "ca"}, {true, ""}, + }, ""}, + {`"*abra*ca*"`, true, []PatternComponent{ + {true, "abra"}, {true, "ca"}, {true, ""}, + }, ""}, + {`"*abra*ca*dabra"`, true, []PatternComponent{ + {true, "abra"}, {true, "ca"}, {true, "dabra"}, + }, ""}, + {`"*abra*c\**da\*ra"`, true, []PatternComponent{ + {true, "abra"}, {true, "c*"}, {true, "da*ra"}, + }, ""}, + {`"\u"`, false, nil, "bad unicode rune"}, + } + for _, tt := range tests { + tt := tt + t.Run(tt.input, func(t *testing.T) { + t.Parallel() + got, err := NewPattern(tt.input) + if err != nil { + testutil.Equals(t, tt.wantOk, false) + testutil.Equals(t, err.Error(), tt.wantErr) + } else { + testutil.Equals(t, tt.wantOk, true) + testutil.Equals(t, got.Comps, tt.want) + testutil.Equals(t, got.String(), tt.input) + } + }) + } +} diff --git a/x/exp/ast/tokenize.go b/x/exp/ast/tokenize.go index f4c07355..6d2cfebb 100644 --- a/x/exp/ast/tokenize.go +++ b/x/exp/ast/tokenize.go @@ -190,48 +190,6 @@ func rustUnquote(b []byte, star bool) (string, []byte, error) { return sb.String(), b[i:], nil } -type PatternComponent struct { - Star bool - Chunk string -} - -type Pattern struct { - Comps []PatternComponent - Raw string -} - -func (p Pattern) String() string { - return p.Raw -} - -func NewPattern(literal string) (Pattern, error) { - rawPat := literal - - literal = strings.TrimPrefix(literal, "\"") - literal = strings.TrimSuffix(literal, "\"") - - b := []byte(literal) - - var comps []PatternComponent - for len(b) > 0 { - var comp PatternComponent - var err error - for len(b) > 0 && b[0] == '*' { - b = b[1:] - comp.Star = true - } - comp.Chunk, b, err = rustUnquote(b, true) - if err != nil { - return Pattern{}, err - } - comps = append(comps, comp) - } - return Pattern{ - Comps: comps, - Raw: rawPat, - }, nil -} - func isHexadecimal(ch rune) bool { return isDecimal(ch) || ('a' <= lower(ch) && lower(ch) <= 'f') } diff --git a/x/exp/ast/tokenize_test.go b/x/exp/ast/tokenize_test.go index cb9292f3..0250b3ad 100644 --- a/x/exp/ast/tokenize_test.go +++ b/x/exp/ast/tokenize_test.go @@ -411,64 +411,6 @@ func TestFakeRustQuote(t *testing.T) { testutil.Equals(t, out, `"hello"`) } -func TestPatternFromStringLiteral(t *testing.T) { - t.Parallel() - tests := []struct { - input string - wantOk bool - want []PatternComponent - wantErr string - }{ - {`""`, true, nil, ""}, - {`"a"`, true, []PatternComponent{{false, "a"}}, ""}, - {`"*"`, true, []PatternComponent{{true, ""}}, ""}, - {`"*a"`, true, []PatternComponent{{true, "a"}}, ""}, - {`"a*"`, true, []PatternComponent{{false, "a"}, {true, ""}}, ""}, - {`"**"`, true, []PatternComponent{{true, ""}}, ""}, - {`"**a"`, true, []PatternComponent{{true, "a"}}, ""}, - {`"a**"`, true, []PatternComponent{{false, "a"}, {true, ""}}, ""}, - {`"*a*"`, true, []PatternComponent{{true, "a"}, {true, ""}}, ""}, - {`"**a**"`, true, []PatternComponent{{true, "a"}, {true, ""}}, ""}, - {`"abra*ca"`, true, []PatternComponent{ - {false, "abra"}, {true, "ca"}, - }, ""}, - {`"abra**ca"`, true, []PatternComponent{ - {false, "abra"}, {true, "ca"}, - }, ""}, - {`"*abra*ca"`, true, []PatternComponent{ - {true, "abra"}, {true, "ca"}, - }, ""}, - {`"abra*ca*"`, true, []PatternComponent{ - {false, "abra"}, {true, "ca"}, {true, ""}, - }, ""}, - {`"*abra*ca*"`, true, []PatternComponent{ - {true, "abra"}, {true, "ca"}, {true, ""}, - }, ""}, - {`"*abra*ca*dabra"`, true, []PatternComponent{ - {true, "abra"}, {true, "ca"}, {true, "dabra"}, - }, ""}, - {`"*abra*c\**da\*ra"`, true, []PatternComponent{ - {true, "abra"}, {true, "c*"}, {true, "da*ra"}, - }, ""}, - {`"\u"`, false, nil, "bad unicode rune"}, - } - for _, tt := range tests { - tt := tt - t.Run(tt.input, func(t *testing.T) { - t.Parallel() - got, err := NewPattern(tt.input) - if err != nil { - testutil.Equals(t, tt.wantOk, false) - testutil.Equals(t, err.Error(), tt.wantErr) - } else { - testutil.Equals(t, tt.wantOk, true) - testutil.Equals(t, got.Comps, tt.want) - testutil.Equals(t, got.String(), tt.input) - } - }) - } -} - func TestScanner(t *testing.T) { t.Parallel() t.Run("SrcError", func(t *testing.T) { From bfb957b34af07b7dbb54749e955c891fdedbd5e1 Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Fri, 2 Aug 2024 12:39:34 -0700 Subject: [PATCH 045/216] cedar-go/x/exp/ast: fix Cedar and JSON pattern parsing The documentation for the JSON pattern parsing format is incorrect. In order to facilitate relatively straightforward parsing of the JSON format, we introduce a builder for the Pattern struct that lets us add pattern components iteratively. Signed-off-by: philhassey --- testutil/testutil.go | 7 ++++ x/exp/ast/json.go | 13 ++++++-- x/exp/ast/json_marshal.go | 20 ++++++++++-- x/exp/ast/json_test.go | 48 +++++++++++++++++++++++++-- x/exp/ast/json_unmarshal.go | 27 +++++++++++++-- x/exp/ast/node.go | 7 +++- x/exp/ast/operator.go | 4 +-- x/exp/ast/parser.go | 16 +++++++-- x/exp/ast/parser_test.go | 25 ++++++++++---- x/exp/ast/pattern.go | 43 +++++++++++++++--------- x/exp/ast/patttern_test.go | 65 +++++++++++++++++++++++++------------ 11 files changed, 218 insertions(+), 57 deletions(-) diff --git a/testutil/testutil.go b/testutil/testutil.go index 6d897b67..16e407d9 100644 --- a/testutil/testutil.go +++ b/testutil/testutil.go @@ -42,3 +42,10 @@ func AssertError(t *testing.T, got, want error) { t.Helper() FatalIf(t, !errors.Is(got, want), "err got %v want %v", got, want) } + +func Must[T any](obj T, err error) T { + if err != nil { + panic(err) + } + return obj +} diff --git a/x/exp/ast/json.go b/x/exp/ast/json.go index 0000474e..226cba27 100644 --- a/x/exp/ast/json.go +++ b/x/exp/ast/json.go @@ -46,9 +46,18 @@ type strJSON struct { Attr string `json:"attr"` } +type patternComponentLiteralJSON struct { + Literal string `json:"Literal,omitempty"` +} + +type patternComponentJSON struct { + Wildcard bool + Literal patternComponentLiteralJSON +} + type patternJSON struct { - Left nodeJSON `json:"left"` - Pattern string `json:"pattern"` + Left nodeJSON `json:"left"` + Pattern []patternComponentJSON `json:"pattern"` } type isJSON struct { diff --git a/x/exp/ast/json_marshal.go b/x/exp/ast/json_marshal.go index 86fd01a5..795116d8 100644 --- a/x/exp/ast/json_marshal.go +++ b/x/exp/ast/json_marshal.go @@ -117,12 +117,19 @@ func strToJSON(dest **strJSON, src strOpNode) error { return nil } -func patternToJSON(dest **patternJSON, src strOpNode) error { +func patternToJSON(dest **patternJSON, src nodeTypeLike) error { res := &patternJSON{} if err := res.Left.FromNode(src.Arg); err != nil { return fmt.Errorf("error in left: %w", err) } - res.Pattern = string(src.Value) + for _, comp := range src.Value.Comps { + if comp.Star { + res.Pattern = append(res.Pattern, patternComponentJSON{Wildcard: true}) + } + if comp.Chunk != "" { + res.Pattern = append(res.Pattern, patternComponentJSON{Literal: patternComponentLiteralJSON{Literal: comp.Chunk}}) + } + } *dest = res return nil } @@ -260,7 +267,7 @@ func (j *nodeJSON) FromNode(src node) error { // like // Like *strJSON `json:"like"` case nodeTypeLike: - return patternToJSON(&j.Like, t.strOpNode) + return patternToJSON(&j.Like, t) // if-then-else // IfThenElse *ifThenElseJSON `json:"if-then-else"` @@ -300,6 +307,13 @@ func (j *nodeJSON) MarshalJSON() ([]byte, error) { return json.Marshal((*nodeJSONAlias)(j)) } +func (p *patternComponentJSON) MarshalJSON() ([]byte, error) { + if p.Wildcard { + return json.Marshal("Wildcard") + } + return json.Marshal(p.Literal) +} + func (p *Policy) MarshalJSON() ([]byte, error) { var j policyJSON j.Effect = "forbid" diff --git a/x/exp/ast/json_test.go b/x/exp/ast/json_test.go index efe03090..a22265a9 100644 --- a/x/exp/ast/json_test.go +++ b/x/exp/ast/json_test.go @@ -408,10 +408,52 @@ func TestUnmarshalJSON(t *testing.T) { testutil.OK, }, { - "like", + "like single wildcard", `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, - "conditions":[{"kind":"when","body":{"like":{"left":{"Value":"text"},"pattern":"*"}}}]}`, - ast.Permit().When(ast.String("text").Like("*")), + "conditions":[{"kind":"when","body":{"like":{"left":{"Value":"text"},"pattern":["Wildcard"]}}}]}`, + ast.Permit().When(ast.String("text").Like(testutil.Must(ast.PatternFromCedar("*")))), + testutil.OK, + }, + { + "like single literal", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, + "conditions":[{"kind":"when","body":{"like":{"left":{"Value":"text"},"pattern":[{"Literal":"foo"}]}}}]}`, + ast.Permit().When(ast.String("text").Like(testutil.Must(ast.PatternFromCedar("foo")))), + testutil.OK, + }, + { + "like wildcard then literal", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, + "conditions":[{"kind":"when","body":{"like":{"left":{"Value":"text"},"pattern":["Wildcard", {"Literal":"foo"}]}}}]}`, + ast.Permit().When(ast.String("text").Like(testutil.Must(ast.PatternFromCedar("*foo")))), + testutil.OK, + }, + { + "like literal then wildcard", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, + "conditions":[{"kind":"when","body":{"like":{"left":{"Value":"text"},"pattern":[{"Literal":"foo"}, "Wildcard"]}}}]}`, + ast.Permit().When(ast.String("text").Like(testutil.Must(ast.PatternFromCedar("foo*")))), + testutil.OK, + }, + { + "like literal with asterisk then wildcard", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, + "conditions":[{"kind":"when","body":{"like":{"left":{"Value":"text"},"pattern":[{"Literal":"f*oo"}, "Wildcard"]}}}]}`, + ast.Permit().When(ast.String("text").Like(testutil.Must(ast.PatternFromCedar(`f\*oo*`)))), + testutil.OK, + }, + { + "like literal sandwich", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, + "conditions":[{"kind":"when","body":{"like":{"left":{"Value":"text"},"pattern":[{"Literal":"foo"}, "Wildcard", {"Literal":"bar"}]}}}]}`, + ast.Permit().When(ast.String("text").Like(testutil.Must(ast.PatternFromCedar(`foo*bar`)))), + testutil.OK, + }, + { + "like wildcard sandwich", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, + "conditions":[{"kind":"when","body":{"like":{"left":{"Value":"text"},"pattern":["Wildcard", {"Literal":"foo"}, "Wildcard"]}}}]}`, + ast.Permit().When(ast.String("text").Like(testutil.Must(ast.PatternFromCedar(`*foo*`)))), testutil.OK, }, { diff --git a/x/exp/ast/json_unmarshal.go b/x/exp/ast/json_unmarshal.go index 13f13bad..0cd5e8f3 100644 --- a/x/exp/ast/json_unmarshal.go +++ b/x/exp/ast/json_unmarshal.go @@ -58,12 +58,21 @@ func (j strJSON) ToNode(f func(a Node, k string) Node) (Node, error) { } return f(left, j.Attr), nil } -func (j patternJSON) ToNode(f func(a Node, k string) Node) (Node, error) { +func (j patternJSON) ToNode(f func(a Node, k Pattern) Node) (Node, error) { left, err := j.Left.ToNode() if err != nil { return Node{}, fmt.Errorf("error in left: %w", err) } - return f(left, j.Pattern), nil + pattern := &Pattern{} + for _, compJSON := range j.Pattern { + if compJSON.Wildcard { + pattern = pattern.AddWildcard() + } else { + pattern = pattern.AddLiteral(compJSON.Literal.Literal) + } + } + + return f(left, *pattern), nil } func (j isJSON) ToNode() (Node, error) { left, err := j.Left.ToNode() @@ -301,6 +310,20 @@ func (n *nodeJSON) UnmarshalJSON(b []byte) error { return json.Unmarshal(b, &n.ExtensionMethod) } +func (p *patternComponentJSON) UnmarshalJSON(b []byte) error { + var wildcard string + err := json.Unmarshal(b, &wildcard) + if err == nil { + if wildcard != "Wildcard" { + return fmt.Errorf("unknown pattern component: \"%v\"", wildcard) + } + p.Wildcard = true + return nil + } + + return json.Unmarshal(b, &p.Literal) +} + func (p *Policy) UnmarshalJSON(b []byte) error { var j policyJSON if err := json.Unmarshal(b, &j); err != nil { diff --git a/x/exp/ast/node.go b/x/exp/ast/node.go index 6c8a0ac2..5a2e2c89 100644 --- a/x/exp/ast/node.go +++ b/x/exp/ast/node.go @@ -10,7 +10,12 @@ type strOpNode struct { type nodeTypeAccess struct{ strOpNode } type nodeTypeHas struct{ strOpNode } -type nodeTypeLike struct{ strOpNode } + +type nodeTypeLike struct { + node + Arg node + Value Pattern +} type nodeTypeAnnotation struct { node diff --git a/x/exp/ast/operator.go b/x/exp/ast/operator.go index e27bf01d..52b3584e 100644 --- a/x/exp/ast/operator.go +++ b/x/exp/ast/operator.go @@ -49,8 +49,8 @@ func (lhs Node) GreaterThanOrEqualExt(rhs Node) Node { return newExtMethodCallNode(lhs, "greaterThanOrEqual", rhs) } -func (lhs Node) Like(patt string) Node { - return newNode(nodeTypeLike{strOpNode: strOpNode{Arg: lhs.v, Value: types.String(patt)}}) +func (lhs Node) Like(pattern Pattern) Node { + return newNode(nodeTypeLike{Arg: lhs.v, Value: pattern}) } // _ _ _ diff --git a/x/exp/ast/parser.go b/x/exp/ast/parser.go index 08d46ac3..75da1566 100644 --- a/x/exp/ast/parser.go +++ b/x/exp/ast/parser.go @@ -4,6 +4,7 @@ import ( "fmt" "net/netip" "strconv" + "strings" "github.com/cedar-policy/cedar-go/types" ) @@ -482,8 +483,19 @@ func (p *parser) relation() (Node, error) { } return Node{}, p.errorf("expected ident or string") } else if t.Text == "like" { - // TODO: Deal with pattern matching - return Node{}, p.errorf("unimplemented") + p.advance() + t = p.advance() + if !t.isString() { + return Node{}, p.errorf("expected string literal") + } + patternRaw := t.Text + patternRaw = strings.TrimPrefix(patternRaw, "\"") + patternRaw = strings.TrimSuffix(patternRaw, "\"") + pattern, err := PatternFromCedar(patternRaw) + if err != nil { + return Node{}, err + } + return lhs.Like(pattern), nil } else if t.Text == "is" { p.advance() entityType, err := p.path() diff --git a/x/exp/ast/parser_test.go b/x/exp/ast/parser_test.go index 8e23e014..b7b03c49 100644 --- a/x/exp/ast/parser_test.go +++ b/x/exp/ast/parser_test.go @@ -248,12 +248,25 @@ func TestParse(t *testing.T) { when { principal has "firstName" };`, Permit().When(Principal().Has("firstName")), }, - // { - // "like no wildcards", - // `permit (principal, action, resource) - // when { principal.firstName like "johnny" };`, - // Permit().When(Principal().Has("firstName")), - // }, + // N.B. Most pattern parsing tests can be found in pattern_test.go + { + "like no wildcards", + `permit (principal, action, resource) + when { principal.firstName like "johnny" };`, + Permit().When(Principal().Access("firstName").Like(testutil.Must(PatternFromCedar("johnny")))), + }, + { + "like escaped asterisk", + `permit (principal, action, resource) + when { principal.firstName like "joh\*nny" };`, + Permit().When(Principal().Access("firstName").Like(testutil.Must(PatternFromCedar(`joh\*nny`)))), + }, + { + "like wildcard", + `permit (principal, action, resource) + when { principal.firstName like "*" };`, + Permit().When(Principal().Access("firstName").Like(testutil.Must(PatternFromCedar("*")))), + }, { "is", `permit (principal, action, resource) diff --git a/x/exp/ast/pattern.go b/x/exp/ast/pattern.go index 9be16a54..1f19f5a5 100644 --- a/x/exp/ast/pattern.go +++ b/x/exp/ast/pattern.go @@ -1,7 +1,5 @@ package ast -import "strings" - type PatternComponent struct { Star bool Chunk string @@ -9,20 +7,10 @@ type PatternComponent struct { type Pattern struct { Comps []PatternComponent - Raw string -} - -func (p Pattern) String() string { - return p.Raw } -func NewPattern(literal string) (Pattern, error) { - rawPat := literal - - literal = strings.TrimPrefix(literal, "\"") - literal = strings.TrimSuffix(literal, "\"") - - b := []byte(literal) +func PatternFromCedar(cedar string) (Pattern, error) { + b := []byte(cedar) var comps []PatternComponent for len(b) > 0 { @@ -40,6 +28,31 @@ func NewPattern(literal string) (Pattern, error) { } return Pattern{ Comps: comps, - Raw: rawPat, }, nil } + +func (p *Pattern) AddWildcard() *Pattern { + star := PatternComponent{Star: true} + if len(p.Comps) == 0 { + p.Comps = []PatternComponent{star} + return p + } + + lastComp := p.Comps[len(p.Comps)-1] + if lastComp.Star && lastComp.Chunk == "" { + return p + } + + p.Comps = append(p.Comps, star) + return p +} + +func (p *Pattern) AddLiteral(s string) *Pattern { + if len(p.Comps) == 0 { + p.Comps = []PatternComponent{{}} + } + + lastComp := &p.Comps[len(p.Comps)-1] + lastComp.Chunk = lastComp.Chunk + s + return p +} diff --git a/x/exp/ast/patttern_test.go b/x/exp/ast/patttern_test.go index 3fb03c99..2eeecbe4 100644 --- a/x/exp/ast/patttern_test.go +++ b/x/exp/ast/patttern_test.go @@ -6,7 +6,7 @@ import ( "github.com/cedar-policy/cedar-go/testutil" ) -func TestPatternFromStringLiteral(t *testing.T) { +func TestPatternFromCedar(t *testing.T) { t.Parallel() tests := []struct { input string @@ -14,52 +14,75 @@ func TestPatternFromStringLiteral(t *testing.T) { want []PatternComponent wantErr string }{ - {`""`, true, nil, ""}, - {`"a"`, true, []PatternComponent{{false, "a"}}, ""}, - {`"*"`, true, []PatternComponent{{true, ""}}, ""}, - {`"*a"`, true, []PatternComponent{{true, "a"}}, ""}, - {`"a*"`, true, []PatternComponent{{false, "a"}, {true, ""}}, ""}, - {`"**"`, true, []PatternComponent{{true, ""}}, ""}, - {`"**a"`, true, []PatternComponent{{true, "a"}}, ""}, - {`"a**"`, true, []PatternComponent{{false, "a"}, {true, ""}}, ""}, - {`"*a*"`, true, []PatternComponent{{true, "a"}, {true, ""}}, ""}, - {`"**a**"`, true, []PatternComponent{{true, "a"}, {true, ""}}, ""}, - {`"abra*ca"`, true, []PatternComponent{ + {"", true, nil, ""}, + {"a", true, []PatternComponent{{false, "a"}}, ""}, + {"*", true, []PatternComponent{{true, ""}}, ""}, + {"*a", true, []PatternComponent{{true, "a"}}, ""}, + {"a*", true, []PatternComponent{{false, "a"}, {true, ""}}, ""}, + {"**", true, []PatternComponent{{true, ""}}, ""}, + {"**a", true, []PatternComponent{{true, "a"}}, ""}, + {"a**", true, []PatternComponent{{false, "a"}, {true, ""}}, ""}, + {"*a*", true, []PatternComponent{{true, "a"}, {true, ""}}, ""}, + {"**a**", true, []PatternComponent{{true, "a"}, {true, ""}}, ""}, + {"abra*ca", true, []PatternComponent{ {false, "abra"}, {true, "ca"}, }, ""}, - {`"abra**ca"`, true, []PatternComponent{ + {"abra**ca", true, []PatternComponent{ {false, "abra"}, {true, "ca"}, }, ""}, - {`"*abra*ca"`, true, []PatternComponent{ + {"*abra*ca", true, []PatternComponent{ {true, "abra"}, {true, "ca"}, }, ""}, - {`"abra*ca*"`, true, []PatternComponent{ + {"abra*ca*", true, []PatternComponent{ {false, "abra"}, {true, "ca"}, {true, ""}, }, ""}, - {`"*abra*ca*"`, true, []PatternComponent{ + {"*abra*ca*", true, []PatternComponent{ {true, "abra"}, {true, "ca"}, {true, ""}, }, ""}, - {`"*abra*ca*dabra"`, true, []PatternComponent{ + {"*abra*ca*dabra", true, []PatternComponent{ {true, "abra"}, {true, "ca"}, {true, "dabra"}, }, ""}, - {`"*abra*c\**da\*ra"`, true, []PatternComponent{ + {`*abra*c\**da\*ra`, true, []PatternComponent{ {true, "abra"}, {true, "c*"}, {true, "da*ra"}, }, ""}, - {`"\u"`, false, nil, "bad unicode rune"}, + {`\u`, false, nil, "bad unicode rune"}, } for _, tt := range tests { tt := tt t.Run(tt.input, func(t *testing.T) { t.Parallel() - got, err := NewPattern(tt.input) + got, err := PatternFromCedar(tt.input) if err != nil { testutil.Equals(t, tt.wantOk, false) testutil.Equals(t, err.Error(), tt.wantErr) } else { testutil.Equals(t, tt.wantOk, true) testutil.Equals(t, got.Comps, tt.want) - testutil.Equals(t, got.String(), tt.input) } }) } } + +func TestPatternFromBuilder(t *testing.T) { + tests := []struct { + name string + Pattern *Pattern + want []PatternComponent + }{ + {"empty", &Pattern{}, nil}, + {"wildcard", (&Pattern{}).AddWildcard(), []PatternComponent{{Star: true}}}, + {"saturate two wildcards", (&Pattern{}).AddWildcard().AddWildcard(), []PatternComponent{{Star: true}}}, + {"literal", (&Pattern{}).AddLiteral("foo"), []PatternComponent{{Chunk: "foo"}}}, + {"saturate two literals", (&Pattern{}).AddLiteral("foo").AddLiteral("bar"), []PatternComponent{{Chunk: "foobar"}}}, + {"literal with asterisk", (&Pattern{}).AddLiteral("fo*o"), []PatternComponent{{Chunk: "fo*o"}}}, + {"wildcard sandwich", (&Pattern{}).AddLiteral("foo").AddWildcard().AddLiteral("bar"), []PatternComponent{{Chunk: "foo"}, {Star: true, Chunk: "bar"}}}, + {"literal sandwich", (&Pattern{}).AddWildcard().AddLiteral("foo").AddWildcard(), []PatternComponent{{Star: true, Chunk: "foo"}, {Star: true}}}, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + testutil.Equals(t, tt.Pattern.Comps, tt.want) + }) + } +} From 058f76cd0d7143d24a83217596e8eb6b0fe65eaf Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Fri, 2 Aug 2024 13:35:03 -0700 Subject: [PATCH 046/216] cedar-go/x/exp/ast: enable test of negate operator Let's punt the optimization to convert -1 to Long(-1) instead of Negate(Long(1)). We can potentially use similar code to accomplish the same thing in both the JSON and Cedar parsers. Signed-off-by: philhassey --- x/exp/ast/parser.go | 5 ++--- x/exp/ast/parser_test.go | 12 ++++++------ 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/x/exp/ast/parser.go b/x/exp/ast/parser.go index 75da1566..c91757ff 100644 --- a/x/exp/ast/parser.go +++ b/x/exp/ast/parser.go @@ -561,8 +561,8 @@ func (p *parser) unary() (Node, error) { var res Node var ops [](func(Node) Node) for { - op := p.peek().Text - switch op { + op := p.peek() + switch op.Text { case "!": p.advance() ops = append(ops, Not) @@ -576,7 +576,6 @@ func (p *parser) unary() (Node, error) { return res, err } - // TODO: add support for parsing -1 into a negative Long rather than a Negate(Long) for i := len(ops) - 1; i >= 0; i-- { res = ops[i](res) } diff --git a/x/exp/ast/parser_test.go b/x/exp/ast/parser_test.go index b7b03c49..d69a0ad9 100644 --- a/x/exp/ast/parser_test.go +++ b/x/exp/ast/parser_test.go @@ -140,12 +140,12 @@ func TestParse(t *testing.T) { when { !true };`, Permit().When(Not(Boolean(true))), }, - // { - // "negate operator", - // `permit (principal, action, resource) - // when { -1 };`, - // Permit().When(Long(-1)), - // }, + { + "negate operator", + `permit (principal, action, resource) + when { -1 };`, + Permit().When(Negate(Long(1))), + }, { "variable member", `permit (principal, action, resource) From 368c991fd390dd59776b6700930b713d21e7d15d Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Fri, 2 Aug 2024 13:37:50 -0700 Subject: [PATCH 047/216] cedar-go/x/exp/ast: remove the unused FakeRustQuote This will be replaced when we start doing AST -> Cedar conversions. Signed-off-by: philhassey --- x/exp/ast/tokenize.go | 5 ----- x/exp/ast/tokenize_test.go | 6 ------ 2 files changed, 11 deletions(-) diff --git a/x/exp/ast/tokenize.go b/x/exp/ast/tokenize.go index 6d2cfebb..8f8fa319 100644 --- a/x/exp/ast/tokenize.go +++ b/x/exp/ast/tokenize.go @@ -194,11 +194,6 @@ func isHexadecimal(ch rune) bool { return isDecimal(ch) || ('a' <= lower(ch) && lower(ch) <= 'f') } -// TODO: make FakeRustQuote actually accurate in all cases -func FakeRustQuote(s string) string { - return strconv.Quote(s) -} - func (t Token) intValue() (int64, error) { return strconv.ParseInt(t.Text, 10, 64) } diff --git a/x/exp/ast/tokenize_test.go b/x/exp/ast/tokenize_test.go index 0250b3ad..fbf9db5a 100644 --- a/x/exp/ast/tokenize_test.go +++ b/x/exp/ast/tokenize_test.go @@ -405,12 +405,6 @@ func TestRustUnquote(t *testing.T) { } } -func TestFakeRustQuote(t *testing.T) { - t.Parallel() - out := FakeRustQuote("hello") - testutil.Equals(t, out, `"hello"`) -} - func TestScanner(t *testing.T) { t.Parallel() t.Run("SrcError", func(t *testing.T) { From d729c1eb912bd6524528997843b6f70574c23130 Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Fri, 2 Aug 2024 13:43:04 -0700 Subject: [PATCH 048/216] cedar-go/x/exp/ast: port over fuzz tests from old parser Signed-off-by: philhassey --- x/exp/ast/fuzz_test.go | 106 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 106 insertions(+) create mode 100644 x/exp/ast/fuzz_test.go diff --git a/x/exp/ast/fuzz_test.go b/x/exp/ast/fuzz_test.go new file mode 100644 index 00000000..cc4d9e11 --- /dev/null +++ b/x/exp/ast/fuzz_test.go @@ -0,0 +1,106 @@ +package ast + +import ( + "testing" +) + +// https://go.dev/doc/tutorial/fuzz +// mkdir testdata +// go test -fuzz=FuzzTokenize -fuzztime 60s +// go test -fuzz=FuzzParse -fuzztime 60s + +func FuzzTokenize(f *testing.F) { + tests := []string{ + `These are some identifiers`, + `0 1 1234`, + `-1 9223372036854775807 -9223372036854775808`, + `"" "string" "\"\'\n\r\t\\\0" "\x123" "\u{0}\u{10fFfF}"`, + `"*" "\*" "*\**"`, + `@.,;(){}[]+-*`, + `:::`, + `!!=<<=>>=`, + `||&&`, + `// single line comment`, + `/*`, + `multiline comment`, + `// embedded comment does nothing`, + `*/`, + `'/%|&=`, + } + for _, tt := range tests { + f.Add(tt) + } + f.Fuzz(func(t *testing.T, orig string) { + toks, err := Tokenize([]byte(orig)) + if err != nil { + if toks != nil { + t.Errorf("toks != nil on err") + } + } + }) +} + +func FuzzParse(f *testing.F) { + tests := []string{ + `permit(principal,action,resource);`, + `forbid(principal,action,resource);`, + `permit(principal,action,resource in asdf::"1234");`, + `permit(principal,action,resource) when { resource in "foo" };`, + `permit(principal,action,resource) when { context.x == 42 };`, + `permit(principal,action,resource) when { context.x == 42 };`, + `permit(principal,action,resource) when { principal.x == 42 };`, + `permit(principal,action,resource) when { principal.x == 42 };`, + `permit(principal,action,resource) when { principal in parent::"bob" };`, + `permit(principal == coder::"cuzco",action,resource);`, + `permit(principal in team::"osiris",action,resource);`, + `permit(principal,action == table::"drop",resource);`, + `permit(principal,action in scary::"stuff",resource);`, + `permit(principal,action in [scary::"stuff"],resource);`, + `permit(principal,action,resource == table::"whatever");`, + `permit(principal,action,resource) unless { false };`, + `permit(principal,action,resource) when { (if true then true else true) };`, + `permit(principal,action,resource) when { (true || false) };`, + `permit(principal,action,resource) when { (true && true) };`, + `permit(principal,action,resource) when { (1<2) && (1<=1) && (2>1) && (1>=1) && (1!=2) && (1==1)};`, + `permit(principal,action,resource) when { principal in principal };`, + `permit(principal,action,resource) when { principal has name };`, + `permit(principal,action,resource) when { 40+3-1==42 };`, + `permit(principal,action,resource) when { 6*7==42 };`, + `permit(principal,action,resource) when { -42==-42 };`, + `permit(principal,action,resource) when { !(1+1==42) };`, + `permit(principal,action,resource) when { [1,2,3].contains(2) };`, + `permit(principal,action,resource) when { {name:"bob"} has name };`, + `permit(principal,action,resource) when { action in action };`, + `permit(principal,action,resource) when { [1,2,3].contains(2) };`, + `permit(principal,action,resource) when { [1,2,3].contains(2,3) };`, + `permit(principal,action,resource) when { [1,2,3].containsAll([2,3]) };`, + `permit(principal,action,resource) when { [1,2,3].containsAll(2,3) };`, + `permit(principal,action,resource) when { [1,2,3].containsAny([2,5]) };`, + `permit(principal,action,resource) when { [1,2,3].containsAny(2,5) };`, + `permit(principal,action,resource) when { {name:"bob"}["name"] == "bob" };`, + `permit(principal,action,resource) when { [1,2,3].shuffle() };`, + `permit(principal,action,resource) when { "bananas" like "*nan*" };`, + `permit(principal,action,resource) when { fooBar("10") };`, + `permit(principal,action,resource) when { decimal(1, 2) };`, + `permit(principal,action,resource) when { ip() };`, + `permit(principal,action,resource) when { ip("1.2.3.4").isIpv4(true) };`, + `permit(principal,action,resource) when { ip("1.2.3.4").isIpv6(true) };`, + `permit(principal,action,resource) when { ip("1.2.3.4").isLoopback(true) };`, + `permit(principal,action,resource) when { ip("1.2.3.4").isMulticast(true) };`, + `permit(principal,action,resource) when { ip("1.2.3.4").isInRange() };`, + } + for _, tt := range tests { + f.Add(tt) + } + f.Fuzz(func(_ *testing.T, orig string) { + tokens, err := Tokenize([]byte(orig)) + if err != nil { + return + } + + parser := newParser(tokens) + + // intentionally ignore parse errors + _, _ = policyFromCedar(&parser) + }) +} From 515f04703c7f2e23d5f2594215ac128234552ff8 Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Fri, 2 Aug 2024 16:15:29 -0700 Subject: [PATCH 049/216] cedar-go/x/exp/ast: refactor parser.unary() a bit for brevity Signed-off-by: philhassey --- x/exp/ast/parser.go | 43 ++++++++++++++++++++-------------------- x/exp/ast/parser_test.go | 6 ++++++ 2 files changed, 28 insertions(+), 21 deletions(-) diff --git a/x/exp/ast/parser.go b/x/exp/ast/parser.go index c91757ff..f8011a58 100644 --- a/x/exp/ast/parser.go +++ b/x/exp/ast/parser.go @@ -558,30 +558,31 @@ func (p *parser) mult() (Node, error) { } func (p *parser) unary() (Node, error) { - var res Node - var ops [](func(Node) Node) - for { - op := p.peek() - switch op.Text { - case "!": - p.advance() - ops = append(ops, Not) - case "-": - p.advance() - ops = append(ops, Negate) - default: - var err error - res, err = p.member() - if err != nil { - return res, err - } + opMap := map[string]func(Node) Node{ + "-": Negate, + "!": Not, + } - for i := len(ops) - 1; i >= 0; i-- { - res = ops[i](res) - } - return res, nil + var ops []func(Node) Node + for { + opToken := p.peek() + op, ok := opMap[opToken.Text] + if !ok { + break } + p.advance() + ops = append(ops, op) + } + + res, err := p.member() + if err != nil { + return res, err } + + for i := len(ops) - 1; i >= 0; i-- { + res = ops[i](res) + } + return res, nil } func (p *parser) member() (Node, error) { diff --git a/x/exp/ast/parser_test.go b/x/exp/ast/parser_test.go index d69a0ad9..d8170433 100644 --- a/x/exp/ast/parser_test.go +++ b/x/exp/ast/parser_test.go @@ -146,6 +146,12 @@ func TestParse(t *testing.T) { when { -1 };`, Permit().When(Negate(Long(1))), }, + { + "mutliple negate operators", + `permit (principal, action, resource) + when { !--1 };`, + Permit().When(Not(Negate(Negate(Long(1))))), + }, { "variable member", `permit (principal, action, resource) From 235d91b8c2eefb9aa0c1ce12d6f7d0420c8e9998 Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Fri, 2 Aug 2024 16:47:15 -0700 Subject: [PATCH 050/216] cedar-go/x/exp/ast: introduce a way to parse a whole policy document Signed-off-by: philhassey --- x/exp/ast/fuzz_test.go | 10 +---- x/exp/ast/parser.go | 80 ++++++++++++++++++++++++++++------------ x/exp/ast/parser_test.go | 52 ++++++++++++++++++++++---- x/exp/ast/policy.go | 2 + x/exp/ast/tokenize.go | 4 ++ 5 files changed, 108 insertions(+), 40 deletions(-) diff --git a/x/exp/ast/fuzz_test.go b/x/exp/ast/fuzz_test.go index cc4d9e11..eccb1251 100644 --- a/x/exp/ast/fuzz_test.go +++ b/x/exp/ast/fuzz_test.go @@ -93,14 +93,8 @@ func FuzzParse(f *testing.F) { f.Add(tt) } f.Fuzz(func(_ *testing.T, orig string) { - tokens, err := Tokenize([]byte(orig)) - if err != nil { - return - } - - parser := newParser(tokens) - // intentionally ignore parse errors - _, _ = policyFromCedar(&parser) + var policy Policy + _ = policy.FromCedar([]byte(orig)) }) } diff --git a/x/exp/ast/parser.go b/x/exp/ast/parser.go index f8011a58..50a8270c 100644 --- a/x/exp/ast/parser.go +++ b/x/exp/ast/parser.go @@ -9,46 +9,78 @@ import ( "github.com/cedar-policy/cedar-go/types" ) -func policyFromCedar(p *parser) (*Policy, error) { - annotations, err := p.annotations() +func (p *PolicySet) FromCedar(b []byte) error { + tokens, err := Tokenize(b) if err != nil { - return nil, err + return err + } + + var policySet PolicySet + parser := newParser(tokens) + for !parser.peek().isEOF() { + var policy Policy + if err = policy.fromCedarWithParser(&parser); err != nil { + return err + } + + policySet = append(policySet, policy) } - policy, err := p.effect(&annotations) + *p = policySet + return nil +} + +func (p *Policy) FromCedar(b []byte) error { + tokens, err := Tokenize(b) if err != nil { - return nil, err + return err } - if err = p.exact("("); err != nil { - return nil, err + parser := newParser(tokens) + return p.fromCedarWithParser(&parser) +} + +func (p *Policy) fromCedarWithParser(parser *parser) error { + annotations, err := parser.annotations() + if err != nil { + return err + } + + newPolicy, err := parser.effect(&annotations) + if err != nil { + return err + } + + if err = parser.exact("("); err != nil { + return err } - if err = p.principal(policy); err != nil { - return nil, err + if err = parser.principal(newPolicy); err != nil { + return err } - if err = p.exact(","); err != nil { - return nil, err + if err = parser.exact(","); err != nil { + return err } - if err = p.action(policy); err != nil { - return nil, err + if err = parser.action(newPolicy); err != nil { + return err } - if err = p.exact(","); err != nil { - return nil, err + if err = parser.exact(","); err != nil { + return err } - if err = p.resource(policy); err != nil { - return nil, err + if err = parser.resource(newPolicy); err != nil { + return err } - if err = p.exact(")"); err != nil { - return nil, err + if err = parser.exact(")"); err != nil { + return err } - if err = p.conditions(policy); err != nil { - return nil, err + if err = parser.conditions(newPolicy); err != nil { + return err } - if err = p.exact(";"); err != nil { - return nil, err + if err = parser.exact(";"); err != nil { + return err } - return policy, nil + *p = *newPolicy + return nil } type parser struct { diff --git a/x/exp/ast/parser_test.go b/x/exp/ast/parser_test.go index d8170433..50001cbd 100644 --- a/x/exp/ast/parser_test.go +++ b/x/exp/ast/parser_test.go @@ -36,7 +36,7 @@ var malus = types.EntityUID{ ID: "malus", } -func TestParse(t *testing.T) { +func TestParsePolicy(t *testing.T) { t.Parallel() parseTests := []struct { Name string @@ -363,15 +363,51 @@ func TestParse(t *testing.T) { t.Run(tt.Name, func(t *testing.T) { t.Parallel() - tokens, err := Tokenize([]byte(tt.Text)) - testutil.OK(t, err) - - parser := newParser(tokens) + var policy Policy + testutil.OK(t, policy.FromCedar([]byte(tt.Text))) + testutil.Equals(t, policy, *tt.ExpectedPolicy) + }) + } +} - policy, err := policyFromCedar(&parser) - testutil.OK(t, err) +func TestParsePolicySet(t *testing.T) { + t.Parallel() + parseTests := []struct { + Name string + Text string + ExpectedPolicies PolicySet + }{ + { + "single policy", + `permit ( + principal, + action, + resource + );`, + PolicySet{*Permit()}, + }, + { + "two policies", + `permit ( + principal, + action, + resource + ); + forbid ( + principal, + action, + resource + );`, + PolicySet{*Permit(), *Forbid()}, + }, + } + for _, tt := range parseTests { + t.Run(tt.Name, func(t *testing.T) { + t.Parallel() - testutil.Equals(t, policy, tt.ExpectedPolicy) + var policies PolicySet + testutil.OK(t, policies.FromCedar([]byte(tt.Text))) + testutil.Equals(t, policies, tt.ExpectedPolicies) }) } } diff --git a/x/exp/ast/policy.go b/x/exp/ast/policy.go index 29305c6b..e1347a31 100644 --- a/x/exp/ast/policy.go +++ b/x/exp/ast/policy.go @@ -1,5 +1,7 @@ package ast +type PolicySet []Policy + type Policy struct { effect effect annotations []nodeTypeAnnotation diff --git a/x/exp/ast/tokenize.go b/x/exp/ast/tokenize.go index 8f8fa319..97101b34 100644 --- a/x/exp/ast/tokenize.go +++ b/x/exp/ast/tokenize.go @@ -32,6 +32,10 @@ type Token struct { Text string } +func (t Token) isEOF() bool { + return t.Type == TokenEOF +} + func (t Token) isIdent() bool { return t.Type == TokenIdent } From ee3c44422e0bcc9f5104b5193bd2876806383046 Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Fri, 2 Aug 2024 16:49:14 -0700 Subject: [PATCH 051/216] cedar-go/x/exp/ast: rename FromCedar to UnmarshalCedar, parser.go to cedar_unmarshal.go, and tokenize.go to cedar_tokenize.go Signed-off-by: philhassey --- x/exp/ast/{fuzz_test.go => cedar_fuzz_test.go} | 2 +- x/exp/ast/{tokenize.go => cedar_tokenize.go} | 0 .../{tokenize_mocks_test.go => cedar_tokenize_mocks_test.go} | 0 x/exp/ast/{tokenize_test.go => cedar_tokenize_test.go} | 0 x/exp/ast/{parser.go => cedar_unmarshal.go} | 4 ++-- x/exp/ast/{parser_test.go => cedar_unmarshal_test.go} | 4 ++-- 6 files changed, 5 insertions(+), 5 deletions(-) rename x/exp/ast/{fuzz_test.go => cedar_fuzz_test.go} (98%) rename x/exp/ast/{tokenize.go => cedar_tokenize.go} (100%) rename x/exp/ast/{tokenize_mocks_test.go => cedar_tokenize_mocks_test.go} (100%) rename x/exp/ast/{tokenize_test.go => cedar_tokenize_test.go} (100%) rename x/exp/ast/{parser.go => cedar_unmarshal.go} (99%) rename x/exp/ast/{parser_test.go => cedar_unmarshal_test.go} (98%) diff --git a/x/exp/ast/fuzz_test.go b/x/exp/ast/cedar_fuzz_test.go similarity index 98% rename from x/exp/ast/fuzz_test.go rename to x/exp/ast/cedar_fuzz_test.go index eccb1251..bc6bb523 100644 --- a/x/exp/ast/fuzz_test.go +++ b/x/exp/ast/cedar_fuzz_test.go @@ -95,6 +95,6 @@ func FuzzParse(f *testing.F) { f.Fuzz(func(_ *testing.T, orig string) { // intentionally ignore parse errors var policy Policy - _ = policy.FromCedar([]byte(orig)) + _ = policy.UnmarshalCedar([]byte(orig)) }) } diff --git a/x/exp/ast/tokenize.go b/x/exp/ast/cedar_tokenize.go similarity index 100% rename from x/exp/ast/tokenize.go rename to x/exp/ast/cedar_tokenize.go diff --git a/x/exp/ast/tokenize_mocks_test.go b/x/exp/ast/cedar_tokenize_mocks_test.go similarity index 100% rename from x/exp/ast/tokenize_mocks_test.go rename to x/exp/ast/cedar_tokenize_mocks_test.go diff --git a/x/exp/ast/tokenize_test.go b/x/exp/ast/cedar_tokenize_test.go similarity index 100% rename from x/exp/ast/tokenize_test.go rename to x/exp/ast/cedar_tokenize_test.go diff --git a/x/exp/ast/parser.go b/x/exp/ast/cedar_unmarshal.go similarity index 99% rename from x/exp/ast/parser.go rename to x/exp/ast/cedar_unmarshal.go index 50a8270c..a55eae91 100644 --- a/x/exp/ast/parser.go +++ b/x/exp/ast/cedar_unmarshal.go @@ -9,7 +9,7 @@ import ( "github.com/cedar-policy/cedar-go/types" ) -func (p *PolicySet) FromCedar(b []byte) error { +func (p *PolicySet) UnmarshalCedar(b []byte) error { tokens, err := Tokenize(b) if err != nil { return err @@ -30,7 +30,7 @@ func (p *PolicySet) FromCedar(b []byte) error { return nil } -func (p *Policy) FromCedar(b []byte) error { +func (p *Policy) UnmarshalCedar(b []byte) error { tokens, err := Tokenize(b) if err != nil { return err diff --git a/x/exp/ast/parser_test.go b/x/exp/ast/cedar_unmarshal_test.go similarity index 98% rename from x/exp/ast/parser_test.go rename to x/exp/ast/cedar_unmarshal_test.go index 50001cbd..34052ca6 100644 --- a/x/exp/ast/parser_test.go +++ b/x/exp/ast/cedar_unmarshal_test.go @@ -364,7 +364,7 @@ func TestParsePolicy(t *testing.T) { t.Parallel() var policy Policy - testutil.OK(t, policy.FromCedar([]byte(tt.Text))) + testutil.OK(t, policy.UnmarshalCedar([]byte(tt.Text))) testutil.Equals(t, policy, *tt.ExpectedPolicy) }) } @@ -406,7 +406,7 @@ func TestParsePolicySet(t *testing.T) { t.Parallel() var policies PolicySet - testutil.OK(t, policies.FromCedar([]byte(tt.Text))) + testutil.OK(t, policies.UnmarshalCedar([]byte(tt.Text))) testutil.Equals(t, policies, tt.ExpectedPolicies) }) } From 8e2b81a212b35aaf5f635ef8309995581be6358e Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Fri, 2 Aug 2024 16:58:44 -0700 Subject: [PATCH 052/216] cedar-go/x/exp/ast: move cedar_unmarshal_test.go out of ast and into ast_test Signed-off-by: philhassey --- x/exp/ast/cedar_unmarshal_test.go | 109 +++++++++++++++--------------- 1 file changed, 55 insertions(+), 54 deletions(-) diff --git a/x/exp/ast/cedar_unmarshal_test.go b/x/exp/ast/cedar_unmarshal_test.go index 34052ca6..0676d409 100644 --- a/x/exp/ast/cedar_unmarshal_test.go +++ b/x/exp/ast/cedar_unmarshal_test.go @@ -1,10 +1,11 @@ -package ast +package ast_test import ( "testing" "github.com/cedar-policy/cedar-go/testutil" "github.com/cedar-policy/cedar-go/types" + "github.com/cedar-policy/cedar-go/x/exp/ast" ) var johnny = types.EntityUID{ @@ -41,7 +42,7 @@ func TestParsePolicy(t *testing.T) { parseTests := []struct { Name string Text string - ExpectedPolicy *Policy + ExpectedPolicy *ast.Policy }{ { "permit any scope", @@ -50,7 +51,7 @@ func TestParsePolicy(t *testing.T) { action, resource );`, - Permit(), + ast.Permit(), }, { "forbid any scope", @@ -59,7 +60,7 @@ func TestParsePolicy(t *testing.T) { action, resource );`, - Forbid(), + ast.Forbid(), }, { "one annotation", @@ -69,7 +70,7 @@ func TestParsePolicy(t *testing.T) { action, resource );`, - Annotation("foo", "bar").Permit(), + ast.Annotation("foo", "bar").Permit(), }, { "two annotations", @@ -80,7 +81,7 @@ func TestParsePolicy(t *testing.T) { action, resource );`, - Annotation("foo", "bar").Annotation("baz", "quux").Permit(), + ast.Annotation("foo", "bar").Annotation("baz", "quux").Permit(), }, { "scope eq", @@ -89,7 +90,7 @@ func TestParsePolicy(t *testing.T) { action == Action::"sow", resource == Crop::"apple" );`, - Permit().PrincipalEq(johnny).ActionEq(sow).ResourceEq(apple), + ast.Permit().PrincipalEq(johnny).ActionEq(sow).ResourceEq(apple), }, { "scope is", @@ -98,7 +99,7 @@ func TestParsePolicy(t *testing.T) { action, resource is Crop );`, - Permit().PrincipalIs("User").ResourceIs("Crop"), + ast.Permit().PrincipalIs("User").ResourceIs("Crop"), }, { "scope is in", @@ -107,7 +108,7 @@ func TestParsePolicy(t *testing.T) { action, resource is Crop in Genus::"malus" );`, - Permit().PrincipalIsIn("User", folkHeroes).ResourceIsIn("Crop", malus), + ast.Permit().PrincipalIsIn("User", folkHeroes).ResourceIsIn("Crop", malus), }, { "scope in", @@ -116,7 +117,7 @@ func TestParsePolicy(t *testing.T) { action in ActionType::"farming", resource in Genus::"malus" );`, - Permit().PrincipalIn(folkHeroes).ActionIn(farming).ResourceIn(malus), + ast.Permit().PrincipalIn(folkHeroes).ActionIn(farming).ResourceIn(malus), }, { "scope action in entities", @@ -125,237 +126,237 @@ func TestParsePolicy(t *testing.T) { action in [ActionType::"farming", ActionType::"forestry"], resource );`, - Permit().ActionInSet(farming, forestry), + ast.Permit().ActionInSet(farming, forestry), }, { "trivial conditions", `permit (principal, action, resource) when { true } unless { false };`, - Permit().When(Boolean(true)).Unless(Boolean(false)), + ast.Permit().When(ast.Boolean(true)).Unless(ast.Boolean(false)), }, { "not operator", `permit (principal, action, resource) when { !true };`, - Permit().When(Not(Boolean(true))), + ast.Permit().When(ast.Not(ast.Boolean(true))), }, { "negate operator", `permit (principal, action, resource) when { -1 };`, - Permit().When(Negate(Long(1))), + ast.Permit().When(ast.Negate(ast.Long(1))), }, { "mutliple negate operators", `permit (principal, action, resource) when { !--1 };`, - Permit().When(Not(Negate(Negate(Long(1))))), + ast.Permit().When(ast.Not(ast.Negate(ast.Negate(ast.Long(1))))), }, { "variable member", `permit (principal, action, resource) when { context.boolValue };`, - Permit().When(Context().Access("boolValue")), + ast.Permit().When(ast.Context().Access("boolValue")), }, { "contains method call", `permit (principal, action, resource) when { context.strings.contains("foo") };`, - Permit().When(Context().Access("strings").Contains(String("foo"))), + ast.Permit().When(ast.Context().Access("strings").Contains(ast.String("foo"))), }, { "containsAll method call", `permit (principal, action, resource) when { context.strings.containsAll(["foo"]) };`, - Permit().When(Context().Access("strings").ContainsAll(SetNodes(String("foo")))), + ast.Permit().When(ast.Context().Access("strings").ContainsAll(ast.SetNodes(ast.String("foo")))), }, { "containsAny method call", `permit (principal, action, resource) when { context.strings.containsAny(["foo"]) };`, - Permit().When(Context().Access("strings").ContainsAny(SetNodes(String("foo")))), + ast.Permit().When(ast.Context().Access("strings").ContainsAny(ast.SetNodes(ast.String("foo")))), }, { "extension method call", `permit (principal, action, resource) when { context.sourceIP.isIpv4() };`, - Permit().When(Context().Access("sourceIP").IsIpv4()), + ast.Permit().When(ast.Context().Access("sourceIP").IsIpv4()), }, { "multiplication", `permit (principal, action, resource) when { 42 * 2 };`, - Permit().When(Long(42).Times(Long(2))), + ast.Permit().When(ast.Long(42).Times(ast.Long(2))), }, { "addition", `permit (principal, action, resource) when { 42 + 2 };`, - Permit().When(Long(42).Plus(Long(2))), + ast.Permit().When(ast.Long(42).Plus(ast.Long(2))), }, { "subtraction", `permit (principal, action, resource) when { 42 - 2 };`, - Permit().When(Long(42).Minus(Long(2))), + ast.Permit().When(ast.Long(42).Minus(ast.Long(2))), }, { "less than", `permit (principal, action, resource) when { 2 < 42 };`, - Permit().When(Long(2).LessThan(Long(42))), + ast.Permit().When(ast.Long(2).LessThan(ast.Long(42))), }, { "less than or equal", `permit (principal, action, resource) when { 2 <= 42 };`, - Permit().When(Long(2).LessThanOrEqual(Long(42))), + ast.Permit().When(ast.Long(2).LessThanOrEqual(ast.Long(42))), }, { "greater than", `permit (principal, action, resource) when { 2 > 42 };`, - Permit().When(Long(2).GreaterThan(Long(42))), + ast.Permit().When(ast.Long(2).GreaterThan(ast.Long(42))), }, { "greater than or equal", `permit (principal, action, resource) when { 2 >= 42 };`, - Permit().When(Long(2).GreaterThanOrEqual(Long(42))), + ast.Permit().When(ast.Long(2).GreaterThanOrEqual(ast.Long(42))), }, { "equal", `permit (principal, action, resource) when { 2 == 42 };`, - Permit().When(Long(2).Equals(Long(42))), + ast.Permit().When(ast.Long(2).Equals(ast.Long(42))), }, { "not equal", `permit (principal, action, resource) when { 2 != 42 };`, - Permit().When(Long(2).NotEquals(Long(42))), + ast.Permit().When(ast.Long(2).NotEquals(ast.Long(42))), }, { "in", `permit (principal, action, resource) when { principal in Group::"folkHeroes" };`, - Permit().When(Principal().In(Entity(folkHeroes))), + ast.Permit().When(ast.Principal().In(ast.Entity(folkHeroes))), }, { "has ident", `permit (principal, action, resource) when { principal has firstName };`, - Permit().When(Principal().Has("firstName")), + ast.Permit().When(ast.Principal().Has("firstName")), }, { "has string", `permit (principal, action, resource) when { principal has "firstName" };`, - Permit().When(Principal().Has("firstName")), + ast.Permit().When(ast.Principal().Has("firstName")), }, // N.B. Most pattern parsing tests can be found in pattern_test.go { "like no wildcards", `permit (principal, action, resource) when { principal.firstName like "johnny" };`, - Permit().When(Principal().Access("firstName").Like(testutil.Must(PatternFromCedar("johnny")))), + ast.Permit().When(ast.Principal().Access("firstName").Like(testutil.Must(ast.PatternFromCedar("johnny")))), }, { "like escaped asterisk", `permit (principal, action, resource) when { principal.firstName like "joh\*nny" };`, - Permit().When(Principal().Access("firstName").Like(testutil.Must(PatternFromCedar(`joh\*nny`)))), + ast.Permit().When(ast.Principal().Access("firstName").Like(testutil.Must(ast.PatternFromCedar(`joh\*nny`)))), }, { "like wildcard", `permit (principal, action, resource) when { principal.firstName like "*" };`, - Permit().When(Principal().Access("firstName").Like(testutil.Must(PatternFromCedar("*")))), + ast.Permit().When(ast.Principal().Access("firstName").Like(testutil.Must(ast.PatternFromCedar("*")))), }, { "is", `permit (principal, action, resource) when { principal is User };`, - Permit().When(Principal().Is("User")), + ast.Permit().When(ast.Principal().Is("User")), }, { "is in", `permit (principal, action, resource) when { principal is User in Group::"folkHeroes" };`, - Permit().When(Principal().IsIn("User", Entity(folkHeroes))), + ast.Permit().When(ast.Principal().IsIn("User", ast.Entity(folkHeroes))), }, { "is in", `permit (principal, action, resource) when { principal is User in Group::"folkHeroes" };`, - Permit().When(Principal().IsIn("User", Entity(folkHeroes))), + ast.Permit().When(ast.Principal().IsIn("User", ast.Entity(folkHeroes))), }, { "and", `permit (principal, action, resource) when { true && false };`, - Permit().When(True().And(False())), + ast.Permit().When(ast.True().And(ast.False())), }, { "or", `permit (principal, action, resource) when { true || false };`, - Permit().When(True().Or(False())), + ast.Permit().When(ast.True().Or(ast.False())), }, { "if then else", `permit (principal, action, resource) when { if true then true else false };`, - Permit().When(If(True(), True(), False())), + ast.Permit().When(ast.If(ast.True(), ast.True(), ast.False())), }, { "and over or precedence", `permit (principal, action, resource) when { true && false || true && true };`, - Permit().When(True().And(False()).Or(True().And(True()))), + ast.Permit().When(ast.True().And(ast.False()).Or(ast.True().And(ast.True()))), }, { "rel over and precedence", `permit (principal, action, resource) when { 1 < 2 && true };`, - Permit().When(Long(1).LessThan(Long(2)).And(True())), + ast.Permit().When(ast.Long(1).LessThan(ast.Long(2)).And(ast.True())), }, { "add over rel precedence", `permit (principal, action, resource) when { 1 + 1 < 3 };`, - Permit().When(Long(1).Plus(Long(1)).LessThan(Long(3))), + ast.Permit().When(ast.Long(1).Plus(ast.Long(1)).LessThan(ast.Long(3))), }, { "mult over add precedence", `permit (principal, action, resource) when { 2 * 3 + 4 == 10 };`, - Permit().When(Long(2).Times(Long(3)).Plus(Long(4)).Equals(Long(10))), + ast.Permit().When(ast.Long(2).Times(ast.Long(3)).Plus(ast.Long(4)).Equals(ast.Long(10))), }, { "unary over mult precedence", `permit (principal, action, resource) when { -2 * 3 == -6 };`, - Permit().When(Negate(Long(2)).Times(Long(3)).Equals(Negate(Long(6)))), + ast.Permit().When(ast.Negate(ast.Long(2)).Times(ast.Long(3)).Equals(ast.Negate(ast.Long(6)))), }, { "member over unary precedence", `permit (principal, action, resource) when { -context.num };`, - Permit().When(Negate(Context().Access("num"))), + ast.Permit().When(ast.Negate(ast.Context().Access("num"))), }, { "member over unary precedence", `permit (principal, action, resource) when { -context.num };`, - Permit().When(Negate(Context().Access("num"))), + ast.Permit().When(ast.Negate(ast.Context().Access("num"))), }, { "parens over unary precedence", `permit (principal, action, resource) when { -(2 + 3) == -5 };`, - Permit().When(Negate(Long(2).Plus(Long(3))).Equals(Negate(Long(5)))), + ast.Permit().When(ast.Negate(ast.Long(2).Plus(ast.Long(3))).Equals(ast.Negate(ast.Long(5)))), }, } @@ -363,7 +364,7 @@ func TestParsePolicy(t *testing.T) { t.Run(tt.Name, func(t *testing.T) { t.Parallel() - var policy Policy + var policy ast.Policy testutil.OK(t, policy.UnmarshalCedar([]byte(tt.Text))) testutil.Equals(t, policy, *tt.ExpectedPolicy) }) @@ -375,7 +376,7 @@ func TestParsePolicySet(t *testing.T) { parseTests := []struct { Name string Text string - ExpectedPolicies PolicySet + ExpectedPolicies ast.PolicySet }{ { "single policy", @@ -384,7 +385,7 @@ func TestParsePolicySet(t *testing.T) { action, resource );`, - PolicySet{*Permit()}, + ast.PolicySet{*ast.Permit()}, }, { "two policies", @@ -398,14 +399,14 @@ func TestParsePolicySet(t *testing.T) { action, resource );`, - PolicySet{*Permit(), *Forbid()}, + ast.PolicySet{*ast.Permit(), *ast.Forbid()}, }, } for _, tt := range parseTests { t.Run(tt.Name, func(t *testing.T) { t.Parallel() - var policies PolicySet + var policies ast.PolicySet testutil.OK(t, policies.UnmarshalCedar([]byte(tt.Text))) testutil.Equals(t, policies, tt.ExpectedPolicies) }) From 3d3cfbfedc311711c81378e385c9c97d0f6e345c Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Mon, 5 Aug 2024 10:57:34 -0700 Subject: [PATCH 053/216] cedar-go/x/exp/ast: remove expensive map allocations in favor of switch statements Signed-off-by: philhassey --- x/exp/ast/cedar_unmarshal.go | 172 ++++++++++++++++++++--------------- 1 file changed, 100 insertions(+), 72 deletions(-) diff --git a/x/exp/ast/cedar_unmarshal.go b/x/exp/ast/cedar_unmarshal.go index a55eae91..fa77c5bb 100644 --- a/x/exp/ast/cedar_unmarshal.go +++ b/x/exp/ast/cedar_unmarshal.go @@ -483,92 +483,116 @@ func (p *parser) relation() (Node, error) { } t := p.peek() - operators := map[string]func(Node) Node{ - "<": lhs.LessThan, - "<=": lhs.LessThanOrEqual, - ">": lhs.GreaterThan, - ">=": lhs.GreaterThanOrEqual, - "!=": lhs.NotEquals, - "==": lhs.Equals, - "in": lhs.In, - } - if f, ok := operators[t.Text]; ok { - p.advance() - rhs, err := p.add() - if err != nil { - return Node{}, err - } - return f(rhs), nil - } if t.Text == "has" { p.advance() - t = p.advance() - if t.isIdent() { - return lhs.Has(t.Text), nil - } else if t.isString() { - str, err := t.stringValue() - if err != nil { - return Node{}, err - } - return lhs.Has(str), nil - } - return Node{}, p.errorf("expected ident or string") + return p.has(lhs) } else if t.Text == "like" { p.advance() - t = p.advance() - if !t.isString() { - return Node{}, p.errorf("expected string literal") - } - patternRaw := t.Text - patternRaw = strings.TrimPrefix(patternRaw, "\"") - patternRaw = strings.TrimSuffix(patternRaw, "\"") - pattern, err := PatternFromCedar(patternRaw) - if err != nil { - return Node{}, err - } - return lhs.Like(pattern), nil + return p.like(lhs) } else if t.Text == "is" { p.advance() - entityType, err := p.path() + return p.is(lhs) + } + + // RELOP + var operator func(Node, Node) Node + switch t.Text { + case "<": + operator = Node.LessThan + case "<=": + operator = Node.LessThanOrEqual + case ">": + operator = Node.GreaterThan + case ">=": + operator = Node.GreaterThanOrEqual + case "!=": + operator = Node.NotEquals + case "==": + operator = Node.Equals + case "in": + operator = Node.In + default: + return lhs, nil + + } + + p.advance() + rhs, err := p.add() + if err != nil { + return Node{}, err + } + return operator(lhs, rhs), nil +} + +func (p *parser) has(lhs Node) (Node, error) { + t := p.advance() + if t.isIdent() { + return lhs.Has(t.Text), nil + } else if t.isString() { + str, err := t.stringValue() if err != nil { return Node{}, err } - if p.peek().Text == "in" { - p.advance() - inEntity, err := p.add() - if err != nil { - return Node{}, err - } - return lhs.IsIn(entityType, inEntity), nil - } - return lhs.Is(entityType), nil + return lhs.Has(str), nil } - - return lhs, err + return Node{}, p.errorf("expected ident or string") } -func (p *parser) add() (Node, error) { - lhs, err := p.mult() +func (p *parser) like(lhs Node) (Node, error) { + t := p.advance() + if !t.isString() { + return Node{}, p.errorf("expected string literal") + } + patternRaw := t.Text + patternRaw = strings.TrimPrefix(patternRaw, "\"") + patternRaw = strings.TrimSuffix(patternRaw, "\"") + pattern, err := PatternFromCedar(patternRaw) if err != nil { return Node{}, err } + return lhs.Like(pattern), nil +} - t := p.peek().Text - operators := map[string]func(Node) Node{ - "+": lhs.Plus, - "-": lhs.Minus, +func (p *parser) is(lhs Node) (Node, error) { + entityType, err := p.path() + if err != nil { + return Node{}, err } - if f, ok := operators[t]; ok { + if p.peek().Text == "in" { p.advance() - rhs, err := p.mult() + inEntity, err := p.add() if err != nil { return Node{}, err } - return f(rhs), nil + return lhs.IsIn(entityType, inEntity), nil } + return lhs.Is(entityType), nil +} - return lhs, nil +func (p *parser) add() (Node, error) { + lhs, err := p.mult() + if err != nil { + return Node{}, err + } + + t := p.peek() + var operator func(Node, Node) Node + switch t.Text { + case "+": + operator = Node.Plus + case "-": + operator = Node.Minus + default: + return lhs, nil + } + + p.advance() + rhs, err := p.mult() + if err != nil { + return Node{}, err + } + return operator(lhs, rhs), nil } func (p *parser) mult() (Node, error) { @@ -825,18 +849,22 @@ func (p *parser) access(lhs Node) (Node, bool, error) { } p.advance() // expressions guarantees ")" - knownMethods := map[string]func(Node) Node{ - "contains": lhs.Contains, - "containsAll": lhs.ContainsAll, - "containsAny": lhs.ContainsAny, + var knownMethod func(Node, Node) Node + switch methodName { + case "contains": + knownMethod = Node.Contains + case "containsAll": + knownMethod = Node.ContainsAll + case "containsAny": + knownMethod = Node.ContainsAny + default: + return newExtMethodCallNode(lhs, types.String(methodName), exprs...), true, nil } - if f, ok := knownMethods[methodName]; ok { - if len(exprs) != 1 { - return Node{}, false, p.errorf("%v expects one argument", methodName) - } - return f(exprs[0]), true, nil + + if len(exprs) != 1 { + return Node{}, false, p.errorf("%v expects one argument", methodName) } - return newExtMethodCallNode(lhs, types.String(methodName), exprs...), true, nil + return knownMethod(lhs, exprs[0]), true, nil } else { return lhs.Access(t.Text), true, nil } From 581924cb1f2c21b0a4552f598dd19b93199c50ec Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Mon, 5 Aug 2024 11:06:32 -0700 Subject: [PATCH 054/216] cedar-go/x/exp/ast: stuff the position of the policy into the Policy struct It's unused for now, but we'll need it eventually. We'll figure out how "public" we want to make it at that time as well. Signed-off-by: philhassey --- x/exp/ast/cedar_tokenize.go | 10 +++++----- x/exp/ast/cedar_unmarshal.go | 4 ++++ x/exp/ast/policy.go | 1 + 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/x/exp/ast/cedar_tokenize.go b/x/exp/ast/cedar_tokenize.go index 97101b34..089ab16e 100644 --- a/x/exp/ast/cedar_tokenize.go +++ b/x/exp/ast/cedar_tokenize.go @@ -28,7 +28,7 @@ const ( type Token struct { Type TokenType - Pos Position + Pos position Text string } @@ -216,15 +216,15 @@ func Tokenize(src []byte) ([]Token, error) { return res, nil } -// Position is a value that represents a source position. +// position is a value that represents a source position. // A position is valid if Line > 0. -type Position struct { +type position struct { Offset int // byte offset, starting at 0 Line int // line number, starting at 1 Column int // column number, starting at 1 (character count per line) } -func (pos Position) String() string { +func (pos position) String() string { return fmt.Sprintf(":%d:%d", pos.Line, pos.Column) } @@ -272,7 +272,7 @@ type scanner struct { // the scanner is not inside a token. Call Pos to obtain an error // position in that case, or to obtain the position immediately // after the most recently scanned token. - position Position + position position } // Init initializes a Scanner with a new source and returns s. diff --git a/x/exp/ast/cedar_unmarshal.go b/x/exp/ast/cedar_unmarshal.go index fa77c5bb..7d1ca6a3 100644 --- a/x/exp/ast/cedar_unmarshal.go +++ b/x/exp/ast/cedar_unmarshal.go @@ -41,6 +41,8 @@ func (p *Policy) UnmarshalCedar(b []byte) error { } func (p *Policy) fromCedarWithParser(parser *parser) error { + pos := parser.peek().Pos + annotations, err := parser.annotations() if err != nil { return err @@ -51,6 +53,8 @@ func (p *Policy) fromCedarWithParser(parser *parser) error { return err } + newPolicy.pos = pos + if err = parser.exact("("); err != nil { return err } diff --git a/x/exp/ast/policy.go b/x/exp/ast/policy.go index e1347a31..48857b05 100644 --- a/x/exp/ast/policy.go +++ b/x/exp/ast/policy.go @@ -9,6 +9,7 @@ type Policy struct { action isScopeNode resource isScopeNode conditions []nodeTypeCondition + pos position } func newPolicy(effect effect, annotations []nodeTypeAnnotation) *Policy { From 7327f7883c123be036820cccd4a9eae482fd6591 Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Mon, 5 Aug 2024 12:02:40 -0700 Subject: [PATCH 055/216] cedar-go/x/exp/ast: remove pos for now while we figure out how to write tests for it Signed-off-by: philhassey --- x/exp/ast/cedar_tokenize_test.go | 152 +++++++++++++++---------------- x/exp/ast/cedar_unmarshal.go | 4 - x/exp/ast/policy.go | 1 - 3 files changed, 76 insertions(+), 81 deletions(-) diff --git a/x/exp/ast/cedar_tokenize_test.go b/x/exp/ast/cedar_tokenize_test.go index fbf9db5a..16eb74ac 100644 --- a/x/exp/ast/cedar_tokenize_test.go +++ b/x/exp/ast/cedar_tokenize_test.go @@ -29,66 +29,66 @@ multiline comment */ '/%|&=` want := []Token{ - {Type: TokenIdent, Text: "These", Pos: Position{Offset: 1, Line: 2, Column: 1}}, - {Type: TokenIdent, Text: "are", Pos: Position{Offset: 7, Line: 2, Column: 7}}, - {Type: TokenIdent, Text: "some", Pos: Position{Offset: 11, Line: 2, Column: 11}}, - {Type: TokenIdent, Text: "identifiers", Pos: Position{Offset: 16, Line: 2, Column: 16}}, - - {Type: TokenInt, Text: "0", Pos: Position{Offset: 28, Line: 3, Column: 1}}, - {Type: TokenInt, Text: "1", Pos: Position{Offset: 30, Line: 3, Column: 3}}, - {Type: TokenInt, Text: "1234", Pos: Position{Offset: 32, Line: 3, Column: 5}}, - - {Type: TokenOperator, Text: "-", Pos: Position{Offset: 37, Line: 4, Column: 1}}, - {Type: TokenInt, Text: "1", Pos: Position{Offset: 38, Line: 4, Column: 2}}, - {Type: TokenInt, Text: "9223372036854775807", Pos: Position{Offset: 40, Line: 4, Column: 4}}, - {Type: TokenOperator, Text: "-", Pos: Position{Offset: 60, Line: 4, Column: 24}}, - {Type: TokenInt, Text: "9223372036854775808", Pos: Position{Offset: 61, Line: 4, Column: 25}}, - - {Type: TokenString, Text: `""`, Pos: Position{Offset: 81, Line: 5, Column: 1}}, - {Type: TokenString, Text: `"string"`, Pos: Position{Offset: 84, Line: 5, Column: 4}}, - {Type: TokenString, Text: `"\"\'\n\r\t\\\0"`, Pos: Position{Offset: 93, Line: 5, Column: 13}}, - {Type: TokenString, Text: `"\x123"`, Pos: Position{Offset: 110, Line: 5, Column: 30}}, - {Type: TokenString, Text: `"\u{0}\u{10fFfF}"`, Pos: Position{Offset: 118, Line: 5, Column: 38}}, - - {Type: TokenString, Text: `"*"`, Pos: Position{Offset: 136, Line: 6, Column: 1}}, - {Type: TokenString, Text: `"\*"`, Pos: Position{Offset: 140, Line: 6, Column: 5}}, - {Type: TokenString, Text: `"*\**"`, Pos: Position{Offset: 145, Line: 6, Column: 10}}, - - {Type: TokenOperator, Text: "@", Pos: Position{Offset: 152, Line: 7, Column: 1}}, - {Type: TokenOperator, Text: ".", Pos: Position{Offset: 153, Line: 7, Column: 2}}, - {Type: TokenOperator, Text: ",", Pos: Position{Offset: 154, Line: 7, Column: 3}}, - {Type: TokenOperator, Text: ";", Pos: Position{Offset: 155, Line: 7, Column: 4}}, - {Type: TokenOperator, Text: "(", Pos: Position{Offset: 156, Line: 7, Column: 5}}, - {Type: TokenOperator, Text: ")", Pos: Position{Offset: 157, Line: 7, Column: 6}}, - {Type: TokenOperator, Text: "{", Pos: Position{Offset: 158, Line: 7, Column: 7}}, - {Type: TokenOperator, Text: "}", Pos: Position{Offset: 159, Line: 7, Column: 8}}, - {Type: TokenOperator, Text: "[", Pos: Position{Offset: 160, Line: 7, Column: 9}}, - {Type: TokenOperator, Text: "]", Pos: Position{Offset: 161, Line: 7, Column: 10}}, - {Type: TokenOperator, Text: "+", Pos: Position{Offset: 162, Line: 7, Column: 11}}, - {Type: TokenOperator, Text: "-", Pos: Position{Offset: 163, Line: 7, Column: 12}}, - {Type: TokenOperator, Text: "*", Pos: Position{Offset: 164, Line: 7, Column: 13}}, - - {Type: TokenOperator, Text: "::", Pos: Position{Offset: 166, Line: 8, Column: 1}}, - {Type: TokenOperator, Text: ":", Pos: Position{Offset: 168, Line: 8, Column: 3}}, - - {Type: TokenOperator, Text: "!", Pos: Position{Offset: 170, Line: 9, Column: 1}}, - {Type: TokenOperator, Text: "!=", Pos: Position{Offset: 171, Line: 9, Column: 2}}, - {Type: TokenOperator, Text: "<", Pos: Position{Offset: 173, Line: 9, Column: 4}}, - {Type: TokenOperator, Text: "<=", Pos: Position{Offset: 174, Line: 9, Column: 5}}, - {Type: TokenOperator, Text: ">", Pos: Position{Offset: 176, Line: 9, Column: 7}}, - {Type: TokenOperator, Text: ">=", Pos: Position{Offset: 177, Line: 9, Column: 8}}, - - {Type: TokenOperator, Text: "||", Pos: Position{Offset: 180, Line: 10, Column: 1}}, - {Type: TokenOperator, Text: "&&", Pos: Position{Offset: 182, Line: 10, Column: 3}}, - - {Type: TokenUnknown, Text: "'", Pos: Position{Offset: 265, Line: 16, Column: 1}}, - {Type: TokenUnknown, Text: "/", Pos: Position{Offset: 266, Line: 16, Column: 2}}, - {Type: TokenUnknown, Text: "%", Pos: Position{Offset: 267, Line: 16, Column: 3}}, - {Type: TokenUnknown, Text: "|", Pos: Position{Offset: 268, Line: 16, Column: 4}}, - {Type: TokenUnknown, Text: "&", Pos: Position{Offset: 269, Line: 16, Column: 5}}, - {Type: TokenUnknown, Text: "=", Pos: Position{Offset: 270, Line: 16, Column: 6}}, - - {Type: TokenEOF, Text: "", Pos: Position{Offset: 271, Line: 16, Column: 7}}, + {Type: TokenIdent, Text: "These", Pos: position{Offset: 1, Line: 2, Column: 1}}, + {Type: TokenIdent, Text: "are", Pos: position{Offset: 7, Line: 2, Column: 7}}, + {Type: TokenIdent, Text: "some", Pos: position{Offset: 11, Line: 2, Column: 11}}, + {Type: TokenIdent, Text: "identifiers", Pos: position{Offset: 16, Line: 2, Column: 16}}, + + {Type: TokenInt, Text: "0", Pos: position{Offset: 28, Line: 3, Column: 1}}, + {Type: TokenInt, Text: "1", Pos: position{Offset: 30, Line: 3, Column: 3}}, + {Type: TokenInt, Text: "1234", Pos: position{Offset: 32, Line: 3, Column: 5}}, + + {Type: TokenOperator, Text: "-", Pos: position{Offset: 37, Line: 4, Column: 1}}, + {Type: TokenInt, Text: "1", Pos: position{Offset: 38, Line: 4, Column: 2}}, + {Type: TokenInt, Text: "9223372036854775807", Pos: position{Offset: 40, Line: 4, Column: 4}}, + {Type: TokenOperator, Text: "-", Pos: position{Offset: 60, Line: 4, Column: 24}}, + {Type: TokenInt, Text: "9223372036854775808", Pos: position{Offset: 61, Line: 4, Column: 25}}, + + {Type: TokenString, Text: `""`, Pos: position{Offset: 81, Line: 5, Column: 1}}, + {Type: TokenString, Text: `"string"`, Pos: position{Offset: 84, Line: 5, Column: 4}}, + {Type: TokenString, Text: `"\"\'\n\r\t\\\0"`, Pos: position{Offset: 93, Line: 5, Column: 13}}, + {Type: TokenString, Text: `"\x123"`, Pos: position{Offset: 110, Line: 5, Column: 30}}, + {Type: TokenString, Text: `"\u{0}\u{10fFfF}"`, Pos: position{Offset: 118, Line: 5, Column: 38}}, + + {Type: TokenString, Text: `"*"`, Pos: position{Offset: 136, Line: 6, Column: 1}}, + {Type: TokenString, Text: `"\*"`, Pos: position{Offset: 140, Line: 6, Column: 5}}, + {Type: TokenString, Text: `"*\**"`, Pos: position{Offset: 145, Line: 6, Column: 10}}, + + {Type: TokenOperator, Text: "@", Pos: position{Offset: 152, Line: 7, Column: 1}}, + {Type: TokenOperator, Text: ".", Pos: position{Offset: 153, Line: 7, Column: 2}}, + {Type: TokenOperator, Text: ",", Pos: position{Offset: 154, Line: 7, Column: 3}}, + {Type: TokenOperator, Text: ";", Pos: position{Offset: 155, Line: 7, Column: 4}}, + {Type: TokenOperator, Text: "(", Pos: position{Offset: 156, Line: 7, Column: 5}}, + {Type: TokenOperator, Text: ")", Pos: position{Offset: 157, Line: 7, Column: 6}}, + {Type: TokenOperator, Text: "{", Pos: position{Offset: 158, Line: 7, Column: 7}}, + {Type: TokenOperator, Text: "}", Pos: position{Offset: 159, Line: 7, Column: 8}}, + {Type: TokenOperator, Text: "[", Pos: position{Offset: 160, Line: 7, Column: 9}}, + {Type: TokenOperator, Text: "]", Pos: position{Offset: 161, Line: 7, Column: 10}}, + {Type: TokenOperator, Text: "+", Pos: position{Offset: 162, Line: 7, Column: 11}}, + {Type: TokenOperator, Text: "-", Pos: position{Offset: 163, Line: 7, Column: 12}}, + {Type: TokenOperator, Text: "*", Pos: position{Offset: 164, Line: 7, Column: 13}}, + + {Type: TokenOperator, Text: "::", Pos: position{Offset: 166, Line: 8, Column: 1}}, + {Type: TokenOperator, Text: ":", Pos: position{Offset: 168, Line: 8, Column: 3}}, + + {Type: TokenOperator, Text: "!", Pos: position{Offset: 170, Line: 9, Column: 1}}, + {Type: TokenOperator, Text: "!=", Pos: position{Offset: 171, Line: 9, Column: 2}}, + {Type: TokenOperator, Text: "<", Pos: position{Offset: 173, Line: 9, Column: 4}}, + {Type: TokenOperator, Text: "<=", Pos: position{Offset: 174, Line: 9, Column: 5}}, + {Type: TokenOperator, Text: ">", Pos: position{Offset: 176, Line: 9, Column: 7}}, + {Type: TokenOperator, Text: ">=", Pos: position{Offset: 177, Line: 9, Column: 8}}, + + {Type: TokenOperator, Text: "||", Pos: position{Offset: 180, Line: 10, Column: 1}}, + {Type: TokenOperator, Text: "&&", Pos: position{Offset: 182, Line: 10, Column: 3}}, + + {Type: TokenUnknown, Text: "'", Pos: position{Offset: 265, Line: 16, Column: 1}}, + {Type: TokenUnknown, Text: "/", Pos: position{Offset: 266, Line: 16, Column: 2}}, + {Type: TokenUnknown, Text: "%", Pos: position{Offset: 267, Line: 16, Column: 3}}, + {Type: TokenUnknown, Text: "|", Pos: position{Offset: 268, Line: 16, Column: 4}}, + {Type: TokenUnknown, Text: "&", Pos: position{Offset: 269, Line: 16, Column: 5}}, + {Type: TokenUnknown, Text: "=", Pos: position{Offset: 270, Line: 16, Column: 6}}, + + {Type: TokenEOF, Text: "", Pos: position{Offset: 271, Line: 16, Column: 7}}, } got, err := Tokenize([]byte(input)) testutil.OK(t, err) @@ -100,26 +100,26 @@ func TestTokenizeErrors(t *testing.T) { tests := []struct { input string wantErrStr string - wantErrPos Position + wantErrPos position }{ - {"okay\x00not okay", "invalid character NUL", Position{Line: 1, Column: 1}}, + {"okay\x00not okay", "invalid character NUL", position{Line: 1, Column: 1}}, {`okay /* stuff - `, "comment not terminated", Position{Line: 1, Column: 6}}, + `, "comment not terminated", position{Line: 1, Column: 6}}, {`okay " - " foo bar`, "literal not terminated", Position{Line: 1, Column: 6}}, - {`"okay" "\a"`, "invalid char escape", Position{Line: 1, Column: 8}}, - {`"okay" "\b"`, "invalid char escape", Position{Line: 1, Column: 8}}, - {`"okay" "\f"`, "invalid char escape", Position{Line: 1, Column: 8}}, - {`"okay" "\v"`, "invalid char escape", Position{Line: 1, Column: 8}}, - {`"okay" "\1"`, "invalid char escape", Position{Line: 1, Column: 8}}, - {`"okay" "\x"`, "invalid char escape", Position{Line: 1, Column: 8}}, - {`"okay" "\x1"`, "invalid char escape", Position{Line: 1, Column: 8}}, - {`"okay" "\ubadf"`, "invalid char escape", Position{Line: 1, Column: 8}}, - {`"okay" "\U0000badf"`, "invalid char escape", Position{Line: 1, Column: 8}}, - {`"okay" "\u{}"`, "invalid char escape", Position{Line: 1, Column: 8}}, - {`"okay" "\u{0000000}"`, "invalid char escape", Position{Line: 1, Column: 8}}, - {`"okay" "\u{z"`, "invalid char escape", Position{Line: 1, Column: 8}}, + " foo bar`, "literal not terminated", position{Line: 1, Column: 6}}, + {`"okay" "\a"`, "invalid char escape", position{Line: 1, Column: 8}}, + {`"okay" "\b"`, "invalid char escape", position{Line: 1, Column: 8}}, + {`"okay" "\f"`, "invalid char escape", position{Line: 1, Column: 8}}, + {`"okay" "\v"`, "invalid char escape", position{Line: 1, Column: 8}}, + {`"okay" "\1"`, "invalid char escape", position{Line: 1, Column: 8}}, + {`"okay" "\x"`, "invalid char escape", position{Line: 1, Column: 8}}, + {`"okay" "\x1"`, "invalid char escape", position{Line: 1, Column: 8}}, + {`"okay" "\ubadf"`, "invalid char escape", position{Line: 1, Column: 8}}, + {`"okay" "\U0000badf"`, "invalid char escape", position{Line: 1, Column: 8}}, + {`"okay" "\u{}"`, "invalid char escape", position{Line: 1, Column: 8}}, + {`"okay" "\u{0000000}"`, "invalid char escape", position{Line: 1, Column: 8}}, + {`"okay" "\u{z"`, "invalid char escape", position{Line: 1, Column: 8}}, } for i, tt := range tests { tt := tt diff --git a/x/exp/ast/cedar_unmarshal.go b/x/exp/ast/cedar_unmarshal.go index 7d1ca6a3..fa77c5bb 100644 --- a/x/exp/ast/cedar_unmarshal.go +++ b/x/exp/ast/cedar_unmarshal.go @@ -41,8 +41,6 @@ func (p *Policy) UnmarshalCedar(b []byte) error { } func (p *Policy) fromCedarWithParser(parser *parser) error { - pos := parser.peek().Pos - annotations, err := parser.annotations() if err != nil { return err @@ -53,8 +51,6 @@ func (p *Policy) fromCedarWithParser(parser *parser) error { return err } - newPolicy.pos = pos - if err = parser.exact("("); err != nil { return err } diff --git a/x/exp/ast/policy.go b/x/exp/ast/policy.go index 48857b05..e1347a31 100644 --- a/x/exp/ast/policy.go +++ b/x/exp/ast/policy.go @@ -9,7 +9,6 @@ type Policy struct { action isScopeNode resource isScopeNode conditions []nodeTypeCondition - pos position } func newPolicy(effect effect, annotations []nodeTypeAnnotation) *Policy { From 9ae4b0596ab46e52b6503c287e8070e0e938d552 Mon Sep 17 00:00:00 2001 From: philhassey Date: Mon, 5 Aug 2024 15:32:24 -0600 Subject: [PATCH 056/216] x/exp/ast: make non-node types no longer be nodes Addresses IDX-49 Signed-off-by: philhassey --- x/exp/ast/annotation.go | 10 +++++----- x/exp/ast/node.go | 19 ------------------- x/exp/ast/policy.go | 42 +++++++++++++++++++++++++++++------------ 3 files changed, 35 insertions(+), 36 deletions(-) diff --git a/x/exp/ast/annotation.go b/x/exp/ast/annotation.go index c79bf94a..d17a7434 100644 --- a/x/exp/ast/annotation.go +++ b/x/exp/ast/annotation.go @@ -3,7 +3,7 @@ package ast import "github.com/cedar-policy/cedar-go/types" type Annotations struct { - nodes []nodeTypeAnnotation + nodes []annotationType } // Annotation allows AST constructors to make policy in a similar shape to textual Cedar with @@ -14,7 +14,7 @@ type Annotations struct { // Permit(). // PrincipalEq(superUser) func Annotation(name, value types.String) *Annotations { - return &Annotations{nodes: []nodeTypeAnnotation{newAnnotation(name, value)}} + return &Annotations{nodes: []annotationType{newAnnotation(name, value)}} } func (a *Annotations) Annotation(name, value types.String) *Annotations { @@ -31,10 +31,10 @@ func (a *Annotations) Forbid() *Policy { } func (p *Policy) Annotate(name, value types.String) *Policy { - p.annotations = append(p.annotations, nodeTypeAnnotation{Key: name, Value: value}) + p.annotations = append(p.annotations, annotationType{Key: name, Value: value}) return p } -func newAnnotation(name, value types.String) nodeTypeAnnotation { - return nodeTypeAnnotation{Key: name, Value: value} +func newAnnotation(name, value types.String) annotationType { + return annotationType{Key: name, Value: value} } diff --git a/x/exp/ast/node.go b/x/exp/ast/node.go index 5a2e2c89..1e3a913e 100644 --- a/x/exp/ast/node.go +++ b/x/exp/ast/node.go @@ -17,12 +17,6 @@ type nodeTypeLike struct { Value Pattern } -type nodeTypeAnnotation struct { - node - Key types.String // TODO: review type - Value types.String -} - type nodeTypeIf struct { node If, Then, Else node @@ -89,19 +83,6 @@ type unaryNode struct { type nodeTypeNegate struct{ unaryNode } type nodeTypeNot struct{ unaryNode } -type condition bool - -const ( - conditionWhen = true - conditionUnless = false -) - -type nodeTypeCondition struct { - node - Condition condition - Body node -} - type nodeTypeVariable struct { node Name types.String // TODO: Review type diff --git a/x/exp/ast/policy.go b/x/exp/ast/policy.go index e1347a31..e6c10177 100644 --- a/x/exp/ast/policy.go +++ b/x/exp/ast/policy.go @@ -1,17 +1,42 @@ package ast +import "github.com/cedar-policy/cedar-go/types" + type PolicySet []Policy +type annotationType struct { + Key types.String // TODO: review type + Value types.String +} +type condition bool + +const ( + conditionWhen = true + conditionUnless = false +) + +type conditionType struct { + Condition condition + Body node +} + +type effect bool + +const ( + effectPermit effect = true + effectForbid effect = false +) + type Policy struct { effect effect - annotations []nodeTypeAnnotation + annotations []annotationType principal isScopeNode action isScopeNode resource isScopeNode - conditions []nodeTypeCondition + conditions []conditionType } -func newPolicy(effect effect, annotations []nodeTypeAnnotation) *Policy { +func newPolicy(effect effect, annotations []annotationType) *Policy { return &Policy{ effect: effect, annotations: annotations, @@ -30,18 +55,11 @@ func Forbid() *Policy { } func (p *Policy) When(node Node) *Policy { - p.conditions = append(p.conditions, nodeTypeCondition{Condition: conditionWhen, Body: node.v}) + p.conditions = append(p.conditions, conditionType{Condition: conditionWhen, Body: node.v}) return p } func (p *Policy) Unless(node Node) *Policy { - p.conditions = append(p.conditions, nodeTypeCondition{Condition: conditionUnless, Body: node.v}) + p.conditions = append(p.conditions, conditionType{Condition: conditionUnless, Body: node.v}) return p } - -type effect bool - -const ( - effectPermit effect = true - effectForbid effect = false -) From d55f6ac5383bd0abd9fd406a689b6ae533e16719 Mon Sep 17 00:00:00 2001 From: philhassey Date: Mon, 5 Aug 2024 15:39:36 -0600 Subject: [PATCH 057/216] x/exp/ast: make usage of Entity vs EntityUID be more consistent with the rest of the package Addresses IDX-49 Signed-off-by: philhassey --- x/exp/ast/cedar_unmarshal.go | 2 +- x/exp/ast/cedar_unmarshal_test.go | 6 +++--- x/exp/ast/json_test.go | 4 ++-- x/exp/ast/value.go | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/x/exp/ast/cedar_unmarshal.go b/x/exp/ast/cedar_unmarshal.go index fa77c5bb..36c59c59 100644 --- a/x/exp/ast/cedar_unmarshal.go +++ b/x/exp/ast/cedar_unmarshal.go @@ -755,7 +755,7 @@ func (p *parser) entityOrExtFun(ident string) (Node, error) { if err != nil { return res, err } - res = Entity(entity) + res = EntityUID(entity) } return res, nil diff --git a/x/exp/ast/cedar_unmarshal_test.go b/x/exp/ast/cedar_unmarshal_test.go index 0676d409..0b45cb26 100644 --- a/x/exp/ast/cedar_unmarshal_test.go +++ b/x/exp/ast/cedar_unmarshal_test.go @@ -241,7 +241,7 @@ func TestParsePolicy(t *testing.T) { "in", `permit (principal, action, resource) when { principal in Group::"folkHeroes" };`, - ast.Permit().When(ast.Principal().In(ast.Entity(folkHeroes))), + ast.Permit().When(ast.Principal().In(ast.EntityUID(folkHeroes))), }, { "has ident", @@ -284,13 +284,13 @@ func TestParsePolicy(t *testing.T) { "is in", `permit (principal, action, resource) when { principal is User in Group::"folkHeroes" };`, - ast.Permit().When(ast.Principal().IsIn("User", ast.Entity(folkHeroes))), + ast.Permit().When(ast.Principal().IsIn("User", ast.EntityUID(folkHeroes))), }, { "is in", `permit (principal, action, resource) when { principal is User in Group::"folkHeroes" };`, - ast.Permit().When(ast.Principal().IsIn("User", ast.Entity(folkHeroes))), + ast.Permit().When(ast.Principal().IsIn("User", ast.EntityUID(folkHeroes))), }, { "and", diff --git a/x/exp/ast/json_test.go b/x/exp/ast/json_test.go index a22265a9..cd5bbf73 100644 --- a/x/exp/ast/json_test.go +++ b/x/exp/ast/json_test.go @@ -201,7 +201,7 @@ func TestUnmarshalJSON(t *testing.T) { "entity", `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, "conditions":[{"kind":"when","body":{"Value":{"__entity":{"type":"T","id":"42"}}}}]}`, - ast.Permit().When(ast.Entity(types.NewEntityUID("T", "42"))), + ast.Permit().When(ast.EntityUID(types.NewEntityUID("T", "42"))), testutil.OK, }, { @@ -404,7 +404,7 @@ func TestUnmarshalJSON(t *testing.T) { "isIn", `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, "conditions":[{"kind":"when","body":{"is":{"left":{"Var":"resource"},"entity_type":"T","in":{"Value":{"__entity":{"type":"P","id":"42"}}}}}}]}`, - ast.Permit().When(ast.Resource().IsIn("T", ast.Entity(types.NewEntityUID("P", "42")))), + ast.Permit().When(ast.Resource().IsIn("T", ast.EntityUID(types.NewEntityUID("P", "42")))), testutil.OK, }, { diff --git a/x/exp/ast/value.go b/x/exp/ast/value.go index 4e52ca54..9548301b 100644 --- a/x/exp/ast/value.go +++ b/x/exp/ast/value.go @@ -84,7 +84,7 @@ func EntityType(e types.String) Node { return newValueNode(e) } -func Entity(e types.EntityUID) Node { +func EntityUID(e types.EntityUID) Node { return newValueNode(e) } @@ -113,7 +113,7 @@ func valueToNode(v types.Value) Node { case types.Record: return Record(x) case types.EntityUID: - return Entity(x) + return EntityUID(x) case types.Decimal: return Decimal(x) case types.IPAddr: From 9582bcfa764c3d1e08392e629e6e2287f2550245 Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Mon, 5 Aug 2024 14:38:10 -0700 Subject: [PATCH 058/216] cedar-go/x/exp/ast: change PolicySet to a map of policy ID to Policy Signed-off-by: philhassey --- x/exp/ast/cedar_unmarshal.go | 8 ++++++-- x/exp/ast/cedar_unmarshal_test.go | 4 ++-- x/exp/ast/policy.go | 2 +- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/x/exp/ast/cedar_unmarshal.go b/x/exp/ast/cedar_unmarshal.go index 36c59c59..01220c1f 100644 --- a/x/exp/ast/cedar_unmarshal.go +++ b/x/exp/ast/cedar_unmarshal.go @@ -15,7 +15,9 @@ func (p *PolicySet) UnmarshalCedar(b []byte) error { return err } - var policySet PolicySet + i := 0 + + policySet := PolicySet{} parser := newParser(tokens) for !parser.peek().isEOF() { var policy Policy @@ -23,7 +25,9 @@ func (p *PolicySet) UnmarshalCedar(b []byte) error { return err } - policySet = append(policySet, policy) + policyName := fmt.Sprintf("policy%v", i) + policySet[policyName] = policy + i++ } *p = policySet diff --git a/x/exp/ast/cedar_unmarshal_test.go b/x/exp/ast/cedar_unmarshal_test.go index 0b45cb26..3e1312ac 100644 --- a/x/exp/ast/cedar_unmarshal_test.go +++ b/x/exp/ast/cedar_unmarshal_test.go @@ -385,7 +385,7 @@ func TestParsePolicySet(t *testing.T) { action, resource );`, - ast.PolicySet{*ast.Permit()}, + ast.PolicySet{"policy0": *ast.Permit()}, }, { "two policies", @@ -399,7 +399,7 @@ func TestParsePolicySet(t *testing.T) { action, resource );`, - ast.PolicySet{*ast.Permit(), *ast.Forbid()}, + ast.PolicySet{"policy0": *ast.Permit(), "policy1": *ast.Forbid()}, }, } for _, tt := range parseTests { diff --git a/x/exp/ast/policy.go b/x/exp/ast/policy.go index e6c10177..8d02e7fb 100644 --- a/x/exp/ast/policy.go +++ b/x/exp/ast/policy.go @@ -2,7 +2,7 @@ package ast import "github.com/cedar-policy/cedar-go/types" -type PolicySet []Policy +type PolicySet map[string]Policy type annotationType struct { Key types.String // TODO: review type From 83ef0dd97109efbf5bb1810093663a531773828b Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Mon, 5 Aug 2024 14:47:09 -0700 Subject: [PATCH 059/216] cedar-go/x/exp/ast: add position information to the entries in the PolicySet map when parsing from Cedar Signed-off-by: philhassey --- x/exp/ast/cedar_tokenize.go | 22 ++--- x/exp/ast/cedar_tokenize_test.go | 152 +++++++++++++++--------------- x/exp/ast/cedar_unmarshal.go | 3 +- x/exp/ast/cedar_unmarshal_test.go | 18 +++- x/exp/ast/policy.go | 7 +- 5 files changed, 111 insertions(+), 91 deletions(-) diff --git a/x/exp/ast/cedar_tokenize.go b/x/exp/ast/cedar_tokenize.go index 089ab16e..9aee4db5 100644 --- a/x/exp/ast/cedar_tokenize.go +++ b/x/exp/ast/cedar_tokenize.go @@ -28,7 +28,7 @@ const ( type Token struct { Type TokenType - Pos position + Pos Position Text string } @@ -216,15 +216,15 @@ func Tokenize(src []byte) ([]Token, error) { return res, nil } -// position is a value that represents a source position. -// A position is valid if Line > 0. -type position struct { +// Position is a value that represents a source Position. +// A Position is valid if Line > 0. +type Position struct { Offset int // byte offset, starting at 0 Line int // line number, starting at 1 Column int // column number, starting at 1 (character count per line) } -func (pos position) String() string { +func (pos Position) String() string { return fmt.Sprintf(":%d:%d", pos.Line, pos.Column) } @@ -272,7 +272,7 @@ type scanner struct { // the scanner is not inside a token. Call Pos to obtain an error // position in that case, or to obtain the position immediately // after the most recently scanned token. - position position + position Position } // Init initializes a Scanner with a new source and returns s. @@ -285,7 +285,7 @@ func (s *scanner) Init(src io.Reader) *scanner { s.srcPos = 0 s.srcEnd = 0 - // initialize source position + // initialize source Position s.srcBufOffset = 0 s.line = 1 s.column = 0 @@ -300,7 +300,7 @@ func (s *scanner) Init(src io.Reader) *scanner { s.ch = specialRuneBOF // no char read yet, not EOF // initialize public fields - s.position.Line = 0 // invalidate token position + s.position.Line = 0 // invalidate token Position return s } @@ -360,7 +360,7 @@ func (s *scanner) next() rune { // uncommon case: not ASCII ch, width = utf8.DecodeRune(s.srcBuf[s.srcPos:s.srcEnd]) if ch == utf8.RuneError && width == 1 { - // advance for correct error position + // advance for correct error Position s.srcPos += width s.lastCharLen = width s.column++ @@ -562,7 +562,7 @@ func (s *scanner) nextToken() Token { ch := s.ch - // reset token text position + // reset token text Position s.tokPos = -1 s.position.Line = 0 @@ -576,7 +576,7 @@ redo: s.tokBuf.Reset() s.tokPos = s.srcPos - s.lastCharLen - // set token position + // set token Position s.position.Offset = s.srcBufOffset + s.tokPos if s.column > 0 { // common case: last character was not a '\n' diff --git a/x/exp/ast/cedar_tokenize_test.go b/x/exp/ast/cedar_tokenize_test.go index 16eb74ac..fbf9db5a 100644 --- a/x/exp/ast/cedar_tokenize_test.go +++ b/x/exp/ast/cedar_tokenize_test.go @@ -29,66 +29,66 @@ multiline comment */ '/%|&=` want := []Token{ - {Type: TokenIdent, Text: "These", Pos: position{Offset: 1, Line: 2, Column: 1}}, - {Type: TokenIdent, Text: "are", Pos: position{Offset: 7, Line: 2, Column: 7}}, - {Type: TokenIdent, Text: "some", Pos: position{Offset: 11, Line: 2, Column: 11}}, - {Type: TokenIdent, Text: "identifiers", Pos: position{Offset: 16, Line: 2, Column: 16}}, - - {Type: TokenInt, Text: "0", Pos: position{Offset: 28, Line: 3, Column: 1}}, - {Type: TokenInt, Text: "1", Pos: position{Offset: 30, Line: 3, Column: 3}}, - {Type: TokenInt, Text: "1234", Pos: position{Offset: 32, Line: 3, Column: 5}}, - - {Type: TokenOperator, Text: "-", Pos: position{Offset: 37, Line: 4, Column: 1}}, - {Type: TokenInt, Text: "1", Pos: position{Offset: 38, Line: 4, Column: 2}}, - {Type: TokenInt, Text: "9223372036854775807", Pos: position{Offset: 40, Line: 4, Column: 4}}, - {Type: TokenOperator, Text: "-", Pos: position{Offset: 60, Line: 4, Column: 24}}, - {Type: TokenInt, Text: "9223372036854775808", Pos: position{Offset: 61, Line: 4, Column: 25}}, - - {Type: TokenString, Text: `""`, Pos: position{Offset: 81, Line: 5, Column: 1}}, - {Type: TokenString, Text: `"string"`, Pos: position{Offset: 84, Line: 5, Column: 4}}, - {Type: TokenString, Text: `"\"\'\n\r\t\\\0"`, Pos: position{Offset: 93, Line: 5, Column: 13}}, - {Type: TokenString, Text: `"\x123"`, Pos: position{Offset: 110, Line: 5, Column: 30}}, - {Type: TokenString, Text: `"\u{0}\u{10fFfF}"`, Pos: position{Offset: 118, Line: 5, Column: 38}}, - - {Type: TokenString, Text: `"*"`, Pos: position{Offset: 136, Line: 6, Column: 1}}, - {Type: TokenString, Text: `"\*"`, Pos: position{Offset: 140, Line: 6, Column: 5}}, - {Type: TokenString, Text: `"*\**"`, Pos: position{Offset: 145, Line: 6, Column: 10}}, - - {Type: TokenOperator, Text: "@", Pos: position{Offset: 152, Line: 7, Column: 1}}, - {Type: TokenOperator, Text: ".", Pos: position{Offset: 153, Line: 7, Column: 2}}, - {Type: TokenOperator, Text: ",", Pos: position{Offset: 154, Line: 7, Column: 3}}, - {Type: TokenOperator, Text: ";", Pos: position{Offset: 155, Line: 7, Column: 4}}, - {Type: TokenOperator, Text: "(", Pos: position{Offset: 156, Line: 7, Column: 5}}, - {Type: TokenOperator, Text: ")", Pos: position{Offset: 157, Line: 7, Column: 6}}, - {Type: TokenOperator, Text: "{", Pos: position{Offset: 158, Line: 7, Column: 7}}, - {Type: TokenOperator, Text: "}", Pos: position{Offset: 159, Line: 7, Column: 8}}, - {Type: TokenOperator, Text: "[", Pos: position{Offset: 160, Line: 7, Column: 9}}, - {Type: TokenOperator, Text: "]", Pos: position{Offset: 161, Line: 7, Column: 10}}, - {Type: TokenOperator, Text: "+", Pos: position{Offset: 162, Line: 7, Column: 11}}, - {Type: TokenOperator, Text: "-", Pos: position{Offset: 163, Line: 7, Column: 12}}, - {Type: TokenOperator, Text: "*", Pos: position{Offset: 164, Line: 7, Column: 13}}, - - {Type: TokenOperator, Text: "::", Pos: position{Offset: 166, Line: 8, Column: 1}}, - {Type: TokenOperator, Text: ":", Pos: position{Offset: 168, Line: 8, Column: 3}}, - - {Type: TokenOperator, Text: "!", Pos: position{Offset: 170, Line: 9, Column: 1}}, - {Type: TokenOperator, Text: "!=", Pos: position{Offset: 171, Line: 9, Column: 2}}, - {Type: TokenOperator, Text: "<", Pos: position{Offset: 173, Line: 9, Column: 4}}, - {Type: TokenOperator, Text: "<=", Pos: position{Offset: 174, Line: 9, Column: 5}}, - {Type: TokenOperator, Text: ">", Pos: position{Offset: 176, Line: 9, Column: 7}}, - {Type: TokenOperator, Text: ">=", Pos: position{Offset: 177, Line: 9, Column: 8}}, - - {Type: TokenOperator, Text: "||", Pos: position{Offset: 180, Line: 10, Column: 1}}, - {Type: TokenOperator, Text: "&&", Pos: position{Offset: 182, Line: 10, Column: 3}}, - - {Type: TokenUnknown, Text: "'", Pos: position{Offset: 265, Line: 16, Column: 1}}, - {Type: TokenUnknown, Text: "/", Pos: position{Offset: 266, Line: 16, Column: 2}}, - {Type: TokenUnknown, Text: "%", Pos: position{Offset: 267, Line: 16, Column: 3}}, - {Type: TokenUnknown, Text: "|", Pos: position{Offset: 268, Line: 16, Column: 4}}, - {Type: TokenUnknown, Text: "&", Pos: position{Offset: 269, Line: 16, Column: 5}}, - {Type: TokenUnknown, Text: "=", Pos: position{Offset: 270, Line: 16, Column: 6}}, - - {Type: TokenEOF, Text: "", Pos: position{Offset: 271, Line: 16, Column: 7}}, + {Type: TokenIdent, Text: "These", Pos: Position{Offset: 1, Line: 2, Column: 1}}, + {Type: TokenIdent, Text: "are", Pos: Position{Offset: 7, Line: 2, Column: 7}}, + {Type: TokenIdent, Text: "some", Pos: Position{Offset: 11, Line: 2, Column: 11}}, + {Type: TokenIdent, Text: "identifiers", Pos: Position{Offset: 16, Line: 2, Column: 16}}, + + {Type: TokenInt, Text: "0", Pos: Position{Offset: 28, Line: 3, Column: 1}}, + {Type: TokenInt, Text: "1", Pos: Position{Offset: 30, Line: 3, Column: 3}}, + {Type: TokenInt, Text: "1234", Pos: Position{Offset: 32, Line: 3, Column: 5}}, + + {Type: TokenOperator, Text: "-", Pos: Position{Offset: 37, Line: 4, Column: 1}}, + {Type: TokenInt, Text: "1", Pos: Position{Offset: 38, Line: 4, Column: 2}}, + {Type: TokenInt, Text: "9223372036854775807", Pos: Position{Offset: 40, Line: 4, Column: 4}}, + {Type: TokenOperator, Text: "-", Pos: Position{Offset: 60, Line: 4, Column: 24}}, + {Type: TokenInt, Text: "9223372036854775808", Pos: Position{Offset: 61, Line: 4, Column: 25}}, + + {Type: TokenString, Text: `""`, Pos: Position{Offset: 81, Line: 5, Column: 1}}, + {Type: TokenString, Text: `"string"`, Pos: Position{Offset: 84, Line: 5, Column: 4}}, + {Type: TokenString, Text: `"\"\'\n\r\t\\\0"`, Pos: Position{Offset: 93, Line: 5, Column: 13}}, + {Type: TokenString, Text: `"\x123"`, Pos: Position{Offset: 110, Line: 5, Column: 30}}, + {Type: TokenString, Text: `"\u{0}\u{10fFfF}"`, Pos: Position{Offset: 118, Line: 5, Column: 38}}, + + {Type: TokenString, Text: `"*"`, Pos: Position{Offset: 136, Line: 6, Column: 1}}, + {Type: TokenString, Text: `"\*"`, Pos: Position{Offset: 140, Line: 6, Column: 5}}, + {Type: TokenString, Text: `"*\**"`, Pos: Position{Offset: 145, Line: 6, Column: 10}}, + + {Type: TokenOperator, Text: "@", Pos: Position{Offset: 152, Line: 7, Column: 1}}, + {Type: TokenOperator, Text: ".", Pos: Position{Offset: 153, Line: 7, Column: 2}}, + {Type: TokenOperator, Text: ",", Pos: Position{Offset: 154, Line: 7, Column: 3}}, + {Type: TokenOperator, Text: ";", Pos: Position{Offset: 155, Line: 7, Column: 4}}, + {Type: TokenOperator, Text: "(", Pos: Position{Offset: 156, Line: 7, Column: 5}}, + {Type: TokenOperator, Text: ")", Pos: Position{Offset: 157, Line: 7, Column: 6}}, + {Type: TokenOperator, Text: "{", Pos: Position{Offset: 158, Line: 7, Column: 7}}, + {Type: TokenOperator, Text: "}", Pos: Position{Offset: 159, Line: 7, Column: 8}}, + {Type: TokenOperator, Text: "[", Pos: Position{Offset: 160, Line: 7, Column: 9}}, + {Type: TokenOperator, Text: "]", Pos: Position{Offset: 161, Line: 7, Column: 10}}, + {Type: TokenOperator, Text: "+", Pos: Position{Offset: 162, Line: 7, Column: 11}}, + {Type: TokenOperator, Text: "-", Pos: Position{Offset: 163, Line: 7, Column: 12}}, + {Type: TokenOperator, Text: "*", Pos: Position{Offset: 164, Line: 7, Column: 13}}, + + {Type: TokenOperator, Text: "::", Pos: Position{Offset: 166, Line: 8, Column: 1}}, + {Type: TokenOperator, Text: ":", Pos: Position{Offset: 168, Line: 8, Column: 3}}, + + {Type: TokenOperator, Text: "!", Pos: Position{Offset: 170, Line: 9, Column: 1}}, + {Type: TokenOperator, Text: "!=", Pos: Position{Offset: 171, Line: 9, Column: 2}}, + {Type: TokenOperator, Text: "<", Pos: Position{Offset: 173, Line: 9, Column: 4}}, + {Type: TokenOperator, Text: "<=", Pos: Position{Offset: 174, Line: 9, Column: 5}}, + {Type: TokenOperator, Text: ">", Pos: Position{Offset: 176, Line: 9, Column: 7}}, + {Type: TokenOperator, Text: ">=", Pos: Position{Offset: 177, Line: 9, Column: 8}}, + + {Type: TokenOperator, Text: "||", Pos: Position{Offset: 180, Line: 10, Column: 1}}, + {Type: TokenOperator, Text: "&&", Pos: Position{Offset: 182, Line: 10, Column: 3}}, + + {Type: TokenUnknown, Text: "'", Pos: Position{Offset: 265, Line: 16, Column: 1}}, + {Type: TokenUnknown, Text: "/", Pos: Position{Offset: 266, Line: 16, Column: 2}}, + {Type: TokenUnknown, Text: "%", Pos: Position{Offset: 267, Line: 16, Column: 3}}, + {Type: TokenUnknown, Text: "|", Pos: Position{Offset: 268, Line: 16, Column: 4}}, + {Type: TokenUnknown, Text: "&", Pos: Position{Offset: 269, Line: 16, Column: 5}}, + {Type: TokenUnknown, Text: "=", Pos: Position{Offset: 270, Line: 16, Column: 6}}, + + {Type: TokenEOF, Text: "", Pos: Position{Offset: 271, Line: 16, Column: 7}}, } got, err := Tokenize([]byte(input)) testutil.OK(t, err) @@ -100,26 +100,26 @@ func TestTokenizeErrors(t *testing.T) { tests := []struct { input string wantErrStr string - wantErrPos position + wantErrPos Position }{ - {"okay\x00not okay", "invalid character NUL", position{Line: 1, Column: 1}}, + {"okay\x00not okay", "invalid character NUL", Position{Line: 1, Column: 1}}, {`okay /* stuff - `, "comment not terminated", position{Line: 1, Column: 6}}, + `, "comment not terminated", Position{Line: 1, Column: 6}}, {`okay " - " foo bar`, "literal not terminated", position{Line: 1, Column: 6}}, - {`"okay" "\a"`, "invalid char escape", position{Line: 1, Column: 8}}, - {`"okay" "\b"`, "invalid char escape", position{Line: 1, Column: 8}}, - {`"okay" "\f"`, "invalid char escape", position{Line: 1, Column: 8}}, - {`"okay" "\v"`, "invalid char escape", position{Line: 1, Column: 8}}, - {`"okay" "\1"`, "invalid char escape", position{Line: 1, Column: 8}}, - {`"okay" "\x"`, "invalid char escape", position{Line: 1, Column: 8}}, - {`"okay" "\x1"`, "invalid char escape", position{Line: 1, Column: 8}}, - {`"okay" "\ubadf"`, "invalid char escape", position{Line: 1, Column: 8}}, - {`"okay" "\U0000badf"`, "invalid char escape", position{Line: 1, Column: 8}}, - {`"okay" "\u{}"`, "invalid char escape", position{Line: 1, Column: 8}}, - {`"okay" "\u{0000000}"`, "invalid char escape", position{Line: 1, Column: 8}}, - {`"okay" "\u{z"`, "invalid char escape", position{Line: 1, Column: 8}}, + " foo bar`, "literal not terminated", Position{Line: 1, Column: 6}}, + {`"okay" "\a"`, "invalid char escape", Position{Line: 1, Column: 8}}, + {`"okay" "\b"`, "invalid char escape", Position{Line: 1, Column: 8}}, + {`"okay" "\f"`, "invalid char escape", Position{Line: 1, Column: 8}}, + {`"okay" "\v"`, "invalid char escape", Position{Line: 1, Column: 8}}, + {`"okay" "\1"`, "invalid char escape", Position{Line: 1, Column: 8}}, + {`"okay" "\x"`, "invalid char escape", Position{Line: 1, Column: 8}}, + {`"okay" "\x1"`, "invalid char escape", Position{Line: 1, Column: 8}}, + {`"okay" "\ubadf"`, "invalid char escape", Position{Line: 1, Column: 8}}, + {`"okay" "\U0000badf"`, "invalid char escape", Position{Line: 1, Column: 8}}, + {`"okay" "\u{}"`, "invalid char escape", Position{Line: 1, Column: 8}}, + {`"okay" "\u{0000000}"`, "invalid char escape", Position{Line: 1, Column: 8}}, + {`"okay" "\u{z"`, "invalid char escape", Position{Line: 1, Column: 8}}, } for i, tt := range tests { tt := tt diff --git a/x/exp/ast/cedar_unmarshal.go b/x/exp/ast/cedar_unmarshal.go index 01220c1f..ce537a85 100644 --- a/x/exp/ast/cedar_unmarshal.go +++ b/x/exp/ast/cedar_unmarshal.go @@ -20,13 +20,14 @@ func (p *PolicySet) UnmarshalCedar(b []byte) error { policySet := PolicySet{} parser := newParser(tokens) for !parser.peek().isEOF() { + pos := parser.peek().Pos var policy Policy if err = policy.fromCedarWithParser(&parser); err != nil { return err } policyName := fmt.Sprintf("policy%v", i) - policySet[policyName] = policy + policySet[policyName] = PolicySetEntry{Policy: policy, Position: pos} i++ } diff --git a/x/exp/ast/cedar_unmarshal_test.go b/x/exp/ast/cedar_unmarshal_test.go index 3e1312ac..cb25beb9 100644 --- a/x/exp/ast/cedar_unmarshal_test.go +++ b/x/exp/ast/cedar_unmarshal_test.go @@ -385,7 +385,12 @@ func TestParsePolicySet(t *testing.T) { action, resource );`, - ast.PolicySet{"policy0": *ast.Permit()}, + ast.PolicySet{ + "policy0": ast.PolicySetEntry{ + *ast.Permit(), + ast.Position{Offset: 0, Line: 1, Column: 1}, + }, + }, }, { "two policies", @@ -399,7 +404,16 @@ func TestParsePolicySet(t *testing.T) { action, resource );`, - ast.PolicySet{"policy0": *ast.Permit(), "policy1": *ast.Forbid()}, + ast.PolicySet{ + "policy0": ast.PolicySetEntry{ + *ast.Permit(), + ast.Position{Offset: 0, Line: 1, Column: 1}, + }, + "policy1": ast.PolicySetEntry{ + *ast.Forbid(), + ast.Position{Offset: 53, Line: 6, Column: 3}, + }, + }, }, } for _, tt := range parseTests { diff --git a/x/exp/ast/policy.go b/x/exp/ast/policy.go index 8d02e7fb..038d47d1 100644 --- a/x/exp/ast/policy.go +++ b/x/exp/ast/policy.go @@ -2,7 +2,12 @@ package ast import "github.com/cedar-policy/cedar-go/types" -type PolicySet map[string]Policy +type PolicySet map[string]PolicySetEntry + +type PolicySetEntry struct { + Policy Policy + Position Position +} type annotationType struct { Key types.String // TODO: review type From a1b879d03cd327af73bfc009233efbd0449f4cfd Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Tue, 6 Aug 2024 07:39:31 -0700 Subject: [PATCH 060/216] cedar-go/x/exp/ast: support multiple operator applications for several operators Signed-off-by: philhassey --- x/exp/ast/cedar_unmarshal.go | 85 ++++++++++++++++--------------- x/exp/ast/cedar_unmarshal_test.go | 48 +++++++++++++++++ 2 files changed, 92 insertions(+), 41 deletions(-) diff --git a/x/exp/ast/cedar_unmarshal.go b/x/exp/ast/cedar_unmarshal.go index ce537a85..f20d94d3 100644 --- a/x/exp/ast/cedar_unmarshal.go +++ b/x/exp/ast/cedar_unmarshal.go @@ -449,17 +449,16 @@ func (p *parser) or() (Node, error) { return Node{}, err } - t := p.peek() - if t.Text != "||" { - return lhs, nil + for p.peek().Text == "||" { + p.advance() + rhs, err := p.and() + if err != nil { + return Node{}, err + } + lhs = lhs.Or(rhs) } - p.advance() - rhs, err := p.and() - if err != nil { - return Node{}, err - } - return lhs.Or(rhs), nil + return lhs, nil } func (p *parser) and() (Node, error) { @@ -468,17 +467,16 @@ func (p *parser) and() (Node, error) { return Node{}, err } - t := p.peek() - if t.Text != "&&" { - return lhs, nil + for p.peek().Text == "&&" { + p.advance() + rhs, err := p.relation() + if err != nil { + return Node{}, err + } + lhs = lhs.And(rhs) } - p.advance() - rhs, err := p.relation() - if err != nil { - return Node{}, err - } - return lhs.And(rhs), nil + return lhs, nil } func (p *parser) relation() (Node, error) { @@ -581,23 +579,28 @@ func (p *parser) add() (Node, error) { return Node{}, err } - t := p.peek() - var operator func(Node, Node) Node - switch t.Text { - case "+": - operator = Node.Plus - case "-": - operator = Node.Minus - default: - return lhs, nil - } +NotAdd: + for { + t := p.peek() + var operator func(Node, Node) Node + switch t.Text { + case "+": + operator = Node.Plus + case "-": + operator = Node.Minus + default: + break NotAdd + } - p.advance() - rhs, err := p.mult() - if err != nil { - return Node{}, err + p.advance() + rhs, err := p.mult() + if err != nil { + return Node{}, err + } + lhs = operator(lhs, rhs) } - return operator(lhs, rhs), nil + + return lhs, nil } func (p *parser) mult() (Node, error) { @@ -606,16 +609,16 @@ func (p *parser) mult() (Node, error) { return Node{}, err } - if p.peek().Text != "*" { - return lhs, nil + for p.peek().Text == "*" { + p.advance() + rhs, err := p.unary() + if err != nil { + return Node{}, err + } + lhs = lhs.Times(rhs) } - p.advance() - rhs, err := p.unary() - if err != nil { - return Node{}, err - } - return lhs.Times(rhs), nil + return lhs, nil } func (p *parser) unary() (Node, error) { diff --git a/x/exp/ast/cedar_unmarshal_test.go b/x/exp/ast/cedar_unmarshal_test.go index cb25beb9..2e0260ea 100644 --- a/x/exp/ast/cedar_unmarshal_test.go +++ b/x/exp/ast/cedar_unmarshal_test.go @@ -141,6 +141,12 @@ func TestParsePolicy(t *testing.T) { when { !true };`, ast.Permit().When(ast.Not(ast.Boolean(true))), }, + { + "multiple not operators", + `permit (principal, action, resource) + when { !!true };`, + ast.Permit().When(ast.Not(ast.Not(ast.Boolean(true)))), + }, { "negate operator", `permit (principal, action, resource) @@ -189,18 +195,42 @@ func TestParsePolicy(t *testing.T) { when { 42 * 2 };`, ast.Permit().When(ast.Long(42).Times(ast.Long(2))), }, + { + "multiple multiplication", + `permit (principal, action, resource) + when { 42 * 2 * 1};`, + ast.Permit().When(ast.Long(42).Times(ast.Long(2)).Times(ast.Long(1))), + }, { "addition", `permit (principal, action, resource) when { 42 + 2 };`, ast.Permit().When(ast.Long(42).Plus(ast.Long(2))), }, + { + "multiple addition", + `permit (principal, action, resource) + when { 42 + 2 + 1 };`, + ast.Permit().When(ast.Long(42).Plus(ast.Long(2)).Plus(ast.Long(1))), + }, { "subtraction", `permit (principal, action, resource) when { 42 - 2 };`, ast.Permit().When(ast.Long(42).Minus(ast.Long(2))), }, + { + "multiple subtraction", + `permit (principal, action, resource) + when { 42 - 2 - 1 };`, + ast.Permit().When(ast.Long(42).Minus(ast.Long(2)).Minus(ast.Long(1))), + }, + { + "mixed addition and subtraction", + `permit (principal, action, resource) + when { 42 - 2 + 1 };`, + ast.Permit().When(ast.Long(42).Minus(ast.Long(2)).Plus(ast.Long(1))), + }, { "less than", `permit (principal, action, resource) @@ -298,12 +328,24 @@ func TestParsePolicy(t *testing.T) { when { true && false };`, ast.Permit().When(ast.True().And(ast.False())), }, + { + "multiple and", + `permit (principal, action, resource) + when { true && false && true };`, + ast.Permit().When(ast.True().And(ast.False()).And(ast.True())), + }, { "or", `permit (principal, action, resource) when { true || false };`, ast.Permit().When(ast.True().Or(ast.False())), }, + { + "multiple or", + `permit (principal, action, resource) + when { true || false || true };`, + ast.Permit().When(ast.True().Or(ast.False()).Or(ast.True())), + }, { "if then else", `permit (principal, action, resource) @@ -358,6 +400,12 @@ func TestParsePolicy(t *testing.T) { when { -(2 + 3) == -5 };`, ast.Permit().When(ast.Negate(ast.Long(2).Plus(ast.Long(3))).Equals(ast.Negate(ast.Long(5)))), }, + { + "multiple parenthesized operations", + `permit ( principal, action, resource ) +when { (2 + 3 + 4) * 5 == 18 };`, + ast.Permit().When(ast.Long(2).Plus(ast.Long(3)).Plus(ast.Long(4)).Times(ast.Long(5)).Equals(ast.Long(18))), + }, } for _, tt := range parseTests { From 0a576e02f37cfa46f108a513ce40f57927e9333f Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Tue, 6 Aug 2024 11:20:50 -0700 Subject: [PATCH 061/216] cedar-go/exp/x/ast: remove janky "switch to label" in add() in favor of just checking for nil Signed-off-by: philhassey --- x/exp/ast/cedar_unmarshal.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/x/exp/ast/cedar_unmarshal.go b/x/exp/ast/cedar_unmarshal.go index f20d94d3..22d35713 100644 --- a/x/exp/ast/cedar_unmarshal.go +++ b/x/exp/ast/cedar_unmarshal.go @@ -579,7 +579,6 @@ func (p *parser) add() (Node, error) { return Node{}, err } -NotAdd: for { t := p.peek() var operator func(Node, Node) Node @@ -588,8 +587,10 @@ NotAdd: operator = Node.Plus case "-": operator = Node.Minus - default: - break NotAdd + } + + if operator == nil { + break } p.advance() From 4a0debbc9cc38ef6f7439f82f3586c59fcc4fa4b Mon Sep 17 00:00:00 2001 From: philhassey Date: Tue, 6 Aug 2024 12:48:57 -0600 Subject: [PATCH 062/216] x/exp/ast: make all function and method calls become extension calls Addresses IDX-49 Signed-off-by: philhassey --- x/exp/ast/cedar_unmarshal.go | 2 +- x/exp/ast/json.go | 10 ++--- x/exp/ast/json_marshal.go | 33 +++++++-------- x/exp/ast/json_test.go | 82 +++++++++++++++++++++++++++++------- x/exp/ast/json_unmarshal.go | 54 +++--------------------- x/exp/ast/node.go | 28 ++++++++---- x/exp/ast/operator.go | 18 ++++---- x/exp/ast/value.go | 4 ++ 8 files changed, 123 insertions(+), 108 deletions(-) diff --git a/x/exp/ast/cedar_unmarshal.go b/x/exp/ast/cedar_unmarshal.go index 22d35713..591a99bc 100644 --- a/x/exp/ast/cedar_unmarshal.go +++ b/x/exp/ast/cedar_unmarshal.go @@ -867,7 +867,7 @@ func (p *parser) access(lhs Node) (Node, bool, error) { case "containsAny": knownMethod = Node.ContainsAny default: - return newExtMethodCallNode(lhs, types.String(methodName), exprs...), true, nil + return newMethodCall(lhs, types.String(methodName), exprs...), true, nil } if len(exprs) != 1 { diff --git a/x/exp/ast/json.go b/x/exp/ast/json.go index 226cba27..fe023e10 100644 --- a/x/exp/ast/json.go +++ b/x/exp/ast/json.go @@ -76,7 +76,7 @@ type arrayJSON []nodeJSON type recordJSON map[string]nodeJSON -type extMethodCallJSON map[string]arrayJSON +type extensionCallJSON map[string]arrayJSON type nodeJSON struct { // Value @@ -128,10 +128,6 @@ type nodeJSON struct { // Record Record recordJSON `json:"Record,omitempty"` - // Any other function: decimal, ip - Decimal arrayJSON `json:"decimal,omitempty"` - IP arrayJSON `json:"ip,omitempty"` - - // Any other method: lessThan, lessThanOrEqual, greaterThan, greaterThanOrEqual, isIpv4, isIpv6, isLoopback, isMulticast, isInRange - ExtensionMethod extMethodCallJSON `json:"-"` + // Any other method: decimal, ip, lessThan, lessThanOrEqual, greaterThan, greaterThanOrEqual, isIpv4, isIpv6, isLoopback, isMulticast, isInRange + ExtensionCall extensionCallJSON `json:"-"` } diff --git a/x/exp/ast/json_marshal.go b/x/exp/ast/json_marshal.go index 795116d8..99e47cc4 100644 --- a/x/exp/ast/json_marshal.go +++ b/x/exp/ast/json_marshal.go @@ -77,24 +77,21 @@ func arrayToJSON(dest *arrayJSON, args []node) error { return nil } -func extToJSON(dest *arrayJSON, src types.Value) error { +func extToJSON(dest *extensionCallJSON, name string, src types.Value) error { res := arrayJSON{} str := src.String() // TODO: is this the correct string? b, _ := json.Marshal(string(str)) // error impossible res = append(res, nodeJSON{ Value: (*json.RawMessage)(&b), }) - *dest = res + *dest = extensionCallJSON{ + name: res, + } return nil } -func extMethodToJSON(dest extMethodCallJSON, src nodeTypeExtMethodCall) error { - objectNode := &nodeJSON{} - err := objectNode.FromNode(src.Left) - if err != nil { - return err - } - jsonArgs := arrayJSON{*objectNode} +func extCallToJSON(dest extensionCallJSON, src nodeTypeExtensionCall) error { + jsonArgs := arrayJSON{} for _, n := range src.Args { argNode := &nodeJSON{} err := argNode.FromNode(n) @@ -103,7 +100,7 @@ func extMethodToJSON(dest extMethodCallJSON, src nodeTypeExtMethodCall) error { } jsonArgs = append(jsonArgs, *argNode) } - dest[string(src.Method)] = jsonArgs + dest[string(src.Name)] = jsonArgs return nil } @@ -196,9 +193,9 @@ func (j *nodeJSON) FromNode(src node) error { // IP arrayJSON `json:"ip"` switch tt := t.Value.(type) { case types.Decimal: - return extToJSON(&j.Decimal, tt) + return extToJSON(&j.ExtensionCall, "decimal", tt) case types.IPAddr: - return extToJSON(&j.IP, tt) + return extToJSON(&j.ExtensionCall, "ip", tt) } b, err := t.Value.ExplicitMarshalJSON() j.Value = (*json.RawMessage)(&b) @@ -284,11 +281,11 @@ func (j *nodeJSON) FromNode(src node) error { case nodeTypeRecord: return recordToJSON(&j.Record, t) - // Any other method: lessThan, lessThanOrEqual, greaterThan, greaterThanOrEqual, isIpv4, isIpv6, isLoopback, isMulticast, isInRange + // Any other method: ip, decimal, lessThan, lessThanOrEqual, greaterThan, greaterThanOrEqual, isIpv4, isIpv6, isLoopback, isMulticast, isInRange // ExtensionMethod map[string]arrayJSON `json:"-"` - case nodeTypeExtMethodCall: - j.ExtensionMethod = extMethodCallJSON{} - return extMethodToJSON(j.ExtensionMethod, t) + case nodeTypeExtensionCall: + j.ExtensionCall = extensionCallJSON{} + return extCallToJSON(j.ExtensionCall, t) } // case nodeTypeRecordEntry: // case nodeTypeEntityType: @@ -299,8 +296,8 @@ func (j *nodeJSON) FromNode(src node) error { } func (j *nodeJSON) MarshalJSON() ([]byte, error) { - if len(j.ExtensionMethod) > 0 { - return json.Marshal(j.ExtensionMethod) + if len(j.ExtensionCall) > 0 { + return json.Marshal(j.ExtensionCall) } type nodeJSONAlias nodeJSON diff --git a/x/exp/ast/json_test.go b/x/exp/ast/json_test.go index cd5bbf73..47bc68a7 100644 --- a/x/exp/ast/json_test.go +++ b/x/exp/ast/json_test.go @@ -218,20 +218,6 @@ func TestUnmarshalJSON(t *testing.T) { ast.Permit().When(ast.Record(types.Record{"key": types.Long(42)})), testutil.OK, }, - { - "decimal", - `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, - "conditions":[{"kind":"when","body":{"decimal":[{"Value":"42.24"}]}}]}`, - ast.Permit().When(ast.Decimal(mustParseDecimal("42.24"))), - testutil.OK, - }, - { - "ip", - `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, - "conditions":[{"kind":"when","body":{"ip":[{"Value":"10.0.0.42/8"}]}}]}`, - ast.Permit().When(ast.IPAddr(mustParseIPAddr("10.0.0.42/8"))), - testutil.OK, - }, { "principalVar", `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, @@ -463,6 +449,20 @@ func TestUnmarshalJSON(t *testing.T) { ast.Permit().When(ast.If(ast.True(), ast.Long(42), ast.Long(24))), testutil.OK, }, + { + "decimal", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, + "conditions":[{"kind":"when","body":{"decimal":[{"Value":"42.24"}]}}]}`, + ast.Permit().When(ast.ExtensionCall("decimal", ast.String("42.24"))), + testutil.OK, + }, + { + "ip", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, + "conditions":[{"kind":"when","body":{"ip":[{"Value":"10.0.0.42/8"}]}}]}`, + ast.Permit().When(ast.ExtensionCall("ip", ast.String("10.0.0.42/8"))), + testutil.OK, + }, { "isInRange", `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, @@ -470,7 +470,7 @@ func TestUnmarshalJSON(t *testing.T) { {"ip":[{"Value":"10.0.0.43"}]}, {"ip":[{"Value":"10.0.0.42/8"}]} ]}}]}`, - ast.Permit().When(ast.IPAddr(mustParseIPAddr("10.0.0.43")).IsInRange(ast.IPAddr(mustParseIPAddr("10.0.0.42/8")))), + ast.Permit().When(ast.ExtensionCall("ip", ast.String("10.0.0.43")).IsInRange(ast.ExtensionCall("ip", ast.String("10.0.0.42/8")))), testutil.OK, }, } @@ -495,6 +495,58 @@ func TestUnmarshalJSON(t *testing.T) { } } +func TestMarshalJSON(t *testing.T) { + // most cases are covered in the TestUnmarshalJSON round-trip. + // this covers some cases that aren't 1:1 round-tripppable, such as hard-coded IP/Decimal values. + + t.Parallel() + tests := []struct { + name string + input *ast.Policy + want string + errFunc func(testing.TB, error) + }{ + { + "decimal", + ast.Permit().When(ast.Decimal(mustParseDecimal("42.24"))), + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, + "conditions":[{"kind":"when","body":{"decimal":[{"Value":"42.24"}]}}]}`, + testutil.OK, + }, + { + "ip", + ast.Permit().When(ast.IPAddr(mustParseIPAddr("10.0.0.42/8"))), + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, + "conditions":[{"kind":"when","body":{"ip":[{"Value":"10.0.0.42/8"}]}}]}`, + testutil.OK, + }, + { + "isInRange", + ast.Permit().When(ast.IPAddr(mustParseIPAddr("10.0.0.43")).IsInRange(ast.IPAddr(mustParseIPAddr("10.0.0.42/8")))), + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, + "conditions":[{"kind":"when","body":{"isInRange":[ + {"ip":[{"Value":"10.0.0.43"}]}, + {"ip":[{"Value":"10.0.0.42/8"}]} + ]}}]}`, + testutil.OK, + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + b, err := json.Marshal(tt.input) + tt.errFunc(t, err) + if err != nil { + return + } + normGot := testNormalizeJSON(t, string(b)) + normWant := testNormalizeJSON(t, tt.want) + testutil.Equals(t, normGot, normWant) + }) + } +} + func testNormalizeJSON(t testing.TB, in string) string { var x any err := json.Unmarshal([]byte(in), &x) diff --git a/x/exp/ast/json_unmarshal.go b/x/exp/ast/json_unmarshal.go index 0cd5e8f3..e4a9b76a 100644 --- a/x/exp/ast/json_unmarshal.go +++ b/x/exp/ast/json_unmarshal.go @@ -115,44 +115,6 @@ func (j arrayJSON) ToNode() (Node, error) { return SetNodes(nodes...), nil } -func (j arrayJSON) ToDecimalNode() (Node, error) { - if len(j) != 1 { - return Node{}, fmt.Errorf("unexpected number of arguments for extension: %v", len(j)) - } - arg, err := j[0].ToNode() - if err != nil { - return Node{}, fmt.Errorf("error in extension: %w", err) - } - s, ok := arg.v.(nodeValue) - if !ok { - return Node{}, fmt.Errorf("unexpected type for decimal") - } - v, err := types.ParseDecimal(s.Value.String()) // TODO: this maybe isn't correct - if err != nil { - return Node{}, fmt.Errorf("error parsing decimal: %w", err) - } - return Decimal(v), nil -} - -func (j arrayJSON) ToIPAddrNode() (Node, error) { - if len(j) != 1 { - return Node{}, fmt.Errorf("unexpected number of arguments for extension: %v", len(j)) - } - arg, err := j[0].ToNode() - if err != nil { - return Node{}, fmt.Errorf("error in extension: %w", err) - } - s, ok := arg.v.(nodeValue) - if !ok { - return Node{}, fmt.Errorf("unexpected type for ipaddr") - } - v, err := types.ParseIPAddr(s.Value.String()) - if err != nil { - return Node{}, fmt.Errorf("error parsing ipaddr: %w", err) - } - return IPAddr(v), nil -} - func (j recordJSON) ToNode() (Node, error) { nodes := map[types.String]Node{} for k, v := range j { @@ -165,7 +127,7 @@ func (j recordJSON) ToNode() (Node, error) { return RecordNodes(nodes), nil } -func (e extMethodCallJSON) ToNode() (Node, error) { +func (e extensionCallJSON) ToNode() (Node, error) { if len(e) != 1 { return Node{}, fmt.Errorf("unexpected number of extension methods in node: %v", len(e)) } @@ -181,7 +143,7 @@ func (e extMethodCallJSON) ToNode() (Node, error) { } argNodes = append(argNodes, node) } - return newExtMethodCallNode(argNodes[0], types.String(k), argNodes[1:]...), nil + return newExtensionCall(types.String(k), argNodes...), nil } panic("unreachable code") } @@ -277,15 +239,9 @@ func (j nodeJSON) ToNode() (Node, error) { case j.Record != nil: return j.Record.ToNode() - // Any other function: decimal, ip - case j.Decimal != nil: - return j.Decimal.ToDecimalNode() - case j.IP != nil: - return j.IP.ToIPAddrNode() - // Any other method: lessThan, lessThanOrEqual, greaterThan, greaterThanOrEqual, isIpv4, isIpv6, isLoopback, isMulticast, isInRange - case j.ExtensionMethod != nil: - return j.ExtensionMethod.ToNode() + case j.ExtensionCall != nil: + return j.ExtensionCall.ToNode() } return Node{}, fmt.Errorf("unknown node") @@ -307,7 +263,7 @@ func (n *nodeJSON) UnmarshalJSON(b []byte) error { // > This key is treated as the name of an extension function or method. The value must // > be a JSON array of values, each of which is itself an JsonExpr object. Note that for // > method calls, the method receiver is the first argument. - return json.Unmarshal(b, &n.ExtensionMethod) + return json.Unmarshal(b, &n.ExtensionCall) } func (p *patternComponentJSON) UnmarshalJSON(b []byte) error { diff --git a/x/exp/ast/node.go b/x/exp/ast/node.go index 1e3a913e..a1eb4b4a 100644 --- a/x/exp/ast/node.go +++ b/x/exp/ast/node.go @@ -33,11 +33,10 @@ type nodeTypeIsIn struct { Entity node } -type nodeTypeExtMethodCall struct { +type nodeTypeExtensionCall struct { node - Left node - Method types.String // TODO: review type - Args []node + Name types.String // TODO: review type + Args []node } func stripNodes(args []Node) []node { @@ -48,11 +47,22 @@ func stripNodes(args []Node) []node { return res } -func newExtMethodCallNode(left Node, method types.String, args ...Node) Node { - return newNode(nodeTypeExtMethodCall{ - Left: left.v, - Method: method, - Args: stripNodes(args), +func newExtensionCall(method types.String, args ...Node) Node { + return newNode(nodeTypeExtensionCall{ + Name: method, + Args: stripNodes(args), + }) +} + +func newMethodCall(lhs Node, method types.String, args ...Node) Node { + res := make([]node, 1+len(args)) + res[0] = lhs.v + for i, v := range args { + res[i+1] = v.v + } + return newNode(nodeTypeExtensionCall{ + Name: method, + Args: res, }) } diff --git a/x/exp/ast/operator.go b/x/exp/ast/operator.go index 52b3584e..95ed737c 100644 --- a/x/exp/ast/operator.go +++ b/x/exp/ast/operator.go @@ -34,19 +34,19 @@ func (lhs Node) GreaterThanOrEqual(rhs Node) Node { } func (lhs Node) LessThanExt(rhs Node) Node { - return newExtMethodCallNode(lhs, "lessThan", rhs) + return newMethodCall(lhs, "lessThan", rhs) } func (lhs Node) LessThanOrEqualExt(rhs Node) Node { - return newExtMethodCallNode(lhs, "lessThanOrEqual", rhs) + return newMethodCall(lhs, "lessThanOrEqual", rhs) } func (lhs Node) GreaterThanExt(rhs Node) Node { - return newExtMethodCallNode(lhs, "greaterThan", rhs) + return newMethodCall(lhs, "greaterThan", rhs) } func (lhs Node) GreaterThanOrEqualExt(rhs Node) Node { - return newExtMethodCallNode(lhs, "greaterThanOrEqual", rhs) + return newMethodCall(lhs, "greaterThanOrEqual", rhs) } func (lhs Node) Like(pattern Pattern) Node { @@ -146,21 +146,21 @@ func (lhs Node) Has(attr string) Node { // |___|_| /_/ \_\__,_|\__,_|_| \___||___/___/ func (lhs Node) IsIpv4() Node { - return newExtMethodCallNode(lhs, "isIpv4") + return newMethodCall(lhs, "isIpv4") } func (lhs Node) IsIpv6() Node { - return newExtMethodCallNode(lhs, "isIpv6") + return newMethodCall(lhs, "isIpv6") } func (lhs Node) IsMulticast() Node { - return newExtMethodCallNode(lhs, "isMulticast") + return newMethodCall(lhs, "isMulticast") } func (lhs Node) IsLoopback() Node { - return newExtMethodCallNode(lhs, "isLoopback") + return newMethodCall(lhs, "isLoopback") } func (lhs Node) IsInRange(rhs Node) Node { - return newExtMethodCallNode(lhs, "isInRange", rhs) + return newMethodCall(lhs, "isInRange", rhs) } diff --git a/x/exp/ast/value.go b/x/exp/ast/value.go index 9548301b..4a47ae97 100644 --- a/x/exp/ast/value.go +++ b/x/exp/ast/value.go @@ -96,6 +96,10 @@ func IPAddr(i types.IPAddr) Node { return newValueNode(i) } +func ExtensionCall(name types.String, args ...Node) Node { + return newExtensionCall(name, args...) +} + func newValueNode(v types.Value) Node { return newNode(nodeValue{Value: v}) } From dd03e592fa07fe37059fddc9ae40c99b50f113e3 Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Tue, 6 Aug 2024 11:51:12 -0700 Subject: [PATCH 063/216] cedar-go/x/exp/ast: fix a bug in a isIdentRune() where non-ASCII digits and letters were erroneously allowed Signed-off-by: philhassey --- x/exp/ast/cedar_tokenize.go | 11 +++++++++-- x/exp/ast/cedar_tokenize_test.go | 6 ++++-- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/x/exp/ast/cedar_tokenize.go b/x/exp/ast/cedar_tokenize.go index 9aee4db5..13974de6 100644 --- a/x/exp/ast/cedar_tokenize.go +++ b/x/exp/ast/cedar_tokenize.go @@ -6,7 +6,6 @@ import ( "io" "strconv" "strings" - "unicode" "unicode/utf8" ) @@ -394,8 +393,16 @@ func (s *scanner) error(msg string) { s.err = fmt.Errorf("%v: %v", s.position, msg) } +func isASCIILetter(ch rune) bool { + return (ch >= 'A' && ch <= 'Z') || (ch >= 'a' && ch <= 'z') +} + +func isASCIINumber(ch rune) bool { + return ch >= '0' && ch <= '9' +} + func isIdentRune(ch rune, first bool) bool { - return ch == '_' || unicode.IsLetter(ch) || unicode.IsDigit(ch) && !first + return ch == '_' || isASCIILetter(ch) || isASCIINumber(ch) && !first } func (s *scanner) scanIdentifier() rune { diff --git a/x/exp/ast/cedar_tokenize_test.go b/x/exp/ast/cedar_tokenize_test.go index fbf9db5a..02377584 100644 --- a/x/exp/ast/cedar_tokenize_test.go +++ b/x/exp/ast/cedar_tokenize_test.go @@ -27,7 +27,7 @@ These are some identifiers multiline comment // embedded comment does nothing */ -'/%|&=` +'/%|&=ë٩` want := []Token{ {Type: TokenIdent, Text: "These", Pos: Position{Offset: 1, Line: 2, Column: 1}}, {Type: TokenIdent, Text: "are", Pos: Position{Offset: 7, Line: 2, Column: 7}}, @@ -87,8 +87,10 @@ multiline comment {Type: TokenUnknown, Text: "|", Pos: Position{Offset: 268, Line: 16, Column: 4}}, {Type: TokenUnknown, Text: "&", Pos: Position{Offset: 269, Line: 16, Column: 5}}, {Type: TokenUnknown, Text: "=", Pos: Position{Offset: 270, Line: 16, Column: 6}}, + {Type: TokenUnknown, Text: "ë", Pos: Position{Offset: 271, Line: 16, Column: 7}}, + {Type: TokenUnknown, Text: "٩", Pos: Position{Offset: 273, Line: 16, Column: 8}}, - {Type: TokenEOF, Text: "", Pos: Position{Offset: 271, Line: 16, Column: 7}}, + {Type: TokenEOF, Text: "", Pos: Position{Offset: 275, Line: 16, Column: 9}}, } got, err := Tokenize([]byte(input)) testutil.OK(t, err) From 0be4bc82dae8f167aacc23db7acb916eb4a9b282 Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Tue, 6 Aug 2024 16:42:35 -0700 Subject: [PATCH 064/216] cedar-go/x/exp/ast: add support for marshaling an AST to Cedar text Signed-off-by: philhassey --- x/exp/ast/cedar_marshal.go | 310 ++++++++++++++++++++++++++++++ x/exp/ast/cedar_unmarshal.go | 15 +- x/exp/ast/cedar_unmarshal_test.go | 304 +++++++++++++++-------------- x/exp/ast/node.go | 227 +++++++++++++++++----- x/exp/ast/pattern.go | 21 ++ x/exp/ast/scope.go | 7 +- 6 files changed, 688 insertions(+), 196 deletions(-) create mode 100644 x/exp/ast/cedar_marshal.go diff --git a/x/exp/ast/cedar_marshal.go b/x/exp/ast/cedar_marshal.go new file mode 100644 index 00000000..686ca998 --- /dev/null +++ b/x/exp/ast/cedar_marshal.go @@ -0,0 +1,310 @@ +package ast + +import ( + "bytes" +) + +// TODO: Add errors to all of this! +func (p *Policy) MarshalCedar(buf *bytes.Buffer) { + for _, a := range p.annotations { + a.MarshalCedar(buf) + buf.WriteRune('\n') + } + p.effect.MarshalCedar(buf) + buf.WriteRune(' ') + p.marshalScope(buf) + + for _, c := range p.conditions { + buf.WriteRune('\n') + c.MarshalCedar(buf) + } + + buf.WriteRune(';') +} + +func (p *Policy) marshalScope(buf *bytes.Buffer) { + _, principalAll := p.principal.(scopeTypeAll) + _, actionAll := p.action.(scopeTypeAll) + _, resourceAll := p.resource.(scopeTypeAll) + if principalAll && actionAll && resourceAll { + buf.WriteString("( principal, action, resource )") + return + } + + buf.WriteString("(\n ") + p.principal.MarshalCedar(buf) + buf.WriteString(",\n ") + p.action.MarshalCedar(buf) + buf.WriteString(",\n ") + p.resource.MarshalCedar(buf) + buf.WriteString("\n)") +} + +func (n annotationType) MarshalCedar(buf *bytes.Buffer) { + buf.WriteRune('@') + buf.WriteString(string(n.Key)) + buf.WriteRune('(') + buf.WriteString(n.Value.Cedar()) + buf.WriteString(")") +} + +func (e effect) MarshalCedar(buf *bytes.Buffer) { + if e == effectPermit { + buf.WriteString("permit") + } else { + buf.WriteString("forbid") + } +} + +func (n nodeTypeVariable) marshalCedar(buf *bytes.Buffer) { + buf.WriteString(string(n.Name)) +} + +func (n scopeTypeAll) MarshalCedar(buf *bytes.Buffer) { + n.Variable.marshalCedar(buf) +} + +func (n scopeTypeEq) MarshalCedar(buf *bytes.Buffer) { + n.Variable.marshalCedar(buf) + buf.WriteString(" == ") + buf.WriteString(n.Entity.Cedar()) +} + +func (n scopeTypeIn) MarshalCedar(buf *bytes.Buffer) { + n.Variable.marshalCedar(buf) + buf.WriteString(" in ") + buf.WriteString(n.Entity.Cedar()) +} + +func (n scopeTypeInSet) MarshalCedar(buf *bytes.Buffer) { + n.Variable.marshalCedar(buf) + buf.WriteString(" in ") + buf.WriteRune('[') + for i := range n.Entities { + buf.WriteString(n.Entities[i].Cedar()) + if i != len(n.Entities)-1 { + buf.WriteString(", ") + } + } + buf.WriteRune(']') +} + +func (n scopeTypeIs) MarshalCedar(buf *bytes.Buffer) { + n.Variable.marshalCedar(buf) + buf.WriteString(" is ") + buf.WriteString(string(n.Type)) +} + +func (n scopeTypeIsIn) MarshalCedar(buf *bytes.Buffer) { + n.Variable.marshalCedar(buf) + buf.WriteString(" is ") + buf.WriteString(string(n.Type)) + buf.WriteString(" in ") + buf.WriteString(n.Entity.Cedar()) +} + +func (c conditionType) MarshalCedar(buf *bytes.Buffer) { + if c.Condition == conditionWhen { + buf.WriteString("when") + } else { + buf.WriteString("unless") + } + + buf.WriteString(" { ") + c.Body.marshalCedar(buf) + buf.WriteString(" }") +} + +func (n nodeValue) marshalCedar(buf *bytes.Buffer) { + buf.WriteString(n.Value.Cedar()) +} + +func marshalChildNode(thisNodePrecedence nodePrecedenceLevel, childNode node, buf *bytes.Buffer) { + if thisNodePrecedence > childNode.precedenceLevel() { + buf.WriteRune('(') + childNode.marshalCedar(buf) + buf.WriteRune(')') + } else { + childNode.marshalCedar(buf) + } +} + +func (n nodeTypeNot) marshalCedar(buf *bytes.Buffer) { + buf.WriteRune('!') + marshalChildNode(n.precedenceLevel(), n.Arg, buf) +} + +func (n nodeTypeNegate) marshalCedar(buf *bytes.Buffer) { + buf.WriteRune('-') + marshalChildNode(n.precedenceLevel(), n.Arg, buf) +} + +func canMarshalAsIdent(s string) bool { + for i, r := range s { + if !isIdentRune(r, i == 0) { + return false + } + } + return true +} + +func (n nodeTypeAccess) marshalCedar(buf *bytes.Buffer) { + marshalChildNode(n.precedenceLevel(), n.Arg, buf) + + if canMarshalAsIdent(string(n.Value)) { + buf.WriteRune('.') + buf.WriteString(string(n.Value)) + } else { + buf.WriteRune('[') + buf.WriteString(n.Value.Cedar()) + buf.WriteRune(']') + } +} + +func (n nodeTypeExtensionCall) marshalCedar(buf *bytes.Buffer) { + var args []node + if n.Name != "ip" && n.Name != "decimal" { + marshalChildNode(n.precedenceLevel(), n.Args[0], buf) + buf.WriteRune('.') + args = n.Args[1:] + } else { + args = n.Args + } + + buf.WriteString(string(n.Name)) + buf.WriteRune('(') + for i := range args { + marshalChildNode(n.precedenceLevel(), n.Args[i], buf) + if i != len(n.Args)-1 { + buf.WriteString(", ") + } + } + buf.WriteRune(')') +} + +func (n nodeTypeContains) marshalCedar(buf *bytes.Buffer) { + marshalChildNode(n.precedenceLevel(), n.Left, buf) + buf.WriteString(".contains(") + marshalChildNode(n.precedenceLevel(), n.Right, buf) + buf.WriteRune(')') +} + +func (n nodeTypeContainsAll) marshalCedar(buf *bytes.Buffer) { + marshalChildNode(n.precedenceLevel(), n.Left, buf) + buf.WriteString(".containsAll(") + marshalChildNode(n.precedenceLevel(), n.Right, buf) + buf.WriteRune(')') +} + +func (n nodeTypeContainsAny) marshalCedar(buf *bytes.Buffer) { + marshalChildNode(n.precedenceLevel(), n.Left, buf) + buf.WriteString(".containsAny(") + marshalChildNode(n.precedenceLevel(), n.Right, buf) + buf.WriteRune(')') +} + +func (n nodeTypeSet) marshalCedar(buf *bytes.Buffer) { + buf.WriteRune('[') + for i := range n.Elements { + marshalChildNode(n.precedenceLevel(), n.Elements[i], buf) + if i != len(n.Elements)-1 { + buf.WriteString(", ") + } + } + buf.WriteRune(']') +} + +func marshalInfixBinaryOp(n binaryNode, precedence nodePrecedenceLevel, op string, buf *bytes.Buffer) { + marshalChildNode(precedence, n.Left, buf) + buf.WriteRune(' ') + buf.WriteString(op) + buf.WriteRune(' ') + marshalChildNode(precedence, n.Right, buf) +} + +func (n nodeTypeMult) marshalCedar(buf *bytes.Buffer) { + marshalInfixBinaryOp(n.binaryNode, n.precedenceLevel(), "*", buf) +} + +func (n nodeTypeAdd) marshalCedar(buf *bytes.Buffer) { + marshalInfixBinaryOp(n.binaryNode, n.precedenceLevel(), "+", buf) +} + +func (n nodeTypeSub) marshalCedar(buf *bytes.Buffer) { + marshalInfixBinaryOp(n.binaryNode, n.precedenceLevel(), "-", buf) +} + +func (n nodeTypeLessThan) marshalCedar(buf *bytes.Buffer) { + marshalInfixBinaryOp(n.binaryNode, n.precedenceLevel(), "<", buf) +} + +func (n nodeTypeLessThanOrEqual) marshalCedar(buf *bytes.Buffer) { + marshalInfixBinaryOp(n.binaryNode, n.precedenceLevel(), "<=", buf) +} + +func (n nodeTypeGreaterThan) marshalCedar(buf *bytes.Buffer) { + marshalInfixBinaryOp(n.binaryNode, n.precedenceLevel(), ">", buf) +} + +func (n nodeTypeGreaterThanOrEqual) marshalCedar(buf *bytes.Buffer) { + marshalInfixBinaryOp(n.binaryNode, n.precedenceLevel(), ">=", buf) +} + +func (n nodeTypeEquals) marshalCedar(buf *bytes.Buffer) { + marshalInfixBinaryOp(n.binaryNode, n.precedenceLevel(), "==", buf) +} + +func (n nodeTypeNotEquals) marshalCedar(buf *bytes.Buffer) { + marshalInfixBinaryOp(n.binaryNode, n.precedenceLevel(), "!=", buf) +} + +func (n nodeTypeIn) marshalCedar(buf *bytes.Buffer) { + marshalInfixBinaryOp(n.binaryNode, n.precedenceLevel(), "in", buf) +} + +func (n nodeTypeAnd) marshalCedar(buf *bytes.Buffer) { + marshalInfixBinaryOp(n.binaryNode, n.precedenceLevel(), "&&", buf) +} + +func (n nodeTypeOr) marshalCedar(buf *bytes.Buffer) { + marshalInfixBinaryOp(n.binaryNode, n.precedenceLevel(), "||", buf) +} + +func (n nodeTypeHas) marshalCedar(buf *bytes.Buffer) { + marshalChildNode(n.precedenceLevel(), n.Arg, buf) + buf.WriteString(" has ") + if canMarshalAsIdent(string(n.Value)) { + buf.WriteString(string(n.Value)) + } else { + buf.WriteString(n.Value.Cedar()) + } +} + +func (n nodeTypeIs) marshalCedar(buf *bytes.Buffer) { + marshalChildNode(n.precedenceLevel(), n.Left, buf) + buf.WriteString(" is ") + buf.WriteString(string(n.EntityType)) +} + +func (n nodeTypeIsIn) marshalCedar(buf *bytes.Buffer) { + marshalChildNode(n.precedenceLevel(), n.Left, buf) + buf.WriteString(" is ") + buf.WriteString(string(n.EntityType)) + buf.WriteString(" in ") + n.Entity.marshalCedar(buf) +} + +func (n nodeTypeLike) marshalCedar(buf *bytes.Buffer) { + marshalChildNode(n.precedenceLevel(), n.Arg, buf) + buf.WriteString(" like ") + n.Value.MarshalCedar(buf) +} + +func (n nodeTypeIf) marshalCedar(buf *bytes.Buffer) { + buf.WriteString("if ") + marshalChildNode(n.precedenceLevel(), n.If, buf) + buf.WriteString(" then ") + marshalChildNode(n.precedenceLevel(), n.Then, buf) + buf.WriteString(" else ") + marshalChildNode(n.precedenceLevel(), n.Else, buf) +} diff --git a/x/exp/ast/cedar_unmarshal.go b/x/exp/ast/cedar_unmarshal.go index 591a99bc..b2c03036 100644 --- a/x/exp/ast/cedar_unmarshal.go +++ b/x/exp/ast/cedar_unmarshal.go @@ -3,7 +3,6 @@ package ast import ( "fmt" "net/netip" - "strconv" "strings" "github.com/cedar-policy/cedar-go/types" @@ -747,17 +746,21 @@ func (p *parser) entityOrExtFun(ident string) (Node, error) { } if ident == "ip" { - ipaddr, err := netip.ParsePrefix(str) + prefix, err := netip.ParsePrefix(str) if err != nil { - return res, err + ipaddr, err := netip.ParseAddr(str) + if err != nil { + return Node{}, err + } + prefix = netip.PrefixFrom(ipaddr, 32) } - res = IPAddr(types.IPAddr(ipaddr)) + res = IPAddr(types.IPAddr(prefix)) } else { - dec, err := strconv.ParseFloat(str, 64) + dec, err := types.ParseDecimal(str) if err != nil { return res, err } - res = Decimal(types.Decimal(dec)) + res = Decimal(dec) } default: entity, err := p.entityFirstPathPreread(ident) diff --git a/x/exp/ast/cedar_unmarshal_test.go b/x/exp/ast/cedar_unmarshal_test.go index 2e0260ea..32fbf22a 100644 --- a/x/exp/ast/cedar_unmarshal_test.go +++ b/x/exp/ast/cedar_unmarshal_test.go @@ -1,6 +1,8 @@ package ast_test import ( + "bytes" + "net/netip" "testing" "github.com/cedar-policy/cedar-go/testutil" @@ -46,358 +48,360 @@ func TestParsePolicy(t *testing.T) { }{ { "permit any scope", - `permit ( - principal, - action, - resource - );`, + `permit ( principal, action, resource );`, ast.Permit(), }, { "forbid any scope", - `forbid ( - principal, - action, - resource - );`, + `forbid ( principal, action, resource );`, ast.Forbid(), }, { "one annotation", `@foo("bar") - permit ( - principal, - action, - resource - );`, +permit ( principal, action, resource );`, ast.Annotation("foo", "bar").Permit(), }, { "two annotations", `@foo("bar") - @baz("quux") - permit ( - principal, - action, - resource - );`, +@baz("quux") +permit ( principal, action, resource );`, ast.Annotation("foo", "bar").Annotation("baz", "quux").Permit(), }, { "scope eq", `permit ( - principal == User::"johnny", - action == Action::"sow", - resource == Crop::"apple" - );`, + principal == User::"johnny", + action == Action::"sow", + resource == Crop::"apple" +);`, ast.Permit().PrincipalEq(johnny).ActionEq(sow).ResourceEq(apple), }, { "scope is", `permit ( - principal is User, - action, - resource is Crop - );`, + principal is User, + action, + resource is Crop +);`, ast.Permit().PrincipalIs("User").ResourceIs("Crop"), }, { "scope is in", `permit ( - principal is User in Group::"folkHeroes", - action, - resource is Crop in Genus::"malus" - );`, + principal is User in Group::"folkHeroes", + action, + resource is Crop in Genus::"malus" +);`, ast.Permit().PrincipalIsIn("User", folkHeroes).ResourceIsIn("Crop", malus), }, { "scope in", `permit ( - principal in Group::"folkHeroes", - action in ActionType::"farming", - resource in Genus::"malus" - );`, + principal in Group::"folkHeroes", + action in ActionType::"farming", + resource in Genus::"malus" +);`, ast.Permit().PrincipalIn(folkHeroes).ActionIn(farming).ResourceIn(malus), }, { "scope action in entities", `permit ( - principal, - action in [ActionType::"farming", ActionType::"forestry"], - resource - );`, + principal, + action in [ActionType::"farming", ActionType::"forestry"], + resource +);`, ast.Permit().ActionInSet(farming, forestry), }, { "trivial conditions", - `permit (principal, action, resource) - when { true } - unless { false };`, + `permit ( principal, action, resource ) +when { true } +unless { false };`, ast.Permit().When(ast.Boolean(true)).Unless(ast.Boolean(false)), }, { "not operator", - `permit (principal, action, resource) - when { !true };`, + `permit ( principal, action, resource ) +when { !true };`, ast.Permit().When(ast.Not(ast.Boolean(true))), }, { "multiple not operators", - `permit (principal, action, resource) - when { !!true };`, + `permit ( principal, action, resource ) +when { !!true };`, ast.Permit().When(ast.Not(ast.Not(ast.Boolean(true)))), }, { "negate operator", - `permit (principal, action, resource) - when { -1 };`, + `permit ( principal, action, resource ) +when { -1 };`, ast.Permit().When(ast.Negate(ast.Long(1))), }, { "mutliple negate operators", - `permit (principal, action, resource) - when { !--1 };`, + `permit ( principal, action, resource ) +when { !--1 };`, ast.Permit().When(ast.Not(ast.Negate(ast.Negate(ast.Long(1))))), }, { "variable member", - `permit (principal, action, resource) - when { context.boolValue };`, + `permit ( principal, action, resource ) +when { context.boolValue };`, ast.Permit().When(ast.Context().Access("boolValue")), }, + { + "variable member via []", + `permit ( principal, action, resource ) +when { context["2legit2quit"] };`, + ast.Permit().When(ast.Context().Access("2legit2quit")), + }, { "contains method call", - `permit (principal, action, resource) - when { context.strings.contains("foo") };`, + `permit ( principal, action, resource ) +when { context.strings.contains("foo") };`, ast.Permit().When(ast.Context().Access("strings").Contains(ast.String("foo"))), }, { "containsAll method call", - `permit (principal, action, resource) - when { context.strings.containsAll(["foo"]) };`, + `permit ( principal, action, resource ) +when { context.strings.containsAll(["foo"]) };`, ast.Permit().When(ast.Context().Access("strings").ContainsAll(ast.SetNodes(ast.String("foo")))), }, { "containsAny method call", - `permit (principal, action, resource) - when { context.strings.containsAny(["foo"]) };`, + `permit ( principal, action, resource ) +when { context.strings.containsAny(["foo"]) };`, ast.Permit().When(ast.Context().Access("strings").ContainsAny(ast.SetNodes(ast.String("foo")))), }, { "extension method call", - `permit (principal, action, resource) - when { context.sourceIP.isIpv4() };`, + `permit ( principal, action, resource ) +when { context.sourceIP.isIpv4() };`, ast.Permit().When(ast.Context().Access("sourceIP").IsIpv4()), }, { "multiplication", - `permit (principal, action, resource) - when { 42 * 2 };`, + `permit ( principal, action, resource ) +when { 42 * 2 };`, ast.Permit().When(ast.Long(42).Times(ast.Long(2))), }, { "multiple multiplication", - `permit (principal, action, resource) - when { 42 * 2 * 1};`, + `permit ( principal, action, resource ) +when { 42 * 2 * 1 };`, ast.Permit().When(ast.Long(42).Times(ast.Long(2)).Times(ast.Long(1))), }, { "addition", - `permit (principal, action, resource) - when { 42 + 2 };`, + `permit ( principal, action, resource ) +when { 42 + 2 };`, ast.Permit().When(ast.Long(42).Plus(ast.Long(2))), }, { "multiple addition", - `permit (principal, action, resource) - when { 42 + 2 + 1 };`, + `permit ( principal, action, resource ) +when { 42 + 2 + 1 };`, ast.Permit().When(ast.Long(42).Plus(ast.Long(2)).Plus(ast.Long(1))), }, { "subtraction", - `permit (principal, action, resource) - when { 42 - 2 };`, + `permit ( principal, action, resource ) +when { 42 - 2 };`, ast.Permit().When(ast.Long(42).Minus(ast.Long(2))), }, { "multiple subtraction", - `permit (principal, action, resource) - when { 42 - 2 - 1 };`, + `permit ( principal, action, resource ) +when { 42 - 2 - 1 };`, ast.Permit().When(ast.Long(42).Minus(ast.Long(2)).Minus(ast.Long(1))), }, { "mixed addition and subtraction", - `permit (principal, action, resource) - when { 42 - 2 + 1 };`, + `permit ( principal, action, resource ) +when { 42 - 2 + 1 };`, ast.Permit().When(ast.Long(42).Minus(ast.Long(2)).Plus(ast.Long(1))), }, { "less than", - `permit (principal, action, resource) - when { 2 < 42 };`, + `permit ( principal, action, resource ) +when { 2 < 42 };`, ast.Permit().When(ast.Long(2).LessThan(ast.Long(42))), }, { "less than or equal", - `permit (principal, action, resource) - when { 2 <= 42 };`, + `permit ( principal, action, resource ) +when { 2 <= 42 };`, ast.Permit().When(ast.Long(2).LessThanOrEqual(ast.Long(42))), }, { "greater than", - `permit (principal, action, resource) - when { 2 > 42 };`, + `permit ( principal, action, resource ) +when { 2 > 42 };`, ast.Permit().When(ast.Long(2).GreaterThan(ast.Long(42))), }, { "greater than or equal", - `permit (principal, action, resource) - when { 2 >= 42 };`, + `permit ( principal, action, resource ) +when { 2 >= 42 };`, ast.Permit().When(ast.Long(2).GreaterThanOrEqual(ast.Long(42))), }, { "equal", - `permit (principal, action, resource) - when { 2 == 42 };`, + `permit ( principal, action, resource ) +when { 2 == 42 };`, ast.Permit().When(ast.Long(2).Equals(ast.Long(42))), }, { "not equal", - `permit (principal, action, resource) - when { 2 != 42 };`, + `permit ( principal, action, resource ) +when { 2 != 42 };`, ast.Permit().When(ast.Long(2).NotEquals(ast.Long(42))), }, { "in", - `permit (principal, action, resource) - when { principal in Group::"folkHeroes" };`, + `permit ( principal, action, resource ) +when { principal in Group::"folkHeroes" };`, ast.Permit().When(ast.Principal().In(ast.EntityUID(folkHeroes))), }, { "has ident", - `permit (principal, action, resource) - when { principal has firstName };`, + `permit ( principal, action, resource ) +when { principal has firstName };`, ast.Permit().When(ast.Principal().Has("firstName")), }, { "has string", - `permit (principal, action, resource) - when { principal has "firstName" };`, - ast.Permit().When(ast.Principal().Has("firstName")), + `permit ( principal, action, resource ) +when { principal has "1stName" };`, + ast.Permit().When(ast.Principal().Has("1stName")), }, // N.B. Most pattern parsing tests can be found in pattern_test.go { "like no wildcards", - `permit (principal, action, resource) - when { principal.firstName like "johnny" };`, + `permit ( principal, action, resource ) +when { principal.firstName like "johnny" };`, ast.Permit().When(ast.Principal().Access("firstName").Like(testutil.Must(ast.PatternFromCedar("johnny")))), }, { "like escaped asterisk", - `permit (principal, action, resource) - when { principal.firstName like "joh\*nny" };`, + `permit ( principal, action, resource ) +when { principal.firstName like "joh\*nny" };`, ast.Permit().When(ast.Principal().Access("firstName").Like(testutil.Must(ast.PatternFromCedar(`joh\*nny`)))), }, { "like wildcard", - `permit (principal, action, resource) - when { principal.firstName like "*" };`, + `permit ( principal, action, resource ) +when { principal.firstName like "*" };`, ast.Permit().When(ast.Principal().Access("firstName").Like(testutil.Must(ast.PatternFromCedar("*")))), }, { "is", - `permit (principal, action, resource) - when { principal is User };`, + `permit ( principal, action, resource ) +when { principal is User };`, ast.Permit().When(ast.Principal().Is("User")), }, { "is in", - `permit (principal, action, resource) - when { principal is User in Group::"folkHeroes" };`, - ast.Permit().When(ast.Principal().IsIn("User", ast.EntityUID(folkHeroes))), - }, - { - "is in", - `permit (principal, action, resource) - when { principal is User in Group::"folkHeroes" };`, + `permit ( principal, action, resource ) +when { principal is User in Group::"folkHeroes" };`, ast.Permit().When(ast.Principal().IsIn("User", ast.EntityUID(folkHeroes))), }, { "and", - `permit (principal, action, resource) - when { true && false };`, + `permit ( principal, action, resource ) +when { true && false };`, ast.Permit().When(ast.True().And(ast.False())), }, { "multiple and", - `permit (principal, action, resource) - when { true && false && true };`, + `permit ( principal, action, resource ) +when { true && false && true };`, ast.Permit().When(ast.True().And(ast.False()).And(ast.True())), }, { "or", - `permit (principal, action, resource) - when { true || false };`, + `permit ( principal, action, resource ) +when { true || false };`, ast.Permit().When(ast.True().Or(ast.False())), }, { "multiple or", - `permit (principal, action, resource) - when { true || false || true };`, + `permit ( principal, action, resource ) +when { true || false || true };`, ast.Permit().When(ast.True().Or(ast.False()).Or(ast.True())), }, { "if then else", - `permit (principal, action, resource) - when { if true then true else false };`, + `permit ( principal, action, resource ) +when { if true then true else false };`, ast.Permit().When(ast.If(ast.True(), ast.True(), ast.False())), }, + { + "ip extension function", + `permit ( principal, action, resource ) +when { ip("1.2.3.4") == ip("2.3.4.5") };`, + ast.Permit().When( + ast.IPAddr(types.IPAddr(netip.MustParsePrefix("1.2.3.4/32"))).Equals( + ast.IPAddr(types.IPAddr(netip.MustParsePrefix("2.3.4.5/32"))), + ), + ), + }, + { + "decimal extension function", + `permit ( principal, action, resource ) +when { decimal("12.34") == decimal("23.45") };`, + ast.Permit().When( + ast.Decimal(types.Decimal(123400)).Equals(ast.Decimal(types.Decimal(234500))), + ), + }, { "and over or precedence", - `permit (principal, action, resource) - when { true && false || true && true };`, + `permit ( principal, action, resource ) +when { true && false || true && true };`, ast.Permit().When(ast.True().And(ast.False()).Or(ast.True().And(ast.True()))), }, { "rel over and precedence", - `permit (principal, action, resource) - when { 1 < 2 && true };`, + `permit ( principal, action, resource ) +when { 1 < 2 && true };`, ast.Permit().When(ast.Long(1).LessThan(ast.Long(2)).And(ast.True())), }, { "add over rel precedence", - `permit (principal, action, resource) - when { 1 + 1 < 3 };`, + `permit ( principal, action, resource ) +when { 1 + 1 < 3 };`, ast.Permit().When(ast.Long(1).Plus(ast.Long(1)).LessThan(ast.Long(3))), }, { - "mult over add precedence", - `permit (principal, action, resource) - when { 2 * 3 + 4 == 10 };`, + "mult over add precedence (rhs add)", + `permit ( principal, action, resource ) +when { 2 * 3 + 4 == 10 };`, ast.Permit().When(ast.Long(2).Times(ast.Long(3)).Plus(ast.Long(4)).Equals(ast.Long(10))), }, { - "unary over mult precedence", - `permit (principal, action, resource) - when { -2 * 3 == -6 };`, - ast.Permit().When(ast.Negate(ast.Long(2)).Times(ast.Long(3)).Equals(ast.Negate(ast.Long(6)))), + "mult over add precedence (lhs add)", + `permit ( principal, action, resource ) +when { 2 + 3 * 4 == 14 };`, + ast.Permit().When(ast.Long(2).Plus(ast.Long(3).Times(ast.Long(4))).Equals(ast.Long(14))), }, { - "member over unary precedence", - `permit (principal, action, resource) - when { -context.num };`, - ast.Permit().When(ast.Negate(ast.Context().Access("num"))), + "unary over mult precedence", + `permit ( principal, action, resource ) +when { -2 * 3 == -6 };`, + ast.Permit().When(ast.Negate(ast.Long(2)).Times(ast.Long(3)).Equals(ast.Negate(ast.Long(6)))), }, { "member over unary precedence", - `permit (principal, action, resource) - when { -context.num };`, + `permit ( principal, action, resource ) +when { -context.num };`, ast.Permit().When(ast.Negate(ast.Context().Access("num"))), }, { "parens over unary precedence", - `permit (principal, action, resource) - when { -(2 + 3) == -5 };`, + `permit ( principal, action, resource ) +when { -(2 + 3) == -5 };`, ast.Permit().When(ast.Negate(ast.Long(2).Plus(ast.Long(3))).Equals(ast.Negate(ast.Long(5)))), }, { @@ -406,6 +410,18 @@ func TestParsePolicy(t *testing.T) { when { (2 + 3 + 4) * 5 == 18 };`, ast.Permit().When(ast.Long(2).Plus(ast.Long(3)).Plus(ast.Long(4)).Times(ast.Long(5)).Equals(ast.Long(18))), }, + { + "parenthesized if", + `permit ( principal, action, resource ) +when { (if true then 2 else 3 * 4) == 2 };`, + ast.Permit().When(ast.If(ast.True(), ast.Long(2), ast.Long(3).Times(ast.Long(4))).Equals(ast.Long(2))), + }, + { + "parenthesized if with trailing mult", + `permit ( principal, action, resource ) +when { (if true then 2 else 3) * 4 == 8 };`, + ast.Permit().When(ast.If(ast.True(), ast.Long(2), ast.Long(3)).Times(ast.Long(4)).Equals(ast.Long(8))), + }, } for _, tt := range parseTests { @@ -415,6 +431,10 @@ when { (2 + 3 + 4) * 5 == 18 };`, var policy ast.Policy testutil.OK(t, policy.UnmarshalCedar([]byte(tt.Text))) testutil.Equals(t, policy, *tt.ExpectedPolicy) + + var buf bytes.Buffer + policy.MarshalCedar(&buf) + testutil.Equals(t, buf.String(), tt.Text) }) } } diff --git a/x/exp/ast/node.go b/x/exp/ast/node.go index a1eb4b4a..e2d67dfa 100644 --- a/x/exp/ast/node.go +++ b/x/exp/ast/node.go @@ -1,6 +1,18 @@ package ast -import "github.com/cedar-policy/cedar-go/types" +import ( + "bytes" + + "github.com/cedar-policy/cedar-go/types" +) + +type Node struct { + v node // NOTE: not an embed because a `Node` is not a `node` +} + +func newNode(v node) Node { + return Node{v: v} +} type strOpNode struct { node @@ -8,8 +20,87 @@ type strOpNode struct { Value types.String } -type nodeTypeAccess struct{ strOpNode } -type nodeTypeHas struct{ strOpNode } +type binaryNode struct { + node + Left, Right node +} + +type nodePrecedenceLevel uint8 + +const ( + ifPrecedence nodePrecedenceLevel = 0 + orPrecedence nodePrecedenceLevel = 1 + andPrecedence nodePrecedenceLevel = 2 + relationPrecedence nodePrecedenceLevel = 3 + addPrecedence nodePrecedenceLevel = 4 + multPrecedence nodePrecedenceLevel = 5 + unaryPrecedence nodePrecedenceLevel = 6 + accessPrecedence nodePrecedenceLevel = 7 + primaryPrecedence nodePrecedenceLevel = 8 +) + +type nodeTypeIf struct { + node + If, Then, Else node +} + +func (n nodeTypeIf) precedenceLevel() nodePrecedenceLevel { + return ifPrecedence +} + +type nodeTypeOr struct{ binaryNode } + +func (n nodeTypeOr) precedenceLevel() nodePrecedenceLevel { + return orPrecedence +} + +type nodeTypeAnd struct { + binaryNode +} + +func (n nodeTypeAnd) precedenceLevel() nodePrecedenceLevel { + return andPrecedence +} + +type relationNode struct{} + +func (n relationNode) precedenceLevel() nodePrecedenceLevel { + return relationPrecedence +} + +type nodeTypeLessThan struct { + binaryNode + relationNode +} +type nodeTypeLessThanOrEqual struct { + binaryNode + relationNode +} +type nodeTypeGreaterThan struct { + binaryNode + relationNode +} +type nodeTypeGreaterThanOrEqual struct { + binaryNode + relationNode +} +type nodeTypeNotEquals struct { + binaryNode + relationNode +} +type nodeTypeEquals struct { + binaryNode + relationNode +} +type nodeTypeIn struct { + binaryNode + relationNode +} + +type nodeTypeHas struct { + strOpNode + relationNode +} type nodeTypeLike struct { node @@ -17,9 +108,8 @@ type nodeTypeLike struct { Value Pattern } -type nodeTypeIf struct { - node - If, Then, Else node +func (n nodeTypeLike) precedenceLevel() nodePrecedenceLevel { + return relationPrecedence } type nodeTypeIs struct { @@ -28,17 +118,69 @@ type nodeTypeIs struct { EntityType types.String // TODO: review type } +func (n nodeTypeIs) precedenceLevel() nodePrecedenceLevel { + return relationPrecedence +} + type nodeTypeIsIn struct { nodeTypeIs Entity node } +func (n nodeTypeIsIn) precedenceLevel() nodePrecedenceLevel { + return relationPrecedence +} + +type addNode struct{} + +func (n addNode) precedenceLevel() nodePrecedenceLevel { + return addPrecedence +} + +type nodeTypeSub struct { + binaryNode + addNode +} + +type nodeTypeAdd struct { + binaryNode + addNode +} + +type nodeTypeMult struct{ binaryNode } + +func (n nodeTypeMult) precedenceLevel() nodePrecedenceLevel { + return multPrecedence +} + +type unaryNode struct { + node + Arg node +} + +func (n unaryNode) precedenceLevel() nodePrecedenceLevel { + return unaryPrecedence +} + +type nodeTypeNegate struct{ unaryNode } +type nodeTypeNot struct{ unaryNode } + +type nodeTypeAccess struct{ strOpNode } + +func (n nodeTypeAccess) precedenceLevel() nodePrecedenceLevel { + return accessPrecedence +} + type nodeTypeExtensionCall struct { node Name types.String // TODO: review type Args []node } +func (n nodeTypeExtensionCall) precedenceLevel() nodePrecedenceLevel { + return accessPrecedence +} + func stripNodes(args []Node) []node { res := make([]node, len(args)) for i, v := range args { @@ -66,8 +208,33 @@ func newMethodCall(lhs Node, method types.String, args ...Node) Node { }) } +type containsNode struct{} + +func (n containsNode) precedenceLevel() nodePrecedenceLevel { + return accessPrecedence +} + +type nodeTypeContains struct { + binaryNode + containsNode +} +type nodeTypeContainsAll struct { + binaryNode + containsNode +} +type nodeTypeContainsAny struct { + binaryNode + containsNode +} + +type primaryNode struct{ node } + +func (n primaryNode) precedenceLevel() nodePrecedenceLevel { + return primaryPrecedence +} + type nodeValue struct { - node + primaryNode Value types.Value } @@ -75,58 +242,24 @@ type recordElement struct { Key types.String Value node } + type nodeTypeRecord struct { - node + primaryNode Elements []recordElement } type nodeTypeSet struct { - node + primaryNode Elements []node } -type unaryNode struct { - node - Arg node -} - -type nodeTypeNegate struct{ unaryNode } -type nodeTypeNot struct{ unaryNode } - type nodeTypeVariable struct { - node + primaryNode Name types.String // TODO: Review type } -type binaryNode struct { - node - Left, Right node -} - -type nodeTypeIn struct{ binaryNode } -type nodeTypeAnd struct{ binaryNode } -type nodeTypeEquals struct{ binaryNode } -type nodeTypeGreaterThan struct{ binaryNode } -type nodeTypeGreaterThanOrEqual struct{ binaryNode } -type nodeTypeLessThan struct{ binaryNode } -type nodeTypeLessThanOrEqual struct{ binaryNode } -type nodeTypeSub struct{ binaryNode } -type nodeTypeAdd struct{ binaryNode } -type nodeTypeContains struct{ binaryNode } -type nodeTypeContainsAll struct{ binaryNode } -type nodeTypeContainsAny struct{ binaryNode } -type nodeTypeMult struct{ binaryNode } -type nodeTypeNotEquals struct{ binaryNode } -type nodeTypeOr struct{ binaryNode } - type node interface { isNode() -} - -type Node struct { - v node // NOTE: not an embed because a `Node` is not a `node` -} - -func newNode(v node) Node { - return Node{v: v} + marshalCedar(*bytes.Buffer) + precedenceLevel() nodePrecedenceLevel } diff --git a/x/exp/ast/pattern.go b/x/exp/ast/pattern.go index 1f19f5a5..6f96f0e4 100644 --- a/x/exp/ast/pattern.go +++ b/x/exp/ast/pattern.go @@ -1,5 +1,11 @@ package ast +import ( + "bytes" + "strconv" + "strings" +) + type PatternComponent struct { Star bool Chunk string @@ -31,6 +37,21 @@ func PatternFromCedar(cedar string) (Pattern, error) { }, nil } +func (p Pattern) MarshalCedar(buf *bytes.Buffer) { + buf.WriteRune('"') + for _, comp := range p.Comps { + if comp.Star { + buf.WriteRune('*') + } + // TODO: This is wrong. It needs to escape unicode the Rustic way. + quotedString := strconv.Quote(comp.Chunk) + quotedString = quotedString[1 : len(quotedString)-1] + quotedString = strings.Replace(quotedString, "*", "\\*", -1) + buf.WriteString(quotedString) + } + buf.WriteRune('"') +} + func (p *Pattern) AddWildcard() *Pattern { star := PatternComponent{Star: true} if len(p.Comps) == 0 { diff --git a/x/exp/ast/scope.go b/x/exp/ast/scope.go index 9d13bc37..927402e2 100644 --- a/x/exp/ast/scope.go +++ b/x/exp/ast/scope.go @@ -1,6 +1,10 @@ package ast -import "github.com/cedar-policy/cedar-go/types" +import ( + "bytes" + + "github.com/cedar-policy/cedar-go/types" +) type scope nodeTypeVariable @@ -85,6 +89,7 @@ func (p *Policy) ResourceIsIn(entityType types.String, entity types.EntityUID) * type isScopeNode interface { isScope() + MarshalCedar(*bytes.Buffer) } type scopeNode struct { From 100db7e4aa886bd4e94db1717bc89e75fe639fe8 Mon Sep 17 00:00:00 2001 From: philhassey Date: Wed, 7 Aug 2024 15:15:07 -0600 Subject: [PATCH 065/216] x/exp/ast: add compile / eval feature and ensure existing Cedar corpus tests pass Addresses IDX-49 Signed-off-by: philhassey --- cedar.go | 57 +- cedar_test.go | 22 +- eval.go | 21 +- x/exp/ast/cedar.go | 59 ++ x/exp/ast/cedar_unmarshal.go | 97 ++- x/exp/ast/cedar_unmarshal_test.go | 7 +- x/exp/ast/eval_compile.go | 43 ++ x/exp/ast/eval_convert.go | 124 ++++ x/exp/ast/eval_impl.go | 1086 +++++++++++++++++++++++++++++ x/exp/ast/eval_match.go | 66 ++ x/exp/ast/extensions.go | 24 + x/exp/ast/json_test.go | 8 +- x/exp/ast/json_unmarshal.go | 8 +- x/exp/ast/node.go | 24 +- x/exp/ast/operator.go | 4 +- x/exp/ast/scope.go | 45 +- 16 files changed, 1554 insertions(+), 141 deletions(-) create mode 100644 x/exp/ast/cedar.go create mode 100644 x/exp/ast/eval_compile.go create mode 100644 x/exp/ast/eval_convert.go create mode 100644 x/exp/ast/eval_impl.go create mode 100644 x/exp/ast/eval_match.go create mode 100644 x/exp/ast/extensions.go diff --git a/cedar.go b/cedar.go index 6340a9df..a5ab3653 100644 --- a/cedar.go +++ b/cedar.go @@ -2,14 +2,13 @@ package cedar import ( - "encoding/json" "fmt" + "slices" "strings" "github.com/cedar-policy/cedar-go/types" - "github.com/cedar-policy/cedar-go/x/exp/parser" + "github.com/cedar-policy/cedar-go/x/exp/ast" "golang.org/x/exp/maps" - "golang.org/x/exp/slices" ) // A PolicySet is a slice of policies. @@ -63,20 +62,13 @@ func (a *Effect) UnmarshalJSON(b []byte) error { // given file name used in Position data. If there is an error parsing the // document, it will be returned. func NewPolicySet(fileName string, document []byte) (PolicySet, error) { - var policies PolicySet - tokens, err := parser.Tokenize(document) - if err != nil { - return nil, fmt.Errorf("tokenize error: %w", err) - } - res, err := parser.Parse(tokens) - if err != nil { + var res ast.PolicySet + if err := res.UnmarshalCedar(document); err != nil { return nil, fmt.Errorf("parser error: %w", err) } + var policies PolicySet for _, p := range res { - ann := Annotations{} - for _, a := range p.Annotations { - ann[a.Key] = a.Value - } + ann := Annotations(p.TmpGetAnnotations()) policies = append(policies, Policy{ Position: Position{ Filename: fileName, @@ -85,38 +77,15 @@ func NewPolicySet(fileName string, document []byte) (PolicySet, error) { Column: p.Position.Column, }, Annotations: ann, - Effect: Effect(p.Effect == parser.EffectPermit), - eval: toEval(p), + Effect: Effect(p.TmpGetEffect()), + eval: ast.Compile(p.Policy), }) } return policies, nil } -// An Entities is a collection of all the Entities that are needed to evaluate -// authorization requests. The key is an EntityUID which uniquely identifies -// the Entity (it must be the same as the UID within the Entity itself.) -type Entities map[types.EntityUID]Entity - -// An Entity defines the parents and attributes for an EntityUID. -type Entity struct { - UID types.EntityUID `json:"uid"` - Parents []types.EntityUID `json:"parents,omitempty"` - Attributes types.Record `json:"attrs"` -} - -func (e Entities) MarshalJSON() ([]byte, error) { - s := e.toSlice() - return json.Marshal(s) -} - -func (e *Entities) UnmarshalJSON(b []byte) error { - var s []Entity - if err := json.Unmarshal(b, &s); err != nil { - return err - } - *e = entitiesFromSlice(s) - return nil -} +type Entities = ast.Entities +type Entity = ast.Entity func entitiesFromSlice(s []Entity) Entities { var res = Entities{} @@ -126,7 +95,7 @@ func entitiesFromSlice(s []Entity) Entities { return res } -func (e Entities) toSlice() []Entity { +func entitiesToSlice(e Entities) []Entity { s := maps.Values(e) slices.SortFunc(s, func(a, b Entity) int { return strings.Compare(a.UID.String(), b.UID.String()) @@ -134,10 +103,6 @@ func (e Entities) toSlice() []Entity { return s } -func (e Entities) Clone() Entities { - return maps.Clone(e) -} - // A Decision is the result of the authorization. type Decision bool diff --git a/cedar_test.go b/cedar_test.go index 0a0a61ee..290feea7 100644 --- a/cedar_test.go +++ b/cedar_test.go @@ -60,6 +60,7 @@ func TestIsAuthorized(t *testing.T) { Context types.Record Want Decision DiagErr int + ParseErr bool }{ { Name: "simple-permit", @@ -448,7 +449,8 @@ func TestIsAuthorized(t *testing.T) { Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: false, - DiagErr: 1, + DiagErr: 0, + ParseErr: true, }, { Name: "permit-when-set-containsAll-ok", @@ -470,7 +472,8 @@ func TestIsAuthorized(t *testing.T) { Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: false, - DiagErr: 1, + DiagErr: 0, + ParseErr: true, }, { Name: "permit-when-set-containsAny-ok", @@ -492,7 +495,8 @@ func TestIsAuthorized(t *testing.T) { Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: false, - DiagErr: 1, + DiagErr: 0, + ParseErr: true, }, { Name: "permit-when-record-attr", @@ -514,7 +518,8 @@ func TestIsAuthorized(t *testing.T) { Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: false, - DiagErr: 1, + DiagErr: 0, + ParseErr: true, }, { Name: "permit-when-like", @@ -536,7 +541,8 @@ func TestIsAuthorized(t *testing.T) { Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: false, - DiagErr: 1, + DiagErr: 0, + ParseErr: true, }, { Name: "permit-when-decimal", @@ -742,15 +748,15 @@ func TestIsAuthorized(t *testing.T) { t.Run(tt.Name, func(t *testing.T) { t.Parallel() ps, err := NewPolicySet("policy.cedar", []byte(tt.Policy)) - testutil.OK(t, err) + testutil.Equals(t, (err != nil), tt.ParseErr) ok, diag := ps.IsAuthorized(tt.Entities, Request{ Principal: tt.Principal, Action: tt.Action, Resource: tt.Resource, Context: tt.Context, }) - testutil.Equals(t, ok, tt.Want) testutil.Equals(t, len(diag.Errors), tt.DiagErr) + testutil.Equals(t, ok, tt.Want) }) } } @@ -774,7 +780,7 @@ func TestEntities(t *testing.T) { }, } entities := entitiesFromSlice(s) - s2 := entities.toSlice() + s2 := entitiesToSlice(entities) testutil.Equals(t, s2, s) }) t.Run("Clone", func(t *testing.T) { diff --git a/eval.go b/eval.go index b38159c5..eae3b660 100644 --- a/eval.go +++ b/eval.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/cedar-policy/cedar-go/types" + "github.com/cedar-policy/cedar-go/x/exp/ast" "github.com/cedar-policy/cedar-go/x/exp/parser" ) @@ -15,15 +16,19 @@ var errAttributeAccess = fmt.Errorf("does not have the attribute") var errEntityNotExist = fmt.Errorf("does not exist") var errUnspecifiedEntity = fmt.Errorf("unspecified entity") -type evalContext struct { - Entities Entities - Principal, Action, Resource types.Value - Context types.Value -} +// type evalContext struct { +// Entities Entities +// Principal, Action, Resource types.Value +// Context types.Value +// } -type evaler interface { - Eval(*evalContext) (types.Value, error) -} +type evalContext = ast.EvalContext + +// type evaler interface { +// Eval(*evalContext) (types.Value, error) +// } + +type evaler = ast.Evaler func evalBool(n evaler, ctx *evalContext) (types.Boolean, error) { v, err := n.Eval(ctx) diff --git a/x/exp/ast/cedar.go b/x/exp/ast/cedar.go new file mode 100644 index 00000000..e5f94e5c --- /dev/null +++ b/x/exp/ast/cedar.go @@ -0,0 +1,59 @@ +package ast + +// TODO: this is a partial cut-and-paste from the main cedar package +// and will need completion / review + +import ( + "encoding/json" + "slices" + "strings" + + "github.com/cedar-policy/cedar-go/types" + "golang.org/x/exp/maps" +) + +// An Entities is a collection of all the Entities that are needed to evaluate +// authorization requests. The key is an EntityUID which uniquely identifies +// the Entity (it must be the same as the UID within the Entity itself.) +type Entities map[types.EntityUID]Entity + +// An Entity defines the parents and attributes for an EntityUID. +type Entity struct { + UID types.EntityUID `json:"uid"` + Parents []types.EntityUID `json:"parents,omitempty"` + Attributes types.Record `json:"attrs"` +} + +func (e Entities) MarshalJSON() ([]byte, error) { + s := e.toSlice() + return json.Marshal(s) +} + +func (e *Entities) UnmarshalJSON(b []byte) error { + var s []Entity + if err := json.Unmarshal(b, &s); err != nil { + return err + } + *e = entitiesFromSlice(s) + return nil +} + +func entitiesFromSlice(s []Entity) Entities { + var res = Entities{} + for _, e := range s { + res[e.UID] = e + } + return res +} + +func (e Entities) toSlice() []Entity { + s := maps.Values(e) + slices.SortFunc(s, func(a, b Entity) int { + return strings.Compare(a.UID.String(), b.UID.String()) + }) + return s +} + +func (e Entities) Clone() Entities { + return maps.Clone(e) +} diff --git a/x/exp/ast/cedar_unmarshal.go b/x/exp/ast/cedar_unmarshal.go index b2c03036..723ec8a5 100644 --- a/x/exp/ast/cedar_unmarshal.go +++ b/x/exp/ast/cedar_unmarshal.go @@ -2,7 +2,6 @@ package ast import ( "fmt" - "net/netip" "strings" "github.com/cedar-policy/cedar-go/types" @@ -252,13 +251,8 @@ func (p *parser) entityFirstPathPreread(firstPath string) (types.EntityUID, erro } } -func (p *parser) path() (types.String, error) { - var res types.String - t := p.advance() - if !t.isIdent() { - return res, p.errorf("expected ident") - } - res = types.String(t.Text) +func (p *parser) pathFirstPathPreread(firstPath string) (types.Path, error) { + res := types.Path(firstPath) for { if p.peek().Text != "::" { return res, nil @@ -267,13 +261,21 @@ func (p *parser) path() (types.String, error) { t := p.advance() switch { case t.isIdent(): - res = types.String(fmt.Sprintf("%v::%v", res, t.Text)) + res = types.Path(fmt.Sprintf("%v::%v", res, t.Text)) default: return res, p.errorf("unexpected token") } } } +func (p *parser) path() (types.Path, error) { + t := p.advance() + if !t.isIdent() { + return "", p.errorf("expected ident") + } + return p.pathFirstPathPreread(t.Text) +} + func (p *parser) action(policy *Policy) error { if err := p.exact("action"); err != nil { return err @@ -722,55 +724,45 @@ func (p *parser) primary() (Node, error) { } func (p *parser) entityOrExtFun(ident string) (Node, error) { - // Technically, according to the grammar, both entities and extension functions - // can have path prefixes and so parsing here is not trivial. In practice, there - // are only two extension functions: `ip()` and `decimal()`, neither of which - // have a path prefix. We'll just handle those two cases specially and treat - // everything else as an entity. - var res Node - switch ident { - case "ip", "decimal": - if err := p.exact("("); err != nil { - return res, err - } + var res types.EntityUID + var err error + res.Type = ident + for { t := p.advance() - if !t.isString() { - return res, p.errorf("expected string") - } - str, err := t.stringValue() - if err != nil { - return res, err - } - if err := p.exact(")"); err != nil { - return res, err - } - - if ident == "ip" { - prefix, err := netip.ParsePrefix(str) - if err != nil { - ipaddr, err := netip.ParseAddr(str) + switch t.Text { + case "::": + t := p.advance() + switch { + case t.isIdent(): + res.Type = fmt.Sprintf("%v::%v", res.Type, t.Text) + case t.isString(): + res.ID, err = t.stringValue() if err != nil { return Node{}, err } - prefix = netip.PrefixFrom(ipaddr, 32) + return EntityUID(res), nil + default: + return Node{}, p.errorf("unexpected token") } - res = IPAddr(types.IPAddr(prefix)) - } else { - dec, err := types.ParseDecimal(str) + case "(": + args, err := p.expressions(")") if err != nil { - return res, err + return Node{}, err } - res = Decimal(dec) - } - default: - entity, err := p.entityFirstPathPreread(ident) - if err != nil { - return res, err + p.advance() + + i, ok := extMap[types.String(res.Type)] + if !ok { + return Node{}, p.errorf("`%v` is not a function", res.Type) + } + if i.IsMethod { + return Node{}, p.errorf("`%v` is a method, not a function", res.Type) + } + return ExtensionCall(types.String(res.Type), args...), nil + default: + return Node{}, p.errorf("unexpected token") } - res = EntityUID(entity) } - - return res, nil } func (p *parser) expressions(endOfListMarker string) ([]Node, error) { @@ -870,6 +862,13 @@ func (p *parser) access(lhs Node) (Node, bool, error) { case "containsAny": knownMethod = Node.ContainsAny default: + i, ok := extMap[types.String(methodName)] + if !ok { + return Node{}, false, p.errorf("not a valid method name: `%v`", methodName) + } + if !i.IsMethod { + return Node{}, false, p.errorf("`%v` is a function, not a method", methodName) + } return newMethodCall(lhs, types.String(methodName), exprs...), true, nil } diff --git a/x/exp/ast/cedar_unmarshal_test.go b/x/exp/ast/cedar_unmarshal_test.go index 32fbf22a..6b70b9d8 100644 --- a/x/exp/ast/cedar_unmarshal_test.go +++ b/x/exp/ast/cedar_unmarshal_test.go @@ -2,7 +2,6 @@ package ast_test import ( "bytes" - "net/netip" "testing" "github.com/cedar-policy/cedar-go/testutil" @@ -343,8 +342,8 @@ when { if true then true else false };`, `permit ( principal, action, resource ) when { ip("1.2.3.4") == ip("2.3.4.5") };`, ast.Permit().When( - ast.IPAddr(types.IPAddr(netip.MustParsePrefix("1.2.3.4/32"))).Equals( - ast.IPAddr(types.IPAddr(netip.MustParsePrefix("2.3.4.5/32"))), + ast.ExtensionCall("ip", ast.String("1.2.3.4")).Equals( + ast.ExtensionCall("ip", ast.String("2.3.4.5")), ), ), }, @@ -353,7 +352,7 @@ when { ip("1.2.3.4") == ip("2.3.4.5") };`, `permit ( principal, action, resource ) when { decimal("12.34") == decimal("23.45") };`, ast.Permit().When( - ast.Decimal(types.Decimal(123400)).Equals(ast.Decimal(types.Decimal(234500))), + ast.ExtensionCall("decimal", ast.String("12.34")).Equals(ast.ExtensionCall("decimal", ast.String("23.45"))), ), }, { diff --git a/x/exp/ast/eval_compile.go b/x/exp/ast/eval_compile.go new file mode 100644 index 00000000..10f48969 --- /dev/null +++ b/x/exp/ast/eval_compile.go @@ -0,0 +1,43 @@ +package ast + +type CompiledPolicySet map[string]CompiledPolicy + +type CompiledPolicy struct { + PolicySetEntry + eval Evaler +} + +func (p PolicySetEntry) TmpGetAnnotations() map[string]string { + res := make(map[string]string, len(p.Policy.annotations)) + for _, e := range p.Policy.annotations { + res[string(e.Key)] = string(e.Value) + } + return res +} +func (p PolicySetEntry) TmpGetEffect() bool { + return bool(p.Policy.effect) +} + +func Compile(p Policy) Evaler { + node := policyToNode(p).v + return toEval(node) +} + +func policyToNode(p Policy) Node { + nodes := make([]Node, 3+len(p.conditions)) + nodes[0] = p.principal.toNode() + nodes[1] = p.action.toNode() + nodes[2] = p.resource.toNode() + for i, c := range p.conditions { + if c.Condition == conditionUnless { + nodes[i+3] = Not(newNode(c.Body)) + continue + } + nodes[i+3] = newNode(c.Body) + } + res := nodes[len(nodes)-1] + for i := len(nodes) - 2; i >= 0; i-- { + res = nodes[i].And(res) + } + return res +} diff --git a/x/exp/ast/eval_convert.go b/x/exp/ast/eval_convert.go new file mode 100644 index 00000000..cf6dba26 --- /dev/null +++ b/x/exp/ast/eval_convert.go @@ -0,0 +1,124 @@ +package ast + +import ( + "fmt" +) + +func toEval(n node) Evaler { + switch v := n.(type) { + case nodeTypeAccess: + return newAttributeAccessEval(toEval(v.Arg), string(v.Value)) + case nodeTypeHas: + return newHasEval(toEval(v.Arg), string(v.Value)) + case nodeTypeLike: + return newLikeEval(toEval(v.Arg), v.Value) + case nodeTypeIf: + return newIfThenElseEval(toEval(v.If), toEval(v.Then), toEval(v.Else)) + case nodeTypeIs: + return newIsEval(toEval(v.Left), newLiteralEval(v.EntityType)) + case nodeTypeIsIn: + obj := toEval(v.Left) + lhs := newIsEval(obj, newLiteralEval(v.EntityType)) + rhs := newInEval(obj, toEval(v.Entity)) + return newAndEval(lhs, rhs) + case nodeTypeExtensionCall: + i, ok := extMap[v.Name] + if !ok { + return newErrorEval(fmt.Errorf("%w: %s", errUnknownMethod, v.Name)) + } + if i.Args != len(v.Args) { + return newErrorEval(fmt.Errorf("%w: %s takes 1 parameter", errArity, v.Name)) + } + switch { + case v.Name == "ip": + return newIPLiteralEval(toEval(v.Args[0])) + case v.Name == "decimal": + return newDecimalLiteralEval(toEval(v.Args[0])) + + case v.Name == "lessThan": + return newDecimalLessThanEval(toEval(v.Args[0]), toEval(v.Args[1])) + case v.Name == "lessThanOrEqual": + return newDecimalLessThanOrEqualEval(toEval(v.Args[0]), toEval(v.Args[1])) + case v.Name == "greaterThan": + return newDecimalGreaterThanEval(toEval(v.Args[0]), toEval(v.Args[1])) + case v.Name == "greaterThanOrEqual": + return newDecimalGreaterThanOrEqualEval(toEval(v.Args[0]), toEval(v.Args[1])) + + case v.Name == "isIpv4": + return newIPTestEval(toEval(v.Args[0]), ipTestIPv4) + case v.Name == "isIpv6": + return newIPTestEval(toEval(v.Args[0]), ipTestIPv6) + case v.Name == "isLoopback": + return newIPTestEval(toEval(v.Args[0]), ipTestLoopback) + case v.Name == "isMulticast": + return newIPTestEval(toEval(v.Args[0]), ipTestMulticast) + case v.Name == "isInRange": + return newIPIsInRangeEval(toEval(v.Args[0]), toEval(v.Args[1])) + default: + panic(fmt.Errorf("unknown extension: %v", v.Name)) + } + case nodeValue: + return newLiteralEval(v.Value) + case nodeTypeRecord: + m := make(map[string]Evaler, len(v.Elements)) + for _, e := range v.Elements { + m[string(e.Key)] = toEval(e.Value) + } + return newRecordLiteralEval(m) + case nodeTypeSet: + s := make([]Evaler, len(v.Elements)) + for i, e := range v.Elements { + s[i] = toEval(e) + } + return newSetLiteralEval(s) + case nodeTypeNegate: + return newNegateEval(toEval(v.Arg)) + case nodeTypeNot: + return newNotEval(toEval(v.Arg)) + case nodeTypeVariable: + switch v.Name { + case "principal": + return newVariableEval(variableNamePrincipal) + case "action": + return newVariableEval(variableNameAction) + case "resource": + return newVariableEval(variableNameResource) + case "context": + return newVariableEval(variableNameContext) + default: + panic(fmt.Errorf("unknown variable: %v", v.Name)) + } + case nodeTypeIn: + return newInEval(toEval(v.Left), toEval(v.Right)) + case nodeTypeAnd: + return newAndEval(toEval(v.Left), toEval(v.Right)) + case nodeTypeEquals: + return newEqualEval(toEval(v.Left), toEval(v.Right)) + case nodeTypeGreaterThan: + return newLongGreaterThanEval(toEval(v.Left), toEval(v.Right)) + case nodeTypeGreaterThanOrEqual: + return newLongGreaterThanOrEqualEval(toEval(v.Left), toEval(v.Right)) + case nodeTypeLessThan: + return newLongLessThanEval(toEval(v.Left), toEval(v.Right)) + case nodeTypeLessThanOrEqual: + return newLongLessThanOrEqualEval(toEval(v.Left), toEval(v.Right)) + case nodeTypeSub: + return newSubtractEval(toEval(v.Left), toEval(v.Right)) + case nodeTypeAdd: + return newAddEval(toEval(v.Left), toEval(v.Right)) + case nodeTypeContains: + return newContainsEval(toEval(v.Left), toEval(v.Right)) + case nodeTypeContainsAll: + return newContainsAllEval(toEval(v.Left), toEval(v.Right)) + case nodeTypeContainsAny: + return newContainsAnyEval(toEval(v.Left), toEval(v.Right)) + case nodeTypeMult: + return newMultiplyEval(toEval(v.Left), toEval(v.Right)) + case nodeTypeNotEquals: + return newNotEqualEval(toEval(v.Left), toEval(v.Right)) + case nodeTypeOr: + return newOrNode(toEval(v.Left), toEval(v.Right)) + default: + panic(fmt.Sprintf("unknown node type %T", v)) + } +} diff --git a/x/exp/ast/eval_impl.go b/x/exp/ast/eval_impl.go new file mode 100644 index 00000000..e2b6ab74 --- /dev/null +++ b/x/exp/ast/eval_impl.go @@ -0,0 +1,1086 @@ +package ast + +import ( + "fmt" + + "github.com/cedar-policy/cedar-go/types" +) + +var errOverflow = fmt.Errorf("integer overflow") +var errUnknownMethod = fmt.Errorf("unknown method") +var errUnknownExtensionFunction = fmt.Errorf("function does not exist") +var errArity = fmt.Errorf("wrong number of arguments provided to extension function") +var errAttributeAccess = fmt.Errorf("does not have the attribute") +var errEntityNotExist = fmt.Errorf("does not exist") +var errUnspecifiedEntity = fmt.Errorf("unspecified entity") + +// TODO: make private again +type EvalContext struct { + Entities Entities + Principal, Action, Resource types.Value + Context types.Value +} + +type Evaler interface { + Eval(*EvalContext) (types.Value, error) +} + +func evalBool(n Evaler, ctx *EvalContext) (types.Boolean, error) { + v, err := n.Eval(ctx) + if err != nil { + return false, err + } + b, err := types.ValueToBool(v) + if err != nil { + return false, err + } + return b, nil +} + +func evalLong(n Evaler, ctx *EvalContext) (types.Long, error) { + v, err := n.Eval(ctx) + if err != nil { + return 0, err + } + l, err := types.ValueToLong(v) + if err != nil { + return 0, err + } + return l, nil +} + +func evalString(n Evaler, ctx *EvalContext) (types.String, error) { + v, err := n.Eval(ctx) + if err != nil { + return "", err + } + s, err := types.ValueToString(v) + if err != nil { + return "", err + } + return s, nil +} + +func evalSet(n Evaler, ctx *EvalContext) (types.Set, error) { + v, err := n.Eval(ctx) + if err != nil { + return nil, err + } + s, err := types.ValueToSet(v) + if err != nil { + return nil, err + } + return s, nil +} + +func evalEntity(n Evaler, ctx *EvalContext) (types.EntityUID, error) { + v, err := n.Eval(ctx) + if err != nil { + return types.EntityUID{}, err + } + e, err := types.ValueToEntity(v) + if err != nil { + return types.EntityUID{}, err + } + return e, nil +} + +func evalPath(n Evaler, ctx *EvalContext) (types.Path, error) { + v, err := n.Eval(ctx) + if err != nil { + return "", err + } + e, err := types.ValueToPath(v) + if err != nil { + return "", err + } + return e, nil +} + +func evalDecimal(n Evaler, ctx *EvalContext) (types.Decimal, error) { + v, err := n.Eval(ctx) + if err != nil { + return types.Decimal(0), err + } + d, err := types.ValueToDecimal(v) + if err != nil { + return types.Decimal(0), err + } + return d, nil +} + +func evalIP(n Evaler, ctx *EvalContext) (types.IPAddr, error) { + v, err := n.Eval(ctx) + if err != nil { + return types.IPAddr{}, err + } + i, err := types.ValueToIP(v) + if err != nil { + return types.IPAddr{}, err + } + return i, nil +} + +// errorEval +type errorEval struct { + err error +} + +func newErrorEval(err error) *errorEval { + return &errorEval{ + err: err, + } +} + +func (n *errorEval) Eval(_ *EvalContext) (types.Value, error) { + return types.ZeroValue(), n.err +} + +// literalEval +type literalEval struct { + value types.Value +} + +func newLiteralEval(value types.Value) *literalEval { + return &literalEval{value: value} +} + +func (n *literalEval) Eval(_ *EvalContext) (types.Value, error) { + return n.value, nil +} + +// orEval +type orEval struct { + lhs Evaler + rhs Evaler +} + +func newOrNode(lhs Evaler, rhs Evaler) *orEval { + return &orEval{ + lhs: lhs, + rhs: rhs, + } +} + +func (n *orEval) Eval(ctx *EvalContext) (types.Value, error) { + v, err := n.lhs.Eval(ctx) + if err != nil { + return types.ZeroValue(), err + } + b, err := types.ValueToBool(v) + if err != nil { + return types.ZeroValue(), err + } + if b { + return v, nil + } + v, err = n.rhs.Eval(ctx) + if err != nil { + return types.ZeroValue(), err + } + _, err = types.ValueToBool(v) + if err != nil { + return types.ZeroValue(), err + } + return v, nil +} + +// andEval +type andEval struct { + lhs Evaler + rhs Evaler +} + +func newAndEval(lhs Evaler, rhs Evaler) *andEval { + return &andEval{ + lhs: lhs, + rhs: rhs, + } +} + +func (n *andEval) Eval(ctx *EvalContext) (types.Value, error) { + v, err := n.lhs.Eval(ctx) + if err != nil { + return types.ZeroValue(), err + } + b, err := types.ValueToBool(v) + if err != nil { + return types.ZeroValue(), err + } + if !b { + return v, nil + } + v, err = n.rhs.Eval(ctx) + if err != nil { + return types.ZeroValue(), err + } + _, err = types.ValueToBool(v) + if err != nil { + return types.ZeroValue(), err + } + return v, nil +} + +// notEval +type notEval struct { + inner Evaler +} + +func newNotEval(inner Evaler) *notEval { + return ¬Eval{ + inner: inner, + } +} + +func (n *notEval) Eval(ctx *EvalContext) (types.Value, error) { + v, err := n.inner.Eval(ctx) + if err != nil { + return types.ZeroValue(), err + } + b, err := types.ValueToBool(v) + if err != nil { + return types.ZeroValue(), err + } + return !b, nil +} + +// Overflow +// The Go spec specifies that overflow results in defined and deterministic +// behavior (https://go.dev/ref/spec#Integer_overflow), so we can go ahead and +// do the operations and then check for overflow ex post facto. + +func checkedAddI64(lhs, rhs types.Long) (types.Long, bool) { + result := lhs + rhs + if (result > lhs) != (rhs > 0) { + return result, false + } + return result, true +} + +func checkedSubI64(lhs, rhs types.Long) (types.Long, bool) { + result := lhs - rhs + if (result > lhs) != (rhs < 0) { + return result, false + } + return result, true +} + +func checkedMulI64(lhs, rhs types.Long) (types.Long, bool) { + if lhs == 0 || rhs == 0 { + return 0, true + } + result := lhs * rhs + if (result < 0) != ((lhs < 0) != (rhs < 0)) { + // If the result doesn't have the correct sign, then we overflowed. + return result, false + } + if result/lhs != rhs { + // If division doesn't yield the original value, then we overflowed. + return result, false + } + return result, true +} + +func checkedNegI64(a types.Long) (types.Long, bool) { + if a == -9_223_372_036_854_775_808 { + return 0, false + } + return -a, true +} + +// addEval +type addEval struct { + lhs Evaler + rhs Evaler +} + +func newAddEval(lhs Evaler, rhs Evaler) *addEval { + return &addEval{ + lhs: lhs, + rhs: rhs, + } +} + +func (n *addEval) Eval(ctx *EvalContext) (types.Value, error) { + lhs, err := evalLong(n.lhs, ctx) + if err != nil { + return types.ZeroValue(), err + } + rhs, err := evalLong(n.rhs, ctx) + if err != nil { + return types.ZeroValue(), err + } + res, ok := checkedAddI64(lhs, rhs) + if !ok { + return types.ZeroValue(), fmt.Errorf("%w while attempting to add `%d` with `%d`", errOverflow, lhs, rhs) + } + return res, nil +} + +// subtractEval +type subtractEval struct { + lhs Evaler + rhs Evaler +} + +func newSubtractEval(lhs Evaler, rhs Evaler) *subtractEval { + return &subtractEval{ + lhs: lhs, + rhs: rhs, + } +} + +func (n *subtractEval) Eval(ctx *EvalContext) (types.Value, error) { + lhs, err := evalLong(n.lhs, ctx) + if err != nil { + return types.ZeroValue(), err + } + rhs, err := evalLong(n.rhs, ctx) + if err != nil { + return types.ZeroValue(), err + } + res, ok := checkedSubI64(lhs, rhs) + if !ok { + return types.ZeroValue(), fmt.Errorf("%w while attempting to subtract `%d` from `%d`", errOverflow, rhs, lhs) + } + return res, nil +} + +// multiplyEval +type multiplyEval struct { + lhs Evaler + rhs Evaler +} + +func newMultiplyEval(lhs Evaler, rhs Evaler) *multiplyEval { + return &multiplyEval{ + lhs: lhs, + rhs: rhs, + } +} + +func (n *multiplyEval) Eval(ctx *EvalContext) (types.Value, error) { + lhs, err := evalLong(n.lhs, ctx) + if err != nil { + return types.ZeroValue(), err + } + rhs, err := evalLong(n.rhs, ctx) + if err != nil { + return types.ZeroValue(), err + } + res, ok := checkedMulI64(lhs, rhs) + if !ok { + return types.ZeroValue(), fmt.Errorf("%w while attempting to multiply `%d` by `%d`", errOverflow, lhs, rhs) + } + return res, nil +} + +// negateEval +type negateEval struct { + inner Evaler +} + +func newNegateEval(inner Evaler) *negateEval { + return &negateEval{ + inner: inner, + } +} + +func (n *negateEval) Eval(ctx *EvalContext) (types.Value, error) { + inner, err := evalLong(n.inner, ctx) + if err != nil { + return types.ZeroValue(), err + } + res, ok := checkedNegI64(inner) + if !ok { + return types.ZeroValue(), fmt.Errorf("%w while attempting to negate `%d`", errOverflow, inner) + } + return res, nil +} + +// longLessThanEval +type longLessThanEval struct { + lhs Evaler + rhs Evaler +} + +func newLongLessThanEval(lhs Evaler, rhs Evaler) *longLessThanEval { + return &longLessThanEval{ + lhs: lhs, + rhs: rhs, + } +} + +func (n *longLessThanEval) Eval(ctx *EvalContext) (types.Value, error) { + lhs, err := evalLong(n.lhs, ctx) + if err != nil { + return types.ZeroValue(), err + } + rhs, err := evalLong(n.rhs, ctx) + if err != nil { + return types.ZeroValue(), err + } + return types.Boolean(lhs < rhs), nil +} + +// longLessThanOrEqualEval +type longLessThanOrEqualEval struct { + lhs Evaler + rhs Evaler +} + +func newLongLessThanOrEqualEval(lhs Evaler, rhs Evaler) *longLessThanOrEqualEval { + return &longLessThanOrEqualEval{ + lhs: lhs, + rhs: rhs, + } +} + +func (n *longLessThanOrEqualEval) Eval(ctx *EvalContext) (types.Value, error) { + lhs, err := evalLong(n.lhs, ctx) + if err != nil { + return types.ZeroValue(), err + } + rhs, err := evalLong(n.rhs, ctx) + if err != nil { + return types.ZeroValue(), err + } + return types.Boolean(lhs <= rhs), nil +} + +// longGreaterThanEval +type longGreaterThanEval struct { + lhs Evaler + rhs Evaler +} + +func newLongGreaterThanEval(lhs Evaler, rhs Evaler) *longGreaterThanEval { + return &longGreaterThanEval{ + lhs: lhs, + rhs: rhs, + } +} + +func (n *longGreaterThanEval) Eval(ctx *EvalContext) (types.Value, error) { + lhs, err := evalLong(n.lhs, ctx) + if err != nil { + return types.ZeroValue(), err + } + rhs, err := evalLong(n.rhs, ctx) + if err != nil { + return types.ZeroValue(), err + } + return types.Boolean(lhs > rhs), nil +} + +// longGreaterThanOrEqualEval +type longGreaterThanOrEqualEval struct { + lhs Evaler + rhs Evaler +} + +func newLongGreaterThanOrEqualEval(lhs Evaler, rhs Evaler) *longGreaterThanOrEqualEval { + return &longGreaterThanOrEqualEval{ + lhs: lhs, + rhs: rhs, + } +} + +func (n *longGreaterThanOrEqualEval) Eval(ctx *EvalContext) (types.Value, error) { + lhs, err := evalLong(n.lhs, ctx) + if err != nil { + return types.ZeroValue(), err + } + rhs, err := evalLong(n.rhs, ctx) + if err != nil { + return types.ZeroValue(), err + } + return types.Boolean(lhs >= rhs), nil +} + +// decimalLessThanEval +type decimalLessThanEval struct { + lhs Evaler + rhs Evaler +} + +func newDecimalLessThanEval(lhs Evaler, rhs Evaler) *decimalLessThanEval { + return &decimalLessThanEval{ + lhs: lhs, + rhs: rhs, + } +} + +func (n *decimalLessThanEval) Eval(ctx *EvalContext) (types.Value, error) { + lhs, err := evalDecimal(n.lhs, ctx) + if err != nil { + return types.ZeroValue(), err + } + rhs, err := evalDecimal(n.rhs, ctx) + if err != nil { + return types.ZeroValue(), err + } + return types.Boolean(lhs < rhs), nil +} + +// decimalLessThanOrEqualEval +type decimalLessThanOrEqualEval struct { + lhs Evaler + rhs Evaler +} + +func newDecimalLessThanOrEqualEval(lhs Evaler, rhs Evaler) *decimalLessThanOrEqualEval { + return &decimalLessThanOrEqualEval{ + lhs: lhs, + rhs: rhs, + } +} + +func (n *decimalLessThanOrEqualEval) Eval(ctx *EvalContext) (types.Value, error) { + lhs, err := evalDecimal(n.lhs, ctx) + if err != nil { + return types.ZeroValue(), err + } + rhs, err := evalDecimal(n.rhs, ctx) + if err != nil { + return types.ZeroValue(), err + } + return types.Boolean(lhs <= rhs), nil +} + +// decimalGreaterThanEval +type decimalGreaterThanEval struct { + lhs Evaler + rhs Evaler +} + +func newDecimalGreaterThanEval(lhs Evaler, rhs Evaler) *decimalGreaterThanEval { + return &decimalGreaterThanEval{ + lhs: lhs, + rhs: rhs, + } +} + +func (n *decimalGreaterThanEval) Eval(ctx *EvalContext) (types.Value, error) { + lhs, err := evalDecimal(n.lhs, ctx) + if err != nil { + return types.ZeroValue(), err + } + rhs, err := evalDecimal(n.rhs, ctx) + if err != nil { + return types.ZeroValue(), err + } + return types.Boolean(lhs > rhs), nil +} + +// decimalGreaterThanOrEqualEval +type decimalGreaterThanOrEqualEval struct { + lhs Evaler + rhs Evaler +} + +func newDecimalGreaterThanOrEqualEval(lhs Evaler, rhs Evaler) *decimalGreaterThanOrEqualEval { + return &decimalGreaterThanOrEqualEval{ + lhs: lhs, + rhs: rhs, + } +} + +func (n *decimalGreaterThanOrEqualEval) Eval(ctx *EvalContext) (types.Value, error) { + lhs, err := evalDecimal(n.lhs, ctx) + if err != nil { + return types.ZeroValue(), err + } + rhs, err := evalDecimal(n.rhs, ctx) + if err != nil { + return types.ZeroValue(), err + } + return types.Boolean(lhs >= rhs), nil +} + +// ifThenElseEval +type ifThenElseEval struct { + if_ Evaler + then Evaler + else_ Evaler +} + +func newIfThenElseEval(if_, then, else_ Evaler) *ifThenElseEval { + return &ifThenElseEval{ + if_: if_, + then: then, + else_: else_, + } +} + +func (n *ifThenElseEval) Eval(ctx *EvalContext) (types.Value, error) { + cond, err := evalBool(n.if_, ctx) + if err != nil { + return types.ZeroValue(), err + } + if cond { + return n.then.Eval(ctx) + } + return n.else_.Eval(ctx) +} + +// notEqualNode +type equalEval struct { + lhs, rhs Evaler +} + +func newEqualEval(lhs, rhs Evaler) *equalEval { + return &equalEval{ + lhs: lhs, + rhs: rhs, + } +} + +func (n *equalEval) Eval(ctx *EvalContext) (types.Value, error) { + lv, err := n.lhs.Eval(ctx) + if err != nil { + return types.ZeroValue(), err + } + rv, err := n.rhs.Eval(ctx) + if err != nil { + return types.ZeroValue(), err + } + return types.Boolean(lv.Equal(rv)), nil +} + +// notEqualEval +type notEqualEval struct { + lhs, rhs Evaler +} + +func newNotEqualEval(lhs, rhs Evaler) *notEqualEval { + return ¬EqualEval{ + lhs: lhs, + rhs: rhs, + } +} + +func (n *notEqualEval) Eval(ctx *EvalContext) (types.Value, error) { + lv, err := n.lhs.Eval(ctx) + if err != nil { + return types.ZeroValue(), err + } + rv, err := n.rhs.Eval(ctx) + if err != nil { + return types.ZeroValue(), err + } + return types.Boolean(!lv.Equal(rv)), nil +} + +// setLiteralEval +type setLiteralEval struct { + elements []Evaler +} + +func newSetLiteralEval(elements []Evaler) *setLiteralEval { + return &setLiteralEval{elements: elements} +} + +func (n *setLiteralEval) Eval(ctx *EvalContext) (types.Value, error) { + var vals types.Set + for _, e := range n.elements { + v, err := e.Eval(ctx) + if err != nil { + return types.ZeroValue(), err + } + vals = append(vals, v) + } + return vals, nil +} + +// containsEval +type containsEval struct { + lhs, rhs Evaler +} + +func newContainsEval(lhs, rhs Evaler) *containsEval { + return &containsEval{ + lhs: lhs, + rhs: rhs, + } +} + +func (n *containsEval) Eval(ctx *EvalContext) (types.Value, error) { + lhs, err := evalSet(n.lhs, ctx) + if err != nil { + return types.ZeroValue(), err + } + rhs, err := n.rhs.Eval(ctx) + if err != nil { + return types.ZeroValue(), err + } + return types.Boolean(lhs.Contains(rhs)), nil +} + +// containsAllEval +type containsAllEval struct { + lhs, rhs Evaler +} + +func newContainsAllEval(lhs, rhs Evaler) *containsAllEval { + return &containsAllEval{ + lhs: lhs, + rhs: rhs, + } +} + +func (n *containsAllEval) Eval(ctx *EvalContext) (types.Value, error) { + lhs, err := evalSet(n.lhs, ctx) + if err != nil { + return types.ZeroValue(), err + } + rhs, err := evalSet(n.rhs, ctx) + if err != nil { + return types.ZeroValue(), err + } + result := true + for _, e := range rhs { + if !lhs.Contains(e) { + result = false + break + } + } + return types.Boolean(result), nil +} + +// containsAnyEval +type containsAnyEval struct { + lhs, rhs Evaler +} + +func newContainsAnyEval(lhs, rhs Evaler) *containsAnyEval { + return &containsAnyEval{ + lhs: lhs, + rhs: rhs, + } +} + +func (n *containsAnyEval) Eval(ctx *EvalContext) (types.Value, error) { + lhs, err := evalSet(n.lhs, ctx) + if err != nil { + return types.ZeroValue(), err + } + rhs, err := evalSet(n.rhs, ctx) + if err != nil { + return types.ZeroValue(), err + } + result := false + for _, e := range rhs { + if lhs.Contains(e) { + result = true + break + } + } + return types.Boolean(result), nil +} + +// recordLiteralEval +type recordLiteralEval struct { + elements map[string]Evaler +} + +func newRecordLiteralEval(elements map[string]Evaler) *recordLiteralEval { + return &recordLiteralEval{elements: elements} +} + +func (n *recordLiteralEval) Eval(ctx *EvalContext) (types.Value, error) { + vals := types.Record{} + for k, en := range n.elements { + v, err := en.Eval(ctx) + if err != nil { + return types.ZeroValue(), err + } + vals[k] = v + } + return vals, nil +} + +// attributeAccessEval +type attributeAccessEval struct { + object Evaler + attribute string +} + +func newAttributeAccessEval(record Evaler, attribute string) *attributeAccessEval { + return &attributeAccessEval{object: record, attribute: attribute} +} + +func (n *attributeAccessEval) Eval(ctx *EvalContext) (types.Value, error) { + v, err := n.object.Eval(ctx) + if err != nil { + return types.ZeroValue(), err + } + var record types.Record + key := "record" + switch vv := v.(type) { + case types.EntityUID: + key = "`" + vv.String() + "`" + var unspecified types.EntityUID + if vv == unspecified { + return types.ZeroValue(), fmt.Errorf("cannot access attribute `%s` of %w", n.attribute, errUnspecifiedEntity) + } + rec, ok := ctx.Entities[vv] + if !ok { + return types.ZeroValue(), fmt.Errorf("entity `%v` %w", vv.String(), errEntityNotExist) + } else { + record = rec.Attributes + } + case types.Record: + record = vv + default: + return types.ZeroValue(), fmt.Errorf("%w: expected one of [record, (entity of type `any_entity_type`)], got %v", types.ErrType, v.TypeName()) + } + val, ok := record[n.attribute] + if !ok { + return types.ZeroValue(), fmt.Errorf("%s %w `%s`", key, errAttributeAccess, n.attribute) + } + return val, nil +} + +// hasEval +type hasEval struct { + object Evaler + attribute string +} + +func newHasEval(record Evaler, attribute string) *hasEval { + return &hasEval{object: record, attribute: attribute} +} + +func (n *hasEval) Eval(ctx *EvalContext) (types.Value, error) { + v, err := n.object.Eval(ctx) + if err != nil { + return types.ZeroValue(), err + } + var record types.Record + switch vv := v.(type) { + case types.EntityUID: + rec, ok := ctx.Entities[vv] + if !ok { + record = types.Record{} + } else { + record = rec.Attributes + } + case types.Record: + record = vv + default: + return types.ZeroValue(), fmt.Errorf("%w: expected one of [record, (entity of type `any_entity_type`)], got %v", types.ErrType, v.TypeName()) + } + _, ok := record[n.attribute] + return types.Boolean(ok), nil +} + +// likeEval +type likeEval struct { + lhs Evaler + pattern Pattern +} + +func newLikeEval(lhs Evaler, pattern Pattern) *likeEval { + return &likeEval{lhs: lhs, pattern: pattern} +} + +func (l *likeEval) Eval(ctx *EvalContext) (types.Value, error) { + v, err := evalString(l.lhs, ctx) + if err != nil { + return types.ZeroValue(), err + } + return types.Boolean(match(l.pattern, string(v))), nil +} + +type variableName func(ctx *EvalContext) types.Value + +func variableNamePrincipal(ctx *EvalContext) types.Value { return ctx.Principal } +func variableNameAction(ctx *EvalContext) types.Value { return ctx.Action } +func variableNameResource(ctx *EvalContext) types.Value { return ctx.Resource } +func variableNameContext(ctx *EvalContext) types.Value { return ctx.Context } + +// variableEval +type variableEval struct { + variableName variableName +} + +func newVariableEval(variableName variableName) *variableEval { + return &variableEval{variableName: variableName} +} + +func (n *variableEval) Eval(ctx *EvalContext) (types.Value, error) { + return n.variableName(ctx), nil +} + +// inEval +type inEval struct { + lhs, rhs Evaler +} + +func newInEval(lhs, rhs Evaler) *inEval { + return &inEval{lhs: lhs, rhs: rhs} +} + +func entityIn(entity types.EntityUID, query map[types.EntityUID]struct{}, entities Entities) bool { + checked := map[types.EntityUID]struct{}{} + toCheck := []types.EntityUID{entity} + for len(toCheck) > 0 { + var candidate types.EntityUID + candidate, toCheck = toCheck[len(toCheck)-1], toCheck[:len(toCheck)-1] + if _, ok := checked[candidate]; ok { + continue + } + if _, ok := query[candidate]; ok { + return true + } + toCheck = append(toCheck, entities[candidate].Parents...) + checked[candidate] = struct{}{} + } + return false +} + +func (n *inEval) Eval(ctx *EvalContext) (types.Value, error) { + lhs, err := evalEntity(n.lhs, ctx) + if err != nil { + return types.ZeroValue(), err + } + + rhs, err := n.rhs.Eval(ctx) + if err != nil { + return types.ZeroValue(), err + } + + query := map[types.EntityUID]struct{}{} + switch rhsv := rhs.(type) { + case types.EntityUID: + query[rhsv] = struct{}{} + case types.Set: + for _, rhv := range rhsv { + e, err := types.ValueToEntity(rhv) + if err != nil { + return types.ZeroValue(), err + } + query[e] = struct{}{} + } + default: + return types.ZeroValue(), fmt.Errorf( + "%w: expected one of [set, (entity of type `any_entity_type`)], got %v", types.ErrType, rhs.TypeName()) + } + return types.Boolean(entityIn(lhs, query, ctx.Entities)), nil +} + +// isEval +type isEval struct { + lhs, rhs Evaler +} + +func newIsEval(lhs, rhs Evaler) *isEval { + return &isEval{lhs: lhs, rhs: rhs} +} + +func (n *isEval) Eval(ctx *EvalContext) (types.Value, error) { + lhs, err := evalEntity(n.lhs, ctx) + if err != nil { + return types.ZeroValue(), err + } + + rhs, err := evalPath(n.rhs, ctx) + if err != nil { + return types.ZeroValue(), err + } + + return types.Boolean(types.Path(lhs.Type) == rhs), nil +} + +// decimalLiteralEval +type decimalLiteralEval struct { + literal Evaler +} + +func newDecimalLiteralEval(literal Evaler) *decimalLiteralEval { + return &decimalLiteralEval{literal: literal} +} + +func (n *decimalLiteralEval) Eval(ctx *EvalContext) (types.Value, error) { + literal, err := evalString(n.literal, ctx) + if err != nil { + return types.ZeroValue(), err + } + + d, err := types.ParseDecimal(string(literal)) + if err != nil { + return types.ZeroValue(), err + } + + return d, nil +} + +type ipLiteralEval struct { + literal Evaler +} + +func newIPLiteralEval(literal Evaler) *ipLiteralEval { + return &ipLiteralEval{literal: literal} +} + +func (n *ipLiteralEval) Eval(ctx *EvalContext) (types.Value, error) { + literal, err := evalString(n.literal, ctx) + if err != nil { + return types.ZeroValue(), err + } + + i, err := types.ParseIPAddr(string(literal)) + if err != nil { + return types.ZeroValue(), err + } + + return i, nil +} + +type ipTestType func(v types.IPAddr) bool + +func ipTestIPv4(v types.IPAddr) bool { return v.IsIPv4() } +func ipTestIPv6(v types.IPAddr) bool { return v.IsIPv6() } +func ipTestLoopback(v types.IPAddr) bool { return v.IsLoopback() } +func ipTestMulticast(v types.IPAddr) bool { return v.IsMulticast() } + +// ipTestEval +type ipTestEval struct { + object Evaler + test ipTestType +} + +func newIPTestEval(object Evaler, test ipTestType) *ipTestEval { + return &ipTestEval{object: object, test: test} +} + +func (n *ipTestEval) Eval(ctx *EvalContext) (types.Value, error) { + i, err := evalIP(n.object, ctx) + if err != nil { + return types.ZeroValue(), err + } + return types.Boolean(n.test(i)), nil +} + +// ipIsInRangeEval + +type ipIsInRangeEval struct { + lhs, rhs Evaler +} + +func newIPIsInRangeEval(lhs, rhs Evaler) *ipIsInRangeEval { + return &ipIsInRangeEval{lhs: lhs, rhs: rhs} +} + +func (n *ipIsInRangeEval) Eval(ctx *EvalContext) (types.Value, error) { + lhs, err := evalIP(n.lhs, ctx) + if err != nil { + return types.ZeroValue(), err + } + rhs, err := evalIP(n.rhs, ctx) + if err != nil { + return types.ZeroValue(), err + } + return types.Boolean(rhs.Contains(lhs)), nil +} diff --git a/x/exp/ast/eval_match.go b/x/exp/ast/eval_match.go new file mode 100644 index 00000000..1f84082c --- /dev/null +++ b/x/exp/ast/eval_match.go @@ -0,0 +1,66 @@ +package ast + +// TODO: move this into the types package + +// ported from Go's stdlib and reduced to our scope. +// https://golang.org/src/path/filepath/match.go?s=1226:1284#L34 + +// Match reports whether name matches the shell file name pattern. +// The pattern syntax is: +// +// pattern: +// { term } +// term: +// '*' matches any sequence of non-Separator characters +// c matches character c (c != '*') +func match(p Pattern, name string) (matched bool) { +Pattern: + for i, comp := range p.Comps { + lastChunk := i == len(p.Comps)-1 + if comp.Star && comp.Chunk == "" { + return true + } + // Look for Match at current position. + t, ok := matchChunk(comp.Chunk, name) + // if we're the last chunk, make sure we've exhausted the name + // otherwise we'll give a false result even if we could still Match + // using the star + if ok && (len(t) == 0 || !lastChunk) { + name = t + continue + } + if comp.Star { + // Look for Match skipping i+1 bytes. + for i := 0; i < len(name); i++ { + t, ok := matchChunk(comp.Chunk, name[i+1:]) + if ok { + // if we're the last chunk, make sure we exhausted the name + if lastChunk && len(t) > 0 { + continue + } + name = t + continue Pattern + } + } + } + return false + } + return len(name) == 0 +} + +// matchChunk checks whether chunk matches the beginning of s. +// If so, it returns the remainder of s (after the Match). +// Chunk is all single-character operators: literals, char classes, and ?. +func matchChunk(chunk, s string) (rest string, ok bool) { + for len(chunk) > 0 { + if len(s) == 0 { + return + } + if chunk[0] != s[0] { + return + } + s = s[1:] + chunk = chunk[1:] + } + return s, true +} diff --git a/x/exp/ast/extensions.go b/x/exp/ast/extensions.go new file mode 100644 index 00000000..340e1206 --- /dev/null +++ b/x/exp/ast/extensions.go @@ -0,0 +1,24 @@ +package ast + +import "github.com/cedar-policy/cedar-go/types" + +type extInfo struct { + Args int + IsMethod bool +} + +var extMap = map[types.String]extInfo{ + "ip": {Args: 1, IsMethod: false}, + "decimal": {Args: 1, IsMethod: false}, + + "lessThan": {Args: 2, IsMethod: true}, + "lessThanOrEqual": {Args: 2, IsMethod: true}, + "greaterThan": {Args: 2, IsMethod: true}, + "greaterThanOrEqual": {Args: 2, IsMethod: true}, + + "isIpv4": {Args: 1, IsMethod: true}, + "isIpv6": {Args: 1, IsMethod: true}, + "isLoopback": {Args: 1, IsMethod: true}, + "isMulticast": {Args: 1, IsMethod: true}, + "isInRange": {Args: 2, IsMethod: true}, +} diff --git a/x/exp/ast/json_test.go b/x/exp/ast/json_test.go index 47bc68a7..bd194f09 100644 --- a/x/exp/ast/json_test.go +++ b/x/exp/ast/json_test.go @@ -118,13 +118,13 @@ func TestUnmarshalJSON(t *testing.T) { { "principalIs", `{"effect":"permit","principal":{"op":"is","entity_type":"T"},"action":{"op":"All"},"resource":{"op":"All"}}`, - ast.Permit().PrincipalIs(types.String("T")), + ast.Permit().PrincipalIs(types.Path("T")), testutil.OK, }, { "principalIsIn", `{"effect":"permit","principal":{"op":"is","entity_type":"T","in":{"entity":{"type":"P","id":"42"}}},"action":{"op":"All"},"resource":{"op":"All"}}`, - ast.Permit().PrincipalIsIn(types.String("T"), types.NewEntityUID("P", "42")), + ast.Permit().PrincipalIsIn(types.Path("T"), types.NewEntityUID("P", "42")), testutil.OK, }, { @@ -160,13 +160,13 @@ func TestUnmarshalJSON(t *testing.T) { { "resourceIs", `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"is","entity_type":"T"}}`, - ast.Permit().ResourceIs(types.String("T")), + ast.Permit().ResourceIs(types.Path("T")), testutil.OK, }, { "resourceIsIn", `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"is","entity_type":"T","in":{"entity":{"type":"P","id":"42"}}}}`, - ast.Permit().ResourceIsIn(types.String("T"), types.NewEntityUID("P", "42")), + ast.Permit().ResourceIsIn(types.Path("T"), types.NewEntityUID("P", "42")), testutil.OK, }, { diff --git a/x/exp/ast/json_unmarshal.go b/x/exp/ast/json_unmarshal.go index e4a9b76a..0dcdbf20 100644 --- a/x/exp/ast/json_unmarshal.go +++ b/x/exp/ast/json_unmarshal.go @@ -26,9 +26,9 @@ func (s *scopeJSON) ToNode(variable scope) (isScopeNode, error) { return variable.InSet(s.Entities), nil case "is": if s.In == nil { - return variable.Is(types.String(s.EntityType)), nil + return variable.Is(types.Path(s.EntityType)), nil } - return variable.IsIn(types.String(s.EntityType), s.In.Entity), nil + return variable.IsIn(types.Path(s.EntityType), s.In.Entity), nil } return nil, fmt.Errorf("unknown op: %v", s.Op) } @@ -84,9 +84,9 @@ func (j isJSON) ToNode() (Node, error) { if err != nil { return Node{}, fmt.Errorf("error in entity: %w", err) } - return left.IsIn(types.String(j.EntityType), right), nil + return left.IsIn(types.Path(j.EntityType), right), nil } - return left.Is(types.String(j.EntityType)), nil + return left.Is(types.Path(j.EntityType)), nil } func (j ifThenElseJSON) ToNode() (Node, error) { if_, err := j.If.ToNode() diff --git a/x/exp/ast/node.go b/x/exp/ast/node.go index e2d67dfa..37d86b5f 100644 --- a/x/exp/ast/node.go +++ b/x/exp/ast/node.go @@ -15,16 +15,18 @@ func newNode(v node) Node { } type strOpNode struct { - node Arg node Value types.String } +func (n strOpNode) isNode() {} + type binaryNode struct { - node Left, Right node } +func (n binaryNode) isNode() {} + type nodePrecedenceLevel uint8 const ( @@ -40,7 +42,6 @@ const ( ) type nodeTypeIf struct { - node If, Then, Else node } @@ -48,6 +49,8 @@ func (n nodeTypeIf) precedenceLevel() nodePrecedenceLevel { return ifPrecedence } +func (n nodeTypeIf) isNode() {} + type nodeTypeOr struct{ binaryNode } func (n nodeTypeOr) precedenceLevel() nodePrecedenceLevel { @@ -103,7 +106,6 @@ type nodeTypeHas struct { } type nodeTypeLike struct { - node Arg node Value Pattern } @@ -111,16 +113,17 @@ type nodeTypeLike struct { func (n nodeTypeLike) precedenceLevel() nodePrecedenceLevel { return relationPrecedence } +func (n nodeTypeLike) isNode() {} type nodeTypeIs struct { - node Left node - EntityType types.String // TODO: review type + EntityType types.Path } func (n nodeTypeIs) precedenceLevel() nodePrecedenceLevel { return relationPrecedence } +func (n nodeTypeIs) isNode() {} type nodeTypeIsIn struct { nodeTypeIs @@ -154,7 +157,6 @@ func (n nodeTypeMult) precedenceLevel() nodePrecedenceLevel { } type unaryNode struct { - node Arg node } @@ -162,6 +164,8 @@ func (n unaryNode) precedenceLevel() nodePrecedenceLevel { return unaryPrecedence } +func (n unaryNode) isNode() {} + type nodeTypeNegate struct{ unaryNode } type nodeTypeNot struct{ unaryNode } @@ -172,7 +176,6 @@ func (n nodeTypeAccess) precedenceLevel() nodePrecedenceLevel { } type nodeTypeExtensionCall struct { - node Name types.String // TODO: review type Args []node } @@ -180,6 +183,7 @@ type nodeTypeExtensionCall struct { func (n nodeTypeExtensionCall) precedenceLevel() nodePrecedenceLevel { return accessPrecedence } +func (n nodeTypeExtensionCall) isNode() {} func stripNodes(args []Node) []node { res := make([]node, len(args)) @@ -238,6 +242,8 @@ type nodeValue struct { Value types.Value } +func (n nodeValue) isNode() {} + type recordElement struct { Key types.String Value node @@ -248,6 +254,8 @@ type nodeTypeRecord struct { Elements []recordElement } +func (n nodeTypeRecord) isNode() {} + type nodeTypeSet struct { primaryNode Elements []node diff --git a/x/exp/ast/operator.go b/x/exp/ast/operator.go index 95ed737c..815f9963 100644 --- a/x/exp/ast/operator.go +++ b/x/exp/ast/operator.go @@ -109,11 +109,11 @@ func (lhs Node) In(rhs Node) Node { return newNode(nodeTypeIn{binaryNode: binaryNode{Left: lhs.v, Right: rhs.v}}) } -func (lhs Node) Is(entityType types.String) Node { +func (lhs Node) Is(entityType types.Path) Node { return newNode(nodeTypeIs{Left: lhs.v, EntityType: entityType}) } -func (lhs Node) IsIn(entityType types.String, rhs Node) Node { +func (lhs Node) IsIn(entityType types.Path, rhs Node) Node { return newNode(nodeTypeIsIn{nodeTypeIs: nodeTypeIs{Left: lhs.v, EntityType: entityType}, Entity: rhs.v}) } diff --git a/x/exp/ast/scope.go b/x/exp/ast/scope.go index 927402e2..50d259c5 100644 --- a/x/exp/ast/scope.go +++ b/x/exp/ast/scope.go @@ -24,11 +24,11 @@ func (s scope) InSet(entities []types.EntityUID) isScopeNode { return scopeTypeInSet{scopeNode: scopeNode{Variable: nodeTypeVariable(s)}, Entities: entities} } -func (s scope) Is(entityType types.String) isScopeNode { +func (s scope) Is(entityType types.Path) isScopeNode { return scopeTypeIs{scopeNode: scopeNode{Variable: nodeTypeVariable(s)}, Type: entityType} } -func (s scope) IsIn(entityType types.String, entity types.EntityUID) isScopeNode { +func (s scope) IsIn(entityType types.Path, entity types.EntityUID) isScopeNode { return scopeTypeIsIn{scopeNode: scopeNode{Variable: nodeTypeVariable(s)}, Type: entityType, Entity: entity} } @@ -42,12 +42,12 @@ func (p *Policy) PrincipalIn(entity types.EntityUID) *Policy { return p } -func (p *Policy) PrincipalIs(entityType types.String) *Policy { +func (p *Policy) PrincipalIs(entityType types.Path) *Policy { p.principal = scope(rawPrincipalNode()).Is(entityType) return p } -func (p *Policy) PrincipalIsIn(entityType types.String, entity types.EntityUID) *Policy { +func (p *Policy) PrincipalIsIn(entityType types.Path, entity types.EntityUID) *Policy { p.principal = scope(rawPrincipalNode()).IsIn(entityType, entity) return p } @@ -77,12 +77,12 @@ func (p *Policy) ResourceIn(entity types.EntityUID) *Policy { return p } -func (p *Policy) ResourceIs(entityType types.String) *Policy { +func (p *Policy) ResourceIs(entityType types.Path) *Policy { p.resource = scope(rawResourceNode()).Is(entityType) return p } -func (p *Policy) ResourceIsIn(entityType types.String, entity types.EntityUID) *Policy { +func (p *Policy) ResourceIsIn(entityType types.Path, entity types.EntityUID) *Policy { p.resource = scope(rawResourceNode()).IsIn(entityType, entity) return p } @@ -90,6 +90,7 @@ func (p *Policy) ResourceIsIn(entityType types.String, entity types.EntityUID) * type isScopeNode interface { isScope() MarshalCedar(*bytes.Buffer) + toNode() Node } type scopeNode struct { @@ -101,28 +102,56 @@ type scopeTypeAll struct { scopeNode } +func (n scopeTypeAll) toNode() Node { + return newNode(True().v) +} + type scopeTypeEq struct { scopeNode Entity types.EntityUID } +func (n scopeTypeEq) toNode() Node { + return newNode(newNode(n.Variable).Equals(EntityUID(n.Entity)).v) +} + type scopeTypeIn struct { scopeNode Entity types.EntityUID } +func (n scopeTypeIn) toNode() Node { + return newNode(newNode(n.Variable).In(EntityUID(n.Entity)).v) +} + type scopeTypeInSet struct { scopeNode Entities []types.EntityUID } +func (n scopeTypeInSet) toNode() Node { + set := make([]types.Value, len(n.Entities)) + for i, e := range n.Entities { + set[i] = e + } + return newNode(newNode(n.Variable).In(Set(set)).v) +} + type scopeTypeIs struct { scopeNode - Type types.String + Type types.Path +} + +func (n scopeTypeIs) toNode() Node { + return newNode(newNode(n.Variable).Is(n.Type).v) } type scopeTypeIsIn struct { scopeNode - Type types.String + Type types.Path Entity types.EntityUID } + +func (n scopeTypeIsIn) toNode() Node { + return newNode(newNode(n.Variable).IsIn(n.Type, EntityUID(n.Entity)).v) +} From f01cd27b24858e5a89aad25e091e8aaea360dd29 Mon Sep 17 00:00:00 2001 From: philhassey Date: Wed, 7 Aug 2024 15:20:57 -0600 Subject: [PATCH 066/216] x/exp/ast: appease linter Addresses IDX-49 Signed-off-by: philhassey --- x/exp/ast/eval_compile.go | 1 - x/exp/ast/eval_convert.go | 4 ++-- x/exp/ast/eval_impl.go | 1 - 3 files changed, 2 insertions(+), 4 deletions(-) diff --git a/x/exp/ast/eval_compile.go b/x/exp/ast/eval_compile.go index 10f48969..b28ce1bf 100644 --- a/x/exp/ast/eval_compile.go +++ b/x/exp/ast/eval_compile.go @@ -4,7 +4,6 @@ type CompiledPolicySet map[string]CompiledPolicy type CompiledPolicy struct { PolicySetEntry - eval Evaler } func (p PolicySetEntry) TmpGetAnnotations() map[string]string { diff --git a/x/exp/ast/eval_convert.go b/x/exp/ast/eval_convert.go index cf6dba26..4fcd452f 100644 --- a/x/exp/ast/eval_convert.go +++ b/x/exp/ast/eval_convert.go @@ -24,10 +24,10 @@ func toEval(n node) Evaler { case nodeTypeExtensionCall: i, ok := extMap[v.Name] if !ok { - return newErrorEval(fmt.Errorf("%w: %s", errUnknownMethod, v.Name)) + return newErrorEval(fmt.Errorf("%w: %s", errUnknownExtensionFunction, v.Name)) } if i.Args != len(v.Args) { - return newErrorEval(fmt.Errorf("%w: %s takes 1 parameter", errArity, v.Name)) + return newErrorEval(fmt.Errorf("%w: %s takes %d parameter(s)", errArity, v.Name, i.Args)) } switch { case v.Name == "ip": diff --git a/x/exp/ast/eval_impl.go b/x/exp/ast/eval_impl.go index e2b6ab74..87d62e21 100644 --- a/x/exp/ast/eval_impl.go +++ b/x/exp/ast/eval_impl.go @@ -7,7 +7,6 @@ import ( ) var errOverflow = fmt.Errorf("integer overflow") -var errUnknownMethod = fmt.Errorf("unknown method") var errUnknownExtensionFunction = fmt.Errorf("function does not exist") var errArity = fmt.Errorf("wrong number of arguments provided to extension function") var errAttributeAccess = fmt.Errorf("does not have the attribute") From 1f38d1652d5dc8a41207d1ff7b4a7783540473ef Mon Sep 17 00:00:00 2001 From: philhassey Date: Wed, 7 Aug 2024 16:00:49 -0600 Subject: [PATCH 067/216] cedar: quiet the linter for long table test Addresses IDX-49 Signed-off-by: philhassey --- cedar_test.go | 1 + 1 file changed, 1 insertion(+) diff --git a/cedar_test.go b/cedar_test.go index 290feea7..c6ddbb3a 100644 --- a/cedar_test.go +++ b/cedar_test.go @@ -50,6 +50,7 @@ func TestNewPolicySet(t *testing.T) { }) } +//nolint:revive // due to table test function-length func TestIsAuthorized(t *testing.T) { t.Parallel() tests := []struct { From 9d5016e659fe9bbc57283c4a104affe931666836 Mon Sep 17 00:00:00 2001 From: philhassey Date: Thu, 8 Aug 2024 12:44:46 -0600 Subject: [PATCH 068/216] types: move pattern into types package, move rust specific string parsing into the internal package Addresses IDX-142 Signed-off-by: philhassey --- internal/rust.go | 161 +++++++++++++++++++ internal/rust_test.go | 216 ++++++++++++++++++++++++++ types/pattern.go | 150 ++++++++++++++++++ {x/exp/ast => types}/patttern_test.go | 75 ++++++--- x/exp/ast/cedar_marshal.go | 2 +- x/exp/ast/cedar_tokenize.go | 164 +------------------ x/exp/ast/cedar_tokenize_test.go | 208 ------------------------- x/exp/ast/cedar_unmarshal.go | 2 +- x/exp/ast/cedar_unmarshal_test.go | 6 +- x/exp/ast/eval_impl.go | 6 +- x/exp/ast/eval_match.go | 66 -------- x/exp/ast/json_marshal.go | 8 +- x/exp/ast/json_test.go | 14 +- x/exp/ast/json_unmarshal.go | 4 +- x/exp/ast/node.go | 2 +- x/exp/ast/operator.go | 2 +- x/exp/ast/pattern.go | 79 ---------- 17 files changed, 613 insertions(+), 552 deletions(-) create mode 100644 internal/rust.go create mode 100644 internal/rust_test.go create mode 100644 types/pattern.go rename {x/exp/ast => types}/patttern_test.go (60%) delete mode 100644 x/exp/ast/eval_match.go delete mode 100644 x/exp/ast/pattern.go diff --git a/internal/rust.go b/internal/rust.go new file mode 100644 index 00000000..5b756fb1 --- /dev/null +++ b/internal/rust.go @@ -0,0 +1,161 @@ +package internal + +import ( + "fmt" + "strings" + "unicode/utf8" +) + +func nextRune(b []byte, i int) (rune, int, error) { + ch, size := utf8.DecodeRune(b[i:]) + if ch == utf8.RuneError { + return ch, i, fmt.Errorf("bad unicode rune") + } + return ch, i + size, nil +} + +func parseHexEscape(b []byte, i int) (rune, int, error) { + var ch rune + var err error + ch, i, err = nextRune(b, i) + if err != nil { + return 0, i, err + } + if !IsHexadecimal(ch) { + return 0, i, fmt.Errorf("bad hex escape sequence") + } + res := DigitVal(ch) + ch, i, err = nextRune(b, i) + if err != nil { + return 0, i, err + } + if !IsHexadecimal(ch) { + return 0, i, fmt.Errorf("bad hex escape sequence") + } + res = 16*res + DigitVal(ch) + if res > 127 { + return 0, i, fmt.Errorf("bad hex escape sequence") + } + return rune(res), i, nil +} + +func ParseUnicodeEscape(b []byte, i int) (rune, int, error) { + var ch rune + var err error + + ch, i, err = nextRune(b, i) + if err != nil { + return 0, i, err + } + if ch != '{' { + return 0, i, fmt.Errorf("bad unicode escape sequence") + } + + digits := 0 + res := 0 + for { + ch, i, err = nextRune(b, i) + if err != nil { + return 0, i, err + } + if ch == '}' { + break + } + if !IsHexadecimal(ch) { + return 0, i, fmt.Errorf("bad unicode escape sequence") + } + res = 16*res + DigitVal(ch) + digits++ + } + + if digits == 0 || digits > 6 || !utf8.ValidRune(rune(res)) { + return 0, i, fmt.Errorf("bad unicode escape sequence") + } + + return rune(res), i, nil +} + +func Unquote(s string) (string, error) { + s = strings.TrimPrefix(s, "\"") + s = strings.TrimSuffix(s, "\"") + res, _, err := RustUnquote([]byte(s), false) + return res, err +} + +func RustUnquote(b []byte, star bool) (string, []byte, error) { + var sb strings.Builder + var ch rune + var err error + i := 0 + for i < len(b) { + ch, i, err = nextRune(b, i) + if err != nil { + return "", nil, err + } + if star && ch == '*' { + i-- + return sb.String(), b[i:], nil + } + if ch != '\\' { + sb.WriteRune(ch) + continue + } + ch, i, err = nextRune(b, i) + if err != nil { + return "", nil, err + } + switch ch { + case 'n': + sb.WriteRune('\n') + case 'r': + sb.WriteRune('\r') + case 't': + sb.WriteRune('\t') + case '\\': + sb.WriteRune('\\') + case '0': + sb.WriteRune('\x00') + case '\'': + sb.WriteRune('\'') + case '"': + sb.WriteRune('"') + case 'x': + ch, i, err = parseHexEscape(b, i) + if err != nil { + return "", nil, err + } + sb.WriteRune(ch) + case 'u': + ch, i, err = ParseUnicodeEscape(b, i) + if err != nil { + return "", nil, err + } + sb.WriteRune(ch) + case '*': + if !star { + return "", nil, fmt.Errorf("bad char escape") + } + sb.WriteRune('*') + default: + return "", nil, fmt.Errorf("bad char escape") + } + } + return sb.String(), b[i:], nil +} + +func IsHexadecimal(ch rune) bool { + return IsDecimal(ch) || ('a' <= lower(ch) && lower(ch) <= 'f') +} + +func lower(ch rune) rune { return ('a' - 'A') | ch } // returns lower-case ch iff ch is ASCII letter +func IsDecimal(ch rune) bool { return '0' <= ch && ch <= '9' } + +func DigitVal(ch rune) int { + switch { + case '0' <= ch && ch <= '9': + return int(ch - '0') + case 'a' <= lower(ch) && lower(ch) <= 'f': + return int(lower(ch) - 'a' + 10) + } + return 16 // larger than any legal digit val +} diff --git a/internal/rust_test.go b/internal/rust_test.go new file mode 100644 index 00000000..24efd067 --- /dev/null +++ b/internal/rust_test.go @@ -0,0 +1,216 @@ +package internal_test + +import ( + "testing" + + "github.com/cedar-policy/cedar-go/internal" + "github.com/cedar-policy/cedar-go/testutil" +) + +func TestParseUnicodeEscape(t *testing.T) { + t.Parallel() + tests := []struct { + name string + in []byte + out rune + outN int + err func(t testing.TB, err error) + }{ + {"happy", []byte{'{', '4', '2', '}'}, 0x42, 4, testutil.OK}, + {"badRune", []byte{'{', 0x80, 0x81}, 0, 1, testutil.Error}, + {"notHex", []byte{'{', 'g'}, 0, 2, testutil.Error}, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + out, n, err := internal.ParseUnicodeEscape(tt.in, 0) + testutil.Equals(t, out, tt.out) + testutil.Equals(t, n, tt.outN) + tt.err(t, err) + }) + } +} + +func TestUnquote(t *testing.T) { + t.Parallel() + tests := []struct { + name string + in string + out string + err func(t testing.TB, err error) + }{ + {"happy", `"test"`, `test`, testutil.OK}, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + out, err := internal.Unquote(tt.in) + testutil.Equals(t, out, tt.out) + tt.err(t, err) + }) + } +} + +func TestRustUnquote(t *testing.T) { + t.Parallel() + // star == false + { + tests := []struct { + input string + wantOk bool + want string + wantErr string + }{ + {``, true, "", ""}, + {`hello`, true, "hello", ""}, + {`a\n\r\t\\\0b`, true, "a\n\r\t\\\x00b", ""}, + {`a\"b`, true, "a\"b", ""}, + {`a\'b`, true, "a'b", ""}, + + {`a\x00b`, true, "a\x00b", ""}, + {`a\x7fb`, true, "a\x7fb", ""}, + {`a\x80b`, false, "", "bad hex escape sequence"}, + + {string([]byte{0x80, 0x81}), false, "", "bad unicode rune"}, + {`a\u`, false, "", "bad unicode rune"}, + {`a\uz`, false, "", "bad unicode escape sequence"}, + {`a\u{}b`, false, "", "bad unicode escape sequence"}, + {`a\u{A}b`, true, "a\u000ab", ""}, + {`a\u{aB}b`, true, "a\u00abb", ""}, + {`a\u{AbC}b`, true, "a\u0abcb", ""}, + {`a\u{aBcD}b`, true, "a\uabcdb", ""}, + {`a\u{AbCdE}b`, true, "a\U000abcdeb", ""}, + {`a\u{10cDeF}b`, true, "a\U0010cdefb", ""}, + {`a\u{ffffff}b`, false, "", "bad unicode escape sequence"}, + {`a\u{0000000}b`, false, "", "bad unicode escape sequence"}, + {`a\u{d7ff}b`, true, "a\ud7ffb", ""}, + {`a\u{d800}b`, false, "", "bad unicode escape sequence"}, + {`a\u{dfff}b`, false, "", "bad unicode escape sequence"}, + {`a\u{e000}b`, true, "a\ue000b", ""}, + {`a\u{10ffff}b`, true, "a\U0010ffffb", ""}, + {`a\u{110000}b`, false, "", "bad unicode escape sequence"}, + + {`\`, false, "", "bad unicode rune"}, + {`\a`, false, "", "bad char escape"}, + {`\*`, false, "", "bad char escape"}, + {`\x`, false, "", "bad unicode rune"}, + {`\xz`, false, "", "bad hex escape sequence"}, + {`\xa`, false, "", "bad unicode rune"}, + {`\xaz`, false, "", "bad hex escape sequence"}, + {`\{`, false, "", "bad char escape"}, + {`\{z`, false, "", "bad char escape"}, + {`\{0`, false, "", "bad char escape"}, + {`\{0z`, false, "", "bad char escape"}, + } + for _, tt := range tests { + tt := tt + t.Run(tt.input, func(t *testing.T) { + t.Parallel() + got, rem, err := internal.RustUnquote([]byte(tt.input), false) + if err != nil { + testutil.Equals(t, tt.wantOk, false) + testutil.Equals(t, err.Error(), tt.wantErr) + testutil.Equals(t, got, tt.want) + } else { + testutil.Equals(t, tt.wantOk, true) + testutil.Equals(t, got, tt.want) + testutil.Equals(t, rem, []byte("")) + } + }) + } + } + + // star == true + { + tests := []struct { + input string + wantOk bool + want string + wantRem string + wantErr string + }{ + {``, true, "", "", ""}, + {`hello`, true, "hello", "", ""}, + {`a\n\r\t\\\0b`, true, "a\n\r\t\\\x00b", "", ""}, + {`a\"b`, true, "a\"b", "", ""}, + {`a\'b`, true, "a'b", "", ""}, + + {`a\x00b`, true, "a\x00b", "", ""}, + {`a\x7fb`, true, "a\x7fb", "", ""}, + {`a\x80b`, false, "", "", "bad hex escape sequence"}, + + {`a\u`, false, "", "", "bad unicode rune"}, + {`a\uz`, false, "", "", "bad unicode escape sequence"}, + {`a\u{}b`, false, "", "", "bad unicode escape sequence"}, + {`a\u{A}b`, true, "a\u000ab", "", ""}, + {`a\u{aB}b`, true, "a\u00abb", "", ""}, + {`a\u{AbC}b`, true, "a\u0abcb", "", ""}, + {`a\u{aBcD}b`, true, "a\uabcdb", "", ""}, + {`a\u{AbCdE}b`, true, "a\U000abcdeb", "", ""}, + {`a\u{10cDeF}b`, true, "a\U0010cdefb", "", ""}, + {`a\u{ffffff}b`, false, "", "", "bad unicode escape sequence"}, + {`a\u{0000000}b`, false, "", "", "bad unicode escape sequence"}, + {`a\u{d7ff}b`, true, "a\ud7ffb", "", ""}, + {`a\u{d800}b`, false, "", "", "bad unicode escape sequence"}, + {`a\u{dfff}b`, false, "", "", "bad unicode escape sequence"}, + {`a\u{e000}b`, true, "a\ue000b", "", ""}, + {`a\u{10ffff}b`, true, "a\U0010ffffb", "", ""}, + {`a\u{110000}b`, false, "", "", "bad unicode escape sequence"}, + + {`*`, true, "", "*", ""}, + {`*hello*how*are*you`, true, "", "*hello*how*are*you", ""}, + {`hello*how*are*you`, true, "hello", "*how*are*you", ""}, + {`\**`, true, "*", "*", ""}, + + {`\`, false, "", "", "bad unicode rune"}, + {`\a`, false, "", "", "bad char escape"}, + {`\x`, false, "", "", "bad unicode rune"}, + {`\xz`, false, "", "", "bad hex escape sequence"}, + {`\xa`, false, "", "", "bad unicode rune"}, + {`\xaz`, false, "", "", "bad hex escape sequence"}, + {`\{`, false, "", "", "bad char escape"}, + {`\{z`, false, "", "", "bad char escape"}, + {`\{0`, false, "", "", "bad char escape"}, + {`\{0z`, false, "", "", "bad char escape"}, + } + for _, tt := range tests { + tt := tt + t.Run(tt.input, func(t *testing.T) { + t.Parallel() + got, rem, err := internal.RustUnquote([]byte(tt.input), true) + if err != nil { + testutil.Equals(t, tt.wantOk, false) + testutil.Equals(t, err.Error(), tt.wantErr) + testutil.Equals(t, got, tt.want) + } else { + testutil.Equals(t, tt.wantOk, true) + testutil.Equals(t, got, tt.want) + testutil.Equals(t, string(rem), tt.wantRem) + } + }) + } + } +} + +func TestDigitVal(t *testing.T) { + t.Parallel() + tests := []struct { + name string + in rune + out int + }{ + {"happy", '0', 0}, + {"hex", 'f', 15}, + {"sad", 'g', 16}, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + out := internal.DigitVal(tt.in) + testutil.Equals(t, out, tt.out) + }) + } +} diff --git a/types/pattern.go b/types/pattern.go new file mode 100644 index 00000000..aa32df32 --- /dev/null +++ b/types/pattern.go @@ -0,0 +1,150 @@ +package types + +import ( + "bytes" + "strconv" + "strings" + + "github.com/cedar-policy/cedar-go/internal" +) + +type PatternComponent struct { + Wildcard bool + Literal string +} + +// Pattern is used to define a string used for the like operator. It does not +// conform to the Value interface, as it is not one of the Cedar types. +type Pattern struct { + Components []PatternComponent +} + +func (p Pattern) Cedar() string { + var buf bytes.Buffer + buf.WriteRune('"') + for _, comp := range p.Components { + if comp.Wildcard { + buf.WriteRune('*') + } + // TODO: This is wrong. It needs to escape unicode the Rustic way. + quotedString := strconv.Quote(comp.Literal) + quotedString = quotedString[1 : len(quotedString)-1] + quotedString = strings.Replace(quotedString, "*", "\\*", -1) + buf.WriteString(quotedString) + } + buf.WriteRune('"') + return buf.String() +} + +func (p *Pattern) AddWildcard() *Pattern { + star := PatternComponent{Wildcard: true} + if len(p.Components) == 0 { + p.Components = []PatternComponent{star} + return p + } + + lastComp := p.Components[len(p.Components)-1] + if lastComp.Wildcard && lastComp.Literal == "" { + return p + } + + p.Components = append(p.Components, star) + return p +} + +func (p *Pattern) AddLiteral(s string) *Pattern { + if len(p.Components) == 0 { + p.Components = []PatternComponent{{}} + } + + lastComp := &p.Components[len(p.Components)-1] + lastComp.Literal = lastComp.Literal + s + return p +} + +// TODO: move this into the types package + +// ported from Go's stdlib and reduced to our scope. +// https://golang.org/src/path/filepath/match.go?s=1226:1284#L34 + +// Match reports whether name matches the shell file name pattern. +// The pattern syntax is: +// +// pattern: +// { term } +// term: +// '*' matches any sequence of non-Separator characters +// c matches character c (c != '*') +func (p Pattern) Match(arg string) (matched bool) { +Pattern: + for i, comp := range p.Components { + lastChunk := i == len(p.Components)-1 + if comp.Wildcard && comp.Literal == "" { + return true + } + // Look for Match at current position. + t, ok := matchChunk(comp.Literal, arg) + // if we're the last chunk, make sure we've exhausted the name + // otherwise we'll give a false result even if we could still Match + // using the star + if ok && (len(t) == 0 || !lastChunk) { + arg = t + continue + } + if comp.Wildcard { + // Look for Match skipping i+1 bytes. + for i := 0; i < len(arg); i++ { + t, ok := matchChunk(comp.Literal, arg[i+1:]) + if ok { + // if we're the last chunk, make sure we exhausted the name + if lastChunk && len(t) > 0 { + continue + } + arg = t + continue Pattern + } + } + } + return false + } + return len(arg) == 0 +} + +// matchChunk checks whether chunk matches the beginning of s. +// If so, it returns the remainder of s (after the Match). +// Chunk is all single-character operators: literals, char classes, and ?. +func matchChunk(chunk, s string) (rest string, ok bool) { + for len(chunk) > 0 { + if len(s) == 0 { + return + } + if chunk[0] != s[0] { + return + } + s = s[1:] + chunk = chunk[1:] + } + return s, true +} + +func ParsePattern(s string) (Pattern, error) { + b := []byte(s) + + var comps []PatternComponent + for len(b) > 0 { + var comp PatternComponent + var err error + for len(b) > 0 && b[0] == '*' { + b = b[1:] + comp.Wildcard = true + } + comp.Literal, b, err = internal.RustUnquote(b, true) + if err != nil { + return Pattern{}, err + } + comps = append(comps, comp) + } + return Pattern{ + Components: comps, + }, nil +} diff --git a/x/exp/ast/patttern_test.go b/types/patttern_test.go similarity index 60% rename from x/exp/ast/patttern_test.go rename to types/patttern_test.go index 2eeecbe4..8eec28a4 100644 --- a/x/exp/ast/patttern_test.go +++ b/types/patttern_test.go @@ -1,4 +1,4 @@ -package ast +package types import ( "testing" @@ -6,7 +6,31 @@ import ( "github.com/cedar-policy/cedar-go/testutil" ) -func TestPatternFromCedar(t *testing.T) { +func TestPatternFromBuilder(t *testing.T) { + tests := []struct { + name string + Pattern *Pattern + want []PatternComponent + }{ + {"empty", &Pattern{}, nil}, + {"wildcard", (&Pattern{}).AddWildcard(), []PatternComponent{{Wildcard: true}}}, + {"saturate two wildcards", (&Pattern{}).AddWildcard().AddWildcard(), []PatternComponent{{Wildcard: true}}}, + {"literal", (&Pattern{}).AddLiteral("foo"), []PatternComponent{{Literal: "foo"}}}, + {"saturate two literals", (&Pattern{}).AddLiteral("foo").AddLiteral("bar"), []PatternComponent{{Literal: "foobar"}}}, + {"literal with asterisk", (&Pattern{}).AddLiteral("fo*o"), []PatternComponent{{Literal: "fo*o"}}}, + {"wildcard sandwich", (&Pattern{}).AddLiteral("foo").AddWildcard().AddLiteral("bar"), []PatternComponent{{Literal: "foo"}, {Wildcard: true, Literal: "bar"}}}, + {"literal sandwich", (&Pattern{}).AddWildcard().AddLiteral("foo").AddWildcard(), []PatternComponent{{Wildcard: true, Literal: "foo"}, {Wildcard: true}}}, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + testutil.Equals(t, tt.Pattern.Components, tt.want) + }) + } +} + +func TestParsePattern(t *testing.T) { t.Parallel() tests := []struct { input string @@ -51,38 +75,53 @@ func TestPatternFromCedar(t *testing.T) { tt := tt t.Run(tt.input, func(t *testing.T) { t.Parallel() - got, err := PatternFromCedar(tt.input) + got, err := ParsePattern(tt.input) if err != nil { testutil.Equals(t, tt.wantOk, false) testutil.Equals(t, err.Error(), tt.wantErr) } else { testutil.Equals(t, tt.wantOk, true) - testutil.Equals(t, got.Comps, tt.want) + testutil.Equals(t, got.Components, tt.want) } }) } } -func TestPatternFromBuilder(t *testing.T) { +func TestMatch(t *testing.T) { + t.Parallel() tests := []struct { - name string - Pattern *Pattern - want []PatternComponent + pattern string + target string + want bool }{ - {"empty", &Pattern{}, nil}, - {"wildcard", (&Pattern{}).AddWildcard(), []PatternComponent{{Star: true}}}, - {"saturate two wildcards", (&Pattern{}).AddWildcard().AddWildcard(), []PatternComponent{{Star: true}}}, - {"literal", (&Pattern{}).AddLiteral("foo"), []PatternComponent{{Chunk: "foo"}}}, - {"saturate two literals", (&Pattern{}).AddLiteral("foo").AddLiteral("bar"), []PatternComponent{{Chunk: "foobar"}}}, - {"literal with asterisk", (&Pattern{}).AddLiteral("fo*o"), []PatternComponent{{Chunk: "fo*o"}}}, - {"wildcard sandwich", (&Pattern{}).AddLiteral("foo").AddWildcard().AddLiteral("bar"), []PatternComponent{{Chunk: "foo"}, {Star: true, Chunk: "bar"}}}, - {"literal sandwich", (&Pattern{}).AddWildcard().AddLiteral("foo").AddWildcard(), []PatternComponent{{Star: true, Chunk: "foo"}, {Star: true}}}, + {`""`, "", true}, + {`""`, "hello", false}, + {`"*"`, "hello", true}, + {`"e"`, "hello", false}, + {`"*e"`, "hello", false}, + {`"*e*"`, "hello", true}, + {`"hello"`, "hello", true}, + {`"hello*"`, "hello", true}, + {`"*h*llo*"`, "hello", true}, + {`"h*e*o"`, "hello", true}, + {`"h*e**o"`, "hello", true}, + {`"h*z*o"`, "hello", false}, + + {`"\u{210d}*"`, "ℍello", true}, + {`"\u{210d}*"`, "Hello", false}, + + {`"\*\**\*\*"`, "**foo**", true}, + {`"\*\**\*\*"`, "**bar**", true}, + {`"\*\**\*\*"`, "*bar*", false}, } for _, tt := range tests { tt := tt - t.Run(tt.name, func(t *testing.T) { + t.Run(tt.pattern+":"+tt.target, func(t *testing.T) { t.Parallel() - testutil.Equals(t, tt.Pattern.Comps, tt.want) + pat, err := ParsePattern(tt.pattern[1 : len(tt.pattern)-1]) + testutil.OK(t, err) + got := pat.Match(tt.target) + testutil.Equals(t, got, tt.want) }) } } diff --git a/x/exp/ast/cedar_marshal.go b/x/exp/ast/cedar_marshal.go index 686ca998..82e02e43 100644 --- a/x/exp/ast/cedar_marshal.go +++ b/x/exp/ast/cedar_marshal.go @@ -297,7 +297,7 @@ func (n nodeTypeIsIn) marshalCedar(buf *bytes.Buffer) { func (n nodeTypeLike) marshalCedar(buf *bytes.Buffer) { marshalChildNode(n.precedenceLevel(), n.Arg, buf) buf.WriteString(" like ") - n.Value.MarshalCedar(buf) + buf.WriteString(n.Value.Cedar()) } func (n nodeTypeIf) marshalCedar(buf *bytes.Buffer) { diff --git a/x/exp/ast/cedar_tokenize.go b/x/exp/ast/cedar_tokenize.go index 13974de6..5c9b4cc5 100644 --- a/x/exp/ast/cedar_tokenize.go +++ b/x/exp/ast/cedar_tokenize.go @@ -7,6 +7,8 @@ import ( "strconv" "strings" "unicode/utf8" + + "github.com/cedar-policy/cedar-go/internal" ) //go:generate moq -pkg parser -fmt goimports -out tokenize_mocks_test.go . reader @@ -52,151 +54,10 @@ func (t Token) stringValue() (string, error) { s = strings.TrimPrefix(s, "\"") s = strings.TrimSuffix(s, "\"") b := []byte(s) - res, _, err := rustUnquote(b, false) + res, _, err := internal.RustUnquote(b, false) return res, err } -func nextRune(b []byte, i int) (rune, int, error) { - ch, size := utf8.DecodeRune(b[i:]) - if ch == utf8.RuneError { - return ch, i, fmt.Errorf("bad unicode rune") - } - return ch, i + size, nil -} - -func parseHexEscape(b []byte, i int) (rune, int, error) { - var ch rune - var err error - ch, i, err = nextRune(b, i) - if err != nil { - return 0, i, err - } - if !isHexadecimal(ch) { - return 0, i, fmt.Errorf("bad hex escape sequence") - } - res := digitVal(ch) - ch, i, err = nextRune(b, i) - if err != nil { - return 0, i, err - } - if !isHexadecimal(ch) { - return 0, i, fmt.Errorf("bad hex escape sequence") - } - res = 16*res + digitVal(ch) - if res > 127 { - return 0, i, fmt.Errorf("bad hex escape sequence") - } - return rune(res), i, nil -} - -func parseUnicodeEscape(b []byte, i int) (rune, int, error) { - var ch rune - var err error - - ch, i, err = nextRune(b, i) - if err != nil { - return 0, i, err - } - if ch != '{' { - return 0, i, fmt.Errorf("bad unicode escape sequence") - } - - digits := 0 - res := 0 - for { - ch, i, err = nextRune(b, i) - if err != nil { - return 0, i, err - } - if ch == '}' { - break - } - if !isHexadecimal(ch) { - return 0, i, fmt.Errorf("bad unicode escape sequence") - } - res = 16*res + digitVal(ch) - digits++ - } - - if digits == 0 || digits > 6 || !utf8.ValidRune(rune(res)) { - return 0, i, fmt.Errorf("bad unicode escape sequence") - } - - return rune(res), i, nil -} - -func Unquote(s string) (string, error) { - s = strings.TrimPrefix(s, "\"") - s = strings.TrimSuffix(s, "\"") - res, _, err := rustUnquote([]byte(s), false) - return res, err -} - -func rustUnquote(b []byte, star bool) (string, []byte, error) { - var sb strings.Builder - var ch rune - var err error - i := 0 - for i < len(b) { - ch, i, err = nextRune(b, i) - if err != nil { - return "", nil, err - } - if star && ch == '*' { - i-- - return sb.String(), b[i:], nil - } - if ch != '\\' { - sb.WriteRune(ch) - continue - } - ch, i, err = nextRune(b, i) - if err != nil { - return "", nil, err - } - switch ch { - case 'n': - sb.WriteRune('\n') - case 'r': - sb.WriteRune('\r') - case 't': - sb.WriteRune('\t') - case '\\': - sb.WriteRune('\\') - case '0': - sb.WriteRune('\x00') - case '\'': - sb.WriteRune('\'') - case '"': - sb.WriteRune('"') - case 'x': - ch, i, err = parseHexEscape(b, i) - if err != nil { - return "", nil, err - } - sb.WriteRune(ch) - case 'u': - ch, i, err = parseUnicodeEscape(b, i) - if err != nil { - return "", nil, err - } - sb.WriteRune(ch) - case '*': - if !star { - return "", nil, fmt.Errorf("bad char escape") - } - sb.WriteRune('*') - default: - return "", nil, fmt.Errorf("bad char escape") - } - } - return sb.String(), b[i:], nil -} - -func isHexadecimal(ch rune) bool { - return isDecimal(ch) || ('a' <= lower(ch) && lower(ch) <= 'f') -} - func (t Token) intValue() (int64, error) { return strconv.ParseInt(t.Text, 10, 64) } @@ -414,29 +275,16 @@ func (s *scanner) scanIdentifier() rune { return ch } -func lower(ch rune) rune { return ('a' - 'A') | ch } // returns lower-case ch iff ch is ASCII letter -func isDecimal(ch rune) bool { return '0' <= ch && ch <= '9' } - func (s *scanner) scanInteger(ch rune) rune { - for isDecimal(ch) { + for internal.IsDecimal(ch) { ch = s.next() } return ch } -func digitVal(ch rune) int { - switch { - case '0' <= ch && ch <= '9': - return int(ch - '0') - case 'a' <= lower(ch) && lower(ch) <= 'f': - return int(lower(ch) - 'a' + 10) - } - return 16 // larger than any legal digit val -} - func (s *scanner) scanHexDigits(ch rune, min, max int) rune { n := 0 - for n < max && isHexadecimal(ch) { + for n < max && internal.IsHexadecimal(ch) { ch = s.next() n++ } @@ -605,7 +453,7 @@ redo: case isIdentRune(ch, true): ch = s.scanIdentifier() tt = TokenIdent - case isDecimal(ch): + case internal.IsDecimal(ch): ch = s.scanInteger(ch) tt = TokenInt case ch == '"': diff --git a/x/exp/ast/cedar_tokenize_test.go b/x/exp/ast/cedar_tokenize_test.go index 02377584..9ff3d900 100644 --- a/x/exp/ast/cedar_tokenize_test.go +++ b/x/exp/ast/cedar_tokenize_test.go @@ -220,193 +220,6 @@ func TestStringTokenValues(t *testing.T) { } } -func TestParseUnicodeEscape(t *testing.T) { - t.Parallel() - tests := []struct { - name string - in []byte - out rune - outN int - err func(t testing.TB, err error) - }{ - {"happy", []byte{'{', '4', '2', '}'}, 0x42, 4, testutil.OK}, - {"badRune", []byte{'{', 0x80, 0x81}, 0, 1, testutil.Error}, - {"notHex", []byte{'{', 'g'}, 0, 2, testutil.Error}, - } - for _, tt := range tests { - tt := tt - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - out, n, err := parseUnicodeEscape(tt.in, 0) - testutil.Equals(t, out, tt.out) - testutil.Equals(t, n, tt.outN) - tt.err(t, err) - }) - } -} - -func TestUnquote(t *testing.T) { - t.Parallel() - tests := []struct { - name string - in string - out string - err func(t testing.TB, err error) - }{ - {"happy", `"test"`, `test`, testutil.OK}, - } - for _, tt := range tests { - tt := tt - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - out, err := Unquote(tt.in) - testutil.Equals(t, out, tt.out) - tt.err(t, err) - }) - } -} - -func TestRustUnquote(t *testing.T) { - t.Parallel() - // star == false - { - tests := []struct { - input string - wantOk bool - want string - wantErr string - }{ - {``, true, "", ""}, - {`hello`, true, "hello", ""}, - {`a\n\r\t\\\0b`, true, "a\n\r\t\\\x00b", ""}, - {`a\"b`, true, "a\"b", ""}, - {`a\'b`, true, "a'b", ""}, - - {`a\x00b`, true, "a\x00b", ""}, - {`a\x7fb`, true, "a\x7fb", ""}, - {`a\x80b`, false, "", "bad hex escape sequence"}, - - {string([]byte{0x80, 0x81}), false, "", "bad unicode rune"}, - {`a\u`, false, "", "bad unicode rune"}, - {`a\uz`, false, "", "bad unicode escape sequence"}, - {`a\u{}b`, false, "", "bad unicode escape sequence"}, - {`a\u{A}b`, true, "a\u000ab", ""}, - {`a\u{aB}b`, true, "a\u00abb", ""}, - {`a\u{AbC}b`, true, "a\u0abcb", ""}, - {`a\u{aBcD}b`, true, "a\uabcdb", ""}, - {`a\u{AbCdE}b`, true, "a\U000abcdeb", ""}, - {`a\u{10cDeF}b`, true, "a\U0010cdefb", ""}, - {`a\u{ffffff}b`, false, "", "bad unicode escape sequence"}, - {`a\u{0000000}b`, false, "", "bad unicode escape sequence"}, - {`a\u{d7ff}b`, true, "a\ud7ffb", ""}, - {`a\u{d800}b`, false, "", "bad unicode escape sequence"}, - {`a\u{dfff}b`, false, "", "bad unicode escape sequence"}, - {`a\u{e000}b`, true, "a\ue000b", ""}, - {`a\u{10ffff}b`, true, "a\U0010ffffb", ""}, - {`a\u{110000}b`, false, "", "bad unicode escape sequence"}, - - {`\`, false, "", "bad unicode rune"}, - {`\a`, false, "", "bad char escape"}, - {`\*`, false, "", "bad char escape"}, - {`\x`, false, "", "bad unicode rune"}, - {`\xz`, false, "", "bad hex escape sequence"}, - {`\xa`, false, "", "bad unicode rune"}, - {`\xaz`, false, "", "bad hex escape sequence"}, - {`\{`, false, "", "bad char escape"}, - {`\{z`, false, "", "bad char escape"}, - {`\{0`, false, "", "bad char escape"}, - {`\{0z`, false, "", "bad char escape"}, - } - for _, tt := range tests { - tt := tt - t.Run(tt.input, func(t *testing.T) { - t.Parallel() - got, rem, err := rustUnquote([]byte(tt.input), false) - if err != nil { - testutil.Equals(t, tt.wantOk, false) - testutil.Equals(t, err.Error(), tt.wantErr) - testutil.Equals(t, got, tt.want) - } else { - testutil.Equals(t, tt.wantOk, true) - testutil.Equals(t, got, tt.want) - testutil.Equals(t, rem, []byte("")) - } - }) - } - } - - // star == true - { - tests := []struct { - input string - wantOk bool - want string - wantRem string - wantErr string - }{ - {``, true, "", "", ""}, - {`hello`, true, "hello", "", ""}, - {`a\n\r\t\\\0b`, true, "a\n\r\t\\\x00b", "", ""}, - {`a\"b`, true, "a\"b", "", ""}, - {`a\'b`, true, "a'b", "", ""}, - - {`a\x00b`, true, "a\x00b", "", ""}, - {`a\x7fb`, true, "a\x7fb", "", ""}, - {`a\x80b`, false, "", "", "bad hex escape sequence"}, - - {`a\u`, false, "", "", "bad unicode rune"}, - {`a\uz`, false, "", "", "bad unicode escape sequence"}, - {`a\u{}b`, false, "", "", "bad unicode escape sequence"}, - {`a\u{A}b`, true, "a\u000ab", "", ""}, - {`a\u{aB}b`, true, "a\u00abb", "", ""}, - {`a\u{AbC}b`, true, "a\u0abcb", "", ""}, - {`a\u{aBcD}b`, true, "a\uabcdb", "", ""}, - {`a\u{AbCdE}b`, true, "a\U000abcdeb", "", ""}, - {`a\u{10cDeF}b`, true, "a\U0010cdefb", "", ""}, - {`a\u{ffffff}b`, false, "", "", "bad unicode escape sequence"}, - {`a\u{0000000}b`, false, "", "", "bad unicode escape sequence"}, - {`a\u{d7ff}b`, true, "a\ud7ffb", "", ""}, - {`a\u{d800}b`, false, "", "", "bad unicode escape sequence"}, - {`a\u{dfff}b`, false, "", "", "bad unicode escape sequence"}, - {`a\u{e000}b`, true, "a\ue000b", "", ""}, - {`a\u{10ffff}b`, true, "a\U0010ffffb", "", ""}, - {`a\u{110000}b`, false, "", "", "bad unicode escape sequence"}, - - {`*`, true, "", "*", ""}, - {`*hello*how*are*you`, true, "", "*hello*how*are*you", ""}, - {`hello*how*are*you`, true, "hello", "*how*are*you", ""}, - {`\**`, true, "*", "*", ""}, - - {`\`, false, "", "", "bad unicode rune"}, - {`\a`, false, "", "", "bad char escape"}, - {`\x`, false, "", "", "bad unicode rune"}, - {`\xz`, false, "", "", "bad hex escape sequence"}, - {`\xa`, false, "", "", "bad unicode rune"}, - {`\xaz`, false, "", "", "bad hex escape sequence"}, - {`\{`, false, "", "", "bad char escape"}, - {`\{z`, false, "", "", "bad char escape"}, - {`\{0`, false, "", "", "bad char escape"}, - {`\{0z`, false, "", "", "bad char escape"}, - } - for _, tt := range tests { - tt := tt - t.Run(tt.input, func(t *testing.T) { - t.Parallel() - got, rem, err := rustUnquote([]byte(tt.input), true) - if err != nil { - testutil.Equals(t, tt.wantOk, false) - testutil.Equals(t, err.Error(), tt.wantErr) - testutil.Equals(t, got, tt.want) - } else { - testutil.Equals(t, tt.wantOk, true) - testutil.Equals(t, got, tt.want) - testutil.Equals(t, string(rem), tt.wantRem) - } - }) - } - } -} - func TestScanner(t *testing.T) { t.Parallel() t.Run("SrcError", func(t *testing.T) { @@ -469,24 +282,3 @@ func TestScanner(t *testing.T) { testutil.Equals(t, out, "") }) } - -func TestDigitVal(t *testing.T) { - t.Parallel() - tests := []struct { - name string - in rune - out int - }{ - {"happy", '0', 0}, - {"hex", 'f', 15}, - {"sad", 'g', 16}, - } - for _, tt := range tests { - tt := tt - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - out := digitVal(tt.in) - testutil.Equals(t, out, tt.out) - }) - } -} diff --git a/x/exp/ast/cedar_unmarshal.go b/x/exp/ast/cedar_unmarshal.go index 723ec8a5..e63ed804 100644 --- a/x/exp/ast/cedar_unmarshal.go +++ b/x/exp/ast/cedar_unmarshal.go @@ -551,7 +551,7 @@ func (p *parser) like(lhs Node) (Node, error) { patternRaw := t.Text patternRaw = strings.TrimPrefix(patternRaw, "\"") patternRaw = strings.TrimSuffix(patternRaw, "\"") - pattern, err := PatternFromCedar(patternRaw) + pattern, err := types.ParsePattern(patternRaw) if err != nil { return Node{}, err } diff --git a/x/exp/ast/cedar_unmarshal_test.go b/x/exp/ast/cedar_unmarshal_test.go index 6b70b9d8..9fcf5c6c 100644 --- a/x/exp/ast/cedar_unmarshal_test.go +++ b/x/exp/ast/cedar_unmarshal_test.go @@ -281,19 +281,19 @@ when { principal has "1stName" };`, "like no wildcards", `permit ( principal, action, resource ) when { principal.firstName like "johnny" };`, - ast.Permit().When(ast.Principal().Access("firstName").Like(testutil.Must(ast.PatternFromCedar("johnny")))), + ast.Permit().When(ast.Principal().Access("firstName").Like(testutil.Must(types.ParsePattern("johnny")))), }, { "like escaped asterisk", `permit ( principal, action, resource ) when { principal.firstName like "joh\*nny" };`, - ast.Permit().When(ast.Principal().Access("firstName").Like(testutil.Must(ast.PatternFromCedar(`joh\*nny`)))), + ast.Permit().When(ast.Principal().Access("firstName").Like(testutil.Must(types.ParsePattern(`joh\*nny`)))), }, { "like wildcard", `permit ( principal, action, resource ) when { principal.firstName like "*" };`, - ast.Permit().When(ast.Principal().Access("firstName").Like(testutil.Must(ast.PatternFromCedar("*")))), + ast.Permit().When(ast.Principal().Access("firstName").Like(testutil.Must(types.ParsePattern("*")))), }, { "is", diff --git a/x/exp/ast/eval_impl.go b/x/exp/ast/eval_impl.go index 87d62e21..af9f1308 100644 --- a/x/exp/ast/eval_impl.go +++ b/x/exp/ast/eval_impl.go @@ -877,10 +877,10 @@ func (n *hasEval) Eval(ctx *EvalContext) (types.Value, error) { // likeEval type likeEval struct { lhs Evaler - pattern Pattern + pattern types.Pattern } -func newLikeEval(lhs Evaler, pattern Pattern) *likeEval { +func newLikeEval(lhs Evaler, pattern types.Pattern) *likeEval { return &likeEval{lhs: lhs, pattern: pattern} } @@ -889,7 +889,7 @@ func (l *likeEval) Eval(ctx *EvalContext) (types.Value, error) { if err != nil { return types.ZeroValue(), err } - return types.Boolean(match(l.pattern, string(v))), nil + return types.Boolean(l.pattern.Match(string(v))), nil } type variableName func(ctx *EvalContext) types.Value diff --git a/x/exp/ast/eval_match.go b/x/exp/ast/eval_match.go deleted file mode 100644 index 1f84082c..00000000 --- a/x/exp/ast/eval_match.go +++ /dev/null @@ -1,66 +0,0 @@ -package ast - -// TODO: move this into the types package - -// ported from Go's stdlib and reduced to our scope. -// https://golang.org/src/path/filepath/match.go?s=1226:1284#L34 - -// Match reports whether name matches the shell file name pattern. -// The pattern syntax is: -// -// pattern: -// { term } -// term: -// '*' matches any sequence of non-Separator characters -// c matches character c (c != '*') -func match(p Pattern, name string) (matched bool) { -Pattern: - for i, comp := range p.Comps { - lastChunk := i == len(p.Comps)-1 - if comp.Star && comp.Chunk == "" { - return true - } - // Look for Match at current position. - t, ok := matchChunk(comp.Chunk, name) - // if we're the last chunk, make sure we've exhausted the name - // otherwise we'll give a false result even if we could still Match - // using the star - if ok && (len(t) == 0 || !lastChunk) { - name = t - continue - } - if comp.Star { - // Look for Match skipping i+1 bytes. - for i := 0; i < len(name); i++ { - t, ok := matchChunk(comp.Chunk, name[i+1:]) - if ok { - // if we're the last chunk, make sure we exhausted the name - if lastChunk && len(t) > 0 { - continue - } - name = t - continue Pattern - } - } - } - return false - } - return len(name) == 0 -} - -// matchChunk checks whether chunk matches the beginning of s. -// If so, it returns the remainder of s (after the Match). -// Chunk is all single-character operators: literals, char classes, and ?. -func matchChunk(chunk, s string) (rest string, ok bool) { - for len(chunk) > 0 { - if len(s) == 0 { - return - } - if chunk[0] != s[0] { - return - } - s = s[1:] - chunk = chunk[1:] - } - return s, true -} diff --git a/x/exp/ast/json_marshal.go b/x/exp/ast/json_marshal.go index 99e47cc4..ee38ca24 100644 --- a/x/exp/ast/json_marshal.go +++ b/x/exp/ast/json_marshal.go @@ -119,12 +119,12 @@ func patternToJSON(dest **patternJSON, src nodeTypeLike) error { if err := res.Left.FromNode(src.Arg); err != nil { return fmt.Errorf("error in left: %w", err) } - for _, comp := range src.Value.Comps { - if comp.Star { + for _, comp := range src.Value.Components { + if comp.Wildcard { res.Pattern = append(res.Pattern, patternComponentJSON{Wildcard: true}) } - if comp.Chunk != "" { - res.Pattern = append(res.Pattern, patternComponentJSON{Literal: patternComponentLiteralJSON{Literal: comp.Chunk}}) + if comp.Literal != "" { + res.Pattern = append(res.Pattern, patternComponentJSON{Literal: patternComponentLiteralJSON{Literal: comp.Literal}}) } } *dest = res diff --git a/x/exp/ast/json_test.go b/x/exp/ast/json_test.go index bd194f09..050a5c0d 100644 --- a/x/exp/ast/json_test.go +++ b/x/exp/ast/json_test.go @@ -397,49 +397,49 @@ func TestUnmarshalJSON(t *testing.T) { "like single wildcard", `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, "conditions":[{"kind":"when","body":{"like":{"left":{"Value":"text"},"pattern":["Wildcard"]}}}]}`, - ast.Permit().When(ast.String("text").Like(testutil.Must(ast.PatternFromCedar("*")))), + ast.Permit().When(ast.String("text").Like(testutil.Must(types.ParsePattern("*")))), testutil.OK, }, { "like single literal", `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, "conditions":[{"kind":"when","body":{"like":{"left":{"Value":"text"},"pattern":[{"Literal":"foo"}]}}}]}`, - ast.Permit().When(ast.String("text").Like(testutil.Must(ast.PatternFromCedar("foo")))), + ast.Permit().When(ast.String("text").Like(testutil.Must(types.ParsePattern("foo")))), testutil.OK, }, { "like wildcard then literal", `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, "conditions":[{"kind":"when","body":{"like":{"left":{"Value":"text"},"pattern":["Wildcard", {"Literal":"foo"}]}}}]}`, - ast.Permit().When(ast.String("text").Like(testutil.Must(ast.PatternFromCedar("*foo")))), + ast.Permit().When(ast.String("text").Like(testutil.Must(types.ParsePattern("*foo")))), testutil.OK, }, { "like literal then wildcard", `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, "conditions":[{"kind":"when","body":{"like":{"left":{"Value":"text"},"pattern":[{"Literal":"foo"}, "Wildcard"]}}}]}`, - ast.Permit().When(ast.String("text").Like(testutil.Must(ast.PatternFromCedar("foo*")))), + ast.Permit().When(ast.String("text").Like(testutil.Must(types.ParsePattern("foo*")))), testutil.OK, }, { "like literal with asterisk then wildcard", `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, "conditions":[{"kind":"when","body":{"like":{"left":{"Value":"text"},"pattern":[{"Literal":"f*oo"}, "Wildcard"]}}}]}`, - ast.Permit().When(ast.String("text").Like(testutil.Must(ast.PatternFromCedar(`f\*oo*`)))), + ast.Permit().When(ast.String("text").Like(testutil.Must(types.ParsePattern(`f\*oo*`)))), testutil.OK, }, { "like literal sandwich", `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, "conditions":[{"kind":"when","body":{"like":{"left":{"Value":"text"},"pattern":[{"Literal":"foo"}, "Wildcard", {"Literal":"bar"}]}}}]}`, - ast.Permit().When(ast.String("text").Like(testutil.Must(ast.PatternFromCedar(`foo*bar`)))), + ast.Permit().When(ast.String("text").Like(testutil.Must(types.ParsePattern(`foo*bar`)))), testutil.OK, }, { "like wildcard sandwich", `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, "conditions":[{"kind":"when","body":{"like":{"left":{"Value":"text"},"pattern":["Wildcard", {"Literal":"foo"}, "Wildcard"]}}}]}`, - ast.Permit().When(ast.String("text").Like(testutil.Must(ast.PatternFromCedar(`*foo*`)))), + ast.Permit().When(ast.String("text").Like(testutil.Must(types.ParsePattern(`*foo*`)))), testutil.OK, }, { diff --git a/x/exp/ast/json_unmarshal.go b/x/exp/ast/json_unmarshal.go index 0dcdbf20..cdb0e868 100644 --- a/x/exp/ast/json_unmarshal.go +++ b/x/exp/ast/json_unmarshal.go @@ -58,12 +58,12 @@ func (j strJSON) ToNode(f func(a Node, k string) Node) (Node, error) { } return f(left, j.Attr), nil } -func (j patternJSON) ToNode(f func(a Node, k Pattern) Node) (Node, error) { +func (j patternJSON) ToNode(f func(a Node, k types.Pattern) Node) (Node, error) { left, err := j.Left.ToNode() if err != nil { return Node{}, fmt.Errorf("error in left: %w", err) } - pattern := &Pattern{} + pattern := &types.Pattern{} for _, compJSON := range j.Pattern { if compJSON.Wildcard { pattern = pattern.AddWildcard() diff --git a/x/exp/ast/node.go b/x/exp/ast/node.go index 37d86b5f..39387a08 100644 --- a/x/exp/ast/node.go +++ b/x/exp/ast/node.go @@ -107,7 +107,7 @@ type nodeTypeHas struct { type nodeTypeLike struct { Arg node - Value Pattern + Value types.Pattern } func (n nodeTypeLike) precedenceLevel() nodePrecedenceLevel { diff --git a/x/exp/ast/operator.go b/x/exp/ast/operator.go index 815f9963..a9b78ca1 100644 --- a/x/exp/ast/operator.go +++ b/x/exp/ast/operator.go @@ -49,7 +49,7 @@ func (lhs Node) GreaterThanOrEqualExt(rhs Node) Node { return newMethodCall(lhs, "greaterThanOrEqual", rhs) } -func (lhs Node) Like(pattern Pattern) Node { +func (lhs Node) Like(pattern types.Pattern) Node { return newNode(nodeTypeLike{Arg: lhs.v, Value: pattern}) } diff --git a/x/exp/ast/pattern.go b/x/exp/ast/pattern.go deleted file mode 100644 index 6f96f0e4..00000000 --- a/x/exp/ast/pattern.go +++ /dev/null @@ -1,79 +0,0 @@ -package ast - -import ( - "bytes" - "strconv" - "strings" -) - -type PatternComponent struct { - Star bool - Chunk string -} - -type Pattern struct { - Comps []PatternComponent -} - -func PatternFromCedar(cedar string) (Pattern, error) { - b := []byte(cedar) - - var comps []PatternComponent - for len(b) > 0 { - var comp PatternComponent - var err error - for len(b) > 0 && b[0] == '*' { - b = b[1:] - comp.Star = true - } - comp.Chunk, b, err = rustUnquote(b, true) - if err != nil { - return Pattern{}, err - } - comps = append(comps, comp) - } - return Pattern{ - Comps: comps, - }, nil -} - -func (p Pattern) MarshalCedar(buf *bytes.Buffer) { - buf.WriteRune('"') - for _, comp := range p.Comps { - if comp.Star { - buf.WriteRune('*') - } - // TODO: This is wrong. It needs to escape unicode the Rustic way. - quotedString := strconv.Quote(comp.Chunk) - quotedString = quotedString[1 : len(quotedString)-1] - quotedString = strings.Replace(quotedString, "*", "\\*", -1) - buf.WriteString(quotedString) - } - buf.WriteRune('"') -} - -func (p *Pattern) AddWildcard() *Pattern { - star := PatternComponent{Star: true} - if len(p.Comps) == 0 { - p.Comps = []PatternComponent{star} - return p - } - - lastComp := p.Comps[len(p.Comps)-1] - if lastComp.Star && lastComp.Chunk == "" { - return p - } - - p.Comps = append(p.Comps, star) - return p -} - -func (p *Pattern) AddLiteral(s string) *Pattern { - if len(p.Comps) == 0 { - p.Comps = []PatternComponent{{}} - } - - lastComp := &p.Comps[len(p.Comps)-1] - lastComp.Chunk = lastComp.Chunk + s - return p -} From 8941531975ddb525cc3ea9366b74a2767c6cbb29 Mon Sep 17 00:00:00 2001 From: philhassey Date: Thu, 8 Aug 2024 15:57:29 -0600 Subject: [PATCH 069/216] x/exp/ast: move parser tests into ast package Addresses IDX-142 Signed-off-by: philhassey --- x/exp/ast/cedar_marshal.go | 17 +- .../parse_test.go => ast/cedar_parse_test.go} | 166 +++--------------- x/exp/ast/cedar_unmarshal.go | 85 +++++---- x/exp/ast/cedar_unmarshal_test.go | 14 +- x/exp/ast/node.go | 4 +- x/exp/ast/scope.go | 3 +- 6 files changed, 107 insertions(+), 182 deletions(-) rename x/exp/{parser/parse_test.go => ast/cedar_parse_test.go} (78%) diff --git a/x/exp/ast/cedar_marshal.go b/x/exp/ast/cedar_marshal.go index 82e02e43..e16d82b9 100644 --- a/x/exp/ast/cedar_marshal.go +++ b/x/exp/ast/cedar_marshal.go @@ -163,14 +163,14 @@ func (n nodeTypeAccess) marshalCedar(buf *bytes.Buffer) { func (n nodeTypeExtensionCall) marshalCedar(buf *bytes.Buffer) { var args []node - if n.Name != "ip" && n.Name != "decimal" { + info := extMap[n.Name] + if info.IsMethod { marshalChildNode(n.precedenceLevel(), n.Args[0], buf) buf.WriteRune('.') args = n.Args[1:] } else { args = n.Args } - buf.WriteString(string(n.Name)) buf.WriteRune('(') for i := range args { @@ -214,6 +214,19 @@ func (n nodeTypeSet) marshalCedar(buf *bytes.Buffer) { buf.WriteRune(']') } +func (n nodeTypeRecord) marshalCedar(buf *bytes.Buffer) { + buf.WriteRune('{') + for i := range n.Elements { + buf.WriteString(n.Elements[i].Key.Cedar()) + buf.WriteString(":") + marshalChildNode(n.precedenceLevel(), n.Elements[i].Value, buf) + if i != len(n.Elements)-1 { + buf.WriteString(", ") + } + } + buf.WriteRune('}') +} + func marshalInfixBinaryOp(n binaryNode, precedence nodePrecedenceLevel, op string, buf *bytes.Buffer) { marshalChildNode(precedence, n.Left, buf) buf.WriteRune(' ') diff --git a/x/exp/parser/parse_test.go b/x/exp/ast/cedar_parse_test.go similarity index 78% rename from x/exp/parser/parse_test.go rename to x/exp/ast/cedar_parse_test.go index 7a6c0bb4..930b623a 100644 --- a/x/exp/parser/parse_test.go +++ b/x/exp/ast/cedar_parse_test.go @@ -1,9 +1,11 @@ -package parser +package ast_test import ( + "bytes" "testing" "github.com/cedar-policy/cedar-go/testutil" + "github.com/cedar-policy/cedar-go/x/exp/ast" ) func TestParse(t *testing.T) { @@ -292,151 +294,29 @@ func TestParse(t *testing.T) { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() - tokens, err := Tokenize([]byte(tt.in)) - testutil.OK(t, err) - got, err := Parse(tokens) + + var policies ast.PolicySet + err := policies.UnmarshalCedar([]byte(tt.in)) testutil.Equals(t, err != nil, tt.err) if err != nil { - testutil.Equals(t, got, nil) return } - - gotTokens, err := Tokenize([]byte(got.String())) - testutil.OK(t, err) - - var tokenStrs []string - for _, t := range tokens { - tokenStrs = append(tokenStrs, t.toString()) - } - - var gotTokenStrs []string - for _, t := range gotTokens { - gotTokenStrs = append(gotTokenStrs, t.toString()) + if len(policies) != 1 { + // TODO: handle 0, > 1 + return } - testutil.Equals(t, gotTokenStrs, tokenStrs) - }) - } -} + var buf bytes.Buffer + pp := policies["policy0"].Policy + pp.MarshalCedar(&buf) -func TestParseTypes(t *testing.T) { - t.Parallel() - tests := []struct { - name string - in string - out Policies - }{ - { - "first", - "permit(principal, action, resource) when { 3 * 2 > 5 };", - Policies{ - Policy{ - Position: Position{Offset: 0, Line: 1, Column: 1}, - Annotations: []Annotation(nil), - Effect: "permit", - Conditions: []Condition{ - { - Type: "when", - Expression: Expression{ - Type: ExpressionOr, - Or: Or{ - Ands: []And{ - { - Relations: []Relation{ - { - Add: Add{ - Mults: []Mult{ - { - Unaries: []Unary{ - { - Ops: []UnaryOp(nil), - Member: Member{ - Primary: Primary{ - Type: PrimaryLiteral, - Literal: Literal{Type: LiteralInt, Long: 3}, - }, - Accesses: []Access(nil), - }, - }, - { - Ops: []UnaryOp(nil), - Member: Member{ - Primary: Primary{ - Type: PrimaryLiteral, - Literal: Literal{Type: LiteralInt, Long: 2}, - }, - Accesses: []Access(nil), - }, - }, - }, - }, - }, - }, - Type: "relop", - RelOp: ">", - RelOpRhs: Add{ - Mults: []Mult{ - { - Unaries: []Unary{ - { - Ops: []UnaryOp(nil), - Member: Member{ - Primary: Primary{ - Type: PrimaryLiteral, - Literal: Literal{Type: LiteralInt, Long: 5}, - }, - Accesses: []Access(nil), - }, - }, - }, - }, - }, - }, - Str: "", - }, - }, - }, - }, - }, - }, - }, - }, - }, - }, - }, - } - for _, tt := range tests { - tt := tt - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - tokens, err := Tokenize([]byte(tt.in)) - testutil.OK(t, err) - got, err := Parse(tokens) + var p2 ast.PolicySet + err = p2.UnmarshalCedar(buf.Bytes()) testutil.OK(t, err) - testutil.Equals(t, got, tt.out) - }) - } -} -func TestParseEntity(t *testing.T) { - t.Parallel() - tests := []struct { - name string - in string - out Entity - err func(testing.TB, error) - }{ - {"happy", `Action::"test"`, Entity{Path: []string{"Action", "test"}}, testutil.OK}, - } - for _, tt := range tests { - tt := tt - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - toks, err := Tokenize([]byte(tt.in)) - testutil.OK(t, err) - out, err := ParseEntity(toks) - testutil.Equals(t, out, tt.out) - tt.err(t, err) + // TODO: support 0, > 1 + testutil.Equals(t, p2["policy0"].Policy, policies["policy0"].Policy) + }) } } @@ -454,12 +334,12 @@ permit( principal, action, resource ); // annotation indent @test("1234") permit (principal, action, resource ); ` - toks, err := Tokenize([]byte(in)) - testutil.OK(t, err) - out, err := Parse(toks) + + var out ast.PolicySet + err := out.UnmarshalCedar([]byte(in)) testutil.OK(t, err) testutil.Equals(t, len(out), 3) - testutil.Equals(t, out[0].Position, Position{Offset: 17, Line: 2, Column: 1}) - testutil.Equals(t, out[1].Position, Position{Offset: 86, Line: 7, Column: 3}) - testutil.Equals(t, out[2].Position, Position{Offset: 148, Line: 10, Column: 2}) + testutil.Equals(t, out["policy0"].Position, ast.Position{Offset: 17, Line: 2, Column: 1}) + testutil.Equals(t, out["policy1"].Position, ast.Position{Offset: 86, Line: 7, Column: 3}) + testutil.Equals(t, out["policy2"].Position, ast.Position{Offset: 148, Line: 10, Column: 2}) } diff --git a/x/exp/ast/cedar_unmarshal.go b/x/exp/ast/cedar_unmarshal.go index e63ed804..8abbfca2 100644 --- a/x/exp/ast/cedar_unmarshal.go +++ b/x/exp/ast/cedar_unmarshal.go @@ -2,6 +2,7 @@ package ast import ( "fmt" + "strconv" "strings" "github.com/cedar-policy/cedar-go/types" @@ -19,7 +20,12 @@ func (p *PolicySet) UnmarshalCedar(b []byte) error { parser := newParser(tokens) for !parser.peek().isEOF() { pos := parser.peek().Pos - var policy Policy + policy := Policy{ + principal: scopeTypeAll{}, + action: scopeTypeAll{}, + resource: scopeTypeAll{}, + } + if err = policy.fromCedarWithParser(&parser); err != nil { return err } @@ -126,9 +132,10 @@ func (p *parser) errorf(s string, args ...interface{}) error { func (p *parser) annotations() (Annotations, error) { var res Annotations + known := map[types.String]struct{}{} for p.peek().Text == "@" { p.advance() - err := p.annotation(&res) + err := p.annotation(&res, known) if err != nil { return res, err } @@ -137,7 +144,7 @@ func (p *parser) annotations() (Annotations, error) { } -func (p *parser) annotation(a *Annotations) error { +func (p *parser) annotation(a *Annotations, known map[types.String]struct{}) error { var err error t := p.advance() if !t.isIdent() { @@ -147,6 +154,10 @@ func (p *parser) annotation(a *Annotations) error { if err = p.exact("("); err != nil { return err } + if _, ok := known[name]; ok { + return p.errorf("duplicate annotation: @%s", name) + } + known[name] = struct{}{} t = p.advance() if !t.isString() { return p.errorf("expected string") @@ -624,29 +635,42 @@ func (p *parser) mult() (Node, error) { } func (p *parser) unary() (Node, error) { - opMap := map[string]func(Node) Node{ - "-": Negate, - "!": Not, - } - - var ops []func(Node) Node + var ops []bool for { opToken := p.peek() - op, ok := opMap[opToken.Text] - if !ok { + if opToken.Text != "-" && opToken.Text != "!" { break } p.advance() - ops = append(ops, op) + ops = append(ops, opToken.Text == "-") } - res, err := p.member() - if err != nil { - return res, err + var res Node + + // special case for max negative long + tok := p.peek() + if len(ops) > 0 && ops[len(ops)-1] && tok.isInt() { + p.advance() + i, err := strconv.ParseInt("-"+tok.Text, 10, 64) + if err != nil { + return Node{}, err + } + res = Long(types.Long(i)) + ops = ops[:len(ops)-1] + } else { + var err error + res, err = p.member() + if err != nil { + return res, err + } } for i := len(ops) - 1; i >= 0; i-- { - res = ops[i](res) + if ops[i] { + res = Negate(res) + } else { + res = Not(res) + } } return res, nil } @@ -750,14 +774,13 @@ func (p *parser) entityOrExtFun(ident string) (Node, error) { return Node{}, err } p.advance() - - i, ok := extMap[types.String(res.Type)] - if !ok { - return Node{}, p.errorf("`%v` is not a function", res.Type) - } - if i.IsMethod { - return Node{}, p.errorf("`%v` is a method, not a function", res.Type) - } + // i, ok := extMap[types.String(res.Type)] + // if !ok { + // return Node{}, p.errorf("`%v` is not a function", res.Type) + // } + // if i.IsMethod { + // return Node{}, p.errorf("`%v` is a method, not a function", res.Type) + // } return ExtensionCall(types.String(res.Type), args...), nil default: return Node{}, p.errorf("unexpected token") @@ -862,13 +885,13 @@ func (p *parser) access(lhs Node) (Node, bool, error) { case "containsAny": knownMethod = Node.ContainsAny default: - i, ok := extMap[types.String(methodName)] - if !ok { - return Node{}, false, p.errorf("not a valid method name: `%v`", methodName) - } - if !i.IsMethod { - return Node{}, false, p.errorf("`%v` is a function, not a method", methodName) - } + // i, ok := extMap[types.String(methodName)] + // if !ok { + // return Node{}, false, p.errorf("not a valid method name: `%v`", methodName) + // } + // if !i.IsMethod { + // return Node{}, false, p.errorf("`%v` is a function, not a method", methodName) + // } return newMethodCall(lhs, types.String(methodName), exprs...), true, nil } diff --git a/x/exp/ast/cedar_unmarshal_test.go b/x/exp/ast/cedar_unmarshal_test.go index 9fcf5c6c..47972c3d 100644 --- a/x/exp/ast/cedar_unmarshal_test.go +++ b/x/exp/ast/cedar_unmarshal_test.go @@ -136,13 +136,19 @@ when { !!true };`, "negate operator", `permit ( principal, action, resource ) when { -1 };`, - ast.Permit().When(ast.Negate(ast.Long(1))), + ast.Permit().When(ast.Long(-1)), + }, + { + "negate operator context", + `permit ( principal, action, resource ) +when { -context };`, + ast.Permit().When(ast.Negate(ast.Context())), }, { "mutliple negate operators", `permit ( principal, action, resource ) when { !--1 };`, - ast.Permit().When(ast.Not(ast.Negate(ast.Negate(ast.Long(1))))), + ast.Permit().When(ast.Not(ast.Negate(ast.Long(-1)))), }, { "variable member", @@ -389,7 +395,7 @@ when { 2 + 3 * 4 == 14 };`, "unary over mult precedence", `permit ( principal, action, resource ) when { -2 * 3 == -6 };`, - ast.Permit().When(ast.Negate(ast.Long(2)).Times(ast.Long(3)).Equals(ast.Negate(ast.Long(6)))), + ast.Permit().When(ast.Long(-2).Times(ast.Long(3)).Equals(ast.Long(-6))), }, { "member over unary precedence", @@ -401,7 +407,7 @@ when { -context.num };`, "parens over unary precedence", `permit ( principal, action, resource ) when { -(2 + 3) == -5 };`, - ast.Permit().When(ast.Negate(ast.Long(2).Plus(ast.Long(3))).Equals(ast.Negate(ast.Long(5)))), + ast.Permit().When(ast.Negate(ast.Long(2).Plus(ast.Long(3))).Equals(ast.Long(-5))), }, { "multiple parenthesized operations", diff --git a/x/exp/ast/node.go b/x/exp/ast/node.go index 39387a08..850562a3 100644 --- a/x/exp/ast/node.go +++ b/x/exp/ast/node.go @@ -231,7 +231,9 @@ type nodeTypeContainsAny struct { containsNode } -type primaryNode struct{ node } +type primaryNode struct{} + +func (n primaryNode) isNode() {} func (n primaryNode) precedenceLevel() nodePrecedenceLevel { return primaryPrecedence diff --git a/x/exp/ast/scope.go b/x/exp/ast/scope.go index 50d259c5..fdbfc565 100644 --- a/x/exp/ast/scope.go +++ b/x/exp/ast/scope.go @@ -94,10 +94,11 @@ type isScopeNode interface { } type scopeNode struct { - isScopeNode Variable nodeTypeVariable } +func (n scopeNode) isScope() {} + type scopeTypeAll struct { scopeNode } From fa29294372e1633e38295ab9b71384ff61d3333d Mon Sep 17 00:00:00 2001 From: philhassey Date: Thu, 8 Aug 2024 16:07:18 -0600 Subject: [PATCH 070/216] cedar: fix broken tests Addresses IDX-142 Signed-off-by: philhassey --- cedar_test.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/cedar_test.go b/cedar_test.go index c6ddbb3a..44eef948 100644 --- a/cedar_test.go +++ b/cedar_test.go @@ -519,8 +519,8 @@ func TestIsAuthorized(t *testing.T) { Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: false, - DiagErr: 0, - ParseErr: true, + DiagErr: 1, + ParseErr: false, }, { Name: "permit-when-like", @@ -542,8 +542,8 @@ func TestIsAuthorized(t *testing.T) { Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: false, - DiagErr: 0, - ParseErr: true, + DiagErr: 1, + ParseErr: false, }, { Name: "permit-when-decimal", From 1bc1538d59ee40288efe643dd7799c237c399d84 Mon Sep 17 00:00:00 2001 From: philhassey Date: Thu, 8 Aug 2024 16:12:12 -0600 Subject: [PATCH 071/216] x/exp/ast: move eval tests into ast Addresses IDX-142 Signed-off-by: philhassey --- match_test.go | 47 ------ eval_test.go => x/exp/ast/eval_test.go | 203 ++++++++++++------------- 2 files changed, 101 insertions(+), 149 deletions(-) delete mode 100644 match_test.go rename eval_test.go => x/exp/ast/eval_test.go (94%) diff --git a/match_test.go b/match_test.go deleted file mode 100644 index 783ed174..00000000 --- a/match_test.go +++ /dev/null @@ -1,47 +0,0 @@ -package cedar - -import ( - "testing" - - "github.com/cedar-policy/cedar-go/testutil" - "github.com/cedar-policy/cedar-go/x/exp/parser" -) - -func TestMatch(t *testing.T) { - t.Parallel() - tests := []struct { - pattern string - target string - want bool - }{ - {`""`, "", true}, - {`""`, "hello", false}, - {`"*"`, "hello", true}, - {`"e"`, "hello", false}, - {`"*e"`, "hello", false}, - {`"*e*"`, "hello", true}, - {`"hello"`, "hello", true}, - {`"hello*"`, "hello", true}, - {`"*h*llo*"`, "hello", true}, - {`"h*e*o"`, "hello", true}, - {`"h*e**o"`, "hello", true}, - {`"h*z*o"`, "hello", false}, - - {`"\u{210d}*"`, "ℍello", true}, - {`"\u{210d}*"`, "Hello", false}, - - {`"\*\**\*\*"`, "**foo**", true}, - {`"\*\**\*\*"`, "**bar**", true}, - {`"\*\**\*\*"`, "*bar*", false}, - } - for _, tt := range tests { - tt := tt - t.Run(tt.pattern+":"+tt.target, func(t *testing.T) { - t.Parallel() - pat, err := parser.NewPattern(tt.pattern) - testutil.OK(t, err) - got := match(pat, tt.target) - testutil.Equals(t, got, tt.want) - }) - } -} diff --git a/eval_test.go b/x/exp/ast/eval_test.go similarity index 94% rename from eval_test.go rename to x/exp/ast/eval_test.go index 3ecdd2b8..875f7008 100644 --- a/eval_test.go +++ b/x/exp/ast/eval_test.go @@ -1,4 +1,4 @@ -package cedar +package ast import ( "fmt" @@ -8,7 +8,6 @@ import ( "github.com/cedar-policy/cedar-go/testutil" "github.com/cedar-policy/cedar-go/types" - "github.com/cedar-policy/cedar-go/x/exp/parser" ) var errTest = fmt.Errorf("test error") @@ -35,7 +34,7 @@ func TestOrNode(t *testing.T) { t.Run(fmt.Sprintf("%v%v", tt.lhs, tt.rhs), func(t *testing.T) { t.Parallel() n := newOrNode(newLiteralEval(types.Boolean(tt.lhs)), newLiteralEval(types.Boolean(tt.rhs))) - v, err := n.Eval(&evalContext{}) + v, err := n.Eval(&EvalContext{}) testutil.OK(t, err) types.AssertBoolValue(t, v, tt.result) }) @@ -46,7 +45,7 @@ func TestOrNode(t *testing.T) { t.Parallel() n := newOrNode( newLiteralEval(types.Boolean(true)), newLiteralEval(types.Long(1))) - v, err := n.Eval(&evalContext{}) + v, err := n.Eval(&EvalContext{}) testutil.OK(t, err) types.AssertBoolValue(t, v, true) }) @@ -54,7 +53,7 @@ func TestOrNode(t *testing.T) { { tests := []struct { name string - lhs, rhs evaler + lhs, rhs Evaler err error }{ {"LhsError", newErrorEval(errTest), newLiteralEval(types.Boolean(true)), errTest}, @@ -67,7 +66,7 @@ func TestOrNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newOrNode(tt.lhs, tt.rhs) - _, err := n.Eval(&evalContext{}) + _, err := n.Eval(&EvalContext{}) testutil.AssertError(t, err, tt.err) }) } @@ -90,7 +89,7 @@ func TestAndNode(t *testing.T) { t.Run(fmt.Sprintf("%v%v", tt.lhs, tt.rhs), func(t *testing.T) { t.Parallel() n := newAndEval(newLiteralEval(types.Boolean(tt.lhs)), newLiteralEval(types.Boolean(tt.rhs))) - v, err := n.Eval(&evalContext{}) + v, err := n.Eval(&EvalContext{}) testutil.OK(t, err) types.AssertBoolValue(t, v, tt.result) }) @@ -101,7 +100,7 @@ func TestAndNode(t *testing.T) { t.Parallel() n := newAndEval( newLiteralEval(types.Boolean(false)), newLiteralEval(types.Long(1))) - v, err := n.Eval(&evalContext{}) + v, err := n.Eval(&EvalContext{}) testutil.OK(t, err) types.AssertBoolValue(t, v, false) }) @@ -109,7 +108,7 @@ func TestAndNode(t *testing.T) { { tests := []struct { name string - lhs, rhs evaler + lhs, rhs Evaler err error }{ {"LhsError", newErrorEval(errTest), newLiteralEval(types.Boolean(true)), errTest}, @@ -122,7 +121,7 @@ func TestAndNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newAndEval(tt.lhs, tt.rhs) - _, err := n.Eval(&evalContext{}) + _, err := n.Eval(&EvalContext{}) testutil.AssertError(t, err, tt.err) }) } @@ -143,7 +142,7 @@ func TestNotNode(t *testing.T) { t.Run(fmt.Sprintf("%v", tt.arg), func(t *testing.T) { t.Parallel() n := newNotEval(newLiteralEval(types.Boolean(tt.arg))) - v, err := n.Eval(&evalContext{}) + v, err := n.Eval(&EvalContext{}) testutil.OK(t, err) types.AssertBoolValue(t, v, tt.result) }) @@ -153,7 +152,7 @@ func TestNotNode(t *testing.T) { { tests := []struct { name string - arg evaler + arg Evaler err error }{ {"Error", newErrorEval(errTest), errTest}, @@ -164,7 +163,7 @@ func TestNotNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newNotEval(tt.arg) - _, err := n.Eval(&evalContext{}) + _, err := n.Eval(&EvalContext{}) testutil.AssertError(t, err, tt.err) }) } @@ -344,14 +343,14 @@ func TestAddNode(t *testing.T) { t.Run("Basic", func(t *testing.T) { t.Parallel() n := newAddEval(newLiteralEval(types.Long(1)), newLiteralEval(types.Long(2))) - v, err := n.Eval(&evalContext{}) + v, err := n.Eval(&EvalContext{}) testutil.OK(t, err) types.AssertLongValue(t, v, 3) }) tests := []struct { name string - lhs, rhs evaler + lhs, rhs Evaler err error }{ {"LhsError", newErrorEval(errTest), newLiteralEval(types.Long(0)), errTest}, @@ -372,7 +371,7 @@ func TestAddNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newAddEval(tt.lhs, tt.rhs) - _, err := n.Eval(&evalContext{}) + _, err := n.Eval(&EvalContext{}) testutil.AssertError(t, err, tt.err) }) } @@ -383,14 +382,14 @@ func TestSubtractNode(t *testing.T) { t.Run("Basic", func(t *testing.T) { t.Parallel() n := newSubtractEval(newLiteralEval(types.Long(1)), newLiteralEval(types.Long(2))) - v, err := n.Eval(&evalContext{}) + v, err := n.Eval(&EvalContext{}) testutil.OK(t, err) types.AssertLongValue(t, v, -1) }) tests := []struct { name string - lhs, rhs evaler + lhs, rhs Evaler err error }{ {"LhsError", newErrorEval(errTest), newLiteralEval(types.Long(0)), errTest}, @@ -411,7 +410,7 @@ func TestSubtractNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newSubtractEval(tt.lhs, tt.rhs) - _, err := n.Eval(&evalContext{}) + _, err := n.Eval(&EvalContext{}) testutil.AssertError(t, err, tt.err) }) } @@ -422,14 +421,14 @@ func TestMultiplyNode(t *testing.T) { t.Run("Basic", func(t *testing.T) { t.Parallel() n := newMultiplyEval(newLiteralEval(types.Long(-3)), newLiteralEval(types.Long(2))) - v, err := n.Eval(&evalContext{}) + v, err := n.Eval(&EvalContext{}) testutil.OK(t, err) types.AssertLongValue(t, v, -6) }) tests := []struct { name string - lhs, rhs evaler + lhs, rhs Evaler err error }{ {"LhsError", newErrorEval(errTest), newLiteralEval(types.Long(0)), errTest}, @@ -450,7 +449,7 @@ func TestMultiplyNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newMultiplyEval(tt.lhs, tt.rhs) - _, err := n.Eval(&evalContext{}) + _, err := n.Eval(&EvalContext{}) testutil.AssertError(t, err, tt.err) }) } @@ -461,14 +460,14 @@ func TestNegateNode(t *testing.T) { t.Run("Basic", func(t *testing.T) { t.Parallel() n := newNegateEval(newLiteralEval(types.Long(-3))) - v, err := n.Eval(&evalContext{}) + v, err := n.Eval(&EvalContext{}) testutil.OK(t, err) types.AssertLongValue(t, v, 3) }) tests := []struct { name string - arg evaler + arg Evaler err error }{ {"Error", newErrorEval(errTest), errTest}, @@ -480,7 +479,7 @@ func TestNegateNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newNegateEval(tt.arg) - _, err := n.Eval(&evalContext{}) + _, err := n.Eval(&EvalContext{}) testutil.AssertError(t, err, tt.err) }) } @@ -509,7 +508,7 @@ func TestLongLessThanNode(t *testing.T) { t.Parallel() n := newLongLessThanEval( newLiteralEval(types.Long(tt.lhs)), newLiteralEval(types.Long(tt.rhs))) - v, err := n.Eval(&evalContext{}) + v, err := n.Eval(&EvalContext{}) testutil.OK(t, err) types.AssertBoolValue(t, v, tt.result) }) @@ -518,7 +517,7 @@ func TestLongLessThanNode(t *testing.T) { { tests := []struct { name string - lhs, rhs evaler + lhs, rhs Evaler err error }{ {"LhsError", newErrorEval(errTest), newLiteralEval(types.Long(0)), errTest}, @@ -531,7 +530,7 @@ func TestLongLessThanNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newLongLessThanEval(tt.lhs, tt.rhs) - _, err := n.Eval(&evalContext{}) + _, err := n.Eval(&EvalContext{}) testutil.AssertError(t, err, tt.err) }) } @@ -561,7 +560,7 @@ func TestLongLessThanOrEqualNode(t *testing.T) { t.Parallel() n := newLongLessThanOrEqualEval( newLiteralEval(types.Long(tt.lhs)), newLiteralEval(types.Long(tt.rhs))) - v, err := n.Eval(&evalContext{}) + v, err := n.Eval(&EvalContext{}) testutil.OK(t, err) types.AssertBoolValue(t, v, tt.result) }) @@ -570,7 +569,7 @@ func TestLongLessThanOrEqualNode(t *testing.T) { { tests := []struct { name string - lhs, rhs evaler + lhs, rhs Evaler err error }{ {"LhsError", newErrorEval(errTest), newLiteralEval(types.Long(0)), errTest}, @@ -583,7 +582,7 @@ func TestLongLessThanOrEqualNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newLongLessThanOrEqualEval(tt.lhs, tt.rhs) - _, err := n.Eval(&evalContext{}) + _, err := n.Eval(&EvalContext{}) testutil.AssertError(t, err, tt.err) }) } @@ -613,7 +612,7 @@ func TestLongGreaterThanNode(t *testing.T) { t.Parallel() n := newLongGreaterThanEval( newLiteralEval(types.Long(tt.lhs)), newLiteralEval(types.Long(tt.rhs))) - v, err := n.Eval(&evalContext{}) + v, err := n.Eval(&EvalContext{}) testutil.OK(t, err) types.AssertBoolValue(t, v, tt.result) }) @@ -622,7 +621,7 @@ func TestLongGreaterThanNode(t *testing.T) { { tests := []struct { name string - lhs, rhs evaler + lhs, rhs Evaler err error }{ {"LhsError", newErrorEval(errTest), newLiteralEval(types.Long(0)), errTest}, @@ -635,7 +634,7 @@ func TestLongGreaterThanNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newLongGreaterThanEval(tt.lhs, tt.rhs) - _, err := n.Eval(&evalContext{}) + _, err := n.Eval(&EvalContext{}) testutil.AssertError(t, err, tt.err) }) } @@ -665,7 +664,7 @@ func TestLongGreaterThanOrEqualNode(t *testing.T) { t.Parallel() n := newLongGreaterThanOrEqualEval( newLiteralEval(types.Long(tt.lhs)), newLiteralEval(types.Long(tt.rhs))) - v, err := n.Eval(&evalContext{}) + v, err := n.Eval(&EvalContext{}) testutil.OK(t, err) types.AssertBoolValue(t, v, tt.result) }) @@ -674,7 +673,7 @@ func TestLongGreaterThanOrEqualNode(t *testing.T) { { tests := []struct { name string - lhs, rhs evaler + lhs, rhs Evaler err error }{ {"LhsError", newErrorEval(errTest), newLiteralEval(types.Long(0)), errTest}, @@ -687,7 +686,7 @@ func TestLongGreaterThanOrEqualNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newLongGreaterThanOrEqualEval(tt.lhs, tt.rhs) - _, err := n.Eval(&evalContext{}) + _, err := n.Eval(&EvalContext{}) testutil.AssertError(t, err, tt.err) }) } @@ -722,7 +721,7 @@ func TestDecimalLessThanNode(t *testing.T) { testutil.OK(t, err) rhsv := rhsd n := newDecimalLessThanEval(newLiteralEval(lhsv), newLiteralEval(rhsv)) - v, err := n.Eval(&evalContext{}) + v, err := n.Eval(&EvalContext{}) testutil.OK(t, err) types.AssertBoolValue(t, v, tt.result) }) @@ -731,7 +730,7 @@ func TestDecimalLessThanNode(t *testing.T) { { tests := []struct { name string - lhs, rhs evaler + lhs, rhs Evaler err error }{ {"LhsError", newErrorEval(errTest), newLiteralEval(types.Decimal(0)), errTest}, @@ -744,7 +743,7 @@ func TestDecimalLessThanNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newDecimalLessThanEval(tt.lhs, tt.rhs) - _, err := n.Eval(&evalContext{}) + _, err := n.Eval(&EvalContext{}) testutil.AssertError(t, err, tt.err) }) } @@ -779,7 +778,7 @@ func TestDecimalLessThanOrEqualNode(t *testing.T) { testutil.OK(t, err) rhsv := rhsd n := newDecimalLessThanOrEqualEval(newLiteralEval(lhsv), newLiteralEval(rhsv)) - v, err := n.Eval(&evalContext{}) + v, err := n.Eval(&EvalContext{}) testutil.OK(t, err) types.AssertBoolValue(t, v, tt.result) }) @@ -788,7 +787,7 @@ func TestDecimalLessThanOrEqualNode(t *testing.T) { { tests := []struct { name string - lhs, rhs evaler + lhs, rhs Evaler err error }{ {"LhsError", newErrorEval(errTest), newLiteralEval(types.Decimal(0)), errTest}, @@ -801,7 +800,7 @@ func TestDecimalLessThanOrEqualNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newDecimalLessThanOrEqualEval(tt.lhs, tt.rhs) - _, err := n.Eval(&evalContext{}) + _, err := n.Eval(&EvalContext{}) testutil.AssertError(t, err, tt.err) }) } @@ -836,7 +835,7 @@ func TestDecimalGreaterThanNode(t *testing.T) { testutil.OK(t, err) rhsv := rhsd n := newDecimalGreaterThanEval(newLiteralEval(lhsv), newLiteralEval(rhsv)) - v, err := n.Eval(&evalContext{}) + v, err := n.Eval(&EvalContext{}) testutil.OK(t, err) types.AssertBoolValue(t, v, tt.result) }) @@ -845,7 +844,7 @@ func TestDecimalGreaterThanNode(t *testing.T) { { tests := []struct { name string - lhs, rhs evaler + lhs, rhs Evaler err error }{ {"LhsError", newErrorEval(errTest), newLiteralEval(types.Decimal(0)), errTest}, @@ -858,7 +857,7 @@ func TestDecimalGreaterThanNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newDecimalGreaterThanEval(tt.lhs, tt.rhs) - _, err := n.Eval(&evalContext{}) + _, err := n.Eval(&EvalContext{}) testutil.AssertError(t, err, tt.err) }) } @@ -893,7 +892,7 @@ func TestDecimalGreaterThanOrEqualNode(t *testing.T) { testutil.OK(t, err) rhsv := rhsd n := newDecimalGreaterThanOrEqualEval(newLiteralEval(lhsv), newLiteralEval(rhsv)) - v, err := n.Eval(&evalContext{}) + v, err := n.Eval(&EvalContext{}) testutil.OK(t, err) types.AssertBoolValue(t, v, tt.result) }) @@ -902,7 +901,7 @@ func TestDecimalGreaterThanOrEqualNode(t *testing.T) { { tests := []struct { name string - lhs, rhs evaler + lhs, rhs Evaler err error }{ {"LhsError", newErrorEval(errTest), newLiteralEval(types.Decimal(0)), errTest}, @@ -915,7 +914,7 @@ func TestDecimalGreaterThanOrEqualNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newDecimalGreaterThanOrEqualEval(tt.lhs, tt.rhs) - _, err := n.Eval(&evalContext{}) + _, err := n.Eval(&EvalContext{}) testutil.AssertError(t, err, tt.err) }) } @@ -926,7 +925,7 @@ func TestIfThenElseNode(t *testing.T) { t.Parallel() tests := []struct { name string - if_, then, else_ evaler + if_, then, else_ Evaler result types.Value err error }{ @@ -946,7 +945,7 @@ func TestIfThenElseNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newIfThenElseEval(tt.if_, tt.then, tt.else_) - v, err := n.Eval(&evalContext{}) + v, err := n.Eval(&EvalContext{}) testutil.AssertError(t, err, tt.err) testutil.Equals(t, v, tt.result) }) @@ -957,7 +956,7 @@ func TestEqualNode(t *testing.T) { t.Parallel() tests := []struct { name string - lhs, rhs evaler + lhs, rhs Evaler result types.Value err error }{ @@ -972,7 +971,7 @@ func TestEqualNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newEqualEval(tt.lhs, tt.rhs) - v, err := n.Eval(&evalContext{}) + v, err := n.Eval(&EvalContext{}) testutil.AssertError(t, err, tt.err) types.AssertValue(t, v, tt.result) }) @@ -983,7 +982,7 @@ func TestNotEqualNode(t *testing.T) { t.Parallel() tests := []struct { name string - lhs, rhs evaler + lhs, rhs Evaler result types.Value err error }{ @@ -998,7 +997,7 @@ func TestNotEqualNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newNotEqualEval(tt.lhs, tt.rhs) - v, err := n.Eval(&evalContext{}) + v, err := n.Eval(&EvalContext{}) testutil.AssertError(t, err, tt.err) types.AssertValue(t, v, tt.result) }) @@ -1009,14 +1008,14 @@ func TestSetLiteralNode(t *testing.T) { t.Parallel() tests := []struct { name string - elems []evaler + elems []Evaler result types.Value err error }{ - {"empty", []evaler{}, types.Set{}, nil}, - {"errorNode", []evaler{newErrorEval(errTest)}, types.ZeroValue(), errTest}, + {"empty", []Evaler{}, types.Set{}, nil}, + {"errorNode", []Evaler{newErrorEval(errTest)}, types.ZeroValue(), errTest}, {"nested", - []evaler{ + []Evaler{ newLiteralEval(types.Boolean(true)), newLiteralEval(types.Set{ types.Boolean(false), @@ -1039,7 +1038,7 @@ func TestSetLiteralNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newSetLiteralEval(tt.elems) - v, err := n.Eval(&evalContext{}) + v, err := n.Eval(&EvalContext{}) testutil.AssertError(t, err, tt.err) types.AssertValue(t, v, tt.result) }) @@ -1051,7 +1050,7 @@ func TestContainsNode(t *testing.T) { { tests := []struct { name string - lhs, rhs evaler + lhs, rhs Evaler err error }{ {"LhsError", newErrorEval(errTest), newLiteralEval(types.Long(0)), errTest}, @@ -1063,7 +1062,7 @@ func TestContainsNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newContainsEval(tt.lhs, tt.rhs) - v, err := n.Eval(&evalContext{}) + v, err := n.Eval(&EvalContext{}) testutil.AssertError(t, err, tt.err) types.AssertZeroValue(t, v) }) @@ -1076,7 +1075,7 @@ func TestContainsNode(t *testing.T) { tests := []struct { name string - lhs, rhs evaler + lhs, rhs Evaler result bool }{ {"empty", newLiteralEval(empty), newLiteralEval(types.Boolean(true)), false}, @@ -1092,7 +1091,7 @@ func TestContainsNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newContainsEval(tt.lhs, tt.rhs) - v, err := n.Eval(&evalContext{}) + v, err := n.Eval(&EvalContext{}) testutil.OK(t, err) types.AssertBoolValue(t, v, tt.result) }) @@ -1105,7 +1104,7 @@ func TestContainsAllNode(t *testing.T) { { tests := []struct { name string - lhs, rhs evaler + lhs, rhs Evaler err error }{ {"LhsError", newErrorEval(errTest), newLiteralEval(types.Set{}), errTest}, @@ -1118,7 +1117,7 @@ func TestContainsAllNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newContainsAllEval(tt.lhs, tt.rhs) - v, err := n.Eval(&evalContext{}) + v, err := n.Eval(&EvalContext{}) testutil.AssertError(t, err, tt.err) types.AssertZeroValue(t, v) }) @@ -1132,7 +1131,7 @@ func TestContainsAllNode(t *testing.T) { tests := []struct { name string - lhs, rhs evaler + lhs, rhs Evaler result bool }{ {"emptyEmpty", newLiteralEval(empty), newLiteralEval(empty), true}, @@ -1146,7 +1145,7 @@ func TestContainsAllNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newContainsAllEval(tt.lhs, tt.rhs) - v, err := n.Eval(&evalContext{}) + v, err := n.Eval(&EvalContext{}) testutil.OK(t, err) types.AssertBoolValue(t, v, tt.result) }) @@ -1159,7 +1158,7 @@ func TestContainsAnyNode(t *testing.T) { { tests := []struct { name string - lhs, rhs evaler + lhs, rhs Evaler err error }{ {"LhsError", newErrorEval(errTest), newLiteralEval(types.Set{}), errTest}, @@ -1172,7 +1171,7 @@ func TestContainsAnyNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newContainsAnyEval(tt.lhs, tt.rhs) - v, err := n.Eval(&evalContext{}) + v, err := n.Eval(&EvalContext{}) testutil.AssertError(t, err, tt.err) types.AssertZeroValue(t, v) }) @@ -1187,7 +1186,7 @@ func TestContainsAnyNode(t *testing.T) { tests := []struct { name string - lhs, rhs evaler + lhs, rhs Evaler result bool }{ {"emptyEmpty", newLiteralEval(empty), newLiteralEval(empty), false}, @@ -1203,7 +1202,7 @@ func TestContainsAnyNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newContainsAnyEval(tt.lhs, tt.rhs) - v, err := n.Eval(&evalContext{}) + v, err := n.Eval(&EvalContext{}) testutil.OK(t, err) types.AssertBoolValue(t, v, tt.result) }) @@ -1215,14 +1214,14 @@ func TestRecordLiteralNode(t *testing.T) { t.Parallel() tests := []struct { name string - elems map[string]evaler + elems map[string]Evaler result types.Value err error }{ - {"empty", map[string]evaler{}, types.Record{}, nil}, - {"errorNode", map[string]evaler{"foo": newErrorEval(errTest)}, types.ZeroValue(), errTest}, + {"empty", map[string]Evaler{}, types.Record{}, nil}, + {"errorNode", map[string]Evaler{"foo": newErrorEval(errTest)}, types.ZeroValue(), errTest}, {"ok", - map[string]evaler{ + map[string]Evaler{ "foo": newLiteralEval(types.Boolean(true)), "bar": newLiteralEval(types.String("baz")), }, types.Record{ @@ -1235,7 +1234,7 @@ func TestRecordLiteralNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newRecordLiteralEval(tt.elems) - v, err := n.Eval(&evalContext{}) + v, err := n.Eval(&EvalContext{}) testutil.AssertError(t, err, tt.err) types.AssertValue(t, v, tt.result) }) @@ -1246,7 +1245,7 @@ func TestAttributeAccessNode(t *testing.T) { t.Parallel() tests := []struct { name string - object evaler + object Evaler attribute string result types.Value err error @@ -1284,7 +1283,7 @@ func TestAttributeAccessNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newAttributeAccessEval(tt.object, tt.attribute) - v, err := n.Eval(&evalContext{ + v, err := n.Eval(&EvalContext{ Entities: entitiesFromSlice([]Entity{ { UID: types.NewEntityUID("knownType", "knownID"), @@ -1302,7 +1301,7 @@ func TestHasNode(t *testing.T) { t.Parallel() tests := []struct { name string - record evaler + record Evaler attribute string result types.Value err error @@ -1340,7 +1339,7 @@ func TestHasNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newHasEval(tt.record, tt.attribute) - v, err := n.Eval(&evalContext{ + v, err := n.Eval(&EvalContext{ Entities: entitiesFromSlice([]Entity{ { UID: types.NewEntityUID("knownType", "knownID"), @@ -1358,7 +1357,7 @@ func TestLikeNode(t *testing.T) { t.Parallel() tests := []struct { name string - str evaler + str Evaler pattern string result types.Value err error @@ -1398,10 +1397,10 @@ func TestLikeNode(t *testing.T) { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() - pat, err := parser.NewPattern(tt.pattern) + pat, err := types.ParsePattern(tt.pattern[1 : len(tt.pattern)-1]) testutil.OK(t, err) n := newLikeEval(tt.str, pat) - v, err := n.Eval(&evalContext{}) + v, err := n.Eval(&EvalContext{}) testutil.AssertError(t, err, tt.err) types.AssertValue(t, v, tt.result) }) @@ -1412,24 +1411,24 @@ func TestVariableNode(t *testing.T) { t.Parallel() tests := []struct { name string - context evalContext + context EvalContext variable variableName result types.Value }{ {"principal", - evalContext{Principal: types.String("foo")}, + EvalContext{Principal: types.String("foo")}, variableNamePrincipal, types.String("foo")}, {"action", - evalContext{Action: types.String("bar")}, + EvalContext{Action: types.String("bar")}, variableNameAction, types.String("bar")}, {"resource", - evalContext{Resource: types.String("baz")}, + EvalContext{Resource: types.String("baz")}, variableNameResource, types.String("baz")}, {"context", - evalContext{Context: types.String("frob")}, + EvalContext{Context: types.String("frob")}, variableNameContext, types.String("frob")}, } @@ -1584,7 +1583,7 @@ func TestIsNode(t *testing.T) { t.Parallel() tests := []struct { name string - lhs, rhs evaler + lhs, rhs Evaler result types.Value err error }{ @@ -1599,7 +1598,7 @@ func TestIsNode(t *testing.T) { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() - got, err := newIsEval(tt.lhs, tt.rhs).Eval(&evalContext{}) + got, err := newIsEval(tt.lhs, tt.rhs).Eval(&EvalContext{}) testutil.AssertError(t, err, tt.err) types.AssertValue(t, got, tt.result) }) @@ -1610,7 +1609,7 @@ func TestInNode(t *testing.T) { t.Parallel() tests := []struct { name string - lhs, rhs evaler + lhs, rhs Evaler parents map[string][]string result types.Value err error @@ -1715,8 +1714,8 @@ func TestInNode(t *testing.T) { Parents: ps, } } - evalContext := evalContext{Entities: entities} - v, err := n.Eval(&evalContext) + EvalContext := EvalContext{Entities: entities} + v, err := n.Eval(&EvalContext) testutil.AssertError(t, err, tt.err) types.AssertValue(t, v, tt.result) }) @@ -1727,7 +1726,7 @@ func TestDecimalLiteralNode(t *testing.T) { t.Parallel() tests := []struct { name string - arg evaler + arg Evaler result types.Value err error }{ @@ -1741,7 +1740,7 @@ func TestDecimalLiteralNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newDecimalLiteralEval(tt.arg) - v, err := n.Eval(&evalContext{}) + v, err := n.Eval(&EvalContext{}) testutil.AssertError(t, err, tt.err) types.AssertValue(t, v, tt.result) }) @@ -1754,7 +1753,7 @@ func TestIPLiteralNode(t *testing.T) { testutil.OK(t, err) tests := []struct { name string - arg evaler + arg Evaler result types.Value err error }{ @@ -1768,7 +1767,7 @@ func TestIPLiteralNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newIPLiteralEval(tt.arg) - v, err := n.Eval(&evalContext{}) + v, err := n.Eval(&EvalContext{}) testutil.AssertError(t, err, tt.err) types.AssertValue(t, v, tt.result) }) @@ -1785,7 +1784,7 @@ func TestIPTestNode(t *testing.T) { testutil.OK(t, err) tests := []struct { name string - lhs evaler + lhs Evaler rhs ipTestType result types.Value err error @@ -1806,7 +1805,7 @@ func TestIPTestNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newIPTestEval(tt.lhs, tt.rhs) - v, err := n.Eval(&evalContext{}) + v, err := n.Eval(&EvalContext{}) testutil.AssertError(t, err, tt.err) types.AssertValue(t, v, tt.result) }) @@ -1823,7 +1822,7 @@ func TestIPIsInRangeNode(t *testing.T) { testutil.OK(t, err) tests := []struct { name string - lhs, rhs evaler + lhs, rhs Evaler result types.Value err error }{ @@ -1844,7 +1843,7 @@ func TestIPIsInRangeNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newIPIsInRangeEval(tt.lhs, tt.rhs) - v, err := n.Eval(&evalContext{}) + v, err := n.Eval(&EvalContext{}) testutil.AssertError(t, err, tt.err) types.AssertValue(t, v, tt.result) }) From 17de84a5fec7b163f0ac561d347fde816a5a7bfc Mon Sep 17 00:00:00 2001 From: philhassey Date: Thu, 8 Aug 2024 16:19:25 -0600 Subject: [PATCH 072/216] cedar: eliminate dead code Addresses IDX-142 Signed-off-by: philhassey --- cedar.go | 4 + cedar_test.go | 40 +- eval.go | 1091 ------------------------------------------------ toeval.go | 307 -------------- toeval_test.go | 170 -------- 5 files changed, 24 insertions(+), 1588 deletions(-) delete mode 100644 eval.go delete mode 100644 toeval.go delete mode 100644 toeval_test.go diff --git a/cedar.go b/cedar.go index a5ab3653..58cd3ec9 100644 --- a/cedar.go +++ b/cedar.go @@ -160,6 +160,10 @@ type Request struct { Context types.Record `json:"context"` } +type evalContext = ast.EvalContext + +type evaler = ast.Evaler + // IsAuthorized uses the combination of the PolicySet and Entities to determine // if the given Request to determine Decision and Diagnostic. func (p PolicySet) IsAuthorized(entities Entities, req Request) (Decision, Diagnostic) { diff --git a/cedar_test.go b/cedar_test.go index 44eef948..350d77ae 100644 --- a/cedar_test.go +++ b/cedar_test.go @@ -863,26 +863,26 @@ func TestError(t *testing.T) { testutil.Equals(t, e.String(), "while evaluating policy `policy42`: bad error") } -func TestInvalidPolicy(t *testing.T) { - t.Parallel() - // This case is very fabricated, it can't really happen - ps := PolicySet{ - { - Effect: Forbid, - eval: newLiteralEval(types.Long(42)), - }, - } - ok, diag := ps.IsAuthorized(Entities{}, Request{}) - testutil.Equals(t, ok, Deny) - testutil.Equals(t, diag, Diagnostic{ - Errors: []Error{ - { - Policy: 0, - Message: "type error: expected bool, got long", - }, - }, - }) -} +// func TestInvalidPolicy(t *testing.T) { +// t.Parallel() +// // This case is very fabricated, it can't really happen +// ps := PolicySet{ +// { +// Effect: Forbid, +// eval: newLiteralEval(types.Long(42)), +// }, +// } +// ok, diag := ps.IsAuthorized(Entities{}, Request{}) +// testutil.Equals(t, ok, Deny) +// testutil.Equals(t, diag, Diagnostic{ +// Errors: []Error{ +// { +// Policy: 0, +// Message: "type error: expected bool, got long", +// }, +// }, +// }) +// } func TestCorpusRelated(t *testing.T) { t.Parallel() diff --git a/eval.go b/eval.go deleted file mode 100644 index eae3b660..00000000 --- a/eval.go +++ /dev/null @@ -1,1091 +0,0 @@ -package cedar - -import ( - "fmt" - - "github.com/cedar-policy/cedar-go/types" - "github.com/cedar-policy/cedar-go/x/exp/ast" - "github.com/cedar-policy/cedar-go/x/exp/parser" -) - -var errOverflow = fmt.Errorf("integer overflow") -var errUnknownMethod = fmt.Errorf("unknown method") -var errUnknownExtensionFunction = fmt.Errorf("function does not exist") -var errArity = fmt.Errorf("wrong number of arguments provided to extension function") -var errAttributeAccess = fmt.Errorf("does not have the attribute") -var errEntityNotExist = fmt.Errorf("does not exist") -var errUnspecifiedEntity = fmt.Errorf("unspecified entity") - -// type evalContext struct { -// Entities Entities -// Principal, Action, Resource types.Value -// Context types.Value -// } - -type evalContext = ast.EvalContext - -// type evaler interface { -// Eval(*evalContext) (types.Value, error) -// } - -type evaler = ast.Evaler - -func evalBool(n evaler, ctx *evalContext) (types.Boolean, error) { - v, err := n.Eval(ctx) - if err != nil { - return false, err - } - b, err := types.ValueToBool(v) - if err != nil { - return false, err - } - return b, nil -} - -func evalLong(n evaler, ctx *evalContext) (types.Long, error) { - v, err := n.Eval(ctx) - if err != nil { - return 0, err - } - l, err := types.ValueToLong(v) - if err != nil { - return 0, err - } - return l, nil -} - -func evalString(n evaler, ctx *evalContext) (types.String, error) { - v, err := n.Eval(ctx) - if err != nil { - return "", err - } - s, err := types.ValueToString(v) - if err != nil { - return "", err - } - return s, nil -} - -func evalSet(n evaler, ctx *evalContext) (types.Set, error) { - v, err := n.Eval(ctx) - if err != nil { - return nil, err - } - s, err := types.ValueToSet(v) - if err != nil { - return nil, err - } - return s, nil -} - -func evalEntity(n evaler, ctx *evalContext) (types.EntityUID, error) { - v, err := n.Eval(ctx) - if err != nil { - return types.EntityUID{}, err - } - e, err := types.ValueToEntity(v) - if err != nil { - return types.EntityUID{}, err - } - return e, nil -} - -func evalPath(n evaler, ctx *evalContext) (types.Path, error) { - v, err := n.Eval(ctx) - if err != nil { - return "", err - } - e, err := types.ValueToPath(v) - if err != nil { - return "", err - } - return e, nil -} - -func evalDecimal(n evaler, ctx *evalContext) (types.Decimal, error) { - v, err := n.Eval(ctx) - if err != nil { - return types.Decimal(0), err - } - d, err := types.ValueToDecimal(v) - if err != nil { - return types.Decimal(0), err - } - return d, nil -} - -func evalIP(n evaler, ctx *evalContext) (types.IPAddr, error) { - v, err := n.Eval(ctx) - if err != nil { - return types.IPAddr{}, err - } - i, err := types.ValueToIP(v) - if err != nil { - return types.IPAddr{}, err - } - return i, nil -} - -// errorEval -type errorEval struct { - err error -} - -func newErrorEval(err error) *errorEval { - return &errorEval{ - err: err, - } -} - -func (n *errorEval) Eval(_ *evalContext) (types.Value, error) { - return types.ZeroValue(), n.err -} - -// literalEval -type literalEval struct { - value types.Value -} - -func newLiteralEval(value types.Value) *literalEval { - return &literalEval{value: value} -} - -func (n *literalEval) Eval(_ *evalContext) (types.Value, error) { - return n.value, nil -} - -// orEval -type orEval struct { - lhs evaler - rhs evaler -} - -func newOrNode(lhs evaler, rhs evaler) *orEval { - return &orEval{ - lhs: lhs, - rhs: rhs, - } -} - -func (n *orEval) Eval(ctx *evalContext) (types.Value, error) { - v, err := n.lhs.Eval(ctx) - if err != nil { - return types.ZeroValue(), err - } - b, err := types.ValueToBool(v) - if err != nil { - return types.ZeroValue(), err - } - if b { - return v, nil - } - v, err = n.rhs.Eval(ctx) - if err != nil { - return types.ZeroValue(), err - } - _, err = types.ValueToBool(v) - if err != nil { - return types.ZeroValue(), err - } - return v, nil -} - -// andEval -type andEval struct { - lhs evaler - rhs evaler -} - -func newAndEval(lhs evaler, rhs evaler) *andEval { - return &andEval{ - lhs: lhs, - rhs: rhs, - } -} - -func (n *andEval) Eval(ctx *evalContext) (types.Value, error) { - v, err := n.lhs.Eval(ctx) - if err != nil { - return types.ZeroValue(), err - } - b, err := types.ValueToBool(v) - if err != nil { - return types.ZeroValue(), err - } - if !b { - return v, nil - } - v, err = n.rhs.Eval(ctx) - if err != nil { - return types.ZeroValue(), err - } - _, err = types.ValueToBool(v) - if err != nil { - return types.ZeroValue(), err - } - return v, nil -} - -// notEval -type notEval struct { - inner evaler -} - -func newNotEval(inner evaler) *notEval { - return ¬Eval{ - inner: inner, - } -} - -func (n *notEval) Eval(ctx *evalContext) (types.Value, error) { - v, err := n.inner.Eval(ctx) - if err != nil { - return types.ZeroValue(), err - } - b, err := types.ValueToBool(v) - if err != nil { - return types.ZeroValue(), err - } - return !b, nil -} - -// Overflow -// The Go spec specifies that overflow results in defined and deterministic -// behavior (https://go.dev/ref/spec#Integer_overflow), so we can go ahead and -// do the operations and then check for overflow ex post facto. - -func checkedAddI64(lhs, rhs types.Long) (types.Long, bool) { - result := lhs + rhs - if (result > lhs) != (rhs > 0) { - return result, false - } - return result, true -} - -func checkedSubI64(lhs, rhs types.Long) (types.Long, bool) { - result := lhs - rhs - if (result > lhs) != (rhs < 0) { - return result, false - } - return result, true -} - -func checkedMulI64(lhs, rhs types.Long) (types.Long, bool) { - if lhs == 0 || rhs == 0 { - return 0, true - } - result := lhs * rhs - if (result < 0) != ((lhs < 0) != (rhs < 0)) { - // If the result doesn't have the correct sign, then we overflowed. - return result, false - } - if result/lhs != rhs { - // If division doesn't yield the original value, then we overflowed. - return result, false - } - return result, true -} - -func checkedNegI64(a types.Long) (types.Long, bool) { - if a == -9_223_372_036_854_775_808 { - return 0, false - } - return -a, true -} - -// addEval -type addEval struct { - lhs evaler - rhs evaler -} - -func newAddEval(lhs evaler, rhs evaler) *addEval { - return &addEval{ - lhs: lhs, - rhs: rhs, - } -} - -func (n *addEval) Eval(ctx *evalContext) (types.Value, error) { - lhs, err := evalLong(n.lhs, ctx) - if err != nil { - return types.ZeroValue(), err - } - rhs, err := evalLong(n.rhs, ctx) - if err != nil { - return types.ZeroValue(), err - } - res, ok := checkedAddI64(lhs, rhs) - if !ok { - return types.ZeroValue(), fmt.Errorf("%w while attempting to add `%d` with `%d`", errOverflow, lhs, rhs) - } - return res, nil -} - -// subtractEval -type subtractEval struct { - lhs evaler - rhs evaler -} - -func newSubtractEval(lhs evaler, rhs evaler) *subtractEval { - return &subtractEval{ - lhs: lhs, - rhs: rhs, - } -} - -func (n *subtractEval) Eval(ctx *evalContext) (types.Value, error) { - lhs, err := evalLong(n.lhs, ctx) - if err != nil { - return types.ZeroValue(), err - } - rhs, err := evalLong(n.rhs, ctx) - if err != nil { - return types.ZeroValue(), err - } - res, ok := checkedSubI64(lhs, rhs) - if !ok { - return types.ZeroValue(), fmt.Errorf("%w while attempting to subtract `%d` from `%d`", errOverflow, rhs, lhs) - } - return res, nil -} - -// multiplyEval -type multiplyEval struct { - lhs evaler - rhs evaler -} - -func newMultiplyEval(lhs evaler, rhs evaler) *multiplyEval { - return &multiplyEval{ - lhs: lhs, - rhs: rhs, - } -} - -func (n *multiplyEval) Eval(ctx *evalContext) (types.Value, error) { - lhs, err := evalLong(n.lhs, ctx) - if err != nil { - return types.ZeroValue(), err - } - rhs, err := evalLong(n.rhs, ctx) - if err != nil { - return types.ZeroValue(), err - } - res, ok := checkedMulI64(lhs, rhs) - if !ok { - return types.ZeroValue(), fmt.Errorf("%w while attempting to multiply `%d` by `%d`", errOverflow, lhs, rhs) - } - return res, nil -} - -// negateEval -type negateEval struct { - inner evaler -} - -func newNegateEval(inner evaler) *negateEval { - return &negateEval{ - inner: inner, - } -} - -func (n *negateEval) Eval(ctx *evalContext) (types.Value, error) { - inner, err := evalLong(n.inner, ctx) - if err != nil { - return types.ZeroValue(), err - } - res, ok := checkedNegI64(inner) - if !ok { - return types.ZeroValue(), fmt.Errorf("%w while attempting to negate `%d`", errOverflow, inner) - } - return res, nil -} - -// longLessThanEval -type longLessThanEval struct { - lhs evaler - rhs evaler -} - -func newLongLessThanEval(lhs evaler, rhs evaler) *longLessThanEval { - return &longLessThanEval{ - lhs: lhs, - rhs: rhs, - } -} - -func (n *longLessThanEval) Eval(ctx *evalContext) (types.Value, error) { - lhs, err := evalLong(n.lhs, ctx) - if err != nil { - return types.ZeroValue(), err - } - rhs, err := evalLong(n.rhs, ctx) - if err != nil { - return types.ZeroValue(), err - } - return types.Boolean(lhs < rhs), nil -} - -// longLessThanOrEqualEval -type longLessThanOrEqualEval struct { - lhs evaler - rhs evaler -} - -func newLongLessThanOrEqualEval(lhs evaler, rhs evaler) *longLessThanOrEqualEval { - return &longLessThanOrEqualEval{ - lhs: lhs, - rhs: rhs, - } -} - -func (n *longLessThanOrEqualEval) Eval(ctx *evalContext) (types.Value, error) { - lhs, err := evalLong(n.lhs, ctx) - if err != nil { - return types.ZeroValue(), err - } - rhs, err := evalLong(n.rhs, ctx) - if err != nil { - return types.ZeroValue(), err - } - return types.Boolean(lhs <= rhs), nil -} - -// longGreaterThanEval -type longGreaterThanEval struct { - lhs evaler - rhs evaler -} - -func newLongGreaterThanEval(lhs evaler, rhs evaler) *longGreaterThanEval { - return &longGreaterThanEval{ - lhs: lhs, - rhs: rhs, - } -} - -func (n *longGreaterThanEval) Eval(ctx *evalContext) (types.Value, error) { - lhs, err := evalLong(n.lhs, ctx) - if err != nil { - return types.ZeroValue(), err - } - rhs, err := evalLong(n.rhs, ctx) - if err != nil { - return types.ZeroValue(), err - } - return types.Boolean(lhs > rhs), nil -} - -// longGreaterThanOrEqualEval -type longGreaterThanOrEqualEval struct { - lhs evaler - rhs evaler -} - -func newLongGreaterThanOrEqualEval(lhs evaler, rhs evaler) *longGreaterThanOrEqualEval { - return &longGreaterThanOrEqualEval{ - lhs: lhs, - rhs: rhs, - } -} - -func (n *longGreaterThanOrEqualEval) Eval(ctx *evalContext) (types.Value, error) { - lhs, err := evalLong(n.lhs, ctx) - if err != nil { - return types.ZeroValue(), err - } - rhs, err := evalLong(n.rhs, ctx) - if err != nil { - return types.ZeroValue(), err - } - return types.Boolean(lhs >= rhs), nil -} - -// decimalLessThanEval -type decimalLessThanEval struct { - lhs evaler - rhs evaler -} - -func newDecimalLessThanEval(lhs evaler, rhs evaler) *decimalLessThanEval { - return &decimalLessThanEval{ - lhs: lhs, - rhs: rhs, - } -} - -func (n *decimalLessThanEval) Eval(ctx *evalContext) (types.Value, error) { - lhs, err := evalDecimal(n.lhs, ctx) - if err != nil { - return types.ZeroValue(), err - } - rhs, err := evalDecimal(n.rhs, ctx) - if err != nil { - return types.ZeroValue(), err - } - return types.Boolean(lhs < rhs), nil -} - -// decimalLessThanOrEqualEval -type decimalLessThanOrEqualEval struct { - lhs evaler - rhs evaler -} - -func newDecimalLessThanOrEqualEval(lhs evaler, rhs evaler) *decimalLessThanOrEqualEval { - return &decimalLessThanOrEqualEval{ - lhs: lhs, - rhs: rhs, - } -} - -func (n *decimalLessThanOrEqualEval) Eval(ctx *evalContext) (types.Value, error) { - lhs, err := evalDecimal(n.lhs, ctx) - if err != nil { - return types.ZeroValue(), err - } - rhs, err := evalDecimal(n.rhs, ctx) - if err != nil { - return types.ZeroValue(), err - } - return types.Boolean(lhs <= rhs), nil -} - -// decimalGreaterThanEval -type decimalGreaterThanEval struct { - lhs evaler - rhs evaler -} - -func newDecimalGreaterThanEval(lhs evaler, rhs evaler) *decimalGreaterThanEval { - return &decimalGreaterThanEval{ - lhs: lhs, - rhs: rhs, - } -} - -func (n *decimalGreaterThanEval) Eval(ctx *evalContext) (types.Value, error) { - lhs, err := evalDecimal(n.lhs, ctx) - if err != nil { - return types.ZeroValue(), err - } - rhs, err := evalDecimal(n.rhs, ctx) - if err != nil { - return types.ZeroValue(), err - } - return types.Boolean(lhs > rhs), nil -} - -// decimalGreaterThanOrEqualEval -type decimalGreaterThanOrEqualEval struct { - lhs evaler - rhs evaler -} - -func newDecimalGreaterThanOrEqualEval(lhs evaler, rhs evaler) *decimalGreaterThanOrEqualEval { - return &decimalGreaterThanOrEqualEval{ - lhs: lhs, - rhs: rhs, - } -} - -func (n *decimalGreaterThanOrEqualEval) Eval(ctx *evalContext) (types.Value, error) { - lhs, err := evalDecimal(n.lhs, ctx) - if err != nil { - return types.ZeroValue(), err - } - rhs, err := evalDecimal(n.rhs, ctx) - if err != nil { - return types.ZeroValue(), err - } - return types.Boolean(lhs >= rhs), nil -} - -// ifThenElseEval -type ifThenElseEval struct { - if_ evaler - then evaler - else_ evaler -} - -func newIfThenElseEval(if_, then, else_ evaler) *ifThenElseEval { - return &ifThenElseEval{ - if_: if_, - then: then, - else_: else_, - } -} - -func (n *ifThenElseEval) Eval(ctx *evalContext) (types.Value, error) { - cond, err := evalBool(n.if_, ctx) - if err != nil { - return types.ZeroValue(), err - } - if cond { - return n.then.Eval(ctx) - } - return n.else_.Eval(ctx) -} - -// notEqualNode -type equalEval struct { - lhs, rhs evaler -} - -func newEqualEval(lhs, rhs evaler) *equalEval { - return &equalEval{ - lhs: lhs, - rhs: rhs, - } -} - -func (n *equalEval) Eval(ctx *evalContext) (types.Value, error) { - lv, err := n.lhs.Eval(ctx) - if err != nil { - return types.ZeroValue(), err - } - rv, err := n.rhs.Eval(ctx) - if err != nil { - return types.ZeroValue(), err - } - return types.Boolean(lv.Equal(rv)), nil -} - -// notEqualEval -type notEqualEval struct { - lhs, rhs evaler -} - -func newNotEqualEval(lhs, rhs evaler) *notEqualEval { - return ¬EqualEval{ - lhs: lhs, - rhs: rhs, - } -} - -func (n *notEqualEval) Eval(ctx *evalContext) (types.Value, error) { - lv, err := n.lhs.Eval(ctx) - if err != nil { - return types.ZeroValue(), err - } - rv, err := n.rhs.Eval(ctx) - if err != nil { - return types.ZeroValue(), err - } - return types.Boolean(!lv.Equal(rv)), nil -} - -// setLiteralEval -type setLiteralEval struct { - elements []evaler -} - -func newSetLiteralEval(elements []evaler) *setLiteralEval { - return &setLiteralEval{elements: elements} -} - -func (n *setLiteralEval) Eval(ctx *evalContext) (types.Value, error) { - var vals types.Set - for _, e := range n.elements { - v, err := e.Eval(ctx) - if err != nil { - return types.ZeroValue(), err - } - vals = append(vals, v) - } - return vals, nil -} - -// containsEval -type containsEval struct { - lhs, rhs evaler -} - -func newContainsEval(lhs, rhs evaler) *containsEval { - return &containsEval{ - lhs: lhs, - rhs: rhs, - } -} - -func (n *containsEval) Eval(ctx *evalContext) (types.Value, error) { - lhs, err := evalSet(n.lhs, ctx) - if err != nil { - return types.ZeroValue(), err - } - rhs, err := n.rhs.Eval(ctx) - if err != nil { - return types.ZeroValue(), err - } - return types.Boolean(lhs.Contains(rhs)), nil -} - -// containsAllEval -type containsAllEval struct { - lhs, rhs evaler -} - -func newContainsAllEval(lhs, rhs evaler) *containsAllEval { - return &containsAllEval{ - lhs: lhs, - rhs: rhs, - } -} - -func (n *containsAllEval) Eval(ctx *evalContext) (types.Value, error) { - lhs, err := evalSet(n.lhs, ctx) - if err != nil { - return types.ZeroValue(), err - } - rhs, err := evalSet(n.rhs, ctx) - if err != nil { - return types.ZeroValue(), err - } - result := true - for _, e := range rhs { - if !lhs.Contains(e) { - result = false - break - } - } - return types.Boolean(result), nil -} - -// containsAnyEval -type containsAnyEval struct { - lhs, rhs evaler -} - -func newContainsAnyEval(lhs, rhs evaler) *containsAnyEval { - return &containsAnyEval{ - lhs: lhs, - rhs: rhs, - } -} - -func (n *containsAnyEval) Eval(ctx *evalContext) (types.Value, error) { - lhs, err := evalSet(n.lhs, ctx) - if err != nil { - return types.ZeroValue(), err - } - rhs, err := evalSet(n.rhs, ctx) - if err != nil { - return types.ZeroValue(), err - } - result := false - for _, e := range rhs { - if lhs.Contains(e) { - result = true - break - } - } - return types.Boolean(result), nil -} - -// recordLiteralEval -type recordLiteralEval struct { - elements map[string]evaler -} - -func newRecordLiteralEval(elements map[string]evaler) *recordLiteralEval { - return &recordLiteralEval{elements: elements} -} - -func (n *recordLiteralEval) Eval(ctx *evalContext) (types.Value, error) { - vals := types.Record{} - for k, en := range n.elements { - v, err := en.Eval(ctx) - if err != nil { - return types.ZeroValue(), err - } - vals[k] = v - } - return vals, nil -} - -// attributeAccessEval -type attributeAccessEval struct { - object evaler - attribute string -} - -func newAttributeAccessEval(record evaler, attribute string) *attributeAccessEval { - return &attributeAccessEval{object: record, attribute: attribute} -} - -func (n *attributeAccessEval) Eval(ctx *evalContext) (types.Value, error) { - v, err := n.object.Eval(ctx) - if err != nil { - return types.ZeroValue(), err - } - var record types.Record - key := "record" - switch vv := v.(type) { - case types.EntityUID: - key = "`" + vv.String() + "`" - var unspecified types.EntityUID - if vv == unspecified { - return types.ZeroValue(), fmt.Errorf("cannot access attribute `%s` of %w", n.attribute, errUnspecifiedEntity) - } - rec, ok := ctx.Entities[vv] - if !ok { - return types.ZeroValue(), fmt.Errorf("entity `%v` %w", vv.String(), errEntityNotExist) - } else { - record = rec.Attributes - } - case types.Record: - record = vv - default: - return types.ZeroValue(), fmt.Errorf("%w: expected one of [record, (entity of type `any_entity_type`)], got %v", types.ErrType, v.TypeName()) - } - val, ok := record[n.attribute] - if !ok { - return types.ZeroValue(), fmt.Errorf("%s %w `%s`", key, errAttributeAccess, n.attribute) - } - return val, nil -} - -// hasEval -type hasEval struct { - object evaler - attribute string -} - -func newHasEval(record evaler, attribute string) *hasEval { - return &hasEval{object: record, attribute: attribute} -} - -func (n *hasEval) Eval(ctx *evalContext) (types.Value, error) { - v, err := n.object.Eval(ctx) - if err != nil { - return types.ZeroValue(), err - } - var record types.Record - switch vv := v.(type) { - case types.EntityUID: - rec, ok := ctx.Entities[vv] - if !ok { - record = types.Record{} - } else { - record = rec.Attributes - } - case types.Record: - record = vv - default: - return types.ZeroValue(), fmt.Errorf("%w: expected one of [record, (entity of type `any_entity_type`)], got %v", types.ErrType, v.TypeName()) - } - _, ok := record[n.attribute] - return types.Boolean(ok), nil -} - -// likeEval -type likeEval struct { - lhs evaler - pattern parser.Pattern -} - -func newLikeEval(lhs evaler, pattern parser.Pattern) *likeEval { - return &likeEval{lhs: lhs, pattern: pattern} -} - -func (l *likeEval) Eval(ctx *evalContext) (types.Value, error) { - v, err := evalString(l.lhs, ctx) - if err != nil { - return types.ZeroValue(), err - } - return types.Boolean(match(l.pattern, string(v))), nil -} - -type variableName func(ctx *evalContext) types.Value - -func variableNamePrincipal(ctx *evalContext) types.Value { return ctx.Principal } -func variableNameAction(ctx *evalContext) types.Value { return ctx.Action } -func variableNameResource(ctx *evalContext) types.Value { return ctx.Resource } -func variableNameContext(ctx *evalContext) types.Value { return ctx.Context } - -// variableEval -type variableEval struct { - variableName variableName -} - -func newVariableEval(variableName variableName) *variableEval { - return &variableEval{variableName: variableName} -} - -func (n *variableEval) Eval(ctx *evalContext) (types.Value, error) { - return n.variableName(ctx), nil -} - -// inEval -type inEval struct { - lhs, rhs evaler -} - -func newInEval(lhs, rhs evaler) *inEval { - return &inEval{lhs: lhs, rhs: rhs} -} - -func entityIn(entity types.EntityUID, query map[types.EntityUID]struct{}, entities Entities) bool { - checked := map[types.EntityUID]struct{}{} - toCheck := []types.EntityUID{entity} - for len(toCheck) > 0 { - var candidate types.EntityUID - candidate, toCheck = toCheck[len(toCheck)-1], toCheck[:len(toCheck)-1] - if _, ok := checked[candidate]; ok { - continue - } - if _, ok := query[candidate]; ok { - return true - } - toCheck = append(toCheck, entities[candidate].Parents...) - checked[candidate] = struct{}{} - } - return false -} - -func (n *inEval) Eval(ctx *evalContext) (types.Value, error) { - lhs, err := evalEntity(n.lhs, ctx) - if err != nil { - return types.ZeroValue(), err - } - - rhs, err := n.rhs.Eval(ctx) - if err != nil { - return types.ZeroValue(), err - } - - query := map[types.EntityUID]struct{}{} - switch rhsv := rhs.(type) { - case types.EntityUID: - query[rhsv] = struct{}{} - case types.Set: - for _, rhv := range rhsv { - e, err := types.ValueToEntity(rhv) - if err != nil { - return types.ZeroValue(), err - } - query[e] = struct{}{} - } - default: - return types.ZeroValue(), fmt.Errorf( - "%w: expected one of [set, (entity of type `any_entity_type`)], got %v", types.ErrType, rhs.TypeName()) - } - return types.Boolean(entityIn(lhs, query, ctx.Entities)), nil -} - -// isEval -type isEval struct { - lhs, rhs evaler -} - -func newIsEval(lhs, rhs evaler) *isEval { - return &isEval{lhs: lhs, rhs: rhs} -} - -func (n *isEval) Eval(ctx *evalContext) (types.Value, error) { - lhs, err := evalEntity(n.lhs, ctx) - if err != nil { - return types.ZeroValue(), err - } - - rhs, err := evalPath(n.rhs, ctx) - if err != nil { - return types.ZeroValue(), err - } - - return types.Boolean(types.Path(lhs.Type) == rhs), nil -} - -// decimalLiteralEval -type decimalLiteralEval struct { - literal evaler -} - -func newDecimalLiteralEval(literal evaler) *decimalLiteralEval { - return &decimalLiteralEval{literal: literal} -} - -func (n *decimalLiteralEval) Eval(ctx *evalContext) (types.Value, error) { - literal, err := evalString(n.literal, ctx) - if err != nil { - return types.ZeroValue(), err - } - - d, err := types.ParseDecimal(string(literal)) - if err != nil { - return types.ZeroValue(), err - } - - return d, nil -} - -type ipLiteralEval struct { - literal evaler -} - -func newIPLiteralEval(literal evaler) *ipLiteralEval { - return &ipLiteralEval{literal: literal} -} - -func (n *ipLiteralEval) Eval(ctx *evalContext) (types.Value, error) { - literal, err := evalString(n.literal, ctx) - if err != nil { - return types.ZeroValue(), err - } - - i, err := types.ParseIPAddr(string(literal)) - if err != nil { - return types.ZeroValue(), err - } - - return i, nil -} - -type ipTestType func(v types.IPAddr) bool - -func ipTestIPv4(v types.IPAddr) bool { return v.IsIPv4() } -func ipTestIPv6(v types.IPAddr) bool { return v.IsIPv6() } -func ipTestLoopback(v types.IPAddr) bool { return v.IsLoopback() } -func ipTestMulticast(v types.IPAddr) bool { return v.IsMulticast() } - -// ipTestEval -type ipTestEval struct { - object evaler - test ipTestType -} - -func newIPTestEval(object evaler, test ipTestType) *ipTestEval { - return &ipTestEval{object: object, test: test} -} - -func (n *ipTestEval) Eval(ctx *evalContext) (types.Value, error) { - i, err := evalIP(n.object, ctx) - if err != nil { - return types.ZeroValue(), err - } - return types.Boolean(n.test(i)), nil -} - -// ipIsInRangeEval - -type ipIsInRangeEval struct { - lhs, rhs evaler -} - -func newIPIsInRangeEval(lhs, rhs evaler) *ipIsInRangeEval { - return &ipIsInRangeEval{lhs: lhs, rhs: rhs} -} - -func (n *ipIsInRangeEval) Eval(ctx *evalContext) (types.Value, error) { - lhs, err := evalIP(n.lhs, ctx) - if err != nil { - return types.ZeroValue(), err - } - rhs, err := evalIP(n.rhs, ctx) - if err != nil { - return types.ZeroValue(), err - } - return types.Boolean(rhs.Contains(lhs)), nil -} diff --git a/toeval.go b/toeval.go deleted file mode 100644 index 9988178e..00000000 --- a/toeval.go +++ /dev/null @@ -1,307 +0,0 @@ -package cedar - -import ( - "fmt" - "strings" - - "github.com/cedar-policy/cedar-go/types" - "github.com/cedar-policy/cedar-go/x/exp/parser" -) - -func toEval(n any) evaler { - switch v := n.(type) { - case parser.Policy: - res := toEval(v.Principal) - res = newAndEval(res, toEval(v.Action)) - res = newAndEval(res, toEval(v.Resource)) - for _, c := range v.Conditions { - res = newAndEval(res, toEval(c)) - } - return res - case parser.Principal: - var res evaler - switch v.Type { - case parser.MatchAny: - res = newLiteralEval(types.Boolean(true)) - case parser.MatchEquals: - res = newEqualEval(newVariableEval(variableNamePrincipal), toEval(v.Entity)) - case parser.MatchIn: - res = newInEval(newVariableEval(variableNamePrincipal), toEval(v.Entity)) - case parser.MatchIs: - res = newIsEval(newVariableEval(variableNamePrincipal), toEval(v.Path)) - case parser.MatchIsIn: - lhs := newIsEval(newVariableEval(variableNamePrincipal), toEval(v.Path)) - rhs := newInEval(newVariableEval(variableNamePrincipal), toEval(v.Entity)) - res = newAndEval(lhs, rhs) - } - return res - case parser.Action: - var res evaler - switch v.Type { - case parser.MatchAny: - res = newLiteralEval(types.Boolean(true)) - case parser.MatchEquals: - res = newEqualEval(newVariableEval(variableNameAction), toEval(v.Entities[0])) - case parser.MatchIn: - res = newInEval(newVariableEval(variableNameAction), toEval(v.Entities[0])) - case parser.MatchInList: - var vals []evaler - for _, e := range v.Entities { - vals = append(vals, toEval(e)) - } - sl := newSetLiteralEval(vals) - res = newInEval(newVariableEval(variableNameAction), sl) - } - return res - case parser.Resource: - var res evaler - switch v.Type { - case parser.MatchAny: - res = newLiteralEval(types.Boolean(true)) - case parser.MatchEquals: - res = newEqualEval(newVariableEval(variableNameResource), toEval(v.Entity)) - case parser.MatchIn: - res = newInEval(newVariableEval(variableNameResource), toEval(v.Entity)) - case parser.MatchIs: - res = newIsEval(newVariableEval(variableNameResource), toEval(v.Path)) - case parser.MatchIsIn: - lhs := newIsEval(newVariableEval(variableNameResource), toEval(v.Path)) - rhs := newInEval(newVariableEval(variableNameResource), toEval(v.Entity)) - res = newAndEval(lhs, rhs) - } - return res - case parser.Entity: - return newLiteralEval(types.EntityValueFromSlice(v.Path)) - case parser.Path: - return newLiteralEval(types.PathFromSlice(v.Path)) - case parser.Condition: - var res evaler - switch v.Type { - case parser.ConditionWhen: - res = toEval(v.Expression) - case parser.ConditionUnless: - res = newNotEval(toEval(v.Expression)) - } - return res - case parser.Expression: - var res evaler - switch v.Type { - case parser.ExpressionOr: - res = toEval(v.Or) - case parser.ExpressionIf: - res = toEval(*v.If) - } - return res - case parser.If: - return newIfThenElseEval(toEval(v.If), toEval(v.Then), toEval(v.Else)) - case parser.Or: - res := toEval(v.Ands[len(v.Ands)-1]) - for i := len(v.Ands) - 2; i >= 0; i-- { - res = newOrNode(toEval(v.Ands[i]), res) - } - return res - case parser.And: - res := toEval(v.Relations[len(v.Relations)-1]) - for i := len(v.Relations) - 2; i >= 0; i-- { - res = newAndEval(toEval(v.Relations[i]), res) - } - return res - case parser.Relation: - lhs := toEval(v.Add) - switch v.Type { - case parser.RelationNone: - return lhs - case parser.RelationRelOp: - rhs := toEval(v.RelOpRhs) - switch v.RelOp { - case parser.RelOpLt: - return newLongLessThanEval(lhs, rhs) - case parser.RelOpLe: - return newLongLessThanOrEqualEval(lhs, rhs) - case parser.RelOpGe: - return newLongGreaterThanOrEqualEval(lhs, rhs) - case parser.RelOpGt: - return newLongGreaterThanEval(lhs, rhs) - case parser.RelOpNe: - return newNotEqualEval(lhs, rhs) - case parser.RelOpEq: - return newEqualEval(lhs, rhs) - case parser.RelOpIn: - return newInEval(lhs, rhs) - default: - panic("missing RelOp case") - } - case parser.RelationHasIdent, parser.RelationHasLiteral: - return newHasEval(lhs, v.Str) - case parser.RelationLike: - return newLikeEval(lhs, v.Pat) - case parser.RelationIs: - return newIsEval(lhs, toEval(v.Path)) - case parser.RelationIsIn: - lhs2 := newIsEval(lhs, toEval(v.Path)) - rhs2 := newInEval(lhs, toEval(v.Entity)) - return newAndEval(lhs2, rhs2) - default: - panic("missing RelationType case") - } - case parser.Add: - res := toEval(v.Mults[len(v.Mults)-1]) - for i := len(v.AddOps) - 1; i >= 0; i-- { - switch v.AddOps[i] { - case parser.AddOpAdd: - res = newAddEval(toEval(v.Mults[i]), res) - case parser.AddOpSub: - res = newSubtractEval(toEval(v.Mults[i]), res) - default: - panic("unknown AddOp") - } - } - return res - case parser.Mult: - res := toEval(v.Unaries[len(v.Unaries)-1]) - for i := len(v.Unaries) - 2; i >= 0; i-- { - res = newMultiplyEval(toEval(v.Unaries[i]), res) - } - return res - - case parser.Unary: - res := toEval(v.Member) - for i := len(v.Ops) - 1; i >= 0; i-- { - switch v.Ops[i] { - case parser.UnaryOpMinus: - res = newNegateEval(res) - case parser.UnaryOpNot: - res = newNotEval(res) - } - } - return res - case parser.Member: - res := toEval(v.Primary) - for _, a := range v.Accesses { - res = toAccess(a, res) - } - return res - case parser.Primary: - switch v.Type { - case parser.PrimaryLiteral: - return toEval(v.Literal) - case parser.PrimaryVar: - return toEval(v.Var) - case parser.PrimaryEntity: - return toEval(v.Entity) - case parser.PrimaryExtFun: - return toEval(v.ExtFun) - case parser.PrimaryExpr: - return toEval(v.Expression) - case parser.PrimaryExprList: - var nodes []evaler - for _, e := range v.Expressions { - nodes = append(nodes, toEval(e)) - } - return newSetLiteralEval(nodes) - case parser.PrimaryRecInits: - nodes := map[string]evaler{} - for _, r := range v.RecInits { - nodes[r.Key] = toEval(r.Value) - } - return newRecordLiteralEval(nodes) - default: - panic("missing PrimaryType case") - } - case parser.Literal: - switch v.Type { - case parser.LiteralBool: - return newLiteralEval(types.Boolean(v.Bool)) - case parser.LiteralInt: - return newLiteralEval(types.Long(v.Long)) - case parser.LiteralString: - return newLiteralEval(types.String(v.Str)) - default: - panic("missing LiteralType case") - } - case parser.Var: - switch v.Type { - case parser.VarPrincipal: - return newVariableEval(variableNamePrincipal) - case parser.VarAction: - return newVariableEval(variableNameAction) - case parser.VarResource: - return newVariableEval(variableNameResource) - case parser.VarContext: - return newVariableEval(variableNameContext) - default: - panic("missing VarType case") - } - case parser.ExtFun: - funName := strings.Join(v.Path, "::") - switch funName { - case "decimal": - if len(v.Expressions) != 1 { - return newErrorEval(fmt.Errorf("%w: %s takes 1 parameter", errArity, funName)) - } - return newDecimalLiteralEval(toEval(v.Expressions[0])) - case "ip": - if len(v.Expressions) != 1 { - return newErrorEval(fmt.Errorf("%w: %s takes 1 parameter", errArity, funName)) - } - return newIPLiteralEval(toEval(v.Expressions[0])) - default: - return newErrorEval(fmt.Errorf("%w: %s", errUnknownExtensionFunction, funName)) - } - - default: - panic(fmt.Sprintf("unknown node type %T", v)) - } -} - -func toAccess(v parser.Access, lhs evaler) evaler { - switch v.Type { - case parser.AccessField: - return newAttributeAccessEval(lhs, v.Name) - case parser.AccessCall: - var ctor1 func(evaler, evaler) evaler - var ctor0 func(evaler) evaler - switch v.Name { - case "contains": - ctor1 = func(lhs, rhs evaler) evaler { return newContainsEval(lhs, rhs) } - case "containsAll": - ctor1 = func(lhs, rhs evaler) evaler { return newContainsAllEval(lhs, rhs) } - case "containsAny": - ctor1 = func(lhs, rhs evaler) evaler { return newContainsAnyEval(lhs, rhs) } - case "lessThan": - ctor1 = func(lhs, rhs evaler) evaler { return newDecimalLessThanEval(lhs, rhs) } - case "lessThanOrEqual": - ctor1 = func(lhs, rhs evaler) evaler { return newDecimalLessThanOrEqualEval(lhs, rhs) } - case "greaterThan": - ctor1 = func(lhs, rhs evaler) evaler { return newDecimalGreaterThanEval(lhs, rhs) } - case "greaterThanOrEqual": - ctor1 = func(lhs, rhs evaler) evaler { return newDecimalGreaterThanOrEqualEval(lhs, rhs) } - case "isIpv4": - ctor0 = func(lhs evaler) evaler { return newIPTestEval(lhs, ipTestIPv4) } - case "isIpv6": - ctor0 = func(lhs evaler) evaler { return newIPTestEval(lhs, ipTestIPv6) } - case "isLoopback": - ctor0 = func(lhs evaler) evaler { return newIPTestEval(lhs, ipTestLoopback) } - case "isMulticast": - ctor0 = func(lhs evaler) evaler { return newIPTestEval(lhs, ipTestMulticast) } - case "isInRange": - ctor1 = func(lhs, rhs evaler) evaler { return newIPIsInRangeEval(lhs, rhs) } - default: - return newErrorEval(fmt.Errorf("%w: %s", errUnknownMethod, v.Name)) - } - if ctor0 != nil { - if len(v.Expressions) != 0 { - return newErrorEval(fmt.Errorf("%w `%s`: expected 1, got %d", errArity, v.Name, len(v.Expressions)+1)) - } - return ctor0(lhs) - } - if len(v.Expressions) != 1 { - return newErrorEval(fmt.Errorf("%w `%s`: expected 2, got %d", errArity, v.Name, len(v.Expressions)+1)) - } - return ctor1(lhs, toEval(v.Expressions[0])) - case parser.AccessIndex: - return newAttributeAccessEval(lhs, v.Name) - default: - panic("missing AccessType case") - } -} diff --git a/toeval_test.go b/toeval_test.go deleted file mode 100644 index 037e58ef..00000000 --- a/toeval_test.go +++ /dev/null @@ -1,170 +0,0 @@ -package cedar - -import ( - "fmt" - "strings" - "testing" - - "github.com/cedar-policy/cedar-go/testutil" - "github.com/cedar-policy/cedar-go/types" - "github.com/cedar-policy/cedar-go/x/exp/parser" -) - -func safeDoErr(f func() error) (err error) { - defer func() { - if r := recover(); r != nil { - err = fmt.Errorf("%v", r) - } - }() - return f() -} - -func TestToEval(t *testing.T) { - t.Parallel() - tests := []struct { - name string - in any - out evaler - panic string - }{ - {"happy", parser.Entity{ - Path: []string{"Action", "test"}, - }, - newLiteralEval(types.EntityValueFromSlice([]string{"Action", "test"})), ""}, - {"missingRelOp", parser.Relation{ - Add: parser.Add{ - Mults: []parser.Mult{ - { - Unaries: []parser.Unary{ - { - Member: parser.Member{ - Primary: parser.Primary{ - Type: parser.PrimaryEntity, - Entity: parser.Entity{ - Path: []string{"Action", "test"}, - }, - }, - }, - }, - }, - }, - }, - }, - RelOpRhs: parser.Add{ - Mults: []parser.Mult{ - { - Unaries: []parser.Unary{ - { - Member: parser.Member{ - Primary: parser.Primary{ - Type: parser.PrimaryEntity, - Entity: parser.Entity{ - Path: []string{"Action", "test"}, - }, - }, - }, - }, - }, - }, - }, - }, - Type: parser.RelationRelOp, - RelOp: "invalid", - }, - nil, "missing RelOp case"}, - - {"missingRelationType", parser.Relation{ - Add: parser.Add{ - Mults: []parser.Mult{ - { - Unaries: []parser.Unary{ - { - Member: parser.Member{ - Primary: parser.Primary{ - Type: parser.PrimaryEntity, - Entity: parser.Entity{ - Path: []string{"Action", "test"}, - }, - }, - }, - }, - }, - }, - }, - }, - Type: "invalid", - }, - nil, "missing RelationType case"}, - - {"unknownAddOp", parser.Add{ - Mults: []parser.Mult{ - { - Unaries: []parser.Unary{ - { - Member: parser.Member{ - Primary: parser.Primary{ - Type: parser.PrimaryEntity, - Entity: parser.Entity{ - Path: []string{"Action", "test"}, - }, - }, - }, - }, - }, - }, - }, - AddOps: []parser.AddOp{"invalid"}, - }, - nil, "unknown AddOp"}, - - {"missingPrimaryType", parser.Primary{ - Type: parser.PrimaryType(-42), - }, - nil, "missing PrimaryType case"}, - - {"missingLiteralType", parser.Literal{ - Type: parser.LiteralType(-42), - }, - nil, "missing LiteralType case"}, - - {"missingVarType", parser.Var{ - Type: "invalid", - }, - nil, "missing VarType case"}, - - {"unknownNodeType", true, - nil, "unknown node type bool"}, - - {"missingAccessType", parser.Member{ - Primary: parser.Primary{ - Type: parser.PrimaryEntity, - Entity: parser.Entity{ - Path: []string{"Action", "test"}, - }, - }, - Accesses: []parser.Access{ - { - Type: parser.AccessType(-42), - }, - }, - }, - nil, "missing AccessType case"}, - } - - for _, tt := range tests { - tt := tt - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - var out evaler - err := safeDoErr(func() error { - out = toEval(tt.in) - return nil - }) - testutil.Equals(t, out, tt.out) - testutil.Equals(t, err != nil, tt.panic != "") - if tt.panic != "" { - testutil.FatalIf(t, !strings.Contains(err.Error(), tt.panic), "panic got %v want %v", err.Error(), tt.panic) - } - }) - } -} From ea01fa3617f03d133812a11bd4bdea07c9288c13 Mon Sep 17 00:00:00 2001 From: philhassey Date: Thu, 8 Aug 2024 16:22:19 -0600 Subject: [PATCH 073/216] internal/testutil: move test util into internal package Addresses IDX-142 Signed-off-by: philhassey --- cedar_test.go | 2 +- internal/rust_test.go | 2 +- {testutil => internal/testutil}/testutil.go | 0 types/json_test.go | 2 +- types/patttern_test.go | 2 +- types/testutil.go | 4 +++- types/value_test.go | 2 +- x/exp/ast/cedar_parse_test.go | 2 +- x/exp/ast/cedar_tokenize_test.go | 2 +- x/exp/ast/cedar_unmarshal_test.go | 2 +- x/exp/ast/eval_test.go | 2 +- x/exp/ast/json_test.go | 2 +- x/exp/parser/tokenize_test.go | 2 +- 13 files changed, 14 insertions(+), 12 deletions(-) rename {testutil => internal/testutil}/testutil.go (100%) diff --git a/cedar_test.go b/cedar_test.go index 350d77ae..c443cff6 100644 --- a/cedar_test.go +++ b/cedar_test.go @@ -5,7 +5,7 @@ import ( "net/netip" "testing" - "github.com/cedar-policy/cedar-go/testutil" + "github.com/cedar-policy/cedar-go/internal/testutil" "github.com/cedar-policy/cedar-go/types" ) diff --git a/internal/rust_test.go b/internal/rust_test.go index 24efd067..da3611e4 100644 --- a/internal/rust_test.go +++ b/internal/rust_test.go @@ -4,7 +4,7 @@ import ( "testing" "github.com/cedar-policy/cedar-go/internal" - "github.com/cedar-policy/cedar-go/testutil" + "github.com/cedar-policy/cedar-go/internal/testutil" ) func TestParseUnicodeEscape(t *testing.T) { diff --git a/testutil/testutil.go b/internal/testutil/testutil.go similarity index 100% rename from testutil/testutil.go rename to internal/testutil/testutil.go diff --git a/types/json_test.go b/types/json_test.go index 50d18cb0..eb58bc37 100644 --- a/types/json_test.go +++ b/types/json_test.go @@ -5,7 +5,7 @@ import ( "fmt" "testing" - "github.com/cedar-policy/cedar-go/testutil" + "github.com/cedar-policy/cedar-go/internal/testutil" ) func mustDecimalValue(v string) Decimal { diff --git a/types/patttern_test.go b/types/patttern_test.go index 8eec28a4..f488f25f 100644 --- a/types/patttern_test.go +++ b/types/patttern_test.go @@ -3,7 +3,7 @@ package types import ( "testing" - "github.com/cedar-policy/cedar-go/testutil" + "github.com/cedar-policy/cedar-go/internal/testutil" ) func TestPatternFromBuilder(t *testing.T) { diff --git a/types/testutil.go b/types/testutil.go index 787f96b9..64df358c 100644 --- a/types/testutil.go +++ b/types/testutil.go @@ -3,9 +3,11 @@ package types import ( "testing" - "github.com/cedar-policy/cedar-go/testutil" + "github.com/cedar-policy/cedar-go/internal/testutil" ) +// TODO: this file should not be public, it should be moved into the eval code + func AssertValue(t *testing.T, got, want Value) { t.Helper() testutil.FatalIf( diff --git a/types/value_test.go b/types/value_test.go index 9381e78e..9cf2ffbb 100644 --- a/types/value_test.go +++ b/types/value_test.go @@ -4,7 +4,7 @@ import ( "fmt" "testing" - "github.com/cedar-policy/cedar-go/testutil" + "github.com/cedar-policy/cedar-go/internal/testutil" ) func TestBool(t *testing.T) { diff --git a/x/exp/ast/cedar_parse_test.go b/x/exp/ast/cedar_parse_test.go index 930b623a..10d151cc 100644 --- a/x/exp/ast/cedar_parse_test.go +++ b/x/exp/ast/cedar_parse_test.go @@ -4,7 +4,7 @@ import ( "bytes" "testing" - "github.com/cedar-policy/cedar-go/testutil" + "github.com/cedar-policy/cedar-go/internal/testutil" "github.com/cedar-policy/cedar-go/x/exp/ast" ) diff --git a/x/exp/ast/cedar_tokenize_test.go b/x/exp/ast/cedar_tokenize_test.go index 9ff3d900..5606a642 100644 --- a/x/exp/ast/cedar_tokenize_test.go +++ b/x/exp/ast/cedar_tokenize_test.go @@ -7,7 +7,7 @@ import ( "testing" "unicode/utf8" - "github.com/cedar-policy/cedar-go/testutil" + "github.com/cedar-policy/cedar-go/internal/testutil" ) func TestTokenize(t *testing.T) { diff --git a/x/exp/ast/cedar_unmarshal_test.go b/x/exp/ast/cedar_unmarshal_test.go index 47972c3d..bc118953 100644 --- a/x/exp/ast/cedar_unmarshal_test.go +++ b/x/exp/ast/cedar_unmarshal_test.go @@ -4,7 +4,7 @@ import ( "bytes" "testing" - "github.com/cedar-policy/cedar-go/testutil" + "github.com/cedar-policy/cedar-go/internal/testutil" "github.com/cedar-policy/cedar-go/types" "github.com/cedar-policy/cedar-go/x/exp/ast" ) diff --git a/x/exp/ast/eval_test.go b/x/exp/ast/eval_test.go index 875f7008..b9d00e4b 100644 --- a/x/exp/ast/eval_test.go +++ b/x/exp/ast/eval_test.go @@ -6,7 +6,7 @@ import ( "strings" "testing" - "github.com/cedar-policy/cedar-go/testutil" + "github.com/cedar-policy/cedar-go/internal/testutil" "github.com/cedar-policy/cedar-go/types" ) diff --git a/x/exp/ast/json_test.go b/x/exp/ast/json_test.go index 050a5c0d..a61b7d0f 100644 --- a/x/exp/ast/json_test.go +++ b/x/exp/ast/json_test.go @@ -4,7 +4,7 @@ import ( "encoding/json" "testing" - "github.com/cedar-policy/cedar-go/testutil" + "github.com/cedar-policy/cedar-go/internal/testutil" "github.com/cedar-policy/cedar-go/types" "github.com/cedar-policy/cedar-go/x/exp/ast" ) diff --git a/x/exp/parser/tokenize_test.go b/x/exp/parser/tokenize_test.go index 42d9911a..926d99ed 100644 --- a/x/exp/parser/tokenize_test.go +++ b/x/exp/parser/tokenize_test.go @@ -7,7 +7,7 @@ import ( "testing" "unicode/utf8" - "github.com/cedar-policy/cedar-go/testutil" + "github.com/cedar-policy/cedar-go/internal/testutil" ) func TestTokenize(t *testing.T) { From 6d2f0f0f7e6137aac674ac369a7696bfbf1be150 Mon Sep 17 00:00:00 2001 From: philhassey Date: Thu, 8 Aug 2024 16:23:02 -0600 Subject: [PATCH 074/216] cedar: remove dead match code Addresses IDX-142 Signed-off-by: philhassey --- match.go | 66 -------------------------------------------------------- 1 file changed, 66 deletions(-) delete mode 100644 match.go diff --git a/match.go b/match.go deleted file mode 100644 index 5a6c9d27..00000000 --- a/match.go +++ /dev/null @@ -1,66 +0,0 @@ -package cedar - -import "github.com/cedar-policy/cedar-go/x/exp/parser" - -// ported from Go's stdlib and reduced to our scope. -// https://golang.org/src/path/filepath/match.go?s=1226:1284#L34 - -// Match reports whether name matches the shell file name pattern. -// The pattern syntax is: -// -// pattern: -// { term } -// term: -// '*' matches any sequence of non-Separator characters -// c matches character c (c != '*') -func match(p parser.Pattern, name string) (matched bool) { -Pattern: - for i, comp := range p.Comps { - lastChunk := i == len(p.Comps)-1 - if comp.Star && comp.Chunk == "" { - return true - } - // Look for Match at current position. - t, ok := matchChunk(comp.Chunk, name) - // if we're the last chunk, make sure we've exhausted the name - // otherwise we'll give a false result even if we could still Match - // using the star - if ok && (len(t) == 0 || !lastChunk) { - name = t - continue - } - if comp.Star { - // Look for Match skipping i+1 bytes. - for i := 0; i < len(name); i++ { - t, ok := matchChunk(comp.Chunk, name[i+1:]) - if ok { - // if we're the last chunk, make sure we exhausted the name - if lastChunk && len(t) > 0 { - continue - } - name = t - continue Pattern - } - } - } - return false - } - return len(name) == 0 -} - -// matchChunk checks whether chunk matches the beginning of s. -// If so, it returns the remainder of s (after the Match). -// Chunk is all single-character operators: literals, char classes, and ?. -func matchChunk(chunk, s string) (rest string, ok bool) { - for len(chunk) > 0 { - if len(s) == 0 { - return - } - if chunk[0] != s[0] { - return - } - s = s[1:] - chunk = chunk[1:] - } - return s, true -} From 730a598cfe29624ddcf3268dba63daefb5233fb8 Mon Sep 17 00:00:00 2001 From: philhassey Date: Thu, 8 Aug 2024 16:30:23 -0600 Subject: [PATCH 075/216] cedar: remove dead parser code Addresses IDX-142 Signed-off-by: philhassey --- x/exp/parser/fuzz_test.go | 103 -- x/exp/parser/parse.go | 1462 --------------------------- x/exp/parser/tokenize.go | 705 ------------- x/exp/parser/tokenize_mocks_test.go | 74 -- x/exp/parser/tokenize_test.go | 554 ---------- 5 files changed, 2898 deletions(-) delete mode 100644 x/exp/parser/fuzz_test.go delete mode 100644 x/exp/parser/parse.go delete mode 100644 x/exp/parser/tokenize.go delete mode 100644 x/exp/parser/tokenize_mocks_test.go delete mode 100644 x/exp/parser/tokenize_test.go diff --git a/x/exp/parser/fuzz_test.go b/x/exp/parser/fuzz_test.go deleted file mode 100644 index c6f89606..00000000 --- a/x/exp/parser/fuzz_test.go +++ /dev/null @@ -1,103 +0,0 @@ -package parser - -import ( - "testing" -) - -// https://go.dev/doc/tutorial/fuzz -// mkdir testdata -// go test -fuzz=FuzzTokenize -fuzztime 60s -// go test -fuzz=FuzzParse -fuzztime 60s - -func FuzzTokenize(f *testing.F) { - tests := []string{ - `These are some identifiers`, - `0 1 1234`, - `-1 9223372036854775807 -9223372036854775808`, - `"" "string" "\"\'\n\r\t\\\0" "\x123" "\u{0}\u{10fFfF}"`, - `"*" "\*" "*\**"`, - `@.,;(){}[]+-*`, - `:::`, - `!!=<<=>>=`, - `||&&`, - `// single line comment`, - `/*`, - `multiline comment`, - `// embedded comment does nothing`, - `*/`, - `'/%|&=`, - } - for _, tt := range tests { - f.Add(tt) - } - f.Fuzz(func(t *testing.T, orig string) { - toks, err := Tokenize([]byte(orig)) - if err != nil { - if toks != nil { - t.Errorf("toks != nil on err") - } - } - }) -} - -func FuzzParse(f *testing.F) { - tests := []string{ - `permit(principal,action,resource);`, - `forbid(principal,action,resource);`, - `permit(principal,action,resource in asdf::"1234");`, - `permit(principal,action,resource) when { resource in "foo" };`, - `permit(principal,action,resource) when { context.x == 42 };`, - `permit(principal,action,resource) when { context.x == 42 };`, - `permit(principal,action,resource) when { principal.x == 42 };`, - `permit(principal,action,resource) when { principal.x == 42 };`, - `permit(principal,action,resource) when { principal in parent::"bob" };`, - `permit(principal == coder::"cuzco",action,resource);`, - `permit(principal in team::"osiris",action,resource);`, - `permit(principal,action == table::"drop",resource);`, - `permit(principal,action in scary::"stuff",resource);`, - `permit(principal,action in [scary::"stuff"],resource);`, - `permit(principal,action,resource == table::"whatever");`, - `permit(principal,action,resource) unless { false };`, - `permit(principal,action,resource) when { (if true then true else true) };`, - `permit(principal,action,resource) when { (true || false) };`, - `permit(principal,action,resource) when { (true && true) };`, - `permit(principal,action,resource) when { (1<2) && (1<=1) && (2>1) && (1>=1) && (1!=2) && (1==1)};`, - `permit(principal,action,resource) when { principal in principal };`, - `permit(principal,action,resource) when { principal has name };`, - `permit(principal,action,resource) when { 40+3-1==42 };`, - `permit(principal,action,resource) when { 6*7==42 };`, - `permit(principal,action,resource) when { -42==-42 };`, - `permit(principal,action,resource) when { !(1+1==42) };`, - `permit(principal,action,resource) when { [1,2,3].contains(2) };`, - `permit(principal,action,resource) when { {name:"bob"} has name };`, - `permit(principal,action,resource) when { action in action };`, - `permit(principal,action,resource) when { [1,2,3].contains(2) };`, - `permit(principal,action,resource) when { [1,2,3].contains(2,3) };`, - `permit(principal,action,resource) when { [1,2,3].containsAll([2,3]) };`, - `permit(principal,action,resource) when { [1,2,3].containsAll(2,3) };`, - `permit(principal,action,resource) when { [1,2,3].containsAny([2,5]) };`, - `permit(principal,action,resource) when { [1,2,3].containsAny(2,5) };`, - `permit(principal,action,resource) when { {name:"bob"}["name"] == "bob" };`, - `permit(principal,action,resource) when { [1,2,3].shuffle() };`, - `permit(principal,action,resource) when { "bananas" like "*nan*" };`, - `permit(principal,action,resource) when { fooBar("10") };`, - `permit(principal,action,resource) when { decimal(1, 2) };`, - `permit(principal,action,resource) when { ip() };`, - `permit(principal,action,resource) when { ip("1.2.3.4").isIpv4(true) };`, - `permit(principal,action,resource) when { ip("1.2.3.4").isIpv6(true) };`, - `permit(principal,action,resource) when { ip("1.2.3.4").isLoopback(true) };`, - `permit(principal,action,resource) when { ip("1.2.3.4").isMulticast(true) };`, - `permit(principal,action,resource) when { ip("1.2.3.4").isInRange() };`, - } - for _, tt := range tests { - f.Add(tt) - } - f.Fuzz(func(_ *testing.T, orig string) { - toks, err := Tokenize([]byte(orig)) - if err != nil { - return - } - // intentionally ignore parse errors - _, _ = Parse(toks) - }) -} diff --git a/x/exp/parser/parse.go b/x/exp/parser/parse.go deleted file mode 100644 index a6ca0690..00000000 --- a/x/exp/parser/parse.go +++ /dev/null @@ -1,1462 +0,0 @@ -package parser - -import ( - "fmt" - "strconv" - "strings" -) - -func Parse(tokens []Token) (Policies, error) { - p := &parser{Tokens: tokens} - return p.Policies() -} - -func ParseEntity(tokens []Token) (Entity, error) { - p := &parser{Tokens: tokens} - return p.Entity() -} - -type parser struct { - Tokens []Token - Pos int -} - -func (p *parser) advance() Token { - t := p.peek() - if p.Pos < len(p.Tokens)-1 { - p.Pos++ - } - return t -} - -func (p *parser) peek() Token { - return p.Tokens[p.Pos] -} - -func (p *parser) exact(tok string) error { - t := p.advance() - if t.Text != tok { - return p.errorf("exact got %v want %v", t.Text, tok) - } - return nil -} - -func (p *parser) errorf(s string, args ...interface{}) error { - var t Token - if p.Pos < len(p.Tokens) { - t = p.Tokens[p.Pos] - } - err := fmt.Errorf(s, args...) - return fmt.Errorf("parse error at %v %q: %w", t.Pos, t.Text, err) -} - -// Policies := {Policy} - -type Policies []Policy - -func (c Policies) String() string { - var sb strings.Builder - for i, p := range c { - if i > 0 { - sb.WriteRune('\n') - } - sb.WriteString(p.String()) - } - return sb.String() -} - -func (p *parser) Policies() (Policies, error) { - var res Policies - for !p.peek().isEOF() { - policy, err := p.policy() - if err != nil { - return nil, err - } - res = append(res, policy) - } - return res, nil -} - -// Policy := {Annotation} Effect '(' Scope ')' {Conditions} ';' -// Scope := Principal ',' Action ',' Resource - -type Policy struct { - Position Position - Annotations []Annotation - Effect Effect - Principal Principal - Action Action - Resource Resource - Conditions []Condition -} - -func (p Policy) String() string { - var sb strings.Builder - for i, a := range p.Annotations { - if i > 0 { - sb.WriteRune('\n') - } - sb.WriteString(a.String()) - } - sb.WriteString(fmt.Sprintf("%s(\n%s,\n%s,\n%s\n)", - p.Effect, p.Principal, p.Action, p.Resource, - )) - for _, c := range p.Conditions { - sb.WriteRune('\n') - sb.WriteString(c.String()) - } - sb.WriteString(";") - return sb.String() -} - -func (p *parser) policy() (Policy, error) { - var res Policy - res.Position = p.peek().Pos - var err error - if res.Annotations, err = p.annotations(); err != nil { - return res, err - } - if res.Effect, err = p.effect(); err != nil { - return res, err - } - if err := p.exact("("); err != nil { - return res, err - } - if res.Principal, err = p.principal(); err != nil { - return res, err - } - if err := p.exact(","); err != nil { - return res, err - } - if res.Action, err = p.action(); err != nil { - return res, err - } - if err := p.exact(","); err != nil { - return res, err - } - if res.Resource, err = p.resource(); err != nil { - return res, err - } - if err := p.exact(")"); err != nil { - return res, err - } - if res.Conditions, err = p.conditions(); err != nil { - return res, err - } - if err := p.exact(";"); err != nil { - return res, err - } - return res, nil -} - -// Annotation := '@'IDENT'('STR')' - -type Annotation struct { - Key string - Value string -} - -func (a Annotation) String() string { - return fmt.Sprintf("@%s(%q)", a.Key, a.Value) -} - -func (p *parser) annotation() (Annotation, error) { - var res Annotation - var err error - t := p.advance() - if !t.isIdent() { - return res, p.errorf("expected ident") - } - res.Key = t.Text - if err := p.exact("("); err != nil { - return res, err - } - t = p.advance() - if !t.isString() { - return res, p.errorf("expected string") - } - if res.Value, err = t.stringValue(); err != nil { - return res, err - } - if err := p.exact(")"); err != nil { - return res, err - } - return res, nil -} - -func (p *parser) annotations() ([]Annotation, error) { - var res []Annotation - for p.peek().Text == "@" { - p.advance() - a, err := p.annotation() - if err != nil { - return res, err - } - for _, aa := range res { - if aa.Key == a.Key { - return res, p.errorf("duplicate annotation") - } - } - res = append(res, a) - } - return res, nil -} - -// Effect := 'permit' | 'forbid' - -type Effect string - -const ( - EffectPermit = Effect("permit") - EffectForbid = Effect("forbid") -) - -func (p *parser) effect() (Effect, error) { - next := p.advance() - res := Effect(next.Text) - switch res { - case EffectForbid: - case EffectPermit: - default: - return res, p.errorf("unexpected effect: %v", res) - } - return res, nil -} - -// MatchType - -type MatchType int - -const ( - MatchAny = MatchType(iota) - MatchEquals - MatchIn - MatchInList - MatchIs - MatchIsIn -) - -// Principal := 'principal' [('in' | '==') Entity] - -type Principal struct { - Type MatchType - Path Path - Entity Entity -} - -func (p Principal) String() string { - var res string - switch p.Type { - case MatchAny: - res = "principal" - case MatchEquals: - res = fmt.Sprintf("principal == %s", p.Entity) - case MatchIs: - res = fmt.Sprintf("principal is %s", p.Path) - case MatchIsIn: - res = fmt.Sprintf("principal is %s in %s", p.Path, p.Entity) - case MatchIn: - res = fmt.Sprintf("principal in %s", p.Entity) - } - return res -} - -func (p *parser) principal() (Principal, error) { - var res Principal - if err := p.exact("principal"); err != nil { - return res, err - } - switch p.peek().Text { - case "==": - p.advance() - var err error - res.Type = MatchEquals - res.Entity, err = p.Entity() - return res, err - case "is": - p.advance() - var err error - res.Type = MatchIs - res.Path, err = p.Path() - if err == nil && p.peek().Text == "in" { - p.advance() - res.Type = MatchIsIn - res.Entity, err = p.Entity() - return res, err - } - return res, err - case "in": - p.advance() - var err error - res.Type = MatchIn - res.Entity, err = p.Entity() - return res, err - default: - return Principal{ - Type: MatchAny, - }, nil - } -} - -// Action := 'action' [( '==' Entity | 'in' ('[' EntList ']' | Entity) )] - -type Action struct { - Type MatchType - Entities []Entity -} - -func (a Action) String() string { - var sb strings.Builder - switch a.Type { - case MatchAny: - sb.WriteString("action") - case MatchEquals: - sb.WriteString(fmt.Sprintf("action == %s", a.Entities[0])) - case MatchIn: - sb.WriteString(fmt.Sprintf("action in %s", a.Entities[0])) - case MatchInList: - sb.WriteString("action in [") - for i, e := range a.Entities { - if i > 0 { - sb.WriteString(", ") - } - sb.WriteString(e.String()) - } - sb.WriteRune(']') - } - return sb.String() -} - -func (p *parser) action() (Action, error) { - var res Action - var err error - if err := p.exact("action"); err != nil { - return res, err - } - switch p.peek().Text { - case "==": - p.advance() - res.Type = MatchEquals - e, err := p.Entity() - if err != nil { - return res, err - } - res.Entities = append(res.Entities, e) - return res, nil - case "in": - p.advance() - if p.peek().Text == "[" { - res.Type = MatchInList - p.advance() - res.Entities, err = p.entlist() - if err != nil { - return res, err - } - p.advance() // entlist guarantees "]" - return res, nil - } else { - res.Type = MatchIn - e, err := p.Entity() - if err != nil { - return res, err - } - res.Entities = append(res.Entities, e) - return res, nil - } - default: - return Action{ - Type: MatchAny, - }, nil - } -} - -// Resource := 'resource' [('in' | '==') Entity)] - -type Resource struct { - Type MatchType - Path Path - Entity Entity -} - -func (r Resource) String() string { - var res string - switch r.Type { - case MatchAny: - res = "resource" - case MatchEquals: - res = fmt.Sprintf("resource == %s", r.Entity) - case MatchIs: - res = fmt.Sprintf("resource is %s", r.Path) - case MatchIsIn: - res = fmt.Sprintf("resource is %s in %s", r.Path, r.Entity) - case MatchIn: - res = fmt.Sprintf("resource in %s", r.Entity) - } - return res -} - -func (p *parser) resource() (Resource, error) { - var res Resource - if err := p.exact("resource"); err != nil { - return res, err - } - switch p.peek().Text { - case "==": - p.advance() - var err error - res.Type = MatchEquals - res.Entity, err = p.Entity() - return res, err - case "is": - p.advance() - var err error - res.Type = MatchIs - res.Path, err = p.Path() - if err == nil && p.peek().Text == "in" { - p.advance() - res.Type = MatchIsIn - res.Entity, err = p.Entity() - return res, err - } - return res, err - case "in": - p.advance() - var err error - res.Type = MatchIn - res.Entity, err = p.Entity() - return res, err - default: - return Resource{ - Type: MatchAny, - }, nil - } -} - -// Entity := Path '::' STR - -type Entity struct { - Path []string -} - -func (e Entity) String() string { - return fmt.Sprintf( - "%s::%q", - strings.Join(e.Path[0:len(e.Path)-1], "::"), - e.Path[len(e.Path)-1], - ) -} - -func (p *parser) Entity() (Entity, error) { - var res Entity - t := p.advance() - if !t.isIdent() { - return res, p.errorf("expected ident") - } - res.Path = append(res.Path, t.Text) - for { - if err := p.exact("::"); err != nil { - return res, err - } - t := p.advance() - switch { - case t.isIdent(): - res.Path = append(res.Path, t.Text) - case t.isString(): - component, err := t.stringValue() - if err != nil { - return res, err - } - res.Path = append(res.Path, component) - return res, nil - default: - return res, p.errorf("unexpected token") - } - } -} - -// Path ::= IDENT {'::' IDENT} - -type Path struct { - Path []string -} - -func (e Path) String() string { - return strings.Join(e.Path, "::") -} - -func (p *parser) Path() (Path, error) { - var res Path - t := p.advance() - if !t.isIdent() { - return res, p.errorf("expected ident") - } - res.Path = append(res.Path, t.Text) - for { - if p.peek().Text != "::" { - return res, nil - } - p.advance() - t := p.advance() - switch { - case t.isIdent(): - res.Path = append(res.Path, t.Text) - default: - return res, p.errorf("unexpected token") - } - } -} - -// EntList := Entity {',' Entity} - -func (p *parser) entlist() ([]Entity, error) { - var res []Entity - for p.peek().Text != "]" { - if len(res) > 0 { - if err := p.exact(","); err != nil { - return res, err - } - } - e, err := p.Entity() - if err != nil { - return res, err - } - res = append(res, e) - } - return res, nil -} - -// Condition := ('when' | 'unless') '{' Expr '}' - -type ConditionType string - -const ( - ConditionWhen ConditionType = "when" - ConditionUnless ConditionType = "unless" -) - -type Condition struct { - Type ConditionType - Expression Expression -} - -func (c Condition) String() string { - var res string - switch c.Type { - case ConditionWhen: - res = fmt.Sprintf("when {\n%s\n}", c.Expression) - case ConditionUnless: - res = fmt.Sprintf("unless {\n%s\n}", c.Expression) - } - return res -} - -func (p *parser) condition() (Condition, error) { - var res Condition - var err error - res.Type = ConditionType(p.advance().Text) - if err := p.exact("{"); err != nil { - return res, err - } - if res.Expression, err = p.expression(); err != nil { - return res, err - } - if err := p.exact("}"); err != nil { - return res, err - } - return res, nil -} - -func (p *parser) conditions() ([]Condition, error) { - var res []Condition - for { - switch p.peek().Text { - case "when", "unless": - c, err := p.condition() - if err != nil { - return res, err - } - res = append(res, c) - default: - return res, nil - } - } -} - -// Expr := Or | If - -type ExpressionType int - -const ( - ExpressionOr ExpressionType = iota - ExpressionIf -) - -type Expression struct { - Type ExpressionType - Or Or - If *If -} - -func (e Expression) String() string { - var res string - switch e.Type { - case ExpressionOr: - res = e.Or.String() - case ExpressionIf: - res = e.If.String() - } - return res -} - -func (p *parser) expression() (Expression, error) { - var res Expression - var err error - if p.peek().Text == "if" { - p.advance() - res.Type = ExpressionIf - i, err := p.ifExpr() - if err != nil { - return res, err - } - res.If = &i - return res, nil - } else { - res.Type = ExpressionOr - if res.Or, err = p.or(); err != nil { - return res, err - } - return res, nil - } -} - -// If := 'if' Expr 'then' Expr 'else' Expr - -type If struct { - If Expression - Then Expression - Else Expression -} - -func (i If) String() string { - return fmt.Sprintf("if %s then %s else %s", i.If, i.Then, i.Else) -} - -func (p *parser) ifExpr() (If, error) { - var res If - var err error - if res.If, err = p.expression(); err != nil { - return res, err - } - if err = p.exact("then"); err != nil { - return res, err - } - if res.Then, err = p.expression(); err != nil { - return res, err - } - if err = p.exact("else"); err != nil { - return res, err - } - if res.Else, err = p.expression(); err != nil { - return res, err - } - return res, err -} - -// Or := And {'||' And} - -type Or struct { - Ands []And -} - -func (o Or) String() string { - var sb strings.Builder - for i, and := range o.Ands { - if i > 0 { - sb.WriteString(" || ") - } - sb.WriteString(and.String()) - } - return sb.String() -} - -func (p *parser) or() (Or, error) { - var res Or - for { - a, err := p.and() - if err != nil { - return res, err - } - res.Ands = append(res.Ands, a) - if p.peek().Text != "||" { - return res, nil - } - p.advance() - } -} - -// And := Relation {'&&' Relation} - -type And struct { - Relations []Relation -} - -func (a And) String() string { - var sb strings.Builder - for i, rel := range a.Relations { - if i > 0 { - sb.WriteString(" && ") - } - sb.WriteString(rel.String()) - } - return sb.String() -} - -func (p *parser) and() (And, error) { - var res And - for { - r, err := p.relation() - if err != nil { - return res, err - } - res.Relations = append(res.Relations, r) - if p.peek().Text != "&&" { - return res, nil - } - p.advance() - } -} - -// Relation := Add [RELOP Add] | Add 'has' (IDENT | STR) | Add 'like' PAT - -type RelationType string - -const ( - RelationNone RelationType = "none" - RelationRelOp RelationType = "relop" - RelationHasIdent RelationType = "hasident" - RelationHasLiteral RelationType = "hasliteral" - RelationLike RelationType = "like" - RelationIs RelationType = "is" - RelationIsIn RelationType = "isIn" -) - -type Relation struct { - Add Add - Type RelationType - RelOp RelOp - RelOpRhs Add - Str string - Pat Pattern - Path Path - Entity Add -} - -func (r Relation) String() string { - var sb strings.Builder - sb.WriteString(r.Add.String()) - switch r.Type { - case RelationNone: - case RelationRelOp: - sb.WriteString(" ") - sb.WriteString(string(r.RelOp)) - sb.WriteString(" ") - sb.WriteString(r.RelOpRhs.String()) - case RelationHasIdent: - sb.WriteString(" has ") - sb.WriteString(r.Str) - case RelationHasLiteral: - sb.WriteString(" has ") - sb.WriteString(strconv.Quote(r.Str)) - case RelationLike: - sb.WriteString(" like ") - sb.WriteString(r.Pat.String()) - case RelationIs: - sb.WriteString(" is ") - sb.WriteString(r.Path.String()) - case RelationIsIn: - sb.WriteString(" is ") - sb.WriteString(r.Path.String()) - sb.WriteString(" in ") - sb.WriteString(r.Entity.String()) - } - return sb.String() -} - -func (p *parser) relation() (Relation, error) { - var res Relation - var err error - if res.Add, err = p.add(); err != nil { - return res, err - } - - t := p.peek() - switch t.Text { - case "<", "<=", ">=", ">", "!=", "==", "in": - p.advance() - res.Type = RelationRelOp - res.RelOp = RelOp(t.Text) - if res.RelOpRhs, err = p.add(); err != nil { - return res, err - } - case "has": - p.advance() - t := p.advance() - switch { - case t.isIdent(): - res.Type = RelationHasIdent - res.Str = t.Text - case t.isString(): - res.Type = RelationHasLiteral - if res.Str, err = t.stringValue(); err != nil { - return res, err - } - default: - return res, p.errorf("unexpected token") - } - case "like": - p.advance() - res.Type = RelationLike - t := p.advance() - if !t.isString() { - return res, p.errorf("unexpected token") - } - if res.Pat, err = t.patternValue(); err != nil { - return res, err - } - case "is": - p.advance() - var err error - res.Type = RelationIs - res.Path, err = p.Path() - if err == nil && p.peek().Text == "in" { - p.advance() - res.Type = RelationIsIn - res.Entity, err = p.add() - return res, err - } - return res, err - default: - res.Type = RelationNone - } - return res, nil -} - -// RELOP := '<' | '<=' | '>=' | '>' | '!=' | '==' | 'in' - -type RelOp string - -const ( - RelOpLt RelOp = "<" - RelOpLe RelOp = "<=" - RelOpGe RelOp = ">=" - RelOpGt RelOp = ">" - RelOpNe RelOp = "!=" - RelOpEq RelOp = "==" - RelOpIn RelOp = "in" -) - -// Add := Mult {ADDOP Mult} - -type Add struct { - Mults []Mult - AddOps []AddOp -} - -func (a Add) String() string { - var sb strings.Builder - sb.WriteString(a.Mults[0].String()) - for i, op := range a.AddOps { - sb.WriteString(fmt.Sprintf(" %s %s", op, a.Mults[i+1].String())) - } - return sb.String() -} - -func (p *parser) add() (Add, error) { - var res Add - var err error - mult, err := p.mult() - if err != nil { - return res, err - } - res.Mults = append(res.Mults, mult) - for { - op := AddOp(p.peek().Text) - switch op { - case AddOpAdd, AddOpSub: - default: - return res, nil - } - p.advance() - mult, err := p.mult() - if err != nil { - return res, err - } - res.AddOps = append(res.AddOps, op) - res.Mults = append(res.Mults, mult) - } -} - -// ADDOP := '+' | '-' - -type AddOp string - -const ( - AddOpAdd AddOp = "+" - AddOpSub AddOp = "-" -) - -// Mult := Unary { '*' Unary} - -type Mult struct { - Unaries []Unary -} - -func (m Mult) String() string { - var sb strings.Builder - for i, u := range m.Unaries { - if i > 0 { - sb.WriteString(" * ") - } - sb.WriteString(u.String()) - } - return sb.String() -} - -func (p *parser) mult() (Mult, error) { - var res Mult - for { - u, err := p.unary() - if err != nil { - return res, err - } - res.Unaries = append(res.Unaries, u) - if p.peek().Text != "*" { - return res, nil - } - p.advance() - } -} - -// Unary := [UNARYOP]x4 Member - -type Unary struct { - Ops []UnaryOp - Member Member -} - -func (u Unary) String() string { - var sb strings.Builder - for _, o := range u.Ops { - sb.WriteString(string(o)) - } - sb.WriteString(u.Member.String()) - return sb.String() -} - -func (p *parser) unary() (Unary, error) { - var res Unary - for { - o := UnaryOp(p.peek().Text) - switch o { - case UnaryOpNot: - p.advance() - res.Ops = append(res.Ops, o) - case UnaryOpMinus: - p.advance() - if p.peek().isInt() { - t := p.advance() - i, err := strconv.ParseInt("-"+t.Text, 10, 64) - if err != nil { - return res, err - } - res.Member = Member{ - Primary: Primary{ - Type: PrimaryLiteral, - Literal: Literal{ - Type: LiteralInt, - Long: i, - }, - }, - } - return res, nil - } - res.Ops = append(res.Ops, o) - default: - var err error - res.Member, err = p.member() - if err != nil { - return res, err - } - return res, nil - } - } -} - -// UNARYOP := '!' | '-' - -type UnaryOp string - -const ( - UnaryOpNot UnaryOp = "!" - UnaryOpMinus UnaryOp = "-" -) - -// Member := Primary {Access} - -type Member struct { - Primary Primary - Accesses []Access -} - -func (m Member) String() string { - var sb strings.Builder - sb.WriteString(m.Primary.String()) - for _, a := range m.Accesses { - sb.WriteString(a.String()) - } - return sb.String() -} - -func (p *parser) member() (Member, error) { - var res Member - var err error - if res.Primary, err = p.primary(); err != nil { - return res, err - } - for { - a, ok, err := p.access() - if !ok { - return res, err - } else { - res.Accesses = append(res.Accesses, a) - } - } -} - -// Primary := LITERAL -// | VAR -// | Entity -// | ExtFun '(' [ExprList] ')' -// | '(' Expr ')' -// | '[' [ExprList] ']' -// | '{' [RecInits] '}' - -type PrimaryType int - -const ( - PrimaryLiteral PrimaryType = iota - PrimaryVar - PrimaryEntity - PrimaryExtFun - PrimaryExpr - PrimaryExprList - PrimaryRecInits -) - -type Primary struct { - Type PrimaryType - Literal Literal - Var Var - Entity Entity - ExtFun ExtFun - Expression Expression - Expressions []Expression - RecInits []RecInit -} - -func (p Primary) String() string { - var res string - switch p.Type { - case PrimaryLiteral: - res = p.Literal.String() - case PrimaryVar: - res = p.Var.String() - case PrimaryEntity: - res = p.Entity.String() - case PrimaryExtFun: - res = p.ExtFun.String() - case PrimaryExpr: - res = fmt.Sprintf("(%s)", p.Expression) - case PrimaryExprList: - var sb strings.Builder - sb.WriteRune('[') - for i, e := range p.Expressions { - if i > 0 { - sb.WriteString(", ") - } - sb.WriteString(e.String()) - } - sb.WriteRune(']') - res = sb.String() - case PrimaryRecInits: - var sb strings.Builder - sb.WriteRune('{') - for i, r := range p.RecInits { - if i > 0 { - sb.WriteString(", ") - } - sb.WriteString(r.String()) - } - sb.WriteRune('}') - res = sb.String() - } - return res -} - -func (p *parser) primary() (Primary, error) { - var res Primary - var err error - t := p.advance() - switch { - case t.isInt(): - i, err := t.intValue() - if err != nil { - return res, err - } - res.Type = PrimaryLiteral - res.Literal = Literal{ - Type: LiteralInt, - Long: i, - } - case t.isString(): - res.Type = PrimaryLiteral - res.Literal.Type = LiteralString - if res.Literal.Str, err = t.stringValue(); err != nil { - return res, err - } - case t.Text == "true", t.Text == "false": - res.Type = PrimaryLiteral - res.Literal = Literal{ - Type: LiteralBool, - Bool: t.Text == "true", - } - case t.Text == string(VarPrincipal), - t.Text == string(VarAction), - t.Text == string(VarResource), - t.Text == string(VarContext): - res.Type = PrimaryVar - res.Var = Var{ - Type: VarType(t.Text), - } - case t.isIdent(): - e, f, err := p.entityOrExtFun(t.Text) - switch { - case e != nil: - res.Type = PrimaryEntity - res.Entity = *e - case f != nil: - res.Type = PrimaryExtFun - res.ExtFun = *f - default: - return res, err - } - case t.Text == "(": - res.Type = PrimaryExpr - if res.Expression, err = p.expression(); err != nil { - return res, err - } - if err := p.exact(")"); err != nil { - return res, err - } - case t.Text == "[": - res.Type = PrimaryExprList - if res.Expressions, err = p.expressions("]"); err != nil { - return res, err - } - p.advance() // expressions guarantees "]" - return res, err - case t.Text == "{": - res.Type = PrimaryRecInits - if res.RecInits, err = p.recInits(); err != nil { - return res, err - } - return res, err - default: - return res, p.errorf("invalid primary") - } - return res, nil -} - -func (p *parser) entityOrExtFun(first string) (*Entity, *ExtFun, error) { - path := []string{first} - for { - if p.peek().Text != "::" { - f, err := p.extFun(path) - if err != nil { - return nil, nil, err - } - return nil, &f, err - } - p.advance() - t := p.advance() - switch { - case t.isIdent(): - path = append(path, t.Text) - case t.isString(): - component, err := t.stringValue() - if err != nil { - return nil, nil, err - } - path = append(path, component) - return &Entity{Path: path}, nil, nil - default: - return nil, nil, p.errorf("unexpected token") - } - } -} - -func (p *parser) expressions(endOfListMarker string) ([]Expression, error) { - var res []Expression - for p.peek().Text != endOfListMarker { - if len(res) > 0 { - if err := p.exact(","); err != nil { - return res, err - } - } - e, err := p.expression() - if err != nil { - return res, err - } - res = append(res, e) - } - return res, nil -} - -func (p *parser) recInits() ([]RecInit, error) { - var res []RecInit - for { - t := p.peek() - if t.Text == "}" { - p.advance() - return res, nil - } - if len(res) > 0 { - if err := p.exact(","); err != nil { - return res, err - } - } - e, err := p.recInit() - if err != nil { - return res, err - } - res = append(res, e) - } -} - -// LITERAL := BOOL | INT | STR - -type LiteralType int - -const ( - LiteralBool LiteralType = iota - LiteralInt - LiteralString -) - -type Literal struct { - Type LiteralType - Bool bool - Long int64 - Str string -} - -func (l Literal) String() string { - var res string - switch l.Type { - case LiteralBool: - res = strconv.FormatBool(l.Bool) - case LiteralInt: - res = strconv.FormatInt(l.Long, 10) - case LiteralString: - res = strconv.Quote(l.Str) - } - return res -} - -// VAR := 'principal' | 'action' | 'resource' | 'context' - -type VarType string - -const ( - VarPrincipal VarType = "principal" - VarAction VarType = "action" - VarResource VarType = "resource" - VarContext VarType = "context" -) - -type Var struct { - Type VarType -} - -func (v Var) String() string { - return string(v.Type) -} - -// ExtFun := [Path '::'] IDENT - -type ExtFun struct { - Path []string - Expressions []Expression -} - -func (f ExtFun) String() string { - var sb strings.Builder - sb.WriteString(strings.Join(f.Path, "::")) - sb.WriteRune('(') - for i, e := range f.Expressions { - if i > 0 { - sb.WriteString(", ") - } - sb.WriteString(e.String()) - } - sb.WriteRune(')') - return sb.String() -} - -func (p *parser) extFun(path []string) (ExtFun, error) { - res := ExtFun{Path: path} - if err := p.exact("("); err != nil { - return res, err - } - var err error - if res.Expressions, err = p.expressions(")"); err != nil { - return res, err - } - p.advance() // expressions guarantees ")" - return res, err -} - -// Access := '.' IDENT ['(' [ExprList] ')'] | '[' STR ']' - -type AccessType int - -const ( - AccessField AccessType = iota - AccessCall - AccessIndex -) - -type Access struct { - Type AccessType - Name string - Expressions []Expression -} - -func (a Access) String() string { - var sb strings.Builder - switch a.Type { - case AccessField: - sb.WriteRune('.') - sb.WriteString(a.Name) - case AccessCall: - sb.WriteRune('.') - sb.WriteString(a.Name) - sb.WriteRune('(') - for i, e := range a.Expressions { - if i > 0 { - sb.WriteString(", ") - } - sb.WriteString(e.String()) - } - sb.WriteRune(')') - case AccessIndex: - sb.WriteRune('[') - sb.WriteString(strconv.Quote(a.Name)) - sb.WriteRune(']') - } - return sb.String() -} - -func (p *parser) access() (Access, bool, error) { - var res Access - var err error - t := p.peek() - switch t.Text { - case ".": - p.advance() - t := p.advance() - if !t.isIdent() { - return res, false, p.errorf("unexpected token") - } - res.Name = t.Text - if p.peek().Text == "(" { - p.advance() - res.Type = AccessCall - if res.Expressions, err = p.expressions(")"); err != nil { - return res, false, err - } - p.advance() // expressions guarantees ")" - } else { - res.Type = AccessField - } - case "[": - p.advance() - res.Type = AccessIndex - t := p.advance() - if !t.isString() { - return res, false, p.errorf("unexpected token") - } - if res.Name, err = t.stringValue(); err != nil { - return res, false, err - } - if err := p.exact("]"); err != nil { - return res, false, err - } - default: - return res, false, nil - } - return res, true, nil -} - -// RecInits := (IDENT | STR) ':' Expr {',' (IDENT | STR) ':' Expr} - -type RecKeyType int - -const ( - RecKeyIdent RecKeyType = iota - RecKeyString -) - -type RecInit struct { - KeyType RecKeyType - Key string - Value Expression -} - -func (r RecInit) String() string { - var sb strings.Builder - switch r.KeyType { - case RecKeyIdent: - sb.WriteString(r.Key) - case RecKeyString: - sb.WriteString(strconv.Quote(r.Key)) - } - sb.WriteString(": ") - sb.WriteString(r.Value.String()) - return sb.String() -} - -func (p *parser) recInit() (RecInit, error) { - var res RecInit - var err error - t := p.advance() - switch { - case t.isIdent(): - res.KeyType = RecKeyIdent - res.Key = t.Text - case t.isString(): - res.KeyType = RecKeyString - if res.Key, err = t.stringValue(); err != nil { - return res, err - } - default: - return res, p.errorf("unexpected token") - } - if err := p.exact(":"); err != nil { - return res, err - } - if res.Value, err = p.expression(); err != nil { - return res, err - } - return res, nil -} diff --git a/x/exp/parser/tokenize.go b/x/exp/parser/tokenize.go deleted file mode 100644 index e2e41d65..00000000 --- a/x/exp/parser/tokenize.go +++ /dev/null @@ -1,705 +0,0 @@ -package parser - -import ( - "bytes" - "fmt" - "io" - "strconv" - "strings" - "unicode" - "unicode/utf8" -) - -//go:generate moq -pkg parser -fmt goimports -out tokenize_mocks_test.go . reader - -// This type alias is for test purposes only. -type reader = io.Reader - -type TokenType int - -const ( - TokenEOF = TokenType(iota) - TokenIdent - TokenInt - TokenString - TokenOperator - TokenUnknown -) - -type Token struct { - Type TokenType - Pos Position - Text string -} - -func (t Token) isEOF() bool { - return t.Type == TokenEOF -} - -func (t Token) isIdent() bool { - return t.Type == TokenIdent -} - -func (t Token) isInt() bool { - return t.Type == TokenInt -} - -func (t Token) isString() bool { - return t.Type == TokenString -} - -func (t Token) toString() string { - return t.Text -} - -func (t Token) stringValue() (string, error) { - s := t.Text - s = strings.TrimPrefix(s, "\"") - s = strings.TrimSuffix(s, "\"") - b := []byte(s) - res, _, err := rustUnquote(b, false) - return res, err -} - -func (t Token) patternValue() (Pattern, error) { - return NewPattern(t.Text) -} - -func nextRune(b []byte, i int) (rune, int, error) { - ch, size := utf8.DecodeRune(b[i:]) - if ch == utf8.RuneError { - return ch, i, fmt.Errorf("bad unicode rune") - } - return ch, i + size, nil -} - -func parseHexEscape(b []byte, i int) (rune, int, error) { - var ch rune - var err error - ch, i, err = nextRune(b, i) - if err != nil { - return 0, i, err - } - if !isHexadecimal(ch) { - return 0, i, fmt.Errorf("bad hex escape sequence") - } - res := digitVal(ch) - ch, i, err = nextRune(b, i) - if err != nil { - return 0, i, err - } - if !isHexadecimal(ch) { - return 0, i, fmt.Errorf("bad hex escape sequence") - } - res = 16*res + digitVal(ch) - if res > 127 { - return 0, i, fmt.Errorf("bad hex escape sequence") - } - return rune(res), i, nil -} - -func parseUnicodeEscape(b []byte, i int) (rune, int, error) { - var ch rune - var err error - - ch, i, err = nextRune(b, i) - if err != nil { - return 0, i, err - } - if ch != '{' { - return 0, i, fmt.Errorf("bad unicode escape sequence") - } - - digits := 0 - res := 0 - for { - ch, i, err = nextRune(b, i) - if err != nil { - return 0, i, err - } - if ch == '}' { - break - } - if !isHexadecimal(ch) { - return 0, i, fmt.Errorf("bad unicode escape sequence") - } - res = 16*res + digitVal(ch) - digits++ - } - - if digits == 0 || digits > 6 || !utf8.ValidRune(rune(res)) { - return 0, i, fmt.Errorf("bad unicode escape sequence") - } - - return rune(res), i, nil -} - -func Unquote(s string) (string, error) { - s = strings.TrimPrefix(s, "\"") - s = strings.TrimSuffix(s, "\"") - res, _, err := rustUnquote([]byte(s), false) - return res, err -} - -func rustUnquote(b []byte, star bool) (string, []byte, error) { - var sb strings.Builder - var ch rune - var err error - i := 0 - for i < len(b) { - ch, i, err = nextRune(b, i) - if err != nil { - return "", nil, err - } - if star && ch == '*' { - i-- - return sb.String(), b[i:], nil - } - if ch != '\\' { - sb.WriteRune(ch) - continue - } - ch, i, err = nextRune(b, i) - if err != nil { - return "", nil, err - } - switch ch { - case 'n': - sb.WriteRune('\n') - case 'r': - sb.WriteRune('\r') - case 't': - sb.WriteRune('\t') - case '\\': - sb.WriteRune('\\') - case '0': - sb.WriteRune('\x00') - case '\'': - sb.WriteRune('\'') - case '"': - sb.WriteRune('"') - case 'x': - ch, i, err = parseHexEscape(b, i) - if err != nil { - return "", nil, err - } - sb.WriteRune(ch) - case 'u': - ch, i, err = parseUnicodeEscape(b, i) - if err != nil { - return "", nil, err - } - sb.WriteRune(ch) - case '*': - if !star { - return "", nil, fmt.Errorf("bad char escape") - } - sb.WriteRune('*') - default: - return "", nil, fmt.Errorf("bad char escape") - } - } - return sb.String(), b[i:], nil -} - -type PatternComponent struct { - Star bool - Chunk string -} - -type Pattern struct { - Comps []PatternComponent - Raw string -} - -func (p Pattern) String() string { - return p.Raw -} - -func NewPattern(literal string) (Pattern, error) { - rawPat := literal - - literal = strings.TrimPrefix(literal, "\"") - literal = strings.TrimSuffix(literal, "\"") - - b := []byte(literal) - - var comps []PatternComponent - for len(b) > 0 { - var comp PatternComponent - var err error - for len(b) > 0 && b[0] == '*' { - b = b[1:] - comp.Star = true - } - comp.Chunk, b, err = rustUnquote(b, true) - if err != nil { - return Pattern{}, err - } - comps = append(comps, comp) - } - return Pattern{ - Comps: comps, - Raw: rawPat, - }, nil -} - -func isHexadecimal(ch rune) bool { - return isDecimal(ch) || ('a' <= lower(ch) && lower(ch) <= 'f') -} - -// TODO: make FakeRustQuote actually accurate in all cases -func FakeRustQuote(s string) string { - return strconv.Quote(s) -} - -func (t Token) intValue() (int64, error) { - return strconv.ParseInt(t.Text, 10, 64) -} - -func Tokenize(src []byte) ([]Token, error) { - var res []Token - var s scanner - s.Init(bytes.NewBuffer(src)) - for tok := s.nextToken(); s.err == nil && tok.Type != TokenEOF; tok = s.nextToken() { - res = append(res, tok) - } - if s.err != nil { - return nil, s.err - } - res = append(res, Token{Type: TokenEOF, Pos: s.position}) - return res, nil -} - -// Position is a value that represents a source position. -// A position is valid if Line > 0. -type Position struct { - Offset int // byte offset, starting at 0 - Line int // line number, starting at 1 - Column int // column number, starting at 1 (character count per line) -} - -func (pos Position) String() string { - return fmt.Sprintf(":%d:%d", pos.Line, pos.Column) -} - -const ( - specialRuneEOF = rune(-(iota + 1)) - specialRuneBOF -) - -const bufLen = 1024 // at least utf8.UTFMax - -// A scanner implements reading of Unicode characters and tokens from an io.Reader. -type scanner struct { - // Input - src io.Reader - - // Source buffer - srcBuf [bufLen + 1]byte // +1 for sentinel for common case of s.next() - srcPos int // reading position (srcBuf index) - srcEnd int // source end (srcBuf index) - - // Source position - srcBufOffset int // byte offset of srcBuf[0] in source - line int // line count - column int // character count - lastLineLen int // length of last line in characters (for correct column reporting) - lastCharLen int // length of last character in bytes - - // Token text buffer - // Typically, token text is stored completely in srcBuf, but in general - // the token text's head may be buffered in tokBuf while the token text's - // tail is stored in srcBuf. - tokBuf bytes.Buffer // token text head that is not in srcBuf anymore - tokPos int // token text tail position (srcBuf index); valid if >= 0 - tokEnd int // token text tail end (srcBuf index) - - // One character look-ahead - ch rune // character before current srcPos - - // Last error encountered by nextToken. - err error - - // Start position of most recently scanned token; set by nextToken. - // Calling Init or Next invalidates the position (Line == 0). - // If an error is reported (via Error) and position is invalid, - // the scanner is not inside a token. Call Pos to obtain an error - // position in that case, or to obtain the position immediately - // after the most recently scanned token. - position Position -} - -// Init initializes a Scanner with a new source and returns s. -func (s *scanner) Init(src io.Reader) *scanner { - s.src = src - - // initialize source buffer - // (the first call to next() will fill it by calling src.Read) - s.srcBuf[0] = utf8.RuneSelf // sentinel - s.srcPos = 0 - s.srcEnd = 0 - - // initialize source position - s.srcBufOffset = 0 - s.line = 1 - s.column = 0 - s.lastLineLen = 0 - s.lastCharLen = 0 - - // initialize token text buffer - // (required for first call to next()). - s.tokPos = -1 - - // initialize one character look-ahead - s.ch = specialRuneBOF // no char read yet, not EOF - - // initialize public fields - s.position.Line = 0 // invalidate token position - - return s -} - -// next reads and returns the next Unicode character. It is designed such -// that only a minimal amount of work needs to be done in the common ASCII -// case (one test to check for both ASCII and end-of-buffer, and one test -// to check for newlines). -func (s *scanner) next() rune { - ch, width := rune(s.srcBuf[s.srcPos]), 1 - - if ch >= utf8.RuneSelf { - // uncommon case: not ASCII or not enough bytes - for s.srcPos+utf8.UTFMax > s.srcEnd && !utf8.FullRune(s.srcBuf[s.srcPos:s.srcEnd]) { - // not enough bytes: read some more, but first - // save away token text if any - if s.tokPos >= 0 { - s.tokBuf.Write(s.srcBuf[s.tokPos:s.srcPos]) - s.tokPos = 0 - // s.tokEnd is set by nextToken() - } - // move unread bytes to beginning of buffer - copy(s.srcBuf[0:], s.srcBuf[s.srcPos:s.srcEnd]) - s.srcBufOffset += s.srcPos - // read more bytes - // (an io.Reader must return io.EOF when it reaches - // the end of what it is reading - simply returning - // n == 0 will make this loop retry forever; but the - // error is in the reader implementation in that case) - i := s.srcEnd - s.srcPos - n, err := s.src.Read(s.srcBuf[i:bufLen]) - s.srcPos = 0 - s.srcEnd = i + n - s.srcBuf[s.srcEnd] = utf8.RuneSelf // sentinel - if err != nil { - if err != io.EOF { - s.error(err.Error()) - } - if s.srcEnd == 0 { - if s.lastCharLen > 0 { - // previous character was not EOF - s.column++ - } - s.lastCharLen = 0 - return specialRuneEOF - } - // If err == EOF, we won't be getting more - // bytes; break to avoid infinite loop. If - // err is something else, we don't know if - // we can get more bytes; thus also break. - break - } - } - // at least one byte - ch = rune(s.srcBuf[s.srcPos]) - if ch >= utf8.RuneSelf { - // uncommon case: not ASCII - ch, width = utf8.DecodeRune(s.srcBuf[s.srcPos:s.srcEnd]) - if ch == utf8.RuneError && width == 1 { - // advance for correct error position - s.srcPos += width - s.lastCharLen = width - s.column++ - s.error("invalid UTF-8 encoding") - return ch - } - } - } - - // advance - s.srcPos += width - s.lastCharLen = width - s.column++ - - // special situations - switch ch { - case 0: - // for compatibility with other tools - s.error("invalid character NUL") - case '\n': - s.line++ - s.lastLineLen = s.column - s.column = 0 - } - - return ch -} - -func (s *scanner) error(msg string) { - s.tokEnd = s.srcPos - s.lastCharLen // make sure token text is terminated - s.err = fmt.Errorf("%v: %v", s.position, msg) -} - -func isIdentRune(ch rune, first bool) bool { - return ch == '_' || unicode.IsLetter(ch) || unicode.IsDigit(ch) && !first -} - -func (s *scanner) scanIdentifier() rune { - // we know the zeroth rune is OK; start scanning at the next one - ch := s.next() - for isIdentRune(ch, false) { - ch = s.next() - } - return ch -} - -func lower(ch rune) rune { return ('a' - 'A') | ch } // returns lower-case ch iff ch is ASCII letter -func isDecimal(ch rune) bool { return '0' <= ch && ch <= '9' } - -func (s *scanner) scanInteger(ch rune) rune { - for isDecimal(ch) { - ch = s.next() - } - return ch -} - -func digitVal(ch rune) int { - switch { - case '0' <= ch && ch <= '9': - return int(ch - '0') - case 'a' <= lower(ch) && lower(ch) <= 'f': - return int(lower(ch) - 'a' + 10) - } - return 16 // larger than any legal digit val -} - -func (s *scanner) scanHexDigits(ch rune, min, max int) rune { - n := 0 - for n < max && isHexadecimal(ch) { - ch = s.next() - n++ - } - if n < min || n > max { - s.error("invalid char escape") - } - return ch -} - -func (s *scanner) scanEscape() rune { - ch := s.next() // read character after '/' - switch ch { - case 'n', 'r', 't', '\\', '0', '\'', '"', '*': - // nothing to do - ch = s.next() - case 'x': - ch = s.scanHexDigits(s.next(), 2, 2) - case 'u': - ch = s.next() - if ch != '{' { - s.error("invalid char escape") - return ch - } - ch = s.scanHexDigits(s.next(), 1, 6) - if ch != '}' { - s.error("invalid char escape") - return ch - } - ch = s.next() - default: - s.error("invalid char escape") - } - return ch -} - -func (s *scanner) scanString() (n int) { - ch := s.next() // read character after quote - for ch != '"' { - if ch == '\n' || ch < 0 { - s.error("literal not terminated") - return - } - if ch == '\\' { - ch = s.scanEscape() - } else { - ch = s.next() - } - n++ - } - return -} - -func (s *scanner) scanComment(ch rune) rune { - // ch == '/' || ch == '*' - if ch == '/' { - // line comment - ch = s.next() // read character after "//" - for ch != '\n' && ch >= 0 { - ch = s.next() - } - return ch - } - - // general comment - ch = s.next() // read character after "/*" - for { - if ch < 0 { - s.error("comment not terminated") - break - } - ch0 := ch - ch = s.next() - if ch0 == '*' && ch == '/' { - ch = s.next() - break - } - } - return ch -} - -func (s *scanner) scanOperator(ch0, ch rune) (TokenType, rune) { - switch ch0 { - case '@', '.', ',', ';', '(', ')', '{', '}', '[', ']', '+', '-', '*': - case ':': - if ch == ':' { - ch = s.next() - } - case '!', '<', '>': - if ch == '=' { - ch = s.next() - } - case '=': - if ch != '=' { - return TokenUnknown, ch - } - ch = s.next() - case '|': - if ch != '|' { - return TokenUnknown, ch - } - ch = s.next() - case '&': - if ch != '&' { - return TokenUnknown, ch - } - ch = s.next() - default: - return TokenUnknown, ch - } - return TokenOperator, ch -} - -func isWhitespace(c rune) bool { - switch c { - case '\t', '\n', '\r', ' ': - return true - default: - return false - } -} - -// nextToken reads the next token or Unicode character from source and returns -// it. It returns specialRuneEOF at the end of the source. It reports scanner -// errors (read and token errors) by calling s.Error, if not nil; otherwise it -// prints an error message to os.Stderr. -func (s *scanner) nextToken() Token { - if s.ch == specialRuneBOF { - s.ch = s.next() - } - - ch := s.ch - - // reset token text position - s.tokPos = -1 - s.position.Line = 0 - -redo: - // skip white space - for isWhitespace(ch) { - ch = s.next() - } - - // start collecting token text - s.tokBuf.Reset() - s.tokPos = s.srcPos - s.lastCharLen - - // set token position - s.position.Offset = s.srcBufOffset + s.tokPos - if s.column > 0 { - // common case: last character was not a '\n' - s.position.Line = s.line - s.position.Column = s.column - } else { - // last character was a '\n' - // (we cannot be at the beginning of the source - // since we have called next() at least once) - s.position.Line = s.line - 1 - s.position.Column = s.lastLineLen - } - - // determine token value - var tt TokenType - switch { - case ch == specialRuneEOF: - tt = TokenEOF - case isIdentRune(ch, true): - ch = s.scanIdentifier() - tt = TokenIdent - case isDecimal(ch): - ch = s.scanInteger(ch) - tt = TokenInt - case ch == '"': - s.scanString() - ch = s.next() - tt = TokenString - case ch == '/': - ch0 := ch - ch = s.next() - if ch == '/' || ch == '*' { - s.tokPos = -1 // don't collect token text - ch = s.scanComment(ch) - goto redo - } - tt, ch = s.scanOperator(ch0, ch) - default: - tt, ch = s.scanOperator(ch, s.next()) - } - - // end of token text - s.tokEnd = s.srcPos - s.lastCharLen - s.ch = ch - - return Token{ - Type: tt, - Pos: s.position, - Text: s.tokenText(), - } -} - -// tokenText returns the string corresponding to the most recently scanned token. -// Valid after calling nextToken and in calls of Scanner.Error. -func (s *scanner) tokenText() string { - if s.tokPos < 0 { - // no token text - return "" - } - - if s.tokBuf.Len() == 0 { - // common case: the entire token text is still in srcBuf - return string(s.srcBuf[s.tokPos:s.tokEnd]) - } - - // part of the token text was saved in tokBuf: save the rest in - // tokBuf as well and return its content - s.tokBuf.Write(s.srcBuf[s.tokPos:s.tokEnd]) - s.tokPos = s.tokEnd // ensure idempotency of TokenText() call - return s.tokBuf.String() -} diff --git a/x/exp/parser/tokenize_mocks_test.go b/x/exp/parser/tokenize_mocks_test.go deleted file mode 100644 index ff5a98fc..00000000 --- a/x/exp/parser/tokenize_mocks_test.go +++ /dev/null @@ -1,74 +0,0 @@ -// Code generated by moq; DO NOT EDIT. -// github.com/matryer/moq - -package parser - -import ( - "sync" -) - -// Ensure, that readerMock does implement reader. -// If this is not the case, regenerate this file with moq. -var _ reader = &readerMock{} - -// readerMock is a mock implementation of reader. -// -// func TestSomethingThatUsesreader(t *testing.T) { -// -// // make and configure a mocked reader -// mockedreader := &readerMock{ -// ReadFunc: func(p []byte) (int, error) { -// panic("mock out the Read method") -// }, -// } -// -// // use mockedreader in code that requires reader -// // and then make assertions. -// -// } -type readerMock struct { - // ReadFunc mocks the Read method. - ReadFunc func(p []byte) (int, error) - - // calls tracks calls to the methods. - calls struct { - // Read holds details about calls to the Read method. - Read []struct { - // P is the p argument value. - P []byte - } - } - lockRead sync.RWMutex -} - -// Read calls ReadFunc. -func (mock *readerMock) Read(p []byte) (int, error) { - if mock.ReadFunc == nil { - panic("readerMock.ReadFunc: method is nil but reader.Read was just called") - } - callInfo := struct { - P []byte - }{ - P: p, - } - mock.lockRead.Lock() - mock.calls.Read = append(mock.calls.Read, callInfo) - mock.lockRead.Unlock() - return mock.ReadFunc(p) -} - -// ReadCalls gets all the calls that were made to Read. -// Check the length with: -// -// len(mockedreader.ReadCalls()) -func (mock *readerMock) ReadCalls() []struct { - P []byte -} { - var calls []struct { - P []byte - } - mock.lockRead.RLock() - calls = mock.calls.Read - mock.lockRead.RUnlock() - return calls -} diff --git a/x/exp/parser/tokenize_test.go b/x/exp/parser/tokenize_test.go deleted file mode 100644 index 926d99ed..00000000 --- a/x/exp/parser/tokenize_test.go +++ /dev/null @@ -1,554 +0,0 @@ -package parser - -import ( - "fmt" - "io" - "strings" - "testing" - "unicode/utf8" - - "github.com/cedar-policy/cedar-go/internal/testutil" -) - -func TestTokenize(t *testing.T) { - t.Parallel() - input := ` -These are some identifiers -0 1 1234 --1 9223372036854775807 -9223372036854775808 -"" "string" "\"\'\n\r\t\\\0" "\x123" "\u{0}\u{10fFfF}" -"*" "\*" "*\**" -@.,;(){}[]+-* -::: -!!=<<=>>= -||&& -// single line comment -/* -multiline comment -// embedded comment does nothing -*/ -'/%|&=` - want := []Token{ - {Type: TokenIdent, Text: "These", Pos: Position{Offset: 1, Line: 2, Column: 1}}, - {Type: TokenIdent, Text: "are", Pos: Position{Offset: 7, Line: 2, Column: 7}}, - {Type: TokenIdent, Text: "some", Pos: Position{Offset: 11, Line: 2, Column: 11}}, - {Type: TokenIdent, Text: "identifiers", Pos: Position{Offset: 16, Line: 2, Column: 16}}, - - {Type: TokenInt, Text: "0", Pos: Position{Offset: 28, Line: 3, Column: 1}}, - {Type: TokenInt, Text: "1", Pos: Position{Offset: 30, Line: 3, Column: 3}}, - {Type: TokenInt, Text: "1234", Pos: Position{Offset: 32, Line: 3, Column: 5}}, - - {Type: TokenOperator, Text: "-", Pos: Position{Offset: 37, Line: 4, Column: 1}}, - {Type: TokenInt, Text: "1", Pos: Position{Offset: 38, Line: 4, Column: 2}}, - {Type: TokenInt, Text: "9223372036854775807", Pos: Position{Offset: 40, Line: 4, Column: 4}}, - {Type: TokenOperator, Text: "-", Pos: Position{Offset: 60, Line: 4, Column: 24}}, - {Type: TokenInt, Text: "9223372036854775808", Pos: Position{Offset: 61, Line: 4, Column: 25}}, - - {Type: TokenString, Text: `""`, Pos: Position{Offset: 81, Line: 5, Column: 1}}, - {Type: TokenString, Text: `"string"`, Pos: Position{Offset: 84, Line: 5, Column: 4}}, - {Type: TokenString, Text: `"\"\'\n\r\t\\\0"`, Pos: Position{Offset: 93, Line: 5, Column: 13}}, - {Type: TokenString, Text: `"\x123"`, Pos: Position{Offset: 110, Line: 5, Column: 30}}, - {Type: TokenString, Text: `"\u{0}\u{10fFfF}"`, Pos: Position{Offset: 118, Line: 5, Column: 38}}, - - {Type: TokenString, Text: `"*"`, Pos: Position{Offset: 136, Line: 6, Column: 1}}, - {Type: TokenString, Text: `"\*"`, Pos: Position{Offset: 140, Line: 6, Column: 5}}, - {Type: TokenString, Text: `"*\**"`, Pos: Position{Offset: 145, Line: 6, Column: 10}}, - - {Type: TokenOperator, Text: "@", Pos: Position{Offset: 152, Line: 7, Column: 1}}, - {Type: TokenOperator, Text: ".", Pos: Position{Offset: 153, Line: 7, Column: 2}}, - {Type: TokenOperator, Text: ",", Pos: Position{Offset: 154, Line: 7, Column: 3}}, - {Type: TokenOperator, Text: ";", Pos: Position{Offset: 155, Line: 7, Column: 4}}, - {Type: TokenOperator, Text: "(", Pos: Position{Offset: 156, Line: 7, Column: 5}}, - {Type: TokenOperator, Text: ")", Pos: Position{Offset: 157, Line: 7, Column: 6}}, - {Type: TokenOperator, Text: "{", Pos: Position{Offset: 158, Line: 7, Column: 7}}, - {Type: TokenOperator, Text: "}", Pos: Position{Offset: 159, Line: 7, Column: 8}}, - {Type: TokenOperator, Text: "[", Pos: Position{Offset: 160, Line: 7, Column: 9}}, - {Type: TokenOperator, Text: "]", Pos: Position{Offset: 161, Line: 7, Column: 10}}, - {Type: TokenOperator, Text: "+", Pos: Position{Offset: 162, Line: 7, Column: 11}}, - {Type: TokenOperator, Text: "-", Pos: Position{Offset: 163, Line: 7, Column: 12}}, - {Type: TokenOperator, Text: "*", Pos: Position{Offset: 164, Line: 7, Column: 13}}, - - {Type: TokenOperator, Text: "::", Pos: Position{Offset: 166, Line: 8, Column: 1}}, - {Type: TokenOperator, Text: ":", Pos: Position{Offset: 168, Line: 8, Column: 3}}, - - {Type: TokenOperator, Text: "!", Pos: Position{Offset: 170, Line: 9, Column: 1}}, - {Type: TokenOperator, Text: "!=", Pos: Position{Offset: 171, Line: 9, Column: 2}}, - {Type: TokenOperator, Text: "<", Pos: Position{Offset: 173, Line: 9, Column: 4}}, - {Type: TokenOperator, Text: "<=", Pos: Position{Offset: 174, Line: 9, Column: 5}}, - {Type: TokenOperator, Text: ">", Pos: Position{Offset: 176, Line: 9, Column: 7}}, - {Type: TokenOperator, Text: ">=", Pos: Position{Offset: 177, Line: 9, Column: 8}}, - - {Type: TokenOperator, Text: "||", Pos: Position{Offset: 180, Line: 10, Column: 1}}, - {Type: TokenOperator, Text: "&&", Pos: Position{Offset: 182, Line: 10, Column: 3}}, - - {Type: TokenUnknown, Text: "'", Pos: Position{Offset: 265, Line: 16, Column: 1}}, - {Type: TokenUnknown, Text: "/", Pos: Position{Offset: 266, Line: 16, Column: 2}}, - {Type: TokenUnknown, Text: "%", Pos: Position{Offset: 267, Line: 16, Column: 3}}, - {Type: TokenUnknown, Text: "|", Pos: Position{Offset: 268, Line: 16, Column: 4}}, - {Type: TokenUnknown, Text: "&", Pos: Position{Offset: 269, Line: 16, Column: 5}}, - {Type: TokenUnknown, Text: "=", Pos: Position{Offset: 270, Line: 16, Column: 6}}, - - {Type: TokenEOF, Text: "", Pos: Position{Offset: 271, Line: 16, Column: 7}}, - } - got, err := Tokenize([]byte(input)) - testutil.OK(t, err) - testutil.Equals(t, got, want) -} - -func TestTokenizeErrors(t *testing.T) { - t.Parallel() - tests := []struct { - input string - wantErrStr string - wantErrPos Position - }{ - {"okay\x00not okay", "invalid character NUL", Position{Line: 1, Column: 1}}, - {`okay /* - stuff - `, "comment not terminated", Position{Line: 1, Column: 6}}, - {`okay " - " foo bar`, "literal not terminated", Position{Line: 1, Column: 6}}, - {`"okay" "\a"`, "invalid char escape", Position{Line: 1, Column: 8}}, - {`"okay" "\b"`, "invalid char escape", Position{Line: 1, Column: 8}}, - {`"okay" "\f"`, "invalid char escape", Position{Line: 1, Column: 8}}, - {`"okay" "\v"`, "invalid char escape", Position{Line: 1, Column: 8}}, - {`"okay" "\1"`, "invalid char escape", Position{Line: 1, Column: 8}}, - {`"okay" "\x"`, "invalid char escape", Position{Line: 1, Column: 8}}, - {`"okay" "\x1"`, "invalid char escape", Position{Line: 1, Column: 8}}, - {`"okay" "\ubadf"`, "invalid char escape", Position{Line: 1, Column: 8}}, - {`"okay" "\U0000badf"`, "invalid char escape", Position{Line: 1, Column: 8}}, - {`"okay" "\u{}"`, "invalid char escape", Position{Line: 1, Column: 8}}, - {`"okay" "\u{0000000}"`, "invalid char escape", Position{Line: 1, Column: 8}}, - {`"okay" "\u{z"`, "invalid char escape", Position{Line: 1, Column: 8}}, - } - for i, tt := range tests { - tt := tt - t.Run(fmt.Sprintf("%02d", i), func(t *testing.T) { - t.Parallel() - got, gotErr := Tokenize([]byte(tt.input)) - wantErrStr := fmt.Sprintf("%v: %s", tt.wantErrPos, tt.wantErrStr) - testutil.Error(t, gotErr) - testutil.Equals(t, gotErr.Error(), wantErrStr) - testutil.Equals(t, got, nil) - }) - } -} - -func TestIntTokenValues(t *testing.T) { - t.Parallel() - tests := []struct { - input string - wantOk bool - want int64 - wantErr string - }{ - {"0", true, 0, ""}, - {"9223372036854775807", true, 9223372036854775807, ""}, - {"9223372036854775808", false, 0, `strconv.ParseInt: parsing "9223372036854775808": value out of range`}, - } - for _, tt := range tests { - tt := tt - t.Run(tt.input, func(t *testing.T) { - t.Parallel() - got, err := Tokenize([]byte(tt.input)) - testutil.OK(t, err) - testutil.Equals(t, len(got), 2) - testutil.Equals(t, got[0].Type, TokenInt) - gotInt, err := got[0].intValue() - if err != nil { - testutil.Equals(t, tt.wantOk, false) - testutil.Equals(t, err.Error(), tt.wantErr) - } else { - testutil.Equals(t, tt.wantOk, true) - testutil.Equals(t, gotInt, tt.want) - } - }) - } -} - -func TestStringTokenValues(t *testing.T) { - t.Parallel() - tests := []struct { - input string - wantOk bool - want string - wantErr string - }{ - {`""`, true, "", ""}, - {`"hello"`, true, "hello", ""}, - {`"a\n\r\t\\\0b"`, true, "a\n\r\t\\\x00b", ""}, - {`"a\"b"`, true, "a\"b", ""}, - {`"a\'b"`, true, "a'b", ""}, - - {`"a\x00b"`, true, "a\x00b", ""}, - {`"a\x7fb"`, true, "a\x7fb", ""}, - {`"a\x80b"`, false, "", "bad hex escape sequence"}, - - {`"a\u{A}b"`, true, "a\u000ab", ""}, - {`"a\u{aB}b"`, true, "a\u00abb", ""}, - {`"a\u{AbC}b"`, true, "a\u0abcb", ""}, - {`"a\u{aBcD}b"`, true, "a\uabcdb", ""}, - {`"a\u{AbCdE}b"`, true, "a\U000abcdeb", ""}, - {`"a\u{10cDeF}b"`, true, "a\U0010cdefb", ""}, - {`"a\u{ffffff}b"`, false, "", "bad unicode escape sequence"}, - {`"a\u{d7ff}b"`, true, "a\ud7ffb", ""}, - {`"a\u{d800}b"`, false, "", "bad unicode escape sequence"}, - {`"a\u{dfff}b"`, false, "", "bad unicode escape sequence"}, - {`"a\u{e000}b"`, true, "a\ue000b", ""}, - {`"a\u{10ffff}b"`, true, "a\U0010ffffb", ""}, - {`"a\u{110000}b"`, false, "", "bad unicode escape sequence"}, - } - for _, tt := range tests { - tt := tt - t.Run(tt.input, func(t *testing.T) { - t.Parallel() - got, err := Tokenize([]byte(tt.input)) - testutil.OK(t, err) - testutil.Equals(t, len(got), 2) - testutil.Equals(t, got[0].Type, TokenString) - gotStr, err := got[0].stringValue() - if err != nil { - testutil.Equals(t, tt.wantOk, false) - testutil.Equals(t, err.Error(), tt.wantErr) - } else { - testutil.Equals(t, tt.wantOk, true) - testutil.Equals(t, gotStr, tt.want) - } - }) - } -} - -func TestParseUnicodeEscape(t *testing.T) { - t.Parallel() - tests := []struct { - name string - in []byte - out rune - outN int - err func(t testing.TB, err error) - }{ - {"happy", []byte{'{', '4', '2', '}'}, 0x42, 4, testutil.OK}, - {"badRune", []byte{'{', 0x80, 0x81}, 0, 1, testutil.Error}, - {"notHex", []byte{'{', 'g'}, 0, 2, testutil.Error}, - } - for _, tt := range tests { - tt := tt - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - out, n, err := parseUnicodeEscape(tt.in, 0) - testutil.Equals(t, out, tt.out) - testutil.Equals(t, n, tt.outN) - tt.err(t, err) - }) - } -} - -func TestUnquote(t *testing.T) { - t.Parallel() - tests := []struct { - name string - in string - out string - err func(t testing.TB, err error) - }{ - {"happy", `"test"`, `test`, testutil.OK}, - } - for _, tt := range tests { - tt := tt - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - out, err := Unquote(tt.in) - testutil.Equals(t, out, tt.out) - tt.err(t, err) - }) - } -} - -func TestRustUnquote(t *testing.T) { - t.Parallel() - // star == false - { - tests := []struct { - input string - wantOk bool - want string - wantErr string - }{ - {``, true, "", ""}, - {`hello`, true, "hello", ""}, - {`a\n\r\t\\\0b`, true, "a\n\r\t\\\x00b", ""}, - {`a\"b`, true, "a\"b", ""}, - {`a\'b`, true, "a'b", ""}, - - {`a\x00b`, true, "a\x00b", ""}, - {`a\x7fb`, true, "a\x7fb", ""}, - {`a\x80b`, false, "", "bad hex escape sequence"}, - - {string([]byte{0x80, 0x81}), false, "", "bad unicode rune"}, - {`a\u`, false, "", "bad unicode rune"}, - {`a\uz`, false, "", "bad unicode escape sequence"}, - {`a\u{}b`, false, "", "bad unicode escape sequence"}, - {`a\u{A}b`, true, "a\u000ab", ""}, - {`a\u{aB}b`, true, "a\u00abb", ""}, - {`a\u{AbC}b`, true, "a\u0abcb", ""}, - {`a\u{aBcD}b`, true, "a\uabcdb", ""}, - {`a\u{AbCdE}b`, true, "a\U000abcdeb", ""}, - {`a\u{10cDeF}b`, true, "a\U0010cdefb", ""}, - {`a\u{ffffff}b`, false, "", "bad unicode escape sequence"}, - {`a\u{0000000}b`, false, "", "bad unicode escape sequence"}, - {`a\u{d7ff}b`, true, "a\ud7ffb", ""}, - {`a\u{d800}b`, false, "", "bad unicode escape sequence"}, - {`a\u{dfff}b`, false, "", "bad unicode escape sequence"}, - {`a\u{e000}b`, true, "a\ue000b", ""}, - {`a\u{10ffff}b`, true, "a\U0010ffffb", ""}, - {`a\u{110000}b`, false, "", "bad unicode escape sequence"}, - - {`\`, false, "", "bad unicode rune"}, - {`\a`, false, "", "bad char escape"}, - {`\*`, false, "", "bad char escape"}, - {`\x`, false, "", "bad unicode rune"}, - {`\xz`, false, "", "bad hex escape sequence"}, - {`\xa`, false, "", "bad unicode rune"}, - {`\xaz`, false, "", "bad hex escape sequence"}, - {`\{`, false, "", "bad char escape"}, - {`\{z`, false, "", "bad char escape"}, - {`\{0`, false, "", "bad char escape"}, - {`\{0z`, false, "", "bad char escape"}, - } - for _, tt := range tests { - tt := tt - t.Run(tt.input, func(t *testing.T) { - t.Parallel() - got, rem, err := rustUnquote([]byte(tt.input), false) - if err != nil { - testutil.Equals(t, tt.wantOk, false) - testutil.Equals(t, err.Error(), tt.wantErr) - testutil.Equals(t, got, tt.want) - } else { - testutil.Equals(t, tt.wantOk, true) - testutil.Equals(t, got, tt.want) - testutil.Equals(t, rem, []byte("")) - } - }) - } - } - - // star == true - { - tests := []struct { - input string - wantOk bool - want string - wantRem string - wantErr string - }{ - {``, true, "", "", ""}, - {`hello`, true, "hello", "", ""}, - {`a\n\r\t\\\0b`, true, "a\n\r\t\\\x00b", "", ""}, - {`a\"b`, true, "a\"b", "", ""}, - {`a\'b`, true, "a'b", "", ""}, - - {`a\x00b`, true, "a\x00b", "", ""}, - {`a\x7fb`, true, "a\x7fb", "", ""}, - {`a\x80b`, false, "", "", "bad hex escape sequence"}, - - {`a\u`, false, "", "", "bad unicode rune"}, - {`a\uz`, false, "", "", "bad unicode escape sequence"}, - {`a\u{}b`, false, "", "", "bad unicode escape sequence"}, - {`a\u{A}b`, true, "a\u000ab", "", ""}, - {`a\u{aB}b`, true, "a\u00abb", "", ""}, - {`a\u{AbC}b`, true, "a\u0abcb", "", ""}, - {`a\u{aBcD}b`, true, "a\uabcdb", "", ""}, - {`a\u{AbCdE}b`, true, "a\U000abcdeb", "", ""}, - {`a\u{10cDeF}b`, true, "a\U0010cdefb", "", ""}, - {`a\u{ffffff}b`, false, "", "", "bad unicode escape sequence"}, - {`a\u{0000000}b`, false, "", "", "bad unicode escape sequence"}, - {`a\u{d7ff}b`, true, "a\ud7ffb", "", ""}, - {`a\u{d800}b`, false, "", "", "bad unicode escape sequence"}, - {`a\u{dfff}b`, false, "", "", "bad unicode escape sequence"}, - {`a\u{e000}b`, true, "a\ue000b", "", ""}, - {`a\u{10ffff}b`, true, "a\U0010ffffb", "", ""}, - {`a\u{110000}b`, false, "", "", "bad unicode escape sequence"}, - - {`*`, true, "", "*", ""}, - {`*hello*how*are*you`, true, "", "*hello*how*are*you", ""}, - {`hello*how*are*you`, true, "hello", "*how*are*you", ""}, - {`\**`, true, "*", "*", ""}, - - {`\`, false, "", "", "bad unicode rune"}, - {`\a`, false, "", "", "bad char escape"}, - {`\x`, false, "", "", "bad unicode rune"}, - {`\xz`, false, "", "", "bad hex escape sequence"}, - {`\xa`, false, "", "", "bad unicode rune"}, - {`\xaz`, false, "", "", "bad hex escape sequence"}, - {`\{`, false, "", "", "bad char escape"}, - {`\{z`, false, "", "", "bad char escape"}, - {`\{0`, false, "", "", "bad char escape"}, - {`\{0z`, false, "", "", "bad char escape"}, - } - for _, tt := range tests { - tt := tt - t.Run(tt.input, func(t *testing.T) { - t.Parallel() - got, rem, err := rustUnquote([]byte(tt.input), true) - if err != nil { - testutil.Equals(t, tt.wantOk, false) - testutil.Equals(t, err.Error(), tt.wantErr) - testutil.Equals(t, got, tt.want) - } else { - testutil.Equals(t, tt.wantOk, true) - testutil.Equals(t, got, tt.want) - testutil.Equals(t, string(rem), tt.wantRem) - } - }) - } - } -} - -func TestFakeRustQuote(t *testing.T) { - t.Parallel() - out := FakeRustQuote("hello") - testutil.Equals(t, out, `"hello"`) -} - -func TestPatternFromStringLiteral(t *testing.T) { - t.Parallel() - tests := []struct { - input string - wantOk bool - want []PatternComponent - wantErr string - }{ - {`""`, true, nil, ""}, - {`"a"`, true, []PatternComponent{{false, "a"}}, ""}, - {`"*"`, true, []PatternComponent{{true, ""}}, ""}, - {`"*a"`, true, []PatternComponent{{true, "a"}}, ""}, - {`"a*"`, true, []PatternComponent{{false, "a"}, {true, ""}}, ""}, - {`"**"`, true, []PatternComponent{{true, ""}}, ""}, - {`"**a"`, true, []PatternComponent{{true, "a"}}, ""}, - {`"a**"`, true, []PatternComponent{{false, "a"}, {true, ""}}, ""}, - {`"*a*"`, true, []PatternComponent{{true, "a"}, {true, ""}}, ""}, - {`"**a**"`, true, []PatternComponent{{true, "a"}, {true, ""}}, ""}, - {`"abra*ca"`, true, []PatternComponent{ - {false, "abra"}, {true, "ca"}, - }, ""}, - {`"abra**ca"`, true, []PatternComponent{ - {false, "abra"}, {true, "ca"}, - }, ""}, - {`"*abra*ca"`, true, []PatternComponent{ - {true, "abra"}, {true, "ca"}, - }, ""}, - {`"abra*ca*"`, true, []PatternComponent{ - {false, "abra"}, {true, "ca"}, {true, ""}, - }, ""}, - {`"*abra*ca*"`, true, []PatternComponent{ - {true, "abra"}, {true, "ca"}, {true, ""}, - }, ""}, - {`"*abra*ca*dabra"`, true, []PatternComponent{ - {true, "abra"}, {true, "ca"}, {true, "dabra"}, - }, ""}, - {`"*abra*c\**da\*ra"`, true, []PatternComponent{ - {true, "abra"}, {true, "c*"}, {true, "da*ra"}, - }, ""}, - {`"\u"`, false, nil, "bad unicode rune"}, - } - for _, tt := range tests { - tt := tt - t.Run(tt.input, func(t *testing.T) { - t.Parallel() - got, err := NewPattern(tt.input) - if err != nil { - testutil.Equals(t, tt.wantOk, false) - testutil.Equals(t, err.Error(), tt.wantErr) - } else { - testutil.Equals(t, tt.wantOk, true) - testutil.Equals(t, got.Comps, tt.want) - testutil.Equals(t, got.String(), tt.input) - } - }) - } -} - -func TestScanner(t *testing.T) { - t.Parallel() - t.Run("SrcError", func(t *testing.T) { - t.Parallel() - wantErr := fmt.Errorf("wantErr") - r := &readerMock{ - ReadFunc: func(_ []byte) (int, error) { - return 0, wantErr - }, - } - var s scanner - s.Init(r) - out := s.next() - testutil.Equals(t, out, specialRuneEOF) - }) - - t.Run("MidEmojiEOF", func(t *testing.T) { - t.Parallel() - var s scanner - var eof bool - str := []byte(string(`🐐`)) - r := &readerMock{ - ReadFunc: func(p []byte) (int, error) { - if eof { - return 0, io.EOF - } - p[0] = str[0] - eof = true - return 1, nil - }, - } - s.Init(r) - out := s.next() - testutil.Equals(t, out, utf8.RuneError) - out = s.next() - testutil.Equals(t, out, specialRuneEOF) - }) - - t.Run("NotAsciiEmoji", func(t *testing.T) { - t.Parallel() - var s scanner - s.Init(strings.NewReader(`🐐`)) - out := s.next() - testutil.Equals(t, out, '🐐') - }) - - t.Run("InvalidUTF8", func(t *testing.T) { - t.Parallel() - var s scanner - s.Init(strings.NewReader(string([]byte{0x80, 0x81}))) - out := s.next() - testutil.Equals(t, out, utf8.RuneError) - }) - - t.Run("tokenTextNone", func(t *testing.T) { - t.Parallel() - var s scanner - s.Init(strings.NewReader("")) - out := s.tokenText() - testutil.Equals(t, out, "") - }) -} - -func TestDigitVal(t *testing.T) { - t.Parallel() - tests := []struct { - name string - in rune - out int - }{ - {"happy", '0', 0}, - {"hex", 'f', 15}, - {"sad", 'g', 16}, - } - for _, tt := range tests { - tt := tt - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - out := digitVal(tt.in) - testutil.Equals(t, out, tt.out) - }) - } -} From 69b29c5b0faa261fd189cde1f9e6e9d00974cfd4 Mon Sep 17 00:00:00 2001 From: philhassey Date: Thu, 8 Aug 2024 16:43:50 -0600 Subject: [PATCH 076/216] x/exp/ast: appease linter Addresses IDX-142 Signed-off-by: philhassey --- x/exp/ast/eval_test.go | 4 ++-- x/exp/ast/value.go | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/x/exp/ast/eval_test.go b/x/exp/ast/eval_test.go index b9d00e4b..59559fef 100644 --- a/x/exp/ast/eval_test.go +++ b/x/exp/ast/eval_test.go @@ -1714,8 +1714,8 @@ func TestInNode(t *testing.T) { Parents: ps, } } - EvalContext := EvalContext{Entities: entities} - v, err := n.Eval(&EvalContext) + ec := EvalContext{Entities: entities} + v, err := n.Eval(&ec) testutil.AssertError(t, err, tt.err) types.AssertValue(t, v, tt.result) }) diff --git a/x/exp/ast/value.go b/x/exp/ast/value.go index 4a47ae97..45dfba0a 100644 --- a/x/exp/ast/value.go +++ b/x/exp/ast/value.go @@ -55,6 +55,7 @@ func SetNodes(nodes ...Node) Node { // Record is a convenience function that wraps concrete instances of a Cedar Record type // types in AST value nodes and passes them along to RecordNodes. func Record(r types.Record) Node { + // TODO: this results in a double allocation, fix that recordNodes := map[types.String]Node{} for k, v := range r { recordNodes[types.String(k)] = valueToNode(v) From 098e8891bee406d8d2c008fc7f5e8769f6241aa6 Mon Sep 17 00:00:00 2001 From: philhassey Date: Thu, 8 Aug 2024 16:49:17 -0600 Subject: [PATCH 077/216] x/exp/ast: fix record creation being non-deterministic Addresses IDX-142 Signed-off-by: philhassey --- x/exp/ast/cedar_unmarshal.go | 12 +++++++----- x/exp/ast/value.go | 13 +++++++++++++ 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/x/exp/ast/cedar_unmarshal.go b/x/exp/ast/cedar_unmarshal.go index 8abbfca2..243310f8 100644 --- a/x/exp/ast/cedar_unmarshal.go +++ b/x/exp/ast/cedar_unmarshal.go @@ -807,14 +807,15 @@ func (p *parser) expressions(endOfListMarker string) ([]Node, error) { func (p *parser) record() (Node, error) { var res Node - entries := map[types.String]Node{} + var elements []RecordElement + known := map[types.String]struct{}{} for { t := p.peek() if t.Text == "}" { p.advance() - return RecordNodes(entries), nil + return RecordElements(elements...), nil } - if len(entries) > 0 { + if len(elements) > 0 { if err := p.exact(","); err != nil { return res, err } @@ -824,10 +825,11 @@ func (p *parser) record() (Node, error) { return res, err } - if _, ok := entries[k]; ok { + if _, ok := known[k]; ok { return res, p.errorf("duplicate key: %v", k) } - entries[k] = v + known[k] = struct{}{} + elements = append(elements, RecordElement{Key: k, Value: v}) } } diff --git a/x/exp/ast/value.go b/x/exp/ast/value.go index 45dfba0a..ea221e03 100644 --- a/x/exp/ast/value.go +++ b/x/exp/ast/value.go @@ -81,6 +81,19 @@ func RecordNodes(entries map[types.String]Node) Node { return newNode(res) } +type RecordElement struct { + Key types.String + Value Node +} + +func RecordElements(elements ...RecordElement) Node { + var res nodeTypeRecord + for _, e := range elements { + res.Elements = append(res.Elements, recordElement{Key: e.Key, Value: e.Value.v}) + } + return newNode(res) +} + func EntityType(e types.String) Node { return newValueNode(e) } From c19eee0f94aed0d657cd8cf03da5331fcf121cde Mon Sep 17 00:00:00 2001 From: philhassey Date: Thu, 8 Aug 2024 16:52:19 -0600 Subject: [PATCH 078/216] x/exp/ast: add note to TODO about errors Addresses IDX-142 Signed-off-by: philhassey --- x/exp/ast/cedar_marshal.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x/exp/ast/cedar_marshal.go b/x/exp/ast/cedar_marshal.go index e16d82b9..2b83551d 100644 --- a/x/exp/ast/cedar_marshal.go +++ b/x/exp/ast/cedar_marshal.go @@ -4,7 +4,7 @@ import ( "bytes" ) -// TODO: Add errors to all of this! +// TODO: Add errors to all of this! TODO: review this ask, I'm not sure any real errors are possible. All buf errors are panics. func (p *Policy) MarshalCedar(buf *bytes.Buffer) { for _, a := range p.annotations { a.MarshalCedar(buf) From 7edf69cc699489bdf4c0a88f698c68213473d943 Mon Sep 17 00:00:00 2001 From: philhassey Date: Thu, 8 Aug 2024 16:54:16 -0600 Subject: [PATCH 079/216] ast: move x/exp/ast to ast Addresses IDX-142 Signed-off-by: philhassey --- {x/exp/ast => ast}/annotation.go | 0 {x/exp/ast => ast}/ast_test.go | 2 +- {x/exp/ast => ast}/cedar.go | 0 {x/exp/ast => ast}/cedar_fuzz_test.go | 0 {x/exp/ast => ast}/cedar_marshal.go | 0 {x/exp/ast => ast}/cedar_parse_test.go | 2 +- {x/exp/ast => ast}/cedar_tokenize.go | 0 {x/exp/ast => ast}/cedar_tokenize_mocks_test.go | 0 {x/exp/ast => ast}/cedar_tokenize_test.go | 0 {x/exp/ast => ast}/cedar_unmarshal.go | 0 {x/exp/ast => ast}/cedar_unmarshal_test.go | 2 +- {x/exp/ast => ast}/eval_compile.go | 0 {x/exp/ast => ast}/eval_convert.go | 0 {x/exp/ast => ast}/eval_impl.go | 0 {x/exp/ast => ast}/eval_test.go | 0 {x/exp/ast => ast}/extensions.go | 0 {x/exp/ast => ast}/json.go | 0 {x/exp/ast => ast}/json_marshal.go | 0 {x/exp/ast => ast}/json_test.go | 2 +- {x/exp/ast => ast}/json_unmarshal.go | 0 {x/exp/ast => ast}/node.go | 0 {x/exp/ast => ast}/operator.go | 0 {x/exp/ast => ast}/policy.go | 0 {x/exp/ast => ast}/scope.go | 0 {x/exp/ast => ast}/value.go | 0 {x/exp/ast => ast}/variable.go | 0 cedar.go | 2 +- 27 files changed, 5 insertions(+), 5 deletions(-) rename {x/exp/ast => ast}/annotation.go (100%) rename {x/exp/ast => ast}/ast_test.go (97%) rename {x/exp/ast => ast}/cedar.go (100%) rename {x/exp/ast => ast}/cedar_fuzz_test.go (100%) rename {x/exp/ast => ast}/cedar_marshal.go (100%) rename {x/exp/ast => ast}/cedar_parse_test.go (99%) rename {x/exp/ast => ast}/cedar_tokenize.go (100%) rename {x/exp/ast => ast}/cedar_tokenize_mocks_test.go (100%) rename {x/exp/ast => ast}/cedar_tokenize_test.go (100%) rename {x/exp/ast => ast}/cedar_unmarshal.go (100%) rename {x/exp/ast => ast}/cedar_unmarshal_test.go (99%) rename {x/exp/ast => ast}/eval_compile.go (100%) rename {x/exp/ast => ast}/eval_convert.go (100%) rename {x/exp/ast => ast}/eval_impl.go (100%) rename {x/exp/ast => ast}/eval_test.go (100%) rename {x/exp/ast => ast}/extensions.go (100%) rename {x/exp/ast => ast}/json.go (100%) rename {x/exp/ast => ast}/json_marshal.go (100%) rename {x/exp/ast => ast}/json_test.go (99%) rename {x/exp/ast => ast}/json_unmarshal.go (100%) rename {x/exp/ast => ast}/node.go (100%) rename {x/exp/ast => ast}/operator.go (100%) rename {x/exp/ast => ast}/policy.go (100%) rename {x/exp/ast => ast}/scope.go (100%) rename {x/exp/ast => ast}/value.go (100%) rename {x/exp/ast => ast}/variable.go (100%) diff --git a/x/exp/ast/annotation.go b/ast/annotation.go similarity index 100% rename from x/exp/ast/annotation.go rename to ast/annotation.go diff --git a/x/exp/ast/ast_test.go b/ast/ast_test.go similarity index 97% rename from x/exp/ast/ast_test.go rename to ast/ast_test.go index b75b1a5d..966cdb0f 100644 --- a/x/exp/ast/ast_test.go +++ b/ast/ast_test.go @@ -3,8 +3,8 @@ package ast_test import ( "testing" + "github.com/cedar-policy/cedar-go/ast" "github.com/cedar-policy/cedar-go/types" - "github.com/cedar-policy/cedar-go/x/exp/ast" ) // These tests mostly verify that policy ASTs compile diff --git a/x/exp/ast/cedar.go b/ast/cedar.go similarity index 100% rename from x/exp/ast/cedar.go rename to ast/cedar.go diff --git a/x/exp/ast/cedar_fuzz_test.go b/ast/cedar_fuzz_test.go similarity index 100% rename from x/exp/ast/cedar_fuzz_test.go rename to ast/cedar_fuzz_test.go diff --git a/x/exp/ast/cedar_marshal.go b/ast/cedar_marshal.go similarity index 100% rename from x/exp/ast/cedar_marshal.go rename to ast/cedar_marshal.go diff --git a/x/exp/ast/cedar_parse_test.go b/ast/cedar_parse_test.go similarity index 99% rename from x/exp/ast/cedar_parse_test.go rename to ast/cedar_parse_test.go index 10d151cc..1599a612 100644 --- a/x/exp/ast/cedar_parse_test.go +++ b/ast/cedar_parse_test.go @@ -4,8 +4,8 @@ import ( "bytes" "testing" + "github.com/cedar-policy/cedar-go/ast" "github.com/cedar-policy/cedar-go/internal/testutil" - "github.com/cedar-policy/cedar-go/x/exp/ast" ) func TestParse(t *testing.T) { diff --git a/x/exp/ast/cedar_tokenize.go b/ast/cedar_tokenize.go similarity index 100% rename from x/exp/ast/cedar_tokenize.go rename to ast/cedar_tokenize.go diff --git a/x/exp/ast/cedar_tokenize_mocks_test.go b/ast/cedar_tokenize_mocks_test.go similarity index 100% rename from x/exp/ast/cedar_tokenize_mocks_test.go rename to ast/cedar_tokenize_mocks_test.go diff --git a/x/exp/ast/cedar_tokenize_test.go b/ast/cedar_tokenize_test.go similarity index 100% rename from x/exp/ast/cedar_tokenize_test.go rename to ast/cedar_tokenize_test.go diff --git a/x/exp/ast/cedar_unmarshal.go b/ast/cedar_unmarshal.go similarity index 100% rename from x/exp/ast/cedar_unmarshal.go rename to ast/cedar_unmarshal.go diff --git a/x/exp/ast/cedar_unmarshal_test.go b/ast/cedar_unmarshal_test.go similarity index 99% rename from x/exp/ast/cedar_unmarshal_test.go rename to ast/cedar_unmarshal_test.go index bc118953..e6068dc9 100644 --- a/x/exp/ast/cedar_unmarshal_test.go +++ b/ast/cedar_unmarshal_test.go @@ -4,9 +4,9 @@ import ( "bytes" "testing" + "github.com/cedar-policy/cedar-go/ast" "github.com/cedar-policy/cedar-go/internal/testutil" "github.com/cedar-policy/cedar-go/types" - "github.com/cedar-policy/cedar-go/x/exp/ast" ) var johnny = types.EntityUID{ diff --git a/x/exp/ast/eval_compile.go b/ast/eval_compile.go similarity index 100% rename from x/exp/ast/eval_compile.go rename to ast/eval_compile.go diff --git a/x/exp/ast/eval_convert.go b/ast/eval_convert.go similarity index 100% rename from x/exp/ast/eval_convert.go rename to ast/eval_convert.go diff --git a/x/exp/ast/eval_impl.go b/ast/eval_impl.go similarity index 100% rename from x/exp/ast/eval_impl.go rename to ast/eval_impl.go diff --git a/x/exp/ast/eval_test.go b/ast/eval_test.go similarity index 100% rename from x/exp/ast/eval_test.go rename to ast/eval_test.go diff --git a/x/exp/ast/extensions.go b/ast/extensions.go similarity index 100% rename from x/exp/ast/extensions.go rename to ast/extensions.go diff --git a/x/exp/ast/json.go b/ast/json.go similarity index 100% rename from x/exp/ast/json.go rename to ast/json.go diff --git a/x/exp/ast/json_marshal.go b/ast/json_marshal.go similarity index 100% rename from x/exp/ast/json_marshal.go rename to ast/json_marshal.go diff --git a/x/exp/ast/json_test.go b/ast/json_test.go similarity index 99% rename from x/exp/ast/json_test.go rename to ast/json_test.go index a61b7d0f..8701ed1b 100644 --- a/x/exp/ast/json_test.go +++ b/ast/json_test.go @@ -4,9 +4,9 @@ import ( "encoding/json" "testing" + "github.com/cedar-policy/cedar-go/ast" "github.com/cedar-policy/cedar-go/internal/testutil" "github.com/cedar-policy/cedar-go/types" - "github.com/cedar-policy/cedar-go/x/exp/ast" ) func TestUnmarshalJSON(t *testing.T) { diff --git a/x/exp/ast/json_unmarshal.go b/ast/json_unmarshal.go similarity index 100% rename from x/exp/ast/json_unmarshal.go rename to ast/json_unmarshal.go diff --git a/x/exp/ast/node.go b/ast/node.go similarity index 100% rename from x/exp/ast/node.go rename to ast/node.go diff --git a/x/exp/ast/operator.go b/ast/operator.go similarity index 100% rename from x/exp/ast/operator.go rename to ast/operator.go diff --git a/x/exp/ast/policy.go b/ast/policy.go similarity index 100% rename from x/exp/ast/policy.go rename to ast/policy.go diff --git a/x/exp/ast/scope.go b/ast/scope.go similarity index 100% rename from x/exp/ast/scope.go rename to ast/scope.go diff --git a/x/exp/ast/value.go b/ast/value.go similarity index 100% rename from x/exp/ast/value.go rename to ast/value.go diff --git a/x/exp/ast/variable.go b/ast/variable.go similarity index 100% rename from x/exp/ast/variable.go rename to ast/variable.go diff --git a/cedar.go b/cedar.go index 58cd3ec9..7cb2011a 100644 --- a/cedar.go +++ b/cedar.go @@ -6,8 +6,8 @@ import ( "slices" "strings" + "github.com/cedar-policy/cedar-go/ast" "github.com/cedar-policy/cedar-go/types" - "github.com/cedar-policy/cedar-go/x/exp/ast" "golang.org/x/exp/maps" ) From a2fd1244d2b70b57ba3e6dadac7277f8903aa5fa Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Mon, 12 Aug 2024 16:39:22 -0700 Subject: [PATCH 080/216] cedar-go/internal/ast: move public ast files into internal package From here, we'll start promoting things that we want to be public. Signed-off-by: philhassey --- cedar.go | 2 +- {ast => internal/ast}/annotation.go | 0 {ast => internal/ast}/ast_test.go | 2 +- {ast => internal/ast}/cedar.go | 0 {ast => internal/ast}/cedar_fuzz_test.go | 0 {ast => internal/ast}/cedar_marshal.go | 0 {ast => internal/ast}/cedar_parse_test.go | 2 +- {ast => internal/ast}/cedar_tokenize.go | 0 {ast => internal/ast}/cedar_tokenize_mocks_test.go | 0 {ast => internal/ast}/cedar_tokenize_test.go | 0 {ast => internal/ast}/cedar_unmarshal.go | 0 {ast => internal/ast}/cedar_unmarshal_test.go | 2 +- {ast => internal/ast}/eval_compile.go | 0 {ast => internal/ast}/eval_convert.go | 0 {ast => internal/ast}/eval_impl.go | 0 {ast => internal/ast}/eval_test.go | 0 {ast => internal/ast}/extensions.go | 0 {ast => internal/ast}/json.go | 0 {ast => internal/ast}/json_marshal.go | 0 {ast => internal/ast}/json_test.go | 2 +- {ast => internal/ast}/json_unmarshal.go | 0 {ast => internal/ast}/node.go | 0 {ast => internal/ast}/operator.go | 0 {ast => internal/ast}/policy.go | 0 {ast => internal/ast}/scope.go | 0 {ast => internal/ast}/value.go | 0 {ast => internal/ast}/variable.go | 0 27 files changed, 5 insertions(+), 5 deletions(-) rename {ast => internal/ast}/annotation.go (100%) rename {ast => internal/ast}/ast_test.go (97%) rename {ast => internal/ast}/cedar.go (100%) rename {ast => internal/ast}/cedar_fuzz_test.go (100%) rename {ast => internal/ast}/cedar_marshal.go (100%) rename {ast => internal/ast}/cedar_parse_test.go (99%) rename {ast => internal/ast}/cedar_tokenize.go (100%) rename {ast => internal/ast}/cedar_tokenize_mocks_test.go (100%) rename {ast => internal/ast}/cedar_tokenize_test.go (100%) rename {ast => internal/ast}/cedar_unmarshal.go (100%) rename {ast => internal/ast}/cedar_unmarshal_test.go (99%) rename {ast => internal/ast}/eval_compile.go (100%) rename {ast => internal/ast}/eval_convert.go (100%) rename {ast => internal/ast}/eval_impl.go (100%) rename {ast => internal/ast}/eval_test.go (100%) rename {ast => internal/ast}/extensions.go (100%) rename {ast => internal/ast}/json.go (100%) rename {ast => internal/ast}/json_marshal.go (100%) rename {ast => internal/ast}/json_test.go (99%) rename {ast => internal/ast}/json_unmarshal.go (100%) rename {ast => internal/ast}/node.go (100%) rename {ast => internal/ast}/operator.go (100%) rename {ast => internal/ast}/policy.go (100%) rename {ast => internal/ast}/scope.go (100%) rename {ast => internal/ast}/value.go (100%) rename {ast => internal/ast}/variable.go (100%) diff --git a/cedar.go b/cedar.go index 7cb2011a..512c3793 100644 --- a/cedar.go +++ b/cedar.go @@ -6,7 +6,7 @@ import ( "slices" "strings" - "github.com/cedar-policy/cedar-go/ast" + "github.com/cedar-policy/cedar-go/internal/ast" "github.com/cedar-policy/cedar-go/types" "golang.org/x/exp/maps" ) diff --git a/ast/annotation.go b/internal/ast/annotation.go similarity index 100% rename from ast/annotation.go rename to internal/ast/annotation.go diff --git a/ast/ast_test.go b/internal/ast/ast_test.go similarity index 97% rename from ast/ast_test.go rename to internal/ast/ast_test.go index 966cdb0f..c765abab 100644 --- a/ast/ast_test.go +++ b/internal/ast/ast_test.go @@ -3,7 +3,7 @@ package ast_test import ( "testing" - "github.com/cedar-policy/cedar-go/ast" + "github.com/cedar-policy/cedar-go/internal/ast" "github.com/cedar-policy/cedar-go/types" ) diff --git a/ast/cedar.go b/internal/ast/cedar.go similarity index 100% rename from ast/cedar.go rename to internal/ast/cedar.go diff --git a/ast/cedar_fuzz_test.go b/internal/ast/cedar_fuzz_test.go similarity index 100% rename from ast/cedar_fuzz_test.go rename to internal/ast/cedar_fuzz_test.go diff --git a/ast/cedar_marshal.go b/internal/ast/cedar_marshal.go similarity index 100% rename from ast/cedar_marshal.go rename to internal/ast/cedar_marshal.go diff --git a/ast/cedar_parse_test.go b/internal/ast/cedar_parse_test.go similarity index 99% rename from ast/cedar_parse_test.go rename to internal/ast/cedar_parse_test.go index 1599a612..4339abe2 100644 --- a/ast/cedar_parse_test.go +++ b/internal/ast/cedar_parse_test.go @@ -4,7 +4,7 @@ import ( "bytes" "testing" - "github.com/cedar-policy/cedar-go/ast" + "github.com/cedar-policy/cedar-go/internal/ast" "github.com/cedar-policy/cedar-go/internal/testutil" ) diff --git a/ast/cedar_tokenize.go b/internal/ast/cedar_tokenize.go similarity index 100% rename from ast/cedar_tokenize.go rename to internal/ast/cedar_tokenize.go diff --git a/ast/cedar_tokenize_mocks_test.go b/internal/ast/cedar_tokenize_mocks_test.go similarity index 100% rename from ast/cedar_tokenize_mocks_test.go rename to internal/ast/cedar_tokenize_mocks_test.go diff --git a/ast/cedar_tokenize_test.go b/internal/ast/cedar_tokenize_test.go similarity index 100% rename from ast/cedar_tokenize_test.go rename to internal/ast/cedar_tokenize_test.go diff --git a/ast/cedar_unmarshal.go b/internal/ast/cedar_unmarshal.go similarity index 100% rename from ast/cedar_unmarshal.go rename to internal/ast/cedar_unmarshal.go diff --git a/ast/cedar_unmarshal_test.go b/internal/ast/cedar_unmarshal_test.go similarity index 99% rename from ast/cedar_unmarshal_test.go rename to internal/ast/cedar_unmarshal_test.go index e6068dc9..72d1a00e 100644 --- a/ast/cedar_unmarshal_test.go +++ b/internal/ast/cedar_unmarshal_test.go @@ -4,7 +4,7 @@ import ( "bytes" "testing" - "github.com/cedar-policy/cedar-go/ast" + "github.com/cedar-policy/cedar-go/internal/ast" "github.com/cedar-policy/cedar-go/internal/testutil" "github.com/cedar-policy/cedar-go/types" ) diff --git a/ast/eval_compile.go b/internal/ast/eval_compile.go similarity index 100% rename from ast/eval_compile.go rename to internal/ast/eval_compile.go diff --git a/ast/eval_convert.go b/internal/ast/eval_convert.go similarity index 100% rename from ast/eval_convert.go rename to internal/ast/eval_convert.go diff --git a/ast/eval_impl.go b/internal/ast/eval_impl.go similarity index 100% rename from ast/eval_impl.go rename to internal/ast/eval_impl.go diff --git a/ast/eval_test.go b/internal/ast/eval_test.go similarity index 100% rename from ast/eval_test.go rename to internal/ast/eval_test.go diff --git a/ast/extensions.go b/internal/ast/extensions.go similarity index 100% rename from ast/extensions.go rename to internal/ast/extensions.go diff --git a/ast/json.go b/internal/ast/json.go similarity index 100% rename from ast/json.go rename to internal/ast/json.go diff --git a/ast/json_marshal.go b/internal/ast/json_marshal.go similarity index 100% rename from ast/json_marshal.go rename to internal/ast/json_marshal.go diff --git a/ast/json_test.go b/internal/ast/json_test.go similarity index 99% rename from ast/json_test.go rename to internal/ast/json_test.go index 8701ed1b..feb3bb76 100644 --- a/ast/json_test.go +++ b/internal/ast/json_test.go @@ -4,7 +4,7 @@ import ( "encoding/json" "testing" - "github.com/cedar-policy/cedar-go/ast" + "github.com/cedar-policy/cedar-go/internal/ast" "github.com/cedar-policy/cedar-go/internal/testutil" "github.com/cedar-policy/cedar-go/types" ) diff --git a/ast/json_unmarshal.go b/internal/ast/json_unmarshal.go similarity index 100% rename from ast/json_unmarshal.go rename to internal/ast/json_unmarshal.go diff --git a/ast/node.go b/internal/ast/node.go similarity index 100% rename from ast/node.go rename to internal/ast/node.go diff --git a/ast/operator.go b/internal/ast/operator.go similarity index 100% rename from ast/operator.go rename to internal/ast/operator.go diff --git a/ast/policy.go b/internal/ast/policy.go similarity index 100% rename from ast/policy.go rename to internal/ast/policy.go diff --git a/ast/scope.go b/internal/ast/scope.go similarity index 100% rename from ast/scope.go rename to internal/ast/scope.go diff --git a/ast/value.go b/internal/ast/value.go similarity index 100% rename from ast/value.go rename to internal/ast/value.go diff --git a/ast/variable.go b/internal/ast/variable.go similarity index 100% rename from ast/variable.go rename to internal/ast/variable.go From 5759171a1fab4e0edbf317d2715a6ee1193b4909 Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Mon, 12 Aug 2024 17:35:53 -0700 Subject: [PATCH 081/216] cedar-go/ast: add projections for all public methods into a new public AST package Signed-off-by: philhassey --- ast/annotation.go | 37 +++++++++ ast/ast_test.go | 70 ++++++++++++++++ ast/node.go | 7 ++ ast/operator.go | 167 +++++++++++++++++++++++++++++++++++++++ ast/policy.go | 23 ++++++ ast/scope.go | 49 ++++++++++++ ast/value.go | 109 +++++++++++++++++++++++++ ast/variable.go | 19 +++++ internal/ast/node.go | 4 + internal/ast/operator.go | 2 - internal/ast/value.go | 4 - 11 files changed, 485 insertions(+), 6 deletions(-) create mode 100644 ast/annotation.go create mode 100644 ast/ast_test.go create mode 100644 ast/node.go create mode 100644 ast/operator.go create mode 100644 ast/policy.go create mode 100644 ast/scope.go create mode 100644 ast/value.go create mode 100644 ast/variable.go diff --git a/ast/annotation.go b/ast/annotation.go new file mode 100644 index 00000000..d1083823 --- /dev/null +++ b/ast/annotation.go @@ -0,0 +1,37 @@ +package ast + +import ( + "github.com/cedar-policy/cedar-go/internal/ast" + "github.com/cedar-policy/cedar-go/types" +) + +type Annotations struct { + *ast.Annotations +} + +// Annotation allows AST constructors to make policy in a similar shape to textual Cedar with +// annotations appearing before the actual policy scope: +// +// ast := Annotation("foo", "bar"). +// Annotation("baz", "quux"). +// Permit(). +// PrincipalEq(superUser) +func Annotation(name, value types.String) *Annotations { + return &Annotations{ast.Annotation(name, value)} +} + +func (a *Annotations) Annotation(name, value types.String) *Annotations { + return &Annotations{a.Annotations.Annotation(name, value)} +} + +func (a *Annotations) Permit() *Policy { + return &Policy{a.Annotations.Permit()} +} + +func (a *Annotations) Forbid() *Policy { + return &Policy{a.Annotations.Forbid()} +} + +func (p *Policy) Annotate(name, value types.String) *Policy { + return &Policy{p.Policy.Annotate(name, value)} +} diff --git a/ast/ast_test.go b/ast/ast_test.go new file mode 100644 index 00000000..966cdb0f --- /dev/null +++ b/ast/ast_test.go @@ -0,0 +1,70 @@ +package ast_test + +import ( + "testing" + + "github.com/cedar-policy/cedar-go/ast" + "github.com/cedar-policy/cedar-go/types" +) + +// These tests mostly verify that policy ASTs compile +func TestAst(t *testing.T) { + t.Parallel() + + johnny := types.NewEntityUID("User", "johnny") + sow := types.NewEntityUID("Action", "sow") + cast := types.NewEntityUID("Action", "cast") + + // @example("one") + // permit ( + // principal == User::"johnny" + // action in [Action::"sow", Action::"cast"] + // resource + // ) + // when { true } + // unless { false }; + _ = ast.Annotation("example", "one"). + Permit(). + PrincipalIsIn("User", johnny). + ActionInSet(sow, cast). + When(ast.True()). + Unless(ast.False()) + + // @example("two") + // forbid (principal, action, resource) + // when { resource.tags.contains("private") } + // unless { resource in principal.allowed_resources }; + private := types.String("private") + _ = ast.Annotation("example", "two"). + Forbid(). + When( + ast.Resource().Access("tags").Contains(ast.String(private)), + ). + Unless( + ast.Resource().In(ast.Principal().Access("allowed_resources")), + ) + + // forbid (principal, action, resource) + // when { {x: "value"}.x == "value" } + // when { {x: 1 + context.fooCount}.x == 3 } + // when { [1, 2 + 3, context.fooCount].contains(1) }; + simpleRecord := types.Record{ + "x": types.String("value"), + } + _ = ast.Forbid(). + When( + ast.Record(simpleRecord).Access("x").Equals(ast.String("value")), + ). + When( + ast.RecordNodes(map[types.String]ast.Node{ + "x": ast.Long(1).Plus(ast.Context().Access("fooCount")), + }).Access("x").Equals(ast.Long(3)), + ). + When( + ast.SetNodes( + ast.Long(1), + ast.Long(2).Plus(ast.Long(3)), + ast.Context().Access("fooCount"), + ).Contains(ast.Long(1)), + ) +} diff --git a/ast/node.go b/ast/node.go new file mode 100644 index 00000000..ed30ec13 --- /dev/null +++ b/ast/node.go @@ -0,0 +1,7 @@ +package ast + +import "github.com/cedar-policy/cedar-go/internal/ast" + +type Node struct { + ast.Node +} diff --git a/ast/operator.go b/ast/operator.go new file mode 100644 index 00000000..98c3f786 --- /dev/null +++ b/ast/operator.go @@ -0,0 +1,167 @@ +package ast + +import ( + "github.com/cedar-policy/cedar-go/internal/ast" + "github.com/cedar-policy/cedar-go/types" +) + +// ____ _ +// / ___|___ _ __ ___ _ __ __ _ _ __(_)___ ___ _ __ +// | | / _ \| '_ ` _ \| '_ \ / _` | '__| / __|/ _ \| '_ \ +// | |__| (_) | | | | | | |_) | (_| | | | \__ \ (_) | | | | +// \____\___/|_| |_| |_| .__/ \__,_|_| |_|___/\___/|_| |_| +// |_| + +func (lhs Node) Equals(rhs Node) Node { + return Node{lhs.Node.Equals(rhs.Node)} +} + +func (lhs Node) NotEquals(rhs Node) Node { + return Node{lhs.Node.NotEquals(rhs.Node)} +} + +func (lhs Node) LessThan(rhs Node) Node { + return Node{lhs.Node.LessThan(rhs.Node)} +} + +func (lhs Node) LessThanOrEqual(rhs Node) Node { + return Node{lhs.Node.LessThanOrEqual(rhs.Node)} +} + +func (lhs Node) GreaterThan(rhs Node) Node { + return Node{lhs.Node.GreaterThan(rhs.Node)} +} + +func (lhs Node) GreaterThanOrEqual(rhs Node) Node { + return Node{lhs.Node.GreaterThanOrEqual(rhs.Node)} +} + +func (lhs Node) LessThanExt(rhs Node) Node { + return Node{lhs.Node.LessThanExt(rhs.Node)} +} + +func (lhs Node) LessThanOrEqualExt(rhs Node) Node { + return Node{lhs.Node.LessThanOrEqualExt(rhs.Node)} +} + +func (lhs Node) GreaterThanExt(rhs Node) Node { + return Node{lhs.Node.GreaterThanExt(rhs.Node)} +} + +func (lhs Node) GreaterThanOrEqualExt(rhs Node) Node { + return Node{lhs.Node.GreaterThanOrEqualExt(rhs.Node)} +} + +func (lhs Node) Like(pattern types.Pattern) Node { + return Node{lhs.Node.Like(pattern)} +} + +// _ _ _ +// | | ___ __ _(_) ___ __ _| | +// | | / _ \ / _` | |/ __/ _` | | +// | |__| (_) | (_| | | (_| (_| | | +// |_____\___/ \__, |_|\___\__,_|_| +// |___/ + +func (lhs Node) And(rhs Node) Node { + return Node{lhs.Node.And(rhs.Node)} +} + +func (lhs Node) Or(rhs Node) Node { + return Node{lhs.Node.Or(rhs.Node)} +} + +func Not(rhs Node) Node { + return Node{ast.Not(rhs.Node)} +} + +func If(condition Node, ifTrue Node, ifFalse Node) Node { + return Node{ast.If(condition.Node, ifTrue.Node, ifFalse.Node)} +} + +// _ _ _ _ _ _ +// / \ _ __(_) |_| |__ _ __ ___ ___| |_(_) ___ +// / _ \ | '__| | __| '_ \| '_ ` _ \ / _ \ __| |/ __| +// / ___ \| | | | |_| | | | | | | | | __/ |_| | (__ +// /_/ \_\_| |_|\__|_| |_|_| |_| |_|\___|\__|_|\___| + +func (lhs Node) Plus(rhs Node) Node { + return Node{lhs.Node.Plus(rhs.Node)} +} + +func (lhs Node) Minus(rhs Node) Node { + return Node{lhs.Node.Minus(rhs.Node)} +} + +func (lhs Node) Times(rhs Node) Node { + return Node{lhs.Node.Times(rhs.Node)} +} + +func Negate(rhs Node) Node { + return Node{ast.Negate(rhs.Node)} +} + +// _ _ _ _ +// | | | (_) ___ _ __ __ _ _ __ ___| |__ _ _ +// | |_| | |/ _ \ '__/ _` | '__/ __| '_ \| | | | +// | _ | | __/ | | (_| | | | (__| | | | |_| | +// |_| |_|_|\___|_| \__,_|_| \___|_| |_|\__, | +// |___/ + +func (lhs Node) In(rhs Node) Node { + return Node{lhs.Node.In(rhs.Node)} +} + +func (lhs Node) Is(entityType types.Path) Node { + return Node{lhs.Node.Is(entityType)} +} + +func (lhs Node) IsIn(entityType types.Path, rhs Node) Node { + return Node{lhs.Node.IsIn(entityType, rhs.Node)} +} + +func (lhs Node) Contains(rhs Node) Node { + return Node{lhs.Node.Contains(rhs.Node)} +} + +func (lhs Node) ContainsAll(rhs Node) Node { + return Node{lhs.Node.ContainsAll(rhs.Node)} +} + +func (lhs Node) ContainsAny(rhs Node) Node { + return Node{lhs.Node.ContainsAny(rhs.Node)} +} + +func (lhs Node) Access(attr string) Node { + return Node{lhs.Node.Access(attr)} +} + +func (lhs Node) Has(attr string) Node { + return Node{lhs.Node.Has(attr)} +} + +// ___ ____ _ _ _ +// |_ _| _ \ / \ __| | __| |_ __ ___ ___ ___ +// | || |_) / _ \ / _` |/ _` | '__/ _ \/ __/ __| +// | || __/ ___ \ (_| | (_| | | | __/\__ \__ \ +// |___|_| /_/ \_\__,_|\__,_|_| \___||___/___/ + +func (lhs Node) IsIpv4() Node { + return Node{lhs.Node.IsIpv4()} +} + +func (lhs Node) IsIpv6() Node { + return Node{lhs.Node.IsIpv6()} +} + +func (lhs Node) IsMulticast() Node { + return Node{lhs.Node.IsMulticast()} +} + +func (lhs Node) IsLoopback() Node { + return Node{lhs.Node.IsLoopback()} +} + +func (lhs Node) IsInRange(rhs Node) Node { + return Node{lhs.Node.IsInRange(rhs.Node)} +} diff --git a/ast/policy.go b/ast/policy.go new file mode 100644 index 00000000..07b9dbb7 --- /dev/null +++ b/ast/policy.go @@ -0,0 +1,23 @@ +package ast + +import "github.com/cedar-policy/cedar-go/internal/ast" + +type Policy struct { + *ast.Policy +} + +func Permit() *Policy { + return &Policy{ast.Permit()} +} + +func Forbid() *Policy { + return &Policy{ast.Forbid()} +} + +func (p *Policy) When(node Node) *Policy { + return &Policy{p.Policy.When(node.Node)} +} + +func (p *Policy) Unless(node Node) *Policy { + return &Policy{p.Policy.Unless(node.Node)} +} diff --git a/ast/scope.go b/ast/scope.go new file mode 100644 index 00000000..40d98444 --- /dev/null +++ b/ast/scope.go @@ -0,0 +1,49 @@ +package ast + +import ( + "github.com/cedar-policy/cedar-go/types" +) + +func (p *Policy) PrincipalEq(entity types.EntityUID) *Policy { + return &Policy{p.Policy.PrincipalEq(entity)} +} + +func (p *Policy) PrincipalIn(entity types.EntityUID) *Policy { + return &Policy{p.Policy.PrincipalIn(entity)} +} + +func (p *Policy) PrincipalIs(entityType types.Path) *Policy { + return &Policy{p.Policy.PrincipalIs(entityType)} +} + +func (p *Policy) PrincipalIsIn(entityType types.Path, entity types.EntityUID) *Policy { + return &Policy{p.Policy.PrincipalIsIn(entityType, entity)} +} + +func (p *Policy) ActionEq(entity types.EntityUID) *Policy { + return &Policy{p.Policy.ActionEq(entity)} +} + +func (p *Policy) ActionIn(entity types.EntityUID) *Policy { + return &Policy{p.Policy.ActionIn(entity)} +} + +func (p *Policy) ActionInSet(entities ...types.EntityUID) *Policy { + return &Policy{p.Policy.ActionInSet(entities...)} +} + +func (p *Policy) ResourceEq(entity types.EntityUID) *Policy { + return &Policy{p.Policy.ResourceEq(entity)} +} + +func (p *Policy) ResourceIn(entity types.EntityUID) *Policy { + return &Policy{p.Policy.ResourceIn(entity)} +} + +func (p *Policy) ResourceIs(entityType types.Path) *Policy { + return &Policy{p.Policy.ResourceIs(entityType)} +} + +func (p *Policy) ResourceIsIn(entityType types.Path, entity types.EntityUID) *Policy { + return &Policy{p.Policy.ResourceIsIn(entityType, entity)} +} diff --git a/ast/value.go b/ast/value.go new file mode 100644 index 00000000..31377199 --- /dev/null +++ b/ast/value.go @@ -0,0 +1,109 @@ +package ast + +import ( + "github.com/cedar-policy/cedar-go/internal/ast" + "github.com/cedar-policy/cedar-go/types" +) + +func Boolean(b types.Boolean) Node { + return Node{ast.Boolean(b)} +} + +func True() Node { + return Boolean(true) +} + +func False() Node { + return Boolean(false) +} + +func String(s types.String) Node { + return Node{ast.String(s)} +} + +func Long(l types.Long) Node { + return Node{ast.Long(l)} +} + +// Set is a convenience function that wraps concrete instances of a Cedar Set type +// types in AST value nodes and passes them along to SetNodes. +func Set(s types.Set) Node { + return Node{ast.Set(s)} +} + +// SetNodes allows for a complex set definition with values potentially +// being Cedar expressions of their own. For example, this Cedar text: +// +// [1, 2 + 3, context.fooCount] +// +// could be expressed in Golang as: +// +// ast.SetNodes( +// ast.Long(1), +// ast.Long(2).Plus(ast.Long(3)), +// ast.Context().Access("fooCount"), +// ) +func SetNodes(nodes ...Node) Node { + var astNodes []ast.Node + for _, n := range nodes { + astNodes = append(astNodes, n.Node) + } + return Node{ast.SetNodes(astNodes...)} +} + +// Record is a convenience function that wraps concrete instances of a Cedar Record type +// types in AST value nodes and passes them along to RecordNodes. +func Record(r types.Record) Node { + return Node{ast.Record(r)} +} + +// RecordNodes allows for a complex record definition with values potentially +// being Cedar expressions of their own. For example, this Cedar text: +// +// {"x": 1 + context.fooCount} +// +// could be expressed in Golang as: +// +// ast.RecordNodes(map[types.String]Node{ +// "x": ast.Long(1).Plus(ast.Context().Access("fooCount"))}, +// }) +func RecordNodes(entries map[types.String]Node) Node { + astNodes := map[types.String]ast.Node{} + for k, v := range entries { + astNodes[k] = v.Node + } + return Node{ast.RecordNodes(astNodes)} +} + +type RecordElement struct { + Key types.String + Value Node +} + +func RecordElements(elements ...RecordElement) Node { + var astNodes []ast.RecordElement + for _, v := range elements { + astNodes = append(astNodes, ast.RecordElement{Key: v.Key, Value: v.Value.Node}) + } + return Node{ast.RecordElements(astNodes...)} +} + +func EntityUID(e types.EntityUID) Node { + return Node{ast.EntityUID(e)} +} + +func Decimal(d types.Decimal) Node { + return Node{ast.Decimal(d)} +} + +func IPAddr(i types.IPAddr) Node { + return Node{ast.IPAddr(i)} +} + +func ExtensionCall(name types.String, args ...Node) Node { + var astNodes []ast.Node + for _, v := range args { + astNodes = append(astNodes, v.Node) + } + return Node{ast.ExtensionCall(name, astNodes...)} +} diff --git a/ast/variable.go b/ast/variable.go new file mode 100644 index 00000000..66f16284 --- /dev/null +++ b/ast/variable.go @@ -0,0 +1,19 @@ +package ast + +import "github.com/cedar-policy/cedar-go/internal/ast" + +func Principal() Node { + return Node{ast.Principal()} +} + +func Action() Node { + return Node{ast.Action()} +} + +func Resource() Node { + return Node{ast.Resource()} +} + +func Context() Node { + return Node{ast.Context()} +} diff --git a/internal/ast/node.go b/internal/ast/node.go index 850562a3..7ba210aa 100644 --- a/internal/ast/node.go +++ b/internal/ast/node.go @@ -14,6 +14,10 @@ func newNode(v node) Node { return Node{v: v} } +func NewNode(v node) Node { + return Node{v: v} +} + type strOpNode struct { Arg node Value types.String diff --git a/internal/ast/operator.go b/internal/ast/operator.go index a9b78ca1..c3a8de58 100644 --- a/internal/ast/operator.go +++ b/internal/ast/operator.go @@ -129,8 +129,6 @@ func (lhs Node) ContainsAny(rhs Node) Node { return newNode(nodeTypeContainsAny{binaryNode: binaryNode{Left: lhs.v, Right: rhs.v}}) } -// Access is a convenience function that wraps a simple string -// in an ast.String() and passes it along to AccessNode. func (lhs Node) Access(attr string) Node { return newNode(nodeTypeAccess{strOpNode: strOpNode{Arg: lhs.v, Value: types.String(attr)}}) } diff --git a/internal/ast/value.go b/internal/ast/value.go index ea221e03..becebbd7 100644 --- a/internal/ast/value.go +++ b/internal/ast/value.go @@ -94,10 +94,6 @@ func RecordElements(elements ...RecordElement) Node { return newNode(res) } -func EntityType(e types.String) Node { - return newValueNode(e) -} - func EntityUID(e types.EntityUID) Node { return newValueNode(e) } From dab7fe908a6ce706b738895c5f380eec0bba8a58 Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Tue, 13 Aug 2024 10:09:25 -0700 Subject: [PATCH 082/216] cedar-go/ast: add constructor functions for wrapper structs to reduce some noise Signed-off-by: philhassey --- ast/annotation.go | 14 +++++++---- ast/node.go | 4 +++ ast/operator.go | 64 +++++++++++++++++++++++------------------------ ast/policy.go | 12 ++++++--- ast/scope.go | 22 ++++++++-------- ast/value.go | 24 +++++++++--------- ast/variable.go | 8 +++--- 7 files changed, 80 insertions(+), 68 deletions(-) diff --git a/ast/annotation.go b/ast/annotation.go index d1083823..e3860354 100644 --- a/ast/annotation.go +++ b/ast/annotation.go @@ -9,6 +9,10 @@ type Annotations struct { *ast.Annotations } +func newAnnotations(a *ast.Annotations) *Annotations { + return &Annotations{a} +} + // Annotation allows AST constructors to make policy in a similar shape to textual Cedar with // annotations appearing before the actual policy scope: // @@ -17,21 +21,21 @@ type Annotations struct { // Permit(). // PrincipalEq(superUser) func Annotation(name, value types.String) *Annotations { - return &Annotations{ast.Annotation(name, value)} + return newAnnotations(ast.Annotation(name, value)) } func (a *Annotations) Annotation(name, value types.String) *Annotations { - return &Annotations{a.Annotations.Annotation(name, value)} + return newAnnotations(a.Annotations.Annotation(name, value)) } func (a *Annotations) Permit() *Policy { - return &Policy{a.Annotations.Permit()} + return newPolicy(a.Annotations.Permit()) } func (a *Annotations) Forbid() *Policy { - return &Policy{a.Annotations.Forbid()} + return newPolicy(a.Annotations.Forbid()) } func (p *Policy) Annotate(name, value types.String) *Policy { - return &Policy{p.Policy.Annotate(name, value)} + return newPolicy(p.Policy.Annotate(name, value)) } diff --git a/ast/node.go b/ast/node.go index ed30ec13..0b99e3e8 100644 --- a/ast/node.go +++ b/ast/node.go @@ -5,3 +5,7 @@ import "github.com/cedar-policy/cedar-go/internal/ast" type Node struct { ast.Node } + +func newNode(n ast.Node) Node { + return Node{n} +} diff --git a/ast/operator.go b/ast/operator.go index 98c3f786..2ca9c52a 100644 --- a/ast/operator.go +++ b/ast/operator.go @@ -13,47 +13,47 @@ import ( // |_| func (lhs Node) Equals(rhs Node) Node { - return Node{lhs.Node.Equals(rhs.Node)} + return newNode(lhs.Node.Equals(rhs.Node)) } func (lhs Node) NotEquals(rhs Node) Node { - return Node{lhs.Node.NotEquals(rhs.Node)} + return newNode(lhs.Node.NotEquals(rhs.Node)) } func (lhs Node) LessThan(rhs Node) Node { - return Node{lhs.Node.LessThan(rhs.Node)} + return newNode(lhs.Node.LessThan(rhs.Node)) } func (lhs Node) LessThanOrEqual(rhs Node) Node { - return Node{lhs.Node.LessThanOrEqual(rhs.Node)} + return newNode(lhs.Node.LessThanOrEqual(rhs.Node)) } func (lhs Node) GreaterThan(rhs Node) Node { - return Node{lhs.Node.GreaterThan(rhs.Node)} + return newNode(lhs.Node.GreaterThan(rhs.Node)) } func (lhs Node) GreaterThanOrEqual(rhs Node) Node { - return Node{lhs.Node.GreaterThanOrEqual(rhs.Node)} + return newNode(lhs.Node.GreaterThanOrEqual(rhs.Node)) } func (lhs Node) LessThanExt(rhs Node) Node { - return Node{lhs.Node.LessThanExt(rhs.Node)} + return newNode(lhs.Node.LessThanExt(rhs.Node)) } func (lhs Node) LessThanOrEqualExt(rhs Node) Node { - return Node{lhs.Node.LessThanOrEqualExt(rhs.Node)} + return newNode(lhs.Node.LessThanOrEqualExt(rhs.Node)) } func (lhs Node) GreaterThanExt(rhs Node) Node { - return Node{lhs.Node.GreaterThanExt(rhs.Node)} + return newNode(lhs.Node.GreaterThanExt(rhs.Node)) } func (lhs Node) GreaterThanOrEqualExt(rhs Node) Node { - return Node{lhs.Node.GreaterThanOrEqualExt(rhs.Node)} + return newNode(lhs.Node.GreaterThanOrEqualExt(rhs.Node)) } func (lhs Node) Like(pattern types.Pattern) Node { - return Node{lhs.Node.Like(pattern)} + return newNode(lhs.Node.Like(pattern)) } // _ _ _ @@ -64,19 +64,19 @@ func (lhs Node) Like(pattern types.Pattern) Node { // |___/ func (lhs Node) And(rhs Node) Node { - return Node{lhs.Node.And(rhs.Node)} + return newNode(lhs.Node.And(rhs.Node)) } func (lhs Node) Or(rhs Node) Node { - return Node{lhs.Node.Or(rhs.Node)} + return newNode(lhs.Node.Or(rhs.Node)) } func Not(rhs Node) Node { - return Node{ast.Not(rhs.Node)} + return newNode(ast.Not(rhs.Node)) } func If(condition Node, ifTrue Node, ifFalse Node) Node { - return Node{ast.If(condition.Node, ifTrue.Node, ifFalse.Node)} + return newNode(ast.If(condition.Node, ifTrue.Node, ifFalse.Node)) } // _ _ _ _ _ _ @@ -86,19 +86,19 @@ func If(condition Node, ifTrue Node, ifFalse Node) Node { // /_/ \_\_| |_|\__|_| |_|_| |_| |_|\___|\__|_|\___| func (lhs Node) Plus(rhs Node) Node { - return Node{lhs.Node.Plus(rhs.Node)} + return newNode(lhs.Node.Plus(rhs.Node)) } func (lhs Node) Minus(rhs Node) Node { - return Node{lhs.Node.Minus(rhs.Node)} + return newNode(lhs.Node.Minus(rhs.Node)) } func (lhs Node) Times(rhs Node) Node { - return Node{lhs.Node.Times(rhs.Node)} + return newNode(lhs.Node.Times(rhs.Node)) } func Negate(rhs Node) Node { - return Node{ast.Negate(rhs.Node)} + return newNode(ast.Negate(rhs.Node)) } // _ _ _ _ @@ -109,35 +109,35 @@ func Negate(rhs Node) Node { // |___/ func (lhs Node) In(rhs Node) Node { - return Node{lhs.Node.In(rhs.Node)} + return newNode(lhs.Node.In(rhs.Node)) } func (lhs Node) Is(entityType types.Path) Node { - return Node{lhs.Node.Is(entityType)} + return newNode(lhs.Node.Is(entityType)) } func (lhs Node) IsIn(entityType types.Path, rhs Node) Node { - return Node{lhs.Node.IsIn(entityType, rhs.Node)} + return newNode(lhs.Node.IsIn(entityType, rhs.Node)) } func (lhs Node) Contains(rhs Node) Node { - return Node{lhs.Node.Contains(rhs.Node)} + return newNode(lhs.Node.Contains(rhs.Node)) } func (lhs Node) ContainsAll(rhs Node) Node { - return Node{lhs.Node.ContainsAll(rhs.Node)} + return newNode(lhs.Node.ContainsAll(rhs.Node)) } func (lhs Node) ContainsAny(rhs Node) Node { - return Node{lhs.Node.ContainsAny(rhs.Node)} + return newNode(lhs.Node.ContainsAny(rhs.Node)) } func (lhs Node) Access(attr string) Node { - return Node{lhs.Node.Access(attr)} + return newNode(lhs.Node.Access(attr)) } func (lhs Node) Has(attr string) Node { - return Node{lhs.Node.Has(attr)} + return newNode(lhs.Node.Has(attr)) } // ___ ____ _ _ _ @@ -147,21 +147,21 @@ func (lhs Node) Has(attr string) Node { // |___|_| /_/ \_\__,_|\__,_|_| \___||___/___/ func (lhs Node) IsIpv4() Node { - return Node{lhs.Node.IsIpv4()} + return newNode(lhs.Node.IsIpv4()) } func (lhs Node) IsIpv6() Node { - return Node{lhs.Node.IsIpv6()} + return newNode(lhs.Node.IsIpv6()) } func (lhs Node) IsMulticast() Node { - return Node{lhs.Node.IsMulticast()} + return newNode(lhs.Node.IsMulticast()) } func (lhs Node) IsLoopback() Node { - return Node{lhs.Node.IsLoopback()} + return newNode(lhs.Node.IsLoopback()) } func (lhs Node) IsInRange(rhs Node) Node { - return Node{lhs.Node.IsInRange(rhs.Node)} + return newNode(lhs.Node.IsInRange(rhs.Node)) } diff --git a/ast/policy.go b/ast/policy.go index 07b9dbb7..bd5c6985 100644 --- a/ast/policy.go +++ b/ast/policy.go @@ -6,18 +6,22 @@ type Policy struct { *ast.Policy } +func newPolicy(p *ast.Policy) *Policy { + return &Policy{p} +} + func Permit() *Policy { - return &Policy{ast.Permit()} + return newPolicy(ast.Permit()) } func Forbid() *Policy { - return &Policy{ast.Forbid()} + return newPolicy(ast.Forbid()) } func (p *Policy) When(node Node) *Policy { - return &Policy{p.Policy.When(node.Node)} + return newPolicy(p.Policy.When(node.Node)) } func (p *Policy) Unless(node Node) *Policy { - return &Policy{p.Policy.Unless(node.Node)} + return newPolicy(p.Policy.Unless(node.Node)) } diff --git a/ast/scope.go b/ast/scope.go index 40d98444..37b4e2ac 100644 --- a/ast/scope.go +++ b/ast/scope.go @@ -5,45 +5,45 @@ import ( ) func (p *Policy) PrincipalEq(entity types.EntityUID) *Policy { - return &Policy{p.Policy.PrincipalEq(entity)} + return newPolicy(p.Policy.PrincipalEq(entity)) } func (p *Policy) PrincipalIn(entity types.EntityUID) *Policy { - return &Policy{p.Policy.PrincipalIn(entity)} + return newPolicy(p.Policy.PrincipalIn(entity)) } func (p *Policy) PrincipalIs(entityType types.Path) *Policy { - return &Policy{p.Policy.PrincipalIs(entityType)} + return newPolicy(p.Policy.PrincipalIs(entityType)) } func (p *Policy) PrincipalIsIn(entityType types.Path, entity types.EntityUID) *Policy { - return &Policy{p.Policy.PrincipalIsIn(entityType, entity)} + return newPolicy(p.Policy.PrincipalIsIn(entityType, entity)) } func (p *Policy) ActionEq(entity types.EntityUID) *Policy { - return &Policy{p.Policy.ActionEq(entity)} + return newPolicy(p.Policy.ActionEq(entity)) } func (p *Policy) ActionIn(entity types.EntityUID) *Policy { - return &Policy{p.Policy.ActionIn(entity)} + return newPolicy(p.Policy.ActionIn(entity)) } func (p *Policy) ActionInSet(entities ...types.EntityUID) *Policy { - return &Policy{p.Policy.ActionInSet(entities...)} + return newPolicy(p.Policy.ActionInSet(entities...)) } func (p *Policy) ResourceEq(entity types.EntityUID) *Policy { - return &Policy{p.Policy.ResourceEq(entity)} + return newPolicy(p.Policy.ResourceEq(entity)) } func (p *Policy) ResourceIn(entity types.EntityUID) *Policy { - return &Policy{p.Policy.ResourceIn(entity)} + return newPolicy(p.Policy.ResourceIn(entity)) } func (p *Policy) ResourceIs(entityType types.Path) *Policy { - return &Policy{p.Policy.ResourceIs(entityType)} + return newPolicy(p.Policy.ResourceIs(entityType)) } func (p *Policy) ResourceIsIn(entityType types.Path, entity types.EntityUID) *Policy { - return &Policy{p.Policy.ResourceIsIn(entityType, entity)} + return newPolicy(p.Policy.ResourceIsIn(entityType, entity)) } diff --git a/ast/value.go b/ast/value.go index 31377199..b2adb5c5 100644 --- a/ast/value.go +++ b/ast/value.go @@ -6,7 +6,7 @@ import ( ) func Boolean(b types.Boolean) Node { - return Node{ast.Boolean(b)} + return newNode(ast.Boolean(b)) } func True() Node { @@ -18,17 +18,17 @@ func False() Node { } func String(s types.String) Node { - return Node{ast.String(s)} + return newNode(ast.String(s)) } func Long(l types.Long) Node { - return Node{ast.Long(l)} + return newNode(ast.Long(l)) } // Set is a convenience function that wraps concrete instances of a Cedar Set type // types in AST value nodes and passes them along to SetNodes. func Set(s types.Set) Node { - return Node{ast.Set(s)} + return newNode(ast.Set(s)) } // SetNodes allows for a complex set definition with values potentially @@ -48,13 +48,13 @@ func SetNodes(nodes ...Node) Node { for _, n := range nodes { astNodes = append(astNodes, n.Node) } - return Node{ast.SetNodes(astNodes...)} + return newNode(ast.SetNodes(astNodes...)) } // Record is a convenience function that wraps concrete instances of a Cedar Record type // types in AST value nodes and passes them along to RecordNodes. func Record(r types.Record) Node { - return Node{ast.Record(r)} + return newNode(ast.Record(r)) } // RecordNodes allows for a complex record definition with values potentially @@ -72,7 +72,7 @@ func RecordNodes(entries map[types.String]Node) Node { for k, v := range entries { astNodes[k] = v.Node } - return Node{ast.RecordNodes(astNodes)} + return newNode(ast.RecordNodes(astNodes)) } type RecordElement struct { @@ -85,19 +85,19 @@ func RecordElements(elements ...RecordElement) Node { for _, v := range elements { astNodes = append(astNodes, ast.RecordElement{Key: v.Key, Value: v.Value.Node}) } - return Node{ast.RecordElements(astNodes...)} + return newNode(ast.RecordElements(astNodes...)) } func EntityUID(e types.EntityUID) Node { - return Node{ast.EntityUID(e)} + return newNode(ast.EntityUID(e)) } func Decimal(d types.Decimal) Node { - return Node{ast.Decimal(d)} + return newNode(ast.Decimal(d)) } func IPAddr(i types.IPAddr) Node { - return Node{ast.IPAddr(i)} + return newNode(ast.IPAddr(i)) } func ExtensionCall(name types.String, args ...Node) Node { @@ -105,5 +105,5 @@ func ExtensionCall(name types.String, args ...Node) Node { for _, v := range args { astNodes = append(astNodes, v.Node) } - return Node{ast.ExtensionCall(name, astNodes...)} + return newNode(ast.ExtensionCall(name, astNodes...)) } diff --git a/ast/variable.go b/ast/variable.go index 66f16284..7d724bf3 100644 --- a/ast/variable.go +++ b/ast/variable.go @@ -3,17 +3,17 @@ package ast import "github.com/cedar-policy/cedar-go/internal/ast" func Principal() Node { - return Node{ast.Principal()} + return newNode(ast.Principal()) } func Action() Node { - return Node{ast.Action()} + return newNode(ast.Action()) } func Resource() Node { - return Node{ast.Resource()} + return newNode(ast.Resource()) } func Context() Node { - return Node{ast.Context()} + return newNode(ast.Context()) } From 800f71086b8d4e5e102979fb9ff2395762764b33 Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Tue, 13 Aug 2024 10:48:00 -0700 Subject: [PATCH 083/216] internal/entities: break out entities from internal/ast package Signed-off-by: philhassey --- cedar.go | 27 +- cedar_test.go | 418 ++++++++---------- corpus_test.go | 3 +- internal/ast/eval_impl.go | 7 +- internal/ast/eval_test.go | 47 +- .../{ast/cedar.go => entities/entities.go} | 5 +- 6 files changed, 231 insertions(+), 276 deletions(-) rename internal/{ast/cedar.go => entities/entities.go} (91%) diff --git a/cedar.go b/cedar.go index 512c3793..6e048303 100644 --- a/cedar.go +++ b/cedar.go @@ -3,12 +3,10 @@ package cedar import ( "fmt" - "slices" - "strings" "github.com/cedar-policy/cedar-go/internal/ast" + "github.com/cedar-policy/cedar-go/internal/entities" "github.com/cedar-policy/cedar-go/types" - "golang.org/x/exp/maps" ) // A PolicySet is a slice of policies. @@ -84,25 +82,6 @@ func NewPolicySet(fileName string, document []byte) (PolicySet, error) { return policies, nil } -type Entities = ast.Entities -type Entity = ast.Entity - -func entitiesFromSlice(s []Entity) Entities { - var res = Entities{} - for _, e := range s { - res[e.UID] = e - } - return res -} - -func entitiesToSlice(e Entities) []Entity { - s := maps.Values(e) - slices.SortFunc(s, func(a, b Entity) int { - return strings.Compare(a.UID.String(), b.UID.String()) - }) - return s -} - // A Decision is the result of the authorization. type Decision bool @@ -166,9 +145,9 @@ type evaler = ast.Evaler // IsAuthorized uses the combination of the PolicySet and Entities to determine // if the given Request to determine Decision and Diagnostic. -func (p PolicySet) IsAuthorized(entities Entities, req Request) (Decision, Diagnostic) { +func (p PolicySet) IsAuthorized(entityMap entities.Entities, req Request) (Decision, Diagnostic) { c := &evalContext{ - Entities: entities, + Entities: entityMap, Principal: req.Principal, Action: req.Action, Resource: req.Resource, diff --git a/cedar_test.go b/cedar_test.go index c443cff6..cd4f8fcd 100644 --- a/cedar_test.go +++ b/cedar_test.go @@ -5,6 +5,7 @@ import ( "net/netip" "testing" + "github.com/cedar-policy/cedar-go/internal/entities" "github.com/cedar-policy/cedar-go/internal/testutil" "github.com/cedar-policy/cedar-go/types" ) @@ -53,10 +54,12 @@ func TestNewPolicySet(t *testing.T) { //nolint:revive // due to table test function-length func TestIsAuthorized(t *testing.T) { t.Parallel() + cuzco := types.NewEntityUID("coder", "cuzco") + dropTable := types.NewEntityUID("table", "drop") tests := []struct { Name string Policy string - Entities Entities + Entities entities.Entities Principal, Action, Resource types.EntityUID Context types.Record Want Decision @@ -66,9 +69,9 @@ func TestIsAuthorized(t *testing.T) { { Name: "simple-permit", Policy: `permit(principal,action,resource);`, - Entities: entitiesFromSlice(nil), - Principal: types.NewEntityUID("coder", "cuzco"), - Action: types.NewEntityUID("table", "drop"), + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: true, @@ -77,9 +80,9 @@ func TestIsAuthorized(t *testing.T) { { Name: "simple-forbid", Policy: `forbid(principal,action,resource);`, - Entities: entitiesFromSlice(nil), - Principal: types.NewEntityUID("coder", "cuzco"), - Action: types.NewEntityUID("table", "drop"), + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: false, @@ -88,9 +91,9 @@ func TestIsAuthorized(t *testing.T) { { Name: "no-permit", Policy: `permit(principal,action,resource in asdf::"1234");`, - Entities: entitiesFromSlice(nil), - Principal: types.NewEntityUID("coder", "cuzco"), - Action: types.NewEntityUID("table", "drop"), + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: false, @@ -99,9 +102,9 @@ func TestIsAuthorized(t *testing.T) { { Name: "error-in-policy", Policy: `permit(principal,action,resource) when { resource in "foo" };`, - Entities: entitiesFromSlice(nil), - Principal: types.NewEntityUID("coder", "cuzco"), - Action: types.NewEntityUID("table", "drop"), + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: false, @@ -112,9 +115,9 @@ func TestIsAuthorized(t *testing.T) { Policy: `permit(principal,action,resource) when { resource in "foo" }; permit(principal,action,resource); `, - Entities: entitiesFromSlice(nil), - Principal: types.NewEntityUID("coder", "cuzco"), - Action: types.NewEntityUID("table", "drop"), + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: true, @@ -123,9 +126,9 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-requires-context-success", Policy: `permit(principal,action,resource) when { context.x == 42 };`, - Entities: entitiesFromSlice(nil), - Principal: types.NewEntityUID("coder", "cuzco"), - Action: types.NewEntityUID("table", "drop"), + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{"x": types.Long(42)}, Want: true, @@ -134,9 +137,9 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-requires-context-fail", Policy: `permit(principal,action,resource) when { context.x == 42 };`, - Entities: entitiesFromSlice(nil), - Principal: types.NewEntityUID("coder", "cuzco"), - Action: types.NewEntityUID("table", "drop"), + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{"x": types.Long(43)}, Want: false, @@ -145,14 +148,14 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-requires-entities-success", Policy: `permit(principal,action,resource) when { principal.x == 42 };`, - Entities: entitiesFromSlice([]Entity{ - { - UID: types.NewEntityUID("coder", "cuzco"), + Entities: entities.Entities{ + cuzco: entities.Entity{ + UID: cuzco, Attributes: types.Record{"x": types.Long(42)}, }, - }), - Principal: types.NewEntityUID("coder", "cuzco"), - Action: types.NewEntityUID("table", "drop"), + }, + Principal: cuzco, + Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: true, @@ -161,14 +164,14 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-requires-entities-fail", Policy: `permit(principal,action,resource) when { principal.x == 42 };`, - Entities: entitiesFromSlice([]Entity{ - { - UID: types.NewEntityUID("coder", "cuzco"), + Entities: entities.Entities{ + cuzco: entities.Entity{ + UID: cuzco, Attributes: types.Record{"x": types.Long(43)}, }, - }), - Principal: types.NewEntityUID("coder", "cuzco"), - Action: types.NewEntityUID("table", "drop"), + }, + Principal: cuzco, + Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: false, @@ -177,14 +180,14 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-requires-entities-parent-success", Policy: `permit(principal,action,resource) when { principal in parent::"bob" };`, - Entities: entitiesFromSlice([]Entity{ - { - UID: types.NewEntityUID("coder", "cuzco"), + Entities: entities.Entities{ + cuzco: entities.Entity{ + UID: cuzco, Parents: []types.EntityUID{types.NewEntityUID("parent", "bob")}, }, - }), - Principal: types.NewEntityUID("coder", "cuzco"), - Action: types.NewEntityUID("table", "drop"), + }, + Principal: cuzco, + Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: true, @@ -193,9 +196,9 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-principal-equals", Policy: `permit(principal == coder::"cuzco",action,resource);`, - Entities: entitiesFromSlice(nil), - Principal: types.NewEntityUID("coder", "cuzco"), - Action: types.NewEntityUID("table", "drop"), + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: true, @@ -204,14 +207,14 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-principal-in", Policy: `permit(principal in team::"osiris",action,resource);`, - Entities: entitiesFromSlice([]Entity{ - { - UID: types.NewEntityUID("coder", "cuzco"), + Entities: entities.Entities{ + cuzco: entities.Entity{ + UID: cuzco, Parents: []types.EntityUID{types.NewEntityUID("team", "osiris")}, }, - }), - Principal: types.NewEntityUID("coder", "cuzco"), - Action: types.NewEntityUID("table", "drop"), + }, + Principal: cuzco, + Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: true, @@ -220,9 +223,9 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-action-equals", Policy: `permit(principal,action == table::"drop",resource);`, - Entities: entitiesFromSlice(nil), - Principal: types.NewEntityUID("coder", "cuzco"), - Action: types.NewEntityUID("table", "drop"), + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: true, @@ -231,14 +234,14 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-action-in", Policy: `permit(principal,action in scary::"stuff",resource);`, - Entities: entitiesFromSlice([]Entity{ - { - UID: types.NewEntityUID("table", "drop"), + Entities: entities.Entities{ + dropTable: entities.Entity{ + UID: dropTable, Parents: []types.EntityUID{types.NewEntityUID("scary", "stuff")}, }, - }), - Principal: types.NewEntityUID("coder", "cuzco"), - Action: types.NewEntityUID("table", "drop"), + }, + Principal: cuzco, + Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: true, @@ -247,14 +250,14 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-action-in-set", Policy: `permit(principal,action in [scary::"stuff"],resource);`, - Entities: entitiesFromSlice([]Entity{ - { - UID: types.NewEntityUID("table", "drop"), + Entities: entities.Entities{ + dropTable: entities.Entity{ + UID: dropTable, Parents: []types.EntityUID{types.NewEntityUID("scary", "stuff")}, }, - }), - Principal: types.NewEntityUID("coder", "cuzco"), - Action: types.NewEntityUID("table", "drop"), + }, + Principal: cuzco, + Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: true, @@ -263,9 +266,9 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-resource-equals", Policy: `permit(principal,action,resource == table::"whatever");`, - Entities: entitiesFromSlice(nil), - Principal: types.NewEntityUID("coder", "cuzco"), - Action: types.NewEntityUID("table", "drop"), + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: true, @@ -274,9 +277,9 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-unless", Policy: `permit(principal,action,resource) unless { false };`, - Entities: entitiesFromSlice(nil), - Principal: types.NewEntityUID("coder", "cuzco"), - Action: types.NewEntityUID("table", "drop"), + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: true, @@ -285,9 +288,9 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-if", Policy: `permit(principal,action,resource) when { (if true then true else true) };`, - Entities: entitiesFromSlice(nil), - Principal: types.NewEntityUID("coder", "cuzco"), - Action: types.NewEntityUID("table", "drop"), + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: true, @@ -296,9 +299,9 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-or", Policy: `permit(principal,action,resource) when { (true || false) };`, - Entities: entitiesFromSlice(nil), - Principal: types.NewEntityUID("coder", "cuzco"), - Action: types.NewEntityUID("table", "drop"), + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: true, @@ -307,9 +310,9 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-and", Policy: `permit(principal,action,resource) when { (true && true) };`, - Entities: entitiesFromSlice(nil), - Principal: types.NewEntityUID("coder", "cuzco"), - Action: types.NewEntityUID("table", "drop"), + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: true, @@ -318,9 +321,9 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-relations", Policy: `permit(principal,action,resource) when { (1<2) && (1<=1) && (2>1) && (1>=1) && (1!=2) && (1==1)};`, - Entities: entitiesFromSlice(nil), - Principal: types.NewEntityUID("coder", "cuzco"), - Action: types.NewEntityUID("table", "drop"), + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: true, @@ -329,9 +332,9 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-relations-in", Policy: `permit(principal,action,resource) when { principal in principal };`, - Entities: entitiesFromSlice(nil), - Principal: types.NewEntityUID("coder", "cuzco"), - Action: types.NewEntityUID("table", "drop"), + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: true, @@ -340,14 +343,14 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-relations-has", Policy: `permit(principal,action,resource) when { principal has name };`, - Entities: entitiesFromSlice([]Entity{ - { - UID: types.NewEntityUID("coder", "cuzco"), + Entities: entities.Entities{ + cuzco: entities.Entity{ + UID: cuzco, Attributes: types.Record{"name": types.String("bob")}, }, - }), - Principal: types.NewEntityUID("coder", "cuzco"), - Action: types.NewEntityUID("table", "drop"), + }, + Principal: cuzco, + Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: true, @@ -356,9 +359,9 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-add-sub", Policy: `permit(principal,action,resource) when { 40+3-1==42 };`, - Entities: entitiesFromSlice(nil), - Principal: types.NewEntityUID("coder", "cuzco"), - Action: types.NewEntityUID("table", "drop"), + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: true, @@ -367,9 +370,9 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-mul", Policy: `permit(principal,action,resource) when { 6*7==42 };`, - Entities: entitiesFromSlice(nil), - Principal: types.NewEntityUID("coder", "cuzco"), - Action: types.NewEntityUID("table", "drop"), + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: true, @@ -378,9 +381,9 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-negate", Policy: `permit(principal,action,resource) when { -42==-42 };`, - Entities: entitiesFromSlice(nil), - Principal: types.NewEntityUID("coder", "cuzco"), - Action: types.NewEntityUID("table", "drop"), + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: true, @@ -389,9 +392,9 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-not", Policy: `permit(principal,action,resource) when { !(1+1==42) };`, - Entities: entitiesFromSlice(nil), - Principal: types.NewEntityUID("coder", "cuzco"), - Action: types.NewEntityUID("table", "drop"), + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: true, @@ -400,9 +403,9 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-set", Policy: `permit(principal,action,resource) when { [1,2,3].contains(2) };`, - Entities: entitiesFromSlice(nil), - Principal: types.NewEntityUID("coder", "cuzco"), - Action: types.NewEntityUID("table", "drop"), + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: true, @@ -411,9 +414,9 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-record", Policy: `permit(principal,action,resource) when { {name:"bob"} has name };`, - Entities: entitiesFromSlice(nil), - Principal: types.NewEntityUID("coder", "cuzco"), - Action: types.NewEntityUID("table", "drop"), + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: true, @@ -422,9 +425,9 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-action", Policy: `permit(principal,action,resource) when { action in action };`, - Entities: entitiesFromSlice(nil), - Principal: types.NewEntityUID("coder", "cuzco"), - Action: types.NewEntityUID("table", "drop"), + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: true, @@ -433,9 +436,9 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-set-contains-ok", Policy: `permit(principal,action,resource) when { [1,2,3].contains(2) };`, - Entities: entitiesFromSlice(nil), - Principal: types.NewEntityUID("coder", "cuzco"), - Action: types.NewEntityUID("table", "drop"), + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: true, @@ -444,9 +447,9 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-set-contains-error", Policy: `permit(principal,action,resource) when { [1,2,3].contains(2,3) };`, - Entities: entitiesFromSlice(nil), - Principal: types.NewEntityUID("coder", "cuzco"), - Action: types.NewEntityUID("table", "drop"), + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: false, @@ -456,9 +459,9 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-set-containsAll-ok", Policy: `permit(principal,action,resource) when { [1,2,3].containsAll([2,3]) };`, - Entities: entitiesFromSlice(nil), - Principal: types.NewEntityUID("coder", "cuzco"), - Action: types.NewEntityUID("table", "drop"), + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: true, @@ -467,9 +470,9 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-set-containsAll-error", Policy: `permit(principal,action,resource) when { [1,2,3].containsAll(2,3) };`, - Entities: entitiesFromSlice(nil), - Principal: types.NewEntityUID("coder", "cuzco"), - Action: types.NewEntityUID("table", "drop"), + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: false, @@ -479,9 +482,9 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-set-containsAny-ok", Policy: `permit(principal,action,resource) when { [1,2,3].containsAny([2,5]) };`, - Entities: entitiesFromSlice(nil), - Principal: types.NewEntityUID("coder", "cuzco"), - Action: types.NewEntityUID("table", "drop"), + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: true, @@ -490,9 +493,9 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-set-containsAny-error", Policy: `permit(principal,action,resource) when { [1,2,3].containsAny(2,5) };`, - Entities: entitiesFromSlice(nil), - Principal: types.NewEntityUID("coder", "cuzco"), - Action: types.NewEntityUID("table", "drop"), + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: false, @@ -502,9 +505,9 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-record-attr", Policy: `permit(principal,action,resource) when { {name:"bob"}["name"] == "bob" };`, - Entities: entitiesFromSlice(nil), - Principal: types.NewEntityUID("coder", "cuzco"), - Action: types.NewEntityUID("table", "drop"), + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: true, @@ -513,9 +516,9 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-unknown-method", Policy: `permit(principal,action,resource) when { [1,2,3].shuffle() };`, - Entities: entitiesFromSlice(nil), - Principal: types.NewEntityUID("coder", "cuzco"), - Action: types.NewEntityUID("table", "drop"), + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: false, @@ -525,9 +528,9 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-like", Policy: `permit(principal,action,resource) when { "bananas" like "*nan*" };`, - Entities: entitiesFromSlice(nil), - Principal: types.NewEntityUID("coder", "cuzco"), - Action: types.NewEntityUID("table", "drop"), + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: true, @@ -536,9 +539,9 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-unknown-ext-fun", Policy: `permit(principal,action,resource) when { fooBar("10") };`, - Entities: entitiesFromSlice(nil), - Principal: types.NewEntityUID("coder", "cuzco"), - Action: types.NewEntityUID("table", "drop"), + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: false, @@ -552,9 +555,9 @@ func TestIsAuthorized(t *testing.T) { decimal("10.0").lessThanOrEqual(decimal("11.0")) && decimal("10.0").greaterThan(decimal("9.0")) && decimal("10.0").greaterThanOrEqual(decimal("9.0")) };`, - Entities: entitiesFromSlice(nil), - Principal: types.NewEntityUID("coder", "cuzco"), - Action: types.NewEntityUID("table", "drop"), + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: true, @@ -563,9 +566,9 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-decimal-fun-wrong-arity", Policy: `permit(principal,action,resource) when { decimal(1, 2) };`, - Entities: entitiesFromSlice(nil), - Principal: types.NewEntityUID("coder", "cuzco"), - Action: types.NewEntityUID("table", "drop"), + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: false, @@ -579,9 +582,9 @@ func TestIsAuthorized(t *testing.T) { ip("::1").isLoopback() && ip("224.1.2.3").isMulticast() && ip("127.0.0.1").isInRange(ip("127.0.0.0/16"))};`, - Entities: entitiesFromSlice(nil), - Principal: types.NewEntityUID("coder", "cuzco"), - Action: types.NewEntityUID("table", "drop"), + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: true, @@ -590,9 +593,9 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-ip-fun-wrong-arity", Policy: `permit(principal,action,resource) when { ip() };`, - Entities: entitiesFromSlice(nil), - Principal: types.NewEntityUID("coder", "cuzco"), - Action: types.NewEntityUID("table", "drop"), + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: false, @@ -601,9 +604,9 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-isIpv4-wrong-arity", Policy: `permit(principal,action,resource) when { ip("1.2.3.4").isIpv4(true) };`, - Entities: entitiesFromSlice(nil), - Principal: types.NewEntityUID("coder", "cuzco"), - Action: types.NewEntityUID("table", "drop"), + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: false, @@ -612,9 +615,9 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-isIpv6-wrong-arity", Policy: `permit(principal,action,resource) when { ip("1.2.3.4").isIpv6(true) };`, - Entities: entitiesFromSlice(nil), - Principal: types.NewEntityUID("coder", "cuzco"), - Action: types.NewEntityUID("table", "drop"), + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: false, @@ -623,9 +626,9 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-isLoopback-wrong-arity", Policy: `permit(principal,action,resource) when { ip("1.2.3.4").isLoopback(true) };`, - Entities: entitiesFromSlice(nil), - Principal: types.NewEntityUID("coder", "cuzco"), - Action: types.NewEntityUID("table", "drop"), + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: false, @@ -634,9 +637,9 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-isMulticast-wrong-arity", Policy: `permit(principal,action,resource) when { ip("1.2.3.4").isMulticast(true) };`, - Entities: entitiesFromSlice(nil), - Principal: types.NewEntityUID("coder", "cuzco"), - Action: types.NewEntityUID("table", "drop"), + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: false, @@ -645,9 +648,9 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-isInRange-wrong-arity", Policy: `permit(principal,action,resource) when { ip("1.2.3.4").isInRange() };`, - Entities: entitiesFromSlice(nil), - Principal: types.NewEntityUID("coder", "cuzco"), - Action: types.NewEntityUID("table", "drop"), + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: false, @@ -656,7 +659,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "negative-unary-op", Policy: `permit(principal,action,resource) when { -context.value > 0 };`, - Entities: entitiesFromSlice(nil), + Entities: entities.Entities{}, Context: types.Record{"value": types.Long(-42)}, Want: true, DiagErr: 0, @@ -664,7 +667,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "principal-is", Policy: `permit(principal is Actor,action,resource);`, - Entities: entitiesFromSlice(nil), + Entities: entities.Entities{}, Principal: types.NewEntityUID("Actor", "cuzco"), Action: types.NewEntityUID("Action", "drop"), Resource: types.NewEntityUID("Resource", "table"), @@ -675,7 +678,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "principal-is-in", Policy: `permit(principal is Actor in Actor::"cuzco",action,resource);`, - Entities: entitiesFromSlice(nil), + Entities: entities.Entities{}, Principal: types.NewEntityUID("Actor", "cuzco"), Action: types.NewEntityUID("Action", "drop"), Resource: types.NewEntityUID("Resource", "table"), @@ -686,7 +689,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "resource-is", Policy: `permit(principal,action,resource is Resource);`, - Entities: entitiesFromSlice(nil), + Entities: entities.Entities{}, Principal: types.NewEntityUID("Actor", "cuzco"), Action: types.NewEntityUID("Action", "drop"), Resource: types.NewEntityUID("Resource", "table"), @@ -697,7 +700,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "resource-is-in", Policy: `permit(principal,action,resource is Resource in Resource::"table");`, - Entities: entitiesFromSlice(nil), + Entities: entities.Entities{}, Principal: types.NewEntityUID("Actor", "cuzco"), Action: types.NewEntityUID("Action", "drop"), Resource: types.NewEntityUID("Resource", "table"), @@ -708,7 +711,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "when-is", Policy: `permit(principal,action,resource) when { resource is Resource };`, - Entities: entitiesFromSlice(nil), + Entities: entities.Entities{}, Principal: types.NewEntityUID("Actor", "cuzco"), Action: types.NewEntityUID("Action", "drop"), Resource: types.NewEntityUID("Resource", "table"), @@ -719,7 +722,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "when-is-in", Policy: `permit(principal,action,resource) when { resource is Resource in Resource::"table" };`, - Entities: entitiesFromSlice(nil), + Entities: entities.Entities{}, Principal: types.NewEntityUID("Actor", "cuzco"), Action: types.NewEntityUID("Action", "drop"), Resource: types.NewEntityUID("Resource", "table"), @@ -730,12 +733,12 @@ func TestIsAuthorized(t *testing.T) { { Name: "when-is-in", Policy: `permit(principal,action,resource) when { resource is Resource in Parent::"id" };`, - Entities: entitiesFromSlice([]Entity{ - { + Entities: entities.Entities{ + types.NewEntityUID("Resource", "table"): entities.Entity{ UID: types.NewEntityUID("Resource", "table"), Parents: []types.EntityUID{types.NewEntityUID("Parent", "id")}, }, - }), + }, Principal: types.NewEntityUID("Actor", "cuzco"), Action: types.NewEntityUID("Action", "drop"), Resource: types.NewEntityUID("Resource", "table"), @@ -764,45 +767,16 @@ func TestIsAuthorized(t *testing.T) { func TestEntities(t *testing.T) { t.Parallel() - t.Run("ToSlice", func(t *testing.T) { - t.Parallel() - s := []Entity{ - { - UID: types.EntityUID{Type: "A", ID: "A"}, - }, - { - UID: types.EntityUID{Type: "A", ID: "B"}, - }, - { - UID: types.EntityUID{Type: "B", ID: "A"}, - }, - { - UID: types.EntityUID{Type: "B", ID: "B"}, - }, - } - entities := entitiesFromSlice(s) - s2 := entitiesToSlice(entities) - testutil.Equals(t, s2, s) - }) t.Run("Clone", func(t *testing.T) { t.Parallel() - s := []Entity{ - { - UID: types.EntityUID{Type: "A", ID: "A"}, - }, - { - UID: types.EntityUID{Type: "A", ID: "B"}, - }, - { - UID: types.EntityUID{Type: "B", ID: "A"}, - }, - { - UID: types.EntityUID{Type: "B", ID: "B"}, - }, + e := entities.Entities{ + types.EntityUID{Type: "A", ID: "A"}: {}, + types.EntityUID{Type: "A", ID: "B"}: {}, + types.EntityUID{Type: "B", ID: "A"}: {}, + types.EntityUID{Type: "B", ID: "B"}: {}, } - entities := entitiesFromSlice(s) - clone := entities.Clone() - testutil.Equals(t, clone, entities) + clone := e.Clone() + testutil.Equals(t, clone, e) }) } @@ -1031,7 +1005,7 @@ func TestCorpusRelated(t *testing.T) { t.Parallel() policy, err := NewPolicySet("", []byte(tt.policy)) testutil.OK(t, err) - ok, diag := policy.IsAuthorized(Entities{}, tt.request) + ok, diag := policy.IsAuthorized(entities.Entities{}, tt.request) testutil.Equals(t, ok, tt.decision) var reasons []int for _, n := range diag.Reasons { @@ -1051,8 +1025,8 @@ func TestEntitiesJSON(t *testing.T) { t.Parallel() t.Run("Marshal", func(t *testing.T) { t.Parallel() - e := Entities{} - ent := Entity{ + e := entities.Entities{} + ent := entities.Entity{ UID: types.NewEntityUID("Type", "id"), Parents: []types.EntityUID{}, Attributes: types.Record{"key": types.Long(42)}, @@ -1066,11 +1040,11 @@ func TestEntitiesJSON(t *testing.T) { t.Run("Unmarshal", func(t *testing.T) { t.Parallel() b := []byte(`[{"uid":{"type":"Type","id":"id"},"parents":[],"attrs":{"key":42}}]`) - var e Entities + var e entities.Entities err := json.Unmarshal(b, &e) testutil.OK(t, err) - want := Entities{} - ent := Entity{ + want := entities.Entities{} + ent := entities.Entity{ UID: types.NewEntityUID("Type", "id"), Parents: []types.EntityUID{}, Attributes: types.Record{"key": types.Long(42)}, @@ -1081,7 +1055,7 @@ func TestEntitiesJSON(t *testing.T) { t.Run("UnmarshalErr", func(t *testing.T) { t.Parallel() - var e Entities + var e entities.Entities err := e.UnmarshalJSON([]byte(`!@#$`)) testutil.Error(t, err) }) diff --git a/corpus_test.go b/corpus_test.go index 305c54f9..b544570e 100644 --- a/corpus_test.go +++ b/corpus_test.go @@ -12,6 +12,7 @@ import ( "strings" "testing" + entities2 "github.com/cedar-policy/cedar-go/internal/entities" "github.com/cedar-policy/cedar-go/types" ) @@ -138,7 +139,7 @@ func TestCorpus(t *testing.T) { t.Fatal("error reading entities content", err) } - var entities Entities + var entities entities2.Entities if err := json.Unmarshal(entitiesContent, &entities); err != nil { t.Fatal("error unmarshalling test", err) } diff --git a/internal/ast/eval_impl.go b/internal/ast/eval_impl.go index af9f1308..edcb8ee6 100644 --- a/internal/ast/eval_impl.go +++ b/internal/ast/eval_impl.go @@ -3,6 +3,7 @@ package ast import ( "fmt" + "github.com/cedar-policy/cedar-go/internal/entities" "github.com/cedar-policy/cedar-go/types" ) @@ -15,7 +16,7 @@ var errUnspecifiedEntity = fmt.Errorf("unspecified entity") // TODO: make private again type EvalContext struct { - Entities Entities + Entities entities.Entities Principal, Action, Resource types.Value Context types.Value } @@ -921,7 +922,7 @@ func newInEval(lhs, rhs Evaler) *inEval { return &inEval{lhs: lhs, rhs: rhs} } -func entityIn(entity types.EntityUID, query map[types.EntityUID]struct{}, entities Entities) bool { +func entityIn(entity types.EntityUID, query map[types.EntityUID]struct{}, entityMap entities.Entities) bool { checked := map[types.EntityUID]struct{}{} toCheck := []types.EntityUID{entity} for len(toCheck) > 0 { @@ -933,7 +934,7 @@ func entityIn(entity types.EntityUID, query map[types.EntityUID]struct{}, entiti if _, ok := query[candidate]; ok { return true } - toCheck = append(toCheck, entities[candidate].Parents...) + toCheck = append(toCheck, entityMap[candidate].Parents...) checked[candidate] = struct{}{} } return false diff --git a/internal/ast/eval_test.go b/internal/ast/eval_test.go index 59559fef..c2c6af7b 100644 --- a/internal/ast/eval_test.go +++ b/internal/ast/eval_test.go @@ -6,6 +6,7 @@ import ( "strings" "testing" + "github.com/cedar-policy/cedar-go/internal/entities" "github.com/cedar-policy/cedar-go/internal/testutil" "github.com/cedar-policy/cedar-go/types" ) @@ -1283,13 +1284,14 @@ func TestAttributeAccessNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newAttributeAccessEval(tt.object, tt.attribute) + entity := entities.Entity{ + UID: types.NewEntityUID("knownType", "knownID"), + Attributes: types.Record{"knownAttr": types.Long(42)}, + } v, err := n.Eval(&EvalContext{ - Entities: entitiesFromSlice([]Entity{ - { - UID: types.NewEntityUID("knownType", "knownID"), - Attributes: types.Record{"knownAttr": types.Long(42)}, - }, - }), + Entities: entities.Entities{ + entity.UID: entity, + }, }) testutil.AssertError(t, err, tt.err) types.AssertValue(t, v, tt.result) @@ -1339,13 +1341,14 @@ func TestHasNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newHasEval(tt.record, tt.attribute) + entity := entities.Entity{ + UID: types.NewEntityUID("knownType", "knownID"), + Attributes: types.Record{"knownAttr": types.Long(42)}, + } v, err := n.Eval(&EvalContext{ - Entities: entitiesFromSlice([]Entity{ - { - UID: types.NewEntityUID("knownType", "knownID"), - Attributes: types.Record{"knownAttr": types.Long(42)}, - }, - }), + Entities: entities.Entities{ + entity.UID: entity, + }, }) testutil.AssertError(t, err, tt.err) types.AssertValue(t, v, tt.result) @@ -1535,19 +1538,19 @@ func TestEntityIn(t *testing.T) { for _, v := range tt.rhs { rhs[strEnt(v)] = struct{}{} } - entities := Entities{} + entityMap := entities.Entities{} for k, p := range tt.parents { var ps []types.EntityUID for _, pp := range p { ps = append(ps, strEnt(pp)) } uid := strEnt(k) - entities[uid] = Entity{ + entityMap[uid] = entities.Entity{ UID: uid, Parents: ps, } } - res := entityIn(strEnt(tt.lhs), rhs, entities) + res := entityIn(strEnt(tt.lhs), rhs, entityMap) testutil.Equals(t, res, tt.result) }) } @@ -1556,25 +1559,25 @@ func TestEntityIn(t *testing.T) { // This test will run for a very long time (O(2^100)) if there isn't caching. ) - entities := Entities{} + entityMap := entities.Entities{} for i := 0; i < 100; i++ { p := []types.EntityUID{ types.NewEntityUID(fmt.Sprint(i+1), "1"), types.NewEntityUID(fmt.Sprint(i+1), "2"), } uid1 := types.NewEntityUID(fmt.Sprint(i), "1") - entities[uid1] = Entity{ + entityMap[uid1] = entities.Entity{ UID: uid1, Parents: p, } uid2 := types.NewEntityUID(fmt.Sprint(i), "2") - entities[uid2] = Entity{ + entityMap[uid2] = entities.Entity{ UID: uid2, Parents: p, } } - res := entityIn(types.NewEntityUID("0", "1"), map[types.EntityUID]struct{}{types.NewEntityUID("0", "3"): {}}, entities) + res := entityIn(types.NewEntityUID("0", "1"), map[types.EntityUID]struct{}{types.NewEntityUID("0", "3"): {}}, entityMap) testutil.Equals(t, res, false) }) } @@ -1702,19 +1705,19 @@ func TestInNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newInEval(tt.lhs, tt.rhs) - entities := Entities{} + entityMap := entities.Entities{} for k, p := range tt.parents { var ps []types.EntityUID for _, pp := range p { ps = append(ps, strEnt(pp)) } uid := strEnt(k) - entities[uid] = Entity{ + entityMap[uid] = entities.Entity{ UID: uid, Parents: ps, } } - ec := EvalContext{Entities: entities} + ec := EvalContext{Entities: entityMap} v, err := n.Eval(&ec) testutil.AssertError(t, err, tt.err) types.AssertValue(t, v, tt.result) diff --git a/internal/ast/cedar.go b/internal/entities/entities.go similarity index 91% rename from internal/ast/cedar.go rename to internal/entities/entities.go index e5f94e5c..45d7757c 100644 --- a/internal/ast/cedar.go +++ b/internal/entities/entities.go @@ -1,7 +1,4 @@ -package ast - -// TODO: this is a partial cut-and-paste from the main cedar package -// and will need completion / review +package entities import ( "encoding/json" From bdec39574a5bbcfee637b30bc54ef84aad50681c Mon Sep 17 00:00:00 2001 From: philhassey Date: Tue, 13 Aug 2024 11:59:17 -0600 Subject: [PATCH 084/216] internal/ast: publicized the policy/node/etc shapes Addresses IDX-142 Signed-off-by: philhassey --- internal/ast/annotation.go | 14 +- internal/ast/cedar_marshal.go | 126 +++++++++--------- internal/ast/cedar_unmarshal.go | 6 +- internal/ast/eval_compile.go | 18 +-- internal/ast/eval_convert.go | 58 ++++----- internal/ast/json_marshal.go | 150 ++++++++++----------- internal/ast/json_unmarshal.go | 8 +- internal/ast/node.go | 224 ++++++++++++++++---------------- internal/ast/operator.go | 46 +++---- internal/ast/policy.go | 52 ++++---- internal/ast/scope.go | 92 ++++++------- internal/ast/value.go | 16 +-- internal/ast/variable.go | 40 ++---- 13 files changed, 417 insertions(+), 433 deletions(-) diff --git a/internal/ast/annotation.go b/internal/ast/annotation.go index d17a7434..895cd1ec 100644 --- a/internal/ast/annotation.go +++ b/internal/ast/annotation.go @@ -3,7 +3,7 @@ package ast import "github.com/cedar-policy/cedar-go/types" type Annotations struct { - nodes []annotationType + nodes []AnnotationType } // Annotation allows AST constructors to make policy in a similar shape to textual Cedar with @@ -14,7 +14,7 @@ type Annotations struct { // Permit(). // PrincipalEq(superUser) func Annotation(name, value types.String) *Annotations { - return &Annotations{nodes: []annotationType{newAnnotation(name, value)}} + return &Annotations{nodes: []AnnotationType{newAnnotation(name, value)}} } func (a *Annotations) Annotation(name, value types.String) *Annotations { @@ -23,18 +23,18 @@ func (a *Annotations) Annotation(name, value types.String) *Annotations { } func (a *Annotations) Permit() *Policy { - return newPolicy(effectPermit, a.nodes) + return newPolicy(EffectPermit, a.nodes) } func (a *Annotations) Forbid() *Policy { - return newPolicy(effectForbid, a.nodes) + return newPolicy(EffectForbid, a.nodes) } func (p *Policy) Annotate(name, value types.String) *Policy { - p.annotations = append(p.annotations, annotationType{Key: name, Value: value}) + p.Annotations = append(p.Annotations, AnnotationType{Key: name, Value: value}) return p } -func newAnnotation(name, value types.String) annotationType { - return annotationType{Key: name, Value: value} +func newAnnotation(name, value types.String) AnnotationType { + return AnnotationType{Key: name, Value: value} } diff --git a/internal/ast/cedar_marshal.go b/internal/ast/cedar_marshal.go index 2b83551d..1a6db61e 100644 --- a/internal/ast/cedar_marshal.go +++ b/internal/ast/cedar_marshal.go @@ -6,15 +6,15 @@ import ( // TODO: Add errors to all of this! TODO: review this ask, I'm not sure any real errors are possible. All buf errors are panics. func (p *Policy) MarshalCedar(buf *bytes.Buffer) { - for _, a := range p.annotations { + for _, a := range p.Annotations { a.MarshalCedar(buf) buf.WriteRune('\n') } - p.effect.MarshalCedar(buf) + p.Effect.MarshalCedar(buf) buf.WriteRune(' ') p.marshalScope(buf) - for _, c := range p.conditions { + for _, c := range p.Conditions { buf.WriteRune('\n') c.MarshalCedar(buf) } @@ -23,24 +23,24 @@ func (p *Policy) MarshalCedar(buf *bytes.Buffer) { } func (p *Policy) marshalScope(buf *bytes.Buffer) { - _, principalAll := p.principal.(scopeTypeAll) - _, actionAll := p.action.(scopeTypeAll) - _, resourceAll := p.resource.(scopeTypeAll) + _, principalAll := p.Principal.(ScopeTypeAll) + _, actionAll := p.Action.(ScopeTypeAll) + _, resourceAll := p.Resource.(ScopeTypeAll) if principalAll && actionAll && resourceAll { buf.WriteString("( principal, action, resource )") return } buf.WriteString("(\n ") - p.principal.MarshalCedar(buf) + p.Principal.MarshalCedar(buf) buf.WriteString(",\n ") - p.action.MarshalCedar(buf) + p.Action.MarshalCedar(buf) buf.WriteString(",\n ") - p.resource.MarshalCedar(buf) + p.Resource.MarshalCedar(buf) buf.WriteString("\n)") } -func (n annotationType) MarshalCedar(buf *bytes.Buffer) { +func (n AnnotationType) MarshalCedar(buf *bytes.Buffer) { buf.WriteRune('@') buf.WriteString(string(n.Key)) buf.WriteRune('(') @@ -48,35 +48,35 @@ func (n annotationType) MarshalCedar(buf *bytes.Buffer) { buf.WriteString(")") } -func (e effect) MarshalCedar(buf *bytes.Buffer) { - if e == effectPermit { +func (e Effect) MarshalCedar(buf *bytes.Buffer) { + if e == EffectPermit { buf.WriteString("permit") } else { buf.WriteString("forbid") } } -func (n nodeTypeVariable) marshalCedar(buf *bytes.Buffer) { +func (n NodeTypeVariable) marshalCedar(buf *bytes.Buffer) { buf.WriteString(string(n.Name)) } -func (n scopeTypeAll) MarshalCedar(buf *bytes.Buffer) { +func (n ScopeTypeAll) MarshalCedar(buf *bytes.Buffer) { n.Variable.marshalCedar(buf) } -func (n scopeTypeEq) MarshalCedar(buf *bytes.Buffer) { +func (n ScopeTypeEq) MarshalCedar(buf *bytes.Buffer) { n.Variable.marshalCedar(buf) buf.WriteString(" == ") buf.WriteString(n.Entity.Cedar()) } -func (n scopeTypeIn) MarshalCedar(buf *bytes.Buffer) { +func (n ScopeTypeIn) MarshalCedar(buf *bytes.Buffer) { n.Variable.marshalCedar(buf) buf.WriteString(" in ") buf.WriteString(n.Entity.Cedar()) } -func (n scopeTypeInSet) MarshalCedar(buf *bytes.Buffer) { +func (n ScopeTypeInSet) MarshalCedar(buf *bytes.Buffer) { n.Variable.marshalCedar(buf) buf.WriteString(" in ") buf.WriteRune('[') @@ -89,13 +89,13 @@ func (n scopeTypeInSet) MarshalCedar(buf *bytes.Buffer) { buf.WriteRune(']') } -func (n scopeTypeIs) MarshalCedar(buf *bytes.Buffer) { +func (n ScopeTypeIs) MarshalCedar(buf *bytes.Buffer) { n.Variable.marshalCedar(buf) buf.WriteString(" is ") buf.WriteString(string(n.Type)) } -func (n scopeTypeIsIn) MarshalCedar(buf *bytes.Buffer) { +func (n ScopeTypeIsIn) MarshalCedar(buf *bytes.Buffer) { n.Variable.marshalCedar(buf) buf.WriteString(" is ") buf.WriteString(string(n.Type)) @@ -103,8 +103,8 @@ func (n scopeTypeIsIn) MarshalCedar(buf *bytes.Buffer) { buf.WriteString(n.Entity.Cedar()) } -func (c conditionType) MarshalCedar(buf *bytes.Buffer) { - if c.Condition == conditionWhen { +func (c ConditionType) MarshalCedar(buf *bytes.Buffer) { + if c.Condition == ConditionWhen { buf.WriteString("when") } else { buf.WriteString("unless") @@ -115,11 +115,11 @@ func (c conditionType) MarshalCedar(buf *bytes.Buffer) { buf.WriteString(" }") } -func (n nodeValue) marshalCedar(buf *bytes.Buffer) { +func (n NodeValue) marshalCedar(buf *bytes.Buffer) { buf.WriteString(n.Value.Cedar()) } -func marshalChildNode(thisNodePrecedence nodePrecedenceLevel, childNode node, buf *bytes.Buffer) { +func marshalChildNode(thisNodePrecedence nodePrecedenceLevel, childNode IsNode, buf *bytes.Buffer) { if thisNodePrecedence > childNode.precedenceLevel() { buf.WriteRune('(') childNode.marshalCedar(buf) @@ -129,12 +129,12 @@ func marshalChildNode(thisNodePrecedence nodePrecedenceLevel, childNode node, bu } } -func (n nodeTypeNot) marshalCedar(buf *bytes.Buffer) { +func (n NodeTypeNot) marshalCedar(buf *bytes.Buffer) { buf.WriteRune('!') marshalChildNode(n.precedenceLevel(), n.Arg, buf) } -func (n nodeTypeNegate) marshalCedar(buf *bytes.Buffer) { +func (n NodeTypeNegate) marshalCedar(buf *bytes.Buffer) { buf.WriteRune('-') marshalChildNode(n.precedenceLevel(), n.Arg, buf) } @@ -148,7 +148,7 @@ func canMarshalAsIdent(s string) bool { return true } -func (n nodeTypeAccess) marshalCedar(buf *bytes.Buffer) { +func (n NodeTypeAccess) marshalCedar(buf *bytes.Buffer) { marshalChildNode(n.precedenceLevel(), n.Arg, buf) if canMarshalAsIdent(string(n.Value)) { @@ -161,8 +161,8 @@ func (n nodeTypeAccess) marshalCedar(buf *bytes.Buffer) { } } -func (n nodeTypeExtensionCall) marshalCedar(buf *bytes.Buffer) { - var args []node +func (n NodeTypeExtensionCall) marshalCedar(buf *bytes.Buffer) { + var args []IsNode info := extMap[n.Name] if info.IsMethod { marshalChildNode(n.precedenceLevel(), n.Args[0], buf) @@ -182,28 +182,28 @@ func (n nodeTypeExtensionCall) marshalCedar(buf *bytes.Buffer) { buf.WriteRune(')') } -func (n nodeTypeContains) marshalCedar(buf *bytes.Buffer) { +func (n NodeTypeContains) marshalCedar(buf *bytes.Buffer) { marshalChildNode(n.precedenceLevel(), n.Left, buf) buf.WriteString(".contains(") marshalChildNode(n.precedenceLevel(), n.Right, buf) buf.WriteRune(')') } -func (n nodeTypeContainsAll) marshalCedar(buf *bytes.Buffer) { +func (n NodeTypeContainsAll) marshalCedar(buf *bytes.Buffer) { marshalChildNode(n.precedenceLevel(), n.Left, buf) buf.WriteString(".containsAll(") marshalChildNode(n.precedenceLevel(), n.Right, buf) buf.WriteRune(')') } -func (n nodeTypeContainsAny) marshalCedar(buf *bytes.Buffer) { +func (n NodeTypeContainsAny) marshalCedar(buf *bytes.Buffer) { marshalChildNode(n.precedenceLevel(), n.Left, buf) buf.WriteString(".containsAny(") marshalChildNode(n.precedenceLevel(), n.Right, buf) buf.WriteRune(')') } -func (n nodeTypeSet) marshalCedar(buf *bytes.Buffer) { +func (n NodeTypeSet) marshalCedar(buf *bytes.Buffer) { buf.WriteRune('[') for i := range n.Elements { marshalChildNode(n.precedenceLevel(), n.Elements[i], buf) @@ -214,7 +214,7 @@ func (n nodeTypeSet) marshalCedar(buf *bytes.Buffer) { buf.WriteRune(']') } -func (n nodeTypeRecord) marshalCedar(buf *bytes.Buffer) { +func (n NodeTypeRecord) marshalCedar(buf *bytes.Buffer) { buf.WriteRune('{') for i := range n.Elements { buf.WriteString(n.Elements[i].Key.Cedar()) @@ -227,7 +227,7 @@ func (n nodeTypeRecord) marshalCedar(buf *bytes.Buffer) { buf.WriteRune('}') } -func marshalInfixBinaryOp(n binaryNode, precedence nodePrecedenceLevel, op string, buf *bytes.Buffer) { +func marshalInfixBinaryOp(n BinaryNode, precedence nodePrecedenceLevel, op string, buf *bytes.Buffer) { marshalChildNode(precedence, n.Left, buf) buf.WriteRune(' ') buf.WriteString(op) @@ -235,55 +235,55 @@ func marshalInfixBinaryOp(n binaryNode, precedence nodePrecedenceLevel, op strin marshalChildNode(precedence, n.Right, buf) } -func (n nodeTypeMult) marshalCedar(buf *bytes.Buffer) { - marshalInfixBinaryOp(n.binaryNode, n.precedenceLevel(), "*", buf) +func (n NodeTypeMult) marshalCedar(buf *bytes.Buffer) { + marshalInfixBinaryOp(n.BinaryNode, n.precedenceLevel(), "*", buf) } -func (n nodeTypeAdd) marshalCedar(buf *bytes.Buffer) { - marshalInfixBinaryOp(n.binaryNode, n.precedenceLevel(), "+", buf) +func (n NodeTypeAdd) marshalCedar(buf *bytes.Buffer) { + marshalInfixBinaryOp(n.BinaryNode, n.precedenceLevel(), "+", buf) } -func (n nodeTypeSub) marshalCedar(buf *bytes.Buffer) { - marshalInfixBinaryOp(n.binaryNode, n.precedenceLevel(), "-", buf) +func (n NodeTypeSub) marshalCedar(buf *bytes.Buffer) { + marshalInfixBinaryOp(n.BinaryNode, n.precedenceLevel(), "-", buf) } -func (n nodeTypeLessThan) marshalCedar(buf *bytes.Buffer) { - marshalInfixBinaryOp(n.binaryNode, n.precedenceLevel(), "<", buf) +func (n NodeTypeLessThan) marshalCedar(buf *bytes.Buffer) { + marshalInfixBinaryOp(n.BinaryNode, n.precedenceLevel(), "<", buf) } -func (n nodeTypeLessThanOrEqual) marshalCedar(buf *bytes.Buffer) { - marshalInfixBinaryOp(n.binaryNode, n.precedenceLevel(), "<=", buf) +func (n NodeTypeLessThanOrEqual) marshalCedar(buf *bytes.Buffer) { + marshalInfixBinaryOp(n.BinaryNode, n.precedenceLevel(), "<=", buf) } -func (n nodeTypeGreaterThan) marshalCedar(buf *bytes.Buffer) { - marshalInfixBinaryOp(n.binaryNode, n.precedenceLevel(), ">", buf) +func (n NodeTypeGreaterThan) marshalCedar(buf *bytes.Buffer) { + marshalInfixBinaryOp(n.BinaryNode, n.precedenceLevel(), ">", buf) } -func (n nodeTypeGreaterThanOrEqual) marshalCedar(buf *bytes.Buffer) { - marshalInfixBinaryOp(n.binaryNode, n.precedenceLevel(), ">=", buf) +func (n NodeTypeGreaterThanOrEqual) marshalCedar(buf *bytes.Buffer) { + marshalInfixBinaryOp(n.BinaryNode, n.precedenceLevel(), ">=", buf) } -func (n nodeTypeEquals) marshalCedar(buf *bytes.Buffer) { - marshalInfixBinaryOp(n.binaryNode, n.precedenceLevel(), "==", buf) +func (n NodeTypeEquals) marshalCedar(buf *bytes.Buffer) { + marshalInfixBinaryOp(n.BinaryNode, n.precedenceLevel(), "==", buf) } -func (n nodeTypeNotEquals) marshalCedar(buf *bytes.Buffer) { - marshalInfixBinaryOp(n.binaryNode, n.precedenceLevel(), "!=", buf) +func (n NodeTypeNotEquals) marshalCedar(buf *bytes.Buffer) { + marshalInfixBinaryOp(n.BinaryNode, n.precedenceLevel(), "!=", buf) } -func (n nodeTypeIn) marshalCedar(buf *bytes.Buffer) { - marshalInfixBinaryOp(n.binaryNode, n.precedenceLevel(), "in", buf) +func (n NodeTypeIn) marshalCedar(buf *bytes.Buffer) { + marshalInfixBinaryOp(n.BinaryNode, n.precedenceLevel(), "in", buf) } -func (n nodeTypeAnd) marshalCedar(buf *bytes.Buffer) { - marshalInfixBinaryOp(n.binaryNode, n.precedenceLevel(), "&&", buf) +func (n NodeTypeAnd) marshalCedar(buf *bytes.Buffer) { + marshalInfixBinaryOp(n.BinaryNode, n.precedenceLevel(), "&&", buf) } -func (n nodeTypeOr) marshalCedar(buf *bytes.Buffer) { - marshalInfixBinaryOp(n.binaryNode, n.precedenceLevel(), "||", buf) +func (n NodeTypeOr) marshalCedar(buf *bytes.Buffer) { + marshalInfixBinaryOp(n.BinaryNode, n.precedenceLevel(), "||", buf) } -func (n nodeTypeHas) marshalCedar(buf *bytes.Buffer) { +func (n NodeTypeHas) marshalCedar(buf *bytes.Buffer) { marshalChildNode(n.precedenceLevel(), n.Arg, buf) buf.WriteString(" has ") if canMarshalAsIdent(string(n.Value)) { @@ -293,13 +293,13 @@ func (n nodeTypeHas) marshalCedar(buf *bytes.Buffer) { } } -func (n nodeTypeIs) marshalCedar(buf *bytes.Buffer) { +func (n NodeTypeIs) marshalCedar(buf *bytes.Buffer) { marshalChildNode(n.precedenceLevel(), n.Left, buf) buf.WriteString(" is ") buf.WriteString(string(n.EntityType)) } -func (n nodeTypeIsIn) marshalCedar(buf *bytes.Buffer) { +func (n NodeTypeIsIn) marshalCedar(buf *bytes.Buffer) { marshalChildNode(n.precedenceLevel(), n.Left, buf) buf.WriteString(" is ") buf.WriteString(string(n.EntityType)) @@ -307,13 +307,13 @@ func (n nodeTypeIsIn) marshalCedar(buf *bytes.Buffer) { n.Entity.marshalCedar(buf) } -func (n nodeTypeLike) marshalCedar(buf *bytes.Buffer) { +func (n NodeTypeLike) marshalCedar(buf *bytes.Buffer) { marshalChildNode(n.precedenceLevel(), n.Arg, buf) buf.WriteString(" like ") buf.WriteString(n.Value.Cedar()) } -func (n nodeTypeIf) marshalCedar(buf *bytes.Buffer) { +func (n NodeTypeIf) marshalCedar(buf *bytes.Buffer) { buf.WriteString("if ") marshalChildNode(n.precedenceLevel(), n.If, buf) buf.WriteString(" then ") diff --git a/internal/ast/cedar_unmarshal.go b/internal/ast/cedar_unmarshal.go index 243310f8..f05e080e 100644 --- a/internal/ast/cedar_unmarshal.go +++ b/internal/ast/cedar_unmarshal.go @@ -21,9 +21,9 @@ func (p *PolicySet) UnmarshalCedar(b []byte) error { for !parser.peek().isEOF() { pos := parser.peek().Pos policy := Policy{ - principal: scopeTypeAll{}, - action: scopeTypeAll{}, - resource: scopeTypeAll{}, + Principal: ScopeTypeAll{}, + Action: ScopeTypeAll{}, + Resource: ScopeTypeAll{}, } if err = policy.fromCedarWithParser(&parser); err != nil { diff --git a/internal/ast/eval_compile.go b/internal/ast/eval_compile.go index b28ce1bf..9d3ff8f7 100644 --- a/internal/ast/eval_compile.go +++ b/internal/ast/eval_compile.go @@ -7,14 +7,14 @@ type CompiledPolicy struct { } func (p PolicySetEntry) TmpGetAnnotations() map[string]string { - res := make(map[string]string, len(p.Policy.annotations)) - for _, e := range p.Policy.annotations { + res := make(map[string]string, len(p.Policy.Annotations)) + for _, e := range p.Policy.Annotations { res[string(e.Key)] = string(e.Value) } return res } func (p PolicySetEntry) TmpGetEffect() bool { - return bool(p.Policy.effect) + return bool(p.Policy.Effect) } func Compile(p Policy) Evaler { @@ -23,12 +23,12 @@ func Compile(p Policy) Evaler { } func policyToNode(p Policy) Node { - nodes := make([]Node, 3+len(p.conditions)) - nodes[0] = p.principal.toNode() - nodes[1] = p.action.toNode() - nodes[2] = p.resource.toNode() - for i, c := range p.conditions { - if c.Condition == conditionUnless { + nodes := make([]Node, 3+len(p.Conditions)) + nodes[0] = p.Principal.toNode() + nodes[1] = p.Action.toNode() + nodes[2] = p.Resource.toNode() + for i, c := range p.Conditions { + if c.Condition == ConditionUnless { nodes[i+3] = Not(newNode(c.Body)) continue } diff --git a/internal/ast/eval_convert.go b/internal/ast/eval_convert.go index 4fcd452f..bce95d9f 100644 --- a/internal/ast/eval_convert.go +++ b/internal/ast/eval_convert.go @@ -4,24 +4,24 @@ import ( "fmt" ) -func toEval(n node) Evaler { +func toEval(n IsNode) Evaler { switch v := n.(type) { - case nodeTypeAccess: + case NodeTypeAccess: return newAttributeAccessEval(toEval(v.Arg), string(v.Value)) - case nodeTypeHas: + case NodeTypeHas: return newHasEval(toEval(v.Arg), string(v.Value)) - case nodeTypeLike: + case NodeTypeLike: return newLikeEval(toEval(v.Arg), v.Value) - case nodeTypeIf: + case NodeTypeIf: return newIfThenElseEval(toEval(v.If), toEval(v.Then), toEval(v.Else)) - case nodeTypeIs: + case NodeTypeIs: return newIsEval(toEval(v.Left), newLiteralEval(v.EntityType)) - case nodeTypeIsIn: + case NodeTypeIsIn: obj := toEval(v.Left) lhs := newIsEval(obj, newLiteralEval(v.EntityType)) rhs := newInEval(obj, toEval(v.Entity)) return newAndEval(lhs, rhs) - case nodeTypeExtensionCall: + case NodeTypeExtensionCall: i, ok := extMap[v.Name] if !ok { return newErrorEval(fmt.Errorf("%w: %s", errUnknownExtensionFunction, v.Name)) @@ -57,25 +57,25 @@ func toEval(n node) Evaler { default: panic(fmt.Errorf("unknown extension: %v", v.Name)) } - case nodeValue: + case NodeValue: return newLiteralEval(v.Value) - case nodeTypeRecord: + case NodeTypeRecord: m := make(map[string]Evaler, len(v.Elements)) for _, e := range v.Elements { m[string(e.Key)] = toEval(e.Value) } return newRecordLiteralEval(m) - case nodeTypeSet: + case NodeTypeSet: s := make([]Evaler, len(v.Elements)) for i, e := range v.Elements { s[i] = toEval(e) } return newSetLiteralEval(s) - case nodeTypeNegate: + case NodeTypeNegate: return newNegateEval(toEval(v.Arg)) - case nodeTypeNot: + case NodeTypeNot: return newNotEval(toEval(v.Arg)) - case nodeTypeVariable: + case NodeTypeVariable: switch v.Name { case "principal": return newVariableEval(variableNamePrincipal) @@ -88,35 +88,35 @@ func toEval(n node) Evaler { default: panic(fmt.Errorf("unknown variable: %v", v.Name)) } - case nodeTypeIn: + case NodeTypeIn: return newInEval(toEval(v.Left), toEval(v.Right)) - case nodeTypeAnd: + case NodeTypeAnd: return newAndEval(toEval(v.Left), toEval(v.Right)) - case nodeTypeEquals: + case NodeTypeEquals: return newEqualEval(toEval(v.Left), toEval(v.Right)) - case nodeTypeGreaterThan: + case NodeTypeGreaterThan: return newLongGreaterThanEval(toEval(v.Left), toEval(v.Right)) - case nodeTypeGreaterThanOrEqual: + case NodeTypeGreaterThanOrEqual: return newLongGreaterThanOrEqualEval(toEval(v.Left), toEval(v.Right)) - case nodeTypeLessThan: + case NodeTypeLessThan: return newLongLessThanEval(toEval(v.Left), toEval(v.Right)) - case nodeTypeLessThanOrEqual: + case NodeTypeLessThanOrEqual: return newLongLessThanOrEqualEval(toEval(v.Left), toEval(v.Right)) - case nodeTypeSub: + case NodeTypeSub: return newSubtractEval(toEval(v.Left), toEval(v.Right)) - case nodeTypeAdd: + case NodeTypeAdd: return newAddEval(toEval(v.Left), toEval(v.Right)) - case nodeTypeContains: + case NodeTypeContains: return newContainsEval(toEval(v.Left), toEval(v.Right)) - case nodeTypeContainsAll: + case NodeTypeContainsAll: return newContainsAllEval(toEval(v.Left), toEval(v.Right)) - case nodeTypeContainsAny: + case NodeTypeContainsAny: return newContainsAnyEval(toEval(v.Left), toEval(v.Right)) - case nodeTypeMult: + case NodeTypeMult: return newMultiplyEval(toEval(v.Left), toEval(v.Right)) - case nodeTypeNotEquals: + case NodeTypeNotEquals: return newNotEqualEval(toEval(v.Left), toEval(v.Right)) - case nodeTypeOr: + case NodeTypeOr: return newOrNode(toEval(v.Left), toEval(v.Right)) default: panic(fmt.Sprintf("unknown node type %T", v)) diff --git a/internal/ast/json_marshal.go b/internal/ast/json_marshal.go index ee38ca24..8d228280 100644 --- a/internal/ast/json_marshal.go +++ b/internal/ast/json_marshal.go @@ -7,30 +7,30 @@ import ( "github.com/cedar-policy/cedar-go/types" ) -func (s *scopeJSON) FromNode(src isScopeNode) error { +func (s *scopeJSON) FromNode(src IsScopeNode) error { switch t := src.(type) { - case scopeTypeAll: + case ScopeTypeAll: s.Op = "All" return nil - case scopeTypeEq: + case ScopeTypeEq: s.Op = "==" e := t.Entity s.Entity = &e return nil - case scopeTypeIn: + case ScopeTypeIn: s.Op = "in" e := t.Entity s.Entity = &e return nil - case scopeTypeInSet: + case ScopeTypeInSet: s.Op = "in" s.Entities = t.Entities return nil - case scopeTypeIs: + case ScopeTypeIs: s.Op = "is" s.EntityType = string(t.Type) return nil - case scopeTypeIsIn: + case ScopeTypeIsIn: s.Op = "is" s.EntityType = string(t.Type) s.In = &scopeInJSON{ @@ -41,8 +41,8 @@ func (s *scopeJSON) FromNode(src isScopeNode) error { return fmt.Errorf("unexpected scope node: %T", src) } -func unaryToJSON(dest **unaryJSON, src unaryNode) error { - n := unaryNode(src) +func unaryToJSON(dest **unaryJSON, src UnaryNode) error { + n := UnaryNode(src) res := &unaryJSON{} if err := res.Arg.FromNode(n.Arg); err != nil { return fmt.Errorf("error in arg: %w", err) @@ -51,8 +51,8 @@ func unaryToJSON(dest **unaryJSON, src unaryNode) error { return nil } -func binaryToJSON(dest **binaryJSON, src binaryNode) error { - n := binaryNode(src) +func binaryToJSON(dest **binaryJSON, src BinaryNode) error { + n := BinaryNode(src) res := &binaryJSON{} if err := res.Left.FromNode(n.Left); err != nil { return fmt.Errorf("error in left: %w", err) @@ -64,7 +64,7 @@ func binaryToJSON(dest **binaryJSON, src binaryNode) error { return nil } -func arrayToJSON(dest *arrayJSON, args []node) error { +func arrayToJSON(dest *arrayJSON, args []IsNode) error { res := arrayJSON{} for _, n := range args { var nn nodeJSON @@ -90,7 +90,7 @@ func extToJSON(dest *extensionCallJSON, name string, src types.Value) error { return nil } -func extCallToJSON(dest extensionCallJSON, src nodeTypeExtensionCall) error { +func extCallToJSON(dest extensionCallJSON, src NodeTypeExtensionCall) error { jsonArgs := arrayJSON{} for _, n := range src.Args { argNode := &nodeJSON{} @@ -104,7 +104,7 @@ func extCallToJSON(dest extensionCallJSON, src nodeTypeExtensionCall) error { return nil } -func strToJSON(dest **strJSON, src strOpNode) error { +func strToJSON(dest **strJSON, src StrOpNode) error { res := &strJSON{} if err := res.Left.FromNode(src.Arg); err != nil { return fmt.Errorf("error in left: %w", err) @@ -114,7 +114,7 @@ func strToJSON(dest **strJSON, src strOpNode) error { return nil } -func patternToJSON(dest **patternJSON, src nodeTypeLike) error { +func patternToJSON(dest **patternJSON, src NodeTypeLike) error { res := &patternJSON{} if err := res.Left.FromNode(src.Arg); err != nil { return fmt.Errorf("error in left: %w", err) @@ -131,7 +131,7 @@ func patternToJSON(dest **patternJSON, src nodeTypeLike) error { return nil } -func recordToJSON(dest *recordJSON, src nodeTypeRecord) error { +func recordToJSON(dest *recordJSON, src NodeTypeRecord) error { res := recordJSON{} for _, kv := range src.Elements { var nn nodeJSON @@ -144,7 +144,7 @@ func recordToJSON(dest *recordJSON, src nodeTypeRecord) error { return nil } -func ifToJSON(dest **ifThenElseJSON, src nodeTypeIf) error { +func ifToJSON(dest **ifThenElseJSON, src NodeTypeIf) error { res := &ifThenElseJSON{} if err := res.If.FromNode(src.If); err != nil { return fmt.Errorf("error in if: %w", err) @@ -159,7 +159,7 @@ func ifToJSON(dest **ifThenElseJSON, src nodeTypeIf) error { return nil } -func isToJSON(dest **isJSON, src nodeTypeIs) error { +func isToJSON(dest **isJSON, src NodeTypeIs) error { res := &isJSON{} if err := res.Left.FromNode(src.Left); err != nil { return fmt.Errorf("error in left: %w", err) @@ -169,7 +169,7 @@ func isToJSON(dest **isJSON, src nodeTypeIs) error { return nil } -func isInToJSON(dest **isJSON, src nodeTypeIsIn) error { +func isInToJSON(dest **isJSON, src NodeTypeIsIn) error { res := &isJSON{} if err := res.Left.FromNode(src.Left); err != nil { return fmt.Errorf("error in left: %w", err) @@ -183,11 +183,11 @@ func isInToJSON(dest **isJSON, src nodeTypeIsIn) error { return nil } -func (j *nodeJSON) FromNode(src node) error { +func (j *nodeJSON) FromNode(src IsNode) error { switch t := src.(type) { // Value // Value *json.RawMessage `json:"Value"` // could be any - case nodeValue: + case NodeValue: // Any other function: decimal, ip // Decimal arrayJSON `json:"decimal"` // IP arrayJSON `json:"ip"` @@ -203,7 +203,7 @@ func (j *nodeJSON) FromNode(src node) error { // Var // Var *string `json:"Var"` - case nodeTypeVariable: + case NodeTypeVariable: val := string(t.Name) j.Var = &val return nil @@ -211,79 +211,79 @@ func (j *nodeJSON) FromNode(src node) error { // ! or neg operators // Not *unaryJSON `json:"!"` // Negate *unaryJSON `json:"neg"` - case nodeTypeNot: - return unaryToJSON(&j.Not, t.unaryNode) - case nodeTypeNegate: - return unaryToJSON(&j.Negate, t.unaryNode) + case NodeTypeNot: + return unaryToJSON(&j.Not, t.UnaryNode) + case NodeTypeNegate: + return unaryToJSON(&j.Negate, t.UnaryNode) // Binary operators: ==, !=, in, <, <=, >, >=, &&, ||, +, -, *, contains, containsAll, containsAny - case nodeTypeAdd: - return binaryToJSON(&j.Plus, t.binaryNode) - case nodeTypeAnd: - return binaryToJSON(&j.And, t.binaryNode) - case nodeTypeContains: - return binaryToJSON(&j.Contains, t.binaryNode) - case nodeTypeContainsAll: - return binaryToJSON(&j.ContainsAll, t.binaryNode) - case nodeTypeContainsAny: - return binaryToJSON(&j.ContainsAny, t.binaryNode) - case nodeTypeEquals: - return binaryToJSON(&j.Equals, t.binaryNode) - case nodeTypeGreaterThan: - return binaryToJSON(&j.GreaterThan, t.binaryNode) - case nodeTypeGreaterThanOrEqual: - return binaryToJSON(&j.GreaterThanOrEqual, t.binaryNode) - case nodeTypeIn: - return binaryToJSON(&j.In, t.binaryNode) - case nodeTypeLessThan: - return binaryToJSON(&j.LessThan, t.binaryNode) - case nodeTypeLessThanOrEqual: - return binaryToJSON(&j.LessThanOrEqual, t.binaryNode) - case nodeTypeMult: - return binaryToJSON(&j.Times, t.binaryNode) - case nodeTypeNotEquals: - return binaryToJSON(&j.NotEquals, t.binaryNode) - case nodeTypeOr: - return binaryToJSON(&j.Or, t.binaryNode) - case nodeTypeSub: - return binaryToJSON(&j.Minus, t.binaryNode) + case NodeTypeAdd: + return binaryToJSON(&j.Plus, t.BinaryNode) + case NodeTypeAnd: + return binaryToJSON(&j.And, t.BinaryNode) + case NodeTypeContains: + return binaryToJSON(&j.Contains, t.BinaryNode) + case NodeTypeContainsAll: + return binaryToJSON(&j.ContainsAll, t.BinaryNode) + case NodeTypeContainsAny: + return binaryToJSON(&j.ContainsAny, t.BinaryNode) + case NodeTypeEquals: + return binaryToJSON(&j.Equals, t.BinaryNode) + case NodeTypeGreaterThan: + return binaryToJSON(&j.GreaterThan, t.BinaryNode) + case NodeTypeGreaterThanOrEqual: + return binaryToJSON(&j.GreaterThanOrEqual, t.BinaryNode) + case NodeTypeIn: + return binaryToJSON(&j.In, t.BinaryNode) + case NodeTypeLessThan: + return binaryToJSON(&j.LessThan, t.BinaryNode) + case NodeTypeLessThanOrEqual: + return binaryToJSON(&j.LessThanOrEqual, t.BinaryNode) + case NodeTypeMult: + return binaryToJSON(&j.Times, t.BinaryNode) + case NodeTypeNotEquals: + return binaryToJSON(&j.NotEquals, t.BinaryNode) + case NodeTypeOr: + return binaryToJSON(&j.Or, t.BinaryNode) + case NodeTypeSub: + return binaryToJSON(&j.Minus, t.BinaryNode) // ., has // Access *strJSON `json:"."` // Has *strJSON `json:"has"` - case nodeTypeAccess: - return strToJSON(&j.Access, t.strOpNode) - case nodeTypeHas: - return strToJSON(&j.Has, t.strOpNode) + case NodeTypeAccess: + return strToJSON(&j.Access, t.StrOpNode) + case NodeTypeHas: + return strToJSON(&j.Has, t.StrOpNode) // is - case nodeTypeIs: + case NodeTypeIs: return isToJSON(&j.Is, t) - case nodeTypeIsIn: + case NodeTypeIsIn: return isInToJSON(&j.Is, t) // like // Like *strJSON `json:"like"` - case nodeTypeLike: + case NodeTypeLike: return patternToJSON(&j.Like, t) // if-then-else // IfThenElse *ifThenElseJSON `json:"if-then-else"` - case nodeTypeIf: + case NodeTypeIf: return ifToJSON(&j.IfThenElse, t) // Set // Set arrayJSON `json:"Set"` - case nodeTypeSet: + case NodeTypeSet: return arrayToJSON(&j.Set, t.Elements) // Record // Record recordJSON `json:"Record"` - case nodeTypeRecord: + case NodeTypeRecord: return recordToJSON(&j.Record, t) // Any other method: ip, decimal, lessThan, lessThanOrEqual, greaterThan, greaterThanOrEqual, isIpv4, isIpv6, isLoopback, isMulticast, isInRange // ExtensionMethod map[string]arrayJSON `json:"-"` - case nodeTypeExtensionCall: + case NodeTypeExtensionCall: j.ExtensionCall = extensionCallJSON{} return extCallToJSON(j.ExtensionCall, t) } @@ -314,28 +314,28 @@ func (p *patternComponentJSON) MarshalJSON() ([]byte, error) { func (p *Policy) MarshalJSON() ([]byte, error) { var j policyJSON j.Effect = "forbid" - if p.effect { + if p.Effect { j.Effect = "permit" } - if len(p.annotations) > 0 { + if len(p.Annotations) > 0 { j.Annotations = map[string]string{} } - for _, a := range p.annotations { + for _, a := range p.Annotations { j.Annotations[string(a.Key)] = string(a.Value) } - if err := j.Principal.FromNode(p.principal); err != nil { + if err := j.Principal.FromNode(p.Principal); err != nil { return nil, fmt.Errorf("error in principal: %w", err) } - if err := j.Action.FromNode(p.action); err != nil { + if err := j.Action.FromNode(p.Action); err != nil { return nil, fmt.Errorf("error in action: %w", err) } - if err := j.Resource.FromNode(p.resource); err != nil { + if err := j.Resource.FromNode(p.Resource); err != nil { return nil, fmt.Errorf("error in resource: %w", err) } - for _, c := range p.conditions { + for _, c := range p.Conditions { var cond conditionJSON cond.Kind = "when" - if c.Condition == conditionUnless { + if c.Condition == ConditionUnless { cond.Kind = "unless" } if err := cond.Body.FromNode(c.Body); err != nil { diff --git a/internal/ast/json_unmarshal.go b/internal/ast/json_unmarshal.go index cdb0e868..4076051b 100644 --- a/internal/ast/json_unmarshal.go +++ b/internal/ast/json_unmarshal.go @@ -9,7 +9,7 @@ import ( "github.com/cedar-policy/cedar-go/types" ) -func (s *scopeJSON) ToNode(variable scope) (isScopeNode, error) { +func (s *scopeJSON) ToNode(variable Scope) (IsScopeNode, error) { // TODO: should we be careful to be more strict about what is allowed here? switch s.Op { case "All": @@ -297,15 +297,15 @@ func (p *Policy) UnmarshalJSON(b []byte) error { p.Annotate(types.String(k), types.String(v)) } var err error - p.principal, err = j.Principal.ToNode(scope(rawPrincipalNode())) + p.Principal, err = j.Principal.ToNode(Scope(newPrincipalNode())) if err != nil { return fmt.Errorf("error in principal: %w", err) } - p.action, err = j.Action.ToNode(scope(rawActionNode())) + p.Action, err = j.Action.ToNode(Scope(newActionNode())) if err != nil { return fmt.Errorf("error in action: %w", err) } - p.resource, err = j.Resource.ToNode(scope(rawResourceNode())) + p.Resource, err = j.Resource.ToNode(Scope(newResourceNode())) if err != nil { return fmt.Errorf("error in resource: %w", err) } diff --git a/internal/ast/node.go b/internal/ast/node.go index 7ba210aa..0981be55 100644 --- a/internal/ast/node.go +++ b/internal/ast/node.go @@ -7,29 +7,29 @@ import ( ) type Node struct { - v node // NOTE: not an embed because a `Node` is not a `node` + v IsNode // NOTE: not an embed because a `Node` is not a `node` } -func newNode(v node) Node { +func newNode(v IsNode) Node { return Node{v: v} } -func NewNode(v node) Node { +func NewNode(v IsNode) Node { return Node{v: v} } -type strOpNode struct { - Arg node +type StrOpNode struct { + Arg IsNode Value types.String } -func (n strOpNode) isNode() {} +func (n StrOpNode) isNode() {} -type binaryNode struct { - Left, Right node +type BinaryNode struct { + Left, Right IsNode } -func (n binaryNode) isNode() {} +func (n BinaryNode) isNode() {} type nodePrecedenceLevel uint8 @@ -45,152 +45,152 @@ const ( primaryPrecedence nodePrecedenceLevel = 8 ) -type nodeTypeIf struct { - If, Then, Else node +type NodeTypeIf struct { + If, Then, Else IsNode } -func (n nodeTypeIf) precedenceLevel() nodePrecedenceLevel { +func (n NodeTypeIf) precedenceLevel() nodePrecedenceLevel { return ifPrecedence } -func (n nodeTypeIf) isNode() {} +func (n NodeTypeIf) isNode() {} -type nodeTypeOr struct{ binaryNode } +type NodeTypeOr struct{ BinaryNode } -func (n nodeTypeOr) precedenceLevel() nodePrecedenceLevel { +func (n NodeTypeOr) precedenceLevel() nodePrecedenceLevel { return orPrecedence } -type nodeTypeAnd struct { - binaryNode +type NodeTypeAnd struct { + BinaryNode } -func (n nodeTypeAnd) precedenceLevel() nodePrecedenceLevel { +func (n NodeTypeAnd) precedenceLevel() nodePrecedenceLevel { return andPrecedence } -type relationNode struct{} +type RelationNode struct{} -func (n relationNode) precedenceLevel() nodePrecedenceLevel { +func (n RelationNode) precedenceLevel() nodePrecedenceLevel { return relationPrecedence } -type nodeTypeLessThan struct { - binaryNode - relationNode +type NodeTypeLessThan struct { + BinaryNode + RelationNode } -type nodeTypeLessThanOrEqual struct { - binaryNode - relationNode +type NodeTypeLessThanOrEqual struct { + BinaryNode + RelationNode } -type nodeTypeGreaterThan struct { - binaryNode - relationNode +type NodeTypeGreaterThan struct { + BinaryNode + RelationNode } -type nodeTypeGreaterThanOrEqual struct { - binaryNode - relationNode +type NodeTypeGreaterThanOrEqual struct { + BinaryNode + RelationNode } -type nodeTypeNotEquals struct { - binaryNode - relationNode +type NodeTypeNotEquals struct { + BinaryNode + RelationNode } -type nodeTypeEquals struct { - binaryNode - relationNode +type NodeTypeEquals struct { + BinaryNode + RelationNode } -type nodeTypeIn struct { - binaryNode - relationNode +type NodeTypeIn struct { + BinaryNode + RelationNode } -type nodeTypeHas struct { - strOpNode - relationNode +type NodeTypeHas struct { + StrOpNode + RelationNode } -type nodeTypeLike struct { - Arg node +type NodeTypeLike struct { + Arg IsNode Value types.Pattern } -func (n nodeTypeLike) precedenceLevel() nodePrecedenceLevel { +func (n NodeTypeLike) precedenceLevel() nodePrecedenceLevel { return relationPrecedence } -func (n nodeTypeLike) isNode() {} +func (n NodeTypeLike) isNode() {} -type nodeTypeIs struct { - Left node +type NodeTypeIs struct { + Left IsNode EntityType types.Path } -func (n nodeTypeIs) precedenceLevel() nodePrecedenceLevel { +func (n NodeTypeIs) precedenceLevel() nodePrecedenceLevel { return relationPrecedence } -func (n nodeTypeIs) isNode() {} +func (n NodeTypeIs) isNode() {} -type nodeTypeIsIn struct { - nodeTypeIs - Entity node +type NodeTypeIsIn struct { + NodeTypeIs + Entity IsNode } -func (n nodeTypeIsIn) precedenceLevel() nodePrecedenceLevel { +func (n NodeTypeIsIn) precedenceLevel() nodePrecedenceLevel { return relationPrecedence } -type addNode struct{} +type AddNode struct{} -func (n addNode) precedenceLevel() nodePrecedenceLevel { +func (n AddNode) precedenceLevel() nodePrecedenceLevel { return addPrecedence } -type nodeTypeSub struct { - binaryNode - addNode +type NodeTypeSub struct { + BinaryNode + AddNode } -type nodeTypeAdd struct { - binaryNode - addNode +type NodeTypeAdd struct { + BinaryNode + AddNode } -type nodeTypeMult struct{ binaryNode } +type NodeTypeMult struct{ BinaryNode } -func (n nodeTypeMult) precedenceLevel() nodePrecedenceLevel { +func (n NodeTypeMult) precedenceLevel() nodePrecedenceLevel { return multPrecedence } -type unaryNode struct { - Arg node +type UnaryNode struct { + Arg IsNode } -func (n unaryNode) precedenceLevel() nodePrecedenceLevel { +func (n UnaryNode) precedenceLevel() nodePrecedenceLevel { return unaryPrecedence } -func (n unaryNode) isNode() {} +func (n UnaryNode) isNode() {} -type nodeTypeNegate struct{ unaryNode } -type nodeTypeNot struct{ unaryNode } +type NodeTypeNegate struct{ UnaryNode } +type NodeTypeNot struct{ UnaryNode } -type nodeTypeAccess struct{ strOpNode } +type NodeTypeAccess struct{ StrOpNode } -func (n nodeTypeAccess) precedenceLevel() nodePrecedenceLevel { +func (n NodeTypeAccess) precedenceLevel() nodePrecedenceLevel { return accessPrecedence } -type nodeTypeExtensionCall struct { +type NodeTypeExtensionCall struct { Name types.String // TODO: review type - Args []node + Args []IsNode } -func (n nodeTypeExtensionCall) precedenceLevel() nodePrecedenceLevel { +func (n NodeTypeExtensionCall) precedenceLevel() nodePrecedenceLevel { return accessPrecedence } -func (n nodeTypeExtensionCall) isNode() {} +func (n NodeTypeExtensionCall) isNode() {} -func stripNodes(args []Node) []node { - res := make([]node, len(args)) +func stripNodes(args []Node) []IsNode { + res := make([]IsNode, len(args)) for i, v := range args { res[i] = v.v } @@ -198,81 +198,81 @@ func stripNodes(args []Node) []node { } func newExtensionCall(method types.String, args ...Node) Node { - return newNode(nodeTypeExtensionCall{ + return newNode(NodeTypeExtensionCall{ Name: method, Args: stripNodes(args), }) } func newMethodCall(lhs Node, method types.String, args ...Node) Node { - res := make([]node, 1+len(args)) + res := make([]IsNode, 1+len(args)) res[0] = lhs.v for i, v := range args { res[i+1] = v.v } - return newNode(nodeTypeExtensionCall{ + return newNode(NodeTypeExtensionCall{ Name: method, Args: res, }) } -type containsNode struct{} +type ContainsNode struct{} -func (n containsNode) precedenceLevel() nodePrecedenceLevel { +func (n ContainsNode) precedenceLevel() nodePrecedenceLevel { return accessPrecedence } -type nodeTypeContains struct { - binaryNode - containsNode +type NodeTypeContains struct { + BinaryNode + ContainsNode } -type nodeTypeContainsAll struct { - binaryNode - containsNode +type NodeTypeContainsAll struct { + BinaryNode + ContainsNode } -type nodeTypeContainsAny struct { - binaryNode - containsNode +type NodeTypeContainsAny struct { + BinaryNode + ContainsNode } -type primaryNode struct{} +type PrimaryNode struct{} -func (n primaryNode) isNode() {} +func (n PrimaryNode) isNode() {} -func (n primaryNode) precedenceLevel() nodePrecedenceLevel { +func (n PrimaryNode) precedenceLevel() nodePrecedenceLevel { return primaryPrecedence } -type nodeValue struct { - primaryNode +type NodeValue struct { + PrimaryNode Value types.Value } -func (n nodeValue) isNode() {} +func (n NodeValue) isNode() {} -type recordElement struct { +type RecordElementNode struct { Key types.String - Value node + Value IsNode } -type nodeTypeRecord struct { - primaryNode - Elements []recordElement +type NodeTypeRecord struct { + PrimaryNode + Elements []RecordElementNode } -func (n nodeTypeRecord) isNode() {} +func (n NodeTypeRecord) isNode() {} -type nodeTypeSet struct { - primaryNode - Elements []node +type NodeTypeSet struct { + PrimaryNode + Elements []IsNode } -type nodeTypeVariable struct { - primaryNode +type NodeTypeVariable struct { + PrimaryNode Name types.String // TODO: Review type } -type node interface { +type IsNode interface { isNode() marshalCedar(*bytes.Buffer) precedenceLevel() nodePrecedenceLevel diff --git a/internal/ast/operator.go b/internal/ast/operator.go index c3a8de58..63fc0b50 100644 --- a/internal/ast/operator.go +++ b/internal/ast/operator.go @@ -10,27 +10,27 @@ import "github.com/cedar-policy/cedar-go/types" // |_| func (lhs Node) Equals(rhs Node) Node { - return newNode(nodeTypeEquals{binaryNode: binaryNode{Left: lhs.v, Right: rhs.v}}) + return newNode(NodeTypeEquals{BinaryNode: BinaryNode{Left: lhs.v, Right: rhs.v}}) } func (lhs Node) NotEquals(rhs Node) Node { - return newNode(nodeTypeNotEquals{binaryNode: binaryNode{Left: lhs.v, Right: rhs.v}}) + return newNode(NodeTypeNotEquals{BinaryNode: BinaryNode{Left: lhs.v, Right: rhs.v}}) } func (lhs Node) LessThan(rhs Node) Node { - return newNode(nodeTypeLessThan{binaryNode: binaryNode{Left: lhs.v, Right: rhs.v}}) + return newNode(NodeTypeLessThan{BinaryNode: BinaryNode{Left: lhs.v, Right: rhs.v}}) } func (lhs Node) LessThanOrEqual(rhs Node) Node { - return newNode(nodeTypeLessThanOrEqual{binaryNode: binaryNode{Left: lhs.v, Right: rhs.v}}) + return newNode(NodeTypeLessThanOrEqual{BinaryNode: BinaryNode{Left: lhs.v, Right: rhs.v}}) } func (lhs Node) GreaterThan(rhs Node) Node { - return newNode(nodeTypeGreaterThan{binaryNode: binaryNode{Left: lhs.v, Right: rhs.v}}) + return newNode(NodeTypeGreaterThan{BinaryNode: BinaryNode{Left: lhs.v, Right: rhs.v}}) } func (lhs Node) GreaterThanOrEqual(rhs Node) Node { - return newNode(nodeTypeGreaterThanOrEqual{binaryNode: binaryNode{Left: lhs.v, Right: rhs.v}}) + return newNode(NodeTypeGreaterThanOrEqual{BinaryNode: BinaryNode{Left: lhs.v, Right: rhs.v}}) } func (lhs Node) LessThanExt(rhs Node) Node { @@ -50,7 +50,7 @@ func (lhs Node) GreaterThanOrEqualExt(rhs Node) Node { } func (lhs Node) Like(pattern types.Pattern) Node { - return newNode(nodeTypeLike{Arg: lhs.v, Value: pattern}) + return newNode(NodeTypeLike{Arg: lhs.v, Value: pattern}) } // _ _ _ @@ -61,19 +61,19 @@ func (lhs Node) Like(pattern types.Pattern) Node { // |___/ func (lhs Node) And(rhs Node) Node { - return newNode(nodeTypeAnd{binaryNode: binaryNode{Left: lhs.v, Right: rhs.v}}) + return newNode(NodeTypeAnd{BinaryNode: BinaryNode{Left: lhs.v, Right: rhs.v}}) } func (lhs Node) Or(rhs Node) Node { - return newNode(nodeTypeOr{binaryNode: binaryNode{Left: lhs.v, Right: rhs.v}}) + return newNode(NodeTypeOr{BinaryNode: BinaryNode{Left: lhs.v, Right: rhs.v}}) } func Not(rhs Node) Node { - return newNode(nodeTypeNot{unaryNode: unaryNode{Arg: rhs.v}}) + return newNode(NodeTypeNot{UnaryNode: UnaryNode{Arg: rhs.v}}) } func If(condition Node, ifTrue Node, ifFalse Node) Node { - return newNode(nodeTypeIf{If: condition.v, Then: ifTrue.v, Else: ifFalse.v}) + return newNode(NodeTypeIf{If: condition.v, Then: ifTrue.v, Else: ifFalse.v}) } // _ _ _ _ _ _ @@ -83,19 +83,19 @@ func If(condition Node, ifTrue Node, ifFalse Node) Node { // /_/ \_\_| |_|\__|_| |_|_| |_| |_|\___|\__|_|\___| func (lhs Node) Plus(rhs Node) Node { - return newNode(nodeTypeAdd{binaryNode: binaryNode{Left: lhs.v, Right: rhs.v}}) + return newNode(NodeTypeAdd{BinaryNode: BinaryNode{Left: lhs.v, Right: rhs.v}}) } func (lhs Node) Minus(rhs Node) Node { - return newNode(nodeTypeSub{binaryNode: binaryNode{Left: lhs.v, Right: rhs.v}}) + return newNode(NodeTypeSub{BinaryNode: BinaryNode{Left: lhs.v, Right: rhs.v}}) } func (lhs Node) Times(rhs Node) Node { - return newNode(nodeTypeMult{binaryNode: binaryNode{Left: lhs.v, Right: rhs.v}}) + return newNode(NodeTypeMult{BinaryNode: BinaryNode{Left: lhs.v, Right: rhs.v}}) } func Negate(rhs Node) Node { - return newNode(nodeTypeNegate{unaryNode: unaryNode{Arg: rhs.v}}) + return newNode(NodeTypeNegate{UnaryNode: UnaryNode{Arg: rhs.v}}) } // _ _ _ _ @@ -106,35 +106,35 @@ func Negate(rhs Node) Node { // |___/ func (lhs Node) In(rhs Node) Node { - return newNode(nodeTypeIn{binaryNode: binaryNode{Left: lhs.v, Right: rhs.v}}) + return newNode(NodeTypeIn{BinaryNode: BinaryNode{Left: lhs.v, Right: rhs.v}}) } func (lhs Node) Is(entityType types.Path) Node { - return newNode(nodeTypeIs{Left: lhs.v, EntityType: entityType}) + return newNode(NodeTypeIs{Left: lhs.v, EntityType: entityType}) } func (lhs Node) IsIn(entityType types.Path, rhs Node) Node { - return newNode(nodeTypeIsIn{nodeTypeIs: nodeTypeIs{Left: lhs.v, EntityType: entityType}, Entity: rhs.v}) + return newNode(NodeTypeIsIn{NodeTypeIs: NodeTypeIs{Left: lhs.v, EntityType: entityType}, Entity: rhs.v}) } func (lhs Node) Contains(rhs Node) Node { - return newNode(nodeTypeContains{binaryNode: binaryNode{Left: lhs.v, Right: rhs.v}}) + return newNode(NodeTypeContains{BinaryNode: BinaryNode{Left: lhs.v, Right: rhs.v}}) } func (lhs Node) ContainsAll(rhs Node) Node { - return newNode(nodeTypeContainsAll{binaryNode: binaryNode{Left: lhs.v, Right: rhs.v}}) + return newNode(NodeTypeContainsAll{BinaryNode: BinaryNode{Left: lhs.v, Right: rhs.v}}) } func (lhs Node) ContainsAny(rhs Node) Node { - return newNode(nodeTypeContainsAny{binaryNode: binaryNode{Left: lhs.v, Right: rhs.v}}) + return newNode(NodeTypeContainsAny{BinaryNode: BinaryNode{Left: lhs.v, Right: rhs.v}}) } func (lhs Node) Access(attr string) Node { - return newNode(nodeTypeAccess{strOpNode: strOpNode{Arg: lhs.v, Value: types.String(attr)}}) + return newNode(NodeTypeAccess{StrOpNode: StrOpNode{Arg: lhs.v, Value: types.String(attr)}}) } func (lhs Node) Has(attr string) Node { - return newNode(nodeTypeHas{strOpNode: strOpNode{Arg: lhs.v, Value: types.String(attr)}}) + return newNode(NodeTypeHas{StrOpNode: StrOpNode{Arg: lhs.v, Value: types.String(attr)}}) } // ___ ____ _ _ _ diff --git a/internal/ast/policy.go b/internal/ast/policy.go index 038d47d1..249505f0 100644 --- a/internal/ast/policy.go +++ b/internal/ast/policy.go @@ -9,62 +9,62 @@ type PolicySetEntry struct { Position Position } -type annotationType struct { +type AnnotationType struct { Key types.String // TODO: review type Value types.String } -type condition bool +type Condition bool const ( - conditionWhen = true - conditionUnless = false + ConditionWhen = true + ConditionUnless = false ) -type conditionType struct { - Condition condition - Body node +type ConditionType struct { + Condition Condition + Body IsNode } -type effect bool +type Effect bool const ( - effectPermit effect = true - effectForbid effect = false + EffectPermit Effect = true + EffectForbid Effect = false ) type Policy struct { - effect effect - annotations []annotationType - principal isScopeNode - action isScopeNode - resource isScopeNode - conditions []conditionType + Effect Effect + Annotations []AnnotationType + Principal IsScopeNode + Action IsScopeNode + Resource IsScopeNode + Conditions []ConditionType } -func newPolicy(effect effect, annotations []annotationType) *Policy { +func newPolicy(effect Effect, annotations []AnnotationType) *Policy { return &Policy{ - effect: effect, - annotations: annotations, - principal: scope(rawPrincipalNode()).All(), - action: scope(rawActionNode()).All(), - resource: scope(rawResourceNode()).All(), + Effect: effect, + Annotations: annotations, + Principal: Scope(newPrincipalNode()).All(), + Action: Scope(newActionNode()).All(), + Resource: Scope(newResourceNode()).All(), } } func Permit() *Policy { - return newPolicy(effectPermit, nil) + return newPolicy(EffectPermit, nil) } func Forbid() *Policy { - return newPolicy(effectForbid, nil) + return newPolicy(EffectForbid, nil) } func (p *Policy) When(node Node) *Policy { - p.conditions = append(p.conditions, conditionType{Condition: conditionWhen, Body: node.v}) + p.Conditions = append(p.Conditions, ConditionType{Condition: ConditionWhen, Body: node.v}) return p } func (p *Policy) Unless(node Node) *Policy { - p.conditions = append(p.conditions, conditionType{Condition: conditionUnless, Body: node.v}) + p.Conditions = append(p.Conditions, ConditionType{Condition: ConditionUnless, Body: node.v}) return p } diff --git a/internal/ast/scope.go b/internal/ast/scope.go index fdbfc565..5720bab6 100644 --- a/internal/ast/scope.go +++ b/internal/ast/scope.go @@ -6,131 +6,131 @@ import ( "github.com/cedar-policy/cedar-go/types" ) -type scope nodeTypeVariable +type Scope NodeTypeVariable -func (s scope) All() isScopeNode { - return scopeTypeAll{scopeNode: scopeNode{Variable: nodeTypeVariable(s)}} +func (s Scope) All() IsScopeNode { + return ScopeTypeAll{ScopeNode: ScopeNode{Variable: NodeTypeVariable(s)}} } -func (s scope) Eq(entity types.EntityUID) isScopeNode { - return scopeTypeEq{scopeNode: scopeNode{Variable: nodeTypeVariable(s)}, Entity: entity} +func (s Scope) Eq(entity types.EntityUID) IsScopeNode { + return ScopeTypeEq{ScopeNode: ScopeNode{Variable: NodeTypeVariable(s)}, Entity: entity} } -func (s scope) In(entity types.EntityUID) isScopeNode { - return scopeTypeIn{scopeNode: scopeNode{Variable: nodeTypeVariable(s)}, Entity: entity} +func (s Scope) In(entity types.EntityUID) IsScopeNode { + return ScopeTypeIn{ScopeNode: ScopeNode{Variable: NodeTypeVariable(s)}, Entity: entity} } -func (s scope) InSet(entities []types.EntityUID) isScopeNode { - return scopeTypeInSet{scopeNode: scopeNode{Variable: nodeTypeVariable(s)}, Entities: entities} +func (s Scope) InSet(entities []types.EntityUID) IsScopeNode { + return ScopeTypeInSet{ScopeNode: ScopeNode{Variable: NodeTypeVariable(s)}, Entities: entities} } -func (s scope) Is(entityType types.Path) isScopeNode { - return scopeTypeIs{scopeNode: scopeNode{Variable: nodeTypeVariable(s)}, Type: entityType} +func (s Scope) Is(entityType types.Path) IsScopeNode { + return ScopeTypeIs{ScopeNode: ScopeNode{Variable: NodeTypeVariable(s)}, Type: entityType} } -func (s scope) IsIn(entityType types.Path, entity types.EntityUID) isScopeNode { - return scopeTypeIsIn{scopeNode: scopeNode{Variable: nodeTypeVariable(s)}, Type: entityType, Entity: entity} +func (s Scope) IsIn(entityType types.Path, entity types.EntityUID) IsScopeNode { + return ScopeTypeIsIn{ScopeNode: ScopeNode{Variable: NodeTypeVariable(s)}, Type: entityType, Entity: entity} } func (p *Policy) PrincipalEq(entity types.EntityUID) *Policy { - p.principal = scope(rawPrincipalNode()).Eq(entity) + p.Principal = Scope(newPrincipalNode()).Eq(entity) return p } func (p *Policy) PrincipalIn(entity types.EntityUID) *Policy { - p.principal = scope(rawPrincipalNode()).In(entity) + p.Principal = Scope(newPrincipalNode()).In(entity) return p } func (p *Policy) PrincipalIs(entityType types.Path) *Policy { - p.principal = scope(rawPrincipalNode()).Is(entityType) + p.Principal = Scope(newPrincipalNode()).Is(entityType) return p } func (p *Policy) PrincipalIsIn(entityType types.Path, entity types.EntityUID) *Policy { - p.principal = scope(rawPrincipalNode()).IsIn(entityType, entity) + p.Principal = Scope(newPrincipalNode()).IsIn(entityType, entity) return p } func (p *Policy) ActionEq(entity types.EntityUID) *Policy { - p.action = scope(rawActionNode()).Eq(entity) + p.Action = Scope(newActionNode()).Eq(entity) return p } func (p *Policy) ActionIn(entity types.EntityUID) *Policy { - p.action = scope(rawActionNode()).In(entity) + p.Action = Scope(newActionNode()).In(entity) return p } func (p *Policy) ActionInSet(entities ...types.EntityUID) *Policy { - p.action = scope(rawActionNode()).InSet(entities) + p.Action = Scope(newActionNode()).InSet(entities) return p } func (p *Policy) ResourceEq(entity types.EntityUID) *Policy { - p.resource = scope(rawResourceNode()).Eq(entity) + p.Resource = Scope(newResourceNode()).Eq(entity) return p } func (p *Policy) ResourceIn(entity types.EntityUID) *Policy { - p.resource = scope(rawResourceNode()).In(entity) + p.Resource = Scope(newResourceNode()).In(entity) return p } func (p *Policy) ResourceIs(entityType types.Path) *Policy { - p.resource = scope(rawResourceNode()).Is(entityType) + p.Resource = Scope(newResourceNode()).Is(entityType) return p } func (p *Policy) ResourceIsIn(entityType types.Path, entity types.EntityUID) *Policy { - p.resource = scope(rawResourceNode()).IsIn(entityType, entity) + p.Resource = Scope(newResourceNode()).IsIn(entityType, entity) return p } -type isScopeNode interface { +type IsScopeNode interface { isScope() MarshalCedar(*bytes.Buffer) toNode() Node } -type scopeNode struct { - Variable nodeTypeVariable +type ScopeNode struct { + Variable NodeTypeVariable } -func (n scopeNode) isScope() {} +func (n ScopeNode) isScope() {} -type scopeTypeAll struct { - scopeNode +type ScopeTypeAll struct { + ScopeNode } -func (n scopeTypeAll) toNode() Node { +func (n ScopeTypeAll) toNode() Node { return newNode(True().v) } -type scopeTypeEq struct { - scopeNode +type ScopeTypeEq struct { + ScopeNode Entity types.EntityUID } -func (n scopeTypeEq) toNode() Node { +func (n ScopeTypeEq) toNode() Node { return newNode(newNode(n.Variable).Equals(EntityUID(n.Entity)).v) } -type scopeTypeIn struct { - scopeNode +type ScopeTypeIn struct { + ScopeNode Entity types.EntityUID } -func (n scopeTypeIn) toNode() Node { +func (n ScopeTypeIn) toNode() Node { return newNode(newNode(n.Variable).In(EntityUID(n.Entity)).v) } -type scopeTypeInSet struct { - scopeNode +type ScopeTypeInSet struct { + ScopeNode Entities []types.EntityUID } -func (n scopeTypeInSet) toNode() Node { +func (n ScopeTypeInSet) toNode() Node { set := make([]types.Value, len(n.Entities)) for i, e := range n.Entities { set[i] = e @@ -138,21 +138,21 @@ func (n scopeTypeInSet) toNode() Node { return newNode(newNode(n.Variable).In(Set(set)).v) } -type scopeTypeIs struct { - scopeNode +type ScopeTypeIs struct { + ScopeNode Type types.Path } -func (n scopeTypeIs) toNode() Node { +func (n ScopeTypeIs) toNode() Node { return newNode(newNode(n.Variable).Is(n.Type).v) } -type scopeTypeIsIn struct { - scopeNode +type ScopeTypeIsIn struct { + ScopeNode Type types.Path Entity types.EntityUID } -func (n scopeTypeIsIn) toNode() Node { +func (n ScopeTypeIsIn) toNode() Node { return newNode(newNode(n.Variable).IsIn(n.Type, EntityUID(n.Entity)).v) } diff --git a/internal/ast/value.go b/internal/ast/value.go index becebbd7..726e2a90 100644 --- a/internal/ast/value.go +++ b/internal/ast/value.go @@ -29,11 +29,11 @@ func Long(l types.Long) Node { // Set is a convenience function that wraps concrete instances of a Cedar Set type // types in AST value nodes and passes them along to SetNodes. func Set(s types.Set) Node { - var nodes []node + var nodes []IsNode for _, v := range s { nodes = append(nodes, valueToNode(v).v) } - return newNode(nodeTypeSet{Elements: nodes}) + return newNode(NodeTypeSet{Elements: nodes}) } // SetNodes allows for a complex set definition with values potentially @@ -49,7 +49,7 @@ func Set(s types.Set) Node { // ast.Context().Access("fooCount"), // ) func SetNodes(nodes ...Node) Node { - return newNode(nodeTypeSet{Elements: stripNodes(nodes)}) + return newNode(NodeTypeSet{Elements: stripNodes(nodes)}) } // Record is a convenience function that wraps concrete instances of a Cedar Record type @@ -74,9 +74,9 @@ func Record(r types.Record) Node { // "x": ast.Long(1).Plus(ast.Context().Access("fooCount"))}, // }) func RecordNodes(entries map[types.String]Node) Node { - var res nodeTypeRecord + var res NodeTypeRecord for k, v := range entries { - res.Elements = append(res.Elements, recordElement{Key: k, Value: v.v}) + res.Elements = append(res.Elements, RecordElementNode{Key: k, Value: v.v}) } return newNode(res) } @@ -87,9 +87,9 @@ type RecordElement struct { } func RecordElements(elements ...RecordElement) Node { - var res nodeTypeRecord + var res NodeTypeRecord for _, e := range elements { - res.Elements = append(res.Elements, recordElement{Key: e.Key, Value: e.Value.v}) + res.Elements = append(res.Elements, RecordElementNode{Key: e.Key, Value: e.Value.v}) } return newNode(res) } @@ -111,7 +111,7 @@ func ExtensionCall(name types.String, args ...Node) Node { } func newValueNode(v types.Value) Node { - return newNode(nodeValue{Value: v}) + return newNode(NodeValue{Value: v}) } func valueToNode(v types.Value) Node { diff --git a/internal/ast/variable.go b/internal/ast/variable.go index e14cc783..4c68b06d 100644 --- a/internal/ast/variable.go +++ b/internal/ast/variable.go @@ -3,49 +3,33 @@ package ast import "github.com/cedar-policy/cedar-go/types" func Principal() Node { - return newPrincipalNode() + return newNode(newPrincipalNode()) } func Action() Node { - return newActionNode() + return newNode(newActionNode()) } func Resource() Node { - return newResourceNode() + return newNode(newResourceNode()) } func Context() Node { - return newContextNode() + return newNode(newContextNode()) } -func newPrincipalNode() Node { - return newNode(rawPrincipalNode()) +func newPrincipalNode() NodeTypeVariable { + return NodeTypeVariable{Name: types.String("principal")} } -func newActionNode() Node { - return newNode(rawActionNode()) +func newActionNode() NodeTypeVariable { + return NodeTypeVariable{Name: types.String("action")} } -func newResourceNode() Node { - return newNode(rawResourceNode()) +func newResourceNode() NodeTypeVariable { + return NodeTypeVariable{Name: types.String("resource")} } -func newContextNode() Node { - return newNode(rawContextNode()) -} - -func rawPrincipalNode() nodeTypeVariable { - return nodeTypeVariable{Name: types.String("principal")} -} - -func rawActionNode() nodeTypeVariable { - return nodeTypeVariable{Name: types.String("action")} -} - -func rawResourceNode() nodeTypeVariable { - return nodeTypeVariable{Name: types.String("resource")} -} - -func rawContextNode() nodeTypeVariable { - return nodeTypeVariable{Name: types.String("context")} +func newContextNode() NodeTypeVariable { + return NodeTypeVariable{Name: types.String("context")} } From 8d030f6324463718cac076a337e4abb8bcf386b4 Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Tue, 13 Aug 2024 11:15:39 -0700 Subject: [PATCH 085/216] cedar-go/internal/ast: make newNode public The eval component is going to want to be able to make nodes itself Signed-off-by: philhassey --- internal/ast/eval_compile.go | 4 ++-- internal/ast/node.go | 8 ++----- internal/ast/operator.go | 46 ++++++++++++++++++------------------ internal/ast/scope.go | 12 +++++----- internal/ast/value.go | 10 ++++---- internal/ast/variable.go | 8 +++---- 6 files changed, 42 insertions(+), 46 deletions(-) diff --git a/internal/ast/eval_compile.go b/internal/ast/eval_compile.go index 9d3ff8f7..c55b0b9e 100644 --- a/internal/ast/eval_compile.go +++ b/internal/ast/eval_compile.go @@ -29,10 +29,10 @@ func policyToNode(p Policy) Node { nodes[2] = p.Resource.toNode() for i, c := range p.Conditions { if c.Condition == ConditionUnless { - nodes[i+3] = Not(newNode(c.Body)) + nodes[i+3] = Not(NewNode(c.Body)) continue } - nodes[i+3] = newNode(c.Body) + nodes[i+3] = NewNode(c.Body) } res := nodes[len(nodes)-1] for i := len(nodes) - 2; i >= 0; i-- { diff --git a/internal/ast/node.go b/internal/ast/node.go index 0981be55..e77fbea0 100644 --- a/internal/ast/node.go +++ b/internal/ast/node.go @@ -10,10 +10,6 @@ type Node struct { v IsNode // NOTE: not an embed because a `Node` is not a `node` } -func newNode(v IsNode) Node { - return Node{v: v} -} - func NewNode(v IsNode) Node { return Node{v: v} } @@ -198,7 +194,7 @@ func stripNodes(args []Node) []IsNode { } func newExtensionCall(method types.String, args ...Node) Node { - return newNode(NodeTypeExtensionCall{ + return NewNode(NodeTypeExtensionCall{ Name: method, Args: stripNodes(args), }) @@ -210,7 +206,7 @@ func newMethodCall(lhs Node, method types.String, args ...Node) Node { for i, v := range args { res[i+1] = v.v } - return newNode(NodeTypeExtensionCall{ + return NewNode(NodeTypeExtensionCall{ Name: method, Args: res, }) diff --git a/internal/ast/operator.go b/internal/ast/operator.go index 63fc0b50..20721193 100644 --- a/internal/ast/operator.go +++ b/internal/ast/operator.go @@ -10,27 +10,27 @@ import "github.com/cedar-policy/cedar-go/types" // |_| func (lhs Node) Equals(rhs Node) Node { - return newNode(NodeTypeEquals{BinaryNode: BinaryNode{Left: lhs.v, Right: rhs.v}}) + return NewNode(NodeTypeEquals{BinaryNode: BinaryNode{Left: lhs.v, Right: rhs.v}}) } func (lhs Node) NotEquals(rhs Node) Node { - return newNode(NodeTypeNotEquals{BinaryNode: BinaryNode{Left: lhs.v, Right: rhs.v}}) + return NewNode(NodeTypeNotEquals{BinaryNode: BinaryNode{Left: lhs.v, Right: rhs.v}}) } func (lhs Node) LessThan(rhs Node) Node { - return newNode(NodeTypeLessThan{BinaryNode: BinaryNode{Left: lhs.v, Right: rhs.v}}) + return NewNode(NodeTypeLessThan{BinaryNode: BinaryNode{Left: lhs.v, Right: rhs.v}}) } func (lhs Node) LessThanOrEqual(rhs Node) Node { - return newNode(NodeTypeLessThanOrEqual{BinaryNode: BinaryNode{Left: lhs.v, Right: rhs.v}}) + return NewNode(NodeTypeLessThanOrEqual{BinaryNode: BinaryNode{Left: lhs.v, Right: rhs.v}}) } func (lhs Node) GreaterThan(rhs Node) Node { - return newNode(NodeTypeGreaterThan{BinaryNode: BinaryNode{Left: lhs.v, Right: rhs.v}}) + return NewNode(NodeTypeGreaterThan{BinaryNode: BinaryNode{Left: lhs.v, Right: rhs.v}}) } func (lhs Node) GreaterThanOrEqual(rhs Node) Node { - return newNode(NodeTypeGreaterThanOrEqual{BinaryNode: BinaryNode{Left: lhs.v, Right: rhs.v}}) + return NewNode(NodeTypeGreaterThanOrEqual{BinaryNode: BinaryNode{Left: lhs.v, Right: rhs.v}}) } func (lhs Node) LessThanExt(rhs Node) Node { @@ -50,7 +50,7 @@ func (lhs Node) GreaterThanOrEqualExt(rhs Node) Node { } func (lhs Node) Like(pattern types.Pattern) Node { - return newNode(NodeTypeLike{Arg: lhs.v, Value: pattern}) + return NewNode(NodeTypeLike{Arg: lhs.v, Value: pattern}) } // _ _ _ @@ -61,19 +61,19 @@ func (lhs Node) Like(pattern types.Pattern) Node { // |___/ func (lhs Node) And(rhs Node) Node { - return newNode(NodeTypeAnd{BinaryNode: BinaryNode{Left: lhs.v, Right: rhs.v}}) + return NewNode(NodeTypeAnd{BinaryNode: BinaryNode{Left: lhs.v, Right: rhs.v}}) } func (lhs Node) Or(rhs Node) Node { - return newNode(NodeTypeOr{BinaryNode: BinaryNode{Left: lhs.v, Right: rhs.v}}) + return NewNode(NodeTypeOr{BinaryNode: BinaryNode{Left: lhs.v, Right: rhs.v}}) } func Not(rhs Node) Node { - return newNode(NodeTypeNot{UnaryNode: UnaryNode{Arg: rhs.v}}) + return NewNode(NodeTypeNot{UnaryNode: UnaryNode{Arg: rhs.v}}) } func If(condition Node, ifTrue Node, ifFalse Node) Node { - return newNode(NodeTypeIf{If: condition.v, Then: ifTrue.v, Else: ifFalse.v}) + return NewNode(NodeTypeIf{If: condition.v, Then: ifTrue.v, Else: ifFalse.v}) } // _ _ _ _ _ _ @@ -83,19 +83,19 @@ func If(condition Node, ifTrue Node, ifFalse Node) Node { // /_/ \_\_| |_|\__|_| |_|_| |_| |_|\___|\__|_|\___| func (lhs Node) Plus(rhs Node) Node { - return newNode(NodeTypeAdd{BinaryNode: BinaryNode{Left: lhs.v, Right: rhs.v}}) + return NewNode(NodeTypeAdd{BinaryNode: BinaryNode{Left: lhs.v, Right: rhs.v}}) } func (lhs Node) Minus(rhs Node) Node { - return newNode(NodeTypeSub{BinaryNode: BinaryNode{Left: lhs.v, Right: rhs.v}}) + return NewNode(NodeTypeSub{BinaryNode: BinaryNode{Left: lhs.v, Right: rhs.v}}) } func (lhs Node) Times(rhs Node) Node { - return newNode(NodeTypeMult{BinaryNode: BinaryNode{Left: lhs.v, Right: rhs.v}}) + return NewNode(NodeTypeMult{BinaryNode: BinaryNode{Left: lhs.v, Right: rhs.v}}) } func Negate(rhs Node) Node { - return newNode(NodeTypeNegate{UnaryNode: UnaryNode{Arg: rhs.v}}) + return NewNode(NodeTypeNegate{UnaryNode: UnaryNode{Arg: rhs.v}}) } // _ _ _ _ @@ -106,35 +106,35 @@ func Negate(rhs Node) Node { // |___/ func (lhs Node) In(rhs Node) Node { - return newNode(NodeTypeIn{BinaryNode: BinaryNode{Left: lhs.v, Right: rhs.v}}) + return NewNode(NodeTypeIn{BinaryNode: BinaryNode{Left: lhs.v, Right: rhs.v}}) } func (lhs Node) Is(entityType types.Path) Node { - return newNode(NodeTypeIs{Left: lhs.v, EntityType: entityType}) + return NewNode(NodeTypeIs{Left: lhs.v, EntityType: entityType}) } func (lhs Node) IsIn(entityType types.Path, rhs Node) Node { - return newNode(NodeTypeIsIn{NodeTypeIs: NodeTypeIs{Left: lhs.v, EntityType: entityType}, Entity: rhs.v}) + return NewNode(NodeTypeIsIn{NodeTypeIs: NodeTypeIs{Left: lhs.v, EntityType: entityType}, Entity: rhs.v}) } func (lhs Node) Contains(rhs Node) Node { - return newNode(NodeTypeContains{BinaryNode: BinaryNode{Left: lhs.v, Right: rhs.v}}) + return NewNode(NodeTypeContains{BinaryNode: BinaryNode{Left: lhs.v, Right: rhs.v}}) } func (lhs Node) ContainsAll(rhs Node) Node { - return newNode(NodeTypeContainsAll{BinaryNode: BinaryNode{Left: lhs.v, Right: rhs.v}}) + return NewNode(NodeTypeContainsAll{BinaryNode: BinaryNode{Left: lhs.v, Right: rhs.v}}) } func (lhs Node) ContainsAny(rhs Node) Node { - return newNode(NodeTypeContainsAny{BinaryNode: BinaryNode{Left: lhs.v, Right: rhs.v}}) + return NewNode(NodeTypeContainsAny{BinaryNode: BinaryNode{Left: lhs.v, Right: rhs.v}}) } func (lhs Node) Access(attr string) Node { - return newNode(NodeTypeAccess{StrOpNode: StrOpNode{Arg: lhs.v, Value: types.String(attr)}}) + return NewNode(NodeTypeAccess{StrOpNode: StrOpNode{Arg: lhs.v, Value: types.String(attr)}}) } func (lhs Node) Has(attr string) Node { - return newNode(NodeTypeHas{StrOpNode: StrOpNode{Arg: lhs.v, Value: types.String(attr)}}) + return NewNode(NodeTypeHas{StrOpNode: StrOpNode{Arg: lhs.v, Value: types.String(attr)}}) } // ___ ____ _ _ _ diff --git a/internal/ast/scope.go b/internal/ast/scope.go index 5720bab6..145a5277 100644 --- a/internal/ast/scope.go +++ b/internal/ast/scope.go @@ -104,7 +104,7 @@ type ScopeTypeAll struct { } func (n ScopeTypeAll) toNode() Node { - return newNode(True().v) + return NewNode(True().v) } type ScopeTypeEq struct { @@ -113,7 +113,7 @@ type ScopeTypeEq struct { } func (n ScopeTypeEq) toNode() Node { - return newNode(newNode(n.Variable).Equals(EntityUID(n.Entity)).v) + return NewNode(NewNode(n.Variable).Equals(EntityUID(n.Entity)).v) } type ScopeTypeIn struct { @@ -122,7 +122,7 @@ type ScopeTypeIn struct { } func (n ScopeTypeIn) toNode() Node { - return newNode(newNode(n.Variable).In(EntityUID(n.Entity)).v) + return NewNode(NewNode(n.Variable).In(EntityUID(n.Entity)).v) } type ScopeTypeInSet struct { @@ -135,7 +135,7 @@ func (n ScopeTypeInSet) toNode() Node { for i, e := range n.Entities { set[i] = e } - return newNode(newNode(n.Variable).In(Set(set)).v) + return NewNode(NewNode(n.Variable).In(Set(set)).v) } type ScopeTypeIs struct { @@ -144,7 +144,7 @@ type ScopeTypeIs struct { } func (n ScopeTypeIs) toNode() Node { - return newNode(newNode(n.Variable).Is(n.Type).v) + return NewNode(NewNode(n.Variable).Is(n.Type).v) } type ScopeTypeIsIn struct { @@ -154,5 +154,5 @@ type ScopeTypeIsIn struct { } func (n ScopeTypeIsIn) toNode() Node { - return newNode(newNode(n.Variable).IsIn(n.Type, EntityUID(n.Entity)).v) + return NewNode(NewNode(n.Variable).IsIn(n.Type, EntityUID(n.Entity)).v) } diff --git a/internal/ast/value.go b/internal/ast/value.go index 726e2a90..6882d161 100644 --- a/internal/ast/value.go +++ b/internal/ast/value.go @@ -33,7 +33,7 @@ func Set(s types.Set) Node { for _, v := range s { nodes = append(nodes, valueToNode(v).v) } - return newNode(NodeTypeSet{Elements: nodes}) + return NewNode(NodeTypeSet{Elements: nodes}) } // SetNodes allows for a complex set definition with values potentially @@ -49,7 +49,7 @@ func Set(s types.Set) Node { // ast.Context().Access("fooCount"), // ) func SetNodes(nodes ...Node) Node { - return newNode(NodeTypeSet{Elements: stripNodes(nodes)}) + return NewNode(NodeTypeSet{Elements: stripNodes(nodes)}) } // Record is a convenience function that wraps concrete instances of a Cedar Record type @@ -78,7 +78,7 @@ func RecordNodes(entries map[types.String]Node) Node { for k, v := range entries { res.Elements = append(res.Elements, RecordElementNode{Key: k, Value: v.v}) } - return newNode(res) + return NewNode(res) } type RecordElement struct { @@ -91,7 +91,7 @@ func RecordElements(elements ...RecordElement) Node { for _, e := range elements { res.Elements = append(res.Elements, RecordElementNode{Key: e.Key, Value: e.Value.v}) } - return newNode(res) + return NewNode(res) } func EntityUID(e types.EntityUID) Node { @@ -111,7 +111,7 @@ func ExtensionCall(name types.String, args ...Node) Node { } func newValueNode(v types.Value) Node { - return newNode(NodeValue{Value: v}) + return NewNode(NodeValue{Value: v}) } func valueToNode(v types.Value) Node { diff --git a/internal/ast/variable.go b/internal/ast/variable.go index 4c68b06d..7254b8c8 100644 --- a/internal/ast/variable.go +++ b/internal/ast/variable.go @@ -3,19 +3,19 @@ package ast import "github.com/cedar-policy/cedar-go/types" func Principal() Node { - return newNode(newPrincipalNode()) + return NewNode(newPrincipalNode()) } func Action() Node { - return newNode(newActionNode()) + return NewNode(newActionNode()) } func Resource() Node { - return newNode(newResourceNode()) + return NewNode(newResourceNode()) } func Context() Node { - return newNode(newContextNode()) + return NewNode(newContextNode()) } func newPrincipalNode() NodeTypeVariable { From 0f26254b862e6b1bc2e99100782882190c02dcf7 Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Tue, 13 Aug 2024 11:22:54 -0700 Subject: [PATCH 086/216] cedar-go/internal/eval: exfiltrate eval code into its own package Signed-off-by: philhassey --- cedar.go | 7 +- internal/ast/cedar_marshal.go | 2 +- internal/ast/eval_compile.go | 42 --------- internal/ast/extensions.go | 2 +- internal/ast/node.go | 4 + internal/ast/policy.go | 11 +++ internal/ast/scope.go | 14 +-- internal/eval/eval_compile.go | 33 +++++++ internal/{ast => eval}/eval_convert.go | 64 +++++++------- internal/{ast => eval}/eval_impl.go | 102 +++++++++++----------- internal/{ast => eval}/eval_test.go | 116 ++++++++++++------------- 11 files changed, 203 insertions(+), 194 deletions(-) delete mode 100644 internal/ast/eval_compile.go create mode 100644 internal/eval/eval_compile.go rename internal/{ast => eval}/eval_convert.go (80%) rename internal/{ast => eval}/eval_impl.go (85%) rename internal/{ast => eval}/eval_test.go (96%) diff --git a/cedar.go b/cedar.go index 6e048303..a5034e0f 100644 --- a/cedar.go +++ b/cedar.go @@ -6,6 +6,7 @@ import ( "github.com/cedar-policy/cedar-go/internal/ast" "github.com/cedar-policy/cedar-go/internal/entities" + "github.com/cedar-policy/cedar-go/internal/eval" "github.com/cedar-policy/cedar-go/types" ) @@ -76,7 +77,7 @@ func NewPolicySet(fileName string, document []byte) (PolicySet, error) { }, Annotations: ann, Effect: Effect(p.TmpGetEffect()), - eval: ast.Compile(p.Policy), + eval: eval.Compile(p.Policy), }) } return policies, nil @@ -139,9 +140,9 @@ type Request struct { Context types.Record `json:"context"` } -type evalContext = ast.EvalContext +type evalContext = eval.Context -type evaler = ast.Evaler +type evaler = eval.Evaler // IsAuthorized uses the combination of the PolicySet and Entities to determine // if the given Request to determine Decision and Diagnostic. diff --git a/internal/ast/cedar_marshal.go b/internal/ast/cedar_marshal.go index 1a6db61e..b097c123 100644 --- a/internal/ast/cedar_marshal.go +++ b/internal/ast/cedar_marshal.go @@ -163,7 +163,7 @@ func (n NodeTypeAccess) marshalCedar(buf *bytes.Buffer) { func (n NodeTypeExtensionCall) marshalCedar(buf *bytes.Buffer) { var args []IsNode - info := extMap[n.Name] + info := ExtMap[n.Name] if info.IsMethod { marshalChildNode(n.precedenceLevel(), n.Args[0], buf) buf.WriteRune('.') diff --git a/internal/ast/eval_compile.go b/internal/ast/eval_compile.go deleted file mode 100644 index c55b0b9e..00000000 --- a/internal/ast/eval_compile.go +++ /dev/null @@ -1,42 +0,0 @@ -package ast - -type CompiledPolicySet map[string]CompiledPolicy - -type CompiledPolicy struct { - PolicySetEntry -} - -func (p PolicySetEntry) TmpGetAnnotations() map[string]string { - res := make(map[string]string, len(p.Policy.Annotations)) - for _, e := range p.Policy.Annotations { - res[string(e.Key)] = string(e.Value) - } - return res -} -func (p PolicySetEntry) TmpGetEffect() bool { - return bool(p.Policy.Effect) -} - -func Compile(p Policy) Evaler { - node := policyToNode(p).v - return toEval(node) -} - -func policyToNode(p Policy) Node { - nodes := make([]Node, 3+len(p.Conditions)) - nodes[0] = p.Principal.toNode() - nodes[1] = p.Action.toNode() - nodes[2] = p.Resource.toNode() - for i, c := range p.Conditions { - if c.Condition == ConditionUnless { - nodes[i+3] = Not(NewNode(c.Body)) - continue - } - nodes[i+3] = NewNode(c.Body) - } - res := nodes[len(nodes)-1] - for i := len(nodes) - 2; i >= 0; i-- { - res = nodes[i].And(res) - } - return res -} diff --git a/internal/ast/extensions.go b/internal/ast/extensions.go index 340e1206..41f1602b 100644 --- a/internal/ast/extensions.go +++ b/internal/ast/extensions.go @@ -7,7 +7,7 @@ type extInfo struct { IsMethod bool } -var extMap = map[types.String]extInfo{ +var ExtMap = map[types.String]extInfo{ "ip": {Args: 1, IsMethod: false}, "decimal": {Args: 1, IsMethod: false}, diff --git a/internal/ast/node.go b/internal/ast/node.go index e77fbea0..c3ac6171 100644 --- a/internal/ast/node.go +++ b/internal/ast/node.go @@ -14,6 +14,10 @@ func NewNode(v IsNode) Node { return Node{v: v} } +func (n Node) AsIsNode() IsNode { + return n.v +} + type StrOpNode struct { Arg IsNode Value types.String diff --git a/internal/ast/policy.go b/internal/ast/policy.go index 249505f0..e77eb8db 100644 --- a/internal/ast/policy.go +++ b/internal/ast/policy.go @@ -9,6 +9,17 @@ type PolicySetEntry struct { Position Position } +func (p PolicySetEntry) TmpGetAnnotations() map[string]string { + res := make(map[string]string, len(p.Policy.Annotations)) + for _, e := range p.Policy.Annotations { + res[string(e.Key)] = string(e.Value) + } + return res +} +func (p PolicySetEntry) TmpGetEffect() bool { + return bool(p.Policy.Effect) +} + type AnnotationType struct { Key types.String // TODO: review type Value types.String diff --git a/internal/ast/scope.go b/internal/ast/scope.go index 145a5277..27c43f79 100644 --- a/internal/ast/scope.go +++ b/internal/ast/scope.go @@ -90,7 +90,7 @@ func (p *Policy) ResourceIsIn(entityType types.Path, entity types.EntityUID) *Po type IsScopeNode interface { isScope() MarshalCedar(*bytes.Buffer) - toNode() Node + ToNode() Node } type ScopeNode struct { @@ -103,7 +103,7 @@ type ScopeTypeAll struct { ScopeNode } -func (n ScopeTypeAll) toNode() Node { +func (n ScopeTypeAll) ToNode() Node { return NewNode(True().v) } @@ -112,7 +112,7 @@ type ScopeTypeEq struct { Entity types.EntityUID } -func (n ScopeTypeEq) toNode() Node { +func (n ScopeTypeEq) ToNode() Node { return NewNode(NewNode(n.Variable).Equals(EntityUID(n.Entity)).v) } @@ -121,7 +121,7 @@ type ScopeTypeIn struct { Entity types.EntityUID } -func (n ScopeTypeIn) toNode() Node { +func (n ScopeTypeIn) ToNode() Node { return NewNode(NewNode(n.Variable).In(EntityUID(n.Entity)).v) } @@ -130,7 +130,7 @@ type ScopeTypeInSet struct { Entities []types.EntityUID } -func (n ScopeTypeInSet) toNode() Node { +func (n ScopeTypeInSet) ToNode() Node { set := make([]types.Value, len(n.Entities)) for i, e := range n.Entities { set[i] = e @@ -143,7 +143,7 @@ type ScopeTypeIs struct { Type types.Path } -func (n ScopeTypeIs) toNode() Node { +func (n ScopeTypeIs) ToNode() Node { return NewNode(NewNode(n.Variable).Is(n.Type).v) } @@ -153,6 +153,6 @@ type ScopeTypeIsIn struct { Entity types.EntityUID } -func (n ScopeTypeIsIn) toNode() Node { +func (n ScopeTypeIsIn) ToNode() Node { return NewNode(NewNode(n.Variable).IsIn(n.Type, EntityUID(n.Entity)).v) } diff --git a/internal/eval/eval_compile.go b/internal/eval/eval_compile.go new file mode 100644 index 00000000..aabc341d --- /dev/null +++ b/internal/eval/eval_compile.go @@ -0,0 +1,33 @@ +package eval + +import "github.com/cedar-policy/cedar-go/internal/ast" + +type CompiledPolicySet map[string]CompiledPolicy + +type CompiledPolicy struct { + ast.PolicySetEntry +} + +func Compile(p ast.Policy) Evaler { + node := policyToNode(p).AsIsNode() + return toEval(node) +} + +func policyToNode(p ast.Policy) ast.Node { + nodes := make([]ast.Node, 3+len(p.Conditions)) + nodes[0] = p.Principal.ToNode() + nodes[1] = p.Action.ToNode() + nodes[2] = p.Resource.ToNode() + for i, c := range p.Conditions { + if c.Condition == ast.ConditionUnless { + nodes[i+3] = ast.Not(ast.NewNode(c.Body)) + continue + } + nodes[i+3] = ast.NewNode(c.Body) + } + res := nodes[len(nodes)-1] + for i := len(nodes) - 2; i >= 0; i-- { + res = nodes[i].And(res) + } + return res +} diff --git a/internal/ast/eval_convert.go b/internal/eval/eval_convert.go similarity index 80% rename from internal/ast/eval_convert.go rename to internal/eval/eval_convert.go index bce95d9f..d956555f 100644 --- a/internal/ast/eval_convert.go +++ b/internal/eval/eval_convert.go @@ -1,28 +1,30 @@ -package ast +package eval import ( "fmt" + + "github.com/cedar-policy/cedar-go/internal/ast" ) -func toEval(n IsNode) Evaler { +func toEval(n ast.IsNode) Evaler { switch v := n.(type) { - case NodeTypeAccess: + case ast.NodeTypeAccess: return newAttributeAccessEval(toEval(v.Arg), string(v.Value)) - case NodeTypeHas: + case ast.NodeTypeHas: return newHasEval(toEval(v.Arg), string(v.Value)) - case NodeTypeLike: + case ast.NodeTypeLike: return newLikeEval(toEval(v.Arg), v.Value) - case NodeTypeIf: + case ast.NodeTypeIf: return newIfThenElseEval(toEval(v.If), toEval(v.Then), toEval(v.Else)) - case NodeTypeIs: + case ast.NodeTypeIs: return newIsEval(toEval(v.Left), newLiteralEval(v.EntityType)) - case NodeTypeIsIn: + case ast.NodeTypeIsIn: obj := toEval(v.Left) lhs := newIsEval(obj, newLiteralEval(v.EntityType)) rhs := newInEval(obj, toEval(v.Entity)) return newAndEval(lhs, rhs) - case NodeTypeExtensionCall: - i, ok := extMap[v.Name] + case ast.NodeTypeExtensionCall: + i, ok := ast.ExtMap[v.Name] if !ok { return newErrorEval(fmt.Errorf("%w: %s", errUnknownExtensionFunction, v.Name)) } @@ -57,25 +59,25 @@ func toEval(n IsNode) Evaler { default: panic(fmt.Errorf("unknown extension: %v", v.Name)) } - case NodeValue: + case ast.NodeValue: return newLiteralEval(v.Value) - case NodeTypeRecord: + case ast.NodeTypeRecord: m := make(map[string]Evaler, len(v.Elements)) for _, e := range v.Elements { m[string(e.Key)] = toEval(e.Value) } return newRecordLiteralEval(m) - case NodeTypeSet: + case ast.NodeTypeSet: s := make([]Evaler, len(v.Elements)) for i, e := range v.Elements { s[i] = toEval(e) } return newSetLiteralEval(s) - case NodeTypeNegate: + case ast.NodeTypeNegate: return newNegateEval(toEval(v.Arg)) - case NodeTypeNot: + case ast.NodeTypeNot: return newNotEval(toEval(v.Arg)) - case NodeTypeVariable: + case ast.NodeTypeVariable: switch v.Name { case "principal": return newVariableEval(variableNamePrincipal) @@ -88,35 +90,35 @@ func toEval(n IsNode) Evaler { default: panic(fmt.Errorf("unknown variable: %v", v.Name)) } - case NodeTypeIn: + case ast.NodeTypeIn: return newInEval(toEval(v.Left), toEval(v.Right)) - case NodeTypeAnd: + case ast.NodeTypeAnd: return newAndEval(toEval(v.Left), toEval(v.Right)) - case NodeTypeEquals: + case ast.NodeTypeEquals: return newEqualEval(toEval(v.Left), toEval(v.Right)) - case NodeTypeGreaterThan: + case ast.NodeTypeGreaterThan: return newLongGreaterThanEval(toEval(v.Left), toEval(v.Right)) - case NodeTypeGreaterThanOrEqual: + case ast.NodeTypeGreaterThanOrEqual: return newLongGreaterThanOrEqualEval(toEval(v.Left), toEval(v.Right)) - case NodeTypeLessThan: + case ast.NodeTypeLessThan: return newLongLessThanEval(toEval(v.Left), toEval(v.Right)) - case NodeTypeLessThanOrEqual: + case ast.NodeTypeLessThanOrEqual: return newLongLessThanOrEqualEval(toEval(v.Left), toEval(v.Right)) - case NodeTypeSub: + case ast.NodeTypeSub: return newSubtractEval(toEval(v.Left), toEval(v.Right)) - case NodeTypeAdd: + case ast.NodeTypeAdd: return newAddEval(toEval(v.Left), toEval(v.Right)) - case NodeTypeContains: + case ast.NodeTypeContains: return newContainsEval(toEval(v.Left), toEval(v.Right)) - case NodeTypeContainsAll: + case ast.NodeTypeContainsAll: return newContainsAllEval(toEval(v.Left), toEval(v.Right)) - case NodeTypeContainsAny: + case ast.NodeTypeContainsAny: return newContainsAnyEval(toEval(v.Left), toEval(v.Right)) - case NodeTypeMult: + case ast.NodeTypeMult: return newMultiplyEval(toEval(v.Left), toEval(v.Right)) - case NodeTypeNotEquals: + case ast.NodeTypeNotEquals: return newNotEqualEval(toEval(v.Left), toEval(v.Right)) - case NodeTypeOr: + case ast.NodeTypeOr: return newOrNode(toEval(v.Left), toEval(v.Right)) default: panic(fmt.Sprintf("unknown node type %T", v)) diff --git a/internal/ast/eval_impl.go b/internal/eval/eval_impl.go similarity index 85% rename from internal/ast/eval_impl.go rename to internal/eval/eval_impl.go index edcb8ee6..4924c354 100644 --- a/internal/ast/eval_impl.go +++ b/internal/eval/eval_impl.go @@ -1,4 +1,4 @@ -package ast +package eval import ( "fmt" @@ -15,17 +15,17 @@ var errEntityNotExist = fmt.Errorf("does not exist") var errUnspecifiedEntity = fmt.Errorf("unspecified entity") // TODO: make private again -type EvalContext struct { +type Context struct { Entities entities.Entities Principal, Action, Resource types.Value Context types.Value } type Evaler interface { - Eval(*EvalContext) (types.Value, error) + Eval(*Context) (types.Value, error) } -func evalBool(n Evaler, ctx *EvalContext) (types.Boolean, error) { +func evalBool(n Evaler, ctx *Context) (types.Boolean, error) { v, err := n.Eval(ctx) if err != nil { return false, err @@ -37,7 +37,7 @@ func evalBool(n Evaler, ctx *EvalContext) (types.Boolean, error) { return b, nil } -func evalLong(n Evaler, ctx *EvalContext) (types.Long, error) { +func evalLong(n Evaler, ctx *Context) (types.Long, error) { v, err := n.Eval(ctx) if err != nil { return 0, err @@ -49,7 +49,7 @@ func evalLong(n Evaler, ctx *EvalContext) (types.Long, error) { return l, nil } -func evalString(n Evaler, ctx *EvalContext) (types.String, error) { +func evalString(n Evaler, ctx *Context) (types.String, error) { v, err := n.Eval(ctx) if err != nil { return "", err @@ -61,7 +61,7 @@ func evalString(n Evaler, ctx *EvalContext) (types.String, error) { return s, nil } -func evalSet(n Evaler, ctx *EvalContext) (types.Set, error) { +func evalSet(n Evaler, ctx *Context) (types.Set, error) { v, err := n.Eval(ctx) if err != nil { return nil, err @@ -73,7 +73,7 @@ func evalSet(n Evaler, ctx *EvalContext) (types.Set, error) { return s, nil } -func evalEntity(n Evaler, ctx *EvalContext) (types.EntityUID, error) { +func evalEntity(n Evaler, ctx *Context) (types.EntityUID, error) { v, err := n.Eval(ctx) if err != nil { return types.EntityUID{}, err @@ -85,7 +85,7 @@ func evalEntity(n Evaler, ctx *EvalContext) (types.EntityUID, error) { return e, nil } -func evalPath(n Evaler, ctx *EvalContext) (types.Path, error) { +func evalPath(n Evaler, ctx *Context) (types.Path, error) { v, err := n.Eval(ctx) if err != nil { return "", err @@ -97,7 +97,7 @@ func evalPath(n Evaler, ctx *EvalContext) (types.Path, error) { return e, nil } -func evalDecimal(n Evaler, ctx *EvalContext) (types.Decimal, error) { +func evalDecimal(n Evaler, ctx *Context) (types.Decimal, error) { v, err := n.Eval(ctx) if err != nil { return types.Decimal(0), err @@ -109,7 +109,7 @@ func evalDecimal(n Evaler, ctx *EvalContext) (types.Decimal, error) { return d, nil } -func evalIP(n Evaler, ctx *EvalContext) (types.IPAddr, error) { +func evalIP(n Evaler, ctx *Context) (types.IPAddr, error) { v, err := n.Eval(ctx) if err != nil { return types.IPAddr{}, err @@ -132,7 +132,7 @@ func newErrorEval(err error) *errorEval { } } -func (n *errorEval) Eval(_ *EvalContext) (types.Value, error) { +func (n *errorEval) Eval(_ *Context) (types.Value, error) { return types.ZeroValue(), n.err } @@ -145,7 +145,7 @@ func newLiteralEval(value types.Value) *literalEval { return &literalEval{value: value} } -func (n *literalEval) Eval(_ *EvalContext) (types.Value, error) { +func (n *literalEval) Eval(_ *Context) (types.Value, error) { return n.value, nil } @@ -162,7 +162,7 @@ func newOrNode(lhs Evaler, rhs Evaler) *orEval { } } -func (n *orEval) Eval(ctx *EvalContext) (types.Value, error) { +func (n *orEval) Eval(ctx *Context) (types.Value, error) { v, err := n.lhs.Eval(ctx) if err != nil { return types.ZeroValue(), err @@ -198,7 +198,7 @@ func newAndEval(lhs Evaler, rhs Evaler) *andEval { } } -func (n *andEval) Eval(ctx *EvalContext) (types.Value, error) { +func (n *andEval) Eval(ctx *Context) (types.Value, error) { v, err := n.lhs.Eval(ctx) if err != nil { return types.ZeroValue(), err @@ -232,7 +232,7 @@ func newNotEval(inner Evaler) *notEval { } } -func (n *notEval) Eval(ctx *EvalContext) (types.Value, error) { +func (n *notEval) Eval(ctx *Context) (types.Value, error) { v, err := n.inner.Eval(ctx) if err != nil { return types.ZeroValue(), err @@ -301,7 +301,7 @@ func newAddEval(lhs Evaler, rhs Evaler) *addEval { } } -func (n *addEval) Eval(ctx *EvalContext) (types.Value, error) { +func (n *addEval) Eval(ctx *Context) (types.Value, error) { lhs, err := evalLong(n.lhs, ctx) if err != nil { return types.ZeroValue(), err @@ -330,7 +330,7 @@ func newSubtractEval(lhs Evaler, rhs Evaler) *subtractEval { } } -func (n *subtractEval) Eval(ctx *EvalContext) (types.Value, error) { +func (n *subtractEval) Eval(ctx *Context) (types.Value, error) { lhs, err := evalLong(n.lhs, ctx) if err != nil { return types.ZeroValue(), err @@ -359,7 +359,7 @@ func newMultiplyEval(lhs Evaler, rhs Evaler) *multiplyEval { } } -func (n *multiplyEval) Eval(ctx *EvalContext) (types.Value, error) { +func (n *multiplyEval) Eval(ctx *Context) (types.Value, error) { lhs, err := evalLong(n.lhs, ctx) if err != nil { return types.ZeroValue(), err @@ -386,7 +386,7 @@ func newNegateEval(inner Evaler) *negateEval { } } -func (n *negateEval) Eval(ctx *EvalContext) (types.Value, error) { +func (n *negateEval) Eval(ctx *Context) (types.Value, error) { inner, err := evalLong(n.inner, ctx) if err != nil { return types.ZeroValue(), err @@ -411,7 +411,7 @@ func newLongLessThanEval(lhs Evaler, rhs Evaler) *longLessThanEval { } } -func (n *longLessThanEval) Eval(ctx *EvalContext) (types.Value, error) { +func (n *longLessThanEval) Eval(ctx *Context) (types.Value, error) { lhs, err := evalLong(n.lhs, ctx) if err != nil { return types.ZeroValue(), err @@ -436,7 +436,7 @@ func newLongLessThanOrEqualEval(lhs Evaler, rhs Evaler) *longLessThanOrEqualEval } } -func (n *longLessThanOrEqualEval) Eval(ctx *EvalContext) (types.Value, error) { +func (n *longLessThanOrEqualEval) Eval(ctx *Context) (types.Value, error) { lhs, err := evalLong(n.lhs, ctx) if err != nil { return types.ZeroValue(), err @@ -461,7 +461,7 @@ func newLongGreaterThanEval(lhs Evaler, rhs Evaler) *longGreaterThanEval { } } -func (n *longGreaterThanEval) Eval(ctx *EvalContext) (types.Value, error) { +func (n *longGreaterThanEval) Eval(ctx *Context) (types.Value, error) { lhs, err := evalLong(n.lhs, ctx) if err != nil { return types.ZeroValue(), err @@ -486,7 +486,7 @@ func newLongGreaterThanOrEqualEval(lhs Evaler, rhs Evaler) *longGreaterThanOrEqu } } -func (n *longGreaterThanOrEqualEval) Eval(ctx *EvalContext) (types.Value, error) { +func (n *longGreaterThanOrEqualEval) Eval(ctx *Context) (types.Value, error) { lhs, err := evalLong(n.lhs, ctx) if err != nil { return types.ZeroValue(), err @@ -511,7 +511,7 @@ func newDecimalLessThanEval(lhs Evaler, rhs Evaler) *decimalLessThanEval { } } -func (n *decimalLessThanEval) Eval(ctx *EvalContext) (types.Value, error) { +func (n *decimalLessThanEval) Eval(ctx *Context) (types.Value, error) { lhs, err := evalDecimal(n.lhs, ctx) if err != nil { return types.ZeroValue(), err @@ -536,7 +536,7 @@ func newDecimalLessThanOrEqualEval(lhs Evaler, rhs Evaler) *decimalLessThanOrEqu } } -func (n *decimalLessThanOrEqualEval) Eval(ctx *EvalContext) (types.Value, error) { +func (n *decimalLessThanOrEqualEval) Eval(ctx *Context) (types.Value, error) { lhs, err := evalDecimal(n.lhs, ctx) if err != nil { return types.ZeroValue(), err @@ -561,7 +561,7 @@ func newDecimalGreaterThanEval(lhs Evaler, rhs Evaler) *decimalGreaterThanEval { } } -func (n *decimalGreaterThanEval) Eval(ctx *EvalContext) (types.Value, error) { +func (n *decimalGreaterThanEval) Eval(ctx *Context) (types.Value, error) { lhs, err := evalDecimal(n.lhs, ctx) if err != nil { return types.ZeroValue(), err @@ -586,7 +586,7 @@ func newDecimalGreaterThanOrEqualEval(lhs Evaler, rhs Evaler) *decimalGreaterTha } } -func (n *decimalGreaterThanOrEqualEval) Eval(ctx *EvalContext) (types.Value, error) { +func (n *decimalGreaterThanOrEqualEval) Eval(ctx *Context) (types.Value, error) { lhs, err := evalDecimal(n.lhs, ctx) if err != nil { return types.ZeroValue(), err @@ -613,7 +613,7 @@ func newIfThenElseEval(if_, then, else_ Evaler) *ifThenElseEval { } } -func (n *ifThenElseEval) Eval(ctx *EvalContext) (types.Value, error) { +func (n *ifThenElseEval) Eval(ctx *Context) (types.Value, error) { cond, err := evalBool(n.if_, ctx) if err != nil { return types.ZeroValue(), err @@ -636,7 +636,7 @@ func newEqualEval(lhs, rhs Evaler) *equalEval { } } -func (n *equalEval) Eval(ctx *EvalContext) (types.Value, error) { +func (n *equalEval) Eval(ctx *Context) (types.Value, error) { lv, err := n.lhs.Eval(ctx) if err != nil { return types.ZeroValue(), err @@ -660,7 +660,7 @@ func newNotEqualEval(lhs, rhs Evaler) *notEqualEval { } } -func (n *notEqualEval) Eval(ctx *EvalContext) (types.Value, error) { +func (n *notEqualEval) Eval(ctx *Context) (types.Value, error) { lv, err := n.lhs.Eval(ctx) if err != nil { return types.ZeroValue(), err @@ -681,7 +681,7 @@ func newSetLiteralEval(elements []Evaler) *setLiteralEval { return &setLiteralEval{elements: elements} } -func (n *setLiteralEval) Eval(ctx *EvalContext) (types.Value, error) { +func (n *setLiteralEval) Eval(ctx *Context) (types.Value, error) { var vals types.Set for _, e := range n.elements { v, err := e.Eval(ctx) @@ -705,7 +705,7 @@ func newContainsEval(lhs, rhs Evaler) *containsEval { } } -func (n *containsEval) Eval(ctx *EvalContext) (types.Value, error) { +func (n *containsEval) Eval(ctx *Context) (types.Value, error) { lhs, err := evalSet(n.lhs, ctx) if err != nil { return types.ZeroValue(), err @@ -729,7 +729,7 @@ func newContainsAllEval(lhs, rhs Evaler) *containsAllEval { } } -func (n *containsAllEval) Eval(ctx *EvalContext) (types.Value, error) { +func (n *containsAllEval) Eval(ctx *Context) (types.Value, error) { lhs, err := evalSet(n.lhs, ctx) if err != nil { return types.ZeroValue(), err @@ -760,7 +760,7 @@ func newContainsAnyEval(lhs, rhs Evaler) *containsAnyEval { } } -func (n *containsAnyEval) Eval(ctx *EvalContext) (types.Value, error) { +func (n *containsAnyEval) Eval(ctx *Context) (types.Value, error) { lhs, err := evalSet(n.lhs, ctx) if err != nil { return types.ZeroValue(), err @@ -788,7 +788,7 @@ func newRecordLiteralEval(elements map[string]Evaler) *recordLiteralEval { return &recordLiteralEval{elements: elements} } -func (n *recordLiteralEval) Eval(ctx *EvalContext) (types.Value, error) { +func (n *recordLiteralEval) Eval(ctx *Context) (types.Value, error) { vals := types.Record{} for k, en := range n.elements { v, err := en.Eval(ctx) @@ -810,7 +810,7 @@ func newAttributeAccessEval(record Evaler, attribute string) *attributeAccessEva return &attributeAccessEval{object: record, attribute: attribute} } -func (n *attributeAccessEval) Eval(ctx *EvalContext) (types.Value, error) { +func (n *attributeAccessEval) Eval(ctx *Context) (types.Value, error) { v, err := n.object.Eval(ctx) if err != nil { return types.ZeroValue(), err @@ -852,7 +852,7 @@ func newHasEval(record Evaler, attribute string) *hasEval { return &hasEval{object: record, attribute: attribute} } -func (n *hasEval) Eval(ctx *EvalContext) (types.Value, error) { +func (n *hasEval) Eval(ctx *Context) (types.Value, error) { v, err := n.object.Eval(ctx) if err != nil { return types.ZeroValue(), err @@ -885,7 +885,7 @@ func newLikeEval(lhs Evaler, pattern types.Pattern) *likeEval { return &likeEval{lhs: lhs, pattern: pattern} } -func (l *likeEval) Eval(ctx *EvalContext) (types.Value, error) { +func (l *likeEval) Eval(ctx *Context) (types.Value, error) { v, err := evalString(l.lhs, ctx) if err != nil { return types.ZeroValue(), err @@ -893,12 +893,12 @@ func (l *likeEval) Eval(ctx *EvalContext) (types.Value, error) { return types.Boolean(l.pattern.Match(string(v))), nil } -type variableName func(ctx *EvalContext) types.Value +type variableName func(ctx *Context) types.Value -func variableNamePrincipal(ctx *EvalContext) types.Value { return ctx.Principal } -func variableNameAction(ctx *EvalContext) types.Value { return ctx.Action } -func variableNameResource(ctx *EvalContext) types.Value { return ctx.Resource } -func variableNameContext(ctx *EvalContext) types.Value { return ctx.Context } +func variableNamePrincipal(ctx *Context) types.Value { return ctx.Principal } +func variableNameAction(ctx *Context) types.Value { return ctx.Action } +func variableNameResource(ctx *Context) types.Value { return ctx.Resource } +func variableNameContext(ctx *Context) types.Value { return ctx.Context } // variableEval type variableEval struct { @@ -909,7 +909,7 @@ func newVariableEval(variableName variableName) *variableEval { return &variableEval{variableName: variableName} } -func (n *variableEval) Eval(ctx *EvalContext) (types.Value, error) { +func (n *variableEval) Eval(ctx *Context) (types.Value, error) { return n.variableName(ctx), nil } @@ -940,7 +940,7 @@ func entityIn(entity types.EntityUID, query map[types.EntityUID]struct{}, entity return false } -func (n *inEval) Eval(ctx *EvalContext) (types.Value, error) { +func (n *inEval) Eval(ctx *Context) (types.Value, error) { lhs, err := evalEntity(n.lhs, ctx) if err != nil { return types.ZeroValue(), err @@ -979,7 +979,7 @@ func newIsEval(lhs, rhs Evaler) *isEval { return &isEval{lhs: lhs, rhs: rhs} } -func (n *isEval) Eval(ctx *EvalContext) (types.Value, error) { +func (n *isEval) Eval(ctx *Context) (types.Value, error) { lhs, err := evalEntity(n.lhs, ctx) if err != nil { return types.ZeroValue(), err @@ -1002,7 +1002,7 @@ func newDecimalLiteralEval(literal Evaler) *decimalLiteralEval { return &decimalLiteralEval{literal: literal} } -func (n *decimalLiteralEval) Eval(ctx *EvalContext) (types.Value, error) { +func (n *decimalLiteralEval) Eval(ctx *Context) (types.Value, error) { literal, err := evalString(n.literal, ctx) if err != nil { return types.ZeroValue(), err @@ -1024,7 +1024,7 @@ func newIPLiteralEval(literal Evaler) *ipLiteralEval { return &ipLiteralEval{literal: literal} } -func (n *ipLiteralEval) Eval(ctx *EvalContext) (types.Value, error) { +func (n *ipLiteralEval) Eval(ctx *Context) (types.Value, error) { literal, err := evalString(n.literal, ctx) if err != nil { return types.ZeroValue(), err @@ -1055,7 +1055,7 @@ func newIPTestEval(object Evaler, test ipTestType) *ipTestEval { return &ipTestEval{object: object, test: test} } -func (n *ipTestEval) Eval(ctx *EvalContext) (types.Value, error) { +func (n *ipTestEval) Eval(ctx *Context) (types.Value, error) { i, err := evalIP(n.object, ctx) if err != nil { return types.ZeroValue(), err @@ -1073,7 +1073,7 @@ func newIPIsInRangeEval(lhs, rhs Evaler) *ipIsInRangeEval { return &ipIsInRangeEval{lhs: lhs, rhs: rhs} } -func (n *ipIsInRangeEval) Eval(ctx *EvalContext) (types.Value, error) { +func (n *ipIsInRangeEval) Eval(ctx *Context) (types.Value, error) { lhs, err := evalIP(n.lhs, ctx) if err != nil { return types.ZeroValue(), err diff --git a/internal/ast/eval_test.go b/internal/eval/eval_test.go similarity index 96% rename from internal/ast/eval_test.go rename to internal/eval/eval_test.go index c2c6af7b..2f2b7ea0 100644 --- a/internal/ast/eval_test.go +++ b/internal/eval/eval_test.go @@ -1,4 +1,4 @@ -package ast +package eval import ( "fmt" @@ -35,7 +35,7 @@ func TestOrNode(t *testing.T) { t.Run(fmt.Sprintf("%v%v", tt.lhs, tt.rhs), func(t *testing.T) { t.Parallel() n := newOrNode(newLiteralEval(types.Boolean(tt.lhs)), newLiteralEval(types.Boolean(tt.rhs))) - v, err := n.Eval(&EvalContext{}) + v, err := n.Eval(&Context{}) testutil.OK(t, err) types.AssertBoolValue(t, v, tt.result) }) @@ -46,7 +46,7 @@ func TestOrNode(t *testing.T) { t.Parallel() n := newOrNode( newLiteralEval(types.Boolean(true)), newLiteralEval(types.Long(1))) - v, err := n.Eval(&EvalContext{}) + v, err := n.Eval(&Context{}) testutil.OK(t, err) types.AssertBoolValue(t, v, true) }) @@ -67,7 +67,7 @@ func TestOrNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newOrNode(tt.lhs, tt.rhs) - _, err := n.Eval(&EvalContext{}) + _, err := n.Eval(&Context{}) testutil.AssertError(t, err, tt.err) }) } @@ -90,7 +90,7 @@ func TestAndNode(t *testing.T) { t.Run(fmt.Sprintf("%v%v", tt.lhs, tt.rhs), func(t *testing.T) { t.Parallel() n := newAndEval(newLiteralEval(types.Boolean(tt.lhs)), newLiteralEval(types.Boolean(tt.rhs))) - v, err := n.Eval(&EvalContext{}) + v, err := n.Eval(&Context{}) testutil.OK(t, err) types.AssertBoolValue(t, v, tt.result) }) @@ -101,7 +101,7 @@ func TestAndNode(t *testing.T) { t.Parallel() n := newAndEval( newLiteralEval(types.Boolean(false)), newLiteralEval(types.Long(1))) - v, err := n.Eval(&EvalContext{}) + v, err := n.Eval(&Context{}) testutil.OK(t, err) types.AssertBoolValue(t, v, false) }) @@ -122,7 +122,7 @@ func TestAndNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newAndEval(tt.lhs, tt.rhs) - _, err := n.Eval(&EvalContext{}) + _, err := n.Eval(&Context{}) testutil.AssertError(t, err, tt.err) }) } @@ -143,7 +143,7 @@ func TestNotNode(t *testing.T) { t.Run(fmt.Sprintf("%v", tt.arg), func(t *testing.T) { t.Parallel() n := newNotEval(newLiteralEval(types.Boolean(tt.arg))) - v, err := n.Eval(&EvalContext{}) + v, err := n.Eval(&Context{}) testutil.OK(t, err) types.AssertBoolValue(t, v, tt.result) }) @@ -164,7 +164,7 @@ func TestNotNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newNotEval(tt.arg) - _, err := n.Eval(&EvalContext{}) + _, err := n.Eval(&Context{}) testutil.AssertError(t, err, tt.err) }) } @@ -344,7 +344,7 @@ func TestAddNode(t *testing.T) { t.Run("Basic", func(t *testing.T) { t.Parallel() n := newAddEval(newLiteralEval(types.Long(1)), newLiteralEval(types.Long(2))) - v, err := n.Eval(&EvalContext{}) + v, err := n.Eval(&Context{}) testutil.OK(t, err) types.AssertLongValue(t, v, 3) }) @@ -372,7 +372,7 @@ func TestAddNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newAddEval(tt.lhs, tt.rhs) - _, err := n.Eval(&EvalContext{}) + _, err := n.Eval(&Context{}) testutil.AssertError(t, err, tt.err) }) } @@ -383,7 +383,7 @@ func TestSubtractNode(t *testing.T) { t.Run("Basic", func(t *testing.T) { t.Parallel() n := newSubtractEval(newLiteralEval(types.Long(1)), newLiteralEval(types.Long(2))) - v, err := n.Eval(&EvalContext{}) + v, err := n.Eval(&Context{}) testutil.OK(t, err) types.AssertLongValue(t, v, -1) }) @@ -411,7 +411,7 @@ func TestSubtractNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newSubtractEval(tt.lhs, tt.rhs) - _, err := n.Eval(&EvalContext{}) + _, err := n.Eval(&Context{}) testutil.AssertError(t, err, tt.err) }) } @@ -422,7 +422,7 @@ func TestMultiplyNode(t *testing.T) { t.Run("Basic", func(t *testing.T) { t.Parallel() n := newMultiplyEval(newLiteralEval(types.Long(-3)), newLiteralEval(types.Long(2))) - v, err := n.Eval(&EvalContext{}) + v, err := n.Eval(&Context{}) testutil.OK(t, err) types.AssertLongValue(t, v, -6) }) @@ -450,7 +450,7 @@ func TestMultiplyNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newMultiplyEval(tt.lhs, tt.rhs) - _, err := n.Eval(&EvalContext{}) + _, err := n.Eval(&Context{}) testutil.AssertError(t, err, tt.err) }) } @@ -461,7 +461,7 @@ func TestNegateNode(t *testing.T) { t.Run("Basic", func(t *testing.T) { t.Parallel() n := newNegateEval(newLiteralEval(types.Long(-3))) - v, err := n.Eval(&EvalContext{}) + v, err := n.Eval(&Context{}) testutil.OK(t, err) types.AssertLongValue(t, v, 3) }) @@ -480,7 +480,7 @@ func TestNegateNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newNegateEval(tt.arg) - _, err := n.Eval(&EvalContext{}) + _, err := n.Eval(&Context{}) testutil.AssertError(t, err, tt.err) }) } @@ -509,7 +509,7 @@ func TestLongLessThanNode(t *testing.T) { t.Parallel() n := newLongLessThanEval( newLiteralEval(types.Long(tt.lhs)), newLiteralEval(types.Long(tt.rhs))) - v, err := n.Eval(&EvalContext{}) + v, err := n.Eval(&Context{}) testutil.OK(t, err) types.AssertBoolValue(t, v, tt.result) }) @@ -531,7 +531,7 @@ func TestLongLessThanNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newLongLessThanEval(tt.lhs, tt.rhs) - _, err := n.Eval(&EvalContext{}) + _, err := n.Eval(&Context{}) testutil.AssertError(t, err, tt.err) }) } @@ -561,7 +561,7 @@ func TestLongLessThanOrEqualNode(t *testing.T) { t.Parallel() n := newLongLessThanOrEqualEval( newLiteralEval(types.Long(tt.lhs)), newLiteralEval(types.Long(tt.rhs))) - v, err := n.Eval(&EvalContext{}) + v, err := n.Eval(&Context{}) testutil.OK(t, err) types.AssertBoolValue(t, v, tt.result) }) @@ -583,7 +583,7 @@ func TestLongLessThanOrEqualNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newLongLessThanOrEqualEval(tt.lhs, tt.rhs) - _, err := n.Eval(&EvalContext{}) + _, err := n.Eval(&Context{}) testutil.AssertError(t, err, tt.err) }) } @@ -613,7 +613,7 @@ func TestLongGreaterThanNode(t *testing.T) { t.Parallel() n := newLongGreaterThanEval( newLiteralEval(types.Long(tt.lhs)), newLiteralEval(types.Long(tt.rhs))) - v, err := n.Eval(&EvalContext{}) + v, err := n.Eval(&Context{}) testutil.OK(t, err) types.AssertBoolValue(t, v, tt.result) }) @@ -635,7 +635,7 @@ func TestLongGreaterThanNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newLongGreaterThanEval(tt.lhs, tt.rhs) - _, err := n.Eval(&EvalContext{}) + _, err := n.Eval(&Context{}) testutil.AssertError(t, err, tt.err) }) } @@ -665,7 +665,7 @@ func TestLongGreaterThanOrEqualNode(t *testing.T) { t.Parallel() n := newLongGreaterThanOrEqualEval( newLiteralEval(types.Long(tt.lhs)), newLiteralEval(types.Long(tt.rhs))) - v, err := n.Eval(&EvalContext{}) + v, err := n.Eval(&Context{}) testutil.OK(t, err) types.AssertBoolValue(t, v, tt.result) }) @@ -687,7 +687,7 @@ func TestLongGreaterThanOrEqualNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newLongGreaterThanOrEqualEval(tt.lhs, tt.rhs) - _, err := n.Eval(&EvalContext{}) + _, err := n.Eval(&Context{}) testutil.AssertError(t, err, tt.err) }) } @@ -722,7 +722,7 @@ func TestDecimalLessThanNode(t *testing.T) { testutil.OK(t, err) rhsv := rhsd n := newDecimalLessThanEval(newLiteralEval(lhsv), newLiteralEval(rhsv)) - v, err := n.Eval(&EvalContext{}) + v, err := n.Eval(&Context{}) testutil.OK(t, err) types.AssertBoolValue(t, v, tt.result) }) @@ -744,7 +744,7 @@ func TestDecimalLessThanNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newDecimalLessThanEval(tt.lhs, tt.rhs) - _, err := n.Eval(&EvalContext{}) + _, err := n.Eval(&Context{}) testutil.AssertError(t, err, tt.err) }) } @@ -779,7 +779,7 @@ func TestDecimalLessThanOrEqualNode(t *testing.T) { testutil.OK(t, err) rhsv := rhsd n := newDecimalLessThanOrEqualEval(newLiteralEval(lhsv), newLiteralEval(rhsv)) - v, err := n.Eval(&EvalContext{}) + v, err := n.Eval(&Context{}) testutil.OK(t, err) types.AssertBoolValue(t, v, tt.result) }) @@ -801,7 +801,7 @@ func TestDecimalLessThanOrEqualNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newDecimalLessThanOrEqualEval(tt.lhs, tt.rhs) - _, err := n.Eval(&EvalContext{}) + _, err := n.Eval(&Context{}) testutil.AssertError(t, err, tt.err) }) } @@ -836,7 +836,7 @@ func TestDecimalGreaterThanNode(t *testing.T) { testutil.OK(t, err) rhsv := rhsd n := newDecimalGreaterThanEval(newLiteralEval(lhsv), newLiteralEval(rhsv)) - v, err := n.Eval(&EvalContext{}) + v, err := n.Eval(&Context{}) testutil.OK(t, err) types.AssertBoolValue(t, v, tt.result) }) @@ -858,7 +858,7 @@ func TestDecimalGreaterThanNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newDecimalGreaterThanEval(tt.lhs, tt.rhs) - _, err := n.Eval(&EvalContext{}) + _, err := n.Eval(&Context{}) testutil.AssertError(t, err, tt.err) }) } @@ -893,7 +893,7 @@ func TestDecimalGreaterThanOrEqualNode(t *testing.T) { testutil.OK(t, err) rhsv := rhsd n := newDecimalGreaterThanOrEqualEval(newLiteralEval(lhsv), newLiteralEval(rhsv)) - v, err := n.Eval(&EvalContext{}) + v, err := n.Eval(&Context{}) testutil.OK(t, err) types.AssertBoolValue(t, v, tt.result) }) @@ -915,7 +915,7 @@ func TestDecimalGreaterThanOrEqualNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newDecimalGreaterThanOrEqualEval(tt.lhs, tt.rhs) - _, err := n.Eval(&EvalContext{}) + _, err := n.Eval(&Context{}) testutil.AssertError(t, err, tt.err) }) } @@ -946,7 +946,7 @@ func TestIfThenElseNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newIfThenElseEval(tt.if_, tt.then, tt.else_) - v, err := n.Eval(&EvalContext{}) + v, err := n.Eval(&Context{}) testutil.AssertError(t, err, tt.err) testutil.Equals(t, v, tt.result) }) @@ -972,7 +972,7 @@ func TestEqualNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newEqualEval(tt.lhs, tt.rhs) - v, err := n.Eval(&EvalContext{}) + v, err := n.Eval(&Context{}) testutil.AssertError(t, err, tt.err) types.AssertValue(t, v, tt.result) }) @@ -998,7 +998,7 @@ func TestNotEqualNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newNotEqualEval(tt.lhs, tt.rhs) - v, err := n.Eval(&EvalContext{}) + v, err := n.Eval(&Context{}) testutil.AssertError(t, err, tt.err) types.AssertValue(t, v, tt.result) }) @@ -1039,7 +1039,7 @@ func TestSetLiteralNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newSetLiteralEval(tt.elems) - v, err := n.Eval(&EvalContext{}) + v, err := n.Eval(&Context{}) testutil.AssertError(t, err, tt.err) types.AssertValue(t, v, tt.result) }) @@ -1063,7 +1063,7 @@ func TestContainsNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newContainsEval(tt.lhs, tt.rhs) - v, err := n.Eval(&EvalContext{}) + v, err := n.Eval(&Context{}) testutil.AssertError(t, err, tt.err) types.AssertZeroValue(t, v) }) @@ -1092,7 +1092,7 @@ func TestContainsNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newContainsEval(tt.lhs, tt.rhs) - v, err := n.Eval(&EvalContext{}) + v, err := n.Eval(&Context{}) testutil.OK(t, err) types.AssertBoolValue(t, v, tt.result) }) @@ -1118,7 +1118,7 @@ func TestContainsAllNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newContainsAllEval(tt.lhs, tt.rhs) - v, err := n.Eval(&EvalContext{}) + v, err := n.Eval(&Context{}) testutil.AssertError(t, err, tt.err) types.AssertZeroValue(t, v) }) @@ -1146,7 +1146,7 @@ func TestContainsAllNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newContainsAllEval(tt.lhs, tt.rhs) - v, err := n.Eval(&EvalContext{}) + v, err := n.Eval(&Context{}) testutil.OK(t, err) types.AssertBoolValue(t, v, tt.result) }) @@ -1172,7 +1172,7 @@ func TestContainsAnyNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newContainsAnyEval(tt.lhs, tt.rhs) - v, err := n.Eval(&EvalContext{}) + v, err := n.Eval(&Context{}) testutil.AssertError(t, err, tt.err) types.AssertZeroValue(t, v) }) @@ -1203,7 +1203,7 @@ func TestContainsAnyNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newContainsAnyEval(tt.lhs, tt.rhs) - v, err := n.Eval(&EvalContext{}) + v, err := n.Eval(&Context{}) testutil.OK(t, err) types.AssertBoolValue(t, v, tt.result) }) @@ -1235,7 +1235,7 @@ func TestRecordLiteralNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newRecordLiteralEval(tt.elems) - v, err := n.Eval(&EvalContext{}) + v, err := n.Eval(&Context{}) testutil.AssertError(t, err, tt.err) types.AssertValue(t, v, tt.result) }) @@ -1288,7 +1288,7 @@ func TestAttributeAccessNode(t *testing.T) { UID: types.NewEntityUID("knownType", "knownID"), Attributes: types.Record{"knownAttr": types.Long(42)}, } - v, err := n.Eval(&EvalContext{ + v, err := n.Eval(&Context{ Entities: entities.Entities{ entity.UID: entity, }, @@ -1345,7 +1345,7 @@ func TestHasNode(t *testing.T) { UID: types.NewEntityUID("knownType", "knownID"), Attributes: types.Record{"knownAttr": types.Long(42)}, } - v, err := n.Eval(&EvalContext{ + v, err := n.Eval(&Context{ Entities: entities.Entities{ entity.UID: entity, }, @@ -1403,7 +1403,7 @@ func TestLikeNode(t *testing.T) { pat, err := types.ParsePattern(tt.pattern[1 : len(tt.pattern)-1]) testutil.OK(t, err) n := newLikeEval(tt.str, pat) - v, err := n.Eval(&EvalContext{}) + v, err := n.Eval(&Context{}) testutil.AssertError(t, err, tt.err) types.AssertValue(t, v, tt.result) }) @@ -1414,24 +1414,24 @@ func TestVariableNode(t *testing.T) { t.Parallel() tests := []struct { name string - context EvalContext + context Context variable variableName result types.Value }{ {"principal", - EvalContext{Principal: types.String("foo")}, + Context{Principal: types.String("foo")}, variableNamePrincipal, types.String("foo")}, {"action", - EvalContext{Action: types.String("bar")}, + Context{Action: types.String("bar")}, variableNameAction, types.String("bar")}, {"resource", - EvalContext{Resource: types.String("baz")}, + Context{Resource: types.String("baz")}, variableNameResource, types.String("baz")}, {"context", - EvalContext{Context: types.String("frob")}, + Context{Context: types.String("frob")}, variableNameContext, types.String("frob")}, } @@ -1601,7 +1601,7 @@ func TestIsNode(t *testing.T) { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() - got, err := newIsEval(tt.lhs, tt.rhs).Eval(&EvalContext{}) + got, err := newIsEval(tt.lhs, tt.rhs).Eval(&Context{}) testutil.AssertError(t, err, tt.err) types.AssertValue(t, got, tt.result) }) @@ -1717,7 +1717,7 @@ func TestInNode(t *testing.T) { Parents: ps, } } - ec := EvalContext{Entities: entityMap} + ec := Context{Entities: entityMap} v, err := n.Eval(&ec) testutil.AssertError(t, err, tt.err) types.AssertValue(t, v, tt.result) @@ -1743,7 +1743,7 @@ func TestDecimalLiteralNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newDecimalLiteralEval(tt.arg) - v, err := n.Eval(&EvalContext{}) + v, err := n.Eval(&Context{}) testutil.AssertError(t, err, tt.err) types.AssertValue(t, v, tt.result) }) @@ -1770,7 +1770,7 @@ func TestIPLiteralNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newIPLiteralEval(tt.arg) - v, err := n.Eval(&EvalContext{}) + v, err := n.Eval(&Context{}) testutil.AssertError(t, err, tt.err) types.AssertValue(t, v, tt.result) }) @@ -1808,7 +1808,7 @@ func TestIPTestNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newIPTestEval(tt.lhs, tt.rhs) - v, err := n.Eval(&EvalContext{}) + v, err := n.Eval(&Context{}) testutil.AssertError(t, err, tt.err) types.AssertValue(t, v, tt.result) }) @@ -1846,7 +1846,7 @@ func TestIPIsInRangeNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newIPIsInRangeEval(tt.lhs, tt.rhs) - v, err := n.Eval(&EvalContext{}) + v, err := n.Eval(&Context{}) testutil.AssertError(t, err, tt.err) types.AssertValue(t, v, tt.result) }) From 27804837004a41cc6d0805c5c6377df6e52b6eb6 Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Tue, 13 Aug 2024 11:23:32 -0700 Subject: [PATCH 087/216] cedar-go/internal/ast: remove redundant test file Signed-off-by: philhassey --- internal/ast/ast_test.go | 70 ---------------------------------------- 1 file changed, 70 deletions(-) delete mode 100644 internal/ast/ast_test.go diff --git a/internal/ast/ast_test.go b/internal/ast/ast_test.go deleted file mode 100644 index c765abab..00000000 --- a/internal/ast/ast_test.go +++ /dev/null @@ -1,70 +0,0 @@ -package ast_test - -import ( - "testing" - - "github.com/cedar-policy/cedar-go/internal/ast" - "github.com/cedar-policy/cedar-go/types" -) - -// These tests mostly verify that policy ASTs compile -func TestAst(t *testing.T) { - t.Parallel() - - johnny := types.NewEntityUID("User", "johnny") - sow := types.NewEntityUID("Action", "sow") - cast := types.NewEntityUID("Action", "cast") - - // @example("one") - // permit ( - // principal == User::"johnny" - // action in [Action::"sow", Action::"cast"] - // resource - // ) - // when { true } - // unless { false }; - _ = ast.Annotation("example", "one"). - Permit(). - PrincipalIsIn("User", johnny). - ActionInSet(sow, cast). - When(ast.True()). - Unless(ast.False()) - - // @example("two") - // forbid (principal, action, resource) - // when { resource.tags.contains("private") } - // unless { resource in principal.allowed_resources }; - private := types.String("private") - _ = ast.Annotation("example", "two"). - Forbid(). - When( - ast.Resource().Access("tags").Contains(ast.String(private)), - ). - Unless( - ast.Resource().In(ast.Principal().Access("allowed_resources")), - ) - - // forbid (principal, action, resource) - // when { {x: "value"}.x == "value" } - // when { {x: 1 + context.fooCount}.x == 3 } - // when { [1, 2 + 3, context.fooCount].contains(1) }; - simpleRecord := types.Record{ - "x": types.String("value"), - } - _ = ast.Forbid(). - When( - ast.Record(simpleRecord).Access("x").Equals(ast.String("value")), - ). - When( - ast.RecordNodes(map[types.String]ast.Node{ - "x": ast.Long(1).Plus(ast.Context().Access("fooCount")), - }).Access("x").Equals(ast.Long(3)), - ). - When( - ast.SetNodes( - ast.Long(1), - ast.Long(2).Plus(ast.Long(3)), - ast.Context().Access("fooCount"), - ).Contains(ast.Long(1)), - ) -} From bca47fdb2de599cfc958508d4b9e23bfdc5fbfe1 Mon Sep 17 00:00:00 2001 From: philhassey Date: Tue, 13 Aug 2024 12:21:51 -0600 Subject: [PATCH 088/216] internal/json: separate json handling out Addresses IDX-142 Signed-off-by: philhassey --- internal/ast/node.go | 2 +- internal/ast/policy.go | 6 +- internal/ast/scope.go | 22 ++-- internal/ast/value.go | 16 +-- internal/ast/variable.go | 16 +-- internal/{ast => json}/json.go | 2 +- internal/{ast => json}/json_marshal.go | 115 ++++++++++--------- internal/{ast => json}/json_test.go | 9 +- internal/{ast => json}/json_unmarshal.go | 135 ++++++++++++----------- 9 files changed, 165 insertions(+), 158 deletions(-) rename internal/{ast => json}/json.go (99%) rename internal/{ast => json}/json_marshal.go (78%) rename internal/{ast => json}/json_test.go (99%) rename internal/{ast => json}/json_unmarshal.go (60%) diff --git a/internal/ast/node.go b/internal/ast/node.go index c3ac6171..8c52b190 100644 --- a/internal/ast/node.go +++ b/internal/ast/node.go @@ -197,7 +197,7 @@ func stripNodes(args []Node) []IsNode { return res } -func newExtensionCall(method types.String, args ...Node) Node { +func NewExtensionCall(method types.String, args ...Node) Node { return NewNode(NodeTypeExtensionCall{ Name: method, Args: stripNodes(args), diff --git a/internal/ast/policy.go b/internal/ast/policy.go index e77eb8db..24b92cef 100644 --- a/internal/ast/policy.go +++ b/internal/ast/policy.go @@ -56,9 +56,9 @@ func newPolicy(effect Effect, annotations []AnnotationType) *Policy { return &Policy{ Effect: effect, Annotations: annotations, - Principal: Scope(newPrincipalNode()).All(), - Action: Scope(newActionNode()).All(), - Resource: Scope(newResourceNode()).All(), + Principal: Scope(NewPrincipalNode()).All(), + Action: Scope(NewActionNode()).All(), + Resource: Scope(NewResourceNode()).All(), } } diff --git a/internal/ast/scope.go b/internal/ast/scope.go index 27c43f79..454cd0ce 100644 --- a/internal/ast/scope.go +++ b/internal/ast/scope.go @@ -33,57 +33,57 @@ func (s Scope) IsIn(entityType types.Path, entity types.EntityUID) IsScopeNode { } func (p *Policy) PrincipalEq(entity types.EntityUID) *Policy { - p.Principal = Scope(newPrincipalNode()).Eq(entity) + p.Principal = Scope(NewPrincipalNode()).Eq(entity) return p } func (p *Policy) PrincipalIn(entity types.EntityUID) *Policy { - p.Principal = Scope(newPrincipalNode()).In(entity) + p.Principal = Scope(NewPrincipalNode()).In(entity) return p } func (p *Policy) PrincipalIs(entityType types.Path) *Policy { - p.Principal = Scope(newPrincipalNode()).Is(entityType) + p.Principal = Scope(NewPrincipalNode()).Is(entityType) return p } func (p *Policy) PrincipalIsIn(entityType types.Path, entity types.EntityUID) *Policy { - p.Principal = Scope(newPrincipalNode()).IsIn(entityType, entity) + p.Principal = Scope(NewPrincipalNode()).IsIn(entityType, entity) return p } func (p *Policy) ActionEq(entity types.EntityUID) *Policy { - p.Action = Scope(newActionNode()).Eq(entity) + p.Action = Scope(NewActionNode()).Eq(entity) return p } func (p *Policy) ActionIn(entity types.EntityUID) *Policy { - p.Action = Scope(newActionNode()).In(entity) + p.Action = Scope(NewActionNode()).In(entity) return p } func (p *Policy) ActionInSet(entities ...types.EntityUID) *Policy { - p.Action = Scope(newActionNode()).InSet(entities) + p.Action = Scope(NewActionNode()).InSet(entities) return p } func (p *Policy) ResourceEq(entity types.EntityUID) *Policy { - p.Resource = Scope(newResourceNode()).Eq(entity) + p.Resource = Scope(NewResourceNode()).Eq(entity) return p } func (p *Policy) ResourceIn(entity types.EntityUID) *Policy { - p.Resource = Scope(newResourceNode()).In(entity) + p.Resource = Scope(NewResourceNode()).In(entity) return p } func (p *Policy) ResourceIs(entityType types.Path) *Policy { - p.Resource = Scope(newResourceNode()).Is(entityType) + p.Resource = Scope(NewResourceNode()).Is(entityType) return p } func (p *Policy) ResourceIsIn(entityType types.Path, entity types.EntityUID) *Policy { - p.Resource = Scope(newResourceNode()).IsIn(entityType, entity) + p.Resource = Scope(NewResourceNode()).IsIn(entityType, entity) return p } diff --git a/internal/ast/value.go b/internal/ast/value.go index 6882d161..ab685ff7 100644 --- a/internal/ast/value.go +++ b/internal/ast/value.go @@ -7,7 +7,7 @@ import ( ) func Boolean(b types.Boolean) Node { - return newValueNode(b) + return NewValueNode(b) } func True() Node { @@ -19,11 +19,11 @@ func False() Node { } func String(s types.String) Node { - return newValueNode(s) + return NewValueNode(s) } func Long(l types.Long) Node { - return newValueNode(l) + return NewValueNode(l) } // Set is a convenience function that wraps concrete instances of a Cedar Set type @@ -95,22 +95,22 @@ func RecordElements(elements ...RecordElement) Node { } func EntityUID(e types.EntityUID) Node { - return newValueNode(e) + return NewValueNode(e) } func Decimal(d types.Decimal) Node { - return newValueNode(d) + return NewValueNode(d) } func IPAddr(i types.IPAddr) Node { - return newValueNode(i) + return NewValueNode(i) } func ExtensionCall(name types.String, args ...Node) Node { - return newExtensionCall(name, args...) + return NewExtensionCall(name, args...) } -func newValueNode(v types.Value) Node { +func NewValueNode(v types.Value) Node { return NewNode(NodeValue{Value: v}) } diff --git a/internal/ast/variable.go b/internal/ast/variable.go index 7254b8c8..be398a8d 100644 --- a/internal/ast/variable.go +++ b/internal/ast/variable.go @@ -3,33 +3,33 @@ package ast import "github.com/cedar-policy/cedar-go/types" func Principal() Node { - return NewNode(newPrincipalNode()) + return NewNode(NewPrincipalNode()) } func Action() Node { - return NewNode(newActionNode()) + return NewNode(NewActionNode()) } func Resource() Node { - return NewNode(newResourceNode()) + return NewNode(NewResourceNode()) } func Context() Node { - return NewNode(newContextNode()) + return NewNode(NewContextNode()) } -func newPrincipalNode() NodeTypeVariable { +func NewPrincipalNode() NodeTypeVariable { return NodeTypeVariable{Name: types.String("principal")} } -func newActionNode() NodeTypeVariable { +func NewActionNode() NodeTypeVariable { return NodeTypeVariable{Name: types.String("action")} } -func newResourceNode() NodeTypeVariable { +func NewResourceNode() NodeTypeVariable { return NodeTypeVariable{Name: types.String("resource")} } -func newContextNode() NodeTypeVariable { +func NewContextNode() NodeTypeVariable { return NodeTypeVariable{Name: types.String("context")} } diff --git a/internal/ast/json.go b/internal/json/json.go similarity index 99% rename from internal/ast/json.go rename to internal/json/json.go index fe023e10..a639e1c7 100644 --- a/internal/ast/json.go +++ b/internal/json/json.go @@ -1,4 +1,4 @@ -package ast +package json import ( "encoding/json" diff --git a/internal/ast/json_marshal.go b/internal/json/json_marshal.go similarity index 78% rename from internal/ast/json_marshal.go rename to internal/json/json_marshal.go index 8d228280..cf6899af 100644 --- a/internal/ast/json_marshal.go +++ b/internal/json/json_marshal.go @@ -1,36 +1,37 @@ -package ast +package json import ( "encoding/json" "fmt" + "github.com/cedar-policy/cedar-go/internal/ast" "github.com/cedar-policy/cedar-go/types" ) -func (s *scopeJSON) FromNode(src IsScopeNode) error { +func (s *scopeJSON) FromNode(src ast.IsScopeNode) error { switch t := src.(type) { - case ScopeTypeAll: + case ast.ScopeTypeAll: s.Op = "All" return nil - case ScopeTypeEq: + case ast.ScopeTypeEq: s.Op = "==" e := t.Entity s.Entity = &e return nil - case ScopeTypeIn: + case ast.ScopeTypeIn: s.Op = "in" e := t.Entity s.Entity = &e return nil - case ScopeTypeInSet: + case ast.ScopeTypeInSet: s.Op = "in" s.Entities = t.Entities return nil - case ScopeTypeIs: + case ast.ScopeTypeIs: s.Op = "is" s.EntityType = string(t.Type) return nil - case ScopeTypeIsIn: + case ast.ScopeTypeIsIn: s.Op = "is" s.EntityType = string(t.Type) s.In = &scopeInJSON{ @@ -41,8 +42,8 @@ func (s *scopeJSON) FromNode(src IsScopeNode) error { return fmt.Errorf("unexpected scope node: %T", src) } -func unaryToJSON(dest **unaryJSON, src UnaryNode) error { - n := UnaryNode(src) +func unaryToJSON(dest **unaryJSON, src ast.UnaryNode) error { + n := ast.UnaryNode(src) res := &unaryJSON{} if err := res.Arg.FromNode(n.Arg); err != nil { return fmt.Errorf("error in arg: %w", err) @@ -51,8 +52,8 @@ func unaryToJSON(dest **unaryJSON, src UnaryNode) error { return nil } -func binaryToJSON(dest **binaryJSON, src BinaryNode) error { - n := BinaryNode(src) +func binaryToJSON(dest **binaryJSON, src ast.BinaryNode) error { + n := ast.BinaryNode(src) res := &binaryJSON{} if err := res.Left.FromNode(n.Left); err != nil { return fmt.Errorf("error in left: %w", err) @@ -64,7 +65,7 @@ func binaryToJSON(dest **binaryJSON, src BinaryNode) error { return nil } -func arrayToJSON(dest *arrayJSON, args []IsNode) error { +func arrayToJSON(dest *arrayJSON, args []ast.IsNode) error { res := arrayJSON{} for _, n := range args { var nn nodeJSON @@ -90,7 +91,7 @@ func extToJSON(dest *extensionCallJSON, name string, src types.Value) error { return nil } -func extCallToJSON(dest extensionCallJSON, src NodeTypeExtensionCall) error { +func extCallToJSON(dest extensionCallJSON, src ast.NodeTypeExtensionCall) error { jsonArgs := arrayJSON{} for _, n := range src.Args { argNode := &nodeJSON{} @@ -104,7 +105,7 @@ func extCallToJSON(dest extensionCallJSON, src NodeTypeExtensionCall) error { return nil } -func strToJSON(dest **strJSON, src StrOpNode) error { +func strToJSON(dest **strJSON, src ast.StrOpNode) error { res := &strJSON{} if err := res.Left.FromNode(src.Arg); err != nil { return fmt.Errorf("error in left: %w", err) @@ -114,7 +115,7 @@ func strToJSON(dest **strJSON, src StrOpNode) error { return nil } -func patternToJSON(dest **patternJSON, src NodeTypeLike) error { +func patternToJSON(dest **patternJSON, src ast.NodeTypeLike) error { res := &patternJSON{} if err := res.Left.FromNode(src.Arg); err != nil { return fmt.Errorf("error in left: %w", err) @@ -131,7 +132,7 @@ func patternToJSON(dest **patternJSON, src NodeTypeLike) error { return nil } -func recordToJSON(dest *recordJSON, src NodeTypeRecord) error { +func recordToJSON(dest *recordJSON, src ast.NodeTypeRecord) error { res := recordJSON{} for _, kv := range src.Elements { var nn nodeJSON @@ -144,7 +145,7 @@ func recordToJSON(dest *recordJSON, src NodeTypeRecord) error { return nil } -func ifToJSON(dest **ifThenElseJSON, src NodeTypeIf) error { +func ifToJSON(dest **ifThenElseJSON, src ast.NodeTypeIf) error { res := &ifThenElseJSON{} if err := res.If.FromNode(src.If); err != nil { return fmt.Errorf("error in if: %w", err) @@ -159,7 +160,7 @@ func ifToJSON(dest **ifThenElseJSON, src NodeTypeIf) error { return nil } -func isToJSON(dest **isJSON, src NodeTypeIs) error { +func isToJSON(dest **isJSON, src ast.NodeTypeIs) error { res := &isJSON{} if err := res.Left.FromNode(src.Left); err != nil { return fmt.Errorf("error in left: %w", err) @@ -169,7 +170,7 @@ func isToJSON(dest **isJSON, src NodeTypeIs) error { return nil } -func isInToJSON(dest **isJSON, src NodeTypeIsIn) error { +func isInToJSON(dest **isJSON, src ast.NodeTypeIsIn) error { res := &isJSON{} if err := res.Left.FromNode(src.Left); err != nil { return fmt.Errorf("error in left: %w", err) @@ -183,11 +184,11 @@ func isInToJSON(dest **isJSON, src NodeTypeIsIn) error { return nil } -func (j *nodeJSON) FromNode(src IsNode) error { +func (j *nodeJSON) FromNode(src ast.IsNode) error { switch t := src.(type) { // Value // Value *json.RawMessage `json:"Value"` // could be any - case NodeValue: + case ast.NodeValue: // Any other function: decimal, ip // Decimal arrayJSON `json:"decimal"` // IP arrayJSON `json:"ip"` @@ -203,7 +204,7 @@ func (j *nodeJSON) FromNode(src IsNode) error { // Var // Var *string `json:"Var"` - case NodeTypeVariable: + case ast.NodeTypeVariable: val := string(t.Name) j.Var = &val return nil @@ -211,87 +212,87 @@ func (j *nodeJSON) FromNode(src IsNode) error { // ! or neg operators // Not *unaryJSON `json:"!"` // Negate *unaryJSON `json:"neg"` - case NodeTypeNot: + case ast.NodeTypeNot: return unaryToJSON(&j.Not, t.UnaryNode) - case NodeTypeNegate: + case ast.NodeTypeNegate: return unaryToJSON(&j.Negate, t.UnaryNode) // Binary operators: ==, !=, in, <, <=, >, >=, &&, ||, +, -, *, contains, containsAll, containsAny - case NodeTypeAdd: + case ast.NodeTypeAdd: return binaryToJSON(&j.Plus, t.BinaryNode) - case NodeTypeAnd: + case ast.NodeTypeAnd: return binaryToJSON(&j.And, t.BinaryNode) - case NodeTypeContains: + case ast.NodeTypeContains: return binaryToJSON(&j.Contains, t.BinaryNode) - case NodeTypeContainsAll: + case ast.NodeTypeContainsAll: return binaryToJSON(&j.ContainsAll, t.BinaryNode) - case NodeTypeContainsAny: + case ast.NodeTypeContainsAny: return binaryToJSON(&j.ContainsAny, t.BinaryNode) - case NodeTypeEquals: + case ast.NodeTypeEquals: return binaryToJSON(&j.Equals, t.BinaryNode) - case NodeTypeGreaterThan: + case ast.NodeTypeGreaterThan: return binaryToJSON(&j.GreaterThan, t.BinaryNode) - case NodeTypeGreaterThanOrEqual: + case ast.NodeTypeGreaterThanOrEqual: return binaryToJSON(&j.GreaterThanOrEqual, t.BinaryNode) - case NodeTypeIn: + case ast.NodeTypeIn: return binaryToJSON(&j.In, t.BinaryNode) - case NodeTypeLessThan: + case ast.NodeTypeLessThan: return binaryToJSON(&j.LessThan, t.BinaryNode) - case NodeTypeLessThanOrEqual: + case ast.NodeTypeLessThanOrEqual: return binaryToJSON(&j.LessThanOrEqual, t.BinaryNode) - case NodeTypeMult: + case ast.NodeTypeMult: return binaryToJSON(&j.Times, t.BinaryNode) - case NodeTypeNotEquals: + case ast.NodeTypeNotEquals: return binaryToJSON(&j.NotEquals, t.BinaryNode) - case NodeTypeOr: + case ast.NodeTypeOr: return binaryToJSON(&j.Or, t.BinaryNode) - case NodeTypeSub: + case ast.NodeTypeSub: return binaryToJSON(&j.Minus, t.BinaryNode) // ., has // Access *strJSON `json:"."` // Has *strJSON `json:"has"` - case NodeTypeAccess: + case ast.NodeTypeAccess: return strToJSON(&j.Access, t.StrOpNode) - case NodeTypeHas: + case ast.NodeTypeHas: return strToJSON(&j.Has, t.StrOpNode) // is - case NodeTypeIs: + case ast.NodeTypeIs: return isToJSON(&j.Is, t) - case NodeTypeIsIn: + case ast.NodeTypeIsIn: return isInToJSON(&j.Is, t) // like // Like *strJSON `json:"like"` - case NodeTypeLike: + case ast.NodeTypeLike: return patternToJSON(&j.Like, t) // if-then-else // IfThenElse *ifThenElseJSON `json:"if-then-else"` - case NodeTypeIf: + case ast.NodeTypeIf: return ifToJSON(&j.IfThenElse, t) // Set // Set arrayJSON `json:"Set"` - case NodeTypeSet: + case ast.NodeTypeSet: return arrayToJSON(&j.Set, t.Elements) // Record // Record recordJSON `json:"Record"` - case NodeTypeRecord: + case ast.NodeTypeRecord: return recordToJSON(&j.Record, t) // Any other method: ip, decimal, lessThan, lessThanOrEqual, greaterThan, greaterThanOrEqual, isIpv4, isIpv6, isLoopback, isMulticast, isInRange // ExtensionMethod map[string]arrayJSON `json:"-"` - case NodeTypeExtensionCall: + case ast.NodeTypeExtensionCall: j.ExtensionCall = extensionCallJSON{} return extCallToJSON(j.ExtensionCall, t) } - // case nodeTypeRecordEntry: - // case nodeTypeEntityType: - // case nodeTypeAnnotation: - // case nodeTypeWhen: - // case nodeTypeUnless: + // case ast.nodeTypeRecordEntry: + // case ast.nodeTypeEntityType: + // case ast.nodeTypeAnnotation: + // case ast.nodeTypeWhen: + // case ast.nodeTypeUnless: return fmt.Errorf("unknown node type: %T", src) } @@ -311,6 +312,10 @@ func (p *patternComponentJSON) MarshalJSON() ([]byte, error) { return json.Marshal(p.Literal) } +type Policy struct { + ast.Policy +} + func (p *Policy) MarshalJSON() ([]byte, error) { var j policyJSON j.Effect = "forbid" @@ -335,7 +340,7 @@ func (p *Policy) MarshalJSON() ([]byte, error) { for _, c := range p.Conditions { var cond conditionJSON cond.Kind = "when" - if c.Condition == ConditionUnless { + if c.Condition == ast.ConditionUnless { cond.Kind = "unless" } if err := cond.Body.FromNode(c.Body); err != nil { diff --git a/internal/ast/json_test.go b/internal/json/json_test.go similarity index 99% rename from internal/ast/json_test.go rename to internal/json/json_test.go index feb3bb76..13a7ce2f 100644 --- a/internal/ast/json_test.go +++ b/internal/json/json_test.go @@ -1,4 +1,4 @@ -package ast_test +package json import ( "encoding/json" @@ -479,13 +479,13 @@ func TestUnmarshalJSON(t *testing.T) { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() - var p ast.Policy + var p Policy err := json.Unmarshal([]byte(tt.input), &p) tt.errFunc(t, err) if err != nil { return } - testutil.Equals(t, p, *tt.want) + testutil.Equals(t, p.Policy, *tt.want) b, err := json.Marshal(&p) testutil.OK(t, err) normInput := testNormalizeJSON(t, tt.input) @@ -535,7 +535,8 @@ func TestMarshalJSON(t *testing.T) { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() - b, err := json.Marshal(tt.input) + pp := &Policy{Policy: *tt.input} + b, err := json.Marshal(pp) tt.errFunc(t, err) if err != nil { return diff --git a/internal/ast/json_unmarshal.go b/internal/json/json_unmarshal.go similarity index 60% rename from internal/ast/json_unmarshal.go rename to internal/json/json_unmarshal.go index 4076051b..47eb9892 100644 --- a/internal/ast/json_unmarshal.go +++ b/internal/json/json_unmarshal.go @@ -1,4 +1,4 @@ -package ast +package json import ( "bytes" @@ -6,10 +6,11 @@ import ( "fmt" "strings" + "github.com/cedar-policy/cedar-go/internal/ast" "github.com/cedar-policy/cedar-go/types" ) -func (s *scopeJSON) ToNode(variable Scope) (IsScopeNode, error) { +func (s *scopeJSON) ToNode(variable ast.Scope) (ast.IsScopeNode, error) { // TODO: should we be careful to be more strict about what is allowed here? switch s.Op { case "All": @@ -33,35 +34,35 @@ func (s *scopeJSON) ToNode(variable Scope) (IsScopeNode, error) { return nil, fmt.Errorf("unknown op: %v", s.Op) } -func (j binaryJSON) ToNode(f func(a, b Node) Node) (Node, error) { +func (j binaryJSON) ToNode(f func(a, b ast.Node) ast.Node) (ast.Node, error) { left, err := j.Left.ToNode() if err != nil { - return Node{}, fmt.Errorf("error in left: %w", err) + return ast.Node{}, fmt.Errorf("error in left: %w", err) } right, err := j.Right.ToNode() if err != nil { - return Node{}, fmt.Errorf("error in right: %w", err) + return ast.Node{}, fmt.Errorf("error in right: %w", err) } return f(left, right), nil } -func (j unaryJSON) ToNode(f func(a Node) Node) (Node, error) { +func (j unaryJSON) ToNode(f func(a ast.Node) ast.Node) (ast.Node, error) { arg, err := j.Arg.ToNode() if err != nil { - return Node{}, fmt.Errorf("error in arg: %w", err) + return ast.Node{}, fmt.Errorf("error in arg: %w", err) } return f(arg), nil } -func (j strJSON) ToNode(f func(a Node, k string) Node) (Node, error) { +func (j strJSON) ToNode(f func(a ast.Node, k string) ast.Node) (ast.Node, error) { left, err := j.Left.ToNode() if err != nil { - return Node{}, fmt.Errorf("error in left: %w", err) + return ast.Node{}, fmt.Errorf("error in left: %w", err) } return f(left, j.Attr), nil } -func (j patternJSON) ToNode(f func(a Node, k types.Pattern) Node) (Node, error) { +func (j patternJSON) ToNode(f func(a ast.Node, k types.Pattern) ast.Node) (ast.Node, error) { left, err := j.Left.ToNode() if err != nil { - return Node{}, fmt.Errorf("error in left: %w", err) + return ast.Node{}, fmt.Errorf("error in left: %w", err) } pattern := &types.Pattern{} for _, compJSON := range j.Pattern { @@ -74,150 +75,150 @@ func (j patternJSON) ToNode(f func(a Node, k types.Pattern) Node) (Node, error) return f(left, *pattern), nil } -func (j isJSON) ToNode() (Node, error) { +func (j isJSON) ToNode() (ast.Node, error) { left, err := j.Left.ToNode() if err != nil { - return Node{}, fmt.Errorf("error in left: %w", err) + return ast.Node{}, fmt.Errorf("error in left: %w", err) } if j.In != nil { right, err := j.In.ToNode() if err != nil { - return Node{}, fmt.Errorf("error in entity: %w", err) + return ast.Node{}, fmt.Errorf("error in entity: %w", err) } return left.IsIn(types.Path(j.EntityType), right), nil } return left.Is(types.Path(j.EntityType)), nil } -func (j ifThenElseJSON) ToNode() (Node, error) { +func (j ifThenElseJSON) ToNode() (ast.Node, error) { if_, err := j.If.ToNode() if err != nil { - return Node{}, fmt.Errorf("error in if: %w", err) + return ast.Node{}, fmt.Errorf("error in if: %w", err) } then, err := j.Then.ToNode() if err != nil { - return Node{}, fmt.Errorf("error in then: %w", err) + return ast.Node{}, fmt.Errorf("error in then: %w", err) } else_, err := j.Else.ToNode() if err != nil { - return Node{}, fmt.Errorf("error in else: %w", err) + return ast.Node{}, fmt.Errorf("error in else: %w", err) } - return If(if_, then, else_), nil + return ast.If(if_, then, else_), nil } -func (j arrayJSON) ToNode() (Node, error) { - var nodes []Node +func (j arrayJSON) ToNode() (ast.Node, error) { + var nodes []ast.Node for _, jj := range j { n, err := jj.ToNode() if err != nil { - return Node{}, fmt.Errorf("error in set: %w", err) + return ast.Node{}, fmt.Errorf("error in set: %w", err) } nodes = append(nodes, n) } - return SetNodes(nodes...), nil + return ast.SetNodes(nodes...), nil } -func (j recordJSON) ToNode() (Node, error) { - nodes := map[types.String]Node{} +func (j recordJSON) ToNode() (ast.Node, error) { + nodes := map[types.String]ast.Node{} for k, v := range j { n, err := v.ToNode() if err != nil { - return Node{}, fmt.Errorf("error in record: %w", err) + return ast.Node{}, fmt.Errorf("error in record: %w", err) } nodes[types.String(k)] = n } - return RecordNodes(nodes), nil + return ast.RecordNodes(nodes), nil } -func (e extensionCallJSON) ToNode() (Node, error) { +func (e extensionCallJSON) ToNode() (ast.Node, error) { if len(e) != 1 { - return Node{}, fmt.Errorf("unexpected number of extension methods in node: %v", len(e)) + return ast.Node{}, fmt.Errorf("unexpected number of extension methods in node: %v", len(e)) } for k, v := range e { if len(v) == 0 { - return Node{}, fmt.Errorf("extension method '%v' must have at least one argument", k) + return ast.Node{}, fmt.Errorf("extension method '%v' must have at least one argument", k) } - var argNodes []Node + var argNodes []ast.Node for _, n := range v { node, err := n.ToNode() if err != nil { - return Node{}, fmt.Errorf("error in extension method argument: %w", err) + return ast.Node{}, fmt.Errorf("error in extension method argument: %w", err) } argNodes = append(argNodes, node) } - return newExtensionCall(types.String(k), argNodes...), nil + return ast.NewExtensionCall(types.String(k), argNodes...), nil } panic("unreachable code") } -func (j nodeJSON) ToNode() (Node, error) { +func (j nodeJSON) ToNode() (ast.Node, error) { switch { // Value case j.Value != nil: var v types.Value if err := types.UnmarshalJSON(*j.Value, &v); err != nil { - return Node{}, fmt.Errorf("error unmarshalling value: %w", err) + return ast.Node{}, fmt.Errorf("error unmarshalling value: %w", err) } - return valueToNode(v), nil + return ast.NewValueNode(v), nil // Var case j.Var != nil: switch *j.Var { case "principal": - return Principal(), nil + return ast.Principal(), nil case "action": - return Action(), nil + return ast.Action(), nil case "resource": - return Resource(), nil + return ast.Resource(), nil case "context": - return Context(), nil + return ast.Context(), nil } - return Node{}, fmt.Errorf("unknown var: %v", j.Var) + return ast.Node{}, fmt.Errorf("unknown var: %v", j.Var) // Slot // Unknown // ! or neg operators case j.Not != nil: - return j.Not.ToNode(Not) + return j.Not.ToNode(ast.Not) case j.Negate != nil: - return j.Negate.ToNode(Negate) + return j.Negate.ToNode(ast.Negate) // Binary operators: ==, !=, in, <, <=, >, >=, &&, ||, +, -, *, contains, containsAll, containsAny case j.Equals != nil: - return j.Equals.ToNode(Node.Equals) + return j.Equals.ToNode(ast.Node.Equals) case j.NotEquals != nil: - return j.NotEquals.ToNode(Node.NotEquals) + return j.NotEquals.ToNode(ast.Node.NotEquals) case j.In != nil: - return j.In.ToNode(Node.In) + return j.In.ToNode(ast.Node.In) case j.LessThan != nil: - return j.LessThan.ToNode(Node.LessThan) + return j.LessThan.ToNode(ast.Node.LessThan) case j.LessThanOrEqual != nil: - return j.LessThanOrEqual.ToNode(Node.LessThanOrEqual) + return j.LessThanOrEqual.ToNode(ast.Node.LessThanOrEqual) case j.GreaterThan != nil: - return j.GreaterThan.ToNode(Node.GreaterThan) + return j.GreaterThan.ToNode(ast.Node.GreaterThan) case j.GreaterThanOrEqual != nil: - return j.GreaterThanOrEqual.ToNode(Node.GreaterThanOrEqual) + return j.GreaterThanOrEqual.ToNode(ast.Node.GreaterThanOrEqual) case j.And != nil: - return j.And.ToNode(Node.And) + return j.And.ToNode(ast.Node.And) case j.Or != nil: - return j.Or.ToNode(Node.Or) + return j.Or.ToNode(ast.Node.Or) case j.Plus != nil: - return j.Plus.ToNode(Node.Plus) + return j.Plus.ToNode(ast.Node.Plus) case j.Minus != nil: - return j.Minus.ToNode(Node.Minus) + return j.Minus.ToNode(ast.Node.Minus) case j.Times != nil: - return j.Times.ToNode(Node.Times) + return j.Times.ToNode(ast.Node.Times) case j.Contains != nil: - return j.Contains.ToNode(Node.Contains) + return j.Contains.ToNode(ast.Node.Contains) case j.ContainsAll != nil: - return j.ContainsAll.ToNode(Node.ContainsAll) + return j.ContainsAll.ToNode(ast.Node.ContainsAll) case j.ContainsAny != nil: - return j.ContainsAny.ToNode(Node.ContainsAny) + return j.ContainsAny.ToNode(ast.Node.ContainsAny) // ., has case j.Access != nil: - return j.Access.ToNode(Node.Access) + return j.Access.ToNode(ast.Node.Access) case j.Has != nil: - return j.Has.ToNode(Node.Has) + return j.Has.ToNode(ast.Node.Has) // is case j.Is != nil: @@ -225,7 +226,7 @@ func (j nodeJSON) ToNode() (Node, error) { // like case j.Like != nil: - return j.Like.ToNode(Node.Like) + return j.Like.ToNode(ast.Node.Like) // if-then-else case j.IfThenElse != nil: @@ -244,7 +245,7 @@ func (j nodeJSON) ToNode() (Node, error) { return j.ExtensionCall.ToNode() } - return Node{}, fmt.Errorf("unknown node") + return ast.Node{}, fmt.Errorf("unknown node") } func (n *nodeJSON) UnmarshalJSON(b []byte) error { @@ -287,9 +288,9 @@ func (p *Policy) UnmarshalJSON(b []byte) error { } switch j.Effect { case "permit": - *p = *Permit() + p.Policy = *ast.Permit() case "forbid": - *p = *Forbid() + p.Policy = *ast.Forbid() default: return fmt.Errorf("unknown effect: %v", j.Effect) } @@ -297,15 +298,15 @@ func (p *Policy) UnmarshalJSON(b []byte) error { p.Annotate(types.String(k), types.String(v)) } var err error - p.Principal, err = j.Principal.ToNode(Scope(newPrincipalNode())) + p.Principal, err = j.Principal.ToNode(ast.Scope(ast.NewPrincipalNode())) if err != nil { return fmt.Errorf("error in principal: %w", err) } - p.Action, err = j.Action.ToNode(Scope(newActionNode())) + p.Action, err = j.Action.ToNode(ast.Scope(ast.NewActionNode())) if err != nil { return fmt.Errorf("error in action: %w", err) } - p.Resource, err = j.Resource.ToNode(Scope(newResourceNode())) + p.Resource, err = j.Resource.ToNode(ast.Scope(ast.NewResourceNode())) if err != nil { return fmt.Errorf("error in resource: %w", err) } From 88dfff2d39e98eed2da14f6b58272f95b336e3c4 Mon Sep 17 00:00:00 2001 From: philhassey Date: Tue, 13 Aug 2024 12:37:37 -0600 Subject: [PATCH 089/216] internal/eval: move ToNode to eval package Addresses IDX-142 Signed-off-by: philhassey --- internal/ast/scope.go | 29 ----------------------------- internal/eval/eval_compile.go | 6 +++--- internal/eval/eval_convert.go | 25 +++++++++++++++++++++++++ 3 files changed, 28 insertions(+), 32 deletions(-) diff --git a/internal/ast/scope.go b/internal/ast/scope.go index 454cd0ce..cad280d7 100644 --- a/internal/ast/scope.go +++ b/internal/ast/scope.go @@ -90,7 +90,6 @@ func (p *Policy) ResourceIsIn(entityType types.Path, entity types.EntityUID) *Po type IsScopeNode interface { isScope() MarshalCedar(*bytes.Buffer) - ToNode() Node } type ScopeNode struct { @@ -103,56 +102,28 @@ type ScopeTypeAll struct { ScopeNode } -func (n ScopeTypeAll) ToNode() Node { - return NewNode(True().v) -} - type ScopeTypeEq struct { ScopeNode Entity types.EntityUID } -func (n ScopeTypeEq) ToNode() Node { - return NewNode(NewNode(n.Variable).Equals(EntityUID(n.Entity)).v) -} - type ScopeTypeIn struct { ScopeNode Entity types.EntityUID } -func (n ScopeTypeIn) ToNode() Node { - return NewNode(NewNode(n.Variable).In(EntityUID(n.Entity)).v) -} - type ScopeTypeInSet struct { ScopeNode Entities []types.EntityUID } -func (n ScopeTypeInSet) ToNode() Node { - set := make([]types.Value, len(n.Entities)) - for i, e := range n.Entities { - set[i] = e - } - return NewNode(NewNode(n.Variable).In(Set(set)).v) -} - type ScopeTypeIs struct { ScopeNode Type types.Path } -func (n ScopeTypeIs) ToNode() Node { - return NewNode(NewNode(n.Variable).Is(n.Type).v) -} - type ScopeTypeIsIn struct { ScopeNode Type types.Path Entity types.EntityUID } - -func (n ScopeTypeIsIn) ToNode() Node { - return NewNode(NewNode(n.Variable).IsIn(n.Type, EntityUID(n.Entity)).v) -} diff --git a/internal/eval/eval_compile.go b/internal/eval/eval_compile.go index aabc341d..0cb65032 100644 --- a/internal/eval/eval_compile.go +++ b/internal/eval/eval_compile.go @@ -15,9 +15,9 @@ func Compile(p ast.Policy) Evaler { func policyToNode(p ast.Policy) ast.Node { nodes := make([]ast.Node, 3+len(p.Conditions)) - nodes[0] = p.Principal.ToNode() - nodes[1] = p.Action.ToNode() - nodes[2] = p.Resource.ToNode() + nodes[0] = scopeToNode(p.Principal) + nodes[1] = scopeToNode(p.Action) + nodes[2] = scopeToNode(p.Resource) for i, c := range p.Conditions { if c.Condition == ast.ConditionUnless { nodes[i+3] = ast.Not(ast.NewNode(c.Body)) diff --git a/internal/eval/eval_convert.go b/internal/eval/eval_convert.go index d956555f..bec793eb 100644 --- a/internal/eval/eval_convert.go +++ b/internal/eval/eval_convert.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/cedar-policy/cedar-go/internal/ast" + "github.com/cedar-policy/cedar-go/types" ) func toEval(n ast.IsNode) Evaler { @@ -124,3 +125,27 @@ func toEval(n ast.IsNode) Evaler { panic(fmt.Sprintf("unknown node type %T", v)) } } + +func scopeToNode(in ast.IsScopeNode) ast.Node { + switch t := in.(type) { + case ast.ScopeTypeAll: + return ast.True() + case ast.ScopeTypeEq: + return ast.NewNode(t.Variable).Equals(ast.EntityUID(t.Entity)) + case ast.ScopeTypeIn: + return ast.NewNode(t.Variable).In(ast.EntityUID(t.Entity)) + case ast.ScopeTypeInSet: + set := make([]types.Value, len(t.Entities)) + for i, e := range t.Entities { + set[i] = e + } + return ast.NewNode(t.Variable).In(ast.Set(set)) + case ast.ScopeTypeIs: + return ast.NewNode(t.Variable).Is(t.Type) + + case ast.ScopeTypeIsIn: + return ast.NewNode(t.Variable).IsIn(t.Type, ast.EntityUID(t.Entity)) + default: + panic(fmt.Sprintf("unknown scope type %T", t)) + } +} From 2d1db20c73a28e2ac29913b252d33cf224fba309 Mon Sep 17 00:00:00 2001 From: philhassey Date: Tue, 13 Aug 2024 14:26:36 -0600 Subject: [PATCH 090/216] internal/parser: move parser out of ast package Addresses IDX-142 Signed-off-by: philhassey --- cedar.go | 6 +- internal/ast/cedar_marshal.go | 323 --------------- internal/ast/node.go | 100 +---- internal/ast/operator.go | 18 +- internal/ast/policy.go | 22 +- internal/ast/scope.go | 3 - internal/eval/eval_compile.go | 10 +- internal/{ast => parser}/cedar_fuzz_test.go | 2 +- internal/parser/cedar_marshal.go | 385 ++++++++++++++++++ internal/{ast => parser}/cedar_parse_test.go | 16 +- internal/{ast => parser}/cedar_tokenize.go | 2 +- .../cedar_tokenize_mocks_test.go | 2 +- .../{ast => parser}/cedar_tokenize_test.go | 2 +- internal/{ast => parser}/cedar_unmarshal.go | 209 +++++----- .../{ast => parser}/cedar_unmarshal_test.go | 33 +- internal/parser/node.go | 194 +++++++++ internal/parser/policy.go | 25 ++ 17 files changed, 762 insertions(+), 590 deletions(-) delete mode 100644 internal/ast/cedar_marshal.go rename internal/{ast => parser}/cedar_fuzz_test.go (99%) create mode 100644 internal/parser/cedar_marshal.go rename internal/{ast => parser}/cedar_parse_test.go (96%) rename internal/{ast => parser}/cedar_tokenize.go (99%) rename internal/{ast => parser}/cedar_tokenize_mocks_test.go (99%) rename internal/{ast => parser}/cedar_tokenize_test.go (99%) rename internal/{ast => parser}/cedar_unmarshal.go (77%) rename internal/{ast => parser}/cedar_unmarshal_test.go (95%) create mode 100644 internal/parser/node.go create mode 100644 internal/parser/policy.go diff --git a/cedar.go b/cedar.go index a5034e0f..a37685e2 100644 --- a/cedar.go +++ b/cedar.go @@ -4,9 +4,9 @@ package cedar import ( "fmt" - "github.com/cedar-policy/cedar-go/internal/ast" "github.com/cedar-policy/cedar-go/internal/entities" "github.com/cedar-policy/cedar-go/internal/eval" + "github.com/cedar-policy/cedar-go/internal/parser" "github.com/cedar-policy/cedar-go/types" ) @@ -61,7 +61,7 @@ func (a *Effect) UnmarshalJSON(b []byte) error { // given file name used in Position data. If there is an error parsing the // document, it will be returned. func NewPolicySet(fileName string, document []byte) (PolicySet, error) { - var res ast.PolicySet + var res parser.PolicySet if err := res.UnmarshalCedar(document); err != nil { return nil, fmt.Errorf("parser error: %w", err) } @@ -77,7 +77,7 @@ func NewPolicySet(fileName string, document []byte) (PolicySet, error) { }, Annotations: ann, Effect: Effect(p.TmpGetEffect()), - eval: eval.Compile(p.Policy), + eval: eval.Compile(p.Policy.Policy), }) } return policies, nil diff --git a/internal/ast/cedar_marshal.go b/internal/ast/cedar_marshal.go deleted file mode 100644 index b097c123..00000000 --- a/internal/ast/cedar_marshal.go +++ /dev/null @@ -1,323 +0,0 @@ -package ast - -import ( - "bytes" -) - -// TODO: Add errors to all of this! TODO: review this ask, I'm not sure any real errors are possible. All buf errors are panics. -func (p *Policy) MarshalCedar(buf *bytes.Buffer) { - for _, a := range p.Annotations { - a.MarshalCedar(buf) - buf.WriteRune('\n') - } - p.Effect.MarshalCedar(buf) - buf.WriteRune(' ') - p.marshalScope(buf) - - for _, c := range p.Conditions { - buf.WriteRune('\n') - c.MarshalCedar(buf) - } - - buf.WriteRune(';') -} - -func (p *Policy) marshalScope(buf *bytes.Buffer) { - _, principalAll := p.Principal.(ScopeTypeAll) - _, actionAll := p.Action.(ScopeTypeAll) - _, resourceAll := p.Resource.(ScopeTypeAll) - if principalAll && actionAll && resourceAll { - buf.WriteString("( principal, action, resource )") - return - } - - buf.WriteString("(\n ") - p.Principal.MarshalCedar(buf) - buf.WriteString(",\n ") - p.Action.MarshalCedar(buf) - buf.WriteString(",\n ") - p.Resource.MarshalCedar(buf) - buf.WriteString("\n)") -} - -func (n AnnotationType) MarshalCedar(buf *bytes.Buffer) { - buf.WriteRune('@') - buf.WriteString(string(n.Key)) - buf.WriteRune('(') - buf.WriteString(n.Value.Cedar()) - buf.WriteString(")") -} - -func (e Effect) MarshalCedar(buf *bytes.Buffer) { - if e == EffectPermit { - buf.WriteString("permit") - } else { - buf.WriteString("forbid") - } -} - -func (n NodeTypeVariable) marshalCedar(buf *bytes.Buffer) { - buf.WriteString(string(n.Name)) -} - -func (n ScopeTypeAll) MarshalCedar(buf *bytes.Buffer) { - n.Variable.marshalCedar(buf) -} - -func (n ScopeTypeEq) MarshalCedar(buf *bytes.Buffer) { - n.Variable.marshalCedar(buf) - buf.WriteString(" == ") - buf.WriteString(n.Entity.Cedar()) -} - -func (n ScopeTypeIn) MarshalCedar(buf *bytes.Buffer) { - n.Variable.marshalCedar(buf) - buf.WriteString(" in ") - buf.WriteString(n.Entity.Cedar()) -} - -func (n ScopeTypeInSet) MarshalCedar(buf *bytes.Buffer) { - n.Variable.marshalCedar(buf) - buf.WriteString(" in ") - buf.WriteRune('[') - for i := range n.Entities { - buf.WriteString(n.Entities[i].Cedar()) - if i != len(n.Entities)-1 { - buf.WriteString(", ") - } - } - buf.WriteRune(']') -} - -func (n ScopeTypeIs) MarshalCedar(buf *bytes.Buffer) { - n.Variable.marshalCedar(buf) - buf.WriteString(" is ") - buf.WriteString(string(n.Type)) -} - -func (n ScopeTypeIsIn) MarshalCedar(buf *bytes.Buffer) { - n.Variable.marshalCedar(buf) - buf.WriteString(" is ") - buf.WriteString(string(n.Type)) - buf.WriteString(" in ") - buf.WriteString(n.Entity.Cedar()) -} - -func (c ConditionType) MarshalCedar(buf *bytes.Buffer) { - if c.Condition == ConditionWhen { - buf.WriteString("when") - } else { - buf.WriteString("unless") - } - - buf.WriteString(" { ") - c.Body.marshalCedar(buf) - buf.WriteString(" }") -} - -func (n NodeValue) marshalCedar(buf *bytes.Buffer) { - buf.WriteString(n.Value.Cedar()) -} - -func marshalChildNode(thisNodePrecedence nodePrecedenceLevel, childNode IsNode, buf *bytes.Buffer) { - if thisNodePrecedence > childNode.precedenceLevel() { - buf.WriteRune('(') - childNode.marshalCedar(buf) - buf.WriteRune(')') - } else { - childNode.marshalCedar(buf) - } -} - -func (n NodeTypeNot) marshalCedar(buf *bytes.Buffer) { - buf.WriteRune('!') - marshalChildNode(n.precedenceLevel(), n.Arg, buf) -} - -func (n NodeTypeNegate) marshalCedar(buf *bytes.Buffer) { - buf.WriteRune('-') - marshalChildNode(n.precedenceLevel(), n.Arg, buf) -} - -func canMarshalAsIdent(s string) bool { - for i, r := range s { - if !isIdentRune(r, i == 0) { - return false - } - } - return true -} - -func (n NodeTypeAccess) marshalCedar(buf *bytes.Buffer) { - marshalChildNode(n.precedenceLevel(), n.Arg, buf) - - if canMarshalAsIdent(string(n.Value)) { - buf.WriteRune('.') - buf.WriteString(string(n.Value)) - } else { - buf.WriteRune('[') - buf.WriteString(n.Value.Cedar()) - buf.WriteRune(']') - } -} - -func (n NodeTypeExtensionCall) marshalCedar(buf *bytes.Buffer) { - var args []IsNode - info := ExtMap[n.Name] - if info.IsMethod { - marshalChildNode(n.precedenceLevel(), n.Args[0], buf) - buf.WriteRune('.') - args = n.Args[1:] - } else { - args = n.Args - } - buf.WriteString(string(n.Name)) - buf.WriteRune('(') - for i := range args { - marshalChildNode(n.precedenceLevel(), n.Args[i], buf) - if i != len(n.Args)-1 { - buf.WriteString(", ") - } - } - buf.WriteRune(')') -} - -func (n NodeTypeContains) marshalCedar(buf *bytes.Buffer) { - marshalChildNode(n.precedenceLevel(), n.Left, buf) - buf.WriteString(".contains(") - marshalChildNode(n.precedenceLevel(), n.Right, buf) - buf.WriteRune(')') -} - -func (n NodeTypeContainsAll) marshalCedar(buf *bytes.Buffer) { - marshalChildNode(n.precedenceLevel(), n.Left, buf) - buf.WriteString(".containsAll(") - marshalChildNode(n.precedenceLevel(), n.Right, buf) - buf.WriteRune(')') -} - -func (n NodeTypeContainsAny) marshalCedar(buf *bytes.Buffer) { - marshalChildNode(n.precedenceLevel(), n.Left, buf) - buf.WriteString(".containsAny(") - marshalChildNode(n.precedenceLevel(), n.Right, buf) - buf.WriteRune(')') -} - -func (n NodeTypeSet) marshalCedar(buf *bytes.Buffer) { - buf.WriteRune('[') - for i := range n.Elements { - marshalChildNode(n.precedenceLevel(), n.Elements[i], buf) - if i != len(n.Elements)-1 { - buf.WriteString(", ") - } - } - buf.WriteRune(']') -} - -func (n NodeTypeRecord) marshalCedar(buf *bytes.Buffer) { - buf.WriteRune('{') - for i := range n.Elements { - buf.WriteString(n.Elements[i].Key.Cedar()) - buf.WriteString(":") - marshalChildNode(n.precedenceLevel(), n.Elements[i].Value, buf) - if i != len(n.Elements)-1 { - buf.WriteString(", ") - } - } - buf.WriteRune('}') -} - -func marshalInfixBinaryOp(n BinaryNode, precedence nodePrecedenceLevel, op string, buf *bytes.Buffer) { - marshalChildNode(precedence, n.Left, buf) - buf.WriteRune(' ') - buf.WriteString(op) - buf.WriteRune(' ') - marshalChildNode(precedence, n.Right, buf) -} - -func (n NodeTypeMult) marshalCedar(buf *bytes.Buffer) { - marshalInfixBinaryOp(n.BinaryNode, n.precedenceLevel(), "*", buf) -} - -func (n NodeTypeAdd) marshalCedar(buf *bytes.Buffer) { - marshalInfixBinaryOp(n.BinaryNode, n.precedenceLevel(), "+", buf) -} - -func (n NodeTypeSub) marshalCedar(buf *bytes.Buffer) { - marshalInfixBinaryOp(n.BinaryNode, n.precedenceLevel(), "-", buf) -} - -func (n NodeTypeLessThan) marshalCedar(buf *bytes.Buffer) { - marshalInfixBinaryOp(n.BinaryNode, n.precedenceLevel(), "<", buf) -} - -func (n NodeTypeLessThanOrEqual) marshalCedar(buf *bytes.Buffer) { - marshalInfixBinaryOp(n.BinaryNode, n.precedenceLevel(), "<=", buf) -} - -func (n NodeTypeGreaterThan) marshalCedar(buf *bytes.Buffer) { - marshalInfixBinaryOp(n.BinaryNode, n.precedenceLevel(), ">", buf) -} - -func (n NodeTypeGreaterThanOrEqual) marshalCedar(buf *bytes.Buffer) { - marshalInfixBinaryOp(n.BinaryNode, n.precedenceLevel(), ">=", buf) -} - -func (n NodeTypeEquals) marshalCedar(buf *bytes.Buffer) { - marshalInfixBinaryOp(n.BinaryNode, n.precedenceLevel(), "==", buf) -} - -func (n NodeTypeNotEquals) marshalCedar(buf *bytes.Buffer) { - marshalInfixBinaryOp(n.BinaryNode, n.precedenceLevel(), "!=", buf) -} - -func (n NodeTypeIn) marshalCedar(buf *bytes.Buffer) { - marshalInfixBinaryOp(n.BinaryNode, n.precedenceLevel(), "in", buf) -} - -func (n NodeTypeAnd) marshalCedar(buf *bytes.Buffer) { - marshalInfixBinaryOp(n.BinaryNode, n.precedenceLevel(), "&&", buf) -} - -func (n NodeTypeOr) marshalCedar(buf *bytes.Buffer) { - marshalInfixBinaryOp(n.BinaryNode, n.precedenceLevel(), "||", buf) -} - -func (n NodeTypeHas) marshalCedar(buf *bytes.Buffer) { - marshalChildNode(n.precedenceLevel(), n.Arg, buf) - buf.WriteString(" has ") - if canMarshalAsIdent(string(n.Value)) { - buf.WriteString(string(n.Value)) - } else { - buf.WriteString(n.Value.Cedar()) - } -} - -func (n NodeTypeIs) marshalCedar(buf *bytes.Buffer) { - marshalChildNode(n.precedenceLevel(), n.Left, buf) - buf.WriteString(" is ") - buf.WriteString(string(n.EntityType)) -} - -func (n NodeTypeIsIn) marshalCedar(buf *bytes.Buffer) { - marshalChildNode(n.precedenceLevel(), n.Left, buf) - buf.WriteString(" is ") - buf.WriteString(string(n.EntityType)) - buf.WriteString(" in ") - n.Entity.marshalCedar(buf) -} - -func (n NodeTypeLike) marshalCedar(buf *bytes.Buffer) { - marshalChildNode(n.precedenceLevel(), n.Arg, buf) - buf.WriteString(" like ") - buf.WriteString(n.Value.Cedar()) -} - -func (n NodeTypeIf) marshalCedar(buf *bytes.Buffer) { - buf.WriteString("if ") - marshalChildNode(n.precedenceLevel(), n.If, buf) - buf.WriteString(" then ") - marshalChildNode(n.precedenceLevel(), n.Then, buf) - buf.WriteString(" else ") - marshalChildNode(n.precedenceLevel(), n.Else, buf) -} diff --git a/internal/ast/node.go b/internal/ast/node.go index 8c52b190..7845c4ff 100644 --- a/internal/ast/node.go +++ b/internal/ast/node.go @@ -1,8 +1,6 @@ package ast import ( - "bytes" - "github.com/cedar-policy/cedar-go/types" ) @@ -31,82 +29,42 @@ type BinaryNode struct { func (n BinaryNode) isNode() {} -type nodePrecedenceLevel uint8 - -const ( - ifPrecedence nodePrecedenceLevel = 0 - orPrecedence nodePrecedenceLevel = 1 - andPrecedence nodePrecedenceLevel = 2 - relationPrecedence nodePrecedenceLevel = 3 - addPrecedence nodePrecedenceLevel = 4 - multPrecedence nodePrecedenceLevel = 5 - unaryPrecedence nodePrecedenceLevel = 6 - accessPrecedence nodePrecedenceLevel = 7 - primaryPrecedence nodePrecedenceLevel = 8 -) - type NodeTypeIf struct { If, Then, Else IsNode } -func (n NodeTypeIf) precedenceLevel() nodePrecedenceLevel { - return ifPrecedence -} - func (n NodeTypeIf) isNode() {} type NodeTypeOr struct{ BinaryNode } -func (n NodeTypeOr) precedenceLevel() nodePrecedenceLevel { - return orPrecedence -} - type NodeTypeAnd struct { BinaryNode } -func (n NodeTypeAnd) precedenceLevel() nodePrecedenceLevel { - return andPrecedence -} - -type RelationNode struct{} - -func (n RelationNode) precedenceLevel() nodePrecedenceLevel { - return relationPrecedence -} - type NodeTypeLessThan struct { BinaryNode - RelationNode } type NodeTypeLessThanOrEqual struct { BinaryNode - RelationNode } type NodeTypeGreaterThan struct { BinaryNode - RelationNode } type NodeTypeGreaterThanOrEqual struct { BinaryNode - RelationNode } type NodeTypeNotEquals struct { BinaryNode - RelationNode } type NodeTypeEquals struct { BinaryNode - RelationNode } type NodeTypeIn struct { BinaryNode - RelationNode } type NodeTypeHas struct { StrOpNode - RelationNode } type NodeTypeLike struct { @@ -114,9 +72,6 @@ type NodeTypeLike struct { Value types.Pattern } -func (n NodeTypeLike) precedenceLevel() nodePrecedenceLevel { - return relationPrecedence -} func (n NodeTypeLike) isNode() {} type NodeTypeIs struct { @@ -124,9 +79,6 @@ type NodeTypeIs struct { EntityType types.Path } -func (n NodeTypeIs) precedenceLevel() nodePrecedenceLevel { - return relationPrecedence -} func (n NodeTypeIs) isNode() {} type NodeTypeIsIn struct { @@ -134,16 +86,8 @@ type NodeTypeIsIn struct { Entity IsNode } -func (n NodeTypeIsIn) precedenceLevel() nodePrecedenceLevel { - return relationPrecedence -} - type AddNode struct{} -func (n AddNode) precedenceLevel() nodePrecedenceLevel { - return addPrecedence -} - type NodeTypeSub struct { BinaryNode AddNode @@ -156,18 +100,10 @@ type NodeTypeAdd struct { type NodeTypeMult struct{ BinaryNode } -func (n NodeTypeMult) precedenceLevel() nodePrecedenceLevel { - return multPrecedence -} - type UnaryNode struct { Arg IsNode } -func (n UnaryNode) precedenceLevel() nodePrecedenceLevel { - return unaryPrecedence -} - func (n UnaryNode) isNode() {} type NodeTypeNegate struct{ UnaryNode } @@ -175,18 +111,11 @@ type NodeTypeNot struct{ UnaryNode } type NodeTypeAccess struct{ StrOpNode } -func (n NodeTypeAccess) precedenceLevel() nodePrecedenceLevel { - return accessPrecedence -} - type NodeTypeExtensionCall struct { Name types.String // TODO: review type Args []IsNode } -func (n NodeTypeExtensionCall) precedenceLevel() nodePrecedenceLevel { - return accessPrecedence -} func (n NodeTypeExtensionCall) isNode() {} func stripNodes(args []Node) []IsNode { @@ -204,7 +133,7 @@ func NewExtensionCall(method types.String, args ...Node) Node { }) } -func newMethodCall(lhs Node, method types.String, args ...Node) Node { +func NewMethodCall(lhs Node, method types.String, args ...Node) Node { res := make([]IsNode, 1+len(args)) res[0] = lhs.v for i, v := range args { @@ -216,35 +145,17 @@ func newMethodCall(lhs Node, method types.String, args ...Node) Node { }) } -type ContainsNode struct{} - -func (n ContainsNode) precedenceLevel() nodePrecedenceLevel { - return accessPrecedence -} - type NodeTypeContains struct { BinaryNode - ContainsNode } type NodeTypeContainsAll struct { BinaryNode - ContainsNode } type NodeTypeContainsAny struct { BinaryNode - ContainsNode -} - -type PrimaryNode struct{} - -func (n PrimaryNode) isNode() {} - -func (n PrimaryNode) precedenceLevel() nodePrecedenceLevel { - return primaryPrecedence } type NodeValue struct { - PrimaryNode Value types.Value } @@ -256,24 +167,23 @@ type RecordElementNode struct { } type NodeTypeRecord struct { - PrimaryNode Elements []RecordElementNode } func (n NodeTypeRecord) isNode() {} type NodeTypeSet struct { - PrimaryNode Elements []IsNode } +func (n NodeTypeSet) isNode() {} + type NodeTypeVariable struct { - PrimaryNode Name types.String // TODO: Review type } +func (n NodeTypeVariable) isNode() {} + type IsNode interface { isNode() - marshalCedar(*bytes.Buffer) - precedenceLevel() nodePrecedenceLevel } diff --git a/internal/ast/operator.go b/internal/ast/operator.go index 20721193..0bc18c05 100644 --- a/internal/ast/operator.go +++ b/internal/ast/operator.go @@ -34,19 +34,19 @@ func (lhs Node) GreaterThanOrEqual(rhs Node) Node { } func (lhs Node) LessThanExt(rhs Node) Node { - return newMethodCall(lhs, "lessThan", rhs) + return NewMethodCall(lhs, "lessThan", rhs) } func (lhs Node) LessThanOrEqualExt(rhs Node) Node { - return newMethodCall(lhs, "lessThanOrEqual", rhs) + return NewMethodCall(lhs, "lessThanOrEqual", rhs) } func (lhs Node) GreaterThanExt(rhs Node) Node { - return newMethodCall(lhs, "greaterThan", rhs) + return NewMethodCall(lhs, "greaterThan", rhs) } func (lhs Node) GreaterThanOrEqualExt(rhs Node) Node { - return newMethodCall(lhs, "greaterThanOrEqual", rhs) + return NewMethodCall(lhs, "greaterThanOrEqual", rhs) } func (lhs Node) Like(pattern types.Pattern) Node { @@ -144,21 +144,21 @@ func (lhs Node) Has(attr string) Node { // |___|_| /_/ \_\__,_|\__,_|_| \___||___/___/ func (lhs Node) IsIpv4() Node { - return newMethodCall(lhs, "isIpv4") + return NewMethodCall(lhs, "isIpv4") } func (lhs Node) IsIpv6() Node { - return newMethodCall(lhs, "isIpv6") + return NewMethodCall(lhs, "isIpv6") } func (lhs Node) IsMulticast() Node { - return newMethodCall(lhs, "isMulticast") + return NewMethodCall(lhs, "isMulticast") } func (lhs Node) IsLoopback() Node { - return newMethodCall(lhs, "isLoopback") + return NewMethodCall(lhs, "isLoopback") } func (lhs Node) IsInRange(rhs Node) Node { - return newMethodCall(lhs, "isInRange", rhs) + return NewMethodCall(lhs, "isInRange", rhs) } diff --git a/internal/ast/policy.go b/internal/ast/policy.go index 24b92cef..419849e7 100644 --- a/internal/ast/policy.go +++ b/internal/ast/policy.go @@ -1,24 +1,8 @@ package ast -import "github.com/cedar-policy/cedar-go/types" - -type PolicySet map[string]PolicySetEntry - -type PolicySetEntry struct { - Policy Policy - Position Position -} - -func (p PolicySetEntry) TmpGetAnnotations() map[string]string { - res := make(map[string]string, len(p.Policy.Annotations)) - for _, e := range p.Policy.Annotations { - res[string(e.Key)] = string(e.Value) - } - return res -} -func (p PolicySetEntry) TmpGetEffect() bool { - return bool(p.Policy.Effect) -} +import ( + "github.com/cedar-policy/cedar-go/types" +) type AnnotationType struct { Key types.String // TODO: review type diff --git a/internal/ast/scope.go b/internal/ast/scope.go index cad280d7..2da1ab79 100644 --- a/internal/ast/scope.go +++ b/internal/ast/scope.go @@ -1,8 +1,6 @@ package ast import ( - "bytes" - "github.com/cedar-policy/cedar-go/types" ) @@ -89,7 +87,6 @@ func (p *Policy) ResourceIsIn(entityType types.Path, entity types.EntityUID) *Po type IsScopeNode interface { isScope() - MarshalCedar(*bytes.Buffer) } type ScopeNode struct { diff --git a/internal/eval/eval_compile.go b/internal/eval/eval_compile.go index 0cb65032..b2406501 100644 --- a/internal/eval/eval_compile.go +++ b/internal/eval/eval_compile.go @@ -1,12 +1,8 @@ package eval -import "github.com/cedar-policy/cedar-go/internal/ast" - -type CompiledPolicySet map[string]CompiledPolicy - -type CompiledPolicy struct { - ast.PolicySetEntry -} +import ( + "github.com/cedar-policy/cedar-go/internal/ast" +) func Compile(p ast.Policy) Evaler { node := policyToNode(p).AsIsNode() diff --git a/internal/ast/cedar_fuzz_test.go b/internal/parser/cedar_fuzz_test.go similarity index 99% rename from internal/ast/cedar_fuzz_test.go rename to internal/parser/cedar_fuzz_test.go index bc6bb523..4088c275 100644 --- a/internal/ast/cedar_fuzz_test.go +++ b/internal/parser/cedar_fuzz_test.go @@ -1,4 +1,4 @@ -package ast +package parser import ( "testing" diff --git a/internal/parser/cedar_marshal.go b/internal/parser/cedar_marshal.go new file mode 100644 index 00000000..11aaa051 --- /dev/null +++ b/internal/parser/cedar_marshal.go @@ -0,0 +1,385 @@ +package parser + +import ( + "bytes" + "fmt" + + "github.com/cedar-policy/cedar-go/internal/ast" + "github.com/cedar-policy/cedar-go/types" +) + +func (p *Policy) MarshalCedar(buf *bytes.Buffer) { + for _, a := range p.Policy.Annotations { + marshalAnnotation(a, buf) + buf.WriteRune('\n') + } + marshalEffect(p.Policy.Effect, buf) + buf.WriteRune(' ') + p.marshalScope(buf) + + for _, c := range p.Policy.Conditions { + buf.WriteRune('\n') + marshalCondition(c, buf) + } + + buf.WriteRune(';') +} + +// scopeToNode is copied in from eval, with the expectation that +// eval will not be using it in the future. +func scopeToNode(in ast.IsScopeNode) ast.Node { + switch t := in.(type) { + case ast.ScopeTypeAll: + return ast.True() + case ast.ScopeTypeEq: + return ast.NewNode(t.Variable).Equals(ast.EntityUID(t.Entity)) + case ast.ScopeTypeIn: + return ast.NewNode(t.Variable).In(ast.EntityUID(t.Entity)) + case ast.ScopeTypeInSet: + set := make([]types.Value, len(t.Entities)) + for i, e := range t.Entities { + set[i] = e + } + return ast.NewNode(t.Variable).In(ast.Set(set)) + case ast.ScopeTypeIs: + return ast.NewNode(t.Variable).Is(t.Type) + + case ast.ScopeTypeIsIn: + return ast.NewNode(t.Variable).IsIn(t.Type, ast.EntityUID(t.Entity)) + default: + panic(fmt.Sprintf("unknown scope type %T", t)) + } +} + +func (p *Policy) marshalScope(buf *bytes.Buffer) { + _, principalAll := p.Policy.Principal.(ast.ScopeTypeAll) + _, actionAll := p.Policy.Action.(ast.ScopeTypeAll) + _, resourceAll := p.Policy.Resource.(ast.ScopeTypeAll) + if principalAll && actionAll && resourceAll { + buf.WriteString("( principal, action, resource )") + return + } + + buf.WriteString("(\n ") + if principalAll { + buf.WriteString("principal") + } else { + astNodeToMarshalNode(scopeToNode(p.Policy.Principal).AsIsNode()).marshalCedar(buf) + } + buf.WriteString(",\n ") + if actionAll { + buf.WriteString("action") + } else { + astNodeToMarshalNode(scopeToNode(p.Policy.Action).AsIsNode()).marshalCedar(buf) + } + buf.WriteString(",\n ") + if resourceAll { + buf.WriteString("resource") + } else { + astNodeToMarshalNode(scopeToNode(p.Policy.Resource).AsIsNode()).marshalCedar(buf) + } + buf.WriteString("\n)") +} + +func marshalAnnotation(n ast.AnnotationType, buf *bytes.Buffer) { + buf.WriteRune('@') + buf.WriteString(string(n.Key)) + buf.WriteRune('(') + buf.WriteString(n.Value.Cedar()) + buf.WriteString(")") +} + +func marshalEffect(e ast.Effect, buf *bytes.Buffer) { + if e == ast.EffectPermit { + buf.WriteString("permit") + } else { + buf.WriteString("forbid") + } +} + +func (n NodeTypeVariable) marshalCedar(buf *bytes.Buffer) { + buf.WriteString(string(n.NodeTypeVariable.Name)) +} + +func marshalCondition(c ast.ConditionType, buf *bytes.Buffer) { + if c.Condition == ast.ConditionWhen { + buf.WriteString("when") + } else { + buf.WriteString("unless") + } + + buf.WriteString(" { ") + astNodeToMarshalNode(c.Body).marshalCedar(buf) + buf.WriteString(" }") +} + +func (n NodeValue) marshalCedar(buf *bytes.Buffer) { + buf.WriteString(n.NodeValue.Value.Cedar()) +} + +func marshalChildNode(thisNodePrecedence nodePrecedenceLevel, childAstNode ast.IsNode, buf *bytes.Buffer) { + childNode := astNodeToMarshalNode(childAstNode) + if thisNodePrecedence > childNode.precedenceLevel() { + buf.WriteRune('(') + childNode.marshalCedar(buf) + buf.WriteRune(')') + } else { + childNode.marshalCedar(buf) + } +} + +func (n NodeTypeNot) marshalCedar(buf *bytes.Buffer) { + buf.WriteRune('!') + marshalChildNode(n.precedenceLevel(), n.NodeTypeNot.Arg, buf) +} + +func (n NodeTypeNegate) marshalCedar(buf *bytes.Buffer) { + buf.WriteRune('-') + marshalChildNode(n.precedenceLevel(), n.NodeTypeNegate.Arg, buf) +} + +func canMarshalAsIdent(s string) bool { + for i, r := range s { + if !isIdentRune(r, i == 0) { + return false + } + } + return true +} + +func (n NodeTypeAccess) marshalCedar(buf *bytes.Buffer) { + marshalChildNode(n.precedenceLevel(), n.NodeTypeAccess.Arg, buf) + + if canMarshalAsIdent(string(n.NodeTypeAccess.Value)) { + buf.WriteRune('.') + buf.WriteString(string(n.NodeTypeAccess.Value)) + } else { + buf.WriteRune('[') + buf.WriteString(n.NodeTypeAccess.Value.Cedar()) + buf.WriteRune(']') + } +} + +func (n NodeTypeExtensionCall) marshalCedar(buf *bytes.Buffer) { + var args []ast.IsNode + info := ast.ExtMap[n.NodeTypeExtensionCall.Name] + if info.IsMethod { + marshalChildNode(n.precedenceLevel(), n.NodeTypeExtensionCall.Args[0], buf) + buf.WriteRune('.') + args = n.NodeTypeExtensionCall.Args[1:] + } else { + args = n.NodeTypeExtensionCall.Args + } + buf.WriteString(string(n.NodeTypeExtensionCall.Name)) + buf.WriteRune('(') + for i := range args { + marshalChildNode(n.precedenceLevel(), n.NodeTypeExtensionCall.Args[i], buf) + if i != len(n.NodeTypeExtensionCall.Args)-1 { + buf.WriteString(", ") + } + } + buf.WriteRune(')') +} + +func (n NodeTypeContains) marshalCedar(buf *bytes.Buffer) { + marshalChildNode(n.precedenceLevel(), n.NodeTypeContains.Left, buf) + buf.WriteString(".contains(") + marshalChildNode(n.precedenceLevel(), n.NodeTypeContains.Right, buf) + buf.WriteRune(')') +} + +func (n NodeTypeContainsAll) marshalCedar(buf *bytes.Buffer) { + marshalChildNode(n.precedenceLevel(), n.NodeTypeContainsAll.Left, buf) + buf.WriteString(".containsAll(") + marshalChildNode(n.precedenceLevel(), n.NodeTypeContainsAll.Right, buf) + buf.WriteRune(')') +} + +func (n NodeTypeContainsAny) marshalCedar(buf *bytes.Buffer) { + marshalChildNode(n.precedenceLevel(), n.NodeTypeContainsAny.Left, buf) + buf.WriteString(".containsAny(") + marshalChildNode(n.precedenceLevel(), n.NodeTypeContainsAny.Right, buf) + buf.WriteRune(')') +} + +func (n NodeTypeSet) marshalCedar(buf *bytes.Buffer) { + buf.WriteRune('[') + for i := range n.NodeTypeSet.Elements { + marshalChildNode(n.precedenceLevel(), n.NodeTypeSet.Elements[i], buf) + if i != len(n.NodeTypeSet.Elements)-1 { + buf.WriteString(", ") + } + } + buf.WriteRune(']') +} + +func (n NodeTypeRecord) marshalCedar(buf *bytes.Buffer) { + buf.WriteRune('{') + for i := range n.NodeTypeRecord.Elements { + buf.WriteString(n.NodeTypeRecord.Elements[i].Key.Cedar()) + buf.WriteString(":") + marshalChildNode(n.precedenceLevel(), n.NodeTypeRecord.Elements[i].Value, buf) + if i != len(n.NodeTypeRecord.Elements)-1 { + buf.WriteString(", ") + } + } + buf.WriteRune('}') +} + +func marshalInfixBinaryOp(n ast.BinaryNode, precedence nodePrecedenceLevel, op string, buf *bytes.Buffer) { + marshalChildNode(precedence, n.Left, buf) + buf.WriteRune(' ') + buf.WriteString(op) + buf.WriteRune(' ') + marshalChildNode(precedence, n.Right, buf) +} + +func (n NodeTypeMult) marshalCedar(buf *bytes.Buffer) { + marshalInfixBinaryOp(n.NodeTypeMult.BinaryNode, n.precedenceLevel(), "*", buf) +} + +func (n NodeTypeAdd) marshalCedar(buf *bytes.Buffer) { + marshalInfixBinaryOp(n.NodeTypeAdd.BinaryNode, n.precedenceLevel(), "+", buf) +} + +func (n NodeTypeSub) marshalCedar(buf *bytes.Buffer) { + marshalInfixBinaryOp(n.NodeTypeSub.BinaryNode, n.precedenceLevel(), "-", buf) +} + +func (n NodeTypeLessThan) marshalCedar(buf *bytes.Buffer) { + marshalInfixBinaryOp(n.NodeTypeLessThan.BinaryNode, n.precedenceLevel(), "<", buf) +} + +func (n NodeTypeLessThanOrEqual) marshalCedar(buf *bytes.Buffer) { + marshalInfixBinaryOp(n.NodeTypeLessThanOrEqual.BinaryNode, n.precedenceLevel(), "<=", buf) +} + +func (n NodeTypeGreaterThan) marshalCedar(buf *bytes.Buffer) { + marshalInfixBinaryOp(n.NodeTypeGreaterThan.BinaryNode, n.precedenceLevel(), ">", buf) +} + +func (n NodeTypeGreaterThanOrEqual) marshalCedar(buf *bytes.Buffer) { + marshalInfixBinaryOp(n.NodeTypeGreaterThanOrEqual.BinaryNode, n.precedenceLevel(), ">=", buf) +} + +func (n NodeTypeEquals) marshalCedar(buf *bytes.Buffer) { + marshalInfixBinaryOp(n.NodeTypeEquals.BinaryNode, n.precedenceLevel(), "==", buf) +} + +func (n NodeTypeNotEquals) marshalCedar(buf *bytes.Buffer) { + marshalInfixBinaryOp(n.NodeTypeNotEquals.BinaryNode, n.precedenceLevel(), "!=", buf) +} + +func (n NodeTypeIn) marshalCedar(buf *bytes.Buffer) { + marshalInfixBinaryOp(n.NodeTypeIn.BinaryNode, n.precedenceLevel(), "in", buf) +} + +func (n NodeTypeAnd) marshalCedar(buf *bytes.Buffer) { + marshalInfixBinaryOp(n.NodeTypeAnd.BinaryNode, n.precedenceLevel(), "&&", buf) +} + +func (n NodeTypeOr) marshalCedar(buf *bytes.Buffer) { + marshalInfixBinaryOp(n.NodeTypeOr.BinaryNode, n.precedenceLevel(), "||", buf) +} + +func (n NodeTypeHas) marshalCedar(buf *bytes.Buffer) { + marshalChildNode(n.precedenceLevel(), n.NodeTypeHas.Arg, buf) + buf.WriteString(" has ") + if canMarshalAsIdent(string(n.NodeTypeHas.Value)) { + buf.WriteString(string(n.NodeTypeHas.Value)) + } else { + buf.WriteString(n.NodeTypeHas.Value.Cedar()) + } +} + +func (n NodeTypeIs) marshalCedar(buf *bytes.Buffer) { + marshalChildNode(n.precedenceLevel(), n.NodeTypeIs.Left, buf) + buf.WriteString(" is ") + buf.WriteString(string(n.NodeTypeIs.EntityType)) +} + +func (n NodeTypeIsIn) marshalCedar(buf *bytes.Buffer) { + marshalChildNode(n.precedenceLevel(), n.NodeTypeIsIn.Left, buf) + buf.WriteString(" is ") + buf.WriteString(string(n.NodeTypeIsIn.EntityType)) + buf.WriteString(" in ") + marshalChildNode(n.precedenceLevel(), n.NodeTypeIsIn.Entity, buf) +} + +func (n NodeTypeLike) marshalCedar(buf *bytes.Buffer) { + marshalChildNode(n.precedenceLevel(), n.NodeTypeLike.Arg, buf) + buf.WriteString(" like ") + buf.WriteString(n.NodeTypeLike.Value.Cedar()) +} + +func (n NodeTypeIf) marshalCedar(buf *bytes.Buffer) { + buf.WriteString("if ") + marshalChildNode(n.precedenceLevel(), n.NodeTypeIf.If, buf) + buf.WriteString(" then ") + marshalChildNode(n.precedenceLevel(), n.NodeTypeIf.Then, buf) + buf.WriteString(" else ") + marshalChildNode(n.precedenceLevel(), n.NodeTypeIf.Else, buf) +} + +func astNodeToMarshalNode(astNode ast.IsNode) IsNode { + switch v := astNode.(type) { + case ast.NodeTypeIf: + return NodeTypeIf{v} + case ast.NodeTypeOr: + return NodeTypeOr{v} + case ast.NodeTypeAnd: + return NodeTypeAnd{v} + case ast.NodeTypeLessThan: + return NodeTypeLessThan{v, RelationNode{}} + case ast.NodeTypeLessThanOrEqual: + return NodeTypeLessThanOrEqual{v, RelationNode{}} + case ast.NodeTypeGreaterThan: + return NodeTypeGreaterThan{v, RelationNode{}} + case ast.NodeTypeGreaterThanOrEqual: + return NodeTypeGreaterThanOrEqual{v, RelationNode{}} + case ast.NodeTypeNotEquals: + return NodeTypeNotEquals{v, RelationNode{}} + case ast.NodeTypeEquals: + return NodeTypeEquals{v, RelationNode{}} + case ast.NodeTypeIn: + return NodeTypeIn{v, RelationNode{}} + case ast.NodeTypeHas: + return NodeTypeHas{v, RelationNode{}} + case ast.NodeTypeLike: + return NodeTypeLike{v, RelationNode{}} + case ast.NodeTypeIs: + return NodeTypeIs{v, RelationNode{}} + case ast.NodeTypeIsIn: + return NodeTypeIsIn{v, RelationNode{}} + case ast.NodeTypeSub: + return NodeTypeSub{v, AddNode{}} + case ast.NodeTypeAdd: + return NodeTypeAdd{v, AddNode{}} + case ast.NodeTypeMult: + return NodeTypeMult{v} + case ast.NodeTypeNegate: + return NodeTypeNegate{v, UnaryNode{}} + case ast.NodeTypeNot: + return NodeTypeNot{v, UnaryNode{}} + case ast.NodeTypeAccess: + return NodeTypeAccess{v} + case ast.NodeTypeExtensionCall: + return NodeTypeExtensionCall{v} + case ast.NodeTypeContains: + return NodeTypeContains{v, ContainsNode{}} + case ast.NodeTypeContainsAll: + return NodeTypeContainsAll{v, ContainsNode{}} + case ast.NodeTypeContainsAny: + return NodeTypeContainsAny{v, ContainsNode{}} + case ast.NodeValue: + return NodeValue{v, PrimaryNode{}} + case ast.NodeTypeRecord: + return NodeTypeRecord{v, PrimaryNode{}} + case ast.NodeTypeSet: + return NodeTypeSet{v, PrimaryNode{}} + case ast.NodeTypeVariable: + return NodeTypeVariable{v, PrimaryNode{}} + default: + panic(fmt.Sprintf("unknown node type %T", v)) + } +} diff --git a/internal/ast/cedar_parse_test.go b/internal/parser/cedar_parse_test.go similarity index 96% rename from internal/ast/cedar_parse_test.go rename to internal/parser/cedar_parse_test.go index 4339abe2..5e1f1277 100644 --- a/internal/ast/cedar_parse_test.go +++ b/internal/parser/cedar_parse_test.go @@ -1,10 +1,10 @@ -package ast_test +package parser_test import ( "bytes" "testing" - "github.com/cedar-policy/cedar-go/internal/ast" + "github.com/cedar-policy/cedar-go/internal/parser" "github.com/cedar-policy/cedar-go/internal/testutil" ) @@ -295,7 +295,7 @@ func TestParse(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - var policies ast.PolicySet + var policies parser.PolicySet err := policies.UnmarshalCedar([]byte(tt.in)) testutil.Equals(t, err != nil, tt.err) if err != nil { @@ -310,7 +310,7 @@ func TestParse(t *testing.T) { pp := policies["policy0"].Policy pp.MarshalCedar(&buf) - var p2 ast.PolicySet + var p2 parser.PolicySet err = p2.UnmarshalCedar(buf.Bytes()) testutil.OK(t, err) @@ -335,11 +335,11 @@ permit( principal, action, resource ); @test("1234") permit (principal, action, resource ); ` - var out ast.PolicySet + var out parser.PolicySet err := out.UnmarshalCedar([]byte(in)) testutil.OK(t, err) testutil.Equals(t, len(out), 3) - testutil.Equals(t, out["policy0"].Position, ast.Position{Offset: 17, Line: 2, Column: 1}) - testutil.Equals(t, out["policy1"].Position, ast.Position{Offset: 86, Line: 7, Column: 3}) - testutil.Equals(t, out["policy2"].Position, ast.Position{Offset: 148, Line: 10, Column: 2}) + testutil.Equals(t, out["policy0"].Position, parser.Position{Offset: 17, Line: 2, Column: 1}) + testutil.Equals(t, out["policy1"].Position, parser.Position{Offset: 86, Line: 7, Column: 3}) + testutil.Equals(t, out["policy2"].Position, parser.Position{Offset: 148, Line: 10, Column: 2}) } diff --git a/internal/ast/cedar_tokenize.go b/internal/parser/cedar_tokenize.go similarity index 99% rename from internal/ast/cedar_tokenize.go rename to internal/parser/cedar_tokenize.go index 5c9b4cc5..3809047b 100644 --- a/internal/ast/cedar_tokenize.go +++ b/internal/parser/cedar_tokenize.go @@ -1,4 +1,4 @@ -package ast +package parser import ( "bytes" diff --git a/internal/ast/cedar_tokenize_mocks_test.go b/internal/parser/cedar_tokenize_mocks_test.go similarity index 99% rename from internal/ast/cedar_tokenize_mocks_test.go rename to internal/parser/cedar_tokenize_mocks_test.go index 21d98b9e..ff5a98fc 100644 --- a/internal/ast/cedar_tokenize_mocks_test.go +++ b/internal/parser/cedar_tokenize_mocks_test.go @@ -1,7 +1,7 @@ // Code generated by moq; DO NOT EDIT. // github.com/matryer/moq -package ast +package parser import ( "sync" diff --git a/internal/ast/cedar_tokenize_test.go b/internal/parser/cedar_tokenize_test.go similarity index 99% rename from internal/ast/cedar_tokenize_test.go rename to internal/parser/cedar_tokenize_test.go index 5606a642..1fae624f 100644 --- a/internal/ast/cedar_tokenize_test.go +++ b/internal/parser/cedar_tokenize_test.go @@ -1,4 +1,4 @@ -package ast +package parser import ( "fmt" diff --git a/internal/ast/cedar_unmarshal.go b/internal/parser/cedar_unmarshal.go similarity index 77% rename from internal/ast/cedar_unmarshal.go rename to internal/parser/cedar_unmarshal.go index f05e080e..2bf619dd 100644 --- a/internal/ast/cedar_unmarshal.go +++ b/internal/parser/cedar_unmarshal.go @@ -1,10 +1,11 @@ -package ast +package parser import ( "fmt" "strconv" "strings" + "github.com/cedar-policy/cedar-go/internal/ast" "github.com/cedar-policy/cedar-go/types" ) @@ -21,9 +22,11 @@ func (p *PolicySet) UnmarshalCedar(b []byte) error { for !parser.peek().isEOF() { pos := parser.peek().Pos policy := Policy{ - Principal: ScopeTypeAll{}, - Action: ScopeTypeAll{}, - Resource: ScopeTypeAll{}, + ast.Policy{ + Principal: ast.ScopeTypeAll{}, + Action: ast.ScopeTypeAll{}, + Resource: ast.ScopeTypeAll{}, + }, } if err = policy.fromCedarWithParser(&parser); err != nil { @@ -88,7 +91,7 @@ func (p *Policy) fromCedarWithParser(parser *parser) error { return err } - *p = *newPolicy + *p = Policy{*newPolicy} return nil } @@ -130,8 +133,8 @@ func (p *parser) errorf(s string, args ...interface{}) error { return fmt.Errorf("parse error at %v %q: %w", t.Pos, t.Text, err) } -func (p *parser) annotations() (Annotations, error) { - var res Annotations +func (p *parser) annotations() (ast.Annotations, error) { + var res ast.Annotations known := map[types.String]struct{}{} for p.peek().Text == "@" { p.advance() @@ -144,7 +147,7 @@ func (p *parser) annotations() (Annotations, error) { } -func (p *parser) annotation(a *Annotations, known map[types.String]struct{}) error { +func (p *parser) annotation(a *ast.Annotations, known map[types.String]struct{}) error { var err error t := p.advance() if !t.isIdent() { @@ -174,7 +177,7 @@ func (p *parser) annotation(a *Annotations, known map[types.String]struct{}) err return nil } -func (p *parser) effect(a *Annotations) (*Policy, error) { +func (p *parser) effect(a *ast.Annotations) (*ast.Policy, error) { next := p.advance() if next.Text == "permit" { return a.Permit(), nil @@ -185,7 +188,7 @@ func (p *parser) effect(a *Annotations) (*Policy, error) { return nil, p.errorf("unexpected effect: %v", next.Text) } -func (p *parser) principal(policy *Policy) error { +func (p *parser) principal(policy *ast.Policy) error { if err := p.exact("principal"); err != nil { return err } @@ -287,7 +290,7 @@ func (p *parser) path() (types.Path, error) { return p.pathFirstPathPreread(t.Text) } -func (p *parser) action(policy *Policy) error { +func (p *parser) action(policy *ast.Policy) error { if err := p.exact("action"); err != nil { return err } @@ -341,7 +344,7 @@ func (p *parser) entlist() ([]types.EntityUID, error) { return res, nil } -func (p *parser) resource(policy *Policy) error { +func (p *parser) resource(policy *ast.Policy) error { if err := p.exact("resource"); err != nil { return err } @@ -385,7 +388,7 @@ func (p *parser) resource(policy *Policy) error { return nil } -func (p *parser) conditions(policy *Policy) error { +func (p *parser) conditions(policy *ast.Policy) error { for { switch p.peek().Text { case "when": @@ -408,8 +411,8 @@ func (p *parser) conditions(policy *Policy) error { } } -func (p *parser) condition() (Node, error) { - var res Node +func (p *parser) condition() (ast.Node, error) { + var res ast.Node var err error if err := p.exact("{"); err != nil { return res, err @@ -423,49 +426,49 @@ func (p *parser) condition() (Node, error) { return res, nil } -func (p *parser) expression() (Node, error) { +func (p *parser) expression() (ast.Node, error) { t := p.peek() if t.Text == "if" { p.advance() condition, err := p.expression() if err != nil { - return Node{}, err + return ast.Node{}, err } if err = p.exact("then"); err != nil { - return Node{}, err + return ast.Node{}, err } ifTrue, err := p.expression() if err != nil { - return Node{}, err + return ast.Node{}, err } if err = p.exact("else"); err != nil { - return Node{}, err + return ast.Node{}, err } ifFalse, err := p.expression() if err != nil { - return Node{}, err + return ast.Node{}, err } - return If(condition, ifTrue, ifFalse), nil + return ast.If(condition, ifTrue, ifFalse), nil } return p.or() } -func (p *parser) or() (Node, error) { +func (p *parser) or() (ast.Node, error) { lhs, err := p.and() if err != nil { - return Node{}, err + return ast.Node{}, err } for p.peek().Text == "||" { p.advance() rhs, err := p.and() if err != nil { - return Node{}, err + return ast.Node{}, err } lhs = lhs.Or(rhs) } @@ -473,17 +476,17 @@ func (p *parser) or() (Node, error) { return lhs, nil } -func (p *parser) and() (Node, error) { +func (p *parser) and() (ast.Node, error) { lhs, err := p.relation() if err != nil { - return Node{}, err + return ast.Node{}, err } for p.peek().Text == "&&" { p.advance() rhs, err := p.relation() if err != nil { - return Node{}, err + return ast.Node{}, err } lhs = lhs.And(rhs) } @@ -491,10 +494,10 @@ func (p *parser) and() (Node, error) { return lhs, nil } -func (p *parser) relation() (Node, error) { +func (p *parser) relation() (ast.Node, error) { lhs, err := p.add() if err != nil { - return Node{}, err + return ast.Node{}, err } t := p.peek() @@ -511,22 +514,22 @@ func (p *parser) relation() (Node, error) { } // RELOP - var operator func(Node, Node) Node + var operator func(ast.Node, ast.Node) ast.Node switch t.Text { case "<": - operator = Node.LessThan + operator = ast.Node.LessThan case "<=": - operator = Node.LessThanOrEqual + operator = ast.Node.LessThanOrEqual case ">": - operator = Node.GreaterThan + operator = ast.Node.GreaterThan case ">=": - operator = Node.GreaterThanOrEqual + operator = ast.Node.GreaterThanOrEqual case "!=": - operator = Node.NotEquals + operator = ast.Node.NotEquals case "==": - operator = Node.Equals + operator = ast.Node.Equals case "in": - operator = Node.In + operator = ast.Node.In default: return lhs, nil @@ -535,70 +538,70 @@ func (p *parser) relation() (Node, error) { p.advance() rhs, err := p.add() if err != nil { - return Node{}, err + return ast.Node{}, err } return operator(lhs, rhs), nil } -func (p *parser) has(lhs Node) (Node, error) { +func (p *parser) has(lhs ast.Node) (ast.Node, error) { t := p.advance() if t.isIdent() { return lhs.Has(t.Text), nil } else if t.isString() { str, err := t.stringValue() if err != nil { - return Node{}, err + return ast.Node{}, err } return lhs.Has(str), nil } - return Node{}, p.errorf("expected ident or string") + return ast.Node{}, p.errorf("expected ident or string") } -func (p *parser) like(lhs Node) (Node, error) { +func (p *parser) like(lhs ast.Node) (ast.Node, error) { t := p.advance() if !t.isString() { - return Node{}, p.errorf("expected string literal") + return ast.Node{}, p.errorf("expected string literal") } patternRaw := t.Text patternRaw = strings.TrimPrefix(patternRaw, "\"") patternRaw = strings.TrimSuffix(patternRaw, "\"") pattern, err := types.ParsePattern(patternRaw) if err != nil { - return Node{}, err + return ast.Node{}, err } return lhs.Like(pattern), nil } -func (p *parser) is(lhs Node) (Node, error) { +func (p *parser) is(lhs ast.Node) (ast.Node, error) { entityType, err := p.path() if err != nil { - return Node{}, err + return ast.Node{}, err } if p.peek().Text == "in" { p.advance() inEntity, err := p.add() if err != nil { - return Node{}, err + return ast.Node{}, err } return lhs.IsIn(entityType, inEntity), nil } return lhs.Is(entityType), nil } -func (p *parser) add() (Node, error) { +func (p *parser) add() (ast.Node, error) { lhs, err := p.mult() if err != nil { - return Node{}, err + return ast.Node{}, err } for { t := p.peek() - var operator func(Node, Node) Node + var operator func(ast.Node, ast.Node) ast.Node switch t.Text { case "+": - operator = Node.Plus + operator = ast.Node.Plus case "-": - operator = Node.Minus + operator = ast.Node.Minus } if operator == nil { @@ -608,7 +611,7 @@ func (p *parser) add() (Node, error) { p.advance() rhs, err := p.mult() if err != nil { - return Node{}, err + return ast.Node{}, err } lhs = operator(lhs, rhs) } @@ -616,17 +619,17 @@ func (p *parser) add() (Node, error) { return lhs, nil } -func (p *parser) mult() (Node, error) { +func (p *parser) mult() (ast.Node, error) { lhs, err := p.unary() if err != nil { - return Node{}, err + return ast.Node{}, err } for p.peek().Text == "*" { p.advance() rhs, err := p.unary() if err != nil { - return Node{}, err + return ast.Node{}, err } lhs = lhs.Times(rhs) } @@ -634,7 +637,7 @@ func (p *parser) mult() (Node, error) { return lhs, nil } -func (p *parser) unary() (Node, error) { +func (p *parser) unary() (ast.Node, error) { var ops []bool for { opToken := p.peek() @@ -645,7 +648,7 @@ func (p *parser) unary() (Node, error) { ops = append(ops, opToken.Text == "-") } - var res Node + var res ast.Node // special case for max negative long tok := p.peek() @@ -653,9 +656,9 @@ func (p *parser) unary() (Node, error) { p.advance() i, err := strconv.ParseInt("-"+tok.Text, 10, 64) if err != nil { - return Node{}, err + return ast.Node{}, err } - res = Long(types.Long(i)) + res = ast.Long(types.Long(i)) ops = ops[:len(ops)-1] } else { var err error @@ -667,15 +670,15 @@ func (p *parser) unary() (Node, error) { for i := len(ops) - 1; i >= 0; i-- { if ops[i] { - res = Negate(res) + res = ast.Negate(res) } else { - res = Not(res) + res = ast.Not(res) } } return res, nil } -func (p *parser) member() (Node, error) { +func (p *parser) member() (ast.Node, error) { res, err := p.primary() if err != nil { return res, err @@ -689,8 +692,8 @@ func (p *parser) member() (Node, error) { } } -func (p *parser) primary() (Node, error) { - var res Node +func (p *parser) primary() (ast.Node, error) { + var res ast.Node t := p.advance() switch { case t.isInt(): @@ -698,25 +701,25 @@ func (p *parser) primary() (Node, error) { if err != nil { return res, err } - res = Long(types.Long(i)) + res = ast.Long(types.Long(i)) case t.isString(): str, err := t.stringValue() if err != nil { return res, err } - res = String(types.String(str)) + res = ast.String(types.String(str)) case t.Text == "true": - res = True() + res = ast.True() case t.Text == "false": - res = False() + res = ast.False() case t.Text == "principal": - res = Principal() + res = ast.Principal() case t.Text == "action": - res = Action() + res = ast.Action() case t.Text == "resource": - res = Resource() + res = ast.Resource() case t.Text == "context": - res = Context() + res = ast.Context() case t.isIdent(): return p.entityOrExtFun(t.Text) case t.Text == "(": @@ -734,7 +737,7 @@ func (p *parser) primary() (Node, error) { return res, err } p.advance() // expressions guarantees "]" - res = SetNodes(set...) + res = ast.SetNodes(set...) case t.Text == "{": record, err := p.record() if err != nil { @@ -747,7 +750,7 @@ func (p *parser) primary() (Node, error) { return res, nil } -func (p *parser) entityOrExtFun(ident string) (Node, error) { +func (p *parser) entityOrExtFun(ident string) (ast.Node, error) { var res types.EntityUID var err error res.Type = ident @@ -762,16 +765,16 @@ func (p *parser) entityOrExtFun(ident string) (Node, error) { case t.isString(): res.ID, err = t.stringValue() if err != nil { - return Node{}, err + return ast.Node{}, err } - return EntityUID(res), nil + return ast.EntityUID(res), nil default: - return Node{}, p.errorf("unexpected token") + return ast.Node{}, p.errorf("unexpected token") } case "(": args, err := p.expressions(")") if err != nil { - return Node{}, err + return ast.Node{}, err } p.advance() // i, ok := extMap[types.String(res.Type)] @@ -781,15 +784,15 @@ func (p *parser) entityOrExtFun(ident string) (Node, error) { // if i.IsMethod { // return Node{}, p.errorf("`%v` is a method, not a function", res.Type) // } - return ExtensionCall(types.String(res.Type), args...), nil + return ast.ExtensionCall(types.String(res.Type), args...), nil default: - return Node{}, p.errorf("unexpected token") + return ast.Node{}, p.errorf("unexpected token") } } } -func (p *parser) expressions(endOfListMarker string) ([]Node, error) { - var res []Node +func (p *parser) expressions(endOfListMarker string) ([]ast.Node, error) { + var res []ast.Node for p.peek().Text != endOfListMarker { if len(res) > 0 { if err := p.exact(","); err != nil { @@ -805,15 +808,15 @@ func (p *parser) expressions(endOfListMarker string) ([]Node, error) { return res, nil } -func (p *parser) record() (Node, error) { - var res Node - var elements []RecordElement +func (p *parser) record() (ast.Node, error) { + var res ast.Node + var elements []ast.RecordElement known := map[types.String]struct{}{} for { t := p.peek() if t.Text == "}" { p.advance() - return RecordElements(elements...), nil + return ast.RecordElements(elements...), nil } if len(elements) > 0 { if err := p.exact(","); err != nil { @@ -829,13 +832,13 @@ func (p *parser) record() (Node, error) { return res, p.errorf("duplicate key: %v", k) } known[k] = struct{}{} - elements = append(elements, RecordElement{Key: k, Value: v}) + elements = append(elements, ast.RecordElement{Key: k, Value: v}) } } -func (p *parser) recordEntry() (types.String, Node, error) { +func (p *parser) recordEntry() (types.String, ast.Node, error) { var key types.String - var value Node + var value ast.Node var err error t := p.advance() switch { @@ -860,32 +863,32 @@ func (p *parser) recordEntry() (types.String, Node, error) { return key, value, nil } -func (p *parser) access(lhs Node) (Node, bool, error) { +func (p *parser) access(lhs ast.Node) (ast.Node, bool, error) { t := p.peek() switch t.Text { case ".": p.advance() t := p.advance() if !t.isIdent() { - return Node{}, false, p.errorf("unexpected token") + return ast.Node{}, false, p.errorf("unexpected token") } if p.peek().Text == "(" { methodName := t.Text p.advance() exprs, err := p.expressions(")") if err != nil { - return Node{}, false, err + return ast.Node{}, false, err } p.advance() // expressions guarantees ")" - var knownMethod func(Node, Node) Node + var knownMethod func(ast.Node, ast.Node) ast.Node switch methodName { case "contains": - knownMethod = Node.Contains + knownMethod = ast.Node.Contains case "containsAll": - knownMethod = Node.ContainsAll + knownMethod = ast.Node.ContainsAll case "containsAny": - knownMethod = Node.ContainsAny + knownMethod = ast.Node.ContainsAny default: // i, ok := extMap[types.String(methodName)] // if !ok { @@ -894,11 +897,11 @@ func (p *parser) access(lhs Node) (Node, bool, error) { // if !i.IsMethod { // return Node{}, false, p.errorf("`%v` is a function, not a method", methodName) // } - return newMethodCall(lhs, types.String(methodName), exprs...), true, nil + return ast.NewMethodCall(lhs, types.String(methodName), exprs...), true, nil } if len(exprs) != 1 { - return Node{}, false, p.errorf("%v expects one argument", methodName) + return ast.Node{}, false, p.errorf("%v expects one argument", methodName) } return knownMethod(lhs, exprs[0]), true, nil } else { @@ -908,14 +911,14 @@ func (p *parser) access(lhs Node) (Node, bool, error) { p.advance() t := p.advance() if !t.isString() { - return Node{}, false, p.errorf("unexpected token") + return ast.Node{}, false, p.errorf("unexpected token") } name, err := t.stringValue() if err != nil { - return Node{}, false, err + return ast.Node{}, false, err } if err := p.exact("]"); err != nil { - return Node{}, false, err + return ast.Node{}, false, err } return lhs.Access(name), true, nil default: diff --git a/internal/ast/cedar_unmarshal_test.go b/internal/parser/cedar_unmarshal_test.go similarity index 95% rename from internal/ast/cedar_unmarshal_test.go rename to internal/parser/cedar_unmarshal_test.go index 72d1a00e..91c3bb51 100644 --- a/internal/ast/cedar_unmarshal_test.go +++ b/internal/parser/cedar_unmarshal_test.go @@ -1,10 +1,11 @@ -package ast_test +package parser_test import ( "bytes" "testing" "github.com/cedar-policy/cedar-go/internal/ast" + "github.com/cedar-policy/cedar-go/internal/parser" "github.com/cedar-policy/cedar-go/internal/testutil" "github.com/cedar-policy/cedar-go/types" ) @@ -433,9 +434,9 @@ when { (if true then 2 else 3) * 4 == 8 };`, t.Run(tt.Name, func(t *testing.T) { t.Parallel() - var policy ast.Policy + var policy parser.Policy testutil.OK(t, policy.UnmarshalCedar([]byte(tt.Text))) - testutil.Equals(t, policy, *tt.ExpectedPolicy) + testutil.Equals(t, policy, parser.Policy{*tt.ExpectedPolicy}) var buf bytes.Buffer policy.MarshalCedar(&buf) @@ -449,7 +450,7 @@ func TestParsePolicySet(t *testing.T) { parseTests := []struct { Name string Text string - ExpectedPolicies ast.PolicySet + ExpectedPolicies parser.PolicySet }{ { "single policy", @@ -458,10 +459,10 @@ func TestParsePolicySet(t *testing.T) { action, resource );`, - ast.PolicySet{ - "policy0": ast.PolicySetEntry{ - *ast.Permit(), - ast.Position{Offset: 0, Line: 1, Column: 1}, + parser.PolicySet{ + "policy0": parser.PolicySetEntry{ + parser.Policy{*ast.Permit()}, + parser.Position{Offset: 0, Line: 1, Column: 1}, }, }, }, @@ -477,14 +478,14 @@ func TestParsePolicySet(t *testing.T) { action, resource );`, - ast.PolicySet{ - "policy0": ast.PolicySetEntry{ - *ast.Permit(), - ast.Position{Offset: 0, Line: 1, Column: 1}, + parser.PolicySet{ + "policy0": parser.PolicySetEntry{ + parser.Policy{*ast.Permit()}, + parser.Position{Offset: 0, Line: 1, Column: 1}, }, - "policy1": ast.PolicySetEntry{ - *ast.Forbid(), - ast.Position{Offset: 53, Line: 6, Column: 3}, + "policy1": parser.PolicySetEntry{ + parser.Policy{*ast.Forbid()}, + parser.Position{Offset: 53, Line: 6, Column: 3}, }, }, }, @@ -493,7 +494,7 @@ func TestParsePolicySet(t *testing.T) { t.Run(tt.Name, func(t *testing.T) { t.Parallel() - var policies ast.PolicySet + var policies parser.PolicySet testutil.OK(t, policies.UnmarshalCedar([]byte(tt.Text))) testutil.Equals(t, policies, tt.ExpectedPolicies) }) diff --git a/internal/parser/node.go b/internal/parser/node.go new file mode 100644 index 00000000..41da04cf --- /dev/null +++ b/internal/parser/node.go @@ -0,0 +1,194 @@ +package parser + +import ( + "bytes" + + "github.com/cedar-policy/cedar-go/internal/ast" +) + +type NodeTypeIf struct{ ast.NodeTypeIf } + +func (n NodeTypeIf) precedenceLevel() nodePrecedenceLevel { + return ifPrecedence +} + +type NodeTypeOr struct{ ast.NodeTypeOr } + +func (n NodeTypeOr) precedenceLevel() nodePrecedenceLevel { + return orPrecedence +} + +type NodeTypeAnd struct{ ast.NodeTypeAnd } + +func (n NodeTypeAnd) precedenceLevel() nodePrecedenceLevel { + return andPrecedence +} + +type RelationNode struct{} + +func (n RelationNode) precedenceLevel() nodePrecedenceLevel { + return relationPrecedence +} + +type NodeTypeLessThan struct { + ast.NodeTypeLessThan + RelationNode +} + +type NodeTypeLessThanOrEqual struct { + ast.NodeTypeLessThanOrEqual + RelationNode +} +type NodeTypeGreaterThan struct { + ast.NodeTypeGreaterThan + RelationNode +} +type NodeTypeGreaterThanOrEqual struct { + ast.NodeTypeGreaterThanOrEqual + RelationNode +} +type NodeTypeNotEquals struct { + ast.NodeTypeNotEquals + RelationNode +} +type NodeTypeEquals struct { + ast.NodeTypeEquals + RelationNode +} +type NodeTypeIn struct { + ast.NodeTypeIn + RelationNode +} + +type NodeTypeHas struct { + ast.NodeTypeHas + RelationNode +} + +type NodeTypeLike struct { + ast.NodeTypeLike + RelationNode +} + +type NodeTypeIs struct { + ast.NodeTypeIs + RelationNode +} + +type NodeTypeIsIn struct { + ast.NodeTypeIsIn + RelationNode +} + +type AddNode struct{} + +func (n AddNode) precedenceLevel() nodePrecedenceLevel { + return addPrecedence +} + +type NodeTypeSub struct { + ast.NodeTypeSub + AddNode +} + +type NodeTypeAdd struct { + ast.NodeTypeAdd + AddNode +} + +type NodeTypeMult struct{ ast.NodeTypeMult } + +func (n NodeTypeMult) precedenceLevel() nodePrecedenceLevel { + return multPrecedence +} + +type UnaryNode struct{ ast.UnaryNode } + +func (n UnaryNode) precedenceLevel() nodePrecedenceLevel { + return unaryPrecedence +} + +type NodeTypeNegate struct { + ast.NodeTypeNegate + UnaryNode +} +type NodeTypeNot struct { + ast.NodeTypeNot + UnaryNode +} + +type NodeTypeAccess struct{ ast.NodeTypeAccess } + +func (n NodeTypeAccess) precedenceLevel() nodePrecedenceLevel { + return accessPrecedence +} + +type NodeTypeExtensionCall struct{ ast.NodeTypeExtensionCall } + +func (n NodeTypeExtensionCall) precedenceLevel() nodePrecedenceLevel { + return accessPrecedence +} + +type ContainsNode struct{} + +func (n ContainsNode) precedenceLevel() nodePrecedenceLevel { + return accessPrecedence +} + +type NodeTypeContains struct { + ast.NodeTypeContains + ContainsNode +} +type NodeTypeContainsAll struct { + ast.NodeTypeContainsAll + ContainsNode +} +type NodeTypeContainsAny struct { + ast.NodeTypeContainsAny + ContainsNode +} + +type PrimaryNode struct{} + +func (n PrimaryNode) precedenceLevel() nodePrecedenceLevel { + return primaryPrecedence +} + +type NodeValue struct { + ast.NodeValue + PrimaryNode +} + +type NodeTypeRecord struct { + ast.NodeTypeRecord + PrimaryNode +} + +type NodeTypeSet struct { + ast.NodeTypeSet + PrimaryNode +} + +type NodeTypeVariable struct { + ast.NodeTypeVariable + PrimaryNode +} + +type nodePrecedenceLevel uint8 + +const ( + ifPrecedence nodePrecedenceLevel = 0 + orPrecedence nodePrecedenceLevel = 1 + andPrecedence nodePrecedenceLevel = 2 + relationPrecedence nodePrecedenceLevel = 3 + addPrecedence nodePrecedenceLevel = 4 + multPrecedence nodePrecedenceLevel = 5 + unaryPrecedence nodePrecedenceLevel = 6 + accessPrecedence nodePrecedenceLevel = 7 + primaryPrecedence nodePrecedenceLevel = 8 +) + +type IsNode interface { + precedenceLevel() nodePrecedenceLevel + marshalCedar(*bytes.Buffer) +} diff --git a/internal/parser/policy.go b/internal/parser/policy.go new file mode 100644 index 00000000..6a2ae318 --- /dev/null +++ b/internal/parser/policy.go @@ -0,0 +1,25 @@ +package parser + +import "github.com/cedar-policy/cedar-go/internal/ast" + +type PolicySet map[string]PolicySetEntry + +type PolicySetEntry struct { + Policy Policy + Position Position +} + +func (p PolicySetEntry) TmpGetAnnotations() map[string]string { + res := make(map[string]string, len(p.Policy.Annotations)) + for _, e := range p.Policy.Annotations { + res[string(e.Key)] = string(e.Value) + } + return res +} +func (p PolicySetEntry) TmpGetEffect() bool { + return bool(p.Policy.Effect) +} + +type Policy struct { + ast.Policy +} From 144fb6823cacb207ff75246734374b706d3d00a3 Mon Sep 17 00:00:00 2001 From: philhassey Date: Tue, 13 Aug 2024 14:44:02 -0600 Subject: [PATCH 091/216] internal/extensions: put extensions in separate package Addresses IDX-142 Signed-off-by: philhassey --- internal/eval/eval_convert.go | 3 ++- internal/{ast => extensions}/extensions.go | 2 +- internal/parser/cedar_marshal.go | 3 ++- internal/parser/cedar_unmarshal.go | 4 ++-- 4 files changed, 7 insertions(+), 5 deletions(-) rename internal/{ast => extensions}/extensions.go (97%) diff --git a/internal/eval/eval_convert.go b/internal/eval/eval_convert.go index bec793eb..734c9352 100644 --- a/internal/eval/eval_convert.go +++ b/internal/eval/eval_convert.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/cedar-policy/cedar-go/internal/ast" + "github.com/cedar-policy/cedar-go/internal/extensions" "github.com/cedar-policy/cedar-go/types" ) @@ -25,7 +26,7 @@ func toEval(n ast.IsNode) Evaler { rhs := newInEval(obj, toEval(v.Entity)) return newAndEval(lhs, rhs) case ast.NodeTypeExtensionCall: - i, ok := ast.ExtMap[v.Name] + i, ok := extensions.ExtMap[v.Name] if !ok { return newErrorEval(fmt.Errorf("%w: %s", errUnknownExtensionFunction, v.Name)) } diff --git a/internal/ast/extensions.go b/internal/extensions/extensions.go similarity index 97% rename from internal/ast/extensions.go rename to internal/extensions/extensions.go index 41f1602b..fe5a4e2c 100644 --- a/internal/ast/extensions.go +++ b/internal/extensions/extensions.go @@ -1,4 +1,4 @@ -package ast +package extensions import "github.com/cedar-policy/cedar-go/types" diff --git a/internal/parser/cedar_marshal.go b/internal/parser/cedar_marshal.go index 11aaa051..d1ebf1b2 100644 --- a/internal/parser/cedar_marshal.go +++ b/internal/parser/cedar_marshal.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/cedar-policy/cedar-go/internal/ast" + "github.com/cedar-policy/cedar-go/internal/extensions" "github.com/cedar-policy/cedar-go/types" ) @@ -162,7 +163,7 @@ func (n NodeTypeAccess) marshalCedar(buf *bytes.Buffer) { func (n NodeTypeExtensionCall) marshalCedar(buf *bytes.Buffer) { var args []ast.IsNode - info := ast.ExtMap[n.NodeTypeExtensionCall.Name] + info := extensions.ExtMap[n.NodeTypeExtensionCall.Name] if info.IsMethod { marshalChildNode(n.precedenceLevel(), n.NodeTypeExtensionCall.Args[0], buf) buf.WriteRune('.') diff --git a/internal/parser/cedar_unmarshal.go b/internal/parser/cedar_unmarshal.go index 2bf619dd..dcf0fa51 100644 --- a/internal/parser/cedar_unmarshal.go +++ b/internal/parser/cedar_unmarshal.go @@ -777,7 +777,7 @@ func (p *parser) entityOrExtFun(ident string) (ast.Node, error) { return ast.Node{}, err } p.advance() - // i, ok := extMap[types.String(res.Type)] + // i, ok := extensions.ExtMap[types.String(res.Type)] // if !ok { // return Node{}, p.errorf("`%v` is not a function", res.Type) // } @@ -890,7 +890,7 @@ func (p *parser) access(lhs ast.Node) (ast.Node, bool, error) { case "containsAny": knownMethod = ast.Node.ContainsAny default: - // i, ok := extMap[types.String(methodName)] + // i, ok := extensions.ExtMap[types.String(methodName)] // if !ok { // return Node{}, false, p.errorf("not a valid method name: `%v`", methodName) // } From 921bee9f4a275ed764239b16dc096854bfa0da16 Mon Sep 17 00:00:00 2001 From: philhassey Date: Tue, 13 Aug 2024 14:47:37 -0600 Subject: [PATCH 092/216] internal/ast: remove unnecessary function Addresses IDX-142 Signed-off-by: philhassey --- internal/ast/value.go | 29 ++--------------------------- 1 file changed, 2 insertions(+), 27 deletions(-) diff --git a/internal/ast/value.go b/internal/ast/value.go index ab685ff7..7efd17fc 100644 --- a/internal/ast/value.go +++ b/internal/ast/value.go @@ -1,8 +1,6 @@ package ast import ( - "fmt" - "github.com/cedar-policy/cedar-go/types" ) @@ -31,7 +29,7 @@ func Long(l types.Long) Node { func Set(s types.Set) Node { var nodes []IsNode for _, v := range s { - nodes = append(nodes, valueToNode(v).v) + nodes = append(nodes, NewValueNode(v).v) } return NewNode(NodeTypeSet{Elements: nodes}) } @@ -58,7 +56,7 @@ func Record(r types.Record) Node { // TODO: this results in a double allocation, fix that recordNodes := map[types.String]Node{} for k, v := range r { - recordNodes[types.String(k)] = valueToNode(v) + recordNodes[types.String(k)] = NewValueNode(v) } return RecordNodes(recordNodes) } @@ -113,26 +111,3 @@ func ExtensionCall(name types.String, args ...Node) Node { func NewValueNode(v types.Value) Node { return NewNode(NodeValue{Value: v}) } - -func valueToNode(v types.Value) Node { - switch x := v.(type) { - case types.Boolean: - return Boolean(x) - case types.String: - return String(x) - case types.Long: - return Long(x) - case types.Set: - return Set(x) - case types.Record: - return Record(x) - case types.EntityUID: - return EntityUID(x) - case types.Decimal: - return Decimal(x) - case types.IPAddr: - return IPAddr(x) - default: - panic(fmt.Sprintf("unexpected value type: %T(%v)", v, v)) - } -} From 7a9965923fedfc59bddfaebd8f7493f3df7987f0 Mon Sep 17 00:00:00 2001 From: philhassey Date: Tue, 13 Aug 2024 15:00:13 -0600 Subject: [PATCH 093/216] ast: reduce memory allocations Addresses IDX-142 Signed-off-by: philhassey --- ast/annotation.go | 20 ++++++++------- ast/node.go | 2 +- ast/operator.go | 64 +++++++++++++++++++++++------------------------ ast/policy.go | 18 +++++++------ ast/scope.go | 22 ++++++++-------- ast/value.go | 24 +++++++++--------- ast/variable.go | 8 +++--- 7 files changed, 81 insertions(+), 77 deletions(-) diff --git a/ast/annotation.go b/ast/annotation.go index e3860354..6fe9e7df 100644 --- a/ast/annotation.go +++ b/ast/annotation.go @@ -5,12 +5,14 @@ import ( "github.com/cedar-policy/cedar-go/types" ) -type Annotations struct { - *ast.Annotations +type Annotations ast.Annotations + +func (a *Annotations) unwrap() *ast.Annotations { + return (*ast.Annotations)(a) } -func newAnnotations(a *ast.Annotations) *Annotations { - return &Annotations{a} +func wrapAnnotations(a *ast.Annotations) *Annotations { + return (*Annotations)(a) } // Annotation allows AST constructors to make policy in a similar shape to textual Cedar with @@ -21,21 +23,21 @@ func newAnnotations(a *ast.Annotations) *Annotations { // Permit(). // PrincipalEq(superUser) func Annotation(name, value types.String) *Annotations { - return newAnnotations(ast.Annotation(name, value)) + return wrapAnnotations(ast.Annotation(name, value)) } func (a *Annotations) Annotation(name, value types.String) *Annotations { - return newAnnotations(a.Annotations.Annotation(name, value)) + return wrapAnnotations(a.unwrap().Annotation(name, value)) } func (a *Annotations) Permit() *Policy { - return newPolicy(a.Annotations.Permit()) + return wrapPolicy(a.unwrap().Permit()) } func (a *Annotations) Forbid() *Policy { - return newPolicy(a.Annotations.Forbid()) + return wrapPolicy(a.unwrap().Forbid()) } func (p *Policy) Annotate(name, value types.String) *Policy { - return newPolicy(p.Policy.Annotate(name, value)) + return wrapPolicy(p.unwrap().Annotate(name, value)) } diff --git a/ast/node.go b/ast/node.go index 0b99e3e8..5f7a74da 100644 --- a/ast/node.go +++ b/ast/node.go @@ -6,6 +6,6 @@ type Node struct { ast.Node } -func newNode(n ast.Node) Node { +func wrapNode(n ast.Node) Node { return Node{n} } diff --git a/ast/operator.go b/ast/operator.go index 2ca9c52a..9c271ce3 100644 --- a/ast/operator.go +++ b/ast/operator.go @@ -13,47 +13,47 @@ import ( // |_| func (lhs Node) Equals(rhs Node) Node { - return newNode(lhs.Node.Equals(rhs.Node)) + return wrapNode(lhs.Node.Equals(rhs.Node)) } func (lhs Node) NotEquals(rhs Node) Node { - return newNode(lhs.Node.NotEquals(rhs.Node)) + return wrapNode(lhs.Node.NotEquals(rhs.Node)) } func (lhs Node) LessThan(rhs Node) Node { - return newNode(lhs.Node.LessThan(rhs.Node)) + return wrapNode(lhs.Node.LessThan(rhs.Node)) } func (lhs Node) LessThanOrEqual(rhs Node) Node { - return newNode(lhs.Node.LessThanOrEqual(rhs.Node)) + return wrapNode(lhs.Node.LessThanOrEqual(rhs.Node)) } func (lhs Node) GreaterThan(rhs Node) Node { - return newNode(lhs.Node.GreaterThan(rhs.Node)) + return wrapNode(lhs.Node.GreaterThan(rhs.Node)) } func (lhs Node) GreaterThanOrEqual(rhs Node) Node { - return newNode(lhs.Node.GreaterThanOrEqual(rhs.Node)) + return wrapNode(lhs.Node.GreaterThanOrEqual(rhs.Node)) } func (lhs Node) LessThanExt(rhs Node) Node { - return newNode(lhs.Node.LessThanExt(rhs.Node)) + return wrapNode(lhs.Node.LessThanExt(rhs.Node)) } func (lhs Node) LessThanOrEqualExt(rhs Node) Node { - return newNode(lhs.Node.LessThanOrEqualExt(rhs.Node)) + return wrapNode(lhs.Node.LessThanOrEqualExt(rhs.Node)) } func (lhs Node) GreaterThanExt(rhs Node) Node { - return newNode(lhs.Node.GreaterThanExt(rhs.Node)) + return wrapNode(lhs.Node.GreaterThanExt(rhs.Node)) } func (lhs Node) GreaterThanOrEqualExt(rhs Node) Node { - return newNode(lhs.Node.GreaterThanOrEqualExt(rhs.Node)) + return wrapNode(lhs.Node.GreaterThanOrEqualExt(rhs.Node)) } func (lhs Node) Like(pattern types.Pattern) Node { - return newNode(lhs.Node.Like(pattern)) + return wrapNode(lhs.Node.Like(pattern)) } // _ _ _ @@ -64,19 +64,19 @@ func (lhs Node) Like(pattern types.Pattern) Node { // |___/ func (lhs Node) And(rhs Node) Node { - return newNode(lhs.Node.And(rhs.Node)) + return wrapNode(lhs.Node.And(rhs.Node)) } func (lhs Node) Or(rhs Node) Node { - return newNode(lhs.Node.Or(rhs.Node)) + return wrapNode(lhs.Node.Or(rhs.Node)) } func Not(rhs Node) Node { - return newNode(ast.Not(rhs.Node)) + return wrapNode(ast.Not(rhs.Node)) } func If(condition Node, ifTrue Node, ifFalse Node) Node { - return newNode(ast.If(condition.Node, ifTrue.Node, ifFalse.Node)) + return wrapNode(ast.If(condition.Node, ifTrue.Node, ifFalse.Node)) } // _ _ _ _ _ _ @@ -86,19 +86,19 @@ func If(condition Node, ifTrue Node, ifFalse Node) Node { // /_/ \_\_| |_|\__|_| |_|_| |_| |_|\___|\__|_|\___| func (lhs Node) Plus(rhs Node) Node { - return newNode(lhs.Node.Plus(rhs.Node)) + return wrapNode(lhs.Node.Plus(rhs.Node)) } func (lhs Node) Minus(rhs Node) Node { - return newNode(lhs.Node.Minus(rhs.Node)) + return wrapNode(lhs.Node.Minus(rhs.Node)) } func (lhs Node) Times(rhs Node) Node { - return newNode(lhs.Node.Times(rhs.Node)) + return wrapNode(lhs.Node.Times(rhs.Node)) } func Negate(rhs Node) Node { - return newNode(ast.Negate(rhs.Node)) + return wrapNode(ast.Negate(rhs.Node)) } // _ _ _ _ @@ -109,35 +109,35 @@ func Negate(rhs Node) Node { // |___/ func (lhs Node) In(rhs Node) Node { - return newNode(lhs.Node.In(rhs.Node)) + return wrapNode(lhs.Node.In(rhs.Node)) } func (lhs Node) Is(entityType types.Path) Node { - return newNode(lhs.Node.Is(entityType)) + return wrapNode(lhs.Node.Is(entityType)) } func (lhs Node) IsIn(entityType types.Path, rhs Node) Node { - return newNode(lhs.Node.IsIn(entityType, rhs.Node)) + return wrapNode(lhs.Node.IsIn(entityType, rhs.Node)) } func (lhs Node) Contains(rhs Node) Node { - return newNode(lhs.Node.Contains(rhs.Node)) + return wrapNode(lhs.Node.Contains(rhs.Node)) } func (lhs Node) ContainsAll(rhs Node) Node { - return newNode(lhs.Node.ContainsAll(rhs.Node)) + return wrapNode(lhs.Node.ContainsAll(rhs.Node)) } func (lhs Node) ContainsAny(rhs Node) Node { - return newNode(lhs.Node.ContainsAny(rhs.Node)) + return wrapNode(lhs.Node.ContainsAny(rhs.Node)) } func (lhs Node) Access(attr string) Node { - return newNode(lhs.Node.Access(attr)) + return wrapNode(lhs.Node.Access(attr)) } func (lhs Node) Has(attr string) Node { - return newNode(lhs.Node.Has(attr)) + return wrapNode(lhs.Node.Has(attr)) } // ___ ____ _ _ _ @@ -147,21 +147,21 @@ func (lhs Node) Has(attr string) Node { // |___|_| /_/ \_\__,_|\__,_|_| \___||___/___/ func (lhs Node) IsIpv4() Node { - return newNode(lhs.Node.IsIpv4()) + return wrapNode(lhs.Node.IsIpv4()) } func (lhs Node) IsIpv6() Node { - return newNode(lhs.Node.IsIpv6()) + return wrapNode(lhs.Node.IsIpv6()) } func (lhs Node) IsMulticast() Node { - return newNode(lhs.Node.IsMulticast()) + return wrapNode(lhs.Node.IsMulticast()) } func (lhs Node) IsLoopback() Node { - return newNode(lhs.Node.IsLoopback()) + return wrapNode(lhs.Node.IsLoopback()) } func (lhs Node) IsInRange(rhs Node) Node { - return newNode(lhs.Node.IsInRange(rhs.Node)) + return wrapNode(lhs.Node.IsInRange(rhs.Node)) } diff --git a/ast/policy.go b/ast/policy.go index bd5c6985..62184bf0 100644 --- a/ast/policy.go +++ b/ast/policy.go @@ -2,26 +2,28 @@ package ast import "github.com/cedar-policy/cedar-go/internal/ast" -type Policy struct { - *ast.Policy +type Policy ast.Policy + +func wrapPolicy(p *ast.Policy) *Policy { + return (*Policy)(p) } -func newPolicy(p *ast.Policy) *Policy { - return &Policy{p} +func (p *Policy) unwrap() *ast.Policy { + return (*ast.Policy)(p) } func Permit() *Policy { - return newPolicy(ast.Permit()) + return wrapPolicy(ast.Permit()) } func Forbid() *Policy { - return newPolicy(ast.Forbid()) + return wrapPolicy(ast.Forbid()) } func (p *Policy) When(node Node) *Policy { - return newPolicy(p.Policy.When(node.Node)) + return wrapPolicy(p.unwrap().When(node.Node)) } func (p *Policy) Unless(node Node) *Policy { - return newPolicy(p.Policy.Unless(node.Node)) + return wrapPolicy(p.unwrap().Unless(node.Node)) } diff --git a/ast/scope.go b/ast/scope.go index 37b4e2ac..d282101b 100644 --- a/ast/scope.go +++ b/ast/scope.go @@ -5,45 +5,45 @@ import ( ) func (p *Policy) PrincipalEq(entity types.EntityUID) *Policy { - return newPolicy(p.Policy.PrincipalEq(entity)) + return wrapPolicy(p.unwrap().PrincipalEq(entity)) } func (p *Policy) PrincipalIn(entity types.EntityUID) *Policy { - return newPolicy(p.Policy.PrincipalIn(entity)) + return wrapPolicy(p.unwrap().PrincipalIn(entity)) } func (p *Policy) PrincipalIs(entityType types.Path) *Policy { - return newPolicy(p.Policy.PrincipalIs(entityType)) + return wrapPolicy(p.unwrap().PrincipalIs(entityType)) } func (p *Policy) PrincipalIsIn(entityType types.Path, entity types.EntityUID) *Policy { - return newPolicy(p.Policy.PrincipalIsIn(entityType, entity)) + return wrapPolicy(p.unwrap().PrincipalIsIn(entityType, entity)) } func (p *Policy) ActionEq(entity types.EntityUID) *Policy { - return newPolicy(p.Policy.ActionEq(entity)) + return wrapPolicy(p.unwrap().ActionEq(entity)) } func (p *Policy) ActionIn(entity types.EntityUID) *Policy { - return newPolicy(p.Policy.ActionIn(entity)) + return wrapPolicy(p.unwrap().ActionIn(entity)) } func (p *Policy) ActionInSet(entities ...types.EntityUID) *Policy { - return newPolicy(p.Policy.ActionInSet(entities...)) + return wrapPolicy(p.unwrap().ActionInSet(entities...)) } func (p *Policy) ResourceEq(entity types.EntityUID) *Policy { - return newPolicy(p.Policy.ResourceEq(entity)) + return wrapPolicy(p.unwrap().ResourceEq(entity)) } func (p *Policy) ResourceIn(entity types.EntityUID) *Policy { - return newPolicy(p.Policy.ResourceIn(entity)) + return wrapPolicy(p.unwrap().ResourceIn(entity)) } func (p *Policy) ResourceIs(entityType types.Path) *Policy { - return newPolicy(p.Policy.ResourceIs(entityType)) + return wrapPolicy(p.unwrap().ResourceIs(entityType)) } func (p *Policy) ResourceIsIn(entityType types.Path, entity types.EntityUID) *Policy { - return newPolicy(p.Policy.ResourceIsIn(entityType, entity)) + return wrapPolicy(p.unwrap().ResourceIsIn(entityType, entity)) } diff --git a/ast/value.go b/ast/value.go index b2adb5c5..e0184a10 100644 --- a/ast/value.go +++ b/ast/value.go @@ -6,7 +6,7 @@ import ( ) func Boolean(b types.Boolean) Node { - return newNode(ast.Boolean(b)) + return wrapNode(ast.Boolean(b)) } func True() Node { @@ -18,17 +18,17 @@ func False() Node { } func String(s types.String) Node { - return newNode(ast.String(s)) + return wrapNode(ast.String(s)) } func Long(l types.Long) Node { - return newNode(ast.Long(l)) + return wrapNode(ast.Long(l)) } // Set is a convenience function that wraps concrete instances of a Cedar Set type // types in AST value nodes and passes them along to SetNodes. func Set(s types.Set) Node { - return newNode(ast.Set(s)) + return wrapNode(ast.Set(s)) } // SetNodes allows for a complex set definition with values potentially @@ -48,13 +48,13 @@ func SetNodes(nodes ...Node) Node { for _, n := range nodes { astNodes = append(astNodes, n.Node) } - return newNode(ast.SetNodes(astNodes...)) + return wrapNode(ast.SetNodes(astNodes...)) } // Record is a convenience function that wraps concrete instances of a Cedar Record type // types in AST value nodes and passes them along to RecordNodes. func Record(r types.Record) Node { - return newNode(ast.Record(r)) + return wrapNode(ast.Record(r)) } // RecordNodes allows for a complex record definition with values potentially @@ -72,7 +72,7 @@ func RecordNodes(entries map[types.String]Node) Node { for k, v := range entries { astNodes[k] = v.Node } - return newNode(ast.RecordNodes(astNodes)) + return wrapNode(ast.RecordNodes(astNodes)) } type RecordElement struct { @@ -85,19 +85,19 @@ func RecordElements(elements ...RecordElement) Node { for _, v := range elements { astNodes = append(astNodes, ast.RecordElement{Key: v.Key, Value: v.Value.Node}) } - return newNode(ast.RecordElements(astNodes...)) + return wrapNode(ast.RecordElements(astNodes...)) } func EntityUID(e types.EntityUID) Node { - return newNode(ast.EntityUID(e)) + return wrapNode(ast.EntityUID(e)) } func Decimal(d types.Decimal) Node { - return newNode(ast.Decimal(d)) + return wrapNode(ast.Decimal(d)) } func IPAddr(i types.IPAddr) Node { - return newNode(ast.IPAddr(i)) + return wrapNode(ast.IPAddr(i)) } func ExtensionCall(name types.String, args ...Node) Node { @@ -105,5 +105,5 @@ func ExtensionCall(name types.String, args ...Node) Node { for _, v := range args { astNodes = append(astNodes, v.Node) } - return newNode(ast.ExtensionCall(name, astNodes...)) + return wrapNode(ast.ExtensionCall(name, astNodes...)) } diff --git a/ast/variable.go b/ast/variable.go index 7d724bf3..f12ce38b 100644 --- a/ast/variable.go +++ b/ast/variable.go @@ -3,17 +3,17 @@ package ast import "github.com/cedar-policy/cedar-go/internal/ast" func Principal() Node { - return newNode(ast.Principal()) + return wrapNode(ast.Principal()) } func Action() Node { - return newNode(ast.Action()) + return wrapNode(ast.Action()) } func Resource() Node { - return newNode(ast.Resource()) + return wrapNode(ast.Resource()) } func Context() Node { - return newNode(ast.Context()) + return wrapNode(ast.Context()) } From 536d32cd359496492d0f3643c60a18a9e13b07ff Mon Sep 17 00:00:00 2001 From: philhassey Date: Tue, 13 Aug 2024 17:03:36 -0600 Subject: [PATCH 094/216] ast: add test coverage for the sugar ast Addresses IDX-142 Signed-off-by: philhassey --- ast/ast_test.go | 91 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 91 insertions(+) diff --git a/ast/ast_test.go b/ast/ast_test.go index 966cdb0f..44b09a1e 100644 --- a/ast/ast_test.go +++ b/ast/ast_test.go @@ -1,9 +1,12 @@ package ast_test import ( + "net/netip" "testing" "github.com/cedar-policy/cedar-go/ast" + internalast "github.com/cedar-policy/cedar-go/internal/ast" + "github.com/cedar-policy/cedar-go/internal/testutil" "github.com/cedar-policy/cedar-go/types" ) @@ -68,3 +71,91 @@ func TestAst(t *testing.T) { ).Contains(ast.Long(1)), ) } + +func TestASTByTable(t *testing.T) { + t.Parallel() + tests := []struct { + name string + in *ast.Policy + out *internalast.Policy + }{ + {"permit", ast.Permit(), internalast.Permit()}, + {"forbid", ast.Forbid(), internalast.Forbid()}, + {"annotationPermit", ast.Annotation("key", "value").Permit(), internalast.Annotation("key", "value").Permit()}, + {"annotationForbid", ast.Annotation("key", "value").Forbid(), internalast.Annotation("key", "value").Forbid()}, + {"annotations", ast.Annotation("key", "value").Annotation("abc", "xyz").Permit(), internalast.Annotation("key", "value").Annotation("abc", "xyz").Permit()}, + {"policyAnnotate", ast.Permit().Annotate("key", "value"), internalast.Permit().Annotate("key", "value")}, + {"when", ast.Permit().When(ast.True()), internalast.Permit().When(internalast.True())}, + {"unless", ast.Permit().Unless(ast.True()), internalast.Permit().Unless(internalast.True())}, + {"scopePrincipalEq", ast.Permit().PrincipalEq(types.NewEntityUID("T", "42")), internalast.Permit().PrincipalEq(types.NewEntityUID("T", "42"))}, + {"scopePrincipalIn", ast.Permit().PrincipalIn(types.NewEntityUID("T", "42")), internalast.Permit().PrincipalIn(types.NewEntityUID("T", "42"))}, + {"scopePrincipalIs", ast.Permit().PrincipalIs("T"), internalast.Permit().PrincipalIs("T")}, + {"scopePrincipalIsIn", ast.Permit().PrincipalIsIn("T", types.NewEntityUID("T", "42")), internalast.Permit().PrincipalIsIn("T", types.NewEntityUID("T", "42"))}, + {"scopeActionEq", ast.Permit().ActionEq(types.NewEntityUID("T", "42")), internalast.Permit().ActionEq(types.NewEntityUID("T", "42"))}, + {"scopeActionIn", ast.Permit().ActionIn(types.NewEntityUID("T", "42")), internalast.Permit().ActionIn(types.NewEntityUID("T", "42"))}, + {"scopeActionInSet", ast.Permit().ActionInSet(types.NewEntityUID("T", "42"), types.NewEntityUID("T", "43")), internalast.Permit().ActionInSet(types.NewEntityUID("T", "42"), types.NewEntityUID("T", "43"))}, + {"scopeResourceEq", ast.Permit().ResourceEq(types.NewEntityUID("T", "42")), internalast.Permit().ResourceEq(types.NewEntityUID("T", "42"))}, + {"scopeResourceIn", ast.Permit().ResourceIn(types.NewEntityUID("T", "42")), internalast.Permit().ResourceIn(types.NewEntityUID("T", "42"))}, + {"scopeResourceIs", ast.Permit().ResourceIs("T"), internalast.Permit().ResourceIs("T")}, + {"scopeResourceIsIn", ast.Permit().ResourceIsIn("T", types.NewEntityUID("T", "42")), internalast.Permit().ResourceIsIn("T", types.NewEntityUID("T", "42"))}, + {"variablePrincipal", ast.Permit().When(ast.Principal()), internalast.Permit().When(internalast.Principal())}, + {"variableAction", ast.Permit().When(ast.Action()), internalast.Permit().When(internalast.Action())}, + {"variableResource", ast.Permit().When(ast.Resource()), internalast.Permit().When(internalast.Resource())}, + {"variableContext", ast.Permit().When(ast.Context()), internalast.Permit().When(internalast.Context())}, + {"valueBoolFalse", ast.Permit().When(ast.Boolean(false)), internalast.Permit().When(internalast.Boolean(false))}, + {"valueBoolTrue", ast.Permit().When(ast.Boolean(true)), internalast.Permit().When(internalast.Boolean(true))}, + {"valueTrue", ast.Permit().When(ast.True()), internalast.Permit().When(internalast.True())}, + {"valueFalse", ast.Permit().When(ast.False()), internalast.Permit().When(internalast.False())}, + {"valueString", ast.Permit().When(ast.String("cedar")), internalast.Permit().When(internalast.String("cedar"))}, + {"valueLong", ast.Permit().When(ast.Long(42)), internalast.Permit().When(internalast.Long(42))}, + {"valueSet", ast.Permit().When(ast.Set(types.Set{types.Long(42), types.Long(43)})), internalast.Permit().When(internalast.Set(types.Set{types.Long(42), types.Long(43)}))}, + {"valueSetNodes", ast.Permit().When(ast.SetNodes(ast.Long(42), ast.Long(43))), internalast.Permit().When(internalast.SetNodes(internalast.Long(42), internalast.Long(43)))}, + {"valueRecord", ast.Permit().When(ast.Record(types.Record{"key": types.Long(43)})), internalast.Permit().When(internalast.Record(types.Record{"key": types.Long(43)}))}, + {"valueRecordNodes", ast.Permit().When(ast.RecordNodes(map[types.String]ast.Node{"key": ast.Long(42)})), internalast.Permit().When(internalast.RecordNodes(map[types.String]internalast.Node{"key": internalast.Long(42)}))}, + {"valueRecordElements", ast.Permit().When(ast.RecordElements(ast.RecordElement{Key: "key", Value: ast.Long(42)})), internalast.Permit().When(internalast.RecordElements(internalast.RecordElement{Key: "key", Value: internalast.Long(42)}))}, + {"valueEntityUID", ast.Permit().When(ast.EntityUID(types.NewEntityUID("T", "42"))), internalast.Permit().When(internalast.EntityUID(types.NewEntityUID("T", "42")))}, + {"valueDecimal", ast.Permit().When(ast.Decimal(420000)), internalast.Permit().When(internalast.Decimal(420000))}, + {"valueIPAddr", ast.Permit().When(ast.IPAddr(types.IPAddr(netip.MustParsePrefix("127.0.0.1/16")))), internalast.Permit().When(internalast.IPAddr(types.IPAddr(netip.MustParsePrefix("127.0.0.1/16"))))}, + {"extensionCall", ast.Permit().When(ast.ExtensionCall("ip", ast.String("127.0.0.1"))), internalast.Permit().When(internalast.ExtensionCall("ip", internalast.String("127.0.0.1")))}, + {"opEquals", ast.Permit().When(ast.Long(42).Equals(ast.Long(43))), internalast.Permit().When(internalast.Long(42).Equals(internalast.Long(43)))}, + {"opNotEquals", ast.Permit().When(ast.Long(42).NotEquals(ast.Long(43))), internalast.Permit().When(internalast.Long(42).NotEquals(internalast.Long(43)))}, + {"opLessThan", ast.Permit().When(ast.Long(42).LessThan(ast.Long(43))), internalast.Permit().When(internalast.Long(42).LessThan(internalast.Long(43)))}, + {"opLessThanOrEqual", ast.Permit().When(ast.Long(42).LessThanOrEqual(ast.Long(43))), internalast.Permit().When(internalast.Long(42).LessThanOrEqual(internalast.Long(43)))}, + {"opGreaterThan", ast.Permit().When(ast.Long(42).GreaterThan(ast.Long(43))), internalast.Permit().When(internalast.Long(42).GreaterThan(internalast.Long(43)))}, + {"opGreaterThanOrEqual", ast.Permit().When(ast.Long(42).GreaterThanOrEqual(ast.Long(43))), internalast.Permit().When(internalast.Long(42).GreaterThanOrEqual(internalast.Long(43)))}, + {"opLessThanExt", ast.Permit().When(ast.Long(42).LessThanExt(ast.Long(43))), internalast.Permit().When(internalast.Long(42).LessThanExt(internalast.Long(43)))}, + {"opLessThanOrEqualExt", ast.Permit().When(ast.Long(42).LessThanOrEqualExt(ast.Long(43))), internalast.Permit().When(internalast.Long(42).LessThanOrEqualExt(internalast.Long(43)))}, + {"opGreaterThanExt", ast.Permit().When(ast.Long(42).GreaterThanExt(ast.Long(43))), internalast.Permit().When(internalast.Long(42).GreaterThanExt(internalast.Long(43)))}, + {"opGreaterThanOrEqualExt", ast.Permit().When(ast.Long(42).GreaterThanOrEqualExt(ast.Long(43))), internalast.Permit().When(internalast.Long(42).GreaterThanOrEqualExt(internalast.Long(43)))}, + {"opLike", ast.Permit().When(ast.Long(42).Like(types.Pattern{})), internalast.Permit().When(internalast.Long(42).Like(types.Pattern{}))}, + {"opAnd", ast.Permit().When(ast.Long(42).And(ast.Long(43))), internalast.Permit().When(internalast.Long(42).And(internalast.Long(43)))}, + {"opOr", ast.Permit().When(ast.Long(42).Or(ast.Long(43))), internalast.Permit().When(internalast.Long(42).Or(internalast.Long(43)))}, + {"opNot", ast.Permit().When(ast.Not(ast.True())), internalast.Permit().When(internalast.Not(internalast.True()))}, + {"opIf", ast.Permit().When(ast.If(ast.True(), ast.Long(42), ast.Long(43))), internalast.Permit().When(internalast.If(internalast.True(), internalast.Long(42), internalast.Long(43)))}, + {"opPlus", ast.Permit().When(ast.Long(42).Plus(ast.Long(43))), internalast.Permit().When(internalast.Long(42).Plus(internalast.Long(43)))}, + {"opMinus", ast.Permit().When(ast.Long(42).Minus(ast.Long(43))), internalast.Permit().When(internalast.Long(42).Minus(internalast.Long(43)))}, + {"opTimes", ast.Permit().When(ast.Long(42).Times(ast.Long(43))), internalast.Permit().When(internalast.Long(42).Times(internalast.Long(43)))}, + {"opNegate", ast.Permit().When(ast.Negate(ast.True())), internalast.Permit().When(internalast.Negate(internalast.True()))}, + {"opIn", ast.Permit().When(ast.Long(42).In(ast.Long(43))), internalast.Permit().When(internalast.Long(42).In(internalast.Long(43)))}, + {"opIs", ast.Permit().When(ast.Long(42).Is(types.Path("T"))), internalast.Permit().When(internalast.Long(42).Is(types.Path("T")))}, + {"opIsIn", ast.Permit().When(ast.Long(42).IsIn(types.Path("T"), ast.Long(43))), internalast.Permit().When(internalast.Long(42).IsIn(types.Path("T"), internalast.Long(43)))}, + {"opContains", ast.Permit().When(ast.Long(42).Contains(ast.Long(43))), internalast.Permit().When(internalast.Long(42).Contains(internalast.Long(43)))}, + {"opContainsAll", ast.Permit().When(ast.Long(42).ContainsAll(ast.Long(43))), internalast.Permit().When(internalast.Long(42).ContainsAll(internalast.Long(43)))}, + {"opContainsAny", ast.Permit().When(ast.Long(42).ContainsAny(ast.Long(43))), internalast.Permit().When(internalast.Long(42).ContainsAny(internalast.Long(43)))}, + {"opAccess", ast.Permit().When(ast.Long(42).Access("key")), internalast.Permit().When(internalast.Long(42).Access("key"))}, + {"opHas", ast.Permit().When(ast.Long(42).Has("key")), internalast.Permit().When(internalast.Long(42).Has("key"))}, + {"opIsIpv4", ast.Permit().When(ast.Long(42).IsIpv4()), internalast.Permit().When(internalast.Long(42).IsIpv4())}, + {"opIsIpv6", ast.Permit().When(ast.Long(42).IsIpv6()), internalast.Permit().When(internalast.Long(42).IsIpv6())}, + {"opIsMulticast", ast.Permit().When(ast.Long(42).IsMulticast()), internalast.Permit().When(internalast.Long(42).IsMulticast())}, + {"opIsLoopback", ast.Permit().When(ast.Long(42).IsLoopback()), internalast.Permit().When(internalast.Long(42).IsLoopback())}, + {"opIsInRange", ast.Permit().When(ast.Long(42).IsInRange(ast.Long(43))), internalast.Permit().When(internalast.Long(42).IsInRange(internalast.Long(43)))}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + testutil.Equals(t, (*internalast.Policy)(tt.in), tt.out) + }) + } +} From 0153fe90fda48d9abeaa112ca0aee671d0734169 Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Tue, 13 Aug 2024 16:15:12 -0700 Subject: [PATCH 095/216] cedar-go/ast: Change the commentary and naming around the AST example tests Signed-off-by: philhassey --- ast/ast_test.go | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/ast/ast_test.go b/ast/ast_test.go index 44b09a1e..d188272c 100644 --- a/ast/ast_test.go +++ b/ast/ast_test.go @@ -10,8 +10,9 @@ import ( "github.com/cedar-policy/cedar-go/types" ) -// These tests mostly verify that policy ASTs compile -func TestAst(t *testing.T) { +// These tests serve mostly as examples of how to translate from Cedar text into programmatic AST construction. They +// don't verify anything. +func TestAstExamples(t *testing.T) { t.Parallel() johnny := types.NewEntityUID("User", "johnny") @@ -50,7 +51,7 @@ func TestAst(t *testing.T) { // forbid (principal, action, resource) // when { {x: "value"}.x == "value" } // when { {x: 1 + context.fooCount}.x == 3 } - // when { [1, 2 + 3, context.fooCount].contains(1) }; + // when { [1, (2 + 3) * 4, context.fooCount].contains(1) }; simpleRecord := types.Record{ "x": types.String("value"), } @@ -66,7 +67,7 @@ func TestAst(t *testing.T) { When( ast.SetNodes( ast.Long(1), - ast.Long(2).Plus(ast.Long(3)), + ast.Long(2).Plus(ast.Long(3)).Times(ast.Long(4)), ast.Context().Access("fooCount"), ).Contains(ast.Long(1)), ) From 56909aa8114bba1600f0cd0a7e7a4a973c03039f Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Tue, 13 Aug 2024 16:20:34 -0700 Subject: [PATCH 096/216] cedar-go/ast: format the table tests to increase legibility Signed-off-by: philhassey --- ast/ast_test.go | 420 ++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 350 insertions(+), 70 deletions(-) diff --git a/ast/ast_test.go b/ast/ast_test.go index d188272c..057272ed 100644 --- a/ast/ast_test.go +++ b/ast/ast_test.go @@ -80,76 +80,356 @@ func TestASTByTable(t *testing.T) { in *ast.Policy out *internalast.Policy }{ - {"permit", ast.Permit(), internalast.Permit()}, - {"forbid", ast.Forbid(), internalast.Forbid()}, - {"annotationPermit", ast.Annotation("key", "value").Permit(), internalast.Annotation("key", "value").Permit()}, - {"annotationForbid", ast.Annotation("key", "value").Forbid(), internalast.Annotation("key", "value").Forbid()}, - {"annotations", ast.Annotation("key", "value").Annotation("abc", "xyz").Permit(), internalast.Annotation("key", "value").Annotation("abc", "xyz").Permit()}, - {"policyAnnotate", ast.Permit().Annotate("key", "value"), internalast.Permit().Annotate("key", "value")}, - {"when", ast.Permit().When(ast.True()), internalast.Permit().When(internalast.True())}, - {"unless", ast.Permit().Unless(ast.True()), internalast.Permit().Unless(internalast.True())}, - {"scopePrincipalEq", ast.Permit().PrincipalEq(types.NewEntityUID("T", "42")), internalast.Permit().PrincipalEq(types.NewEntityUID("T", "42"))}, - {"scopePrincipalIn", ast.Permit().PrincipalIn(types.NewEntityUID("T", "42")), internalast.Permit().PrincipalIn(types.NewEntityUID("T", "42"))}, - {"scopePrincipalIs", ast.Permit().PrincipalIs("T"), internalast.Permit().PrincipalIs("T")}, - {"scopePrincipalIsIn", ast.Permit().PrincipalIsIn("T", types.NewEntityUID("T", "42")), internalast.Permit().PrincipalIsIn("T", types.NewEntityUID("T", "42"))}, - {"scopeActionEq", ast.Permit().ActionEq(types.NewEntityUID("T", "42")), internalast.Permit().ActionEq(types.NewEntityUID("T", "42"))}, - {"scopeActionIn", ast.Permit().ActionIn(types.NewEntityUID("T", "42")), internalast.Permit().ActionIn(types.NewEntityUID("T", "42"))}, - {"scopeActionInSet", ast.Permit().ActionInSet(types.NewEntityUID("T", "42"), types.NewEntityUID("T", "43")), internalast.Permit().ActionInSet(types.NewEntityUID("T", "42"), types.NewEntityUID("T", "43"))}, - {"scopeResourceEq", ast.Permit().ResourceEq(types.NewEntityUID("T", "42")), internalast.Permit().ResourceEq(types.NewEntityUID("T", "42"))}, - {"scopeResourceIn", ast.Permit().ResourceIn(types.NewEntityUID("T", "42")), internalast.Permit().ResourceIn(types.NewEntityUID("T", "42"))}, - {"scopeResourceIs", ast.Permit().ResourceIs("T"), internalast.Permit().ResourceIs("T")}, - {"scopeResourceIsIn", ast.Permit().ResourceIsIn("T", types.NewEntityUID("T", "42")), internalast.Permit().ResourceIsIn("T", types.NewEntityUID("T", "42"))}, - {"variablePrincipal", ast.Permit().When(ast.Principal()), internalast.Permit().When(internalast.Principal())}, - {"variableAction", ast.Permit().When(ast.Action()), internalast.Permit().When(internalast.Action())}, - {"variableResource", ast.Permit().When(ast.Resource()), internalast.Permit().When(internalast.Resource())}, - {"variableContext", ast.Permit().When(ast.Context()), internalast.Permit().When(internalast.Context())}, - {"valueBoolFalse", ast.Permit().When(ast.Boolean(false)), internalast.Permit().When(internalast.Boolean(false))}, - {"valueBoolTrue", ast.Permit().When(ast.Boolean(true)), internalast.Permit().When(internalast.Boolean(true))}, - {"valueTrue", ast.Permit().When(ast.True()), internalast.Permit().When(internalast.True())}, - {"valueFalse", ast.Permit().When(ast.False()), internalast.Permit().When(internalast.False())}, - {"valueString", ast.Permit().When(ast.String("cedar")), internalast.Permit().When(internalast.String("cedar"))}, - {"valueLong", ast.Permit().When(ast.Long(42)), internalast.Permit().When(internalast.Long(42))}, - {"valueSet", ast.Permit().When(ast.Set(types.Set{types.Long(42), types.Long(43)})), internalast.Permit().When(internalast.Set(types.Set{types.Long(42), types.Long(43)}))}, - {"valueSetNodes", ast.Permit().When(ast.SetNodes(ast.Long(42), ast.Long(43))), internalast.Permit().When(internalast.SetNodes(internalast.Long(42), internalast.Long(43)))}, - {"valueRecord", ast.Permit().When(ast.Record(types.Record{"key": types.Long(43)})), internalast.Permit().When(internalast.Record(types.Record{"key": types.Long(43)}))}, - {"valueRecordNodes", ast.Permit().When(ast.RecordNodes(map[types.String]ast.Node{"key": ast.Long(42)})), internalast.Permit().When(internalast.RecordNodes(map[types.String]internalast.Node{"key": internalast.Long(42)}))}, - {"valueRecordElements", ast.Permit().When(ast.RecordElements(ast.RecordElement{Key: "key", Value: ast.Long(42)})), internalast.Permit().When(internalast.RecordElements(internalast.RecordElement{Key: "key", Value: internalast.Long(42)}))}, - {"valueEntityUID", ast.Permit().When(ast.EntityUID(types.NewEntityUID("T", "42"))), internalast.Permit().When(internalast.EntityUID(types.NewEntityUID("T", "42")))}, - {"valueDecimal", ast.Permit().When(ast.Decimal(420000)), internalast.Permit().When(internalast.Decimal(420000))}, - {"valueIPAddr", ast.Permit().When(ast.IPAddr(types.IPAddr(netip.MustParsePrefix("127.0.0.1/16")))), internalast.Permit().When(internalast.IPAddr(types.IPAddr(netip.MustParsePrefix("127.0.0.1/16"))))}, - {"extensionCall", ast.Permit().When(ast.ExtensionCall("ip", ast.String("127.0.0.1"))), internalast.Permit().When(internalast.ExtensionCall("ip", internalast.String("127.0.0.1")))}, - {"opEquals", ast.Permit().When(ast.Long(42).Equals(ast.Long(43))), internalast.Permit().When(internalast.Long(42).Equals(internalast.Long(43)))}, - {"opNotEquals", ast.Permit().When(ast.Long(42).NotEquals(ast.Long(43))), internalast.Permit().When(internalast.Long(42).NotEquals(internalast.Long(43)))}, - {"opLessThan", ast.Permit().When(ast.Long(42).LessThan(ast.Long(43))), internalast.Permit().When(internalast.Long(42).LessThan(internalast.Long(43)))}, - {"opLessThanOrEqual", ast.Permit().When(ast.Long(42).LessThanOrEqual(ast.Long(43))), internalast.Permit().When(internalast.Long(42).LessThanOrEqual(internalast.Long(43)))}, - {"opGreaterThan", ast.Permit().When(ast.Long(42).GreaterThan(ast.Long(43))), internalast.Permit().When(internalast.Long(42).GreaterThan(internalast.Long(43)))}, - {"opGreaterThanOrEqual", ast.Permit().When(ast.Long(42).GreaterThanOrEqual(ast.Long(43))), internalast.Permit().When(internalast.Long(42).GreaterThanOrEqual(internalast.Long(43)))}, - {"opLessThanExt", ast.Permit().When(ast.Long(42).LessThanExt(ast.Long(43))), internalast.Permit().When(internalast.Long(42).LessThanExt(internalast.Long(43)))}, - {"opLessThanOrEqualExt", ast.Permit().When(ast.Long(42).LessThanOrEqualExt(ast.Long(43))), internalast.Permit().When(internalast.Long(42).LessThanOrEqualExt(internalast.Long(43)))}, - {"opGreaterThanExt", ast.Permit().When(ast.Long(42).GreaterThanExt(ast.Long(43))), internalast.Permit().When(internalast.Long(42).GreaterThanExt(internalast.Long(43)))}, - {"opGreaterThanOrEqualExt", ast.Permit().When(ast.Long(42).GreaterThanOrEqualExt(ast.Long(43))), internalast.Permit().When(internalast.Long(42).GreaterThanOrEqualExt(internalast.Long(43)))}, - {"opLike", ast.Permit().When(ast.Long(42).Like(types.Pattern{})), internalast.Permit().When(internalast.Long(42).Like(types.Pattern{}))}, - {"opAnd", ast.Permit().When(ast.Long(42).And(ast.Long(43))), internalast.Permit().When(internalast.Long(42).And(internalast.Long(43)))}, - {"opOr", ast.Permit().When(ast.Long(42).Or(ast.Long(43))), internalast.Permit().When(internalast.Long(42).Or(internalast.Long(43)))}, - {"opNot", ast.Permit().When(ast.Not(ast.True())), internalast.Permit().When(internalast.Not(internalast.True()))}, - {"opIf", ast.Permit().When(ast.If(ast.True(), ast.Long(42), ast.Long(43))), internalast.Permit().When(internalast.If(internalast.True(), internalast.Long(42), internalast.Long(43)))}, - {"opPlus", ast.Permit().When(ast.Long(42).Plus(ast.Long(43))), internalast.Permit().When(internalast.Long(42).Plus(internalast.Long(43)))}, - {"opMinus", ast.Permit().When(ast.Long(42).Minus(ast.Long(43))), internalast.Permit().When(internalast.Long(42).Minus(internalast.Long(43)))}, - {"opTimes", ast.Permit().When(ast.Long(42).Times(ast.Long(43))), internalast.Permit().When(internalast.Long(42).Times(internalast.Long(43)))}, - {"opNegate", ast.Permit().When(ast.Negate(ast.True())), internalast.Permit().When(internalast.Negate(internalast.True()))}, - {"opIn", ast.Permit().When(ast.Long(42).In(ast.Long(43))), internalast.Permit().When(internalast.Long(42).In(internalast.Long(43)))}, - {"opIs", ast.Permit().When(ast.Long(42).Is(types.Path("T"))), internalast.Permit().When(internalast.Long(42).Is(types.Path("T")))}, - {"opIsIn", ast.Permit().When(ast.Long(42).IsIn(types.Path("T"), ast.Long(43))), internalast.Permit().When(internalast.Long(42).IsIn(types.Path("T"), internalast.Long(43)))}, - {"opContains", ast.Permit().When(ast.Long(42).Contains(ast.Long(43))), internalast.Permit().When(internalast.Long(42).Contains(internalast.Long(43)))}, - {"opContainsAll", ast.Permit().When(ast.Long(42).ContainsAll(ast.Long(43))), internalast.Permit().When(internalast.Long(42).ContainsAll(internalast.Long(43)))}, - {"opContainsAny", ast.Permit().When(ast.Long(42).ContainsAny(ast.Long(43))), internalast.Permit().When(internalast.Long(42).ContainsAny(internalast.Long(43)))}, - {"opAccess", ast.Permit().When(ast.Long(42).Access("key")), internalast.Permit().When(internalast.Long(42).Access("key"))}, - {"opHas", ast.Permit().When(ast.Long(42).Has("key")), internalast.Permit().When(internalast.Long(42).Has("key"))}, - {"opIsIpv4", ast.Permit().When(ast.Long(42).IsIpv4()), internalast.Permit().When(internalast.Long(42).IsIpv4())}, - {"opIsIpv6", ast.Permit().When(ast.Long(42).IsIpv6()), internalast.Permit().When(internalast.Long(42).IsIpv6())}, - {"opIsMulticast", ast.Permit().When(ast.Long(42).IsMulticast()), internalast.Permit().When(internalast.Long(42).IsMulticast())}, - {"opIsLoopback", ast.Permit().When(ast.Long(42).IsLoopback()), internalast.Permit().When(internalast.Long(42).IsLoopback())}, - {"opIsInRange", ast.Permit().When(ast.Long(42).IsInRange(ast.Long(43))), internalast.Permit().When(internalast.Long(42).IsInRange(internalast.Long(43)))}, + { + "permit", + ast.Permit(), + internalast.Permit(), + }, + { + "forbid", + ast.Forbid(), + internalast.Forbid(), + }, + { + "annotationPermit", + ast.Annotation("key", "value").Permit(), + internalast.Annotation("key", "value").Permit(), + }, + { + "annotationForbid", + ast.Annotation("key", "value").Forbid(), + internalast.Annotation("key", "value").Forbid(), + }, + { + "annotations", + ast.Annotation("key", "value").Annotation("abc", "xyz").Permit(), + internalast.Annotation("key", "value").Annotation("abc", "xyz").Permit(), + }, + { + "policyAnnotate", + ast.Permit().Annotate("key", "value"), + internalast.Permit().Annotate("key", "value"), + }, + { + "when", + ast.Permit().When(ast.True()), + internalast.Permit().When(internalast.True()), + }, + { + "unless", + ast.Permit().Unless(ast.True()), + internalast.Permit().Unless(internalast.True()), + }, + { + "scopePrincipalEq", + ast.Permit().PrincipalEq(types.NewEntityUID("T", "42")), + internalast.Permit().PrincipalEq(types.NewEntityUID("T", "42")), + }, + { + "scopePrincipalIn", + ast.Permit().PrincipalIn(types.NewEntityUID("T", "42")), + internalast.Permit().PrincipalIn(types.NewEntityUID("T", "42")), + }, + { + "scopePrincipalIs", + ast.Permit().PrincipalIs("T"), + internalast.Permit().PrincipalIs("T"), + }, + { + "scopePrincipalIsIn", + ast.Permit().PrincipalIsIn("T", types.NewEntityUID("T", "42")), + internalast.Permit().PrincipalIsIn("T", types.NewEntityUID("T", "42")), + }, + { + "scopeActionEq", + ast.Permit().ActionEq(types.NewEntityUID("T", "42")), + internalast.Permit().ActionEq(types.NewEntityUID("T", "42")), + }, + { + "scopeActionIn", + ast.Permit().ActionIn(types.NewEntityUID("T", "42")), + internalast.Permit().ActionIn(types.NewEntityUID("T", "42")), + }, + { + "scopeActionInSet", + ast.Permit().ActionInSet(types.NewEntityUID("T", "42"), types.NewEntityUID("T", "43")), + internalast.Permit().ActionInSet(types.NewEntityUID("T", "42"), types.NewEntityUID("T", "43")), + }, + { + "scopeResourceEq", + ast.Permit().ResourceEq(types.NewEntityUID("T", "42")), + internalast.Permit().ResourceEq(types.NewEntityUID("T", "42")), + }, + { + "scopeResourceIn", + ast.Permit().ResourceIn(types.NewEntityUID("T", "42")), + internalast.Permit().ResourceIn(types.NewEntityUID("T", "42")), + }, + { + "scopeResourceIs", + ast.Permit().ResourceIs("T"), + internalast.Permit().ResourceIs("T"), + }, + { + "scopeResourceIsIn", + ast.Permit().ResourceIsIn("T", types.NewEntityUID("T", "42")), + internalast.Permit().ResourceIsIn("T", types.NewEntityUID("T", "42")), + }, + { + "variablePrincipal", + ast.Permit().When(ast.Principal()), + internalast.Permit().When(internalast.Principal()), + }, + { + "variableAction", + ast.Permit().When(ast.Action()), + internalast.Permit().When(internalast.Action()), + }, + { + "variableResource", + ast.Permit().When(ast.Resource()), + internalast.Permit().When(internalast.Resource()), + }, + { + "variableContext", + ast.Permit().When(ast.Context()), + internalast.Permit().When(internalast.Context()), + }, + { + "valueBoolFalse", + ast.Permit().When(ast.Boolean(false)), + internalast.Permit().When(internalast.Boolean(false)), + }, + { + "valueBoolTrue", + ast.Permit().When(ast.Boolean(true)), + internalast.Permit().When(internalast.Boolean(true)), + }, + { + "valueTrue", + ast.Permit().When(ast.True()), + internalast.Permit().When(internalast.True()), + }, + { + "valueFalse", + ast.Permit().When(ast.False()), + internalast.Permit().When(internalast.False()), + }, + { + "valueString", + ast.Permit().When(ast.String("cedar")), + internalast.Permit().When(internalast.String("cedar")), + }, + { + "valueLong", + ast.Permit().When(ast.Long(42)), + internalast.Permit().When(internalast.Long(42)), + }, + { + "valueSet", + ast.Permit().When(ast.Set(types.Set{types.Long(42), types.Long(43)})), + internalast.Permit().When(internalast.Set(types.Set{types.Long(42), types.Long(43)})), + }, + { + "valueSetNodes", + ast.Permit().When(ast.SetNodes(ast.Long(42), ast.Long(43))), + internalast.Permit().When(internalast.SetNodes(internalast.Long(42), internalast.Long(43))), + }, + { + "valueRecord", + ast.Permit().When(ast.Record(types.Record{"key": types.Long(43)})), + internalast.Permit().When(internalast.Record(types.Record{"key": types.Long(43)})), + }, + { + "valueRecordNodes", + ast.Permit().When(ast.RecordNodes(map[types.String]ast.Node{"key": ast.Long(42)})), + internalast.Permit().When(internalast.RecordNodes(map[types.String]internalast.Node{"key": internalast.Long(42)})), + }, + { + "valueRecordElements", + ast.Permit().When(ast.RecordElements(ast.RecordElement{Key: "key", Value: ast.Long(42)})), + internalast.Permit().When(internalast.RecordElements(internalast.RecordElement{Key: "key", Value: internalast.Long(42)})), + }, + { + "valueEntityUID", + ast.Permit().When(ast.EntityUID(types.NewEntityUID("T", "42"))), + internalast.Permit().When(internalast.EntityUID(types.NewEntityUID("T", "42"))), + }, + { + "valueDecimal", + ast.Permit().When(ast.Decimal(420000)), + internalast.Permit().When(internalast.Decimal(420000)), + }, + { + "valueIPAddr", + ast.Permit().When(ast.IPAddr(types.IPAddr(netip.MustParsePrefix("127.0.0.1/16")))), + internalast.Permit().When(internalast.IPAddr(types.IPAddr(netip.MustParsePrefix("127.0.0.1/16")))), + }, + { + "extensionCall", + ast.Permit().When(ast.ExtensionCall("ip", ast.String("127.0.0.1"))), + internalast.Permit().When(internalast.ExtensionCall("ip", internalast.String("127.0.0.1"))), + }, + { + "opEquals", + ast.Permit().When(ast.Long(42).Equals(ast.Long(43))), + internalast.Permit().When(internalast.Long(42).Equals(internalast.Long(43))), + }, + { + "opNotEquals", + ast.Permit().When(ast.Long(42).NotEquals(ast.Long(43))), + internalast.Permit().When(internalast.Long(42).NotEquals(internalast.Long(43))), + }, + { + "opLessThan", + ast.Permit().When(ast.Long(42).LessThan(ast.Long(43))), + internalast.Permit().When(internalast.Long(42).LessThan(internalast.Long(43))), + }, + { + "opLessThanOrEqual", + ast.Permit().When(ast.Long(42).LessThanOrEqual(ast.Long(43))), + internalast.Permit().When(internalast.Long(42).LessThanOrEqual(internalast.Long(43))), + }, + { + "opGreaterThan", + ast.Permit().When(ast.Long(42).GreaterThan(ast.Long(43))), + internalast.Permit().When(internalast.Long(42).GreaterThan(internalast.Long(43))), + }, + { + "opGreaterThanOrEqual", + ast.Permit().When(ast.Long(42).GreaterThanOrEqual(ast.Long(43))), + internalast.Permit().When(internalast.Long(42).GreaterThanOrEqual(internalast.Long(43))), + }, + { + "opLessThanExt", + ast.Permit().When(ast.Long(42).LessThanExt(ast.Long(43))), + internalast.Permit().When(internalast.Long(42).LessThanExt(internalast.Long(43))), + }, + { + "opLessThanOrEqualExt", + ast.Permit().When(ast.Long(42).LessThanOrEqualExt(ast.Long(43))), + internalast.Permit().When(internalast.Long(42).LessThanOrEqualExt(internalast.Long(43))), + }, + { + "opGreaterThanExt", + ast.Permit().When(ast.Long(42).GreaterThanExt(ast.Long(43))), + internalast.Permit().When(internalast.Long(42).GreaterThanExt(internalast.Long(43))), + }, + { + "opGreaterThanOrEqualExt", + ast.Permit().When(ast.Long(42).GreaterThanOrEqualExt(ast.Long(43))), + internalast.Permit().When(internalast.Long(42).GreaterThanOrEqualExt(internalast.Long(43))), + }, + { + "opLike", + ast.Permit().When(ast.Long(42).Like(types.Pattern{})), + internalast.Permit().When(internalast.Long(42).Like(types.Pattern{})), + }, + { + "opAnd", + ast.Permit().When(ast.Long(42).And(ast.Long(43))), + internalast.Permit().When(internalast.Long(42).And(internalast.Long(43))), + }, + { + "opOr", + ast.Permit().When(ast.Long(42).Or(ast.Long(43))), + internalast.Permit().When(internalast.Long(42).Or(internalast.Long(43))), + }, + { + "opNot", + ast.Permit().When(ast.Not(ast.True())), + internalast.Permit().When(internalast.Not(internalast.True())), + }, + { + "opIf", + ast.Permit().When(ast.If(ast.True(), ast.Long(42), ast.Long(43))), + internalast.Permit().When(internalast.If(internalast.True(), internalast.Long(42), internalast.Long(43))), + }, + { + "opPlus", + ast.Permit().When(ast.Long(42).Plus(ast.Long(43))), + internalast.Permit().When(internalast.Long(42).Plus(internalast.Long(43))), + }, + { + "opMinus", + ast.Permit().When(ast.Long(42).Minus(ast.Long(43))), + internalast.Permit().When(internalast.Long(42).Minus(internalast.Long(43))), + }, + { + "opTimes", + ast.Permit().When(ast.Long(42).Times(ast.Long(43))), + internalast.Permit().When(internalast.Long(42).Times(internalast.Long(43))), + }, + { + "opNegate", + ast.Permit().When(ast.Negate(ast.True())), + internalast.Permit().When(internalast.Negate(internalast.True())), + }, + { + "opIn", + ast.Permit().When(ast.Long(42).In(ast.Long(43))), + internalast.Permit().When(internalast.Long(42).In(internalast.Long(43))), + }, + { + "opIs", + ast.Permit().When(ast.Long(42).Is(types.Path("T"))), + internalast.Permit().When(internalast.Long(42).Is(types.Path("T"))), + }, + { + "opIsIn", + ast.Permit().When(ast.Long(42).IsIn(types.Path("T"), ast.Long(43))), + internalast.Permit().When(internalast.Long(42).IsIn(types.Path("T"), internalast.Long(43))), + }, + { + "opContains", + ast.Permit().When(ast.Long(42).Contains(ast.Long(43))), + internalast.Permit().When(internalast.Long(42).Contains(internalast.Long(43))), + }, + { + "opContainsAll", + ast.Permit().When(ast.Long(42).ContainsAll(ast.Long(43))), + internalast.Permit().When(internalast.Long(42).ContainsAll(internalast.Long(43))), + }, + { + "opContainsAny", + ast.Permit().When(ast.Long(42).ContainsAny(ast.Long(43))), + internalast.Permit().When(internalast.Long(42).ContainsAny(internalast.Long(43))), + }, + { + "opAccess", + ast.Permit().When(ast.Long(42).Access("key")), + internalast.Permit().When(internalast.Long(42).Access("key")), + }, + { + "opHas", + ast.Permit().When(ast.Long(42).Has("key")), + internalast.Permit().When(internalast.Long(42).Has("key")), + }, + { + "opIsIpv4", + ast.Permit().When(ast.Long(42).IsIpv4()), + internalast.Permit().When(internalast.Long(42).IsIpv4()), + }, + { + "opIsIpv6", + ast.Permit().When(ast.Long(42).IsIpv6()), + internalast.Permit().When(internalast.Long(42).IsIpv6()), + }, + { + "opIsMulticast", + ast.Permit().When(ast.Long(42).IsMulticast()), + internalast.Permit().When(internalast.Long(42).IsMulticast()), + }, + { + "opIsLoopback", + ast.Permit().When(ast.Long(42).IsLoopback()), + internalast.Permit().When(internalast.Long(42).IsLoopback()), + }, + { + "opIsInRange", + ast.Permit().When(ast.Long(42).IsInRange(ast.Long(43))), + internalast.Permit().When(internalast.Long(42).IsInRange(internalast.Long(43))), + }, } for _, tt := range tests { From eebce6552f32e69102d48f0654b3e86b5343b80d Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Tue, 13 Aug 2024 17:07:45 -0700 Subject: [PATCH 097/216] cedar-go: add JSON marshaling and unmarshaling to the top-level Policy struct Signed-off-by: philhassey --- cedar.go | 58 ++++++++++++++------ cedar_test.go | 111 ++++++++++++++++++++------------------ internal/parser/policy.go | 11 ---- 3 files changed, 99 insertions(+), 81 deletions(-) diff --git a/cedar.go b/cedar.go index a37685e2..97833f4a 100644 --- a/cedar.go +++ b/cedar.go @@ -4,22 +4,21 @@ package cedar import ( "fmt" + "github.com/cedar-policy/cedar-go/internal/ast" "github.com/cedar-policy/cedar-go/internal/entities" "github.com/cedar-policy/cedar-go/internal/eval" + "github.com/cedar-policy/cedar-go/internal/json" "github.com/cedar-policy/cedar-go/internal/parser" "github.com/cedar-policy/cedar-go/types" ) -// A PolicySet is a slice of policies. -type PolicySet []Policy - -// A Policy is the parsed form of a single Cedar language policy statement. It -// includes the following elements, a Position, Annotations, and an Effect. +// A Policy is the parsed form of a single Cedar language policy statement. type Policy struct { Position Position // location within the policy text document Annotations Annotations // annotations found for this policy Effect Effect // the effect of this policy eval evaler // determines if a policy matches a request. + ast ast.Policy } // A Position describes an arbitrary source position including the file, line, and column location. @@ -34,9 +33,18 @@ type Position struct { // have no impact on policy evaluation. type Annotations map[string]string +// TODO: Is this where we should deal with duplicate keys? +func newAnnotationsFromSlice(annotations []ast.AnnotationType) Annotations { + res := make(map[string]string, len(annotations)) + for _, e := range annotations { + res[string(e.Key)] = string(e.Value) + } + return res +} + // An Effect specifies the intent of the policy, to either permit or forbid any // request that matches the scope and conditions specified in the policy. -type Effect bool +type Effect ast.Effect // Each Policy has a Permit or Forbid effect that is determined during parsing. const ( @@ -44,19 +52,35 @@ const ( Forbid = Effect(false) ) -func (a Effect) String() string { - if a { - return "permit" - } - return "forbid" +// MarshalJSON encodes a single Policy statement in the JSON format specified by the [Cedar documentation]. +// +// [Cedar documentation]: https://docs.cedarpolicy.com/policies/json-format.html +func (p *Policy) MarshalJSON() ([]byte, error) { + jsonPolicy := &json.Policy{Policy: p.ast} + return jsonPolicy.MarshalJSON() } -func (a Effect) MarshalJSON() ([]byte, error) { return []byte(`"` + a.String() + `"`), nil } -func (a *Effect) UnmarshalJSON(b []byte) error { - *a = string(b) == `"permit"` +// UnmarshalJSON parses and compiles a single Policy statement in the JSON format specified by the [Cedar documentation]. +// +// [Cedar documentation]: https://docs.cedarpolicy.com/policies/json-format.html +func (p *Policy) UnmarshalJSON(b []byte) error { + var jsonPolicy json.Policy + if err := jsonPolicy.UnmarshalJSON(b); err != nil { + return err + } + *p = Policy{ + Position: Position{}, + Annotations: newAnnotationsFromSlice(jsonPolicy.Annotations), + Effect: Effect(jsonPolicy.Effect), + eval: eval.Compile(jsonPolicy.Policy), + ast: jsonPolicy.Policy, + } return nil } +// A PolicySet is a slice of policies. +type PolicySet []Policy + // NewPolicySet will create a PolicySet from the given text document with the // given file name used in Position data. If there is an error parsing the // document, it will be returned. @@ -67,7 +91,6 @@ func NewPolicySet(fileName string, document []byte) (PolicySet, error) { } var policies PolicySet for _, p := range res { - ann := Annotations(p.TmpGetAnnotations()) policies = append(policies, Policy{ Position: Position{ Filename: fileName, @@ -75,9 +98,10 @@ func NewPolicySet(fileName string, document []byte) (PolicySet, error) { Line: p.Position.Line, Column: p.Position.Column, }, - Annotations: ann, - Effect: Effect(p.TmpGetEffect()), + Annotations: newAnnotationsFromSlice(p.Policy.Annotations), + Effect: Effect(p.Policy.Effect), eval: eval.Compile(p.Policy.Policy), + ast: p.Policy.Policy, }) } return policies, nil diff --git a/cedar_test.go b/cedar_test.go index cd4f8fcd..14f0bd01 100644 --- a/cedar_test.go +++ b/cedar_test.go @@ -1,6 +1,7 @@ package cedar import ( + "bytes" "encoding/json" "net/netip" "testing" @@ -837,27 +838,6 @@ func TestError(t *testing.T) { testutil.Equals(t, e.String(), "while evaluating policy `policy42`: bad error") } -// func TestInvalidPolicy(t *testing.T) { -// t.Parallel() -// // This case is very fabricated, it can't really happen -// ps := PolicySet{ -// { -// Effect: Forbid, -// eval: newLiteralEval(types.Long(42)), -// }, -// } -// ok, diag := ps.IsAuthorized(Entities{}, Request{}) -// testutil.Equals(t, ok, Deny) -// testutil.Equals(t, diag, Diagnostic{ -// Errors: []Error{ -// { -// Policy: 0, -// Message: "type error: expected bool, got long", -// }, -// }, -// }) -// } - func TestCorpusRelated(t *testing.T) { t.Parallel() tests := []struct { @@ -1021,6 +1001,63 @@ func TestCorpusRelated(t *testing.T) { } } +func prettifyJson(in []byte) []byte { + var buf bytes.Buffer + _ = json.Indent(&buf, in, "", " ") + return buf.Bytes() +} + +func TestPolicyJSON(t *testing.T) { + t.Parallel() + + // Taken from https://docs.cedarpolicy.com/policies/json-format.html + jsonEncodedPolicy := prettifyJson([]byte(` + { + "effect": "permit", + "principal": { + "op": "==", + "entity": { "type": "User", "id": "12UA45" } + }, + "action": { + "op": "==", + "entity": { "type": "Action", "id": "view" } + }, + "resource": { + "op": "in", + "entity": { "type": "Folder", "id": "abc" } + }, + "conditions": [ + { + "kind": "when", + "body": { + "==": { + "left": { + ".": { + "left": { + "Var": "context" + }, + "attr": "tls_version" + } + }, + "right": { + "Value": "1.3" + } + } + } + } + ] + }`, + )) + + var policy Policy + testutil.OK(t, policy.UnmarshalJSON(jsonEncodedPolicy)) + + output, err := policy.MarshalJSON() + testutil.OK(t, err) + + testutil.Equals(t, string(prettifyJson(output)), string(jsonEncodedPolicy)) +} + func TestEntitiesJSON(t *testing.T) { t.Parallel() t.Run("Marshal", func(t *testing.T) { @@ -1061,38 +1098,6 @@ func TestEntitiesJSON(t *testing.T) { }) } -func TestJSONEffect(t *testing.T) { - t.Parallel() - t.Run("MarshalPermit", func(t *testing.T) { - t.Parallel() - e := Permit - b, err := e.MarshalJSON() - testutil.OK(t, err) - testutil.Equals(t, string(b), `"permit"`) - }) - t.Run("MarshalForbid", func(t *testing.T) { - t.Parallel() - e := Forbid - b, err := e.MarshalJSON() - testutil.OK(t, err) - testutil.Equals(t, string(b), `"forbid"`) - }) - t.Run("UnmarshalPermit", func(t *testing.T) { - t.Parallel() - var e Effect - err := json.Unmarshal([]byte(`"permit"`), &e) - testutil.OK(t, err) - testutil.Equals(t, e, Permit) - }) - t.Run("UnmarshalForbid", func(t *testing.T) { - t.Parallel() - var e Effect - err := json.Unmarshal([]byte(`"forbid"`), &e) - testutil.OK(t, err) - testutil.Equals(t, e, Forbid) - }) -} - func TestJSONDecision(t *testing.T) { t.Parallel() t.Run("MarshalAllow", func(t *testing.T) { diff --git a/internal/parser/policy.go b/internal/parser/policy.go index 6a2ae318..0e5c84b5 100644 --- a/internal/parser/policy.go +++ b/internal/parser/policy.go @@ -9,17 +9,6 @@ type PolicySetEntry struct { Position Position } -func (p PolicySetEntry) TmpGetAnnotations() map[string]string { - res := make(map[string]string, len(p.Policy.Annotations)) - for _, e := range p.Policy.Annotations { - res[string(e.Key)] = string(e.Value) - } - return res -} -func (p PolicySetEntry) TmpGetEffect() bool { - return bool(p.Policy.Effect) -} - type Policy struct { ast.Policy } From 114e628d524d0f4f0069cda540b808ac03175c80 Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Tue, 13 Aug 2024 17:09:09 -0700 Subject: [PATCH 098/216] cedar-go: move all Policy-related code to policy.go Signed-off-by: philhassey --- cedar.go | 68 --------------------------------------------------- policy.go | 73 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 73 insertions(+), 68 deletions(-) create mode 100644 policy.go diff --git a/cedar.go b/cedar.go index 97833f4a..019064fe 100644 --- a/cedar.go +++ b/cedar.go @@ -4,80 +4,12 @@ package cedar import ( "fmt" - "github.com/cedar-policy/cedar-go/internal/ast" "github.com/cedar-policy/cedar-go/internal/entities" "github.com/cedar-policy/cedar-go/internal/eval" - "github.com/cedar-policy/cedar-go/internal/json" "github.com/cedar-policy/cedar-go/internal/parser" "github.com/cedar-policy/cedar-go/types" ) -// A Policy is the parsed form of a single Cedar language policy statement. -type Policy struct { - Position Position // location within the policy text document - Annotations Annotations // annotations found for this policy - Effect Effect // the effect of this policy - eval evaler // determines if a policy matches a request. - ast ast.Policy -} - -// A Position describes an arbitrary source position including the file, line, and column location. -type Position struct { - Filename string // filename, if any - Offset int // byte offset, starting at 0 - Line int // line number, starting at 1 - Column int // column number, starting at 1 (character count per line) -} - -// An Annotations is a map of key, value pairs found in the policy. Annotations -// have no impact on policy evaluation. -type Annotations map[string]string - -// TODO: Is this where we should deal with duplicate keys? -func newAnnotationsFromSlice(annotations []ast.AnnotationType) Annotations { - res := make(map[string]string, len(annotations)) - for _, e := range annotations { - res[string(e.Key)] = string(e.Value) - } - return res -} - -// An Effect specifies the intent of the policy, to either permit or forbid any -// request that matches the scope and conditions specified in the policy. -type Effect ast.Effect - -// Each Policy has a Permit or Forbid effect that is determined during parsing. -const ( - Permit = Effect(true) - Forbid = Effect(false) -) - -// MarshalJSON encodes a single Policy statement in the JSON format specified by the [Cedar documentation]. -// -// [Cedar documentation]: https://docs.cedarpolicy.com/policies/json-format.html -func (p *Policy) MarshalJSON() ([]byte, error) { - jsonPolicy := &json.Policy{Policy: p.ast} - return jsonPolicy.MarshalJSON() -} - -// UnmarshalJSON parses and compiles a single Policy statement in the JSON format specified by the [Cedar documentation]. -// -// [Cedar documentation]: https://docs.cedarpolicy.com/policies/json-format.html -func (p *Policy) UnmarshalJSON(b []byte) error { - var jsonPolicy json.Policy - if err := jsonPolicy.UnmarshalJSON(b); err != nil { - return err - } - *p = Policy{ - Position: Position{}, - Annotations: newAnnotationsFromSlice(jsonPolicy.Annotations), - Effect: Effect(jsonPolicy.Effect), - eval: eval.Compile(jsonPolicy.Policy), - ast: jsonPolicy.Policy, - } - return nil -} - // A PolicySet is a slice of policies. type PolicySet []Policy diff --git a/policy.go b/policy.go new file mode 100644 index 00000000..0b3feb06 --- /dev/null +++ b/policy.go @@ -0,0 +1,73 @@ +package cedar + +import ( + "github.com/cedar-policy/cedar-go/internal/ast" + "github.com/cedar-policy/cedar-go/internal/eval" + "github.com/cedar-policy/cedar-go/internal/json" +) + +// A Policy is the parsed form of a single Cedar language policy statement. +type Policy struct { + Position Position // location within the policy text document + Annotations Annotations // annotations found for this policy + Effect Effect // the effect of this policy + eval evaler // determines if a policy matches a request. + ast ast.Policy +} + +// A Position describes an arbitrary source position including the file, line, and column location. +type Position struct { + Filename string // filename, if any + Offset int // byte offset, starting at 0 + Line int // line number, starting at 1 + Column int // column number, starting at 1 (character count per line) +} + +// An Annotations is a map of key, value pairs found in the policy. Annotations +// have no impact on policy evaluation. +type Annotations map[string]string + +// TODO: Is this where we should deal with duplicate keys? +func newAnnotationsFromSlice(annotations []ast.AnnotationType) Annotations { + res := make(map[string]string, len(annotations)) + for _, e := range annotations { + res[string(e.Key)] = string(e.Value) + } + return res +} + +// An Effect specifies the intent of the policy, to either permit or forbid any +// request that matches the scope and conditions specified in the policy. +type Effect ast.Effect + +// Each Policy has a Permit or Forbid effect that is determined during parsing. +const ( + Permit = Effect(true) + Forbid = Effect(false) +) + +// MarshalJSON encodes a single Policy statement in the JSON format specified by the [Cedar documentation]. +// +// [Cedar documentation]: https://docs.cedarpolicy.com/policies/json-format.html +func (p *Policy) MarshalJSON() ([]byte, error) { + jsonPolicy := &json.Policy{Policy: p.ast} + return jsonPolicy.MarshalJSON() +} + +// UnmarshalJSON parses and compiles a single Policy statement in the JSON format specified by the [Cedar documentation]. +// +// [Cedar documentation]: https://docs.cedarpolicy.com/policies/json-format.html +func (p *Policy) UnmarshalJSON(b []byte) error { + var jsonPolicy json.Policy + if err := jsonPolicy.UnmarshalJSON(b); err != nil { + return err + } + *p = Policy{ + Position: Position{}, + Annotations: newAnnotationsFromSlice(jsonPolicy.Annotations), + Effect: Effect(jsonPolicy.Effect), + eval: eval.Compile(jsonPolicy.Policy), + ast: jsonPolicy.Policy, + } + return nil +} From 333c39d8ff6d8d45278893fe129a4cf9927f7836 Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Tue, 13 Aug 2024 17:10:38 -0700 Subject: [PATCH 099/216] cedar-go: move authorization code to authorize.go Signed-off-by: philhassey --- authorize.go | 121 ++++++++ authorize_test.go | 723 ++++++++++++++++++++++++++++++++++++++++++++++ cedar.go | 114 -------- cedar_test.go | 714 --------------------------------------------- 4 files changed, 844 insertions(+), 828 deletions(-) create mode 100644 authorize.go create mode 100644 authorize_test.go diff --git a/authorize.go b/authorize.go new file mode 100644 index 00000000..1862c2cf --- /dev/null +++ b/authorize.go @@ -0,0 +1,121 @@ +package cedar + +import ( + "fmt" + + "github.com/cedar-policy/cedar-go/internal/entities" + "github.com/cedar-policy/cedar-go/internal/eval" + "github.com/cedar-policy/cedar-go/types" +) + +// A Decision is the result of the authorization. +type Decision bool + +// Each authorization results in one of these Decisions. +const ( + Allow = Decision(true) + Deny = Decision(false) +) + +func (a Decision) String() string { + if a { + return "allow" + } + return "deny" +} + +func (a Decision) MarshalJSON() ([]byte, error) { return []byte(`"` + a.String() + `"`), nil } + +func (a *Decision) UnmarshalJSON(b []byte) error { + *a = string(b) == `"allow"` + return nil +} + +// A Diagnostic details the errors and reasons for an authorization decision. +type Diagnostic struct { + Reasons []Reason `json:"reasons,omitempty"` + Errors []Error `json:"errors,omitempty"` +} + +// An Error details the Policy index within a PolicySet, the Position within the +// text document, and the resulting error message. +type Error struct { + Policy int `json:"policy"` + Position Position `json:"position"` + Message string `json:"message"` +} + +func (e Error) String() string { + return fmt.Sprintf("while evaluating policy `policy%d`: %v", e.Policy, e.Message) +} + +// A Reason details the Policy index within a PolicySet, and the Position within +// the text document. +type Reason struct { + Policy int `json:"policy"` + Position Position `json:"position"` +} + +// A Request is the Principal, Action, Resource, and Context portion of an +// authorization request. +type Request struct { + Principal types.EntityUID `json:"principal"` + Action types.EntityUID `json:"action"` + Resource types.EntityUID `json:"resource"` + Context types.Record `json:"context"` +} + +type evalContext = eval.Context + +type evaler = eval.Evaler + +// IsAuthorized uses the combination of the PolicySet and Entities to determine +// if the given Request to determine Decision and Diagnostic. +func (p PolicySet) IsAuthorized(entityMap entities.Entities, req Request) (Decision, Diagnostic) { + c := &evalContext{ + Entities: entityMap, + Principal: req.Principal, + Action: req.Action, + Resource: req.Resource, + Context: req.Context, + } + var diag Diagnostic + var gotForbid bool + var forbidReasons []Reason + var gotPermit bool + var permitReasons []Reason + // Don't try to short circuit this. + // - Even though single forbid means forbid + // - All policy should be run to collect errors + // - For permit, all permits must be run to collect annotations + // - For forbid, forbids must be run to collect annotations + for n, po := range p { + v, err := po.eval.Eval(c) + if err != nil { + diag.Errors = append(diag.Errors, Error{Policy: n, Position: po.Position, Message: err.Error()}) + continue + } + vb, err := types.ValueToBool(v) + if err != nil { + // should never happen, maybe remove this case + diag.Errors = append(diag.Errors, Error{Policy: n, Position: po.Position, Message: err.Error()}) + continue + } + if !vb { + continue + } + if po.Effect == Forbid { + forbidReasons = append(forbidReasons, Reason{Policy: n, Position: po.Position}) + gotForbid = true + } else { + permitReasons = append(permitReasons, Reason{Policy: n, Position: po.Position}) + gotPermit = true + } + } + if gotForbid { + diag.Reasons = forbidReasons + } else if gotPermit { + diag.Reasons = permitReasons + } + return Decision(gotPermit && !gotForbid), diag +} diff --git a/authorize_test.go b/authorize_test.go new file mode 100644 index 00000000..4ea47287 --- /dev/null +++ b/authorize_test.go @@ -0,0 +1,723 @@ +package cedar + +import ( + "testing" + + "github.com/cedar-policy/cedar-go/internal/entities" + "github.com/cedar-policy/cedar-go/internal/testutil" + "github.com/cedar-policy/cedar-go/types" +) + +//nolint:revive // due to table test function-length +func TestIsAuthorized(t *testing.T) { + t.Parallel() + cuzco := types.NewEntityUID("coder", "cuzco") + dropTable := types.NewEntityUID("table", "drop") + tests := []struct { + Name string + Policy string + Entities entities.Entities + Principal, Action, Resource types.EntityUID + Context types.Record + Want Decision + DiagErr int + ParseErr bool + }{ + { + Name: "simple-permit", + Policy: `permit(principal,action,resource);`, + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, + Resource: types.NewEntityUID("table", "whatever"), + Context: types.Record{}, + Want: true, + DiagErr: 0, + }, + { + Name: "simple-forbid", + Policy: `forbid(principal,action,resource);`, + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, + Resource: types.NewEntityUID("table", "whatever"), + Context: types.Record{}, + Want: false, + DiagErr: 0, + }, + { + Name: "no-permit", + Policy: `permit(principal,action,resource in asdf::"1234");`, + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, + Resource: types.NewEntityUID("table", "whatever"), + Context: types.Record{}, + Want: false, + DiagErr: 0, + }, + { + Name: "error-in-policy", + Policy: `permit(principal,action,resource) when { resource in "foo" };`, + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, + Resource: types.NewEntityUID("table", "whatever"), + Context: types.Record{}, + Want: false, + DiagErr: 1, + }, + { + Name: "error-in-policy-continues", + Policy: `permit(principal,action,resource) when { resource in "foo" }; + permit(principal,action,resource); + `, + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, + Resource: types.NewEntityUID("table", "whatever"), + Context: types.Record{}, + Want: true, + DiagErr: 1, + }, + { + Name: "permit-requires-context-success", + Policy: `permit(principal,action,resource) when { context.x == 42 };`, + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, + Resource: types.NewEntityUID("table", "whatever"), + Context: types.Record{"x": types.Long(42)}, + Want: true, + DiagErr: 0, + }, + { + Name: "permit-requires-context-fail", + Policy: `permit(principal,action,resource) when { context.x == 42 };`, + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, + Resource: types.NewEntityUID("table", "whatever"), + Context: types.Record{"x": types.Long(43)}, + Want: false, + DiagErr: 0, + }, + { + Name: "permit-requires-entities-success", + Policy: `permit(principal,action,resource) when { principal.x == 42 };`, + Entities: entities.Entities{ + cuzco: entities.Entity{ + UID: cuzco, + Attributes: types.Record{"x": types.Long(42)}, + }, + }, + Principal: cuzco, + Action: dropTable, + Resource: types.NewEntityUID("table", "whatever"), + Context: types.Record{}, + Want: true, + DiagErr: 0, + }, + { + Name: "permit-requires-entities-fail", + Policy: `permit(principal,action,resource) when { principal.x == 42 };`, + Entities: entities.Entities{ + cuzco: entities.Entity{ + UID: cuzco, + Attributes: types.Record{"x": types.Long(43)}, + }, + }, + Principal: cuzco, + Action: dropTable, + Resource: types.NewEntityUID("table", "whatever"), + Context: types.Record{}, + Want: false, + DiagErr: 0, + }, + { + Name: "permit-requires-entities-parent-success", + Policy: `permit(principal,action,resource) when { principal in parent::"bob" };`, + Entities: entities.Entities{ + cuzco: entities.Entity{ + UID: cuzco, + Parents: []types.EntityUID{types.NewEntityUID("parent", "bob")}, + }, + }, + Principal: cuzco, + Action: dropTable, + Resource: types.NewEntityUID("table", "whatever"), + Context: types.Record{}, + Want: true, + DiagErr: 0, + }, + { + Name: "permit-principal-equals", + Policy: `permit(principal == coder::"cuzco",action,resource);`, + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, + Resource: types.NewEntityUID("table", "whatever"), + Context: types.Record{}, + Want: true, + DiagErr: 0, + }, + { + Name: "permit-principal-in", + Policy: `permit(principal in team::"osiris",action,resource);`, + Entities: entities.Entities{ + cuzco: entities.Entity{ + UID: cuzco, + Parents: []types.EntityUID{types.NewEntityUID("team", "osiris")}, + }, + }, + Principal: cuzco, + Action: dropTable, + Resource: types.NewEntityUID("table", "whatever"), + Context: types.Record{}, + Want: true, + DiagErr: 0, + }, + { + Name: "permit-action-equals", + Policy: `permit(principal,action == table::"drop",resource);`, + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, + Resource: types.NewEntityUID("table", "whatever"), + Context: types.Record{}, + Want: true, + DiagErr: 0, + }, + { + Name: "permit-action-in", + Policy: `permit(principal,action in scary::"stuff",resource);`, + Entities: entities.Entities{ + dropTable: entities.Entity{ + UID: dropTable, + Parents: []types.EntityUID{types.NewEntityUID("scary", "stuff")}, + }, + }, + Principal: cuzco, + Action: dropTable, + Resource: types.NewEntityUID("table", "whatever"), + Context: types.Record{}, + Want: true, + DiagErr: 0, + }, + { + Name: "permit-action-in-set", + Policy: `permit(principal,action in [scary::"stuff"],resource);`, + Entities: entities.Entities{ + dropTable: entities.Entity{ + UID: dropTable, + Parents: []types.EntityUID{types.NewEntityUID("scary", "stuff")}, + }, + }, + Principal: cuzco, + Action: dropTable, + Resource: types.NewEntityUID("table", "whatever"), + Context: types.Record{}, + Want: true, + DiagErr: 0, + }, + { + Name: "permit-resource-equals", + Policy: `permit(principal,action,resource == table::"whatever");`, + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, + Resource: types.NewEntityUID("table", "whatever"), + Context: types.Record{}, + Want: true, + DiagErr: 0, + }, + { + Name: "permit-unless", + Policy: `permit(principal,action,resource) unless { false };`, + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, + Resource: types.NewEntityUID("table", "whatever"), + Context: types.Record{}, + Want: true, + DiagErr: 0, + }, + { + Name: "permit-when-if", + Policy: `permit(principal,action,resource) when { (if true then true else true) };`, + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, + Resource: types.NewEntityUID("table", "whatever"), + Context: types.Record{}, + Want: true, + DiagErr: 0, + }, + { + Name: "permit-when-or", + Policy: `permit(principal,action,resource) when { (true || false) };`, + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, + Resource: types.NewEntityUID("table", "whatever"), + Context: types.Record{}, + Want: true, + DiagErr: 0, + }, + { + Name: "permit-when-and", + Policy: `permit(principal,action,resource) when { (true && true) };`, + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, + Resource: types.NewEntityUID("table", "whatever"), + Context: types.Record{}, + Want: true, + DiagErr: 0, + }, + { + Name: "permit-when-relations", + Policy: `permit(principal,action,resource) when { (1<2) && (1<=1) && (2>1) && (1>=1) && (1!=2) && (1==1)};`, + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, + Resource: types.NewEntityUID("table", "whatever"), + Context: types.Record{}, + Want: true, + DiagErr: 0, + }, + { + Name: "permit-when-relations-in", + Policy: `permit(principal,action,resource) when { principal in principal };`, + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, + Resource: types.NewEntityUID("table", "whatever"), + Context: types.Record{}, + Want: true, + DiagErr: 0, + }, + { + Name: "permit-when-relations-has", + Policy: `permit(principal,action,resource) when { principal has name };`, + Entities: entities.Entities{ + cuzco: entities.Entity{ + UID: cuzco, + Attributes: types.Record{"name": types.String("bob")}, + }, + }, + Principal: cuzco, + Action: dropTable, + Resource: types.NewEntityUID("table", "whatever"), + Context: types.Record{}, + Want: true, + DiagErr: 0, + }, + { + Name: "permit-when-add-sub", + Policy: `permit(principal,action,resource) when { 40+3-1==42 };`, + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, + Resource: types.NewEntityUID("table", "whatever"), + Context: types.Record{}, + Want: true, + DiagErr: 0, + }, + { + Name: "permit-when-mul", + Policy: `permit(principal,action,resource) when { 6*7==42 };`, + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, + Resource: types.NewEntityUID("table", "whatever"), + Context: types.Record{}, + Want: true, + DiagErr: 0, + }, + { + Name: "permit-when-negate", + Policy: `permit(principal,action,resource) when { -42==-42 };`, + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, + Resource: types.NewEntityUID("table", "whatever"), + Context: types.Record{}, + Want: true, + DiagErr: 0, + }, + { + Name: "permit-when-not", + Policy: `permit(principal,action,resource) when { !(1+1==42) };`, + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, + Resource: types.NewEntityUID("table", "whatever"), + Context: types.Record{}, + Want: true, + DiagErr: 0, + }, + { + Name: "permit-when-set", + Policy: `permit(principal,action,resource) when { [1,2,3].contains(2) };`, + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, + Resource: types.NewEntityUID("table", "whatever"), + Context: types.Record{}, + Want: true, + DiagErr: 0, + }, + { + Name: "permit-when-record", + Policy: `permit(principal,action,resource) when { {name:"bob"} has name };`, + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, + Resource: types.NewEntityUID("table", "whatever"), + Context: types.Record{}, + Want: true, + DiagErr: 0, + }, + { + Name: "permit-when-action", + Policy: `permit(principal,action,resource) when { action in action };`, + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, + Resource: types.NewEntityUID("table", "whatever"), + Context: types.Record{}, + Want: true, + DiagErr: 0, + }, + { + Name: "permit-when-set-contains-ok", + Policy: `permit(principal,action,resource) when { [1,2,3].contains(2) };`, + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, + Resource: types.NewEntityUID("table", "whatever"), + Context: types.Record{}, + Want: true, + DiagErr: 0, + }, + { + Name: "permit-when-set-contains-error", + Policy: `permit(principal,action,resource) when { [1,2,3].contains(2,3) };`, + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, + Resource: types.NewEntityUID("table", "whatever"), + Context: types.Record{}, + Want: false, + DiagErr: 0, + ParseErr: true, + }, + { + Name: "permit-when-set-containsAll-ok", + Policy: `permit(principal,action,resource) when { [1,2,3].containsAll([2,3]) };`, + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, + Resource: types.NewEntityUID("table", "whatever"), + Context: types.Record{}, + Want: true, + DiagErr: 0, + }, + { + Name: "permit-when-set-containsAll-error", + Policy: `permit(principal,action,resource) when { [1,2,3].containsAll(2,3) };`, + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, + Resource: types.NewEntityUID("table", "whatever"), + Context: types.Record{}, + Want: false, + DiagErr: 0, + ParseErr: true, + }, + { + Name: "permit-when-set-containsAny-ok", + Policy: `permit(principal,action,resource) when { [1,2,3].containsAny([2,5]) };`, + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, + Resource: types.NewEntityUID("table", "whatever"), + Context: types.Record{}, + Want: true, + DiagErr: 0, + }, + { + Name: "permit-when-set-containsAny-error", + Policy: `permit(principal,action,resource) when { [1,2,3].containsAny(2,5) };`, + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, + Resource: types.NewEntityUID("table", "whatever"), + Context: types.Record{}, + Want: false, + DiagErr: 0, + ParseErr: true, + }, + { + Name: "permit-when-record-attr", + Policy: `permit(principal,action,resource) when { {name:"bob"}["name"] == "bob" };`, + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, + Resource: types.NewEntityUID("table", "whatever"), + Context: types.Record{}, + Want: true, + DiagErr: 0, + }, + { + Name: "permit-when-unknown-method", + Policy: `permit(principal,action,resource) when { [1,2,3].shuffle() };`, + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, + Resource: types.NewEntityUID("table", "whatever"), + Context: types.Record{}, + Want: false, + DiagErr: 1, + ParseErr: false, + }, + { + Name: "permit-when-like", + Policy: `permit(principal,action,resource) when { "bananas" like "*nan*" };`, + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, + Resource: types.NewEntityUID("table", "whatever"), + Context: types.Record{}, + Want: true, + DiagErr: 0, + }, + { + Name: "permit-when-unknown-ext-fun", + Policy: `permit(principal,action,resource) when { fooBar("10") };`, + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, + Resource: types.NewEntityUID("table", "whatever"), + Context: types.Record{}, + Want: false, + DiagErr: 1, + ParseErr: false, + }, + { + Name: "permit-when-decimal", + Policy: `permit(principal,action,resource) when { + decimal("10.0").lessThan(decimal("11.0")) && + decimal("10.0").lessThanOrEqual(decimal("11.0")) && + decimal("10.0").greaterThan(decimal("9.0")) && + decimal("10.0").greaterThanOrEqual(decimal("9.0")) };`, + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, + Resource: types.NewEntityUID("table", "whatever"), + Context: types.Record{}, + Want: true, + DiagErr: 0, + }, + { + Name: "permit-when-decimal-fun-wrong-arity", + Policy: `permit(principal,action,resource) when { decimal(1, 2) };`, + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, + Resource: types.NewEntityUID("table", "whatever"), + Context: types.Record{}, + Want: false, + DiagErr: 1, + }, + { + Name: "permit-when-ip", + Policy: `permit(principal,action,resource) when { + ip("1.2.3.4").isIpv4() && + ip("a:b:c:d::/16").isIpv6() && + ip("::1").isLoopback() && + ip("224.1.2.3").isMulticast() && + ip("127.0.0.1").isInRange(ip("127.0.0.0/16"))};`, + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, + Resource: types.NewEntityUID("table", "whatever"), + Context: types.Record{}, + Want: true, + DiagErr: 0, + }, + { + Name: "permit-when-ip-fun-wrong-arity", + Policy: `permit(principal,action,resource) when { ip() };`, + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, + Resource: types.NewEntityUID("table", "whatever"), + Context: types.Record{}, + Want: false, + DiagErr: 1, + }, + { + Name: "permit-when-isIpv4-wrong-arity", + Policy: `permit(principal,action,resource) when { ip("1.2.3.4").isIpv4(true) };`, + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, + Resource: types.NewEntityUID("table", "whatever"), + Context: types.Record{}, + Want: false, + DiagErr: 1, + }, + { + Name: "permit-when-isIpv6-wrong-arity", + Policy: `permit(principal,action,resource) when { ip("1.2.3.4").isIpv6(true) };`, + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, + Resource: types.NewEntityUID("table", "whatever"), + Context: types.Record{}, + Want: false, + DiagErr: 1, + }, + { + Name: "permit-when-isLoopback-wrong-arity", + Policy: `permit(principal,action,resource) when { ip("1.2.3.4").isLoopback(true) };`, + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, + Resource: types.NewEntityUID("table", "whatever"), + Context: types.Record{}, + Want: false, + DiagErr: 1, + }, + { + Name: "permit-when-isMulticast-wrong-arity", + Policy: `permit(principal,action,resource) when { ip("1.2.3.4").isMulticast(true) };`, + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, + Resource: types.NewEntityUID("table", "whatever"), + Context: types.Record{}, + Want: false, + DiagErr: 1, + }, + { + Name: "permit-when-isInRange-wrong-arity", + Policy: `permit(principal,action,resource) when { ip("1.2.3.4").isInRange() };`, + Entities: entities.Entities{}, + Principal: cuzco, + Action: dropTable, + Resource: types.NewEntityUID("table", "whatever"), + Context: types.Record{}, + Want: false, + DiagErr: 1, + }, + { + Name: "negative-unary-op", + Policy: `permit(principal,action,resource) when { -context.value > 0 };`, + Entities: entities.Entities{}, + Context: types.Record{"value": types.Long(-42)}, + Want: true, + DiagErr: 0, + }, + { + Name: "principal-is", + Policy: `permit(principal is Actor,action,resource);`, + Entities: entities.Entities{}, + Principal: types.NewEntityUID("Actor", "cuzco"), + Action: types.NewEntityUID("Action", "drop"), + Resource: types.NewEntityUID("Resource", "table"), + Context: types.Record{}, + Want: true, + DiagErr: 0, + }, + { + Name: "principal-is-in", + Policy: `permit(principal is Actor in Actor::"cuzco",action,resource);`, + Entities: entities.Entities{}, + Principal: types.NewEntityUID("Actor", "cuzco"), + Action: types.NewEntityUID("Action", "drop"), + Resource: types.NewEntityUID("Resource", "table"), + Context: types.Record{}, + Want: true, + DiagErr: 0, + }, + { + Name: "resource-is", + Policy: `permit(principal,action,resource is Resource);`, + Entities: entities.Entities{}, + Principal: types.NewEntityUID("Actor", "cuzco"), + Action: types.NewEntityUID("Action", "drop"), + Resource: types.NewEntityUID("Resource", "table"), + Context: types.Record{}, + Want: true, + DiagErr: 0, + }, + { + Name: "resource-is-in", + Policy: `permit(principal,action,resource is Resource in Resource::"table");`, + Entities: entities.Entities{}, + Principal: types.NewEntityUID("Actor", "cuzco"), + Action: types.NewEntityUID("Action", "drop"), + Resource: types.NewEntityUID("Resource", "table"), + Context: types.Record{}, + Want: true, + DiagErr: 0, + }, + { + Name: "when-is", + Policy: `permit(principal,action,resource) when { resource is Resource };`, + Entities: entities.Entities{}, + Principal: types.NewEntityUID("Actor", "cuzco"), + Action: types.NewEntityUID("Action", "drop"), + Resource: types.NewEntityUID("Resource", "table"), + Context: types.Record{}, + Want: true, + DiagErr: 0, + }, + { + Name: "when-is-in", + Policy: `permit(principal,action,resource) when { resource is Resource in Resource::"table" };`, + Entities: entities.Entities{}, + Principal: types.NewEntityUID("Actor", "cuzco"), + Action: types.NewEntityUID("Action", "drop"), + Resource: types.NewEntityUID("Resource", "table"), + Context: types.Record{}, + Want: true, + DiagErr: 0, + }, + { + Name: "when-is-in", + Policy: `permit(principal,action,resource) when { resource is Resource in Parent::"id" };`, + Entities: entities.Entities{ + types.NewEntityUID("Resource", "table"): entities.Entity{ + UID: types.NewEntityUID("Resource", "table"), + Parents: []types.EntityUID{types.NewEntityUID("Parent", "id")}, + }, + }, + Principal: types.NewEntityUID("Actor", "cuzco"), + Action: types.NewEntityUID("Action", "drop"), + Resource: types.NewEntityUID("Resource", "table"), + Context: types.Record{}, + Want: true, + DiagErr: 0, + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.Name, func(t *testing.T) { + t.Parallel() + ps, err := NewPolicySet("policy.cedar", []byte(tt.Policy)) + testutil.Equals(t, (err != nil), tt.ParseErr) + ok, diag := ps.IsAuthorized(tt.Entities, Request{ + Principal: tt.Principal, + Action: tt.Action, + Resource: tt.Resource, + Context: tt.Context, + }) + testutil.Equals(t, len(diag.Errors), tt.DiagErr) + testutil.Equals(t, ok, tt.Want) + }) + } +} diff --git a/cedar.go b/cedar.go index 019064fe..ec7e269a 100644 --- a/cedar.go +++ b/cedar.go @@ -4,10 +4,8 @@ package cedar import ( "fmt" - "github.com/cedar-policy/cedar-go/internal/entities" "github.com/cedar-policy/cedar-go/internal/eval" "github.com/cedar-policy/cedar-go/internal/parser" - "github.com/cedar-policy/cedar-go/types" ) // A PolicySet is a slice of policies. @@ -38,115 +36,3 @@ func NewPolicySet(fileName string, document []byte) (PolicySet, error) { } return policies, nil } - -// A Decision is the result of the authorization. -type Decision bool - -// Each authorization results in one of these Decisions. -const ( - Allow = Decision(true) - Deny = Decision(false) -) - -func (a Decision) String() string { - if a { - return "allow" - } - return "deny" -} - -func (a Decision) MarshalJSON() ([]byte, error) { return []byte(`"` + a.String() + `"`), nil } - -func (a *Decision) UnmarshalJSON(b []byte) error { - *a = string(b) == `"allow"` - return nil -} - -// A Diagnostic details the errors and reasons for an authorization decision. -type Diagnostic struct { - Reasons []Reason `json:"reasons,omitempty"` - Errors []Error `json:"errors,omitempty"` -} - -// An Error details the Policy index within a PolicySet, the Position within the -// text document, and the resulting error message. -type Error struct { - Policy int `json:"policy"` - Position Position `json:"position"` - Message string `json:"message"` -} - -func (e Error) String() string { - return fmt.Sprintf("while evaluating policy `policy%d`: %v", e.Policy, e.Message) -} - -// A Reason details the Policy index within a PolicySet, and the Position within -// the text document. -type Reason struct { - Policy int `json:"policy"` - Position Position `json:"position"` -} - -// A Request is the Principal, Action, Resource, and Context portion of an -// authorization request. -type Request struct { - Principal types.EntityUID `json:"principal"` - Action types.EntityUID `json:"action"` - Resource types.EntityUID `json:"resource"` - Context types.Record `json:"context"` -} - -type evalContext = eval.Context - -type evaler = eval.Evaler - -// IsAuthorized uses the combination of the PolicySet and Entities to determine -// if the given Request to determine Decision and Diagnostic. -func (p PolicySet) IsAuthorized(entityMap entities.Entities, req Request) (Decision, Diagnostic) { - c := &evalContext{ - Entities: entityMap, - Principal: req.Principal, - Action: req.Action, - Resource: req.Resource, - Context: req.Context, - } - var diag Diagnostic - var gotForbid bool - var forbidReasons []Reason - var gotPermit bool - var permitReasons []Reason - // Don't try to short circuit this. - // - Even though single forbid means forbid - // - All policy should be run to collect errors - // - For permit, all permits must be run to collect annotations - // - For forbid, forbids must be run to collect annotations - for n, po := range p { - v, err := po.eval.Eval(c) - if err != nil { - diag.Errors = append(diag.Errors, Error{Policy: n, Position: po.Position, Message: err.Error()}) - continue - } - vb, err := types.ValueToBool(v) - if err != nil { - // should never happen, maybe remove this case - diag.Errors = append(diag.Errors, Error{Policy: n, Position: po.Position, Message: err.Error()}) - continue - } - if !vb { - continue - } - if po.Effect == Forbid { - forbidReasons = append(forbidReasons, Reason{Policy: n, Position: po.Position}) - gotForbid = true - } else { - permitReasons = append(permitReasons, Reason{Policy: n, Position: po.Position}) - gotPermit = true - } - } - if gotForbid { - diag.Reasons = forbidReasons - } else if gotPermit { - diag.Reasons = permitReasons - } - return Decision(gotPermit && !gotForbid), diag -} diff --git a/cedar_test.go b/cedar_test.go index 14f0bd01..20f19775 100644 --- a/cedar_test.go +++ b/cedar_test.go @@ -52,720 +52,6 @@ func TestNewPolicySet(t *testing.T) { }) } -//nolint:revive // due to table test function-length -func TestIsAuthorized(t *testing.T) { - t.Parallel() - cuzco := types.NewEntityUID("coder", "cuzco") - dropTable := types.NewEntityUID("table", "drop") - tests := []struct { - Name string - Policy string - Entities entities.Entities - Principal, Action, Resource types.EntityUID - Context types.Record - Want Decision - DiagErr int - ParseErr bool - }{ - { - Name: "simple-permit", - Policy: `permit(principal,action,resource);`, - Entities: entities.Entities{}, - Principal: cuzco, - Action: dropTable, - Resource: types.NewEntityUID("table", "whatever"), - Context: types.Record{}, - Want: true, - DiagErr: 0, - }, - { - Name: "simple-forbid", - Policy: `forbid(principal,action,resource);`, - Entities: entities.Entities{}, - Principal: cuzco, - Action: dropTable, - Resource: types.NewEntityUID("table", "whatever"), - Context: types.Record{}, - Want: false, - DiagErr: 0, - }, - { - Name: "no-permit", - Policy: `permit(principal,action,resource in asdf::"1234");`, - Entities: entities.Entities{}, - Principal: cuzco, - Action: dropTable, - Resource: types.NewEntityUID("table", "whatever"), - Context: types.Record{}, - Want: false, - DiagErr: 0, - }, - { - Name: "error-in-policy", - Policy: `permit(principal,action,resource) when { resource in "foo" };`, - Entities: entities.Entities{}, - Principal: cuzco, - Action: dropTable, - Resource: types.NewEntityUID("table", "whatever"), - Context: types.Record{}, - Want: false, - DiagErr: 1, - }, - { - Name: "error-in-policy-continues", - Policy: `permit(principal,action,resource) when { resource in "foo" }; - permit(principal,action,resource); - `, - Entities: entities.Entities{}, - Principal: cuzco, - Action: dropTable, - Resource: types.NewEntityUID("table", "whatever"), - Context: types.Record{}, - Want: true, - DiagErr: 1, - }, - { - Name: "permit-requires-context-success", - Policy: `permit(principal,action,resource) when { context.x == 42 };`, - Entities: entities.Entities{}, - Principal: cuzco, - Action: dropTable, - Resource: types.NewEntityUID("table", "whatever"), - Context: types.Record{"x": types.Long(42)}, - Want: true, - DiagErr: 0, - }, - { - Name: "permit-requires-context-fail", - Policy: `permit(principal,action,resource) when { context.x == 42 };`, - Entities: entities.Entities{}, - Principal: cuzco, - Action: dropTable, - Resource: types.NewEntityUID("table", "whatever"), - Context: types.Record{"x": types.Long(43)}, - Want: false, - DiagErr: 0, - }, - { - Name: "permit-requires-entities-success", - Policy: `permit(principal,action,resource) when { principal.x == 42 };`, - Entities: entities.Entities{ - cuzco: entities.Entity{ - UID: cuzco, - Attributes: types.Record{"x": types.Long(42)}, - }, - }, - Principal: cuzco, - Action: dropTable, - Resource: types.NewEntityUID("table", "whatever"), - Context: types.Record{}, - Want: true, - DiagErr: 0, - }, - { - Name: "permit-requires-entities-fail", - Policy: `permit(principal,action,resource) when { principal.x == 42 };`, - Entities: entities.Entities{ - cuzco: entities.Entity{ - UID: cuzco, - Attributes: types.Record{"x": types.Long(43)}, - }, - }, - Principal: cuzco, - Action: dropTable, - Resource: types.NewEntityUID("table", "whatever"), - Context: types.Record{}, - Want: false, - DiagErr: 0, - }, - { - Name: "permit-requires-entities-parent-success", - Policy: `permit(principal,action,resource) when { principal in parent::"bob" };`, - Entities: entities.Entities{ - cuzco: entities.Entity{ - UID: cuzco, - Parents: []types.EntityUID{types.NewEntityUID("parent", "bob")}, - }, - }, - Principal: cuzco, - Action: dropTable, - Resource: types.NewEntityUID("table", "whatever"), - Context: types.Record{}, - Want: true, - DiagErr: 0, - }, - { - Name: "permit-principal-equals", - Policy: `permit(principal == coder::"cuzco",action,resource);`, - Entities: entities.Entities{}, - Principal: cuzco, - Action: dropTable, - Resource: types.NewEntityUID("table", "whatever"), - Context: types.Record{}, - Want: true, - DiagErr: 0, - }, - { - Name: "permit-principal-in", - Policy: `permit(principal in team::"osiris",action,resource);`, - Entities: entities.Entities{ - cuzco: entities.Entity{ - UID: cuzco, - Parents: []types.EntityUID{types.NewEntityUID("team", "osiris")}, - }, - }, - Principal: cuzco, - Action: dropTable, - Resource: types.NewEntityUID("table", "whatever"), - Context: types.Record{}, - Want: true, - DiagErr: 0, - }, - { - Name: "permit-action-equals", - Policy: `permit(principal,action == table::"drop",resource);`, - Entities: entities.Entities{}, - Principal: cuzco, - Action: dropTable, - Resource: types.NewEntityUID("table", "whatever"), - Context: types.Record{}, - Want: true, - DiagErr: 0, - }, - { - Name: "permit-action-in", - Policy: `permit(principal,action in scary::"stuff",resource);`, - Entities: entities.Entities{ - dropTable: entities.Entity{ - UID: dropTable, - Parents: []types.EntityUID{types.NewEntityUID("scary", "stuff")}, - }, - }, - Principal: cuzco, - Action: dropTable, - Resource: types.NewEntityUID("table", "whatever"), - Context: types.Record{}, - Want: true, - DiagErr: 0, - }, - { - Name: "permit-action-in-set", - Policy: `permit(principal,action in [scary::"stuff"],resource);`, - Entities: entities.Entities{ - dropTable: entities.Entity{ - UID: dropTable, - Parents: []types.EntityUID{types.NewEntityUID("scary", "stuff")}, - }, - }, - Principal: cuzco, - Action: dropTable, - Resource: types.NewEntityUID("table", "whatever"), - Context: types.Record{}, - Want: true, - DiagErr: 0, - }, - { - Name: "permit-resource-equals", - Policy: `permit(principal,action,resource == table::"whatever");`, - Entities: entities.Entities{}, - Principal: cuzco, - Action: dropTable, - Resource: types.NewEntityUID("table", "whatever"), - Context: types.Record{}, - Want: true, - DiagErr: 0, - }, - { - Name: "permit-unless", - Policy: `permit(principal,action,resource) unless { false };`, - Entities: entities.Entities{}, - Principal: cuzco, - Action: dropTable, - Resource: types.NewEntityUID("table", "whatever"), - Context: types.Record{}, - Want: true, - DiagErr: 0, - }, - { - Name: "permit-when-if", - Policy: `permit(principal,action,resource) when { (if true then true else true) };`, - Entities: entities.Entities{}, - Principal: cuzco, - Action: dropTable, - Resource: types.NewEntityUID("table", "whatever"), - Context: types.Record{}, - Want: true, - DiagErr: 0, - }, - { - Name: "permit-when-or", - Policy: `permit(principal,action,resource) when { (true || false) };`, - Entities: entities.Entities{}, - Principal: cuzco, - Action: dropTable, - Resource: types.NewEntityUID("table", "whatever"), - Context: types.Record{}, - Want: true, - DiagErr: 0, - }, - { - Name: "permit-when-and", - Policy: `permit(principal,action,resource) when { (true && true) };`, - Entities: entities.Entities{}, - Principal: cuzco, - Action: dropTable, - Resource: types.NewEntityUID("table", "whatever"), - Context: types.Record{}, - Want: true, - DiagErr: 0, - }, - { - Name: "permit-when-relations", - Policy: `permit(principal,action,resource) when { (1<2) && (1<=1) && (2>1) && (1>=1) && (1!=2) && (1==1)};`, - Entities: entities.Entities{}, - Principal: cuzco, - Action: dropTable, - Resource: types.NewEntityUID("table", "whatever"), - Context: types.Record{}, - Want: true, - DiagErr: 0, - }, - { - Name: "permit-when-relations-in", - Policy: `permit(principal,action,resource) when { principal in principal };`, - Entities: entities.Entities{}, - Principal: cuzco, - Action: dropTable, - Resource: types.NewEntityUID("table", "whatever"), - Context: types.Record{}, - Want: true, - DiagErr: 0, - }, - { - Name: "permit-when-relations-has", - Policy: `permit(principal,action,resource) when { principal has name };`, - Entities: entities.Entities{ - cuzco: entities.Entity{ - UID: cuzco, - Attributes: types.Record{"name": types.String("bob")}, - }, - }, - Principal: cuzco, - Action: dropTable, - Resource: types.NewEntityUID("table", "whatever"), - Context: types.Record{}, - Want: true, - DiagErr: 0, - }, - { - Name: "permit-when-add-sub", - Policy: `permit(principal,action,resource) when { 40+3-1==42 };`, - Entities: entities.Entities{}, - Principal: cuzco, - Action: dropTable, - Resource: types.NewEntityUID("table", "whatever"), - Context: types.Record{}, - Want: true, - DiagErr: 0, - }, - { - Name: "permit-when-mul", - Policy: `permit(principal,action,resource) when { 6*7==42 };`, - Entities: entities.Entities{}, - Principal: cuzco, - Action: dropTable, - Resource: types.NewEntityUID("table", "whatever"), - Context: types.Record{}, - Want: true, - DiagErr: 0, - }, - { - Name: "permit-when-negate", - Policy: `permit(principal,action,resource) when { -42==-42 };`, - Entities: entities.Entities{}, - Principal: cuzco, - Action: dropTable, - Resource: types.NewEntityUID("table", "whatever"), - Context: types.Record{}, - Want: true, - DiagErr: 0, - }, - { - Name: "permit-when-not", - Policy: `permit(principal,action,resource) when { !(1+1==42) };`, - Entities: entities.Entities{}, - Principal: cuzco, - Action: dropTable, - Resource: types.NewEntityUID("table", "whatever"), - Context: types.Record{}, - Want: true, - DiagErr: 0, - }, - { - Name: "permit-when-set", - Policy: `permit(principal,action,resource) when { [1,2,3].contains(2) };`, - Entities: entities.Entities{}, - Principal: cuzco, - Action: dropTable, - Resource: types.NewEntityUID("table", "whatever"), - Context: types.Record{}, - Want: true, - DiagErr: 0, - }, - { - Name: "permit-when-record", - Policy: `permit(principal,action,resource) when { {name:"bob"} has name };`, - Entities: entities.Entities{}, - Principal: cuzco, - Action: dropTable, - Resource: types.NewEntityUID("table", "whatever"), - Context: types.Record{}, - Want: true, - DiagErr: 0, - }, - { - Name: "permit-when-action", - Policy: `permit(principal,action,resource) when { action in action };`, - Entities: entities.Entities{}, - Principal: cuzco, - Action: dropTable, - Resource: types.NewEntityUID("table", "whatever"), - Context: types.Record{}, - Want: true, - DiagErr: 0, - }, - { - Name: "permit-when-set-contains-ok", - Policy: `permit(principal,action,resource) when { [1,2,3].contains(2) };`, - Entities: entities.Entities{}, - Principal: cuzco, - Action: dropTable, - Resource: types.NewEntityUID("table", "whatever"), - Context: types.Record{}, - Want: true, - DiagErr: 0, - }, - { - Name: "permit-when-set-contains-error", - Policy: `permit(principal,action,resource) when { [1,2,3].contains(2,3) };`, - Entities: entities.Entities{}, - Principal: cuzco, - Action: dropTable, - Resource: types.NewEntityUID("table", "whatever"), - Context: types.Record{}, - Want: false, - DiagErr: 0, - ParseErr: true, - }, - { - Name: "permit-when-set-containsAll-ok", - Policy: `permit(principal,action,resource) when { [1,2,3].containsAll([2,3]) };`, - Entities: entities.Entities{}, - Principal: cuzco, - Action: dropTable, - Resource: types.NewEntityUID("table", "whatever"), - Context: types.Record{}, - Want: true, - DiagErr: 0, - }, - { - Name: "permit-when-set-containsAll-error", - Policy: `permit(principal,action,resource) when { [1,2,3].containsAll(2,3) };`, - Entities: entities.Entities{}, - Principal: cuzco, - Action: dropTable, - Resource: types.NewEntityUID("table", "whatever"), - Context: types.Record{}, - Want: false, - DiagErr: 0, - ParseErr: true, - }, - { - Name: "permit-when-set-containsAny-ok", - Policy: `permit(principal,action,resource) when { [1,2,3].containsAny([2,5]) };`, - Entities: entities.Entities{}, - Principal: cuzco, - Action: dropTable, - Resource: types.NewEntityUID("table", "whatever"), - Context: types.Record{}, - Want: true, - DiagErr: 0, - }, - { - Name: "permit-when-set-containsAny-error", - Policy: `permit(principal,action,resource) when { [1,2,3].containsAny(2,5) };`, - Entities: entities.Entities{}, - Principal: cuzco, - Action: dropTable, - Resource: types.NewEntityUID("table", "whatever"), - Context: types.Record{}, - Want: false, - DiagErr: 0, - ParseErr: true, - }, - { - Name: "permit-when-record-attr", - Policy: `permit(principal,action,resource) when { {name:"bob"}["name"] == "bob" };`, - Entities: entities.Entities{}, - Principal: cuzco, - Action: dropTable, - Resource: types.NewEntityUID("table", "whatever"), - Context: types.Record{}, - Want: true, - DiagErr: 0, - }, - { - Name: "permit-when-unknown-method", - Policy: `permit(principal,action,resource) when { [1,2,3].shuffle() };`, - Entities: entities.Entities{}, - Principal: cuzco, - Action: dropTable, - Resource: types.NewEntityUID("table", "whatever"), - Context: types.Record{}, - Want: false, - DiagErr: 1, - ParseErr: false, - }, - { - Name: "permit-when-like", - Policy: `permit(principal,action,resource) when { "bananas" like "*nan*" };`, - Entities: entities.Entities{}, - Principal: cuzco, - Action: dropTable, - Resource: types.NewEntityUID("table", "whatever"), - Context: types.Record{}, - Want: true, - DiagErr: 0, - }, - { - Name: "permit-when-unknown-ext-fun", - Policy: `permit(principal,action,resource) when { fooBar("10") };`, - Entities: entities.Entities{}, - Principal: cuzco, - Action: dropTable, - Resource: types.NewEntityUID("table", "whatever"), - Context: types.Record{}, - Want: false, - DiagErr: 1, - ParseErr: false, - }, - { - Name: "permit-when-decimal", - Policy: `permit(principal,action,resource) when { - decimal("10.0").lessThan(decimal("11.0")) && - decimal("10.0").lessThanOrEqual(decimal("11.0")) && - decimal("10.0").greaterThan(decimal("9.0")) && - decimal("10.0").greaterThanOrEqual(decimal("9.0")) };`, - Entities: entities.Entities{}, - Principal: cuzco, - Action: dropTable, - Resource: types.NewEntityUID("table", "whatever"), - Context: types.Record{}, - Want: true, - DiagErr: 0, - }, - { - Name: "permit-when-decimal-fun-wrong-arity", - Policy: `permit(principal,action,resource) when { decimal(1, 2) };`, - Entities: entities.Entities{}, - Principal: cuzco, - Action: dropTable, - Resource: types.NewEntityUID("table", "whatever"), - Context: types.Record{}, - Want: false, - DiagErr: 1, - }, - { - Name: "permit-when-ip", - Policy: `permit(principal,action,resource) when { - ip("1.2.3.4").isIpv4() && - ip("a:b:c:d::/16").isIpv6() && - ip("::1").isLoopback() && - ip("224.1.2.3").isMulticast() && - ip("127.0.0.1").isInRange(ip("127.0.0.0/16"))};`, - Entities: entities.Entities{}, - Principal: cuzco, - Action: dropTable, - Resource: types.NewEntityUID("table", "whatever"), - Context: types.Record{}, - Want: true, - DiagErr: 0, - }, - { - Name: "permit-when-ip-fun-wrong-arity", - Policy: `permit(principal,action,resource) when { ip() };`, - Entities: entities.Entities{}, - Principal: cuzco, - Action: dropTable, - Resource: types.NewEntityUID("table", "whatever"), - Context: types.Record{}, - Want: false, - DiagErr: 1, - }, - { - Name: "permit-when-isIpv4-wrong-arity", - Policy: `permit(principal,action,resource) when { ip("1.2.3.4").isIpv4(true) };`, - Entities: entities.Entities{}, - Principal: cuzco, - Action: dropTable, - Resource: types.NewEntityUID("table", "whatever"), - Context: types.Record{}, - Want: false, - DiagErr: 1, - }, - { - Name: "permit-when-isIpv6-wrong-arity", - Policy: `permit(principal,action,resource) when { ip("1.2.3.4").isIpv6(true) };`, - Entities: entities.Entities{}, - Principal: cuzco, - Action: dropTable, - Resource: types.NewEntityUID("table", "whatever"), - Context: types.Record{}, - Want: false, - DiagErr: 1, - }, - { - Name: "permit-when-isLoopback-wrong-arity", - Policy: `permit(principal,action,resource) when { ip("1.2.3.4").isLoopback(true) };`, - Entities: entities.Entities{}, - Principal: cuzco, - Action: dropTable, - Resource: types.NewEntityUID("table", "whatever"), - Context: types.Record{}, - Want: false, - DiagErr: 1, - }, - { - Name: "permit-when-isMulticast-wrong-arity", - Policy: `permit(principal,action,resource) when { ip("1.2.3.4").isMulticast(true) };`, - Entities: entities.Entities{}, - Principal: cuzco, - Action: dropTable, - Resource: types.NewEntityUID("table", "whatever"), - Context: types.Record{}, - Want: false, - DiagErr: 1, - }, - { - Name: "permit-when-isInRange-wrong-arity", - Policy: `permit(principal,action,resource) when { ip("1.2.3.4").isInRange() };`, - Entities: entities.Entities{}, - Principal: cuzco, - Action: dropTable, - Resource: types.NewEntityUID("table", "whatever"), - Context: types.Record{}, - Want: false, - DiagErr: 1, - }, - { - Name: "negative-unary-op", - Policy: `permit(principal,action,resource) when { -context.value > 0 };`, - Entities: entities.Entities{}, - Context: types.Record{"value": types.Long(-42)}, - Want: true, - DiagErr: 0, - }, - { - Name: "principal-is", - Policy: `permit(principal is Actor,action,resource);`, - Entities: entities.Entities{}, - Principal: types.NewEntityUID("Actor", "cuzco"), - Action: types.NewEntityUID("Action", "drop"), - Resource: types.NewEntityUID("Resource", "table"), - Context: types.Record{}, - Want: true, - DiagErr: 0, - }, - { - Name: "principal-is-in", - Policy: `permit(principal is Actor in Actor::"cuzco",action,resource);`, - Entities: entities.Entities{}, - Principal: types.NewEntityUID("Actor", "cuzco"), - Action: types.NewEntityUID("Action", "drop"), - Resource: types.NewEntityUID("Resource", "table"), - Context: types.Record{}, - Want: true, - DiagErr: 0, - }, - { - Name: "resource-is", - Policy: `permit(principal,action,resource is Resource);`, - Entities: entities.Entities{}, - Principal: types.NewEntityUID("Actor", "cuzco"), - Action: types.NewEntityUID("Action", "drop"), - Resource: types.NewEntityUID("Resource", "table"), - Context: types.Record{}, - Want: true, - DiagErr: 0, - }, - { - Name: "resource-is-in", - Policy: `permit(principal,action,resource is Resource in Resource::"table");`, - Entities: entities.Entities{}, - Principal: types.NewEntityUID("Actor", "cuzco"), - Action: types.NewEntityUID("Action", "drop"), - Resource: types.NewEntityUID("Resource", "table"), - Context: types.Record{}, - Want: true, - DiagErr: 0, - }, - { - Name: "when-is", - Policy: `permit(principal,action,resource) when { resource is Resource };`, - Entities: entities.Entities{}, - Principal: types.NewEntityUID("Actor", "cuzco"), - Action: types.NewEntityUID("Action", "drop"), - Resource: types.NewEntityUID("Resource", "table"), - Context: types.Record{}, - Want: true, - DiagErr: 0, - }, - { - Name: "when-is-in", - Policy: `permit(principal,action,resource) when { resource is Resource in Resource::"table" };`, - Entities: entities.Entities{}, - Principal: types.NewEntityUID("Actor", "cuzco"), - Action: types.NewEntityUID("Action", "drop"), - Resource: types.NewEntityUID("Resource", "table"), - Context: types.Record{}, - Want: true, - DiagErr: 0, - }, - { - Name: "when-is-in", - Policy: `permit(principal,action,resource) when { resource is Resource in Parent::"id" };`, - Entities: entities.Entities{ - types.NewEntityUID("Resource", "table"): entities.Entity{ - UID: types.NewEntityUID("Resource", "table"), - Parents: []types.EntityUID{types.NewEntityUID("Parent", "id")}, - }, - }, - Principal: types.NewEntityUID("Actor", "cuzco"), - Action: types.NewEntityUID("Action", "drop"), - Resource: types.NewEntityUID("Resource", "table"), - Context: types.Record{}, - Want: true, - DiagErr: 0, - }, - } - for _, tt := range tests { - tt := tt - t.Run(tt.Name, func(t *testing.T) { - t.Parallel() - ps, err := NewPolicySet("policy.cedar", []byte(tt.Policy)) - testutil.Equals(t, (err != nil), tt.ParseErr) - ok, diag := ps.IsAuthorized(tt.Entities, Request{ - Principal: tt.Principal, - Action: tt.Action, - Resource: tt.Resource, - Context: tt.Context, - }) - testutil.Equals(t, len(diag.Errors), tt.DiagErr) - testutil.Equals(t, ok, tt.Want) - }) - } -} - func TestEntities(t *testing.T) { t.Parallel() t.Run("Clone", func(t *testing.T) { From 86751e2a075b12f5ce51cfc37dcc70cea0118611 Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Tue, 13 Aug 2024 17:13:28 -0700 Subject: [PATCH 100/216] cedar-go: move Policy-related tests to policy_test.go Signed-off-by: philhassey --- cedar_test.go | 58 -------------------------------------------- policy_test.go | 66 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 66 insertions(+), 58 deletions(-) create mode 100644 policy_test.go diff --git a/cedar_test.go b/cedar_test.go index 20f19775..74735e95 100644 --- a/cedar_test.go +++ b/cedar_test.go @@ -1,7 +1,6 @@ package cedar import ( - "bytes" "encoding/json" "net/netip" "testing" @@ -287,63 +286,6 @@ func TestCorpusRelated(t *testing.T) { } } -func prettifyJson(in []byte) []byte { - var buf bytes.Buffer - _ = json.Indent(&buf, in, "", " ") - return buf.Bytes() -} - -func TestPolicyJSON(t *testing.T) { - t.Parallel() - - // Taken from https://docs.cedarpolicy.com/policies/json-format.html - jsonEncodedPolicy := prettifyJson([]byte(` - { - "effect": "permit", - "principal": { - "op": "==", - "entity": { "type": "User", "id": "12UA45" } - }, - "action": { - "op": "==", - "entity": { "type": "Action", "id": "view" } - }, - "resource": { - "op": "in", - "entity": { "type": "Folder", "id": "abc" } - }, - "conditions": [ - { - "kind": "when", - "body": { - "==": { - "left": { - ".": { - "left": { - "Var": "context" - }, - "attr": "tls_version" - } - }, - "right": { - "Value": "1.3" - } - } - } - } - ] - }`, - )) - - var policy Policy - testutil.OK(t, policy.UnmarshalJSON(jsonEncodedPolicy)) - - output, err := policy.MarshalJSON() - testutil.OK(t, err) - - testutil.Equals(t, string(prettifyJson(output)), string(jsonEncodedPolicy)) -} - func TestEntitiesJSON(t *testing.T) { t.Parallel() t.Run("Marshal", func(t *testing.T) { diff --git a/policy_test.go b/policy_test.go new file mode 100644 index 00000000..072f08d4 --- /dev/null +++ b/policy_test.go @@ -0,0 +1,66 @@ +package cedar + +import ( + "bytes" + "encoding/json" + "testing" + + "github.com/cedar-policy/cedar-go/internal/testutil" +) + +func prettifyJson(in []byte) []byte { + var buf bytes.Buffer + _ = json.Indent(&buf, in, "", " ") + return buf.Bytes() +} + +func TestPolicyJSON(t *testing.T) { + t.Parallel() + + // Taken from https://docs.cedarpolicy.com/policies/json-format.html + jsonEncodedPolicy := prettifyJson([]byte(` + { + "effect": "permit", + "principal": { + "op": "==", + "entity": { "type": "User", "id": "12UA45" } + }, + "action": { + "op": "==", + "entity": { "type": "Action", "id": "view" } + }, + "resource": { + "op": "in", + "entity": { "type": "Folder", "id": "abc" } + }, + "conditions": [ + { + "kind": "when", + "body": { + "==": { + "left": { + ".": { + "left": { + "Var": "context" + }, + "attr": "tls_version" + } + }, + "right": { + "Value": "1.3" + } + } + } + } + ] + }`, + )) + + var policy Policy + testutil.OK(t, policy.UnmarshalJSON(jsonEncodedPolicy)) + + output, err := policy.MarshalJSON() + testutil.OK(t, err) + + testutil.Equals(t, string(prettifyJson(output)), string(jsonEncodedPolicy)) +} From be3a815a14e0e8f9b362f63c26ee98e92459318c Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Tue, 13 Aug 2024 17:14:45 -0700 Subject: [PATCH 101/216] cedar-go: remove superfluous test Signed-off-by: philhassey --- cedar_test.go | 51 --------------------------------------------------- 1 file changed, 51 deletions(-) diff --git a/cedar_test.go b/cedar_test.go index 74735e95..0a3c6017 100644 --- a/cedar_test.go +++ b/cedar_test.go @@ -2,7 +2,6 @@ package cedar import ( "encoding/json" - "net/netip" "testing" "github.com/cedar-policy/cedar-go/internal/entities" @@ -67,56 +66,6 @@ func TestEntities(t *testing.T) { } -func TestValueFrom(t *testing.T) { - t.Parallel() - tests := []struct { - name string - in types.Value - outJSON string - }{ - { - name: "string", - in: types.String("hello"), - outJSON: `"hello"`, - }, - { - name: "bool", - in: types.Boolean(true), - outJSON: `true`, - }, - { - name: "int64", - in: types.Long(42), - outJSON: `42`, - }, - { - name: "int64", - in: types.EntityUID{Type: "T", ID: "0"}, - outJSON: `{"__entity":{"type":"T","id":"0"}}`, - }, - { - name: "record", - in: types.Record{"K": types.Boolean(true)}, - outJSON: `{"K":true}`, - }, - { - name: "netipPrefix", - in: types.IPAddr(netip.MustParsePrefix("192.168.0.42/32")), - outJSON: `{"__extn":{"fn":"ip","arg":"192.168.0.42"}}`, - }, - } - - for _, tt := range tests { - tt := tt - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - out, err := tt.in.ExplicitMarshalJSON() - testutil.OK(t, err) - testutil.Equals(t, string(out), tt.outJSON) - }) - } -} - func TestError(t *testing.T) { t.Parallel() e := Error{Policy: 42, Message: "bad error"} From ec792965fdebc0ff9d4eb7eab497fd3d64e71ee2 Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Tue, 13 Aug 2024 17:16:01 -0700 Subject: [PATCH 102/216] cedar-go: move TestEntities into value_test.go Signed-off-by: philhassey --- cedar_test.go | 16 ---------------- types/value_test.go | 17 +++++++++++++++++ 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/cedar_test.go b/cedar_test.go index 0a3c6017..4725c2d3 100644 --- a/cedar_test.go +++ b/cedar_test.go @@ -50,22 +50,6 @@ func TestNewPolicySet(t *testing.T) { }) } -func TestEntities(t *testing.T) { - t.Parallel() - t.Run("Clone", func(t *testing.T) { - t.Parallel() - e := entities.Entities{ - types.EntityUID{Type: "A", ID: "A"}: {}, - types.EntityUID{Type: "A", ID: "B"}: {}, - types.EntityUID{Type: "B", ID: "A"}: {}, - types.EntityUID{Type: "B", ID: "B"}: {}, - } - clone := e.Clone() - testutil.Equals(t, clone, e) - }) - -} - func TestError(t *testing.T) { t.Parallel() e := Error{Policy: 42, Message: "bad error"} diff --git a/types/value_test.go b/types/value_test.go index 9cf2ffbb..51e03048 100644 --- a/types/value_test.go +++ b/types/value_test.go @@ -4,6 +4,7 @@ import ( "fmt" "testing" + "github.com/cedar-policy/cedar-go/internal/entities" "github.com/cedar-policy/cedar-go/internal/testutil" ) @@ -877,3 +878,19 @@ func TestPath(t *testing.T) { }) } + +func TestEntities(t *testing.T) { + t.Parallel() + t.Run("Clone", func(t *testing.T) { + t.Parallel() + e := entities.Entities{ + EntityUID{Type: "A", ID: "A"}: {}, + EntityUID{Type: "A", ID: "B"}: {}, + EntityUID{Type: "B", ID: "A"}: {}, + EntityUID{Type: "B", ID: "B"}: {}, + } + clone := e.Clone() + testutil.Equals(t, clone, e) + }) + +} From 681676a21a11db1e93a300f170b018913e8d6fdb Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Tue, 13 Aug 2024 17:16:47 -0700 Subject: [PATCH 103/216] cedar-go: move error test to authorize_test.go Signed-off-by: philhassey --- authorize_test.go | 6 ++++++ cedar_test.go | 6 ------ 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/authorize_test.go b/authorize_test.go index 4ea47287..1bdd2f9c 100644 --- a/authorize_test.go +++ b/authorize_test.go @@ -721,3 +721,9 @@ func TestIsAuthorized(t *testing.T) { }) } } + +func TestError(t *testing.T) { + t.Parallel() + e := Error{Policy: 42, Message: "bad error"} + testutil.Equals(t, e.String(), "while evaluating policy `policy42`: bad error") +} diff --git a/cedar_test.go b/cedar_test.go index 4725c2d3..c50802b4 100644 --- a/cedar_test.go +++ b/cedar_test.go @@ -50,12 +50,6 @@ func TestNewPolicySet(t *testing.T) { }) } -func TestError(t *testing.T) { - t.Parallel() - e := Error{Policy: 42, Message: "bad error"} - testutil.Equals(t, e.String(), "while evaluating policy `policy42`: bad error") -} - func TestCorpusRelated(t *testing.T) { t.Parallel() tests := []struct { From 59b6802ca629a157f64f1d0e733f9fb18aaa1fc4 Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Tue, 13 Aug 2024 17:17:47 -0700 Subject: [PATCH 104/216] cedar-go: move TestCorpusRelated into corpus_test.go Signed-off-by: philhassey --- cedar_test.go | 163 ------------------------------------------------ corpus_test.go | 165 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 165 insertions(+), 163 deletions(-) diff --git a/cedar_test.go b/cedar_test.go index c50802b4..a837e62e 100644 --- a/cedar_test.go +++ b/cedar_test.go @@ -50,169 +50,6 @@ func TestNewPolicySet(t *testing.T) { }) } -func TestCorpusRelated(t *testing.T) { - t.Parallel() - tests := []struct { - name string - policy string - request Request - decision Decision - reasons []int - errors []int - }{ - { - "0cb1ad7042508e708f1999284b634ed0f334bc00", - `forbid( - principal in a::"\0\0", - action == Action::"action", - resource - ) when { - (true && (((!870985681610) == principal) == principal)) && principal - };`, - Request{Principal: types.NewEntityUID("a", "\u0000\u0000"), Action: types.NewEntityUID("Action", "action"), Resource: types.NewEntityUID("a", "\u0000\u0000")}, - Deny, - nil, - []int{0}, - }, - - { - "0cb1ad7042508e708f1999284b634ed0f334bc00/partial1", - `forbid( - principal in a::"\0\0", - action == Action::"action", - resource - ) when { - (((!870985681610) == principal) == principal) - };`, - Request{Principal: types.NewEntityUID("a", "\u0000\u0000"), Action: types.NewEntityUID("Action", "action"), Resource: types.NewEntityUID("a", "\u0000\u0000")}, - Deny, - nil, - []int{0}, - }, - { - "0cb1ad7042508e708f1999284b634ed0f334bc00/partial2", - `forbid( - principal in a::"\0\0", - action == Action::"action", - resource - ) when { - ((!870985681610) == principal) - };`, - Request{Principal: types.NewEntityUID("a", "\u0000\u0000"), Action: types.NewEntityUID("Action", "action"), Resource: types.NewEntityUID("a", "\u0000\u0000")}, - Deny, - nil, - []int{0}, - }, - - { - "0cb1ad7042508e708f1999284b634ed0f334bc00/partial3", - `forbid( - principal in a::"\0\0", - action == Action::"action", - resource - ) when { - (!870985681610) - };`, - Request{Principal: types.NewEntityUID("a", "\u0000\u0000"), Action: types.NewEntityUID("Action", "action"), Resource: types.NewEntityUID("a", "\u0000\u0000")}, - Deny, - nil, - []int{0}, - }, - - { - "0cb1ad7042508e708f1999284b634ed0f334bc00/partial2/simplified", - `forbid( - principal, - action, - resource - ) when { - ((!42) == principal) - };`, - Request{}, - Deny, - nil, - []int{0}, - }, - - { - "0cb1ad7042508e708f1999284b634ed0f334bc00/partial2/simplified2", - `forbid( - principal, - action, - resource - ) when { - (!42 == principal) - };`, - Request{}, - Deny, - nil, - []int{0}, - }, - - {"48d0ba6537a3efe02112ba0f5a3daabdcad27b04", - `forbid( - principal, - action in [Action::"action"], - resource is a in a::"\0\u{8}\u{11}\0R" - ) when { - true && ((if (principal in action) then (ip("")) else (if true then (ip("6b6b:f00::32ff:ffff:6368/00")) else (ip("7265:6c69:706d:6f43:5f74:6f70:7374:6f68")))).isMulticast()) - };`, - Request{Principal: types.NewEntityUID("a", "\u0000\b\u0011\u0000R"), Action: types.NewEntityUID("Action", "action"), Resource: types.NewEntityUID("a", "\u0000\b\u0011\u0000R")}, - Deny, - nil, - []int{0}, - }, - - {"48d0ba6537a3efe02112ba0f5a3daabdcad27b04/simplified", - `forbid( - principal, - action, - resource - ) when { - true && ip("6b6b:f00::32ff:ffff:6368/00").isMulticast() - };`, - Request{}, - Deny, - nil, - []int{0}, - }, - - {name: "e91da4e6af5c73e27f5fb610d723dfa21635d10b", - policy: `forbid( - principal is a in a::"\0\0(W\0\0\0", - action, - resource - ) when { - true && (([ip("c5c5:c5c5:c5c5:c5c5:c5c5:c5c5:c5c5:c5c5/68")].containsAll([ip("c5c5:c5c5:c5c5:c5c5:c5c5:5cc5:c5c5:c5c5/68")])) || ((ip("")) == (ip("")))) - };`, - request: Request{Principal: types.NewEntityUID("a", "\u0000\u0000(W\u0000\u0000\u0000"), Action: types.NewEntityUID("Action", "action"), Resource: types.NewEntityUID("a", "")}, - decision: Deny, - reasons: nil, - errors: []int{0}, - }, - } - for _, tt := range tests { - tt := tt - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - policy, err := NewPolicySet("", []byte(tt.policy)) - testutil.OK(t, err) - ok, diag := policy.IsAuthorized(entities.Entities{}, tt.request) - testutil.Equals(t, ok, tt.decision) - var reasons []int - for _, n := range diag.Reasons { - reasons = append(reasons, n.Policy) - } - testutil.Equals(t, reasons, tt.reasons) - var errors []int - for _, n := range diag.Errors { - errors = append(errors, n.Policy) - } - testutil.Equals(t, errors, tt.errors) - }) - } -} - func TestEntitiesJSON(t *testing.T) { t.Parallel() t.Run("Marshal", func(t *testing.T) { diff --git a/corpus_test.go b/corpus_test.go index b544570e..7edf9712 100644 --- a/corpus_test.go +++ b/corpus_test.go @@ -13,6 +13,7 @@ import ( "testing" entities2 "github.com/cedar-policy/cedar-go/internal/entities" + "github.com/cedar-policy/cedar-go/internal/testutil" "github.com/cedar-policy/cedar-go/types" ) @@ -188,3 +189,167 @@ func TestCorpus(t *testing.T) { }) } } + +// Specific corpus tests that have been extracted for easy regression testing purposes +func TestCorpusRelated(t *testing.T) { + t.Parallel() + tests := []struct { + name string + policy string + request Request + decision Decision + reasons []int + errors []int + }{ + { + "0cb1ad7042508e708f1999284b634ed0f334bc00", + `forbid( + principal in a::"\0\0", + action == Action::"action", + resource + ) when { + (true && (((!870985681610) == principal) == principal)) && principal + };`, + Request{Principal: types.NewEntityUID("a", "\u0000\u0000"), Action: types.NewEntityUID("Action", "action"), Resource: types.NewEntityUID("a", "\u0000\u0000")}, + Deny, + nil, + []int{0}, + }, + + { + "0cb1ad7042508e708f1999284b634ed0f334bc00/partial1", + `forbid( + principal in a::"\0\0", + action == Action::"action", + resource + ) when { + (((!870985681610) == principal) == principal) + };`, + Request{Principal: types.NewEntityUID("a", "\u0000\u0000"), Action: types.NewEntityUID("Action", "action"), Resource: types.NewEntityUID("a", "\u0000\u0000")}, + Deny, + nil, + []int{0}, + }, + { + "0cb1ad7042508e708f1999284b634ed0f334bc00/partial2", + `forbid( + principal in a::"\0\0", + action == Action::"action", + resource + ) when { + ((!870985681610) == principal) + };`, + Request{Principal: types.NewEntityUID("a", "\u0000\u0000"), Action: types.NewEntityUID("Action", "action"), Resource: types.NewEntityUID("a", "\u0000\u0000")}, + Deny, + nil, + []int{0}, + }, + + { + "0cb1ad7042508e708f1999284b634ed0f334bc00/partial3", + `forbid( + principal in a::"\0\0", + action == Action::"action", + resource + ) when { + (!870985681610) + };`, + Request{Principal: types.NewEntityUID("a", "\u0000\u0000"), Action: types.NewEntityUID("Action", "action"), Resource: types.NewEntityUID("a", "\u0000\u0000")}, + Deny, + nil, + []int{0}, + }, + + { + "0cb1ad7042508e708f1999284b634ed0f334bc00/partial2/simplified", + `forbid( + principal, + action, + resource + ) when { + ((!42) == principal) + };`, + Request{}, + Deny, + nil, + []int{0}, + }, + + { + "0cb1ad7042508e708f1999284b634ed0f334bc00/partial2/simplified2", + `forbid( + principal, + action, + resource + ) when { + (!42 == principal) + };`, + Request{}, + Deny, + nil, + []int{0}, + }, + + {"48d0ba6537a3efe02112ba0f5a3daabdcad27b04", + `forbid( + principal, + action in [Action::"action"], + resource is a in a::"\0\u{8}\u{11}\0R" + ) when { + true && ((if (principal in action) then (ip("")) else (if true then (ip("6b6b:f00::32ff:ffff:6368/00")) else (ip("7265:6c69:706d:6f43:5f74:6f70:7374:6f68")))).isMulticast()) + };`, + Request{Principal: types.NewEntityUID("a", "\u0000\b\u0011\u0000R"), Action: types.NewEntityUID("Action", "action"), Resource: types.NewEntityUID("a", "\u0000\b\u0011\u0000R")}, + Deny, + nil, + []int{0}, + }, + + {"48d0ba6537a3efe02112ba0f5a3daabdcad27b04/simplified", + `forbid( + principal, + action, + resource + ) when { + true && ip("6b6b:f00::32ff:ffff:6368/00").isMulticast() + };`, + Request{}, + Deny, + nil, + []int{0}, + }, + + {name: "e91da4e6af5c73e27f5fb610d723dfa21635d10b", + policy: `forbid( + principal is a in a::"\0\0(W\0\0\0", + action, + resource + ) when { + true && (([ip("c5c5:c5c5:c5c5:c5c5:c5c5:c5c5:c5c5:c5c5/68")].containsAll([ip("c5c5:c5c5:c5c5:c5c5:c5c5:5cc5:c5c5:c5c5/68")])) || ((ip("")) == (ip("")))) + };`, + request: Request{Principal: types.NewEntityUID("a", "\u0000\u0000(W\u0000\u0000\u0000"), Action: types.NewEntityUID("Action", "action"), Resource: types.NewEntityUID("a", "")}, + decision: Deny, + reasons: nil, + errors: []int{0}, + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + policy, err := NewPolicySet("", []byte(tt.policy)) + testutil.OK(t, err) + ok, diag := policy.IsAuthorized(entities2.Entities{}, tt.request) + testutil.Equals(t, ok, tt.decision) + var reasons []int + for _, n := range diag.Reasons { + reasons = append(reasons, n.Policy) + } + testutil.Equals(t, reasons, tt.reasons) + var errors []int + for _, n := range diag.Errors { + errors = append(errors, n.Policy) + } + testutil.Equals(t, errors, tt.errors) + }) + } +} From 875ac6bf1a5c19fd1630e3f6c65e1df38fd7dab6 Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Tue, 13 Aug 2024 17:23:29 -0700 Subject: [PATCH 105/216] cedar-go/internal/entities: move entities test into entities package Signed-off-by: philhassey --- internal/entities/entities_test.go | 24 ++++++++++++++++++++++++ types/value_test.go | 17 ----------------- 2 files changed, 24 insertions(+), 17 deletions(-) create mode 100644 internal/entities/entities_test.go diff --git a/internal/entities/entities_test.go b/internal/entities/entities_test.go new file mode 100644 index 00000000..b0636fa8 --- /dev/null +++ b/internal/entities/entities_test.go @@ -0,0 +1,24 @@ +package entities + +import ( + "testing" + + "github.com/cedar-policy/cedar-go/internal/testutil" + "github.com/cedar-policy/cedar-go/types" +) + +func TestEntities(t *testing.T) { + t.Parallel() + t.Run("Clone", func(t *testing.T) { + t.Parallel() + e := Entities{ + types.EntityUID{Type: "A", ID: "A"}: {}, + types.EntityUID{Type: "A", ID: "B"}: {}, + types.EntityUID{Type: "B", ID: "A"}: {}, + types.EntityUID{Type: "B", ID: "B"}: {}, + } + clone := e.Clone() + testutil.Equals(t, clone, e) + }) + +} diff --git a/types/value_test.go b/types/value_test.go index 51e03048..9cf2ffbb 100644 --- a/types/value_test.go +++ b/types/value_test.go @@ -4,7 +4,6 @@ import ( "fmt" "testing" - "github.com/cedar-policy/cedar-go/internal/entities" "github.com/cedar-policy/cedar-go/internal/testutil" ) @@ -878,19 +877,3 @@ func TestPath(t *testing.T) { }) } - -func TestEntities(t *testing.T) { - t.Parallel() - t.Run("Clone", func(t *testing.T) { - t.Parallel() - e := entities.Entities{ - EntityUID{Type: "A", ID: "A"}: {}, - EntityUID{Type: "A", ID: "B"}: {}, - EntityUID{Type: "B", ID: "A"}: {}, - EntityUID{Type: "B", ID: "B"}: {}, - } - clone := e.Clone() - testutil.Equals(t, clone, e) - }) - -} From 83dcdcb8928b508cd16d66892bb39a0b8ea70ea6 Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Tue, 13 Aug 2024 17:24:30 -0700 Subject: [PATCH 106/216] cedar-go: move TestJSONDecision into authorize_test.go Signed-off-by: philhassey --- authorize_test.go | 33 +++++++++++++++++++++++++++++++++ cedar_test.go | 32 -------------------------------- 2 files changed, 33 insertions(+), 32 deletions(-) diff --git a/authorize_test.go b/authorize_test.go index 1bdd2f9c..201ee083 100644 --- a/authorize_test.go +++ b/authorize_test.go @@ -1,6 +1,7 @@ package cedar import ( + "encoding/json" "testing" "github.com/cedar-policy/cedar-go/internal/entities" @@ -727,3 +728,35 @@ func TestError(t *testing.T) { e := Error{Policy: 42, Message: "bad error"} testutil.Equals(t, e.String(), "while evaluating policy `policy42`: bad error") } + +func TestJSONDecision(t *testing.T) { + t.Parallel() + t.Run("MarshalAllow", func(t *testing.T) { + t.Parallel() + d := Allow + b, err := d.MarshalJSON() + testutil.OK(t, err) + testutil.Equals(t, string(b), `"allow"`) + }) + t.Run("MarshalDeny", func(t *testing.T) { + t.Parallel() + d := Deny + b, err := d.MarshalJSON() + testutil.OK(t, err) + testutil.Equals(t, string(b), `"deny"`) + }) + t.Run("UnmarshalAllow", func(t *testing.T) { + t.Parallel() + var d Decision + err := json.Unmarshal([]byte(`"allow"`), &d) + testutil.OK(t, err) + testutil.Equals(t, d, Allow) + }) + t.Run("UnmarshalDeny", func(t *testing.T) { + t.Parallel() + var d Decision + err := json.Unmarshal([]byte(`"deny"`), &d) + testutil.OK(t, err) + testutil.Equals(t, d, Deny) + }) +} diff --git a/cedar_test.go b/cedar_test.go index a837e62e..a5f26c0f 100644 --- a/cedar_test.go +++ b/cedar_test.go @@ -89,35 +89,3 @@ func TestEntitiesJSON(t *testing.T) { testutil.Error(t, err) }) } - -func TestJSONDecision(t *testing.T) { - t.Parallel() - t.Run("MarshalAllow", func(t *testing.T) { - t.Parallel() - d := Allow - b, err := d.MarshalJSON() - testutil.OK(t, err) - testutil.Equals(t, string(b), `"allow"`) - }) - t.Run("MarshalDeny", func(t *testing.T) { - t.Parallel() - d := Deny - b, err := d.MarshalJSON() - testutil.OK(t, err) - testutil.Equals(t, string(b), `"deny"`) - }) - t.Run("UnmarshalAllow", func(t *testing.T) { - t.Parallel() - var d Decision - err := json.Unmarshal([]byte(`"allow"`), &d) - testutil.OK(t, err) - testutil.Equals(t, d, Allow) - }) - t.Run("UnmarshalDeny", func(t *testing.T) { - t.Parallel() - var d Decision - err := json.Unmarshal([]byte(`"deny"`), &d) - testutil.OK(t, err) - testutil.Equals(t, d, Deny) - }) -} From 8bd3ecc18b03f224531098802c504811d2c9bc68 Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Tue, 13 Aug 2024 17:25:37 -0700 Subject: [PATCH 107/216] cedar-go/internal/entities: move TestEntitiesJSON to entities package Signed-off-by: philhassey --- cedar_test.go | 42 ------------------------------ internal/entities/entities_test.go | 41 +++++++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 42 deletions(-) diff --git a/cedar_test.go b/cedar_test.go index a5f26c0f..18feb7a1 100644 --- a/cedar_test.go +++ b/cedar_test.go @@ -1,10 +1,8 @@ package cedar import ( - "encoding/json" "testing" - "github.com/cedar-policy/cedar-go/internal/entities" "github.com/cedar-policy/cedar-go/internal/testutil" "github.com/cedar-policy/cedar-go/types" ) @@ -49,43 +47,3 @@ func TestNewPolicySet(t *testing.T) { testutil.Equals(t, ps[0].Annotations, Annotations{"key": "value"}) }) } - -func TestEntitiesJSON(t *testing.T) { - t.Parallel() - t.Run("Marshal", func(t *testing.T) { - t.Parallel() - e := entities.Entities{} - ent := entities.Entity{ - UID: types.NewEntityUID("Type", "id"), - Parents: []types.EntityUID{}, - Attributes: types.Record{"key": types.Long(42)}, - } - e[ent.UID] = ent - b, err := e.MarshalJSON() - testutil.OK(t, err) - testutil.Equals(t, string(b), `[{"uid":{"type":"Type","id":"id"},"attrs":{"key":42}}]`) - }) - - t.Run("Unmarshal", func(t *testing.T) { - t.Parallel() - b := []byte(`[{"uid":{"type":"Type","id":"id"},"parents":[],"attrs":{"key":42}}]`) - var e entities.Entities - err := json.Unmarshal(b, &e) - testutil.OK(t, err) - want := entities.Entities{} - ent := entities.Entity{ - UID: types.NewEntityUID("Type", "id"), - Parents: []types.EntityUID{}, - Attributes: types.Record{"key": types.Long(42)}, - } - want[ent.UID] = ent - testutil.Equals(t, e, want) - }) - - t.Run("UnmarshalErr", func(t *testing.T) { - t.Parallel() - var e entities.Entities - err := e.UnmarshalJSON([]byte(`!@#$`)) - testutil.Error(t, err) - }) -} diff --git a/internal/entities/entities_test.go b/internal/entities/entities_test.go index b0636fa8..16a24618 100644 --- a/internal/entities/entities_test.go +++ b/internal/entities/entities_test.go @@ -1,6 +1,7 @@ package entities import ( + "encoding/json" "testing" "github.com/cedar-policy/cedar-go/internal/testutil" @@ -22,3 +23,43 @@ func TestEntities(t *testing.T) { }) } + +func TestEntitiesJSON(t *testing.T) { + t.Parallel() + t.Run("Marshal", func(t *testing.T) { + t.Parallel() + e := Entities{} + ent := Entity{ + UID: types.NewEntityUID("Type", "id"), + Parents: []types.EntityUID{}, + Attributes: types.Record{"key": types.Long(42)}, + } + e[ent.UID] = ent + b, err := e.MarshalJSON() + testutil.OK(t, err) + testutil.Equals(t, string(b), `[{"uid":{"type":"Type","id":"id"},"attrs":{"key":42}}]`) + }) + + t.Run("Unmarshal", func(t *testing.T) { + t.Parallel() + b := []byte(`[{"uid":{"type":"Type","id":"id"},"parents":[],"attrs":{"key":42}}]`) + var e Entities + err := json.Unmarshal(b, &e) + testutil.OK(t, err) + want := Entities{} + ent := Entity{ + UID: types.NewEntityUID("Type", "id"), + Parents: []types.EntityUID{}, + Attributes: types.Record{"key": types.Long(42)}, + } + want[ent.UID] = ent + testutil.Equals(t, e, want) + }) + + t.Run("UnmarshalErr", func(t *testing.T) { + t.Parallel() + var e Entities + err := e.UnmarshalJSON([]byte(`!@#$`)) + testutil.Error(t, err) + }) +} From 958f195843646caa5e85f040e0c74f6ce7320df8 Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Tue, 13 Aug 2024 17:26:19 -0700 Subject: [PATCH 108/216] cedar-go/internal/entities: move TestEntityIsZero into entities package Signed-off-by: philhassey --- cedar_test.go | 22 ---------------------- internal/entities/entities_test.go | 21 +++++++++++++++++++++ 2 files changed, 21 insertions(+), 22 deletions(-) diff --git a/cedar_test.go b/cedar_test.go index 18feb7a1..e008afdb 100644 --- a/cedar_test.go +++ b/cedar_test.go @@ -4,30 +4,8 @@ import ( "testing" "github.com/cedar-policy/cedar-go/internal/testutil" - "github.com/cedar-policy/cedar-go/types" ) -func TestEntityIsZero(t *testing.T) { - t.Parallel() - tests := []struct { - name string - uid types.EntityUID - want bool - }{ - {"empty", types.EntityUID{}, true}, - {"empty-type", types.NewEntityUID("one", ""), false}, - {"empty-id", types.NewEntityUID("", "one"), false}, - {"not-empty", types.NewEntityUID("one", "two"), false}, - } - for _, tt := range tests { - tt := tt - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - testutil.Equals(t, tt.uid.IsZero(), tt.want) - }) - } -} - func TestNewPolicySet(t *testing.T) { t.Parallel() t.Run("err-in-tokenize", func(t *testing.T) { diff --git a/internal/entities/entities_test.go b/internal/entities/entities_test.go index 16a24618..1ccaceaa 100644 --- a/internal/entities/entities_test.go +++ b/internal/entities/entities_test.go @@ -63,3 +63,24 @@ func TestEntitiesJSON(t *testing.T) { testutil.Error(t, err) }) } + +func TestEntityIsZero(t *testing.T) { + t.Parallel() + tests := []struct { + name string + uid types.EntityUID + want bool + }{ + {"empty", types.EntityUID{}, true}, + {"empty-type", types.NewEntityUID("one", ""), false}, + {"empty-id", types.NewEntityUID("", "one"), false}, + {"not-empty", types.NewEntityUID("one", "two"), false}, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + testutil.Equals(t, tt.uid.IsZero(), tt.want) + }) + } +} From e5d5ad26e1e1054064062242e4c452d9d3bccbbf Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Tue, 13 Aug 2024 17:27:16 -0700 Subject: [PATCH 109/216] cedar-go: rename cedar.go to policy_set.go Signed-off-by: philhassey --- cedar.go => policy_set.go | 0 cedar_test.go => policy_set_test.go | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename cedar.go => policy_set.go (100%) rename cedar_test.go => policy_set_test.go (100%) diff --git a/cedar.go b/policy_set.go similarity index 100% rename from cedar.go rename to policy_set.go diff --git a/cedar_test.go b/policy_set_test.go similarity index 100% rename from cedar_test.go rename to policy_set_test.go From 3850049a0032458457405a127a2a28c94eac2d94 Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Tue, 13 Aug 2024 17:37:03 -0700 Subject: [PATCH 110/216] cedar-go: add Cedar marshaling and unmarshaling to Policy Signed-off-by: philhassey --- policy.go | 24 ++++++++++++++++++++++++ policy_test.go | 20 ++++++++++++++++++++ 2 files changed, 44 insertions(+) diff --git a/policy.go b/policy.go index 0b3feb06..fdeb25da 100644 --- a/policy.go +++ b/policy.go @@ -1,9 +1,12 @@ package cedar import ( + "bytes" + "github.com/cedar-policy/cedar-go/internal/ast" "github.com/cedar-policy/cedar-go/internal/eval" "github.com/cedar-policy/cedar-go/internal/json" + "github.com/cedar-policy/cedar-go/internal/parser" ) // A Policy is the parsed form of a single Cedar language policy statement. @@ -71,3 +74,24 @@ func (p *Policy) UnmarshalJSON(b []byte) error { } return nil } + +func (p *Policy) MarshalCedar(buf *bytes.Buffer) { + cedarPolicy := &parser.Policy{Policy: p.ast} + cedarPolicy.MarshalCedar(buf) +} + +func (p *Policy) UnmarshalCedar(b []byte) error { + var cedarPolicy parser.Policy + if err := cedarPolicy.UnmarshalCedar(b); err != nil { + return err + } + + *p = Policy{ + Position: Position{}, + Annotations: newAnnotationsFromSlice(cedarPolicy.Annotations), + Effect: Effect(cedarPolicy.Effect), + eval: eval.Compile(cedarPolicy.Policy), + ast: cedarPolicy.Policy, + } + return nil +} diff --git a/policy_test.go b/policy_test.go index 072f08d4..07ad5f0f 100644 --- a/policy_test.go +++ b/policy_test.go @@ -64,3 +64,23 @@ func TestPolicyJSON(t *testing.T) { testutil.Equals(t, string(prettifyJson(output)), string(jsonEncodedPolicy)) } + +func TestPolicyCedar(t *testing.T) { + t.Parallel() + + // Taken from https://docs.cedarpolicy.com/policies/syntax-policy.html + policyStr := `permit ( + principal, + action == Action::"editPhoto", + resource +) +when { resource.owner == principal };` + + var policy Policy + testutil.OK(t, policy.UnmarshalCedar([]byte(policyStr))) + + var buf bytes.Buffer + policy.MarshalCedar(&buf) + + testutil.Equals(t, buf.String(), policyStr) +} From f419c39d9774f5ac732fb874ee45994d3a2c261a Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Tue, 13 Aug 2024 17:54:21 -0700 Subject: [PATCH 111/216] cedar-go: add a method for constructing a Policy from an AST Signed-off-by: philhassey --- policy.go | 28 ++++++++++++++++++++-------- policy_set.go | 2 +- policy_test.go | 12 ++++++++++++ 3 files changed, 33 insertions(+), 9 deletions(-) diff --git a/policy.go b/policy.go index fdeb25da..ac91ddf6 100644 --- a/policy.go +++ b/policy.go @@ -3,7 +3,8 @@ package cedar import ( "bytes" - "github.com/cedar-policy/cedar-go/internal/ast" + "github.com/cedar-policy/cedar-go/ast" + internalast "github.com/cedar-policy/cedar-go/internal/ast" "github.com/cedar-policy/cedar-go/internal/eval" "github.com/cedar-policy/cedar-go/internal/json" "github.com/cedar-policy/cedar-go/internal/parser" @@ -15,7 +16,7 @@ type Policy struct { Annotations Annotations // annotations found for this policy Effect Effect // the effect of this policy eval evaler // determines if a policy matches a request. - ast ast.Policy + ast *internalast.Policy } // A Position describes an arbitrary source position including the file, line, and column location. @@ -31,7 +32,7 @@ type Position struct { type Annotations map[string]string // TODO: Is this where we should deal with duplicate keys? -func newAnnotationsFromSlice(annotations []ast.AnnotationType) Annotations { +func newAnnotationsFromSlice(annotations []internalast.AnnotationType) Annotations { res := make(map[string]string, len(annotations)) for _, e := range annotations { res[string(e.Key)] = string(e.Value) @@ -41,7 +42,7 @@ func newAnnotationsFromSlice(annotations []ast.AnnotationType) Annotations { // An Effect specifies the intent of the policy, to either permit or forbid any // request that matches the scope and conditions specified in the policy. -type Effect ast.Effect +type Effect internalast.Effect // Each Policy has a Permit or Forbid effect that is determined during parsing. const ( @@ -53,7 +54,7 @@ const ( // // [Cedar documentation]: https://docs.cedarpolicy.com/policies/json-format.html func (p *Policy) MarshalJSON() ([]byte, error) { - jsonPolicy := &json.Policy{Policy: p.ast} + jsonPolicy := &json.Policy{Policy: *p.ast} return jsonPolicy.MarshalJSON() } @@ -70,13 +71,13 @@ func (p *Policy) UnmarshalJSON(b []byte) error { Annotations: newAnnotationsFromSlice(jsonPolicy.Annotations), Effect: Effect(jsonPolicy.Effect), eval: eval.Compile(jsonPolicy.Policy), - ast: jsonPolicy.Policy, + ast: &jsonPolicy.Policy, } return nil } func (p *Policy) MarshalCedar(buf *bytes.Buffer) { - cedarPolicy := &parser.Policy{Policy: p.ast} + cedarPolicy := &parser.Policy{Policy: *p.ast} cedarPolicy.MarshalCedar(buf) } @@ -91,7 +92,18 @@ func (p *Policy) UnmarshalCedar(b []byte) error { Annotations: newAnnotationsFromSlice(cedarPolicy.Annotations), Effect: Effect(cedarPolicy.Effect), eval: eval.Compile(cedarPolicy.Policy), - ast: cedarPolicy.Policy, + ast: &cedarPolicy.Policy, } return nil } + +func NewPolicyFromAST(astIn *ast.Policy) *Policy { + pp := (*internalast.Policy)(astIn) + return &Policy{ + Position: Position{}, + Annotations: newAnnotationsFromSlice(astIn.Annotations), + Effect: Effect(astIn.Effect), + eval: eval.Compile(*pp), + ast: pp, + } +} diff --git a/policy_set.go b/policy_set.go index ec7e269a..c0b613d9 100644 --- a/policy_set.go +++ b/policy_set.go @@ -31,7 +31,7 @@ func NewPolicySet(fileName string, document []byte) (PolicySet, error) { Annotations: newAnnotationsFromSlice(p.Policy.Annotations), Effect: Effect(p.Policy.Effect), eval: eval.Compile(p.Policy.Policy), - ast: p.Policy.Policy, + ast: &p.Policy.Policy, }) } return policies, nil diff --git a/policy_test.go b/policy_test.go index 07ad5f0f..1a1a84ef 100644 --- a/policy_test.go +++ b/policy_test.go @@ -5,7 +5,9 @@ import ( "encoding/json" "testing" + "github.com/cedar-policy/cedar-go/ast" "github.com/cedar-policy/cedar-go/internal/testutil" + "github.com/cedar-policy/cedar-go/types" ) func prettifyJson(in []byte) []byte { @@ -84,3 +86,13 @@ when { resource.owner == principal };` testutil.Equals(t, buf.String(), policyStr) } + +func TestPolicyAST(t *testing.T) { + t.Parallel() + + astExample := ast.Permit(). + ActionEq(types.NewEntityUID("Action", "editPhoto")). + When(ast.Resource().Access("owner").Equals(ast.Principal())) + + _ = NewPolicyFromAST(astExample) +} From 25fca15d008507b559cef75d28ae5ca087036e03 Mon Sep 17 00:00:00 2001 From: philhassey Date: Wed, 14 Aug 2024 11:56:22 -0600 Subject: [PATCH 112/216] internal/rust: move rust string helpers into their own package Addresses IDX-142 Signed-off-by: philhassey --- internal/parser/cedar_tokenize.go | 10 +++++----- internal/{ => rust}/rust.go | 16 ++++++++-------- internal/{ => rust}/rust_test.go | 13 ++++++------- types/pattern.go | 4 ++-- 4 files changed, 21 insertions(+), 22 deletions(-) rename internal/{ => rust}/rust.go (91%) rename internal/{ => rust}/rust_test.go (94%) diff --git a/internal/parser/cedar_tokenize.go b/internal/parser/cedar_tokenize.go index 3809047b..773fef13 100644 --- a/internal/parser/cedar_tokenize.go +++ b/internal/parser/cedar_tokenize.go @@ -8,7 +8,7 @@ import ( "strings" "unicode/utf8" - "github.com/cedar-policy/cedar-go/internal" + "github.com/cedar-policy/cedar-go/internal/rust" ) //go:generate moq -pkg parser -fmt goimports -out tokenize_mocks_test.go . reader @@ -54,7 +54,7 @@ func (t Token) stringValue() (string, error) { s = strings.TrimPrefix(s, "\"") s = strings.TrimSuffix(s, "\"") b := []byte(s) - res, _, err := internal.RustUnquote(b, false) + res, _, err := rust.RustUnquote(b, false) return res, err } @@ -276,7 +276,7 @@ func (s *scanner) scanIdentifier() rune { } func (s *scanner) scanInteger(ch rune) rune { - for internal.IsDecimal(ch) { + for rust.IsDecimal(ch) { ch = s.next() } return ch @@ -284,7 +284,7 @@ func (s *scanner) scanInteger(ch rune) rune { func (s *scanner) scanHexDigits(ch rune, min, max int) rune { n := 0 - for n < max && internal.IsHexadecimal(ch) { + for n < max && rust.IsHexadecimal(ch) { ch = s.next() n++ } @@ -453,7 +453,7 @@ redo: case isIdentRune(ch, true): ch = s.scanIdentifier() tt = TokenIdent - case internal.IsDecimal(ch): + case rust.IsDecimal(ch): ch = s.scanInteger(ch) tt = TokenInt case ch == '"': diff --git a/internal/rust.go b/internal/rust/rust.go similarity index 91% rename from internal/rust.go rename to internal/rust/rust.go index 5b756fb1..a514c56b 100644 --- a/internal/rust.go +++ b/internal/rust/rust.go @@ -1,4 +1,4 @@ -package internal +package rust import ( "fmt" @@ -24,7 +24,7 @@ func parseHexEscape(b []byte, i int) (rune, int, error) { if !IsHexadecimal(ch) { return 0, i, fmt.Errorf("bad hex escape sequence") } - res := DigitVal(ch) + res := digitVal(ch) ch, i, err = nextRune(b, i) if err != nil { return 0, i, err @@ -32,14 +32,14 @@ func parseHexEscape(b []byte, i int) (rune, int, error) { if !IsHexadecimal(ch) { return 0, i, fmt.Errorf("bad hex escape sequence") } - res = 16*res + DigitVal(ch) + res = 16*res + digitVal(ch) if res > 127 { return 0, i, fmt.Errorf("bad hex escape sequence") } return rune(res), i, nil } -func ParseUnicodeEscape(b []byte, i int) (rune, int, error) { +func parseUnicodeEscape(b []byte, i int) (rune, int, error) { var ch rune var err error @@ -64,7 +64,7 @@ func ParseUnicodeEscape(b []byte, i int) (rune, int, error) { if !IsHexadecimal(ch) { return 0, i, fmt.Errorf("bad unicode escape sequence") } - res = 16*res + DigitVal(ch) + res = 16*res + digitVal(ch) digits++ } @@ -75,7 +75,7 @@ func ParseUnicodeEscape(b []byte, i int) (rune, int, error) { return rune(res), i, nil } -func Unquote(s string) (string, error) { +func unquote(s string) (string, error) { s = strings.TrimPrefix(s, "\"") s = strings.TrimSuffix(s, "\"") res, _, err := RustUnquote([]byte(s), false) @@ -126,7 +126,7 @@ func RustUnquote(b []byte, star bool) (string, []byte, error) { } sb.WriteRune(ch) case 'u': - ch, i, err = ParseUnicodeEscape(b, i) + ch, i, err = parseUnicodeEscape(b, i) if err != nil { return "", nil, err } @@ -150,7 +150,7 @@ func IsHexadecimal(ch rune) bool { func lower(ch rune) rune { return ('a' - 'A') | ch } // returns lower-case ch iff ch is ASCII letter func IsDecimal(ch rune) bool { return '0' <= ch && ch <= '9' } -func DigitVal(ch rune) int { +func digitVal(ch rune) int { switch { case '0' <= ch && ch <= '9': return int(ch - '0') diff --git a/internal/rust_test.go b/internal/rust/rust_test.go similarity index 94% rename from internal/rust_test.go rename to internal/rust/rust_test.go index da3611e4..a109a55a 100644 --- a/internal/rust_test.go +++ b/internal/rust/rust_test.go @@ -1,9 +1,8 @@ -package internal_test +package rust import ( "testing" - "github.com/cedar-policy/cedar-go/internal" "github.com/cedar-policy/cedar-go/internal/testutil" ) @@ -24,7 +23,7 @@ func TestParseUnicodeEscape(t *testing.T) { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() - out, n, err := internal.ParseUnicodeEscape(tt.in, 0) + out, n, err := parseUnicodeEscape(tt.in, 0) testutil.Equals(t, out, tt.out) testutil.Equals(t, n, tt.outN) tt.err(t, err) @@ -46,7 +45,7 @@ func TestUnquote(t *testing.T) { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() - out, err := internal.Unquote(tt.in) + out, err := unquote(tt.in) testutil.Equals(t, out, tt.out) tt.err(t, err) }) @@ -108,7 +107,7 @@ func TestRustUnquote(t *testing.T) { tt := tt t.Run(tt.input, func(t *testing.T) { t.Parallel() - got, rem, err := internal.RustUnquote([]byte(tt.input), false) + got, rem, err := RustUnquote([]byte(tt.input), false) if err != nil { testutil.Equals(t, tt.wantOk, false) testutil.Equals(t, err.Error(), tt.wantErr) @@ -179,7 +178,7 @@ func TestRustUnquote(t *testing.T) { tt := tt t.Run(tt.input, func(t *testing.T) { t.Parallel() - got, rem, err := internal.RustUnquote([]byte(tt.input), true) + got, rem, err := RustUnquote([]byte(tt.input), true) if err != nil { testutil.Equals(t, tt.wantOk, false) testutil.Equals(t, err.Error(), tt.wantErr) @@ -209,7 +208,7 @@ func TestDigitVal(t *testing.T) { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() - out := internal.DigitVal(tt.in) + out := digitVal(tt.in) testutil.Equals(t, out, tt.out) }) } diff --git a/types/pattern.go b/types/pattern.go index aa32df32..26d812d6 100644 --- a/types/pattern.go +++ b/types/pattern.go @@ -5,7 +5,7 @@ import ( "strconv" "strings" - "github.com/cedar-policy/cedar-go/internal" + "github.com/cedar-policy/cedar-go/internal/rust" ) type PatternComponent struct { @@ -138,7 +138,7 @@ func ParsePattern(s string) (Pattern, error) { b = b[1:] comp.Wildcard = true } - comp.Literal, b, err = internal.RustUnquote(b, true) + comp.Literal, b, err = rust.RustUnquote(b, true) if err != nil { return Pattern{}, err } From dd37c581cf448643fbe27618c79e8ba5fc3db640 Mon Sep 17 00:00:00 2001 From: philhassey Date: Wed, 14 Aug 2024 12:06:18 -0600 Subject: [PATCH 113/216] internal/rust: remove stutter to appease linter Addresses IDX-142 Signed-off-by: philhassey --- internal/parser/cedar_tokenize.go | 2 +- internal/rust/rust.go | 4 ++-- internal/rust/rust_test.go | 4 ++-- types/pattern.go | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/internal/parser/cedar_tokenize.go b/internal/parser/cedar_tokenize.go index 773fef13..7a034df0 100644 --- a/internal/parser/cedar_tokenize.go +++ b/internal/parser/cedar_tokenize.go @@ -54,7 +54,7 @@ func (t Token) stringValue() (string, error) { s = strings.TrimPrefix(s, "\"") s = strings.TrimSuffix(s, "\"") b := []byte(s) - res, _, err := rust.RustUnquote(b, false) + res, _, err := rust.Unquote(b, false) return res, err } diff --git a/internal/rust/rust.go b/internal/rust/rust.go index a514c56b..8203ca53 100644 --- a/internal/rust/rust.go +++ b/internal/rust/rust.go @@ -78,11 +78,11 @@ func parseUnicodeEscape(b []byte, i int) (rune, int, error) { func unquote(s string) (string, error) { s = strings.TrimPrefix(s, "\"") s = strings.TrimSuffix(s, "\"") - res, _, err := RustUnquote([]byte(s), false) + res, _, err := Unquote([]byte(s), false) return res, err } -func RustUnquote(b []byte, star bool) (string, []byte, error) { +func Unquote(b []byte, star bool) (string, []byte, error) { var sb strings.Builder var ch rune var err error diff --git a/internal/rust/rust_test.go b/internal/rust/rust_test.go index a109a55a..99b60cb1 100644 --- a/internal/rust/rust_test.go +++ b/internal/rust/rust_test.go @@ -107,7 +107,7 @@ func TestRustUnquote(t *testing.T) { tt := tt t.Run(tt.input, func(t *testing.T) { t.Parallel() - got, rem, err := RustUnquote([]byte(tt.input), false) + got, rem, err := Unquote([]byte(tt.input), false) if err != nil { testutil.Equals(t, tt.wantOk, false) testutil.Equals(t, err.Error(), tt.wantErr) @@ -178,7 +178,7 @@ func TestRustUnquote(t *testing.T) { tt := tt t.Run(tt.input, func(t *testing.T) { t.Parallel() - got, rem, err := RustUnquote([]byte(tt.input), true) + got, rem, err := Unquote([]byte(tt.input), true) if err != nil { testutil.Equals(t, tt.wantOk, false) testutil.Equals(t, err.Error(), tt.wantErr) diff --git a/types/pattern.go b/types/pattern.go index 26d812d6..4d5b05fd 100644 --- a/types/pattern.go +++ b/types/pattern.go @@ -138,7 +138,7 @@ func ParsePattern(s string) (Pattern, error) { b = b[1:] comp.Wildcard = true } - comp.Literal, b, err = rust.RustUnquote(b, true) + comp.Literal, b, err = rust.Unquote(b, true) if err != nil { return Pattern{}, err } From 40d779c54e18be2a6798205b7a6457bb2b67eb8b Mon Sep 17 00:00:00 2001 From: philhassey Date: Wed, 14 Aug 2024 12:38:29 -0600 Subject: [PATCH 114/216] internal/json: simplified policy embed Addresses IDX-142 Signed-off-by: philhassey --- internal/json/json_marshal.go | 19 +++++++------- internal/json/json_test.go | 4 +-- internal/json/json_unmarshal.go | 45 ++++++++++++++++----------------- 3 files changed, 34 insertions(+), 34 deletions(-) diff --git a/internal/json/json_marshal.go b/internal/json/json_marshal.go index cf6899af..abe68e89 100644 --- a/internal/json/json_marshal.go +++ b/internal/json/json_marshal.go @@ -93,13 +93,8 @@ func extToJSON(dest *extensionCallJSON, name string, src types.Value) error { func extCallToJSON(dest extensionCallJSON, src ast.NodeTypeExtensionCall) error { jsonArgs := arrayJSON{} - for _, n := range src.Args { - argNode := &nodeJSON{} - err := argNode.FromNode(n) - if err != nil { - return err - } - jsonArgs = append(jsonArgs, *argNode) + if err := arrayToJSON(&jsonArgs, src.Args); err != nil { + return err } dest[string(src.Name)] = jsonArgs return nil @@ -312,8 +307,14 @@ func (p *patternComponentJSON) MarshalJSON() ([]byte, error) { return json.Marshal(p.Literal) } -type Policy struct { - ast.Policy +type Policy ast.Policy + +func wrapPolicy(p *ast.Policy) *Policy { + return (*Policy)(p) +} + +func (p *Policy) unwrap() *ast.Policy { + return (*ast.Policy)(p) } func (p *Policy) MarshalJSON() ([]byte, error) { diff --git a/internal/json/json_test.go b/internal/json/json_test.go index 13a7ce2f..38cd5c05 100644 --- a/internal/json/json_test.go +++ b/internal/json/json_test.go @@ -485,7 +485,7 @@ func TestUnmarshalJSON(t *testing.T) { if err != nil { return } - testutil.Equals(t, p.Policy, *tt.want) + testutil.Equals(t, p.unwrap(), tt.want) b, err := json.Marshal(&p) testutil.OK(t, err) normInput := testNormalizeJSON(t, tt.input) @@ -535,7 +535,7 @@ func TestMarshalJSON(t *testing.T) { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() - pp := &Policy{Policy: *tt.input} + pp := wrapPolicy(tt.input) b, err := json.Marshal(pp) tt.errFunc(t, err) if err != nil { diff --git a/internal/json/json_unmarshal.go b/internal/json/json_unmarshal.go index 47eb9892..36919915 100644 --- a/internal/json/json_unmarshal.go +++ b/internal/json/json_unmarshal.go @@ -130,23 +130,24 @@ func (j recordJSON) ToNode() (ast.Node, error) { func (e extensionCallJSON) ToNode() (ast.Node, error) { if len(e) != 1 { - return ast.Node{}, fmt.Errorf("unexpected number of extension methods in node: %v", len(e)) + return ast.Node{}, fmt.Errorf("unexpected number of extensions in node: %v", len(e)) } - for k, v := range e { - if len(v) == 0 { - return ast.Node{}, fmt.Errorf("extension method '%v' must have at least one argument", k) - } - var argNodes []ast.Node - for _, n := range v { - node, err := n.ToNode() - if err != nil { - return ast.Node{}, fmt.Errorf("error in extension method argument: %w", err) - } - argNodes = append(argNodes, node) + var k string + var v arrayJSON + for k, v = range e { + } + if len(v) == 0 { + return ast.Node{}, fmt.Errorf("extension '%v' must have at least one argument", k) + } + var argNodes []ast.Node + for _, n := range v { + node, err := n.ToNode() + if err != nil { + return ast.Node{}, fmt.Errorf("error in extension arg: %w", err) } - return ast.NewExtensionCall(types.String(k), argNodes...), nil + argNodes = append(argNodes, node) } - panic("unreachable code") + return ast.NewExtensionCall(types.String(k), argNodes...), nil } func (j nodeJSON) ToNode() (ast.Node, error) { @@ -171,7 +172,7 @@ func (j nodeJSON) ToNode() (ast.Node, error) { case "context": return ast.Context(), nil } - return ast.Node{}, fmt.Errorf("unknown var: %v", j.Var) + return ast.Node{}, fmt.Errorf("unknown variable: %v", j.Var) // Slot // Unknown @@ -241,11 +242,9 @@ func (j nodeJSON) ToNode() (ast.Node, error) { return j.Record.ToNode() // Any other method: lessThan, lessThanOrEqual, greaterThan, greaterThanOrEqual, isIpv4, isIpv6, isLoopback, isMulticast, isInRange - case j.ExtensionCall != nil: + default: return j.ExtensionCall.ToNode() } - - return ast.Node{}, fmt.Errorf("unknown node") } func (n *nodeJSON) UnmarshalJSON(b []byte) error { @@ -288,14 +287,14 @@ func (p *Policy) UnmarshalJSON(b []byte) error { } switch j.Effect { case "permit": - p.Policy = *ast.Permit() + *(p.unwrap()) = *ast.Permit() case "forbid": - p.Policy = *ast.Forbid() + *(p.unwrap()) = *ast.Forbid() default: return fmt.Errorf("unknown effect: %v", j.Effect) } for k, v := range j.Annotations { - p.Annotate(types.String(k), types.String(v)) + p.unwrap().Annotate(types.String(k), types.String(v)) } var err error p.Principal, err = j.Principal.ToNode(ast.Scope(ast.NewPrincipalNode())) @@ -317,9 +316,9 @@ func (p *Policy) UnmarshalJSON(b []byte) error { } switch c.Kind { case "when": - p.When(n) + p.unwrap().When(n) case "unless": - p.Unless(n) + p.unwrap().Unless(n) default: return fmt.Errorf("unknown condition kind: %v", c.Kind) } From 248253e1f3781ea0b66688f76af4920ff10c1367 Mon Sep 17 00:00:00 2001 From: philhassey Date: Wed, 14 Aug 2024 12:59:51 -0600 Subject: [PATCH 115/216] internal/json: restructure code to be less error-prone Addresses IDX-142 Signed-off-by: philhassey --- internal/json/json.go | 15 ++- internal/json/json_marshal.go | 229 +++++++++++++++----------------- internal/json/json_unmarshal.go | 6 +- 3 files changed, 118 insertions(+), 132 deletions(-) diff --git a/internal/json/json.go b/internal/json/json.go index a639e1c7..7de5a794 100644 --- a/internal/json/json.go +++ b/internal/json/json.go @@ -1,8 +1,6 @@ package json import ( - "encoding/json" - "github.com/cedar-policy/cedar-go/types" ) @@ -78,9 +76,20 @@ type recordJSON map[string]nodeJSON type extensionCallJSON map[string]arrayJSON +type valueJSON struct { + v types.Value +} + +func (e *valueJSON) MarshalJSON() ([]byte, error) { + return e.v.ExplicitMarshalJSON() +} +func (e *valueJSON) UnmarshalJSON(b []byte) error { + return types.UnmarshalJSON(b, &e.v) +} + type nodeJSON struct { // Value - Value *json.RawMessage `json:"Value,omitempty"` // could be any + Value *valueJSON `json:"Value,omitempty"` // could be any // Var Var *string `json:"Var,omitempty"` diff --git a/internal/json/json_marshal.go b/internal/json/json_marshal.go index abe68e89..d94c5555 100644 --- a/internal/json/json_marshal.go +++ b/internal/json/json_marshal.go @@ -8,113 +8,95 @@ import ( "github.com/cedar-policy/cedar-go/types" ) -func (s *scopeJSON) FromNode(src ast.IsScopeNode) error { +func (s *scopeJSON) FromNode(src ast.IsScopeNode) { switch t := src.(type) { case ast.ScopeTypeAll: s.Op = "All" - return nil + return case ast.ScopeTypeEq: s.Op = "==" e := t.Entity s.Entity = &e - return nil + return case ast.ScopeTypeIn: s.Op = "in" e := t.Entity s.Entity = &e - return nil + return case ast.ScopeTypeInSet: s.Op = "in" s.Entities = t.Entities - return nil + return case ast.ScopeTypeIs: s.Op = "is" s.EntityType = string(t.Type) - return nil + return case ast.ScopeTypeIsIn: s.Op = "is" s.EntityType = string(t.Type) s.In = &scopeInJSON{ Entity: t.Entity, } - return nil + return + default: + panic(fmt.Sprintf("unknown scope type %T", t)) } - return fmt.Errorf("unexpected scope node: %T", src) } func unaryToJSON(dest **unaryJSON, src ast.UnaryNode) error { n := ast.UnaryNode(src) res := &unaryJSON{} - if err := res.Arg.FromNode(n.Arg); err != nil { - return fmt.Errorf("error in arg: %w", err) - } + res.Arg.FromNode(n.Arg) *dest = res return nil } -func binaryToJSON(dest **binaryJSON, src ast.BinaryNode) error { +func binaryToJSON(dest **binaryJSON, src ast.BinaryNode) { n := ast.BinaryNode(src) res := &binaryJSON{} - if err := res.Left.FromNode(n.Left); err != nil { - return fmt.Errorf("error in left: %w", err) - } - if err := res.Right.FromNode(n.Right); err != nil { - return fmt.Errorf("error in right: %w", err) - } + res.Left.FromNode(n.Left) + res.Right.FromNode(n.Right) *dest = res - return nil } -func arrayToJSON(dest *arrayJSON, args []ast.IsNode) error { +func arrayToJSON(dest *arrayJSON, args []ast.IsNode) { res := arrayJSON{} for _, n := range args { var nn nodeJSON - if err := nn.FromNode(n); err != nil { - return fmt.Errorf("error in array: %w", err) - } + nn.FromNode(n) res = append(res, nn) } *dest = res - return nil } -func extToJSON(dest *extensionCallJSON, name string, src types.Value) error { +func extToJSON(dest *extensionCallJSON, name string, src types.Value) { res := arrayJSON{} - str := src.String() // TODO: is this the correct string? - b, _ := json.Marshal(string(str)) // error impossible + str := src.String() // TODO: is this the correct string? + val := valueJSON{v: types.String(str)} res = append(res, nodeJSON{ - Value: (*json.RawMessage)(&b), + Value: &val, }) *dest = extensionCallJSON{ name: res, } - return nil } -func extCallToJSON(dest extensionCallJSON, src ast.NodeTypeExtensionCall) error { +func extCallToJSON(dest extensionCallJSON, src ast.NodeTypeExtensionCall) { jsonArgs := arrayJSON{} - if err := arrayToJSON(&jsonArgs, src.Args); err != nil { - return err - } + arrayToJSON(&jsonArgs, src.Args) dest[string(src.Name)] = jsonArgs - return nil } -func strToJSON(dest **strJSON, src ast.StrOpNode) error { +func strToJSON(dest **strJSON, src ast.StrOpNode) { res := &strJSON{} - if err := res.Left.FromNode(src.Arg); err != nil { - return fmt.Errorf("error in left: %w", err) - } + res.Left.FromNode(src.Arg) res.Attr = string(src.Value) *dest = res - return nil } -func patternToJSON(dest **patternJSON, src ast.NodeTypeLike) error { +func patternToJSON(dest **patternJSON, src ast.NodeTypeLike){ res := &patternJSON{} - if err := res.Left.FromNode(src.Arg); err != nil { - return fmt.Errorf("error in left: %w", err) - } + res.Left.FromNode(src.Arg) for _, comp := range src.Value.Components { if comp.Wildcard { res.Pattern = append(res.Pattern, patternComponentJSON{Wildcard: true}) @@ -124,62 +106,43 @@ func patternToJSON(dest **patternJSON, src ast.NodeTypeLike) error { } } *dest = res - return nil } -func recordToJSON(dest *recordJSON, src ast.NodeTypeRecord) error { +func recordToJSON(dest *recordJSON, src ast.NodeTypeRecord) { res := recordJSON{} for _, kv := range src.Elements { var nn nodeJSON - if err := nn.FromNode(kv.Value); err != nil { - return err - } + nn.FromNode(kv.Value) res[string(kv.Key)] = nn } *dest = res - return nil } -func ifToJSON(dest **ifThenElseJSON, src ast.NodeTypeIf) error { +func ifToJSON(dest **ifThenElseJSON, src ast.NodeTypeIf) { res := &ifThenElseJSON{} - if err := res.If.FromNode(src.If); err != nil { - return fmt.Errorf("error in if: %w", err) - } - if err := res.Then.FromNode(src.Then); err != nil { - return fmt.Errorf("error in then: %w", err) - } - if err := res.Else.FromNode(src.Else); err != nil { - return fmt.Errorf("error in else: %w", err) - } + res.If.FromNode(src.If) + res.Then.FromNode(src.Then) + res.Else.FromNode(src.Else) *dest = res - return nil } -func isToJSON(dest **isJSON, src ast.NodeTypeIs) error { +func isToJSON(dest **isJSON, src ast.NodeTypeIs) { res := &isJSON{} - if err := res.Left.FromNode(src.Left); err != nil { - return fmt.Errorf("error in left: %w", err) - } + res.Left.FromNode(src.Left) res.EntityType = string(src.EntityType) *dest = res - return nil } -func isInToJSON(dest **isJSON, src ast.NodeTypeIsIn) error { +func isInToJSON(dest **isJSON, src ast.NodeTypeIsIn) { res := &isJSON{} - if err := res.Left.FromNode(src.Left); err != nil { - return fmt.Errorf("error in left: %w", err) - } + res.Left.FromNode(src.Left) res.EntityType = string(src.EntityType) res.In = &nodeJSON{} - if err := res.In.FromNode(src.Entity); err != nil { - return fmt.Errorf("error in entity: %w", err) - } + res.In.FromNode(src.Entity) *dest = res - return nil } -func (j *nodeJSON) FromNode(src ast.IsNode) error { +func (j *nodeJSON) FromNode(src ast.IsNode) { switch t := src.(type) { // Value // Value *json.RawMessage `json:"Value"` // could be any @@ -189,106 +152,132 @@ func (j *nodeJSON) FromNode(src ast.IsNode) error { // IP arrayJSON `json:"ip"` switch tt := t.Value.(type) { case types.Decimal: - return extToJSON(&j.ExtensionCall, "decimal", tt) + extToJSON(&j.ExtensionCall, "decimal", tt) + return case types.IPAddr: - return extToJSON(&j.ExtensionCall, "ip", tt) + extToJSON(&j.ExtensionCall, "ip", tt) + return } - b, err := t.Value.ExplicitMarshalJSON() - j.Value = (*json.RawMessage)(&b) - return err + val := valueJSON{v: t.Value} + j.Value = &val + return // Var // Var *string `json:"Var"` case ast.NodeTypeVariable: val := string(t.Name) j.Var = &val - return nil + return // ! or neg operators // Not *unaryJSON `json:"!"` // Negate *unaryJSON `json:"neg"` case ast.NodeTypeNot: - return unaryToJSON(&j.Not, t.UnaryNode) + unaryToJSON(&j.Not, t.UnaryNode) + return case ast.NodeTypeNegate: - return unaryToJSON(&j.Negate, t.UnaryNode) + unaryToJSON(&j.Negate, t.UnaryNode) + return // Binary operators: ==, !=, in, <, <=, >, >=, &&, ||, +, -, *, contains, containsAll, containsAny case ast.NodeTypeAdd: - return binaryToJSON(&j.Plus, t.BinaryNode) + binaryToJSON(&j.Plus, t.BinaryNode) + return case ast.NodeTypeAnd: - return binaryToJSON(&j.And, t.BinaryNode) + binaryToJSON(&j.And, t.BinaryNode) + return case ast.NodeTypeContains: - return binaryToJSON(&j.Contains, t.BinaryNode) + binaryToJSON(&j.Contains, t.BinaryNode) + return case ast.NodeTypeContainsAll: - return binaryToJSON(&j.ContainsAll, t.BinaryNode) + binaryToJSON(&j.ContainsAll, t.BinaryNode) + return case ast.NodeTypeContainsAny: - return binaryToJSON(&j.ContainsAny, t.BinaryNode) + binaryToJSON(&j.ContainsAny, t.BinaryNode) + return case ast.NodeTypeEquals: - return binaryToJSON(&j.Equals, t.BinaryNode) + binaryToJSON(&j.Equals, t.BinaryNode) + return case ast.NodeTypeGreaterThan: - return binaryToJSON(&j.GreaterThan, t.BinaryNode) + binaryToJSON(&j.GreaterThan, t.BinaryNode) + return case ast.NodeTypeGreaterThanOrEqual: - return binaryToJSON(&j.GreaterThanOrEqual, t.BinaryNode) + binaryToJSON(&j.GreaterThanOrEqual, t.BinaryNode) + return case ast.NodeTypeIn: - return binaryToJSON(&j.In, t.BinaryNode) + binaryToJSON(&j.In, t.BinaryNode) + return case ast.NodeTypeLessThan: - return binaryToJSON(&j.LessThan, t.BinaryNode) + binaryToJSON(&j.LessThan, t.BinaryNode) + return case ast.NodeTypeLessThanOrEqual: - return binaryToJSON(&j.LessThanOrEqual, t.BinaryNode) + binaryToJSON(&j.LessThanOrEqual, t.BinaryNode) + return case ast.NodeTypeMult: - return binaryToJSON(&j.Times, t.BinaryNode) + binaryToJSON(&j.Times, t.BinaryNode) + return case ast.NodeTypeNotEquals: - return binaryToJSON(&j.NotEquals, t.BinaryNode) + binaryToJSON(&j.NotEquals, t.BinaryNode) + return case ast.NodeTypeOr: - return binaryToJSON(&j.Or, t.BinaryNode) + binaryToJSON(&j.Or, t.BinaryNode) + return case ast.NodeTypeSub: - return binaryToJSON(&j.Minus, t.BinaryNode) + binaryToJSON(&j.Minus, t.BinaryNode) + return // ., has // Access *strJSON `json:"."` // Has *strJSON `json:"has"` case ast.NodeTypeAccess: - return strToJSON(&j.Access, t.StrOpNode) + strToJSON(&j.Access, t.StrOpNode) + return case ast.NodeTypeHas: - return strToJSON(&j.Has, t.StrOpNode) + strToJSON(&j.Has, t.StrOpNode) + return // is case ast.NodeTypeIs: - return isToJSON(&j.Is, t) + isToJSON(&j.Is, t) + return case ast.NodeTypeIsIn: - return isInToJSON(&j.Is, t) + isInToJSON(&j.Is, t) + return // like // Like *strJSON `json:"like"` case ast.NodeTypeLike: - return patternToJSON(&j.Like, t) + patternToJSON(&j.Like, t) + return // if-then-else // IfThenElse *ifThenElseJSON `json:"if-then-else"` case ast.NodeTypeIf: - return ifToJSON(&j.IfThenElse, t) + ifToJSON(&j.IfThenElse, t) + return // Set // Set arrayJSON `json:"Set"` case ast.NodeTypeSet: - return arrayToJSON(&j.Set, t.Elements) + arrayToJSON(&j.Set, t.Elements) + return // Record // Record recordJSON `json:"Record"` case ast.NodeTypeRecord: - return recordToJSON(&j.Record, t) + recordToJSON(&j.Record, t) + return // Any other method: ip, decimal, lessThan, lessThanOrEqual, greaterThan, greaterThanOrEqual, isIpv4, isIpv6, isLoopback, isMulticast, isInRange // ExtensionMethod map[string]arrayJSON `json:"-"` case ast.NodeTypeExtensionCall: j.ExtensionCall = extensionCallJSON{} - return extCallToJSON(j.ExtensionCall, t) + extCallToJSON(j.ExtensionCall, t) + return + default: + panic(fmt.Sprintf("unknown node type %T", t)) + } - // case ast.nodeTypeRecordEntry: - // case ast.nodeTypeEntityType: - // case ast.nodeTypeAnnotation: - // case ast.nodeTypeWhen: - // case ast.nodeTypeUnless: - return fmt.Errorf("unknown node type: %T", src) + } func (j *nodeJSON) MarshalJSON() ([]byte, error) { @@ -329,24 +318,16 @@ func (p *Policy) MarshalJSON() ([]byte, error) { for _, a := range p.Annotations { j.Annotations[string(a.Key)] = string(a.Value) } - if err := j.Principal.FromNode(p.Principal); err != nil { - return nil, fmt.Errorf("error in principal: %w", err) - } - if err := j.Action.FromNode(p.Action); err != nil { - return nil, fmt.Errorf("error in action: %w", err) - } - if err := j.Resource.FromNode(p.Resource); err != nil { - return nil, fmt.Errorf("error in resource: %w", err) - } + j.Principal.FromNode(p.Principal) + j.Action.FromNode(p.Action) + j.Resource.FromNode(p.Resource) for _, c := range p.Conditions { var cond conditionJSON cond.Kind = "when" if c.Condition == ast.ConditionUnless { cond.Kind = "unless" } - if err := cond.Body.FromNode(c.Body); err != nil { - return nil, fmt.Errorf("error in condition: %w", err) - } + cond.Body.FromNode(c.Body) j.Conditions = append(j.Conditions, cond) } return json.Marshal(j) diff --git a/internal/json/json_unmarshal.go b/internal/json/json_unmarshal.go index 36919915..46c87bc4 100644 --- a/internal/json/json_unmarshal.go +++ b/internal/json/json_unmarshal.go @@ -154,11 +154,7 @@ func (j nodeJSON) ToNode() (ast.Node, error) { switch { // Value case j.Value != nil: - var v types.Value - if err := types.UnmarshalJSON(*j.Value, &v); err != nil { - return ast.Node{}, fmt.Errorf("error unmarshalling value: %w", err) - } - return ast.NewValueNode(v), nil + return ast.NewValueNode(j.Value.v), nil // Var case j.Var != nil: From 163afe22fbc311aeabd8c60a074ab33af67315bc Mon Sep 17 00:00:00 2001 From: philhassey Date: Wed, 14 Aug 2024 13:33:45 -0600 Subject: [PATCH 116/216] cedar: fix type casting Addresses IDX-142 Signed-off-by: philhassey --- policy.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/policy.go b/policy.go index ac91ddf6..44d424f4 100644 --- a/policy.go +++ b/policy.go @@ -54,7 +54,7 @@ const ( // // [Cedar documentation]: https://docs.cedarpolicy.com/policies/json-format.html func (p *Policy) MarshalJSON() ([]byte, error) { - jsonPolicy := &json.Policy{Policy: *p.ast} + jsonPolicy := (*json.Policy)(p.ast) return jsonPolicy.MarshalJSON() } @@ -70,8 +70,8 @@ func (p *Policy) UnmarshalJSON(b []byte) error { Position: Position{}, Annotations: newAnnotationsFromSlice(jsonPolicy.Annotations), Effect: Effect(jsonPolicy.Effect), - eval: eval.Compile(jsonPolicy.Policy), - ast: &jsonPolicy.Policy, + eval: eval.Compile((internalast.Policy)(jsonPolicy)), + ast: (*internalast.Policy)(&jsonPolicy), } return nil } From d7266c1c97ccd6c6ae96519eaa2fa56361dabedd Mon Sep 17 00:00:00 2001 From: philhassey Date: Wed, 14 Aug 2024 13:50:36 -0600 Subject: [PATCH 117/216] internal/json: fix spacing Addresses IDX-142 Signed-off-by: philhassey --- internal/json/json_marshal.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/internal/json/json_marshal.go b/internal/json/json_marshal.go index d94c5555..9dfe19f2 100644 --- a/internal/json/json_marshal.go +++ b/internal/json/json_marshal.go @@ -94,7 +94,7 @@ func strToJSON(dest **strJSON, src ast.StrOpNode) { *dest = res } -func patternToJSON(dest **patternJSON, src ast.NodeTypeLike){ +func patternToJSON(dest **patternJSON, src ast.NodeTypeLike) { res := &patternJSON{} res.Left.FromNode(src.Arg) for _, comp := range src.Value.Components { @@ -108,7 +108,7 @@ func patternToJSON(dest **patternJSON, src ast.NodeTypeLike){ *dest = res } -func recordToJSON(dest *recordJSON, src ast.NodeTypeRecord) { +func recordToJSON(dest *recordJSON, src ast.NodeTypeRecord) { res := recordJSON{} for _, kv := range src.Elements { var nn nodeJSON @@ -118,7 +118,7 @@ func recordToJSON(dest *recordJSON, src ast.NodeTypeRecord) { *dest = res } -func ifToJSON(dest **ifThenElseJSON, src ast.NodeTypeIf) { +func ifToJSON(dest **ifThenElseJSON, src ast.NodeTypeIf) { res := &ifThenElseJSON{} res.If.FromNode(src.If) res.Then.FromNode(src.Then) @@ -126,14 +126,14 @@ func ifToJSON(dest **ifThenElseJSON, src ast.NodeTypeIf) { *dest = res } -func isToJSON(dest **isJSON, src ast.NodeTypeIs) { +func isToJSON(dest **isJSON, src ast.NodeTypeIs) { res := &isJSON{} res.Left.FromNode(src.Left) res.EntityType = string(src.EntityType) *dest = res } -func isInToJSON(dest **isJSON, src ast.NodeTypeIsIn) { +func isInToJSON(dest **isJSON, src ast.NodeTypeIsIn) { res := &isJSON{} res.Left.FromNode(src.Left) res.EntityType = string(src.EntityType) From a1482632eb331ff7a9a0c9b243e796a74145b2ce Mon Sep 17 00:00:00 2001 From: philhassey Date: Wed, 14 Aug 2024 13:52:12 -0600 Subject: [PATCH 118/216] internal/json: appease linter Addresses IDX-142 Signed-off-by: philhassey --- internal/json/json_unmarshal.go | 1 + 1 file changed, 1 insertion(+) diff --git a/internal/json/json_unmarshal.go b/internal/json/json_unmarshal.go index 46c87bc4..6f30f09c 100644 --- a/internal/json/json_unmarshal.go +++ b/internal/json/json_unmarshal.go @@ -135,6 +135,7 @@ func (e extensionCallJSON) ToNode() (ast.Node, error) { var k string var v arrayJSON for k, v = range e { + _, _ = k, v } if len(v) == 0 { return ast.Node{}, fmt.Errorf("extension '%v' must have at least one argument", k) From 231cd1ab00b62f9baad435d44cb6c07aec35078d Mon Sep 17 00:00:00 2001 From: philhassey Date: Wed, 14 Aug 2024 13:59:23 -0600 Subject: [PATCH 119/216] internal/json: add coverage for marshal type panics Addresses IDX-142 Signed-off-by: philhassey --- internal/json/json_test.go | 20 ++++++++++++++++++++ internal/testutil/testutil.go | 9 +++++++++ 2 files changed, 29 insertions(+) diff --git a/internal/json/json_test.go b/internal/json/json_test.go index 38cd5c05..dbf49b39 100644 --- a/internal/json/json_test.go +++ b/internal/json/json_test.go @@ -565,3 +565,23 @@ func mustParseIPAddr(v string) types.IPAddr { res, _ := types.ParseIPAddr(v) return res } + +func TestMarshalPanics(t *testing.T) { + t.Parallel() + t.Run("nilScope", func(t *testing.T) { + t.Parallel() + testutil.AssertPanic(t, func() { + s := scopeJSON{} + var v ast.IsScopeNode + s.FromNode(v) + }) + }) + t.Run("nilNode", func(t *testing.T) { + t.Parallel() + testutil.AssertPanic(t, func() { + s := nodeJSON{} + var v ast.IsNode + s.FromNode(v) + }) + }) +} diff --git a/internal/testutil/testutil.go b/internal/testutil/testutil.go index 16e407d9..73ae253f 100644 --- a/internal/testutil/testutil.go +++ b/internal/testutil/testutil.go @@ -49,3 +49,12 @@ func Must[T any](obj T, err error) T { } return obj } + +func AssertPanic(t *testing.T, f func()) { + defer func() { + if e := recover(); e == nil { + t.Fatal("expected panic, got nil") + } + }() + f() +} From 0a258694e879f5074c2215e10b35cf7bd9b0faf7 Mon Sep 17 00:00:00 2001 From: philhassey Date: Wed, 14 Aug 2024 15:15:04 -0600 Subject: [PATCH 120/216] internal/json: hit full coverage in JSON Addresses IDX-142 Signed-off-by: philhassey --- internal/json/json.go | 4 +- internal/json/json_marshal.go | 8 +- internal/json/json_test.go | 153 ++++++++++++++++++++++++++++++++ internal/json/json_unmarshal.go | 5 +- 4 files changed, 160 insertions(+), 10 deletions(-) diff --git a/internal/json/json.go b/internal/json/json.go index 7de5a794..35b5c812 100644 --- a/internal/json/json.go +++ b/internal/json/json.go @@ -74,7 +74,7 @@ type arrayJSON []nodeJSON type recordJSON map[string]nodeJSON -type extensionCallJSON map[string]arrayJSON +type extensionJSON map[string]arrayJSON type valueJSON struct { v types.Value @@ -138,5 +138,5 @@ type nodeJSON struct { Record recordJSON `json:"Record,omitempty"` // Any other method: decimal, ip, lessThan, lessThanOrEqual, greaterThan, greaterThanOrEqual, isIpv4, isIpv6, isLoopback, isMulticast, isInRange - ExtensionCall extensionCallJSON `json:"-"` + ExtensionCall extensionJSON `json:"-"` } diff --git a/internal/json/json_marshal.go b/internal/json/json_marshal.go index 9dfe19f2..d148ccce 100644 --- a/internal/json/json_marshal.go +++ b/internal/json/json_marshal.go @@ -69,19 +69,19 @@ func arrayToJSON(dest *arrayJSON, args []ast.IsNode) { *dest = res } -func extToJSON(dest *extensionCallJSON, name string, src types.Value) { +func extToJSON(dest *extensionJSON, name string, src types.Value) { res := arrayJSON{} str := src.String() // TODO: is this the correct string? val := valueJSON{v: types.String(str)} res = append(res, nodeJSON{ Value: &val, }) - *dest = extensionCallJSON{ + *dest = extensionJSON{ name: res, } } -func extCallToJSON(dest extensionCallJSON, src ast.NodeTypeExtensionCall) { +func extCallToJSON(dest extensionJSON, src ast.NodeTypeExtensionCall) { jsonArgs := arrayJSON{} arrayToJSON(&jsonArgs, src.Args) dest[string(src.Name)] = jsonArgs @@ -270,7 +270,7 @@ func (j *nodeJSON) FromNode(src ast.IsNode) { // Any other method: ip, decimal, lessThan, lessThanOrEqual, greaterThan, greaterThanOrEqual, isIpv4, isIpv6, isLoopback, isMulticast, isInRange // ExtensionMethod map[string]arrayJSON `json:"-"` case ast.NodeTypeExtensionCall: - j.ExtensionCall = extensionCallJSON{} + j.ExtensionCall = extensionJSON{} extCallToJSON(j.ExtensionCall, t) return default: diff --git a/internal/json/json_test.go b/internal/json/json_test.go index dbf49b39..3cdceca9 100644 --- a/internal/json/json_test.go +++ b/internal/json/json_test.go @@ -585,3 +585,156 @@ func TestMarshalPanics(t *testing.T) { }) }) } + +func TestUnmarshalErrors(t *testing.T) { + t.Parallel() + tests := []struct { + name string + input string + errFunc func(testing.TB, error) + }{ + { + "effect", + `{"effect":"unknown","principal":{"op":"=="},"action":{"op":"All"},"resource":{"op":"All"}}`, + testutil.Error, + }, + { + "scopeEqMissingEntity", + `{"effect":"permit","principal":{"op":"=="},"action":{"op":"All"},"resource":{"op":"All"}}`, + testutil.Error, + }, + { + "scopeUnknownOp", + `{"effect":"permit","principal":{"op":"???"},"action":{"op":"All"},"resource":{"op":"All"}}`, + testutil.Error, + }, + { + "actionUnknownOp", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"???"},"resource":{"op":"All"}}`, + testutil.Error, + }, + { + "resourceUnknownOp", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"???"}}`, + testutil.Error, + }, + { + "conditionUnknown", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, + "conditions":[{"kind":"unknown","body":{"Value":24}}]}`, + testutil.Error, + }, + { + "binaryLeft", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, + "conditions":[{"kind":"when","body":{"&&":{"left":null,"right":{"Value":24}}}}]}`, + testutil.Error, + }, + { + "binaryRight", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, + "conditions":[{"kind":"when","body":{"&&":{"left":{"Value":24},"right":null}}}]}`, + testutil.Error, + }, + { + "unaryArg", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, + "conditions":[{"kind":"when","body":{"!":{"arg":null}}}]}`, + testutil.Error, + }, + { + "accessLeft", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, + "conditions":[{"kind":"when","body":{".":{"left":null,"attr":"key"}}}]}`, + testutil.Error, + }, + { + "patternLeft", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, + "conditions":[{"kind":"when","body":{"like":{"left":null,"pattern":["Wildcard"]}}}]}`, + testutil.Error, + }, + { + "patternWildcard", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, + "conditions":[{"kind":"when","body":{"like":{"left":{"Value":"text"},"pattern":["invalid"]}}}]}`, + testutil.Error, + }, + { + "isLeft", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, + "conditions":[{"kind":"when","body":{"is":{"left":null,"entity_type":"T"}}}]}`, + testutil.Error, + }, + { + "isIn", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, + "conditions":[{"kind":"when","body":{"is":{"left":{"Var":"resource"},"entity_type":"T","in":{"Value":null}}}}]}`, + testutil.Error, + }, + { + "ifErrThenElse", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, + "conditions":[{"kind":"when","body":{"if-then-else":{"if":{"Value":null},"then":{"Value":42},"else":{"Value":24}}}}]}`, + testutil.Error, + }, + { + "ifThenErrElse", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, + "conditions":[{"kind":"when","body":{"if-then-else":{"if":{"Value":true},"then":{"Value":null},"else":{"Value":24}}}}]}`, + testutil.Error, + }, + { + "ifThenElseErr", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, + "conditions":[{"kind":"when","body":{"if-then-else":{"if":{"Value":true},"then":{"Value":42},"else":{"Value":null}}}}]}`, + testutil.Error, + }, + { + "setErr", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, + "conditions":[{"kind":"when","body":{"Set":[{"Value":null},{"Value":"bananas"}]}}]}`, + testutil.Error, + }, + { + "recordErr", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, + "conditions":[{"kind":"when","body":{"Record":{"key":{"Value":null}}}}]}`, + testutil.Error, + }, + { + "extensionTooMany", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, + "conditions":[{"kind":"when","body":{"ip":[{"Value":"10.0.0.42/8"}],"pi":[{"Value":"3.14"}]}}]}`, + testutil.Error, + }, + { + "extensionArg", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, + "conditions":[{"kind":"when","body":{"ip":[{"Value":null}]}}]}`, + testutil.Error, + }, + { + "var", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, + "conditions":[{"kind":"when","body":{"Var":"unknown"}}]}`, + testutil.Error, + }, + { + "otherJSONerror", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, + "conditions":[{"kind":"when","body":42}]}`, + testutil.Error, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + var p Policy + err := json.Unmarshal([]byte(tt.input), &p) + tt.errFunc(t, err) + }) + } +} diff --git a/internal/json/json_unmarshal.go b/internal/json/json_unmarshal.go index 6f30f09c..6023001b 100644 --- a/internal/json/json_unmarshal.go +++ b/internal/json/json_unmarshal.go @@ -128,7 +128,7 @@ func (j recordJSON) ToNode() (ast.Node, error) { return ast.RecordNodes(nodes), nil } -func (e extensionCallJSON) ToNode() (ast.Node, error) { +func (e extensionJSON) ToNode() (ast.Node, error) { if len(e) != 1 { return ast.Node{}, fmt.Errorf("unexpected number of extensions in node: %v", len(e)) } @@ -137,9 +137,6 @@ func (e extensionCallJSON) ToNode() (ast.Node, error) { for k, v = range e { _, _ = k, v } - if len(v) == 0 { - return ast.Node{}, fmt.Errorf("extension '%v' must have at least one argument", k) - } var argNodes []ast.Node for _, n := range v { node, err := n.ToNode() From ce7d04311cb436fc88fbf0f665ba1dd755774b31 Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Wed, 14 Aug 2024 14:41:05 -0700 Subject: [PATCH 121/216] cedar-go: turn PolicySet into a struct that contains a map from PolicyID to *Policy Signed-off-by: philhassey --- authorize.go | 16 +++++++------- authorize_test.go | 2 +- corpus_test.go | 34 +++++++++++++++--------------- internal/parser/cedar_unmarshal.go | 2 +- internal/parser/policy.go | 4 +++- policy_set.go | 24 +++++++++++++++------ policy_set_test.go | 2 +- 7 files changed, 48 insertions(+), 36 deletions(-) diff --git a/authorize.go b/authorize.go index 1862c2cf..538a8b58 100644 --- a/authorize.go +++ b/authorize.go @@ -40,19 +40,19 @@ type Diagnostic struct { // An Error details the Policy index within a PolicySet, the Position within the // text document, and the resulting error message. type Error struct { - Policy int `json:"policy"` + PolicyID PolicyID `json:"policy"` Position Position `json:"position"` Message string `json:"message"` } func (e Error) String() string { - return fmt.Sprintf("while evaluating policy `policy%d`: %v", e.Policy, e.Message) + return fmt.Sprintf("while evaluating policy `%v`: %v", e.PolicyID, e.Message) } // A Reason details the Policy index within a PolicySet, and the Position within // the text document. type Reason struct { - Policy int `json:"policy"` + PolicyID PolicyID `json:"policy"` Position Position `json:"position"` } @@ -89,26 +89,26 @@ func (p PolicySet) IsAuthorized(entityMap entities.Entities, req Request) (Decis // - All policy should be run to collect errors // - For permit, all permits must be run to collect annotations // - For forbid, forbids must be run to collect annotations - for n, po := range p { + for id, po := range p.policies { v, err := po.eval.Eval(c) if err != nil { - diag.Errors = append(diag.Errors, Error{Policy: n, Position: po.Position, Message: err.Error()}) + diag.Errors = append(diag.Errors, Error{PolicyID: id, Position: po.Position, Message: err.Error()}) continue } vb, err := types.ValueToBool(v) if err != nil { // should never happen, maybe remove this case - diag.Errors = append(diag.Errors, Error{Policy: n, Position: po.Position, Message: err.Error()}) + diag.Errors = append(diag.Errors, Error{PolicyID: id, Position: po.Position, Message: err.Error()}) continue } if !vb { continue } if po.Effect == Forbid { - forbidReasons = append(forbidReasons, Reason{Policy: n, Position: po.Position}) + forbidReasons = append(forbidReasons, Reason{PolicyID: id, Position: po.Position}) gotForbid = true } else { - permitReasons = append(permitReasons, Reason{Policy: n, Position: po.Position}) + permitReasons = append(permitReasons, Reason{PolicyID: id, Position: po.Position}) gotPermit = true } } diff --git a/authorize_test.go b/authorize_test.go index 201ee083..3622d6be 100644 --- a/authorize_test.go +++ b/authorize_test.go @@ -725,7 +725,7 @@ func TestIsAuthorized(t *testing.T) { func TestError(t *testing.T) { t.Parallel() - e := Error{Policy: 42, Message: "bad error"} + e := Error{PolicyID: "policy42", Message: "bad error"} testutil.Equals(t, e.String(), "while evaluating policy `policy42`: bad error") } diff --git a/corpus_test.go b/corpus_test.go index 7edf9712..93ddef7c 100644 --- a/corpus_test.go +++ b/corpus_test.go @@ -172,14 +172,14 @@ func TestCorpus(t *testing.T) { } var errors []string for _, n := range diag.Errors { - errors = append(errors, fmt.Sprintf("policy%d", n.Policy)) + errors = append(errors, string(n.PolicyID)) } if !slices.Equal(errors, request.Errors) { t.Errorf("errors got %v want %v", errors, request.Errors) } var reasons []string for _, n := range diag.Reasons { - reasons = append(reasons, fmt.Sprintf("policy%d", n.Policy)) + reasons = append(reasons, string(n.PolicyID)) } if !slices.Equal(reasons, request.Reasons) { t.Errorf("reasons got %v want %v", reasons, request.Reasons) @@ -198,8 +198,8 @@ func TestCorpusRelated(t *testing.T) { policy string request Request decision Decision - reasons []int - errors []int + reasons []PolicyID + errors []PolicyID }{ { "0cb1ad7042508e708f1999284b634ed0f334bc00", @@ -213,7 +213,7 @@ func TestCorpusRelated(t *testing.T) { Request{Principal: types.NewEntityUID("a", "\u0000\u0000"), Action: types.NewEntityUID("Action", "action"), Resource: types.NewEntityUID("a", "\u0000\u0000")}, Deny, nil, - []int{0}, + []PolicyID{"policy0"}, }, { @@ -228,7 +228,7 @@ func TestCorpusRelated(t *testing.T) { Request{Principal: types.NewEntityUID("a", "\u0000\u0000"), Action: types.NewEntityUID("Action", "action"), Resource: types.NewEntityUID("a", "\u0000\u0000")}, Deny, nil, - []int{0}, + []PolicyID{"policy0"}, }, { "0cb1ad7042508e708f1999284b634ed0f334bc00/partial2", @@ -242,7 +242,7 @@ func TestCorpusRelated(t *testing.T) { Request{Principal: types.NewEntityUID("a", "\u0000\u0000"), Action: types.NewEntityUID("Action", "action"), Resource: types.NewEntityUID("a", "\u0000\u0000")}, Deny, nil, - []int{0}, + []PolicyID{"policy0"}, }, { @@ -257,7 +257,7 @@ func TestCorpusRelated(t *testing.T) { Request{Principal: types.NewEntityUID("a", "\u0000\u0000"), Action: types.NewEntityUID("Action", "action"), Resource: types.NewEntityUID("a", "\u0000\u0000")}, Deny, nil, - []int{0}, + []PolicyID{"policy0"}, }, { @@ -272,7 +272,7 @@ func TestCorpusRelated(t *testing.T) { Request{}, Deny, nil, - []int{0}, + []PolicyID{"policy0"}, }, { @@ -287,7 +287,7 @@ func TestCorpusRelated(t *testing.T) { Request{}, Deny, nil, - []int{0}, + []PolicyID{"policy0"}, }, {"48d0ba6537a3efe02112ba0f5a3daabdcad27b04", @@ -301,7 +301,7 @@ func TestCorpusRelated(t *testing.T) { Request{Principal: types.NewEntityUID("a", "\u0000\b\u0011\u0000R"), Action: types.NewEntityUID("Action", "action"), Resource: types.NewEntityUID("a", "\u0000\b\u0011\u0000R")}, Deny, nil, - []int{0}, + []PolicyID{"policy0"}, }, {"48d0ba6537a3efe02112ba0f5a3daabdcad27b04/simplified", @@ -315,7 +315,7 @@ func TestCorpusRelated(t *testing.T) { Request{}, Deny, nil, - []int{0}, + []PolicyID{"policy0"}, }, {name: "e91da4e6af5c73e27f5fb610d723dfa21635d10b", @@ -329,7 +329,7 @@ func TestCorpusRelated(t *testing.T) { request: Request{Principal: types.NewEntityUID("a", "\u0000\u0000(W\u0000\u0000\u0000"), Action: types.NewEntityUID("Action", "action"), Resource: types.NewEntityUID("a", "")}, decision: Deny, reasons: nil, - errors: []int{0}, + errors: []PolicyID{"policy0"}, }, } for _, tt := range tests { @@ -340,14 +340,14 @@ func TestCorpusRelated(t *testing.T) { testutil.OK(t, err) ok, diag := policy.IsAuthorized(entities2.Entities{}, tt.request) testutil.Equals(t, ok, tt.decision) - var reasons []int + var reasons []PolicyID for _, n := range diag.Reasons { - reasons = append(reasons, n.Policy) + reasons = append(reasons, n.PolicyID) } testutil.Equals(t, reasons, tt.reasons) - var errors []int + var errors []PolicyID for _, n := range diag.Errors { - errors = append(errors, n.Policy) + errors = append(errors, n.PolicyID) } testutil.Equals(t, errors, tt.errors) }) diff --git a/internal/parser/cedar_unmarshal.go b/internal/parser/cedar_unmarshal.go index dcf0fa51..2454d3b4 100644 --- a/internal/parser/cedar_unmarshal.go +++ b/internal/parser/cedar_unmarshal.go @@ -33,7 +33,7 @@ func (p *PolicySet) UnmarshalCedar(b []byte) error { return err } - policyName := fmt.Sprintf("policy%v", i) + policyName := PolicyID(fmt.Sprintf("policy%v", i)) policySet[policyName] = PolicySetEntry{Policy: policy, Position: pos} i++ } diff --git a/internal/parser/policy.go b/internal/parser/policy.go index 0e5c84b5..cc5b4418 100644 --- a/internal/parser/policy.go +++ b/internal/parser/policy.go @@ -2,7 +2,9 @@ package parser import "github.com/cedar-policy/cedar-go/internal/ast" -type PolicySet map[string]PolicySetEntry +type PolicyID string + +type PolicySet map[PolicyID]PolicySetEntry type PolicySetEntry struct { Policy Policy diff --git a/policy_set.go b/policy_set.go index c0b613d9..f61b8fe9 100644 --- a/policy_set.go +++ b/policy_set.go @@ -8,8 +8,12 @@ import ( "github.com/cedar-policy/cedar-go/internal/parser" ) +type PolicyID parser.PolicyID + // A PolicySet is a slice of policies. -type PolicySet []Policy +type PolicySet struct { + policies map[PolicyID]*Policy +} // NewPolicySet will create a PolicySet from the given text document with the // given file name used in Position data. If there is an error parsing the @@ -17,11 +21,11 @@ type PolicySet []Policy func NewPolicySet(fileName string, document []byte) (PolicySet, error) { var res parser.PolicySet if err := res.UnmarshalCedar(document); err != nil { - return nil, fmt.Errorf("parser error: %w", err) + return PolicySet{}, fmt.Errorf("parser error: %w", err) } - var policies PolicySet - for _, p := range res { - policies = append(policies, Policy{ + policyMap := make(map[PolicyID]*Policy, len(res)) + for name, p := range res { + policyMap[PolicyID(name)] = &Policy{ Position: Position{ Filename: fileName, Offset: p.Position.Offset, @@ -32,7 +36,13 @@ func NewPolicySet(fileName string, document []byte) (PolicySet, error) { Effect: Effect(p.Policy.Effect), eval: eval.Compile(p.Policy.Policy), ast: &p.Policy.Policy, - }) + } } - return policies, nil + return PolicySet{policies: policyMap}, nil +} + +// GetPolicy returns a pointer to the Policy with the given ID. If a policy with the given ID does not exist, nil is +// returned. +func (p PolicySet) GetPolicy(policyID PolicyID) *Policy { + return p.policies[policyID] } diff --git a/policy_set_test.go b/policy_set_test.go index e008afdb..fbdb4eb8 100644 --- a/policy_set_test.go +++ b/policy_set_test.go @@ -22,6 +22,6 @@ func TestNewPolicySet(t *testing.T) { t.Parallel() ps, err := NewPolicySet("policy.cedar", []byte(`@key("value") permit (principal, action, resource);`)) testutil.OK(t, err) - testutil.Equals(t, ps[0].Annotations, Annotations{"key": "value"}) + testutil.Equals(t, ps.GetPolicy("policy0").Annotations, Annotations{"key": "value"}) }) } From baedc99253dd290b0ba6fec759888807e809b4e4 Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Wed, 14 Aug 2024 15:18:22 -0700 Subject: [PATCH 122/216] cedar-go: add a way to create a PolicySet from a set of existing Policys Signed-off-by: philhassey --- policy_set.go | 20 ++++++++++++++++++++ policy_set_test.go | 29 +++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+) diff --git a/policy_set.go b/policy_set.go index f61b8fe9..27d73667 100644 --- a/policy_set.go +++ b/policy_set.go @@ -41,6 +41,26 @@ func NewPolicySet(fileName string, document []byte) (PolicySet, error) { return PolicySet{policies: policyMap}, nil } +// NewPolicySetFromPolicies will create a PolicySet from a slice of existing Policys. This constructor can be used to +// support the creation of a PolicySet from JSON-encoded policies or policies created via the ast package, like so: +// +// policy0 := NewPolicyFromAST(ast.Forbid()) +// +// var policy1 Policy +// _ = policy1.UnmarshalJSON( +// []byte(`{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}}`), +// )) +// +// ps := NewPolicySetFromPolicies([]*Policy{policy0, &policy1}) +func NewPolicySetFromPolicies(policies []*Policy) PolicySet { + policyMap := make(map[PolicyID]*Policy, len(policies)) + for i, p := range policies { + policyID := PolicyID(fmt.Sprintf("policy%d", i)) + policyMap[policyID] = p + } + return PolicySet{policies: policyMap} +} + // GetPolicy returns a pointer to the Policy with the given ID. If a policy with the given ID does not exist, nil is // returned. func (p PolicySet) GetPolicy(policyID PolicyID) *Policy { diff --git a/policy_set_test.go b/policy_set_test.go index fbdb4eb8..3de5bdb9 100644 --- a/policy_set_test.go +++ b/policy_set_test.go @@ -3,6 +3,7 @@ package cedar import ( "testing" + "github.com/cedar-policy/cedar-go/ast" "github.com/cedar-policy/cedar-go/internal/testutil" ) @@ -25,3 +26,31 @@ func TestNewPolicySet(t *testing.T) { testutil.Equals(t, ps.GetPolicy("policy0").Annotations, Annotations{"key": "value"}) }) } + +func TestNewPolicySetFromPolicies(t *testing.T) { + t.Parallel() + t.Run("empty slice", func(t *testing.T) { + t.Parallel() + + var policies []*Policy + ps := NewPolicySetFromPolicies(policies) + + testutil.Equals(t, ps.GetPolicy("policy0"), nil) + }) + t.Run("non-empty slice", func(t *testing.T) { + t.Parallel() + + policy0 := NewPolicyFromAST(ast.Forbid()) + + var policy1 Policy + testutil.OK(t, policy1.UnmarshalJSON( + []byte(`{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}}`), + )) + + ps := NewPolicySetFromPolicies([]*Policy{policy0, &policy1}) + + testutil.Equals(t, ps.GetPolicy("policy0"), policy0) + testutil.Equals(t, ps.GetPolicy("policy1"), &policy1) + testutil.Equals(t, ps.GetPolicy("policy2"), nil) + }) +} From 3938400a04d44fd93a76eb2845443b3b8b8943f1 Mon Sep 17 00:00:00 2001 From: philhassey Date: Wed, 14 Aug 2024 16:36:08 -0600 Subject: [PATCH 123/216] internal/ast: add coverage for internal ast package Addresses IDX-142 Signed-off-by: philhassey --- internal/ast/ast_test.go | 514 +++++++++++++++++++++++++++++++ internal/ast/internal_test.go | 33 ++ internal/ast/scope.go | 13 +- internal/eval/eval_compile.go | 6 +- internal/eval/eval_convert.go | 12 +- internal/parser/cedar_marshal.go | 18 +- 6 files changed, 571 insertions(+), 25 deletions(-) create mode 100644 internal/ast/ast_test.go create mode 100644 internal/ast/internal_test.go diff --git a/internal/ast/ast_test.go b/internal/ast/ast_test.go new file mode 100644 index 00000000..52f2bfaa --- /dev/null +++ b/internal/ast/ast_test.go @@ -0,0 +1,514 @@ +package ast_test + +import ( + "net/netip" + "testing" + + "github.com/cedar-policy/cedar-go/internal/ast" + "github.com/cedar-policy/cedar-go/internal/testutil" + "github.com/cedar-policy/cedar-go/types" +) + +// These tests serve mostly as examples of how to translate from Cedar text into programmatic AST construction. They +// don't verify anything. +func TestAstExamples(t *testing.T) { + t.Parallel() + + johnny := types.NewEntityUID("User", "johnny") + sow := types.NewEntityUID("Action", "sow") + cast := types.NewEntityUID("Action", "cast") + + // @example("one") + // permit ( + // principal == User::"johnny" + // action in [Action::"sow", Action::"cast"] + // resource + // ) + // when { true } + // unless { false }; + _ = ast.Annotation("example", "one"). + Permit(). + PrincipalIsIn("User", johnny). + ActionInSet(sow, cast). + When(ast.True()). + Unless(ast.False()) + + // @example("two") + // forbid (principal, action, resource) + // when { resource.tags.contains("private") } + // unless { resource in principal.allowed_resources }; + private := types.String("private") + _ = ast.Annotation("example", "two"). + Forbid(). + When( + ast.Resource().Access("tags").Contains(ast.String(private)), + ). + Unless( + ast.Resource().In(ast.Principal().Access("allowed_resources")), + ) + + // forbid (principal, action, resource) + // when { {x: "value"}.x == "value" } + // when { {x: 1 + context.fooCount}.x == 3 } + // when { [1, (2 + 3) * 4, context.fooCount].contains(1) }; + simpleRecord := types.Record{ + "x": types.String("value"), + } + _ = ast.Forbid(). + When( + ast.Record(simpleRecord).Access("x").Equals(ast.String("value")), + ). + When( + ast.RecordNodes(map[types.String]ast.Node{ + "x": ast.Long(1).Plus(ast.Context().Access("fooCount")), + }).Access("x").Equals(ast.Long(3)), + ). + When( + ast.SetNodes( + ast.Long(1), + ast.Long(2).Plus(ast.Long(3)).Times(ast.Long(4)), + ast.Context().Access("fooCount"), + ).Contains(ast.Long(1)), + ) +} + +func TestASTByTable(t *testing.T) { + t.Parallel() + tests := []struct { + name string + in *ast.Policy + out ast.Policy + }{ + { + "permit", + ast.Permit(), + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}}, + }, + { + "forbid", + ast.Forbid(), + ast.Policy{Effect: ast.EffectForbid, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}}, + }, + { + "annotationPermit", + ast.Annotation("key", "value").Permit(), + ast.Policy{Annotations: []ast.AnnotationType{{Key: "key", Value: "value"}}, Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}}, + }, + { + "annotationForbid", + ast.Annotation("key", "value").Forbid(), + ast.Policy{Annotations: []ast.AnnotationType{{Key: "key", Value: "value"}}, Effect: ast.EffectForbid, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}}, + }, + { + "annotations", + ast.Annotation("key", "value").Annotation("abc", "xyz").Permit(), + ast.Policy{Annotations: []ast.AnnotationType{{Key: "key", Value: "value"}, {Key: "abc", Value: "xyz"}}, Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}}, + }, + { + "policyAnnotate", + ast.Permit().Annotate("key", "value"), + ast.Policy{Annotations: []ast.AnnotationType{{Key: "key", Value: "value"}}, Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}}, + }, + { + "when", + ast.Permit().When(ast.True()), + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeValue{Value: types.Boolean(true)}}}, + }, + }, + { + "unless", + ast.Permit().Unless(ast.True()), + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionUnless, Body: ast.NodeValue{Value: types.Boolean(true)}}}, + }, + }, + { + "scopePrincipalEq", + ast.Permit().PrincipalEq(types.NewEntityUID("T", "42")), + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeEq{Entity: types.NewEntityUID("T", "42")}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}}, + }, + { + "scopePrincipalIn", + ast.Permit().PrincipalIn(types.NewEntityUID("T", "42")), + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeIn{Entity: types.NewEntityUID("T", "42")}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}}, + }, + { + "scopePrincipalIs", + ast.Permit().PrincipalIs("T"), + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeIs{Type: types.Path("T")}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}}, + }, + { + "scopePrincipalIsIn", + ast.Permit().PrincipalIsIn("T", types.NewEntityUID("T", "42")), + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeIsIn{Type: types.Path("T"), Entity: types.NewEntityUID("T", "42")}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}}, + }, + { + "scopeActionEq", + ast.Permit().ActionEq(types.NewEntityUID("T", "42")), + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeEq{Entity: types.NewEntityUID("T", "42")}, Resource: ast.ScopeTypeAll{}}, + }, + { + "scopeActionIn", + ast.Permit().ActionIn(types.NewEntityUID("T", "42")), + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeIn{Entity: types.NewEntityUID("T", "42")}, Resource: ast.ScopeTypeAll{}}, + }, + { + "scopeActionInSet", + ast.Permit().ActionInSet(types.NewEntityUID("T", "42"), types.NewEntityUID("T", "43")), + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeInSet{Entities: []types.EntityUID{types.NewEntityUID("T", "42"), types.NewEntityUID("T", "43")}}, Resource: ast.ScopeTypeAll{}}, + }, + { + "scopeResourceEq", + ast.Permit().ResourceEq(types.NewEntityUID("T", "42")), + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeEq{Entity: types.NewEntityUID("T", "42")}}, + }, + { + "scopeResourceIn", + ast.Permit().ResourceIn(types.NewEntityUID("T", "42")), + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeIn{Entity: types.NewEntityUID("T", "42")}}, + }, + { + "scopeResourceIs", + ast.Permit().ResourceIs("T"), + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeIs{Type: types.Path("T")}}, + }, + { + "scopeResourceIsIn", + ast.Permit().ResourceIsIn("T", types.NewEntityUID("T", "42")), + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeIsIn{Type: types.Path("T"), Entity: types.NewEntityUID("T", "42")}}, + }, + { + "variablePrincipal", + ast.Permit().When(ast.Principal()), + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeVariable{Name: "principal"}}}, + }, + }, + { + "variableAction", + ast.Permit().When(ast.Action()), + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeVariable{Name: "action"}}}, + }, + }, + { + "variableResource", + ast.Permit().When(ast.Resource()), + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeVariable{Name: "resource"}}}, + }, + }, + { + "variableContext", + ast.Permit().When(ast.Context()), + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeVariable{Name: "context"}}}, + }, + }, + { + "valueBoolFalse", + ast.Permit().When(ast.Boolean(false)), + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeValue{Value: types.Boolean(false)}}}, + }, + }, + { + "valueBoolTrue", + ast.Permit().When(ast.Boolean(true)), + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeValue{Value: types.Boolean(true)}}}, + }, + }, + { + "valueTrue", + ast.Permit().When(ast.True()), + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeValue{Value: types.Boolean(true)}}}, + }, + }, + { + "valueFalse", + ast.Permit().When(ast.False()), + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeValue{Value: types.Boolean(false)}}}, + }, + }, + { + "valueString", + ast.Permit().When(ast.String("cedar")), + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeValue{Value: types.String("cedar")}}}, + }, + }, + { + "valueLong", + ast.Permit().When(ast.Long(42)), + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeValue{Value: types.Long(42)}}}, + }, + }, + { + "valueSet", + ast.Permit().When(ast.Set(types.Set{types.Long(42), types.Long(43)})), + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeSet{Elements: []ast.IsNode{ast.NodeValue{Value: types.Long(42)}, ast.NodeValue{Value: types.Long(43)}}}}}, + }, + }, + { + "valueSetNodes", + ast.Permit().When(ast.SetNodes(ast.Long(42), ast.Long(43))), + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeSet{Elements: []ast.IsNode{ast.NodeValue{Value: types.Long(42)}, ast.NodeValue{Value: types.Long(43)}}}}}, + }, + }, + { + "valueRecord", + ast.Permit().When(ast.Record(types.Record{"key": types.Long(43)})), + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeRecord{Elements: []ast.RecordElementNode{{Key: "key", Value: ast.NodeValue{Value: types.Long(43)}}}}}}, + }, + }, + { + "valueRecordNodes", + ast.Permit().When(ast.RecordNodes(map[types.String]ast.Node{"key": ast.Long(42)})), + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeRecord{Elements: []ast.RecordElementNode{{Key: "key", Value: ast.NodeValue{Value: types.Long(42)}}}}}}, + }, + }, + { + "valueRecordElements", + ast.Permit().When(ast.RecordElements(ast.RecordElement{Key: "key", Value: ast.Long(42)})), + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeRecord{Elements: []ast.RecordElementNode{{Key: "key", Value: ast.NodeValue{Value: types.Long(42)}}}}}}, + }, + }, + { + "valueEntityUID", + ast.Permit().When(ast.EntityUID(types.NewEntityUID("T", "42"))), + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeValue{Value: types.NewEntityUID("T", "42")}}}, + }, + }, + { + "valueDecimal", + ast.Permit().When(ast.Decimal(420000)), + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeValue{Value: types.Decimal(420000)}}}, + }, + }, + { + "valueIPAddr", + ast.Permit().When(ast.IPAddr(types.IPAddr(netip.MustParsePrefix("127.0.0.1/16")))), + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeValue{Value: types.IPAddr(netip.MustParsePrefix("127.0.0.1/16"))}}}, + }, + }, + { + "extensionCall", + ast.Permit().When(ast.ExtensionCall("ip", ast.String("127.0.0.1"))), + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeExtensionCall{Name: "ip", Args: []ast.IsNode{ast.NodeValue{Value: types.String("127.0.0.1")}}}}}, + }}, + { + "opEquals", + ast.Permit().When(ast.Long(42).Equals(ast.Long(43))), + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeEquals{BinaryNode: ast.BinaryNode{Left: ast.NodeValue{Value: types.Long(42)}, Right: ast.NodeValue{Value: types.Long(43)}}}}}}, + }, + { + "opNotEquals", + ast.Permit().When(ast.Long(42).NotEquals(ast.Long(43))), + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeNotEquals{BinaryNode: ast.BinaryNode{Left: ast.NodeValue{Value: types.Long(42)}, Right: ast.NodeValue{Value: types.Long(43)}}}}}}, + }, + { + "opLessThan", + ast.Permit().When(ast.Long(42).LessThan(ast.Long(43))), + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeLessThan{BinaryNode: ast.BinaryNode{Left: ast.NodeValue{Value: types.Long(42)}, Right: ast.NodeValue{Value: types.Long(43)}}}}}}, + }, + { + "opLessThanOrEqual", + ast.Permit().When(ast.Long(42).LessThanOrEqual(ast.Long(43))), + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeLessThanOrEqual{BinaryNode: ast.BinaryNode{Left: ast.NodeValue{Value: types.Long(42)}, Right: ast.NodeValue{Value: types.Long(43)}}}}}}, + }, + { + "opGreaterThan", + ast.Permit().When(ast.Long(42).GreaterThan(ast.Long(43))), + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeGreaterThan{BinaryNode: ast.BinaryNode{Left: ast.NodeValue{Value: types.Long(42)}, Right: ast.NodeValue{Value: types.Long(43)}}}}}}, + }, + { + "opGreaterThanOrEqual", + ast.Permit().When(ast.Long(42).GreaterThanOrEqual(ast.Long(43))), + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeGreaterThanOrEqual{BinaryNode: ast.BinaryNode{Left: ast.NodeValue{Value: types.Long(42)}, Right: ast.NodeValue{Value: types.Long(43)}}}}}}, + }, + { + "opLessThanExt", + ast.Permit().When(ast.Long(42).LessThanExt(ast.Long(43))), + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeExtensionCall{Name: "lessThan", Args: []ast.IsNode{ast.NodeValue{Value: types.Long(42)}, ast.NodeValue{Value: types.Long(43)}}}}}}, + }, + { + "opLessThanOrEqualExt", + ast.Permit().When(ast.Long(42).LessThanOrEqualExt(ast.Long(43))), + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeExtensionCall{Name: "lessThanOrEqual", Args: []ast.IsNode{ast.NodeValue{Value: types.Long(42)}, ast.NodeValue{Value: types.Long(43)}}}}}}, + }, + { + "opGreaterThanExt", + ast.Permit().When(ast.Long(42).GreaterThanExt(ast.Long(43))), + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeExtensionCall{Name: "greaterThan", Args: []ast.IsNode{ast.NodeValue{Value: types.Long(42)}, ast.NodeValue{Value: types.Long(43)}}}}}}, + }, + { + "opGreaterThanOrEqualExt", + ast.Permit().When(ast.Long(42).GreaterThanOrEqualExt(ast.Long(43))), + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeExtensionCall{Name: "greaterThanOrEqual", Args: []ast.IsNode{ast.NodeValue{Value: types.Long(42)}, ast.NodeValue{Value: types.Long(43)}}}}}}, + }, + { + "opLike", + ast.Permit().When(ast.Long(42).Like(types.Pattern{})), + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeLike{Arg: ast.NodeValue{Value: types.Long(42)}, Value: types.Pattern{}}}}}, + }, + { + "opAnd", + ast.Permit().When(ast.Long(42).And(ast.Long(43))), + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeAnd{BinaryNode: ast.BinaryNode{Left: ast.NodeValue{Value: types.Long(42)}, Right: ast.NodeValue{Value: types.Long(43)}}}}}}, + }, + { + "opOr", + ast.Permit().When(ast.Long(42).Or(ast.Long(43))), + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeOr{BinaryNode: ast.BinaryNode{Left: ast.NodeValue{Value: types.Long(42)}, Right: ast.NodeValue{Value: types.Long(43)}}}}}}, + }, + { + "opNot", + ast.Permit().When(ast.Not(ast.True())), + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeNot{UnaryNode: ast.UnaryNode{Arg: ast.NodeValue{Value: types.Boolean(true)}}}}}}, + }, + { + "opIf", + ast.Permit().When(ast.If(ast.True(), ast.Long(42), ast.Long(43))), + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeIf{If: ast.NodeValue{Value: types.Boolean(true)}, Then: ast.NodeValue{Value: types.Long(42)}, Else: ast.NodeValue{Value: types.Long(43)}}}}}, + }, + { + "opPlus", + ast.Permit().When(ast.Long(42).Plus(ast.Long(43))), + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeAdd{BinaryNode: ast.BinaryNode{Left: ast.NodeValue{Value: types.Long(42)}, Right: ast.NodeValue{Value: types.Long(43)}}}}}}, + }, + { + "opMinus", + ast.Permit().When(ast.Long(42).Minus(ast.Long(43))), + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeSub{BinaryNode: ast.BinaryNode{Left: ast.NodeValue{Value: types.Long(42)}, Right: ast.NodeValue{Value: types.Long(43)}}}}}}, + }, + { + "opTimes", + ast.Permit().When(ast.Long(42).Times(ast.Long(43))), + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeMult{BinaryNode: ast.BinaryNode{Left: ast.NodeValue{Value: types.Long(42)}, Right: ast.NodeValue{Value: types.Long(43)}}}}}}, + }, + { + "opNegate", + ast.Permit().When(ast.Negate(ast.True())), + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeNegate{UnaryNode: ast.UnaryNode{Arg: ast.NodeValue{Value: types.Boolean(true)}}}}}}, + }, + { + "opIn", + ast.Permit().When(ast.Long(42).In(ast.Long(43))), + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeIn{BinaryNode: ast.BinaryNode{Left: ast.NodeValue{Value: types.Long(42)}, Right: ast.NodeValue{Value: types.Long(43)}}}}}}, + }, + { + "opIs", + ast.Permit().When(ast.Long(42).Is(types.Path("T"))), + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeIs{Left: ast.NodeValue{Value: types.Long(42)}, EntityType: types.Path("T")}}}}, + }, + { + "opIsIn", + ast.Permit().When(ast.Long(42).IsIn(types.Path("T"), ast.Long(43))), + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeIsIn{NodeTypeIs: ast.NodeTypeIs{Left: ast.NodeValue{Value: types.Long(42)}, EntityType: types.Path("T")}, Entity: ast.NodeValue{Value: types.Long(43)}}}}}, + }, + { + "opContains", + ast.Permit().When(ast.Long(42).Contains(ast.Long(43))), + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeContains{BinaryNode: ast.BinaryNode{Left: ast.NodeValue{Value: types.Long(42)}, Right: ast.NodeValue{Value: types.Long(43)}}}}}}, + }, + { + "opContainsAll", + ast.Permit().When(ast.Long(42).ContainsAll(ast.Long(43))), + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeContainsAll{BinaryNode: ast.BinaryNode{Left: ast.NodeValue{Value: types.Long(42)}, Right: ast.NodeValue{Value: types.Long(43)}}}}}}, + }, + { + "opContainsAny", + ast.Permit().When(ast.Long(42).ContainsAny(ast.Long(43))), + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeContainsAny{BinaryNode: ast.BinaryNode{Left: ast.NodeValue{Value: types.Long(42)}, Right: ast.NodeValue{Value: types.Long(43)}}}}}}, + }, + { + "opAccess", + ast.Permit().When(ast.Long(42).Access("key")), + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeAccess{StrOpNode: ast.StrOpNode{Arg: ast.NodeValue{Value: types.Long(42)}, Value: "key"}}}}}, + }, + { + "opHas", + ast.Permit().When(ast.Long(42).Has("key")), + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeHas{StrOpNode: ast.StrOpNode{Arg: ast.NodeValue{Value: types.Long(42)}, Value: "key"}}}}}, + }, + { + "opIsIpv4", + ast.Permit().When(ast.Long(42).IsIpv4()), + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeExtensionCall{Name: "isIpv4", Args: []ast.IsNode{ast.NodeValue{Value: types.Long(42)}}}}}}, + }, + { + "opIsIpv6", + ast.Permit().When(ast.Long(42).IsIpv6()), + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeExtensionCall{Name: "isIpv6", Args: []ast.IsNode{ast.NodeValue{Value: types.Long(42)}}}}}}, + }, + { + "opIsMulticast", + ast.Permit().When(ast.Long(42).IsMulticast()), + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeExtensionCall{Name: "isMulticast", Args: []ast.IsNode{ast.NodeValue{Value: types.Long(42)}}}}}}, + }, + { + "opIsLoopback", + ast.Permit().When(ast.Long(42).IsLoopback()), + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeExtensionCall{Name: "isLoopback", Args: []ast.IsNode{ast.NodeValue{Value: types.Long(42)}}}}}}, + }, + { + "opIsInRange", + ast.Permit().When(ast.Long(42).IsInRange(ast.Long(43))), + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeExtensionCall{Name: "isInRange", Args: []ast.IsNode{ast.NodeValue{Value: types.Long(42)}, ast.NodeValue{Value: types.Long(43)}}}}}}, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + testutil.Equals(t, tt.in, &tt.out) + }) + } +} diff --git a/internal/ast/internal_test.go b/internal/ast/internal_test.go new file mode 100644 index 00000000..02d21c86 --- /dev/null +++ b/internal/ast/internal_test.go @@ -0,0 +1,33 @@ +package ast + +import ( + "testing" + + "github.com/cedar-policy/cedar-go/internal/testutil" + "github.com/cedar-policy/cedar-go/types" +) + +func TestIsNode(t *testing.T) { + t.Parallel() + ScopeNode{}.isScope() + + StrOpNode{}.isNode() + BinaryNode{}.isNode() + NodeTypeIf{}.isNode() + NodeTypeLike{}.isNode() + NodeTypeIs{}.isNode() + UnaryNode{}.isNode() + NodeTypeExtensionCall{}.isNode() + NodeValue{}.isNode() + NodeTypeRecord{}.isNode() + NodeTypeSet{}.isNode() + NodeTypeVariable{}.isNode() + +} + +func TestAsNode(t *testing.T) { + t.Parallel() + n := NewNode(NodeValue{Value: types.Long(42)}) + v := n.AsIsNode() + testutil.Equals(t, v, (IsNode)(NodeValue{Value: types.Long(42)})) +} diff --git a/internal/ast/scope.go b/internal/ast/scope.go index 2da1ab79..521912a0 100644 --- a/internal/ast/scope.go +++ b/internal/ast/scope.go @@ -7,27 +7,27 @@ import ( type Scope NodeTypeVariable func (s Scope) All() IsScopeNode { - return ScopeTypeAll{ScopeNode: ScopeNode{Variable: NodeTypeVariable(s)}} + return ScopeTypeAll{} } func (s Scope) Eq(entity types.EntityUID) IsScopeNode { - return ScopeTypeEq{ScopeNode: ScopeNode{Variable: NodeTypeVariable(s)}, Entity: entity} + return ScopeTypeEq{Entity: entity} } func (s Scope) In(entity types.EntityUID) IsScopeNode { - return ScopeTypeIn{ScopeNode: ScopeNode{Variable: NodeTypeVariable(s)}, Entity: entity} + return ScopeTypeIn{Entity: entity} } func (s Scope) InSet(entities []types.EntityUID) IsScopeNode { - return ScopeTypeInSet{ScopeNode: ScopeNode{Variable: NodeTypeVariable(s)}, Entities: entities} + return ScopeTypeInSet{Entities: entities} } func (s Scope) Is(entityType types.Path) IsScopeNode { - return ScopeTypeIs{ScopeNode: ScopeNode{Variable: NodeTypeVariable(s)}, Type: entityType} + return ScopeTypeIs{Type: entityType} } func (s Scope) IsIn(entityType types.Path, entity types.EntityUID) IsScopeNode { - return ScopeTypeIsIn{ScopeNode: ScopeNode{Variable: NodeTypeVariable(s)}, Type: entityType, Entity: entity} + return ScopeTypeIsIn{Type: entityType, Entity: entity} } func (p *Policy) PrincipalEq(entity types.EntityUID) *Policy { @@ -90,7 +90,6 @@ type IsScopeNode interface { } type ScopeNode struct { - Variable NodeTypeVariable } func (n ScopeNode) isScope() {} diff --git a/internal/eval/eval_compile.go b/internal/eval/eval_compile.go index b2406501..dd1d4847 100644 --- a/internal/eval/eval_compile.go +++ b/internal/eval/eval_compile.go @@ -11,9 +11,9 @@ func Compile(p ast.Policy) Evaler { func policyToNode(p ast.Policy) ast.Node { nodes := make([]ast.Node, 3+len(p.Conditions)) - nodes[0] = scopeToNode(p.Principal) - nodes[1] = scopeToNode(p.Action) - nodes[2] = scopeToNode(p.Resource) + nodes[0] = scopeToNode(ast.NewPrincipalNode(), p.Principal) + nodes[1] = scopeToNode(ast.NewActionNode(), p.Action) + nodes[2] = scopeToNode(ast.NewResourceNode(), p.Resource) for i, c := range p.Conditions { if c.Condition == ast.ConditionUnless { nodes[i+3] = ast.Not(ast.NewNode(c.Body)) diff --git a/internal/eval/eval_convert.go b/internal/eval/eval_convert.go index 734c9352..006c51fd 100644 --- a/internal/eval/eval_convert.go +++ b/internal/eval/eval_convert.go @@ -127,25 +127,25 @@ func toEval(n ast.IsNode) Evaler { } } -func scopeToNode(in ast.IsScopeNode) ast.Node { +func scopeToNode(varNode ast.NodeTypeVariable, in ast.IsScopeNode) ast.Node { switch t := in.(type) { case ast.ScopeTypeAll: return ast.True() case ast.ScopeTypeEq: - return ast.NewNode(t.Variable).Equals(ast.EntityUID(t.Entity)) + return ast.NewNode(varNode).Equals(ast.EntityUID(t.Entity)) case ast.ScopeTypeIn: - return ast.NewNode(t.Variable).In(ast.EntityUID(t.Entity)) + return ast.NewNode(varNode).In(ast.EntityUID(t.Entity)) case ast.ScopeTypeInSet: set := make([]types.Value, len(t.Entities)) for i, e := range t.Entities { set[i] = e } - return ast.NewNode(t.Variable).In(ast.Set(set)) + return ast.NewNode(varNode).In(ast.Set(set)) case ast.ScopeTypeIs: - return ast.NewNode(t.Variable).Is(t.Type) + return ast.NewNode(varNode).Is(t.Type) case ast.ScopeTypeIsIn: - return ast.NewNode(t.Variable).IsIn(t.Type, ast.EntityUID(t.Entity)) + return ast.NewNode(varNode).IsIn(t.Type, ast.EntityUID(t.Entity)) default: panic(fmt.Sprintf("unknown scope type %T", t)) } diff --git a/internal/parser/cedar_marshal.go b/internal/parser/cedar_marshal.go index d1ebf1b2..51032d7e 100644 --- a/internal/parser/cedar_marshal.go +++ b/internal/parser/cedar_marshal.go @@ -28,25 +28,25 @@ func (p *Policy) MarshalCedar(buf *bytes.Buffer) { // scopeToNode is copied in from eval, with the expectation that // eval will not be using it in the future. -func scopeToNode(in ast.IsScopeNode) ast.Node { +func scopeToNode(varNode ast.NodeTypeVariable, in ast.IsScopeNode) ast.Node { switch t := in.(type) { case ast.ScopeTypeAll: return ast.True() case ast.ScopeTypeEq: - return ast.NewNode(t.Variable).Equals(ast.EntityUID(t.Entity)) + return ast.NewNode(varNode).Equals(ast.EntityUID(t.Entity)) case ast.ScopeTypeIn: - return ast.NewNode(t.Variable).In(ast.EntityUID(t.Entity)) + return ast.NewNode(varNode).In(ast.EntityUID(t.Entity)) case ast.ScopeTypeInSet: set := make([]types.Value, len(t.Entities)) for i, e := range t.Entities { set[i] = e } - return ast.NewNode(t.Variable).In(ast.Set(set)) + return ast.NewNode(varNode).In(ast.Set(set)) case ast.ScopeTypeIs: - return ast.NewNode(t.Variable).Is(t.Type) + return ast.NewNode(varNode).Is(t.Type) case ast.ScopeTypeIsIn: - return ast.NewNode(t.Variable).IsIn(t.Type, ast.EntityUID(t.Entity)) + return ast.NewNode(varNode).IsIn(t.Type, ast.EntityUID(t.Entity)) default: panic(fmt.Sprintf("unknown scope type %T", t)) } @@ -65,19 +65,19 @@ func (p *Policy) marshalScope(buf *bytes.Buffer) { if principalAll { buf.WriteString("principal") } else { - astNodeToMarshalNode(scopeToNode(p.Policy.Principal).AsIsNode()).marshalCedar(buf) + astNodeToMarshalNode(scopeToNode(ast.NewPrincipalNode(), p.Policy.Principal).AsIsNode()).marshalCedar(buf) } buf.WriteString(",\n ") if actionAll { buf.WriteString("action") } else { - astNodeToMarshalNode(scopeToNode(p.Policy.Action).AsIsNode()).marshalCedar(buf) + astNodeToMarshalNode(scopeToNode(ast.NewActionNode(), p.Policy.Action).AsIsNode()).marshalCedar(buf) } buf.WriteString(",\n ") if resourceAll { buf.WriteString("resource") } else { - astNodeToMarshalNode(scopeToNode(p.Policy.Resource).AsIsNode()).marshalCedar(buf) + astNodeToMarshalNode(scopeToNode(ast.NewResourceNode(), p.Policy.Resource).AsIsNode()).marshalCedar(buf) } buf.WriteString("\n)") } From 82ce28b687d95c1d33fa4af59fc523e0408a72bb Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Wed, 14 Aug 2024 15:40:56 -0700 Subject: [PATCH 124/216] cedar-go/internal/ast: eliminate the notion of policy IDs at the AST level Signed-off-by: philhassey --- internal/parser/cedar_parse_test.go | 10 +++++----- internal/parser/cedar_unmarshal.go | 8 ++------ internal/parser/cedar_unmarshal_test.go | 6 +++--- internal/parser/policy.go | 4 +--- policy_set.go | 7 ++++--- 5 files changed, 15 insertions(+), 20 deletions(-) diff --git a/internal/parser/cedar_parse_test.go b/internal/parser/cedar_parse_test.go index 5e1f1277..5e8e25c6 100644 --- a/internal/parser/cedar_parse_test.go +++ b/internal/parser/cedar_parse_test.go @@ -307,7 +307,7 @@ func TestParse(t *testing.T) { } var buf bytes.Buffer - pp := policies["policy0"].Policy + pp := policies[0].Policy pp.MarshalCedar(&buf) var p2 parser.PolicySet @@ -315,7 +315,7 @@ func TestParse(t *testing.T) { testutil.OK(t, err) // TODO: support 0, > 1 - testutil.Equals(t, p2["policy0"].Policy, policies["policy0"].Policy) + testutil.Equals(t, p2[0].Policy, policies[0].Policy) }) } @@ -339,7 +339,7 @@ permit( principal, action, resource ); err := out.UnmarshalCedar([]byte(in)) testutil.OK(t, err) testutil.Equals(t, len(out), 3) - testutil.Equals(t, out["policy0"].Position, parser.Position{Offset: 17, Line: 2, Column: 1}) - testutil.Equals(t, out["policy1"].Position, parser.Position{Offset: 86, Line: 7, Column: 3}) - testutil.Equals(t, out["policy2"].Position, parser.Position{Offset: 148, Line: 10, Column: 2}) + testutil.Equals(t, out[0].Position, parser.Position{Offset: 17, Line: 2, Column: 1}) + testutil.Equals(t, out[1].Position, parser.Position{Offset: 86, Line: 7, Column: 3}) + testutil.Equals(t, out[2].Position, parser.Position{Offset: 148, Line: 10, Column: 2}) } diff --git a/internal/parser/cedar_unmarshal.go b/internal/parser/cedar_unmarshal.go index 2454d3b4..1e6f8b5f 100644 --- a/internal/parser/cedar_unmarshal.go +++ b/internal/parser/cedar_unmarshal.go @@ -15,9 +15,7 @@ func (p *PolicySet) UnmarshalCedar(b []byte) error { return err } - i := 0 - - policySet := PolicySet{} + var policySet PolicySet parser := newParser(tokens) for !parser.peek().isEOF() { pos := parser.peek().Pos @@ -33,9 +31,7 @@ func (p *PolicySet) UnmarshalCedar(b []byte) error { return err } - policyName := PolicyID(fmt.Sprintf("policy%v", i)) - policySet[policyName] = PolicySetEntry{Policy: policy, Position: pos} - i++ + policySet = append(policySet, PolicySetEntry{Policy: policy, Position: pos}) } *p = policySet diff --git a/internal/parser/cedar_unmarshal_test.go b/internal/parser/cedar_unmarshal_test.go index 91c3bb51..c97d845d 100644 --- a/internal/parser/cedar_unmarshal_test.go +++ b/internal/parser/cedar_unmarshal_test.go @@ -460,7 +460,7 @@ func TestParsePolicySet(t *testing.T) { resource );`, parser.PolicySet{ - "policy0": parser.PolicySetEntry{ + parser.PolicySetEntry{ parser.Policy{*ast.Permit()}, parser.Position{Offset: 0, Line: 1, Column: 1}, }, @@ -479,11 +479,11 @@ func TestParsePolicySet(t *testing.T) { resource );`, parser.PolicySet{ - "policy0": parser.PolicySetEntry{ + parser.PolicySetEntry{ parser.Policy{*ast.Permit()}, parser.Position{Offset: 0, Line: 1, Column: 1}, }, - "policy1": parser.PolicySetEntry{ + parser.PolicySetEntry{ parser.Policy{*ast.Forbid()}, parser.Position{Offset: 53, Line: 6, Column: 3}, }, diff --git a/internal/parser/policy.go b/internal/parser/policy.go index cc5b4418..775754fb 100644 --- a/internal/parser/policy.go +++ b/internal/parser/policy.go @@ -2,9 +2,7 @@ package parser import "github.com/cedar-policy/cedar-go/internal/ast" -type PolicyID string - -type PolicySet map[PolicyID]PolicySetEntry +type PolicySet []PolicySetEntry type PolicySetEntry struct { Policy Policy diff --git a/policy_set.go b/policy_set.go index 27d73667..425d73ae 100644 --- a/policy_set.go +++ b/policy_set.go @@ -8,7 +8,7 @@ import ( "github.com/cedar-policy/cedar-go/internal/parser" ) -type PolicyID parser.PolicyID +type PolicyID string // A PolicySet is a slice of policies. type PolicySet struct { @@ -24,8 +24,9 @@ func NewPolicySet(fileName string, document []byte) (PolicySet, error) { return PolicySet{}, fmt.Errorf("parser error: %w", err) } policyMap := make(map[PolicyID]*Policy, len(res)) - for name, p := range res { - policyMap[PolicyID(name)] = &Policy{ + for i, p := range res { + policyID := PolicyID(fmt.Sprintf("policy%d", i)) + policyMap[policyID] = &Policy{ Position: Position{ Filename: fileName, Offset: p.Position.Offset, From d60e1cad97b6a257c16a04cf84a3050559be1343 Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Wed, 14 Aug 2024 15:51:02 -0700 Subject: [PATCH 125/216] cedar-go: add the ability to upsert a Policy into a PolicySet Signed-off-by: philhassey --- policy_set.go | 9 +++++++++ policy_set_test.go | 30 +++++++++++++++++++++++++++--- 2 files changed, 36 insertions(+), 3 deletions(-) diff --git a/policy_set.go b/policy_set.go index 425d73ae..3663c09c 100644 --- a/policy_set.go +++ b/policy_set.go @@ -53,6 +53,10 @@ func NewPolicySet(fileName string, document []byte) (PolicySet, error) { // )) // // ps := NewPolicySetFromPolicies([]*Policy{policy0, &policy1}) +// +// NewPolicySetFromPolicies assigns default PolicyIDs to the policies that are passed. If you would like to assign your +// own PolicyIDs, call NewPolicySetFromPolicies() with an empty slice and use PolicySet.UpsertPolicy() to add the +// policies individually with the desired PolicyID. func NewPolicySetFromPolicies(policies []*Policy) PolicySet { policyMap := make(map[PolicyID]*Policy, len(policies)) for i, p := range policies { @@ -67,3 +71,8 @@ func NewPolicySetFromPolicies(policies []*Policy) PolicySet { func (p PolicySet) GetPolicy(policyID PolicyID) *Policy { return p.policies[policyID] } + +// UpsertPolicy inserts or updates a policy with the given ID. +func (p *PolicySet) UpsertPolicy(policyID PolicyID, policy *Policy) { + p.policies[policyID] = policy +} diff --git a/policy_set_test.go b/policy_set_test.go index 3de5bdb9..9185ecdf 100644 --- a/policy_set_test.go +++ b/policy_set_test.go @@ -32,9 +32,7 @@ func TestNewPolicySetFromPolicies(t *testing.T) { t.Run("empty slice", func(t *testing.T) { t.Parallel() - var policies []*Policy - ps := NewPolicySetFromPolicies(policies) - + ps := NewPolicySetFromPolicies(nil) testutil.Equals(t, ps.GetPolicy("policy0"), nil) }) t.Run("non-empty slice", func(t *testing.T) { @@ -54,3 +52,29 @@ func TestNewPolicySetFromPolicies(t *testing.T) { testutil.Equals(t, ps.GetPolicy("policy2"), nil) }) } + +func TestUpsertPolicy(t *testing.T) { + t.Parallel() + t.Run("insert", func(t *testing.T) { + t.Parallel() + + ps := NewPolicySetFromPolicies(nil) + p := NewPolicyFromAST(ast.Forbid()) + ps.UpsertPolicy("a very strict policy", p) + + testutil.Equals(t, ps.GetPolicy("a very strict policy"), p) + }) + t.Run("upsert", func(t *testing.T) { + t.Parallel() + + ps := NewPolicySetFromPolicies(nil) + + p1 := NewPolicyFromAST(ast.Forbid()) + ps.UpsertPolicy("a wavering policy", p1) + + p2 := NewPolicyFromAST(ast.Permit()) + ps.UpsertPolicy("a wavering policy", p2) + + testutil.Equals(t, ps.GetPolicy("a wavering policy"), p2) + }) +} From 3e899031833e5b502bfcca6b3a71232e2fb3de45 Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Wed, 14 Aug 2024 16:19:16 -0700 Subject: [PATCH 126/216] cedar-go/internal/ast: move Position into the ast Node This could theoretically be usable in other instances such as policies parsed from JSON. Also, in the future, we're probably going to want addition Position information on individual nodes as well. Signed-off-by: philhassey --- internal/ast/policy.go | 9 ++++ internal/parser/cedar_parse_test.go | 13 +++-- internal/parser/cedar_tokenize.go | 9 +--- internal/parser/cedar_unmarshal.go | 20 +++----- internal/parser/cedar_unmarshal_test.go | 65 ++++++++++--------------- internal/parser/policy.go | 7 +-- policy_set.go | 4 +- 7 files changed, 55 insertions(+), 72 deletions(-) diff --git a/internal/ast/policy.go b/internal/ast/policy.go index 419849e7..24382b25 100644 --- a/internal/ast/policy.go +++ b/internal/ast/policy.go @@ -27,6 +27,14 @@ const ( EffectForbid Effect = false ) +// Position is a value that represents a source Position. +// A Position is valid if Line > 0. +type Position struct { + Offset int // byte offset, starting at 0 + Line int // line number, starting at 1 + Column int // column number, starting at 1 (character count per line) +} + type Policy struct { Effect Effect Annotations []AnnotationType @@ -34,6 +42,7 @@ type Policy struct { Action IsScopeNode Resource IsScopeNode Conditions []ConditionType + Position Position } func newPolicy(effect Effect, annotations []AnnotationType) *Policy { diff --git a/internal/parser/cedar_parse_test.go b/internal/parser/cedar_parse_test.go index 5e8e25c6..9ba2a97d 100644 --- a/internal/parser/cedar_parse_test.go +++ b/internal/parser/cedar_parse_test.go @@ -4,6 +4,7 @@ import ( "bytes" "testing" + "github.com/cedar-policy/cedar-go/internal/ast" "github.com/cedar-policy/cedar-go/internal/parser" "github.com/cedar-policy/cedar-go/internal/testutil" ) @@ -306,8 +307,12 @@ func TestParse(t *testing.T) { return } + // N.B. Until we support the re-rendering of comments, we have to ignore the position for the purposes of + // these tests (see test "ex1") + policies[0].Position = ast.Position{Offset: 0, Line: 1, Column: 1} + var buf bytes.Buffer - pp := policies[0].Policy + pp := policies[0] pp.MarshalCedar(&buf) var p2 parser.PolicySet @@ -339,7 +344,7 @@ permit( principal, action, resource ); err := out.UnmarshalCedar([]byte(in)) testutil.OK(t, err) testutil.Equals(t, len(out), 3) - testutil.Equals(t, out[0].Position, parser.Position{Offset: 17, Line: 2, Column: 1}) - testutil.Equals(t, out[1].Position, parser.Position{Offset: 86, Line: 7, Column: 3}) - testutil.Equals(t, out[2].Position, parser.Position{Offset: 148, Line: 10, Column: 2}) + testutil.Equals(t, out[0].Position, ast.Position{Offset: 17, Line: 2, Column: 1}) + testutil.Equals(t, out[1].Position, ast.Position{Offset: 86, Line: 7, Column: 3}) + testutil.Equals(t, out[2].Position, ast.Position{Offset: 148, Line: 10, Column: 2}) } diff --git a/internal/parser/cedar_tokenize.go b/internal/parser/cedar_tokenize.go index 7a034df0..67b8501f 100644 --- a/internal/parser/cedar_tokenize.go +++ b/internal/parser/cedar_tokenize.go @@ -8,6 +8,7 @@ import ( "strings" "unicode/utf8" + "github.com/cedar-policy/cedar-go/internal/ast" "github.com/cedar-policy/cedar-go/internal/rust" ) @@ -76,13 +77,7 @@ func Tokenize(src []byte) ([]Token, error) { return res, nil } -// Position is a value that represents a source Position. -// A Position is valid if Line > 0. -type Position struct { - Offset int // byte offset, starting at 0 - Line int // line number, starting at 1 - Column int // column number, starting at 1 (character count per line) -} +type Position ast.Position func (pos Position) String() string { return fmt.Sprintf(":%d:%d", pos.Line, pos.Column) diff --git a/internal/parser/cedar_unmarshal.go b/internal/parser/cedar_unmarshal.go index 1e6f8b5f..9dc57011 100644 --- a/internal/parser/cedar_unmarshal.go +++ b/internal/parser/cedar_unmarshal.go @@ -18,20 +18,12 @@ func (p *PolicySet) UnmarshalCedar(b []byte) error { var policySet PolicySet parser := newParser(tokens) for !parser.peek().isEOF() { - pos := parser.peek().Pos - policy := Policy{ - ast.Policy{ - Principal: ast.ScopeTypeAll{}, - Action: ast.ScopeTypeAll{}, - Resource: ast.ScopeTypeAll{}, - }, - } - - if err = policy.fromCedarWithParser(&parser); err != nil { + var policy Policy + if err = policy.fromCedar(&parser); err != nil { return err } - policySet = append(policySet, PolicySetEntry{Policy: policy, Position: pos}) + policySet = append(policySet, policy) } *p = policySet @@ -45,10 +37,11 @@ func (p *Policy) UnmarshalCedar(b []byte) error { } parser := newParser(tokens) - return p.fromCedarWithParser(&parser) + return p.fromCedar(&parser) } -func (p *Policy) fromCedarWithParser(parser *parser) error { +func (p *Policy) fromCedar(parser *parser) error { + pos := parser.peek().Pos annotations, err := parser.annotations() if err != nil { return err @@ -58,6 +51,7 @@ func (p *Policy) fromCedarWithParser(parser *parser) error { if err != nil { return err } + newPolicy.Position = (ast.Position)(pos) if err = parser.exact("("); err != nil { return err diff --git a/internal/parser/cedar_unmarshal_test.go b/internal/parser/cedar_unmarshal_test.go index c97d845d..63ffd49f 100644 --- a/internal/parser/cedar_unmarshal_test.go +++ b/internal/parser/cedar_unmarshal_test.go @@ -436,6 +436,7 @@ when { (if true then 2 else 3) * 4 == 8 };`, var policy parser.Policy testutil.OK(t, policy.UnmarshalCedar([]byte(tt.Text))) + policy.Position = ast.Position{} testutil.Equals(t, policy, parser.Policy{*tt.ExpectedPolicy}) var buf bytes.Buffer @@ -447,28 +448,22 @@ when { (if true then 2 else 3) * 4 == 8 };`, func TestParsePolicySet(t *testing.T) { t.Parallel() - parseTests := []struct { - Name string - Text string - ExpectedPolicies parser.PolicySet - }{ - { - "single policy", - `permit ( + t.Run("single policy", func(t *testing.T) { + policyStr := []byte(`permit ( principal, action, resource - );`, - parser.PolicySet{ - parser.PolicySetEntry{ - parser.Policy{*ast.Permit()}, - parser.Position{Offset: 0, Line: 1, Column: 1}, - }, - }, - }, - { - "two policies", - `permit ( + );`) + + var policies parser.PolicySet + testutil.OK(t, policies.UnmarshalCedar(policyStr)) + + expectedPolicy := ast.Permit() + expectedPolicy.Position = ast.Position{Offset: 0, Line: 1, Column: 1} + testutil.Equals(t, &policies[0].Policy, expectedPolicy) + }) + t.Run("two policies", func(t *testing.T) { + policyStr := []byte(`permit ( principal, action, resource @@ -477,26 +472,16 @@ func TestParsePolicySet(t *testing.T) { principal, action, resource - );`, - parser.PolicySet{ - parser.PolicySetEntry{ - parser.Policy{*ast.Permit()}, - parser.Position{Offset: 0, Line: 1, Column: 1}, - }, - parser.PolicySetEntry{ - parser.Policy{*ast.Forbid()}, - parser.Position{Offset: 53, Line: 6, Column: 3}, - }, - }, - }, - } - for _, tt := range parseTests { - t.Run(tt.Name, func(t *testing.T) { - t.Parallel() + );`) + var policies parser.PolicySet + testutil.OK(t, policies.UnmarshalCedar(policyStr)) - var policies parser.PolicySet - testutil.OK(t, policies.UnmarshalCedar([]byte(tt.Text))) - testutil.Equals(t, policies, tt.ExpectedPolicies) - }) - } + expectedPolicy0 := ast.Permit() + expectedPolicy0.Position = ast.Position{Offset: 0, Line: 1, Column: 1} + testutil.Equals(t, &policies[0].Policy, expectedPolicy0) + + expectedPolicy1 := ast.Forbid() + expectedPolicy1.Position = ast.Position{Offset: 53, Line: 6, Column: 3} + testutil.Equals(t, &policies[1].Policy, expectedPolicy1) + }) } diff --git a/internal/parser/policy.go b/internal/parser/policy.go index 775754fb..d4b17c23 100644 --- a/internal/parser/policy.go +++ b/internal/parser/policy.go @@ -2,12 +2,7 @@ package parser import "github.com/cedar-policy/cedar-go/internal/ast" -type PolicySet []PolicySetEntry - -type PolicySetEntry struct { - Policy Policy - Position Position -} +type PolicySet []Policy type Policy struct { ast.Policy diff --git a/policy_set.go b/policy_set.go index 3663c09c..f68166b4 100644 --- a/policy_set.go +++ b/policy_set.go @@ -35,8 +35,8 @@ func NewPolicySet(fileName string, document []byte) (PolicySet, error) { }, Annotations: newAnnotationsFromSlice(p.Policy.Annotations), Effect: Effect(p.Policy.Effect), - eval: eval.Compile(p.Policy.Policy), - ast: &p.Policy.Policy, + eval: eval.Compile(p.Policy), + ast: &p.Policy, } } return PolicySet{policies: policyMap}, nil From f4a5105fd1551ce3b112e3dc44a7937de4b1866e Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Wed, 14 Aug 2024 16:33:24 -0700 Subject: [PATCH 127/216] cedar-go/internal/parser: convert parser.Policy into a newtype for ast.Policy rather than using embedding Signed-off-by: philhassey --- internal/parser/cedar_marshal.go | 18 +++++++++--------- internal/parser/cedar_parse_test.go | 2 +- internal/parser/cedar_unmarshal.go | 4 ++-- internal/parser/cedar_unmarshal_test.go | 8 ++++---- internal/parser/policy.go | 7 ++----- policy.go | 6 +++--- policy_set.go | 9 +++++---- 7 files changed, 26 insertions(+), 28 deletions(-) diff --git a/internal/parser/cedar_marshal.go b/internal/parser/cedar_marshal.go index 51032d7e..06dd0648 100644 --- a/internal/parser/cedar_marshal.go +++ b/internal/parser/cedar_marshal.go @@ -10,15 +10,15 @@ import ( ) func (p *Policy) MarshalCedar(buf *bytes.Buffer) { - for _, a := range p.Policy.Annotations { + for _, a := range p.Annotations { marshalAnnotation(a, buf) buf.WriteRune('\n') } - marshalEffect(p.Policy.Effect, buf) + marshalEffect(p.Effect, buf) buf.WriteRune(' ') p.marshalScope(buf) - for _, c := range p.Policy.Conditions { + for _, c := range p.Conditions { buf.WriteRune('\n') marshalCondition(c, buf) } @@ -53,9 +53,9 @@ func scopeToNode(varNode ast.NodeTypeVariable, in ast.IsScopeNode) ast.Node { } func (p *Policy) marshalScope(buf *bytes.Buffer) { - _, principalAll := p.Policy.Principal.(ast.ScopeTypeAll) - _, actionAll := p.Policy.Action.(ast.ScopeTypeAll) - _, resourceAll := p.Policy.Resource.(ast.ScopeTypeAll) + _, principalAll := p.Principal.(ast.ScopeTypeAll) + _, actionAll := p.Action.(ast.ScopeTypeAll) + _, resourceAll := p.Resource.(ast.ScopeTypeAll) if principalAll && actionAll && resourceAll { buf.WriteString("( principal, action, resource )") return @@ -65,19 +65,19 @@ func (p *Policy) marshalScope(buf *bytes.Buffer) { if principalAll { buf.WriteString("principal") } else { - astNodeToMarshalNode(scopeToNode(ast.NewPrincipalNode(), p.Policy.Principal).AsIsNode()).marshalCedar(buf) + astNodeToMarshalNode(scopeToNode(ast.NewPrincipalNode(), p.Principal).AsIsNode()).marshalCedar(buf) } buf.WriteString(",\n ") if actionAll { buf.WriteString("action") } else { - astNodeToMarshalNode(scopeToNode(ast.NewActionNode(), p.Policy.Action).AsIsNode()).marshalCedar(buf) + astNodeToMarshalNode(scopeToNode(ast.NewActionNode(), p.Action).AsIsNode()).marshalCedar(buf) } buf.WriteString(",\n ") if resourceAll { buf.WriteString("resource") } else { - astNodeToMarshalNode(scopeToNode(ast.NewResourceNode(), p.Policy.Resource).AsIsNode()).marshalCedar(buf) + astNodeToMarshalNode(scopeToNode(ast.NewResourceNode(), p.Resource).AsIsNode()).marshalCedar(buf) } buf.WriteString("\n)") } diff --git a/internal/parser/cedar_parse_test.go b/internal/parser/cedar_parse_test.go index 9ba2a97d..6737dbbe 100644 --- a/internal/parser/cedar_parse_test.go +++ b/internal/parser/cedar_parse_test.go @@ -320,7 +320,7 @@ func TestParse(t *testing.T) { testutil.OK(t, err) // TODO: support 0, > 1 - testutil.Equals(t, p2[0].Policy, policies[0].Policy) + testutil.Equals(t, p2[0], policies[0]) }) } diff --git a/internal/parser/cedar_unmarshal.go b/internal/parser/cedar_unmarshal.go index 9dc57011..120c4302 100644 --- a/internal/parser/cedar_unmarshal.go +++ b/internal/parser/cedar_unmarshal.go @@ -23,7 +23,7 @@ func (p *PolicySet) UnmarshalCedar(b []byte) error { return err } - policySet = append(policySet, policy) + policySet = append(policySet, &policy) } *p = policySet @@ -81,7 +81,7 @@ func (p *Policy) fromCedar(parser *parser) error { return err } - *p = Policy{*newPolicy} + *p = *(*Policy)(newPolicy) return nil } diff --git a/internal/parser/cedar_unmarshal_test.go b/internal/parser/cedar_unmarshal_test.go index 63ffd49f..e02b96f9 100644 --- a/internal/parser/cedar_unmarshal_test.go +++ b/internal/parser/cedar_unmarshal_test.go @@ -437,7 +437,7 @@ when { (if true then 2 else 3) * 4 == 8 };`, var policy parser.Policy testutil.OK(t, policy.UnmarshalCedar([]byte(tt.Text))) policy.Position = ast.Position{} - testutil.Equals(t, policy, parser.Policy{*tt.ExpectedPolicy}) + testutil.Equals(t, &policy, (*parser.Policy)(tt.ExpectedPolicy)) var buf bytes.Buffer policy.MarshalCedar(&buf) @@ -460,7 +460,7 @@ func TestParsePolicySet(t *testing.T) { expectedPolicy := ast.Permit() expectedPolicy.Position = ast.Position{Offset: 0, Line: 1, Column: 1} - testutil.Equals(t, &policies[0].Policy, expectedPolicy) + testutil.Equals(t, policies[0], (*parser.Policy)(expectedPolicy)) }) t.Run("two policies", func(t *testing.T) { policyStr := []byte(`permit ( @@ -478,10 +478,10 @@ func TestParsePolicySet(t *testing.T) { expectedPolicy0 := ast.Permit() expectedPolicy0.Position = ast.Position{Offset: 0, Line: 1, Column: 1} - testutil.Equals(t, &policies[0].Policy, expectedPolicy0) + testutil.Equals(t, policies[0], (*parser.Policy)(expectedPolicy0)) expectedPolicy1 := ast.Forbid() expectedPolicy1.Position = ast.Position{Offset: 53, Line: 6, Column: 3} - testutil.Equals(t, &policies[1].Policy, expectedPolicy1) + testutil.Equals(t, policies[1], (*parser.Policy)(expectedPolicy1)) }) } diff --git a/internal/parser/policy.go b/internal/parser/policy.go index d4b17c23..ba489f58 100644 --- a/internal/parser/policy.go +++ b/internal/parser/policy.go @@ -2,8 +2,5 @@ package parser import "github.com/cedar-policy/cedar-go/internal/ast" -type PolicySet []Policy - -type Policy struct { - ast.Policy -} +type PolicySet []*Policy +type Policy ast.Policy diff --git a/policy.go b/policy.go index 44d424f4..18631dc1 100644 --- a/policy.go +++ b/policy.go @@ -77,7 +77,7 @@ func (p *Policy) UnmarshalJSON(b []byte) error { } func (p *Policy) MarshalCedar(buf *bytes.Buffer) { - cedarPolicy := &parser.Policy{Policy: *p.ast} + cedarPolicy := (*parser.Policy)(p.ast) cedarPolicy.MarshalCedar(buf) } @@ -91,8 +91,8 @@ func (p *Policy) UnmarshalCedar(b []byte) error { Position: Position{}, Annotations: newAnnotationsFromSlice(cedarPolicy.Annotations), Effect: Effect(cedarPolicy.Effect), - eval: eval.Compile(cedarPolicy.Policy), - ast: &cedarPolicy.Policy, + eval: eval.Compile((internalast.Policy)(cedarPolicy)), + ast: (*internalast.Policy)(&cedarPolicy), } return nil } diff --git a/policy_set.go b/policy_set.go index f68166b4..32c35431 100644 --- a/policy_set.go +++ b/policy_set.go @@ -4,6 +4,7 @@ package cedar import ( "fmt" + internalast "github.com/cedar-policy/cedar-go/internal/ast" "github.com/cedar-policy/cedar-go/internal/eval" "github.com/cedar-policy/cedar-go/internal/parser" ) @@ -33,10 +34,10 @@ func NewPolicySet(fileName string, document []byte) (PolicySet, error) { Line: p.Position.Line, Column: p.Position.Column, }, - Annotations: newAnnotationsFromSlice(p.Policy.Annotations), - Effect: Effect(p.Policy.Effect), - eval: eval.Compile(p.Policy), - ast: &p.Policy, + Annotations: newAnnotationsFromSlice(p.Annotations), + Effect: Effect(p.Effect), + eval: eval.Compile((internalast.Policy)(*p)), + ast: (*internalast.Policy)(p), } } return PolicySet{policies: policyMap}, nil From 976347f3a66716cb9767f45867fb35fb79549990 Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Wed, 14 Aug 2024 16:37:00 -0700 Subject: [PATCH 128/216] cedar-go/internal/eval: change argument type to *ast.Policy to avoid unnecessary dereferencing Signed-off-by: philhassey --- internal/eval/eval_compile.go | 4 ++-- policy.go | 6 +++--- policy_set.go | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/internal/eval/eval_compile.go b/internal/eval/eval_compile.go index dd1d4847..0a708637 100644 --- a/internal/eval/eval_compile.go +++ b/internal/eval/eval_compile.go @@ -4,12 +4,12 @@ import ( "github.com/cedar-policy/cedar-go/internal/ast" ) -func Compile(p ast.Policy) Evaler { +func Compile(p *ast.Policy) Evaler { node := policyToNode(p).AsIsNode() return toEval(node) } -func policyToNode(p ast.Policy) ast.Node { +func policyToNode(p *ast.Policy) ast.Node { nodes := make([]ast.Node, 3+len(p.Conditions)) nodes[0] = scopeToNode(ast.NewPrincipalNode(), p.Principal) nodes[1] = scopeToNode(ast.NewActionNode(), p.Action) diff --git a/policy.go b/policy.go index 18631dc1..9bfb771b 100644 --- a/policy.go +++ b/policy.go @@ -70,7 +70,7 @@ func (p *Policy) UnmarshalJSON(b []byte) error { Position: Position{}, Annotations: newAnnotationsFromSlice(jsonPolicy.Annotations), Effect: Effect(jsonPolicy.Effect), - eval: eval.Compile((internalast.Policy)(jsonPolicy)), + eval: eval.Compile((*internalast.Policy)(&jsonPolicy)), ast: (*internalast.Policy)(&jsonPolicy), } return nil @@ -91,7 +91,7 @@ func (p *Policy) UnmarshalCedar(b []byte) error { Position: Position{}, Annotations: newAnnotationsFromSlice(cedarPolicy.Annotations), Effect: Effect(cedarPolicy.Effect), - eval: eval.Compile((internalast.Policy)(cedarPolicy)), + eval: eval.Compile((*internalast.Policy)(&cedarPolicy)), ast: (*internalast.Policy)(&cedarPolicy), } return nil @@ -103,7 +103,7 @@ func NewPolicyFromAST(astIn *ast.Policy) *Policy { Position: Position{}, Annotations: newAnnotationsFromSlice(astIn.Annotations), Effect: Effect(astIn.Effect), - eval: eval.Compile(*pp), + eval: eval.Compile(pp), ast: pp, } } diff --git a/policy_set.go b/policy_set.go index 32c35431..87f2c8d4 100644 --- a/policy_set.go +++ b/policy_set.go @@ -36,7 +36,7 @@ func NewPolicySet(fileName string, document []byte) (PolicySet, error) { }, Annotations: newAnnotationsFromSlice(p.Annotations), Effect: Effect(p.Effect), - eval: eval.Compile((internalast.Policy)(*p)), + eval: eval.Compile((*internalast.Policy)(p)), ast: (*internalast.Policy)(p), } } From ac1a46b405fb127c95351663d3e30c84e2037e2a Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Thu, 15 Aug 2024 10:33:12 -0700 Subject: [PATCH 129/216] cedar-go: add a comment about the default naming of policies in NewPolicySet Signed-off-by: philhassey --- policy_set.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/policy_set.go b/policy_set.go index 87f2c8d4..cb3f89c6 100644 --- a/policy_set.go +++ b/policy_set.go @@ -19,6 +19,8 @@ type PolicySet struct { // NewPolicySet will create a PolicySet from the given text document with the // given file name used in Position data. If there is an error parsing the // document, it will be returned. +// +// NewPolicySet assigns default PolicyIDs to the policies contained in fileName. func NewPolicySet(fileName string, document []byte) (PolicySet, error) { var res parser.PolicySet if err := res.UnmarshalCedar(document); err != nil { From 5877001f342cd5caa19eb246f9ca4fbcb9b1e2bf Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Thu, 15 Aug 2024 11:14:08 -0700 Subject: [PATCH 130/216] cedar-go/ast: remove ExtensionCall sugar - this should never be needed by external users of the AST Signed-off-by: philhassey --- ast/ast_test.go | 5 ----- ast/value.go | 8 -------- policy_set.go | 2 +- 3 files changed, 1 insertion(+), 14 deletions(-) diff --git a/ast/ast_test.go b/ast/ast_test.go index 057272ed..b8aed222 100644 --- a/ast/ast_test.go +++ b/ast/ast_test.go @@ -265,11 +265,6 @@ func TestASTByTable(t *testing.T) { ast.Permit().When(ast.IPAddr(types.IPAddr(netip.MustParsePrefix("127.0.0.1/16")))), internalast.Permit().When(internalast.IPAddr(types.IPAddr(netip.MustParsePrefix("127.0.0.1/16")))), }, - { - "extensionCall", - ast.Permit().When(ast.ExtensionCall("ip", ast.String("127.0.0.1"))), - internalast.Permit().When(internalast.ExtensionCall("ip", internalast.String("127.0.0.1"))), - }, { "opEquals", ast.Permit().When(ast.Long(42).Equals(ast.Long(43))), diff --git a/ast/value.go b/ast/value.go index e0184a10..b5bccd53 100644 --- a/ast/value.go +++ b/ast/value.go @@ -99,11 +99,3 @@ func Decimal(d types.Decimal) Node { func IPAddr(i types.IPAddr) Node { return wrapNode(ast.IPAddr(i)) } - -func ExtensionCall(name types.String, args ...Node) Node { - var astNodes []ast.Node - for _, v := range args { - astNodes = append(astNodes, v.Node) - } - return wrapNode(ast.ExtensionCall(name, astNodes...)) -} diff --git a/policy_set.go b/policy_set.go index cb3f89c6..e5454fbf 100644 --- a/policy_set.go +++ b/policy_set.go @@ -11,7 +11,7 @@ import ( type PolicyID string -// A PolicySet is a slice of policies. +// TODO: Put a better comment here type PolicySet struct { policies map[PolicyID]*Policy } From efa4ed9a5eb49f7276327ab30b2e81c046551b8e Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Thu, 15 Aug 2024 12:41:27 -0700 Subject: [PATCH 131/216] cedar-go/internal/parser: check for known extension methods and functions as in the reference implementation Signed-off-by: philhassey --- authorize_test.go | 4 +-- internal/parser/cedar_marshal.go | 4 +-- internal/parser/cedar_parse_test.go | 48 ++++++++++++------------- internal/parser/cedar_unmarshal.go | 47 ++++++++++++------------ internal/parser/cedar_unmarshal_test.go | 40 +++++++++++++++++++++ 5 files changed, 93 insertions(+), 50 deletions(-) diff --git a/authorize_test.go b/authorize_test.go index 3622d6be..e7337a3a 100644 --- a/authorize_test.go +++ b/authorize_test.go @@ -481,7 +481,7 @@ func TestIsAuthorized(t *testing.T) { Context: types.Record{}, Want: false, DiagErr: 1, - ParseErr: false, + ParseErr: true, }, { Name: "permit-when-like", @@ -504,7 +504,7 @@ func TestIsAuthorized(t *testing.T) { Context: types.Record{}, Want: false, DiagErr: 1, - ParseErr: false, + ParseErr: true, }, { Name: "permit-when-decimal", diff --git a/internal/parser/cedar_marshal.go b/internal/parser/cedar_marshal.go index 06dd0648..29110aa6 100644 --- a/internal/parser/cedar_marshal.go +++ b/internal/parser/cedar_marshal.go @@ -174,8 +174,8 @@ func (n NodeTypeExtensionCall) marshalCedar(buf *bytes.Buffer) { buf.WriteString(string(n.NodeTypeExtensionCall.Name)) buf.WriteRune('(') for i := range args { - marshalChildNode(n.precedenceLevel(), n.NodeTypeExtensionCall.Args[i], buf) - if i != len(n.NodeTypeExtensionCall.Args)-1 { + marshalChildNode(n.precedenceLevel(), args[i], buf) + if i != len(args)-1 { buf.WriteString(", ") } } diff --git a/internal/parser/cedar_parse_test.go b/internal/parser/cedar_parse_test.go index 6737dbbe..854a9d04 100644 --- a/internal/parser/cedar_parse_test.go +++ b/internal/parser/cedar_parse_test.go @@ -97,20 +97,19 @@ func TestParse(t *testing.T) { when { Org::User::"alice" }; `, false}, {"primaryExtFun", `permit(principal, action, resource) - when { foo() } - unless { foo::bar::as() } - when { foo("hello") } - unless { foo::bar(true, 42, "forty two") }; + when { ip() } + when { ip("hello") } + when { ip(context.someString) }; `, false}, {"ifElseThen", `permit(principal, action, resource) when { if false then principal else principal };`, false}, {"access", `permit(principal, action, resource) when { resource.foo } unless { resource.foo.bar } - when { principal.foo() } - unless { principal.bar(false) } - when { action.foo["bar"].baz() } - unless { principal.bar(false, 123, "foo") } + when { principal.isIpv4() } + unless { principal.isIpv4(false) } + when { action.foo["bar"].isIpv4() } + unless { principal.isIpv4(false, 123, "foo") } when { principal["foo"] };`, false}, {"unary", `permit(principal, action, resource) when { !resource.foo } @@ -129,26 +128,26 @@ func TestParse(t *testing.T) { when { 42 + resource.bar - 43 } when { resource.foo + principal.bar };`, false}, {"relations", `permit(principal, action, resource) - when { foo() } - unless { foo() < 3 } - unless { foo() <= 3 } - unless { foo() > 3 } - unless { foo() >= 3 } - unless { foo() != 3 } - unless { foo() == 3 } - unless { foo() in Domain::"value" } - unless { foo() has blah } - when { foo() has "bar" } - when { foo() like "h*ll*" };`, false}, + when { ip() } + unless { ip() < 3 } + unless { ip() <= 3 } + unless { ip() > 3 } + unless { ip() >= 3 } + unless { ip() != 3 } + unless { ip() == 3 } + unless { ip() in Domain::"value" } + unless { ip() has blah } + when { ip() has "bar" } + when { ip() like "h*ll*" };`, false}, {"foo-like-foo", `permit(principal, action, resource) when { "f*o" like "f\*o" };`, false}, {"ands", `permit(principal, action, resource) - when { foo() && bar() && 3};`, false}, + when { ip() && decimal() && 3};`, false}, {"ors_and_ands", `permit(principal, action, resource) - when { foo() && bar() || baz() || 1 < 2 && 2 < 3};`, false}, + when { ip() && decimal() || ip() || 1 < 2 && 2 < 3};`, false}, {"primaryExpression", `permit(principal, action, resource) when { (true) } - unless { ((if (foo() <= 234) then principal else principal) like "") };`, false}, + unless { ((if (ip() <= 234) then principal else principal) like "") };`, false}, {"primaryExprList", `permit(principal, action, resource) when { [] } unless { [true] } @@ -298,10 +297,11 @@ func TestParse(t *testing.T) { var policies parser.PolicySet err := policies.UnmarshalCedar([]byte(tt.in)) - testutil.Equals(t, err != nil, tt.err) - if err != nil { + if tt.err { + testutil.Error(t, err) return } + testutil.OK(t, err) if len(policies) != 1 { // TODO: handle 0, > 1 return diff --git a/internal/parser/cedar_unmarshal.go b/internal/parser/cedar_unmarshal.go index 120c4302..654f1f27 100644 --- a/internal/parser/cedar_unmarshal.go +++ b/internal/parser/cedar_unmarshal.go @@ -6,6 +6,7 @@ import ( "strings" "github.com/cedar-policy/cedar-go/internal/ast" + "github.com/cedar-policy/cedar-go/internal/extensions" "github.com/cedar-policy/cedar-go/types" ) @@ -740,10 +741,7 @@ func (p *parser) primary() (ast.Node, error) { return res, nil } -func (p *parser) entityOrExtFun(ident string) (ast.Node, error) { - var res types.EntityUID - var err error - res.Type = ident +func (p *parser) entityOrExtFun(prefix string) (ast.Node, error) { for { t := p.advance() switch t.Text { @@ -751,30 +749,33 @@ func (p *parser) entityOrExtFun(ident string) (ast.Node, error) { t := p.advance() switch { case t.isIdent(): - res.Type = fmt.Sprintf("%v::%v", res.Type, t.Text) + prefix = prefix + "::" + t.Text case t.isString(): - res.ID, err = t.stringValue() + id, err := t.stringValue() if err != nil { return ast.Node{}, err } - return ast.EntityUID(res), nil + return ast.EntityUID(types.NewEntityUID(prefix, id)), nil default: return ast.Node{}, p.errorf("unexpected token") } case "(": + // Although the Cedar grammar says that any name can be provided here, the reference implementation actually + // checks at parse time whether the name corresponds to a known extension function. + i, ok := extensions.ExtMap[types.String(prefix)] + if !ok { + return ast.Node{}, p.errorf("`%v` is not a function", prefix) + } + if i.IsMethod { + return ast.Node{}, p.errorf("`%v` is a method, not a function", prefix) + } + args, err := p.expressions(")") if err != nil { return ast.Node{}, err } p.advance() - // i, ok := extensions.ExtMap[types.String(res.Type)] - // if !ok { - // return Node{}, p.errorf("`%v` is not a function", res.Type) - // } - // if i.IsMethod { - // return Node{}, p.errorf("`%v` is a method, not a function", res.Type) - // } - return ast.ExtensionCall(types.String(res.Type), args...), nil + return ast.ExtensionCall(types.String(prefix), args...), nil default: return ast.Node{}, p.errorf("unexpected token") } @@ -880,13 +881,15 @@ func (p *parser) access(lhs ast.Node) (ast.Node, bool, error) { case "containsAny": knownMethod = ast.Node.ContainsAny default: - // i, ok := extensions.ExtMap[types.String(methodName)] - // if !ok { - // return Node{}, false, p.errorf("not a valid method name: `%v`", methodName) - // } - // if !i.IsMethod { - // return Node{}, false, p.errorf("`%v` is a function, not a method", methodName) - // } + // Although the Cedar grammar says that any name can be provided here, the reference implementation + // actually checks at parse time whether the name corresponds to a known extension method. + i, ok := extensions.ExtMap[types.String(methodName)] + if !ok { + return ast.Node{}, false, p.errorf("`%v` is not a method", methodName) + } + if !i.IsMethod { + return ast.Node{}, false, p.errorf("`%v` is a function, not a method", methodName) + } return ast.NewMethodCall(lhs, types.String(methodName), exprs...), true, nil } diff --git a/internal/parser/cedar_unmarshal_test.go b/internal/parser/cedar_unmarshal_test.go index e02b96f9..616f4cae 100644 --- a/internal/parser/cedar_unmarshal_test.go +++ b/internal/parser/cedar_unmarshal_test.go @@ -446,6 +446,46 @@ when { (if true then 2 else 3) * 4 == 8 };`, } } +func TestParsePolicySetErrors(t *testing.T) { + t.Parallel() + parseTests := []struct { + Name string + Text string + ExpectedError string + }{ + { + "not-extension-function", + "permit ( principal, action, resource ) when { not_an_extension_fn() };", + "parse error at :1:67 \")\": `not_an_extension_fn` is not a function", + }, + { + "extension-function-is-method", + "permit ( principal, action, resource ) when { isIpv4() };", + "parse error at :1:54 \")\": `isIpv4` is a method, not a function", + }, + { + "not-extension-method", + "permit ( principal, action, resource ) when { context.not_an_extension_method() };", + "parse error at :1:81 \"}\": `not_an_extension_method` is not a method", + }, + { + "extension-method-is-function", + "permit ( principal, action, resource ) when { context.ip() };", + "parse error at :1:60 \"}\": `ip` is a function, not a method", + }, + } + + for _, tt := range parseTests { + t.Run(tt.Name, func(t *testing.T) { + t.Parallel() + var policy parser.Policy + err := policy.UnmarshalCedar([]byte(tt.Text)) + testutil.Error(t, err) + testutil.Equals(t, err.Error(), tt.ExpectedError) + }) + } +} + func TestParsePolicySet(t *testing.T) { t.Parallel() t.Run("single policy", func(t *testing.T) { From e552be77249d7f3d4e50b497e38c596d35738d0c Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Thu, 15 Aug 2024 12:59:39 -0700 Subject: [PATCH 132/216] cedar-go: fix two broken tests in authorize_test.go Signed-off-by: philhassey --- authorize_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/authorize_test.go b/authorize_test.go index e7337a3a..8eda2181 100644 --- a/authorize_test.go +++ b/authorize_test.go @@ -480,7 +480,7 @@ func TestIsAuthorized(t *testing.T) { Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: false, - DiagErr: 1, + DiagErr: 0, ParseErr: true, }, { @@ -503,7 +503,7 @@ func TestIsAuthorized(t *testing.T) { Resource: types.NewEntityUID("table", "whatever"), Context: types.Record{}, Want: false, - DiagErr: 1, + DiagErr: 0, ParseErr: true, }, { From c1a7b0afa1cee465b335bfef27304565df421053 Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Thu, 15 Aug 2024 13:00:09 -0700 Subject: [PATCH 133/216] cedar-go/internal/json: implement checking for invalid extension functions and methods in the JSON parser Signed-off-by: philhassey --- internal/json/json_test.go | 35 ++++++++------------------------- internal/json/json_unmarshal.go | 5 +++++ 2 files changed, 13 insertions(+), 27 deletions(-) diff --git a/internal/json/json_test.go b/internal/json/json_test.go index 3cdceca9..b89115e6 100644 --- a/internal/json/json_test.go +++ b/internal/json/json_test.go @@ -589,142 +589,123 @@ func TestMarshalPanics(t *testing.T) { func TestUnmarshalErrors(t *testing.T) { t.Parallel() tests := []struct { - name string - input string - errFunc func(testing.TB, error) + name string + input string }{ { "effect", `{"effect":"unknown","principal":{"op":"=="},"action":{"op":"All"},"resource":{"op":"All"}}`, - testutil.Error, }, { "scopeEqMissingEntity", `{"effect":"permit","principal":{"op":"=="},"action":{"op":"All"},"resource":{"op":"All"}}`, - testutil.Error, }, { "scopeUnknownOp", `{"effect":"permit","principal":{"op":"???"},"action":{"op":"All"},"resource":{"op":"All"}}`, - testutil.Error, }, { "actionUnknownOp", `{"effect":"permit","principal":{"op":"All"},"action":{"op":"???"},"resource":{"op":"All"}}`, - testutil.Error, }, { "resourceUnknownOp", `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"???"}}`, - testutil.Error, }, { "conditionUnknown", `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, "conditions":[{"kind":"unknown","body":{"Value":24}}]}`, - testutil.Error, }, { "binaryLeft", `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, "conditions":[{"kind":"when","body":{"&&":{"left":null,"right":{"Value":24}}}}]}`, - testutil.Error, }, { "binaryRight", `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, "conditions":[{"kind":"when","body":{"&&":{"left":{"Value":24},"right":null}}}]}`, - testutil.Error, }, { "unaryArg", `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, "conditions":[{"kind":"when","body":{"!":{"arg":null}}}]}`, - testutil.Error, }, { "accessLeft", `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, "conditions":[{"kind":"when","body":{".":{"left":null,"attr":"key"}}}]}`, - testutil.Error, }, { "patternLeft", `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, "conditions":[{"kind":"when","body":{"like":{"left":null,"pattern":["Wildcard"]}}}]}`, - testutil.Error, }, { "patternWildcard", `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, "conditions":[{"kind":"when","body":{"like":{"left":{"Value":"text"},"pattern":["invalid"]}}}]}`, - testutil.Error, }, { "isLeft", `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, "conditions":[{"kind":"when","body":{"is":{"left":null,"entity_type":"T"}}}]}`, - testutil.Error, }, { "isIn", `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, "conditions":[{"kind":"when","body":{"is":{"left":{"Var":"resource"},"entity_type":"T","in":{"Value":null}}}}]}`, - testutil.Error, }, { "ifErrThenElse", `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, "conditions":[{"kind":"when","body":{"if-then-else":{"if":{"Value":null},"then":{"Value":42},"else":{"Value":24}}}}]}`, - testutil.Error, }, { "ifThenErrElse", `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, "conditions":[{"kind":"when","body":{"if-then-else":{"if":{"Value":true},"then":{"Value":null},"else":{"Value":24}}}}]}`, - testutil.Error, }, { "ifThenElseErr", `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, "conditions":[{"kind":"when","body":{"if-then-else":{"if":{"Value":true},"then":{"Value":42},"else":{"Value":null}}}}]}`, - testutil.Error, }, { "setErr", `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, "conditions":[{"kind":"when","body":{"Set":[{"Value":null},{"Value":"bananas"}]}}]}`, - testutil.Error, }, { "recordErr", `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, "conditions":[{"kind":"when","body":{"Record":{"key":{"Value":null}}}}]}`, - testutil.Error, }, { "extensionTooMany", `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, "conditions":[{"kind":"when","body":{"ip":[{"Value":"10.0.0.42/8"}],"pi":[{"Value":"3.14"}]}}]}`, - testutil.Error, }, { "extensionArg", `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, "conditions":[{"kind":"when","body":{"ip":[{"Value":null}]}}]}`, - testutil.Error, }, { "var", `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, "conditions":[{"kind":"when","body":{"Var":"unknown"}}]}`, - testutil.Error, }, { "otherJSONerror", `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, "conditions":[{"kind":"when","body":42}]}`, - testutil.Error, + }, + { + "unknown-extension-function", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, + "conditions":[{"kind":"when","body":{"not_an_extension_function":[]}}]}`, }, } @@ -734,7 +715,7 @@ func TestUnmarshalErrors(t *testing.T) { t.Parallel() var p Policy err := json.Unmarshal([]byte(tt.input), &p) - tt.errFunc(t, err) + testutil.Error(t, err) }) } } diff --git a/internal/json/json_unmarshal.go b/internal/json/json_unmarshal.go index 6023001b..f5566e60 100644 --- a/internal/json/json_unmarshal.go +++ b/internal/json/json_unmarshal.go @@ -7,6 +7,7 @@ import ( "strings" "github.com/cedar-policy/cedar-go/internal/ast" + "github.com/cedar-policy/cedar-go/internal/extensions" "github.com/cedar-policy/cedar-go/types" ) @@ -137,6 +138,10 @@ func (e extensionJSON) ToNode() (ast.Node, error) { for k, v = range e { _, _ = k, v } + _, ok := extensions.ExtMap[types.String(k)] + if !ok { + return ast.Node{}, fmt.Errorf("`%v` is not a known extension function or method", k) + } var argNodes []ast.Node for _, n := range v { node, err := n.ToNode() From bc2cf7f31d85c17028e5509b1821fcbc00d4a87f Mon Sep 17 00:00:00 2001 From: philhassey Date: Thu, 15 Aug 2024 14:27:41 -0600 Subject: [PATCH 134/216] internal/eval: add test coverage for toEval Addresses IDX-142 Signed-off-by: philhassey --- internal/eval/compile.go | 56 +++ internal/eval/{eval_convert.go => convert.go} | 98 ++---- internal/eval/convert_test.go | 319 ++++++++++++++++++ internal/eval/eval_compile.go | 29 -- internal/eval/{eval_impl.go => evalers.go} | 0 .../eval/{eval_test.go => evalers_test.go} | 0 6 files changed, 410 insertions(+), 92 deletions(-) create mode 100644 internal/eval/compile.go rename internal/eval/{eval_convert.go => convert.go} (59%) create mode 100644 internal/eval/convert_test.go delete mode 100644 internal/eval/eval_compile.go rename internal/eval/{eval_impl.go => evalers.go} (100%) rename internal/eval/{eval_test.go => evalers_test.go} (100%) diff --git a/internal/eval/compile.go b/internal/eval/compile.go new file mode 100644 index 00000000..25c928c7 --- /dev/null +++ b/internal/eval/compile.go @@ -0,0 +1,56 @@ +package eval + +import ( + "fmt" + + "github.com/cedar-policy/cedar-go/internal/ast" + "github.com/cedar-policy/cedar-go/types" +) + +func Compile(p *ast.Policy) Evaler { + node := policyToNode(p).AsIsNode() + return toEval(node) +} + +func policyToNode(p *ast.Policy) ast.Node { + nodes := make([]ast.Node, 3+len(p.Conditions)) + nodes[0] = scopeToNode(ast.NewPrincipalNode(), p.Principal) + nodes[1] = scopeToNode(ast.NewActionNode(), p.Action) + nodes[2] = scopeToNode(ast.NewResourceNode(), p.Resource) + for i, c := range p.Conditions { + if c.Condition == ast.ConditionUnless { + nodes[i+3] = ast.Not(ast.NewNode(c.Body)) + continue + } + nodes[i+3] = ast.NewNode(c.Body) + } + res := nodes[len(nodes)-1] + for i := len(nodes) - 2; i >= 0; i-- { + res = nodes[i].And(res) + } + return res +} + +func scopeToNode(varNode ast.NodeTypeVariable, in ast.IsScopeNode) ast.Node { + switch t := in.(type) { + case ast.ScopeTypeAll: + return ast.True() + case ast.ScopeTypeEq: + return ast.NewNode(varNode).Equals(ast.EntityUID(t.Entity)) + case ast.ScopeTypeIn: + return ast.NewNode(varNode).In(ast.EntityUID(t.Entity)) + case ast.ScopeTypeInSet: + set := make([]types.Value, len(t.Entities)) + for i, e := range t.Entities { + set[i] = e + } + return ast.NewNode(varNode).In(ast.Set(set)) + case ast.ScopeTypeIs: + return ast.NewNode(varNode).Is(t.Type) + + case ast.ScopeTypeIsIn: + return ast.NewNode(varNode).IsIn(t.Type, ast.EntityUID(t.Entity)) + default: + panic(fmt.Sprintf("unknown scope type %T", t)) + } +} diff --git a/internal/eval/eval_convert.go b/internal/eval/convert.go similarity index 59% rename from internal/eval/eval_convert.go rename to internal/eval/convert.go index 006c51fd..a5433d5a 100644 --- a/internal/eval/eval_convert.go +++ b/internal/eval/convert.go @@ -5,7 +5,6 @@ import ( "github.com/cedar-policy/cedar-go/internal/ast" "github.com/cedar-policy/cedar-go/internal/extensions" - "github.com/cedar-policy/cedar-go/types" ) func toEval(n ast.IsNode) Evaler { @@ -26,41 +25,38 @@ func toEval(n ast.IsNode) Evaler { rhs := newInEval(obj, toEval(v.Entity)) return newAndEval(lhs, rhs) case ast.NodeTypeExtensionCall: - i, ok := extensions.ExtMap[v.Name] - if !ok { - return newErrorEval(fmt.Errorf("%w: %s", errUnknownExtensionFunction, v.Name)) - } - if i.Args != len(v.Args) { - return newErrorEval(fmt.Errorf("%w: %s takes %d parameter(s)", errArity, v.Name, i.Args)) - } - switch { - case v.Name == "ip": - return newIPLiteralEval(toEval(v.Args[0])) - case v.Name == "decimal": - return newDecimalLiteralEval(toEval(v.Args[0])) + if i, ok := extensions.ExtMap[v.Name]; ok { + if i.Args != len(v.Args) { + return newErrorEval(fmt.Errorf("%w: %s takes %d parameter(s)", errArity, v.Name, i.Args)) + } + switch { + case v.Name == "ip": + return newIPLiteralEval(toEval(v.Args[0])) + case v.Name == "decimal": + return newDecimalLiteralEval(toEval(v.Args[0])) - case v.Name == "lessThan": - return newDecimalLessThanEval(toEval(v.Args[0]), toEval(v.Args[1])) - case v.Name == "lessThanOrEqual": - return newDecimalLessThanOrEqualEval(toEval(v.Args[0]), toEval(v.Args[1])) - case v.Name == "greaterThan": - return newDecimalGreaterThanEval(toEval(v.Args[0]), toEval(v.Args[1])) - case v.Name == "greaterThanOrEqual": - return newDecimalGreaterThanOrEqualEval(toEval(v.Args[0]), toEval(v.Args[1])) + case v.Name == "lessThan": + return newDecimalLessThanEval(toEval(v.Args[0]), toEval(v.Args[1])) + case v.Name == "lessThanOrEqual": + return newDecimalLessThanOrEqualEval(toEval(v.Args[0]), toEval(v.Args[1])) + case v.Name == "greaterThan": + return newDecimalGreaterThanEval(toEval(v.Args[0]), toEval(v.Args[1])) + case v.Name == "greaterThanOrEqual": + return newDecimalGreaterThanOrEqualEval(toEval(v.Args[0]), toEval(v.Args[1])) - case v.Name == "isIpv4": - return newIPTestEval(toEval(v.Args[0]), ipTestIPv4) - case v.Name == "isIpv6": - return newIPTestEval(toEval(v.Args[0]), ipTestIPv6) - case v.Name == "isLoopback": - return newIPTestEval(toEval(v.Args[0]), ipTestLoopback) - case v.Name == "isMulticast": - return newIPTestEval(toEval(v.Args[0]), ipTestMulticast) - case v.Name == "isInRange": - return newIPIsInRangeEval(toEval(v.Args[0]), toEval(v.Args[1])) - default: - panic(fmt.Errorf("unknown extension: %v", v.Name)) + case v.Name == "isIpv4": + return newIPTestEval(toEval(v.Args[0]), ipTestIPv4) + case v.Name == "isIpv6": + return newIPTestEval(toEval(v.Args[0]), ipTestIPv6) + case v.Name == "isLoopback": + return newIPTestEval(toEval(v.Args[0]), ipTestLoopback) + case v.Name == "isMulticast": + return newIPTestEval(toEval(v.Args[0]), ipTestMulticast) + case v.Name == "isInRange": + return newIPIsInRangeEval(toEval(v.Args[0]), toEval(v.Args[1])) + } } + return newErrorEval(fmt.Errorf("%w: %s", errUnknownExtensionFunction, v.Name)) case ast.NodeValue: return newLiteralEval(v.Value) case ast.NodeTypeRecord: @@ -96,8 +92,12 @@ func toEval(n ast.IsNode) Evaler { return newInEval(toEval(v.Left), toEval(v.Right)) case ast.NodeTypeAnd: return newAndEval(toEval(v.Left), toEval(v.Right)) + case ast.NodeTypeOr: + return newOrNode(toEval(v.Left), toEval(v.Right)) case ast.NodeTypeEquals: return newEqualEval(toEval(v.Left), toEval(v.Right)) + case ast.NodeTypeNotEquals: + return newNotEqualEval(toEval(v.Left), toEval(v.Right)) case ast.NodeTypeGreaterThan: return newLongGreaterThanEval(toEval(v.Left), toEval(v.Right)) case ast.NodeTypeGreaterThanOrEqual: @@ -110,43 +110,15 @@ func toEval(n ast.IsNode) Evaler { return newSubtractEval(toEval(v.Left), toEval(v.Right)) case ast.NodeTypeAdd: return newAddEval(toEval(v.Left), toEval(v.Right)) + case ast.NodeTypeMult: + return newMultiplyEval(toEval(v.Left), toEval(v.Right)) case ast.NodeTypeContains: return newContainsEval(toEval(v.Left), toEval(v.Right)) case ast.NodeTypeContainsAll: return newContainsAllEval(toEval(v.Left), toEval(v.Right)) case ast.NodeTypeContainsAny: return newContainsAnyEval(toEval(v.Left), toEval(v.Right)) - case ast.NodeTypeMult: - return newMultiplyEval(toEval(v.Left), toEval(v.Right)) - case ast.NodeTypeNotEquals: - return newNotEqualEval(toEval(v.Left), toEval(v.Right)) - case ast.NodeTypeOr: - return newOrNode(toEval(v.Left), toEval(v.Right)) default: panic(fmt.Sprintf("unknown node type %T", v)) } } - -func scopeToNode(varNode ast.NodeTypeVariable, in ast.IsScopeNode) ast.Node { - switch t := in.(type) { - case ast.ScopeTypeAll: - return ast.True() - case ast.ScopeTypeEq: - return ast.NewNode(varNode).Equals(ast.EntityUID(t.Entity)) - case ast.ScopeTypeIn: - return ast.NewNode(varNode).In(ast.EntityUID(t.Entity)) - case ast.ScopeTypeInSet: - set := make([]types.Value, len(t.Entities)) - for i, e := range t.Entities { - set[i] = e - } - return ast.NewNode(varNode).In(ast.Set(set)) - case ast.ScopeTypeIs: - return ast.NewNode(varNode).Is(t.Type) - - case ast.ScopeTypeIsIn: - return ast.NewNode(varNode).IsIn(t.Type, ast.EntityUID(t.Entity)) - default: - panic(fmt.Sprintf("unknown scope type %T", t)) - } -} diff --git a/internal/eval/convert_test.go b/internal/eval/convert_test.go new file mode 100644 index 00000000..9106a4e6 --- /dev/null +++ b/internal/eval/convert_test.go @@ -0,0 +1,319 @@ +package eval + +import ( + "net/netip" + "testing" + + "github.com/cedar-policy/cedar-go/internal/ast" + "github.com/cedar-policy/cedar-go/internal/testutil" + "github.com/cedar-policy/cedar-go/types" +) + +func TestToEval(t *testing.T) { + t.Parallel() + tests := []struct { + name string + in ast.Node + out types.Value + err func(testing.TB, error) + }{ + { + "access", + ast.Record(types.Record{"key": types.Long(42)}).Access("key"), + types.Long(42), + testutil.OK, + }, + { + "has", + ast.Record(types.Record{"key": types.Long(42)}).Has("key"), + types.Boolean(true), + testutil.OK, + }, + { + "like", + ast.String("test").Like(types.Pattern{}), + types.Boolean(false), + testutil.OK, + }, + { + "if", + ast.If(ast.True(), ast.Long(42), ast.Long(43)), + types.Long(42), + testutil.OK, + }, + { + "is", + ast.EntityUID(types.NewEntityUID("T", "42")).Is("T"), + types.Boolean(true), + testutil.OK, + }, + { + "isIn", + ast.EntityUID(types.NewEntityUID("T", "42")).IsIn("T", ast.EntityUID(types.NewEntityUID("T", "42"))), + types.Boolean(true), + testutil.OK, + }, + { + "value", + ast.Long(42), + types.Long(42), + testutil.OK, + }, + { + "record", + ast.RecordElements(ast.RecordElement{Key: "key", Value: ast.Long(42)}), + types.Record{"key": types.Long(42)}, + testutil.OK, + }, + { + "set", + ast.SetNodes(ast.Long(42)), + types.Set{types.Long(42)}, + testutil.OK, + }, + { + "negate", + ast.Negate(ast.Long(42)), + types.Long(-42), + testutil.OK, + }, + { + "not", + ast.Not(ast.True()), + types.Boolean(false), + testutil.OK, + }, + { + "principal", + ast.Principal(), + types.NewEntityUID("Actor", "principal"), + testutil.OK, + }, + { + "action", + ast.Action(), + types.NewEntityUID("Action", "test"), + testutil.OK, + }, + { + "resource", + ast.Resource(), + types.NewEntityUID("Resource", "database"), + testutil.OK, + }, + { + "context", + ast.Context(), + types.Record{}, + testutil.OK, + }, + { + "in", + ast.EntityUID(types.NewEntityUID("T", "42")).In(ast.EntityUID(types.NewEntityUID("T", "43"))), + types.Boolean(false), + testutil.OK, + }, + { + "and", + ast.True().And(ast.False()), + types.Boolean(false), + testutil.OK, + }, + { + "or", + ast.True().Or(ast.False()), + types.Boolean(true), + testutil.OK, + }, + { + "equals", + ast.Long(42).Equals(ast.Long(43)), + types.Boolean(false), + testutil.OK, + }, + { + "notEquals", + ast.Long(42).NotEquals(ast.Long(43)), + types.Boolean(true), + testutil.OK, + }, + { + "greaterThan", + ast.Long(42).GreaterThan(ast.Long(43)), + types.Boolean(false), + testutil.OK, + }, + { + "greaterThanOrEqual", + ast.Long(42).GreaterThanOrEqual(ast.Long(43)), + types.Boolean(false), + testutil.OK, + }, + { + "lessThan", + ast.Long(42).LessThan(ast.Long(43)), + types.Boolean(true), + testutil.OK, + }, + { + "lessThanOrEqual", + ast.Long(42).LessThanOrEqual(ast.Long(43)), + types.Boolean(true), + testutil.OK, + }, + { + "sub", + ast.Long(42).Minus(ast.Long(2)), + types.Long(40), + testutil.OK, + }, + { + "add", + ast.Long(40).Plus(ast.Long(2)), + types.Long(42), + testutil.OK, + }, + { + "mult", + ast.Long(6).Times(ast.Long(7)), + types.Long(42), + testutil.OK, + }, + { + "contains", + ast.Set(types.Set{types.Long(42)}).Contains(ast.Long(42)), + types.Boolean(true), + testutil.OK, + }, + { + "containsAll", + ast.Set(types.Set{types.Long(42), types.Long(43), types.Long(44)}).ContainsAll(ast.Set(types.Set{types.Long(42), types.Long(43)})), + types.Boolean(true), + testutil.OK, + }, + { + "containsAny", + ast.Set(types.Set{types.Long(42), types.Long(43), types.Long(44)}).ContainsAny(ast.Set(types.Set{types.Long(1), types.Long(42)})), + types.Boolean(true), + testutil.OK, + }, + { + "ip", + ast.ExtensionCall("ip", ast.String("127.0.0.42/16")), + types.IPAddr(netip.MustParsePrefix("127.0.0.42/16")), + testutil.OK, + }, + { + "decimal", + ast.ExtensionCall("decimal", ast.String("42.42")), + types.Decimal(424200), + testutil.OK, + }, + { + "lessThan", + ast.ExtensionCall("lessThan", ast.Decimal(420000), ast.Decimal(430000)), + types.Boolean(true), + testutil.OK, + }, + { + "lessThanOrEqual", + ast.ExtensionCall("lessThanOrEqual", ast.Decimal(420000), ast.Decimal(430000)), + types.Boolean(true), + testutil.OK, + }, + { + "greaterThan", + ast.ExtensionCall("greaterThan", ast.Decimal(420000), ast.Decimal(430000)), + types.Boolean(false), + testutil.OK, + }, + { + "greaterThanOrEqual", + ast.ExtensionCall("greaterThanOrEqual", ast.Decimal(420000), ast.Decimal(430000)), + types.Boolean(false), + testutil.OK, + }, + { + "isIpv4", + ast.ExtensionCall("isIpv4", ast.IPAddr(types.IPAddr(netip.MustParsePrefix("127.0.0.42/16")))), + types.Boolean(true), + testutil.OK, + }, + { + "isIpv6", + ast.ExtensionCall("isIpv6", ast.IPAddr(types.IPAddr(netip.MustParsePrefix("::1/16")))), + types.Boolean(true), + testutil.OK, + }, + { + "isLoopback", + ast.ExtensionCall("isLoopback", ast.IPAddr(types.IPAddr(netip.MustParsePrefix("127.0.0.1/32")))), + types.Boolean(true), + testutil.OK, + }, + { + "isMulticast", + ast.ExtensionCall("isMulticast", ast.IPAddr(types.IPAddr(netip.MustParsePrefix("239.255.255.255/32")))), + types.Boolean(true), + testutil.OK, + }, + { + "isInRange", + ast.ExtensionCall("isInRange", ast.IPAddr(types.IPAddr(netip.MustParsePrefix("127.0.0.42/32"))), ast.IPAddr(types.IPAddr(netip.MustParsePrefix("127.0.0.0/16")))), + types.Boolean(true), + testutil.OK, + }, + { + "extUnknown", + ast.ExtensionCall("unknown", ast.String("hello")), + nil, + testutil.Error, + }, + { + "extArgs", + ast.ExtensionCall("ip", ast.String("1"), ast.String("2")), + nil, + testutil.Error, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + e := toEval(tt.in.AsIsNode()) + out, err := e.Eval(&Context{ + Principal: types.NewEntityUID("Actor", "principal"), + Action: types.NewEntityUID("Action", "test"), + Resource: types.NewEntityUID("Resource", "database"), + Context: types.Record{}, + }) + tt.err(t, err) + testutil.Equals(t, out, tt.out) + }) + } + +} + +func TestToEvalPanics(t *testing.T) { + t.Parallel() + tests := []struct { + name string + in ast.Node + }{ + { + "unknownNode", + ast.Node{}, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + testutil.AssertPanic(t, func() { + _ = toEval(tt.in.AsIsNode()) + }) + }) + } +} diff --git a/internal/eval/eval_compile.go b/internal/eval/eval_compile.go deleted file mode 100644 index 0a708637..00000000 --- a/internal/eval/eval_compile.go +++ /dev/null @@ -1,29 +0,0 @@ -package eval - -import ( - "github.com/cedar-policy/cedar-go/internal/ast" -) - -func Compile(p *ast.Policy) Evaler { - node := policyToNode(p).AsIsNode() - return toEval(node) -} - -func policyToNode(p *ast.Policy) ast.Node { - nodes := make([]ast.Node, 3+len(p.Conditions)) - nodes[0] = scopeToNode(ast.NewPrincipalNode(), p.Principal) - nodes[1] = scopeToNode(ast.NewActionNode(), p.Action) - nodes[2] = scopeToNode(ast.NewResourceNode(), p.Resource) - for i, c := range p.Conditions { - if c.Condition == ast.ConditionUnless { - nodes[i+3] = ast.Not(ast.NewNode(c.Body)) - continue - } - nodes[i+3] = ast.NewNode(c.Body) - } - res := nodes[len(nodes)-1] - for i := len(nodes) - 2; i >= 0; i-- { - res = nodes[i].And(res) - } - return res -} diff --git a/internal/eval/eval_impl.go b/internal/eval/evalers.go similarity index 100% rename from internal/eval/eval_impl.go rename to internal/eval/evalers.go diff --git a/internal/eval/eval_test.go b/internal/eval/evalers_test.go similarity index 100% rename from internal/eval/eval_test.go rename to internal/eval/evalers_test.go From 0e6df3d53fb19d03bcbfbcf515280455535af50f Mon Sep 17 00:00:00 2001 From: philhassey Date: Thu, 15 Aug 2024 15:09:14 -0600 Subject: [PATCH 135/216] internal/eval: add more tests Addresses IDX-142 Signed-off-by: philhassey --- internal/eval/compile_test.go | 126 ++++++++++++++++++++++++++++++++++ internal/eval/convert_test.go | 24 ++----- 2 files changed, 130 insertions(+), 20 deletions(-) create mode 100644 internal/eval/compile_test.go diff --git a/internal/eval/compile_test.go b/internal/eval/compile_test.go new file mode 100644 index 00000000..7b716de1 --- /dev/null +++ b/internal/eval/compile_test.go @@ -0,0 +1,126 @@ +package eval + +import ( + "testing" + + "github.com/cedar-policy/cedar-go/internal/ast" + "github.com/cedar-policy/cedar-go/internal/testutil" + "github.com/cedar-policy/cedar-go/types" +) + +func TestCompile(t *testing.T) { + t.Parallel() + e := Compile(ast.Permit()) + res, err := e.Eval(nil) + testutil.OK(t, err) + testutil.Equals(t, res, types.Value(types.Boolean(true))) +} + +func TestPolicyToNode(t *testing.T) { + t.Parallel() + tests := []struct { + name string + in *ast.Policy + out ast.Node + }{ + { + "permit", + ast.Permit(), + ast.True().And(ast.True().And(ast.True())), + }, + { + "eqs", + + ast.Permit(). + PrincipalEq(types.NewEntityUID("Account", "principal")). + ActionEq(types.NewEntityUID("Action", "test")). + ResourceEq(types.NewEntityUID("Resource", "database")), + + ast.Principal().Equals(ast.EntityUID(types.NewEntityUID("Account", "principal"))).And( + ast.Action().Equals(ast.EntityUID(types.NewEntityUID("Action", "test"))).And( + ast.Resource().Equals(ast.EntityUID(types.NewEntityUID("Resource", "database"))), + ), + ), + }, + + { + "conds", + + ast.Permit(). + When(ast.Long(123)). + Unless(ast.Long(456)), + + ast.True().And(ast.True().And(ast.True().And(ast.Long(123).And(ast.Not(ast.Long(456)))))), + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + out := policyToNode(tt.in) + testutil.Equals(t, out, tt.out) + }) + } +} + +func TestScopeToNode(t *testing.T) { + t.Parallel() + tests := []struct { + name string + scope ast.NodeTypeVariable + in ast.IsScopeNode + out ast.Node + }{ + { + "all", + ast.NewPrincipalNode(), + ast.ScopeTypeAll{}, + ast.True(), + }, + { + "eq", + ast.NewPrincipalNode(), + ast.ScopeTypeEq{Entity: types.NewEntityUID("T", "42")}, + ast.Principal().Equals(ast.EntityUID(types.NewEntityUID("T", "42"))), + }, + { + "in", + ast.NewPrincipalNode(), + ast.ScopeTypeIn{Entity: types.NewEntityUID("T", "42")}, + ast.Principal().In(ast.EntityUID(types.NewEntityUID("T", "42"))), + }, + { + "inSet", + ast.NewActionNode(), + ast.ScopeTypeInSet{Entities: []types.EntityUID{types.NewEntityUID("T", "42")}}, + ast.Action().In(ast.Set(types.Set{types.NewEntityUID("T", "42")})), + }, + { + "is", + ast.NewResourceNode(), + ast.ScopeTypeIs{Type: "T"}, + ast.Resource().Is("T"), + }, + { + "isIn", + ast.NewResourceNode(), + ast.ScopeTypeIsIn{Type: "T", Entity: types.NewEntityUID("T", "42")}, + ast.Resource().IsIn("T", ast.EntityUID(types.NewEntityUID("T", "42"))), + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + out := scopeToNode(tt.scope, tt.in) + testutil.Equals(t, out, tt.out) + }) + } +} + +func TestScopeToNodePanic(t *testing.T) { + t.Parallel() + testutil.AssertPanic(t, func() { + _ = scopeToNode(ast.NewPrincipalNode(), ast.ScopeNode{}) + }) +} diff --git a/internal/eval/convert_test.go b/internal/eval/convert_test.go index 9106a4e6..62216ac8 100644 --- a/internal/eval/convert_test.go +++ b/internal/eval/convert_test.go @@ -295,25 +295,9 @@ func TestToEval(t *testing.T) { } -func TestToEvalPanics(t *testing.T) { +func TestToEvalPanic(t *testing.T) { t.Parallel() - tests := []struct { - name string - in ast.Node - }{ - { - "unknownNode", - ast.Node{}, - }, - } - - for _, tt := range tests { - tt := tt - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - testutil.AssertPanic(t, func() { - _ = toEval(tt.in.AsIsNode()) - }) - }) - } + testutil.AssertPanic(t, func() { + _ = toEval(ast.Node{}.AsIsNode()) + }) } From 33e367901e9c390e6c4a8a0a2d4603e1ec66323b Mon Sep 17 00:00:00 2001 From: philhassey Date: Thu, 15 Aug 2024 15:14:31 -0600 Subject: [PATCH 136/216] types: add True and False constants Addresses IDX-142 Signed-off-by: philhassey --- internal/ast/ast_test.go | 18 +-- internal/eval/compile_test.go | 2 +- internal/eval/convert_test.go | 52 ++++---- internal/eval/evalers_test.go | 242 +++++++++++++++++----------------- types/value.go | 5 + 5 files changed, 162 insertions(+), 157 deletions(-) diff --git a/internal/ast/ast_test.go b/internal/ast/ast_test.go index 52f2bfaa..b15242d7 100644 --- a/internal/ast/ast_test.go +++ b/internal/ast/ast_test.go @@ -113,14 +113,14 @@ func TestASTByTable(t *testing.T) { "when", ast.Permit().When(ast.True()), ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, - Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeValue{Value: types.Boolean(true)}}}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeValue{Value: types.True}}}, }, }, { "unless", ast.Permit().Unless(ast.True()), ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, - Conditions: []ast.ConditionType{{Condition: ast.ConditionUnless, Body: ast.NodeValue{Value: types.Boolean(true)}}}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionUnless, Body: ast.NodeValue{Value: types.True}}}, }, }, { @@ -210,28 +210,28 @@ func TestASTByTable(t *testing.T) { "valueBoolFalse", ast.Permit().When(ast.Boolean(false)), ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, - Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeValue{Value: types.Boolean(false)}}}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeValue{Value: types.False}}}, }, }, { "valueBoolTrue", ast.Permit().When(ast.Boolean(true)), ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, - Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeValue{Value: types.Boolean(true)}}}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeValue{Value: types.True}}}, }, }, { "valueTrue", ast.Permit().When(ast.True()), ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, - Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeValue{Value: types.Boolean(true)}}}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeValue{Value: types.True}}}, }, }, { "valueFalse", ast.Permit().When(ast.False()), ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, - Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeValue{Value: types.Boolean(false)}}}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeValue{Value: types.False}}}, }, }, { @@ -392,13 +392,13 @@ func TestASTByTable(t *testing.T) { "opNot", ast.Permit().When(ast.Not(ast.True())), ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, - Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeNot{UnaryNode: ast.UnaryNode{Arg: ast.NodeValue{Value: types.Boolean(true)}}}}}}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeNot{UnaryNode: ast.UnaryNode{Arg: ast.NodeValue{Value: types.True}}}}}}, }, { "opIf", ast.Permit().When(ast.If(ast.True(), ast.Long(42), ast.Long(43))), ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, - Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeIf{If: ast.NodeValue{Value: types.Boolean(true)}, Then: ast.NodeValue{Value: types.Long(42)}, Else: ast.NodeValue{Value: types.Long(43)}}}}}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeIf{If: ast.NodeValue{Value: types.True}, Then: ast.NodeValue{Value: types.Long(42)}, Else: ast.NodeValue{Value: types.Long(43)}}}}}, }, { "opPlus", @@ -422,7 +422,7 @@ func TestASTByTable(t *testing.T) { "opNegate", ast.Permit().When(ast.Negate(ast.True())), ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, - Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeNegate{UnaryNode: ast.UnaryNode{Arg: ast.NodeValue{Value: types.Boolean(true)}}}}}}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeNegate{UnaryNode: ast.UnaryNode{Arg: ast.NodeValue{Value: types.True}}}}}}, }, { "opIn", diff --git a/internal/eval/compile_test.go b/internal/eval/compile_test.go index 7b716de1..94e79d6c 100644 --- a/internal/eval/compile_test.go +++ b/internal/eval/compile_test.go @@ -13,7 +13,7 @@ func TestCompile(t *testing.T) { e := Compile(ast.Permit()) res, err := e.Eval(nil) testutil.OK(t, err) - testutil.Equals(t, res, types.Value(types.Boolean(true))) + testutil.Equals(t, res, types.Value(types.True)) } func TestPolicyToNode(t *testing.T) { diff --git a/internal/eval/convert_test.go b/internal/eval/convert_test.go index 62216ac8..a6fe1a23 100644 --- a/internal/eval/convert_test.go +++ b/internal/eval/convert_test.go @@ -26,13 +26,13 @@ func TestToEval(t *testing.T) { { "has", ast.Record(types.Record{"key": types.Long(42)}).Has("key"), - types.Boolean(true), + types.True, testutil.OK, }, { "like", ast.String("test").Like(types.Pattern{}), - types.Boolean(false), + types.False, testutil.OK, }, { @@ -44,13 +44,13 @@ func TestToEval(t *testing.T) { { "is", ast.EntityUID(types.NewEntityUID("T", "42")).Is("T"), - types.Boolean(true), + types.True, testutil.OK, }, { "isIn", ast.EntityUID(types.NewEntityUID("T", "42")).IsIn("T", ast.EntityUID(types.NewEntityUID("T", "42"))), - types.Boolean(true), + types.True, testutil.OK, }, { @@ -80,7 +80,7 @@ func TestToEval(t *testing.T) { { "not", ast.Not(ast.True()), - types.Boolean(false), + types.False, testutil.OK, }, { @@ -110,55 +110,55 @@ func TestToEval(t *testing.T) { { "in", ast.EntityUID(types.NewEntityUID("T", "42")).In(ast.EntityUID(types.NewEntityUID("T", "43"))), - types.Boolean(false), + types.False, testutil.OK, }, { "and", ast.True().And(ast.False()), - types.Boolean(false), + types.False, testutil.OK, }, { "or", ast.True().Or(ast.False()), - types.Boolean(true), + types.True, testutil.OK, }, { "equals", ast.Long(42).Equals(ast.Long(43)), - types.Boolean(false), + types.False, testutil.OK, }, { "notEquals", ast.Long(42).NotEquals(ast.Long(43)), - types.Boolean(true), + types.True, testutil.OK, }, { "greaterThan", ast.Long(42).GreaterThan(ast.Long(43)), - types.Boolean(false), + types.False, testutil.OK, }, { "greaterThanOrEqual", ast.Long(42).GreaterThanOrEqual(ast.Long(43)), - types.Boolean(false), + types.False, testutil.OK, }, { "lessThan", ast.Long(42).LessThan(ast.Long(43)), - types.Boolean(true), + types.True, testutil.OK, }, { "lessThanOrEqual", ast.Long(42).LessThanOrEqual(ast.Long(43)), - types.Boolean(true), + types.True, testutil.OK, }, { @@ -182,19 +182,19 @@ func TestToEval(t *testing.T) { { "contains", ast.Set(types.Set{types.Long(42)}).Contains(ast.Long(42)), - types.Boolean(true), + types.True, testutil.OK, }, { "containsAll", ast.Set(types.Set{types.Long(42), types.Long(43), types.Long(44)}).ContainsAll(ast.Set(types.Set{types.Long(42), types.Long(43)})), - types.Boolean(true), + types.True, testutil.OK, }, { "containsAny", ast.Set(types.Set{types.Long(42), types.Long(43), types.Long(44)}).ContainsAny(ast.Set(types.Set{types.Long(1), types.Long(42)})), - types.Boolean(true), + types.True, testutil.OK, }, { @@ -212,55 +212,55 @@ func TestToEval(t *testing.T) { { "lessThan", ast.ExtensionCall("lessThan", ast.Decimal(420000), ast.Decimal(430000)), - types.Boolean(true), + types.True, testutil.OK, }, { "lessThanOrEqual", ast.ExtensionCall("lessThanOrEqual", ast.Decimal(420000), ast.Decimal(430000)), - types.Boolean(true), + types.True, testutil.OK, }, { "greaterThan", ast.ExtensionCall("greaterThan", ast.Decimal(420000), ast.Decimal(430000)), - types.Boolean(false), + types.False, testutil.OK, }, { "greaterThanOrEqual", ast.ExtensionCall("greaterThanOrEqual", ast.Decimal(420000), ast.Decimal(430000)), - types.Boolean(false), + types.False, testutil.OK, }, { "isIpv4", ast.ExtensionCall("isIpv4", ast.IPAddr(types.IPAddr(netip.MustParsePrefix("127.0.0.42/16")))), - types.Boolean(true), + types.True, testutil.OK, }, { "isIpv6", ast.ExtensionCall("isIpv6", ast.IPAddr(types.IPAddr(netip.MustParsePrefix("::1/16")))), - types.Boolean(true), + types.True, testutil.OK, }, { "isLoopback", ast.ExtensionCall("isLoopback", ast.IPAddr(types.IPAddr(netip.MustParsePrefix("127.0.0.1/32")))), - types.Boolean(true), + types.True, testutil.OK, }, { "isMulticast", ast.ExtensionCall("isMulticast", ast.IPAddr(types.IPAddr(netip.MustParsePrefix("239.255.255.255/32")))), - types.Boolean(true), + types.True, testutil.OK, }, { "isInRange", ast.ExtensionCall("isInRange", ast.IPAddr(types.IPAddr(netip.MustParsePrefix("127.0.0.42/32"))), ast.IPAddr(types.IPAddr(netip.MustParsePrefix("127.0.0.0/16")))), - types.Boolean(true), + types.True, testutil.OK, }, { diff --git a/internal/eval/evalers_test.go b/internal/eval/evalers_test.go index 2f2b7ea0..aab60924 100644 --- a/internal/eval/evalers_test.go +++ b/internal/eval/evalers_test.go @@ -45,7 +45,7 @@ func TestOrNode(t *testing.T) { t.Run("TrueXShortCircuit", func(t *testing.T) { t.Parallel() n := newOrNode( - newLiteralEval(types.Boolean(true)), newLiteralEval(types.Long(1))) + newLiteralEval(types.True), newLiteralEval(types.Long(1))) v, err := n.Eval(&Context{}) testutil.OK(t, err) types.AssertBoolValue(t, v, true) @@ -57,10 +57,10 @@ func TestOrNode(t *testing.T) { lhs, rhs Evaler err error }{ - {"LhsError", newErrorEval(errTest), newLiteralEval(types.Boolean(true)), errTest}, - {"LhsTypeError", newLiteralEval(types.Long(1)), newLiteralEval(types.Boolean(true)), types.ErrType}, - {"RhsError", newLiteralEval(types.Boolean(false)), newErrorEval(errTest), errTest}, - {"RhsTypeError", newLiteralEval(types.Boolean(false)), newLiteralEval(types.Long(1)), types.ErrType}, + {"LhsError", newErrorEval(errTest), newLiteralEval(types.True), errTest}, + {"LhsTypeError", newLiteralEval(types.Long(1)), newLiteralEval(types.True), types.ErrType}, + {"RhsError", newLiteralEval(types.False), newErrorEval(errTest), errTest}, + {"RhsTypeError", newLiteralEval(types.False), newLiteralEval(types.Long(1)), types.ErrType}, } for _, tt := range tests { tt := tt @@ -100,7 +100,7 @@ func TestAndNode(t *testing.T) { t.Run("FalseXShortCircuit", func(t *testing.T) { t.Parallel() n := newAndEval( - newLiteralEval(types.Boolean(false)), newLiteralEval(types.Long(1))) + newLiteralEval(types.False), newLiteralEval(types.Long(1))) v, err := n.Eval(&Context{}) testutil.OK(t, err) types.AssertBoolValue(t, v, false) @@ -112,10 +112,10 @@ func TestAndNode(t *testing.T) { lhs, rhs Evaler err error }{ - {"LhsError", newErrorEval(errTest), newLiteralEval(types.Boolean(true)), errTest}, - {"LhsTypeError", newLiteralEval(types.Long(1)), newLiteralEval(types.Boolean(true)), types.ErrType}, - {"RhsError", newLiteralEval(types.Boolean(true)), newErrorEval(errTest), errTest}, - {"RhsTypeError", newLiteralEval(types.Boolean(true)), newLiteralEval(types.Long(1)), types.ErrType}, + {"LhsError", newErrorEval(errTest), newLiteralEval(types.True), errTest}, + {"LhsTypeError", newLiteralEval(types.Long(1)), newLiteralEval(types.True), types.ErrType}, + {"RhsError", newLiteralEval(types.True), newErrorEval(errTest), errTest}, + {"RhsTypeError", newLiteralEval(types.True), newLiteralEval(types.Long(1)), types.ErrType}, } for _, tt := range tests { tt := tt @@ -355,9 +355,9 @@ func TestAddNode(t *testing.T) { err error }{ {"LhsError", newErrorEval(errTest), newLiteralEval(types.Long(0)), errTest}, - {"LhsTypeError", newLiteralEval(types.Boolean(true)), newLiteralEval(types.Long(0)), types.ErrType}, + {"LhsTypeError", newLiteralEval(types.True), newLiteralEval(types.Long(0)), types.ErrType}, {"RhsError", newLiteralEval(types.Long(0)), newErrorEval(errTest), errTest}, - {"RhsTypeError", newLiteralEval(types.Long(0)), newLiteralEval(types.Boolean(true)), types.ErrType}, + {"RhsTypeError", newLiteralEval(types.Long(0)), newLiteralEval(types.True), types.ErrType}, {"PositiveOverflow", newLiteralEval(types.Long(9_223_372_036_854_775_807)), newLiteralEval(types.Long(1)), @@ -394,9 +394,9 @@ func TestSubtractNode(t *testing.T) { err error }{ {"LhsError", newErrorEval(errTest), newLiteralEval(types.Long(0)), errTest}, - {"LhsTypeError", newLiteralEval(types.Boolean(true)), newLiteralEval(types.Long(0)), types.ErrType}, + {"LhsTypeError", newLiteralEval(types.True), newLiteralEval(types.Long(0)), types.ErrType}, {"RhsError", newLiteralEval(types.Long(0)), newErrorEval(errTest), errTest}, - {"RhsTypeError", newLiteralEval(types.Long(0)), newLiteralEval(types.Boolean(true)), types.ErrType}, + {"RhsTypeError", newLiteralEval(types.Long(0)), newLiteralEval(types.True), types.ErrType}, {"PositiveOverflow", newLiteralEval(types.Long(9_223_372_036_854_775_807)), newLiteralEval(types.Long(-1)), @@ -433,9 +433,9 @@ func TestMultiplyNode(t *testing.T) { err error }{ {"LhsError", newErrorEval(errTest), newLiteralEval(types.Long(0)), errTest}, - {"LhsTypeError", newLiteralEval(types.Boolean(true)), newLiteralEval(types.Long(0)), types.ErrType}, + {"LhsTypeError", newLiteralEval(types.True), newLiteralEval(types.Long(0)), types.ErrType}, {"RhsError", newLiteralEval(types.Long(0)), newErrorEval(errTest), errTest}, - {"RhsTypeError", newLiteralEval(types.Long(0)), newLiteralEval(types.Boolean(true)), types.ErrType}, + {"RhsTypeError", newLiteralEval(types.Long(0)), newLiteralEval(types.True), types.ErrType}, {"PositiveOverflow", newLiteralEval(types.Long(9_223_372_036_854_775_807)), newLiteralEval(types.Long(2)), @@ -472,7 +472,7 @@ func TestNegateNode(t *testing.T) { err error }{ {"Error", newErrorEval(errTest), errTest}, - {"TypeError", newLiteralEval(types.Boolean(true)), types.ErrType}, + {"TypeError", newLiteralEval(types.True), types.ErrType}, {"Overflow", newLiteralEval(types.Long(-9_223_372_036_854_775_808)), errOverflow}, } for _, tt := range tests { @@ -522,9 +522,9 @@ func TestLongLessThanNode(t *testing.T) { err error }{ {"LhsError", newErrorEval(errTest), newLiteralEval(types.Long(0)), errTest}, - {"LhsTypeError", newLiteralEval(types.Boolean(true)), newLiteralEval(types.Long(0)), types.ErrType}, + {"LhsTypeError", newLiteralEval(types.True), newLiteralEval(types.Long(0)), types.ErrType}, {"RhsError", newLiteralEval(types.Long(0)), newErrorEval(errTest), errTest}, - {"RhsTypeError", newLiteralEval(types.Long(0)), newLiteralEval(types.Boolean(true)), types.ErrType}, + {"RhsTypeError", newLiteralEval(types.Long(0)), newLiteralEval(types.True), types.ErrType}, } for _, tt := range tests { tt := tt @@ -574,9 +574,9 @@ func TestLongLessThanOrEqualNode(t *testing.T) { err error }{ {"LhsError", newErrorEval(errTest), newLiteralEval(types.Long(0)), errTest}, - {"LhsTypeError", newLiteralEval(types.Boolean(true)), newLiteralEval(types.Long(0)), types.ErrType}, + {"LhsTypeError", newLiteralEval(types.True), newLiteralEval(types.Long(0)), types.ErrType}, {"RhsError", newLiteralEval(types.Long(0)), newErrorEval(errTest), errTest}, - {"RhsTypeError", newLiteralEval(types.Long(0)), newLiteralEval(types.Boolean(true)), types.ErrType}, + {"RhsTypeError", newLiteralEval(types.Long(0)), newLiteralEval(types.True), types.ErrType}, } for _, tt := range tests { tt := tt @@ -626,9 +626,9 @@ func TestLongGreaterThanNode(t *testing.T) { err error }{ {"LhsError", newErrorEval(errTest), newLiteralEval(types.Long(0)), errTest}, - {"LhsTypeError", newLiteralEval(types.Boolean(true)), newLiteralEval(types.Long(0)), types.ErrType}, + {"LhsTypeError", newLiteralEval(types.True), newLiteralEval(types.Long(0)), types.ErrType}, {"RhsError", newLiteralEval(types.Long(0)), newErrorEval(errTest), errTest}, - {"RhsTypeError", newLiteralEval(types.Long(0)), newLiteralEval(types.Boolean(true)), types.ErrType}, + {"RhsTypeError", newLiteralEval(types.Long(0)), newLiteralEval(types.True), types.ErrType}, } for _, tt := range tests { tt := tt @@ -678,9 +678,9 @@ func TestLongGreaterThanOrEqualNode(t *testing.T) { err error }{ {"LhsError", newErrorEval(errTest), newLiteralEval(types.Long(0)), errTest}, - {"LhsTypeError", newLiteralEval(types.Boolean(true)), newLiteralEval(types.Long(0)), types.ErrType}, + {"LhsTypeError", newLiteralEval(types.True), newLiteralEval(types.Long(0)), types.ErrType}, {"RhsError", newLiteralEval(types.Long(0)), newErrorEval(errTest), errTest}, - {"RhsTypeError", newLiteralEval(types.Long(0)), newLiteralEval(types.Boolean(true)), types.ErrType}, + {"RhsTypeError", newLiteralEval(types.Long(0)), newLiteralEval(types.True), types.ErrType}, } for _, tt := range tests { tt := tt @@ -735,9 +735,9 @@ func TestDecimalLessThanNode(t *testing.T) { err error }{ {"LhsError", newErrorEval(errTest), newLiteralEval(types.Decimal(0)), errTest}, - {"LhsTypeError", newLiteralEval(types.Boolean(true)), newLiteralEval(types.Decimal(0)), types.ErrType}, + {"LhsTypeError", newLiteralEval(types.True), newLiteralEval(types.Decimal(0)), types.ErrType}, {"RhsError", newLiteralEval(types.Decimal(0)), newErrorEval(errTest), errTest}, - {"RhsTypeError", newLiteralEval(types.Decimal(0)), newLiteralEval(types.Boolean(true)), types.ErrType}, + {"RhsTypeError", newLiteralEval(types.Decimal(0)), newLiteralEval(types.True), types.ErrType}, } for _, tt := range tests { tt := tt @@ -792,9 +792,9 @@ func TestDecimalLessThanOrEqualNode(t *testing.T) { err error }{ {"LhsError", newErrorEval(errTest), newLiteralEval(types.Decimal(0)), errTest}, - {"LhsTypeError", newLiteralEval(types.Boolean(true)), newLiteralEval(types.Decimal(0)), types.ErrType}, + {"LhsTypeError", newLiteralEval(types.True), newLiteralEval(types.Decimal(0)), types.ErrType}, {"RhsError", newLiteralEval(types.Decimal(0)), newErrorEval(errTest), errTest}, - {"RhsTypeError", newLiteralEval(types.Decimal(0)), newLiteralEval(types.Boolean(true)), types.ErrType}, + {"RhsTypeError", newLiteralEval(types.Decimal(0)), newLiteralEval(types.True), types.ErrType}, } for _, tt := range tests { tt := tt @@ -849,9 +849,9 @@ func TestDecimalGreaterThanNode(t *testing.T) { err error }{ {"LhsError", newErrorEval(errTest), newLiteralEval(types.Decimal(0)), errTest}, - {"LhsTypeError", newLiteralEval(types.Boolean(true)), newLiteralEval(types.Decimal(0)), types.ErrType}, + {"LhsTypeError", newLiteralEval(types.True), newLiteralEval(types.Decimal(0)), types.ErrType}, {"RhsError", newLiteralEval(types.Decimal(0)), newErrorEval(errTest), errTest}, - {"RhsTypeError", newLiteralEval(types.Decimal(0)), newLiteralEval(types.Boolean(true)), types.ErrType}, + {"RhsTypeError", newLiteralEval(types.Decimal(0)), newLiteralEval(types.True), types.ErrType}, } for _, tt := range tests { tt := tt @@ -906,9 +906,9 @@ func TestDecimalGreaterThanOrEqualNode(t *testing.T) { err error }{ {"LhsError", newErrorEval(errTest), newLiteralEval(types.Decimal(0)), errTest}, - {"LhsTypeError", newLiteralEval(types.Boolean(true)), newLiteralEval(types.Decimal(0)), types.ErrType}, + {"LhsTypeError", newLiteralEval(types.True), newLiteralEval(types.Decimal(0)), types.ErrType}, {"RhsError", newLiteralEval(types.Decimal(0)), newErrorEval(errTest), errTest}, - {"RhsTypeError", newLiteralEval(types.Decimal(0)), newLiteralEval(types.Boolean(true)), types.ErrType}, + {"RhsTypeError", newLiteralEval(types.Decimal(0)), newLiteralEval(types.True), types.ErrType}, } for _, tt := range tests { tt := tt @@ -930,10 +930,10 @@ func TestIfThenElseNode(t *testing.T) { result types.Value err error }{ - {"Then", newLiteralEval(types.Boolean(true)), newLiteralEval(types.Long(42)), + {"Then", newLiteralEval(types.True), newLiteralEval(types.Long(42)), newLiteralEval(types.Long(-1)), types.Long(42), nil}, - {"Else", newLiteralEval(types.Boolean(false)), newLiteralEval(types.Long(-1)), + {"Else", newLiteralEval(types.False), newLiteralEval(types.Long(-1)), newLiteralEval(types.Long(42)), types.Long(42), nil}, {"Err", newErrorEval(errTest), newLiteralEval(types.ZeroValue()), newLiteralEval(types.ZeroValue()), types.ZeroValue(), @@ -961,11 +961,11 @@ func TestEqualNode(t *testing.T) { result types.Value err error }{ - {"equals", newLiteralEval(types.Long(42)), newLiteralEval(types.Long(42)), types.Boolean(true), nil}, - {"notEquals", newLiteralEval(types.Long(42)), newLiteralEval(types.Long(1234)), types.Boolean(false), nil}, + {"equals", newLiteralEval(types.Long(42)), newLiteralEval(types.Long(42)), types.True, nil}, + {"notEquals", newLiteralEval(types.Long(42)), newLiteralEval(types.Long(1234)), types.False, nil}, {"leftErr", newErrorEval(errTest), newLiteralEval(types.ZeroValue()), types.ZeroValue(), errTest}, {"rightErr", newLiteralEval(types.ZeroValue()), newErrorEval(errTest), types.ZeroValue(), errTest}, - {"typesNotEqual", newLiteralEval(types.Long(1)), newLiteralEval(types.Boolean(true)), types.Boolean(false), nil}, + {"typesNotEqual", newLiteralEval(types.Long(1)), newLiteralEval(types.True), types.False, nil}, } for _, tt := range tests { tt := tt @@ -987,11 +987,11 @@ func TestNotEqualNode(t *testing.T) { result types.Value err error }{ - {"equals", newLiteralEval(types.Long(42)), newLiteralEval(types.Long(42)), types.Boolean(false), nil}, - {"notEquals", newLiteralEval(types.Long(42)), newLiteralEval(types.Long(1234)), types.Boolean(true), nil}, + {"equals", newLiteralEval(types.Long(42)), newLiteralEval(types.Long(42)), types.False, nil}, + {"notEquals", newLiteralEval(types.Long(42)), newLiteralEval(types.Long(1234)), types.True, nil}, {"leftErr", newErrorEval(errTest), newLiteralEval(types.ZeroValue()), types.ZeroValue(), errTest}, {"rightErr", newLiteralEval(types.ZeroValue()), newErrorEval(errTest), types.ZeroValue(), errTest}, - {"typesNotEqual", newLiteralEval(types.Long(1)), newLiteralEval(types.Boolean(true)), types.Boolean(true), nil}, + {"typesNotEqual", newLiteralEval(types.Long(1)), newLiteralEval(types.True), types.True, nil}, } for _, tt := range tests { tt := tt @@ -1017,17 +1017,17 @@ func TestSetLiteralNode(t *testing.T) { {"errorNode", []Evaler{newErrorEval(errTest)}, types.ZeroValue(), errTest}, {"nested", []Evaler{ - newLiteralEval(types.Boolean(true)), + newLiteralEval(types.True), newLiteralEval(types.Set{ - types.Boolean(false), + types.False, types.Long(1), }), newLiteralEval(types.Long(10)), }, types.Set{ - types.Boolean(true), + types.True, types.Set{ - types.Boolean(false), + types.False, types.Long(1), }, types.Long(10), @@ -1055,7 +1055,7 @@ func TestContainsNode(t *testing.T) { err error }{ {"LhsError", newErrorEval(errTest), newLiteralEval(types.Long(0)), errTest}, - {"LhsTypeError", newLiteralEval(types.Boolean(true)), newLiteralEval(types.Long(0)), types.ErrType}, + {"LhsTypeError", newLiteralEval(types.True), newLiteralEval(types.Long(0)), types.ErrType}, {"RhsError", newLiteralEval(types.Set{}), newErrorEval(errTest), errTest}, } for _, tt := range tests { @@ -1071,21 +1071,21 @@ func TestContainsNode(t *testing.T) { } { empty := types.Set{} - trueAndOne := types.Set{types.Boolean(true), types.Long(1)} - nested := types.Set{trueAndOne, types.Boolean(false), types.Long(2)} + trueAndOne := types.Set{types.True, types.Long(1)} + nested := types.Set{trueAndOne, types.False, types.Long(2)} tests := []struct { name string lhs, rhs Evaler result bool }{ - {"empty", newLiteralEval(empty), newLiteralEval(types.Boolean(true)), false}, - {"trueAndOneContainsTrue", newLiteralEval(trueAndOne), newLiteralEval(types.Boolean(true)), true}, + {"empty", newLiteralEval(empty), newLiteralEval(types.True), false}, + {"trueAndOneContainsTrue", newLiteralEval(trueAndOne), newLiteralEval(types.True), true}, {"trueAndOneContainsOne", newLiteralEval(trueAndOne), newLiteralEval(types.Long(1)), true}, {"trueAndOneDoesNotContainTwo", newLiteralEval(trueAndOne), newLiteralEval(types.Long(2)), false}, - {"nestedContainsFalse", newLiteralEval(nested), newLiteralEval(types.Boolean(false)), true}, + {"nestedContainsFalse", newLiteralEval(nested), newLiteralEval(types.False), true}, {"nestedContainsSet", newLiteralEval(nested), newLiteralEval(trueAndOne), true}, - {"nestedDoesNotContainTrue", newLiteralEval(nested), newLiteralEval(types.Boolean(true)), false}, + {"nestedDoesNotContainTrue", newLiteralEval(nested), newLiteralEval(types.True), false}, } for _, tt := range tests { tt := tt @@ -1109,7 +1109,7 @@ func TestContainsAllNode(t *testing.T) { err error }{ {"LhsError", newErrorEval(errTest), newLiteralEval(types.Set{}), errTest}, - {"LhsTypeError", newLiteralEval(types.Boolean(true)), newLiteralEval(types.Set{}), types.ErrType}, + {"LhsTypeError", newLiteralEval(types.True), newLiteralEval(types.Set{}), types.ErrType}, {"RhsError", newLiteralEval(types.Set{}), newErrorEval(errTest), errTest}, {"RhsTypeError", newLiteralEval(types.Set{}), newLiteralEval(types.Long(0)), types.ErrType}, } @@ -1126,9 +1126,9 @@ func TestContainsAllNode(t *testing.T) { } { empty := types.Set{} - trueOnly := types.Set{types.Boolean(true)} - trueAndOne := types.Set{types.Boolean(true), types.Long(1)} - nested := types.Set{trueAndOne, types.Boolean(false), types.Long(2)} + trueOnly := types.Set{types.True} + trueAndOne := types.Set{types.True, types.Long(1)} + nested := types.Set{trueAndOne, types.False, types.Long(2)} tests := []struct { name string @@ -1163,7 +1163,7 @@ func TestContainsAnyNode(t *testing.T) { err error }{ {"LhsError", newErrorEval(errTest), newLiteralEval(types.Set{}), errTest}, - {"LhsTypeError", newLiteralEval(types.Boolean(true)), newLiteralEval(types.Set{}), types.ErrType}, + {"LhsTypeError", newLiteralEval(types.True), newLiteralEval(types.Set{}), types.ErrType}, {"RhsError", newLiteralEval(types.Set{}), newErrorEval(errTest), errTest}, {"RhsTypeError", newLiteralEval(types.Set{}), newLiteralEval(types.Long(0)), types.ErrType}, } @@ -1180,10 +1180,10 @@ func TestContainsAnyNode(t *testing.T) { } { empty := types.Set{} - trueOnly := types.Set{types.Boolean(true)} - trueAndOne := types.Set{types.Boolean(true), types.Long(1)} - trueAndTwo := types.Set{types.Boolean(true), types.Long(2)} - nested := types.Set{trueAndOne, types.Boolean(false), types.Long(2)} + trueOnly := types.Set{types.True} + trueAndOne := types.Set{types.True, types.Long(1)} + trueAndTwo := types.Set{types.True, types.Long(2)} + nested := types.Set{trueAndOne, types.False, types.Long(2)} tests := []struct { name string @@ -1223,10 +1223,10 @@ func TestRecordLiteralNode(t *testing.T) { {"errorNode", map[string]Evaler{"foo": newErrorEval(errTest)}, types.ZeroValue(), errTest}, {"ok", map[string]Evaler{ - "foo": newLiteralEval(types.Boolean(true)), + "foo": newLiteralEval(types.True), "bar": newLiteralEval(types.String("baz")), }, types.Record{ - "foo": types.Boolean(true), + "foo": types.True, "bar": types.String("baz"), }, nil}, } @@ -1252,7 +1252,7 @@ func TestAttributeAccessNode(t *testing.T) { err error }{ {"RecordError", newErrorEval(errTest), "foo", types.ZeroValue(), errTest}, - {"RecordTypeError", newLiteralEval(types.Boolean(true)), "foo", types.ZeroValue(), types.ErrType}, + {"RecordTypeError", newLiteralEval(types.True), "foo", types.ZeroValue(), types.ErrType}, {"UnknownAttribute", newLiteralEval(types.Record{}), "foo", @@ -1309,31 +1309,31 @@ func TestHasNode(t *testing.T) { err error }{ {"RecordError", newErrorEval(errTest), "foo", types.ZeroValue(), errTest}, - {"RecordTypeError", newLiteralEval(types.Boolean(true)), "foo", types.ZeroValue(), types.ErrType}, + {"RecordTypeError", newLiteralEval(types.True), "foo", types.ZeroValue(), types.ErrType}, {"UnknownAttribute", newLiteralEval(types.Record{}), "foo", - types.Boolean(false), + types.False, nil}, {"KnownAttribute", newLiteralEval(types.Record{"foo": types.Long(42)}), "foo", - types.Boolean(true), + types.True, nil}, {"KnownAttributeOnEntity", newLiteralEval(types.NewEntityUID("knownType", "knownID")), "knownAttr", - types.Boolean(true), + types.True, nil}, {"UnknownAttributeOnEntity", newLiteralEval(types.NewEntityUID("knownType", "knownID")), "unknownAttr", - types.Boolean(false), + types.False, nil}, {"UnknownEntity", newLiteralEval(types.NewEntityUID("unknownType", "unknownID")), "unknownAttr", - types.Boolean(false), + types.False, nil}, } for _, tt := range tests { @@ -1366,35 +1366,35 @@ func TestLikeNode(t *testing.T) { err error }{ {"leftError", newErrorEval(errTest), `"foo"`, types.ZeroValue(), errTest}, - {"leftTypeError", newLiteralEval(types.Boolean(true)), `"foo"`, types.ZeroValue(), types.ErrType}, - {"noMatch", newLiteralEval(types.String("test")), `"zebra"`, types.Boolean(false), nil}, - {"match", newLiteralEval(types.String("test")), `"*es*"`, types.Boolean(true), nil}, - - {"case-1", newLiteralEval(types.String("eggs")), `"ham*"`, types.Boolean(false), nil}, - {"case-2", newLiteralEval(types.String("eggs")), `"*ham"`, types.Boolean(false), nil}, - {"case-3", newLiteralEval(types.String("eggs")), `"*ham*"`, types.Boolean(false), nil}, - {"case-4", newLiteralEval(types.String("ham and eggs")), `"ham*"`, types.Boolean(true), nil}, - {"case-5", newLiteralEval(types.String("ham and eggs")), `"*ham"`, types.Boolean(false), nil}, - {"case-6", newLiteralEval(types.String("ham and eggs")), `"*ham*"`, types.Boolean(true), nil}, - {"case-7", newLiteralEval(types.String("ham and eggs")), `"*h*a*m*"`, types.Boolean(true), nil}, - {"case-8", newLiteralEval(types.String("eggs and ham")), `"ham*"`, types.Boolean(false), nil}, - {"case-9", newLiteralEval(types.String("eggs and ham")), `"*ham"`, types.Boolean(true), nil}, - {"case-10", newLiteralEval(types.String("eggs, ham, and spinach")), `"ham*"`, types.Boolean(false), nil}, - {"case-11", newLiteralEval(types.String("eggs, ham, and spinach")), `"*ham"`, types.Boolean(false), nil}, - {"case-12", newLiteralEval(types.String("eggs, ham, and spinach")), `"*ham*"`, types.Boolean(true), nil}, - {"case-13", newLiteralEval(types.String("Gotham")), `"ham*"`, types.Boolean(false), nil}, - {"case-14", newLiteralEval(types.String("Gotham")), `"*ham"`, types.Boolean(true), nil}, - {"case-15", newLiteralEval(types.String("ham")), `"ham"`, types.Boolean(true), nil}, - {"case-16", newLiteralEval(types.String("ham")), `"ham*"`, types.Boolean(true), nil}, - {"case-17", newLiteralEval(types.String("ham")), `"*ham"`, types.Boolean(true), nil}, - {"case-18", newLiteralEval(types.String("ham")), `"*h*a*m*"`, types.Boolean(true), nil}, - {"case-19", newLiteralEval(types.String("ham and ham")), `"ham*"`, types.Boolean(true), nil}, - {"case-20", newLiteralEval(types.String("ham and ham")), `"*ham"`, types.Boolean(true), nil}, - {"case-21", newLiteralEval(types.String("ham")), `"*ham and eggs*"`, types.Boolean(false), nil}, - {"case-22", newLiteralEval(types.String("\\afterslash")), `"\\*"`, types.Boolean(true), nil}, - {"case-23", newLiteralEval(types.String("string\\with\\backslashes")), `"string\\with\\backslashes"`, types.Boolean(true), nil}, - {"case-24", newLiteralEval(types.String("string\\with\\backslashes")), `"string*with*backslashes"`, types.Boolean(true), nil}, - {"case-25", newLiteralEval(types.String("string*with*stars")), `"string\*with\*stars"`, types.Boolean(true), nil}, + {"leftTypeError", newLiteralEval(types.True), `"foo"`, types.ZeroValue(), types.ErrType}, + {"noMatch", newLiteralEval(types.String("test")), `"zebra"`, types.False, nil}, + {"match", newLiteralEval(types.String("test")), `"*es*"`, types.True, nil}, + + {"case-1", newLiteralEval(types.String("eggs")), `"ham*"`, types.False, nil}, + {"case-2", newLiteralEval(types.String("eggs")), `"*ham"`, types.False, nil}, + {"case-3", newLiteralEval(types.String("eggs")), `"*ham*"`, types.False, nil}, + {"case-4", newLiteralEval(types.String("ham and eggs")), `"ham*"`, types.True, nil}, + {"case-5", newLiteralEval(types.String("ham and eggs")), `"*ham"`, types.False, nil}, + {"case-6", newLiteralEval(types.String("ham and eggs")), `"*ham*"`, types.True, nil}, + {"case-7", newLiteralEval(types.String("ham and eggs")), `"*h*a*m*"`, types.True, nil}, + {"case-8", newLiteralEval(types.String("eggs and ham")), `"ham*"`, types.False, nil}, + {"case-9", newLiteralEval(types.String("eggs and ham")), `"*ham"`, types.True, nil}, + {"case-10", newLiteralEval(types.String("eggs, ham, and spinach")), `"ham*"`, types.False, nil}, + {"case-11", newLiteralEval(types.String("eggs, ham, and spinach")), `"*ham"`, types.False, nil}, + {"case-12", newLiteralEval(types.String("eggs, ham, and spinach")), `"*ham*"`, types.True, nil}, + {"case-13", newLiteralEval(types.String("Gotham")), `"ham*"`, types.False, nil}, + {"case-14", newLiteralEval(types.String("Gotham")), `"*ham"`, types.True, nil}, + {"case-15", newLiteralEval(types.String("ham")), `"ham"`, types.True, nil}, + {"case-16", newLiteralEval(types.String("ham")), `"ham*"`, types.True, nil}, + {"case-17", newLiteralEval(types.String("ham")), `"*ham"`, types.True, nil}, + {"case-18", newLiteralEval(types.String("ham")), `"*h*a*m*"`, types.True, nil}, + {"case-19", newLiteralEval(types.String("ham and ham")), `"ham*"`, types.True, nil}, + {"case-20", newLiteralEval(types.String("ham and ham")), `"*ham"`, types.True, nil}, + {"case-21", newLiteralEval(types.String("ham")), `"*ham and eggs*"`, types.False, nil}, + {"case-22", newLiteralEval(types.String("\\afterslash")), `"\\*"`, types.True, nil}, + {"case-23", newLiteralEval(types.String("string\\with\\backslashes")), `"string\\with\\backslashes"`, types.True, nil}, + {"case-24", newLiteralEval(types.String("string\\with\\backslashes")), `"string*with*backslashes"`, types.True, nil}, + {"case-25", newLiteralEval(types.String("string*with*stars")), `"string\*with\*stars"`, types.True, nil}, } for _, tt := range tests { tt := tt @@ -1590,8 +1590,8 @@ func TestIsNode(t *testing.T) { result types.Value err error }{ - {"happyEq", newLiteralEval(types.NewEntityUID("X", "z")), newLiteralEval(types.Path("X")), types.Boolean(true), nil}, - {"happyNeq", newLiteralEval(types.NewEntityUID("X", "z")), newLiteralEval(types.Path("Y")), types.Boolean(false), nil}, + {"happyEq", newLiteralEval(types.NewEntityUID("X", "z")), newLiteralEval(types.Path("X")), types.True, nil}, + {"happyNeq", newLiteralEval(types.NewEntityUID("X", "z")), newLiteralEval(types.Path("Y")), types.False, nil}, {"badLhs", newLiteralEval(types.Long(42)), newLiteralEval(types.Path("X")), types.ZeroValue(), types.ErrType}, {"badRhs", newLiteralEval(types.NewEntityUID("X", "z")), newLiteralEval(types.Long(42)), types.ZeroValue(), types.ErrType}, {"errLhs", newErrorEval(errTest), newLiteralEval(types.Path("X")), types.ZeroValue(), errTest}, @@ -1664,7 +1664,7 @@ func TestInNode(t *testing.T) { newLiteralEval(types.NewEntityUID("human", "joe")), newLiteralEval(types.NewEntityUID("human", "joe")), map[string][]string{}, - types.Boolean(true), + types.True, nil, }, { @@ -1674,7 +1674,7 @@ func TestInNode(t *testing.T) { types.NewEntityUID("human", "joe"), }), map[string][]string{}, - types.Boolean(true), + types.True, nil, }, { @@ -1685,7 +1685,7 @@ func TestInNode(t *testing.T) { `human::"joe"`: {`species::"human"`}, `species::"human"`: {`kingdom::"animal"`}, }, - types.Boolean(true), + types.True, nil, }, { @@ -1696,7 +1696,7 @@ func TestInNode(t *testing.T) { `human::"joe"`: {`species::"human"`}, `species::"human"`: {`kingdom::"animal"`}, }, - types.Boolean(false), + types.False, nil, }, } @@ -1794,14 +1794,14 @@ func TestIPTestNode(t *testing.T) { }{ {"Error", newErrorEval(errTest), ipTestIPv4, types.ZeroValue(), errTest}, {"TypeError", newLiteralEval(types.Long(1)), ipTestIPv4, types.ZeroValue(), types.ErrType}, - {"IPv4True", newLiteralEval(ipv4Loopback), ipTestIPv4, types.Boolean(true), nil}, - {"IPv4False", newLiteralEval(ipv6Loopback), ipTestIPv4, types.Boolean(false), nil}, - {"IPv6True", newLiteralEval(ipv6Loopback), ipTestIPv6, types.Boolean(true), nil}, - {"IPv6False", newLiteralEval(ipv4Loopback), ipTestIPv6, types.Boolean(false), nil}, - {"LoopbackTrue", newLiteralEval(ipv6Loopback), ipTestLoopback, types.Boolean(true), nil}, - {"LoopbackFalse", newLiteralEval(ipv4Multicast), ipTestLoopback, types.Boolean(false), nil}, - {"MulticastTrue", newLiteralEval(ipv4Multicast), ipTestMulticast, types.Boolean(true), nil}, - {"MulticastFalse", newLiteralEval(ipv6Loopback), ipTestMulticast, types.Boolean(false), nil}, + {"IPv4True", newLiteralEval(ipv4Loopback), ipTestIPv4, types.True, nil}, + {"IPv4False", newLiteralEval(ipv6Loopback), ipTestIPv4, types.False, nil}, + {"IPv6True", newLiteralEval(ipv6Loopback), ipTestIPv6, types.True, nil}, + {"IPv6False", newLiteralEval(ipv4Loopback), ipTestIPv6, types.False, nil}, + {"LoopbackTrue", newLiteralEval(ipv6Loopback), ipTestLoopback, types.True, nil}, + {"LoopbackFalse", newLiteralEval(ipv4Multicast), ipTestLoopback, types.False, nil}, + {"MulticastTrue", newLiteralEval(ipv4Multicast), ipTestMulticast, types.True, nil}, + {"MulticastFalse", newLiteralEval(ipv6Loopback), ipTestMulticast, types.False, nil}, } for _, tt := range tests { tt := tt @@ -1833,13 +1833,13 @@ func TestIPIsInRangeNode(t *testing.T) { {"LhsTypeError", newLiteralEval(types.Long(1)), newLiteralEval(ipv4A), types.ZeroValue(), types.ErrType}, {"RhsError", newLiteralEval(ipv4A), newErrorEval(errTest), types.ZeroValue(), errTest}, {"RhsTypeError", newLiteralEval(ipv4A), newLiteralEval(types.Long(1)), types.ZeroValue(), types.ErrType}, - {"AA", newLiteralEval(ipv4A), newLiteralEval(ipv4A), types.Boolean(true), nil}, - {"AB", newLiteralEval(ipv4A), newLiteralEval(ipv4B), types.Boolean(true), nil}, - {"BA", newLiteralEval(ipv4B), newLiteralEval(ipv4A), types.Boolean(false), nil}, - {"AC", newLiteralEval(ipv4A), newLiteralEval(ipv4C), types.Boolean(false), nil}, - {"CA", newLiteralEval(ipv4C), newLiteralEval(ipv4A), types.Boolean(false), nil}, - {"BC", newLiteralEval(ipv4B), newLiteralEval(ipv4C), types.Boolean(false), nil}, - {"CB", newLiteralEval(ipv4C), newLiteralEval(ipv4B), types.Boolean(false), nil}, + {"AA", newLiteralEval(ipv4A), newLiteralEval(ipv4A), types.True, nil}, + {"AB", newLiteralEval(ipv4A), newLiteralEval(ipv4B), types.True, nil}, + {"BA", newLiteralEval(ipv4B), newLiteralEval(ipv4A), types.False, nil}, + {"AC", newLiteralEval(ipv4A), newLiteralEval(ipv4C), types.False, nil}, + {"CA", newLiteralEval(ipv4C), newLiteralEval(ipv4A), types.False, nil}, + {"BC", newLiteralEval(ipv4B), newLiteralEval(ipv4C), types.False, nil}, + {"CB", newLiteralEval(ipv4C), newLiteralEval(ipv4B), types.False, nil}, } for _, tt := range tests { tt := tt @@ -1863,7 +1863,7 @@ func TestCedarString(t *testing.T) { }{ {"string", types.String("hello"), `hello`, `"hello"`}, {"number", types.Long(42), `42`, `42`}, - {"bool", types.Boolean(true), `true`, `true`}, + {"bool", types.True, `true`, `true`}, {"record", types.Record{"a": types.Long(42), "b": types.Long(43)}, `{"a":42,"b":43}`, `{"a":42,"b":43}`}, {"set", types.Set{types.Long(42), types.Long(43)}, `[42,43]`, `[42,43]`}, {"singleIP", types.IPAddr(netip.MustParsePrefix("192.168.0.42/32")), `192.168.0.42`, `ip("192.168.0.42")`}, diff --git a/types/value.go b/types/value.go index 433021ee..1c21f2a5 100644 --- a/types/value.go +++ b/types/value.go @@ -39,6 +39,11 @@ func ZeroValue() Value { // A Boolean is a value that is either true or false. type Boolean bool +const ( + True = Boolean(true) + False = Boolean(false) +) + func (a Boolean) Equal(bi Value) bool { b, ok := bi.(Boolean) return ok && a == b From 5538fad268468d4b040a6a6e1761a338d48eb97e Mon Sep 17 00:00:00 2001 From: philhassey Date: Thu, 15 Aug 2024 15:20:58 -0600 Subject: [PATCH 137/216] internal/consts: add PARC string constants Addresses IDX-142 Signed-off-by: philhassey --- internal/ast/variable.go | 13 ++++++++----- internal/consts/consts.go | 8 ++++++++ internal/eval/convert.go | 9 +++++---- internal/json/json_unmarshal.go | 9 +++++---- internal/parser/cedar_marshal.go | 9 +++++---- internal/parser/cedar_unmarshal.go | 15 ++++++++------- 6 files changed, 39 insertions(+), 24 deletions(-) create mode 100644 internal/consts/consts.go diff --git a/internal/ast/variable.go b/internal/ast/variable.go index be398a8d..bcd101ee 100644 --- a/internal/ast/variable.go +++ b/internal/ast/variable.go @@ -1,6 +1,9 @@ package ast -import "github.com/cedar-policy/cedar-go/types" +import ( + "github.com/cedar-policy/cedar-go/internal/consts" + "github.com/cedar-policy/cedar-go/types" +) func Principal() Node { return NewNode(NewPrincipalNode()) @@ -19,17 +22,17 @@ func Context() Node { } func NewPrincipalNode() NodeTypeVariable { - return NodeTypeVariable{Name: types.String("principal")} + return NodeTypeVariable{Name: types.String(consts.Principal)} } func NewActionNode() NodeTypeVariable { - return NodeTypeVariable{Name: types.String("action")} + return NodeTypeVariable{Name: types.String(consts.Action)} } func NewResourceNode() NodeTypeVariable { - return NodeTypeVariable{Name: types.String("resource")} + return NodeTypeVariable{Name: types.String(consts.Resource)} } func NewContextNode() NodeTypeVariable { - return NodeTypeVariable{Name: types.String("context")} + return NodeTypeVariable{Name: types.String(consts.Context)} } diff --git a/internal/consts/consts.go b/internal/consts/consts.go new file mode 100644 index 00000000..559d2a8f --- /dev/null +++ b/internal/consts/consts.go @@ -0,0 +1,8 @@ +package consts + +const ( + Principal = "principal" + Action = "action" + Resource = "resource" + Context = "context" +) diff --git a/internal/eval/convert.go b/internal/eval/convert.go index a5433d5a..b502cd4b 100644 --- a/internal/eval/convert.go +++ b/internal/eval/convert.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/cedar-policy/cedar-go/internal/ast" + "github.com/cedar-policy/cedar-go/internal/consts" "github.com/cedar-policy/cedar-go/internal/extensions" ) @@ -77,13 +78,13 @@ func toEval(n ast.IsNode) Evaler { return newNotEval(toEval(v.Arg)) case ast.NodeTypeVariable: switch v.Name { - case "principal": + case consts.Principal: return newVariableEval(variableNamePrincipal) - case "action": + case consts.Action: return newVariableEval(variableNameAction) - case "resource": + case consts.Resource: return newVariableEval(variableNameResource) - case "context": + case consts.Context: return newVariableEval(variableNameContext) default: panic(fmt.Errorf("unknown variable: %v", v.Name)) diff --git a/internal/json/json_unmarshal.go b/internal/json/json_unmarshal.go index f5566e60..46135d8e 100644 --- a/internal/json/json_unmarshal.go +++ b/internal/json/json_unmarshal.go @@ -7,6 +7,7 @@ import ( "strings" "github.com/cedar-policy/cedar-go/internal/ast" + "github.com/cedar-policy/cedar-go/internal/consts" "github.com/cedar-policy/cedar-go/internal/extensions" "github.com/cedar-policy/cedar-go/types" ) @@ -162,13 +163,13 @@ func (j nodeJSON) ToNode() (ast.Node, error) { // Var case j.Var != nil: switch *j.Var { - case "principal": + case consts.Principal: return ast.Principal(), nil - case "action": + case consts.Action: return ast.Action(), nil - case "resource": + case consts.Resource: return ast.Resource(), nil - case "context": + case consts.Context: return ast.Context(), nil } return ast.Node{}, fmt.Errorf("unknown variable: %v", j.Var) diff --git a/internal/parser/cedar_marshal.go b/internal/parser/cedar_marshal.go index 29110aa6..129800e9 100644 --- a/internal/parser/cedar_marshal.go +++ b/internal/parser/cedar_marshal.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/cedar-policy/cedar-go/internal/ast" + "github.com/cedar-policy/cedar-go/internal/consts" "github.com/cedar-policy/cedar-go/internal/extensions" "github.com/cedar-policy/cedar-go/types" ) @@ -57,25 +58,25 @@ func (p *Policy) marshalScope(buf *bytes.Buffer) { _, actionAll := p.Action.(ast.ScopeTypeAll) _, resourceAll := p.Resource.(ast.ScopeTypeAll) if principalAll && actionAll && resourceAll { - buf.WriteString("( principal, action, resource )") + buf.WriteString("( " + consts.Principal + ", " + consts.Action + ", " + consts.Resource + " )") return } buf.WriteString("(\n ") if principalAll { - buf.WriteString("principal") + buf.WriteString(consts.Principal) } else { astNodeToMarshalNode(scopeToNode(ast.NewPrincipalNode(), p.Principal).AsIsNode()).marshalCedar(buf) } buf.WriteString(",\n ") if actionAll { - buf.WriteString("action") + buf.WriteString(consts.Action) } else { astNodeToMarshalNode(scopeToNode(ast.NewActionNode(), p.Action).AsIsNode()).marshalCedar(buf) } buf.WriteString(",\n ") if resourceAll { - buf.WriteString("resource") + buf.WriteString(consts.Resource) } else { astNodeToMarshalNode(scopeToNode(ast.NewResourceNode(), p.Resource).AsIsNode()).marshalCedar(buf) } diff --git a/internal/parser/cedar_unmarshal.go b/internal/parser/cedar_unmarshal.go index 654f1f27..02e8f703 100644 --- a/internal/parser/cedar_unmarshal.go +++ b/internal/parser/cedar_unmarshal.go @@ -6,6 +6,7 @@ import ( "strings" "github.com/cedar-policy/cedar-go/internal/ast" + "github.com/cedar-policy/cedar-go/internal/consts" "github.com/cedar-policy/cedar-go/internal/extensions" "github.com/cedar-policy/cedar-go/types" ) @@ -180,7 +181,7 @@ func (p *parser) effect(a *ast.Annotations) (*ast.Policy, error) { } func (p *parser) principal(policy *ast.Policy) error { - if err := p.exact("principal"); err != nil { + if err := p.exact(consts.Principal); err != nil { return err } switch p.peek().Text { @@ -282,7 +283,7 @@ func (p *parser) path() (types.Path, error) { } func (p *parser) action(policy *ast.Policy) error { - if err := p.exact("action"); err != nil { + if err := p.exact(consts.Action); err != nil { return err } switch p.peek().Text { @@ -336,7 +337,7 @@ func (p *parser) entlist() ([]types.EntityUID, error) { } func (p *parser) resource(policy *ast.Policy) error { - if err := p.exact("resource"); err != nil { + if err := p.exact(consts.Resource); err != nil { return err } switch p.peek().Text { @@ -703,13 +704,13 @@ func (p *parser) primary() (ast.Node, error) { res = ast.True() case t.Text == "false": res = ast.False() - case t.Text == "principal": + case t.Text == consts.Principal: res = ast.Principal() - case t.Text == "action": + case t.Text == consts.Action: res = ast.Action() - case t.Text == "resource": + case t.Text == consts.Resource: res = ast.Resource() - case t.Text == "context": + case t.Text == consts.Context: res = ast.Context() case t.isIdent(): return p.entityOrExtFun(t.Text) From ff886084d9c22145eae4b2d8539d5dd6c580cefb Mon Sep 17 00:00:00 2001 From: philhassey Date: Thu, 15 Aug 2024 15:44:59 -0600 Subject: [PATCH 138/216] ast: simplify the sugar for Set, Record Addresses IDX-142 Signed-off-by: philhassey --- ast/ast_test.go | 31 +++---------- ast/value.go | 54 ++++++---------------- internal/ast/ast_test.go | 31 ++++--------- internal/ast/value.go | 61 +++++++------------------ internal/eval/compile.go | 2 +- internal/eval/compile_test.go | 2 +- internal/eval/convert_test.go | 14 +++--- internal/json/json_test.go | 4 +- internal/json/json_unmarshal.go | 10 ++-- internal/parser/cedar_marshal.go | 2 +- internal/parser/cedar_unmarshal.go | 8 ++-- internal/parser/cedar_unmarshal_test.go | 4 +- 12 files changed, 70 insertions(+), 153 deletions(-) diff --git a/ast/ast_test.go b/ast/ast_test.go index b8aed222..7b3b1e13 100644 --- a/ast/ast_test.go +++ b/ast/ast_test.go @@ -57,15 +57,13 @@ func TestAstExamples(t *testing.T) { } _ = ast.Forbid(). When( - ast.Record(simpleRecord).Access("x").Equals(ast.String("value")), + ast.Value(simpleRecord).Access("x").Equals(ast.String("value")), ). When( - ast.RecordNodes(map[types.String]ast.Node{ - "x": ast.Long(1).Plus(ast.Context().Access("fooCount")), - }).Access("x").Equals(ast.Long(3)), + ast.Record(ast.Pairs{{Key: "x", Value: ast.Long(1).Plus(ast.Context().Access("fooCount"))}}).Access("x").Equals(ast.Long(3)), ). When( - ast.SetNodes( + ast.Set( ast.Long(1), ast.Long(2).Plus(ast.Long(3)).Times(ast.Long(4)), ast.Context().Access("fooCount"), @@ -225,30 +223,15 @@ func TestASTByTable(t *testing.T) { ast.Permit().When(ast.Long(42)), internalast.Permit().When(internalast.Long(42)), }, - { - "valueSet", - ast.Permit().When(ast.Set(types.Set{types.Long(42), types.Long(43)})), - internalast.Permit().When(internalast.Set(types.Set{types.Long(42), types.Long(43)})), - }, { "valueSetNodes", - ast.Permit().When(ast.SetNodes(ast.Long(42), ast.Long(43))), - internalast.Permit().When(internalast.SetNodes(internalast.Long(42), internalast.Long(43))), - }, - { - "valueRecord", - ast.Permit().When(ast.Record(types.Record{"key": types.Long(43)})), - internalast.Permit().When(internalast.Record(types.Record{"key": types.Long(43)})), - }, - { - "valueRecordNodes", - ast.Permit().When(ast.RecordNodes(map[types.String]ast.Node{"key": ast.Long(42)})), - internalast.Permit().When(internalast.RecordNodes(map[types.String]internalast.Node{"key": internalast.Long(42)})), + ast.Permit().When(ast.Set(ast.Long(42), ast.Long(43))), + internalast.Permit().When(internalast.Set(internalast.Long(42), internalast.Long(43))), }, { "valueRecordElements", - ast.Permit().When(ast.RecordElements(ast.RecordElement{Key: "key", Value: ast.Long(42)})), - internalast.Permit().When(internalast.RecordElements(internalast.RecordElement{Key: "key", Value: internalast.Long(42)})), + ast.Permit().When(ast.Record(ast.Pairs{{Key: "key", Value: ast.Long(42)}})), + internalast.Permit().When(internalast.Record(internalast.Pairs{{Key: "key", Value: internalast.Long(42)}})), }, { "valueEntityUID", diff --git a/ast/value.go b/ast/value.go index b5bccd53..efa1213b 100644 --- a/ast/value.go +++ b/ast/value.go @@ -25,67 +25,39 @@ func Long(l types.Long) Node { return wrapNode(ast.Long(l)) } -// Set is a convenience function that wraps concrete instances of a Cedar Set type -// types in AST value nodes and passes them along to SetNodes. -func Set(s types.Set) Node { - return wrapNode(ast.Set(s)) -} - -// SetNodes allows for a complex set definition with values potentially +// Set allows for a complex set definition with values potentially // being Cedar expressions of their own. For example, this Cedar text: // // [1, 2 + 3, context.fooCount] // // could be expressed in Golang as: // -// ast.SetNodes( +// ast.Set( // ast.Long(1), // ast.Long(2).Plus(ast.Long(3)), // ast.Context().Access("fooCount"), // ) -func SetNodes(nodes ...Node) Node { +func Set(nodes ...Node) Node { var astNodes []ast.Node for _, n := range nodes { astNodes = append(astNodes, n.Node) } - return wrapNode(ast.SetNodes(astNodes...)) -} - -// Record is a convenience function that wraps concrete instances of a Cedar Record type -// types in AST value nodes and passes them along to RecordNodes. -func Record(r types.Record) Node { - return wrapNode(ast.Record(r)) + return wrapNode(ast.Set(astNodes...)) } -// RecordNodes allows for a complex record definition with values potentially -// being Cedar expressions of their own. For example, this Cedar text: -// -// {"x": 1 + context.fooCount} -// -// could be expressed in Golang as: -// -// ast.RecordNodes(map[types.String]Node{ -// "x": ast.Long(1).Plus(ast.Context().Access("fooCount"))}, -// }) -func RecordNodes(entries map[types.String]Node) Node { - astNodes := map[types.String]ast.Node{} - for k, v := range entries { - astNodes[k] = v.Node - } - return wrapNode(ast.RecordNodes(astNodes)) -} - -type RecordElement struct { +type Pair struct { Key types.String Value Node } -func RecordElements(elements ...RecordElement) Node { - var astNodes []ast.RecordElement +type Pairs []Pair + +func Record(elements Pairs) Node { + var astNodes []ast.Pair for _, v := range elements { - astNodes = append(astNodes, ast.RecordElement{Key: v.Key, Value: v.Value.Node}) + astNodes = append(astNodes, ast.Pair{Key: v.Key, Value: v.Value.Node}) } - return wrapNode(ast.RecordElements(astNodes...)) + return wrapNode(ast.Record(astNodes)) } func EntityUID(e types.EntityUID) Node { @@ -99,3 +71,7 @@ func Decimal(d types.Decimal) Node { func IPAddr(i types.IPAddr) Node { return wrapNode(ast.IPAddr(i)) } + +func Value(v types.Value) Node { + return wrapNode(ast.Value(v)) +} diff --git a/internal/ast/ast_test.go b/internal/ast/ast_test.go index b15242d7..994f25c9 100644 --- a/internal/ast/ast_test.go +++ b/internal/ast/ast_test.go @@ -56,15 +56,14 @@ func TestAstExamples(t *testing.T) { } _ = ast.Forbid(). When( - ast.Record(simpleRecord).Access("x").Equals(ast.String("value")), + ast.Value(simpleRecord).Access("x").Equals(ast.String("value")), ). When( - ast.RecordNodes(map[types.String]ast.Node{ - "x": ast.Long(1).Plus(ast.Context().Access("fooCount")), - }).Access("x").Equals(ast.Long(3)), + ast.Record(ast.Pairs{{Key: "x", Value: ast.Long(1).Plus(ast.Context().Access("fooCount"))}}). + Access("x").Equals(ast.Long(3)), ). When( - ast.SetNodes( + ast.Set( ast.Long(1), ast.Long(2).Plus(ast.Long(3)).Times(ast.Long(4)), ast.Context().Access("fooCount"), @@ -250,37 +249,23 @@ func TestASTByTable(t *testing.T) { }, { "valueSet", - ast.Permit().When(ast.Set(types.Set{types.Long(42), types.Long(43)})), + ast.Permit().When(ast.SetDeprecated(types.Set{types.Long(42), types.Long(43)})), ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeSet{Elements: []ast.IsNode{ast.NodeValue{Value: types.Long(42)}, ast.NodeValue{Value: types.Long(43)}}}}}, }, }, { "valueSetNodes", - ast.Permit().When(ast.SetNodes(ast.Long(42), ast.Long(43))), + ast.Permit().When(ast.Set(ast.Long(42), ast.Long(43))), ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeSet{Elements: []ast.IsNode{ast.NodeValue{Value: types.Long(42)}, ast.NodeValue{Value: types.Long(43)}}}}}, }, }, { "valueRecord", - ast.Permit().When(ast.Record(types.Record{"key": types.Long(43)})), + ast.Permit().When(ast.Value(types.Record{"key": types.Long(43)})), ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, - Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeRecord{Elements: []ast.RecordElementNode{{Key: "key", Value: ast.NodeValue{Value: types.Long(43)}}}}}}, - }, - }, - { - "valueRecordNodes", - ast.Permit().When(ast.RecordNodes(map[types.String]ast.Node{"key": ast.Long(42)})), - ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, - Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeRecord{Elements: []ast.RecordElementNode{{Key: "key", Value: ast.NodeValue{Value: types.Long(42)}}}}}}, - }, - }, - { - "valueRecordElements", - ast.Permit().When(ast.RecordElements(ast.RecordElement{Key: "key", Value: ast.Long(42)})), - ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, - Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeRecord{Elements: []ast.RecordElementNode{{Key: "key", Value: ast.NodeValue{Value: types.Long(42)}}}}}}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeValue{Value: types.Record{"key": types.Long(43)}}}}, }, }, { diff --git a/internal/ast/value.go b/internal/ast/value.go index 7efd17fc..d78861c2 100644 --- a/internal/ast/value.go +++ b/internal/ast/value.go @@ -5,7 +5,7 @@ import ( ) func Boolean(b types.Boolean) Node { - return NewValueNode(b) + return Value(b) } func True() Node { @@ -17,74 +17,47 @@ func False() Node { } func String(s types.String) Node { - return NewValueNode(s) + return Value(s) } func Long(l types.Long) Node { - return NewValueNode(l) + return Value(l) } -// Set is a convenience function that wraps concrete instances of a Cedar Set type +// SetDeprecated is a convenience function that wraps concrete instances of a Cedar SetDeprecated type // types in AST value nodes and passes them along to SetNodes. -func Set(s types.Set) Node { +func SetDeprecated(s types.Set) Node { var nodes []IsNode for _, v := range s { - nodes = append(nodes, NewValueNode(v).v) + nodes = append(nodes, Value(v).v) } return NewNode(NodeTypeSet{Elements: nodes}) } -// SetNodes allows for a complex set definition with values potentially +// Set allows for a complex set definition with values potentially // being Cedar expressions of their own. For example, this Cedar text: // // [1, 2 + 3, context.fooCount] // // could be expressed in Golang as: // -// ast.SetNodes( +// ast.Set( // ast.Long(1), // ast.Long(2).Plus(ast.Long(3)), // ast.Context().Access("fooCount"), // ) -func SetNodes(nodes ...Node) Node { +func Set(nodes ...Node) Node { return NewNode(NodeTypeSet{Elements: stripNodes(nodes)}) } -// Record is a convenience function that wraps concrete instances of a Cedar Record type -// types in AST value nodes and passes them along to RecordNodes. -func Record(r types.Record) Node { - // TODO: this results in a double allocation, fix that - recordNodes := map[types.String]Node{} - for k, v := range r { - recordNodes[types.String(k)] = NewValueNode(v) - } - return RecordNodes(recordNodes) -} - -// RecordNodes allows for a complex record definition with values potentially -// being Cedar expressions of their own. For example, this Cedar text: -// -// {"x": 1 + context.fooCount} -// -// could be expressed in Golang as: -// -// ast.RecordNodes(map[types.String]Node{ -// "x": ast.Long(1).Plus(ast.Context().Access("fooCount"))}, -// }) -func RecordNodes(entries map[types.String]Node) Node { - var res NodeTypeRecord - for k, v := range entries { - res.Elements = append(res.Elements, RecordElementNode{Key: k, Value: v.v}) - } - return NewNode(res) -} - -type RecordElement struct { +type Pair struct { Key types.String Value Node } -func RecordElements(elements ...RecordElement) Node { +type Pairs []Pair + +func Record(elements Pairs) Node { var res NodeTypeRecord for _, e := range elements { res.Elements = append(res.Elements, RecordElementNode{Key: e.Key, Value: e.Value.v}) @@ -93,21 +66,21 @@ func RecordElements(elements ...RecordElement) Node { } func EntityUID(e types.EntityUID) Node { - return NewValueNode(e) + return Value(e) } func Decimal(d types.Decimal) Node { - return NewValueNode(d) + return Value(d) } func IPAddr(i types.IPAddr) Node { - return NewValueNode(i) + return Value(i) } func ExtensionCall(name types.String, args ...Node) Node { return NewExtensionCall(name, args...) } -func NewValueNode(v types.Value) Node { +func Value(v types.Value) Node { return NewNode(NodeValue{Value: v}) } diff --git a/internal/eval/compile.go b/internal/eval/compile.go index 25c928c7..932fb60c 100644 --- a/internal/eval/compile.go +++ b/internal/eval/compile.go @@ -44,7 +44,7 @@ func scopeToNode(varNode ast.NodeTypeVariable, in ast.IsScopeNode) ast.Node { for i, e := range t.Entities { set[i] = e } - return ast.NewNode(varNode).In(ast.Set(set)) + return ast.NewNode(varNode).In(ast.SetDeprecated(set)) case ast.ScopeTypeIs: return ast.NewNode(varNode).Is(t.Type) diff --git a/internal/eval/compile_test.go b/internal/eval/compile_test.go index 94e79d6c..98df2077 100644 --- a/internal/eval/compile_test.go +++ b/internal/eval/compile_test.go @@ -93,7 +93,7 @@ func TestScopeToNode(t *testing.T) { "inSet", ast.NewActionNode(), ast.ScopeTypeInSet{Entities: []types.EntityUID{types.NewEntityUID("T", "42")}}, - ast.Action().In(ast.Set(types.Set{types.NewEntityUID("T", "42")})), + ast.Action().In(ast.SetDeprecated(types.Set{types.NewEntityUID("T", "42")})), }, { "is", diff --git a/internal/eval/convert_test.go b/internal/eval/convert_test.go index a6fe1a23..dffcbb26 100644 --- a/internal/eval/convert_test.go +++ b/internal/eval/convert_test.go @@ -19,13 +19,13 @@ func TestToEval(t *testing.T) { }{ { "access", - ast.Record(types.Record{"key": types.Long(42)}).Access("key"), + ast.Value(types.Record{"key": types.Long(42)}).Access("key"), types.Long(42), testutil.OK, }, { "has", - ast.Record(types.Record{"key": types.Long(42)}).Has("key"), + ast.Value(types.Record{"key": types.Long(42)}).Has("key"), types.True, testutil.OK, }, @@ -61,13 +61,13 @@ func TestToEval(t *testing.T) { }, { "record", - ast.RecordElements(ast.RecordElement{Key: "key", Value: ast.Long(42)}), + ast.Record(ast.Pairs{{Key: "key", Value: ast.Long(42)}}), types.Record{"key": types.Long(42)}, testutil.OK, }, { "set", - ast.SetNodes(ast.Long(42)), + ast.Set(ast.Long(42)), types.Set{types.Long(42)}, testutil.OK, }, @@ -181,19 +181,19 @@ func TestToEval(t *testing.T) { }, { "contains", - ast.Set(types.Set{types.Long(42)}).Contains(ast.Long(42)), + ast.SetDeprecated(types.Set{types.Long(42)}).Contains(ast.Long(42)), types.True, testutil.OK, }, { "containsAll", - ast.Set(types.Set{types.Long(42), types.Long(43), types.Long(44)}).ContainsAll(ast.Set(types.Set{types.Long(42), types.Long(43)})), + ast.SetDeprecated(types.Set{types.Long(42), types.Long(43), types.Long(44)}).ContainsAll(ast.SetDeprecated(types.Set{types.Long(42), types.Long(43)})), types.True, testutil.OK, }, { "containsAny", - ast.Set(types.Set{types.Long(42), types.Long(43), types.Long(44)}).ContainsAny(ast.Set(types.Set{types.Long(1), types.Long(42)})), + ast.SetDeprecated(types.Set{types.Long(42), types.Long(43), types.Long(44)}).ContainsAny(ast.SetDeprecated(types.Set{types.Long(1), types.Long(42)})), types.True, testutil.OK, }, diff --git a/internal/json/json_test.go b/internal/json/json_test.go index b89115e6..8eaefbab 100644 --- a/internal/json/json_test.go +++ b/internal/json/json_test.go @@ -208,14 +208,14 @@ func TestUnmarshalJSON(t *testing.T) { "set", `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, "conditions":[{"kind":"when","body":{"Set":[{"Value":42},{"Value":"bananas"}]}}]}`, - ast.Permit().When(ast.Set(types.Set{types.Long(42), types.String("bananas")})), + ast.Permit().When(ast.SetDeprecated(types.Set{types.Long(42), types.String("bananas")})), testutil.OK, }, { "record", `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, "conditions":[{"kind":"when","body":{"Record":{"key":{"Value":42}}}}]}`, - ast.Permit().When(ast.Record(types.Record{"key": types.Long(42)})), + ast.Permit().When(ast.Record(ast.Pairs{{Key: "key", Value: ast.Long(42)}})), testutil.OK, }, { diff --git a/internal/json/json_unmarshal.go b/internal/json/json_unmarshal.go index 46135d8e..088a908f 100644 --- a/internal/json/json_unmarshal.go +++ b/internal/json/json_unmarshal.go @@ -115,19 +115,19 @@ func (j arrayJSON) ToNode() (ast.Node, error) { } nodes = append(nodes, n) } - return ast.SetNodes(nodes...), nil + return ast.Set(nodes...), nil } func (j recordJSON) ToNode() (ast.Node, error) { - nodes := map[types.String]ast.Node{} + var nodes ast.Pairs for k, v := range j { n, err := v.ToNode() if err != nil { return ast.Node{}, fmt.Errorf("error in record: %w", err) } - nodes[types.String(k)] = n + nodes = append(nodes, ast.Pair{Key: types.String(k), Value: n}) } - return ast.RecordNodes(nodes), nil + return ast.Record(nodes), nil } func (e extensionJSON) ToNode() (ast.Node, error) { @@ -158,7 +158,7 @@ func (j nodeJSON) ToNode() (ast.Node, error) { switch { // Value case j.Value != nil: - return ast.NewValueNode(j.Value.v), nil + return ast.Value(j.Value.v), nil // Var case j.Var != nil: diff --git a/internal/parser/cedar_marshal.go b/internal/parser/cedar_marshal.go index 129800e9..d230af73 100644 --- a/internal/parser/cedar_marshal.go +++ b/internal/parser/cedar_marshal.go @@ -42,7 +42,7 @@ func scopeToNode(varNode ast.NodeTypeVariable, in ast.IsScopeNode) ast.Node { for i, e := range t.Entities { set[i] = e } - return ast.NewNode(varNode).In(ast.Set(set)) + return ast.NewNode(varNode).In(ast.SetDeprecated(set)) case ast.ScopeTypeIs: return ast.NewNode(varNode).Is(t.Type) diff --git a/internal/parser/cedar_unmarshal.go b/internal/parser/cedar_unmarshal.go index 02e8f703..5e19e1fe 100644 --- a/internal/parser/cedar_unmarshal.go +++ b/internal/parser/cedar_unmarshal.go @@ -729,7 +729,7 @@ func (p *parser) primary() (ast.Node, error) { return res, err } p.advance() // expressions guarantees "]" - res = ast.SetNodes(set...) + res = ast.Set(set...) case t.Text == "{": record, err := p.record() if err != nil { @@ -802,13 +802,13 @@ func (p *parser) expressions(endOfListMarker string) ([]ast.Node, error) { func (p *parser) record() (ast.Node, error) { var res ast.Node - var elements []ast.RecordElement + var elements ast.Pairs known := map[types.String]struct{}{} for { t := p.peek() if t.Text == "}" { p.advance() - return ast.RecordElements(elements...), nil + return ast.Record(elements), nil } if len(elements) > 0 { if err := p.exact(","); err != nil { @@ -824,7 +824,7 @@ func (p *parser) record() (ast.Node, error) { return res, p.errorf("duplicate key: %v", k) } known[k] = struct{}{} - elements = append(elements, ast.RecordElement{Key: k, Value: v}) + elements = append(elements, ast.Pair{Key: k, Value: v}) } } diff --git a/internal/parser/cedar_unmarshal_test.go b/internal/parser/cedar_unmarshal_test.go index 616f4cae..bc93ebf5 100644 --- a/internal/parser/cedar_unmarshal_test.go +++ b/internal/parser/cedar_unmarshal_test.go @@ -173,13 +173,13 @@ when { context.strings.contains("foo") };`, "containsAll method call", `permit ( principal, action, resource ) when { context.strings.containsAll(["foo"]) };`, - ast.Permit().When(ast.Context().Access("strings").ContainsAll(ast.SetNodes(ast.String("foo")))), + ast.Permit().When(ast.Context().Access("strings").ContainsAll(ast.Set(ast.String("foo")))), }, { "containsAny method call", `permit ( principal, action, resource ) when { context.strings.containsAny(["foo"]) };`, - ast.Permit().When(ast.Context().Access("strings").ContainsAny(ast.SetNodes(ast.String("foo")))), + ast.Permit().When(ast.Context().Access("strings").ContainsAny(ast.Set(ast.String("foo")))), }, { "extension method call", From c60a21317cc8d092ce372ded147177c077d7ab79 Mon Sep 17 00:00:00 2001 From: philhassey Date: Thu, 15 Aug 2024 17:02:05 -0600 Subject: [PATCH 139/216] ast: improve UX of sugar Addresses IDX-142 Signed-off-by: philhassey --- ast/ast_test.go | 15 +++++-------- ast/value.go | 21 +++++++++--------- internal/ast/ast_test.go | 13 +++-------- internal/ast/value.go | 29 ++++++++++++------------- internal/eval/compile.go | 6 ++--- internal/eval/compile_test.go | 12 +++++----- internal/eval/convert_test.go | 24 ++++++++++---------- internal/json/json_test.go | 11 +++++----- internal/json/json_unmarshal.go | 2 +- internal/parser/cedar_marshal.go | 6 ++--- internal/parser/cedar_unmarshal.go | 18 +++++++-------- internal/parser/cedar_unmarshal_test.go | 4 ++-- 12 files changed, 74 insertions(+), 87 deletions(-) diff --git a/ast/ast_test.go b/ast/ast_test.go index 7b3b1e13..dadeb18a 100644 --- a/ast/ast_test.go +++ b/ast/ast_test.go @@ -38,7 +38,7 @@ func TestAstExamples(t *testing.T) { // forbid (principal, action, resource) // when { resource.tags.contains("private") } // unless { resource in principal.allowed_resources }; - private := types.String("private") + private := "private" _ = ast.Annotation("example", "two"). Forbid(). When( @@ -235,18 +235,13 @@ func TestASTByTable(t *testing.T) { }, { "valueEntityUID", - ast.Permit().When(ast.EntityUID(types.NewEntityUID("T", "42"))), - internalast.Permit().When(internalast.EntityUID(types.NewEntityUID("T", "42"))), - }, - { - "valueDecimal", - ast.Permit().When(ast.Decimal(420000)), - internalast.Permit().When(internalast.Decimal(420000)), + ast.Permit().When(ast.EntityUID("T", "42")), + internalast.Permit().When(internalast.EntityUID("T", "42")), }, { "valueIPAddr", - ast.Permit().When(ast.IPAddr(types.IPAddr(netip.MustParsePrefix("127.0.0.1/16")))), - internalast.Permit().When(internalast.IPAddr(types.IPAddr(netip.MustParsePrefix("127.0.0.1/16")))), + ast.Permit().When(ast.IPAddr(netip.MustParsePrefix("127.0.0.1/16"))), + internalast.Permit().When(internalast.IPAddr(netip.MustParsePrefix("127.0.0.1/16"))), }, { "opEquals", diff --git a/ast/value.go b/ast/value.go index efa1213b..30f066fe 100644 --- a/ast/value.go +++ b/ast/value.go @@ -1,11 +1,13 @@ package ast import ( + "net/netip" + "github.com/cedar-policy/cedar-go/internal/ast" "github.com/cedar-policy/cedar-go/types" ) -func Boolean(b types.Boolean) Node { +func Boolean(b bool) Node { return wrapNode(ast.Boolean(b)) } @@ -17,11 +19,11 @@ func False() Node { return Boolean(false) } -func String(s types.String) Node { +func String(s string) Node { return wrapNode(ast.String(s)) } -func Long(l types.Long) Node { +func Long(l int64) Node { return wrapNode(ast.Long(l)) } @@ -46,12 +48,13 @@ func Set(nodes ...Node) Node { } type Pair struct { - Key types.String + Key string Value Node } type Pairs []Pair +// Record, TODO: document how duplicate keys might not really get handled in a meaningful way func Record(elements Pairs) Node { var astNodes []ast.Pair for _, v := range elements { @@ -60,15 +63,11 @@ func Record(elements Pairs) Node { return wrapNode(ast.Record(astNodes)) } -func EntityUID(e types.EntityUID) Node { - return wrapNode(ast.EntityUID(e)) -} - -func Decimal(d types.Decimal) Node { - return wrapNode(ast.Decimal(d)) +func EntityUID(typ, id string) Node { + return wrapNode(ast.EntityUID(typ, id)) } -func IPAddr(i types.IPAddr) Node { +func IPAddr(i netip.Prefix) Node { return wrapNode(ast.IPAddr(i)) } diff --git a/internal/ast/ast_test.go b/internal/ast/ast_test.go index 994f25c9..264830da 100644 --- a/internal/ast/ast_test.go +++ b/internal/ast/ast_test.go @@ -37,7 +37,7 @@ func TestAstExamples(t *testing.T) { // forbid (principal, action, resource) // when { resource.tags.contains("private") } // unless { resource in principal.allowed_resources }; - private := types.String("private") + private := "private" _ = ast.Annotation("example", "two"). Forbid(). When( @@ -270,21 +270,14 @@ func TestASTByTable(t *testing.T) { }, { "valueEntityUID", - ast.Permit().When(ast.EntityUID(types.NewEntityUID("T", "42"))), + ast.Permit().When(ast.EntityUID("T", "42")), ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeValue{Value: types.NewEntityUID("T", "42")}}}, }, }, - { - "valueDecimal", - ast.Permit().When(ast.Decimal(420000)), - ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, - Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeValue{Value: types.Decimal(420000)}}}, - }, - }, { "valueIPAddr", - ast.Permit().When(ast.IPAddr(types.IPAddr(netip.MustParsePrefix("127.0.0.1/16")))), + ast.Permit().When(ast.IPAddr(netip.MustParsePrefix("127.0.0.1/16"))), ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeValue{Value: types.IPAddr(netip.MustParsePrefix("127.0.0.1/16"))}}}, }, diff --git a/internal/ast/value.go b/internal/ast/value.go index d78861c2..d10b7674 100644 --- a/internal/ast/value.go +++ b/internal/ast/value.go @@ -1,11 +1,13 @@ package ast import ( + "net/netip" + "github.com/cedar-policy/cedar-go/types" ) -func Boolean(b types.Boolean) Node { - return Value(b) +func Boolean(b bool) Node { + return Value(types.Boolean(b)) } func True() Node { @@ -16,12 +18,12 @@ func False() Node { return Boolean(false) } -func String(s types.String) Node { - return Value(s) +func String(s string) Node { + return Value(types.String(s)) } -func Long(l types.Long) Node { - return Value(l) +func Long(l int64) Node { + return Value(types.Long(l)) } // SetDeprecated is a convenience function that wraps concrete instances of a Cedar SetDeprecated type @@ -51,7 +53,7 @@ func Set(nodes ...Node) Node { } type Pair struct { - Key types.String + Key string Value Node } @@ -60,21 +62,18 @@ type Pairs []Pair func Record(elements Pairs) Node { var res NodeTypeRecord for _, e := range elements { - res.Elements = append(res.Elements, RecordElementNode{Key: e.Key, Value: e.Value.v}) + res.Elements = append(res.Elements, RecordElementNode{Key: types.String(e.Key), Value: e.Value.v}) } return NewNode(res) } -func EntityUID(e types.EntityUID) Node { +func EntityUID(typ, id string) Node { + e := types.NewEntityUID(typ, id) return Value(e) } -func Decimal(d types.Decimal) Node { - return Value(d) -} - -func IPAddr(i types.IPAddr) Node { - return Value(i) +func IPAddr(i netip.Prefix) Node { + return Value(types.IPAddr(i)) } func ExtensionCall(name types.String, args ...Node) Node { diff --git a/internal/eval/compile.go b/internal/eval/compile.go index 932fb60c..4111fd68 100644 --- a/internal/eval/compile.go +++ b/internal/eval/compile.go @@ -36,9 +36,9 @@ func scopeToNode(varNode ast.NodeTypeVariable, in ast.IsScopeNode) ast.Node { case ast.ScopeTypeAll: return ast.True() case ast.ScopeTypeEq: - return ast.NewNode(varNode).Equals(ast.EntityUID(t.Entity)) + return ast.NewNode(varNode).Equals(ast.Value(t.Entity)) case ast.ScopeTypeIn: - return ast.NewNode(varNode).In(ast.EntityUID(t.Entity)) + return ast.NewNode(varNode).In(ast.Value(t.Entity)) case ast.ScopeTypeInSet: set := make([]types.Value, len(t.Entities)) for i, e := range t.Entities { @@ -49,7 +49,7 @@ func scopeToNode(varNode ast.NodeTypeVariable, in ast.IsScopeNode) ast.Node { return ast.NewNode(varNode).Is(t.Type) case ast.ScopeTypeIsIn: - return ast.NewNode(varNode).IsIn(t.Type, ast.EntityUID(t.Entity)) + return ast.NewNode(varNode).IsIn(t.Type, ast.Value(t.Entity)) default: panic(fmt.Sprintf("unknown scope type %T", t)) } diff --git a/internal/eval/compile_test.go b/internal/eval/compile_test.go index 98df2077..7c5c6f2d 100644 --- a/internal/eval/compile_test.go +++ b/internal/eval/compile_test.go @@ -36,9 +36,9 @@ func TestPolicyToNode(t *testing.T) { ActionEq(types.NewEntityUID("Action", "test")). ResourceEq(types.NewEntityUID("Resource", "database")), - ast.Principal().Equals(ast.EntityUID(types.NewEntityUID("Account", "principal"))).And( - ast.Action().Equals(ast.EntityUID(types.NewEntityUID("Action", "test"))).And( - ast.Resource().Equals(ast.EntityUID(types.NewEntityUID("Resource", "database"))), + ast.Principal().Equals(ast.EntityUID("Account", "principal")).And( + ast.Action().Equals(ast.EntityUID("Action", "test")).And( + ast.Resource().Equals(ast.EntityUID("Resource", "database")), ), ), }, @@ -81,13 +81,13 @@ func TestScopeToNode(t *testing.T) { "eq", ast.NewPrincipalNode(), ast.ScopeTypeEq{Entity: types.NewEntityUID("T", "42")}, - ast.Principal().Equals(ast.EntityUID(types.NewEntityUID("T", "42"))), + ast.Principal().Equals(ast.EntityUID("T", "42")), }, { "in", ast.NewPrincipalNode(), ast.ScopeTypeIn{Entity: types.NewEntityUID("T", "42")}, - ast.Principal().In(ast.EntityUID(types.NewEntityUID("T", "42"))), + ast.Principal().In(ast.EntityUID("T", "42")), }, { "inSet", @@ -105,7 +105,7 @@ func TestScopeToNode(t *testing.T) { "isIn", ast.NewResourceNode(), ast.ScopeTypeIsIn{Type: "T", Entity: types.NewEntityUID("T", "42")}, - ast.Resource().IsIn("T", ast.EntityUID(types.NewEntityUID("T", "42"))), + ast.Resource().IsIn("T", ast.EntityUID("T", "42")), }, } for _, tt := range tests { diff --git a/internal/eval/convert_test.go b/internal/eval/convert_test.go index dffcbb26..9d2db8d3 100644 --- a/internal/eval/convert_test.go +++ b/internal/eval/convert_test.go @@ -43,13 +43,13 @@ func TestToEval(t *testing.T) { }, { "is", - ast.EntityUID(types.NewEntityUID("T", "42")).Is("T"), + ast.EntityUID("T", "42").Is("T"), types.True, testutil.OK, }, { "isIn", - ast.EntityUID(types.NewEntityUID("T", "42")).IsIn("T", ast.EntityUID(types.NewEntityUID("T", "42"))), + ast.EntityUID("T", "42").IsIn("T", ast.EntityUID("T", "42")), types.True, testutil.OK, }, @@ -109,7 +109,7 @@ func TestToEval(t *testing.T) { }, { "in", - ast.EntityUID(types.NewEntityUID("T", "42")).In(ast.EntityUID(types.NewEntityUID("T", "43"))), + ast.EntityUID("T", "42").In(ast.EntityUID("T", "43")), types.False, testutil.OK, }, @@ -211,55 +211,55 @@ func TestToEval(t *testing.T) { }, { "lessThan", - ast.ExtensionCall("lessThan", ast.Decimal(420000), ast.Decimal(430000)), + ast.ExtensionCall("lessThan", ast.Value(types.Decimal(420000)), ast.Value(types.Decimal(430000))), types.True, testutil.OK, }, { "lessThanOrEqual", - ast.ExtensionCall("lessThanOrEqual", ast.Decimal(420000), ast.Decimal(430000)), + ast.ExtensionCall("lessThanOrEqual", ast.Value(types.Decimal(420000)), ast.Value(types.Decimal(430000))), types.True, testutil.OK, }, { "greaterThan", - ast.ExtensionCall("greaterThan", ast.Decimal(420000), ast.Decimal(430000)), + ast.ExtensionCall("greaterThan", ast.Value(types.Decimal(420000)), ast.Value(types.Decimal(430000))), types.False, testutil.OK, }, { "greaterThanOrEqual", - ast.ExtensionCall("greaterThanOrEqual", ast.Decimal(420000), ast.Decimal(430000)), + ast.ExtensionCall("greaterThanOrEqual", ast.Value(types.Decimal(420000)), ast.Value(types.Decimal(430000))), types.False, testutil.OK, }, { "isIpv4", - ast.ExtensionCall("isIpv4", ast.IPAddr(types.IPAddr(netip.MustParsePrefix("127.0.0.42/16")))), + ast.ExtensionCall("isIpv4", ast.IPAddr(netip.MustParsePrefix("127.0.0.42/16"))), types.True, testutil.OK, }, { "isIpv6", - ast.ExtensionCall("isIpv6", ast.IPAddr(types.IPAddr(netip.MustParsePrefix("::1/16")))), + ast.ExtensionCall("isIpv6", ast.IPAddr(netip.MustParsePrefix("::1/16"))), types.True, testutil.OK, }, { "isLoopback", - ast.ExtensionCall("isLoopback", ast.IPAddr(types.IPAddr(netip.MustParsePrefix("127.0.0.1/32")))), + ast.ExtensionCall("isLoopback", ast.IPAddr(netip.MustParsePrefix("127.0.0.1/32"))), types.True, testutil.OK, }, { "isMulticast", - ast.ExtensionCall("isMulticast", ast.IPAddr(types.IPAddr(netip.MustParsePrefix("239.255.255.255/32")))), + ast.ExtensionCall("isMulticast", ast.IPAddr(netip.MustParsePrefix("239.255.255.255/32"))), types.True, testutil.OK, }, { "isInRange", - ast.ExtensionCall("isInRange", ast.IPAddr(types.IPAddr(netip.MustParsePrefix("127.0.0.42/32"))), ast.IPAddr(types.IPAddr(netip.MustParsePrefix("127.0.0.0/16")))), + ast.ExtensionCall("isInRange", ast.IPAddr(netip.MustParsePrefix("127.0.0.42/32")), ast.IPAddr(netip.MustParsePrefix("127.0.0.0/16"))), types.True, testutil.OK, }, diff --git a/internal/json/json_test.go b/internal/json/json_test.go index 8eaefbab..461e0e8d 100644 --- a/internal/json/json_test.go +++ b/internal/json/json_test.go @@ -2,6 +2,7 @@ package json import ( "encoding/json" + "net/netip" "testing" "github.com/cedar-policy/cedar-go/internal/ast" @@ -201,7 +202,7 @@ func TestUnmarshalJSON(t *testing.T) { "entity", `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, "conditions":[{"kind":"when","body":{"Value":{"__entity":{"type":"T","id":"42"}}}}]}`, - ast.Permit().When(ast.EntityUID(types.NewEntityUID("T", "42"))), + ast.Permit().When(ast.EntityUID("T", "42")), testutil.OK, }, { @@ -390,7 +391,7 @@ func TestUnmarshalJSON(t *testing.T) { "isIn", `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, "conditions":[{"kind":"when","body":{"is":{"left":{"Var":"resource"},"entity_type":"T","in":{"Value":{"__entity":{"type":"P","id":"42"}}}}}}]}`, - ast.Permit().When(ast.Resource().IsIn("T", ast.EntityUID(types.NewEntityUID("P", "42")))), + ast.Permit().When(ast.Resource().IsIn("T", ast.EntityUID("P", "42"))), testutil.OK, }, { @@ -508,7 +509,7 @@ func TestMarshalJSON(t *testing.T) { }{ { "decimal", - ast.Permit().When(ast.Decimal(mustParseDecimal("42.24"))), + ast.Permit().When(ast.Value(mustParseDecimal("42.24"))), `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, "conditions":[{"kind":"when","body":{"decimal":[{"Value":"42.24"}]}}]}`, testutil.OK, @@ -561,9 +562,9 @@ func mustParseDecimal(v string) types.Decimal { res, _ := types.ParseDecimal(v) return res } -func mustParseIPAddr(v string) types.IPAddr { +func mustParseIPAddr(v string) netip.Prefix { res, _ := types.ParseIPAddr(v) - return res + return netip.Prefix(res) } func TestMarshalPanics(t *testing.T) { diff --git a/internal/json/json_unmarshal.go b/internal/json/json_unmarshal.go index 088a908f..c345a25a 100644 --- a/internal/json/json_unmarshal.go +++ b/internal/json/json_unmarshal.go @@ -125,7 +125,7 @@ func (j recordJSON) ToNode() (ast.Node, error) { if err != nil { return ast.Node{}, fmt.Errorf("error in record: %w", err) } - nodes = append(nodes, ast.Pair{Key: types.String(k), Value: n}) + nodes = append(nodes, ast.Pair{Key: k, Value: n}) } return ast.Record(nodes), nil } diff --git a/internal/parser/cedar_marshal.go b/internal/parser/cedar_marshal.go index d230af73..14b4e0ef 100644 --- a/internal/parser/cedar_marshal.go +++ b/internal/parser/cedar_marshal.go @@ -34,9 +34,9 @@ func scopeToNode(varNode ast.NodeTypeVariable, in ast.IsScopeNode) ast.Node { case ast.ScopeTypeAll: return ast.True() case ast.ScopeTypeEq: - return ast.NewNode(varNode).Equals(ast.EntityUID(t.Entity)) + return ast.NewNode(varNode).Equals(ast.Value(t.Entity)) case ast.ScopeTypeIn: - return ast.NewNode(varNode).In(ast.EntityUID(t.Entity)) + return ast.NewNode(varNode).In(ast.Value(t.Entity)) case ast.ScopeTypeInSet: set := make([]types.Value, len(t.Entities)) for i, e := range t.Entities { @@ -47,7 +47,7 @@ func scopeToNode(varNode ast.NodeTypeVariable, in ast.IsScopeNode) ast.Node { return ast.NewNode(varNode).Is(t.Type) case ast.ScopeTypeIsIn: - return ast.NewNode(varNode).IsIn(t.Type, ast.EntityUID(t.Entity)) + return ast.NewNode(varNode).IsIn(t.Type, ast.Value(t.Entity)) default: panic(fmt.Sprintf("unknown scope type %T", t)) } diff --git a/internal/parser/cedar_unmarshal.go b/internal/parser/cedar_unmarshal.go index 5e19e1fe..7ca92c44 100644 --- a/internal/parser/cedar_unmarshal.go +++ b/internal/parser/cedar_unmarshal.go @@ -650,7 +650,7 @@ func (p *parser) unary() (ast.Node, error) { if err != nil { return ast.Node{}, err } - res = ast.Long(types.Long(i)) + res = ast.Long(i) ops = ops[:len(ops)-1] } else { var err error @@ -693,13 +693,13 @@ func (p *parser) primary() (ast.Node, error) { if err != nil { return res, err } - res = ast.Long(types.Long(i)) + res = ast.Long(i) case t.isString(): str, err := t.stringValue() if err != nil { return res, err } - res = ast.String(types.String(str)) + res = ast.String(str) case t.Text == "true": res = ast.True() case t.Text == "false": @@ -756,7 +756,7 @@ func (p *parser) entityOrExtFun(prefix string) (ast.Node, error) { if err != nil { return ast.Node{}, err } - return ast.EntityUID(types.NewEntityUID(prefix, id)), nil + return ast.EntityUID(prefix, id), nil default: return ast.Node{}, p.errorf("unexpected token") } @@ -803,7 +803,7 @@ func (p *parser) expressions(endOfListMarker string) ([]ast.Node, error) { func (p *parser) record() (ast.Node, error) { var res ast.Node var elements ast.Pairs - known := map[types.String]struct{}{} + known := map[string]struct{}{} for { t := p.peek() if t.Text == "}" { @@ -828,20 +828,20 @@ func (p *parser) record() (ast.Node, error) { } } -func (p *parser) recordEntry() (types.String, ast.Node, error) { - var key types.String +func (p *parser) recordEntry() (string, ast.Node, error) { + var key string var value ast.Node var err error t := p.advance() switch { case t.isIdent(): - key = types.String(t.Text) + key = t.Text case t.isString(): str, err := t.stringValue() if err != nil { return key, value, err } - key = types.String(str) + key = str default: return key, value, p.errorf("unexpected token") } diff --git a/internal/parser/cedar_unmarshal_test.go b/internal/parser/cedar_unmarshal_test.go index bc93ebf5..deb46cd1 100644 --- a/internal/parser/cedar_unmarshal_test.go +++ b/internal/parser/cedar_unmarshal_test.go @@ -269,7 +269,7 @@ when { 2 != 42 };`, "in", `permit ( principal, action, resource ) when { principal in Group::"folkHeroes" };`, - ast.Permit().When(ast.Principal().In(ast.EntityUID(folkHeroes))), + ast.Permit().When(ast.Principal().In(ast.Value(folkHeroes))), }, { "has ident", @@ -312,7 +312,7 @@ when { principal is User };`, "is in", `permit ( principal, action, resource ) when { principal is User in Group::"folkHeroes" };`, - ast.Permit().When(ast.Principal().IsIn("User", ast.EntityUID(folkHeroes))), + ast.Permit().When(ast.Principal().IsIn("User", ast.Value(folkHeroes))), }, { "and", From e8ca09107d721deea5c012629fbf62a923baa6db Mon Sep 17 00:00:00 2001 From: philhassey Date: Thu, 15 Aug 2024 17:35:55 -0600 Subject: [PATCH 140/216] ast: make sugar of types sweeter Addresses IDX-142 Signed-off-by: philhassey --- ast/value.go | 16 ++++++++-------- internal/ast/ast_test.go | 4 ++-- internal/ast/value.go | 8 ++++---- internal/eval/compile.go | 4 ++-- internal/eval/compile_test.go | 2 +- internal/eval/convert_test.go | 13 ++++++++++--- internal/eval/evalers_test.go | 4 ++-- internal/json/json_test.go | 2 +- internal/parser/cedar_marshal.go | 4 ++-- types/value.go | 6 +++--- types/value_test.go | 6 +++--- 11 files changed, 38 insertions(+), 31 deletions(-) diff --git a/ast/value.go b/ast/value.go index 30f066fe..a7039b7a 100644 --- a/ast/value.go +++ b/ast/value.go @@ -7,8 +7,8 @@ import ( "github.com/cedar-policy/cedar-go/types" ) -func Boolean(b bool) Node { - return wrapNode(ast.Boolean(b)) +func Boolean[T bool | types.Boolean](b T) Node { + return wrapNode(ast.Boolean(types.Boolean(b))) } func True() Node { @@ -19,12 +19,12 @@ func False() Node { return Boolean(false) } -func String(s string) Node { - return wrapNode(ast.String(s)) +func String[T string | types.String](s T) Node { + return wrapNode(ast.String(types.String(s))) } -func Long(l int64) Node { - return wrapNode(ast.Long(l)) +func Long[T int | int64 | types.Long](l T) Node { + return wrapNode(ast.Long(types.Long(l))) } // Set allows for a complex set definition with values potentially @@ -67,8 +67,8 @@ func EntityUID(typ, id string) Node { return wrapNode(ast.EntityUID(typ, id)) } -func IPAddr(i netip.Prefix) Node { - return wrapNode(ast.IPAddr(i)) +func IPAddr[T netip.Prefix | types.IPAddr](i T) Node { + return wrapNode(ast.IPAddr(types.IPAddr(i))) } func Value(v types.Value) Node { diff --git a/internal/ast/ast_test.go b/internal/ast/ast_test.go index 264830da..0473e82e 100644 --- a/internal/ast/ast_test.go +++ b/internal/ast/ast_test.go @@ -249,9 +249,9 @@ func TestASTByTable(t *testing.T) { }, { "valueSet", - ast.Permit().When(ast.SetDeprecated(types.Set{types.Long(42), types.Long(43)})), + ast.Permit().When(ast.Value(types.Set{types.Long(42), types.Long(43)})), ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, - Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeSet{Elements: []ast.IsNode{ast.NodeValue{Value: types.Long(42)}, ast.NodeValue{Value: types.Long(43)}}}}}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeValue{Value: types.Set{types.Long(42), types.Long(43)}}}}, }, }, { diff --git a/internal/ast/value.go b/internal/ast/value.go index d10b7674..1ed6e66b 100644 --- a/internal/ast/value.go +++ b/internal/ast/value.go @@ -6,7 +6,7 @@ import ( "github.com/cedar-policy/cedar-go/types" ) -func Boolean(b bool) Node { +func Boolean[T bool | types.Boolean](b T) Node { return Value(types.Boolean(b)) } @@ -18,11 +18,11 @@ func False() Node { return Boolean(false) } -func String(s string) Node { +func String[T string | types.String](s T) Node { return Value(types.String(s)) } -func Long(l int64) Node { +func Long[T int | int64 | types.Long](l T) Node { return Value(types.Long(l)) } @@ -72,7 +72,7 @@ func EntityUID(typ, id string) Node { return Value(e) } -func IPAddr(i netip.Prefix) Node { +func IPAddr[T netip.Prefix | types.IPAddr](i T) Node { return Value(types.IPAddr(i)) } diff --git a/internal/eval/compile.go b/internal/eval/compile.go index 4111fd68..7018bb97 100644 --- a/internal/eval/compile.go +++ b/internal/eval/compile.go @@ -40,11 +40,11 @@ func scopeToNode(varNode ast.NodeTypeVariable, in ast.IsScopeNode) ast.Node { case ast.ScopeTypeIn: return ast.NewNode(varNode).In(ast.Value(t.Entity)) case ast.ScopeTypeInSet: - set := make([]types.Value, len(t.Entities)) + set := make(types.Set, len(t.Entities)) for i, e := range t.Entities { set[i] = e } - return ast.NewNode(varNode).In(ast.SetDeprecated(set)) + return ast.NewNode(varNode).In(ast.Value(set)) case ast.ScopeTypeIs: return ast.NewNode(varNode).Is(t.Type) diff --git a/internal/eval/compile_test.go b/internal/eval/compile_test.go index 7c5c6f2d..dad65f8c 100644 --- a/internal/eval/compile_test.go +++ b/internal/eval/compile_test.go @@ -93,7 +93,7 @@ func TestScopeToNode(t *testing.T) { "inSet", ast.NewActionNode(), ast.ScopeTypeInSet{Entities: []types.EntityUID{types.NewEntityUID("T", "42")}}, - ast.Action().In(ast.SetDeprecated(types.Set{types.NewEntityUID("T", "42")})), + ast.Action().In(ast.Value(types.Set{types.NewEntityUID("T", "42")})), }, { "is", diff --git a/internal/eval/convert_test.go b/internal/eval/convert_test.go index 9d2db8d3..1794b8c1 100644 --- a/internal/eval/convert_test.go +++ b/internal/eval/convert_test.go @@ -181,19 +181,19 @@ func TestToEval(t *testing.T) { }, { "contains", - ast.SetDeprecated(types.Set{types.Long(42)}).Contains(ast.Long(42)), + ast.Value(types.Set{types.Long(42)}).Contains(ast.Long(42)), types.True, testutil.OK, }, { "containsAll", - ast.SetDeprecated(types.Set{types.Long(42), types.Long(43), types.Long(44)}).ContainsAll(ast.SetDeprecated(types.Set{types.Long(42), types.Long(43)})), + ast.Value(types.Set{types.Long(42), types.Long(43), types.Long(44)}).ContainsAll(ast.Value(types.Set{types.Long(42), types.Long(43)})), types.True, testutil.OK, }, { "containsAny", - ast.SetDeprecated(types.Set{types.Long(42), types.Long(43), types.Long(44)}).ContainsAny(ast.SetDeprecated(types.Set{types.Long(1), types.Long(42)})), + ast.Value(types.Set{types.Long(42), types.Long(43), types.Long(44)}).ContainsAny(ast.Value(types.Set{types.Long(1), types.Long(42)})), types.True, testutil.OK, }, @@ -301,3 +301,10 @@ func TestToEvalPanic(t *testing.T) { _ = toEval(ast.Node{}.AsIsNode()) }) } + +func TestToEvalVariablePanic(t *testing.T) { + t.Parallel() + testutil.AssertPanic(t, func() { + _ = toEval(ast.NodeTypeVariable{Name: "bananas"}) + }) +} diff --git a/internal/eval/evalers_test.go b/internal/eval/evalers_test.go index aab60924..a3793473 100644 --- a/internal/eval/evalers_test.go +++ b/internal/eval/evalers_test.go @@ -1864,8 +1864,8 @@ func TestCedarString(t *testing.T) { {"string", types.String("hello"), `hello`, `"hello"`}, {"number", types.Long(42), `42`, `42`}, {"bool", types.True, `true`, `true`}, - {"record", types.Record{"a": types.Long(42), "b": types.Long(43)}, `{"a":42,"b":43}`, `{"a":42,"b":43}`}, - {"set", types.Set{types.Long(42), types.Long(43)}, `[42,43]`, `[42,43]`}, + {"record", types.Record{"a": types.Long(42), "b": types.Long(43)}, `{"a": 42, "b": 43}`, `{"a": 42, "b": 43}`}, + {"set", types.Set{types.Long(42), types.Long(43)}, `[42, 43]`, `[42, 43]`}, {"singleIP", types.IPAddr(netip.MustParsePrefix("192.168.0.42/32")), `192.168.0.42`, `ip("192.168.0.42")`}, {"ipPrefix", types.IPAddr(netip.MustParsePrefix("192.168.0.42/24")), `192.168.0.42/24`, `ip("192.168.0.42/24")`}, {"decimal", types.Decimal(12345678), `1234.5678`, `decimal("1234.5678")`}, diff --git a/internal/json/json_test.go b/internal/json/json_test.go index 461e0e8d..8ccf5c53 100644 --- a/internal/json/json_test.go +++ b/internal/json/json_test.go @@ -209,7 +209,7 @@ func TestUnmarshalJSON(t *testing.T) { "set", `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, "conditions":[{"kind":"when","body":{"Set":[{"Value":42},{"Value":"bananas"}]}}]}`, - ast.Permit().When(ast.SetDeprecated(types.Set{types.Long(42), types.String("bananas")})), + ast.Permit().When(ast.Set(ast.Long(42), ast.String("bananas"))), testutil.OK, }, { diff --git a/internal/parser/cedar_marshal.go b/internal/parser/cedar_marshal.go index 14b4e0ef..44bb02a0 100644 --- a/internal/parser/cedar_marshal.go +++ b/internal/parser/cedar_marshal.go @@ -38,11 +38,11 @@ func scopeToNode(varNode ast.NodeTypeVariable, in ast.IsScopeNode) ast.Node { case ast.ScopeTypeIn: return ast.NewNode(varNode).In(ast.Value(t.Entity)) case ast.ScopeTypeInSet: - set := make([]types.Value, len(t.Entities)) + set := make(types.Set, len(t.Entities)) for i, e := range t.Entities { set[i] = e } - return ast.NewNode(varNode).In(ast.SetDeprecated(set)) + return ast.NewNode(varNode).In(ast.Value(set)) case ast.ScopeTypeIs: return ast.NewNode(varNode).Is(t.Type) diff --git a/types/value.go b/types/value.go index 1c21f2a5..f3549f6d 100644 --- a/types/value.go +++ b/types/value.go @@ -213,7 +213,7 @@ func (v Set) Cedar() string { sb.WriteRune('[') for i, elem := range v { if i > 0 { - sb.WriteString(",") + sb.WriteString(", ") } sb.WriteString(elem.Cedar()) } @@ -319,11 +319,11 @@ func (r Record) Cedar() string { for _, k := range keys { v := r[k] if !first { - sb.WriteString(",") + sb.WriteString(", ") } first = false sb.WriteString(strconv.Quote(k)) - sb.WriteString(":") + sb.WriteString(": ") sb.WriteString(v.Cedar()) } sb.WriteRune('}') diff --git a/types/value_test.go b/types/value_test.go index 9cf2ffbb..1f43e4e4 100644 --- a/types/value_test.go +++ b/types/value_test.go @@ -185,7 +185,7 @@ func TestSet(t *testing.T) { AssertValueString( t, Set{Boolean(true), Long(1)}, - "[true,1]") + "[true, 1]") }) t.Run("TypeName", func(t *testing.T) { @@ -267,14 +267,14 @@ func TestRecord(t *testing.T) { AssertValueString( t, Record{"foo": Boolean(true)}, - `{"foo":true}`) + `{"foo": true}`) AssertValueString( t, Record{ "foo": Boolean(true), "bar": String("blah"), }, - `{"bar":"blah","foo":true}`) + `{"bar": "blah", "foo": true}`) }) t.Run("TypeName", func(t *testing.T) { From b0a914d7abcc7291ca82997a20a083ee9933c770 Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Thu, 15 Aug 2024 15:14:00 -0700 Subject: [PATCH 141/216] cedar-go: add DeletePolicy to PolicySet Signed-off-by: philhassey --- policy_set.go | 5 +++++ policy_set_test.go | 23 +++++++++++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/policy_set.go b/policy_set.go index e5454fbf..fa327fb3 100644 --- a/policy_set.go +++ b/policy_set.go @@ -79,3 +79,8 @@ func (p PolicySet) GetPolicy(policyID PolicyID) *Policy { func (p *PolicySet) UpsertPolicy(policyID PolicyID, policy *Policy) { p.policies[policyID] = policy } + +// DeletePolicy removes a policy from the PolicySet. Deleting a non-existent policy is a no-op. +func (p *PolicySet) DeletePolicy(policyID PolicyID) { + delete(p.policies, policyID) +} diff --git a/policy_set_test.go b/policy_set_test.go index 9185ecdf..9a44c47a 100644 --- a/policy_set_test.go +++ b/policy_set_test.go @@ -78,3 +78,26 @@ func TestUpsertPolicy(t *testing.T) { testutil.Equals(t, ps.GetPolicy("a wavering policy"), p2) }) } + +func TestDeletePolicy(t *testing.T) { + t.Parallel() + t.Run("delete non-existent", func(t *testing.T) { + t.Parallel() + + ps := NewPolicySetFromPolicies(nil) + + // Just verify that this doesn't crash + ps.DeletePolicy("not a policy") + }) + t.Run("delete existing", func(t *testing.T) { + t.Parallel() + + ps := NewPolicySetFromPolicies(nil) + + p1 := NewPolicyFromAST(ast.Forbid()) + ps.UpsertPolicy("a policy", p1) + ps.DeletePolicy("a policy") + + testutil.Equals(t, ps.GetPolicy("a policy"), nil) + }) +} From fbce820de231d22cafebd5e732e363cd63bbc7aa Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Thu, 15 Aug 2024 15:38:14 -0700 Subject: [PATCH 142/216] cedar-go: rejigger the constructors for PolicySet Most importantly, NewPolicySet() now creates an empty PolicySet and what used to be called NewPolicySet() is now called NewPolicySetFromFile() Signed-off-by: philhassey --- authorize_test.go | 2 +- corpus_test.go | 4 ++-- policy_set.go | 40 ++++++++++------------------------------ policy_set_test.go | 41 ++++++++++++----------------------------- 4 files changed, 25 insertions(+), 62 deletions(-) diff --git a/authorize_test.go b/authorize_test.go index 8eda2181..70f74204 100644 --- a/authorize_test.go +++ b/authorize_test.go @@ -709,7 +709,7 @@ func TestIsAuthorized(t *testing.T) { tt := tt t.Run(tt.Name, func(t *testing.T) { t.Parallel() - ps, err := NewPolicySet("policy.cedar", []byte(tt.Policy)) + ps, err := NewPolicySetFromFile("policy.cedar", []byte(tt.Policy)) testutil.Equals(t, (err != nil), tt.ParseErr) ok, diag := ps.IsAuthorized(tt.Entities, Request{ Principal: tt.Principal, diff --git a/corpus_test.go b/corpus_test.go index 93ddef7c..509c7773 100644 --- a/corpus_test.go +++ b/corpus_test.go @@ -150,7 +150,7 @@ func TestCorpus(t *testing.T) { t.Fatal("error reading policy content", err) } - policySet, err := NewPolicySet("policy.cedar", policyContent) + policySet, err := NewPolicySetFromFile("policy.cedar", policyContent) if err != nil { t.Fatal("error parsing policy set", err) } @@ -336,7 +336,7 @@ func TestCorpusRelated(t *testing.T) { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() - policy, err := NewPolicySet("", []byte(tt.policy)) + policy, err := NewPolicySetFromFile("", []byte(tt.policy)) testutil.OK(t, err) ok, diag := policy.IsAuthorized(entities2.Entities{}, tt.request) testutil.Equals(t, ok, tt.decision) diff --git a/policy_set.go b/policy_set.go index fa327fb3..886ae343 100644 --- a/policy_set.go +++ b/policy_set.go @@ -11,17 +11,21 @@ import ( type PolicyID string -// TODO: Put a better comment here +// A set of named policies against which a request can be authorized. type PolicySet struct { policies map[PolicyID]*Policy } -// NewPolicySet will create a PolicySet from the given text document with the -// given file name used in Position data. If there is an error parsing the -// document, it will be returned. +// Create a new, empty PolicySet +func NewPolicySet() PolicySet { + return PolicySet{policies: map[PolicyID]*Policy{}} +} + +// NewPolicySetFromFile will create a PolicySet from the given text document with the/ given file name used in Position +// data. If there is an error parsing the document, it will be returned. // -// NewPolicySet assigns default PolicyIDs to the policies contained in fileName. -func NewPolicySet(fileName string, document []byte) (PolicySet, error) { +// NewPolicySetFromFile assigns default PolicyIDs to the policies contained in fileName. +func NewPolicySetFromFile(fileName string, document []byte) (PolicySet, error) { var res parser.PolicySet if err := res.UnmarshalCedar(document); err != nil { return PolicySet{}, fmt.Errorf("parser error: %w", err) @@ -45,30 +49,6 @@ func NewPolicySet(fileName string, document []byte) (PolicySet, error) { return PolicySet{policies: policyMap}, nil } -// NewPolicySetFromPolicies will create a PolicySet from a slice of existing Policys. This constructor can be used to -// support the creation of a PolicySet from JSON-encoded policies or policies created via the ast package, like so: -// -// policy0 := NewPolicyFromAST(ast.Forbid()) -// -// var policy1 Policy -// _ = policy1.UnmarshalJSON( -// []byte(`{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}}`), -// )) -// -// ps := NewPolicySetFromPolicies([]*Policy{policy0, &policy1}) -// -// NewPolicySetFromPolicies assigns default PolicyIDs to the policies that are passed. If you would like to assign your -// own PolicyIDs, call NewPolicySetFromPolicies() with an empty slice and use PolicySet.UpsertPolicy() to add the -// policies individually with the desired PolicyID. -func NewPolicySetFromPolicies(policies []*Policy) PolicySet { - policyMap := make(map[PolicyID]*Policy, len(policies)) - for i, p := range policies { - policyID := PolicyID(fmt.Sprintf("policy%d", i)) - policyMap[policyID] = p - } - return PolicySet{policies: policyMap} -} - // GetPolicy returns a pointer to the Policy with the given ID. If a policy with the given ID does not exist, nil is // returned. func (p PolicySet) GetPolicy(policyID PolicyID) *Policy { diff --git a/policy_set_test.go b/policy_set_test.go index 9a44c47a..93f41c98 100644 --- a/policy_set_test.go +++ b/policy_set_test.go @@ -7,35 +7,29 @@ import ( "github.com/cedar-policy/cedar-go/internal/testutil" ) -func TestNewPolicySet(t *testing.T) { +func TestNewPolicySetFromFile(t *testing.T) { t.Parallel() t.Run("err-in-tokenize", func(t *testing.T) { t.Parallel() - _, err := NewPolicySet("policy.cedar", []byte(`"`)) + _, err := NewPolicySetFromFile("policy.cedar", []byte(`"`)) testutil.Error(t, err) }) t.Run("err-in-parse", func(t *testing.T) { t.Parallel() - _, err := NewPolicySet("policy.cedar", []byte(`err`)) + _, err := NewPolicySetFromFile("policy.cedar", []byte(`err`)) testutil.Error(t, err) }) t.Run("annotations", func(t *testing.T) { t.Parallel() - ps, err := NewPolicySet("policy.cedar", []byte(`@key("value") permit (principal, action, resource);`)) + ps, err := NewPolicySetFromFile("policy.cedar", []byte(`@key("value") permit (principal, action, resource);`)) testutil.OK(t, err) testutil.Equals(t, ps.GetPolicy("policy0").Annotations, Annotations{"key": "value"}) }) } -func TestNewPolicySetFromPolicies(t *testing.T) { +func TestUpsertPolicy(t *testing.T) { t.Parallel() - t.Run("empty slice", func(t *testing.T) { - t.Parallel() - - ps := NewPolicySetFromPolicies(nil) - testutil.Equals(t, ps.GetPolicy("policy0"), nil) - }) - t.Run("non-empty slice", func(t *testing.T) { + t.Run("insert", func(t *testing.T) { t.Parallel() policy0 := NewPolicyFromAST(ast.Forbid()) @@ -45,29 +39,18 @@ func TestNewPolicySetFromPolicies(t *testing.T) { []byte(`{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}}`), )) - ps := NewPolicySetFromPolicies([]*Policy{policy0, &policy1}) + ps := NewPolicySet() + ps.UpsertPolicy("policy0", policy0) + ps.UpsertPolicy("policy1", &policy1) testutil.Equals(t, ps.GetPolicy("policy0"), policy0) testutil.Equals(t, ps.GetPolicy("policy1"), &policy1) testutil.Equals(t, ps.GetPolicy("policy2"), nil) }) -} - -func TestUpsertPolicy(t *testing.T) { - t.Parallel() - t.Run("insert", func(t *testing.T) { - t.Parallel() - - ps := NewPolicySetFromPolicies(nil) - p := NewPolicyFromAST(ast.Forbid()) - ps.UpsertPolicy("a very strict policy", p) - - testutil.Equals(t, ps.GetPolicy("a very strict policy"), p) - }) t.Run("upsert", func(t *testing.T) { t.Parallel() - ps := NewPolicySetFromPolicies(nil) + ps := NewPolicySet() p1 := NewPolicyFromAST(ast.Forbid()) ps.UpsertPolicy("a wavering policy", p1) @@ -84,7 +67,7 @@ func TestDeletePolicy(t *testing.T) { t.Run("delete non-existent", func(t *testing.T) { t.Parallel() - ps := NewPolicySetFromPolicies(nil) + ps := NewPolicySet() // Just verify that this doesn't crash ps.DeletePolicy("not a policy") @@ -92,7 +75,7 @@ func TestDeletePolicy(t *testing.T) { t.Run("delete existing", func(t *testing.T) { t.Parallel() - ps := NewPolicySetFromPolicies(nil) + ps := NewPolicySet() p1 := NewPolicyFromAST(ast.Forbid()) ps.UpsertPolicy("a policy", p1) From 0981ee479a35521bed3aeb078f080b367f5c1faf Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Thu, 15 Aug 2024 16:18:03 -0700 Subject: [PATCH 143/216] cedar-go: add the ability to parse and emit a slice of Policys from a concatenation of Cedar-encoded policies Signed-off-by: philhassey --- internal/parser/cedar_parse_test.go | 6 ++-- internal/parser/cedar_unmarshal.go | 4 +-- internal/parser/cedar_unmarshal_test.go | 4 +-- internal/parser/policy.go | 2 +- policy.go | 41 +++++++++++++++++++++++++ policy_set.go | 32 ++++++------------- policy_set_test.go | 30 ++++++++++++++++++ policy_test.go | 25 +++++++++++++++ 8 files changed, 114 insertions(+), 30 deletions(-) diff --git a/internal/parser/cedar_parse_test.go b/internal/parser/cedar_parse_test.go index 854a9d04..d3a2581f 100644 --- a/internal/parser/cedar_parse_test.go +++ b/internal/parser/cedar_parse_test.go @@ -295,7 +295,7 @@ func TestParse(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - var policies parser.PolicySet + var policies parser.PolicySlice err := policies.UnmarshalCedar([]byte(tt.in)) if tt.err { testutil.Error(t, err) @@ -315,7 +315,7 @@ func TestParse(t *testing.T) { pp := policies[0] pp.MarshalCedar(&buf) - var p2 parser.PolicySet + var p2 parser.PolicySlice err = p2.UnmarshalCedar(buf.Bytes()) testutil.OK(t, err) @@ -340,7 +340,7 @@ permit( principal, action, resource ); @test("1234") permit (principal, action, resource ); ` - var out parser.PolicySet + var out parser.PolicySlice err := out.UnmarshalCedar([]byte(in)) testutil.OK(t, err) testutil.Equals(t, len(out), 3) diff --git a/internal/parser/cedar_unmarshal.go b/internal/parser/cedar_unmarshal.go index 7ca92c44..332b7d68 100644 --- a/internal/parser/cedar_unmarshal.go +++ b/internal/parser/cedar_unmarshal.go @@ -11,13 +11,13 @@ import ( "github.com/cedar-policy/cedar-go/types" ) -func (p *PolicySet) UnmarshalCedar(b []byte) error { +func (p *PolicySlice) UnmarshalCedar(b []byte) error { tokens, err := Tokenize(b) if err != nil { return err } - var policySet PolicySet + var policySet PolicySlice parser := newParser(tokens) for !parser.peek().isEOF() { var policy Policy diff --git a/internal/parser/cedar_unmarshal_test.go b/internal/parser/cedar_unmarshal_test.go index deb46cd1..8eb6ddab 100644 --- a/internal/parser/cedar_unmarshal_test.go +++ b/internal/parser/cedar_unmarshal_test.go @@ -495,7 +495,7 @@ func TestParsePolicySet(t *testing.T) { resource );`) - var policies parser.PolicySet + var policies parser.PolicySlice testutil.OK(t, policies.UnmarshalCedar(policyStr)) expectedPolicy := ast.Permit() @@ -513,7 +513,7 @@ func TestParsePolicySet(t *testing.T) { action, resource );`) - var policies parser.PolicySet + var policies parser.PolicySlice testutil.OK(t, policies.UnmarshalCedar(policyStr)) expectedPolicy0 := ast.Permit() diff --git a/internal/parser/policy.go b/internal/parser/policy.go index ba489f58..1f3d1b48 100644 --- a/internal/parser/policy.go +++ b/internal/parser/policy.go @@ -2,5 +2,5 @@ package parser import "github.com/cedar-policy/cedar-go/internal/ast" -type PolicySet []*Policy +type PolicySlice []*Policy type Policy ast.Policy diff --git a/policy.go b/policy.go index 9bfb771b..083a4c95 100644 --- a/policy.go +++ b/policy.go @@ -2,6 +2,7 @@ package cedar import ( "bytes" + "fmt" "github.com/cedar-policy/cedar-go/ast" internalast "github.com/cedar-policy/cedar-go/internal/ast" @@ -107,3 +108,43 @@ func NewPolicyFromAST(astIn *ast.Policy) *Policy { ast: pp, } } + +// PolicySlice represents a set of un-named Policy's. Cedar documents, unlike the JSON format, don't have a means of +// naming individual policies. +type PolicySlice []*Policy + +// UnmarshalCedar parses a concatenation of un-named Cedar policy statements. Names can be assigned to these policies +// when adding them to a PolicySet. +func (p *PolicySlice) UnmarshalCedar(b []byte) error { + var res parser.PolicySlice + if err := res.UnmarshalCedar(b); err != nil { + return fmt.Errorf("parser error: %w", err) + } + policySlice := make([]*Policy, 0, len(res)) + for _, p := range res { + policySlice = append(policySlice, &Policy{ + Position: Position{ + Offset: p.Position.Offset, + Line: p.Position.Line, + Column: p.Position.Column, + }, + Annotations: newAnnotationsFromSlice(p.Annotations), + Effect: Effect(p.Effect), + eval: eval.Compile((*internalast.Policy)(p)), + ast: (*internalast.Policy)(p), + }) + } + *p = policySlice + return nil +} + +// MarshalCedar emits a concatenated Cedar representation of a PolicySlice +func (p PolicySlice) MarshalCedar(buf *bytes.Buffer) { + for i, policy := range p { + policy.MarshalCedar(buf) + + if i < len(p)-1 { + buf.WriteString("\n\n") + } + } +} diff --git a/policy_set.go b/policy_set.go index 886ae343..7603243f 100644 --- a/policy_set.go +++ b/policy_set.go @@ -3,10 +3,6 @@ package cedar import ( "fmt" - - internalast "github.com/cedar-policy/cedar-go/internal/ast" - "github.com/cedar-policy/cedar-go/internal/eval" - "github.com/cedar-policy/cedar-go/internal/parser" ) type PolicyID string @@ -24,27 +20,19 @@ func NewPolicySet() PolicySet { // NewPolicySetFromFile will create a PolicySet from the given text document with the/ given file name used in Position // data. If there is an error parsing the document, it will be returned. // -// NewPolicySetFromFile assigns default PolicyIDs to the policies contained in fileName. +// NewPolicySetFromFile assigns default PolicyIDs to the policies contained in fileName in the format "policy" where +// is incremented for each new policy found in the file. func NewPolicySetFromFile(fileName string, document []byte) (PolicySet, error) { - var res parser.PolicySet - if err := res.UnmarshalCedar(document); err != nil { - return PolicySet{}, fmt.Errorf("parser error: %w", err) + var policySlice PolicySlice + if err := policySlice.UnmarshalCedar(document); err != nil { + return PolicySet{}, err } - policyMap := make(map[PolicyID]*Policy, len(res)) - for i, p := range res { + + policyMap := make(map[PolicyID]*Policy, len(policySlice)) + for i, p := range policySlice { policyID := PolicyID(fmt.Sprintf("policy%d", i)) - policyMap[policyID] = &Policy{ - Position: Position{ - Filename: fileName, - Offset: p.Position.Offset, - Line: p.Position.Line, - Column: p.Position.Column, - }, - Annotations: newAnnotationsFromSlice(p.Annotations), - Effect: Effect(p.Effect), - eval: eval.Compile((*internalast.Policy)(p)), - ast: (*internalast.Policy)(p), - } + p.Position.Filename = fileName + policyMap[policyID] = p } return PolicySet{policies: policyMap}, nil } diff --git a/policy_set_test.go b/policy_set_test.go index 93f41c98..854eed1e 100644 --- a/policy_set_test.go +++ b/policy_set_test.go @@ -1,6 +1,7 @@ package cedar import ( + "fmt" "testing" "github.com/cedar-policy/cedar-go/ast" @@ -84,3 +85,32 @@ func TestDeletePolicy(t *testing.T) { testutil.Equals(t, ps.GetPolicy("a policy"), nil) }) } + +func TestNewPolicySetFromSlice(t *testing.T) { + t.Parallel() + + policiesStr := `permit ( + principal, + action == Action::"editPhoto", + resource +) +when { resource.owner == principal }; + +forbid ( + principal in Groups::"bannedUsers", + action, + resource +);` + + var policies PolicySlice + testutil.OK(t, policies.UnmarshalCedar([]byte(policiesStr))) + + ps := NewPolicySet() + for i, p := range policies { + p.Position.Filename = "example.cedar" + ps.UpsertPolicy(PolicyID(fmt.Sprintf("policy%d", i)), p) + } + + testutil.Equals(t, ps.GetPolicy("policy0").Effect, Permit) + testutil.Equals(t, ps.GetPolicy("policy1").Effect, Forbid) +} diff --git a/policy_test.go b/policy_test.go index 1a1a84ef..d469254f 100644 --- a/policy_test.go +++ b/policy_test.go @@ -96,3 +96,28 @@ func TestPolicyAST(t *testing.T) { _ = NewPolicyFromAST(astExample) } + +func TestPolicySlice(t *testing.T) { + t.Parallel() + + policiesStr := `permit ( + principal, + action == Action::"editPhoto", + resource +) +when { resource.owner == principal }; + +forbid ( + principal in Groups::"bannedUsers", + action, + resource +);` + + var policies PolicySlice + testutil.OK(t, policies.UnmarshalCedar([]byte(policiesStr))) + + var buf bytes.Buffer + policies.MarshalCedar(&buf) + + testutil.Equals(t, buf.String(), policiesStr) +} From 1d3342ee3f1098cb23718685a4c445e70c816c7d Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Thu, 15 Aug 2024 16:29:22 -0700 Subject: [PATCH 144/216] cedar-go: add the ability to marshal a PolicySet to Cedar text Note: there's no direct way to do the reverse because we really want to avoid generating policy IDs for users. The NewPolicySetFromFile() is the one exception for people who are looking for something super brain-dead. Signed-off-by: philhassey --- policy_set.go | 23 +++++++++++++++++++++++ policy_set_test.go | 6 ++++++ 2 files changed, 29 insertions(+) diff --git a/policy_set.go b/policy_set.go index 7603243f..9959e5ff 100644 --- a/policy_set.go +++ b/policy_set.go @@ -2,7 +2,9 @@ package cedar import ( + "bytes" "fmt" + "slices" ) type PolicyID string @@ -52,3 +54,24 @@ func (p *PolicySet) UpsertPolicy(policyID PolicyID, policy *Policy) { func (p *PolicySet) DeletePolicy(policyID PolicyID) { delete(p.policies, policyID) } + +// MarshalCedar emits a concatenated Cedar representation of a PolicySet. The policy names are stripped, but policies +// are emitted in lexicographical order by ID. +func (p PolicySet) MarshalCedar(buf *bytes.Buffer) { + ids := make([]PolicyID, 0, len(p.policies)) + for k := range p.policies { + ids = append(ids, k) + } + slices.Sort(ids) + + i := 0 + for _, id := range ids { + policy := p.policies[id] + policy.MarshalCedar(buf) + + if i < len(p.policies)-1 { + buf.WriteString("\n\n") + } + i++ + } +} diff --git a/policy_set_test.go b/policy_set_test.go index 854eed1e..65451b35 100644 --- a/policy_set_test.go +++ b/policy_set_test.go @@ -1,6 +1,7 @@ package cedar import ( + "bytes" "fmt" "testing" @@ -113,4 +114,9 @@ forbid ( testutil.Equals(t, ps.GetPolicy("policy0").Effect, Permit) testutil.Equals(t, ps.GetPolicy("policy1").Effect, Forbid) + + var buf bytes.Buffer + ps.MarshalCedar(&buf) + + testutil.Equals(t, buf.String(), policiesStr) } From 7e7f35418c92383bdb6676aec52f50a79569f3cf Mon Sep 17 00:00:00 2001 From: philhassey Date: Thu, 15 Aug 2024 17:46:58 -0600 Subject: [PATCH 145/216] internal/eval: move ValueTo* functions to eval package Addresses IDX-142 Signed-off-by: philhassey --- authorize.go | 2 +- internal/eval/evalers.go | 34 ++++---- internal/eval/evalers_test.go | 94 ++++++++++----------- internal/eval/util.go | 81 ++++++++++++++++++ internal/eval/util_test.go | 153 ++++++++++++++++++++++++++++++++++ types/value.go | 73 ---------------- types/value_test.go | 107 ------------------------ 7 files changed, 299 insertions(+), 245 deletions(-) create mode 100644 internal/eval/util.go create mode 100644 internal/eval/util_test.go diff --git a/authorize.go b/authorize.go index 538a8b58..2da1abcb 100644 --- a/authorize.go +++ b/authorize.go @@ -95,7 +95,7 @@ func (p PolicySet) IsAuthorized(entityMap entities.Entities, req Request) (Decis diag.Errors = append(diag.Errors, Error{PolicyID: id, Position: po.Position, Message: err.Error()}) continue } - vb, err := types.ValueToBool(v) + vb, err := eval.ValueToBool(v) if err != nil { // should never happen, maybe remove this case diag.Errors = append(diag.Errors, Error{PolicyID: id, Position: po.Position, Message: err.Error()}) diff --git a/internal/eval/evalers.go b/internal/eval/evalers.go index 4924c354..ab79cc36 100644 --- a/internal/eval/evalers.go +++ b/internal/eval/evalers.go @@ -30,7 +30,7 @@ func evalBool(n Evaler, ctx *Context) (types.Boolean, error) { if err != nil { return false, err } - b, err := types.ValueToBool(v) + b, err := ValueToBool(v) if err != nil { return false, err } @@ -42,7 +42,7 @@ func evalLong(n Evaler, ctx *Context) (types.Long, error) { if err != nil { return 0, err } - l, err := types.ValueToLong(v) + l, err := ValueToLong(v) if err != nil { return 0, err } @@ -54,7 +54,7 @@ func evalString(n Evaler, ctx *Context) (types.String, error) { if err != nil { return "", err } - s, err := types.ValueToString(v) + s, err := ValueToString(v) if err != nil { return "", err } @@ -66,7 +66,7 @@ func evalSet(n Evaler, ctx *Context) (types.Set, error) { if err != nil { return nil, err } - s, err := types.ValueToSet(v) + s, err := ValueToSet(v) if err != nil { return nil, err } @@ -78,7 +78,7 @@ func evalEntity(n Evaler, ctx *Context) (types.EntityUID, error) { if err != nil { return types.EntityUID{}, err } - e, err := types.ValueToEntity(v) + e, err := ValueToEntity(v) if err != nil { return types.EntityUID{}, err } @@ -90,7 +90,7 @@ func evalPath(n Evaler, ctx *Context) (types.Path, error) { if err != nil { return "", err } - e, err := types.ValueToPath(v) + e, err := ValueToPath(v) if err != nil { return "", err } @@ -102,7 +102,7 @@ func evalDecimal(n Evaler, ctx *Context) (types.Decimal, error) { if err != nil { return types.Decimal(0), err } - d, err := types.ValueToDecimal(v) + d, err := ValueToDecimal(v) if err != nil { return types.Decimal(0), err } @@ -114,7 +114,7 @@ func evalIP(n Evaler, ctx *Context) (types.IPAddr, error) { if err != nil { return types.IPAddr{}, err } - i, err := types.ValueToIP(v) + i, err := ValueToIP(v) if err != nil { return types.IPAddr{}, err } @@ -167,7 +167,7 @@ func (n *orEval) Eval(ctx *Context) (types.Value, error) { if err != nil { return types.ZeroValue(), err } - b, err := types.ValueToBool(v) + b, err := ValueToBool(v) if err != nil { return types.ZeroValue(), err } @@ -178,7 +178,7 @@ func (n *orEval) Eval(ctx *Context) (types.Value, error) { if err != nil { return types.ZeroValue(), err } - _, err = types.ValueToBool(v) + _, err = ValueToBool(v) if err != nil { return types.ZeroValue(), err } @@ -203,7 +203,7 @@ func (n *andEval) Eval(ctx *Context) (types.Value, error) { if err != nil { return types.ZeroValue(), err } - b, err := types.ValueToBool(v) + b, err := ValueToBool(v) if err != nil { return types.ZeroValue(), err } @@ -214,7 +214,7 @@ func (n *andEval) Eval(ctx *Context) (types.Value, error) { if err != nil { return types.ZeroValue(), err } - _, err = types.ValueToBool(v) + _, err = ValueToBool(v) if err != nil { return types.ZeroValue(), err } @@ -237,7 +237,7 @@ func (n *notEval) Eval(ctx *Context) (types.Value, error) { if err != nil { return types.ZeroValue(), err } - b, err := types.ValueToBool(v) + b, err := ValueToBool(v) if err != nil { return types.ZeroValue(), err } @@ -833,7 +833,7 @@ func (n *attributeAccessEval) Eval(ctx *Context) (types.Value, error) { case types.Record: record = vv default: - return types.ZeroValue(), fmt.Errorf("%w: expected one of [record, (entity of type `any_entity_type`)], got %v", types.ErrType, v.TypeName()) + return types.ZeroValue(), fmt.Errorf("%w: expected one of [record, (entity of type `any_entity_type`)], got %v", ErrType, v.TypeName()) } val, ok := record[n.attribute] if !ok { @@ -869,7 +869,7 @@ func (n *hasEval) Eval(ctx *Context) (types.Value, error) { case types.Record: record = vv default: - return types.ZeroValue(), fmt.Errorf("%w: expected one of [record, (entity of type `any_entity_type`)], got %v", types.ErrType, v.TypeName()) + return types.ZeroValue(), fmt.Errorf("%w: expected one of [record, (entity of type `any_entity_type`)], got %v", ErrType, v.TypeName()) } _, ok := record[n.attribute] return types.Boolean(ok), nil @@ -957,7 +957,7 @@ func (n *inEval) Eval(ctx *Context) (types.Value, error) { query[rhsv] = struct{}{} case types.Set: for _, rhv := range rhsv { - e, err := types.ValueToEntity(rhv) + e, err := ValueToEntity(rhv) if err != nil { return types.ZeroValue(), err } @@ -965,7 +965,7 @@ func (n *inEval) Eval(ctx *Context) (types.Value, error) { } default: return types.ZeroValue(), fmt.Errorf( - "%w: expected one of [set, (entity of type `any_entity_type`)], got %v", types.ErrType, rhs.TypeName()) + "%w: expected one of [set, (entity of type `any_entity_type`)], got %v", ErrType, rhs.TypeName()) } return types.Boolean(entityIn(lhs, query, ctx.Entities)), nil } diff --git a/internal/eval/evalers_test.go b/internal/eval/evalers_test.go index a3793473..f530f88e 100644 --- a/internal/eval/evalers_test.go +++ b/internal/eval/evalers_test.go @@ -58,9 +58,9 @@ func TestOrNode(t *testing.T) { err error }{ {"LhsError", newErrorEval(errTest), newLiteralEval(types.True), errTest}, - {"LhsTypeError", newLiteralEval(types.Long(1)), newLiteralEval(types.True), types.ErrType}, + {"LhsTypeError", newLiteralEval(types.Long(1)), newLiteralEval(types.True), ErrType}, {"RhsError", newLiteralEval(types.False), newErrorEval(errTest), errTest}, - {"RhsTypeError", newLiteralEval(types.False), newLiteralEval(types.Long(1)), types.ErrType}, + {"RhsTypeError", newLiteralEval(types.False), newLiteralEval(types.Long(1)), ErrType}, } for _, tt := range tests { tt := tt @@ -113,9 +113,9 @@ func TestAndNode(t *testing.T) { err error }{ {"LhsError", newErrorEval(errTest), newLiteralEval(types.True), errTest}, - {"LhsTypeError", newLiteralEval(types.Long(1)), newLiteralEval(types.True), types.ErrType}, + {"LhsTypeError", newLiteralEval(types.Long(1)), newLiteralEval(types.True), ErrType}, {"RhsError", newLiteralEval(types.True), newErrorEval(errTest), errTest}, - {"RhsTypeError", newLiteralEval(types.True), newLiteralEval(types.Long(1)), types.ErrType}, + {"RhsTypeError", newLiteralEval(types.True), newLiteralEval(types.Long(1)), ErrType}, } for _, tt := range tests { tt := tt @@ -157,7 +157,7 @@ func TestNotNode(t *testing.T) { err error }{ {"Error", newErrorEval(errTest), errTest}, - {"TypeError", newLiteralEval(types.Long(1)), types.ErrType}, + {"TypeError", newLiteralEval(types.Long(1)), ErrType}, } for _, tt := range tests { tt := tt @@ -355,9 +355,9 @@ func TestAddNode(t *testing.T) { err error }{ {"LhsError", newErrorEval(errTest), newLiteralEval(types.Long(0)), errTest}, - {"LhsTypeError", newLiteralEval(types.True), newLiteralEval(types.Long(0)), types.ErrType}, + {"LhsTypeError", newLiteralEval(types.True), newLiteralEval(types.Long(0)), ErrType}, {"RhsError", newLiteralEval(types.Long(0)), newErrorEval(errTest), errTest}, - {"RhsTypeError", newLiteralEval(types.Long(0)), newLiteralEval(types.True), types.ErrType}, + {"RhsTypeError", newLiteralEval(types.Long(0)), newLiteralEval(types.True), ErrType}, {"PositiveOverflow", newLiteralEval(types.Long(9_223_372_036_854_775_807)), newLiteralEval(types.Long(1)), @@ -394,9 +394,9 @@ func TestSubtractNode(t *testing.T) { err error }{ {"LhsError", newErrorEval(errTest), newLiteralEval(types.Long(0)), errTest}, - {"LhsTypeError", newLiteralEval(types.True), newLiteralEval(types.Long(0)), types.ErrType}, + {"LhsTypeError", newLiteralEval(types.True), newLiteralEval(types.Long(0)), ErrType}, {"RhsError", newLiteralEval(types.Long(0)), newErrorEval(errTest), errTest}, - {"RhsTypeError", newLiteralEval(types.Long(0)), newLiteralEval(types.True), types.ErrType}, + {"RhsTypeError", newLiteralEval(types.Long(0)), newLiteralEval(types.True), ErrType}, {"PositiveOverflow", newLiteralEval(types.Long(9_223_372_036_854_775_807)), newLiteralEval(types.Long(-1)), @@ -433,9 +433,9 @@ func TestMultiplyNode(t *testing.T) { err error }{ {"LhsError", newErrorEval(errTest), newLiteralEval(types.Long(0)), errTest}, - {"LhsTypeError", newLiteralEval(types.True), newLiteralEval(types.Long(0)), types.ErrType}, + {"LhsTypeError", newLiteralEval(types.True), newLiteralEval(types.Long(0)), ErrType}, {"RhsError", newLiteralEval(types.Long(0)), newErrorEval(errTest), errTest}, - {"RhsTypeError", newLiteralEval(types.Long(0)), newLiteralEval(types.True), types.ErrType}, + {"RhsTypeError", newLiteralEval(types.Long(0)), newLiteralEval(types.True), ErrType}, {"PositiveOverflow", newLiteralEval(types.Long(9_223_372_036_854_775_807)), newLiteralEval(types.Long(2)), @@ -472,7 +472,7 @@ func TestNegateNode(t *testing.T) { err error }{ {"Error", newErrorEval(errTest), errTest}, - {"TypeError", newLiteralEval(types.True), types.ErrType}, + {"TypeError", newLiteralEval(types.True), ErrType}, {"Overflow", newLiteralEval(types.Long(-9_223_372_036_854_775_808)), errOverflow}, } for _, tt := range tests { @@ -522,9 +522,9 @@ func TestLongLessThanNode(t *testing.T) { err error }{ {"LhsError", newErrorEval(errTest), newLiteralEval(types.Long(0)), errTest}, - {"LhsTypeError", newLiteralEval(types.True), newLiteralEval(types.Long(0)), types.ErrType}, + {"LhsTypeError", newLiteralEval(types.True), newLiteralEval(types.Long(0)), ErrType}, {"RhsError", newLiteralEval(types.Long(0)), newErrorEval(errTest), errTest}, - {"RhsTypeError", newLiteralEval(types.Long(0)), newLiteralEval(types.True), types.ErrType}, + {"RhsTypeError", newLiteralEval(types.Long(0)), newLiteralEval(types.True), ErrType}, } for _, tt := range tests { tt := tt @@ -574,9 +574,9 @@ func TestLongLessThanOrEqualNode(t *testing.T) { err error }{ {"LhsError", newErrorEval(errTest), newLiteralEval(types.Long(0)), errTest}, - {"LhsTypeError", newLiteralEval(types.True), newLiteralEval(types.Long(0)), types.ErrType}, + {"LhsTypeError", newLiteralEval(types.True), newLiteralEval(types.Long(0)), ErrType}, {"RhsError", newLiteralEval(types.Long(0)), newErrorEval(errTest), errTest}, - {"RhsTypeError", newLiteralEval(types.Long(0)), newLiteralEval(types.True), types.ErrType}, + {"RhsTypeError", newLiteralEval(types.Long(0)), newLiteralEval(types.True), ErrType}, } for _, tt := range tests { tt := tt @@ -626,9 +626,9 @@ func TestLongGreaterThanNode(t *testing.T) { err error }{ {"LhsError", newErrorEval(errTest), newLiteralEval(types.Long(0)), errTest}, - {"LhsTypeError", newLiteralEval(types.True), newLiteralEval(types.Long(0)), types.ErrType}, + {"LhsTypeError", newLiteralEval(types.True), newLiteralEval(types.Long(0)), ErrType}, {"RhsError", newLiteralEval(types.Long(0)), newErrorEval(errTest), errTest}, - {"RhsTypeError", newLiteralEval(types.Long(0)), newLiteralEval(types.True), types.ErrType}, + {"RhsTypeError", newLiteralEval(types.Long(0)), newLiteralEval(types.True), ErrType}, } for _, tt := range tests { tt := tt @@ -678,9 +678,9 @@ func TestLongGreaterThanOrEqualNode(t *testing.T) { err error }{ {"LhsError", newErrorEval(errTest), newLiteralEval(types.Long(0)), errTest}, - {"LhsTypeError", newLiteralEval(types.True), newLiteralEval(types.Long(0)), types.ErrType}, + {"LhsTypeError", newLiteralEval(types.True), newLiteralEval(types.Long(0)), ErrType}, {"RhsError", newLiteralEval(types.Long(0)), newErrorEval(errTest), errTest}, - {"RhsTypeError", newLiteralEval(types.Long(0)), newLiteralEval(types.True), types.ErrType}, + {"RhsTypeError", newLiteralEval(types.Long(0)), newLiteralEval(types.True), ErrType}, } for _, tt := range tests { tt := tt @@ -735,9 +735,9 @@ func TestDecimalLessThanNode(t *testing.T) { err error }{ {"LhsError", newErrorEval(errTest), newLiteralEval(types.Decimal(0)), errTest}, - {"LhsTypeError", newLiteralEval(types.True), newLiteralEval(types.Decimal(0)), types.ErrType}, + {"LhsTypeError", newLiteralEval(types.True), newLiteralEval(types.Decimal(0)), ErrType}, {"RhsError", newLiteralEval(types.Decimal(0)), newErrorEval(errTest), errTest}, - {"RhsTypeError", newLiteralEval(types.Decimal(0)), newLiteralEval(types.True), types.ErrType}, + {"RhsTypeError", newLiteralEval(types.Decimal(0)), newLiteralEval(types.True), ErrType}, } for _, tt := range tests { tt := tt @@ -792,9 +792,9 @@ func TestDecimalLessThanOrEqualNode(t *testing.T) { err error }{ {"LhsError", newErrorEval(errTest), newLiteralEval(types.Decimal(0)), errTest}, - {"LhsTypeError", newLiteralEval(types.True), newLiteralEval(types.Decimal(0)), types.ErrType}, + {"LhsTypeError", newLiteralEval(types.True), newLiteralEval(types.Decimal(0)), ErrType}, {"RhsError", newLiteralEval(types.Decimal(0)), newErrorEval(errTest), errTest}, - {"RhsTypeError", newLiteralEval(types.Decimal(0)), newLiteralEval(types.True), types.ErrType}, + {"RhsTypeError", newLiteralEval(types.Decimal(0)), newLiteralEval(types.True), ErrType}, } for _, tt := range tests { tt := tt @@ -849,9 +849,9 @@ func TestDecimalGreaterThanNode(t *testing.T) { err error }{ {"LhsError", newErrorEval(errTest), newLiteralEval(types.Decimal(0)), errTest}, - {"LhsTypeError", newLiteralEval(types.True), newLiteralEval(types.Decimal(0)), types.ErrType}, + {"LhsTypeError", newLiteralEval(types.True), newLiteralEval(types.Decimal(0)), ErrType}, {"RhsError", newLiteralEval(types.Decimal(0)), newErrorEval(errTest), errTest}, - {"RhsTypeError", newLiteralEval(types.Decimal(0)), newLiteralEval(types.True), types.ErrType}, + {"RhsTypeError", newLiteralEval(types.Decimal(0)), newLiteralEval(types.True), ErrType}, } for _, tt := range tests { tt := tt @@ -906,9 +906,9 @@ func TestDecimalGreaterThanOrEqualNode(t *testing.T) { err error }{ {"LhsError", newErrorEval(errTest), newLiteralEval(types.Decimal(0)), errTest}, - {"LhsTypeError", newLiteralEval(types.True), newLiteralEval(types.Decimal(0)), types.ErrType}, + {"LhsTypeError", newLiteralEval(types.True), newLiteralEval(types.Decimal(0)), ErrType}, {"RhsError", newLiteralEval(types.Decimal(0)), newErrorEval(errTest), errTest}, - {"RhsTypeError", newLiteralEval(types.Decimal(0)), newLiteralEval(types.True), types.ErrType}, + {"RhsTypeError", newLiteralEval(types.Decimal(0)), newLiteralEval(types.True), ErrType}, } for _, tt := range tests { tt := tt @@ -939,7 +939,7 @@ func TestIfThenElseNode(t *testing.T) { {"Err", newErrorEval(errTest), newLiteralEval(types.ZeroValue()), newLiteralEval(types.ZeroValue()), types.ZeroValue(), errTest}, {"ErrType", newLiteralEval(types.Long(123)), newLiteralEval(types.ZeroValue()), newLiteralEval(types.ZeroValue()), types.ZeroValue(), - types.ErrType}, + ErrType}, } for _, tt := range tests { tt := tt @@ -1055,7 +1055,7 @@ func TestContainsNode(t *testing.T) { err error }{ {"LhsError", newErrorEval(errTest), newLiteralEval(types.Long(0)), errTest}, - {"LhsTypeError", newLiteralEval(types.True), newLiteralEval(types.Long(0)), types.ErrType}, + {"LhsTypeError", newLiteralEval(types.True), newLiteralEval(types.Long(0)), ErrType}, {"RhsError", newLiteralEval(types.Set{}), newErrorEval(errTest), errTest}, } for _, tt := range tests { @@ -1109,9 +1109,9 @@ func TestContainsAllNode(t *testing.T) { err error }{ {"LhsError", newErrorEval(errTest), newLiteralEval(types.Set{}), errTest}, - {"LhsTypeError", newLiteralEval(types.True), newLiteralEval(types.Set{}), types.ErrType}, + {"LhsTypeError", newLiteralEval(types.True), newLiteralEval(types.Set{}), ErrType}, {"RhsError", newLiteralEval(types.Set{}), newErrorEval(errTest), errTest}, - {"RhsTypeError", newLiteralEval(types.Set{}), newLiteralEval(types.Long(0)), types.ErrType}, + {"RhsTypeError", newLiteralEval(types.Set{}), newLiteralEval(types.Long(0)), ErrType}, } for _, tt := range tests { tt := tt @@ -1163,9 +1163,9 @@ func TestContainsAnyNode(t *testing.T) { err error }{ {"LhsError", newErrorEval(errTest), newLiteralEval(types.Set{}), errTest}, - {"LhsTypeError", newLiteralEval(types.True), newLiteralEval(types.Set{}), types.ErrType}, + {"LhsTypeError", newLiteralEval(types.True), newLiteralEval(types.Set{}), ErrType}, {"RhsError", newLiteralEval(types.Set{}), newErrorEval(errTest), errTest}, - {"RhsTypeError", newLiteralEval(types.Set{}), newLiteralEval(types.Long(0)), types.ErrType}, + {"RhsTypeError", newLiteralEval(types.Set{}), newLiteralEval(types.Long(0)), ErrType}, } for _, tt := range tests { tt := tt @@ -1252,7 +1252,7 @@ func TestAttributeAccessNode(t *testing.T) { err error }{ {"RecordError", newErrorEval(errTest), "foo", types.ZeroValue(), errTest}, - {"RecordTypeError", newLiteralEval(types.True), "foo", types.ZeroValue(), types.ErrType}, + {"RecordTypeError", newLiteralEval(types.True), "foo", types.ZeroValue(), ErrType}, {"UnknownAttribute", newLiteralEval(types.Record{}), "foo", @@ -1309,7 +1309,7 @@ func TestHasNode(t *testing.T) { err error }{ {"RecordError", newErrorEval(errTest), "foo", types.ZeroValue(), errTest}, - {"RecordTypeError", newLiteralEval(types.True), "foo", types.ZeroValue(), types.ErrType}, + {"RecordTypeError", newLiteralEval(types.True), "foo", types.ZeroValue(), ErrType}, {"UnknownAttribute", newLiteralEval(types.Record{}), "foo", @@ -1366,7 +1366,7 @@ func TestLikeNode(t *testing.T) { err error }{ {"leftError", newErrorEval(errTest), `"foo"`, types.ZeroValue(), errTest}, - {"leftTypeError", newLiteralEval(types.True), `"foo"`, types.ZeroValue(), types.ErrType}, + {"leftTypeError", newLiteralEval(types.True), `"foo"`, types.ZeroValue(), ErrType}, {"noMatch", newLiteralEval(types.String("test")), `"zebra"`, types.False, nil}, {"match", newLiteralEval(types.String("test")), `"*es*"`, types.True, nil}, @@ -1592,8 +1592,8 @@ func TestIsNode(t *testing.T) { }{ {"happyEq", newLiteralEval(types.NewEntityUID("X", "z")), newLiteralEval(types.Path("X")), types.True, nil}, {"happyNeq", newLiteralEval(types.NewEntityUID("X", "z")), newLiteralEval(types.Path("Y")), types.False, nil}, - {"badLhs", newLiteralEval(types.Long(42)), newLiteralEval(types.Path("X")), types.ZeroValue(), types.ErrType}, - {"badRhs", newLiteralEval(types.NewEntityUID("X", "z")), newLiteralEval(types.Long(42)), types.ZeroValue(), types.ErrType}, + {"badLhs", newLiteralEval(types.Long(42)), newLiteralEval(types.Path("X")), types.ZeroValue(), ErrType}, + {"badRhs", newLiteralEval(types.NewEntityUID("X", "z")), newLiteralEval(types.Long(42)), types.ZeroValue(), ErrType}, {"errLhs", newErrorEval(errTest), newLiteralEval(types.Path("X")), types.ZeroValue(), errTest}, {"errRhs", newLiteralEval(types.NewEntityUID("X", "z")), newErrorEval(errTest), types.ZeroValue(), errTest}, } @@ -1631,7 +1631,7 @@ func TestInNode(t *testing.T) { newLiteralEval(types.Set{}), map[string][]string{}, types.ZeroValue(), - types.ErrType, + ErrType, }, { "RhsError", @@ -1647,7 +1647,7 @@ func TestInNode(t *testing.T) { newLiteralEval(types.String("foo")), map[string][]string{}, types.ZeroValue(), - types.ErrType, + ErrType, }, { "RhsTypeError2", @@ -1657,7 +1657,7 @@ func TestInNode(t *testing.T) { }), map[string][]string{}, types.ZeroValue(), - types.ErrType, + ErrType, }, { "Reflexive1", @@ -1734,7 +1734,7 @@ func TestDecimalLiteralNode(t *testing.T) { err error }{ {"Error", newErrorEval(errTest), types.ZeroValue(), errTest}, - {"TypeError", newLiteralEval(types.Long(1)), types.ZeroValue(), types.ErrType}, + {"TypeError", newLiteralEval(types.Long(1)), types.ZeroValue(), ErrType}, {"DecimalError", newLiteralEval(types.String("frob")), types.ZeroValue(), types.ErrDecimal}, {"Success", newLiteralEval(types.String("1.0")), types.Decimal(10000), nil}, } @@ -1761,7 +1761,7 @@ func TestIPLiteralNode(t *testing.T) { err error }{ {"Error", newErrorEval(errTest), types.ZeroValue(), errTest}, - {"TypeError", newLiteralEval(types.Long(1)), types.ZeroValue(), types.ErrType}, + {"TypeError", newLiteralEval(types.Long(1)), types.ZeroValue(), ErrType}, {"IPError", newLiteralEval(types.String("not-an-IP-address")), types.ZeroValue(), types.ErrIP}, {"Success", newLiteralEval(types.String("::1/128")), ipv6Loopback, nil}, } @@ -1793,7 +1793,7 @@ func TestIPTestNode(t *testing.T) { err error }{ {"Error", newErrorEval(errTest), ipTestIPv4, types.ZeroValue(), errTest}, - {"TypeError", newLiteralEval(types.Long(1)), ipTestIPv4, types.ZeroValue(), types.ErrType}, + {"TypeError", newLiteralEval(types.Long(1)), ipTestIPv4, types.ZeroValue(), ErrType}, {"IPv4True", newLiteralEval(ipv4Loopback), ipTestIPv4, types.True, nil}, {"IPv4False", newLiteralEval(ipv6Loopback), ipTestIPv4, types.False, nil}, {"IPv6True", newLiteralEval(ipv6Loopback), ipTestIPv6, types.True, nil}, @@ -1830,9 +1830,9 @@ func TestIPIsInRangeNode(t *testing.T) { err error }{ {"LhsError", newErrorEval(errTest), newLiteralEval(ipv4A), types.ZeroValue(), errTest}, - {"LhsTypeError", newLiteralEval(types.Long(1)), newLiteralEval(ipv4A), types.ZeroValue(), types.ErrType}, + {"LhsTypeError", newLiteralEval(types.Long(1)), newLiteralEval(ipv4A), types.ZeroValue(), ErrType}, {"RhsError", newLiteralEval(ipv4A), newErrorEval(errTest), types.ZeroValue(), errTest}, - {"RhsTypeError", newLiteralEval(ipv4A), newLiteralEval(types.Long(1)), types.ZeroValue(), types.ErrType}, + {"RhsTypeError", newLiteralEval(ipv4A), newLiteralEval(types.Long(1)), types.ZeroValue(), ErrType}, {"AA", newLiteralEval(ipv4A), newLiteralEval(ipv4A), types.True, nil}, {"AB", newLiteralEval(ipv4A), newLiteralEval(ipv4B), types.True, nil}, {"BA", newLiteralEval(ipv4B), newLiteralEval(ipv4A), types.False, nil}, diff --git a/internal/eval/util.go b/internal/eval/util.go new file mode 100644 index 00000000..cf140118 --- /dev/null +++ b/internal/eval/util.go @@ -0,0 +1,81 @@ +package eval + +import ( + "fmt" + + "github.com/cedar-policy/cedar-go/types" +) + +var ErrType = fmt.Errorf("type error") + +func ValueToBool(v types.Value) (types.Boolean, error) { + bv, ok := v.(types.Boolean) + if !ok { + return false, fmt.Errorf("%w: expected bool, got %v", ErrType, v.TypeName()) + } + return bv, nil +} + +func ValueToLong(v types.Value) (types.Long, error) { + lv, ok := v.(types.Long) + if !ok { + return 0, fmt.Errorf("%w: expected long, got %v", ErrType, v.TypeName()) + } + return lv, nil +} + +func ValueToString(v types.Value) (types.String, error) { + sv, ok := v.(types.String) + if !ok { + return "", fmt.Errorf("%w: expected string, got %v", ErrType, v.TypeName()) + } + return sv, nil +} + +func ValueToSet(v types.Value) (types.Set, error) { + sv, ok := v.(types.Set) + if !ok { + return nil, fmt.Errorf("%w: expected set, got %v", ErrType, v.TypeName()) + } + return sv, nil +} + +func ValueToRecord(v types.Value) (types.Record, error) { + rv, ok := v.(types.Record) + if !ok { + return nil, fmt.Errorf("%w: expected record got %v", ErrType, v.TypeName()) + } + return rv, nil +} + +func ValueToEntity(v types.Value) (types.EntityUID, error) { + ev, ok := v.(types.EntityUID) + if !ok { + return types.EntityUID{}, fmt.Errorf("%w: expected (entity of type `any_entity_type`), got %v", ErrType, v.TypeName()) + } + return ev, nil +} + +func ValueToPath(v types.Value) (types.Path, error) { + ev, ok := v.(types.Path) + if !ok { + return "", fmt.Errorf("%w: expected (Path of type `any_entity_type`), got %v", ErrType, v.TypeName()) + } + return ev, nil +} + +func ValueToDecimal(v types.Value) (types.Decimal, error) { + d, ok := v.(types.Decimal) + if !ok { + return 0, fmt.Errorf("%w: expected decimal, got %v", ErrType, v.TypeName()) + } + return d, nil +} + +func ValueToIP(v types.Value) (types.IPAddr, error) { + i, ok := v.(types.IPAddr) + if !ok { + return types.IPAddr{}, fmt.Errorf("%w: expected ipaddr, got %v", ErrType, v.TypeName()) + } + return i, nil +} diff --git a/internal/eval/util_test.go b/internal/eval/util_test.go new file mode 100644 index 00000000..7164da1f --- /dev/null +++ b/internal/eval/util_test.go @@ -0,0 +1,153 @@ +package eval + +import ( + "testing" + + "github.com/cedar-policy/cedar-go/internal/testutil" + "github.com/cedar-policy/cedar-go/types" +) + +func TestUtil(t *testing.T) { + t.Parallel() + t.Run("Boolean", func(t *testing.T) { + t.Parallel() + t.Run("roundTrip", func(t *testing.T) { + t.Parallel() + v, err := ValueToBool(types.Boolean(true)) + testutil.OK(t, err) + testutil.Equals(t, v, true) + }) + + t.Run("toBoolOnNonBool", func(t *testing.T) { + t.Parallel() + v, err := ValueToBool(types.Long(0)) + testutil.AssertError(t, err, ErrType) + testutil.Equals(t, v, false) + }) + }) + + t.Run("Long", func(t *testing.T) { + t.Parallel() + t.Run("roundTrip", func(t *testing.T) { + t.Parallel() + v, err := ValueToLong(types.Long(42)) + testutil.OK(t, err) + testutil.Equals(t, v, 42) + }) + + t.Run("toLongOnNonLong", func(t *testing.T) { + t.Parallel() + v, err := ValueToLong(types.Boolean(true)) + testutil.AssertError(t, err, ErrType) + testutil.Equals(t, v, 0) + }) + }) + + t.Run("String", func(t *testing.T) { + t.Parallel() + t.Run("roundTrip", func(t *testing.T) { + t.Parallel() + v, err := ValueToString(types.String("hello")) + testutil.OK(t, err) + testutil.Equals(t, v, "hello") + }) + + t.Run("toStringOnNonString", func(t *testing.T) { + t.Parallel() + v, err := ValueToString(types.Boolean(true)) + testutil.AssertError(t, err, ErrType) + testutil.Equals(t, v, "") + }) + }) + + t.Run("Set", func(t *testing.T) { + t.Parallel() + t.Run("roundTrip", func(t *testing.T) { + t.Parallel() + v := types.Set{types.Boolean(true), types.Long(1)} + slice, err := ValueToSet(v) + testutil.OK(t, err) + v2 := slice + testutil.FatalIf(t, !v.Equal(v2), "got %v want %v", v, v2) + }) + + t.Run("ToSetOnNonSet", func(t *testing.T) { + t.Parallel() + v, err := ValueToSet(types.Boolean(true)) + testutil.AssertError(t, err, ErrType) + testutil.Equals(t, v, nil) + }) + }) + + t.Run("Record", func(t *testing.T) { + t.Parallel() + t.Run("roundTrip", func(t *testing.T) { + t.Parallel() + v := types.Record{ + "foo": types.Boolean(true), + "bar": types.Long(1), + } + map_, err := ValueToRecord(v) + testutil.OK(t, err) + v2 := map_ + testutil.FatalIf(t, !v.Equal(v2), "got %v want %v", v, v2) + }) + + t.Run("toRecordOnNonRecord", func(t *testing.T) { + t.Parallel() + v, err := ValueToRecord(types.String("hello")) + testutil.AssertError(t, err, ErrType) + testutil.Equals(t, v, nil) + }) + }) + + t.Run("Entity", func(t *testing.T) { + t.Parallel() + t.Run("roundTrip", func(t *testing.T) { + t.Parallel() + want := types.EntityUID{Type: "User", ID: "bananas"} + v, err := ValueToEntity(want) + testutil.OK(t, err) + testutil.Equals(t, v, want) + }) + t.Run("ToEntityOnNonEntity", func(t *testing.T) { + t.Parallel() + v, err := ValueToEntity(types.String("hello")) + testutil.AssertError(t, err, ErrType) + testutil.Equals(t, v, types.EntityUID{}) + }) + + }) + + t.Run("Decimal", func(t *testing.T) { + t.Parallel() + t.Run("roundTrip", func(t *testing.T) { + t.Parallel() + dv, err := types.ParseDecimal("1.20") + testutil.OK(t, err) + v, err := ValueToDecimal(dv) + testutil.OK(t, err) + testutil.FatalIf(t, !v.Equal(dv), "got %v want %v", v, dv) + }) + + t.Run("toDecimalOnNonDecimal", func(t *testing.T) { + t.Parallel() + v, err := ValueToDecimal(types.Boolean(true)) + testutil.AssertError(t, err, ErrType) + testutil.Equals(t, v, 0) + }) + + }) + + t.Run("IP", func(t *testing.T) { + t.Parallel() + + t.Run("toIPOnNonIP", func(t *testing.T) { + t.Parallel() + v, err := ValueToIP(types.Boolean(true)) + testutil.AssertError(t, err, ErrType) + testutil.Equals(t, v, types.IPAddr{}) + }) + }) + +} diff --git a/types/value.go b/types/value.go index f3549f6d..e73285b1 100644 --- a/types/value.go +++ b/types/value.go @@ -16,7 +16,6 @@ import ( var ErrDecimal = fmt.Errorf("error parsing decimal value") var ErrIP = fmt.Errorf("error parsing ip value") -var ErrType = fmt.Errorf("type error") type Value interface { // String produces a string representation of the Value. @@ -62,14 +61,6 @@ func (v Boolean) Cedar() string { func (v Boolean) ExplicitMarshalJSON() ([]byte, error) { return json.Marshal(v) } func (v Boolean) deepClone() Value { return v } -func ValueToBool(v Value) (Boolean, error) { - bv, ok := v.(Boolean) - if !ok { - return false, fmt.Errorf("%w: expected bool, got %v", ErrType, v.TypeName()) - } - return bv, nil -} - // A Long is a whole number without decimals that can range from -9223372036854775808 to 9223372036854775807. type Long int64 @@ -91,14 +82,6 @@ func (v Long) Cedar() string { } func (v Long) deepClone() Value { return v } -func ValueToLong(v Value) (Long, error) { - lv, ok := v.(Long) - if !ok { - return 0, fmt.Errorf("%w: expected long, got %v", ErrType, v.TypeName()) - } - return lv, nil -} - // A String is a sequence of characters consisting of letters, numbers, or symbols. type String string @@ -122,14 +105,6 @@ func (v String) Cedar() string { } func (v String) deepClone() Value { return v } -func ValueToString(v Value) (String, error) { - sv, ok := v.(String) - if !ok { - return "", fmt.Errorf("%w: expected string, got %v", ErrType, v.TypeName()) - } - return sv, nil -} - // A Set is a collection of elements that can be of the same or different types. type Set []Value @@ -234,14 +209,6 @@ func (v Set) DeepClone() Set { return res } -func ValueToSet(v Value) (Set, error) { - sv, ok := v.(Set) - if !ok { - return nil, fmt.Errorf("%w: expected set, got %v", ErrType, v.TypeName()) - } - return sv, nil -} - // A Record is a collection of attributes. Each attribute consists of a name and // an associated value. Names are simple strings. Values can be of any type. type Record map[string]Value @@ -343,14 +310,6 @@ func (v Record) DeepClone() Record { return res } -func ValueToRecord(v Value) (Record, error) { - rv, ok := v.(Record) - if !ok { - return nil, fmt.Errorf("%w: expected record got %v", ErrType, v.TypeName()) - } - return rv, nil -} - // An EntityUID is the identifier for a principal, action, or resource. type EntityUID struct { Type string @@ -420,14 +379,6 @@ func (v EntityUID) ExplicitMarshalJSON() ([]byte, error) { } func (v EntityUID) deepClone() Value { return v } -func ValueToEntity(v Value) (EntityUID, error) { - ev, ok := v.(EntityUID) - if !ok { - return EntityUID{}, fmt.Errorf("%w: expected (entity of type `any_entity_type`), got %v", ErrType, v.TypeName()) - } - return ev, nil -} - func EntityValueFromSlice(v []string) EntityUID { return EntityUID{ Type: strings.Join(v[:len(v)-1], "::"), @@ -449,14 +400,6 @@ func (v Path) Cedar() string { return string(v) } func (v Path) ExplicitMarshalJSON() ([]byte, error) { return json.Marshal(string(v)) } func (v Path) deepClone() Value { return v } -func ValueToPath(v Value) (Path, error) { - ev, ok := v.(Path) - if !ok { - return "", fmt.Errorf("%w: expected (Path of type `any_entity_type`), got %v", ErrType, v.TypeName()) - } - return ev, nil -} - func PathFromSlice(v []string) Path { return Path(strings.Join(v, "::")) } @@ -637,14 +580,6 @@ func (v Decimal) ExplicitMarshalJSON() ([]byte, error) { } func (v Decimal) deepClone() Value { return v } -func ValueToDecimal(v Value) (Decimal, error) { - d, ok := v.(Decimal) - if !ok { - return 0, fmt.Errorf("%w: expected decimal, got %v", ErrType, v.TypeName()) - } - return d, nil -} - // An IPAddr is value that represents an IP address. It can be either IPv4 or IPv6. // The value can represent an individual address or a range of addresses. type IPAddr netip.Prefix @@ -793,11 +728,3 @@ func (v IPAddr) ExplicitMarshalJSON() ([]byte, error) { // in this case, netip.Prefix does contain a pointer, but // the interface given is immutable, so it is safe to return func (v IPAddr) deepClone() Value { return v } - -func ValueToIP(v Value) (IPAddr, error) { - i, ok := v.(IPAddr) - if !ok { - return IPAddr{}, fmt.Errorf("%w: expected ipaddr, got %v", ErrType, v.TypeName()) - } - return i, nil -} diff --git a/types/value_test.go b/types/value_test.go index 1f43e4e4..a89601fc 100644 --- a/types/value_test.go +++ b/types/value_test.go @@ -9,19 +9,6 @@ import ( func TestBool(t *testing.T) { t.Parallel() - t.Run("roundTrip", func(t *testing.T) { - t.Parallel() - v, err := ValueToBool(Boolean(true)) - testutil.OK(t, err) - testutil.Equals(t, v, true) - }) - - t.Run("toBoolOnNonBool", func(t *testing.T) { - t.Parallel() - v, err := ValueToBool(Long(0)) - testutil.AssertError(t, err, ErrType) - testutil.Equals(t, v, false) - }) t.Run("Equal", func(t *testing.T) { t.Parallel() @@ -50,19 +37,6 @@ func TestBool(t *testing.T) { func TestLong(t *testing.T) { t.Parallel() - t.Run("roundTrip", func(t *testing.T) { - t.Parallel() - v, err := ValueToLong(Long(42)) - testutil.OK(t, err) - testutil.Equals(t, v, 42) - }) - - t.Run("toLongOnNonLong", func(t *testing.T) { - t.Parallel() - v, err := ValueToLong(Boolean(true)) - testutil.AssertError(t, err, ErrType) - testutil.Equals(t, v, 0) - }) t.Run("Equal", func(t *testing.T) { t.Parallel() @@ -91,19 +65,6 @@ func TestLong(t *testing.T) { func TestString(t *testing.T) { t.Parallel() - t.Run("roundTrip", func(t *testing.T) { - t.Parallel() - v, err := ValueToString(String("hello")) - testutil.OK(t, err) - testutil.Equals(t, v, "hello") - }) - - t.Run("toStringOnNonString", func(t *testing.T) { - t.Parallel() - v, err := ValueToString(Boolean(true)) - testutil.AssertError(t, err, ErrType) - testutil.Equals(t, v, "") - }) t.Run("Equal", func(t *testing.T) { t.Parallel() @@ -130,21 +91,6 @@ func TestString(t *testing.T) { func TestSet(t *testing.T) { t.Parallel() - t.Run("roundTrip", func(t *testing.T) { - t.Parallel() - v := Set{Boolean(true), Long(1)} - slice, err := ValueToSet(v) - testutil.OK(t, err) - v2 := slice - testutil.FatalIf(t, !v.Equal(v2), "got %v want %v", v, v2) - }) - - t.Run("ToSetOnNonSet", func(t *testing.T) { - t.Parallel() - v, err := ValueToSet(Boolean(true)) - testutil.AssertError(t, err, ErrType) - testutil.Equals(t, v, nil) - }) t.Run("Equal", func(t *testing.T) { t.Parallel() @@ -197,24 +143,6 @@ func TestSet(t *testing.T) { func TestRecord(t *testing.T) { t.Parallel() - t.Run("roundTrip", func(t *testing.T) { - t.Parallel() - v := Record{ - "foo": Boolean(true), - "bar": Long(1), - } - map_, err := ValueToRecord(v) - testutil.OK(t, err) - v2 := map_ - testutil.FatalIf(t, !v.Equal(v2), "got %v want %v", v, v2) - }) - - t.Run("toRecordOnNonRecord", func(t *testing.T) { - t.Parallel() - v, err := ValueToRecord(String("hello")) - testutil.AssertError(t, err, ErrType) - testutil.Equals(t, v, nil) - }) t.Run("Equal", func(t *testing.T) { t.Parallel() @@ -286,19 +214,6 @@ func TestRecord(t *testing.T) { func TestEntity(t *testing.T) { t.Parallel() - t.Run("roundTrip", func(t *testing.T) { - t.Parallel() - want := EntityUID{Type: "User", ID: "bananas"} - v, err := ValueToEntity(want) - testutil.OK(t, err) - testutil.Equals(t, v, want) - }) - t.Run("ToEntityOnNonEntity", func(t *testing.T) { - t.Parallel() - v, err := ValueToEntity(String("hello")) - testutil.AssertError(t, err, ErrType) - testutil.Equals(t, v, EntityUID{}) - }) t.Run("Equal", func(t *testing.T) { t.Parallel() @@ -422,22 +337,6 @@ func TestDecimal(t *testing.T) { } } - t.Run("roundTrip", func(t *testing.T) { - t.Parallel() - dv, err := ParseDecimal("1.20") - testutil.OK(t, err) - v, err := ValueToDecimal(dv) - testutil.OK(t, err) - testutil.FatalIf(t, !v.Equal(dv), "got %v want %v", v, dv) - }) - - t.Run("toDecimalOnNonDecimal", func(t *testing.T) { - t.Parallel() - v, err := ValueToDecimal(Boolean(true)) - testutil.AssertError(t, err, ErrType) - testutil.Equals(t, v, 0) - }) - t.Run("Equal", func(t *testing.T) { t.Parallel() one := Decimal(10000) @@ -511,12 +410,6 @@ func TestIP(t *testing.T) { } }) - t.Run("toIPOnNonIP", func(t *testing.T) { - t.Parallel() - v, err := ValueToIP(Boolean(true)) - testutil.AssertError(t, err, ErrType) - testutil.Equals(t, v, IPAddr{}) - }) t.Run("Equal", func(t *testing.T) { t.Parallel() From f5b13c31966232c867213f2a80e581c46faf3732 Mon Sep 17 00:00:00 2001 From: philhassey Date: Thu, 15 Aug 2024 17:49:20 -0600 Subject: [PATCH 146/216] types: appease linter Addresses IDX-142 Signed-off-by: philhassey --- types/value_test.go | 1 - 1 file changed, 1 deletion(-) diff --git a/types/value_test.go b/types/value_test.go index a89601fc..d80bdf3c 100644 --- a/types/value_test.go +++ b/types/value_test.go @@ -410,7 +410,6 @@ func TestIP(t *testing.T) { } }) - t.Run("Equal", func(t *testing.T) { t.Parallel() tests := []struct { From 13b9ea0d563eda49d28a214177e5fd78229ebb14 Mon Sep 17 00:00:00 2001 From: philhassey Date: Thu, 15 Aug 2024 18:00:26 -0600 Subject: [PATCH 147/216] types: improve pattern ergonomics Addresses IDX-142 Signed-off-by: philhassey --- ast/ast_test.go | 4 ++-- internal/ast/ast_test.go | 4 ++-- internal/json/json_marshal.go | 2 +- internal/json/json_unmarshal.go | 8 ++++---- types/pattern.go | 34 ++++++++++++++------------------- types/patttern_test.go | 22 ++++++++++----------- 6 files changed, 34 insertions(+), 40 deletions(-) diff --git a/ast/ast_test.go b/ast/ast_test.go index dadeb18a..e5f4992f 100644 --- a/ast/ast_test.go +++ b/ast/ast_test.go @@ -295,8 +295,8 @@ func TestASTByTable(t *testing.T) { }, { "opLike", - ast.Permit().When(ast.Long(42).Like(types.Pattern{})), - internalast.Permit().When(internalast.Long(42).Like(types.Pattern{})), + ast.Permit().When(ast.Long(42).Like(types.Pattern{}.Wildcard())), + internalast.Permit().When(internalast.Long(42).Like(types.Pattern{}.Wildcard())), }, { "opAnd", diff --git a/internal/ast/ast_test.go b/internal/ast/ast_test.go index 0473e82e..ee7f9056 100644 --- a/internal/ast/ast_test.go +++ b/internal/ast/ast_test.go @@ -350,9 +350,9 @@ func TestASTByTable(t *testing.T) { }, { "opLike", - ast.Permit().When(ast.Long(42).Like(types.Pattern{})), + ast.Permit().When(ast.Long(42).Like(types.Pattern{}.Wildcard())), ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, - Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeLike{Arg: ast.NodeValue{Value: types.Long(42)}, Value: types.Pattern{}}}}}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeLike{Arg: ast.NodeValue{Value: types.Long(42)}, Value: types.Pattern{}.Wildcard()}}}}, }, { "opAnd", diff --git a/internal/json/json_marshal.go b/internal/json/json_marshal.go index d148ccce..47815d31 100644 --- a/internal/json/json_marshal.go +++ b/internal/json/json_marshal.go @@ -97,7 +97,7 @@ func strToJSON(dest **strJSON, src ast.StrOpNode) { func patternToJSON(dest **patternJSON, src ast.NodeTypeLike) { res := &patternJSON{} res.Left.FromNode(src.Arg) - for _, comp := range src.Value.Components { + for _, comp := range src.Value { if comp.Wildcard { res.Pattern = append(res.Pattern, patternComponentJSON{Wildcard: true}) } diff --git a/internal/json/json_unmarshal.go b/internal/json/json_unmarshal.go index c345a25a..fc2676a6 100644 --- a/internal/json/json_unmarshal.go +++ b/internal/json/json_unmarshal.go @@ -66,16 +66,16 @@ func (j patternJSON) ToNode(f func(a ast.Node, k types.Pattern) ast.Node) (ast.N if err != nil { return ast.Node{}, fmt.Errorf("error in left: %w", err) } - pattern := &types.Pattern{} + pattern := types.Pattern{} for _, compJSON := range j.Pattern { if compJSON.Wildcard { - pattern = pattern.AddWildcard() + pattern = pattern.Wildcard() } else { - pattern = pattern.AddLiteral(compJSON.Literal.Literal) + pattern = pattern.Literal(compJSON.Literal.Literal) } } - return f(left, *pattern), nil + return f(left, pattern), nil } func (j isJSON) ToNode() (ast.Node, error) { left, err := j.Left.ToNode() diff --git a/types/pattern.go b/types/pattern.go index 4d5b05fd..3aac8335 100644 --- a/types/pattern.go +++ b/types/pattern.go @@ -15,14 +15,12 @@ type PatternComponent struct { // Pattern is used to define a string used for the like operator. It does not // conform to the Value interface, as it is not one of the Cedar types. -type Pattern struct { - Components []PatternComponent -} +type Pattern []PatternComponent func (p Pattern) Cedar() string { var buf bytes.Buffer buf.WriteRune('"') - for _, comp := range p.Components { + for _, comp := range p { if comp.Wildcard { buf.WriteRune('*') } @@ -36,29 +34,27 @@ func (p Pattern) Cedar() string { return buf.String() } -func (p *Pattern) AddWildcard() *Pattern { +func (p Pattern) Wildcard() Pattern { star := PatternComponent{Wildcard: true} - if len(p.Components) == 0 { - p.Components = []PatternComponent{star} + if len(p) == 0 { + p = Pattern{star} return p } - lastComp := p.Components[len(p.Components)-1] + lastComp := p[len(p)-1] if lastComp.Wildcard && lastComp.Literal == "" { return p } - p.Components = append(p.Components, star) + p = append(p, star) return p } -func (p *Pattern) AddLiteral(s string) *Pattern { - if len(p.Components) == 0 { - p.Components = []PatternComponent{{}} +func (p Pattern) Literal(s string) Pattern { + if len(p) == 0 { + p = Pattern{{}} } - - lastComp := &p.Components[len(p.Components)-1] - lastComp.Literal = lastComp.Literal + s + p[len(p)-1].Literal += s return p } @@ -77,8 +73,8 @@ func (p *Pattern) AddLiteral(s string) *Pattern { // c matches character c (c != '*') func (p Pattern) Match(arg string) (matched bool) { Pattern: - for i, comp := range p.Components { - lastChunk := i == len(p.Components)-1 + for i, comp := range p { + lastChunk := i == len(p)-1 if comp.Wildcard && comp.Literal == "" { return true } @@ -144,7 +140,5 @@ func ParsePattern(s string) (Pattern, error) { } comps = append(comps, comp) } - return Pattern{ - Components: comps, - }, nil + return comps, nil } diff --git a/types/patttern_test.go b/types/patttern_test.go index f488f25f..d76159e1 100644 --- a/types/patttern_test.go +++ b/types/patttern_test.go @@ -9,23 +9,23 @@ import ( func TestPatternFromBuilder(t *testing.T) { tests := []struct { name string - Pattern *Pattern + Pattern Pattern want []PatternComponent }{ - {"empty", &Pattern{}, nil}, - {"wildcard", (&Pattern{}).AddWildcard(), []PatternComponent{{Wildcard: true}}}, - {"saturate two wildcards", (&Pattern{}).AddWildcard().AddWildcard(), []PatternComponent{{Wildcard: true}}}, - {"literal", (&Pattern{}).AddLiteral("foo"), []PatternComponent{{Literal: "foo"}}}, - {"saturate two literals", (&Pattern{}).AddLiteral("foo").AddLiteral("bar"), []PatternComponent{{Literal: "foobar"}}}, - {"literal with asterisk", (&Pattern{}).AddLiteral("fo*o"), []PatternComponent{{Literal: "fo*o"}}}, - {"wildcard sandwich", (&Pattern{}).AddLiteral("foo").AddWildcard().AddLiteral("bar"), []PatternComponent{{Literal: "foo"}, {Wildcard: true, Literal: "bar"}}}, - {"literal sandwich", (&Pattern{}).AddWildcard().AddLiteral("foo").AddWildcard(), []PatternComponent{{Wildcard: true, Literal: "foo"}, {Wildcard: true}}}, + {"empty", Pattern{}, Pattern{}}, + {"wildcard", (Pattern{}).Wildcard(), Pattern{{Wildcard: true}}}, + {"saturate two wildcards", (Pattern{}).Wildcard().Wildcard(), Pattern{{Wildcard: true}}}, + {"literal", (Pattern{}).Literal("foo"), Pattern{{Literal: "foo"}}}, + {"saturate two literals", (Pattern{}).Literal("foo").Literal("bar"), Pattern{{Literal: "foobar"}}}, + {"literal with asterisk", (Pattern{}).Literal("fo*o"), Pattern{{Literal: "fo*o"}}}, + {"wildcard sandwich", (Pattern{}).Literal("foo").Wildcard().Literal("bar"), Pattern{{Literal: "foo"}, {Wildcard: true, Literal: "bar"}}}, + {"literal sandwich", (Pattern{}).Wildcard().Literal("foo").Wildcard(), Pattern{{Wildcard: true, Literal: "foo"}, {Wildcard: true}}}, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() - testutil.Equals(t, tt.Pattern.Components, tt.want) + testutil.Equals(t, tt.Pattern, tt.want) }) } } @@ -81,7 +81,7 @@ func TestParsePattern(t *testing.T) { testutil.Equals(t, err.Error(), tt.wantErr) } else { testutil.Equals(t, tt.wantOk, true) - testutil.Equals(t, got.Components, tt.want) + testutil.Equals(t, got, tt.want) } }) } From 022e46dbd74056bc7da96b8e4205226d78008a07 Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Thu, 15 Aug 2024 16:51:42 -0700 Subject: [PATCH 148/216] cedar-go: hide Policy annotations inside an accessor method Signed-off-by: philhassey --- policy.go | 61 +++++++++++++++++++++------------------------- policy_set_test.go | 2 +- 2 files changed, 29 insertions(+), 34 deletions(-) diff --git a/policy.go b/policy.go index 083a4c95..8fe28feb 100644 --- a/policy.go +++ b/policy.go @@ -13,11 +13,10 @@ import ( // A Policy is the parsed form of a single Cedar language policy statement. type Policy struct { - Position Position // location within the policy text document - Annotations Annotations // annotations found for this policy - Effect Effect // the effect of this policy - eval evaler // determines if a policy matches a request. - ast *internalast.Policy + Position Position // location within the policy text document + Effect Effect // the effect of this policy + eval evaler // determines if a policy matches a request. + ast *internalast.Policy } // A Position describes an arbitrary source position including the file, line, and column location. @@ -32,15 +31,6 @@ type Position struct { // have no impact on policy evaluation. type Annotations map[string]string -// TODO: Is this where we should deal with duplicate keys? -func newAnnotationsFromSlice(annotations []internalast.AnnotationType) Annotations { - res := make(map[string]string, len(annotations)) - for _, e := range annotations { - res[string(e.Key)] = string(e.Value) - } - return res -} - // An Effect specifies the intent of the policy, to either permit or forbid any // request that matches the scope and conditions specified in the policy. type Effect internalast.Effect @@ -68,11 +58,10 @@ func (p *Policy) UnmarshalJSON(b []byte) error { return err } *p = Policy{ - Position: Position{}, - Annotations: newAnnotationsFromSlice(jsonPolicy.Annotations), - Effect: Effect(jsonPolicy.Effect), - eval: eval.Compile((*internalast.Policy)(&jsonPolicy)), - ast: (*internalast.Policy)(&jsonPolicy), + Position: Position{}, + Effect: Effect(jsonPolicy.Effect), + eval: eval.Compile((*internalast.Policy)(&jsonPolicy)), + ast: (*internalast.Policy)(&jsonPolicy), } return nil } @@ -89,11 +78,10 @@ func (p *Policy) UnmarshalCedar(b []byte) error { } *p = Policy{ - Position: Position{}, - Annotations: newAnnotationsFromSlice(cedarPolicy.Annotations), - Effect: Effect(cedarPolicy.Effect), - eval: eval.Compile((*internalast.Policy)(&cedarPolicy)), - ast: (*internalast.Policy)(&cedarPolicy), + Position: Position{}, + Effect: Effect(cedarPolicy.Effect), + eval: eval.Compile((*internalast.Policy)(&cedarPolicy)), + ast: (*internalast.Policy)(&cedarPolicy), } return nil } @@ -101,14 +89,22 @@ func (p *Policy) UnmarshalCedar(b []byte) error { func NewPolicyFromAST(astIn *ast.Policy) *Policy { pp := (*internalast.Policy)(astIn) return &Policy{ - Position: Position{}, - Annotations: newAnnotationsFromSlice(astIn.Annotations), - Effect: Effect(astIn.Effect), - eval: eval.Compile(pp), - ast: pp, + Position: Position{}, + Effect: Effect(astIn.Effect), + eval: eval.Compile(pp), + ast: pp, } } +func (p Policy) Annotations() Annotations { + // TODO: Where should we deal with duplicate keys? + res := make(map[string]string, len(p.ast.Annotations)) + for _, e := range p.ast.Annotations { + res[string(e.Key)] = string(e.Value) + } + return res +} + // PolicySlice represents a set of un-named Policy's. Cedar documents, unlike the JSON format, don't have a means of // naming individual policies. type PolicySlice []*Policy @@ -128,10 +124,9 @@ func (p *PolicySlice) UnmarshalCedar(b []byte) error { Line: p.Position.Line, Column: p.Position.Column, }, - Annotations: newAnnotationsFromSlice(p.Annotations), - Effect: Effect(p.Effect), - eval: eval.Compile((*internalast.Policy)(p)), - ast: (*internalast.Policy)(p), + Effect: Effect(p.Effect), + eval: eval.Compile((*internalast.Policy)(p)), + ast: (*internalast.Policy)(p), }) } *p = policySlice diff --git a/policy_set_test.go b/policy_set_test.go index 65451b35..114804bb 100644 --- a/policy_set_test.go +++ b/policy_set_test.go @@ -25,7 +25,7 @@ func TestNewPolicySetFromFile(t *testing.T) { t.Parallel() ps, err := NewPolicySetFromFile("policy.cedar", []byte(`@key("value") permit (principal, action, resource);`)) testutil.OK(t, err) - testutil.Equals(t, ps.GetPolicy("policy0").Annotations, Annotations{"key": "value"}) + testutil.Equals(t, ps.GetPolicy("policy0").Annotations(), Annotations{"key": "value"}) }) } From 636f2b8dd9b20f3ddf0146debd1a3fbb768506d8 Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Thu, 15 Aug 2024 16:53:46 -0700 Subject: [PATCH 149/216] cedar-go: hide Policy.Effect inside an accessor method Signed-off-by: philhassey --- authorize.go | 2 +- policy.go | 13 ++++++------- policy_set_test.go | 4 ++-- 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/authorize.go b/authorize.go index 2da1abcb..4bd3d960 100644 --- a/authorize.go +++ b/authorize.go @@ -104,7 +104,7 @@ func (p PolicySet) IsAuthorized(entityMap entities.Entities, req Request) (Decis if !vb { continue } - if po.Effect == Forbid { + if po.Effect() == Forbid { forbidReasons = append(forbidReasons, Reason{PolicyID: id, Position: po.Position}) gotForbid = true } else { diff --git a/policy.go b/policy.go index 8fe28feb..81f6957d 100644 --- a/policy.go +++ b/policy.go @@ -14,7 +14,6 @@ import ( // A Policy is the parsed form of a single Cedar language policy statement. type Policy struct { Position Position // location within the policy text document - Effect Effect // the effect of this policy eval evaler // determines if a policy matches a request. ast *internalast.Policy } @@ -59,7 +58,6 @@ func (p *Policy) UnmarshalJSON(b []byte) error { } *p = Policy{ Position: Position{}, - Effect: Effect(jsonPolicy.Effect), eval: eval.Compile((*internalast.Policy)(&jsonPolicy)), ast: (*internalast.Policy)(&jsonPolicy), } @@ -79,7 +77,6 @@ func (p *Policy) UnmarshalCedar(b []byte) error { *p = Policy{ Position: Position{}, - Effect: Effect(cedarPolicy.Effect), eval: eval.Compile((*internalast.Policy)(&cedarPolicy)), ast: (*internalast.Policy)(&cedarPolicy), } @@ -90,7 +87,6 @@ func NewPolicyFromAST(astIn *ast.Policy) *Policy { pp := (*internalast.Policy)(astIn) return &Policy{ Position: Position{}, - Effect: Effect(astIn.Effect), eval: eval.Compile(pp), ast: pp, } @@ -105,6 +101,10 @@ func (p Policy) Annotations() Annotations { return res } +func (p Policy) Effect() Effect { + return Effect(p.ast.Effect) +} + // PolicySlice represents a set of un-named Policy's. Cedar documents, unlike the JSON format, don't have a means of // naming individual policies. type PolicySlice []*Policy @@ -124,9 +124,8 @@ func (p *PolicySlice) UnmarshalCedar(b []byte) error { Line: p.Position.Line, Column: p.Position.Column, }, - Effect: Effect(p.Effect), - eval: eval.Compile((*internalast.Policy)(p)), - ast: (*internalast.Policy)(p), + eval: eval.Compile((*internalast.Policy)(p)), + ast: (*internalast.Policy)(p), }) } *p = policySlice diff --git a/policy_set_test.go b/policy_set_test.go index 114804bb..4f6d5c0a 100644 --- a/policy_set_test.go +++ b/policy_set_test.go @@ -112,8 +112,8 @@ forbid ( ps.UpsertPolicy(PolicyID(fmt.Sprintf("policy%d", i)), p) } - testutil.Equals(t, ps.GetPolicy("policy0").Effect, Permit) - testutil.Equals(t, ps.GetPolicy("policy1").Effect, Forbid) + testutil.Equals(t, ps.GetPolicy("policy0").Effect(), Permit) + testutil.Equals(t, ps.GetPolicy("policy1").Effect(), Forbid) var buf bytes.Buffer ps.MarshalCedar(&buf) From e5198d1554673059e7482bf27a575c7249a9fdc0 Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Thu, 15 Aug 2024 17:01:40 -0700 Subject: [PATCH 150/216] cedar-go: change Policy.Position into an accessor function, while still allowing callers to set the file name on the policy Signed-off-by: philhassey --- authorize.go | 8 ++++---- policy.go | 40 +++++++++++++++++++++++----------------- policy_set.go | 2 +- policy_set_test.go | 2 +- 4 files changed, 29 insertions(+), 23 deletions(-) diff --git a/authorize.go b/authorize.go index 4bd3d960..feecd70b 100644 --- a/authorize.go +++ b/authorize.go @@ -92,23 +92,23 @@ func (p PolicySet) IsAuthorized(entityMap entities.Entities, req Request) (Decis for id, po := range p.policies { v, err := po.eval.Eval(c) if err != nil { - diag.Errors = append(diag.Errors, Error{PolicyID: id, Position: po.Position, Message: err.Error()}) + diag.Errors = append(diag.Errors, Error{PolicyID: id, Position: po.Position(), Message: err.Error()}) continue } vb, err := eval.ValueToBool(v) if err != nil { // should never happen, maybe remove this case - diag.Errors = append(diag.Errors, Error{PolicyID: id, Position: po.Position, Message: err.Error()}) + diag.Errors = append(diag.Errors, Error{PolicyID: id, Position: po.Position(), Message: err.Error()}) continue } if !vb { continue } if po.Effect() == Forbid { - forbidReasons = append(forbidReasons, Reason{PolicyID: id, Position: po.Position}) + forbidReasons = append(forbidReasons, Reason{PolicyID: id, Position: po.Position()}) gotForbid = true } else { - permitReasons = append(permitReasons, Reason{PolicyID: id, Position: po.Position}) + permitReasons = append(permitReasons, Reason{PolicyID: id, Position: po.Position()}) gotPermit = true } } diff --git a/policy.go b/policy.go index 81f6957d..85d75122 100644 --- a/policy.go +++ b/policy.go @@ -13,9 +13,10 @@ import ( // A Policy is the parsed form of a single Cedar language policy statement. type Policy struct { - Position Position // location within the policy text document - eval evaler // determines if a policy matches a request. - ast *internalast.Policy + eval evaler // determines if a policy matches a request. + ast *internalast.Policy + // TODO: Remove this and just store source file information in the generated policy ID? + sourceFile string } // A Position describes an arbitrary source position including the file, line, and column location. @@ -57,9 +58,8 @@ func (p *Policy) UnmarshalJSON(b []byte) error { return err } *p = Policy{ - Position: Position{}, - eval: eval.Compile((*internalast.Policy)(&jsonPolicy)), - ast: (*internalast.Policy)(&jsonPolicy), + eval: eval.Compile((*internalast.Policy)(&jsonPolicy)), + ast: (*internalast.Policy)(&jsonPolicy), } return nil } @@ -76,9 +76,8 @@ func (p *Policy) UnmarshalCedar(b []byte) error { } *p = Policy{ - Position: Position{}, - eval: eval.Compile((*internalast.Policy)(&cedarPolicy)), - ast: (*internalast.Policy)(&cedarPolicy), + eval: eval.Compile((*internalast.Policy)(&cedarPolicy)), + ast: (*internalast.Policy)(&cedarPolicy), } return nil } @@ -86,9 +85,8 @@ func (p *Policy) UnmarshalCedar(b []byte) error { func NewPolicyFromAST(astIn *ast.Policy) *Policy { pp := (*internalast.Policy)(astIn) return &Policy{ - Position: Position{}, - eval: eval.Compile(pp), - ast: pp, + eval: eval.Compile(pp), + ast: pp, } } @@ -105,6 +103,19 @@ func (p Policy) Effect() Effect { return Effect(p.ast.Effect) } +func (p Policy) Position() Position { + return Position{ + Filename: p.sourceFile, + Offset: p.ast.Position.Offset, + Line: p.ast.Position.Line, + Column: p.ast.Position.Column, + } +} + +func (p *Policy) SetSourceFile(path string) { + p.sourceFile = path +} + // PolicySlice represents a set of un-named Policy's. Cedar documents, unlike the JSON format, don't have a means of // naming individual policies. type PolicySlice []*Policy @@ -119,11 +130,6 @@ func (p *PolicySlice) UnmarshalCedar(b []byte) error { policySlice := make([]*Policy, 0, len(res)) for _, p := range res { policySlice = append(policySlice, &Policy{ - Position: Position{ - Offset: p.Position.Offset, - Line: p.Position.Line, - Column: p.Position.Column, - }, eval: eval.Compile((*internalast.Policy)(p)), ast: (*internalast.Policy)(p), }) diff --git a/policy_set.go b/policy_set.go index 9959e5ff..3294cd44 100644 --- a/policy_set.go +++ b/policy_set.go @@ -33,7 +33,7 @@ func NewPolicySetFromFile(fileName string, document []byte) (PolicySet, error) { policyMap := make(map[PolicyID]*Policy, len(policySlice)) for i, p := range policySlice { policyID := PolicyID(fmt.Sprintf("policy%d", i)) - p.Position.Filename = fileName + p.SetSourceFile(fileName) policyMap[policyID] = p } return PolicySet{policies: policyMap}, nil diff --git a/policy_set_test.go b/policy_set_test.go index 4f6d5c0a..bbe89c43 100644 --- a/policy_set_test.go +++ b/policy_set_test.go @@ -108,7 +108,7 @@ forbid ( ps := NewPolicySet() for i, p := range policies { - p.Position.Filename = "example.cedar" + p.SetSourceFile("example.cedar") ps.UpsertPolicy(PolicyID(fmt.Sprintf("policy%d", i)), p) } From 6eb01d56484e51b83ce1bdcfa7d525718651fdcc Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Thu, 15 Aug 2024 17:08:56 -0700 Subject: [PATCH 151/216] cedar-go: create a newPolicy constructor for polcy to simplify callsites Signed-off-by: philhassey --- policy.go | 28 +++++++++++----------------- 1 file changed, 11 insertions(+), 17 deletions(-) diff --git a/policy.go b/policy.go index 85d75122..04ec3d8d 100644 --- a/policy.go +++ b/policy.go @@ -41,6 +41,10 @@ const ( Forbid = Effect(false) ) +func newPolicy(astIn *internalast.Policy) Policy { + return Policy{eval: eval.Compile(astIn), ast: astIn, sourceFile: ""} +} + // MarshalJSON encodes a single Policy statement in the JSON format specified by the [Cedar documentation]. // // [Cedar documentation]: https://docs.cedarpolicy.com/policies/json-format.html @@ -57,10 +61,8 @@ func (p *Policy) UnmarshalJSON(b []byte) error { if err := jsonPolicy.UnmarshalJSON(b); err != nil { return err } - *p = Policy{ - eval: eval.Compile((*internalast.Policy)(&jsonPolicy)), - ast: (*internalast.Policy)(&jsonPolicy), - } + + *p = newPolicy((*internalast.Policy)(&jsonPolicy)) return nil } @@ -75,19 +77,13 @@ func (p *Policy) UnmarshalCedar(b []byte) error { return err } - *p = Policy{ - eval: eval.Compile((*internalast.Policy)(&cedarPolicy)), - ast: (*internalast.Policy)(&cedarPolicy), - } + *p = newPolicy((*internalast.Policy)(&cedarPolicy)) return nil } func NewPolicyFromAST(astIn *ast.Policy) *Policy { - pp := (*internalast.Policy)(astIn) - return &Policy{ - eval: eval.Compile(pp), - ast: pp, - } + p := newPolicy((*internalast.Policy)(astIn)) + return &p } func (p Policy) Annotations() Annotations { @@ -129,10 +125,8 @@ func (p *PolicySlice) UnmarshalCedar(b []byte) error { } policySlice := make([]*Policy, 0, len(res)) for _, p := range res { - policySlice = append(policySlice, &Policy{ - eval: eval.Compile((*internalast.Policy)(p)), - ast: (*internalast.Policy)(p), - }) + newPolicy := newPolicy((*internalast.Policy)(p)) + policySlice = append(policySlice, &newPolicy) } *p = policySlice return nil From 09aff40e0473bd7a42f8620b2258308b5a6d02b1 Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Thu, 15 Aug 2024 17:10:00 -0700 Subject: [PATCH 152/216] cedar-go: re-home type definitions next to the appropriate accessor function Signed-off-by: philhassey --- policy.go | 44 ++++++++++++++++++++++---------------------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/policy.go b/policy.go index 04ec3d8d..02241e8f 100644 --- a/policy.go +++ b/policy.go @@ -19,28 +19,6 @@ type Policy struct { sourceFile string } -// A Position describes an arbitrary source position including the file, line, and column location. -type Position struct { - Filename string // filename, if any - Offset int // byte offset, starting at 0 - Line int // line number, starting at 1 - Column int // column number, starting at 1 (character count per line) -} - -// An Annotations is a map of key, value pairs found in the policy. Annotations -// have no impact on policy evaluation. -type Annotations map[string]string - -// An Effect specifies the intent of the policy, to either permit or forbid any -// request that matches the scope and conditions specified in the policy. -type Effect internalast.Effect - -// Each Policy has a Permit or Forbid effect that is determined during parsing. -const ( - Permit = Effect(true) - Forbid = Effect(false) -) - func newPolicy(astIn *internalast.Policy) Policy { return Policy{eval: eval.Compile(astIn), ast: astIn, sourceFile: ""} } @@ -86,6 +64,10 @@ func NewPolicyFromAST(astIn *ast.Policy) *Policy { return &p } +// An Annotations is a map of key, value pairs found in the policy. Annotations +// have no impact on policy evaluation. +type Annotations map[string]string + func (p Policy) Annotations() Annotations { // TODO: Where should we deal with duplicate keys? res := make(map[string]string, len(p.ast.Annotations)) @@ -95,10 +77,28 @@ func (p Policy) Annotations() Annotations { return res } +// An Effect specifies the intent of the policy, to either permit or forbid any +// request that matches the scope and conditions specified in the policy. +type Effect internalast.Effect + +// Each Policy has a Permit or Forbid effect that is determined during parsing. +const ( + Permit = Effect(true) + Forbid = Effect(false) +) + func (p Policy) Effect() Effect { return Effect(p.ast.Effect) } +// A Position describes an arbitrary source position including the file, line, and column location. +type Position struct { + Filename string // filename, if any + Offset int // byte offset, starting at 0 + Line int // line number, starting at 1 + Column int // column number, starting at 1 (character count per line) +} + func (p Policy) Position() Position { return Position{ Filename: p.sourceFile, From 8afc7365267029dd9476e8588f49a934c40873b0 Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Thu, 15 Aug 2024 17:21:12 -0700 Subject: [PATCH 153/216] cedar-go: comment hygiene Signed-off-by: philhassey --- policy_set.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/policy_set.go b/policy_set.go index 3294cd44..c5a67e0b 100644 --- a/policy_set.go +++ b/policy_set.go @@ -9,12 +9,12 @@ import ( type PolicyID string -// A set of named policies against which a request can be authorized. +// PolicySet is a set of named policies against which a request can be authorized. type PolicySet struct { policies map[PolicyID]*Policy } -// Create a new, empty PolicySet +// NewPolicySet creates a new, empty PolicySet func NewPolicySet() PolicySet { return PolicySet{policies: map[PolicyID]*Policy{}} } From 61a8b9c13471f74f225315c2daf429f824cf7b48 Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Thu, 15 Aug 2024 17:28:28 -0700 Subject: [PATCH 154/216] cedar-go: more window dressing Signed-off-by: philhassey --- README.md | 11 ++++++++--- policy_set_test.go | 39 ++++++++++++++++++++------------------- policy_test.go | 11 ++++++----- 3 files changed, 34 insertions(+), 27 deletions(-) diff --git a/README.md b/README.md index cb41c7fa..efd8992b 100644 --- a/README.md +++ b/README.md @@ -56,7 +56,7 @@ import ( "fmt" "log" - "github.com/cedar-policy/cedar-go" + cedar "github.com/cedar-policy/cedar-go" ) const policyCedar = `permit ( @@ -80,10 +80,14 @@ const entitiesJSON = `[ ]` func main() { - ps, err := cedar.NewPolicySet("policy.cedar", []byte(policyCedar)) - if err != nil { + var policy cedar.Policy + if err := policy.UnmarshalCedar([]byte(policyCedar)); err != nil { log.Fatal(err) } + + ps := cedar.NewPolicySet() + ps.UpsertPolicy("policy0", &policy) + var entities cedar.Entities if err := json.Unmarshal([]byte(entitiesJSON), &entities); err != nil { log.Fatal(err) @@ -94,6 +98,7 @@ func main() { Resource: cedar.EntityUID{Type: "Photo", ID: "VacationPhoto94.jpg"}, Context: cedar.Record{}, } + ok, _ := ps.IsAuthorized(entities, req) fmt.Println(ok) } diff --git a/policy_set_test.go b/policy_set_test.go index bbe89c43..7d056280 100644 --- a/policy_set_test.go +++ b/policy_set_test.go @@ -1,10 +1,11 @@ -package cedar +package cedar_test import ( "bytes" "fmt" "testing" + "github.com/cedar-policy/cedar-go" "github.com/cedar-policy/cedar-go/ast" "github.com/cedar-policy/cedar-go/internal/testutil" ) @@ -13,19 +14,19 @@ func TestNewPolicySetFromFile(t *testing.T) { t.Parallel() t.Run("err-in-tokenize", func(t *testing.T) { t.Parallel() - _, err := NewPolicySetFromFile("policy.cedar", []byte(`"`)) + _, err := cedar.NewPolicySetFromFile("policy.cedar", []byte(`"`)) testutil.Error(t, err) }) t.Run("err-in-parse", func(t *testing.T) { t.Parallel() - _, err := NewPolicySetFromFile("policy.cedar", []byte(`err`)) + _, err := cedar.NewPolicySetFromFile("policy.cedar", []byte(`err`)) testutil.Error(t, err) }) t.Run("annotations", func(t *testing.T) { t.Parallel() - ps, err := NewPolicySetFromFile("policy.cedar", []byte(`@key("value") permit (principal, action, resource);`)) + ps, err := cedar.NewPolicySetFromFile("policy.cedar", []byte(`@key("value") permit (principal, action, resource);`)) testutil.OK(t, err) - testutil.Equals(t, ps.GetPolicy("policy0").Annotations(), Annotations{"key": "value"}) + testutil.Equals(t, ps.GetPolicy("policy0").Annotations(), cedar.Annotations{"key": "value"}) }) } @@ -34,14 +35,14 @@ func TestUpsertPolicy(t *testing.T) { t.Run("insert", func(t *testing.T) { t.Parallel() - policy0 := NewPolicyFromAST(ast.Forbid()) + policy0 := cedar.NewPolicyFromAST(ast.Forbid()) - var policy1 Policy + var policy1 cedar.Policy testutil.OK(t, policy1.UnmarshalJSON( []byte(`{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}}`), )) - ps := NewPolicySet() + ps := cedar.NewPolicySet() ps.UpsertPolicy("policy0", policy0) ps.UpsertPolicy("policy1", &policy1) @@ -52,12 +53,12 @@ func TestUpsertPolicy(t *testing.T) { t.Run("upsert", func(t *testing.T) { t.Parallel() - ps := NewPolicySet() + ps := cedar.NewPolicySet() - p1 := NewPolicyFromAST(ast.Forbid()) + p1 := cedar.NewPolicyFromAST(ast.Forbid()) ps.UpsertPolicy("a wavering policy", p1) - p2 := NewPolicyFromAST(ast.Permit()) + p2 := cedar.NewPolicyFromAST(ast.Permit()) ps.UpsertPolicy("a wavering policy", p2) testutil.Equals(t, ps.GetPolicy("a wavering policy"), p2) @@ -69,7 +70,7 @@ func TestDeletePolicy(t *testing.T) { t.Run("delete non-existent", func(t *testing.T) { t.Parallel() - ps := NewPolicySet() + ps := cedar.NewPolicySet() // Just verify that this doesn't crash ps.DeletePolicy("not a policy") @@ -77,9 +78,9 @@ func TestDeletePolicy(t *testing.T) { t.Run("delete existing", func(t *testing.T) { t.Parallel() - ps := NewPolicySet() + ps := cedar.NewPolicySet() - p1 := NewPolicyFromAST(ast.Forbid()) + p1 := cedar.NewPolicyFromAST(ast.Forbid()) ps.UpsertPolicy("a policy", p1) ps.DeletePolicy("a policy") @@ -103,17 +104,17 @@ forbid ( resource );` - var policies PolicySlice + var policies cedar.PolicySlice testutil.OK(t, policies.UnmarshalCedar([]byte(policiesStr))) - ps := NewPolicySet() + ps := cedar.NewPolicySet() for i, p := range policies { p.SetSourceFile("example.cedar") - ps.UpsertPolicy(PolicyID(fmt.Sprintf("policy%d", i)), p) + ps.UpsertPolicy(cedar.PolicyID(fmt.Sprintf("policy%d", i)), p) } - testutil.Equals(t, ps.GetPolicy("policy0").Effect(), Permit) - testutil.Equals(t, ps.GetPolicy("policy1").Effect(), Forbid) + testutil.Equals(t, ps.GetPolicy("policy0").Effect(), cedar.Permit) + testutil.Equals(t, ps.GetPolicy("policy1").Effect(), cedar.Forbid) var buf bytes.Buffer ps.MarshalCedar(&buf) diff --git a/policy_test.go b/policy_test.go index d469254f..36fc10d1 100644 --- a/policy_test.go +++ b/policy_test.go @@ -1,10 +1,11 @@ -package cedar +package cedar_test import ( "bytes" "encoding/json" "testing" + "github.com/cedar-policy/cedar-go" "github.com/cedar-policy/cedar-go/ast" "github.com/cedar-policy/cedar-go/internal/testutil" "github.com/cedar-policy/cedar-go/types" @@ -58,7 +59,7 @@ func TestPolicyJSON(t *testing.T) { }`, )) - var policy Policy + var policy cedar.Policy testutil.OK(t, policy.UnmarshalJSON(jsonEncodedPolicy)) output, err := policy.MarshalJSON() @@ -78,7 +79,7 @@ func TestPolicyCedar(t *testing.T) { ) when { resource.owner == principal };` - var policy Policy + var policy cedar.Policy testutil.OK(t, policy.UnmarshalCedar([]byte(policyStr))) var buf bytes.Buffer @@ -94,7 +95,7 @@ func TestPolicyAST(t *testing.T) { ActionEq(types.NewEntityUID("Action", "editPhoto")). When(ast.Resource().Access("owner").Equals(ast.Principal())) - _ = NewPolicyFromAST(astExample) + _ = cedar.NewPolicyFromAST(astExample) } func TestPolicySlice(t *testing.T) { @@ -113,7 +114,7 @@ forbid ( resource );` - var policies PolicySlice + var policies cedar.PolicySlice testutil.OK(t, policies.UnmarshalCedar([]byte(policiesStr))) var buf bytes.Buffer From c8810d5b9f4f8f482622f1d53035fae7e4549b84 Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Fri, 16 Aug 2024 10:49:25 -0700 Subject: [PATCH 155/216] cedar-go: add a way to upsert one PolicySet into another Signed-off-by: philhassey --- policy_set.go | 8 +++++++ policy_set_test.go | 54 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+) diff --git a/policy_set.go b/policy_set.go index c5a67e0b..97a4dddc 100644 --- a/policy_set.go +++ b/policy_set.go @@ -55,6 +55,14 @@ func (p *PolicySet) DeletePolicy(policyID PolicyID) { delete(p.policies, policyID) } +// UpsertPolicySet inserts or updates all the policies from src into this PolicySet. Policies in this PolicySet with +// identical IDs in src are clobbered by the policies from src. +func (p *PolicySet) UpsertPolicySet(src PolicySet) { + for id, policy := range src.policies { + p.policies[id] = policy + } +} + // MarshalCedar emits a concatenated Cedar representation of a PolicySet. The policy names are stripped, but policies // are emitted in lexicographical order by ID. func (p PolicySet) MarshalCedar(buf *bytes.Buffer) { diff --git a/policy_set_test.go b/policy_set_test.go index 7d056280..f94b1a85 100644 --- a/policy_set_test.go +++ b/policy_set_test.go @@ -65,6 +65,60 @@ func TestUpsertPolicy(t *testing.T) { }) } +func TestUpsertPolicySet(t *testing.T) { + t.Parallel() + t.Run("empty dst", func(t *testing.T) { + t.Parallel() + + policy0 := cedar.NewPolicyFromAST(ast.Forbid()) + + var policy1 cedar.Policy + testutil.OK(t, policy1.UnmarshalJSON( + []byte(`{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}}`), + )) + + ps1 := cedar.NewPolicySet() + ps1.UpsertPolicy("policy0", policy0) + ps1.UpsertPolicy("policy1", &policy1) + + ps2 := cedar.NewPolicySet() + ps2.UpsertPolicySet(ps1) + + testutil.Equals(t, ps2.GetPolicy("policy0"), policy0) + testutil.Equals(t, ps2.GetPolicy("policy1"), &policy1) + testutil.Equals(t, ps2.GetPolicy("policy2"), nil) + }) + t.Run("upsert", func(t *testing.T) { + t.Parallel() + + policyA := cedar.NewPolicyFromAST(ast.Forbid()) + + var policyB cedar.Policy + testutil.OK(t, policyB.UnmarshalJSON( + []byte(`{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}}`), + )) + + policyC := cedar.NewPolicyFromAST(ast.Permit()) + + // ps1 maps 0 -> A and 1 -> B + ps1 := cedar.NewPolicySet() + ps1.UpsertPolicy("policy0", policyA) + ps1.UpsertPolicy("policy1", &policyB) + + // ps1 maps 0 -> b and 2 -> C + ps2 := cedar.NewPolicySet() + ps2.UpsertPolicy("policy0", &policyB) + ps2.UpsertPolicy("policy2", policyC) + + // Upsert should clobber ps2's policy0, insert policy1, and leave policy2 untouched + ps2.UpsertPolicySet(ps1) + + testutil.Equals(t, ps2.GetPolicy("policy0"), policyA) + testutil.Equals(t, ps2.GetPolicy("policy1"), &policyB) + testutil.Equals(t, ps2.GetPolicy("policy2"), policyC) + }) +} + func TestDeletePolicy(t *testing.T) { t.Parallel() t.Run("delete non-existent", func(t *testing.T) { From da1155bab341665b1229290f40b4d9e3be3fef47 Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Fri, 16 Aug 2024 11:02:55 -0700 Subject: [PATCH 156/216] cedar-go: push FileName down into the internal AST Position struct This avoids storing the file name as an appendage in cedar.Policy, which was kind of ugly. Arguably, the filename should live completely outside of the cedar library, but we can debate that later. Signed-off-by: philhassey --- internal/ast/policy.go | 7 ++++--- policy.go | 20 ++++---------------- 2 files changed, 8 insertions(+), 19 deletions(-) diff --git a/internal/ast/policy.go b/internal/ast/policy.go index 24382b25..b27f3808 100644 --- a/internal/ast/policy.go +++ b/internal/ast/policy.go @@ -30,9 +30,10 @@ const ( // Position is a value that represents a source Position. // A Position is valid if Line > 0. type Position struct { - Offset int // byte offset, starting at 0 - Line int // line number, starting at 1 - Column int // column number, starting at 1 (character count per line) + FileName string // optional name of the source file for the enclosing policy, "" if the source is unknown or not a named file + Offset int // byte offset, starting at 0 + Line int // line number, starting at 1 + Column int // column number, starting at 1 (character count per line) } type Policy struct { diff --git a/policy.go b/policy.go index 02241e8f..5e018f10 100644 --- a/policy.go +++ b/policy.go @@ -15,12 +15,10 @@ import ( type Policy struct { eval evaler // determines if a policy matches a request. ast *internalast.Policy - // TODO: Remove this and just store source file information in the generated policy ID? - sourceFile string } func newPolicy(astIn *internalast.Policy) Policy { - return Policy{eval: eval.Compile(astIn), ast: astIn, sourceFile: ""} + return Policy{eval: eval.Compile(astIn), ast: astIn} } // MarshalJSON encodes a single Policy statement in the JSON format specified by the [Cedar documentation]. @@ -92,24 +90,14 @@ func (p Policy) Effect() Effect { } // A Position describes an arbitrary source position including the file, line, and column location. -type Position struct { - Filename string // filename, if any - Offset int // byte offset, starting at 0 - Line int // line number, starting at 1 - Column int // column number, starting at 1 (character count per line) -} +type Position internalast.Position func (p Policy) Position() Position { - return Position{ - Filename: p.sourceFile, - Offset: p.ast.Position.Offset, - Line: p.ast.Position.Line, - Column: p.ast.Position.Column, - } + return Position(p.ast.Position) } func (p *Policy) SetSourceFile(path string) { - p.sourceFile = path + p.ast.Position.FileName = path } // PolicySlice represents a set of un-named Policy's. Cedar documents, unlike the JSON format, don't have a means of From 409236993a28d9c732cc4132cf98a6b6f81a5ea0 Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Fri, 16 Aug 2024 11:09:33 -0700 Subject: [PATCH 157/216] cedar-go: reshape the public MarshalCedar() implementations to conform to the normal Golang marshaing interface Signed-off-by: philhassey --- policy.go | 14 ++++++++++---- policy_set.go | 6 ++++-- policy_set_test.go | 6 +----- policy_test.go | 10 ++-------- 4 files changed, 17 insertions(+), 19 deletions(-) diff --git a/policy.go b/policy.go index 5e018f10..205ccac7 100644 --- a/policy.go +++ b/policy.go @@ -42,9 +42,13 @@ func (p *Policy) UnmarshalJSON(b []byte) error { return nil } -func (p *Policy) MarshalCedar(buf *bytes.Buffer) { +func (p *Policy) MarshalCedar() []byte { cedarPolicy := (*parser.Policy)(p.ast) - cedarPolicy.MarshalCedar(buf) + + var buf bytes.Buffer + cedarPolicy.MarshalCedar(&buf) + + return buf.Bytes() } func (p *Policy) UnmarshalCedar(b []byte) error { @@ -121,12 +125,14 @@ func (p *PolicySlice) UnmarshalCedar(b []byte) error { } // MarshalCedar emits a concatenated Cedar representation of a PolicySlice -func (p PolicySlice) MarshalCedar(buf *bytes.Buffer) { +func (p PolicySlice) MarshalCedar() []byte { + var buf bytes.Buffer for i, policy := range p { - policy.MarshalCedar(buf) + buf.Write(policy.MarshalCedar()) if i < len(p)-1 { buf.WriteString("\n\n") } } + return buf.Bytes() } diff --git a/policy_set.go b/policy_set.go index 97a4dddc..d5228b6c 100644 --- a/policy_set.go +++ b/policy_set.go @@ -65,21 +65,23 @@ func (p *PolicySet) UpsertPolicySet(src PolicySet) { // MarshalCedar emits a concatenated Cedar representation of a PolicySet. The policy names are stripped, but policies // are emitted in lexicographical order by ID. -func (p PolicySet) MarshalCedar(buf *bytes.Buffer) { +func (p PolicySet) MarshalCedar() []byte { ids := make([]PolicyID, 0, len(p.policies)) for k := range p.policies { ids = append(ids, k) } slices.Sort(ids) + var buf bytes.Buffer i := 0 for _, id := range ids { policy := p.policies[id] - policy.MarshalCedar(buf) + buf.Write(policy.MarshalCedar()) if i < len(p.policies)-1 { buf.WriteString("\n\n") } i++ } + return buf.Bytes() } diff --git a/policy_set_test.go b/policy_set_test.go index f94b1a85..3ea5046d 100644 --- a/policy_set_test.go +++ b/policy_set_test.go @@ -1,7 +1,6 @@ package cedar_test import ( - "bytes" "fmt" "testing" @@ -170,8 +169,5 @@ forbid ( testutil.Equals(t, ps.GetPolicy("policy0").Effect(), cedar.Permit) testutil.Equals(t, ps.GetPolicy("policy1").Effect(), cedar.Forbid) - var buf bytes.Buffer - ps.MarshalCedar(&buf) - - testutil.Equals(t, buf.String(), policiesStr) + testutil.Equals(t, string(ps.MarshalCedar()), policiesStr) } diff --git a/policy_test.go b/policy_test.go index 36fc10d1..9f4e419f 100644 --- a/policy_test.go +++ b/policy_test.go @@ -82,10 +82,7 @@ when { resource.owner == principal };` var policy cedar.Policy testutil.OK(t, policy.UnmarshalCedar([]byte(policyStr))) - var buf bytes.Buffer - policy.MarshalCedar(&buf) - - testutil.Equals(t, buf.String(), policyStr) + testutil.Equals(t, string(policy.MarshalCedar()), policyStr) } func TestPolicyAST(t *testing.T) { @@ -117,8 +114,5 @@ forbid ( var policies cedar.PolicySlice testutil.OK(t, policies.UnmarshalCedar([]byte(policiesStr))) - var buf bytes.Buffer - policies.MarshalCedar(&buf) - - testutil.Equals(t, buf.String(), policiesStr) + testutil.Equals(t, string(policies.MarshalCedar()), policiesStr) } From fb91513bcc154c9176fc3866bd34c45dc9df47bd Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Fri, 16 Aug 2024 11:36:34 -0700 Subject: [PATCH 158/216] cedar-go/types: rename ParsePattern to UnmarshalCedar and follow the typical Unmarshaling interface shape Signed-off-by: philhassey --- internal/eval/evalers_test.go | 3 ++- internal/json/json_test.go | 14 +++++++------- internal/parser/cedar_unmarshal.go | 4 ++-- internal/parser/cedar_unmarshal_test.go | 6 +++--- types/pattern.go | 12 +++++------- types/patttern_test.go | 7 ++++--- 6 files changed, 23 insertions(+), 23 deletions(-) diff --git a/internal/eval/evalers_test.go b/internal/eval/evalers_test.go index f530f88e..c0a7b1a6 100644 --- a/internal/eval/evalers_test.go +++ b/internal/eval/evalers_test.go @@ -1400,7 +1400,8 @@ func TestLikeNode(t *testing.T) { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() - pat, err := types.ParsePattern(tt.pattern[1 : len(tt.pattern)-1]) + var pat types.Pattern + err := pat.UnmarshalCedar([]byte(tt.pattern[1 : len(tt.pattern)-1])) testutil.OK(t, err) n := newLikeEval(tt.str, pat) v, err := n.Eval(&Context{}) diff --git a/internal/json/json_test.go b/internal/json/json_test.go index 8ccf5c53..8a06fde1 100644 --- a/internal/json/json_test.go +++ b/internal/json/json_test.go @@ -398,49 +398,49 @@ func TestUnmarshalJSON(t *testing.T) { "like single wildcard", `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, "conditions":[{"kind":"when","body":{"like":{"left":{"Value":"text"},"pattern":["Wildcard"]}}}]}`, - ast.Permit().When(ast.String("text").Like(testutil.Must(types.ParsePattern("*")))), + ast.Permit().When(ast.String("text").Like(types.Pattern{}.Wildcard())), testutil.OK, }, { "like single literal", `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, "conditions":[{"kind":"when","body":{"like":{"left":{"Value":"text"},"pattern":[{"Literal":"foo"}]}}}]}`, - ast.Permit().When(ast.String("text").Like(testutil.Must(types.ParsePattern("foo")))), + ast.Permit().When(ast.String("text").Like(types.Pattern{}.Literal("foo"))), testutil.OK, }, { "like wildcard then literal", `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, "conditions":[{"kind":"when","body":{"like":{"left":{"Value":"text"},"pattern":["Wildcard", {"Literal":"foo"}]}}}]}`, - ast.Permit().When(ast.String("text").Like(testutil.Must(types.ParsePattern("*foo")))), + ast.Permit().When(ast.String("text").Like(types.Pattern{}.Wildcard().Literal("foo"))), testutil.OK, }, { "like literal then wildcard", `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, "conditions":[{"kind":"when","body":{"like":{"left":{"Value":"text"},"pattern":[{"Literal":"foo"}, "Wildcard"]}}}]}`, - ast.Permit().When(ast.String("text").Like(testutil.Must(types.ParsePattern("foo*")))), + ast.Permit().When(ast.String("text").Like(types.Pattern{}.Literal("foo").Wildcard())), testutil.OK, }, { "like literal with asterisk then wildcard", `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, "conditions":[{"kind":"when","body":{"like":{"left":{"Value":"text"},"pattern":[{"Literal":"f*oo"}, "Wildcard"]}}}]}`, - ast.Permit().When(ast.String("text").Like(testutil.Must(types.ParsePattern(`f\*oo*`)))), + ast.Permit().When(ast.String("text").Like(types.Pattern{}.Literal("f*oo").Wildcard())), testutil.OK, }, { "like literal sandwich", `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, "conditions":[{"kind":"when","body":{"like":{"left":{"Value":"text"},"pattern":[{"Literal":"foo"}, "Wildcard", {"Literal":"bar"}]}}}]}`, - ast.Permit().When(ast.String("text").Like(testutil.Must(types.ParsePattern(`foo*bar`)))), + ast.Permit().When(ast.String("text").Like(types.Pattern{}.Literal("foo").Wildcard().Literal("bar"))), testutil.OK, }, { "like wildcard sandwich", `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, "conditions":[{"kind":"when","body":{"like":{"left":{"Value":"text"},"pattern":["Wildcard", {"Literal":"foo"}, "Wildcard"]}}}]}`, - ast.Permit().When(ast.String("text").Like(testutil.Must(types.ParsePattern(`*foo*`)))), + ast.Permit().When(ast.String("text").Like(types.Pattern{}.Wildcard().Literal("foo").Wildcard())), testutil.OK, }, { diff --git a/internal/parser/cedar_unmarshal.go b/internal/parser/cedar_unmarshal.go index 332b7d68..0c7fbbc5 100644 --- a/internal/parser/cedar_unmarshal.go +++ b/internal/parser/cedar_unmarshal.go @@ -557,8 +557,8 @@ func (p *parser) like(lhs ast.Node) (ast.Node, error) { patternRaw := t.Text patternRaw = strings.TrimPrefix(patternRaw, "\"") patternRaw = strings.TrimSuffix(patternRaw, "\"") - pattern, err := types.ParsePattern(patternRaw) - if err != nil { + var pattern types.Pattern + if err := pattern.UnmarshalCedar([]byte(patternRaw)); err != nil { return ast.Node{}, err } return lhs.Like(pattern), nil diff --git a/internal/parser/cedar_unmarshal_test.go b/internal/parser/cedar_unmarshal_test.go index 8eb6ddab..15d62c3e 100644 --- a/internal/parser/cedar_unmarshal_test.go +++ b/internal/parser/cedar_unmarshal_test.go @@ -288,19 +288,19 @@ when { principal has "1stName" };`, "like no wildcards", `permit ( principal, action, resource ) when { principal.firstName like "johnny" };`, - ast.Permit().When(ast.Principal().Access("firstName").Like(testutil.Must(types.ParsePattern("johnny")))), + ast.Permit().When(ast.Principal().Access("firstName").Like(types.Pattern{}.Literal("johnny"))), }, { "like escaped asterisk", `permit ( principal, action, resource ) when { principal.firstName like "joh\*nny" };`, - ast.Permit().When(ast.Principal().Access("firstName").Like(testutil.Must(types.ParsePattern(`joh\*nny`)))), + ast.Permit().When(ast.Principal().Access("firstName").Like(types.Pattern{}.Literal("joh*nny"))), }, { "like wildcard", `permit ( principal, action, resource ) when { principal.firstName like "*" };`, - ast.Permit().When(ast.Principal().Access("firstName").Like(testutil.Must(types.ParsePattern("*")))), + ast.Permit().When(ast.Principal().Access("firstName").Like(types.Pattern{}.Wildcard())), }, { "is", diff --git a/types/pattern.go b/types/pattern.go index 3aac8335..4ad7ba8d 100644 --- a/types/pattern.go +++ b/types/pattern.go @@ -58,8 +58,6 @@ func (p Pattern) Literal(s string) Pattern { return p } -// TODO: move this into the types package - // ported from Go's stdlib and reduced to our scope. // https://golang.org/src/path/filepath/match.go?s=1226:1284#L34 @@ -123,9 +121,7 @@ func matchChunk(chunk, s string) (rest string, ok bool) { return s, true } -func ParsePattern(s string) (Pattern, error) { - b := []byte(s) - +func (p *Pattern) UnmarshalCedar(b []byte) error { var comps []PatternComponent for len(b) > 0 { var comp PatternComponent @@ -136,9 +132,11 @@ func ParsePattern(s string) (Pattern, error) { } comp.Literal, b, err = rust.Unquote(b, true) if err != nil { - return Pattern{}, err + return err } comps = append(comps, comp) } - return comps, nil + + *p = comps + return nil } diff --git a/types/patttern_test.go b/types/patttern_test.go index d76159e1..000f37a4 100644 --- a/types/patttern_test.go +++ b/types/patttern_test.go @@ -75,8 +75,8 @@ func TestParsePattern(t *testing.T) { tt := tt t.Run(tt.input, func(t *testing.T) { t.Parallel() - got, err := ParsePattern(tt.input) - if err != nil { + var got Pattern + if err := got.UnmarshalCedar([]byte(tt.input)); err != nil { testutil.Equals(t, tt.wantOk, false) testutil.Equals(t, err.Error(), tt.wantErr) } else { @@ -118,7 +118,8 @@ func TestMatch(t *testing.T) { tt := tt t.Run(tt.pattern+":"+tt.target, func(t *testing.T) { t.Parallel() - pat, err := ParsePattern(tt.pattern[1 : len(tt.pattern)-1]) + var pat Pattern + err := pat.UnmarshalCedar([]byte(tt.pattern[1 : len(tt.pattern)-1])) testutil.OK(t, err) got := pat.Match(tt.target) testutil.Equals(t, got, tt.want) From 66eb49e3ac7eba61767e6cb0359137877d091d56 Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Fri, 16 Aug 2024 14:44:13 -0700 Subject: [PATCH 159/216] cedar-go/types: move Pattern JSON parsing into pattern.go Signed-off-by: philhassey --- internal/json/json.go | 17 +-- internal/json/json_marshal.go | 22 +--- internal/json/json_test.go | 22 +--- internal/json/json_unmarshal.go | 26 +--- internal/parser/cedar_unmarshal_test.go | 2 +- types/json.go | 15 +-- types/pattern.go | 71 +++++++++++ types/patttern_test.go | 157 ++++++++++++++++++++++++ 8 files changed, 248 insertions(+), 84 deletions(-) diff --git a/internal/json/json.go b/internal/json/json.go index 35b5c812..2c17cf10 100644 --- a/internal/json/json.go +++ b/internal/json/json.go @@ -44,18 +44,9 @@ type strJSON struct { Attr string `json:"attr"` } -type patternComponentLiteralJSON struct { - Literal string `json:"Literal,omitempty"` -} - -type patternComponentJSON struct { - Wildcard bool - Literal patternComponentLiteralJSON -} - -type patternJSON struct { - Left nodeJSON `json:"left"` - Pattern []patternComponentJSON `json:"pattern"` +type likeJSON struct { + Left nodeJSON `json:"left"` + Pattern types.Pattern `json:"pattern"` } type isJSON struct { @@ -126,7 +117,7 @@ type nodeJSON struct { Is *isJSON `json:"is,omitempty"` // like - Like *patternJSON `json:"like,omitempty"` + Like *likeJSON `json:"like,omitempty"` // if-then-else IfThenElse *ifThenElseJSON `json:"if-then-else,omitempty"` diff --git a/internal/json/json_marshal.go b/internal/json/json_marshal.go index 47815d31..ff0b2877 100644 --- a/internal/json/json_marshal.go +++ b/internal/json/json_marshal.go @@ -94,17 +94,10 @@ func strToJSON(dest **strJSON, src ast.StrOpNode) { *dest = res } -func patternToJSON(dest **patternJSON, src ast.NodeTypeLike) { - res := &patternJSON{} +func likeToJSON(dest **likeJSON, src ast.NodeTypeLike) { + res := &likeJSON{} res.Left.FromNode(src.Arg) - for _, comp := range src.Value { - if comp.Wildcard { - res.Pattern = append(res.Pattern, patternComponentJSON{Wildcard: true}) - } - if comp.Literal != "" { - res.Pattern = append(res.Pattern, patternComponentJSON{Literal: patternComponentLiteralJSON{Literal: comp.Literal}}) - } - } + res.Pattern = src.Value *dest = res } @@ -246,7 +239,7 @@ func (j *nodeJSON) FromNode(src ast.IsNode) { // like // Like *strJSON `json:"like"` case ast.NodeTypeLike: - patternToJSON(&j.Like, t) + likeToJSON(&j.Like, t) return // if-then-else @@ -289,13 +282,6 @@ func (j *nodeJSON) MarshalJSON() ([]byte, error) { return json.Marshal((*nodeJSONAlias)(j)) } -func (p *patternComponentJSON) MarshalJSON() ([]byte, error) { - if p.Wildcard { - return json.Marshal("Wildcard") - } - return json.Marshal(p.Literal) -} - type Policy ast.Policy func wrapPolicy(p *ast.Policy) *Policy { diff --git a/internal/json/json_test.go b/internal/json/json_test.go index 8a06fde1..f425e686 100644 --- a/internal/json/json_test.go +++ b/internal/json/json_test.go @@ -394,6 +394,7 @@ func TestUnmarshalJSON(t *testing.T) { ast.Permit().When(ast.Resource().IsIn("T", ast.EntityUID("P", "42"))), testutil.OK, }, + // N.B. Most pattern parsing tests can be found in types/pattern_test.go { "like single wildcard", `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, @@ -422,27 +423,6 @@ func TestUnmarshalJSON(t *testing.T) { ast.Permit().When(ast.String("text").Like(types.Pattern{}.Literal("foo").Wildcard())), testutil.OK, }, - { - "like literal with asterisk then wildcard", - `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, - "conditions":[{"kind":"when","body":{"like":{"left":{"Value":"text"},"pattern":[{"Literal":"f*oo"}, "Wildcard"]}}}]}`, - ast.Permit().When(ast.String("text").Like(types.Pattern{}.Literal("f*oo").Wildcard())), - testutil.OK, - }, - { - "like literal sandwich", - `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, - "conditions":[{"kind":"when","body":{"like":{"left":{"Value":"text"},"pattern":[{"Literal":"foo"}, "Wildcard", {"Literal":"bar"}]}}}]}`, - ast.Permit().When(ast.String("text").Like(types.Pattern{}.Literal("foo").Wildcard().Literal("bar"))), - testutil.OK, - }, - { - "like wildcard sandwich", - `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, - "conditions":[{"kind":"when","body":{"like":{"left":{"Value":"text"},"pattern":["Wildcard", {"Literal":"foo"}, "Wildcard"]}}}]}`, - ast.Permit().When(ast.String("text").Like(types.Pattern{}.Wildcard().Literal("foo").Wildcard())), - testutil.OK, - }, { "ifThenElse", `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, diff --git a/internal/json/json_unmarshal.go b/internal/json/json_unmarshal.go index fc2676a6..39bba7c7 100644 --- a/internal/json/json_unmarshal.go +++ b/internal/json/json_unmarshal.go @@ -61,21 +61,13 @@ func (j strJSON) ToNode(f func(a ast.Node, k string) ast.Node) (ast.Node, error) } return f(left, j.Attr), nil } -func (j patternJSON) ToNode(f func(a ast.Node, k types.Pattern) ast.Node) (ast.Node, error) { +func (j likeJSON) ToNode(f func(a ast.Node, k types.Pattern) ast.Node) (ast.Node, error) { left, err := j.Left.ToNode() if err != nil { return ast.Node{}, fmt.Errorf("error in left: %w", err) } - pattern := types.Pattern{} - for _, compJSON := range j.Pattern { - if compJSON.Wildcard { - pattern = pattern.Wildcard() - } else { - pattern = pattern.Literal(compJSON.Literal.Literal) - } - } - return f(left, pattern), nil + return f(left, j.Pattern), nil } func (j isJSON) ToNode() (ast.Node, error) { left, err := j.Left.ToNode() @@ -266,20 +258,6 @@ func (n *nodeJSON) UnmarshalJSON(b []byte) error { return json.Unmarshal(b, &n.ExtensionCall) } -func (p *patternComponentJSON) UnmarshalJSON(b []byte) error { - var wildcard string - err := json.Unmarshal(b, &wildcard) - if err == nil { - if wildcard != "Wildcard" { - return fmt.Errorf("unknown pattern component: \"%v\"", wildcard) - } - p.Wildcard = true - return nil - } - - return json.Unmarshal(b, &p.Literal) -} - func (p *Policy) UnmarshalJSON(b []byte) error { var j policyJSON if err := json.Unmarshal(b, &j); err != nil { diff --git a/internal/parser/cedar_unmarshal_test.go b/internal/parser/cedar_unmarshal_test.go index 15d62c3e..8f2c72ca 100644 --- a/internal/parser/cedar_unmarshal_test.go +++ b/internal/parser/cedar_unmarshal_test.go @@ -283,7 +283,7 @@ when { principal has firstName };`, when { principal has "1stName" };`, ast.Permit().When(ast.Principal().Has("1stName")), }, - // N.B. Most pattern parsing tests can be found in pattern_test.go + // N.B. Most pattern parsing tests can be found in types/pattern_test.go { "like no wildcards", `permit ( principal, action, resource ) diff --git a/types/json.go b/types/json.go index 8dd29cd7..7e19cb24 100644 --- a/types/json.go +++ b/types/json.go @@ -7,13 +7,14 @@ import ( ) var ( - errJSONInvalidExtn = fmt.Errorf("invalid extension") - errJSONDecode = fmt.Errorf("error decoding json") - errJSONLongOutOfRange = fmt.Errorf("long out of range") - errJSONUnsupportedType = fmt.Errorf("unsupported type") - errJSONExtFnMatch = fmt.Errorf("json extn mismatch") - errJSONExtNotFound = fmt.Errorf("json extn not found") - errJSONEntityNotFound = fmt.Errorf("json entity not found") + errJSONInvalidExtn = fmt.Errorf("invalid extension") + errJSONDecode = fmt.Errorf("error decoding json") + errJSONLongOutOfRange = fmt.Errorf("long out of range") + errJSONUnsupportedType = fmt.Errorf("unsupported type") + errJSONExtFnMatch = fmt.Errorf("json extn mismatch") + errJSONExtNotFound = fmt.Errorf("json extn not found") + errJSONEntityNotFound = fmt.Errorf("json entity not found") + errJSONInvalidPatternComponent = fmt.Errorf("invalid pattern component") ) type extn struct { diff --git a/types/pattern.go b/types/pattern.go index 4ad7ba8d..74e48bfe 100644 --- a/types/pattern.go +++ b/types/pattern.go @@ -2,6 +2,8 @@ package types import ( "bytes" + "encoding/json" + "fmt" "strconv" "strings" @@ -140,3 +142,72 @@ func (p *Pattern) UnmarshalCedar(b []byte) error { *p = comps return nil } + +func (p Pattern) MarshalJSON() ([]byte, error) { + var buf bytes.Buffer + buf.WriteRune('[') + for i, comp := range p { + if comp.Wildcard { + buf.WriteString(`"Wildcard"`) + } + + if comp.Literal != "" { + if comp.Wildcard { + buf.WriteString(", ") + } + buf.WriteString(`{"Literal":"`) + buf.WriteString(comp.Literal) + buf.WriteString(`"}`) + } + + if i < len(p)-1 { + buf.WriteString(", ") + } + } + buf.WriteRune(']') + return buf.Bytes(), nil +} + +func (p *Pattern) UnmarshalJSON(b []byte) error { + dec := json.NewDecoder(bytes.NewReader(b)) + var comps []any + if err := dec.Decode(&comps); err != nil { + return err + } + + if len(comps) == 0 { + return fmt.Errorf(`%w: must provide at least one pattern component`, errJSONInvalidPatternComponent) + } + + newPattern := Pattern{} + for _, comp := range comps { + switch v := comp.(type) { + case string: + if v != "Wildcard" { + return fmt.Errorf(`%w: invalid component string "%v"`, errJSONInvalidPatternComponent, v) + } + newPattern = newPattern.Wildcard() + case map[string]any: + if len(v) != 1 { + return fmt.Errorf(`%w: too many keys in literal object`, errJSONInvalidPatternComponent) + } + + literal, ok := v["Literal"] + if !ok { + return fmt.Errorf(`%w: missing "Literal" key in literal object`, errJSONInvalidPatternComponent) + } + + literalStr, ok := literal.(string) + if !ok { + return fmt.Errorf(`%w: invalid "Literal" value "%v"`, errJSONInvalidPatternComponent, literal) + } + + newPattern = newPattern.Literal(literalStr) + default: + return fmt.Errorf(`%w: unknown component type`, errJSONInvalidPatternComponent) + } + } + + *p = newPattern + return nil +} diff --git a/types/patttern_test.go b/types/patttern_test.go index 000f37a4..0875a80f 100644 --- a/types/patttern_test.go +++ b/types/patttern_test.go @@ -126,3 +126,160 @@ func TestMatch(t *testing.T) { }) } } + +func TestJSON(t *testing.T) { + t.Parallel() + tests := []struct { + name string + pattern string + errFunc func(testing.TB, error) + target Pattern + shouldRoundTrip bool + }{{ + "like single wildcard", + `["Wildcard"]`, + testutil.OK, + Pattern{}.Wildcard(), + true, + }, + { + "like single literal", + `[{"Literal":"foo"}]`, + testutil.OK, + Pattern{}.Literal("foo"), + true, + }, + { + "like wildcard then literal", + `["Wildcard", {"Literal":"foo"}]`, + testutil.OK, + Pattern{}.Wildcard().Literal("foo"), + true, + }, + { + "like literal then wildcard", + `[{"Literal":"foo"}, "Wildcard"]`, + testutil.OK, + Pattern{}.Literal("foo").Wildcard(), + true, + }, + { + "like literal with asterisk then wildcard", + `[{"Literal":"f*oo"}, "Wildcard"]`, + testutil.OK, + Pattern{}.Literal("f*oo").Wildcard(), + true, + }, + { + "like literal sandwich", + `[{"Literal":"foo"}, "Wildcard", {"Literal":"bar"}]`, + testutil.OK, + Pattern{}.Literal("foo").Wildcard().Literal("bar"), + true, + }, + { + "like wildcard sandwich", + `["Wildcard", {"Literal":"foo"}, "Wildcard"]`, + testutil.OK, + Pattern{}.Wildcard().Literal("foo").Wildcard(), + true, + }, + { + "double wildcard", + `["Wildcard", "Wildcard", {"Literal":"foo"}]`, + testutil.OK, + Pattern{}.Wildcard().Literal("foo"), + false, + }, + { + "double literal", + `["Wildcard", {"Literal":"foo"}, {"Literal":"bar"}]`, + testutil.OK, + Pattern{}.Wildcard().Literal("foobar"), + false, + }, + { + "literal with asterisk", + `["Wildcard", {"Literal":"foo*"}, "Wildcard"]`, + testutil.OK, + Pattern{}.Wildcard().Literal("foo*").Wildcard(), + true, + }, + { + "not list", + `"Wildcard"`, + testutil.Error, + Pattern{}, + false, + }, + { + "lower case wildcard", + `["wildcard"]`, + testutil.Error, + Pattern{}, + false, + }, + { + "other string", + `["cardwild"]`, + testutil.Error, + Pattern{}, + false, + }, + { + "lowercase literal", + `[{"literal": "foo"}]`, + testutil.Error, + Pattern{}, + false, + }, + { + "missing literal", + `[{"figurative": "haha"}]`, + testutil.Error, + Pattern{}, + false, + }, + { + "two keys", + `[{"Literal": "foo", "Figurative": "haha"}]`, + testutil.Error, + Pattern{}, + false, + }, + { + "nonstring literal", + `[{"Literal": 2}]`, + testutil.Error, + Pattern{}, + false, + }, + { + "empty pattern", + `[]`, + testutil.Error, + Pattern{}, + false, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + var pat Pattern + err := pat.UnmarshalJSON([]byte(tt.pattern)) + tt.errFunc(t, err) + if err != nil { + return + } + + marshaled, err := pat.MarshalJSON() + testutil.OK(t, err) + + if tt.shouldRoundTrip { + testutil.Equals(t, string(marshaled), tt.pattern) + } + }) + } +} From 7dcf6829729777276bccbb8f53aa937e5a0a9482 Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Fri, 16 Aug 2024 15:42:06 -0700 Subject: [PATCH 160/216] cedar-go/types: make patternComponent a private element of Pattern and make Pattern immutable Signed-off-by: philhassey --- ast/ast_test.go | 4 +- internal/ast/ast_test.go | 4 +- internal/json/json_test.go | 8 +- internal/parser/cedar_unmarshal_test.go | 6 +- types/pattern.go | 97 ++++++++++++-------- types/patttern_test.go | 115 ++++++++++-------------- 6 files changed, 115 insertions(+), 119 deletions(-) diff --git a/ast/ast_test.go b/ast/ast_test.go index e5f4992f..96355eef 100644 --- a/ast/ast_test.go +++ b/ast/ast_test.go @@ -295,8 +295,8 @@ func TestASTByTable(t *testing.T) { }, { "opLike", - ast.Permit().When(ast.Long(42).Like(types.Pattern{}.Wildcard())), - internalast.Permit().When(internalast.Long(42).Like(types.Pattern{}.Wildcard())), + ast.Permit().When(ast.Long(42).Like(types.WildcardPattern)), + internalast.Permit().When(internalast.Long(42).Like(types.WildcardPattern)), }, { "opAnd", diff --git a/internal/ast/ast_test.go b/internal/ast/ast_test.go index ee7f9056..775df819 100644 --- a/internal/ast/ast_test.go +++ b/internal/ast/ast_test.go @@ -350,9 +350,9 @@ func TestASTByTable(t *testing.T) { }, { "opLike", - ast.Permit().When(ast.Long(42).Like(types.Pattern{}.Wildcard())), + ast.Permit().When(ast.Long(42).Like(types.WildcardPattern)), ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, - Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeLike{Arg: ast.NodeValue{Value: types.Long(42)}, Value: types.Pattern{}.Wildcard()}}}}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeLike{Arg: ast.NodeValue{Value: types.Long(42)}, Value: types.WildcardPattern}}}}, }, { "opAnd", diff --git a/internal/json/json_test.go b/internal/json/json_test.go index f425e686..a537b0a7 100644 --- a/internal/json/json_test.go +++ b/internal/json/json_test.go @@ -399,28 +399,28 @@ func TestUnmarshalJSON(t *testing.T) { "like single wildcard", `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, "conditions":[{"kind":"when","body":{"like":{"left":{"Value":"text"},"pattern":["Wildcard"]}}}]}`, - ast.Permit().When(ast.String("text").Like(types.Pattern{}.Wildcard())), + ast.Permit().When(ast.String("text").Like(types.WildcardPattern)), testutil.OK, }, { "like single literal", `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, "conditions":[{"kind":"when","body":{"like":{"left":{"Value":"text"},"pattern":[{"Literal":"foo"}]}}}]}`, - ast.Permit().When(ast.String("text").Like(types.Pattern{}.Literal("foo"))), + ast.Permit().When(ast.String("text").Like(types.LiteralPattern("foo"))), testutil.OK, }, { "like wildcard then literal", `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, "conditions":[{"kind":"when","body":{"like":{"left":{"Value":"text"},"pattern":["Wildcard", {"Literal":"foo"}]}}}]}`, - ast.Permit().When(ast.String("text").Like(types.Pattern{}.Wildcard().Literal("foo"))), + ast.Permit().When(ast.String("text").Like(types.LiteralPattern("foo"))), testutil.OK, }, { "like literal then wildcard", `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, "conditions":[{"kind":"when","body":{"like":{"left":{"Value":"text"},"pattern":[{"Literal":"foo"}, "Wildcard"]}}}]}`, - ast.Permit().When(ast.String("text").Like(types.Pattern{}.Literal("foo").Wildcard())), + ast.Permit().When(ast.String("text").Like(types.LiteralPattern("foo"))), testutil.OK, }, { diff --git a/internal/parser/cedar_unmarshal_test.go b/internal/parser/cedar_unmarshal_test.go index 8f2c72ca..0630ccb3 100644 --- a/internal/parser/cedar_unmarshal_test.go +++ b/internal/parser/cedar_unmarshal_test.go @@ -288,19 +288,19 @@ when { principal has "1stName" };`, "like no wildcards", `permit ( principal, action, resource ) when { principal.firstName like "johnny" };`, - ast.Permit().When(ast.Principal().Access("firstName").Like(types.Pattern{}.Literal("johnny"))), + ast.Permit().When(ast.Principal().Access("firstName").Like(types.LiteralPattern("johnny"))), }, { "like escaped asterisk", `permit ( principal, action, resource ) when { principal.firstName like "joh\*nny" };`, - ast.Permit().When(ast.Principal().Access("firstName").Like(types.Pattern{}.Literal("joh*nny"))), + ast.Permit().When(ast.Principal().Access("firstName").Like(types.LiteralPattern("joh*nny"))), }, { "like wildcard", `permit ( principal, action, resource ) when { principal.firstName like "*" };`, - ast.Permit().When(ast.Principal().Access("firstName").Like(types.Pattern{}.Wildcard())), + ast.Permit().When(ast.Principal().Access("firstName").Like(types.WildcardPattern)), }, { "is", diff --git a/types/pattern.go b/types/pattern.go index 74e48bfe..65794596 100644 --- a/types/pattern.go +++ b/types/pattern.go @@ -10,19 +10,34 @@ import ( "github.com/cedar-policy/cedar-go/internal/rust" ) -type PatternComponent struct { +type patternComponent struct { Wildcard bool Literal string } // Pattern is used to define a string used for the like operator. It does not // conform to the Value interface, as it is not one of the Cedar types. -type Pattern []PatternComponent +type Pattern struct { + comps []patternComponent +} + +var WildcardPattern = newPattern([]patternComponent{{Wildcard: true}}) + +func newPattern(comps []patternComponent) Pattern { + return Pattern{comps: comps} +} + +func LiteralPattern(literal string) Pattern { + if literal == "" { + return newPattern(nil) + } + return newPattern([]patternComponent{{Wildcard: false, Literal: literal}}) +} func (p Pattern) Cedar() string { var buf bytes.Buffer buf.WriteRune('"') - for _, comp := range p { + for _, comp := range p.comps { if comp.Wildcard { buf.WriteRune('*') } @@ -36,30 +51,6 @@ func (p Pattern) Cedar() string { return buf.String() } -func (p Pattern) Wildcard() Pattern { - star := PatternComponent{Wildcard: true} - if len(p) == 0 { - p = Pattern{star} - return p - } - - lastComp := p[len(p)-1] - if lastComp.Wildcard && lastComp.Literal == "" { - return p - } - - p = append(p, star) - return p -} - -func (p Pattern) Literal(s string) Pattern { - if len(p) == 0 { - p = Pattern{{}} - } - p[len(p)-1].Literal += s - return p -} - // ported from Go's stdlib and reduced to our scope. // https://golang.org/src/path/filepath/match.go?s=1226:1284#L34 @@ -73,8 +64,8 @@ func (p Pattern) Literal(s string) Pattern { // c matches character c (c != '*') func (p Pattern) Match(arg string) (matched bool) { Pattern: - for i, comp := range p { - lastChunk := i == len(p)-1 + for i, comp := range p.comps { + lastChunk := i == len(p.comps)-1 if comp.Wildcard && comp.Literal == "" { return true } @@ -124,9 +115,9 @@ func matchChunk(chunk, s string) (rest string, ok bool) { } func (p *Pattern) UnmarshalCedar(b []byte) error { - var comps []PatternComponent + var comps []patternComponent for len(b) > 0 { - var comp PatternComponent + var comp patternComponent var err error for len(b) > 0 && b[0] == '*' { b = b[1:] @@ -139,14 +130,14 @@ func (p *Pattern) UnmarshalCedar(b []byte) error { comps = append(comps, comp) } - *p = comps + *p = Pattern{comps: comps} return nil } func (p Pattern) MarshalJSON() ([]byte, error) { var buf bytes.Buffer buf.WriteRune('[') - for i, comp := range p { + for i, comp := range p.comps { if comp.Wildcard { buf.WriteString(`"Wildcard"`) } @@ -160,7 +151,7 @@ func (p Pattern) MarshalJSON() ([]byte, error) { buf.WriteString(`"}`) } - if i < len(p)-1 { + if i < len(p.comps)-1 { buf.WriteString(", ") } } @@ -179,14 +170,14 @@ func (p *Pattern) UnmarshalJSON(b []byte) error { return fmt.Errorf(`%w: must provide at least one pattern component`, errJSONInvalidPatternComponent) } - newPattern := Pattern{} + pb := PatternBuilder{} for _, comp := range comps { switch v := comp.(type) { case string: if v != "Wildcard" { return fmt.Errorf(`%w: invalid component string "%v"`, errJSONInvalidPatternComponent, v) } - newPattern = newPattern.Wildcard() + pb = pb.AddWildcard() case map[string]any: if len(v) != 1 { return fmt.Errorf(`%w: too many keys in literal object`, errJSONInvalidPatternComponent) @@ -202,12 +193,42 @@ func (p *Pattern) UnmarshalJSON(b []byte) error { return fmt.Errorf(`%w: invalid "Literal" value "%v"`, errJSONInvalidPatternComponent, literal) } - newPattern = newPattern.Literal(literalStr) + pb = pb.AddLiteral(literalStr) default: return fmt.Errorf(`%w: unknown component type`, errJSONInvalidPatternComponent) } } - *p = newPattern + *p = pb.Build() return nil } + +// PatternBuilder can be used to programmatically build a Cedar pattern string, like so: +// PatternBuilder{}.AddWildcard().AddLiteral("foo").AddWildcard().Build() +type PatternBuilder []patternComponent + +func (p PatternBuilder) AddWildcard() PatternBuilder { + star := patternComponent{Wildcard: true} + if len(p) == 0 { + return PatternBuilder{star} + } + + lastComp := (p)[len(p)-1] + if lastComp.Wildcard && lastComp.Literal == "" { + return p + } + + return append(p, star) +} + +func (p PatternBuilder) AddLiteral(s string) PatternBuilder { + if len(p) == 0 { + p = PatternBuilder{patternComponent{}} + } + p[len(p)-1].Literal += s + return p +} + +func (p PatternBuilder) Build() Pattern { + return newPattern(p) +} diff --git a/types/patttern_test.go b/types/patttern_test.go index 0875a80f..a0a9376d 100644 --- a/types/patttern_test.go +++ b/types/patttern_test.go @@ -7,27 +7,15 @@ import ( ) func TestPatternFromBuilder(t *testing.T) { - tests := []struct { - name string - Pattern Pattern - want []PatternComponent - }{ - {"empty", Pattern{}, Pattern{}}, - {"wildcard", (Pattern{}).Wildcard(), Pattern{{Wildcard: true}}}, - {"saturate two wildcards", (Pattern{}).Wildcard().Wildcard(), Pattern{{Wildcard: true}}}, - {"literal", (Pattern{}).Literal("foo"), Pattern{{Literal: "foo"}}}, - {"saturate two literals", (Pattern{}).Literal("foo").Literal("bar"), Pattern{{Literal: "foobar"}}}, - {"literal with asterisk", (Pattern{}).Literal("fo*o"), Pattern{{Literal: "fo*o"}}}, - {"wildcard sandwich", (Pattern{}).Literal("foo").Wildcard().Literal("bar"), Pattern{{Literal: "foo"}, {Wildcard: true, Literal: "bar"}}}, - {"literal sandwich", (Pattern{}).Wildcard().Literal("foo").Wildcard(), Pattern{{Wildcard: true, Literal: "foo"}, {Wildcard: true}}}, - } - for _, tt := range tests { - tt := tt - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - testutil.Equals(t, tt.Pattern, tt.want) - }) - } + t.Run("saturate two wildcards", func(t *testing.T) { + pattern1 := PatternBuilder{}.AddWildcard().AddWildcard().Build() + testutil.Equals(t, pattern1, WildcardPattern) + }) + t.Run("saturate two literals", func(t *testing.T) { + pattern1 := PatternBuilder{}.AddLiteral("foo").AddLiteral("bar").Build() + pattern2 := LiteralPattern("foobar") + testutil.Equals(t, pattern1, pattern2) + }) } func TestParsePattern(t *testing.T) { @@ -35,41 +23,27 @@ func TestParsePattern(t *testing.T) { tests := []struct { input string wantOk bool - want []PatternComponent + want Pattern wantErr string }{ - {"", true, nil, ""}, - {"a", true, []PatternComponent{{false, "a"}}, ""}, - {"*", true, []PatternComponent{{true, ""}}, ""}, - {"*a", true, []PatternComponent{{true, "a"}}, ""}, - {"a*", true, []PatternComponent{{false, "a"}, {true, ""}}, ""}, - {"**", true, []PatternComponent{{true, ""}}, ""}, - {"**a", true, []PatternComponent{{true, "a"}}, ""}, - {"a**", true, []PatternComponent{{false, "a"}, {true, ""}}, ""}, - {"*a*", true, []PatternComponent{{true, "a"}, {true, ""}}, ""}, - {"**a**", true, []PatternComponent{{true, "a"}, {true, ""}}, ""}, - {"abra*ca", true, []PatternComponent{ - {false, "abra"}, {true, "ca"}, - }, ""}, - {"abra**ca", true, []PatternComponent{ - {false, "abra"}, {true, "ca"}, - }, ""}, - {"*abra*ca", true, []PatternComponent{ - {true, "abra"}, {true, "ca"}, - }, ""}, - {"abra*ca*", true, []PatternComponent{ - {false, "abra"}, {true, "ca"}, {true, ""}, - }, ""}, - {"*abra*ca*", true, []PatternComponent{ - {true, "abra"}, {true, "ca"}, {true, ""}, - }, ""}, - {"*abra*ca*dabra", true, []PatternComponent{ - {true, "abra"}, {true, "ca"}, {true, "dabra"}, - }, ""}, - {`*abra*c\**da\*ra`, true, []PatternComponent{ - {true, "abra"}, {true, "c*"}, {true, "da*ra"}, - }, ""}, - {`\u`, false, nil, "bad unicode rune"}, + {"", true, LiteralPattern(""), ""}, + {"a", true, LiteralPattern("a"), ""}, + {"*", true, WildcardPattern, ""}, + {"*a", true, PatternBuilder{}.AddWildcard().AddLiteral("a").Build(), ""}, + {"a*", true, PatternBuilder{}.AddLiteral("a").AddWildcard().Build(), ""}, + {"**", true, WildcardPattern, ""}, + {"**a", true, PatternBuilder{}.AddWildcard().AddLiteral("a").Build(), ""}, + {"a**", true, PatternBuilder{}.AddLiteral("a").AddWildcard().Build(), ""}, + {"*a*", true, PatternBuilder{}.AddWildcard().AddLiteral("a").AddWildcard().Build(), ""}, + {"**a**", true, PatternBuilder{}.AddWildcard().AddLiteral("a").AddWildcard().Build(), ""}, + {"abra*ca", true, PatternBuilder{}.AddLiteral("abra").AddWildcard().AddLiteral("ca").Build(), ""}, + {"abra**ca", true, PatternBuilder{}.AddLiteral("abra").AddWildcard().AddLiteral("ca").Build(), ""}, + {"*abra*ca", true, PatternBuilder{}.AddWildcard().AddLiteral("abra").AddWildcard().AddLiteral("ca").Build(), ""}, + {"abra*ca*", true, PatternBuilder{}.AddLiteral("abra").AddWildcard().AddLiteral("ca").AddWildcard().Build(), ""}, + {"*abra*ca*", true, PatternBuilder{}.AddWildcard().AddLiteral("abra").AddWildcard().AddLiteral("ca").AddWildcard().Build(), ""}, + {"*abra*ca*dabra", true, PatternBuilder{}.AddWildcard().AddLiteral("abra").AddWildcard().AddLiteral("ca").AddWildcard().AddLiteral("dabra").Build(), ""}, + {`*abra*c\**da\*bra`, true, PatternBuilder{}.AddWildcard().AddLiteral("abra").AddWildcard().AddLiteral("c*").AddWildcard().AddLiteral("da*bra").Build(), ""}, + {`\u`, false, Pattern{}, "bad unicode rune"}, } for _, tt := range tests { tt := tt @@ -135,74 +109,75 @@ func TestJSON(t *testing.T) { errFunc func(testing.TB, error) target Pattern shouldRoundTrip bool - }{{ - "like single wildcard", - `["Wildcard"]`, - testutil.OK, - Pattern{}.Wildcard(), - true, - }, + }{ + { + "like single wildcard", + `["Wildcard"]`, + testutil.OK, + WildcardPattern, + true, + }, { "like single literal", `[{"Literal":"foo"}]`, testutil.OK, - Pattern{}.Literal("foo"), + LiteralPattern("foo"), true, }, { "like wildcard then literal", `["Wildcard", {"Literal":"foo"}]`, testutil.OK, - Pattern{}.Wildcard().Literal("foo"), + PatternBuilder{}.AddWildcard().AddLiteral("foo").Build(), true, }, { "like literal then wildcard", `[{"Literal":"foo"}, "Wildcard"]`, testutil.OK, - Pattern{}.Literal("foo").Wildcard(), + PatternBuilder{}.AddLiteral("foo").AddWildcard().Build(), true, }, { "like literal with asterisk then wildcard", `[{"Literal":"f*oo"}, "Wildcard"]`, testutil.OK, - Pattern{}.Literal("f*oo").Wildcard(), + PatternBuilder{}.AddLiteral("f*oo").AddWildcard().Build(), true, }, { "like literal sandwich", `[{"Literal":"foo"}, "Wildcard", {"Literal":"bar"}]`, testutil.OK, - Pattern{}.Literal("foo").Wildcard().Literal("bar"), + PatternBuilder{}.AddLiteral("foo").AddWildcard().AddLiteral("bar").Build(), true, }, { "like wildcard sandwich", `["Wildcard", {"Literal":"foo"}, "Wildcard"]`, testutil.OK, - Pattern{}.Wildcard().Literal("foo").Wildcard(), + PatternBuilder{}.AddWildcard().AddLiteral("foo").AddWildcard().Build(), true, }, { "double wildcard", `["Wildcard", "Wildcard", {"Literal":"foo"}]`, testutil.OK, - Pattern{}.Wildcard().Literal("foo"), + PatternBuilder{}.AddWildcard().AddLiteral("foo").Build(), false, }, { "double literal", `["Wildcard", {"Literal":"foo"}, {"Literal":"bar"}]`, testutil.OK, - Pattern{}.Wildcard().Literal("foobar"), + PatternBuilder{}.AddWildcard().AddLiteral("foobar").Build(), false, }, { "literal with asterisk", `["Wildcard", {"Literal":"foo*"}, "Wildcard"]`, testutil.OK, - Pattern{}.Wildcard().Literal("foo*").Wildcard(), + PatternBuilder{}.AddWildcard().AddLiteral("foo*").AddWildcard().Build(), true, }, { From 209239c295225c09d533e089f33fe7b415c12d2f Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Fri, 16 Aug 2024 16:35:26 -0700 Subject: [PATCH 161/216] cedar-go/types: give a new coat of paint to Pattern This removes the PatternBuilder interface in favor of a NewPattern constructor that takes components as a variadic argument. This greatly improves the terseness of creating a Pattern programmatically. Signed-off-by: philhassey --- ast/ast_test.go | 4 +- internal/ast/ast_test.go | 4 +- internal/json/json_test.go | 8 +- internal/parser/cedar_unmarshal_test.go | 6 +- types/pattern.go | 100 +++++++++++------------- types/patttern_test.go | 69 ++++++++-------- 6 files changed, 90 insertions(+), 101 deletions(-) diff --git a/ast/ast_test.go b/ast/ast_test.go index 96355eef..45e7a66b 100644 --- a/ast/ast_test.go +++ b/ast/ast_test.go @@ -295,8 +295,8 @@ func TestASTByTable(t *testing.T) { }, { "opLike", - ast.Permit().When(ast.Long(42).Like(types.WildcardPattern)), - internalast.Permit().When(internalast.Long(42).Like(types.WildcardPattern)), + ast.Permit().When(ast.Long(42).Like(types.NewPattern(types.Wildcard))), + internalast.Permit().When(internalast.Long(42).Like(types.NewPattern(types.Wildcard))), }, { "opAnd", diff --git a/internal/ast/ast_test.go b/internal/ast/ast_test.go index 775df819..dced3c46 100644 --- a/internal/ast/ast_test.go +++ b/internal/ast/ast_test.go @@ -350,9 +350,9 @@ func TestASTByTable(t *testing.T) { }, { "opLike", - ast.Permit().When(ast.Long(42).Like(types.WildcardPattern)), + ast.Permit().When(ast.Long(42).Like(types.NewPattern(types.Wildcard))), ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, - Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeLike{Arg: ast.NodeValue{Value: types.Long(42)}, Value: types.WildcardPattern}}}}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeLike{Arg: ast.NodeValue{Value: types.Long(42)}, Value: types.NewPattern(types.Wildcard)}}}}, }, { "opAnd", diff --git a/internal/json/json_test.go b/internal/json/json_test.go index a537b0a7..f60ad439 100644 --- a/internal/json/json_test.go +++ b/internal/json/json_test.go @@ -399,28 +399,28 @@ func TestUnmarshalJSON(t *testing.T) { "like single wildcard", `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, "conditions":[{"kind":"when","body":{"like":{"left":{"Value":"text"},"pattern":["Wildcard"]}}}]}`, - ast.Permit().When(ast.String("text").Like(types.WildcardPattern)), + ast.Permit().When(ast.String("text").Like(types.NewPattern(types.Wildcard))), testutil.OK, }, { "like single literal", `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, "conditions":[{"kind":"when","body":{"like":{"left":{"Value":"text"},"pattern":[{"Literal":"foo"}]}}}]}`, - ast.Permit().When(ast.String("text").Like(types.LiteralPattern("foo"))), + ast.Permit().When(ast.String("text").Like(types.NewPattern(types.String("foo")))), testutil.OK, }, { "like wildcard then literal", `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, "conditions":[{"kind":"when","body":{"like":{"left":{"Value":"text"},"pattern":["Wildcard", {"Literal":"foo"}]}}}]}`, - ast.Permit().When(ast.String("text").Like(types.LiteralPattern("foo"))), + ast.Permit().When(ast.String("text").Like(types.NewPattern(types.Wildcard, types.String("foo")))), testutil.OK, }, { "like literal then wildcard", `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, "conditions":[{"kind":"when","body":{"like":{"left":{"Value":"text"},"pattern":[{"Literal":"foo"}, "Wildcard"]}}}]}`, - ast.Permit().When(ast.String("text").Like(types.LiteralPattern("foo"))), + ast.Permit().When(ast.String("text").Like(types.NewPattern(types.String("foo"), types.Wildcard))), testutil.OK, }, { diff --git a/internal/parser/cedar_unmarshal_test.go b/internal/parser/cedar_unmarshal_test.go index 0630ccb3..f4aa346c 100644 --- a/internal/parser/cedar_unmarshal_test.go +++ b/internal/parser/cedar_unmarshal_test.go @@ -288,19 +288,19 @@ when { principal has "1stName" };`, "like no wildcards", `permit ( principal, action, resource ) when { principal.firstName like "johnny" };`, - ast.Permit().When(ast.Principal().Access("firstName").Like(types.LiteralPattern("johnny"))), + ast.Permit().When(ast.Principal().Access("firstName").Like(types.NewPattern(types.String("johnny")))), }, { "like escaped asterisk", `permit ( principal, action, resource ) when { principal.firstName like "joh\*nny" };`, - ast.Permit().When(ast.Principal().Access("firstName").Like(types.LiteralPattern("joh*nny"))), + ast.Permit().When(ast.Principal().Access("firstName").Like(types.NewPattern(types.String("joh*nny")))), }, { "like wildcard", `permit ( principal, action, resource ) when { principal.firstName like "*" };`, - ast.Permit().When(ast.Principal().Access("firstName").Like(types.WildcardPattern)), + ast.Permit().When(ast.Principal().Access("firstName").Like(types.NewPattern(types.Wildcard))), }, { "is", diff --git a/types/pattern.go b/types/pattern.go index 65794596..3b7db9bb 100644 --- a/types/pattern.go +++ b/types/pattern.go @@ -21,17 +21,40 @@ type Pattern struct { comps []patternComponent } -var WildcardPattern = newPattern([]patternComponent{{Wildcard: true}}) - -func newPattern(comps []patternComponent) Pattern { - return Pattern{comps: comps} +// A PatternComponent is either a wildcard (represented as "*" in Cedar text) or a literal string. Note that * +// characters in literal strings are treated as literal asterisks rather than wildcards. +type PatternComponent interface { + isPatternComponent() } -func LiteralPattern(literal string) Pattern { - if literal == "" { - return newPattern(nil) +type WildcardPatternComponent struct{} + +func (WildcardPatternComponent) isPatternComponent() {} + +// Wildcard is a constant which can be used to conveniently construct an instance of WildcardPatternComponent +var Wildcard = WildcardPatternComponent{} + +func (String) isPatternComponent() {} + +// NewPattern permits for the programmatic construction of a Pattern out of a set of PatternComponents. +func NewPattern(components ...PatternComponent) Pattern { + var comps []patternComponent + for _, c := range components { + switch v := c.(type) { + case WildcardPatternComponent: + if len(comps) == 0 || comps[len(comps)-1].Literal != "" { + comps = append(comps, patternComponent{Wildcard: true, Literal: ""}) + } + case String: + if len(comps) == 0 { + comps = []patternComponent{{Wildcard: false, Literal: ""}} + } + comps[len(comps)-1].Literal += string(v) + default: + panic(fmt.Sprintf("unexpected component type: %T", v)) + } } - return newPattern([]patternComponent{{Wildcard: false, Literal: literal}}) + return Pattern{comps: comps} } func (p Pattern) Cedar() string { @@ -115,22 +138,23 @@ func matchChunk(chunk, s string) (rest string, ok bool) { } func (p *Pattern) UnmarshalCedar(b []byte) error { - var comps []patternComponent + var comps []PatternComponent for len(b) > 0 { - var comp patternComponent - var err error for len(b) > 0 && b[0] == '*' { b = b[1:] - comp.Wildcard = true + comps = append(comps, Wildcard) } - comp.Literal, b, err = rust.Unquote(b, true) + + var err error + var literal string + literal, b, err = rust.Unquote(b, true) if err != nil { return err } - comps = append(comps, comp) + comps = append(comps, String(literal)) } - *p = Pattern{comps: comps} + *p = NewPattern(comps...) return nil } @@ -161,23 +185,23 @@ func (p Pattern) MarshalJSON() ([]byte, error) { func (p *Pattern) UnmarshalJSON(b []byte) error { dec := json.NewDecoder(bytes.NewReader(b)) - var comps []any - if err := dec.Decode(&comps); err != nil { + var objs []any + if err := dec.Decode(&objs); err != nil { return err } - if len(comps) == 0 { + if len(objs) == 0 { return fmt.Errorf(`%w: must provide at least one pattern component`, errJSONInvalidPatternComponent) } - pb := PatternBuilder{} - for _, comp := range comps { + var comps []PatternComponent + for _, comp := range objs { switch v := comp.(type) { case string: if v != "Wildcard" { return fmt.Errorf(`%w: invalid component string "%v"`, errJSONInvalidPatternComponent, v) } - pb = pb.AddWildcard() + comps = append(comps, Wildcard) case map[string]any: if len(v) != 1 { return fmt.Errorf(`%w: too many keys in literal object`, errJSONInvalidPatternComponent) @@ -193,42 +217,12 @@ func (p *Pattern) UnmarshalJSON(b []byte) error { return fmt.Errorf(`%w: invalid "Literal" value "%v"`, errJSONInvalidPatternComponent, literal) } - pb = pb.AddLiteral(literalStr) + comps = append(comps, String(literalStr)) default: return fmt.Errorf(`%w: unknown component type`, errJSONInvalidPatternComponent) } } - *p = pb.Build() + *p = NewPattern(comps...) return nil } - -// PatternBuilder can be used to programmatically build a Cedar pattern string, like so: -// PatternBuilder{}.AddWildcard().AddLiteral("foo").AddWildcard().Build() -type PatternBuilder []patternComponent - -func (p PatternBuilder) AddWildcard() PatternBuilder { - star := patternComponent{Wildcard: true} - if len(p) == 0 { - return PatternBuilder{star} - } - - lastComp := (p)[len(p)-1] - if lastComp.Wildcard && lastComp.Literal == "" { - return p - } - - return append(p, star) -} - -func (p PatternBuilder) AddLiteral(s string) PatternBuilder { - if len(p) == 0 { - p = PatternBuilder{patternComponent{}} - } - p[len(p)-1].Literal += s - return p -} - -func (p PatternBuilder) Build() Pattern { - return newPattern(p) -} diff --git a/types/patttern_test.go b/types/patttern_test.go index a0a9376d..6ff9b132 100644 --- a/types/patttern_test.go +++ b/types/patttern_test.go @@ -8,41 +8,43 @@ import ( func TestPatternFromBuilder(t *testing.T) { t.Run("saturate two wildcards", func(t *testing.T) { - pattern1 := PatternBuilder{}.AddWildcard().AddWildcard().Build() - testutil.Equals(t, pattern1, WildcardPattern) + pattern1 := NewPattern(Wildcard, Wildcard) + pattern2 := NewPattern(Wildcard) + testutil.Equals(t, pattern1, pattern2) }) t.Run("saturate two literals", func(t *testing.T) { - pattern1 := PatternBuilder{}.AddLiteral("foo").AddLiteral("bar").Build() - pattern2 := LiteralPattern("foobar") + pattern1 := NewPattern(String("foo"), String("bar")) + pattern2 := NewPattern(String("foobar")) testutil.Equals(t, pattern1, pattern2) }) } func TestParsePattern(t *testing.T) { t.Parallel() + a := String("a") tests := []struct { input string wantOk bool want Pattern wantErr string }{ - {"", true, LiteralPattern(""), ""}, - {"a", true, LiteralPattern("a"), ""}, - {"*", true, WildcardPattern, ""}, - {"*a", true, PatternBuilder{}.AddWildcard().AddLiteral("a").Build(), ""}, - {"a*", true, PatternBuilder{}.AddLiteral("a").AddWildcard().Build(), ""}, - {"**", true, WildcardPattern, ""}, - {"**a", true, PatternBuilder{}.AddWildcard().AddLiteral("a").Build(), ""}, - {"a**", true, PatternBuilder{}.AddLiteral("a").AddWildcard().Build(), ""}, - {"*a*", true, PatternBuilder{}.AddWildcard().AddLiteral("a").AddWildcard().Build(), ""}, - {"**a**", true, PatternBuilder{}.AddWildcard().AddLiteral("a").AddWildcard().Build(), ""}, - {"abra*ca", true, PatternBuilder{}.AddLiteral("abra").AddWildcard().AddLiteral("ca").Build(), ""}, - {"abra**ca", true, PatternBuilder{}.AddLiteral("abra").AddWildcard().AddLiteral("ca").Build(), ""}, - {"*abra*ca", true, PatternBuilder{}.AddWildcard().AddLiteral("abra").AddWildcard().AddLiteral("ca").Build(), ""}, - {"abra*ca*", true, PatternBuilder{}.AddLiteral("abra").AddWildcard().AddLiteral("ca").AddWildcard().Build(), ""}, - {"*abra*ca*", true, PatternBuilder{}.AddWildcard().AddLiteral("abra").AddWildcard().AddLiteral("ca").AddWildcard().Build(), ""}, - {"*abra*ca*dabra", true, PatternBuilder{}.AddWildcard().AddLiteral("abra").AddWildcard().AddLiteral("ca").AddWildcard().AddLiteral("dabra").Build(), ""}, - {`*abra*c\**da\*bra`, true, PatternBuilder{}.AddWildcard().AddLiteral("abra").AddWildcard().AddLiteral("c*").AddWildcard().AddLiteral("da*bra").Build(), ""}, + {"", true, NewPattern(), ""}, + {"a", true, NewPattern(a), ""}, + {"*", true, NewPattern(Wildcard), ""}, + {"*a", true, NewPattern(Wildcard, a), ""}, + {"a*", true, NewPattern(a, Wildcard), ""}, + {"**", true, NewPattern(Wildcard), ""}, + {"**a", true, NewPattern(Wildcard, a), ""}, + {"a**", true, NewPattern(a, Wildcard), ""}, + {"*a*", true, NewPattern(Wildcard, a, Wildcard), ""}, + {"**a**", true, NewPattern(Wildcard, a, Wildcard), ""}, + {"abra*ca", true, NewPattern(String("abra"), Wildcard, String("ca")), ""}, + {"abra**ca", true, NewPattern(String("abra"), Wildcard, String("ca")), ""}, + {"*abra*ca", true, NewPattern(Wildcard, String("abra"), Wildcard, String("ca")), ""}, + {"abra*ca*", true, NewPattern(String("abra"), Wildcard, String("ca"), Wildcard), ""}, + {"*abra*ca*", true, NewPattern(Wildcard, String("abra"), Wildcard, String("ca"), Wildcard), ""}, + {"*abra*ca*dabra", true, NewPattern(Wildcard, String("abra"), Wildcard, String("ca"), Wildcard, String("dabra")), ""}, + {`*abra*c\**da\*bra`, true, NewPattern(Wildcard, String("abra"), Wildcard, String("c*"), Wildcard, String("da*bra")), ""}, {`\u`, false, Pattern{}, "bad unicode rune"}, } for _, tt := range tests { @@ -114,72 +116,65 @@ func TestJSON(t *testing.T) { "like single wildcard", `["Wildcard"]`, testutil.OK, - WildcardPattern, + NewPattern(Wildcard), true, }, { "like single literal", `[{"Literal":"foo"}]`, testutil.OK, - LiteralPattern("foo"), + NewPattern(String("foo")), true, }, { "like wildcard then literal", `["Wildcard", {"Literal":"foo"}]`, testutil.OK, - PatternBuilder{}.AddWildcard().AddLiteral("foo").Build(), + NewPattern(Wildcard, String("foo")), true, }, { "like literal then wildcard", `[{"Literal":"foo"}, "Wildcard"]`, testutil.OK, - PatternBuilder{}.AddLiteral("foo").AddWildcard().Build(), + NewPattern(String("foo"), Wildcard), true, }, { "like literal with asterisk then wildcard", `[{"Literal":"f*oo"}, "Wildcard"]`, testutil.OK, - PatternBuilder{}.AddLiteral("f*oo").AddWildcard().Build(), + NewPattern(String("f*oo"), Wildcard), true, }, { "like literal sandwich", `[{"Literal":"foo"}, "Wildcard", {"Literal":"bar"}]`, testutil.OK, - PatternBuilder{}.AddLiteral("foo").AddWildcard().AddLiteral("bar").Build(), + NewPattern(String("foo"), Wildcard, String("bar")), true, }, { "like wildcard sandwich", `["Wildcard", {"Literal":"foo"}, "Wildcard"]`, testutil.OK, - PatternBuilder{}.AddWildcard().AddLiteral("foo").AddWildcard().Build(), + NewPattern(Wildcard, String("foo"), Wildcard), true, }, { "double wildcard", `["Wildcard", "Wildcard", {"Literal":"foo"}]`, testutil.OK, - PatternBuilder{}.AddWildcard().AddLiteral("foo").Build(), + NewPattern(Wildcard, String("foo")), false, }, { "double literal", `["Wildcard", {"Literal":"foo"}, {"Literal":"bar"}]`, testutil.OK, - PatternBuilder{}.AddWildcard().AddLiteral("foobar").Build(), + NewPattern(Wildcard, String("foobar")), false, }, - { - "literal with asterisk", - `["Wildcard", {"Literal":"foo*"}, "Wildcard"]`, - testutil.OK, - PatternBuilder{}.AddWildcard().AddLiteral("foo*").AddWildcard().Build(), - true, - }, { "not list", `"Wildcard"`, From c4c3da57fb24aa2b9b4081e1f2c6a687b4af0687 Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Fri, 16 Aug 2024 16:43:35 -0700 Subject: [PATCH 162/216] cedar-go/types: rename Path to EntityType for clarity The "path" concept only appears in the Cedar grammar. Everywhere else in the documentation (including the JSON grammar), this is referred to as the "entity type". Signed-off-by: philhassey --- ast/ast_test.go | 8 ++++---- ast/operator.go | 4 ++-- ast/scope.go | 8 ++++---- internal/ast/ast_test.go | 16 ++++++++-------- internal/ast/node.go | 2 +- internal/ast/operator.go | 4 ++-- internal/ast/scope.go | 16 ++++++++-------- internal/eval/evalers.go | 8 ++++---- internal/eval/evalers_test.go | 8 ++++---- internal/eval/util.go | 6 +++--- internal/json/json_test.go | 8 ++++---- internal/json/json_unmarshal.go | 8 ++++---- internal/parser/cedar_unmarshal.go | 8 ++++---- types/value.go | 22 +++++++++++----------- types/value_test.go | 26 +++++++++++++------------- 15 files changed, 76 insertions(+), 76 deletions(-) diff --git a/ast/ast_test.go b/ast/ast_test.go index 45e7a66b..9a3cc821 100644 --- a/ast/ast_test.go +++ b/ast/ast_test.go @@ -345,13 +345,13 @@ func TestASTByTable(t *testing.T) { }, { "opIs", - ast.Permit().When(ast.Long(42).Is(types.Path("T"))), - internalast.Permit().When(internalast.Long(42).Is(types.Path("T"))), + ast.Permit().When(ast.Long(42).Is(types.EntityType("T"))), + internalast.Permit().When(internalast.Long(42).Is(types.EntityType("T"))), }, { "opIsIn", - ast.Permit().When(ast.Long(42).IsIn(types.Path("T"), ast.Long(43))), - internalast.Permit().When(internalast.Long(42).IsIn(types.Path("T"), internalast.Long(43))), + ast.Permit().When(ast.Long(42).IsIn(types.EntityType("T"), ast.Long(43))), + internalast.Permit().When(internalast.Long(42).IsIn(types.EntityType("T"), internalast.Long(43))), }, { "opContains", diff --git a/ast/operator.go b/ast/operator.go index 9c271ce3..a13bbf83 100644 --- a/ast/operator.go +++ b/ast/operator.go @@ -112,11 +112,11 @@ func (lhs Node) In(rhs Node) Node { return wrapNode(lhs.Node.In(rhs.Node)) } -func (lhs Node) Is(entityType types.Path) Node { +func (lhs Node) Is(entityType types.EntityType) Node { return wrapNode(lhs.Node.Is(entityType)) } -func (lhs Node) IsIn(entityType types.Path, rhs Node) Node { +func (lhs Node) IsIn(entityType types.EntityType, rhs Node) Node { return wrapNode(lhs.Node.IsIn(entityType, rhs.Node)) } diff --git a/ast/scope.go b/ast/scope.go index d282101b..6db3cd25 100644 --- a/ast/scope.go +++ b/ast/scope.go @@ -12,11 +12,11 @@ func (p *Policy) PrincipalIn(entity types.EntityUID) *Policy { return wrapPolicy(p.unwrap().PrincipalIn(entity)) } -func (p *Policy) PrincipalIs(entityType types.Path) *Policy { +func (p *Policy) PrincipalIs(entityType types.EntityType) *Policy { return wrapPolicy(p.unwrap().PrincipalIs(entityType)) } -func (p *Policy) PrincipalIsIn(entityType types.Path, entity types.EntityUID) *Policy { +func (p *Policy) PrincipalIsIn(entityType types.EntityType, entity types.EntityUID) *Policy { return wrapPolicy(p.unwrap().PrincipalIsIn(entityType, entity)) } @@ -40,10 +40,10 @@ func (p *Policy) ResourceIn(entity types.EntityUID) *Policy { return wrapPolicy(p.unwrap().ResourceIn(entity)) } -func (p *Policy) ResourceIs(entityType types.Path) *Policy { +func (p *Policy) ResourceIs(entityType types.EntityType) *Policy { return wrapPolicy(p.unwrap().ResourceIs(entityType)) } -func (p *Policy) ResourceIsIn(entityType types.Path, entity types.EntityUID) *Policy { +func (p *Policy) ResourceIsIn(entityType types.EntityType, entity types.EntityUID) *Policy { return wrapPolicy(p.unwrap().ResourceIsIn(entityType, entity)) } diff --git a/internal/ast/ast_test.go b/internal/ast/ast_test.go index dced3c46..ede28f73 100644 --- a/internal/ast/ast_test.go +++ b/internal/ast/ast_test.go @@ -135,12 +135,12 @@ func TestASTByTable(t *testing.T) { { "scopePrincipalIs", ast.Permit().PrincipalIs("T"), - ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeIs{Type: types.Path("T")}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}}, + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeIs{Type: types.EntityType("T")}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}}, }, { "scopePrincipalIsIn", ast.Permit().PrincipalIsIn("T", types.NewEntityUID("T", "42")), - ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeIsIn{Type: types.Path("T"), Entity: types.NewEntityUID("T", "42")}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}}, + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeIsIn{Type: types.EntityType("T"), Entity: types.NewEntityUID("T", "42")}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}}, }, { "scopeActionEq", @@ -170,12 +170,12 @@ func TestASTByTable(t *testing.T) { { "scopeResourceIs", ast.Permit().ResourceIs("T"), - ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeIs{Type: types.Path("T")}}, + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeIs{Type: types.EntityType("T")}}, }, { "scopeResourceIsIn", ast.Permit().ResourceIsIn("T", types.NewEntityUID("T", "42")), - ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeIsIn{Type: types.Path("T"), Entity: types.NewEntityUID("T", "42")}}, + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeIsIn{Type: types.EntityType("T"), Entity: types.NewEntityUID("T", "42")}}, }, { "variablePrincipal", @@ -410,15 +410,15 @@ func TestASTByTable(t *testing.T) { }, { "opIs", - ast.Permit().When(ast.Long(42).Is(types.Path("T"))), + ast.Permit().When(ast.Long(42).Is(types.EntityType("T"))), ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, - Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeIs{Left: ast.NodeValue{Value: types.Long(42)}, EntityType: types.Path("T")}}}}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeIs{Left: ast.NodeValue{Value: types.Long(42)}, EntityType: types.EntityType("T")}}}}, }, { "opIsIn", - ast.Permit().When(ast.Long(42).IsIn(types.Path("T"), ast.Long(43))), + ast.Permit().When(ast.Long(42).IsIn(types.EntityType("T"), ast.Long(43))), ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, - Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeIsIn{NodeTypeIs: ast.NodeTypeIs{Left: ast.NodeValue{Value: types.Long(42)}, EntityType: types.Path("T")}, Entity: ast.NodeValue{Value: types.Long(43)}}}}}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeIsIn{NodeTypeIs: ast.NodeTypeIs{Left: ast.NodeValue{Value: types.Long(42)}, EntityType: types.EntityType("T")}, Entity: ast.NodeValue{Value: types.Long(43)}}}}}, }, { "opContains", diff --git a/internal/ast/node.go b/internal/ast/node.go index 7845c4ff..623b1400 100644 --- a/internal/ast/node.go +++ b/internal/ast/node.go @@ -76,7 +76,7 @@ func (n NodeTypeLike) isNode() {} type NodeTypeIs struct { Left IsNode - EntityType types.Path + EntityType types.EntityType } func (n NodeTypeIs) isNode() {} diff --git a/internal/ast/operator.go b/internal/ast/operator.go index 0bc18c05..62cdf3ab 100644 --- a/internal/ast/operator.go +++ b/internal/ast/operator.go @@ -109,11 +109,11 @@ func (lhs Node) In(rhs Node) Node { return NewNode(NodeTypeIn{BinaryNode: BinaryNode{Left: lhs.v, Right: rhs.v}}) } -func (lhs Node) Is(entityType types.Path) Node { +func (lhs Node) Is(entityType types.EntityType) Node { return NewNode(NodeTypeIs{Left: lhs.v, EntityType: entityType}) } -func (lhs Node) IsIn(entityType types.Path, rhs Node) Node { +func (lhs Node) IsIn(entityType types.EntityType, rhs Node) Node { return NewNode(NodeTypeIsIn{NodeTypeIs: NodeTypeIs{Left: lhs.v, EntityType: entityType}, Entity: rhs.v}) } diff --git a/internal/ast/scope.go b/internal/ast/scope.go index 521912a0..9b9cddbe 100644 --- a/internal/ast/scope.go +++ b/internal/ast/scope.go @@ -22,11 +22,11 @@ func (s Scope) InSet(entities []types.EntityUID) IsScopeNode { return ScopeTypeInSet{Entities: entities} } -func (s Scope) Is(entityType types.Path) IsScopeNode { +func (s Scope) Is(entityType types.EntityType) IsScopeNode { return ScopeTypeIs{Type: entityType} } -func (s Scope) IsIn(entityType types.Path, entity types.EntityUID) IsScopeNode { +func (s Scope) IsIn(entityType types.EntityType, entity types.EntityUID) IsScopeNode { return ScopeTypeIsIn{Type: entityType, Entity: entity} } @@ -40,12 +40,12 @@ func (p *Policy) PrincipalIn(entity types.EntityUID) *Policy { return p } -func (p *Policy) PrincipalIs(entityType types.Path) *Policy { +func (p *Policy) PrincipalIs(entityType types.EntityType) *Policy { p.Principal = Scope(NewPrincipalNode()).Is(entityType) return p } -func (p *Policy) PrincipalIsIn(entityType types.Path, entity types.EntityUID) *Policy { +func (p *Policy) PrincipalIsIn(entityType types.EntityType, entity types.EntityUID) *Policy { p.Principal = Scope(NewPrincipalNode()).IsIn(entityType, entity) return p } @@ -75,12 +75,12 @@ func (p *Policy) ResourceIn(entity types.EntityUID) *Policy { return p } -func (p *Policy) ResourceIs(entityType types.Path) *Policy { +func (p *Policy) ResourceIs(entityType types.EntityType) *Policy { p.Resource = Scope(NewResourceNode()).Is(entityType) return p } -func (p *Policy) ResourceIsIn(entityType types.Path, entity types.EntityUID) *Policy { +func (p *Policy) ResourceIsIn(entityType types.EntityType, entity types.EntityUID) *Policy { p.Resource = Scope(NewResourceNode()).IsIn(entityType, entity) return p } @@ -115,11 +115,11 @@ type ScopeTypeInSet struct { type ScopeTypeIs struct { ScopeNode - Type types.Path + Type types.EntityType } type ScopeTypeIsIn struct { ScopeNode - Type types.Path + Type types.EntityType Entity types.EntityUID } diff --git a/internal/eval/evalers.go b/internal/eval/evalers.go index ab79cc36..b600453e 100644 --- a/internal/eval/evalers.go +++ b/internal/eval/evalers.go @@ -85,12 +85,12 @@ func evalEntity(n Evaler, ctx *Context) (types.EntityUID, error) { return e, nil } -func evalPath(n Evaler, ctx *Context) (types.Path, error) { +func evalEntityType(n Evaler, ctx *Context) (types.EntityType, error) { v, err := n.Eval(ctx) if err != nil { return "", err } - e, err := ValueToPath(v) + e, err := ValueToEntityType(v) if err != nil { return "", err } @@ -985,12 +985,12 @@ func (n *isEval) Eval(ctx *Context) (types.Value, error) { return types.ZeroValue(), err } - rhs, err := evalPath(n.rhs, ctx) + rhs, err := evalEntityType(n.rhs, ctx) if err != nil { return types.ZeroValue(), err } - return types.Boolean(types.Path(lhs.Type) == rhs), nil + return types.Boolean(types.EntityType(lhs.Type) == rhs), nil } // decimalLiteralEval diff --git a/internal/eval/evalers_test.go b/internal/eval/evalers_test.go index c0a7b1a6..da2fb857 100644 --- a/internal/eval/evalers_test.go +++ b/internal/eval/evalers_test.go @@ -1591,11 +1591,11 @@ func TestIsNode(t *testing.T) { result types.Value err error }{ - {"happyEq", newLiteralEval(types.NewEntityUID("X", "z")), newLiteralEval(types.Path("X")), types.True, nil}, - {"happyNeq", newLiteralEval(types.NewEntityUID("X", "z")), newLiteralEval(types.Path("Y")), types.False, nil}, - {"badLhs", newLiteralEval(types.Long(42)), newLiteralEval(types.Path("X")), types.ZeroValue(), ErrType}, + {"happyEq", newLiteralEval(types.NewEntityUID("X", "z")), newLiteralEval(types.EntityType("X")), types.True, nil}, + {"happyNeq", newLiteralEval(types.NewEntityUID("X", "z")), newLiteralEval(types.EntityType("Y")), types.False, nil}, + {"badLhs", newLiteralEval(types.Long(42)), newLiteralEval(types.EntityType("X")), types.ZeroValue(), ErrType}, {"badRhs", newLiteralEval(types.NewEntityUID("X", "z")), newLiteralEval(types.Long(42)), types.ZeroValue(), ErrType}, - {"errLhs", newErrorEval(errTest), newLiteralEval(types.Path("X")), types.ZeroValue(), errTest}, + {"errLhs", newErrorEval(errTest), newLiteralEval(types.EntityType("X")), types.ZeroValue(), errTest}, {"errRhs", newLiteralEval(types.NewEntityUID("X", "z")), newErrorEval(errTest), types.ZeroValue(), errTest}, } for _, tt := range tests { diff --git a/internal/eval/util.go b/internal/eval/util.go index cf140118..ce4ba1f8 100644 --- a/internal/eval/util.go +++ b/internal/eval/util.go @@ -56,10 +56,10 @@ func ValueToEntity(v types.Value) (types.EntityUID, error) { return ev, nil } -func ValueToPath(v types.Value) (types.Path, error) { - ev, ok := v.(types.Path) +func ValueToEntityType(v types.Value) (types.EntityType, error) { + ev, ok := v.(types.EntityType) if !ok { - return "", fmt.Errorf("%w: expected (Path of type `any_entity_type`), got %v", ErrType, v.TypeName()) + return "", fmt.Errorf("%w: expected (EntityType of type `any_entity_type`), got %v", ErrType, v.TypeName()) } return ev, nil } diff --git a/internal/json/json_test.go b/internal/json/json_test.go index f60ad439..e4b0819c 100644 --- a/internal/json/json_test.go +++ b/internal/json/json_test.go @@ -119,13 +119,13 @@ func TestUnmarshalJSON(t *testing.T) { { "principalIs", `{"effect":"permit","principal":{"op":"is","entity_type":"T"},"action":{"op":"All"},"resource":{"op":"All"}}`, - ast.Permit().PrincipalIs(types.Path("T")), + ast.Permit().PrincipalIs(types.EntityType("T")), testutil.OK, }, { "principalIsIn", `{"effect":"permit","principal":{"op":"is","entity_type":"T","in":{"entity":{"type":"P","id":"42"}}},"action":{"op":"All"},"resource":{"op":"All"}}`, - ast.Permit().PrincipalIsIn(types.Path("T"), types.NewEntityUID("P", "42")), + ast.Permit().PrincipalIsIn(types.EntityType("T"), types.NewEntityUID("P", "42")), testutil.OK, }, { @@ -161,13 +161,13 @@ func TestUnmarshalJSON(t *testing.T) { { "resourceIs", `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"is","entity_type":"T"}}`, - ast.Permit().ResourceIs(types.Path("T")), + ast.Permit().ResourceIs(types.EntityType("T")), testutil.OK, }, { "resourceIsIn", `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"is","entity_type":"T","in":{"entity":{"type":"P","id":"42"}}}}`, - ast.Permit().ResourceIsIn(types.Path("T"), types.NewEntityUID("P", "42")), + ast.Permit().ResourceIsIn(types.EntityType("T"), types.NewEntityUID("P", "42")), testutil.OK, }, { diff --git a/internal/json/json_unmarshal.go b/internal/json/json_unmarshal.go index 39bba7c7..44854ded 100644 --- a/internal/json/json_unmarshal.go +++ b/internal/json/json_unmarshal.go @@ -29,9 +29,9 @@ func (s *scopeJSON) ToNode(variable ast.Scope) (ast.IsScopeNode, error) { return variable.InSet(s.Entities), nil case "is": if s.In == nil { - return variable.Is(types.Path(s.EntityType)), nil + return variable.Is(types.EntityType(s.EntityType)), nil } - return variable.IsIn(types.Path(s.EntityType), s.In.Entity), nil + return variable.IsIn(types.EntityType(s.EntityType), s.In.Entity), nil } return nil, fmt.Errorf("unknown op: %v", s.Op) } @@ -79,9 +79,9 @@ func (j isJSON) ToNode() (ast.Node, error) { if err != nil { return ast.Node{}, fmt.Errorf("error in entity: %w", err) } - return left.IsIn(types.Path(j.EntityType), right), nil + return left.IsIn(types.EntityType(j.EntityType), right), nil } - return left.Is(types.Path(j.EntityType)), nil + return left.Is(types.EntityType(j.EntityType)), nil } func (j ifThenElseJSON) ToNode() (ast.Node, error) { if_, err := j.If.ToNode() diff --git a/internal/parser/cedar_unmarshal.go b/internal/parser/cedar_unmarshal.go index 0c7fbbc5..71423ef5 100644 --- a/internal/parser/cedar_unmarshal.go +++ b/internal/parser/cedar_unmarshal.go @@ -257,8 +257,8 @@ func (p *parser) entityFirstPathPreread(firstPath string) (types.EntityUID, erro } } -func (p *parser) pathFirstPathPreread(firstPath string) (types.Path, error) { - res := types.Path(firstPath) +func (p *parser) pathFirstPathPreread(firstPath string) (types.EntityType, error) { + res := types.EntityType(firstPath) for { if p.peek().Text != "::" { return res, nil @@ -267,14 +267,14 @@ func (p *parser) pathFirstPathPreread(firstPath string) (types.Path, error) { t := p.advance() switch { case t.isIdent(): - res = types.Path(fmt.Sprintf("%v::%v", res, t.Text)) + res = types.EntityType(fmt.Sprintf("%v::%v", res, t.Text)) default: return res, p.errorf("unexpected token") } } } -func (p *parser) path() (types.Path, error) { +func (p *parser) path() (types.EntityType, error) { t := p.advance() if !t.isIdent() { return "", p.errorf("expected ident") diff --git a/types/value.go b/types/value.go index e73285b1..e05340e3 100644 --- a/types/value.go +++ b/types/value.go @@ -386,22 +386,22 @@ func EntityValueFromSlice(v []string) EntityUID { } } -// Path is the type portion of an EntityUID -type Path string +// EntityType is the type portion of an EntityUID +type EntityType string -func (a Path) Equal(bi Value) bool { - b, ok := bi.(Path) +func (a EntityType) Equal(bi Value) bool { + b, ok := bi.(EntityType) return ok && a == b } -func (v Path) TypeName() string { return fmt.Sprintf("(Path of type `%s`)", v) } +func (v EntityType) TypeName() string { return fmt.Sprintf("(EntityType of type `%s`)", v) } -func (v Path) String() string { return string(v) } -func (v Path) Cedar() string { return string(v) } -func (v Path) ExplicitMarshalJSON() ([]byte, error) { return json.Marshal(string(v)) } -func (v Path) deepClone() Value { return v } +func (v EntityType) String() string { return string(v) } +func (v EntityType) Cedar() string { return string(v) } +func (v EntityType) ExplicitMarshalJSON() ([]byte, error) { return json.Marshal(string(v)) } +func (v EntityType) deepClone() Value { return v } -func PathFromSlice(v []string) Path { - return Path(strings.Join(v, "::")) +func EntityTypeFromSlice(v []string) EntityType { + return EntityType(strings.Join(v, "::")) } // A Decimal is a value with both a whole number part and a decimal part of no diff --git a/types/value_test.go b/types/value_test.go index d80bdf3c..3d026d9f 100644 --- a/types/value_test.go +++ b/types/value_test.go @@ -719,13 +719,13 @@ func TestDeepClone(t *testing.T) { }) } -func TestPath(t *testing.T) { +func TestEntityType(t *testing.T) { t.Parallel() t.Run("Equal", func(t *testing.T) { t.Parallel() - a := Path("X") - b := Path("X") - c := Path("Y") + a := EntityType("X") + b := EntityType("X") + c := EntityType("Y") testutil.Equals(t, a.Equal(b), true) testutil.Equals(t, b.Equal(a), true) testutil.Equals(t, a.Equal(c), false) @@ -733,39 +733,39 @@ func TestPath(t *testing.T) { }) t.Run("TypeName", func(t *testing.T) { t.Parallel() - a := Path("X") - testutil.Equals(t, a.TypeName(), "(Path of type `X`)") + a := EntityType("X") + testutil.Equals(t, a.TypeName(), "(EntityType of type `X`)") }) t.Run("String", func(t *testing.T) { t.Parallel() - a := Path("X") + a := EntityType("X") testutil.Equals(t, a.String(), "X") }) t.Run("Cedar", func(t *testing.T) { t.Parallel() - a := Path("X") + a := EntityType("X") testutil.Equals(t, a.Cedar(), "X") }) t.Run("ExplicitMarshalJSON", func(t *testing.T) { t.Parallel() - a := Path("X") + a := EntityType("X") v, err := a.ExplicitMarshalJSON() testutil.OK(t, err) testutil.Equals(t, string(v), `"X"`) }) t.Run("deepClone", func(t *testing.T) { t.Parallel() - a := Path("X") + a := EntityType("X") b := a.deepClone() - c, ok := b.(Path) + c, ok := b.(EntityType) testutil.Equals(t, ok, true) testutil.Equals(t, c, a) }) t.Run("pathFromSlice", func(t *testing.T) { t.Parallel() - a := PathFromSlice([]string{"X", "Y"}) - testutil.Equals(t, a, Path("X::Y")) + a := EntityTypeFromSlice([]string{"X", "Y"}) + testutil.Equals(t, a, EntityType("X::Y")) }) } From 432ab3e0c0863edfeb06d2eb37f1f0d1cb5355dd Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Fri, 16 Aug 2024 16:54:45 -0700 Subject: [PATCH 163/216] cedar-go/types: Change type of EntityUID.Type to EntityType Signed-off-by: philhassey --- internal/eval/evalers_test.go | 2 +- internal/parser/cedar_unmarshal.go | 6 +++--- types/json.go | 10 +++++----- types/value.go | 13 +++---------- 4 files changed, 12 insertions(+), 19 deletions(-) diff --git a/internal/eval/evalers_test.go b/internal/eval/evalers_test.go index da2fb857..0db96f6a 100644 --- a/internal/eval/evalers_test.go +++ b/internal/eval/evalers_test.go @@ -16,7 +16,7 @@ var errTest = fmt.Errorf("test error") // not a real parser func strEnt(v string) types.EntityUID { p := strings.Split(v, "::\"") - return types.EntityUID{Type: p[0], ID: p[1][:len(p[1])-1]} + return types.EntityUID{Type: types.EntityType(p[0]), ID: p[1][:len(p[1])-1]} } func TestOrNode(t *testing.T) { diff --git a/internal/parser/cedar_unmarshal.go b/internal/parser/cedar_unmarshal.go index 71423ef5..e8c3c936 100644 --- a/internal/parser/cedar_unmarshal.go +++ b/internal/parser/cedar_unmarshal.go @@ -230,10 +230,10 @@ func (p *parser) entity() (types.EntityUID, error) { if !t.isIdent() { return res, p.errorf("expected ident") } - return p.entityFirstPathPreread(t.Text) + return p.entityFirstPathPreread(types.EntityType(t.Text)) } -func (p *parser) entityFirstPathPreread(firstPath string) (types.EntityUID, error) { +func (p *parser) entityFirstPathPreread(firstPath types.EntityType) (types.EntityUID, error) { var res types.EntityUID var err error res.Type = firstPath @@ -244,7 +244,7 @@ func (p *parser) entityFirstPathPreread(firstPath string) (types.EntityUID, erro t := p.advance() switch { case t.isIdent(): - res.Type = fmt.Sprintf("%v::%v", res.Type, t.Text) + res.Type = types.EntityType(res.Type.String() + "::" + t.Text) case t.isString(): res.ID, err = t.stringValue() if err != nil { diff --git a/types/json.go b/types/json.go index 7e19cb24..9d31245e 100644 --- a/types/json.go +++ b/types/json.go @@ -27,14 +27,14 @@ type extValueJSON struct { } type extEntity struct { - Type string `json:"type"` - ID string `json:"id"` + Type EntityType `json:"type"` + ID string `json:"id"` } type entityValueJSON struct { - Type *string `json:"type,omitempty"` - ID *string `json:"id,omitempty"` - Entity *extEntity `json:"__entity,omitempty"` + Type *EntityType `json:"type,omitempty"` + ID *string `json:"id,omitempty"` + Entity *extEntity `json:"__entity,omitempty"` } type explicitValue struct { diff --git a/types/value.go b/types/value.go index e05340e3..24c784fd 100644 --- a/types/value.go +++ b/types/value.go @@ -312,13 +312,13 @@ func (v Record) DeepClone() Record { // An EntityUID is the identifier for a principal, action, or resource. type EntityUID struct { - Type string + Type EntityType ID string } func NewEntityUID(typ, id string) EntityUID { return EntityUID{ - Type: typ, + Type: EntityType(typ), ID: id, } } @@ -339,7 +339,7 @@ func (v EntityUID) String() string { return v.Cedar() } // Cedar produces a valid Cedar language representation of the EntityUID, e.g. `Type::"id"`. func (v EntityUID) Cedar() string { - return v.Type + "::" + strconv.Quote(v.ID) + return v.Type.String() + "::" + strconv.Quote(v.ID) } func (v *EntityUID) UnmarshalJSON(b []byte) error { @@ -379,13 +379,6 @@ func (v EntityUID) ExplicitMarshalJSON() ([]byte, error) { } func (v EntityUID) deepClone() Value { return v } -func EntityValueFromSlice(v []string) EntityUID { - return EntityUID{ - Type: strings.Join(v[:len(v)-1], "::"), - ID: v[len(v)-1], - } -} - // EntityType is the type portion of an EntityUID type EntityType string From 5f53c3c7c291988ecfae58b3514c01bb8ea64ec1 Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Fri, 16 Aug 2024 16:57:58 -0700 Subject: [PATCH 164/216] cedar-go/types: change NewEntityUID to accept an EntityType as the type Signed-off-by: philhassey --- internal/ast/value.go | 2 +- internal/eval/evalers_test.go | 8 ++++---- types/value.go | 4 ++-- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/internal/ast/value.go b/internal/ast/value.go index 1ed6e66b..8f4ca163 100644 --- a/internal/ast/value.go +++ b/internal/ast/value.go @@ -68,7 +68,7 @@ func Record(elements Pairs) Node { } func EntityUID(typ, id string) Node { - e := types.NewEntityUID(typ, id) + e := types.NewEntityUID(types.EntityType(typ), id) return Value(e) } diff --git a/internal/eval/evalers_test.go b/internal/eval/evalers_test.go index 0db96f6a..c4126d99 100644 --- a/internal/eval/evalers_test.go +++ b/internal/eval/evalers_test.go @@ -1563,15 +1563,15 @@ func TestEntityIn(t *testing.T) { entityMap := entities.Entities{} for i := 0; i < 100; i++ { p := []types.EntityUID{ - types.NewEntityUID(fmt.Sprint(i+1), "1"), - types.NewEntityUID(fmt.Sprint(i+1), "2"), + types.NewEntityUID(types.EntityType(fmt.Sprint(i+1)), "1"), + types.NewEntityUID(types.EntityType(fmt.Sprint(i+1)), "2"), } - uid1 := types.NewEntityUID(fmt.Sprint(i), "1") + uid1 := types.NewEntityUID(types.EntityType(fmt.Sprint(i)), "1") entityMap[uid1] = entities.Entity{ UID: uid1, Parents: p, } - uid2 := types.NewEntityUID(fmt.Sprint(i), "2") + uid2 := types.NewEntityUID(types.EntityType(fmt.Sprint(i)), "2") entityMap[uid2] = entities.Entity{ UID: uid2, Parents: p, diff --git a/types/value.go b/types/value.go index 24c784fd..288e0681 100644 --- a/types/value.go +++ b/types/value.go @@ -316,9 +316,9 @@ type EntityUID struct { ID string } -func NewEntityUID(typ, id string) EntityUID { +func NewEntityUID(typ EntityType, id string) EntityUID { return EntityUID{ - Type: EntityType(typ), + Type: typ, ID: id, } } From 8766f02c487047947aef90954a375bced7a6ed33 Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Fri, 16 Aug 2024 17:20:03 -0700 Subject: [PATCH 165/216] cedar-go/types: split up each type into its own file for legibility's sake Signed-off-by: philhassey --- internal/eval/evalers_test.go | 98 +++-- types/boolean.go | 32 ++ types/boolean_test.go | 36 ++ types/decimal.go | 185 +++++++++ types/decimal_test.go | 128 +++++++ types/entity_type.go | 25 ++ types/entity_type_test.go | 50 +++ types/entity_uid.go | 76 ++++ types/entity_uid_test.go | 34 ++ types/ipaddr.go | 158 ++++++++ types/ipaddr_test.go | 278 ++++++++++++++ types/json_test.go | 9 + types/long.go | 27 ++ types/long_test.go | 36 ++ types/record.go | 112 ++++++ types/record_test.go | 79 ++++ types/set.go | 111 ++++++ types/set_test.go | 60 +++ types/string.go | 29 ++ types/string_test.go | 34 ++ types/testutil.go | 38 -- types/testutil_test.go | 15 + types/value.go | 697 ---------------------------------- types/value_test.go | 680 +-------------------------------- 24 files changed, 1583 insertions(+), 1444 deletions(-) create mode 100644 types/boolean.go create mode 100644 types/boolean_test.go create mode 100644 types/decimal.go create mode 100644 types/decimal_test.go create mode 100644 types/entity_type.go create mode 100644 types/entity_type_test.go create mode 100644 types/entity_uid.go create mode 100644 types/entity_uid_test.go create mode 100644 types/ipaddr.go create mode 100644 types/ipaddr_test.go create mode 100644 types/long.go create mode 100644 types/long_test.go create mode 100644 types/record.go create mode 100644 types/record_test.go create mode 100644 types/set.go create mode 100644 types/set_test.go create mode 100644 types/string.go create mode 100644 types/string_test.go delete mode 100644 types/testutil.go create mode 100644 types/testutil_test.go diff --git a/internal/eval/evalers_test.go b/internal/eval/evalers_test.go index c4126d99..2e6b15d7 100644 --- a/internal/eval/evalers_test.go +++ b/internal/eval/evalers_test.go @@ -19,6 +19,30 @@ func strEnt(v string) types.EntityUID { return types.EntityUID{Type: types.EntityType(p[0]), ID: p[1][:len(p[1])-1]} } +func AssertValue(t *testing.T, got, want types.Value) { + t.Helper() + testutil.FatalIf( + t, + !((got == types.ZeroValue() && want == types.ZeroValue()) || + (got != types.ZeroValue() && want != types.ZeroValue() && got.Equal(want))), + "got %v want %v", got, want) +} + +func AssertBoolValue(t *testing.T, got types.Value, want bool) { + t.Helper() + testutil.Equals[types.Value](t, got, types.Boolean(want)) +} + +func AssertLongValue(t *testing.T, got types.Value, want int64) { + t.Helper() + testutil.Equals[types.Value](t, got, types.Long(want)) +} + +func AssertZeroValue(t *testing.T, got types.Value) { + t.Helper() + testutil.Equals(t, got, types.ZeroValue()) +} + func TestOrNode(t *testing.T) { t.Parallel() { @@ -37,7 +61,7 @@ func TestOrNode(t *testing.T) { n := newOrNode(newLiteralEval(types.Boolean(tt.lhs)), newLiteralEval(types.Boolean(tt.rhs))) v, err := n.Eval(&Context{}) testutil.OK(t, err) - types.AssertBoolValue(t, v, tt.result) + AssertBoolValue(t, v, tt.result) }) } } @@ -48,7 +72,7 @@ func TestOrNode(t *testing.T) { newLiteralEval(types.True), newLiteralEval(types.Long(1))) v, err := n.Eval(&Context{}) testutil.OK(t, err) - types.AssertBoolValue(t, v, true) + AssertBoolValue(t, v, true) }) { @@ -92,7 +116,7 @@ func TestAndNode(t *testing.T) { n := newAndEval(newLiteralEval(types.Boolean(tt.lhs)), newLiteralEval(types.Boolean(tt.rhs))) v, err := n.Eval(&Context{}) testutil.OK(t, err) - types.AssertBoolValue(t, v, tt.result) + AssertBoolValue(t, v, tt.result) }) } } @@ -103,7 +127,7 @@ func TestAndNode(t *testing.T) { newLiteralEval(types.False), newLiteralEval(types.Long(1))) v, err := n.Eval(&Context{}) testutil.OK(t, err) - types.AssertBoolValue(t, v, false) + AssertBoolValue(t, v, false) }) { @@ -145,7 +169,7 @@ func TestNotNode(t *testing.T) { n := newNotEval(newLiteralEval(types.Boolean(tt.arg))) v, err := n.Eval(&Context{}) testutil.OK(t, err) - types.AssertBoolValue(t, v, tt.result) + AssertBoolValue(t, v, tt.result) }) } } @@ -346,7 +370,7 @@ func TestAddNode(t *testing.T) { n := newAddEval(newLiteralEval(types.Long(1)), newLiteralEval(types.Long(2))) v, err := n.Eval(&Context{}) testutil.OK(t, err) - types.AssertLongValue(t, v, 3) + AssertLongValue(t, v, 3) }) tests := []struct { @@ -385,7 +409,7 @@ func TestSubtractNode(t *testing.T) { n := newSubtractEval(newLiteralEval(types.Long(1)), newLiteralEval(types.Long(2))) v, err := n.Eval(&Context{}) testutil.OK(t, err) - types.AssertLongValue(t, v, -1) + AssertLongValue(t, v, -1) }) tests := []struct { @@ -424,7 +448,7 @@ func TestMultiplyNode(t *testing.T) { n := newMultiplyEval(newLiteralEval(types.Long(-3)), newLiteralEval(types.Long(2))) v, err := n.Eval(&Context{}) testutil.OK(t, err) - types.AssertLongValue(t, v, -6) + AssertLongValue(t, v, -6) }) tests := []struct { @@ -463,7 +487,7 @@ func TestNegateNode(t *testing.T) { n := newNegateEval(newLiteralEval(types.Long(-3))) v, err := n.Eval(&Context{}) testutil.OK(t, err) - types.AssertLongValue(t, v, 3) + AssertLongValue(t, v, 3) }) tests := []struct { @@ -511,7 +535,7 @@ func TestLongLessThanNode(t *testing.T) { newLiteralEval(types.Long(tt.lhs)), newLiteralEval(types.Long(tt.rhs))) v, err := n.Eval(&Context{}) testutil.OK(t, err) - types.AssertBoolValue(t, v, tt.result) + AssertBoolValue(t, v, tt.result) }) } } @@ -563,7 +587,7 @@ func TestLongLessThanOrEqualNode(t *testing.T) { newLiteralEval(types.Long(tt.lhs)), newLiteralEval(types.Long(tt.rhs))) v, err := n.Eval(&Context{}) testutil.OK(t, err) - types.AssertBoolValue(t, v, tt.result) + AssertBoolValue(t, v, tt.result) }) } } @@ -615,7 +639,7 @@ func TestLongGreaterThanNode(t *testing.T) { newLiteralEval(types.Long(tt.lhs)), newLiteralEval(types.Long(tt.rhs))) v, err := n.Eval(&Context{}) testutil.OK(t, err) - types.AssertBoolValue(t, v, tt.result) + AssertBoolValue(t, v, tt.result) }) } } @@ -667,7 +691,7 @@ func TestLongGreaterThanOrEqualNode(t *testing.T) { newLiteralEval(types.Long(tt.lhs)), newLiteralEval(types.Long(tt.rhs))) v, err := n.Eval(&Context{}) testutil.OK(t, err) - types.AssertBoolValue(t, v, tt.result) + AssertBoolValue(t, v, tt.result) }) } } @@ -724,7 +748,7 @@ func TestDecimalLessThanNode(t *testing.T) { n := newDecimalLessThanEval(newLiteralEval(lhsv), newLiteralEval(rhsv)) v, err := n.Eval(&Context{}) testutil.OK(t, err) - types.AssertBoolValue(t, v, tt.result) + AssertBoolValue(t, v, tt.result) }) } } @@ -781,7 +805,7 @@ func TestDecimalLessThanOrEqualNode(t *testing.T) { n := newDecimalLessThanOrEqualEval(newLiteralEval(lhsv), newLiteralEval(rhsv)) v, err := n.Eval(&Context{}) testutil.OK(t, err) - types.AssertBoolValue(t, v, tt.result) + AssertBoolValue(t, v, tt.result) }) } } @@ -838,7 +862,7 @@ func TestDecimalGreaterThanNode(t *testing.T) { n := newDecimalGreaterThanEval(newLiteralEval(lhsv), newLiteralEval(rhsv)) v, err := n.Eval(&Context{}) testutil.OK(t, err) - types.AssertBoolValue(t, v, tt.result) + AssertBoolValue(t, v, tt.result) }) } } @@ -895,7 +919,7 @@ func TestDecimalGreaterThanOrEqualNode(t *testing.T) { n := newDecimalGreaterThanOrEqualEval(newLiteralEval(lhsv), newLiteralEval(rhsv)) v, err := n.Eval(&Context{}) testutil.OK(t, err) - types.AssertBoolValue(t, v, tt.result) + AssertBoolValue(t, v, tt.result) }) } } @@ -974,7 +998,7 @@ func TestEqualNode(t *testing.T) { n := newEqualEval(tt.lhs, tt.rhs) v, err := n.Eval(&Context{}) testutil.AssertError(t, err, tt.err) - types.AssertValue(t, v, tt.result) + AssertValue(t, v, tt.result) }) } } @@ -1000,7 +1024,7 @@ func TestNotEqualNode(t *testing.T) { n := newNotEqualEval(tt.lhs, tt.rhs) v, err := n.Eval(&Context{}) testutil.AssertError(t, err, tt.err) - types.AssertValue(t, v, tt.result) + AssertValue(t, v, tt.result) }) } } @@ -1041,7 +1065,7 @@ func TestSetLiteralNode(t *testing.T) { n := newSetLiteralEval(tt.elems) v, err := n.Eval(&Context{}) testutil.AssertError(t, err, tt.err) - types.AssertValue(t, v, tt.result) + AssertValue(t, v, tt.result) }) } } @@ -1065,7 +1089,7 @@ func TestContainsNode(t *testing.T) { n := newContainsEval(tt.lhs, tt.rhs) v, err := n.Eval(&Context{}) testutil.AssertError(t, err, tt.err) - types.AssertZeroValue(t, v) + AssertZeroValue(t, v) }) } } @@ -1094,7 +1118,7 @@ func TestContainsNode(t *testing.T) { n := newContainsEval(tt.lhs, tt.rhs) v, err := n.Eval(&Context{}) testutil.OK(t, err) - types.AssertBoolValue(t, v, tt.result) + AssertBoolValue(t, v, tt.result) }) } } @@ -1120,7 +1144,7 @@ func TestContainsAllNode(t *testing.T) { n := newContainsAllEval(tt.lhs, tt.rhs) v, err := n.Eval(&Context{}) testutil.AssertError(t, err, tt.err) - types.AssertZeroValue(t, v) + AssertZeroValue(t, v) }) } } @@ -1148,7 +1172,7 @@ func TestContainsAllNode(t *testing.T) { n := newContainsAllEval(tt.lhs, tt.rhs) v, err := n.Eval(&Context{}) testutil.OK(t, err) - types.AssertBoolValue(t, v, tt.result) + AssertBoolValue(t, v, tt.result) }) } } @@ -1174,7 +1198,7 @@ func TestContainsAnyNode(t *testing.T) { n := newContainsAnyEval(tt.lhs, tt.rhs) v, err := n.Eval(&Context{}) testutil.AssertError(t, err, tt.err) - types.AssertZeroValue(t, v) + AssertZeroValue(t, v) }) } } @@ -1205,7 +1229,7 @@ func TestContainsAnyNode(t *testing.T) { n := newContainsAnyEval(tt.lhs, tt.rhs) v, err := n.Eval(&Context{}) testutil.OK(t, err) - types.AssertBoolValue(t, v, tt.result) + AssertBoolValue(t, v, tt.result) }) } } @@ -1237,7 +1261,7 @@ func TestRecordLiteralNode(t *testing.T) { n := newRecordLiteralEval(tt.elems) v, err := n.Eval(&Context{}) testutil.AssertError(t, err, tt.err) - types.AssertValue(t, v, tt.result) + AssertValue(t, v, tt.result) }) } } @@ -1294,7 +1318,7 @@ func TestAttributeAccessNode(t *testing.T) { }, }) testutil.AssertError(t, err, tt.err) - types.AssertValue(t, v, tt.result) + AssertValue(t, v, tt.result) }) } } @@ -1351,7 +1375,7 @@ func TestHasNode(t *testing.T) { }, }) testutil.AssertError(t, err, tt.err) - types.AssertValue(t, v, tt.result) + AssertValue(t, v, tt.result) }) } } @@ -1406,7 +1430,7 @@ func TestLikeNode(t *testing.T) { n := newLikeEval(tt.str, pat) v, err := n.Eval(&Context{}) testutil.AssertError(t, err, tt.err) - types.AssertValue(t, v, tt.result) + AssertValue(t, v, tt.result) }) } } @@ -1443,7 +1467,7 @@ func TestVariableNode(t *testing.T) { n := newVariableEval(tt.variable) v, err := n.Eval(&tt.context) testutil.OK(t, err) - types.AssertValue(t, v, tt.result) + AssertValue(t, v, tt.result) }) } } @@ -1604,7 +1628,7 @@ func TestIsNode(t *testing.T) { t.Parallel() got, err := newIsEval(tt.lhs, tt.rhs).Eval(&Context{}) testutil.AssertError(t, err, tt.err) - types.AssertValue(t, got, tt.result) + AssertValue(t, got, tt.result) }) } } @@ -1721,7 +1745,7 @@ func TestInNode(t *testing.T) { ec := Context{Entities: entityMap} v, err := n.Eval(&ec) testutil.AssertError(t, err, tt.err) - types.AssertValue(t, v, tt.result) + AssertValue(t, v, tt.result) }) } } @@ -1746,7 +1770,7 @@ func TestDecimalLiteralNode(t *testing.T) { n := newDecimalLiteralEval(tt.arg) v, err := n.Eval(&Context{}) testutil.AssertError(t, err, tt.err) - types.AssertValue(t, v, tt.result) + AssertValue(t, v, tt.result) }) } } @@ -1773,7 +1797,7 @@ func TestIPLiteralNode(t *testing.T) { n := newIPLiteralEval(tt.arg) v, err := n.Eval(&Context{}) testutil.AssertError(t, err, tt.err) - types.AssertValue(t, v, tt.result) + AssertValue(t, v, tt.result) }) } } @@ -1811,7 +1835,7 @@ func TestIPTestNode(t *testing.T) { n := newIPTestEval(tt.lhs, tt.rhs) v, err := n.Eval(&Context{}) testutil.AssertError(t, err, tt.err) - types.AssertValue(t, v, tt.result) + AssertValue(t, v, tt.result) }) } } @@ -1849,7 +1873,7 @@ func TestIPIsInRangeNode(t *testing.T) { n := newIPIsInRangeEval(tt.lhs, tt.rhs) v, err := n.Eval(&Context{}) testutil.AssertError(t, err, tt.err) - types.AssertValue(t, v, tt.result) + AssertValue(t, v, tt.result) }) } } diff --git a/types/boolean.go b/types/boolean.go new file mode 100644 index 00000000..c89c0a97 --- /dev/null +++ b/types/boolean.go @@ -0,0 +1,32 @@ +package types + +import ( + "encoding/json" + "fmt" +) + +// A Boolean is a value that is either true or false. +type Boolean bool + +const ( + True = Boolean(true) + False = Boolean(false) +) + +func (a Boolean) Equal(bi Value) bool { + b, ok := bi.(Boolean) + return ok && a == b +} +func (v Boolean) TypeName() string { return "bool" } + +// String produces a string representation of the Boolean, e.g. `true`. +func (v Boolean) String() string { return v.Cedar() } + +// Cedar produces a valid Cedar language representation of the Boolean, e.g. `true`. +func (v Boolean) Cedar() string { + return fmt.Sprint(bool(v)) +} + +// ExplicitMarshalJSON marshals the Boolean into JSON. +func (v Boolean) ExplicitMarshalJSON() ([]byte, error) { return json.Marshal(v) } +func (v Boolean) deepClone() Value { return v } diff --git a/types/boolean_test.go b/types/boolean_test.go new file mode 100644 index 00000000..fdfd2833 --- /dev/null +++ b/types/boolean_test.go @@ -0,0 +1,36 @@ +package types_test + +import ( + "testing" + + "github.com/cedar-policy/cedar-go/internal/testutil" + "github.com/cedar-policy/cedar-go/types" +) + +func TestBool(t *testing.T) { + t.Parallel() + + t.Run("Equal", func(t *testing.T) { + t.Parallel() + t1 := types.Boolean(true) + t2 := types.Boolean(true) + f := types.Boolean(false) + zero := types.Long(0) + testutil.FatalIf(t, !t1.Equal(t1), "%v not Equal to %v", t1, t1) + testutil.FatalIf(t, !t1.Equal(t2), "%v not Equal to %v", t1, t2) + testutil.FatalIf(t, t1.Equal(f), "%v Equal to %v", t1, f) + testutil.FatalIf(t, f.Equal(t1), "%v Equal to %v", f, t1) + testutil.FatalIf(t, f.Equal(zero), "%v Equal to %v", f, zero) + }) + + t.Run("string", func(t *testing.T) { + t.Parallel() + AssertValueString(t, types.Boolean(true), "true") + }) + + t.Run("TypeName", func(t *testing.T) { + t.Parallel() + tn := types.Boolean(true).TypeName() + testutil.Equals(t, tn, "bool") + }) +} diff --git a/types/decimal.go b/types/decimal.go new file mode 100644 index 00000000..21796602 --- /dev/null +++ b/types/decimal.go @@ -0,0 +1,185 @@ +package types + +import ( + "encoding/json" + "errors" + "fmt" + "strconv" + "unicode" +) + +// A Decimal is a value with both a whole number part and a decimal part of no +// more than four digits. In Go this is stored as an int64, the precision is +// defined by the constant DecimalPrecision. +type Decimal int64 + +// DecimalPrecision is the precision of a Decimal. +const DecimalPrecision = 10000 + +// ParseDecimal takes a string representation of a decimal number and converts it into a Decimal type. +func ParseDecimal(s string) (Decimal, error) { + // Check for empty string. + if len(s) == 0 { + return Decimal(0), fmt.Errorf("%w: string too short", ErrDecimal) + } + i := 0 + + // Parse an optional '-'. + negative := false + if s[i] == '-' { + negative = true + i++ + if i == len(s) { + return Decimal(0), fmt.Errorf("%w: string too short", ErrDecimal) + } + } + + // Parse the required first digit. + c := rune(s[i]) + if !unicode.IsDigit(c) { + return Decimal(0), fmt.Errorf("%w: unexpected character %s", ErrDecimal, strconv.QuoteRune(c)) + } + integer := int64(c - '0') + i++ + + // Parse any other digits, ending with i pointing to '.'. + for ; ; i++ { + if i == len(s) { + return Decimal(0), fmt.Errorf("%w: string missing decimal point", ErrDecimal) + } + c = rune(s[i]) + if c == '.' { + break + } + if !unicode.IsDigit(c) { + return Decimal(0), fmt.Errorf("%w: unexpected character %s", ErrDecimal, strconv.QuoteRune(c)) + } + integer = 10*integer + int64(c-'0') + if integer > 922337203685477 { + return Decimal(0), fmt.Errorf("%w: overflow", ErrDecimal) + } + } + + // Advance past the '.'. + i++ + + // Parse the fraction part + fraction := int64(0) + fractionDigits := 0 + for ; i < len(s); i++ { + c = rune(s[i]) + if !unicode.IsDigit(c) { + return Decimal(0), fmt.Errorf("%w: unexpected character %s", ErrDecimal, strconv.QuoteRune(c)) + } + fraction = 10*fraction + int64(c-'0') + fractionDigits++ + } + + // Adjust the fraction part based on how many digits we parsed. + switch fractionDigits { + case 0: + return Decimal(0), fmt.Errorf("%w: missing digits after decimal point", ErrDecimal) + case 1: + fraction *= 1000 + case 2: + fraction *= 100 + case 3: + fraction *= 10 + case 4: + default: + return Decimal(0), fmt.Errorf("%w: too many digits after decimal point", ErrDecimal) + } + + // Check for overflow before we put the number together. + if integer >= 922337203685477 && (fraction > 5808 || (!negative && fraction == 5808)) { + return Decimal(0), fmt.Errorf("%w: overflow", ErrDecimal) + } + + // Put the number together. + if negative { + // Doing things in this order keeps us from overflowing when parsing + // -922337203685477.5808. This isn't technically necessary because the + // go spec defines arithmetic to be well-defined when overflowing. + // However, doing things this way doesn't hurt, so let's be pedantic. + return Decimal(DecimalPrecision*-integer - fraction), nil + } else { + return Decimal(DecimalPrecision*integer + fraction), nil + } +} + +func (a Decimal) Equal(bi Value) bool { + b, ok := bi.(Decimal) + return ok && a == b +} + +func (v Decimal) TypeName() string { return "decimal" } + +// Cedar produces a valid Cedar language representation of the Decimal, e.g. `decimal("12.34")`. +func (v Decimal) Cedar() string { return `decimal("` + v.String() + `")` } + +// String produces a string representation of the Decimal, e.g. `12.34`. +func (v Decimal) String() string { + var res string + if v < 0 { + // Make sure we don't overflow here. Also, go truncates towards zero. + integer := v / DecimalPrecision + decimal := integer*DecimalPrecision - v + res = fmt.Sprintf("-%d.%04d", -integer, decimal) + } else { + res = fmt.Sprintf("%d.%04d", v/DecimalPrecision, v%DecimalPrecision) + } + + // Trim off up to three trailing zeros. + right := len(res) + for trimmed := 0; right-1 >= 0 && trimmed < 3; right, trimmed = right-1, trimmed+1 { + if res[right-1] != '0' { + break + } + } + return res[:right] +} + +func (v *Decimal) UnmarshalJSON(b []byte) error { + var arg string + if len(b) > 0 && b[0] == '"' { + if err := json.Unmarshal(b, &arg); err != nil { + return errors.Join(errJSONDecode, err) + } + } else { + // NOTE: cedar supports two other forms, for now we're only supporting the smallest implicit and explicit form. + // The following are not supported: + // "decimal(\"1234.5678\")" + // {"fn":"decimal","arg":"1234.5678"} + var res extValueJSON + if err := json.Unmarshal(b, &res); err != nil { + return errors.Join(errJSONDecode, err) + } + if res.Extn == nil { + return errJSONExtNotFound + } + if res.Extn.Fn != "decimal" { + return errJSONExtFnMatch + } + arg = res.Extn.Arg + } + vv, err := ParseDecimal(arg) + if err != nil { + return err + } + *v = vv + return nil +} + +// ExplicitMarshalJSON marshals the Decimal into JSON using the implicit form. +func (v Decimal) MarshalJSON() ([]byte, error) { return []byte(`"` + v.String() + `"`), nil } + +// ExplicitMarshalJSON marshals the Decimal into JSON using the explicit form. +func (v Decimal) ExplicitMarshalJSON() ([]byte, error) { + return json.Marshal(extValueJSON{ + Extn: &extn{ + Fn: "decimal", + Arg: v.String(), + }, + }) +} +func (v Decimal) deepClone() Value { return v } diff --git a/types/decimal_test.go b/types/decimal_test.go new file mode 100644 index 00000000..839aba8f --- /dev/null +++ b/types/decimal_test.go @@ -0,0 +1,128 @@ +package types_test + +import ( + "fmt" + "testing" + + "github.com/cedar-policy/cedar-go/internal/testutil" + "github.com/cedar-policy/cedar-go/types" +) + +func TestDecimal(t *testing.T) { + t.Parallel() + { + tests := []struct{ in, out string }{ + {"1.2345", "1.2345"}, + {"1.2340", "1.234"}, + {"1.2300", "1.23"}, + {"1.2000", "1.2"}, + {"1.0000", "1.0"}, + {"1.234", "1.234"}, + {"1.230", "1.23"}, + {"1.200", "1.2"}, + {"1.000", "1.0"}, + {"1.23", "1.23"}, + {"1.20", "1.2"}, + {"1.00", "1.0"}, + {"1.2", "1.2"}, + {"1.0", "1.0"}, + {"01.2345", "1.2345"}, + {"01.2340", "1.234"}, + {"01.2300", "1.23"}, + {"01.2000", "1.2"}, + {"01.0000", "1.0"}, + {"01.234", "1.234"}, + {"01.230", "1.23"}, + {"01.200", "1.2"}, + {"01.000", "1.0"}, + {"01.23", "1.23"}, + {"01.20", "1.2"}, + {"01.00", "1.0"}, + {"01.2", "1.2"}, + {"01.0", "1.0"}, + {"1234.5678", "1234.5678"}, + {"1234.5670", "1234.567"}, + {"1234.5600", "1234.56"}, + {"1234.5000", "1234.5"}, + {"1234.0000", "1234.0"}, + {"1234.567", "1234.567"}, + {"1234.560", "1234.56"}, + {"1234.500", "1234.5"}, + {"1234.000", "1234.0"}, + {"1234.56", "1234.56"}, + {"1234.50", "1234.5"}, + {"1234.00", "1234.0"}, + {"1234.5", "1234.5"}, + {"1234.0", "1234.0"}, + {"0.0", "0.0"}, + {"00.0", "0.0"}, + {"000000000000000000000000000000000000000000000000000000000000000000.0", "0.0"}, + {"922337203685477.5807", "922337203685477.5807"}, + {"-922337203685477.5808", "-922337203685477.5808"}, + } + for _, tt := range tests { + tt := tt + t.Run(fmt.Sprintf("%s->%s", tt.in, tt.out), func(t *testing.T) { + t.Parallel() + d, err := types.ParseDecimal(tt.in) + testutil.OK(t, err) + testutil.Equals(t, d.String(), tt.out) + }) + } + } + + { + tests := []struct{ in, errStr string }{ + {"", "error parsing decimal value: string too short"}, + {"-", "error parsing decimal value: string too short"}, + {"a", "error parsing decimal value: unexpected character 'a'"}, + {"-a", "error parsing decimal value: unexpected character 'a'"}, + {"'", `error parsing decimal value: unexpected character '\''`}, + {`-\\`, `error parsing decimal value: unexpected character '\\'`}, + {"0", "error parsing decimal value: string missing decimal point"}, + {"1a", "error parsing decimal value: unexpected character 'a'"}, + {"1a", "error parsing decimal value: unexpected character 'a'"}, + {"1.", "error parsing decimal value: missing digits after decimal point"}, + {"1.00000", "error parsing decimal value: too many digits after decimal point"}, + {"1.a", "error parsing decimal value: unexpected character 'a'"}, + {"1.0a", "error parsing decimal value: unexpected character 'a'"}, + {"1.0000a", "error parsing decimal value: unexpected character 'a'"}, + {"1.0000a", "error parsing decimal value: unexpected character 'a'"}, + + {"1000000000000000.0", "error parsing decimal value: overflow"}, + {"-1000000000000000.0", "error parsing decimal value: overflow"}, + {"922337203685477.5808", "error parsing decimal value: overflow"}, + {"922337203685478.0", "error parsing decimal value: overflow"}, + {"-922337203685477.5809", "error parsing decimal value: overflow"}, + {"-922337203685478.0", "error parsing decimal value: overflow"}, + } + for _, tt := range tests { + tt := tt + t.Run(fmt.Sprintf("%s->%s", tt.in, tt.errStr), func(t *testing.T) { + t.Parallel() + _, err := types.ParseDecimal(tt.in) + testutil.AssertError(t, err, types.ErrDecimal) + testutil.Equals(t, err.Error(), tt.errStr) + }) + } + } + + t.Run("Equal", func(t *testing.T) { + t.Parallel() + one := types.Decimal(10000) + one2 := types.Decimal(10000) + zero := types.Decimal(0) + f := types.Boolean(false) + testutil.FatalIf(t, !one.Equal(one), "%v not Equal to %v", one, one) + testutil.FatalIf(t, !one.Equal(one2), "%v not Equal to %v", one, one2) + testutil.FatalIf(t, one.Equal(zero), "%v Equal to %v", one, zero) + testutil.FatalIf(t, zero.Equal(one), "%v Equal to %v", zero, one) + testutil.FatalIf(t, zero.Equal(f), "%v Equal to %v", zero, f) + }) + + t.Run("TypeName", func(t *testing.T) { + t.Parallel() + tn := types.Decimal(0).TypeName() + testutil.Equals(t, tn, "decimal") + }) +} diff --git a/types/entity_type.go b/types/entity_type.go new file mode 100644 index 00000000..b1331676 --- /dev/null +++ b/types/entity_type.go @@ -0,0 +1,25 @@ +package types + +import ( + "encoding/json" + "fmt" + "strings" +) + +// EntityType is the type portion of an EntityUID +type EntityType string + +func (a EntityType) Equal(bi Value) bool { + b, ok := bi.(EntityType) + return ok && a == b +} +func (v EntityType) TypeName() string { return fmt.Sprintf("(EntityType of type `%s`)", v) } + +func (v EntityType) String() string { return string(v) } +func (v EntityType) Cedar() string { return string(v) } +func (v EntityType) ExplicitMarshalJSON() ([]byte, error) { return json.Marshal(string(v)) } +func (v EntityType) deepClone() Value { return v } + +func EntityTypeFromSlice(v []string) EntityType { + return EntityType(strings.Join(v, "::")) +} diff --git a/types/entity_type_test.go b/types/entity_type_test.go new file mode 100644 index 00000000..bef081ff --- /dev/null +++ b/types/entity_type_test.go @@ -0,0 +1,50 @@ +package types_test + +import ( + "testing" + + "github.com/cedar-policy/cedar-go/internal/testutil" + "github.com/cedar-policy/cedar-go/types" +) + +func TestEntityType(t *testing.T) { + t.Parallel() + t.Run("Equal", func(t *testing.T) { + t.Parallel() + a := types.EntityType("X") + b := types.EntityType("X") + c := types.EntityType("Y") + testutil.Equals(t, a.Equal(b), true) + testutil.Equals(t, b.Equal(a), true) + testutil.Equals(t, a.Equal(c), false) + testutil.Equals(t, c.Equal(a), false) + }) + t.Run("TypeName", func(t *testing.T) { + t.Parallel() + a := types.EntityType("X") + testutil.Equals(t, a.TypeName(), "(EntityType of type `X`)") + }) + t.Run("String", func(t *testing.T) { + t.Parallel() + a := types.EntityType("X") + testutil.Equals(t, a.String(), "X") + }) + t.Run("Cedar", func(t *testing.T) { + t.Parallel() + a := types.EntityType("X") + testutil.Equals(t, a.Cedar(), "X") + }) + t.Run("ExplicitMarshalJSON", func(t *testing.T) { + t.Parallel() + a := types.EntityType("X") + v, err := a.ExplicitMarshalJSON() + testutil.OK(t, err) + testutil.Equals(t, string(v), `"X"`) + }) + t.Run("pathFromSlice", func(t *testing.T) { + t.Parallel() + a := types.EntityTypeFromSlice([]string{"X", "Y"}) + testutil.Equals(t, a, types.EntityType("X::Y")) + }) + +} diff --git a/types/entity_uid.go b/types/entity_uid.go new file mode 100644 index 00000000..e863317c --- /dev/null +++ b/types/entity_uid.go @@ -0,0 +1,76 @@ +package types + +import ( + "encoding/json" + "fmt" + "strconv" +) + +// An EntityUID is the identifier for a principal, action, or resource. +type EntityUID struct { + Type EntityType + ID string +} + +func NewEntityUID(typ EntityType, id string) EntityUID { + return EntityUID{ + Type: typ, + ID: id, + } +} + +// IsZero returns true if the EntityUID has an empty Type and ID. +func (a EntityUID) IsZero() bool { + return a.Type == "" && a.ID == "" +} + +func (a EntityUID) Equal(bi Value) bool { + b, ok := bi.(EntityUID) + return ok && a == b +} +func (v EntityUID) TypeName() string { return fmt.Sprintf("(entity of type `%s`)", v.Type) } + +// String produces a string representation of the EntityUID, e.g. `Type::"id"`. +func (v EntityUID) String() string { return v.Cedar() } + +// Cedar produces a valid Cedar language representation of the EntityUID, e.g. `Type::"id"`. +func (v EntityUID) Cedar() string { + return v.Type.String() + "::" + strconv.Quote(v.ID) +} + +func (v *EntityUID) UnmarshalJSON(b []byte) error { + // TODO: review after adding support for schemas + var res entityValueJSON + if err := json.Unmarshal(b, &res); err != nil { + return err + } + if res.Entity != nil { + v.Type = res.Entity.Type + v.ID = res.Entity.ID + return nil + } else if res.Type != nil && res.ID != nil { // require both Type and ID to parse "implicit" JSON + v.Type = *res.Type + v.ID = *res.ID + return nil + } + return errJSONEntityNotFound +} + +// ExplicitMarshalJSON marshals the EntityUID into JSON using the implicit form. +func (v EntityUID) MarshalJSON() ([]byte, error) { + return json.Marshal(entityValueJSON{ + Type: &v.Type, + ID: &v.ID, + }) +} + +// ExplicitMarshalJSON marshals the EntityUID into JSON using the explicit form. +func (v EntityUID) ExplicitMarshalJSON() ([]byte, error) { + return json.Marshal(entityValueJSON{ + Entity: &extEntity{ + Type: v.Type, + ID: v.ID, + }, + }) +} +func (v EntityUID) deepClone() Value { return v } diff --git a/types/entity_uid_test.go b/types/entity_uid_test.go new file mode 100644 index 00000000..826a2dcb --- /dev/null +++ b/types/entity_uid_test.go @@ -0,0 +1,34 @@ +package types_test + +import ( + "testing" + + "github.com/cedar-policy/cedar-go/internal/testutil" + "github.com/cedar-policy/cedar-go/types" +) + +func TestEntity(t *testing.T) { + t.Parallel() + + t.Run("Equal", func(t *testing.T) { + t.Parallel() + twoElems := types.EntityUID{"type", "id"} + twoElems2 := types.EntityUID{"type", "id"} + differentValues := types.EntityUID{"asdf", "vfds"} + testutil.FatalIf(t, !twoElems.Equal(twoElems), "%v not Equal to %v", twoElems, twoElems) + testutil.FatalIf(t, !twoElems.Equal(twoElems2), "%v not Equal to %v", twoElems, twoElems2) + testutil.FatalIf(t, twoElems.Equal(differentValues), "%v Equal to %v", twoElems, differentValues) + }) + + t.Run("string", func(t *testing.T) { + t.Parallel() + AssertValueString(t, types.EntityUID{Type: "type", ID: "id"}, `type::"id"`) + AssertValueString(t, types.EntityUID{Type: "namespace::type", ID: "id"}, `namespace::type::"id"`) + }) + + t.Run("TypeName", func(t *testing.T) { + t.Parallel() + tn := types.EntityUID{"T", "id"}.TypeName() + testutil.Equals(t, tn, "(entity of type `T`)") + }) +} diff --git a/types/ipaddr.go b/types/ipaddr.go new file mode 100644 index 00000000..af738c6b --- /dev/null +++ b/types/ipaddr.go @@ -0,0 +1,158 @@ +package types + +import ( + "encoding/json" + "errors" + "fmt" + "net/netip" + "strings" +) + +// An IPAddr is value that represents an IP address. It can be either IPv4 or IPv6. +// The value can represent an individual address or a range of addresses. +type IPAddr netip.Prefix + +// ParseIPAddr takes a string representation of an IP address and converts it into an IPAddr type. +func ParseIPAddr(s string) (IPAddr, error) { + // We disallow IPv4-mapped IPv6 addresses in dotted notation because Cedar does. + if strings.Count(s, ":") >= 2 && strings.Count(s, ".") >= 2 { + return IPAddr{}, fmt.Errorf("%w: cannot parse IPv4 addresses embedded in IPv6 addresses", ErrIP) + } else if net, err := netip.ParsePrefix(s); err == nil { + return IPAddr(net), nil + } else if addr, err := netip.ParseAddr(s); err == nil { + return IPAddr(netip.PrefixFrom(addr, addr.BitLen())), nil + } else { + return IPAddr{}, fmt.Errorf("%w: error parsing IP address %s", ErrIP, s) + } +} + +func (a IPAddr) Equal(bi Value) bool { + b, ok := bi.(IPAddr) + return ok && a == b +} + +func (v IPAddr) TypeName() string { return "IP" } + +// Cedar produces a valid Cedar language representation of the IPAddr, e.g. `ip("127.0.0.1")`. +func (v IPAddr) Cedar() string { return `ip("` + v.String() + `")` } + +// String produces a string representation of the IPAddr, e.g. `127.0.0.1`. +func (v IPAddr) String() string { + if v.Prefix().Bits() == v.Addr().BitLen() { + return v.Addr().String() + } + return v.Prefix().String() +} + +func (v IPAddr) Prefix() netip.Prefix { + return netip.Prefix(v) +} + +func (v IPAddr) IsIPv4() bool { + return v.Addr().Is4() +} + +func (v IPAddr) IsIPv6() bool { + return v.Addr().Is6() +} + +func (v IPAddr) IsLoopback() bool { + // This comment is in the Cedar Rust implementation: + // + // Loopback addresses are "127.0.0.0/8" for IpV4 and "::1" for IpV6 + // + // Unlike the implementation of `is_multicast`, we don't need to test prefix + // + // The reason for IpV6 is obvious: There's only one loopback address + // + // The reason for IpV4 is that provided the truncated ip address is a + // loopback address, its prefix cannot be less than 8 because + // otherwise its more significant byte cannot be 127 + return v.Prefix().Masked().Addr().IsLoopback() +} + +func (v IPAddr) Addr() netip.Addr { + return netip.Prefix(v).Addr() +} + +func (v IPAddr) IsMulticast() bool { + // This comment is in the Cedar Rust implementation: + // + // Multicast addresses are "224.0.0.0/4" for IpV4 and "ff00::/8" for + // IpV6 + // + // If an IpNet's addresses are multicast addresses, calling + // `is_in_range()` over it and its associated net above should + // evaluate to true + // + // The implementation uses the property that if `ip1/prefix1` is in + // range `ip2/prefix2`, then `ip1` is in `ip2/prefix2` and `prefix1 >= + // prefix2` + var min_prefix_len int + if v.IsIPv4() { + min_prefix_len = 4 + } else { + min_prefix_len = 8 + } + return v.Addr().IsMulticast() && v.Prefix().Bits() >= min_prefix_len +} + +func (c IPAddr) Contains(o IPAddr) bool { + return c.Prefix().Contains(o.Addr()) && c.Prefix().Bits() <= o.Prefix().Bits() +} + +func (v *IPAddr) UnmarshalJSON(b []byte) error { + var arg string + if len(b) > 0 && b[0] == '"' { + if err := json.Unmarshal(b, &arg); err != nil { + return errors.Join(errJSONDecode, err) + } + } else { + // NOTE: cedar supports two other forms, for now we're only supporting the smallest implicit explicit form. + // The following are not supported: + // "ip(\"192.168.0.42\")" + // {"fn":"ip","arg":"192.168.0.42"} + var res extValueJSON + if err := json.Unmarshal(b, &res); err != nil { + return errors.Join(errJSONDecode, err) + } + if res.Extn == nil { + return errJSONExtNotFound + } + if res.Extn.Fn != "ip" { + return errJSONExtFnMatch + } + arg = res.Extn.Arg + } + vv, err := ParseIPAddr(arg) + if err != nil { + return err + } + *v = vv + return nil +} + +// ExplicitMarshalJSON marshals the IPAddr into JSON using the implicit form. +func (v IPAddr) MarshalJSON() ([]byte, error) { return []byte(`"` + v.String() + `"`), nil } + +// ExplicitMarshalJSON marshals the IPAddr into JSON using the explicit form. +func (v IPAddr) ExplicitMarshalJSON() ([]byte, error) { + if v.Prefix().Bits() == v.Prefix().Addr().BitLen() { + return json.Marshal(extValueJSON{ + Extn: &extn{ + Fn: "ip", + Arg: v.Addr().String(), + }, + }) + } + return json.Marshal(extValueJSON{ + Extn: &extn{ + Fn: "ip", + Arg: v.String(), + }, + }) +} + +// in this case, netip.Prefix does contain a pointer, but +// the interface given is immutable, so it is safe to return +func (v IPAddr) deepClone() Value { return v } diff --git a/types/ipaddr_test.go b/types/ipaddr_test.go new file mode 100644 index 00000000..997a4466 --- /dev/null +++ b/types/ipaddr_test.go @@ -0,0 +1,278 @@ +package types_test + +import ( + "fmt" + "testing" + + "github.com/cedar-policy/cedar-go/internal/testutil" + "github.com/cedar-policy/cedar-go/types" +) + +func TestIP(t *testing.T) { + t.Parallel() + t.Run("ParseAndString", func(t *testing.T) { + t.Parallel() + + tests := []struct { + in string + parses bool + out string + }{ + {"0.0.0.0", true, "0.0.0.0"}, + {"0.0.0.1", true, "0.0.0.1"}, + {"127.0.0.1", true, "127.0.0.1"}, + {"127.0.0.1/32", true, "127.0.0.1"}, + {"127.0.0.1/24", true, "127.0.0.1/24"}, + {"127.1.2.3/8", true, "127.1.2.3/8"}, + {"::/128", true, "::"}, + {"::1/128", true, "::1"}, + {"2001:db8::1", true, "2001:db8::1"}, + {"2001:db8::1:0:0:1", true, "2001:db8::1:0:0:1"}, + {"::ffff:192.0.2.128", false, ""}, + {"::ffff:c000:0280", true, "::ffff:192.0.2.128"}, + {"2001:db8::1/32", true, "2001:db8::1/32"}, + {"2001:db8::1:0:0:1/96", true, "2001:db8::1:0:0:1/96"}, + {"::ffff:192.0.2.128/24", false, ""}, + {"::ffff:192.0.2.128/120", false, ""}, + {"::ffff:c000:0280/24", true, "::ffff:192.0.2.128/24"}, + {"::ffff:c000:0280/120", true, "::ffff:192.0.2.128/120"}, + {"6b6b:f00::32ff:ffff:6368/00", false, ""}, // leading zero(s) + {"garbage", false, ""}, + {"c5c5:c5c5:c5c5:c5c5:c5c5:c5c5:c5c5:c5c5/68", true, "c5c5:c5c5:c5c5:c5c5:c5c5:c5c5:c5c5:c5c5/68"}, + } + for _, tt := range tests { + tt := tt + var testName string + if tt.parses { + testName = fmt.Sprintf("%s-parses-and-prints-as-%s", tt.in, tt.out) + } else { + testName = fmt.Sprintf("%s-does-not-parse", tt.in) + } + t.Run(testName, func(t *testing.T) { + t.Parallel() + i, err := types.ParseIPAddr(tt.in) + if tt.parses { + testutil.OK(t, err) + testutil.Equals(t, i.String(), tt.out) + } else { + testutil.Error(t, err) + } + }) + } + }) + + t.Run("Equal", func(t *testing.T) { + t.Parallel() + tests := []struct { + lhs, rhs string + equal bool + }{ + {"0.0.0.0", "0.0.0.0", true}, + {"0.0.0.0", "0.0.0.0/32", true}, + {"127.0.0.1", "127.0.0.1", true}, + {"127.0.0.1", "127.0.0.1/32", true}, + {"::", "::", true}, + {"::", "::/128", true}, + {"::1", "::1", true}, + {"::1", "::1/128", true}, + {"::", "0.0.0.0", false}, + {"::1", "127.0.0.1", false}, + {"::ffff:c000:0280", "192.0.2.128", false}, + {"1.2.3.4", "1.2.3.4", true}, + {"1.2.3.4", "1.2.3.4/32", true}, + {"1.2.3.4/32", "1.2.3.4/32", true}, + {"1.2.3.4/24", "1.2.3.4/24", true}, + {"1.2.3.0/24", "1.2.3.255/24", false}, + {"1.2.3.0/24", "1.2.3.0/25", false}, + {"::ffff:c000:0280/24", "::/24", false}, + {"::ffff:c000:0280/120", "192.0.2.0/24", false}, + {"2001:db8::1/32", "2001:db8::/32", false}, + {"2001:db8::1:0:0:1/96", "2001:db8:0:0:1::/96", false}, + {"c5c5:c5c5:c5c5:c5c5:c5c5:c5c5:c5c5:c5c5/68", "c5c5:c5c5:c5c5:c5c5:c5c5:5cc5:c5c5:c5c5/68", false}, + } + for _, tt := range tests { + tt := tt + t.Run(fmt.Sprintf("ip(%v).Equal(ip(%v))", tt.lhs, tt.rhs), func(t *testing.T) { + t.Parallel() + lhs, err := types.ParseIPAddr(tt.lhs) + testutil.OK(t, err) + rhs, err := types.ParseIPAddr(tt.rhs) + testutil.OK(t, err) + equal := lhs.Equal(rhs) + if equal != tt.equal { + t.Fatalf("expected ip(%v).Equal(ip(%v)) to be %v instead of %v", tt.lhs, tt.rhs, tt.equal, equal) + } + if equal { + testutil.FatalIf( + t, + !lhs.Contains(rhs), + "ip(%v) and ip(%v) compare Equal but !ip(%v).contains(ip(%v))", tt.lhs, tt.rhs, tt.lhs, tt.rhs) + testutil.FatalIf( + t, + !rhs.Contains(lhs), + "ip(%v) and ip(%v) compare Equal but !ip(%v).contains(ip(%v))", tt.rhs, tt.lhs, tt.rhs, tt.lhs) + } + }) + } + }) + + t.Run("isIPv4", func(t *testing.T) { + t.Parallel() + tests := []struct { + val string + isIPv4, isIPv6 bool + }{ + {"0.0.0.0", true, false}, + {"0.0.0.0/32", true, false}, + {"127.0.0.1", true, false}, + {"127.0.0.1/32", true, false}, + {"::", false, true}, + {"::1", false, true}, + {"::/128", false, true}, + {"::1/128", false, true}, + {"::ffff:c000:0280", false, true}, + {"::ffff:c000:0280/128", false, true}, + {"::ffff:c000:0280/24", false, true}, + {"2001:db8::1", false, true}, + {"2001:db8::1:0:0:1", false, true}, + {"2001:db8::1/32", false, true}, + {"2001:db8::1:0:0:1/96", false, true}, + } + for _, tt := range tests { + tt := tt + t.Run(fmt.Sprintf("ip(%v).isIPv{4,6}()", tt.val), func(t *testing.T) { + t.Parallel() + val, err := types.ParseIPAddr(tt.val) + testutil.OK(t, err) + isIPv4 := val.IsIPv4() + if isIPv4 != tt.isIPv4 { + t.Fatalf("expected ip(%v).isIPv4() to be %v instead of %v", tt.val, tt.isIPv4, isIPv4) + } + isIPv6 := val.IsIPv6() + if isIPv6 != tt.isIPv6 { + t.Fatalf("expected ip(%v).isIPv6() to be %v instead of %v", tt.val, tt.isIPv6, isIPv6) + } + }) + } + }) + + t.Run("isLoopback", func(t *testing.T) { + t.Parallel() + tests := []struct { + val string + isLoopback bool + }{ + {"0.0.0.0", false}, + {"127.0.0.1", true}, + {"127.0.0.2", true}, + {"127.0.0.1/32", true}, + {"127.0.0.1/24", true}, + {"127.0.0.1/8", true}, + {"127.0.0.1/7", false}, + {"::", false}, + {"::1", true}, + {"::/128", false}, + {"::1/128", true}, + {"::1/127", false}, + {"::ffff:8000:0001", false}, + {"::ffff:8000:0002", false}, + {"::ffff:8000:0001/128", false}, + {"::ffff:8000:0002/128", false}, + {"::ffff:8000:0001/104", false}, + {"::ffff:8000:0002/104", false}, + {"::ffff:8000:0001/100", false}, + {"::ffff:8000:0002/100", false}, + {"2001:db8::1", false}, + {"2001:db8::1:0:0:1", false}, + {"2001:db8::1/32", false}, + {"2001:db8::1:0:0:1/96", false}, + } + for _, tt := range tests { + tt := tt + t.Run(fmt.Sprintf("ip(%v).isLoopback()", tt.val), func(t *testing.T) { + t.Parallel() + val, err := types.ParseIPAddr(tt.val) + testutil.OK(t, err) + isLoopback := val.IsLoopback() + if isLoopback != tt.isLoopback { + t.Fatalf("expected ip(%v).isLoopback() to be %v instead of %v", tt.val, tt.isLoopback, isLoopback) + } + }) + } + }) + + t.Run("isMulticast", func(t *testing.T) { + t.Parallel() + tests := []struct { + val string + isMulticast bool + }{ + {"0.0.0.0", false}, + {"127.0.0.1", false}, + {"223.255.255.255", false}, + {"224.0.0.0", true}, + {"239.255.255.255", true}, + {"240.0.0.0", false}, + {"feff:ffff:ffff:ffff:ffff:ffff:ffff:ffff", false}, + {"ff00::", true}, + {"ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff", true}, + {"ff00::/8", true}, + {"ff00::/7", false}, + {"224.0.0.0/4", true}, + {"224.0.0.0/3", false}, + } + for _, tt := range tests { + tt := tt + t.Run(fmt.Sprintf("ip(%v).isMulticast()", tt.val), func(t *testing.T) { + t.Parallel() + val, err := types.ParseIPAddr(tt.val) + testutil.OK(t, err) + isMulticast := val.IsMulticast() + if isMulticast != tt.isMulticast { + t.Fatalf("expected ip(%v).isMulticast() to be %v instead of %v", tt.val, tt.isMulticast, isMulticast) + } + }) + } + }) + + t.Run("contains", func(t *testing.T) { + t.Parallel() + tests := []struct { + lhs, rhs string + contains bool + }{ + {"0.0.0.0/31", "0.0.0.0", true}, + {"0.0.0.0", "0.0.0.0/31", false}, + {"255.255.0.0/16", "255.255.255.255", true}, + {"255.255.0.0/16", "255.255.255.248/28", true}, + {"255.255.0.0/16", "255.255.255.0/24", true}, + {"255.255.0.0/16", "255.255.248.0/20", true}, + {"255.255.0.0/16", "255.255.0.0/16", true}, + {"255.255.0.0/16", "255.254.0.0/15", false}, + {"255.255.0.0/16", "255.254.255.0/24", false}, + {"::ffff:c000:0280", "192.0.2.128", false}, + {"2001:db8::/120", "2001:db8::2", true}, + {"2001:db8::/64", "2001:db8:0:0:dead:f00d::/96", true}, + } + for _, tt := range tests { + tt := tt + t.Run(fmt.Sprintf("ip(%v).contains(ip(%v))", tt.lhs, tt.rhs), func(t *testing.T) { + t.Parallel() + lhs, err := types.ParseIPAddr(tt.lhs) + testutil.OK(t, err) + rhs, err := types.ParseIPAddr(tt.rhs) + testutil.OK(t, err) + contains := lhs.Contains(rhs) + if contains != tt.contains { + t.Fatalf("expected ip(%v).contains(ip(%v)) to be %v instead of %v", tt.lhs, tt.rhs, tt.contains, contains) + } + }) + } + }) + + t.Run("TypeName", func(t *testing.T) { + t.Parallel() + tn := types.IPAddr{}.TypeName() + testutil.Equals(t, tn, "IP") + }) +} diff --git a/types/json_test.go b/types/json_test.go index eb58bc37..93e58317 100644 --- a/types/json_test.go +++ b/types/json_test.go @@ -18,6 +18,15 @@ func mustIPValue(v string) IPAddr { return r } +func AssertValue(t *testing.T, got, want Value) { + t.Helper() + testutil.FatalIf( + t, + !((got == ZeroValue() && want == ZeroValue()) || + (got != ZeroValue() && want != ZeroValue() && got.Equal(want))), + "got %v want %v", got, want) +} + func TestJSON_Value(t *testing.T) { t.Parallel() tests := []struct { diff --git a/types/long.go b/types/long.go new file mode 100644 index 00000000..01751c1a --- /dev/null +++ b/types/long.go @@ -0,0 +1,27 @@ +package types + +import ( + "encoding/json" + "fmt" +) + +// A Long is a whole number without decimals that can range from -9223372036854775808 to 9223372036854775807. +type Long int64 + +func (a Long) Equal(bi Value) bool { + b, ok := bi.(Long) + return ok && a == b +} + +// ExplicitMarshalJSON marshals the Long into JSON. +func (v Long) ExplicitMarshalJSON() ([]byte, error) { return json.Marshal(v) } +func (v Long) TypeName() string { return "long" } + +// String produces a string representation of the Long, e.g. `42`. +func (v Long) String() string { return v.Cedar() } + +// Cedar produces a valid Cedar language representation of the Long, e.g. `42`. +func (v Long) Cedar() string { + return fmt.Sprint(int64(v)) +} +func (v Long) deepClone() Value { return v } diff --git a/types/long_test.go b/types/long_test.go new file mode 100644 index 00000000..d4de2134 --- /dev/null +++ b/types/long_test.go @@ -0,0 +1,36 @@ +package types_test + +import ( + "testing" + + "github.com/cedar-policy/cedar-go/internal/testutil" + "github.com/cedar-policy/cedar-go/types" +) + +func TestLong(t *testing.T) { + t.Parallel() + + t.Run("Equal", func(t *testing.T) { + t.Parallel() + one := types.Long(1) + one2 := types.Long(1) + zero := types.Long(0) + f := types.Boolean(false) + testutil.FatalIf(t, !one.Equal(one), "%v not Equal to %v", one, one) + testutil.FatalIf(t, !one.Equal(one2), "%v not Equal to %v", one, one2) + testutil.FatalIf(t, one.Equal(zero), "%v Equal to %v", one, zero) + testutil.FatalIf(t, zero.Equal(one), "%v Equal to %v", zero, one) + testutil.FatalIf(t, zero.Equal(f), "%v Equal to %v", zero, f) + }) + + t.Run("string", func(t *testing.T) { + t.Parallel() + AssertValueString(t, types.Long(1), "1") + }) + + t.Run("TypeName", func(t *testing.T) { + t.Parallel() + tn := types.Long(1).TypeName() + testutil.Equals(t, tn, "long") + }) +} diff --git a/types/record.go b/types/record.go new file mode 100644 index 00000000..77992408 --- /dev/null +++ b/types/record.go @@ -0,0 +1,112 @@ +package types + +import ( + "bytes" + "encoding/json" + "slices" + "strconv" + "strings" + + "golang.org/x/exp/maps" +) + +// A Record is a collection of attributes. Each attribute consists of a name and +// an associated value. Names are simple strings. Values can be of any type. +type Record map[string]Value + +// Equals returns true if the records are Equal. +func (r Record) Equals(b Record) bool { return r.Equal(b) } + +func (a Record) Equal(bi Value) bool { + b, ok := bi.(Record) + if !ok || len(a) != len(b) { + return false + } + for k, av := range a { + bv, ok := b[k] + if !ok || !av.Equal(bv) { + return false + } + } + return true +} + +func (v *Record) UnmarshalJSON(b []byte) error { + var res map[string]explicitValue + err := json.Unmarshal(b, &res) + if err != nil { + return err + } + *v = Record{} + for kk, vv := range res { + (*v)[kk] = vv.Value + } + return nil +} + +// MarshalJSON marshals the Record into JSON, the marshaller uses the explicit +// JSON form for all the values in the Record. +func (v Record) MarshalJSON() ([]byte, error) { + w := &bytes.Buffer{} + w.WriteByte('{') + keys := maps.Keys(v) + slices.Sort(keys) + for i, kk := range keys { + if i > 0 { + w.WriteByte(',') + } + kb, _ := json.Marshal(kk) // json.Marshal cannot error on strings + w.Write(kb) + w.WriteByte(':') + vv := v[kk] + vb, err := vv.ExplicitMarshalJSON() + if err != nil { + return nil, err + } + w.Write(vb) + } + w.WriteByte('}') + return w.Bytes(), nil +} + +// ExplicitMarshalJSON marshals the Record into JSON, the marshaller uses the +// explicit JSON form for all the values in the Record. +func (v Record) ExplicitMarshalJSON() ([]byte, error) { return v.MarshalJSON() } +func (r Record) TypeName() string { return "record" } + +// String produces a string representation of the Record, e.g. `{"a":1,"b":2,"c":3}`. +func (r Record) String() string { return r.Cedar() } + +// Cedar produces a valid Cedar language representation of the Record, e.g. `{"a":1,"b":2,"c":3}`. +func (r Record) Cedar() string { + var sb strings.Builder + sb.WriteRune('{') + first := true + keys := maps.Keys(r) + slices.Sort(keys) + for _, k := range keys { + v := r[k] + if !first { + sb.WriteString(", ") + } + first = false + sb.WriteString(strconv.Quote(k)) + sb.WriteString(": ") + sb.WriteString(v.Cedar()) + } + sb.WriteRune('}') + return sb.String() +} +func (v Record) deepClone() Value { return v.DeepClone() } + +// DeepClone returns a deep clone of the Record. +func (v Record) DeepClone() Record { + if v == nil { + return v + } + res := make(Record, len(v)) + for k, vv := range v { + res[k] = vv.deepClone() + } + return res +} diff --git a/types/record_test.go b/types/record_test.go new file mode 100644 index 00000000..b69b4372 --- /dev/null +++ b/types/record_test.go @@ -0,0 +1,79 @@ +package types_test + +import ( + "testing" + + "github.com/cedar-policy/cedar-go/internal/testutil" + "github.com/cedar-policy/cedar-go/types" +) + +func TestRecord(t *testing.T) { + t.Parallel() + + t.Run("Equal", func(t *testing.T) { + t.Parallel() + empty := types.Record{} + empty2 := types.Record{} + twoElems := types.Record{ + "foo": types.Boolean(true), + "bar": types.String("blah"), + } + twoElems2 := types.Record{ + "foo": types.Boolean(true), + "bar": types.String("blah"), + } + differentValues := types.Record{ + "foo": types.Boolean(false), + "bar": types.String("blaz"), + } + differentKeys := types.Record{ + "foo": types.Boolean(false), + "bar": types.Long(1), + } + nested := types.Record{ + "one": types.Long(1), + "two": types.Long(2), + "nest": twoElems, + } + nested2 := types.Record{ + "one": types.Long(1), + "two": types.Long(2), + "nest": twoElems, + } + + testutil.FatalIf(t, !empty.Equals(empty), "%v not Equal to %v", empty, empty) + testutil.FatalIf(t, !empty.Equals(empty2), "%v not Equal to %v", empty, empty2) + + testutil.FatalIf(t, !twoElems.Equals(twoElems), "%v not Equal to %v", twoElems, twoElems) + testutil.FatalIf(t, !twoElems.Equals(twoElems2), "%v not Equal to %v", twoElems, twoElems2) + + testutil.FatalIf(t, !nested.Equals(nested), "%v not Equal to %v", nested, nested) + testutil.FatalIf(t, !nested.Equals(nested2), "%v not Equal to %v", nested, nested2) + + testutil.FatalIf(t, nested.Equals(twoElems), "%v Equal to %v", nested, twoElems) + testutil.FatalIf(t, twoElems.Equals(differentValues), "%v Equal to %v", twoElems, differentValues) + testutil.FatalIf(t, twoElems.Equals(differentKeys), "%v Equal to %v", twoElems, differentKeys) + }) + + t.Run("string", func(t *testing.T) { + t.Parallel() + AssertValueString(t, types.Record{}, "{}") + AssertValueString( + t, + types.Record{"foo": types.Boolean(true)}, + `{"foo": true}`) + AssertValueString( + t, + types.Record{ + "foo": types.Boolean(true), + "bar": types.String("blah"), + }, + `{"bar": "blah", "foo": true}`) + }) + + t.Run("TypeName", func(t *testing.T) { + t.Parallel() + tn := types.Record{}.TypeName() + testutil.Equals(t, tn, "record") + }) +} diff --git a/types/set.go b/types/set.go new file mode 100644 index 00000000..510b68e8 --- /dev/null +++ b/types/set.go @@ -0,0 +1,111 @@ +package types + +import ( + "bytes" + "encoding/json" + "strings" +) + +// A Set is a collection of elements that can be of the same or different types. +type Set []Value + +func (s Set) Contains(v Value) bool { + for _, e := range s { + if e.Equal(v) { + return true + } + } + return false +} + +// Equals returns true if the sets are Equal. +func (s Set) Equals(b Set) bool { return s.Equal(b) } + +func (as Set) Equal(bi Value) bool { + bs, ok := bi.(Set) + if !ok { + return false + } + for _, a := range as { + if !bs.Contains(a) { + return false + } + } + for _, b := range bs { + if !as.Contains(b) { + return false + } + } + return true +} + +func (v *explicitValue) UnmarshalJSON(b []byte) error { + return UnmarshalJSON(b, &v.Value) +} + +func (v *Set) UnmarshalJSON(b []byte) error { + var res []explicitValue + err := json.Unmarshal(b, &res) + if err != nil { + return err + } + for _, vv := range res { + *v = append(*v, vv.Value) + } + return nil +} + +// MarshalJSON marshals the Set into JSON, the marshaller uses the explicit JSON +// form for all the values in the Set. +func (v Set) MarshalJSON() ([]byte, error) { + w := &bytes.Buffer{} + w.WriteByte('[') + for i, vv := range v { + if i > 0 { + w.WriteByte(',') + } + b, err := vv.ExplicitMarshalJSON() + if err != nil { + return nil, err + } + w.Write(b) + } + w.WriteByte(']') + return w.Bytes(), nil +} + +// ExplicitMarshalJSON marshals the Set into JSON, the marshaller uses the +// explicit JSON form for all the values in the Set. +func (v Set) ExplicitMarshalJSON() ([]byte, error) { return v.MarshalJSON() } + +func (v Set) TypeName() string { return "set" } + +// String produces a string representation of the Set, e.g. `[1,2,3]`. +func (v Set) String() string { return v.Cedar() } + +// Cedar produces a valid Cedar language representation of the Set, e.g. `[1,2,3]`. +func (v Set) Cedar() string { + var sb strings.Builder + sb.WriteRune('[') + for i, elem := range v { + if i > 0 { + sb.WriteString(", ") + } + sb.WriteString(elem.Cedar()) + } + sb.WriteRune(']') + return sb.String() +} +func (v Set) deepClone() Value { return v.DeepClone() } + +// DeepClone returns a deep clone of the Set. +func (v Set) DeepClone() Set { + if v == nil { + return v + } + res := make(Set, len(v)) + for i, vv := range v { + res[i] = vv.deepClone() + } + return res +} diff --git a/types/set_test.go b/types/set_test.go new file mode 100644 index 00000000..43adc5e8 --- /dev/null +++ b/types/set_test.go @@ -0,0 +1,60 @@ +package types_test + +import ( + "testing" + + "github.com/cedar-policy/cedar-go/internal/testutil" + "github.com/cedar-policy/cedar-go/types" +) + +func TestSet(t *testing.T) { + t.Parallel() + + t.Run("Equal", func(t *testing.T) { + t.Parallel() + empty := types.Set{} + empty2 := types.Set{} + oneTrue := types.Set{types.Boolean(true)} + oneTrue2 := types.Set{types.Boolean(true)} + oneFalse := types.Set{types.Boolean(false)} + nestedOnce := types.Set{empty, oneTrue, oneFalse} + nestedOnce2 := types.Set{empty, oneTrue, oneFalse} + nestedTwice := types.Set{empty, oneTrue, oneFalse, nestedOnce} + nestedTwice2 := types.Set{empty, oneTrue, oneFalse, nestedOnce} + oneTwoThree := types.Set{ + types.Long(1), types.Long(2), types.Long(3), + } + threeTwoTwoOne := types.Set{ + types.Long(3), types.Long(2), types.Long(2), types.Long(1), + } + + testutil.FatalIf(t, !empty.Equals(empty), "%v not Equal to %v", empty, empty) + testutil.FatalIf(t, !empty.Equals(empty2), "%v not Equal to %v", empty, empty2) + testutil.FatalIf(t, !oneTrue.Equals(oneTrue), "%v not Equal to %v", oneTrue, oneTrue) + testutil.FatalIf(t, !oneTrue.Equals(oneTrue2), "%v not Equal to %v", oneTrue, oneTrue2) + testutil.FatalIf(t, !nestedOnce.Equals(nestedOnce), "%v not Equal to %v", nestedOnce, nestedOnce) + testutil.FatalIf(t, !nestedOnce.Equals(nestedOnce2), "%v not Equal to %v", nestedOnce, nestedOnce2) + testutil.FatalIf(t, !nestedTwice.Equals(nestedTwice), "%v not Equal to %v", nestedTwice, nestedTwice) + testutil.FatalIf(t, !nestedTwice.Equals(nestedTwice2), "%v not Equal to %v", nestedTwice, nestedTwice2) + testutil.FatalIf(t, !oneTwoThree.Equals(threeTwoTwoOne), "%v not Equal to %v", oneTwoThree, threeTwoTwoOne) + + testutil.FatalIf(t, empty.Equals(oneFalse), "%v Equal to %v", empty, oneFalse) + testutil.FatalIf(t, oneTrue.Equals(oneFalse), "%v Equal to %v", oneTrue, oneFalse) + testutil.FatalIf(t, nestedOnce.Equals(nestedTwice), "%v Equal to %v", nestedOnce, nestedTwice) + }) + + t.Run("string", func(t *testing.T) { + t.Parallel() + AssertValueString(t, types.Set{}, "[]") + AssertValueString( + t, + types.Set{types.Boolean(true), types.Long(1)}, + "[true, 1]") + }) + + t.Run("TypeName", func(t *testing.T) { + t.Parallel() + tn := types.Set{}.TypeName() + testutil.Equals(t, tn, "set") + }) +} diff --git a/types/string.go b/types/string.go new file mode 100644 index 00000000..dd15a4a7 --- /dev/null +++ b/types/string.go @@ -0,0 +1,29 @@ +package types + +import ( + "encoding/json" + "strconv" +) + +// A String is a sequence of characters consisting of letters, numbers, or symbols. +type String string + +func (a String) Equal(bi Value) bool { + b, ok := bi.(String) + return ok && a == b +} + +// ExplicitMarshalJSON marshals the String into JSON. +func (v String) ExplicitMarshalJSON() ([]byte, error) { return json.Marshal(v) } +func (v String) TypeName() string { return "string" } + +// String produces an unquoted string representation of the String, e.g. `hello`. +func (v String) String() string { + return string(v) +} + +// Cedar produces a valid Cedar language representation of the String, e.g. `"hello"`. +func (v String) Cedar() string { + return strconv.Quote(string(v)) +} +func (v String) deepClone() Value { return v } diff --git a/types/string_test.go b/types/string_test.go new file mode 100644 index 00000000..a995f64e --- /dev/null +++ b/types/string_test.go @@ -0,0 +1,34 @@ +package types_test + +import ( + "testing" + + "github.com/cedar-policy/cedar-go/internal/testutil" + "github.com/cedar-policy/cedar-go/types" +) + +func TestString(t *testing.T) { + t.Parallel() + + t.Run("Equal", func(t *testing.T) { + t.Parallel() + hello := types.String("hello") + hello2 := types.String("hello") + goodbye := types.String("goodbye") + testutil.FatalIf(t, !hello.Equal(hello), "%v not Equal to %v", hello, hello) + testutil.FatalIf(t, !hello.Equal(hello2), "%v not Equal to %v", hello, hello2) + testutil.FatalIf(t, hello.Equal(goodbye), "%v Equal to %v", hello, goodbye) + }) + + t.Run("string", func(t *testing.T) { + t.Parallel() + AssertValueString(t, types.String("hello"), `hello`) + AssertValueString(t, types.String("hello\ngoodbye"), "hello\ngoodbye") + }) + + t.Run("TypeName", func(t *testing.T) { + t.Parallel() + tn := types.String("hello").TypeName() + testutil.Equals(t, tn, "string") + }) +} diff --git a/types/testutil.go b/types/testutil.go deleted file mode 100644 index 64df358c..00000000 --- a/types/testutil.go +++ /dev/null @@ -1,38 +0,0 @@ -package types - -import ( - "testing" - - "github.com/cedar-policy/cedar-go/internal/testutil" -) - -// TODO: this file should not be public, it should be moved into the eval code - -func AssertValue(t *testing.T, got, want Value) { - t.Helper() - testutil.FatalIf( - t, - !((got == ZeroValue() && want == ZeroValue()) || - (got != ZeroValue() && want != ZeroValue() && got.Equal(want))), - "got %v want %v", got, want) -} - -func AssertBoolValue(t *testing.T, got Value, want bool) { - t.Helper() - testutil.Equals[Value](t, got, Boolean(want)) -} - -func AssertLongValue(t *testing.T, got Value, want int64) { - t.Helper() - testutil.Equals[Value](t, got, Long(want)) -} - -func AssertZeroValue(t *testing.T, got Value) { - t.Helper() - testutil.Equals(t, got, ZeroValue()) -} - -func AssertValueString(t *testing.T, v Value, want string) { - t.Helper() - testutil.Equals(t, v.String(), want) -} diff --git a/types/testutil_test.go b/types/testutil_test.go new file mode 100644 index 00000000..c5dbfa19 --- /dev/null +++ b/types/testutil_test.go @@ -0,0 +1,15 @@ +package types_test + +import ( + "testing" + + "github.com/cedar-policy/cedar-go/internal/testutil" + "github.com/cedar-policy/cedar-go/types" +) + +// TODO: this file should not be public, it should be moved into the eval code + +func AssertValueString(t *testing.T, v types.Value, want string) { + t.Helper() + testutil.Equals(t, v.String(), want) +} diff --git a/types/value.go b/types/value.go index 288e0681..fc528ec8 100644 --- a/types/value.go +++ b/types/value.go @@ -1,17 +1,7 @@ package types import ( - "bytes" - "encoding/json" - "errors" "fmt" - "net/netip" - "strconv" - "strings" - "unicode" - - "golang.org/x/exp/maps" - "golang.org/x/exp/slices" ) var ErrDecimal = fmt.Errorf("error parsing decimal value") @@ -34,690 +24,3 @@ type Value interface { func ZeroValue() Value { return nil } - -// A Boolean is a value that is either true or false. -type Boolean bool - -const ( - True = Boolean(true) - False = Boolean(false) -) - -func (a Boolean) Equal(bi Value) bool { - b, ok := bi.(Boolean) - return ok && a == b -} -func (v Boolean) TypeName() string { return "bool" } - -// String produces a string representation of the Boolean, e.g. `true`. -func (v Boolean) String() string { return v.Cedar() } - -// Cedar produces a valid Cedar language representation of the Boolean, e.g. `true`. -func (v Boolean) Cedar() string { - return fmt.Sprint(bool(v)) -} - -// ExplicitMarshalJSON marshals the Boolean into JSON. -func (v Boolean) ExplicitMarshalJSON() ([]byte, error) { return json.Marshal(v) } -func (v Boolean) deepClone() Value { return v } - -// A Long is a whole number without decimals that can range from -9223372036854775808 to 9223372036854775807. -type Long int64 - -func (a Long) Equal(bi Value) bool { - b, ok := bi.(Long) - return ok && a == b -} - -// ExplicitMarshalJSON marshals the Long into JSON. -func (v Long) ExplicitMarshalJSON() ([]byte, error) { return json.Marshal(v) } -func (v Long) TypeName() string { return "long" } - -// String produces a string representation of the Long, e.g. `42`. -func (v Long) String() string { return v.Cedar() } - -// Cedar produces a valid Cedar language representation of the Long, e.g. `42`. -func (v Long) Cedar() string { - return fmt.Sprint(int64(v)) -} -func (v Long) deepClone() Value { return v } - -// A String is a sequence of characters consisting of letters, numbers, or symbols. -type String string - -func (a String) Equal(bi Value) bool { - b, ok := bi.(String) - return ok && a == b -} - -// ExplicitMarshalJSON marshals the String into JSON. -func (v String) ExplicitMarshalJSON() ([]byte, error) { return json.Marshal(v) } -func (v String) TypeName() string { return "string" } - -// String produces an unquoted string representation of the String, e.g. `hello`. -func (v String) String() string { - return string(v) -} - -// Cedar produces a valid Cedar language representation of the String, e.g. `"hello"`. -func (v String) Cedar() string { - return strconv.Quote(string(v)) -} -func (v String) deepClone() Value { return v } - -// A Set is a collection of elements that can be of the same or different types. -type Set []Value - -func (s Set) Contains(v Value) bool { - for _, e := range s { - if e.Equal(v) { - return true - } - } - return false -} - -// Equals returns true if the sets are Equal. -func (s Set) Equals(b Set) bool { return s.Equal(b) } - -func (as Set) Equal(bi Value) bool { - bs, ok := bi.(Set) - if !ok { - return false - } - for _, a := range as { - if !bs.Contains(a) { - return false - } - } - for _, b := range bs { - if !as.Contains(b) { - return false - } - } - return true -} - -func (v *explicitValue) UnmarshalJSON(b []byte) error { - return UnmarshalJSON(b, &v.Value) -} - -func (v *Set) UnmarshalJSON(b []byte) error { - var res []explicitValue - err := json.Unmarshal(b, &res) - if err != nil { - return err - } - for _, vv := range res { - *v = append(*v, vv.Value) - } - return nil -} - -// MarshalJSON marshals the Set into JSON, the marshaller uses the explicit JSON -// form for all the values in the Set. -func (v Set) MarshalJSON() ([]byte, error) { - w := &bytes.Buffer{} - w.WriteByte('[') - for i, vv := range v { - if i > 0 { - w.WriteByte(',') - } - b, err := vv.ExplicitMarshalJSON() - if err != nil { - return nil, err - } - w.Write(b) - } - w.WriteByte(']') - return w.Bytes(), nil -} - -// ExplicitMarshalJSON marshals the Set into JSON, the marshaller uses the -// explicit JSON form for all the values in the Set. -func (v Set) ExplicitMarshalJSON() ([]byte, error) { return v.MarshalJSON() } - -func (v Set) TypeName() string { return "set" } - -// String produces a string representation of the Set, e.g. `[1,2,3]`. -func (v Set) String() string { return v.Cedar() } - -// Cedar produces a valid Cedar language representation of the Set, e.g. `[1,2,3]`. -func (v Set) Cedar() string { - var sb strings.Builder - sb.WriteRune('[') - for i, elem := range v { - if i > 0 { - sb.WriteString(", ") - } - sb.WriteString(elem.Cedar()) - } - sb.WriteRune(']') - return sb.String() -} -func (v Set) deepClone() Value { return v.DeepClone() } - -// DeepClone returns a deep clone of the Set. -func (v Set) DeepClone() Set { - if v == nil { - return v - } - res := make(Set, len(v)) - for i, vv := range v { - res[i] = vv.deepClone() - } - return res -} - -// A Record is a collection of attributes. Each attribute consists of a name and -// an associated value. Names are simple strings. Values can be of any type. -type Record map[string]Value - -// Equals returns true if the records are Equal. -func (r Record) Equals(b Record) bool { return r.Equal(b) } - -func (a Record) Equal(bi Value) bool { - b, ok := bi.(Record) - if !ok || len(a) != len(b) { - return false - } - for k, av := range a { - bv, ok := b[k] - if !ok || !av.Equal(bv) { - return false - } - } - return true -} - -func (v *Record) UnmarshalJSON(b []byte) error { - var res map[string]explicitValue - err := json.Unmarshal(b, &res) - if err != nil { - return err - } - *v = Record{} - for kk, vv := range res { - (*v)[kk] = vv.Value - } - return nil -} - -// MarshalJSON marshals the Record into JSON, the marshaller uses the explicit -// JSON form for all the values in the Record. -func (v Record) MarshalJSON() ([]byte, error) { - w := &bytes.Buffer{} - w.WriteByte('{') - keys := maps.Keys(v) - slices.Sort(keys) - for i, kk := range keys { - if i > 0 { - w.WriteByte(',') - } - kb, _ := json.Marshal(kk) // json.Marshal cannot error on strings - w.Write(kb) - w.WriteByte(':') - vv := v[kk] - vb, err := vv.ExplicitMarshalJSON() - if err != nil { - return nil, err - } - w.Write(vb) - } - w.WriteByte('}') - return w.Bytes(), nil -} - -// ExplicitMarshalJSON marshals the Record into JSON, the marshaller uses the -// explicit JSON form for all the values in the Record. -func (v Record) ExplicitMarshalJSON() ([]byte, error) { return v.MarshalJSON() } -func (r Record) TypeName() string { return "record" } - -// String produces a string representation of the Record, e.g. `{"a":1,"b":2,"c":3}`. -func (r Record) String() string { return r.Cedar() } - -// Cedar produces a valid Cedar language representation of the Record, e.g. `{"a":1,"b":2,"c":3}`. -func (r Record) Cedar() string { - var sb strings.Builder - sb.WriteRune('{') - first := true - keys := maps.Keys(r) - slices.Sort(keys) - for _, k := range keys { - v := r[k] - if !first { - sb.WriteString(", ") - } - first = false - sb.WriteString(strconv.Quote(k)) - sb.WriteString(": ") - sb.WriteString(v.Cedar()) - } - sb.WriteRune('}') - return sb.String() -} -func (v Record) deepClone() Value { return v.DeepClone() } - -// DeepClone returns a deep clone of the Record. -func (v Record) DeepClone() Record { - if v == nil { - return v - } - res := make(Record, len(v)) - for k, vv := range v { - res[k] = vv.deepClone() - } - return res -} - -// An EntityUID is the identifier for a principal, action, or resource. -type EntityUID struct { - Type EntityType - ID string -} - -func NewEntityUID(typ EntityType, id string) EntityUID { - return EntityUID{ - Type: typ, - ID: id, - } -} - -// IsZero returns true if the EntityUID has an empty Type and ID. -func (a EntityUID) IsZero() bool { - return a.Type == "" && a.ID == "" -} - -func (a EntityUID) Equal(bi Value) bool { - b, ok := bi.(EntityUID) - return ok && a == b -} -func (v EntityUID) TypeName() string { return fmt.Sprintf("(entity of type `%s`)", v.Type) } - -// String produces a string representation of the EntityUID, e.g. `Type::"id"`. -func (v EntityUID) String() string { return v.Cedar() } - -// Cedar produces a valid Cedar language representation of the EntityUID, e.g. `Type::"id"`. -func (v EntityUID) Cedar() string { - return v.Type.String() + "::" + strconv.Quote(v.ID) -} - -func (v *EntityUID) UnmarshalJSON(b []byte) error { - // TODO: review after adding support for schemas - var res entityValueJSON - if err := json.Unmarshal(b, &res); err != nil { - return err - } - if res.Entity != nil { - v.Type = res.Entity.Type - v.ID = res.Entity.ID - return nil - } else if res.Type != nil && res.ID != nil { // require both Type and ID to parse "implicit" JSON - v.Type = *res.Type - v.ID = *res.ID - return nil - } - return errJSONEntityNotFound -} - -// ExplicitMarshalJSON marshals the EntityUID into JSON using the implicit form. -func (v EntityUID) MarshalJSON() ([]byte, error) { - return json.Marshal(entityValueJSON{ - Type: &v.Type, - ID: &v.ID, - }) -} - -// ExplicitMarshalJSON marshals the EntityUID into JSON using the explicit form. -func (v EntityUID) ExplicitMarshalJSON() ([]byte, error) { - return json.Marshal(entityValueJSON{ - Entity: &extEntity{ - Type: v.Type, - ID: v.ID, - }, - }) -} -func (v EntityUID) deepClone() Value { return v } - -// EntityType is the type portion of an EntityUID -type EntityType string - -func (a EntityType) Equal(bi Value) bool { - b, ok := bi.(EntityType) - return ok && a == b -} -func (v EntityType) TypeName() string { return fmt.Sprintf("(EntityType of type `%s`)", v) } - -func (v EntityType) String() string { return string(v) } -func (v EntityType) Cedar() string { return string(v) } -func (v EntityType) ExplicitMarshalJSON() ([]byte, error) { return json.Marshal(string(v)) } -func (v EntityType) deepClone() Value { return v } - -func EntityTypeFromSlice(v []string) EntityType { - return EntityType(strings.Join(v, "::")) -} - -// A Decimal is a value with both a whole number part and a decimal part of no -// more than four digits. In Go this is stored as an int64, the precision is -// defined by the constant DecimalPrecision. -type Decimal int64 - -// DecimalPrecision is the precision of a Decimal. -const DecimalPrecision = 10000 - -// ParseDecimal takes a string representation of a decimal number and converts it into a Decimal type. -func ParseDecimal(s string) (Decimal, error) { - // Check for empty string. - if len(s) == 0 { - return Decimal(0), fmt.Errorf("%w: string too short", ErrDecimal) - } - i := 0 - - // Parse an optional '-'. - negative := false - if s[i] == '-' { - negative = true - i++ - if i == len(s) { - return Decimal(0), fmt.Errorf("%w: string too short", ErrDecimal) - } - } - - // Parse the required first digit. - c := rune(s[i]) - if !unicode.IsDigit(c) { - return Decimal(0), fmt.Errorf("%w: unexpected character %s", ErrDecimal, strconv.QuoteRune(c)) - } - integer := int64(c - '0') - i++ - - // Parse any other digits, ending with i pointing to '.'. - for ; ; i++ { - if i == len(s) { - return Decimal(0), fmt.Errorf("%w: string missing decimal point", ErrDecimal) - } - c = rune(s[i]) - if c == '.' { - break - } - if !unicode.IsDigit(c) { - return Decimal(0), fmt.Errorf("%w: unexpected character %s", ErrDecimal, strconv.QuoteRune(c)) - } - integer = 10*integer + int64(c-'0') - if integer > 922337203685477 { - return Decimal(0), fmt.Errorf("%w: overflow", ErrDecimal) - } - } - - // Advance past the '.'. - i++ - - // Parse the fraction part - fraction := int64(0) - fractionDigits := 0 - for ; i < len(s); i++ { - c = rune(s[i]) - if !unicode.IsDigit(c) { - return Decimal(0), fmt.Errorf("%w: unexpected character %s", ErrDecimal, strconv.QuoteRune(c)) - } - fraction = 10*fraction + int64(c-'0') - fractionDigits++ - } - - // Adjust the fraction part based on how many digits we parsed. - switch fractionDigits { - case 0: - return Decimal(0), fmt.Errorf("%w: missing digits after decimal point", ErrDecimal) - case 1: - fraction *= 1000 - case 2: - fraction *= 100 - case 3: - fraction *= 10 - case 4: - default: - return Decimal(0), fmt.Errorf("%w: too many digits after decimal point", ErrDecimal) - } - - // Check for overflow before we put the number together. - if integer >= 922337203685477 && (fraction > 5808 || (!negative && fraction == 5808)) { - return Decimal(0), fmt.Errorf("%w: overflow", ErrDecimal) - } - - // Put the number together. - if negative { - // Doing things in this order keeps us from overflowing when parsing - // -922337203685477.5808. This isn't technically necessary because the - // go spec defines arithmetic to be well-defined when overflowing. - // However, doing things this way doesn't hurt, so let's be pedantic. - return Decimal(DecimalPrecision*-integer - fraction), nil - } else { - return Decimal(DecimalPrecision*integer + fraction), nil - } -} - -func (a Decimal) Equal(bi Value) bool { - b, ok := bi.(Decimal) - return ok && a == b -} - -func (v Decimal) TypeName() string { return "decimal" } - -// Cedar produces a valid Cedar language representation of the Decimal, e.g. `decimal("12.34")`. -func (v Decimal) Cedar() string { return `decimal("` + v.String() + `")` } - -// String produces a string representation of the Decimal, e.g. `12.34`. -func (v Decimal) String() string { - var res string - if v < 0 { - // Make sure we don't overflow here. Also, go truncates towards zero. - integer := v / DecimalPrecision - decimal := integer*DecimalPrecision - v - res = fmt.Sprintf("-%d.%04d", -integer, decimal) - } else { - res = fmt.Sprintf("%d.%04d", v/DecimalPrecision, v%DecimalPrecision) - } - - // Trim off up to three trailing zeros. - right := len(res) - for trimmed := 0; right-1 >= 0 && trimmed < 3; right, trimmed = right-1, trimmed+1 { - if res[right-1] != '0' { - break - } - } - return res[:right] -} - -func (v *Decimal) UnmarshalJSON(b []byte) error { - var arg string - if len(b) > 0 && b[0] == '"' { - if err := json.Unmarshal(b, &arg); err != nil { - return errors.Join(errJSONDecode, err) - } - } else { - // NOTE: cedar supports two other forms, for now we're only supporting the smallest implicit and explicit form. - // The following are not supported: - // "decimal(\"1234.5678\")" - // {"fn":"decimal","arg":"1234.5678"} - var res extValueJSON - if err := json.Unmarshal(b, &res); err != nil { - return errors.Join(errJSONDecode, err) - } - if res.Extn == nil { - return errJSONExtNotFound - } - if res.Extn.Fn != "decimal" { - return errJSONExtFnMatch - } - arg = res.Extn.Arg - } - vv, err := ParseDecimal(arg) - if err != nil { - return err - } - *v = vv - return nil -} - -// ExplicitMarshalJSON marshals the Decimal into JSON using the implicit form. -func (v Decimal) MarshalJSON() ([]byte, error) { return []byte(`"` + v.String() + `"`), nil } - -// ExplicitMarshalJSON marshals the Decimal into JSON using the explicit form. -func (v Decimal) ExplicitMarshalJSON() ([]byte, error) { - return json.Marshal(extValueJSON{ - Extn: &extn{ - Fn: "decimal", - Arg: v.String(), - }, - }) -} -func (v Decimal) deepClone() Value { return v } - -// An IPAddr is value that represents an IP address. It can be either IPv4 or IPv6. -// The value can represent an individual address or a range of addresses. -type IPAddr netip.Prefix - -// ParseIPAddr takes a string representation of an IP address and converts it into an IPAddr type. -func ParseIPAddr(s string) (IPAddr, error) { - // We disallow IPv4-mapped IPv6 addresses in dotted notation because Cedar does. - if strings.Count(s, ":") >= 2 && strings.Count(s, ".") >= 2 { - return IPAddr{}, fmt.Errorf("%w: cannot parse IPv4 addresses embedded in IPv6 addresses", ErrIP) - } else if net, err := netip.ParsePrefix(s); err == nil { - return IPAddr(net), nil - } else if addr, err := netip.ParseAddr(s); err == nil { - return IPAddr(netip.PrefixFrom(addr, addr.BitLen())), nil - } else { - return IPAddr{}, fmt.Errorf("%w: error parsing IP address %s", ErrIP, s) - } -} - -func (a IPAddr) Equal(bi Value) bool { - b, ok := bi.(IPAddr) - return ok && a == b -} - -func (v IPAddr) TypeName() string { return "IP" } - -// Cedar produces a valid Cedar language representation of the IPAddr, e.g. `ip("127.0.0.1")`. -func (v IPAddr) Cedar() string { return `ip("` + v.String() + `")` } - -// String produces a string representation of the IPAddr, e.g. `127.0.0.1`. -func (v IPAddr) String() string { - if v.Prefix().Bits() == v.Addr().BitLen() { - return v.Addr().String() - } - return v.Prefix().String() -} - -func (v IPAddr) Prefix() netip.Prefix { - return netip.Prefix(v) -} - -func (v IPAddr) IsIPv4() bool { - return v.Addr().Is4() -} - -func (v IPAddr) IsIPv6() bool { - return v.Addr().Is6() -} - -func (v IPAddr) IsLoopback() bool { - // This comment is in the Cedar Rust implementation: - // - // Loopback addresses are "127.0.0.0/8" for IpV4 and "::1" for IpV6 - // - // Unlike the implementation of `is_multicast`, we don't need to test prefix - // - // The reason for IpV6 is obvious: There's only one loopback address - // - // The reason for IpV4 is that provided the truncated ip address is a - // loopback address, its prefix cannot be less than 8 because - // otherwise its more significant byte cannot be 127 - return v.Prefix().Masked().Addr().IsLoopback() -} - -func (v IPAddr) Addr() netip.Addr { - return netip.Prefix(v).Addr() -} - -func (v IPAddr) IsMulticast() bool { - // This comment is in the Cedar Rust implementation: - // - // Multicast addresses are "224.0.0.0/4" for IpV4 and "ff00::/8" for - // IpV6 - // - // If an IpNet's addresses are multicast addresses, calling - // `is_in_range()` over it and its associated net above should - // evaluate to true - // - // The implementation uses the property that if `ip1/prefix1` is in - // range `ip2/prefix2`, then `ip1` is in `ip2/prefix2` and `prefix1 >= - // prefix2` - var min_prefix_len int - if v.IsIPv4() { - min_prefix_len = 4 - } else { - min_prefix_len = 8 - } - return v.Addr().IsMulticast() && v.Prefix().Bits() >= min_prefix_len -} - -func (c IPAddr) Contains(o IPAddr) bool { - return c.Prefix().Contains(o.Addr()) && c.Prefix().Bits() <= o.Prefix().Bits() -} - -func (v *IPAddr) UnmarshalJSON(b []byte) error { - var arg string - if len(b) > 0 && b[0] == '"' { - if err := json.Unmarshal(b, &arg); err != nil { - return errors.Join(errJSONDecode, err) - } - } else { - // NOTE: cedar supports two other forms, for now we're only supporting the smallest implicit explicit form. - // The following are not supported: - // "ip(\"192.168.0.42\")" - // {"fn":"ip","arg":"192.168.0.42"} - var res extValueJSON - if err := json.Unmarshal(b, &res); err != nil { - return errors.Join(errJSONDecode, err) - } - if res.Extn == nil { - return errJSONExtNotFound - } - if res.Extn.Fn != "ip" { - return errJSONExtFnMatch - } - arg = res.Extn.Arg - } - vv, err := ParseIPAddr(arg) - if err != nil { - return err - } - *v = vv - return nil -} - -// ExplicitMarshalJSON marshals the IPAddr into JSON using the implicit form. -func (v IPAddr) MarshalJSON() ([]byte, error) { return []byte(`"` + v.String() + `"`), nil } - -// ExplicitMarshalJSON marshals the IPAddr into JSON using the explicit form. -func (v IPAddr) ExplicitMarshalJSON() ([]byte, error) { - if v.Prefix().Bits() == v.Prefix().Addr().BitLen() { - return json.Marshal(extValueJSON{ - Extn: &extn{ - Fn: "ip", - Arg: v.Addr().String(), - }, - }) - } - return json.Marshal(extValueJSON{ - Extn: &extn{ - Fn: "ip", - Arg: v.String(), - }, - }) -} - -// in this case, netip.Prefix does contain a pointer, but -// the interface given is immutable, so it is safe to return -func (v IPAddr) deepClone() Value { return v } diff --git a/types/value_test.go b/types/value_test.go index 3d026d9f..2712195d 100644 --- a/types/value_test.go +++ b/types/value_test.go @@ -1,631 +1,11 @@ package types import ( - "fmt" "testing" "github.com/cedar-policy/cedar-go/internal/testutil" ) -func TestBool(t *testing.T) { - t.Parallel() - - t.Run("Equal", func(t *testing.T) { - t.Parallel() - t1 := Boolean(true) - t2 := Boolean(true) - f := Boolean(false) - zero := Long(0) - testutil.FatalIf(t, !t1.Equal(t1), "%v not Equal to %v", t1, t1) - testutil.FatalIf(t, !t1.Equal(t2), "%v not Equal to %v", t1, t2) - testutil.FatalIf(t, t1.Equal(f), "%v Equal to %v", t1, f) - testutil.FatalIf(t, f.Equal(t1), "%v Equal to %v", f, t1) - testutil.FatalIf(t, f.Equal(zero), "%v Equal to %v", f, zero) - }) - - t.Run("string", func(t *testing.T) { - t.Parallel() - AssertValueString(t, Boolean(true), "true") - }) - - t.Run("TypeName", func(t *testing.T) { - t.Parallel() - tn := Boolean(true).TypeName() - testutil.Equals(t, tn, "bool") - }) -} - -func TestLong(t *testing.T) { - t.Parallel() - - t.Run("Equal", func(t *testing.T) { - t.Parallel() - one := Long(1) - one2 := Long(1) - zero := Long(0) - f := Boolean(false) - testutil.FatalIf(t, !one.Equal(one), "%v not Equal to %v", one, one) - testutil.FatalIf(t, !one.Equal(one2), "%v not Equal to %v", one, one2) - testutil.FatalIf(t, one.Equal(zero), "%v Equal to %v", one, zero) - testutil.FatalIf(t, zero.Equal(one), "%v Equal to %v", zero, one) - testutil.FatalIf(t, zero.Equal(f), "%v Equal to %v", zero, f) - }) - - t.Run("string", func(t *testing.T) { - t.Parallel() - AssertValueString(t, Long(1), "1") - }) - - t.Run("TypeName", func(t *testing.T) { - t.Parallel() - tn := Long(1).TypeName() - testutil.Equals(t, tn, "long") - }) -} - -func TestString(t *testing.T) { - t.Parallel() - - t.Run("Equal", func(t *testing.T) { - t.Parallel() - hello := String("hello") - hello2 := String("hello") - goodbye := String("goodbye") - testutil.FatalIf(t, !hello.Equal(hello), "%v not Equal to %v", hello, hello) - testutil.FatalIf(t, !hello.Equal(hello2), "%v not Equal to %v", hello, hello2) - testutil.FatalIf(t, hello.Equal(goodbye), "%v Equal to %v", hello, goodbye) - }) - - t.Run("string", func(t *testing.T) { - t.Parallel() - AssertValueString(t, String("hello"), `hello`) - AssertValueString(t, String("hello\ngoodbye"), "hello\ngoodbye") - }) - - t.Run("TypeName", func(t *testing.T) { - t.Parallel() - tn := String("hello").TypeName() - testutil.Equals(t, tn, "string") - }) -} - -func TestSet(t *testing.T) { - t.Parallel() - - t.Run("Equal", func(t *testing.T) { - t.Parallel() - empty := Set{} - empty2 := Set{} - oneTrue := Set{Boolean(true)} - oneTrue2 := Set{Boolean(true)} - oneFalse := Set{Boolean(false)} - nestedOnce := Set{empty, oneTrue, oneFalse} - nestedOnce2 := Set{empty, oneTrue, oneFalse} - nestedTwice := Set{empty, oneTrue, oneFalse, nestedOnce} - nestedTwice2 := Set{empty, oneTrue, oneFalse, nestedOnce} - oneTwoThree := Set{ - Long(1), Long(2), Long(3), - } - threeTwoTwoOne := Set{ - Long(3), Long(2), Long(2), Long(1), - } - - testutil.FatalIf(t, !empty.Equals(empty), "%v not Equal to %v", empty, empty) - testutil.FatalIf(t, !empty.Equals(empty2), "%v not Equal to %v", empty, empty2) - testutil.FatalIf(t, !oneTrue.Equals(oneTrue), "%v not Equal to %v", oneTrue, oneTrue) - testutil.FatalIf(t, !oneTrue.Equals(oneTrue2), "%v not Equal to %v", oneTrue, oneTrue2) - testutil.FatalIf(t, !nestedOnce.Equals(nestedOnce), "%v not Equal to %v", nestedOnce, nestedOnce) - testutil.FatalIf(t, !nestedOnce.Equals(nestedOnce2), "%v not Equal to %v", nestedOnce, nestedOnce2) - testutil.FatalIf(t, !nestedTwice.Equals(nestedTwice), "%v not Equal to %v", nestedTwice, nestedTwice) - testutil.FatalIf(t, !nestedTwice.Equals(nestedTwice2), "%v not Equal to %v", nestedTwice, nestedTwice2) - testutil.FatalIf(t, !oneTwoThree.Equals(threeTwoTwoOne), "%v not Equal to %v", oneTwoThree, threeTwoTwoOne) - - testutil.FatalIf(t, empty.Equals(oneFalse), "%v Equal to %v", empty, oneFalse) - testutil.FatalIf(t, oneTrue.Equals(oneFalse), "%v Equal to %v", oneTrue, oneFalse) - testutil.FatalIf(t, nestedOnce.Equals(nestedTwice), "%v Equal to %v", nestedOnce, nestedTwice) - }) - - t.Run("string", func(t *testing.T) { - t.Parallel() - AssertValueString(t, Set{}, "[]") - AssertValueString( - t, - Set{Boolean(true), Long(1)}, - "[true, 1]") - }) - - t.Run("TypeName", func(t *testing.T) { - t.Parallel() - tn := Set{}.TypeName() - testutil.Equals(t, tn, "set") - }) -} - -func TestRecord(t *testing.T) { - t.Parallel() - - t.Run("Equal", func(t *testing.T) { - t.Parallel() - empty := Record{} - empty2 := Record{} - twoElems := Record{ - "foo": Boolean(true), - "bar": String("blah"), - } - twoElems2 := Record{ - "foo": Boolean(true), - "bar": String("blah"), - } - differentValues := Record{ - "foo": Boolean(false), - "bar": String("blaz"), - } - differentKeys := Record{ - "foo": Boolean(false), - "bar": Long(1), - } - nested := Record{ - "one": Long(1), - "two": Long(2), - "nest": twoElems, - } - nested2 := Record{ - "one": Long(1), - "two": Long(2), - "nest": twoElems, - } - - testutil.FatalIf(t, !empty.Equals(empty), "%v not Equal to %v", empty, empty) - testutil.FatalIf(t, !empty.Equals(empty2), "%v not Equal to %v", empty, empty2) - - testutil.FatalIf(t, !twoElems.Equals(twoElems), "%v not Equal to %v", twoElems, twoElems) - testutil.FatalIf(t, !twoElems.Equals(twoElems2), "%v not Equal to %v", twoElems, twoElems2) - - testutil.FatalIf(t, !nested.Equals(nested), "%v not Equal to %v", nested, nested) - testutil.FatalIf(t, !nested.Equals(nested2), "%v not Equal to %v", nested, nested2) - - testutil.FatalIf(t, nested.Equals(twoElems), "%v Equal to %v", nested, twoElems) - testutil.FatalIf(t, twoElems.Equals(differentValues), "%v Equal to %v", twoElems, differentValues) - testutil.FatalIf(t, twoElems.Equals(differentKeys), "%v Equal to %v", twoElems, differentKeys) - }) - - t.Run("string", func(t *testing.T) { - t.Parallel() - AssertValueString(t, Record{}, "{}") - AssertValueString( - t, - Record{"foo": Boolean(true)}, - `{"foo": true}`) - AssertValueString( - t, - Record{ - "foo": Boolean(true), - "bar": String("blah"), - }, - `{"bar": "blah", "foo": true}`) - }) - - t.Run("TypeName", func(t *testing.T) { - t.Parallel() - tn := Record{}.TypeName() - testutil.Equals(t, tn, "record") - }) -} - -func TestEntity(t *testing.T) { - t.Parallel() - - t.Run("Equal", func(t *testing.T) { - t.Parallel() - twoElems := EntityUID{"type", "id"} - twoElems2 := EntityUID{"type", "id"} - differentValues := EntityUID{"asdf", "vfds"} - testutil.FatalIf(t, !twoElems.Equal(twoElems), "%v not Equal to %v", twoElems, twoElems) - testutil.FatalIf(t, !twoElems.Equal(twoElems2), "%v not Equal to %v", twoElems, twoElems2) - testutil.FatalIf(t, twoElems.Equal(differentValues), "%v Equal to %v", twoElems, differentValues) - }) - - t.Run("string", func(t *testing.T) { - t.Parallel() - AssertValueString(t, EntityUID{Type: "type", ID: "id"}, `type::"id"`) - AssertValueString(t, EntityUID{Type: "namespace::type", ID: "id"}, `namespace::type::"id"`) - }) - - t.Run("TypeName", func(t *testing.T) { - t.Parallel() - tn := EntityUID{"T", "id"}.TypeName() - testutil.Equals(t, tn, "(entity of type `T`)") - }) -} - -func TestDecimal(t *testing.T) { - t.Parallel() - { - tests := []struct{ in, out string }{ - {"1.2345", "1.2345"}, - {"1.2340", "1.234"}, - {"1.2300", "1.23"}, - {"1.2000", "1.2"}, - {"1.0000", "1.0"}, - {"1.234", "1.234"}, - {"1.230", "1.23"}, - {"1.200", "1.2"}, - {"1.000", "1.0"}, - {"1.23", "1.23"}, - {"1.20", "1.2"}, - {"1.00", "1.0"}, - {"1.2", "1.2"}, - {"1.0", "1.0"}, - {"01.2345", "1.2345"}, - {"01.2340", "1.234"}, - {"01.2300", "1.23"}, - {"01.2000", "1.2"}, - {"01.0000", "1.0"}, - {"01.234", "1.234"}, - {"01.230", "1.23"}, - {"01.200", "1.2"}, - {"01.000", "1.0"}, - {"01.23", "1.23"}, - {"01.20", "1.2"}, - {"01.00", "1.0"}, - {"01.2", "1.2"}, - {"01.0", "1.0"}, - {"1234.5678", "1234.5678"}, - {"1234.5670", "1234.567"}, - {"1234.5600", "1234.56"}, - {"1234.5000", "1234.5"}, - {"1234.0000", "1234.0"}, - {"1234.567", "1234.567"}, - {"1234.560", "1234.56"}, - {"1234.500", "1234.5"}, - {"1234.000", "1234.0"}, - {"1234.56", "1234.56"}, - {"1234.50", "1234.5"}, - {"1234.00", "1234.0"}, - {"1234.5", "1234.5"}, - {"1234.0", "1234.0"}, - {"0.0", "0.0"}, - {"00.0", "0.0"}, - {"000000000000000000000000000000000000000000000000000000000000000000.0", "0.0"}, - {"922337203685477.5807", "922337203685477.5807"}, - {"-922337203685477.5808", "-922337203685477.5808"}, - } - for _, tt := range tests { - tt := tt - t.Run(fmt.Sprintf("%s->%s", tt.in, tt.out), func(t *testing.T) { - t.Parallel() - d, err := ParseDecimal(tt.in) - testutil.OK(t, err) - testutil.Equals(t, d.String(), tt.out) - }) - } - } - - { - tests := []struct{ in, errStr string }{ - {"", "error parsing decimal value: string too short"}, - {"-", "error parsing decimal value: string too short"}, - {"a", "error parsing decimal value: unexpected character 'a'"}, - {"-a", "error parsing decimal value: unexpected character 'a'"}, - {"'", `error parsing decimal value: unexpected character '\''`}, - {`-\\`, `error parsing decimal value: unexpected character '\\'`}, - {"0", "error parsing decimal value: string missing decimal point"}, - {"1a", "error parsing decimal value: unexpected character 'a'"}, - {"1a", "error parsing decimal value: unexpected character 'a'"}, - {"1.", "error parsing decimal value: missing digits after decimal point"}, - {"1.00000", "error parsing decimal value: too many digits after decimal point"}, - {"1.a", "error parsing decimal value: unexpected character 'a'"}, - {"1.0a", "error parsing decimal value: unexpected character 'a'"}, - {"1.0000a", "error parsing decimal value: unexpected character 'a'"}, - {"1.0000a", "error parsing decimal value: unexpected character 'a'"}, - - {"1000000000000000.0", "error parsing decimal value: overflow"}, - {"-1000000000000000.0", "error parsing decimal value: overflow"}, - {"922337203685477.5808", "error parsing decimal value: overflow"}, - {"922337203685478.0", "error parsing decimal value: overflow"}, - {"-922337203685477.5809", "error parsing decimal value: overflow"}, - {"-922337203685478.0", "error parsing decimal value: overflow"}, - } - for _, tt := range tests { - tt := tt - t.Run(fmt.Sprintf("%s->%s", tt.in, tt.errStr), func(t *testing.T) { - t.Parallel() - _, err := ParseDecimal(tt.in) - testutil.AssertError(t, err, ErrDecimal) - testutil.Equals(t, err.Error(), tt.errStr) - }) - } - } - - t.Run("Equal", func(t *testing.T) { - t.Parallel() - one := Decimal(10000) - one2 := Decimal(10000) - zero := Decimal(0) - f := Boolean(false) - testutil.FatalIf(t, !one.Equal(one), "%v not Equal to %v", one, one) - testutil.FatalIf(t, !one.Equal(one2), "%v not Equal to %v", one, one2) - testutil.FatalIf(t, one.Equal(zero), "%v Equal to %v", one, zero) - testutil.FatalIf(t, zero.Equal(one), "%v Equal to %v", zero, one) - testutil.FatalIf(t, zero.Equal(f), "%v Equal to %v", zero, f) - }) - - t.Run("TypeName", func(t *testing.T) { - t.Parallel() - tn := Decimal(0).TypeName() - testutil.Equals(t, tn, "decimal") - }) -} - -func TestIP(t *testing.T) { - t.Parallel() - t.Run("ParseAndString", func(t *testing.T) { - t.Parallel() - - tests := []struct { - in string - parses bool - out string - }{ - {"0.0.0.0", true, "0.0.0.0"}, - {"0.0.0.1", true, "0.0.0.1"}, - {"127.0.0.1", true, "127.0.0.1"}, - {"127.0.0.1/32", true, "127.0.0.1"}, - {"127.0.0.1/24", true, "127.0.0.1/24"}, - {"127.1.2.3/8", true, "127.1.2.3/8"}, - {"::/128", true, "::"}, - {"::1/128", true, "::1"}, - {"2001:db8::1", true, "2001:db8::1"}, - {"2001:db8::1:0:0:1", true, "2001:db8::1:0:0:1"}, - {"::ffff:192.0.2.128", false, ""}, - {"::ffff:c000:0280", true, "::ffff:192.0.2.128"}, - {"2001:db8::1/32", true, "2001:db8::1/32"}, - {"2001:db8::1:0:0:1/96", true, "2001:db8::1:0:0:1/96"}, - {"::ffff:192.0.2.128/24", false, ""}, - {"::ffff:192.0.2.128/120", false, ""}, - {"::ffff:c000:0280/24", true, "::ffff:192.0.2.128/24"}, - {"::ffff:c000:0280/120", true, "::ffff:192.0.2.128/120"}, - {"6b6b:f00::32ff:ffff:6368/00", false, ""}, // leading zero(s) - {"garbage", false, ""}, - {"c5c5:c5c5:c5c5:c5c5:c5c5:c5c5:c5c5:c5c5/68", true, "c5c5:c5c5:c5c5:c5c5:c5c5:c5c5:c5c5:c5c5/68"}, - } - for _, tt := range tests { - tt := tt - var testName string - if tt.parses { - testName = fmt.Sprintf("%s-parses-and-prints-as-%s", tt.in, tt.out) - } else { - testName = fmt.Sprintf("%s-does-not-parse", tt.in) - } - t.Run(testName, func(t *testing.T) { - t.Parallel() - i, err := ParseIPAddr(tt.in) - if tt.parses { - testutil.OK(t, err) - testutil.Equals(t, i.String(), tt.out) - } else { - testutil.Error(t, err) - } - }) - } - }) - - t.Run("Equal", func(t *testing.T) { - t.Parallel() - tests := []struct { - lhs, rhs string - equal bool - }{ - {"0.0.0.0", "0.0.0.0", true}, - {"0.0.0.0", "0.0.0.0/32", true}, - {"127.0.0.1", "127.0.0.1", true}, - {"127.0.0.1", "127.0.0.1/32", true}, - {"::", "::", true}, - {"::", "::/128", true}, - {"::1", "::1", true}, - {"::1", "::1/128", true}, - {"::", "0.0.0.0", false}, - {"::1", "127.0.0.1", false}, - {"::ffff:c000:0280", "192.0.2.128", false}, - {"1.2.3.4", "1.2.3.4", true}, - {"1.2.3.4", "1.2.3.4/32", true}, - {"1.2.3.4/32", "1.2.3.4/32", true}, - {"1.2.3.4/24", "1.2.3.4/24", true}, - {"1.2.3.0/24", "1.2.3.255/24", false}, - {"1.2.3.0/24", "1.2.3.0/25", false}, - {"::ffff:c000:0280/24", "::/24", false}, - {"::ffff:c000:0280/120", "192.0.2.0/24", false}, - {"2001:db8::1/32", "2001:db8::/32", false}, - {"2001:db8::1:0:0:1/96", "2001:db8:0:0:1::/96", false}, - {"c5c5:c5c5:c5c5:c5c5:c5c5:c5c5:c5c5:c5c5/68", "c5c5:c5c5:c5c5:c5c5:c5c5:5cc5:c5c5:c5c5/68", false}, - } - for _, tt := range tests { - tt := tt - t.Run(fmt.Sprintf("ip(%v).Equal(ip(%v))", tt.lhs, tt.rhs), func(t *testing.T) { - t.Parallel() - lhs, err := ParseIPAddr(tt.lhs) - testutil.OK(t, err) - rhs, err := ParseIPAddr(tt.rhs) - testutil.OK(t, err) - equal := lhs.Equal(rhs) - if equal != tt.equal { - t.Fatalf("expected ip(%v).Equal(ip(%v)) to be %v instead of %v", tt.lhs, tt.rhs, tt.equal, equal) - } - if equal { - testutil.FatalIf( - t, - !lhs.Contains(rhs), - "ip(%v) and ip(%v) compare Equal but !ip(%v).contains(ip(%v))", tt.lhs, tt.rhs, tt.lhs, tt.rhs) - testutil.FatalIf( - t, - !rhs.Contains(lhs), - "ip(%v) and ip(%v) compare Equal but !ip(%v).contains(ip(%v))", tt.rhs, tt.lhs, tt.rhs, tt.lhs) - } - }) - } - }) - - t.Run("isIPv4", func(t *testing.T) { - t.Parallel() - tests := []struct { - val string - isIPv4, isIPv6 bool - }{ - {"0.0.0.0", true, false}, - {"0.0.0.0/32", true, false}, - {"127.0.0.1", true, false}, - {"127.0.0.1/32", true, false}, - {"::", false, true}, - {"::1", false, true}, - {"::/128", false, true}, - {"::1/128", false, true}, - {"::ffff:c000:0280", false, true}, - {"::ffff:c000:0280/128", false, true}, - {"::ffff:c000:0280/24", false, true}, - {"2001:db8::1", false, true}, - {"2001:db8::1:0:0:1", false, true}, - {"2001:db8::1/32", false, true}, - {"2001:db8::1:0:0:1/96", false, true}, - } - for _, tt := range tests { - tt := tt - t.Run(fmt.Sprintf("ip(%v).isIPv{4,6}()", tt.val), func(t *testing.T) { - t.Parallel() - val, err := ParseIPAddr(tt.val) - testutil.OK(t, err) - isIPv4 := val.IsIPv4() - if isIPv4 != tt.isIPv4 { - t.Fatalf("expected ip(%v).isIPv4() to be %v instead of %v", tt.val, tt.isIPv4, isIPv4) - } - isIPv6 := val.IsIPv6() - if isIPv6 != tt.isIPv6 { - t.Fatalf("expected ip(%v).isIPv6() to be %v instead of %v", tt.val, tt.isIPv6, isIPv6) - } - }) - } - }) - - t.Run("isLoopback", func(t *testing.T) { - t.Parallel() - tests := []struct { - val string - isLoopback bool - }{ - {"0.0.0.0", false}, - {"127.0.0.1", true}, - {"127.0.0.2", true}, - {"127.0.0.1/32", true}, - {"127.0.0.1/24", true}, - {"127.0.0.1/8", true}, - {"127.0.0.1/7", false}, - {"::", false}, - {"::1", true}, - {"::/128", false}, - {"::1/128", true}, - {"::1/127", false}, - {"::ffff:8000:0001", false}, - {"::ffff:8000:0002", false}, - {"::ffff:8000:0001/128", false}, - {"::ffff:8000:0002/128", false}, - {"::ffff:8000:0001/104", false}, - {"::ffff:8000:0002/104", false}, - {"::ffff:8000:0001/100", false}, - {"::ffff:8000:0002/100", false}, - {"2001:db8::1", false}, - {"2001:db8::1:0:0:1", false}, - {"2001:db8::1/32", false}, - {"2001:db8::1:0:0:1/96", false}, - } - for _, tt := range tests { - tt := tt - t.Run(fmt.Sprintf("ip(%v).isLoopback()", tt.val), func(t *testing.T) { - t.Parallel() - val, err := ParseIPAddr(tt.val) - testutil.OK(t, err) - isLoopback := val.IsLoopback() - if isLoopback != tt.isLoopback { - t.Fatalf("expected ip(%v).isLoopback() to be %v instead of %v", tt.val, tt.isLoopback, isLoopback) - } - }) - } - }) - - t.Run("isMulticast", func(t *testing.T) { - t.Parallel() - tests := []struct { - val string - isMulticast bool - }{ - {"0.0.0.0", false}, - {"127.0.0.1", false}, - {"223.255.255.255", false}, - {"224.0.0.0", true}, - {"239.255.255.255", true}, - {"240.0.0.0", false}, - {"feff:ffff:ffff:ffff:ffff:ffff:ffff:ffff", false}, - {"ff00::", true}, - {"ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff", true}, - {"ff00::/8", true}, - {"ff00::/7", false}, - {"224.0.0.0/4", true}, - {"224.0.0.0/3", false}, - } - for _, tt := range tests { - tt := tt - t.Run(fmt.Sprintf("ip(%v).isMulticast()", tt.val), func(t *testing.T) { - t.Parallel() - val, err := ParseIPAddr(tt.val) - testutil.OK(t, err) - isMulticast := val.IsMulticast() - if isMulticast != tt.isMulticast { - t.Fatalf("expected ip(%v).isMulticast() to be %v instead of %v", tt.val, tt.isMulticast, isMulticast) - } - }) - } - }) - - t.Run("contains", func(t *testing.T) { - t.Parallel() - tests := []struct { - lhs, rhs string - contains bool - }{ - {"0.0.0.0/31", "0.0.0.0", true}, - {"0.0.0.0", "0.0.0.0/31", false}, - {"255.255.0.0/16", "255.255.255.255", true}, - {"255.255.0.0/16", "255.255.255.248/28", true}, - {"255.255.0.0/16", "255.255.255.0/24", true}, - {"255.255.0.0/16", "255.255.248.0/20", true}, - {"255.255.0.0/16", "255.255.0.0/16", true}, - {"255.255.0.0/16", "255.254.0.0/15", false}, - {"255.255.0.0/16", "255.254.255.0/24", false}, - {"::ffff:c000:0280", "192.0.2.128", false}, - {"2001:db8::/120", "2001:db8::2", true}, - {"2001:db8::/64", "2001:db8:0:0:dead:f00d::/96", true}, - } - for _, tt := range tests { - tt := tt - t.Run(fmt.Sprintf("ip(%v).contains(ip(%v))", tt.lhs, tt.rhs), func(t *testing.T) { - t.Parallel() - lhs, err := ParseIPAddr(tt.lhs) - testutil.OK(t, err) - rhs, err := ParseIPAddr(tt.rhs) - testutil.OK(t, err) - contains := lhs.Contains(rhs) - if contains != tt.contains { - t.Fatalf("expected ip(%v).contains(ip(%v)) to be %v instead of %v", tt.lhs, tt.rhs, tt.contains, contains) - } - }) - } - }) - - t.Run("TypeName", func(t *testing.T) { - t.Parallel() - tn := IPAddr{}.TypeName() - testutil.Equals(t, tn, "IP") - }) -} - func TestDeepClone(t *testing.T) { t.Parallel() t.Run("Boolean", func(t *testing.T) { @@ -664,7 +44,14 @@ func TestDeepClone(t *testing.T) { testutil.Equals(t, a, NewEntityUID("Action", "bananas")) testutil.Equals(t, b, Value(NewEntityUID("Action", "test"))) }) - + t.Run("EntityType", func(t *testing.T) { + t.Parallel() + a := EntityType("X") + b := a.deepClone() + c, ok := b.(EntityType) + testutil.Equals(t, ok, true) + testutil.Equals(t, c, a) + }) t.Run("Set", func(t *testing.T) { t.Parallel() a := Set{Long(42)} @@ -718,54 +105,3 @@ func TestDeepClone(t *testing.T) { testutil.Equals(t, b.Cedar(), mustIPValue("127.0.0.42").Cedar()) }) } - -func TestEntityType(t *testing.T) { - t.Parallel() - t.Run("Equal", func(t *testing.T) { - t.Parallel() - a := EntityType("X") - b := EntityType("X") - c := EntityType("Y") - testutil.Equals(t, a.Equal(b), true) - testutil.Equals(t, b.Equal(a), true) - testutil.Equals(t, a.Equal(c), false) - testutil.Equals(t, c.Equal(a), false) - }) - t.Run("TypeName", func(t *testing.T) { - t.Parallel() - a := EntityType("X") - testutil.Equals(t, a.TypeName(), "(EntityType of type `X`)") - }) - t.Run("String", func(t *testing.T) { - t.Parallel() - a := EntityType("X") - testutil.Equals(t, a.String(), "X") - }) - t.Run("Cedar", func(t *testing.T) { - t.Parallel() - a := EntityType("X") - testutil.Equals(t, a.Cedar(), "X") - }) - t.Run("ExplicitMarshalJSON", func(t *testing.T) { - t.Parallel() - a := EntityType("X") - v, err := a.ExplicitMarshalJSON() - testutil.OK(t, err) - testutil.Equals(t, string(v), `"X"`) - }) - t.Run("deepClone", func(t *testing.T) { - t.Parallel() - a := EntityType("X") - b := a.deepClone() - c, ok := b.(EntityType) - testutil.Equals(t, ok, true) - testutil.Equals(t, c, a) - }) - - t.Run("pathFromSlice", func(t *testing.T) { - t.Parallel() - a := EntityTypeFromSlice([]string{"X", "Y"}) - testutil.Equals(t, a, EntityType("X::Y")) - }) - -} From ace189dd94b7aedb255cd4ceeea9d9bc6d9d66eb Mon Sep 17 00:00:00 2001 From: philhassey Date: Mon, 19 Aug 2024 12:42:32 -0600 Subject: [PATCH 166/216] types: made record output the same as the rust cedar formatter Addresses IDX-142 Signed-off-by: philhassey --- internal/eval/evalers_test.go | 2 +- types/record.go | 2 +- types/record_test.go | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/internal/eval/evalers_test.go b/internal/eval/evalers_test.go index 2e6b15d7..f52ab8d7 100644 --- a/internal/eval/evalers_test.go +++ b/internal/eval/evalers_test.go @@ -1889,7 +1889,7 @@ func TestCedarString(t *testing.T) { {"string", types.String("hello"), `hello`, `"hello"`}, {"number", types.Long(42), `42`, `42`}, {"bool", types.True, `true`, `true`}, - {"record", types.Record{"a": types.Long(42), "b": types.Long(43)}, `{"a": 42, "b": 43}`, `{"a": 42, "b": 43}`}, + {"record", types.Record{"a": types.Long(42), "b": types.Long(43)}, `{"a":42, "b":43}`, `{"a":42, "b":43}`}, {"set", types.Set{types.Long(42), types.Long(43)}, `[42, 43]`, `[42, 43]`}, {"singleIP", types.IPAddr(netip.MustParsePrefix("192.168.0.42/32")), `192.168.0.42`, `ip("192.168.0.42")`}, {"ipPrefix", types.IPAddr(netip.MustParsePrefix("192.168.0.42/24")), `192.168.0.42/24`, `ip("192.168.0.42/24")`}, diff --git a/types/record.go b/types/record.go index 77992408..a5934b87 100644 --- a/types/record.go +++ b/types/record.go @@ -91,7 +91,7 @@ func (r Record) Cedar() string { } first = false sb.WriteString(strconv.Quote(k)) - sb.WriteString(": ") + sb.WriteString(":") sb.WriteString(v.Cedar()) } sb.WriteRune('}') diff --git a/types/record_test.go b/types/record_test.go index b69b4372..382aee5c 100644 --- a/types/record_test.go +++ b/types/record_test.go @@ -61,14 +61,14 @@ func TestRecord(t *testing.T) { AssertValueString( t, types.Record{"foo": types.Boolean(true)}, - `{"foo": true}`) + `{"foo":true}`) AssertValueString( t, types.Record{ "foo": types.Boolean(true), "bar": types.String("blah"), }, - `{"bar": "blah", "foo": true}`) + `{"bar":"blah", "foo":true}`) }) t.Run("TypeName", func(t *testing.T) { From a1e1d14a788a33f95a2eb2e835ba714528563390 Mon Sep 17 00:00:00 2001 From: philhassey Date: Mon, 19 Aug 2024 12:57:46 -0600 Subject: [PATCH 167/216] internal/json: add extra tests proving the equivalence of ip and decimal for their type and extensioncall forms Addresses IDX-142 Signed-off-by: philhassey --- internal/json/json_test.go | 41 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/internal/json/json_test.go b/internal/json/json_test.go index e4b0819c..1d556972 100644 --- a/internal/json/json_test.go +++ b/internal/json/json_test.go @@ -700,3 +700,44 @@ func TestUnmarshalErrors(t *testing.T) { }) } } + +func TestMarshalExtensions(t *testing.T) { + t.Parallel() + tests := []struct { + name string + in *ast.Policy + out string + }{ + { + "decimalType", + ast.Permit().When(ast.Value(types.Decimal(420000))), + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"},"conditions":[{"kind":"when","body":{"decimal":[{"Value":"42.0"}]}}]}`, + }, + { + "decimalExtension", + ast.Permit().When(ast.ExtensionCall("decimal", ast.String("42.0"))), + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"},"conditions":[{"kind":"when","body":{"decimal":[{"Value":"42.0"}]}}]}`, + }, + { + "ipType", + ast.Permit().When(ast.Value(types.IPAddr(netip.MustParsePrefix("127.0.0.1/16")))), + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"},"conditions":[{"kind":"when","body":{"ip":[{"Value":"127.0.0.1/16"}]}}]}`, + }, + { + "ipExtension", + ast.Permit().When(ast.ExtensionCall("ip", ast.String("127.0.0.1/16"))), + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"},"conditions":[{"kind":"when","body":{"ip":[{"Value":"127.0.0.1/16"}]}}]}`, + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + p := (*Policy)(tt.in) + out, err := p.MarshalJSON() + testutil.OK(t, err) + testutil.Equals(t, string(out), tt.out) + + }) + } +} From 59e403f7b746753ed82bfc8705bbf9ad7abd56c1 Mon Sep 17 00:00:00 2001 From: philhassey Date: Mon, 19 Aug 2024 15:59:48 -0600 Subject: [PATCH 168/216] ast: improve consistency of ast add/mul/sub naming Addresses IDX-142 Signed-off-by: philhassey --- ast/ast_test.go | 16 ++++++------- ast/operator.go | 12 +++++----- internal/ast/ast_test.go | 10 ++++----- internal/ast/operator.go | 6 ++--- internal/eval/convert_test.go | 6 ++--- internal/json/json.go | 6 ++--- internal/json/json_marshal.go | 6 ++--- internal/json/json_test.go | 6 ++--- internal/json/json_unmarshal.go | 12 +++++----- internal/parser/cedar_unmarshal.go | 6 ++--- internal/parser/cedar_unmarshal_test.go | 30 ++++++++++++------------- 11 files changed, 58 insertions(+), 58 deletions(-) diff --git a/ast/ast_test.go b/ast/ast_test.go index 9a3cc821..c7a32954 100644 --- a/ast/ast_test.go +++ b/ast/ast_test.go @@ -60,12 +60,12 @@ func TestAstExamples(t *testing.T) { ast.Value(simpleRecord).Access("x").Equals(ast.String("value")), ). When( - ast.Record(ast.Pairs{{Key: "x", Value: ast.Long(1).Plus(ast.Context().Access("fooCount"))}}).Access("x").Equals(ast.Long(3)), + ast.Record(ast.Pairs{{Key: "x", Value: ast.Long(1).Add(ast.Context().Access("fooCount"))}}).Access("x").Equals(ast.Long(3)), ). When( ast.Set( ast.Long(1), - ast.Long(2).Plus(ast.Long(3)).Times(ast.Long(4)), + ast.Long(2).Add(ast.Long(3)).Multiply(ast.Long(4)), ast.Context().Access("fooCount"), ).Contains(ast.Long(1)), ) @@ -320,18 +320,18 @@ func TestASTByTable(t *testing.T) { }, { "opPlus", - ast.Permit().When(ast.Long(42).Plus(ast.Long(43))), - internalast.Permit().When(internalast.Long(42).Plus(internalast.Long(43))), + ast.Permit().When(ast.Long(42).Add(ast.Long(43))), + internalast.Permit().When(internalast.Long(42).Add(internalast.Long(43))), }, { "opMinus", - ast.Permit().When(ast.Long(42).Minus(ast.Long(43))), - internalast.Permit().When(internalast.Long(42).Minus(internalast.Long(43))), + ast.Permit().When(ast.Long(42).Subtract(ast.Long(43))), + internalast.Permit().When(internalast.Long(42).Subtract(internalast.Long(43))), }, { "opTimes", - ast.Permit().When(ast.Long(42).Times(ast.Long(43))), - internalast.Permit().When(internalast.Long(42).Times(internalast.Long(43))), + ast.Permit().When(ast.Long(42).Multiply(ast.Long(43))), + internalast.Permit().When(internalast.Long(42).Multiply(internalast.Long(43))), }, { "opNegate", diff --git a/ast/operator.go b/ast/operator.go index a13bbf83..185014de 100644 --- a/ast/operator.go +++ b/ast/operator.go @@ -85,16 +85,16 @@ func If(condition Node, ifTrue Node, ifFalse Node) Node { // / ___ \| | | | |_| | | | | | | | | __/ |_| | (__ // /_/ \_\_| |_|\__|_| |_|_| |_| |_|\___|\__|_|\___| -func (lhs Node) Plus(rhs Node) Node { - return wrapNode(lhs.Node.Plus(rhs.Node)) +func (lhs Node) Add(rhs Node) Node { + return wrapNode(lhs.Node.Add(rhs.Node)) } -func (lhs Node) Minus(rhs Node) Node { - return wrapNode(lhs.Node.Minus(rhs.Node)) +func (lhs Node) Subtract(rhs Node) Node { + return wrapNode(lhs.Node.Subtract(rhs.Node)) } -func (lhs Node) Times(rhs Node) Node { - return wrapNode(lhs.Node.Times(rhs.Node)) +func (lhs Node) Multiply(rhs Node) Node { + return wrapNode(lhs.Node.Multiply(rhs.Node)) } func Negate(rhs Node) Node { diff --git a/internal/ast/ast_test.go b/internal/ast/ast_test.go index ede28f73..4a147187 100644 --- a/internal/ast/ast_test.go +++ b/internal/ast/ast_test.go @@ -59,13 +59,13 @@ func TestAstExamples(t *testing.T) { ast.Value(simpleRecord).Access("x").Equals(ast.String("value")), ). When( - ast.Record(ast.Pairs{{Key: "x", Value: ast.Long(1).Plus(ast.Context().Access("fooCount"))}}). + ast.Record(ast.Pairs{{Key: "x", Value: ast.Long(1).Add(ast.Context().Access("fooCount"))}}). Access("x").Equals(ast.Long(3)), ). When( ast.Set( ast.Long(1), - ast.Long(2).Plus(ast.Long(3)).Times(ast.Long(4)), + ast.Long(2).Add(ast.Long(3)).Multiply(ast.Long(4)), ast.Context().Access("fooCount"), ).Contains(ast.Long(1)), ) @@ -380,19 +380,19 @@ func TestASTByTable(t *testing.T) { }, { "opPlus", - ast.Permit().When(ast.Long(42).Plus(ast.Long(43))), + ast.Permit().When(ast.Long(42).Add(ast.Long(43))), ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeAdd{BinaryNode: ast.BinaryNode{Left: ast.NodeValue{Value: types.Long(42)}, Right: ast.NodeValue{Value: types.Long(43)}}}}}}, }, { "opMinus", - ast.Permit().When(ast.Long(42).Minus(ast.Long(43))), + ast.Permit().When(ast.Long(42).Subtract(ast.Long(43))), ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeSub{BinaryNode: ast.BinaryNode{Left: ast.NodeValue{Value: types.Long(42)}, Right: ast.NodeValue{Value: types.Long(43)}}}}}}, }, { "opTimes", - ast.Permit().When(ast.Long(42).Times(ast.Long(43))), + ast.Permit().When(ast.Long(42).Multiply(ast.Long(43))), ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeMult{BinaryNode: ast.BinaryNode{Left: ast.NodeValue{Value: types.Long(42)}, Right: ast.NodeValue{Value: types.Long(43)}}}}}}, }, diff --git a/internal/ast/operator.go b/internal/ast/operator.go index 62cdf3ab..036e5d7c 100644 --- a/internal/ast/operator.go +++ b/internal/ast/operator.go @@ -82,15 +82,15 @@ func If(condition Node, ifTrue Node, ifFalse Node) Node { // / ___ \| | | | |_| | | | | | | | | __/ |_| | (__ // /_/ \_\_| |_|\__|_| |_|_| |_| |_|\___|\__|_|\___| -func (lhs Node) Plus(rhs Node) Node { +func (lhs Node) Add(rhs Node) Node { return NewNode(NodeTypeAdd{BinaryNode: BinaryNode{Left: lhs.v, Right: rhs.v}}) } -func (lhs Node) Minus(rhs Node) Node { +func (lhs Node) Subtract(rhs Node) Node { return NewNode(NodeTypeSub{BinaryNode: BinaryNode{Left: lhs.v, Right: rhs.v}}) } -func (lhs Node) Times(rhs Node) Node { +func (lhs Node) Multiply(rhs Node) Node { return NewNode(NodeTypeMult{BinaryNode: BinaryNode{Left: lhs.v, Right: rhs.v}}) } diff --git a/internal/eval/convert_test.go b/internal/eval/convert_test.go index 1794b8c1..ec0c496e 100644 --- a/internal/eval/convert_test.go +++ b/internal/eval/convert_test.go @@ -163,19 +163,19 @@ func TestToEval(t *testing.T) { }, { "sub", - ast.Long(42).Minus(ast.Long(2)), + ast.Long(42).Subtract(ast.Long(2)), types.Long(40), testutil.OK, }, { "add", - ast.Long(40).Plus(ast.Long(2)), + ast.Long(40).Add(ast.Long(2)), types.Long(42), testutil.OK, }, { "mult", - ast.Long(6).Times(ast.Long(7)), + ast.Long(6).Multiply(ast.Long(7)), types.Long(42), testutil.OK, }, diff --git a/internal/json/json.go b/internal/json/json.go index 2c17cf10..d5755be2 100644 --- a/internal/json/json.go +++ b/internal/json/json.go @@ -102,9 +102,9 @@ type nodeJSON struct { GreaterThanOrEqual *binaryJSON `json:">=,omitempty"` And *binaryJSON `json:"&&,omitempty"` Or *binaryJSON `json:"||,omitempty"` - Plus *binaryJSON `json:"+,omitempty"` - Minus *binaryJSON `json:"-,omitempty"` - Times *binaryJSON `json:"*,omitempty"` + Add *binaryJSON `json:"+,omitempty"` + Subtract *binaryJSON `json:"-,omitempty"` + Multiply *binaryJSON `json:"*,omitempty"` Contains *binaryJSON `json:"contains,omitempty"` ContainsAll *binaryJSON `json:"containsAll,omitempty"` ContainsAny *binaryJSON `json:"containsAny,omitempty"` diff --git a/internal/json/json_marshal.go b/internal/json/json_marshal.go index ff0b2877..27874675 100644 --- a/internal/json/json_marshal.go +++ b/internal/json/json_marshal.go @@ -174,7 +174,7 @@ func (j *nodeJSON) FromNode(src ast.IsNode) { // Binary operators: ==, !=, in, <, <=, >, >=, &&, ||, +, -, *, contains, containsAll, containsAny case ast.NodeTypeAdd: - binaryToJSON(&j.Plus, t.BinaryNode) + binaryToJSON(&j.Add, t.BinaryNode) return case ast.NodeTypeAnd: binaryToJSON(&j.And, t.BinaryNode) @@ -207,7 +207,7 @@ func (j *nodeJSON) FromNode(src ast.IsNode) { binaryToJSON(&j.LessThanOrEqual, t.BinaryNode) return case ast.NodeTypeMult: - binaryToJSON(&j.Times, t.BinaryNode) + binaryToJSON(&j.Multiply, t.BinaryNode) return case ast.NodeTypeNotEquals: binaryToJSON(&j.NotEquals, t.BinaryNode) @@ -216,7 +216,7 @@ func (j *nodeJSON) FromNode(src ast.IsNode) { binaryToJSON(&j.Or, t.BinaryNode) return case ast.NodeTypeSub: - binaryToJSON(&j.Minus, t.BinaryNode) + binaryToJSON(&j.Subtract, t.BinaryNode) return // ., has diff --git a/internal/json/json_test.go b/internal/json/json_test.go index 1d556972..de080834 100644 --- a/internal/json/json_test.go +++ b/internal/json/json_test.go @@ -328,21 +328,21 @@ func TestUnmarshalJSON(t *testing.T) { "plus", `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, "conditions":[{"kind":"when","body":{"+":{"left":{"Value":42},"right":{"Value":24}}}}]}`, - ast.Permit().When(ast.Long(42).Plus(ast.Long(24))), + ast.Permit().When(ast.Long(42).Add(ast.Long(24))), testutil.OK, }, { "minus", `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, "conditions":[{"kind":"when","body":{"-":{"left":{"Value":42},"right":{"Value":24}}}}]}`, - ast.Permit().When(ast.Long(42).Minus(ast.Long(24))), + ast.Permit().When(ast.Long(42).Subtract(ast.Long(24))), testutil.OK, }, { "times", `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, "conditions":[{"kind":"when","body":{"*":{"left":{"Value":42},"right":{"Value":24}}}}]}`, - ast.Permit().When(ast.Long(42).Times(ast.Long(24))), + ast.Permit().When(ast.Long(42).Multiply(ast.Long(24))), testutil.OK, }, { diff --git a/internal/json/json_unmarshal.go b/internal/json/json_unmarshal.go index 44854ded..0c9e1355 100644 --- a/internal/json/json_unmarshal.go +++ b/internal/json/json_unmarshal.go @@ -194,12 +194,12 @@ func (j nodeJSON) ToNode() (ast.Node, error) { return j.And.ToNode(ast.Node.And) case j.Or != nil: return j.Or.ToNode(ast.Node.Or) - case j.Plus != nil: - return j.Plus.ToNode(ast.Node.Plus) - case j.Minus != nil: - return j.Minus.ToNode(ast.Node.Minus) - case j.Times != nil: - return j.Times.ToNode(ast.Node.Times) + case j.Add != nil: + return j.Add.ToNode(ast.Node.Add) + case j.Subtract != nil: + return j.Subtract.ToNode(ast.Node.Subtract) + case j.Multiply != nil: + return j.Multiply.ToNode(ast.Node.Multiply) case j.Contains != nil: return j.Contains.ToNode(ast.Node.Contains) case j.ContainsAll != nil: diff --git a/internal/parser/cedar_unmarshal.go b/internal/parser/cedar_unmarshal.go index e8c3c936..b62e5260 100644 --- a/internal/parser/cedar_unmarshal.go +++ b/internal/parser/cedar_unmarshal.go @@ -591,9 +591,9 @@ func (p *parser) add() (ast.Node, error) { var operator func(ast.Node, ast.Node) ast.Node switch t.Text { case "+": - operator = ast.Node.Plus + operator = ast.Node.Add case "-": - operator = ast.Node.Minus + operator = ast.Node.Subtract } if operator == nil { @@ -623,7 +623,7 @@ func (p *parser) mult() (ast.Node, error) { if err != nil { return ast.Node{}, err } - lhs = lhs.Times(rhs) + lhs = lhs.Multiply(rhs) } return lhs, nil diff --git a/internal/parser/cedar_unmarshal_test.go b/internal/parser/cedar_unmarshal_test.go index f4aa346c..d79c8f23 100644 --- a/internal/parser/cedar_unmarshal_test.go +++ b/internal/parser/cedar_unmarshal_test.go @@ -191,43 +191,43 @@ when { context.sourceIP.isIpv4() };`, "multiplication", `permit ( principal, action, resource ) when { 42 * 2 };`, - ast.Permit().When(ast.Long(42).Times(ast.Long(2))), + ast.Permit().When(ast.Long(42).Multiply(ast.Long(2))), }, { "multiple multiplication", `permit ( principal, action, resource ) when { 42 * 2 * 1 };`, - ast.Permit().When(ast.Long(42).Times(ast.Long(2)).Times(ast.Long(1))), + ast.Permit().When(ast.Long(42).Multiply(ast.Long(2)).Multiply(ast.Long(1))), }, { "addition", `permit ( principal, action, resource ) when { 42 + 2 };`, - ast.Permit().When(ast.Long(42).Plus(ast.Long(2))), + ast.Permit().When(ast.Long(42).Add(ast.Long(2))), }, { "multiple addition", `permit ( principal, action, resource ) when { 42 + 2 + 1 };`, - ast.Permit().When(ast.Long(42).Plus(ast.Long(2)).Plus(ast.Long(1))), + ast.Permit().When(ast.Long(42).Add(ast.Long(2)).Add(ast.Long(1))), }, { "subtraction", `permit ( principal, action, resource ) when { 42 - 2 };`, - ast.Permit().When(ast.Long(42).Minus(ast.Long(2))), + ast.Permit().When(ast.Long(42).Subtract(ast.Long(2))), }, { "multiple subtraction", `permit ( principal, action, resource ) when { 42 - 2 - 1 };`, - ast.Permit().When(ast.Long(42).Minus(ast.Long(2)).Minus(ast.Long(1))), + ast.Permit().When(ast.Long(42).Subtract(ast.Long(2)).Subtract(ast.Long(1))), }, { "mixed addition and subtraction", `permit ( principal, action, resource ) when { 42 - 2 + 1 };`, - ast.Permit().When(ast.Long(42).Minus(ast.Long(2)).Plus(ast.Long(1))), + ast.Permit().When(ast.Long(42).Subtract(ast.Long(2)).Add(ast.Long(1))), }, { "less than", @@ -378,25 +378,25 @@ when { 1 < 2 && true };`, "add over rel precedence", `permit ( principal, action, resource ) when { 1 + 1 < 3 };`, - ast.Permit().When(ast.Long(1).Plus(ast.Long(1)).LessThan(ast.Long(3))), + ast.Permit().When(ast.Long(1).Add(ast.Long(1)).LessThan(ast.Long(3))), }, { "mult over add precedence (rhs add)", `permit ( principal, action, resource ) when { 2 * 3 + 4 == 10 };`, - ast.Permit().When(ast.Long(2).Times(ast.Long(3)).Plus(ast.Long(4)).Equals(ast.Long(10))), + ast.Permit().When(ast.Long(2).Multiply(ast.Long(3)).Add(ast.Long(4)).Equals(ast.Long(10))), }, { "mult over add precedence (lhs add)", `permit ( principal, action, resource ) when { 2 + 3 * 4 == 14 };`, - ast.Permit().When(ast.Long(2).Plus(ast.Long(3).Times(ast.Long(4))).Equals(ast.Long(14))), + ast.Permit().When(ast.Long(2).Add(ast.Long(3).Multiply(ast.Long(4))).Equals(ast.Long(14))), }, { "unary over mult precedence", `permit ( principal, action, resource ) when { -2 * 3 == -6 };`, - ast.Permit().When(ast.Long(-2).Times(ast.Long(3)).Equals(ast.Long(-6))), + ast.Permit().When(ast.Long(-2).Multiply(ast.Long(3)).Equals(ast.Long(-6))), }, { "member over unary precedence", @@ -408,25 +408,25 @@ when { -context.num };`, "parens over unary precedence", `permit ( principal, action, resource ) when { -(2 + 3) == -5 };`, - ast.Permit().When(ast.Negate(ast.Long(2).Plus(ast.Long(3))).Equals(ast.Long(-5))), + ast.Permit().When(ast.Negate(ast.Long(2).Add(ast.Long(3))).Equals(ast.Long(-5))), }, { "multiple parenthesized operations", `permit ( principal, action, resource ) when { (2 + 3 + 4) * 5 == 18 };`, - ast.Permit().When(ast.Long(2).Plus(ast.Long(3)).Plus(ast.Long(4)).Times(ast.Long(5)).Equals(ast.Long(18))), + ast.Permit().When(ast.Long(2).Add(ast.Long(3)).Add(ast.Long(4)).Multiply(ast.Long(5)).Equals(ast.Long(18))), }, { "parenthesized if", `permit ( principal, action, resource ) when { (if true then 2 else 3 * 4) == 2 };`, - ast.Permit().When(ast.If(ast.True(), ast.Long(2), ast.Long(3).Times(ast.Long(4))).Equals(ast.Long(2))), + ast.Permit().When(ast.If(ast.True(), ast.Long(2), ast.Long(3).Multiply(ast.Long(4))).Equals(ast.Long(2))), }, { "parenthesized if with trailing mult", `permit ( principal, action, resource ) when { (if true then 2 else 3) * 4 == 8 };`, - ast.Permit().When(ast.If(ast.True(), ast.Long(2), ast.Long(3)).Times(ast.Long(4)).Equals(ast.Long(8))), + ast.Permit().When(ast.If(ast.True(), ast.Long(2), ast.Long(3)).Multiply(ast.Long(4)).Equals(ast.Long(8))), }, } From dc5561736ca2c37bfbc013a50f7567c3acd68f71 Mon Sep 17 00:00:00 2001 From: philhassey Date: Mon, 19 Aug 2024 16:01:41 -0600 Subject: [PATCH 169/216] ast: improving naming for ext functions Addresses IDX-142 Signed-off-by: philhassey --- ast/ast_test.go | 16 ++++++++-------- ast/operator.go | 16 ++++++++-------- internal/ast/ast_test.go | 8 ++++---- internal/ast/operator.go | 8 ++++---- 4 files changed, 24 insertions(+), 24 deletions(-) diff --git a/ast/ast_test.go b/ast/ast_test.go index c7a32954..d671ec27 100644 --- a/ast/ast_test.go +++ b/ast/ast_test.go @@ -275,23 +275,23 @@ func TestASTByTable(t *testing.T) { }, { "opLessThanExt", - ast.Permit().When(ast.Long(42).LessThanExt(ast.Long(43))), - internalast.Permit().When(internalast.Long(42).LessThanExt(internalast.Long(43))), + ast.Permit().When(ast.Long(42).DecimalLessThan(ast.Long(43))), + internalast.Permit().When(internalast.Long(42).DecimalLessThan(internalast.Long(43))), }, { "opLessThanOrEqualExt", - ast.Permit().When(ast.Long(42).LessThanOrEqualExt(ast.Long(43))), - internalast.Permit().When(internalast.Long(42).LessThanOrEqualExt(internalast.Long(43))), + ast.Permit().When(ast.Long(42).DecimalLessThanOrEqual(ast.Long(43))), + internalast.Permit().When(internalast.Long(42).DecimalLessThanOrEqual(internalast.Long(43))), }, { "opGreaterThanExt", - ast.Permit().When(ast.Long(42).GreaterThanExt(ast.Long(43))), - internalast.Permit().When(internalast.Long(42).GreaterThanExt(internalast.Long(43))), + ast.Permit().When(ast.Long(42).DecimalGreaterThan(ast.Long(43))), + internalast.Permit().When(internalast.Long(42).DecimalGreaterThan(internalast.Long(43))), }, { "opGreaterThanOrEqualExt", - ast.Permit().When(ast.Long(42).GreaterThanOrEqualExt(ast.Long(43))), - internalast.Permit().When(internalast.Long(42).GreaterThanOrEqualExt(internalast.Long(43))), + ast.Permit().When(ast.Long(42).DecimalGreaterThanOrEqual(ast.Long(43))), + internalast.Permit().When(internalast.Long(42).DecimalGreaterThanOrEqual(internalast.Long(43))), }, { "opLike", diff --git a/ast/operator.go b/ast/operator.go index 185014de..7d2d250b 100644 --- a/ast/operator.go +++ b/ast/operator.go @@ -36,20 +36,20 @@ func (lhs Node) GreaterThanOrEqual(rhs Node) Node { return wrapNode(lhs.Node.GreaterThanOrEqual(rhs.Node)) } -func (lhs Node) LessThanExt(rhs Node) Node { - return wrapNode(lhs.Node.LessThanExt(rhs.Node)) +func (lhs Node) DecimalLessThan(rhs Node) Node { + return wrapNode(lhs.Node.DecimalLessThan(rhs.Node)) } -func (lhs Node) LessThanOrEqualExt(rhs Node) Node { - return wrapNode(lhs.Node.LessThanOrEqualExt(rhs.Node)) +func (lhs Node) DecimalLessThanOrEqual(rhs Node) Node { + return wrapNode(lhs.Node.DecimalLessThanOrEqual(rhs.Node)) } -func (lhs Node) GreaterThanExt(rhs Node) Node { - return wrapNode(lhs.Node.GreaterThanExt(rhs.Node)) +func (lhs Node) DecimalGreaterThan(rhs Node) Node { + return wrapNode(lhs.Node.DecimalGreaterThan(rhs.Node)) } -func (lhs Node) GreaterThanOrEqualExt(rhs Node) Node { - return wrapNode(lhs.Node.GreaterThanOrEqualExt(rhs.Node)) +func (lhs Node) DecimalGreaterThanOrEqual(rhs Node) Node { + return wrapNode(lhs.Node.DecimalGreaterThanOrEqual(rhs.Node)) } func (lhs Node) Like(pattern types.Pattern) Node { diff --git a/internal/ast/ast_test.go b/internal/ast/ast_test.go index 4a147187..a81048b0 100644 --- a/internal/ast/ast_test.go +++ b/internal/ast/ast_test.go @@ -326,25 +326,25 @@ func TestASTByTable(t *testing.T) { }, { "opLessThanExt", - ast.Permit().When(ast.Long(42).LessThanExt(ast.Long(43))), + ast.Permit().When(ast.Long(42).DecimalLessThan(ast.Long(43))), ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeExtensionCall{Name: "lessThan", Args: []ast.IsNode{ast.NodeValue{Value: types.Long(42)}, ast.NodeValue{Value: types.Long(43)}}}}}}, }, { "opLessThanOrEqualExt", - ast.Permit().When(ast.Long(42).LessThanOrEqualExt(ast.Long(43))), + ast.Permit().When(ast.Long(42).DecimalLessThanOrEqual(ast.Long(43))), ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeExtensionCall{Name: "lessThanOrEqual", Args: []ast.IsNode{ast.NodeValue{Value: types.Long(42)}, ast.NodeValue{Value: types.Long(43)}}}}}}, }, { "opGreaterThanExt", - ast.Permit().When(ast.Long(42).GreaterThanExt(ast.Long(43))), + ast.Permit().When(ast.Long(42).DecimalGreaterThan(ast.Long(43))), ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeExtensionCall{Name: "greaterThan", Args: []ast.IsNode{ast.NodeValue{Value: types.Long(42)}, ast.NodeValue{Value: types.Long(43)}}}}}}, }, { "opGreaterThanOrEqualExt", - ast.Permit().When(ast.Long(42).GreaterThanOrEqualExt(ast.Long(43))), + ast.Permit().When(ast.Long(42).DecimalGreaterThanOrEqual(ast.Long(43))), ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeExtensionCall{Name: "greaterThanOrEqual", Args: []ast.IsNode{ast.NodeValue{Value: types.Long(42)}, ast.NodeValue{Value: types.Long(43)}}}}}}, }, diff --git a/internal/ast/operator.go b/internal/ast/operator.go index 036e5d7c..8cc17908 100644 --- a/internal/ast/operator.go +++ b/internal/ast/operator.go @@ -33,19 +33,19 @@ func (lhs Node) GreaterThanOrEqual(rhs Node) Node { return NewNode(NodeTypeGreaterThanOrEqual{BinaryNode: BinaryNode{Left: lhs.v, Right: rhs.v}}) } -func (lhs Node) LessThanExt(rhs Node) Node { +func (lhs Node) DecimalLessThan(rhs Node) Node { return NewMethodCall(lhs, "lessThan", rhs) } -func (lhs Node) LessThanOrEqualExt(rhs Node) Node { +func (lhs Node) DecimalLessThanOrEqual(rhs Node) Node { return NewMethodCall(lhs, "lessThanOrEqual", rhs) } -func (lhs Node) GreaterThanExt(rhs Node) Node { +func (lhs Node) DecimalGreaterThan(rhs Node) Node { return NewMethodCall(lhs, "greaterThan", rhs) } -func (lhs Node) GreaterThanOrEqualExt(rhs Node) Node { +func (lhs Node) DecimalGreaterThanOrEqual(rhs Node) Node { return NewMethodCall(lhs, "greaterThanOrEqual", rhs) } From de80f0aef9283c661951b99f22d1763e8cc898f8 Mon Sep 17 00:00:00 2001 From: philhassey Date: Mon, 19 Aug 2024 16:04:34 -0600 Subject: [PATCH 170/216] ast: made if then else more consistent Addresses IDX-142 Signed-off-by: philhassey --- ast/ast_test.go | 4 ++-- ast/operator.go | 4 ++-- internal/ast/ast_test.go | 4 ++-- internal/ast/internal_test.go | 2 +- internal/ast/node.go | 4 ++-- internal/ast/operator.go | 4 ++-- internal/eval/convert.go | 2 +- internal/eval/convert_test.go | 2 +- internal/json/json_marshal.go | 4 ++-- internal/json/json_test.go | 2 +- internal/json/json_unmarshal.go | 2 +- internal/parser/cedar_marshal.go | 8 ++++---- internal/parser/cedar_unmarshal.go | 2 +- internal/parser/cedar_unmarshal_test.go | 6 +++--- internal/parser/node.go | 2 +- 15 files changed, 26 insertions(+), 26 deletions(-) diff --git a/ast/ast_test.go b/ast/ast_test.go index d671ec27..19d20969 100644 --- a/ast/ast_test.go +++ b/ast/ast_test.go @@ -315,8 +315,8 @@ func TestASTByTable(t *testing.T) { }, { "opIf", - ast.Permit().When(ast.If(ast.True(), ast.Long(42), ast.Long(43))), - internalast.Permit().When(internalast.If(internalast.True(), internalast.Long(42), internalast.Long(43))), + ast.Permit().When(ast.IfThenElse(ast.True(), ast.Long(42), ast.Long(43))), + internalast.Permit().When(internalast.IfThenElse(internalast.True(), internalast.Long(42), internalast.Long(43))), }, { "opPlus", diff --git a/ast/operator.go b/ast/operator.go index 7d2d250b..3dc649e4 100644 --- a/ast/operator.go +++ b/ast/operator.go @@ -75,8 +75,8 @@ func Not(rhs Node) Node { return wrapNode(ast.Not(rhs.Node)) } -func If(condition Node, ifTrue Node, ifFalse Node) Node { - return wrapNode(ast.If(condition.Node, ifTrue.Node, ifFalse.Node)) +func IfThenElse(condition Node, thenNode Node, elseNode Node) Node { + return wrapNode(ast.IfThenElse(condition.Node, thenNode.Node, elseNode.Node)) } // _ _ _ _ _ _ diff --git a/internal/ast/ast_test.go b/internal/ast/ast_test.go index a81048b0..87cded5b 100644 --- a/internal/ast/ast_test.go +++ b/internal/ast/ast_test.go @@ -374,9 +374,9 @@ func TestASTByTable(t *testing.T) { }, { "opIf", - ast.Permit().When(ast.If(ast.True(), ast.Long(42), ast.Long(43))), + ast.Permit().When(ast.IfThenElse(ast.True(), ast.Long(42), ast.Long(43))), ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, - Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeIf{If: ast.NodeValue{Value: types.True}, Then: ast.NodeValue{Value: types.Long(42)}, Else: ast.NodeValue{Value: types.Long(43)}}}}}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeIfThenElse{If: ast.NodeValue{Value: types.True}, Then: ast.NodeValue{Value: types.Long(42)}, Else: ast.NodeValue{Value: types.Long(43)}}}}}, }, { "opPlus", diff --git a/internal/ast/internal_test.go b/internal/ast/internal_test.go index 02d21c86..f5c9ac75 100644 --- a/internal/ast/internal_test.go +++ b/internal/ast/internal_test.go @@ -13,7 +13,7 @@ func TestIsNode(t *testing.T) { StrOpNode{}.isNode() BinaryNode{}.isNode() - NodeTypeIf{}.isNode() + NodeTypeIfThenElse{}.isNode() NodeTypeLike{}.isNode() NodeTypeIs{}.isNode() UnaryNode{}.isNode() diff --git a/internal/ast/node.go b/internal/ast/node.go index 623b1400..41dc7091 100644 --- a/internal/ast/node.go +++ b/internal/ast/node.go @@ -29,11 +29,11 @@ type BinaryNode struct { func (n BinaryNode) isNode() {} -type NodeTypeIf struct { +type NodeTypeIfThenElse struct { If, Then, Else IsNode } -func (n NodeTypeIf) isNode() {} +func (n NodeTypeIfThenElse) isNode() {} type NodeTypeOr struct{ BinaryNode } diff --git a/internal/ast/operator.go b/internal/ast/operator.go index 8cc17908..8d393805 100644 --- a/internal/ast/operator.go +++ b/internal/ast/operator.go @@ -72,8 +72,8 @@ func Not(rhs Node) Node { return NewNode(NodeTypeNot{UnaryNode: UnaryNode{Arg: rhs.v}}) } -func If(condition Node, ifTrue Node, ifFalse Node) Node { - return NewNode(NodeTypeIf{If: condition.v, Then: ifTrue.v, Else: ifFalse.v}) +func IfThenElse(condition Node, thenNode Node, elseNode Node) Node { + return NewNode(NodeTypeIfThenElse{If: condition.v, Then: thenNode.v, Else: elseNode.v}) } // _ _ _ _ _ _ diff --git a/internal/eval/convert.go b/internal/eval/convert.go index b502cd4b..c50b78bb 100644 --- a/internal/eval/convert.go +++ b/internal/eval/convert.go @@ -16,7 +16,7 @@ func toEval(n ast.IsNode) Evaler { return newHasEval(toEval(v.Arg), string(v.Value)) case ast.NodeTypeLike: return newLikeEval(toEval(v.Arg), v.Value) - case ast.NodeTypeIf: + case ast.NodeTypeIfThenElse: return newIfThenElseEval(toEval(v.If), toEval(v.Then), toEval(v.Else)) case ast.NodeTypeIs: return newIsEval(toEval(v.Left), newLiteralEval(v.EntityType)) diff --git a/internal/eval/convert_test.go b/internal/eval/convert_test.go index ec0c496e..70c8dc74 100644 --- a/internal/eval/convert_test.go +++ b/internal/eval/convert_test.go @@ -37,7 +37,7 @@ func TestToEval(t *testing.T) { }, { "if", - ast.If(ast.True(), ast.Long(42), ast.Long(43)), + ast.IfThenElse(ast.True(), ast.Long(42), ast.Long(43)), types.Long(42), testutil.OK, }, diff --git a/internal/json/json_marshal.go b/internal/json/json_marshal.go index 27874675..8794da7f 100644 --- a/internal/json/json_marshal.go +++ b/internal/json/json_marshal.go @@ -111,7 +111,7 @@ func recordToJSON(dest *recordJSON, src ast.NodeTypeRecord) { *dest = res } -func ifToJSON(dest **ifThenElseJSON, src ast.NodeTypeIf) { +func ifToJSON(dest **ifThenElseJSON, src ast.NodeTypeIfThenElse) { res := &ifThenElseJSON{} res.If.FromNode(src.If) res.Then.FromNode(src.Then) @@ -244,7 +244,7 @@ func (j *nodeJSON) FromNode(src ast.IsNode) { // if-then-else // IfThenElse *ifThenElseJSON `json:"if-then-else"` - case ast.NodeTypeIf: + case ast.NodeTypeIfThenElse: ifToJSON(&j.IfThenElse, t) return diff --git a/internal/json/json_test.go b/internal/json/json_test.go index de080834..65715d12 100644 --- a/internal/json/json_test.go +++ b/internal/json/json_test.go @@ -427,7 +427,7 @@ func TestUnmarshalJSON(t *testing.T) { "ifThenElse", `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, "conditions":[{"kind":"when","body":{"if-then-else":{"if":{"Value":true},"then":{"Value":42},"else":{"Value":24}}}}]}`, - ast.Permit().When(ast.If(ast.True(), ast.Long(42), ast.Long(24))), + ast.Permit().When(ast.IfThenElse(ast.True(), ast.Long(42), ast.Long(24))), testutil.OK, }, { diff --git a/internal/json/json_unmarshal.go b/internal/json/json_unmarshal.go index 0c9e1355..cfe9f8dc 100644 --- a/internal/json/json_unmarshal.go +++ b/internal/json/json_unmarshal.go @@ -96,7 +96,7 @@ func (j ifThenElseJSON) ToNode() (ast.Node, error) { if err != nil { return ast.Node{}, fmt.Errorf("error in else: %w", err) } - return ast.If(if_, then, else_), nil + return ast.IfThenElse(if_, then, else_), nil } func (j arrayJSON) ToNode() (ast.Node, error) { var nodes []ast.Node diff --git a/internal/parser/cedar_marshal.go b/internal/parser/cedar_marshal.go index 44bb02a0..95383f53 100644 --- a/internal/parser/cedar_marshal.go +++ b/internal/parser/cedar_marshal.go @@ -316,16 +316,16 @@ func (n NodeTypeLike) marshalCedar(buf *bytes.Buffer) { func (n NodeTypeIf) marshalCedar(buf *bytes.Buffer) { buf.WriteString("if ") - marshalChildNode(n.precedenceLevel(), n.NodeTypeIf.If, buf) + marshalChildNode(n.precedenceLevel(), n.NodeTypeIfThenElse.If, buf) buf.WriteString(" then ") - marshalChildNode(n.precedenceLevel(), n.NodeTypeIf.Then, buf) + marshalChildNode(n.precedenceLevel(), n.NodeTypeIfThenElse.Then, buf) buf.WriteString(" else ") - marshalChildNode(n.precedenceLevel(), n.NodeTypeIf.Else, buf) + marshalChildNode(n.precedenceLevel(), n.NodeTypeIfThenElse.Else, buf) } func astNodeToMarshalNode(astNode ast.IsNode) IsNode { switch v := astNode.(type) { - case ast.NodeTypeIf: + case ast.NodeTypeIfThenElse: return NodeTypeIf{v} case ast.NodeTypeOr: return NodeTypeOr{v} diff --git a/internal/parser/cedar_unmarshal.go b/internal/parser/cedar_unmarshal.go index b62e5260..37d95603 100644 --- a/internal/parser/cedar_unmarshal.go +++ b/internal/parser/cedar_unmarshal.go @@ -444,7 +444,7 @@ func (p *parser) expression() (ast.Node, error) { return ast.Node{}, err } - return ast.If(condition, ifTrue, ifFalse), nil + return ast.IfThenElse(condition, ifTrue, ifFalse), nil } return p.or() diff --git a/internal/parser/cedar_unmarshal_test.go b/internal/parser/cedar_unmarshal_test.go index d79c8f23..3b569421 100644 --- a/internal/parser/cedar_unmarshal_test.go +++ b/internal/parser/cedar_unmarshal_test.go @@ -342,7 +342,7 @@ when { true || false || true };`, "if then else", `permit ( principal, action, resource ) when { if true then true else false };`, - ast.Permit().When(ast.If(ast.True(), ast.True(), ast.False())), + ast.Permit().When(ast.IfThenElse(ast.True(), ast.True(), ast.False())), }, { "ip extension function", @@ -420,13 +420,13 @@ when { (2 + 3 + 4) * 5 == 18 };`, "parenthesized if", `permit ( principal, action, resource ) when { (if true then 2 else 3 * 4) == 2 };`, - ast.Permit().When(ast.If(ast.True(), ast.Long(2), ast.Long(3).Multiply(ast.Long(4))).Equals(ast.Long(2))), + ast.Permit().When(ast.IfThenElse(ast.True(), ast.Long(2), ast.Long(3).Multiply(ast.Long(4))).Equals(ast.Long(2))), }, { "parenthesized if with trailing mult", `permit ( principal, action, resource ) when { (if true then 2 else 3) * 4 == 8 };`, - ast.Permit().When(ast.If(ast.True(), ast.Long(2), ast.Long(3)).Multiply(ast.Long(4)).Equals(ast.Long(8))), + ast.Permit().When(ast.IfThenElse(ast.True(), ast.Long(2), ast.Long(3)).Multiply(ast.Long(4)).Equals(ast.Long(8))), }, } diff --git a/internal/parser/node.go b/internal/parser/node.go index 41da04cf..0e8dada8 100644 --- a/internal/parser/node.go +++ b/internal/parser/node.go @@ -6,7 +6,7 @@ import ( "github.com/cedar-policy/cedar-go/internal/ast" ) -type NodeTypeIf struct{ ast.NodeTypeIf } +type NodeTypeIf struct{ ast.NodeTypeIfThenElse } func (n NodeTypeIf) precedenceLevel() nodePrecedenceLevel { return ifPrecedence From 65eb6206441775094c40a5d5357a642b7966eacc Mon Sep 17 00:00:00 2001 From: philhassey Date: Mon, 19 Aug 2024 16:37:26 -0600 Subject: [PATCH 171/216] types: change UnmarshalCedar to ParsePattern Addresses IDX-142 Signed-off-by: philhassey --- internal/eval/evalers_test.go | 3 +-- internal/parser/cedar_unmarshal.go | 4 ++-- types/pattern.go | 11 +++++------ types/patttern_test.go | 7 +++---- 4 files changed, 11 insertions(+), 14 deletions(-) diff --git a/internal/eval/evalers_test.go b/internal/eval/evalers_test.go index f52ab8d7..b1924ee8 100644 --- a/internal/eval/evalers_test.go +++ b/internal/eval/evalers_test.go @@ -1424,8 +1424,7 @@ func TestLikeNode(t *testing.T) { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() - var pat types.Pattern - err := pat.UnmarshalCedar([]byte(tt.pattern[1 : len(tt.pattern)-1])) + pat, err := types.ParsePattern(tt.pattern[1 : len(tt.pattern)-1]) testutil.OK(t, err) n := newLikeEval(tt.str, pat) v, err := n.Eval(&Context{}) diff --git a/internal/parser/cedar_unmarshal.go b/internal/parser/cedar_unmarshal.go index 37d95603..f7fd4f2c 100644 --- a/internal/parser/cedar_unmarshal.go +++ b/internal/parser/cedar_unmarshal.go @@ -557,8 +557,8 @@ func (p *parser) like(lhs ast.Node) (ast.Node, error) { patternRaw := t.Text patternRaw = strings.TrimPrefix(patternRaw, "\"") patternRaw = strings.TrimSuffix(patternRaw, "\"") - var pattern types.Pattern - if err := pattern.UnmarshalCedar([]byte(patternRaw)); err != nil { + pattern, err := types.ParsePattern(patternRaw) + if err != nil { return ast.Node{}, err } return lhs.Like(pattern), nil diff --git a/types/pattern.go b/types/pattern.go index 3b7db9bb..8dc3a85e 100644 --- a/types/pattern.go +++ b/types/pattern.go @@ -137,25 +137,24 @@ func matchChunk(chunk, s string) (rest string, ok bool) { return s, true } -func (p *Pattern) UnmarshalCedar(b []byte) error { +// ParsePattern will parse an unquoted rust-style string with \*'s in it. +func ParsePattern(v string) (Pattern, error) { + b := []byte(v) var comps []PatternComponent for len(b) > 0 { for len(b) > 0 && b[0] == '*' { b = b[1:] comps = append(comps, Wildcard) } - var err error var literal string literal, b, err = rust.Unquote(b, true) if err != nil { - return err + return Pattern{}, err } comps = append(comps, String(literal)) } - - *p = NewPattern(comps...) - return nil + return NewPattern(comps...), nil } func (p Pattern) MarshalJSON() ([]byte, error) { diff --git a/types/patttern_test.go b/types/patttern_test.go index 6ff9b132..ce3b6319 100644 --- a/types/patttern_test.go +++ b/types/patttern_test.go @@ -51,8 +51,8 @@ func TestParsePattern(t *testing.T) { tt := tt t.Run(tt.input, func(t *testing.T) { t.Parallel() - var got Pattern - if err := got.UnmarshalCedar([]byte(tt.input)); err != nil { + got, err := ParsePattern(tt.input) + if err != nil { testutil.Equals(t, tt.wantOk, false) testutil.Equals(t, err.Error(), tt.wantErr) } else { @@ -94,8 +94,7 @@ func TestMatch(t *testing.T) { tt := tt t.Run(tt.pattern+":"+tt.target, func(t *testing.T) { t.Parallel() - var pat Pattern - err := pat.UnmarshalCedar([]byte(tt.pattern[1 : len(tt.pattern)-1])) + pat, err := ParsePattern(tt.pattern[1 : len(tt.pattern)-1]) testutil.OK(t, err) got := pat.Match(tt.target) testutil.Equals(t, got, tt.want) From 693e72083e65a33154c0140e238ed4e357283c0c Mon Sep 17 00:00:00 2001 From: philhassey Date: Mon, 19 Aug 2024 17:15:28 -0600 Subject: [PATCH 172/216] cedar: minor renames Addresses IDX-142 Signed-off-by: philhassey --- authorize_test.go | 2 +- corpus_test.go | 4 ++-- policy.go | 36 +---------------------------- policy_set.go | 12 ++++------ policy_set_test.go | 8 +++---- policy_slice.go | 55 ++++++++++++++++++++++++++++++++++++++++++++ policy_slice_test.go | 30 ++++++++++++++++++++++++ policy_test.go | 22 ------------------ 8 files changed, 98 insertions(+), 71 deletions(-) create mode 100644 policy_slice.go create mode 100644 policy_slice_test.go diff --git a/authorize_test.go b/authorize_test.go index 70f74204..a3610e9b 100644 --- a/authorize_test.go +++ b/authorize_test.go @@ -709,7 +709,7 @@ func TestIsAuthorized(t *testing.T) { tt := tt t.Run(tt.Name, func(t *testing.T) { t.Parallel() - ps, err := NewPolicySetFromFile("policy.cedar", []byte(tt.Policy)) + ps, err := NewPolicySetFromBytes("policy.cedar", []byte(tt.Policy)) testutil.Equals(t, (err != nil), tt.ParseErr) ok, diag := ps.IsAuthorized(tt.Entities, Request{ Principal: tt.Principal, diff --git a/corpus_test.go b/corpus_test.go index 509c7773..5e8d72e2 100644 --- a/corpus_test.go +++ b/corpus_test.go @@ -150,7 +150,7 @@ func TestCorpus(t *testing.T) { t.Fatal("error reading policy content", err) } - policySet, err := NewPolicySetFromFile("policy.cedar", policyContent) + policySet, err := NewPolicySetFromBytes("policy.cedar", policyContent) if err != nil { t.Fatal("error parsing policy set", err) } @@ -336,7 +336,7 @@ func TestCorpusRelated(t *testing.T) { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() - policy, err := NewPolicySetFromFile("", []byte(tt.policy)) + policy, err := NewPolicySetFromBytes("", []byte(tt.policy)) testutil.OK(t, err) ok, diag := policy.IsAuthorized(entities2.Entities{}, tt.request) testutil.Equals(t, ok, tt.decision) diff --git a/policy.go b/policy.go index 205ccac7..367e94b2 100644 --- a/policy.go +++ b/policy.go @@ -2,7 +2,6 @@ package cedar import ( "bytes" - "fmt" "github.com/cedar-policy/cedar-go/ast" internalast "github.com/cedar-policy/cedar-go/internal/ast" @@ -100,39 +99,6 @@ func (p Policy) Position() Position { return Position(p.ast.Position) } -func (p *Policy) SetSourceFile(path string) { +func (p *Policy) SetFileName(path string) { p.ast.Position.FileName = path } - -// PolicySlice represents a set of un-named Policy's. Cedar documents, unlike the JSON format, don't have a means of -// naming individual policies. -type PolicySlice []*Policy - -// UnmarshalCedar parses a concatenation of un-named Cedar policy statements. Names can be assigned to these policies -// when adding them to a PolicySet. -func (p *PolicySlice) UnmarshalCedar(b []byte) error { - var res parser.PolicySlice - if err := res.UnmarshalCedar(b); err != nil { - return fmt.Errorf("parser error: %w", err) - } - policySlice := make([]*Policy, 0, len(res)) - for _, p := range res { - newPolicy := newPolicy((*internalast.Policy)(p)) - policySlice = append(policySlice, &newPolicy) - } - *p = policySlice - return nil -} - -// MarshalCedar emits a concatenated Cedar representation of a PolicySlice -func (p PolicySlice) MarshalCedar() []byte { - var buf bytes.Buffer - for i, policy := range p { - buf.Write(policy.MarshalCedar()) - - if i < len(p)-1 { - buf.WriteString("\n\n") - } - } - return buf.Bytes() -} diff --git a/policy_set.go b/policy_set.go index d5228b6c..fd11d625 100644 --- a/policy_set.go +++ b/policy_set.go @@ -19,21 +19,19 @@ func NewPolicySet() PolicySet { return PolicySet{policies: map[PolicyID]*Policy{}} } -// NewPolicySetFromFile will create a PolicySet from the given text document with the/ given file name used in Position +// NewPolicySetFromBytes will create a PolicySet from the given text document with the/ given file name used in Position // data. If there is an error parsing the document, it will be returned. // -// NewPolicySetFromFile assigns default PolicyIDs to the policies contained in fileName in the format "policy" where +// NewPolicySetFromBytes assigns default PolicyIDs to the policies contained in fileName in the format "policy" where // is incremented for each new policy found in the file. -func NewPolicySetFromFile(fileName string, document []byte) (PolicySet, error) { - var policySlice PolicySlice - if err := policySlice.UnmarshalCedar(document); err != nil { +func NewPolicySetFromBytes(fileName string, document []byte) (PolicySet, error) { + policySlice, err := NewPolicySliceFromBytes(fileName, document) + if err != nil { return PolicySet{}, err } - policyMap := make(map[PolicyID]*Policy, len(policySlice)) for i, p := range policySlice { policyID := PolicyID(fmt.Sprintf("policy%d", i)) - p.SetSourceFile(fileName) policyMap[policyID] = p } return PolicySet{policies: policyMap}, nil diff --git a/policy_set_test.go b/policy_set_test.go index 3ea5046d..1889f8f2 100644 --- a/policy_set_test.go +++ b/policy_set_test.go @@ -13,17 +13,17 @@ func TestNewPolicySetFromFile(t *testing.T) { t.Parallel() t.Run("err-in-tokenize", func(t *testing.T) { t.Parallel() - _, err := cedar.NewPolicySetFromFile("policy.cedar", []byte(`"`)) + _, err := cedar.NewPolicySetFromBytes("policy.cedar", []byte(`"`)) testutil.Error(t, err) }) t.Run("err-in-parse", func(t *testing.T) { t.Parallel() - _, err := cedar.NewPolicySetFromFile("policy.cedar", []byte(`err`)) + _, err := cedar.NewPolicySetFromBytes("policy.cedar", []byte(`err`)) testutil.Error(t, err) }) t.Run("annotations", func(t *testing.T) { t.Parallel() - ps, err := cedar.NewPolicySetFromFile("policy.cedar", []byte(`@key("value") permit (principal, action, resource);`)) + ps, err := cedar.NewPolicySetFromBytes("policy.cedar", []byte(`@key("value") permit (principal, action, resource);`)) testutil.OK(t, err) testutil.Equals(t, ps.GetPolicy("policy0").Annotations(), cedar.Annotations{"key": "value"}) }) @@ -162,7 +162,7 @@ forbid ( ps := cedar.NewPolicySet() for i, p := range policies { - p.SetSourceFile("example.cedar") + p.SetFileName("example.cedar") ps.UpsertPolicy(cedar.PolicyID(fmt.Sprintf("policy%d", i)), p) } diff --git a/policy_slice.go b/policy_slice.go new file mode 100644 index 00000000..25a70dd8 --- /dev/null +++ b/policy_slice.go @@ -0,0 +1,55 @@ +package cedar + +import ( + "bytes" + "fmt" + + internalast "github.com/cedar-policy/cedar-go/internal/ast" + "github.com/cedar-policy/cedar-go/internal/parser" +) + +// PolicySlice represents a set of un-named Policy's. Cedar documents, unlike the JSON format, don't have a means of +// naming individual policies. +type PolicySlice []*Policy + +// NewPolicySliceFromBytes will create a PolicySet from the given text document with the/ given file name used in Position +// data. If there is an error parsing the document, it will be returned. +func NewPolicySliceFromBytes(fileName string, document []byte) (PolicySlice, error) { + var policySlice PolicySlice + if err := policySlice.UnmarshalCedar(document); err != nil { + return nil, err + } + for _, p := range policySlice { + p.SetFileName(fileName) + } + return policySlice, nil +} + +// UnmarshalCedar parses a concatenation of un-named Cedar policy statements. Names can be assigned to these policies +// when adding them to a PolicySet. +func (p *PolicySlice) UnmarshalCedar(b []byte) error { + var res parser.PolicySlice + if err := res.UnmarshalCedar(b); err != nil { + return fmt.Errorf("parser error: %w", err) + } + policySlice := make([]*Policy, 0, len(res)) + for _, p := range res { + newPolicy := newPolicy((*internalast.Policy)(p)) + policySlice = append(policySlice, &newPolicy) + } + *p = policySlice + return nil +} + +// MarshalCedar emits a concatenated Cedar representation of a PolicySlice +func (p PolicySlice) MarshalCedar() []byte { + var buf bytes.Buffer + for i, policy := range p { + buf.Write(policy.MarshalCedar()) + + if i < len(p)-1 { + buf.WriteString("\n\n") + } + } + return buf.Bytes() +} diff --git a/policy_slice_test.go b/policy_slice_test.go new file mode 100644 index 00000000..fc728085 --- /dev/null +++ b/policy_slice_test.go @@ -0,0 +1,30 @@ +package cedar_test + +import ( + "testing" + + "github.com/cedar-policy/cedar-go" + "github.com/cedar-policy/cedar-go/internal/testutil" +) + +func TestPolicySlice(t *testing.T) { + t.Parallel() + + policiesStr := `permit ( + principal, + action == Action::"editPhoto", + resource +) +when { resource.owner == principal }; + +forbid ( + principal in Groups::"bannedUsers", + action, + resource +);` + + var policies cedar.PolicySlice + testutil.OK(t, policies.UnmarshalCedar([]byte(policiesStr))) + + testutil.Equals(t, string(policies.MarshalCedar()), policiesStr) +} diff --git a/policy_test.go b/policy_test.go index 9f4e419f..36e8845a 100644 --- a/policy_test.go +++ b/policy_test.go @@ -94,25 +94,3 @@ func TestPolicyAST(t *testing.T) { _ = cedar.NewPolicyFromAST(astExample) } - -func TestPolicySlice(t *testing.T) { - t.Parallel() - - policiesStr := `permit ( - principal, - action == Action::"editPhoto", - resource -) -when { resource.owner == principal }; - -forbid ( - principal in Groups::"bannedUsers", - action, - resource -);` - - var policies cedar.PolicySlice - testutil.OK(t, policies.UnmarshalCedar([]byte(policiesStr))) - - testutil.Equals(t, string(policies.MarshalCedar()), policiesStr) -} From 8936289ad3c178342a7e1443b03854d173576a82 Mon Sep 17 00:00:00 2001 From: philhassey Date: Mon, 19 Aug 2024 17:15:56 -0600 Subject: [PATCH 173/216] types: hide ZeroValue Addresses IDX-142 Signed-off-by: philhassey --- internal/eval/evalers.go | 148 +++++++++++++++++----------------- internal/eval/evalers_test.go | 82 +++++++++---------- types/json_test.go | 20 +++-- types/value.go | 4 - 4 files changed, 129 insertions(+), 125 deletions(-) diff --git a/internal/eval/evalers.go b/internal/eval/evalers.go index b600453e..79b15a1c 100644 --- a/internal/eval/evalers.go +++ b/internal/eval/evalers.go @@ -14,6 +14,10 @@ var errAttributeAccess = fmt.Errorf("does not have the attribute") var errEntityNotExist = fmt.Errorf("does not exist") var errUnspecifiedEntity = fmt.Errorf("unspecified entity") +func zeroValue() types.Value { + return nil +} + // TODO: make private again type Context struct { Entities entities.Entities @@ -133,7 +137,7 @@ func newErrorEval(err error) *errorEval { } func (n *errorEval) Eval(_ *Context) (types.Value, error) { - return types.ZeroValue(), n.err + return zeroValue(), n.err } // literalEval @@ -165,22 +169,22 @@ func newOrNode(lhs Evaler, rhs Evaler) *orEval { func (n *orEval) Eval(ctx *Context) (types.Value, error) { v, err := n.lhs.Eval(ctx) if err != nil { - return types.ZeroValue(), err + return zeroValue(), err } b, err := ValueToBool(v) if err != nil { - return types.ZeroValue(), err + return zeroValue(), err } if b { return v, nil } v, err = n.rhs.Eval(ctx) if err != nil { - return types.ZeroValue(), err + return zeroValue(), err } _, err = ValueToBool(v) if err != nil { - return types.ZeroValue(), err + return zeroValue(), err } return v, nil } @@ -201,22 +205,22 @@ func newAndEval(lhs Evaler, rhs Evaler) *andEval { func (n *andEval) Eval(ctx *Context) (types.Value, error) { v, err := n.lhs.Eval(ctx) if err != nil { - return types.ZeroValue(), err + return zeroValue(), err } b, err := ValueToBool(v) if err != nil { - return types.ZeroValue(), err + return zeroValue(), err } if !b { return v, nil } v, err = n.rhs.Eval(ctx) if err != nil { - return types.ZeroValue(), err + return zeroValue(), err } _, err = ValueToBool(v) if err != nil { - return types.ZeroValue(), err + return zeroValue(), err } return v, nil } @@ -235,11 +239,11 @@ func newNotEval(inner Evaler) *notEval { func (n *notEval) Eval(ctx *Context) (types.Value, error) { v, err := n.inner.Eval(ctx) if err != nil { - return types.ZeroValue(), err + return zeroValue(), err } b, err := ValueToBool(v) if err != nil { - return types.ZeroValue(), err + return zeroValue(), err } return !b, nil } @@ -304,15 +308,15 @@ func newAddEval(lhs Evaler, rhs Evaler) *addEval { func (n *addEval) Eval(ctx *Context) (types.Value, error) { lhs, err := evalLong(n.lhs, ctx) if err != nil { - return types.ZeroValue(), err + return zeroValue(), err } rhs, err := evalLong(n.rhs, ctx) if err != nil { - return types.ZeroValue(), err + return zeroValue(), err } res, ok := checkedAddI64(lhs, rhs) if !ok { - return types.ZeroValue(), fmt.Errorf("%w while attempting to add `%d` with `%d`", errOverflow, lhs, rhs) + return zeroValue(), fmt.Errorf("%w while attempting to add `%d` with `%d`", errOverflow, lhs, rhs) } return res, nil } @@ -333,15 +337,15 @@ func newSubtractEval(lhs Evaler, rhs Evaler) *subtractEval { func (n *subtractEval) Eval(ctx *Context) (types.Value, error) { lhs, err := evalLong(n.lhs, ctx) if err != nil { - return types.ZeroValue(), err + return zeroValue(), err } rhs, err := evalLong(n.rhs, ctx) if err != nil { - return types.ZeroValue(), err + return zeroValue(), err } res, ok := checkedSubI64(lhs, rhs) if !ok { - return types.ZeroValue(), fmt.Errorf("%w while attempting to subtract `%d` from `%d`", errOverflow, rhs, lhs) + return zeroValue(), fmt.Errorf("%w while attempting to subtract `%d` from `%d`", errOverflow, rhs, lhs) } return res, nil } @@ -362,15 +366,15 @@ func newMultiplyEval(lhs Evaler, rhs Evaler) *multiplyEval { func (n *multiplyEval) Eval(ctx *Context) (types.Value, error) { lhs, err := evalLong(n.lhs, ctx) if err != nil { - return types.ZeroValue(), err + return zeroValue(), err } rhs, err := evalLong(n.rhs, ctx) if err != nil { - return types.ZeroValue(), err + return zeroValue(), err } res, ok := checkedMulI64(lhs, rhs) if !ok { - return types.ZeroValue(), fmt.Errorf("%w while attempting to multiply `%d` by `%d`", errOverflow, lhs, rhs) + return zeroValue(), fmt.Errorf("%w while attempting to multiply `%d` by `%d`", errOverflow, lhs, rhs) } return res, nil } @@ -389,11 +393,11 @@ func newNegateEval(inner Evaler) *negateEval { func (n *negateEval) Eval(ctx *Context) (types.Value, error) { inner, err := evalLong(n.inner, ctx) if err != nil { - return types.ZeroValue(), err + return zeroValue(), err } res, ok := checkedNegI64(inner) if !ok { - return types.ZeroValue(), fmt.Errorf("%w while attempting to negate `%d`", errOverflow, inner) + return zeroValue(), fmt.Errorf("%w while attempting to negate `%d`", errOverflow, inner) } return res, nil } @@ -414,11 +418,11 @@ func newLongLessThanEval(lhs Evaler, rhs Evaler) *longLessThanEval { func (n *longLessThanEval) Eval(ctx *Context) (types.Value, error) { lhs, err := evalLong(n.lhs, ctx) if err != nil { - return types.ZeroValue(), err + return zeroValue(), err } rhs, err := evalLong(n.rhs, ctx) if err != nil { - return types.ZeroValue(), err + return zeroValue(), err } return types.Boolean(lhs < rhs), nil } @@ -439,11 +443,11 @@ func newLongLessThanOrEqualEval(lhs Evaler, rhs Evaler) *longLessThanOrEqualEval func (n *longLessThanOrEqualEval) Eval(ctx *Context) (types.Value, error) { lhs, err := evalLong(n.lhs, ctx) if err != nil { - return types.ZeroValue(), err + return zeroValue(), err } rhs, err := evalLong(n.rhs, ctx) if err != nil { - return types.ZeroValue(), err + return zeroValue(), err } return types.Boolean(lhs <= rhs), nil } @@ -464,11 +468,11 @@ func newLongGreaterThanEval(lhs Evaler, rhs Evaler) *longGreaterThanEval { func (n *longGreaterThanEval) Eval(ctx *Context) (types.Value, error) { lhs, err := evalLong(n.lhs, ctx) if err != nil { - return types.ZeroValue(), err + return zeroValue(), err } rhs, err := evalLong(n.rhs, ctx) if err != nil { - return types.ZeroValue(), err + return zeroValue(), err } return types.Boolean(lhs > rhs), nil } @@ -489,11 +493,11 @@ func newLongGreaterThanOrEqualEval(lhs Evaler, rhs Evaler) *longGreaterThanOrEqu func (n *longGreaterThanOrEqualEval) Eval(ctx *Context) (types.Value, error) { lhs, err := evalLong(n.lhs, ctx) if err != nil { - return types.ZeroValue(), err + return zeroValue(), err } rhs, err := evalLong(n.rhs, ctx) if err != nil { - return types.ZeroValue(), err + return zeroValue(), err } return types.Boolean(lhs >= rhs), nil } @@ -514,11 +518,11 @@ func newDecimalLessThanEval(lhs Evaler, rhs Evaler) *decimalLessThanEval { func (n *decimalLessThanEval) Eval(ctx *Context) (types.Value, error) { lhs, err := evalDecimal(n.lhs, ctx) if err != nil { - return types.ZeroValue(), err + return zeroValue(), err } rhs, err := evalDecimal(n.rhs, ctx) if err != nil { - return types.ZeroValue(), err + return zeroValue(), err } return types.Boolean(lhs < rhs), nil } @@ -539,11 +543,11 @@ func newDecimalLessThanOrEqualEval(lhs Evaler, rhs Evaler) *decimalLessThanOrEqu func (n *decimalLessThanOrEqualEval) Eval(ctx *Context) (types.Value, error) { lhs, err := evalDecimal(n.lhs, ctx) if err != nil { - return types.ZeroValue(), err + return zeroValue(), err } rhs, err := evalDecimal(n.rhs, ctx) if err != nil { - return types.ZeroValue(), err + return zeroValue(), err } return types.Boolean(lhs <= rhs), nil } @@ -564,11 +568,11 @@ func newDecimalGreaterThanEval(lhs Evaler, rhs Evaler) *decimalGreaterThanEval { func (n *decimalGreaterThanEval) Eval(ctx *Context) (types.Value, error) { lhs, err := evalDecimal(n.lhs, ctx) if err != nil { - return types.ZeroValue(), err + return zeroValue(), err } rhs, err := evalDecimal(n.rhs, ctx) if err != nil { - return types.ZeroValue(), err + return zeroValue(), err } return types.Boolean(lhs > rhs), nil } @@ -589,11 +593,11 @@ func newDecimalGreaterThanOrEqualEval(lhs Evaler, rhs Evaler) *decimalGreaterTha func (n *decimalGreaterThanOrEqualEval) Eval(ctx *Context) (types.Value, error) { lhs, err := evalDecimal(n.lhs, ctx) if err != nil { - return types.ZeroValue(), err + return zeroValue(), err } rhs, err := evalDecimal(n.rhs, ctx) if err != nil { - return types.ZeroValue(), err + return zeroValue(), err } return types.Boolean(lhs >= rhs), nil } @@ -616,7 +620,7 @@ func newIfThenElseEval(if_, then, else_ Evaler) *ifThenElseEval { func (n *ifThenElseEval) Eval(ctx *Context) (types.Value, error) { cond, err := evalBool(n.if_, ctx) if err != nil { - return types.ZeroValue(), err + return zeroValue(), err } if cond { return n.then.Eval(ctx) @@ -639,11 +643,11 @@ func newEqualEval(lhs, rhs Evaler) *equalEval { func (n *equalEval) Eval(ctx *Context) (types.Value, error) { lv, err := n.lhs.Eval(ctx) if err != nil { - return types.ZeroValue(), err + return zeroValue(), err } rv, err := n.rhs.Eval(ctx) if err != nil { - return types.ZeroValue(), err + return zeroValue(), err } return types.Boolean(lv.Equal(rv)), nil } @@ -663,11 +667,11 @@ func newNotEqualEval(lhs, rhs Evaler) *notEqualEval { func (n *notEqualEval) Eval(ctx *Context) (types.Value, error) { lv, err := n.lhs.Eval(ctx) if err != nil { - return types.ZeroValue(), err + return zeroValue(), err } rv, err := n.rhs.Eval(ctx) if err != nil { - return types.ZeroValue(), err + return zeroValue(), err } return types.Boolean(!lv.Equal(rv)), nil } @@ -686,7 +690,7 @@ func (n *setLiteralEval) Eval(ctx *Context) (types.Value, error) { for _, e := range n.elements { v, err := e.Eval(ctx) if err != nil { - return types.ZeroValue(), err + return zeroValue(), err } vals = append(vals, v) } @@ -708,11 +712,11 @@ func newContainsEval(lhs, rhs Evaler) *containsEval { func (n *containsEval) Eval(ctx *Context) (types.Value, error) { lhs, err := evalSet(n.lhs, ctx) if err != nil { - return types.ZeroValue(), err + return zeroValue(), err } rhs, err := n.rhs.Eval(ctx) if err != nil { - return types.ZeroValue(), err + return zeroValue(), err } return types.Boolean(lhs.Contains(rhs)), nil } @@ -732,11 +736,11 @@ func newContainsAllEval(lhs, rhs Evaler) *containsAllEval { func (n *containsAllEval) Eval(ctx *Context) (types.Value, error) { lhs, err := evalSet(n.lhs, ctx) if err != nil { - return types.ZeroValue(), err + return zeroValue(), err } rhs, err := evalSet(n.rhs, ctx) if err != nil { - return types.ZeroValue(), err + return zeroValue(), err } result := true for _, e := range rhs { @@ -763,11 +767,11 @@ func newContainsAnyEval(lhs, rhs Evaler) *containsAnyEval { func (n *containsAnyEval) Eval(ctx *Context) (types.Value, error) { lhs, err := evalSet(n.lhs, ctx) if err != nil { - return types.ZeroValue(), err + return zeroValue(), err } rhs, err := evalSet(n.rhs, ctx) if err != nil { - return types.ZeroValue(), err + return zeroValue(), err } result := false for _, e := range rhs { @@ -793,7 +797,7 @@ func (n *recordLiteralEval) Eval(ctx *Context) (types.Value, error) { for k, en := range n.elements { v, err := en.Eval(ctx) if err != nil { - return types.ZeroValue(), err + return zeroValue(), err } vals[k] = v } @@ -813,7 +817,7 @@ func newAttributeAccessEval(record Evaler, attribute string) *attributeAccessEva func (n *attributeAccessEval) Eval(ctx *Context) (types.Value, error) { v, err := n.object.Eval(ctx) if err != nil { - return types.ZeroValue(), err + return zeroValue(), err } var record types.Record key := "record" @@ -822,22 +826,22 @@ func (n *attributeAccessEval) Eval(ctx *Context) (types.Value, error) { key = "`" + vv.String() + "`" var unspecified types.EntityUID if vv == unspecified { - return types.ZeroValue(), fmt.Errorf("cannot access attribute `%s` of %w", n.attribute, errUnspecifiedEntity) + return zeroValue(), fmt.Errorf("cannot access attribute `%s` of %w", n.attribute, errUnspecifiedEntity) } rec, ok := ctx.Entities[vv] if !ok { - return types.ZeroValue(), fmt.Errorf("entity `%v` %w", vv.String(), errEntityNotExist) + return zeroValue(), fmt.Errorf("entity `%v` %w", vv.String(), errEntityNotExist) } else { record = rec.Attributes } case types.Record: record = vv default: - return types.ZeroValue(), fmt.Errorf("%w: expected one of [record, (entity of type `any_entity_type`)], got %v", ErrType, v.TypeName()) + return zeroValue(), fmt.Errorf("%w: expected one of [record, (entity of type `any_entity_type`)], got %v", ErrType, v.TypeName()) } val, ok := record[n.attribute] if !ok { - return types.ZeroValue(), fmt.Errorf("%s %w `%s`", key, errAttributeAccess, n.attribute) + return zeroValue(), fmt.Errorf("%s %w `%s`", key, errAttributeAccess, n.attribute) } return val, nil } @@ -855,7 +859,7 @@ func newHasEval(record Evaler, attribute string) *hasEval { func (n *hasEval) Eval(ctx *Context) (types.Value, error) { v, err := n.object.Eval(ctx) if err != nil { - return types.ZeroValue(), err + return zeroValue(), err } var record types.Record switch vv := v.(type) { @@ -869,7 +873,7 @@ func (n *hasEval) Eval(ctx *Context) (types.Value, error) { case types.Record: record = vv default: - return types.ZeroValue(), fmt.Errorf("%w: expected one of [record, (entity of type `any_entity_type`)], got %v", ErrType, v.TypeName()) + return zeroValue(), fmt.Errorf("%w: expected one of [record, (entity of type `any_entity_type`)], got %v", ErrType, v.TypeName()) } _, ok := record[n.attribute] return types.Boolean(ok), nil @@ -888,7 +892,7 @@ func newLikeEval(lhs Evaler, pattern types.Pattern) *likeEval { func (l *likeEval) Eval(ctx *Context) (types.Value, error) { v, err := evalString(l.lhs, ctx) if err != nil { - return types.ZeroValue(), err + return zeroValue(), err } return types.Boolean(l.pattern.Match(string(v))), nil } @@ -943,12 +947,12 @@ func entityIn(entity types.EntityUID, query map[types.EntityUID]struct{}, entity func (n *inEval) Eval(ctx *Context) (types.Value, error) { lhs, err := evalEntity(n.lhs, ctx) if err != nil { - return types.ZeroValue(), err + return zeroValue(), err } rhs, err := n.rhs.Eval(ctx) if err != nil { - return types.ZeroValue(), err + return zeroValue(), err } query := map[types.EntityUID]struct{}{} @@ -959,12 +963,12 @@ func (n *inEval) Eval(ctx *Context) (types.Value, error) { for _, rhv := range rhsv { e, err := ValueToEntity(rhv) if err != nil { - return types.ZeroValue(), err + return zeroValue(), err } query[e] = struct{}{} } default: - return types.ZeroValue(), fmt.Errorf( + return zeroValue(), fmt.Errorf( "%w: expected one of [set, (entity of type `any_entity_type`)], got %v", ErrType, rhs.TypeName()) } return types.Boolean(entityIn(lhs, query, ctx.Entities)), nil @@ -982,12 +986,12 @@ func newIsEval(lhs, rhs Evaler) *isEval { func (n *isEval) Eval(ctx *Context) (types.Value, error) { lhs, err := evalEntity(n.lhs, ctx) if err != nil { - return types.ZeroValue(), err + return zeroValue(), err } rhs, err := evalEntityType(n.rhs, ctx) if err != nil { - return types.ZeroValue(), err + return zeroValue(), err } return types.Boolean(types.EntityType(lhs.Type) == rhs), nil @@ -1005,12 +1009,12 @@ func newDecimalLiteralEval(literal Evaler) *decimalLiteralEval { func (n *decimalLiteralEval) Eval(ctx *Context) (types.Value, error) { literal, err := evalString(n.literal, ctx) if err != nil { - return types.ZeroValue(), err + return zeroValue(), err } d, err := types.ParseDecimal(string(literal)) if err != nil { - return types.ZeroValue(), err + return zeroValue(), err } return d, nil @@ -1027,12 +1031,12 @@ func newIPLiteralEval(literal Evaler) *ipLiteralEval { func (n *ipLiteralEval) Eval(ctx *Context) (types.Value, error) { literal, err := evalString(n.literal, ctx) if err != nil { - return types.ZeroValue(), err + return zeroValue(), err } i, err := types.ParseIPAddr(string(literal)) if err != nil { - return types.ZeroValue(), err + return zeroValue(), err } return i, nil @@ -1058,7 +1062,7 @@ func newIPTestEval(object Evaler, test ipTestType) *ipTestEval { func (n *ipTestEval) Eval(ctx *Context) (types.Value, error) { i, err := evalIP(n.object, ctx) if err != nil { - return types.ZeroValue(), err + return zeroValue(), err } return types.Boolean(n.test(i)), nil } @@ -1076,11 +1080,11 @@ func newIPIsInRangeEval(lhs, rhs Evaler) *ipIsInRangeEval { func (n *ipIsInRangeEval) Eval(ctx *Context) (types.Value, error) { lhs, err := evalIP(n.lhs, ctx) if err != nil { - return types.ZeroValue(), err + return zeroValue(), err } rhs, err := evalIP(n.rhs, ctx) if err != nil { - return types.ZeroValue(), err + return zeroValue(), err } return types.Boolean(rhs.Contains(lhs)), nil } diff --git a/internal/eval/evalers_test.go b/internal/eval/evalers_test.go index b1924ee8..bc110314 100644 --- a/internal/eval/evalers_test.go +++ b/internal/eval/evalers_test.go @@ -23,8 +23,8 @@ func AssertValue(t *testing.T, got, want types.Value) { t.Helper() testutil.FatalIf( t, - !((got == types.ZeroValue() && want == types.ZeroValue()) || - (got != types.ZeroValue() && want != types.ZeroValue() && got.Equal(want))), + !((got == zeroValue() && want == zeroValue()) || + (got != zeroValue() && want != zeroValue() && got.Equal(want))), "got %v want %v", got, want) } @@ -40,7 +40,7 @@ func AssertLongValue(t *testing.T, got types.Value, want int64) { func AssertZeroValue(t *testing.T, got types.Value) { t.Helper() - testutil.Equals(t, got, types.ZeroValue()) + testutil.Equals(t, got, zeroValue()) } func TestOrNode(t *testing.T) { @@ -960,9 +960,9 @@ func TestIfThenElseNode(t *testing.T) { {"Else", newLiteralEval(types.False), newLiteralEval(types.Long(-1)), newLiteralEval(types.Long(42)), types.Long(42), nil}, - {"Err", newErrorEval(errTest), newLiteralEval(types.ZeroValue()), newLiteralEval(types.ZeroValue()), types.ZeroValue(), + {"Err", newErrorEval(errTest), newLiteralEval(zeroValue()), newLiteralEval(zeroValue()), zeroValue(), errTest}, - {"ErrType", newLiteralEval(types.Long(123)), newLiteralEval(types.ZeroValue()), newLiteralEval(types.ZeroValue()), types.ZeroValue(), + {"ErrType", newLiteralEval(types.Long(123)), newLiteralEval(zeroValue()), newLiteralEval(zeroValue()), zeroValue(), ErrType}, } for _, tt := range tests { @@ -987,8 +987,8 @@ func TestEqualNode(t *testing.T) { }{ {"equals", newLiteralEval(types.Long(42)), newLiteralEval(types.Long(42)), types.True, nil}, {"notEquals", newLiteralEval(types.Long(42)), newLiteralEval(types.Long(1234)), types.False, nil}, - {"leftErr", newErrorEval(errTest), newLiteralEval(types.ZeroValue()), types.ZeroValue(), errTest}, - {"rightErr", newLiteralEval(types.ZeroValue()), newErrorEval(errTest), types.ZeroValue(), errTest}, + {"leftErr", newErrorEval(errTest), newLiteralEval(zeroValue()), zeroValue(), errTest}, + {"rightErr", newLiteralEval(zeroValue()), newErrorEval(errTest), zeroValue(), errTest}, {"typesNotEqual", newLiteralEval(types.Long(1)), newLiteralEval(types.True), types.False, nil}, } for _, tt := range tests { @@ -1013,8 +1013,8 @@ func TestNotEqualNode(t *testing.T) { }{ {"equals", newLiteralEval(types.Long(42)), newLiteralEval(types.Long(42)), types.False, nil}, {"notEquals", newLiteralEval(types.Long(42)), newLiteralEval(types.Long(1234)), types.True, nil}, - {"leftErr", newErrorEval(errTest), newLiteralEval(types.ZeroValue()), types.ZeroValue(), errTest}, - {"rightErr", newLiteralEval(types.ZeroValue()), newErrorEval(errTest), types.ZeroValue(), errTest}, + {"leftErr", newErrorEval(errTest), newLiteralEval(zeroValue()), zeroValue(), errTest}, + {"rightErr", newLiteralEval(zeroValue()), newErrorEval(errTest), zeroValue(), errTest}, {"typesNotEqual", newLiteralEval(types.Long(1)), newLiteralEval(types.True), types.True, nil}, } for _, tt := range tests { @@ -1038,7 +1038,7 @@ func TestSetLiteralNode(t *testing.T) { err error }{ {"empty", []Evaler{}, types.Set{}, nil}, - {"errorNode", []Evaler{newErrorEval(errTest)}, types.ZeroValue(), errTest}, + {"errorNode", []Evaler{newErrorEval(errTest)}, zeroValue(), errTest}, {"nested", []Evaler{ newLiteralEval(types.True), @@ -1244,7 +1244,7 @@ func TestRecordLiteralNode(t *testing.T) { err error }{ {"empty", map[string]Evaler{}, types.Record{}, nil}, - {"errorNode", map[string]Evaler{"foo": newErrorEval(errTest)}, types.ZeroValue(), errTest}, + {"errorNode", map[string]Evaler{"foo": newErrorEval(errTest)}, zeroValue(), errTest}, {"ok", map[string]Evaler{ "foo": newLiteralEval(types.True), @@ -1275,12 +1275,12 @@ func TestAttributeAccessNode(t *testing.T) { result types.Value err error }{ - {"RecordError", newErrorEval(errTest), "foo", types.ZeroValue(), errTest}, - {"RecordTypeError", newLiteralEval(types.True), "foo", types.ZeroValue(), ErrType}, + {"RecordError", newErrorEval(errTest), "foo", zeroValue(), errTest}, + {"RecordTypeError", newLiteralEval(types.True), "foo", zeroValue(), ErrType}, {"UnknownAttribute", newLiteralEval(types.Record{}), "foo", - types.ZeroValue(), + zeroValue(), errAttributeAccess}, {"KnownAttribute", newLiteralEval(types.Record{"foo": types.Long(42)}), @@ -1295,12 +1295,12 @@ func TestAttributeAccessNode(t *testing.T) { {"UnknownEntity", newLiteralEval(types.NewEntityUID("unknownType", "unknownID")), "unknownAttr", - types.ZeroValue(), + zeroValue(), errEntityNotExist}, {"UnspecifiedEntity", newLiteralEval(types.NewEntityUID("", "")), "knownAttr", - types.ZeroValue(), + zeroValue(), errUnspecifiedEntity}, } for _, tt := range tests { @@ -1332,8 +1332,8 @@ func TestHasNode(t *testing.T) { result types.Value err error }{ - {"RecordError", newErrorEval(errTest), "foo", types.ZeroValue(), errTest}, - {"RecordTypeError", newLiteralEval(types.True), "foo", types.ZeroValue(), ErrType}, + {"RecordError", newErrorEval(errTest), "foo", zeroValue(), errTest}, + {"RecordTypeError", newLiteralEval(types.True), "foo", zeroValue(), ErrType}, {"UnknownAttribute", newLiteralEval(types.Record{}), "foo", @@ -1389,8 +1389,8 @@ func TestLikeNode(t *testing.T) { result types.Value err error }{ - {"leftError", newErrorEval(errTest), `"foo"`, types.ZeroValue(), errTest}, - {"leftTypeError", newLiteralEval(types.True), `"foo"`, types.ZeroValue(), ErrType}, + {"leftError", newErrorEval(errTest), `"foo"`, zeroValue(), errTest}, + {"leftTypeError", newLiteralEval(types.True), `"foo"`, zeroValue(), ErrType}, {"noMatch", newLiteralEval(types.String("test")), `"zebra"`, types.False, nil}, {"match", newLiteralEval(types.String("test")), `"*es*"`, types.True, nil}, @@ -1616,10 +1616,10 @@ func TestIsNode(t *testing.T) { }{ {"happyEq", newLiteralEval(types.NewEntityUID("X", "z")), newLiteralEval(types.EntityType("X")), types.True, nil}, {"happyNeq", newLiteralEval(types.NewEntityUID("X", "z")), newLiteralEval(types.EntityType("Y")), types.False, nil}, - {"badLhs", newLiteralEval(types.Long(42)), newLiteralEval(types.EntityType("X")), types.ZeroValue(), ErrType}, - {"badRhs", newLiteralEval(types.NewEntityUID("X", "z")), newLiteralEval(types.Long(42)), types.ZeroValue(), ErrType}, - {"errLhs", newErrorEval(errTest), newLiteralEval(types.EntityType("X")), types.ZeroValue(), errTest}, - {"errRhs", newLiteralEval(types.NewEntityUID("X", "z")), newErrorEval(errTest), types.ZeroValue(), errTest}, + {"badLhs", newLiteralEval(types.Long(42)), newLiteralEval(types.EntityType("X")), zeroValue(), ErrType}, + {"badRhs", newLiteralEval(types.NewEntityUID("X", "z")), newLiteralEval(types.Long(42)), zeroValue(), ErrType}, + {"errLhs", newErrorEval(errTest), newLiteralEval(types.EntityType("X")), zeroValue(), errTest}, + {"errRhs", newLiteralEval(types.NewEntityUID("X", "z")), newErrorEval(errTest), zeroValue(), errTest}, } for _, tt := range tests { tt := tt @@ -1646,7 +1646,7 @@ func TestInNode(t *testing.T) { newErrorEval(errTest), newLiteralEval(types.Set{}), map[string][]string{}, - types.ZeroValue(), + zeroValue(), errTest, }, { @@ -1654,7 +1654,7 @@ func TestInNode(t *testing.T) { newLiteralEval(types.String("foo")), newLiteralEval(types.Set{}), map[string][]string{}, - types.ZeroValue(), + zeroValue(), ErrType, }, { @@ -1662,7 +1662,7 @@ func TestInNode(t *testing.T) { newLiteralEval(types.NewEntityUID("human", "joe")), newErrorEval(errTest), map[string][]string{}, - types.ZeroValue(), + zeroValue(), errTest, }, { @@ -1670,7 +1670,7 @@ func TestInNode(t *testing.T) { newLiteralEval(types.NewEntityUID("human", "joe")), newLiteralEval(types.String("foo")), map[string][]string{}, - types.ZeroValue(), + zeroValue(), ErrType, }, { @@ -1680,7 +1680,7 @@ func TestInNode(t *testing.T) { types.String("foo"), }), map[string][]string{}, - types.ZeroValue(), + zeroValue(), ErrType, }, { @@ -1757,9 +1757,9 @@ func TestDecimalLiteralNode(t *testing.T) { result types.Value err error }{ - {"Error", newErrorEval(errTest), types.ZeroValue(), errTest}, - {"TypeError", newLiteralEval(types.Long(1)), types.ZeroValue(), ErrType}, - {"DecimalError", newLiteralEval(types.String("frob")), types.ZeroValue(), types.ErrDecimal}, + {"Error", newErrorEval(errTest), zeroValue(), errTest}, + {"TypeError", newLiteralEval(types.Long(1)), zeroValue(), ErrType}, + {"DecimalError", newLiteralEval(types.String("frob")), zeroValue(), types.ErrDecimal}, {"Success", newLiteralEval(types.String("1.0")), types.Decimal(10000), nil}, } for _, tt := range tests { @@ -1784,9 +1784,9 @@ func TestIPLiteralNode(t *testing.T) { result types.Value err error }{ - {"Error", newErrorEval(errTest), types.ZeroValue(), errTest}, - {"TypeError", newLiteralEval(types.Long(1)), types.ZeroValue(), ErrType}, - {"IPError", newLiteralEval(types.String("not-an-IP-address")), types.ZeroValue(), types.ErrIP}, + {"Error", newErrorEval(errTest), zeroValue(), errTest}, + {"TypeError", newLiteralEval(types.Long(1)), zeroValue(), ErrType}, + {"IPError", newLiteralEval(types.String("not-an-IP-address")), zeroValue(), types.ErrIP}, {"Success", newLiteralEval(types.String("::1/128")), ipv6Loopback, nil}, } for _, tt := range tests { @@ -1816,8 +1816,8 @@ func TestIPTestNode(t *testing.T) { result types.Value err error }{ - {"Error", newErrorEval(errTest), ipTestIPv4, types.ZeroValue(), errTest}, - {"TypeError", newLiteralEval(types.Long(1)), ipTestIPv4, types.ZeroValue(), ErrType}, + {"Error", newErrorEval(errTest), ipTestIPv4, zeroValue(), errTest}, + {"TypeError", newLiteralEval(types.Long(1)), ipTestIPv4, zeroValue(), ErrType}, {"IPv4True", newLiteralEval(ipv4Loopback), ipTestIPv4, types.True, nil}, {"IPv4False", newLiteralEval(ipv6Loopback), ipTestIPv4, types.False, nil}, {"IPv6True", newLiteralEval(ipv6Loopback), ipTestIPv6, types.True, nil}, @@ -1853,10 +1853,10 @@ func TestIPIsInRangeNode(t *testing.T) { result types.Value err error }{ - {"LhsError", newErrorEval(errTest), newLiteralEval(ipv4A), types.ZeroValue(), errTest}, - {"LhsTypeError", newLiteralEval(types.Long(1)), newLiteralEval(ipv4A), types.ZeroValue(), ErrType}, - {"RhsError", newLiteralEval(ipv4A), newErrorEval(errTest), types.ZeroValue(), errTest}, - {"RhsTypeError", newLiteralEval(ipv4A), newLiteralEval(types.Long(1)), types.ZeroValue(), ErrType}, + {"LhsError", newErrorEval(errTest), newLiteralEval(ipv4A), zeroValue(), errTest}, + {"LhsTypeError", newLiteralEval(types.Long(1)), newLiteralEval(ipv4A), zeroValue(), ErrType}, + {"RhsError", newLiteralEval(ipv4A), newErrorEval(errTest), zeroValue(), errTest}, + {"RhsTypeError", newLiteralEval(ipv4A), newLiteralEval(types.Long(1)), zeroValue(), ErrType}, {"AA", newLiteralEval(ipv4A), newLiteralEval(ipv4A), types.True, nil}, {"AB", newLiteralEval(ipv4A), newLiteralEval(ipv4B), types.True, nil}, {"BA", newLiteralEval(ipv4B), newLiteralEval(ipv4A), types.False, nil}, diff --git a/types/json_test.go b/types/json_test.go index 93e58317..4e42bb67 100644 --- a/types/json_test.go +++ b/types/json_test.go @@ -8,6 +8,10 @@ import ( "github.com/cedar-policy/cedar-go/internal/testutil" ) +func zeroValue() Value { + return nil +} + func mustDecimalValue(v string) Decimal { r, _ := ParseDecimal(v) return r @@ -22,8 +26,8 @@ func AssertValue(t *testing.T, got, want Value) { t.Helper() testutil.FatalIf( t, - !((got == ZeroValue() && want == ZeroValue()) || - (got != ZeroValue() && want != ZeroValue() && got.Equal(want))), + !((got == zeroValue() && want == zeroValue()) || + (got != zeroValue() && want != zeroValue() && got.Equal(want))), "got %v want %v", got, want) } @@ -39,15 +43,15 @@ func TestJSON_Value(t *testing.T) { {"explicitEntity", `{ "__entity": { "type": "User", "id": "alice" } }`, EntityUID{Type: "User", ID: "alice"}, nil}, {"impliedLongEntity", `{ "type": "User::External", "id": "alice" }`, EntityUID{Type: "User::External", ID: "alice"}, nil}, {"explicitLongEntity", `{ "__entity": { "type": "User::External", "id": "alice" } }`, EntityUID{Type: "User::External", ID: "alice"}, nil}, - {"invalidJSON", `!@#$`, ZeroValue(), errJSONDecode}, - {"numericOverflow", "12341234123412341234", ZeroValue(), errJSONLongOutOfRange}, - {"unsupportedNull", "null", ZeroValue(), errJSONUnsupportedType}, + {"invalidJSON", `!@#$`, zeroValue(), errJSONDecode}, + {"numericOverflow", "12341234123412341234", zeroValue(), errJSONLongOutOfRange}, + {"unsupportedNull", "null", zeroValue(), errJSONUnsupportedType}, {"explicitIP", `{ "__extn": { "fn": "ip", "arg": "222.222.222.7" } }`, mustIPValue("222.222.222.7"), nil}, {"explicitSubnet", `{ "__extn": { "fn": "ip", "arg": "192.168.0.0/16" } }`, mustIPValue("192.168.0.0/16"), nil}, {"explicitDecimal", `{ "__extn": { "fn": "decimal", "arg": "33.57" } }`, mustDecimalValue("33.57"), nil}, - {"invalidExtension", `{ "__extn": { "fn": "asdf", "arg": "blah" } }`, ZeroValue(), errJSONInvalidExtn}, - {"badIP", `{ "__extn": { "fn": "ip", "arg": "bad" } }`, ZeroValue(), ErrIP}, - {"badDecimal", `{ "__extn": { "fn": "decimal", "arg": "bad" } }`, ZeroValue(), ErrDecimal}, + {"invalidExtension", `{ "__extn": { "fn": "asdf", "arg": "blah" } }`, zeroValue(), errJSONInvalidExtn}, + {"badIP", `{ "__extn": { "fn": "ip", "arg": "bad" } }`, zeroValue(), ErrIP}, + {"badDecimal", `{ "__extn": { "fn": "decimal", "arg": "bad" } }`, zeroValue(), ErrDecimal}, {"set", `[42]`, Set{Long(42)}, nil}, {"record", `{"a":"b"}`, Record{"a": String("b")}, nil}, {"bool", `false`, Boolean(false), nil}, diff --git a/types/value.go b/types/value.go index fc528ec8..64b42487 100644 --- a/types/value.go +++ b/types/value.go @@ -20,7 +20,3 @@ type Value interface { TypeName() string deepClone() Value } - -func ZeroValue() Value { - return nil -} From f48a39abf8db79ba9f332b055cff3def0e67a79f Mon Sep 17 00:00:00 2001 From: philhassey Date: Mon, 19 Aug 2024 17:19:30 -0600 Subject: [PATCH 174/216] types: remove redundant equals method Addresses IDX-142 Signed-off-by: philhassey --- types/record.go | 2 -- types/record_test.go | 18 +++++++++--------- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/types/record.go b/types/record.go index a5934b87..a9215ded 100644 --- a/types/record.go +++ b/types/record.go @@ -15,8 +15,6 @@ import ( type Record map[string]Value // Equals returns true if the records are Equal. -func (r Record) Equals(b Record) bool { return r.Equal(b) } - func (a Record) Equal(bi Value) bool { b, ok := bi.(Record) if !ok || len(a) != len(b) { diff --git a/types/record_test.go b/types/record_test.go index 382aee5c..3b9c0c3b 100644 --- a/types/record_test.go +++ b/types/record_test.go @@ -41,18 +41,18 @@ func TestRecord(t *testing.T) { "nest": twoElems, } - testutil.FatalIf(t, !empty.Equals(empty), "%v not Equal to %v", empty, empty) - testutil.FatalIf(t, !empty.Equals(empty2), "%v not Equal to %v", empty, empty2) + testutil.FatalIf(t, !empty.Equal(empty), "%v not Equal to %v", empty, empty) + testutil.FatalIf(t, !empty.Equal(empty2), "%v not Equal to %v", empty, empty2) - testutil.FatalIf(t, !twoElems.Equals(twoElems), "%v not Equal to %v", twoElems, twoElems) - testutil.FatalIf(t, !twoElems.Equals(twoElems2), "%v not Equal to %v", twoElems, twoElems2) + testutil.FatalIf(t, !twoElems.Equal(twoElems), "%v not Equal to %v", twoElems, twoElems) + testutil.FatalIf(t, !twoElems.Equal(twoElems2), "%v not Equal to %v", twoElems, twoElems2) - testutil.FatalIf(t, !nested.Equals(nested), "%v not Equal to %v", nested, nested) - testutil.FatalIf(t, !nested.Equals(nested2), "%v not Equal to %v", nested, nested2) + testutil.FatalIf(t, !nested.Equal(nested), "%v not Equal to %v", nested, nested) + testutil.FatalIf(t, !nested.Equal(nested2), "%v not Equal to %v", nested, nested2) - testutil.FatalIf(t, nested.Equals(twoElems), "%v Equal to %v", nested, twoElems) - testutil.FatalIf(t, twoElems.Equals(differentValues), "%v Equal to %v", twoElems, differentValues) - testutil.FatalIf(t, twoElems.Equals(differentKeys), "%v Equal to %v", twoElems, differentKeys) + testutil.FatalIf(t, nested.Equal(twoElems), "%v Equal to %v", nested, twoElems) + testutil.FatalIf(t, twoElems.Equal(differentValues), "%v Equal to %v", twoElems, differentValues) + testutil.FatalIf(t, twoElems.Equal(differentKeys), "%v Equal to %v", twoElems, differentKeys) }) t.Run("string", func(t *testing.T) { From 327edac7bfd7f328065f9f65adfabe559e4e23f3 Mon Sep 17 00:00:00 2001 From: philhassey Date: Mon, 19 Aug 2024 17:20:46 -0600 Subject: [PATCH 175/216] types: remove redundant equals method from set Addresses IDX-142 Signed-off-by: philhassey --- types/set.go | 4 +--- types/set_test.go | 26 +++++++++++++------------- 2 files changed, 14 insertions(+), 16 deletions(-) diff --git a/types/set.go b/types/set.go index 510b68e8..0c1b8298 100644 --- a/types/set.go +++ b/types/set.go @@ -18,9 +18,7 @@ func (s Set) Contains(v Value) bool { return false } -// Equals returns true if the sets are Equal. -func (s Set) Equals(b Set) bool { return s.Equal(b) } - +// Equal returns true if the sets are Equal. func (as Set) Equal(bi Value) bool { bs, ok := bi.(Set) if !ok { diff --git a/types/set_test.go b/types/set_test.go index 43adc5e8..7e4bc5ff 100644 --- a/types/set_test.go +++ b/types/set_test.go @@ -28,19 +28,19 @@ func TestSet(t *testing.T) { types.Long(3), types.Long(2), types.Long(2), types.Long(1), } - testutil.FatalIf(t, !empty.Equals(empty), "%v not Equal to %v", empty, empty) - testutil.FatalIf(t, !empty.Equals(empty2), "%v not Equal to %v", empty, empty2) - testutil.FatalIf(t, !oneTrue.Equals(oneTrue), "%v not Equal to %v", oneTrue, oneTrue) - testutil.FatalIf(t, !oneTrue.Equals(oneTrue2), "%v not Equal to %v", oneTrue, oneTrue2) - testutil.FatalIf(t, !nestedOnce.Equals(nestedOnce), "%v not Equal to %v", nestedOnce, nestedOnce) - testutil.FatalIf(t, !nestedOnce.Equals(nestedOnce2), "%v not Equal to %v", nestedOnce, nestedOnce2) - testutil.FatalIf(t, !nestedTwice.Equals(nestedTwice), "%v not Equal to %v", nestedTwice, nestedTwice) - testutil.FatalIf(t, !nestedTwice.Equals(nestedTwice2), "%v not Equal to %v", nestedTwice, nestedTwice2) - testutil.FatalIf(t, !oneTwoThree.Equals(threeTwoTwoOne), "%v not Equal to %v", oneTwoThree, threeTwoTwoOne) - - testutil.FatalIf(t, empty.Equals(oneFalse), "%v Equal to %v", empty, oneFalse) - testutil.FatalIf(t, oneTrue.Equals(oneFalse), "%v Equal to %v", oneTrue, oneFalse) - testutil.FatalIf(t, nestedOnce.Equals(nestedTwice), "%v Equal to %v", nestedOnce, nestedTwice) + testutil.FatalIf(t, !empty.Equal(empty), "%v not Equal to %v", empty, empty) + testutil.FatalIf(t, !empty.Equal(empty2), "%v not Equal to %v", empty, empty2) + testutil.FatalIf(t, !oneTrue.Equal(oneTrue), "%v not Equal to %v", oneTrue, oneTrue) + testutil.FatalIf(t, !oneTrue.Equal(oneTrue2), "%v not Equal to %v", oneTrue, oneTrue2) + testutil.FatalIf(t, !nestedOnce.Equal(nestedOnce), "%v not Equal to %v", nestedOnce, nestedOnce) + testutil.FatalIf(t, !nestedOnce.Equal(nestedOnce2), "%v not Equal to %v", nestedOnce, nestedOnce2) + testutil.FatalIf(t, !nestedTwice.Equal(nestedTwice), "%v not Equal to %v", nestedTwice, nestedTwice) + testutil.FatalIf(t, !nestedTwice.Equal(nestedTwice2), "%v not Equal to %v", nestedTwice, nestedTwice2) + testutil.FatalIf(t, !oneTwoThree.Equal(threeTwoTwoOne), "%v not Equal to %v", oneTwoThree, threeTwoTwoOne) + + testutil.FatalIf(t, empty.Equal(oneFalse), "%v Equal to %v", empty, oneFalse) + testutil.FatalIf(t, oneTrue.Equal(oneFalse), "%v Equal to %v", oneTrue, oneFalse) + testutil.FatalIf(t, nestedOnce.Equal(nestedTwice), "%v Equal to %v", nestedOnce, nestedTwice) }) t.Run("string", func(t *testing.T) { From bf3d07d318b2b7163bd807878dedc2fb59cf2a29 Mon Sep 17 00:00:00 2001 From: philhassey Date: Mon, 19 Aug 2024 17:35:04 -0600 Subject: [PATCH 176/216] internal/eval: move typename to eval Addresses IDX-142 Signed-off-by: philhassey --- internal/eval/evalers.go | 6 +++--- internal/eval/util.go | 43 ++++++++++++++++++++++++++++++-------- internal/eval/util_test.go | 29 +++++++++++++++++++++++++ types/boolean_test.go | 5 ----- types/decimal.go | 1 - types/decimal_test.go | 5 ----- types/entity_type.go | 2 -- types/entity_type_test.go | 6 +----- types/entity_uid.go | 2 -- types/entity_uid_test.go | 5 ----- types/ipaddr.go | 1 - types/ipaddr_test.go | 5 ----- types/long.go | 1 - types/long_test.go | 5 ----- types/record.go | 1 - types/record_test.go | 5 ----- types/set.go | 2 -- types/set_test.go | 5 ----- types/string.go | 1 - types/string_test.go | 5 ----- types/value.go | 1 - 21 files changed, 67 insertions(+), 69 deletions(-) diff --git a/internal/eval/evalers.go b/internal/eval/evalers.go index 79b15a1c..d8367921 100644 --- a/internal/eval/evalers.go +++ b/internal/eval/evalers.go @@ -837,7 +837,7 @@ func (n *attributeAccessEval) Eval(ctx *Context) (types.Value, error) { case types.Record: record = vv default: - return zeroValue(), fmt.Errorf("%w: expected one of [record, (entity of type `any_entity_type`)], got %v", ErrType, v.TypeName()) + return zeroValue(), fmt.Errorf("%w: expected one of [record, (entity of type `any_entity_type`)], got %v", ErrType, TypeName(v)) } val, ok := record[n.attribute] if !ok { @@ -873,7 +873,7 @@ func (n *hasEval) Eval(ctx *Context) (types.Value, error) { case types.Record: record = vv default: - return zeroValue(), fmt.Errorf("%w: expected one of [record, (entity of type `any_entity_type`)], got %v", ErrType, v.TypeName()) + return zeroValue(), fmt.Errorf("%w: expected one of [record, (entity of type `any_entity_type`)], got %v", ErrType, TypeName(v)) } _, ok := record[n.attribute] return types.Boolean(ok), nil @@ -969,7 +969,7 @@ func (n *inEval) Eval(ctx *Context) (types.Value, error) { } default: return zeroValue(), fmt.Errorf( - "%w: expected one of [set, (entity of type `any_entity_type`)], got %v", ErrType, rhs.TypeName()) + "%w: expected one of [set, (entity of type `any_entity_type`)], got %v", ErrType, TypeName(rhs)) } return types.Boolean(entityIn(lhs, query, ctx.Entities)), nil } diff --git a/internal/eval/util.go b/internal/eval/util.go index ce4ba1f8..81c133fd 100644 --- a/internal/eval/util.go +++ b/internal/eval/util.go @@ -6,12 +6,37 @@ import ( "github.com/cedar-policy/cedar-go/types" ) +func TypeName(v types.Value) string { + switch t := v.(type) { + case types.Boolean: + return "bool" + case types.Decimal: + return "decimal" + case types.EntityType: + return fmt.Sprintf("(EntityType of type `%s`)", t) + case types.EntityUID: + return fmt.Sprintf("(entity of type `%s`)", t.Type) + case types.IPAddr: + return "IP" + case types.Long: + return "long" + case types.Record: + return "record" + case types.Set: + return "set" + case types.String: + return "string" + default: + return "unknown type" + } +} + var ErrType = fmt.Errorf("type error") func ValueToBool(v types.Value) (types.Boolean, error) { bv, ok := v.(types.Boolean) if !ok { - return false, fmt.Errorf("%w: expected bool, got %v", ErrType, v.TypeName()) + return false, fmt.Errorf("%w: expected bool, got %v", ErrType, TypeName(v)) } return bv, nil } @@ -19,7 +44,7 @@ func ValueToBool(v types.Value) (types.Boolean, error) { func ValueToLong(v types.Value) (types.Long, error) { lv, ok := v.(types.Long) if !ok { - return 0, fmt.Errorf("%w: expected long, got %v", ErrType, v.TypeName()) + return 0, fmt.Errorf("%w: expected long, got %v", ErrType, TypeName(v)) } return lv, nil } @@ -27,7 +52,7 @@ func ValueToLong(v types.Value) (types.Long, error) { func ValueToString(v types.Value) (types.String, error) { sv, ok := v.(types.String) if !ok { - return "", fmt.Errorf("%w: expected string, got %v", ErrType, v.TypeName()) + return "", fmt.Errorf("%w: expected string, got %v", ErrType, TypeName(v)) } return sv, nil } @@ -35,7 +60,7 @@ func ValueToString(v types.Value) (types.String, error) { func ValueToSet(v types.Value) (types.Set, error) { sv, ok := v.(types.Set) if !ok { - return nil, fmt.Errorf("%w: expected set, got %v", ErrType, v.TypeName()) + return nil, fmt.Errorf("%w: expected set, got %v", ErrType, TypeName(v)) } return sv, nil } @@ -43,7 +68,7 @@ func ValueToSet(v types.Value) (types.Set, error) { func ValueToRecord(v types.Value) (types.Record, error) { rv, ok := v.(types.Record) if !ok { - return nil, fmt.Errorf("%w: expected record got %v", ErrType, v.TypeName()) + return nil, fmt.Errorf("%w: expected record got %v", ErrType, TypeName(v)) } return rv, nil } @@ -51,7 +76,7 @@ func ValueToRecord(v types.Value) (types.Record, error) { func ValueToEntity(v types.Value) (types.EntityUID, error) { ev, ok := v.(types.EntityUID) if !ok { - return types.EntityUID{}, fmt.Errorf("%w: expected (entity of type `any_entity_type`), got %v", ErrType, v.TypeName()) + return types.EntityUID{}, fmt.Errorf("%w: expected (entity of type `any_entity_type`), got %v", ErrType, TypeName(v)) } return ev, nil } @@ -59,7 +84,7 @@ func ValueToEntity(v types.Value) (types.EntityUID, error) { func ValueToEntityType(v types.Value) (types.EntityType, error) { ev, ok := v.(types.EntityType) if !ok { - return "", fmt.Errorf("%w: expected (EntityType of type `any_entity_type`), got %v", ErrType, v.TypeName()) + return "", fmt.Errorf("%w: expected (EntityType of type `any_entity_type`), got %v", ErrType, TypeName(v)) } return ev, nil } @@ -67,7 +92,7 @@ func ValueToEntityType(v types.Value) (types.EntityType, error) { func ValueToDecimal(v types.Value) (types.Decimal, error) { d, ok := v.(types.Decimal) if !ok { - return 0, fmt.Errorf("%w: expected decimal, got %v", ErrType, v.TypeName()) + return 0, fmt.Errorf("%w: expected decimal, got %v", ErrType, TypeName(v)) } return d, nil } @@ -75,7 +100,7 @@ func ValueToDecimal(v types.Value) (types.Decimal, error) { func ValueToIP(v types.Value) (types.IPAddr, error) { i, ok := v.(types.IPAddr) if !ok { - return types.IPAddr{}, fmt.Errorf("%w: expected ipaddr, got %v", ErrType, v.TypeName()) + return types.IPAddr{}, fmt.Errorf("%w: expected ipaddr, got %v", ErrType, TypeName(v)) } return i, nil } diff --git a/internal/eval/util_test.go b/internal/eval/util_test.go index 7164da1f..0c9a9690 100644 --- a/internal/eval/util_test.go +++ b/internal/eval/util_test.go @@ -151,3 +151,32 @@ func TestUtil(t *testing.T) { }) } + +func TestTypeName(t *testing.T) { + t.Parallel() + tests := []struct { + name string + in types.Value + out string + }{ + + {"boolean", types.Boolean(true), "bool"}, + {"decimal", types.Decimal(42), "decimal"}, + {"entityType", types.EntityType("T"), "(EntityType of type `T`)"}, + {"entityUID", types.NewEntityUID("T", "42"), "(entity of type `T`)"}, + {"ip", types.IPAddr{}, "IP"}, + {"long", types.Long(42), "long"}, + {"record", types.Record{}, "record"}, + {"set", types.Set{}, "set"}, + {"string", types.String("test"), "string"}, + {"nil", nil, "unknown type"}, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + out := TypeName(tt.in) + testutil.Equals(t, out, tt.out) + }) + } +} diff --git a/types/boolean_test.go b/types/boolean_test.go index fdfd2833..25d7e2b8 100644 --- a/types/boolean_test.go +++ b/types/boolean_test.go @@ -28,9 +28,4 @@ func TestBool(t *testing.T) { AssertValueString(t, types.Boolean(true), "true") }) - t.Run("TypeName", func(t *testing.T) { - t.Parallel() - tn := types.Boolean(true).TypeName() - testutil.Equals(t, tn, "bool") - }) } diff --git a/types/decimal.go b/types/decimal.go index 21796602..ecca6b98 100644 --- a/types/decimal.go +++ b/types/decimal.go @@ -112,7 +112,6 @@ func (a Decimal) Equal(bi Value) bool { return ok && a == b } -func (v Decimal) TypeName() string { return "decimal" } // Cedar produces a valid Cedar language representation of the Decimal, e.g. `decimal("12.34")`. func (v Decimal) Cedar() string { return `decimal("` + v.String() + `")` } diff --git a/types/decimal_test.go b/types/decimal_test.go index 839aba8f..d6002178 100644 --- a/types/decimal_test.go +++ b/types/decimal_test.go @@ -120,9 +120,4 @@ func TestDecimal(t *testing.T) { testutil.FatalIf(t, zero.Equal(f), "%v Equal to %v", zero, f) }) - t.Run("TypeName", func(t *testing.T) { - t.Parallel() - tn := types.Decimal(0).TypeName() - testutil.Equals(t, tn, "decimal") - }) } diff --git a/types/entity_type.go b/types/entity_type.go index b1331676..576bcc03 100644 --- a/types/entity_type.go +++ b/types/entity_type.go @@ -2,7 +2,6 @@ package types import ( "encoding/json" - "fmt" "strings" ) @@ -13,7 +12,6 @@ func (a EntityType) Equal(bi Value) bool { b, ok := bi.(EntityType) return ok && a == b } -func (v EntityType) TypeName() string { return fmt.Sprintf("(EntityType of type `%s`)", v) } func (v EntityType) String() string { return string(v) } func (v EntityType) Cedar() string { return string(v) } diff --git a/types/entity_type_test.go b/types/entity_type_test.go index bef081ff..3c59a653 100644 --- a/types/entity_type_test.go +++ b/types/entity_type_test.go @@ -19,11 +19,7 @@ func TestEntityType(t *testing.T) { testutil.Equals(t, a.Equal(c), false) testutil.Equals(t, c.Equal(a), false) }) - t.Run("TypeName", func(t *testing.T) { - t.Parallel() - a := types.EntityType("X") - testutil.Equals(t, a.TypeName(), "(EntityType of type `X`)") - }) + t.Run("String", func(t *testing.T) { t.Parallel() a := types.EntityType("X") diff --git a/types/entity_uid.go b/types/entity_uid.go index e863317c..6e2c6307 100644 --- a/types/entity_uid.go +++ b/types/entity_uid.go @@ -2,7 +2,6 @@ package types import ( "encoding/json" - "fmt" "strconv" ) @@ -28,7 +27,6 @@ func (a EntityUID) Equal(bi Value) bool { b, ok := bi.(EntityUID) return ok && a == b } -func (v EntityUID) TypeName() string { return fmt.Sprintf("(entity of type `%s`)", v.Type) } // String produces a string representation of the EntityUID, e.g. `Type::"id"`. func (v EntityUID) String() string { return v.Cedar() } diff --git a/types/entity_uid_test.go b/types/entity_uid_test.go index 826a2dcb..51e5605c 100644 --- a/types/entity_uid_test.go +++ b/types/entity_uid_test.go @@ -26,9 +26,4 @@ func TestEntity(t *testing.T) { AssertValueString(t, types.EntityUID{Type: "namespace::type", ID: "id"}, `namespace::type::"id"`) }) - t.Run("TypeName", func(t *testing.T) { - t.Parallel() - tn := types.EntityUID{"T", "id"}.TypeName() - testutil.Equals(t, tn, "(entity of type `T`)") - }) } diff --git a/types/ipaddr.go b/types/ipaddr.go index af738c6b..02bfb402 100644 --- a/types/ipaddr.go +++ b/types/ipaddr.go @@ -31,7 +31,6 @@ func (a IPAddr) Equal(bi Value) bool { return ok && a == b } -func (v IPAddr) TypeName() string { return "IP" } // Cedar produces a valid Cedar language representation of the IPAddr, e.g. `ip("127.0.0.1")`. func (v IPAddr) Cedar() string { return `ip("` + v.String() + `")` } diff --git a/types/ipaddr_test.go b/types/ipaddr_test.go index 997a4466..4346539f 100644 --- a/types/ipaddr_test.go +++ b/types/ipaddr_test.go @@ -270,9 +270,4 @@ func TestIP(t *testing.T) { } }) - t.Run("TypeName", func(t *testing.T) { - t.Parallel() - tn := types.IPAddr{}.TypeName() - testutil.Equals(t, tn, "IP") - }) } diff --git a/types/long.go b/types/long.go index 01751c1a..5c1319e3 100644 --- a/types/long.go +++ b/types/long.go @@ -15,7 +15,6 @@ func (a Long) Equal(bi Value) bool { // ExplicitMarshalJSON marshals the Long into JSON. func (v Long) ExplicitMarshalJSON() ([]byte, error) { return json.Marshal(v) } -func (v Long) TypeName() string { return "long" } // String produces a string representation of the Long, e.g. `42`. func (v Long) String() string { return v.Cedar() } diff --git a/types/long_test.go b/types/long_test.go index d4de2134..1ae49adb 100644 --- a/types/long_test.go +++ b/types/long_test.go @@ -28,9 +28,4 @@ func TestLong(t *testing.T) { AssertValueString(t, types.Long(1), "1") }) - t.Run("TypeName", func(t *testing.T) { - t.Parallel() - tn := types.Long(1).TypeName() - testutil.Equals(t, tn, "long") - }) } diff --git a/types/record.go b/types/record.go index a9215ded..cc953e26 100644 --- a/types/record.go +++ b/types/record.go @@ -70,7 +70,6 @@ func (v Record) MarshalJSON() ([]byte, error) { // ExplicitMarshalJSON marshals the Record into JSON, the marshaller uses the // explicit JSON form for all the values in the Record. func (v Record) ExplicitMarshalJSON() ([]byte, error) { return v.MarshalJSON() } -func (r Record) TypeName() string { return "record" } // String produces a string representation of the Record, e.g. `{"a":1,"b":2,"c":3}`. func (r Record) String() string { return r.Cedar() } diff --git a/types/record_test.go b/types/record_test.go index 3b9c0c3b..dd5b6068 100644 --- a/types/record_test.go +++ b/types/record_test.go @@ -71,9 +71,4 @@ func TestRecord(t *testing.T) { `{"bar":"blah", "foo":true}`) }) - t.Run("TypeName", func(t *testing.T) { - t.Parallel() - tn := types.Record{}.TypeName() - testutil.Equals(t, tn, "record") - }) } diff --git a/types/set.go b/types/set.go index 0c1b8298..fc7da4a5 100644 --- a/types/set.go +++ b/types/set.go @@ -76,8 +76,6 @@ func (v Set) MarshalJSON() ([]byte, error) { // explicit JSON form for all the values in the Set. func (v Set) ExplicitMarshalJSON() ([]byte, error) { return v.MarshalJSON() } -func (v Set) TypeName() string { return "set" } - // String produces a string representation of the Set, e.g. `[1,2,3]`. func (v Set) String() string { return v.Cedar() } diff --git a/types/set_test.go b/types/set_test.go index 7e4bc5ff..6705df6c 100644 --- a/types/set_test.go +++ b/types/set_test.go @@ -52,9 +52,4 @@ func TestSet(t *testing.T) { "[true, 1]") }) - t.Run("TypeName", func(t *testing.T) { - t.Parallel() - tn := types.Set{}.TypeName() - testutil.Equals(t, tn, "set") - }) } diff --git a/types/string.go b/types/string.go index dd15a4a7..9891aa9e 100644 --- a/types/string.go +++ b/types/string.go @@ -15,7 +15,6 @@ func (a String) Equal(bi Value) bool { // ExplicitMarshalJSON marshals the String into JSON. func (v String) ExplicitMarshalJSON() ([]byte, error) { return json.Marshal(v) } -func (v String) TypeName() string { return "string" } // String produces an unquoted string representation of the String, e.g. `hello`. func (v String) String() string { diff --git a/types/string_test.go b/types/string_test.go index a995f64e..ca9e262b 100644 --- a/types/string_test.go +++ b/types/string_test.go @@ -26,9 +26,4 @@ func TestString(t *testing.T) { AssertValueString(t, types.String("hello\ngoodbye"), "hello\ngoodbye") }) - t.Run("TypeName", func(t *testing.T) { - t.Parallel() - tn := types.String("hello").TypeName() - testutil.Equals(t, tn, "string") - }) } diff --git a/types/value.go b/types/value.go index 64b42487..a1922e89 100644 --- a/types/value.go +++ b/types/value.go @@ -17,6 +17,5 @@ type Value interface { // Sets or Records where the type is not defined. ExplicitMarshalJSON() ([]byte, error) Equal(Value) bool - TypeName() string deepClone() Value } From 1bc846d0dd2bbaa0e764d0fe0c77a483302235b8 Mon Sep 17 00:00:00 2001 From: philhassey Date: Mon, 19 Aug 2024 17:37:08 -0600 Subject: [PATCH 177/216] ast: made equal not equal more consistently named Addresses IDX-142 Signed-off-by: philhassey --- ast/ast_test.go | 12 ++++++------ ast/operator.go | 8 ++++---- internal/ast/ast_test.go | 8 ++++---- internal/ast/operator.go | 4 ++-- internal/eval/compile.go | 2 +- internal/eval/compile_test.go | 8 ++++---- internal/eval/convert_test.go | 4 ++-- internal/json/json_test.go | 6 +++--- internal/json/json_unmarshal.go | 4 ++-- internal/parser/cedar_marshal.go | 2 +- internal/parser/cedar_unmarshal.go | 4 ++-- internal/parser/cedar_unmarshal_test.go | 22 +++++++++++----------- policy_test.go | 2 +- 13 files changed, 43 insertions(+), 43 deletions(-) diff --git a/ast/ast_test.go b/ast/ast_test.go index 19d20969..cf73c172 100644 --- a/ast/ast_test.go +++ b/ast/ast_test.go @@ -57,10 +57,10 @@ func TestAstExamples(t *testing.T) { } _ = ast.Forbid(). When( - ast.Value(simpleRecord).Access("x").Equals(ast.String("value")), + ast.Value(simpleRecord).Access("x").Equal(ast.String("value")), ). When( - ast.Record(ast.Pairs{{Key: "x", Value: ast.Long(1).Add(ast.Context().Access("fooCount"))}}).Access("x").Equals(ast.Long(3)), + ast.Record(ast.Pairs{{Key: "x", Value: ast.Long(1).Add(ast.Context().Access("fooCount"))}}).Access("x").Equal(ast.Long(3)), ). When( ast.Set( @@ -245,13 +245,13 @@ func TestASTByTable(t *testing.T) { }, { "opEquals", - ast.Permit().When(ast.Long(42).Equals(ast.Long(43))), - internalast.Permit().When(internalast.Long(42).Equals(internalast.Long(43))), + ast.Permit().When(ast.Long(42).Equal(ast.Long(43))), + internalast.Permit().When(internalast.Long(42).Equal(internalast.Long(43))), }, { "opNotEquals", - ast.Permit().When(ast.Long(42).NotEquals(ast.Long(43))), - internalast.Permit().When(internalast.Long(42).NotEquals(internalast.Long(43))), + ast.Permit().When(ast.Long(42).NotEqual(ast.Long(43))), + internalast.Permit().When(internalast.Long(42).NotEqual(internalast.Long(43))), }, { "opLessThan", diff --git a/ast/operator.go b/ast/operator.go index 3dc649e4..6767d0aa 100644 --- a/ast/operator.go +++ b/ast/operator.go @@ -12,12 +12,12 @@ import ( // \____\___/|_| |_| |_| .__/ \__,_|_| |_|___/\___/|_| |_| // |_| -func (lhs Node) Equals(rhs Node) Node { - return wrapNode(lhs.Node.Equals(rhs.Node)) +func (lhs Node) Equal(rhs Node) Node { + return wrapNode(lhs.Node.Equal(rhs.Node)) } -func (lhs Node) NotEquals(rhs Node) Node { - return wrapNode(lhs.Node.NotEquals(rhs.Node)) +func (lhs Node) NotEqual(rhs Node) Node { + return wrapNode(lhs.Node.NotEqual(rhs.Node)) } func (lhs Node) LessThan(rhs Node) Node { diff --git a/internal/ast/ast_test.go b/internal/ast/ast_test.go index 87cded5b..f2d10918 100644 --- a/internal/ast/ast_test.go +++ b/internal/ast/ast_test.go @@ -56,11 +56,11 @@ func TestAstExamples(t *testing.T) { } _ = ast.Forbid(). When( - ast.Value(simpleRecord).Access("x").Equals(ast.String("value")), + ast.Value(simpleRecord).Access("x").Equal(ast.String("value")), ). When( ast.Record(ast.Pairs{{Key: "x", Value: ast.Long(1).Add(ast.Context().Access("fooCount"))}}). - Access("x").Equals(ast.Long(3)), + Access("x").Equal(ast.Long(3)), ). When( ast.Set( @@ -290,13 +290,13 @@ func TestASTByTable(t *testing.T) { }}, { "opEquals", - ast.Permit().When(ast.Long(42).Equals(ast.Long(43))), + ast.Permit().When(ast.Long(42).Equal(ast.Long(43))), ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeEquals{BinaryNode: ast.BinaryNode{Left: ast.NodeValue{Value: types.Long(42)}, Right: ast.NodeValue{Value: types.Long(43)}}}}}}, }, { "opNotEquals", - ast.Permit().When(ast.Long(42).NotEquals(ast.Long(43))), + ast.Permit().When(ast.Long(42).NotEqual(ast.Long(43))), ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeNotEquals{BinaryNode: ast.BinaryNode{Left: ast.NodeValue{Value: types.Long(42)}, Right: ast.NodeValue{Value: types.Long(43)}}}}}}, }, diff --git a/internal/ast/operator.go b/internal/ast/operator.go index 8d393805..07287f86 100644 --- a/internal/ast/operator.go +++ b/internal/ast/operator.go @@ -9,11 +9,11 @@ import "github.com/cedar-policy/cedar-go/types" // \____\___/|_| |_| |_| .__/ \__,_|_| |_|___/\___/|_| |_| // |_| -func (lhs Node) Equals(rhs Node) Node { +func (lhs Node) Equal(rhs Node) Node { return NewNode(NodeTypeEquals{BinaryNode: BinaryNode{Left: lhs.v, Right: rhs.v}}) } -func (lhs Node) NotEquals(rhs Node) Node { +func (lhs Node) NotEqual(rhs Node) Node { return NewNode(NodeTypeNotEquals{BinaryNode: BinaryNode{Left: lhs.v, Right: rhs.v}}) } diff --git a/internal/eval/compile.go b/internal/eval/compile.go index 7018bb97..a47de7c3 100644 --- a/internal/eval/compile.go +++ b/internal/eval/compile.go @@ -36,7 +36,7 @@ func scopeToNode(varNode ast.NodeTypeVariable, in ast.IsScopeNode) ast.Node { case ast.ScopeTypeAll: return ast.True() case ast.ScopeTypeEq: - return ast.NewNode(varNode).Equals(ast.Value(t.Entity)) + return ast.NewNode(varNode).Equal(ast.Value(t.Entity)) case ast.ScopeTypeIn: return ast.NewNode(varNode).In(ast.Value(t.Entity)) case ast.ScopeTypeInSet: diff --git a/internal/eval/compile_test.go b/internal/eval/compile_test.go index dad65f8c..4ff8a1a2 100644 --- a/internal/eval/compile_test.go +++ b/internal/eval/compile_test.go @@ -36,9 +36,9 @@ func TestPolicyToNode(t *testing.T) { ActionEq(types.NewEntityUID("Action", "test")). ResourceEq(types.NewEntityUID("Resource", "database")), - ast.Principal().Equals(ast.EntityUID("Account", "principal")).And( - ast.Action().Equals(ast.EntityUID("Action", "test")).And( - ast.Resource().Equals(ast.EntityUID("Resource", "database")), + ast.Principal().Equal(ast.EntityUID("Account", "principal")).And( + ast.Action().Equal(ast.EntityUID("Action", "test")).And( + ast.Resource().Equal(ast.EntityUID("Resource", "database")), ), ), }, @@ -81,7 +81,7 @@ func TestScopeToNode(t *testing.T) { "eq", ast.NewPrincipalNode(), ast.ScopeTypeEq{Entity: types.NewEntityUID("T", "42")}, - ast.Principal().Equals(ast.EntityUID("T", "42")), + ast.Principal().Equal(ast.EntityUID("T", "42")), }, { "in", diff --git a/internal/eval/convert_test.go b/internal/eval/convert_test.go index 70c8dc74..556bfb09 100644 --- a/internal/eval/convert_test.go +++ b/internal/eval/convert_test.go @@ -127,13 +127,13 @@ func TestToEval(t *testing.T) { }, { "equals", - ast.Long(42).Equals(ast.Long(43)), + ast.Long(42).Equal(ast.Long(43)), types.False, testutil.OK, }, { "notEquals", - ast.Long(42).NotEquals(ast.Long(43)), + ast.Long(42).NotEqual(ast.Long(43)), types.True, testutil.OK, }, diff --git a/internal/json/json_test.go b/internal/json/json_test.go index 65715d12..a8442584 100644 --- a/internal/json/json_test.go +++ b/internal/json/json_test.go @@ -81,7 +81,7 @@ func TestUnmarshalJSON(t *testing.T) { ActionEq(types.NewEntityUID("Action", "view")). ResourceIn(types.NewEntityUID("Folder", "abc")). When( - ast.Context().Access("tls_version").Equals(ast.String("1.3")), + ast.Context().Access("tls_version").Equal(ast.String("1.3")), ), testutil.OK, }, @@ -265,14 +265,14 @@ func TestUnmarshalJSON(t *testing.T) { "equals", `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, "conditions":[{"kind":"when","body":{"==":{"left":{"Value":42},"right":{"Value":24}}}}]}`, - ast.Permit().When(ast.Long(42).Equals(ast.Long(24))), + ast.Permit().When(ast.Long(42).Equal(ast.Long(24))), testutil.OK, }, { "notEquals", `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, "conditions":[{"kind":"when","body":{"!=":{"left":{"Value":42},"right":{"Value":24}}}}]}`, - ast.Permit().When(ast.Long(42).NotEquals(ast.Long(24))), + ast.Permit().When(ast.Long(42).NotEqual(ast.Long(24))), testutil.OK, }, { diff --git a/internal/json/json_unmarshal.go b/internal/json/json_unmarshal.go index cfe9f8dc..ef2b2a1e 100644 --- a/internal/json/json_unmarshal.go +++ b/internal/json/json_unmarshal.go @@ -177,9 +177,9 @@ func (j nodeJSON) ToNode() (ast.Node, error) { // Binary operators: ==, !=, in, <, <=, >, >=, &&, ||, +, -, *, contains, containsAll, containsAny case j.Equals != nil: - return j.Equals.ToNode(ast.Node.Equals) + return j.Equals.ToNode(ast.Node.Equal) case j.NotEquals != nil: - return j.NotEquals.ToNode(ast.Node.NotEquals) + return j.NotEquals.ToNode(ast.Node.NotEqual) case j.In != nil: return j.In.ToNode(ast.Node.In) case j.LessThan != nil: diff --git a/internal/parser/cedar_marshal.go b/internal/parser/cedar_marshal.go index 95383f53..bde5a4d3 100644 --- a/internal/parser/cedar_marshal.go +++ b/internal/parser/cedar_marshal.go @@ -34,7 +34,7 @@ func scopeToNode(varNode ast.NodeTypeVariable, in ast.IsScopeNode) ast.Node { case ast.ScopeTypeAll: return ast.True() case ast.ScopeTypeEq: - return ast.NewNode(varNode).Equals(ast.Value(t.Entity)) + return ast.NewNode(varNode).Equal(ast.Value(t.Entity)) case ast.ScopeTypeIn: return ast.NewNode(varNode).In(ast.Value(t.Entity)) case ast.ScopeTypeInSet: diff --git a/internal/parser/cedar_unmarshal.go b/internal/parser/cedar_unmarshal.go index f7fd4f2c..933a6c00 100644 --- a/internal/parser/cedar_unmarshal.go +++ b/internal/parser/cedar_unmarshal.go @@ -517,9 +517,9 @@ func (p *parser) relation() (ast.Node, error) { case ">=": operator = ast.Node.GreaterThanOrEqual case "!=": - operator = ast.Node.NotEquals + operator = ast.Node.NotEqual case "==": - operator = ast.Node.Equals + operator = ast.Node.Equal case "in": operator = ast.Node.In default: diff --git a/internal/parser/cedar_unmarshal_test.go b/internal/parser/cedar_unmarshal_test.go index 3b569421..98e975ea 100644 --- a/internal/parser/cedar_unmarshal_test.go +++ b/internal/parser/cedar_unmarshal_test.go @@ -257,13 +257,13 @@ when { 2 >= 42 };`, "equal", `permit ( principal, action, resource ) when { 2 == 42 };`, - ast.Permit().When(ast.Long(2).Equals(ast.Long(42))), + ast.Permit().When(ast.Long(2).Equal(ast.Long(42))), }, { "not equal", `permit ( principal, action, resource ) when { 2 != 42 };`, - ast.Permit().When(ast.Long(2).NotEquals(ast.Long(42))), + ast.Permit().When(ast.Long(2).NotEqual(ast.Long(42))), }, { "in", @@ -349,7 +349,7 @@ when { if true then true else false };`, `permit ( principal, action, resource ) when { ip("1.2.3.4") == ip("2.3.4.5") };`, ast.Permit().When( - ast.ExtensionCall("ip", ast.String("1.2.3.4")).Equals( + ast.ExtensionCall("ip", ast.String("1.2.3.4")).Equal( ast.ExtensionCall("ip", ast.String("2.3.4.5")), ), ), @@ -359,7 +359,7 @@ when { ip("1.2.3.4") == ip("2.3.4.5") };`, `permit ( principal, action, resource ) when { decimal("12.34") == decimal("23.45") };`, ast.Permit().When( - ast.ExtensionCall("decimal", ast.String("12.34")).Equals(ast.ExtensionCall("decimal", ast.String("23.45"))), + ast.ExtensionCall("decimal", ast.String("12.34")).Equal(ast.ExtensionCall("decimal", ast.String("23.45"))), ), }, { @@ -384,19 +384,19 @@ when { 1 + 1 < 3 };`, "mult over add precedence (rhs add)", `permit ( principal, action, resource ) when { 2 * 3 + 4 == 10 };`, - ast.Permit().When(ast.Long(2).Multiply(ast.Long(3)).Add(ast.Long(4)).Equals(ast.Long(10))), + ast.Permit().When(ast.Long(2).Multiply(ast.Long(3)).Add(ast.Long(4)).Equal(ast.Long(10))), }, { "mult over add precedence (lhs add)", `permit ( principal, action, resource ) when { 2 + 3 * 4 == 14 };`, - ast.Permit().When(ast.Long(2).Add(ast.Long(3).Multiply(ast.Long(4))).Equals(ast.Long(14))), + ast.Permit().When(ast.Long(2).Add(ast.Long(3).Multiply(ast.Long(4))).Equal(ast.Long(14))), }, { "unary over mult precedence", `permit ( principal, action, resource ) when { -2 * 3 == -6 };`, - ast.Permit().When(ast.Long(-2).Multiply(ast.Long(3)).Equals(ast.Long(-6))), + ast.Permit().When(ast.Long(-2).Multiply(ast.Long(3)).Equal(ast.Long(-6))), }, { "member over unary precedence", @@ -408,25 +408,25 @@ when { -context.num };`, "parens over unary precedence", `permit ( principal, action, resource ) when { -(2 + 3) == -5 };`, - ast.Permit().When(ast.Negate(ast.Long(2).Add(ast.Long(3))).Equals(ast.Long(-5))), + ast.Permit().When(ast.Negate(ast.Long(2).Add(ast.Long(3))).Equal(ast.Long(-5))), }, { "multiple parenthesized operations", `permit ( principal, action, resource ) when { (2 + 3 + 4) * 5 == 18 };`, - ast.Permit().When(ast.Long(2).Add(ast.Long(3)).Add(ast.Long(4)).Multiply(ast.Long(5)).Equals(ast.Long(18))), + ast.Permit().When(ast.Long(2).Add(ast.Long(3)).Add(ast.Long(4)).Multiply(ast.Long(5)).Equal(ast.Long(18))), }, { "parenthesized if", `permit ( principal, action, resource ) when { (if true then 2 else 3 * 4) == 2 };`, - ast.Permit().When(ast.IfThenElse(ast.True(), ast.Long(2), ast.Long(3).Multiply(ast.Long(4))).Equals(ast.Long(2))), + ast.Permit().When(ast.IfThenElse(ast.True(), ast.Long(2), ast.Long(3).Multiply(ast.Long(4))).Equal(ast.Long(2))), }, { "parenthesized if with trailing mult", `permit ( principal, action, resource ) when { (if true then 2 else 3) * 4 == 8 };`, - ast.Permit().When(ast.IfThenElse(ast.True(), ast.Long(2), ast.Long(3)).Multiply(ast.Long(4)).Equals(ast.Long(8))), + ast.Permit().When(ast.IfThenElse(ast.True(), ast.Long(2), ast.Long(3)).Multiply(ast.Long(4)).Equal(ast.Long(8))), }, } diff --git a/policy_test.go b/policy_test.go index 36e8845a..4f1e9b21 100644 --- a/policy_test.go +++ b/policy_test.go @@ -90,7 +90,7 @@ func TestPolicyAST(t *testing.T) { astExample := ast.Permit(). ActionEq(types.NewEntityUID("Action", "editPhoto")). - When(ast.Resource().Access("owner").Equals(ast.Principal())) + When(ast.Resource().Access("owner").Equal(ast.Principal())) _ = cedar.NewPolicyFromAST(astExample) } From 75979990879f4b3b8404aebdb9e8ac946ac76858 Mon Sep 17 00:00:00 2001 From: philhassey Date: Mon, 19 Aug 2024 17:45:51 -0600 Subject: [PATCH 178/216] types: have MarshalCedar always emit []byte Addresses IDX-142 Signed-off-by: philhassey --- internal/eval/evalers_test.go | 2 +- internal/parser/cedar_marshal.go | 10 +++++----- types/boolean.go | 12 +++++++----- types/decimal.go | 5 ++--- types/entity_type.go | 2 +- types/entity_type_test.go | 2 +- types/entity_uid.go | 8 ++++---- types/ipaddr.go | 5 ++--- types/json_test.go | 2 +- types/long.go | 8 ++++---- types/record.go | 13 ++++++------- types/set.go | 13 ++++++------- types/string.go | 6 +++--- types/value.go | 4 ++-- types/value_test.go | 6 +++--- 15 files changed, 48 insertions(+), 50 deletions(-) diff --git a/internal/eval/evalers_test.go b/internal/eval/evalers_test.go index bc110314..68ea34d9 100644 --- a/internal/eval/evalers_test.go +++ b/internal/eval/evalers_test.go @@ -1900,7 +1900,7 @@ func TestCedarString(t *testing.T) { t.Parallel() gotString := tt.in.String() testutil.Equals(t, gotString, tt.wantString) - gotCedar := tt.in.Cedar() + gotCedar := string(tt.in.MarshalCedar()) testutil.Equals(t, gotCedar, tt.wantCedar) }) } diff --git a/internal/parser/cedar_marshal.go b/internal/parser/cedar_marshal.go index bde5a4d3..7b1dfcbe 100644 --- a/internal/parser/cedar_marshal.go +++ b/internal/parser/cedar_marshal.go @@ -87,7 +87,7 @@ func marshalAnnotation(n ast.AnnotationType, buf *bytes.Buffer) { buf.WriteRune('@') buf.WriteString(string(n.Key)) buf.WriteRune('(') - buf.WriteString(n.Value.Cedar()) + buf.Write(n.Value.MarshalCedar()) buf.WriteString(")") } @@ -116,7 +116,7 @@ func marshalCondition(c ast.ConditionType, buf *bytes.Buffer) { } func (n NodeValue) marshalCedar(buf *bytes.Buffer) { - buf.WriteString(n.NodeValue.Value.Cedar()) + buf.Write(n.NodeValue.Value.MarshalCedar()) } func marshalChildNode(thisNodePrecedence nodePrecedenceLevel, childAstNode ast.IsNode, buf *bytes.Buffer) { @@ -157,7 +157,7 @@ func (n NodeTypeAccess) marshalCedar(buf *bytes.Buffer) { buf.WriteString(string(n.NodeTypeAccess.Value)) } else { buf.WriteRune('[') - buf.WriteString(n.NodeTypeAccess.Value.Cedar()) + buf.Write(n.NodeTypeAccess.Value.MarshalCedar()) buf.WriteRune(']') } } @@ -218,7 +218,7 @@ func (n NodeTypeSet) marshalCedar(buf *bytes.Buffer) { func (n NodeTypeRecord) marshalCedar(buf *bytes.Buffer) { buf.WriteRune('{') for i := range n.NodeTypeRecord.Elements { - buf.WriteString(n.NodeTypeRecord.Elements[i].Key.Cedar()) + buf.Write(n.NodeTypeRecord.Elements[i].Key.MarshalCedar()) buf.WriteString(":") marshalChildNode(n.precedenceLevel(), n.NodeTypeRecord.Elements[i].Value, buf) if i != len(n.NodeTypeRecord.Elements)-1 { @@ -290,7 +290,7 @@ func (n NodeTypeHas) marshalCedar(buf *bytes.Buffer) { if canMarshalAsIdent(string(n.NodeTypeHas.Value)) { buf.WriteString(string(n.NodeTypeHas.Value)) } else { - buf.WriteString(n.NodeTypeHas.Value.Cedar()) + buf.Write(n.NodeTypeHas.Value.MarshalCedar()) } } diff --git a/types/boolean.go b/types/boolean.go index c89c0a97..732742a5 100644 --- a/types/boolean.go +++ b/types/boolean.go @@ -2,7 +2,6 @@ package types import ( "encoding/json" - "fmt" ) // A Boolean is a value that is either true or false. @@ -20,11 +19,14 @@ func (a Boolean) Equal(bi Value) bool { func (v Boolean) TypeName() string { return "bool" } // String produces a string representation of the Boolean, e.g. `true`. -func (v Boolean) String() string { return v.Cedar() } +func (v Boolean) String() string { return string(v.MarshalCedar()) } -// Cedar produces a valid Cedar language representation of the Boolean, e.g. `true`. -func (v Boolean) Cedar() string { - return fmt.Sprint(bool(v)) +// MarshalCedar produces a valid MarshalCedar language representation of the Boolean, e.g. `true`. +func (v Boolean) MarshalCedar() []byte { + if v { + return []byte("true") + } + return []byte("false") } // ExplicitMarshalJSON marshals the Boolean into JSON. diff --git a/types/decimal.go b/types/decimal.go index ecca6b98..19e35aad 100644 --- a/types/decimal.go +++ b/types/decimal.go @@ -112,9 +112,8 @@ func (a Decimal) Equal(bi Value) bool { return ok && a == b } - -// Cedar produces a valid Cedar language representation of the Decimal, e.g. `decimal("12.34")`. -func (v Decimal) Cedar() string { return `decimal("` + v.String() + `")` } +// MarshalCedar produces a valid MarshalCedar language representation of the Decimal, e.g. `decimal("12.34")`. +func (v Decimal) MarshalCedar() []byte { return []byte(`decimal("` + v.String() + `")`) } // String produces a string representation of the Decimal, e.g. `12.34`. func (v Decimal) String() string { diff --git a/types/entity_type.go b/types/entity_type.go index 576bcc03..0b595d30 100644 --- a/types/entity_type.go +++ b/types/entity_type.go @@ -14,7 +14,7 @@ func (a EntityType) Equal(bi Value) bool { } func (v EntityType) String() string { return string(v) } -func (v EntityType) Cedar() string { return string(v) } +func (v EntityType) MarshalCedar() []byte { return []byte(v) } func (v EntityType) ExplicitMarshalJSON() ([]byte, error) { return json.Marshal(string(v)) } func (v EntityType) deepClone() Value { return v } diff --git a/types/entity_type_test.go b/types/entity_type_test.go index 3c59a653..467dcd29 100644 --- a/types/entity_type_test.go +++ b/types/entity_type_test.go @@ -28,7 +28,7 @@ func TestEntityType(t *testing.T) { t.Run("Cedar", func(t *testing.T) { t.Parallel() a := types.EntityType("X") - testutil.Equals(t, a.Cedar(), "X") + testutil.Equals(t, a.MarshalCedar(), []byte("X")) }) t.Run("ExplicitMarshalJSON", func(t *testing.T) { t.Parallel() diff --git a/types/entity_uid.go b/types/entity_uid.go index 6e2c6307..70277808 100644 --- a/types/entity_uid.go +++ b/types/entity_uid.go @@ -29,11 +29,11 @@ func (a EntityUID) Equal(bi Value) bool { } // String produces a string representation of the EntityUID, e.g. `Type::"id"`. -func (v EntityUID) String() string { return v.Cedar() } +func (v EntityUID) String() string { return string(v.Type.String() + "::" + strconv.Quote(v.ID)) } -// Cedar produces a valid Cedar language representation of the EntityUID, e.g. `Type::"id"`. -func (v EntityUID) Cedar() string { - return v.Type.String() + "::" + strconv.Quote(v.ID) +// MarshalCedar produces a valid MarshalCedar language representation of the EntityUID, e.g. `Type::"id"`. +func (v EntityUID) MarshalCedar() []byte { + return []byte(v.String()) } func (v *EntityUID) UnmarshalJSON(b []byte) error { diff --git a/types/ipaddr.go b/types/ipaddr.go index 02bfb402..ee65ae37 100644 --- a/types/ipaddr.go +++ b/types/ipaddr.go @@ -31,9 +31,8 @@ func (a IPAddr) Equal(bi Value) bool { return ok && a == b } - -// Cedar produces a valid Cedar language representation of the IPAddr, e.g. `ip("127.0.0.1")`. -func (v IPAddr) Cedar() string { return `ip("` + v.String() + `")` } +// MarshalCedar produces a valid MarshalCedar language representation of the IPAddr, e.g. `ip("127.0.0.1")`. +func (v IPAddr) MarshalCedar() []byte { return []byte(`ip("` + v.String() + `")`) } // String produces a string representation of the IPAddr, e.g. `127.0.0.1`. func (v IPAddr) String() string { diff --git a/types/json_test.go b/types/json_test.go index 4e42bb67..f318e8f3 100644 --- a/types/json_test.go +++ b/types/json_test.go @@ -313,7 +313,7 @@ func TestJSONMarshal(t *testing.T) { type jsonErr struct{} func (j *jsonErr) String() string { return "" } -func (j *jsonErr) Cedar() string { return "" } +func (j *jsonErr) MarshalCedar() []byte { return nil } func (j *jsonErr) Equal(Value) bool { return false } func (j *jsonErr) ExplicitMarshalJSON() ([]byte, error) { return nil, fmt.Errorf("jsonErr") } func (j *jsonErr) TypeName() string { return "jsonErr" } diff --git a/types/long.go b/types/long.go index 5c1319e3..0d2a0e70 100644 --- a/types/long.go +++ b/types/long.go @@ -17,10 +17,10 @@ func (a Long) Equal(bi Value) bool { func (v Long) ExplicitMarshalJSON() ([]byte, error) { return json.Marshal(v) } // String produces a string representation of the Long, e.g. `42`. -func (v Long) String() string { return v.Cedar() } +func (v Long) String() string { return fmt.Sprint(int64(v)) } -// Cedar produces a valid Cedar language representation of the Long, e.g. `42`. -func (v Long) Cedar() string { - return fmt.Sprint(int64(v)) +// MarshalCedar produces a valid MarshalCedar language representation of the Long, e.g. `42`. +func (v Long) MarshalCedar() []byte { + return []byte(v.String()) } func (v Long) deepClone() Value { return v } diff --git a/types/record.go b/types/record.go index cc953e26..153ae09b 100644 --- a/types/record.go +++ b/types/record.go @@ -5,7 +5,6 @@ import ( "encoding/json" "slices" "strconv" - "strings" "golang.org/x/exp/maps" ) @@ -72,11 +71,11 @@ func (v Record) MarshalJSON() ([]byte, error) { func (v Record) ExplicitMarshalJSON() ([]byte, error) { return v.MarshalJSON() } // String produces a string representation of the Record, e.g. `{"a":1,"b":2,"c":3}`. -func (r Record) String() string { return r.Cedar() } +func (r Record) String() string { return string(r.MarshalCedar()) } -// Cedar produces a valid Cedar language representation of the Record, e.g. `{"a":1,"b":2,"c":3}`. -func (r Record) Cedar() string { - var sb strings.Builder +// MarshalCedar produces a valid MarshalCedar language representation of the Record, e.g. `{"a":1,"b":2,"c":3}`. +func (r Record) MarshalCedar() []byte { + var sb bytes.Buffer sb.WriteRune('{') first := true keys := maps.Keys(r) @@ -89,10 +88,10 @@ func (r Record) Cedar() string { first = false sb.WriteString(strconv.Quote(k)) sb.WriteString(":") - sb.WriteString(v.Cedar()) + sb.Write(v.MarshalCedar()) } sb.WriteRune('}') - return sb.String() + return sb.Bytes() } func (v Record) deepClone() Value { return v.DeepClone() } diff --git a/types/set.go b/types/set.go index fc7da4a5..af18a80a 100644 --- a/types/set.go +++ b/types/set.go @@ -3,7 +3,6 @@ package types import ( "bytes" "encoding/json" - "strings" ) // A Set is a collection of elements that can be of the same or different types. @@ -77,20 +76,20 @@ func (v Set) MarshalJSON() ([]byte, error) { func (v Set) ExplicitMarshalJSON() ([]byte, error) { return v.MarshalJSON() } // String produces a string representation of the Set, e.g. `[1,2,3]`. -func (v Set) String() string { return v.Cedar() } +func (v Set) String() string { return string(v.MarshalCedar()) } -// Cedar produces a valid Cedar language representation of the Set, e.g. `[1,2,3]`. -func (v Set) Cedar() string { - var sb strings.Builder +// MarshalCedar produces a valid MarshalCedar language representation of the Set, e.g. `[1,2,3]`. +func (v Set) MarshalCedar() []byte { + var sb bytes.Buffer sb.WriteRune('[') for i, elem := range v { if i > 0 { sb.WriteString(", ") } - sb.WriteString(elem.Cedar()) + sb.Write(elem.MarshalCedar()) } sb.WriteRune(']') - return sb.String() + return sb.Bytes() } func (v Set) deepClone() Value { return v.DeepClone() } diff --git a/types/string.go b/types/string.go index 9891aa9e..1341b7ca 100644 --- a/types/string.go +++ b/types/string.go @@ -21,8 +21,8 @@ func (v String) String() string { return string(v) } -// Cedar produces a valid Cedar language representation of the String, e.g. `"hello"`. -func (v String) Cedar() string { - return strconv.Quote(string(v)) +// MarshalCedar produces a valid MarshalCedar language representation of the String, e.g. `"hello"`. +func (v String) MarshalCedar() []byte { + return []byte(strconv.Quote(string(v))) } func (v String) deepClone() Value { return v } diff --git a/types/value.go b/types/value.go index a1922e89..ee9373eb 100644 --- a/types/value.go +++ b/types/value.go @@ -10,8 +10,8 @@ var ErrIP = fmt.Errorf("error parsing ip value") type Value interface { // String produces a string representation of the Value. String() string - // Cedar produces a valid Cedar language representation of the Value. - Cedar() string + // MarshalCedar produces a valid MarshalCedar language representation of the Value. + MarshalCedar() []byte // ExplicitMarshalJSON marshals the Value into JSON using the explicit (if // applicable) JSON form, which is necessary for marshalling values within // Sets or Records where the type is not defined. diff --git a/types/value_test.go b/types/value_test.go index 2712195d..0764d624 100644 --- a/types/value_test.go +++ b/types/value_test.go @@ -99,9 +99,9 @@ func TestDeepClone(t *testing.T) { t.Parallel() a := mustIPValue("127.0.0.42") b := a.deepClone() - testutil.Equals(t, a.Cedar(), b.Cedar()) + testutil.Equals(t, a.MarshalCedar(), b.MarshalCedar()) a = mustIPValue("127.0.0.43") - testutil.Equals(t, a.Cedar(), mustIPValue("127.0.0.43").Cedar()) - testutil.Equals(t, b.Cedar(), mustIPValue("127.0.0.42").Cedar()) + testutil.Equals(t, a.MarshalCedar(), mustIPValue("127.0.0.43").MarshalCedar()) + testutil.Equals(t, b.MarshalCedar(), mustIPValue("127.0.0.42").MarshalCedar()) }) } From 587937ebbb7f33bf5815c48015ce134f20152c87 Mon Sep 17 00:00:00 2001 From: philhassey Date: Mon, 19 Aug 2024 17:49:08 -0600 Subject: [PATCH 179/216] types: hide private test function Addresses IDX-142 Signed-off-by: philhassey --- types/boolean_test.go | 2 +- types/entity_uid_test.go | 4 ++-- types/long_test.go | 2 +- types/record_test.go | 6 +++--- types/set_test.go | 4 ++-- types/string_test.go | 4 ++-- types/testutil_test.go | 4 +--- 7 files changed, 12 insertions(+), 14 deletions(-) diff --git a/types/boolean_test.go b/types/boolean_test.go index 25d7e2b8..a95dbecb 100644 --- a/types/boolean_test.go +++ b/types/boolean_test.go @@ -25,7 +25,7 @@ func TestBool(t *testing.T) { t.Run("string", func(t *testing.T) { t.Parallel() - AssertValueString(t, types.Boolean(true), "true") + assertValueString(t, types.Boolean(true), "true") }) } diff --git a/types/entity_uid_test.go b/types/entity_uid_test.go index 51e5605c..60a26a05 100644 --- a/types/entity_uid_test.go +++ b/types/entity_uid_test.go @@ -22,8 +22,8 @@ func TestEntity(t *testing.T) { t.Run("string", func(t *testing.T) { t.Parallel() - AssertValueString(t, types.EntityUID{Type: "type", ID: "id"}, `type::"id"`) - AssertValueString(t, types.EntityUID{Type: "namespace::type", ID: "id"}, `namespace::type::"id"`) + assertValueString(t, types.EntityUID{Type: "type", ID: "id"}, `type::"id"`) + assertValueString(t, types.EntityUID{Type: "namespace::type", ID: "id"}, `namespace::type::"id"`) }) } diff --git a/types/long_test.go b/types/long_test.go index 1ae49adb..ec00570c 100644 --- a/types/long_test.go +++ b/types/long_test.go @@ -25,7 +25,7 @@ func TestLong(t *testing.T) { t.Run("string", func(t *testing.T) { t.Parallel() - AssertValueString(t, types.Long(1), "1") + assertValueString(t, types.Long(1), "1") }) } diff --git a/types/record_test.go b/types/record_test.go index dd5b6068..a7ce1fe1 100644 --- a/types/record_test.go +++ b/types/record_test.go @@ -57,12 +57,12 @@ func TestRecord(t *testing.T) { t.Run("string", func(t *testing.T) { t.Parallel() - AssertValueString(t, types.Record{}, "{}") - AssertValueString( + assertValueString(t, types.Record{}, "{}") + assertValueString( t, types.Record{"foo": types.Boolean(true)}, `{"foo":true}`) - AssertValueString( + assertValueString( t, types.Record{ "foo": types.Boolean(true), diff --git a/types/set_test.go b/types/set_test.go index 6705df6c..770def8d 100644 --- a/types/set_test.go +++ b/types/set_test.go @@ -45,8 +45,8 @@ func TestSet(t *testing.T) { t.Run("string", func(t *testing.T) { t.Parallel() - AssertValueString(t, types.Set{}, "[]") - AssertValueString( + assertValueString(t, types.Set{}, "[]") + assertValueString( t, types.Set{types.Boolean(true), types.Long(1)}, "[true, 1]") diff --git a/types/string_test.go b/types/string_test.go index ca9e262b..882f96d2 100644 --- a/types/string_test.go +++ b/types/string_test.go @@ -22,8 +22,8 @@ func TestString(t *testing.T) { t.Run("string", func(t *testing.T) { t.Parallel() - AssertValueString(t, types.String("hello"), `hello`) - AssertValueString(t, types.String("hello\ngoodbye"), "hello\ngoodbye") + assertValueString(t, types.String("hello"), `hello`) + assertValueString(t, types.String("hello\ngoodbye"), "hello\ngoodbye") }) } diff --git a/types/testutil_test.go b/types/testutil_test.go index c5dbfa19..f91b13d5 100644 --- a/types/testutil_test.go +++ b/types/testutil_test.go @@ -7,9 +7,7 @@ import ( "github.com/cedar-policy/cedar-go/types" ) -// TODO: this file should not be public, it should be moved into the eval code - -func AssertValueString(t *testing.T, v types.Value, want string) { +func assertValueString(t *testing.T, v types.Value, want string) { t.Helper() testutil.Equals(t, v.String(), want) } From dc57157d9a983360efc64ee9a834c89e8d5afddc Mon Sep 17 00:00:00 2001 From: philhassey Date: Mon, 19 Aug 2024 17:55:36 -0600 Subject: [PATCH 180/216] types: have pattern have MarshalCedar method Addresses IDX-142 Signed-off-by: philhassey --- internal/eval/evalers.go | 1 - internal/json/json_marshal.go | 2 +- internal/parser/cedar_marshal.go | 2 +- types/pattern.go | 4 ++-- 4 files changed, 4 insertions(+), 5 deletions(-) diff --git a/internal/eval/evalers.go b/internal/eval/evalers.go index d8367921..1bef363b 100644 --- a/internal/eval/evalers.go +++ b/internal/eval/evalers.go @@ -18,7 +18,6 @@ func zeroValue() types.Value { return nil } -// TODO: make private again type Context struct { Entities entities.Entities Principal, Action, Resource types.Value diff --git a/internal/json/json_marshal.go b/internal/json/json_marshal.go index 8794da7f..195ec442 100644 --- a/internal/json/json_marshal.go +++ b/internal/json/json_marshal.go @@ -71,7 +71,7 @@ func arrayToJSON(dest *arrayJSON, args []ast.IsNode) { func extToJSON(dest *extensionJSON, name string, src types.Value) { res := arrayJSON{} - str := src.String() // TODO: is this the correct string? + str := src.String() val := valueJSON{v: types.String(str)} res = append(res, nodeJSON{ Value: &val, diff --git a/internal/parser/cedar_marshal.go b/internal/parser/cedar_marshal.go index 7b1dfcbe..a08944a6 100644 --- a/internal/parser/cedar_marshal.go +++ b/internal/parser/cedar_marshal.go @@ -311,7 +311,7 @@ func (n NodeTypeIsIn) marshalCedar(buf *bytes.Buffer) { func (n NodeTypeLike) marshalCedar(buf *bytes.Buffer) { marshalChildNode(n.precedenceLevel(), n.NodeTypeLike.Arg, buf) buf.WriteString(" like ") - buf.WriteString(n.NodeTypeLike.Value.Cedar()) + buf.Write(n.NodeTypeLike.Value.MarshalCedar()) } func (n NodeTypeIf) marshalCedar(buf *bytes.Buffer) { diff --git a/types/pattern.go b/types/pattern.go index 8dc3a85e..87d779c3 100644 --- a/types/pattern.go +++ b/types/pattern.go @@ -57,7 +57,7 @@ func NewPattern(components ...PatternComponent) Pattern { return Pattern{comps: comps} } -func (p Pattern) Cedar() string { +func (p Pattern) MarshalCedar() []byte { var buf bytes.Buffer buf.WriteRune('"') for _, comp := range p.comps { @@ -71,7 +71,7 @@ func (p Pattern) Cedar() string { buf.WriteString(quotedString) } buf.WriteRune('"') - return buf.String() + return buf.Bytes() } // ported from Go's stdlib and reduced to our scope. From ecc6b16513b7194c899b0f9c8421e2c0f1720f46 Mon Sep 17 00:00:00 2001 From: philhassey Date: Tue, 20 Aug 2024 09:44:58 -0600 Subject: [PATCH 181/216] cedar: match filename case to go scanner package Addresses IDX-142 Signed-off-by: philhassey --- internal/ast/policy.go | 2 +- policy.go | 4 ++-- policy_set_test.go | 2 +- policy_slice.go | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/internal/ast/policy.go b/internal/ast/policy.go index b27f3808..2d591d21 100644 --- a/internal/ast/policy.go +++ b/internal/ast/policy.go @@ -30,7 +30,7 @@ const ( // Position is a value that represents a source Position. // A Position is valid if Line > 0. type Position struct { - FileName string // optional name of the source file for the enclosing policy, "" if the source is unknown or not a named file + Filename string // optional name of the source file for the enclosing policy, "" if the source is unknown or not a named file Offset int // byte offset, starting at 0 Line int // line number, starting at 1 Column int // column number, starting at 1 (character count per line) diff --git a/policy.go b/policy.go index 367e94b2..5c60a75b 100644 --- a/policy.go +++ b/policy.go @@ -99,6 +99,6 @@ func (p Policy) Position() Position { return Position(p.ast.Position) } -func (p *Policy) SetFileName(path string) { - p.ast.Position.FileName = path +func (p *Policy) SetFilename(fileName string) { + p.ast.Position.Filename = fileName } diff --git a/policy_set_test.go b/policy_set_test.go index 1889f8f2..5817dcfd 100644 --- a/policy_set_test.go +++ b/policy_set_test.go @@ -162,7 +162,7 @@ forbid ( ps := cedar.NewPolicySet() for i, p := range policies { - p.SetFileName("example.cedar") + p.SetFilename("example.cedar") ps.UpsertPolicy(cedar.PolicyID(fmt.Sprintf("policy%d", i)), p) } diff --git a/policy_slice.go b/policy_slice.go index 25a70dd8..56576f74 100644 --- a/policy_slice.go +++ b/policy_slice.go @@ -20,7 +20,7 @@ func NewPolicySliceFromBytes(fileName string, document []byte) (PolicySlice, err return nil, err } for _, p := range policySlice { - p.SetFileName(fileName) + p.SetFilename(fileName) } return policySlice, nil } From a2aa168afad835b7cd7f7aaed7d0ed535b40e335 Mon Sep 17 00:00:00 2001 From: philhassey Date: Tue, 20 Aug 2024 10:05:18 -0600 Subject: [PATCH 182/216] internal/json: improve tightness of scope definition so JSON errors on invalid scopes Addresses IDX-142 Signed-off-by: philhassey --- internal/ast/policy.go | 12 ++--- internal/ast/scope.go | 78 +++++++++++++++++++++++++-------- internal/json/json_test.go | 10 ++++- internal/json/json_unmarshal.go | 46 ++++++++++++++----- 4 files changed, 108 insertions(+), 38 deletions(-) diff --git a/internal/ast/policy.go b/internal/ast/policy.go index 2d591d21..7c29e888 100644 --- a/internal/ast/policy.go +++ b/internal/ast/policy.go @@ -39,9 +39,9 @@ type Position struct { type Policy struct { Effect Effect Annotations []AnnotationType - Principal IsScopeNode - Action IsScopeNode - Resource IsScopeNode + Principal IsPrincipalScopeNode + Action IsActionScopeNode + Resource IsResourceScopeNode Conditions []ConditionType Position Position } @@ -50,9 +50,9 @@ func newPolicy(effect Effect, annotations []AnnotationType) *Policy { return &Policy{ Effect: effect, Annotations: annotations, - Principal: Scope(NewPrincipalNode()).All(), - Action: Scope(NewActionNode()).All(), - Resource: Scope(NewResourceNode()).All(), + Principal: Scope{}.All(), + Action: Scope{}.All(), + Resource: Scope{}.All(), } } diff --git a/internal/ast/scope.go b/internal/ast/scope.go index 9b9cddbe..fa685447 100644 --- a/internal/ast/scope.go +++ b/internal/ast/scope.go @@ -4,84 +4,84 @@ import ( "github.com/cedar-policy/cedar-go/types" ) -type Scope NodeTypeVariable +type Scope struct{} -func (s Scope) All() IsScopeNode { +func (s Scope) All() ScopeTypeAll { return ScopeTypeAll{} } -func (s Scope) Eq(entity types.EntityUID) IsScopeNode { +func (s Scope) Eq(entity types.EntityUID) ScopeTypeEq { return ScopeTypeEq{Entity: entity} } -func (s Scope) In(entity types.EntityUID) IsScopeNode { +func (s Scope) In(entity types.EntityUID) ScopeTypeIn { return ScopeTypeIn{Entity: entity} } -func (s Scope) InSet(entities []types.EntityUID) IsScopeNode { +func (s Scope) InSet(entities []types.EntityUID) ScopeTypeInSet { return ScopeTypeInSet{Entities: entities} } -func (s Scope) Is(entityType types.EntityType) IsScopeNode { +func (s Scope) Is(entityType types.EntityType) ScopeTypeIs { return ScopeTypeIs{Type: entityType} } -func (s Scope) IsIn(entityType types.EntityType, entity types.EntityUID) IsScopeNode { +func (s Scope) IsIn(entityType types.EntityType, entity types.EntityUID) ScopeTypeIsIn { return ScopeTypeIsIn{Type: entityType, Entity: entity} } func (p *Policy) PrincipalEq(entity types.EntityUID) *Policy { - p.Principal = Scope(NewPrincipalNode()).Eq(entity) + p.Principal = Scope{}.Eq(entity) return p } func (p *Policy) PrincipalIn(entity types.EntityUID) *Policy { - p.Principal = Scope(NewPrincipalNode()).In(entity) + p.Principal = Scope{}.In(entity) return p } func (p *Policy) PrincipalIs(entityType types.EntityType) *Policy { - p.Principal = Scope(NewPrincipalNode()).Is(entityType) + p.Principal = Scope{}.Is(entityType) return p } func (p *Policy) PrincipalIsIn(entityType types.EntityType, entity types.EntityUID) *Policy { - p.Principal = Scope(NewPrincipalNode()).IsIn(entityType, entity) + p.Principal = Scope{}.IsIn(entityType, entity) return p } func (p *Policy) ActionEq(entity types.EntityUID) *Policy { - p.Action = Scope(NewActionNode()).Eq(entity) + p.Action = Scope{}.Eq(entity) return p } func (p *Policy) ActionIn(entity types.EntityUID) *Policy { - p.Action = Scope(NewActionNode()).In(entity) + p.Action = Scope{}.In(entity) return p } func (p *Policy) ActionInSet(entities ...types.EntityUID) *Policy { - p.Action = Scope(NewActionNode()).InSet(entities) + p.Action = Scope{}.InSet(entities) return p } func (p *Policy) ResourceEq(entity types.EntityUID) *Policy { - p.Resource = Scope(NewResourceNode()).Eq(entity) + p.Resource = Scope{}.Eq(entity) return p } func (p *Policy) ResourceIn(entity types.EntityUID) *Policy { - p.Resource = Scope(NewResourceNode()).In(entity) + p.Resource = Scope{}.In(entity) return p } func (p *Policy) ResourceIs(entityType types.EntityType) *Policy { - p.Resource = Scope(NewResourceNode()).Is(entityType) + p.Resource = Scope{}.Is(entityType) return p } func (p *Policy) ResourceIsIn(entityType types.EntityType, entity types.EntityUID) *Policy { - p.Resource = Scope(NewResourceNode()).IsIn(entityType, entity) + p.Resource = Scope{}.IsIn(entityType, entity) return p } @@ -89,37 +89,77 @@ type IsScopeNode interface { isScope() } -type ScopeNode struct { +type IsPrincipalScopeNode interface { + IsScopeNode + isPrincipalScope() } +type IsActionScopeNode interface { + IsScopeNode + isActionScope() +} + +type IsResourceScopeNode interface { + IsScopeNode + isResourceScope() +} + +type ScopeNode struct{} + func (n ScopeNode) isScope() {} +type PrincipalScopeNode struct{} + +func (n PrincipalScopeNode) isPrincipalScope() {} + +type ActionScopeNode struct{} + +func (n ActionScopeNode) isActionScope() {} + +type ResourceScopeNode struct{} + +func (n ResourceScopeNode) isResourceScope() {} + type ScopeTypeAll struct { ScopeNode + PrincipalScopeNode + ActionScopeNode + ResourceScopeNode } type ScopeTypeEq struct { ScopeNode + PrincipalScopeNode + ActionScopeNode + ResourceScopeNode Entity types.EntityUID } type ScopeTypeIn struct { ScopeNode + PrincipalScopeNode + ActionScopeNode + ResourceScopeNode Entity types.EntityUID } type ScopeTypeInSet struct { ScopeNode + ActionScopeNode Entities []types.EntityUID } type ScopeTypeIs struct { ScopeNode + PrincipalScopeNode + ResourceScopeNode Type types.EntityType } type ScopeTypeIsIn struct { ScopeNode + PrincipalScopeNode + ResourceScopeNode Type types.EntityType Entity types.EntityUID } diff --git a/internal/json/json_test.go b/internal/json/json_test.go index a8442584..4750b0ce 100644 --- a/internal/json/json_test.go +++ b/internal/json/json_test.go @@ -578,9 +578,17 @@ func TestUnmarshalErrors(t *testing.T) { `{"effect":"unknown","principal":{"op":"=="},"action":{"op":"All"},"resource":{"op":"All"}}`, }, { - "scopeEqMissingEntity", + "principalScopeEqMissingEntity", `{"effect":"permit","principal":{"op":"=="},"action":{"op":"All"},"resource":{"op":"All"}}`, }, + { + "principalScopeInMissingEntity", + `{"effect":"permit","principal":{"op":"in"},"action":{"op":"All"},"resource":{"op":"All"}}`, + }, + { + "actionScopeEqMissingEntity", + `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"action":{"op":"=="}}`, + }, { "scopeUnknownOp", `{"effect":"permit","principal":{"op":"???"},"action":{"op":"All"},"resource":{"op":"All"}}`, diff --git a/internal/json/json_unmarshal.go b/internal/json/json_unmarshal.go index ef2b2a1e..6ea0e34a 100644 --- a/internal/json/json_unmarshal.go +++ b/internal/json/json_unmarshal.go @@ -12,26 +12,48 @@ import ( "github.com/cedar-policy/cedar-go/types" ) -func (s *scopeJSON) ToNode(variable ast.Scope) (ast.IsScopeNode, error) { - // TODO: should we be careful to be more strict about what is allowed here? +type isPrincipalResourceScopeNode interface { + ast.IsPrincipalScopeNode + ast.IsResourceScopeNode +} + +func (s *scopeJSON) ToPrincipalResourceNode() (isPrincipalResourceScopeNode, error) { switch s.Op { case "All": - return variable.All(), nil + return ast.Scope{}.All(), nil case "==": if s.Entity == nil { return nil, fmt.Errorf("missing entity") } - return variable.Eq(*s.Entity), nil + return ast.Scope{}.Eq(*s.Entity), nil case "in": - if s.Entity != nil { - return variable.In(*s.Entity), nil + if s.Entity == nil { + return nil, fmt.Errorf("missing entity") } - return variable.InSet(s.Entities), nil + return ast.Scope{}.In(*s.Entity), nil case "is": if s.In == nil { - return variable.Is(types.EntityType(s.EntityType)), nil + return ast.Scope{}.Is(types.EntityType(s.EntityType)), nil + } + return ast.Scope{}.IsIn(types.EntityType(s.EntityType), s.In.Entity), nil + } + return nil, fmt.Errorf("unknown op: %v", s.Op) +} + +func (s *scopeJSON) ToActionNode() (ast.IsActionScopeNode, error) { + switch s.Op { + case "All": + return ast.Scope{}.All(), nil + case "==": + if s.Entity == nil { + return nil, fmt.Errorf("missing entity") + } + return ast.Scope{}.Eq(*s.Entity), nil + case "in": + if s.Entity != nil { + return ast.Scope{}.In(*s.Entity), nil } - return variable.IsIn(types.EntityType(s.EntityType), s.In.Entity), nil + return ast.Scope{}.InSet(s.Entities), nil } return nil, fmt.Errorf("unknown op: %v", s.Op) } @@ -275,15 +297,15 @@ func (p *Policy) UnmarshalJSON(b []byte) error { p.unwrap().Annotate(types.String(k), types.String(v)) } var err error - p.Principal, err = j.Principal.ToNode(ast.Scope(ast.NewPrincipalNode())) + p.Principal, err = j.Principal.ToPrincipalResourceNode() if err != nil { return fmt.Errorf("error in principal: %w", err) } - p.Action, err = j.Action.ToNode(ast.Scope(ast.NewActionNode())) + p.Action, err = j.Action.ToActionNode() if err != nil { return fmt.Errorf("error in action: %w", err) } - p.Resource, err = j.Resource.ToNode(ast.Scope(ast.NewResourceNode())) + p.Resource, err = j.Resource.ToPrincipalResourceNode() if err != nil { return fmt.Errorf("error in resource: %w", err) } From e41ac94bc85b4abe92784fca172e2d54800f602e Mon Sep 17 00:00:00 2001 From: philhassey Date: Tue, 20 Aug 2024 10:20:43 -0600 Subject: [PATCH 183/216] internal/ast: consistently use the last known value for keys, annotations, and scope changes Addresses IDX-142 Signed-off-by: philhassey --- ast/annotation.go | 10 ++++++---- ast/scope.go | 11 +++++++++++ ast/value.go | 2 +- internal/ast/annotation.go | 22 ++++++++++++++++------ internal/ast/value.go | 11 +++++++++-- 5 files changed, 43 insertions(+), 13 deletions(-) diff --git a/ast/annotation.go b/ast/annotation.go index 6fe9e7df..19101740 100644 --- a/ast/annotation.go +++ b/ast/annotation.go @@ -26,8 +26,9 @@ func Annotation(name, value types.String) *Annotations { return wrapAnnotations(ast.Annotation(name, value)) } -func (a *Annotations) Annotation(name, value types.String) *Annotations { - return wrapAnnotations(a.unwrap().Annotation(name, value)) +// If a previous annotation exists with the same key, this builder will replace it. +func (a *Annotations) Annotation(key, value types.String) *Annotations { + return wrapAnnotations(a.unwrap().Annotation(key, value)) } func (a *Annotations) Permit() *Policy { @@ -38,6 +39,7 @@ func (a *Annotations) Forbid() *Policy { return wrapPolicy(a.unwrap().Forbid()) } -func (p *Policy) Annotate(name, value types.String) *Policy { - return wrapPolicy(p.unwrap().Annotate(name, value)) +// If a previous annotation exists with the same key, this builder will replace it. +func (p *Policy) Annotate(key, value types.String) *Policy { + return wrapPolicy(p.unwrap().Annotate(key, value)) } diff --git a/ast/scope.go b/ast/scope.go index 6db3cd25..3e25f003 100644 --- a/ast/scope.go +++ b/ast/scope.go @@ -4,46 +4,57 @@ import ( "github.com/cedar-policy/cedar-go/types" ) +// This builder will replace the previous principal scope condition. func (p *Policy) PrincipalEq(entity types.EntityUID) *Policy { return wrapPolicy(p.unwrap().PrincipalEq(entity)) } +// This builder will replace the previous principal scope condition. func (p *Policy) PrincipalIn(entity types.EntityUID) *Policy { return wrapPolicy(p.unwrap().PrincipalIn(entity)) } +// This builder will replace the previous principal scope condition. func (p *Policy) PrincipalIs(entityType types.EntityType) *Policy { return wrapPolicy(p.unwrap().PrincipalIs(entityType)) } +// This builder will replace the previous principal scope condition. func (p *Policy) PrincipalIsIn(entityType types.EntityType, entity types.EntityUID) *Policy { return wrapPolicy(p.unwrap().PrincipalIsIn(entityType, entity)) } +// This builder will replace the previous action scope condition. func (p *Policy) ActionEq(entity types.EntityUID) *Policy { return wrapPolicy(p.unwrap().ActionEq(entity)) } +// This builder will replace the previous action scope condition. func (p *Policy) ActionIn(entity types.EntityUID) *Policy { return wrapPolicy(p.unwrap().ActionIn(entity)) } +// This builder will replace the previous action scope condition. func (p *Policy) ActionInSet(entities ...types.EntityUID) *Policy { return wrapPolicy(p.unwrap().ActionInSet(entities...)) } +// This builder will replace the previous resource scope condition. func (p *Policy) ResourceEq(entity types.EntityUID) *Policy { return wrapPolicy(p.unwrap().ResourceEq(entity)) } +// This builder will replace the previous resource scope condition. func (p *Policy) ResourceIn(entity types.EntityUID) *Policy { return wrapPolicy(p.unwrap().ResourceIn(entity)) } +// This builder will replace the previous resource scope condition. func (p *Policy) ResourceIs(entityType types.EntityType) *Policy { return wrapPolicy(p.unwrap().ResourceIs(entityType)) } +// This builder will replace the previous resource scope condition. func (p *Policy) ResourceIsIn(entityType types.EntityType, entity types.EntityUID) *Policy { return wrapPolicy(p.unwrap().ResourceIsIn(entityType, entity)) } diff --git a/ast/value.go b/ast/value.go index a7039b7a..8205e2f9 100644 --- a/ast/value.go +++ b/ast/value.go @@ -54,7 +54,7 @@ type Pair struct { type Pairs []Pair -// Record, TODO: document how duplicate keys might not really get handled in a meaningful way +// In the case where duplicate keys exist, the latter value will be preserved. func Record(elements Pairs) Node { var astNodes []ast.Pair for _, v := range elements { diff --git a/internal/ast/annotation.go b/internal/ast/annotation.go index 895cd1ec..7c053132 100644 --- a/internal/ast/annotation.go +++ b/internal/ast/annotation.go @@ -17,8 +17,18 @@ func Annotation(name, value types.String) *Annotations { return &Annotations{nodes: []AnnotationType{newAnnotation(name, value)}} } -func (a *Annotations) Annotation(name, value types.String) *Annotations { - a.nodes = append(a.nodes, newAnnotation(name, value)) +func addAnnotation(in []AnnotationType, key, value types.String) []AnnotationType { + for i, aa := range in { + if aa.Key == key { + in[i] = newAnnotation(key, value) + return in + } + } + return append(in, newAnnotation(key, value)) +} + +func (a *Annotations) Annotation(key, value types.String) *Annotations { + a.nodes = addAnnotation(a.nodes, key, value) return a } @@ -30,11 +40,11 @@ func (a *Annotations) Forbid() *Policy { return newPolicy(EffectForbid, a.nodes) } -func (p *Policy) Annotate(name, value types.String) *Policy { - p.Annotations = append(p.Annotations, AnnotationType{Key: name, Value: value}) +func (p *Policy) Annotate(key, value types.String) *Policy { + p.Annotations = addAnnotation(p.Annotations, key, value) return p } -func newAnnotation(name, value types.String) AnnotationType { - return AnnotationType{Key: name, Value: value} +func newAnnotation(key, value types.String) AnnotationType { + return AnnotationType{Key: key, Value: value} } diff --git a/internal/ast/value.go b/internal/ast/value.go index 8f4ca163..5d59fd22 100644 --- a/internal/ast/value.go +++ b/internal/ast/value.go @@ -59,10 +59,17 @@ type Pair struct { type Pairs []Pair +// In the case where duplicate keys exist, the latter value will be preserved. func Record(elements Pairs) Node { var res NodeTypeRecord - for _, e := range elements { - res.Elements = append(res.Elements, RecordElementNode{Key: types.String(e.Key), Value: e.Value.v}) + m := make(map[string]int, len(elements)) + for _, v := range elements { + if i, ok := m[v.Key]; ok { + res.Elements[i] = RecordElementNode{Key: types.String(v.Key), Value: v.Value.v} + continue + } + m[v.Key] = len(res.Elements) + res.Elements = append(res.Elements, RecordElementNode{Key: types.String(v.Key), Value: v.Value.v}) } return NewNode(res) } From 0749276ec9dac2f527467afc7cc9fee8e0da5e47 Mon Sep 17 00:00:00 2001 From: philhassey Date: Tue, 20 Aug 2024 10:30:51 -0600 Subject: [PATCH 184/216] cedar: remove visible references to internal package and improve docs Addresses IDX-142 Signed-off-by: philhassey --- authorize.go | 6 +----- internal/ast/policy.go | 2 +- policy.go | 12 ++++++++---- policy_set.go | 2 +- policy_slice.go | 2 +- 5 files changed, 12 insertions(+), 12 deletions(-) diff --git a/authorize.go b/authorize.go index feecd70b..e23b0ac1 100644 --- a/authorize.go +++ b/authorize.go @@ -65,14 +65,10 @@ type Request struct { Context types.Record `json:"context"` } -type evalContext = eval.Context - -type evaler = eval.Evaler - // IsAuthorized uses the combination of the PolicySet and Entities to determine // if the given Request to determine Decision and Diagnostic. func (p PolicySet) IsAuthorized(entityMap entities.Entities, req Request) (Decision, Diagnostic) { - c := &evalContext{ + c := &eval.Context{ Entities: entityMap, Principal: req.Principal, Action: req.Action, diff --git a/internal/ast/policy.go b/internal/ast/policy.go index 7c29e888..cd11acac 100644 --- a/internal/ast/policy.go +++ b/internal/ast/policy.go @@ -38,7 +38,7 @@ type Position struct { type Policy struct { Effect Effect - Annotations []AnnotationType + Annotations []AnnotationType // duplicate keys are prevented via the builders Principal IsPrincipalScopeNode Action IsActionScopeNode Resource IsResourceScopeNode diff --git a/policy.go b/policy.go index 5c60a75b..4edddad0 100644 --- a/policy.go +++ b/policy.go @@ -12,7 +12,7 @@ import ( // A Policy is the parsed form of a single Cedar language policy statement. type Policy struct { - eval evaler // determines if a policy matches a request. + eval eval.Evaler // determines if a policy matches a request. ast *internalast.Policy } @@ -70,7 +70,6 @@ func NewPolicyFromAST(astIn *ast.Policy) *Policy { type Annotations map[string]string func (p Policy) Annotations() Annotations { - // TODO: Where should we deal with duplicate keys? res := make(map[string]string, len(p.ast.Annotations)) for _, e := range p.ast.Annotations { res[string(e.Key)] = string(e.Value) @@ -80,7 +79,7 @@ func (p Policy) Annotations() Annotations { // An Effect specifies the intent of the policy, to either permit or forbid any // request that matches the scope and conditions specified in the policy. -type Effect internalast.Effect +type Effect bool // Each Policy has a Permit or Forbid effect that is determined during parsing. const ( @@ -93,7 +92,12 @@ func (p Policy) Effect() Effect { } // A Position describes an arbitrary source position including the file, line, and column location. -type Position internalast.Position +type Position struct { + Filename string // optional name of the source file for the enclosing policy, "" if the source is unknown or not a named file + Offset int // byte offset, starting at 0 + Line int // line number, starting at 1 + Column int // column number, starting at 1 (character count per line) +} func (p Policy) Position() Position { return Position(p.ast.Position) diff --git a/policy_set.go b/policy_set.go index fd11d625..86deae23 100644 --- a/policy_set.go +++ b/policy_set.go @@ -19,7 +19,7 @@ func NewPolicySet() PolicySet { return PolicySet{policies: map[PolicyID]*Policy{}} } -// NewPolicySetFromBytes will create a PolicySet from the given text document with the/ given file name used in Position +// NewPolicySetFromBytes will create a PolicySet from the given text document with the given file name used in Position // data. If there is an error parsing the document, it will be returned. // // NewPolicySetFromBytes assigns default PolicyIDs to the policies contained in fileName in the format "policy" where diff --git a/policy_slice.go b/policy_slice.go index 56576f74..5df57350 100644 --- a/policy_slice.go +++ b/policy_slice.go @@ -12,7 +12,7 @@ import ( // naming individual policies. type PolicySlice []*Policy -// NewPolicySliceFromBytes will create a PolicySet from the given text document with the/ given file name used in Position +// NewPolicySliceFromBytes will create a PolicySet from the given text document with the given file name used in Position // data. If there is an error parsing the document, it will be returned. func NewPolicySliceFromBytes(fileName string, document []byte) (PolicySlice, error) { var policySlice PolicySlice From 3550350cd589c258d49624b3d29c69cf89c902dd Mon Sep 17 00:00:00 2001 From: philhassey Date: Tue, 20 Aug 2024 10:55:19 -0600 Subject: [PATCH 185/216] types: make decimal type a bit more opaque, add convenience functions Addresses IDX-142 Signed-off-by: philhassey --- internal/eval/convert_test.go | 10 ++++----- internal/eval/evalers.go | 12 +++++----- internal/eval/evalers_test.go | 36 +++++++++++++++--------------- internal/eval/util.go | 2 +- internal/eval/util_test.go | 4 ++-- internal/json/json_test.go | 2 +- types/decimal.go | 42 +++++++++++++++++++++-------------- types/decimal_test.go | 6 ++--- types/json_test.go | 10 ++++----- types/value_test.go | 8 +++---- 10 files changed, 70 insertions(+), 62 deletions(-) diff --git a/internal/eval/convert_test.go b/internal/eval/convert_test.go index 556bfb09..8e80ca68 100644 --- a/internal/eval/convert_test.go +++ b/internal/eval/convert_test.go @@ -206,30 +206,30 @@ func TestToEval(t *testing.T) { { "decimal", ast.ExtensionCall("decimal", ast.String("42.42")), - types.Decimal(424200), + types.NewDecimal(42.42), testutil.OK, }, { "lessThan", - ast.ExtensionCall("lessThan", ast.Value(types.Decimal(420000)), ast.Value(types.Decimal(430000))), + ast.ExtensionCall("lessThan", ast.Value(types.NewDecimal(42.0)), ast.Value(types.NewDecimal(43))), types.True, testutil.OK, }, { "lessThanOrEqual", - ast.ExtensionCall("lessThanOrEqual", ast.Value(types.Decimal(420000)), ast.Value(types.Decimal(430000))), + ast.ExtensionCall("lessThanOrEqual", ast.Value(types.NewDecimal(42.0)), ast.Value(types.NewDecimal(43))), types.True, testutil.OK, }, { "greaterThan", - ast.ExtensionCall("greaterThan", ast.Value(types.Decimal(420000)), ast.Value(types.Decimal(430000))), + ast.ExtensionCall("greaterThan", ast.Value(types.NewDecimal(42.0)), ast.Value(types.NewDecimal(43))), types.False, testutil.OK, }, { "greaterThanOrEqual", - ast.ExtensionCall("greaterThanOrEqual", ast.Value(types.Decimal(420000)), ast.Value(types.Decimal(430000))), + ast.ExtensionCall("greaterThanOrEqual", ast.Value(types.NewDecimal(42.0)), ast.Value(types.NewDecimal(43))), types.False, testutil.OK, }, diff --git a/internal/eval/evalers.go b/internal/eval/evalers.go index 1bef363b..bfc170f2 100644 --- a/internal/eval/evalers.go +++ b/internal/eval/evalers.go @@ -103,11 +103,11 @@ func evalEntityType(n Evaler, ctx *Context) (types.EntityType, error) { func evalDecimal(n Evaler, ctx *Context) (types.Decimal, error) { v, err := n.Eval(ctx) if err != nil { - return types.Decimal(0), err + return types.Decimal{}, err } d, err := ValueToDecimal(v) if err != nil { - return types.Decimal(0), err + return types.Decimal{}, err } return d, nil } @@ -523,7 +523,7 @@ func (n *decimalLessThanEval) Eval(ctx *Context) (types.Value, error) { if err != nil { return zeroValue(), err } - return types.Boolean(lhs < rhs), nil + return types.Boolean(lhs.Value < rhs.Value), nil } // decimalLessThanOrEqualEval @@ -548,7 +548,7 @@ func (n *decimalLessThanOrEqualEval) Eval(ctx *Context) (types.Value, error) { if err != nil { return zeroValue(), err } - return types.Boolean(lhs <= rhs), nil + return types.Boolean(lhs.Value <= rhs.Value), nil } // decimalGreaterThanEval @@ -573,7 +573,7 @@ func (n *decimalGreaterThanEval) Eval(ctx *Context) (types.Value, error) { if err != nil { return zeroValue(), err } - return types.Boolean(lhs > rhs), nil + return types.Boolean(lhs.Value > rhs.Value), nil } // decimalGreaterThanOrEqualEval @@ -598,7 +598,7 @@ func (n *decimalGreaterThanOrEqualEval) Eval(ctx *Context) (types.Value, error) if err != nil { return zeroValue(), err } - return types.Boolean(lhs >= rhs), nil + return types.Boolean(lhs.Value >= rhs.Value), nil } // ifThenElseEval diff --git a/internal/eval/evalers_test.go b/internal/eval/evalers_test.go index 68ea34d9..8b98c6e8 100644 --- a/internal/eval/evalers_test.go +++ b/internal/eval/evalers_test.go @@ -758,10 +758,10 @@ func TestDecimalLessThanNode(t *testing.T) { lhs, rhs Evaler err error }{ - {"LhsError", newErrorEval(errTest), newLiteralEval(types.Decimal(0)), errTest}, - {"LhsTypeError", newLiteralEval(types.True), newLiteralEval(types.Decimal(0)), ErrType}, - {"RhsError", newLiteralEval(types.Decimal(0)), newErrorEval(errTest), errTest}, - {"RhsTypeError", newLiteralEval(types.Decimal(0)), newLiteralEval(types.True), ErrType}, + {"LhsError", newErrorEval(errTest), newLiteralEval(types.Decimal{}), errTest}, + {"LhsTypeError", newLiteralEval(types.True), newLiteralEval(types.Decimal{}), ErrType}, + {"RhsError", newLiteralEval(types.Decimal{}), newErrorEval(errTest), errTest}, + {"RhsTypeError", newLiteralEval(types.Decimal{}), newLiteralEval(types.True), ErrType}, } for _, tt := range tests { tt := tt @@ -815,10 +815,10 @@ func TestDecimalLessThanOrEqualNode(t *testing.T) { lhs, rhs Evaler err error }{ - {"LhsError", newErrorEval(errTest), newLiteralEval(types.Decimal(0)), errTest}, - {"LhsTypeError", newLiteralEval(types.True), newLiteralEval(types.Decimal(0)), ErrType}, - {"RhsError", newLiteralEval(types.Decimal(0)), newErrorEval(errTest), errTest}, - {"RhsTypeError", newLiteralEval(types.Decimal(0)), newLiteralEval(types.True), ErrType}, + {"LhsError", newErrorEval(errTest), newLiteralEval(types.Decimal{}), errTest}, + {"LhsTypeError", newLiteralEval(types.True), newLiteralEval(types.Decimal{}), ErrType}, + {"RhsError", newLiteralEval(types.Decimal{}), newErrorEval(errTest), errTest}, + {"RhsTypeError", newLiteralEval(types.Decimal{}), newLiteralEval(types.True), ErrType}, } for _, tt := range tests { tt := tt @@ -872,10 +872,10 @@ func TestDecimalGreaterThanNode(t *testing.T) { lhs, rhs Evaler err error }{ - {"LhsError", newErrorEval(errTest), newLiteralEval(types.Decimal(0)), errTest}, - {"LhsTypeError", newLiteralEval(types.True), newLiteralEval(types.Decimal(0)), ErrType}, - {"RhsError", newLiteralEval(types.Decimal(0)), newErrorEval(errTest), errTest}, - {"RhsTypeError", newLiteralEval(types.Decimal(0)), newLiteralEval(types.True), ErrType}, + {"LhsError", newErrorEval(errTest), newLiteralEval(types.Decimal{}), errTest}, + {"LhsTypeError", newLiteralEval(types.True), newLiteralEval(types.Decimal{}), ErrType}, + {"RhsError", newLiteralEval(types.Decimal{}), newErrorEval(errTest), errTest}, + {"RhsTypeError", newLiteralEval(types.Decimal{}), newLiteralEval(types.True), ErrType}, } for _, tt := range tests { tt := tt @@ -929,10 +929,10 @@ func TestDecimalGreaterThanOrEqualNode(t *testing.T) { lhs, rhs Evaler err error }{ - {"LhsError", newErrorEval(errTest), newLiteralEval(types.Decimal(0)), errTest}, - {"LhsTypeError", newLiteralEval(types.True), newLiteralEval(types.Decimal(0)), ErrType}, - {"RhsError", newLiteralEval(types.Decimal(0)), newErrorEval(errTest), errTest}, - {"RhsTypeError", newLiteralEval(types.Decimal(0)), newLiteralEval(types.True), ErrType}, + {"LhsError", newErrorEval(errTest), newLiteralEval(types.Decimal{}), errTest}, + {"LhsTypeError", newLiteralEval(types.True), newLiteralEval(types.Decimal{}), ErrType}, + {"RhsError", newLiteralEval(types.Decimal{}), newErrorEval(errTest), errTest}, + {"RhsTypeError", newLiteralEval(types.Decimal{}), newLiteralEval(types.True), ErrType}, } for _, tt := range tests { tt := tt @@ -1760,7 +1760,7 @@ func TestDecimalLiteralNode(t *testing.T) { {"Error", newErrorEval(errTest), zeroValue(), errTest}, {"TypeError", newLiteralEval(types.Long(1)), zeroValue(), ErrType}, {"DecimalError", newLiteralEval(types.String("frob")), zeroValue(), types.ErrDecimal}, - {"Success", newLiteralEval(types.String("1.0")), types.Decimal(10000), nil}, + {"Success", newLiteralEval(types.String("1.0")), types.NewDecimal(1), nil}, } for _, tt := range tests { tt := tt @@ -1892,7 +1892,7 @@ func TestCedarString(t *testing.T) { {"set", types.Set{types.Long(42), types.Long(43)}, `[42, 43]`, `[42, 43]`}, {"singleIP", types.IPAddr(netip.MustParsePrefix("192.168.0.42/32")), `192.168.0.42`, `ip("192.168.0.42")`}, {"ipPrefix", types.IPAddr(netip.MustParsePrefix("192.168.0.42/24")), `192.168.0.42/24`, `ip("192.168.0.42/24")`}, - {"decimal", types.Decimal(12345678), `1234.5678`, `decimal("1234.5678")`}, + {"decimal", types.NewDecimal(1234.5678), `1234.5678`, `decimal("1234.5678")`}, } for _, tt := range tests { tt := tt diff --git a/internal/eval/util.go b/internal/eval/util.go index 81c133fd..786c8342 100644 --- a/internal/eval/util.go +++ b/internal/eval/util.go @@ -92,7 +92,7 @@ func ValueToEntityType(v types.Value) (types.EntityType, error) { func ValueToDecimal(v types.Value) (types.Decimal, error) { d, ok := v.(types.Decimal) if !ok { - return 0, fmt.Errorf("%w: expected decimal, got %v", ErrType, TypeName(v)) + return types.Decimal{}, fmt.Errorf("%w: expected decimal, got %v", ErrType, TypeName(v)) } return d, nil } diff --git a/internal/eval/util_test.go b/internal/eval/util_test.go index 0c9a9690..7e118d4b 100644 --- a/internal/eval/util_test.go +++ b/internal/eval/util_test.go @@ -134,7 +134,7 @@ func TestUtil(t *testing.T) { t.Parallel() v, err := ValueToDecimal(types.Boolean(true)) testutil.AssertError(t, err, ErrType) - testutil.Equals(t, v, 0) + testutil.Equals(t, v, types.Decimal{}) }) }) @@ -161,7 +161,7 @@ func TestTypeName(t *testing.T) { }{ {"boolean", types.Boolean(true), "bool"}, - {"decimal", types.Decimal(42), "decimal"}, + {"decimal", types.NewDecimal(42), "decimal"}, {"entityType", types.EntityType("T"), "(EntityType of type `T`)"}, {"entityUID", types.NewEntityUID("T", "42"), "(entity of type `T`)"}, {"ip", types.IPAddr{}, "IP"}, diff --git a/internal/json/json_test.go b/internal/json/json_test.go index 4750b0ce..0886ea41 100644 --- a/internal/json/json_test.go +++ b/internal/json/json_test.go @@ -718,7 +718,7 @@ func TestMarshalExtensions(t *testing.T) { }{ { "decimalType", - ast.Permit().When(ast.Value(types.Decimal(420000))), + ast.Permit().When(ast.Value(types.NewDecimal(42))), `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"},"conditions":[{"kind":"when","body":{"decimal":[{"Value":"42.0"}]}}]}`, }, { diff --git a/types/decimal.go b/types/decimal.go index 19e35aad..ba19eab6 100644 --- a/types/decimal.go +++ b/types/decimal.go @@ -11,7 +11,15 @@ import ( // A Decimal is a value with both a whole number part and a decimal part of no // more than four digits. In Go this is stored as an int64, the precision is // defined by the constant DecimalPrecision. -type Decimal int64 +type Decimal struct { + Value int64 +} + +// NewDecimal creates a decimal via trivial conversion from int, int64, float64. +// Precision may be lost and overflows may occur. +func NewDecimal[T int | int64 | float64](v T) Decimal { + return Decimal{Value: int64(v * DecimalPrecision)} +} // DecimalPrecision is the precision of a Decimal. const DecimalPrecision = 10000 @@ -20,7 +28,7 @@ const DecimalPrecision = 10000 func ParseDecimal(s string) (Decimal, error) { // Check for empty string. if len(s) == 0 { - return Decimal(0), fmt.Errorf("%w: string too short", ErrDecimal) + return Decimal{}, fmt.Errorf("%w: string too short", ErrDecimal) } i := 0 @@ -30,14 +38,14 @@ func ParseDecimal(s string) (Decimal, error) { negative = true i++ if i == len(s) { - return Decimal(0), fmt.Errorf("%w: string too short", ErrDecimal) + return Decimal{}, fmt.Errorf("%w: string too short", ErrDecimal) } } // Parse the required first digit. c := rune(s[i]) if !unicode.IsDigit(c) { - return Decimal(0), fmt.Errorf("%w: unexpected character %s", ErrDecimal, strconv.QuoteRune(c)) + return Decimal{}, fmt.Errorf("%w: unexpected character %s", ErrDecimal, strconv.QuoteRune(c)) } integer := int64(c - '0') i++ @@ -45,18 +53,18 @@ func ParseDecimal(s string) (Decimal, error) { // Parse any other digits, ending with i pointing to '.'. for ; ; i++ { if i == len(s) { - return Decimal(0), fmt.Errorf("%w: string missing decimal point", ErrDecimal) + return Decimal{}, fmt.Errorf("%w: string missing decimal point", ErrDecimal) } c = rune(s[i]) if c == '.' { break } if !unicode.IsDigit(c) { - return Decimal(0), fmt.Errorf("%w: unexpected character %s", ErrDecimal, strconv.QuoteRune(c)) + return Decimal{}, fmt.Errorf("%w: unexpected character %s", ErrDecimal, strconv.QuoteRune(c)) } integer = 10*integer + int64(c-'0') if integer > 922337203685477 { - return Decimal(0), fmt.Errorf("%w: overflow", ErrDecimal) + return Decimal{}, fmt.Errorf("%w: overflow", ErrDecimal) } } @@ -69,7 +77,7 @@ func ParseDecimal(s string) (Decimal, error) { for ; i < len(s); i++ { c = rune(s[i]) if !unicode.IsDigit(c) { - return Decimal(0), fmt.Errorf("%w: unexpected character %s", ErrDecimal, strconv.QuoteRune(c)) + return Decimal{}, fmt.Errorf("%w: unexpected character %s", ErrDecimal, strconv.QuoteRune(c)) } fraction = 10*fraction + int64(c-'0') fractionDigits++ @@ -78,7 +86,7 @@ func ParseDecimal(s string) (Decimal, error) { // Adjust the fraction part based on how many digits we parsed. switch fractionDigits { case 0: - return Decimal(0), fmt.Errorf("%w: missing digits after decimal point", ErrDecimal) + return Decimal{}, fmt.Errorf("%w: missing digits after decimal point", ErrDecimal) case 1: fraction *= 1000 case 2: @@ -87,12 +95,12 @@ func ParseDecimal(s string) (Decimal, error) { fraction *= 10 case 4: default: - return Decimal(0), fmt.Errorf("%w: too many digits after decimal point", ErrDecimal) + return Decimal{}, fmt.Errorf("%w: too many digits after decimal point", ErrDecimal) } // Check for overflow before we put the number together. if integer >= 922337203685477 && (fraction > 5808 || (!negative && fraction == 5808)) { - return Decimal(0), fmt.Errorf("%w: overflow", ErrDecimal) + return Decimal{}, fmt.Errorf("%w: overflow", ErrDecimal) } // Put the number together. @@ -101,9 +109,9 @@ func ParseDecimal(s string) (Decimal, error) { // -922337203685477.5808. This isn't technically necessary because the // go spec defines arithmetic to be well-defined when overflowing. // However, doing things this way doesn't hurt, so let's be pedantic. - return Decimal(DecimalPrecision*-integer - fraction), nil + return Decimal{Value: DecimalPrecision*-integer - fraction}, nil } else { - return Decimal(DecimalPrecision*integer + fraction), nil + return Decimal{Value: DecimalPrecision*integer + fraction}, nil } } @@ -118,13 +126,13 @@ func (v Decimal) MarshalCedar() []byte { return []byte(`decimal("` + v.String() // String produces a string representation of the Decimal, e.g. `12.34`. func (v Decimal) String() string { var res string - if v < 0 { + if v.Value < 0 { // Make sure we don't overflow here. Also, go truncates towards zero. - integer := v / DecimalPrecision - decimal := integer*DecimalPrecision - v + integer := v.Value / DecimalPrecision + decimal := integer*DecimalPrecision - v.Value res = fmt.Sprintf("-%d.%04d", -integer, decimal) } else { - res = fmt.Sprintf("%d.%04d", v/DecimalPrecision, v%DecimalPrecision) + res = fmt.Sprintf("%d.%04d", v.Value/DecimalPrecision, v.Value%DecimalPrecision) } // Trim off up to three trailing zeros. diff --git a/types/decimal_test.go b/types/decimal_test.go index d6002178..604ea6ce 100644 --- a/types/decimal_test.go +++ b/types/decimal_test.go @@ -109,9 +109,9 @@ func TestDecimal(t *testing.T) { t.Run("Equal", func(t *testing.T) { t.Parallel() - one := types.Decimal(10000) - one2 := types.Decimal(10000) - zero := types.Decimal(0) + one := types.NewDecimal(1) + one2 := types.NewDecimal(1) + zero := types.NewDecimal(0) f := types.Boolean(false) testutil.FatalIf(t, !one.Equal(one), "%v not Equal to %v", one, one) testutil.FatalIf(t, !one.Equal(one2), "%v not Equal to %v", one, one2) diff --git a/types/json_test.go b/types/json_test.go index f318e8f3..c649ad37 100644 --- a/types/json_test.go +++ b/types/json_test.go @@ -210,7 +210,7 @@ func TestTypedJSONUnmarshal(t *testing.T) { return res, err }, in: `"bad`, - wantValue: Decimal(0), + wantValue: Decimal{}, wantErr: errJSONDecode, }, { @@ -221,7 +221,7 @@ func TestTypedJSONUnmarshal(t *testing.T) { return res, err }, in: `{ "__extn": { "fn": "decimal", "arg": "bad" } }`, - wantValue: Decimal(0), + wantValue: Decimal{}, wantErr: ErrDecimal, }, { @@ -232,7 +232,7 @@ func TestTypedJSONUnmarshal(t *testing.T) { return res, err }, in: `bad`, - wantValue: Decimal(0), + wantValue: Decimal{}, wantErr: errJSONDecode, }, { @@ -243,7 +243,7 @@ func TestTypedJSONUnmarshal(t *testing.T) { return res, err }, in: `{ "__extn": { "fn": "bad", "arg": "1234.5678" } }`, - wantValue: Decimal(0), + wantValue: Decimal{}, wantErr: errJSONExtFnMatch, }, { @@ -254,7 +254,7 @@ func TestTypedJSONUnmarshal(t *testing.T) { return res, err }, in: `{ }`, - wantValue: Decimal(0), + wantValue: Decimal{}, wantErr: errJSONExtNotFound, }, } diff --git a/types/value_test.go b/types/value_test.go index 0764d624..b82d14ac 100644 --- a/types/value_test.go +++ b/types/value_test.go @@ -87,12 +87,12 @@ func TestDeepClone(t *testing.T) { t.Run("Decimal", func(t *testing.T) { t.Parallel() - a := Decimal(42) + a := NewDecimal(42) b := a.deepClone() testutil.Equals(t, Value(a), b) - a = Decimal(43) - testutil.Equals(t, a, Decimal(43)) - testutil.Equals(t, b, Value(Decimal(42))) + a = NewDecimal(43) + testutil.Equals(t, a, NewDecimal(43)) + testutil.Equals(t, b, Value(NewDecimal(42))) }) t.Run("IPAddr", func(t *testing.T) { From 5809d6b4b1118dee720820159854bbb06853d92b Mon Sep 17 00:00:00 2001 From: philhassey Date: Tue, 20 Aug 2024 10:58:38 -0600 Subject: [PATCH 186/216] ast: fix naming of annotation key Addresses IDX-142 Signed-off-by: philhassey --- ast/annotation.go | 4 ++-- internal/ast/annotation.go | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/ast/annotation.go b/ast/annotation.go index 19101740..01a8990a 100644 --- a/ast/annotation.go +++ b/ast/annotation.go @@ -22,8 +22,8 @@ func wrapAnnotations(a *ast.Annotations) *Annotations { // Annotation("baz", "quux"). // Permit(). // PrincipalEq(superUser) -func Annotation(name, value types.String) *Annotations { - return wrapAnnotations(ast.Annotation(name, value)) +func Annotation(key, value types.String) *Annotations { + return wrapAnnotations(ast.Annotation(key, value)) } // If a previous annotation exists with the same key, this builder will replace it. diff --git a/internal/ast/annotation.go b/internal/ast/annotation.go index 7c053132..acfd3014 100644 --- a/internal/ast/annotation.go +++ b/internal/ast/annotation.go @@ -13,8 +13,8 @@ type Annotations struct { // Annotation("baz", "quux"). // Permit(). // PrincipalEq(superUser) -func Annotation(name, value types.String) *Annotations { - return &Annotations{nodes: []AnnotationType{newAnnotation(name, value)}} +func Annotation(key, value types.String) *Annotations { + return &Annotations{nodes: []AnnotationType{newAnnotation(key, value)}} } func addAnnotation(in []AnnotationType, key, value types.String) []AnnotationType { From 0d61556b9230e8e6c1a98b75e16d4c7e37140821 Mon Sep 17 00:00:00 2001 From: philhassey Date: Tue, 20 Aug 2024 14:14:57 -0600 Subject: [PATCH 187/216] types: improve type consistency across whole package Addresses IDX-142 Signed-off-by: philhassey --- ast/annotation.go | 6 ++-- ast/ast_test.go | 8 +++--- ast/operator.go | 4 +-- ast/scope.go | 8 +++--- ast/value.go | 4 +-- internal/ast/annotation.go | 10 +++---- internal/ast/ast_test.go | 16 +++++------ internal/ast/node.go | 10 +++---- internal/ast/operator.go | 4 +-- internal/ast/policy.go | 2 +- internal/ast/scope.go | 16 +++++------ internal/ast/value.go | 10 +++---- internal/eval/convert.go | 4 +-- internal/eval/evalers.go | 24 +++------------- internal/eval/evalers_test.go | 29 +++++++++---------- internal/eval/util.go | 10 ------- internal/eval/util_test.go | 1 - internal/extensions/extensions.go | 2 +- internal/json/json_test.go | 8 +++--- internal/json/json_unmarshal.go | 16 +++++------ internal/parser/cedar_unmarshal.go | 38 ++++++++++++------------ types/decimal.go | 1 + types/entity_type.go | 23 --------------- types/entity_type_test.go | 46 ------------------------------ types/entity_uid.go | 24 ++++++++-------- types/ident.go | 3 ++ types/json.go | 8 +++--- types/path.go | 10 +++++++ types/path_test.go | 18 ++++++++++++ types/value_test.go | 8 ------ 30 files changed, 149 insertions(+), 222 deletions(-) delete mode 100644 types/entity_type.go delete mode 100644 types/entity_type_test.go create mode 100644 types/ident.go create mode 100644 types/path.go create mode 100644 types/path_test.go diff --git a/ast/annotation.go b/ast/annotation.go index 01a8990a..edb87daf 100644 --- a/ast/annotation.go +++ b/ast/annotation.go @@ -22,12 +22,12 @@ func wrapAnnotations(a *ast.Annotations) *Annotations { // Annotation("baz", "quux"). // Permit(). // PrincipalEq(superUser) -func Annotation(key, value types.String) *Annotations { +func Annotation(key types.Ident, value types.String) *Annotations { return wrapAnnotations(ast.Annotation(key, value)) } // If a previous annotation exists with the same key, this builder will replace it. -func (a *Annotations) Annotation(key, value types.String) *Annotations { +func (a *Annotations) Annotation(key types.Ident, value types.String) *Annotations { return wrapAnnotations(a.unwrap().Annotation(key, value)) } @@ -40,6 +40,6 @@ func (a *Annotations) Forbid() *Policy { } // If a previous annotation exists with the same key, this builder will replace it. -func (p *Policy) Annotate(key, value types.String) *Policy { +func (p *Policy) Annotate(key types.Ident, value types.String) *Policy { return wrapPolicy(p.unwrap().Annotate(key, value)) } diff --git a/ast/ast_test.go b/ast/ast_test.go index cf73c172..b0b28b9f 100644 --- a/ast/ast_test.go +++ b/ast/ast_test.go @@ -345,13 +345,13 @@ func TestASTByTable(t *testing.T) { }, { "opIs", - ast.Permit().When(ast.Long(42).Is(types.EntityType("T"))), - internalast.Permit().When(internalast.Long(42).Is(types.EntityType("T"))), + ast.Permit().When(ast.Long(42).Is(types.Path("T"))), + internalast.Permit().When(internalast.Long(42).Is(types.Path("T"))), }, { "opIsIn", - ast.Permit().When(ast.Long(42).IsIn(types.EntityType("T"), ast.Long(43))), - internalast.Permit().When(internalast.Long(42).IsIn(types.EntityType("T"), internalast.Long(43))), + ast.Permit().When(ast.Long(42).IsIn(types.Path("T"), ast.Long(43))), + internalast.Permit().When(internalast.Long(42).IsIn(types.Path("T"), internalast.Long(43))), }, { "opContains", diff --git a/ast/operator.go b/ast/operator.go index 6767d0aa..ab7af3f3 100644 --- a/ast/operator.go +++ b/ast/operator.go @@ -112,11 +112,11 @@ func (lhs Node) In(rhs Node) Node { return wrapNode(lhs.Node.In(rhs.Node)) } -func (lhs Node) Is(entityType types.EntityType) Node { +func (lhs Node) Is(entityType types.Path) Node { return wrapNode(lhs.Node.Is(entityType)) } -func (lhs Node) IsIn(entityType types.EntityType, rhs Node) Node { +func (lhs Node) IsIn(entityType types.Path, rhs Node) Node { return wrapNode(lhs.Node.IsIn(entityType, rhs.Node)) } diff --git a/ast/scope.go b/ast/scope.go index 3e25f003..7141105d 100644 --- a/ast/scope.go +++ b/ast/scope.go @@ -15,12 +15,12 @@ func (p *Policy) PrincipalIn(entity types.EntityUID) *Policy { } // This builder will replace the previous principal scope condition. -func (p *Policy) PrincipalIs(entityType types.EntityType) *Policy { +func (p *Policy) PrincipalIs(entityType types.Path) *Policy { return wrapPolicy(p.unwrap().PrincipalIs(entityType)) } // This builder will replace the previous principal scope condition. -func (p *Policy) PrincipalIsIn(entityType types.EntityType, entity types.EntityUID) *Policy { +func (p *Policy) PrincipalIsIn(entityType types.Path, entity types.EntityUID) *Policy { return wrapPolicy(p.unwrap().PrincipalIsIn(entityType, entity)) } @@ -50,11 +50,11 @@ func (p *Policy) ResourceIn(entity types.EntityUID) *Policy { } // This builder will replace the previous resource scope condition. -func (p *Policy) ResourceIs(entityType types.EntityType) *Policy { +func (p *Policy) ResourceIs(entityType types.Path) *Policy { return wrapPolicy(p.unwrap().ResourceIs(entityType)) } // This builder will replace the previous resource scope condition. -func (p *Policy) ResourceIsIn(entityType types.EntityType, entity types.EntityUID) *Policy { +func (p *Policy) ResourceIsIn(entityType types.Path, entity types.EntityUID) *Policy { return wrapPolicy(p.unwrap().ResourceIsIn(entityType, entity)) } diff --git a/ast/value.go b/ast/value.go index 8205e2f9..ab27ae5d 100644 --- a/ast/value.go +++ b/ast/value.go @@ -48,7 +48,7 @@ func Set(nodes ...Node) Node { } type Pair struct { - Key string + Key types.String Value Node } @@ -63,7 +63,7 @@ func Record(elements Pairs) Node { return wrapNode(ast.Record(astNodes)) } -func EntityUID(typ, id string) Node { +func EntityUID(typ types.Ident, id types.String) Node { return wrapNode(ast.EntityUID(typ, id)) } diff --git a/internal/ast/annotation.go b/internal/ast/annotation.go index acfd3014..a1fafdfc 100644 --- a/internal/ast/annotation.go +++ b/internal/ast/annotation.go @@ -13,11 +13,11 @@ type Annotations struct { // Annotation("baz", "quux"). // Permit(). // PrincipalEq(superUser) -func Annotation(key, value types.String) *Annotations { +func Annotation(key types.Ident, value types.String) *Annotations { return &Annotations{nodes: []AnnotationType{newAnnotation(key, value)}} } -func addAnnotation(in []AnnotationType, key, value types.String) []AnnotationType { +func addAnnotation(in []AnnotationType, key types.Ident, value types.String) []AnnotationType { for i, aa := range in { if aa.Key == key { in[i] = newAnnotation(key, value) @@ -27,7 +27,7 @@ func addAnnotation(in []AnnotationType, key, value types.String) []AnnotationTyp return append(in, newAnnotation(key, value)) } -func (a *Annotations) Annotation(key, value types.String) *Annotations { +func (a *Annotations) Annotation(key types.Ident, value types.String) *Annotations { a.nodes = addAnnotation(a.nodes, key, value) return a } @@ -40,11 +40,11 @@ func (a *Annotations) Forbid() *Policy { return newPolicy(EffectForbid, a.nodes) } -func (p *Policy) Annotate(key, value types.String) *Policy { +func (p *Policy) Annotate(key types.Ident, value types.String) *Policy { p.Annotations = addAnnotation(p.Annotations, key, value) return p } -func newAnnotation(key, value types.String) AnnotationType { +func newAnnotation(key types.Ident, value types.String) AnnotationType { return AnnotationType{Key: key, Value: value} } diff --git a/internal/ast/ast_test.go b/internal/ast/ast_test.go index f2d10918..185602e6 100644 --- a/internal/ast/ast_test.go +++ b/internal/ast/ast_test.go @@ -135,12 +135,12 @@ func TestASTByTable(t *testing.T) { { "scopePrincipalIs", ast.Permit().PrincipalIs("T"), - ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeIs{Type: types.EntityType("T")}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}}, + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeIs{Type: types.Path("T")}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}}, }, { "scopePrincipalIsIn", ast.Permit().PrincipalIsIn("T", types.NewEntityUID("T", "42")), - ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeIsIn{Type: types.EntityType("T"), Entity: types.NewEntityUID("T", "42")}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}}, + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeIsIn{Type: types.Path("T"), Entity: types.NewEntityUID("T", "42")}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}}, }, { "scopeActionEq", @@ -170,12 +170,12 @@ func TestASTByTable(t *testing.T) { { "scopeResourceIs", ast.Permit().ResourceIs("T"), - ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeIs{Type: types.EntityType("T")}}, + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeIs{Type: types.Path("T")}}, }, { "scopeResourceIsIn", ast.Permit().ResourceIsIn("T", types.NewEntityUID("T", "42")), - ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeIsIn{Type: types.EntityType("T"), Entity: types.NewEntityUID("T", "42")}}, + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeIsIn{Type: types.Path("T"), Entity: types.NewEntityUID("T", "42")}}, }, { "variablePrincipal", @@ -410,15 +410,15 @@ func TestASTByTable(t *testing.T) { }, { "opIs", - ast.Permit().When(ast.Long(42).Is(types.EntityType("T"))), + ast.Permit().When(ast.Long(42).Is("T")), ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, - Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeIs{Left: ast.NodeValue{Value: types.Long(42)}, EntityType: types.EntityType("T")}}}}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeIs{Left: ast.NodeValue{Value: types.Long(42)}, EntityType: types.Path("T")}}}}, }, { "opIsIn", - ast.Permit().When(ast.Long(42).IsIn(types.EntityType("T"), ast.Long(43))), + ast.Permit().When(ast.Long(42).IsIn("T", ast.Long(43))), ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, - Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeIsIn{NodeTypeIs: ast.NodeTypeIs{Left: ast.NodeValue{Value: types.Long(42)}, EntityType: types.EntityType("T")}, Entity: ast.NodeValue{Value: types.Long(43)}}}}}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeIsIn{NodeTypeIs: ast.NodeTypeIs{Left: ast.NodeValue{Value: types.Long(42)}, EntityType: types.Path("T")}, Entity: ast.NodeValue{Value: types.Long(43)}}}}}, }, { "opContains", diff --git a/internal/ast/node.go b/internal/ast/node.go index 41dc7091..30d7c835 100644 --- a/internal/ast/node.go +++ b/internal/ast/node.go @@ -76,7 +76,7 @@ func (n NodeTypeLike) isNode() {} type NodeTypeIs struct { Left IsNode - EntityType types.EntityType + EntityType types.Path } func (n NodeTypeIs) isNode() {} @@ -112,7 +112,7 @@ type NodeTypeNot struct{ UnaryNode } type NodeTypeAccess struct{ StrOpNode } type NodeTypeExtensionCall struct { - Name types.String // TODO: review type + Name types.Path Args []IsNode } @@ -126,14 +126,14 @@ func stripNodes(args []Node) []IsNode { return res } -func NewExtensionCall(method types.String, args ...Node) Node { +func NewExtensionCall(method types.Path, args ...Node) Node { return NewNode(NodeTypeExtensionCall{ Name: method, Args: stripNodes(args), }) } -func NewMethodCall(lhs Node, method types.String, args ...Node) Node { +func NewMethodCall(lhs Node, method types.Path, args ...Node) Node { res := make([]IsNode, 1+len(args)) res[0] = lhs.v for i, v := range args { @@ -179,7 +179,7 @@ type NodeTypeSet struct { func (n NodeTypeSet) isNode() {} type NodeTypeVariable struct { - Name types.String // TODO: Review type + Name types.String } func (n NodeTypeVariable) isNode() {} diff --git a/internal/ast/operator.go b/internal/ast/operator.go index 07287f86..81a39018 100644 --- a/internal/ast/operator.go +++ b/internal/ast/operator.go @@ -109,11 +109,11 @@ func (lhs Node) In(rhs Node) Node { return NewNode(NodeTypeIn{BinaryNode: BinaryNode{Left: lhs.v, Right: rhs.v}}) } -func (lhs Node) Is(entityType types.EntityType) Node { +func (lhs Node) Is(entityType types.Path) Node { return NewNode(NodeTypeIs{Left: lhs.v, EntityType: entityType}) } -func (lhs Node) IsIn(entityType types.EntityType, rhs Node) Node { +func (lhs Node) IsIn(entityType types.Path, rhs Node) Node { return NewNode(NodeTypeIsIn{NodeTypeIs: NodeTypeIs{Left: lhs.v, EntityType: entityType}, Entity: rhs.v}) } diff --git a/internal/ast/policy.go b/internal/ast/policy.go index cd11acac..5578ce39 100644 --- a/internal/ast/policy.go +++ b/internal/ast/policy.go @@ -5,7 +5,7 @@ import ( ) type AnnotationType struct { - Key types.String // TODO: review type + Key types.Ident Value types.String } type Condition bool diff --git a/internal/ast/scope.go b/internal/ast/scope.go index fa685447..2cb1a9ef 100644 --- a/internal/ast/scope.go +++ b/internal/ast/scope.go @@ -22,11 +22,11 @@ func (s Scope) InSet(entities []types.EntityUID) ScopeTypeInSet { return ScopeTypeInSet{Entities: entities} } -func (s Scope) Is(entityType types.EntityType) ScopeTypeIs { +func (s Scope) Is(entityType types.Path) ScopeTypeIs { return ScopeTypeIs{Type: entityType} } -func (s Scope) IsIn(entityType types.EntityType, entity types.EntityUID) ScopeTypeIsIn { +func (s Scope) IsIn(entityType types.Path, entity types.EntityUID) ScopeTypeIsIn { return ScopeTypeIsIn{Type: entityType, Entity: entity} } @@ -40,12 +40,12 @@ func (p *Policy) PrincipalIn(entity types.EntityUID) *Policy { return p } -func (p *Policy) PrincipalIs(entityType types.EntityType) *Policy { +func (p *Policy) PrincipalIs(entityType types.Path) *Policy { p.Principal = Scope{}.Is(entityType) return p } -func (p *Policy) PrincipalIsIn(entityType types.EntityType, entity types.EntityUID) *Policy { +func (p *Policy) PrincipalIsIn(entityType types.Path, entity types.EntityUID) *Policy { p.Principal = Scope{}.IsIn(entityType, entity) return p } @@ -75,12 +75,12 @@ func (p *Policy) ResourceIn(entity types.EntityUID) *Policy { return p } -func (p *Policy) ResourceIs(entityType types.EntityType) *Policy { +func (p *Policy) ResourceIs(entityType types.Path) *Policy { p.Resource = Scope{}.Is(entityType) return p } -func (p *Policy) ResourceIsIn(entityType types.EntityType, entity types.EntityUID) *Policy { +func (p *Policy) ResourceIsIn(entityType types.Path, entity types.EntityUID) *Policy { p.Resource = Scope{}.IsIn(entityType, entity) return p } @@ -153,13 +153,13 @@ type ScopeTypeIs struct { ScopeNode PrincipalScopeNode ResourceScopeNode - Type types.EntityType + Type types.Path } type ScopeTypeIsIn struct { ScopeNode PrincipalScopeNode ResourceScopeNode - Type types.EntityType + Type types.Path Entity types.EntityUID } diff --git a/internal/ast/value.go b/internal/ast/value.go index 5d59fd22..6e835d88 100644 --- a/internal/ast/value.go +++ b/internal/ast/value.go @@ -53,7 +53,7 @@ func Set(nodes ...Node) Node { } type Pair struct { - Key string + Key types.String Value Node } @@ -62,7 +62,7 @@ type Pairs []Pair // In the case where duplicate keys exist, the latter value will be preserved. func Record(elements Pairs) Node { var res NodeTypeRecord - m := make(map[string]int, len(elements)) + m := make(map[types.String]int, len(elements)) for _, v := range elements { if i, ok := m[v.Key]; ok { res.Elements[i] = RecordElementNode{Key: types.String(v.Key), Value: v.Value.v} @@ -74,8 +74,8 @@ func Record(elements Pairs) Node { return NewNode(res) } -func EntityUID(typ, id string) Node { - e := types.NewEntityUID(types.EntityType(typ), id) +func EntityUID(typ types.Ident, id types.String) Node { + e := types.NewEntityUID(types.Path(typ), types.String(id)) return Value(e) } @@ -83,7 +83,7 @@ func IPAddr[T netip.Prefix | types.IPAddr](i T) Node { return Value(types.IPAddr(i)) } -func ExtensionCall(name types.String, args ...Node) Node { +func ExtensionCall(name types.Path, args ...Node) Node { return NewExtensionCall(name, args...) } diff --git a/internal/eval/convert.go b/internal/eval/convert.go index c50b78bb..4ea615a0 100644 --- a/internal/eval/convert.go +++ b/internal/eval/convert.go @@ -19,10 +19,10 @@ func toEval(n ast.IsNode) Evaler { case ast.NodeTypeIfThenElse: return newIfThenElseEval(toEval(v.If), toEval(v.Then), toEval(v.Else)) case ast.NodeTypeIs: - return newIsEval(toEval(v.Left), newLiteralEval(v.EntityType)) + return newIsEval(toEval(v.Left), v.EntityType) case ast.NodeTypeIsIn: obj := toEval(v.Left) - lhs := newIsEval(obj, newLiteralEval(v.EntityType)) + lhs := newIsEval(obj, v.EntityType) rhs := newInEval(obj, toEval(v.Entity)) return newAndEval(lhs, rhs) case ast.NodeTypeExtensionCall: diff --git a/internal/eval/evalers.go b/internal/eval/evalers.go index bfc170f2..92061414 100644 --- a/internal/eval/evalers.go +++ b/internal/eval/evalers.go @@ -88,18 +88,6 @@ func evalEntity(n Evaler, ctx *Context) (types.EntityUID, error) { return e, nil } -func evalEntityType(n Evaler, ctx *Context) (types.EntityType, error) { - v, err := n.Eval(ctx) - if err != nil { - return "", err - } - e, err := ValueToEntityType(v) - if err != nil { - return "", err - } - return e, nil -} - func evalDecimal(n Evaler, ctx *Context) (types.Decimal, error) { v, err := n.Eval(ctx) if err != nil { @@ -975,10 +963,11 @@ func (n *inEval) Eval(ctx *Context) (types.Value, error) { // isEval type isEval struct { - lhs, rhs Evaler + lhs Evaler + rhs types.Path } -func newIsEval(lhs, rhs Evaler) *isEval { +func newIsEval(lhs Evaler, rhs types.Path) *isEval { return &isEval{lhs: lhs, rhs: rhs} } @@ -988,12 +977,7 @@ func (n *isEval) Eval(ctx *Context) (types.Value, error) { return zeroValue(), err } - rhs, err := evalEntityType(n.rhs, ctx) - if err != nil { - return zeroValue(), err - } - - return types.Boolean(types.EntityType(lhs.Type) == rhs), nil + return types.Boolean(lhs.Type == n.rhs), nil } // decimalLiteralEval diff --git a/internal/eval/evalers_test.go b/internal/eval/evalers_test.go index 8b98c6e8..d668c241 100644 --- a/internal/eval/evalers_test.go +++ b/internal/eval/evalers_test.go @@ -16,7 +16,7 @@ var errTest = fmt.Errorf("test error") // not a real parser func strEnt(v string) types.EntityUID { p := strings.Split(v, "::\"") - return types.EntityUID{Type: types.EntityType(p[0]), ID: p[1][:len(p[1])-1]} + return types.EntityUID{Type: types.Path(p[0]), ID: types.String(p[1][:len(p[1])-1])} } func AssertValue(t *testing.T, got, want types.Value) { @@ -1586,15 +1586,15 @@ func TestEntityIn(t *testing.T) { entityMap := entities.Entities{} for i := 0; i < 100; i++ { p := []types.EntityUID{ - types.NewEntityUID(types.EntityType(fmt.Sprint(i+1)), "1"), - types.NewEntityUID(types.EntityType(fmt.Sprint(i+1)), "2"), + types.NewEntityUID(types.Path(fmt.Sprint(i+1)), "1"), + types.NewEntityUID(types.Path(fmt.Sprint(i+1)), "2"), } - uid1 := types.NewEntityUID(types.EntityType(fmt.Sprint(i)), "1") + uid1 := types.NewEntityUID(types.Path(fmt.Sprint(i)), "1") entityMap[uid1] = entities.Entity{ UID: uid1, Parents: p, } - uid2 := types.NewEntityUID(types.EntityType(fmt.Sprint(i)), "2") + uid2 := types.NewEntityUID(types.Path(fmt.Sprint(i)), "2") entityMap[uid2] = entities.Entity{ UID: uid2, Parents: p, @@ -1609,17 +1609,16 @@ func TestEntityIn(t *testing.T) { func TestIsNode(t *testing.T) { t.Parallel() tests := []struct { - name string - lhs, rhs Evaler - result types.Value - err error + name string + lhs Evaler + rhs types.Path + result types.Value + err error }{ - {"happyEq", newLiteralEval(types.NewEntityUID("X", "z")), newLiteralEval(types.EntityType("X")), types.True, nil}, - {"happyNeq", newLiteralEval(types.NewEntityUID("X", "z")), newLiteralEval(types.EntityType("Y")), types.False, nil}, - {"badLhs", newLiteralEval(types.Long(42)), newLiteralEval(types.EntityType("X")), zeroValue(), ErrType}, - {"badRhs", newLiteralEval(types.NewEntityUID("X", "z")), newLiteralEval(types.Long(42)), zeroValue(), ErrType}, - {"errLhs", newErrorEval(errTest), newLiteralEval(types.EntityType("X")), zeroValue(), errTest}, - {"errRhs", newLiteralEval(types.NewEntityUID("X", "z")), newErrorEval(errTest), zeroValue(), errTest}, + {"happyEq", newLiteralEval(types.NewEntityUID("X", "z")), types.Path("X"), types.True, nil}, + {"happyNeq", newLiteralEval(types.NewEntityUID("X", "z")), types.Path("Y"), types.False, nil}, + {"badLhs", newLiteralEval(types.Long(42)), types.Path("X"), zeroValue(), ErrType}, + {"errLhs", newErrorEval(errTest), types.Path("X"), zeroValue(), errTest}, } for _, tt := range tests { tt := tt diff --git a/internal/eval/util.go b/internal/eval/util.go index 786c8342..6c62a618 100644 --- a/internal/eval/util.go +++ b/internal/eval/util.go @@ -12,8 +12,6 @@ func TypeName(v types.Value) string { return "bool" case types.Decimal: return "decimal" - case types.EntityType: - return fmt.Sprintf("(EntityType of type `%s`)", t) case types.EntityUID: return fmt.Sprintf("(entity of type `%s`)", t.Type) case types.IPAddr: @@ -81,14 +79,6 @@ func ValueToEntity(v types.Value) (types.EntityUID, error) { return ev, nil } -func ValueToEntityType(v types.Value) (types.EntityType, error) { - ev, ok := v.(types.EntityType) - if !ok { - return "", fmt.Errorf("%w: expected (EntityType of type `any_entity_type`), got %v", ErrType, TypeName(v)) - } - return ev, nil -} - func ValueToDecimal(v types.Value) (types.Decimal, error) { d, ok := v.(types.Decimal) if !ok { diff --git a/internal/eval/util_test.go b/internal/eval/util_test.go index 7e118d4b..08521407 100644 --- a/internal/eval/util_test.go +++ b/internal/eval/util_test.go @@ -162,7 +162,6 @@ func TestTypeName(t *testing.T) { {"boolean", types.Boolean(true), "bool"}, {"decimal", types.NewDecimal(42), "decimal"}, - {"entityType", types.EntityType("T"), "(EntityType of type `T`)"}, {"entityUID", types.NewEntityUID("T", "42"), "(entity of type `T`)"}, {"ip", types.IPAddr{}, "IP"}, {"long", types.Long(42), "long"}, diff --git a/internal/extensions/extensions.go b/internal/extensions/extensions.go index fe5a4e2c..3fa7f021 100644 --- a/internal/extensions/extensions.go +++ b/internal/extensions/extensions.go @@ -7,7 +7,7 @@ type extInfo struct { IsMethod bool } -var ExtMap = map[types.String]extInfo{ +var ExtMap = map[types.Path]extInfo{ "ip": {Args: 1, IsMethod: false}, "decimal": {Args: 1, IsMethod: false}, diff --git a/internal/json/json_test.go b/internal/json/json_test.go index 0886ea41..a3383d7f 100644 --- a/internal/json/json_test.go +++ b/internal/json/json_test.go @@ -119,13 +119,13 @@ func TestUnmarshalJSON(t *testing.T) { { "principalIs", `{"effect":"permit","principal":{"op":"is","entity_type":"T"},"action":{"op":"All"},"resource":{"op":"All"}}`, - ast.Permit().PrincipalIs(types.EntityType("T")), + ast.Permit().PrincipalIs(types.Path("T")), testutil.OK, }, { "principalIsIn", `{"effect":"permit","principal":{"op":"is","entity_type":"T","in":{"entity":{"type":"P","id":"42"}}},"action":{"op":"All"},"resource":{"op":"All"}}`, - ast.Permit().PrincipalIsIn(types.EntityType("T"), types.NewEntityUID("P", "42")), + ast.Permit().PrincipalIsIn(types.Path("T"), types.NewEntityUID("P", "42")), testutil.OK, }, { @@ -161,13 +161,13 @@ func TestUnmarshalJSON(t *testing.T) { { "resourceIs", `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"is","entity_type":"T"}}`, - ast.Permit().ResourceIs(types.EntityType("T")), + ast.Permit().ResourceIs(types.Path("T")), testutil.OK, }, { "resourceIsIn", `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"is","entity_type":"T","in":{"entity":{"type":"P","id":"42"}}}}`, - ast.Permit().ResourceIsIn(types.EntityType("T"), types.NewEntityUID("P", "42")), + ast.Permit().ResourceIsIn(types.Path("T"), types.NewEntityUID("P", "42")), testutil.OK, }, { diff --git a/internal/json/json_unmarshal.go b/internal/json/json_unmarshal.go index 6ea0e34a..329322de 100644 --- a/internal/json/json_unmarshal.go +++ b/internal/json/json_unmarshal.go @@ -33,9 +33,9 @@ func (s *scopeJSON) ToPrincipalResourceNode() (isPrincipalResourceScopeNode, err return ast.Scope{}.In(*s.Entity), nil case "is": if s.In == nil { - return ast.Scope{}.Is(types.EntityType(s.EntityType)), nil + return ast.Scope{}.Is(types.Path(s.EntityType)), nil } - return ast.Scope{}.IsIn(types.EntityType(s.EntityType), s.In.Entity), nil + return ast.Scope{}.IsIn(types.Path(s.EntityType), s.In.Entity), nil } return nil, fmt.Errorf("unknown op: %v", s.Op) } @@ -101,9 +101,9 @@ func (j isJSON) ToNode() (ast.Node, error) { if err != nil { return ast.Node{}, fmt.Errorf("error in entity: %w", err) } - return left.IsIn(types.EntityType(j.EntityType), right), nil + return left.IsIn(types.Path(j.EntityType), right), nil } - return left.Is(types.EntityType(j.EntityType)), nil + return left.Is(types.Path(j.EntityType)), nil } func (j ifThenElseJSON) ToNode() (ast.Node, error) { if_, err := j.If.ToNode() @@ -139,7 +139,7 @@ func (j recordJSON) ToNode() (ast.Node, error) { if err != nil { return ast.Node{}, fmt.Errorf("error in record: %w", err) } - nodes = append(nodes, ast.Pair{Key: k, Value: n}) + nodes = append(nodes, ast.Pair{Key: types.String(k), Value: n}) } return ast.Record(nodes), nil } @@ -153,7 +153,7 @@ func (e extensionJSON) ToNode() (ast.Node, error) { for k, v = range e { _, _ = k, v } - _, ok := extensions.ExtMap[types.String(k)] + _, ok := extensions.ExtMap[types.Path(k)] if !ok { return ast.Node{}, fmt.Errorf("`%v` is not a known extension function or method", k) } @@ -165,7 +165,7 @@ func (e extensionJSON) ToNode() (ast.Node, error) { } argNodes = append(argNodes, node) } - return ast.NewExtensionCall(types.String(k), argNodes...), nil + return ast.NewExtensionCall(types.Path(k), argNodes...), nil } func (j nodeJSON) ToNode() (ast.Node, error) { @@ -294,7 +294,7 @@ func (p *Policy) UnmarshalJSON(b []byte) error { return fmt.Errorf("unknown effect: %v", j.Effect) } for k, v := range j.Annotations { - p.unwrap().Annotate(types.String(k), types.String(v)) + p.unwrap().Annotate(types.Ident(k), types.String(v)) } var err error p.Principal, err = j.Principal.ToPrincipalResourceNode() diff --git a/internal/parser/cedar_unmarshal.go b/internal/parser/cedar_unmarshal.go index 933a6c00..ec284fe3 100644 --- a/internal/parser/cedar_unmarshal.go +++ b/internal/parser/cedar_unmarshal.go @@ -127,7 +127,7 @@ func (p *parser) errorf(s string, args ...interface{}) error { func (p *parser) annotations() (ast.Annotations, error) { var res ast.Annotations - known := map[types.String]struct{}{} + known := map[string]struct{}{} for p.peek().Text == "@" { p.advance() err := p.annotation(&res, known) @@ -139,13 +139,13 @@ func (p *parser) annotations() (ast.Annotations, error) { } -func (p *parser) annotation(a *ast.Annotations, known map[types.String]struct{}) error { +func (p *parser) annotation(a *ast.Annotations, known map[string]struct{}) error { var err error t := p.advance() if !t.isIdent() { return p.errorf("expected ident") } - name := types.String(t.Text) + name := t.Text if err = p.exact("("); err != nil { return err } @@ -165,7 +165,7 @@ func (p *parser) annotation(a *ast.Annotations, known map[types.String]struct{}) return err } - a.Annotation(name, types.String(value)) + a.Annotation(types.Ident(name), types.String(value)) return nil } @@ -230,12 +230,11 @@ func (p *parser) entity() (types.EntityUID, error) { if !t.isIdent() { return res, p.errorf("expected ident") } - return p.entityFirstPathPreread(types.EntityType(t.Text)) + return p.entityFirstPathPreread(types.Path(t.Text)) } -func (p *parser) entityFirstPathPreread(firstPath types.EntityType) (types.EntityUID, error) { +func (p *parser) entityFirstPathPreread(firstPath types.Path) (types.EntityUID, error) { var res types.EntityUID - var err error res.Type = firstPath for { if err := p.exact("::"); err != nil { @@ -244,12 +243,13 @@ func (p *parser) entityFirstPathPreread(firstPath types.EntityType) (types.Entit t := p.advance() switch { case t.isIdent(): - res.Type = types.EntityType(res.Type.String() + "::" + t.Text) + res.Type = types.Path(res.Type) + "::" + types.Path(t.Text) case t.isString(): - res.ID, err = t.stringValue() + id, err := t.stringValue() if err != nil { return res, err } + res.ID = types.String(id) return res, nil default: return res, p.errorf("unexpected token") @@ -257,8 +257,8 @@ func (p *parser) entityFirstPathPreread(firstPath types.EntityType) (types.Entit } } -func (p *parser) pathFirstPathPreread(firstPath string) (types.EntityType, error) { - res := types.EntityType(firstPath) +func (p *parser) pathFirstPathPreread(firstPath string) (types.Path, error) { + res := types.Path(firstPath) for { if p.peek().Text != "::" { return res, nil @@ -267,14 +267,14 @@ func (p *parser) pathFirstPathPreread(firstPath string) (types.EntityType, error t := p.advance() switch { case t.isIdent(): - res = types.EntityType(fmt.Sprintf("%v::%v", res, t.Text)) + res = types.Path(fmt.Sprintf("%v::%v", res, t.Text)) default: return res, p.errorf("unexpected token") } } } -func (p *parser) path() (types.EntityType, error) { +func (p *parser) path() (types.Path, error) { t := p.advance() if !t.isIdent() { return "", p.errorf("expected ident") @@ -756,14 +756,14 @@ func (p *parser) entityOrExtFun(prefix string) (ast.Node, error) { if err != nil { return ast.Node{}, err } - return ast.EntityUID(prefix, id), nil + return ast.EntityUID(types.Ident(prefix), types.String(id)), nil default: return ast.Node{}, p.errorf("unexpected token") } case "(": // Although the Cedar grammar says that any name can be provided here, the reference implementation actually // checks at parse time whether the name corresponds to a known extension function. - i, ok := extensions.ExtMap[types.String(prefix)] + i, ok := extensions.ExtMap[types.Path(prefix)] if !ok { return ast.Node{}, p.errorf("`%v` is not a function", prefix) } @@ -776,7 +776,7 @@ func (p *parser) entityOrExtFun(prefix string) (ast.Node, error) { return ast.Node{}, err } p.advance() - return ast.ExtensionCall(types.String(prefix), args...), nil + return ast.ExtensionCall(types.Path(prefix), args...), nil default: return ast.Node{}, p.errorf("unexpected token") } @@ -824,7 +824,7 @@ func (p *parser) record() (ast.Node, error) { return res, p.errorf("duplicate key: %v", k) } known[k] = struct{}{} - elements = append(elements, ast.Pair{Key: k, Value: v}) + elements = append(elements, ast.Pair{Key: types.String(k), Value: v}) } } @@ -884,14 +884,14 @@ func (p *parser) access(lhs ast.Node) (ast.Node, bool, error) { default: // Although the Cedar grammar says that any name can be provided here, the reference implementation // actually checks at parse time whether the name corresponds to a known extension method. - i, ok := extensions.ExtMap[types.String(methodName)] + i, ok := extensions.ExtMap[types.Path(methodName)] if !ok { return ast.Node{}, false, p.errorf("`%v` is not a method", methodName) } if !i.IsMethod { return ast.Node{}, false, p.errorf("`%v` is a function, not a method", methodName) } - return ast.NewMethodCall(lhs, types.String(methodName), exprs...), true, nil + return ast.NewMethodCall(lhs, types.Path(methodName), exprs...), true, nil } if len(exprs) != 1 { diff --git a/types/decimal.go b/types/decimal.go index ba19eab6..74395347 100644 --- a/types/decimal.go +++ b/types/decimal.go @@ -17,6 +17,7 @@ type Decimal struct { // NewDecimal creates a decimal via trivial conversion from int, int64, float64. // Precision may be lost and overflows may occur. +// TODO: reconsider ... func NewDecimal[T int | int64 | float64](v T) Decimal { return Decimal{Value: int64(v * DecimalPrecision)} } diff --git a/types/entity_type.go b/types/entity_type.go deleted file mode 100644 index 0b595d30..00000000 --- a/types/entity_type.go +++ /dev/null @@ -1,23 +0,0 @@ -package types - -import ( - "encoding/json" - "strings" -) - -// EntityType is the type portion of an EntityUID -type EntityType string - -func (a EntityType) Equal(bi Value) bool { - b, ok := bi.(EntityType) - return ok && a == b -} - -func (v EntityType) String() string { return string(v) } -func (v EntityType) MarshalCedar() []byte { return []byte(v) } -func (v EntityType) ExplicitMarshalJSON() ([]byte, error) { return json.Marshal(string(v)) } -func (v EntityType) deepClone() Value { return v } - -func EntityTypeFromSlice(v []string) EntityType { - return EntityType(strings.Join(v, "::")) -} diff --git a/types/entity_type_test.go b/types/entity_type_test.go deleted file mode 100644 index 467dcd29..00000000 --- a/types/entity_type_test.go +++ /dev/null @@ -1,46 +0,0 @@ -package types_test - -import ( - "testing" - - "github.com/cedar-policy/cedar-go/internal/testutil" - "github.com/cedar-policy/cedar-go/types" -) - -func TestEntityType(t *testing.T) { - t.Parallel() - t.Run("Equal", func(t *testing.T) { - t.Parallel() - a := types.EntityType("X") - b := types.EntityType("X") - c := types.EntityType("Y") - testutil.Equals(t, a.Equal(b), true) - testutil.Equals(t, b.Equal(a), true) - testutil.Equals(t, a.Equal(c), false) - testutil.Equals(t, c.Equal(a), false) - }) - - t.Run("String", func(t *testing.T) { - t.Parallel() - a := types.EntityType("X") - testutil.Equals(t, a.String(), "X") - }) - t.Run("Cedar", func(t *testing.T) { - t.Parallel() - a := types.EntityType("X") - testutil.Equals(t, a.MarshalCedar(), []byte("X")) - }) - t.Run("ExplicitMarshalJSON", func(t *testing.T) { - t.Parallel() - a := types.EntityType("X") - v, err := a.ExplicitMarshalJSON() - testutil.OK(t, err) - testutil.Equals(t, string(v), `"X"`) - }) - t.Run("pathFromSlice", func(t *testing.T) { - t.Parallel() - a := types.EntityTypeFromSlice([]string{"X", "Y"}) - testutil.Equals(t, a, types.EntityType("X::Y")) - }) - -} diff --git a/types/entity_uid.go b/types/entity_uid.go index 70277808..c1b84868 100644 --- a/types/entity_uid.go +++ b/types/entity_uid.go @@ -7,11 +7,11 @@ import ( // An EntityUID is the identifier for a principal, action, or resource. type EntityUID struct { - Type EntityType - ID string + Type Path + ID String } -func NewEntityUID(typ EntityType, id string) EntityUID { +func NewEntityUID(typ Path, id String) EntityUID { return EntityUID{ Type: typ, ID: id, @@ -29,7 +29,7 @@ func (a EntityUID) Equal(bi Value) bool { } // String produces a string representation of the EntityUID, e.g. `Type::"id"`. -func (v EntityUID) String() string { return string(v.Type.String() + "::" + strconv.Quote(v.ID)) } +func (v EntityUID) String() string { return string(v.Type) + "::" + strconv.Quote(string(v.ID)) } // MarshalCedar produces a valid MarshalCedar language representation of the EntityUID, e.g. `Type::"id"`. func (v EntityUID) MarshalCedar() []byte { @@ -43,12 +43,12 @@ func (v *EntityUID) UnmarshalJSON(b []byte) error { return err } if res.Entity != nil { - v.Type = res.Entity.Type - v.ID = res.Entity.ID + v.Type = Path(res.Entity.Type) + v.ID = String(res.Entity.ID) return nil } else if res.Type != nil && res.ID != nil { // require both Type and ID to parse "implicit" JSON - v.Type = *res.Type - v.ID = *res.ID + v.Type = Path(*res.Type) + v.ID = String(*res.ID) return nil } return errJSONEntityNotFound @@ -57,8 +57,8 @@ func (v *EntityUID) UnmarshalJSON(b []byte) error { // ExplicitMarshalJSON marshals the EntityUID into JSON using the implicit form. func (v EntityUID) MarshalJSON() ([]byte, error) { return json.Marshal(entityValueJSON{ - Type: &v.Type, - ID: &v.ID, + Type: (*string)(&v.Type), + ID: (*string)(&v.ID), }) } @@ -66,8 +66,8 @@ func (v EntityUID) MarshalJSON() ([]byte, error) { func (v EntityUID) ExplicitMarshalJSON() ([]byte, error) { return json.Marshal(entityValueJSON{ Entity: &extEntity{ - Type: v.Type, - ID: v.ID, + Type: string(v.Type), + ID: string(v.ID), }, }) } diff --git a/types/ident.go b/types/ident.go new file mode 100644 index 00000000..45d4e334 --- /dev/null +++ b/types/ident.go @@ -0,0 +1,3 @@ +package types + +type Ident string diff --git a/types/json.go b/types/json.go index 9d31245e..3fc2a9c6 100644 --- a/types/json.go +++ b/types/json.go @@ -27,14 +27,14 @@ type extValueJSON struct { } type extEntity struct { - Type EntityType `json:"type"` + Type string `json:"type"` ID string `json:"id"` } type entityValueJSON struct { - Type *EntityType `json:"type,omitempty"` - ID *string `json:"id,omitempty"` - Entity *extEntity `json:"__entity,omitempty"` + Type *string `json:"type,omitempty"` + ID *string `json:"id,omitempty"` + Entity *extEntity `json:"__entity,omitempty"` } type explicitValue struct { diff --git a/types/path.go b/types/path.go new file mode 100644 index 00000000..8c53c402 --- /dev/null +++ b/types/path.go @@ -0,0 +1,10 @@ +package types + +import "strings" + +// Path is the type portion of an EntityUID +type Path string + +func PathFromSlice(v []string) Path { + return Path(strings.Join(v, "::")) +} diff --git a/types/path_test.go b/types/path_test.go new file mode 100644 index 00000000..efc3f4f2 --- /dev/null +++ b/types/path_test.go @@ -0,0 +1,18 @@ +package types_test + +import ( + "testing" + + "github.com/cedar-policy/cedar-go/internal/testutil" + "github.com/cedar-policy/cedar-go/types" +) + +func TestEntityType(t *testing.T) { + t.Parallel() + t.Run("pathFromSlice", func(t *testing.T) { + t.Parallel() + a := types.PathFromSlice([]string{"X", "Y"}) + testutil.Equals(t, a, types.Path("X::Y")) + }) + +} diff --git a/types/value_test.go b/types/value_test.go index b82d14ac..7054d386 100644 --- a/types/value_test.go +++ b/types/value_test.go @@ -44,14 +44,6 @@ func TestDeepClone(t *testing.T) { testutil.Equals(t, a, NewEntityUID("Action", "bananas")) testutil.Equals(t, b, Value(NewEntityUID("Action", "test"))) }) - t.Run("EntityType", func(t *testing.T) { - t.Parallel() - a := EntityType("X") - b := a.deepClone() - c, ok := b.(EntityType) - testutil.Equals(t, ok, true) - testutil.Equals(t, c, a) - }) t.Run("Set", func(t *testing.T) { t.Parallel() a := Set{Long(42)} From 308ed8ecec01d4815d64c02a640775c98b497b29 Mon Sep 17 00:00:00 2001 From: philhassey Date: Tue, 20 Aug 2024 14:21:03 -0600 Subject: [PATCH 188/216] types: make unsafe decimal feel unsafe Addresses IDX-142 Signed-off-by: philhassey --- internal/eval/convert_test.go | 10 +++++----- internal/eval/evalers_test.go | 4 ++-- internal/eval/util_test.go | 2 +- internal/json/json_test.go | 2 +- types/decimal.go | 5 ++--- types/decimal_test.go | 6 +++--- types/value_test.go | 8 ++++---- 7 files changed, 18 insertions(+), 19 deletions(-) diff --git a/internal/eval/convert_test.go b/internal/eval/convert_test.go index 8e80ca68..c2eefc5b 100644 --- a/internal/eval/convert_test.go +++ b/internal/eval/convert_test.go @@ -206,30 +206,30 @@ func TestToEval(t *testing.T) { { "decimal", ast.ExtensionCall("decimal", ast.String("42.42")), - types.NewDecimal(42.42), + types.UnsafeDecimal(42.42), testutil.OK, }, { "lessThan", - ast.ExtensionCall("lessThan", ast.Value(types.NewDecimal(42.0)), ast.Value(types.NewDecimal(43))), + ast.ExtensionCall("lessThan", ast.Value(types.UnsafeDecimal(42.0)), ast.Value(types.UnsafeDecimal(43))), types.True, testutil.OK, }, { "lessThanOrEqual", - ast.ExtensionCall("lessThanOrEqual", ast.Value(types.NewDecimal(42.0)), ast.Value(types.NewDecimal(43))), + ast.ExtensionCall("lessThanOrEqual", ast.Value(types.UnsafeDecimal(42.0)), ast.Value(types.UnsafeDecimal(43))), types.True, testutil.OK, }, { "greaterThan", - ast.ExtensionCall("greaterThan", ast.Value(types.NewDecimal(42.0)), ast.Value(types.NewDecimal(43))), + ast.ExtensionCall("greaterThan", ast.Value(types.UnsafeDecimal(42.0)), ast.Value(types.UnsafeDecimal(43))), types.False, testutil.OK, }, { "greaterThanOrEqual", - ast.ExtensionCall("greaterThanOrEqual", ast.Value(types.NewDecimal(42.0)), ast.Value(types.NewDecimal(43))), + ast.ExtensionCall("greaterThanOrEqual", ast.Value(types.UnsafeDecimal(42.0)), ast.Value(types.UnsafeDecimal(43))), types.False, testutil.OK, }, diff --git a/internal/eval/evalers_test.go b/internal/eval/evalers_test.go index d668c241..dfb9b53f 100644 --- a/internal/eval/evalers_test.go +++ b/internal/eval/evalers_test.go @@ -1759,7 +1759,7 @@ func TestDecimalLiteralNode(t *testing.T) { {"Error", newErrorEval(errTest), zeroValue(), errTest}, {"TypeError", newLiteralEval(types.Long(1)), zeroValue(), ErrType}, {"DecimalError", newLiteralEval(types.String("frob")), zeroValue(), types.ErrDecimal}, - {"Success", newLiteralEval(types.String("1.0")), types.NewDecimal(1), nil}, + {"Success", newLiteralEval(types.String("1.0")), types.UnsafeDecimal(1), nil}, } for _, tt := range tests { tt := tt @@ -1891,7 +1891,7 @@ func TestCedarString(t *testing.T) { {"set", types.Set{types.Long(42), types.Long(43)}, `[42, 43]`, `[42, 43]`}, {"singleIP", types.IPAddr(netip.MustParsePrefix("192.168.0.42/32")), `192.168.0.42`, `ip("192.168.0.42")`}, {"ipPrefix", types.IPAddr(netip.MustParsePrefix("192.168.0.42/24")), `192.168.0.42/24`, `ip("192.168.0.42/24")`}, - {"decimal", types.NewDecimal(1234.5678), `1234.5678`, `decimal("1234.5678")`}, + {"decimal", types.UnsafeDecimal(1234.5678), `1234.5678`, `decimal("1234.5678")`}, } for _, tt := range tests { tt := tt diff --git a/internal/eval/util_test.go b/internal/eval/util_test.go index 08521407..75615b66 100644 --- a/internal/eval/util_test.go +++ b/internal/eval/util_test.go @@ -161,7 +161,7 @@ func TestTypeName(t *testing.T) { }{ {"boolean", types.Boolean(true), "bool"}, - {"decimal", types.NewDecimal(42), "decimal"}, + {"decimal", types.UnsafeDecimal(42), "decimal"}, {"entityUID", types.NewEntityUID("T", "42"), "(entity of type `T`)"}, {"ip", types.IPAddr{}, "IP"}, {"long", types.Long(42), "long"}, diff --git a/internal/json/json_test.go b/internal/json/json_test.go index a3383d7f..8bab6311 100644 --- a/internal/json/json_test.go +++ b/internal/json/json_test.go @@ -718,7 +718,7 @@ func TestMarshalExtensions(t *testing.T) { }{ { "decimalType", - ast.Permit().When(ast.Value(types.NewDecimal(42))), + ast.Permit().When(ast.Value(types.UnsafeDecimal(42))), `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"},"conditions":[{"kind":"when","body":{"decimal":[{"Value":"42.0"}]}}]}`, }, { diff --git a/types/decimal.go b/types/decimal.go index 74395347..3d224e86 100644 --- a/types/decimal.go +++ b/types/decimal.go @@ -15,10 +15,9 @@ type Decimal struct { Value int64 } -// NewDecimal creates a decimal via trivial conversion from int, int64, float64. +// UnsafeDecimal creates a decimal via unsafe conversion from int, int64, float64. // Precision may be lost and overflows may occur. -// TODO: reconsider ... -func NewDecimal[T int | int64 | float64](v T) Decimal { +func UnsafeDecimal[T int | int64 | float64](v T) Decimal { return Decimal{Value: int64(v * DecimalPrecision)} } diff --git a/types/decimal_test.go b/types/decimal_test.go index 604ea6ce..2640af2f 100644 --- a/types/decimal_test.go +++ b/types/decimal_test.go @@ -109,9 +109,9 @@ func TestDecimal(t *testing.T) { t.Run("Equal", func(t *testing.T) { t.Parallel() - one := types.NewDecimal(1) - one2 := types.NewDecimal(1) - zero := types.NewDecimal(0) + one := types.UnsafeDecimal(1) + one2 := types.UnsafeDecimal(1) + zero := types.UnsafeDecimal(0) f := types.Boolean(false) testutil.FatalIf(t, !one.Equal(one), "%v not Equal to %v", one, one) testutil.FatalIf(t, !one.Equal(one2), "%v not Equal to %v", one, one2) diff --git a/types/value_test.go b/types/value_test.go index 7054d386..a148e241 100644 --- a/types/value_test.go +++ b/types/value_test.go @@ -79,12 +79,12 @@ func TestDeepClone(t *testing.T) { t.Run("Decimal", func(t *testing.T) { t.Parallel() - a := NewDecimal(42) + a := UnsafeDecimal(42) b := a.deepClone() testutil.Equals(t, Value(a), b) - a = NewDecimal(43) - testutil.Equals(t, a, NewDecimal(43)) - testutil.Equals(t, b, Value(NewDecimal(42))) + a = UnsafeDecimal(43) + testutil.Equals(t, a, UnsafeDecimal(43)) + testutil.Equals(t, b, Value(UnsafeDecimal(42))) }) t.Run("IPAddr", func(t *testing.T) { From 595915bdb826c626e5b70a770f66b239376c4ad5 Mon Sep 17 00:00:00 2001 From: philhassey Date: Tue, 20 Aug 2024 14:59:32 -0600 Subject: [PATCH 189/216] types: tweak shape of pattern Addresses IDX-142 Signed-off-by: philhassey --- types/json.go | 17 +++++++------- types/pattern.go | 28 ++++++++++++----------- types/patttern_test.go | 52 +++++++++++++++++++++--------------------- 3 files changed, 49 insertions(+), 48 deletions(-) diff --git a/types/json.go b/types/json.go index 3fc2a9c6..8dd29cd7 100644 --- a/types/json.go +++ b/types/json.go @@ -7,14 +7,13 @@ import ( ) var ( - errJSONInvalidExtn = fmt.Errorf("invalid extension") - errJSONDecode = fmt.Errorf("error decoding json") - errJSONLongOutOfRange = fmt.Errorf("long out of range") - errJSONUnsupportedType = fmt.Errorf("unsupported type") - errJSONExtFnMatch = fmt.Errorf("json extn mismatch") - errJSONExtNotFound = fmt.Errorf("json extn not found") - errJSONEntityNotFound = fmt.Errorf("json entity not found") - errJSONInvalidPatternComponent = fmt.Errorf("invalid pattern component") + errJSONInvalidExtn = fmt.Errorf("invalid extension") + errJSONDecode = fmt.Errorf("error decoding json") + errJSONLongOutOfRange = fmt.Errorf("long out of range") + errJSONUnsupportedType = fmt.Errorf("unsupported type") + errJSONExtFnMatch = fmt.Errorf("json extn mismatch") + errJSONExtNotFound = fmt.Errorf("json extn not found") + errJSONEntityNotFound = fmt.Errorf("json entity not found") ) type extn struct { @@ -28,7 +27,7 @@ type extValueJSON struct { type extEntity struct { Type string `json:"type"` - ID string `json:"id"` + ID string `json:"id"` } type entityValueJSON struct { diff --git a/types/pattern.go b/types/pattern.go index 87d779c3..ad64bab5 100644 --- a/types/pattern.go +++ b/types/pattern.go @@ -10,6 +10,8 @@ import ( "github.com/cedar-policy/cedar-go/internal/rust" ) +var errJSONInvalidPatternComponent = fmt.Errorf("invalid pattern component") + type patternComponent struct { Wildcard bool Literal string @@ -27,21 +29,21 @@ type PatternComponent interface { isPatternComponent() } -type WildcardPatternComponent struct{} - -func (WildcardPatternComponent) isPatternComponent() {} +type wildcardComponent struct{} -// Wildcard is a constant which can be used to conveniently construct an instance of WildcardPatternComponent -var Wildcard = WildcardPatternComponent{} +func (wildcardComponent) isPatternComponent() {} func (String) isPatternComponent() {} +// Wildcard is a constant which can be used to conveniently construct an instance of WildcardPatternComponent +func Wildcard() PatternComponent { return wildcardComponent{} } + // NewPattern permits for the programmatic construction of a Pattern out of a set of PatternComponents. func NewPattern(components ...PatternComponent) Pattern { var comps []patternComponent for _, c := range components { switch v := c.(type) { - case WildcardPatternComponent: + case wildcardComponent: if len(comps) == 0 || comps[len(comps)-1].Literal != "" { comps = append(comps, patternComponent{Wildcard: true, Literal: ""}) } @@ -85,7 +87,7 @@ func (p Pattern) MarshalCedar() []byte { // term: // '*' matches any sequence of non-Separator characters // c matches character c (c != '*') -func (p Pattern) Match(arg string) (matched bool) { +func (p Pattern) Match(arg String) (matched bool) { Pattern: for i, comp := range p.comps { lastChunk := i == len(p.comps)-1 @@ -93,24 +95,24 @@ Pattern: return true } // Look for Match at current position. - t, ok := matchChunk(comp.Literal, arg) + t, ok := matchChunk(comp.Literal, string(arg)) // if we're the last chunk, make sure we've exhausted the name // otherwise we'll give a false result even if we could still Match // using the star if ok && (len(t) == 0 || !lastChunk) { - arg = t + arg = String(t) continue } if comp.Wildcard { // Look for Match skipping i+1 bytes. for i := 0; i < len(arg); i++ { - t, ok := matchChunk(comp.Literal, arg[i+1:]) + t, ok := matchChunk(comp.Literal, string(arg[i+1:])) if ok { // if we're the last chunk, make sure we exhausted the name if lastChunk && len(t) > 0 { continue } - arg = t + arg = String(t) continue Pattern } } @@ -144,7 +146,7 @@ func ParsePattern(v string) (Pattern, error) { for len(b) > 0 { for len(b) > 0 && b[0] == '*' { b = b[1:] - comps = append(comps, Wildcard) + comps = append(comps, Wildcard()) } var err error var literal string @@ -200,7 +202,7 @@ func (p *Pattern) UnmarshalJSON(b []byte) error { if v != "Wildcard" { return fmt.Errorf(`%w: invalid component string "%v"`, errJSONInvalidPatternComponent, v) } - comps = append(comps, Wildcard) + comps = append(comps, Wildcard()) case map[string]any: if len(v) != 1 { return fmt.Errorf(`%w: too many keys in literal object`, errJSONInvalidPatternComponent) diff --git a/types/patttern_test.go b/types/patttern_test.go index ce3b6319..27837f1f 100644 --- a/types/patttern_test.go +++ b/types/patttern_test.go @@ -8,8 +8,8 @@ import ( func TestPatternFromBuilder(t *testing.T) { t.Run("saturate two wildcards", func(t *testing.T) { - pattern1 := NewPattern(Wildcard, Wildcard) - pattern2 := NewPattern(Wildcard) + pattern1 := NewPattern(Wildcard(), Wildcard()) + pattern2 := NewPattern(Wildcard()) testutil.Equals(t, pattern1, pattern2) }) t.Run("saturate two literals", func(t *testing.T) { @@ -30,21 +30,21 @@ func TestParsePattern(t *testing.T) { }{ {"", true, NewPattern(), ""}, {"a", true, NewPattern(a), ""}, - {"*", true, NewPattern(Wildcard), ""}, - {"*a", true, NewPattern(Wildcard, a), ""}, - {"a*", true, NewPattern(a, Wildcard), ""}, - {"**", true, NewPattern(Wildcard), ""}, - {"**a", true, NewPattern(Wildcard, a), ""}, - {"a**", true, NewPattern(a, Wildcard), ""}, - {"*a*", true, NewPattern(Wildcard, a, Wildcard), ""}, - {"**a**", true, NewPattern(Wildcard, a, Wildcard), ""}, - {"abra*ca", true, NewPattern(String("abra"), Wildcard, String("ca")), ""}, - {"abra**ca", true, NewPattern(String("abra"), Wildcard, String("ca")), ""}, - {"*abra*ca", true, NewPattern(Wildcard, String("abra"), Wildcard, String("ca")), ""}, - {"abra*ca*", true, NewPattern(String("abra"), Wildcard, String("ca"), Wildcard), ""}, - {"*abra*ca*", true, NewPattern(Wildcard, String("abra"), Wildcard, String("ca"), Wildcard), ""}, - {"*abra*ca*dabra", true, NewPattern(Wildcard, String("abra"), Wildcard, String("ca"), Wildcard, String("dabra")), ""}, - {`*abra*c\**da\*bra`, true, NewPattern(Wildcard, String("abra"), Wildcard, String("c*"), Wildcard, String("da*bra")), ""}, + {"*", true, NewPattern(Wildcard()), ""}, + {"*a", true, NewPattern(Wildcard(), a), ""}, + {"a*", true, NewPattern(a, Wildcard()), ""}, + {"**", true, NewPattern(Wildcard()), ""}, + {"**a", true, NewPattern(Wildcard(), a), ""}, + {"a**", true, NewPattern(a, Wildcard()), ""}, + {"*a*", true, NewPattern(Wildcard(), a, Wildcard()), ""}, + {"**a**", true, NewPattern(Wildcard(), a, Wildcard()), ""}, + {"abra*ca", true, NewPattern(String("abra"), Wildcard(), String("ca")), ""}, + {"abra**ca", true, NewPattern(String("abra"), Wildcard(), String("ca")), ""}, + {"*abra*ca", true, NewPattern(Wildcard(), String("abra"), Wildcard(), String("ca")), ""}, + {"abra*ca*", true, NewPattern(String("abra"), Wildcard(), String("ca"), Wildcard()), ""}, + {"*abra*ca*", true, NewPattern(Wildcard(), String("abra"), Wildcard(), String("ca"), Wildcard()), ""}, + {"*abra*ca*dabra", true, NewPattern(Wildcard(), String("abra"), Wildcard(), String("ca"), Wildcard(), String("dabra")), ""}, + {`*abra*c\**da\*bra`, true, NewPattern(Wildcard(), String("abra"), Wildcard(), String("c*"), Wildcard(), String("da*bra")), ""}, {`\u`, false, Pattern{}, "bad unicode rune"}, } for _, tt := range tests { @@ -96,7 +96,7 @@ func TestMatch(t *testing.T) { t.Parallel() pat, err := ParsePattern(tt.pattern[1 : len(tt.pattern)-1]) testutil.OK(t, err) - got := pat.Match(tt.target) + got := pat.Match(String(tt.target)) testutil.Equals(t, got, tt.want) }) } @@ -115,7 +115,7 @@ func TestJSON(t *testing.T) { "like single wildcard", `["Wildcard"]`, testutil.OK, - NewPattern(Wildcard), + NewPattern(Wildcard()), true, }, { @@ -129,49 +129,49 @@ func TestJSON(t *testing.T) { "like wildcard then literal", `["Wildcard", {"Literal":"foo"}]`, testutil.OK, - NewPattern(Wildcard, String("foo")), + NewPattern(Wildcard(), String("foo")), true, }, { "like literal then wildcard", `[{"Literal":"foo"}, "Wildcard"]`, testutil.OK, - NewPattern(String("foo"), Wildcard), + NewPattern(String("foo"), Wildcard()), true, }, { "like literal with asterisk then wildcard", `[{"Literal":"f*oo"}, "Wildcard"]`, testutil.OK, - NewPattern(String("f*oo"), Wildcard), + NewPattern(String("f*oo"), Wildcard()), true, }, { "like literal sandwich", `[{"Literal":"foo"}, "Wildcard", {"Literal":"bar"}]`, testutil.OK, - NewPattern(String("foo"), Wildcard, String("bar")), + NewPattern(String("foo"), Wildcard(), String("bar")), true, }, { "like wildcard sandwich", `["Wildcard", {"Literal":"foo"}, "Wildcard"]`, testutil.OK, - NewPattern(Wildcard, String("foo"), Wildcard), + NewPattern(Wildcard(), String("foo"), Wildcard()), true, }, { "double wildcard", `["Wildcard", "Wildcard", {"Literal":"foo"}]`, testutil.OK, - NewPattern(Wildcard, String("foo")), + NewPattern(Wildcard(), String("foo")), false, }, { "double literal", `["Wildcard", {"Literal":"foo"}, {"Literal":"bar"}]`, testutil.OK, - NewPattern(Wildcard, String("foobar")), + NewPattern(Wildcard(), String("foobar")), false, }, { From 454072fa4e2f911d60bdafbac464ebd8f4b35ed6 Mon Sep 17 00:00:00 2001 From: philhassey Date: Tue, 20 Aug 2024 15:09:00 -0600 Subject: [PATCH 190/216] internal/parser: move ParsePattern into the parser Addresses IDX-142 Signed-off-by: philhassey --- ast/ast_test.go | 4 +- internal/ast/ast_test.go | 4 +- internal/eval/evalers.go | 2 +- internal/eval/evalers_test.go | 3 +- internal/json/json_test.go | 6 +- internal/parser/cedar_unmarshal.go | 2 +- internal/parser/cedar_unmarshal_test.go | 2 +- internal/parser/pattern.go | 26 +++++++ internal/parser/pattern_test.go | 91 +++++++++++++++++++++++++ types/pattern.go | 22 ------ types/patttern_test.go | 83 ---------------------- 11 files changed, 129 insertions(+), 116 deletions(-) create mode 100644 internal/parser/pattern.go create mode 100644 internal/parser/pattern_test.go diff --git a/ast/ast_test.go b/ast/ast_test.go index b0b28b9f..afa76c52 100644 --- a/ast/ast_test.go +++ b/ast/ast_test.go @@ -295,8 +295,8 @@ func TestASTByTable(t *testing.T) { }, { "opLike", - ast.Permit().When(ast.Long(42).Like(types.NewPattern(types.Wildcard))), - internalast.Permit().When(internalast.Long(42).Like(types.NewPattern(types.Wildcard))), + ast.Permit().When(ast.Long(42).Like(types.NewPattern(types.Wildcard()))), + internalast.Permit().When(internalast.Long(42).Like(types.NewPattern(types.Wildcard()))), }, { "opAnd", diff --git a/internal/ast/ast_test.go b/internal/ast/ast_test.go index 185602e6..55a10f63 100644 --- a/internal/ast/ast_test.go +++ b/internal/ast/ast_test.go @@ -350,9 +350,9 @@ func TestASTByTable(t *testing.T) { }, { "opLike", - ast.Permit().When(ast.Long(42).Like(types.NewPattern(types.Wildcard))), + ast.Permit().When(ast.Long(42).Like(types.NewPattern(types.Wildcard()))), ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, - Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeLike{Arg: ast.NodeValue{Value: types.Long(42)}, Value: types.NewPattern(types.Wildcard)}}}}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeLike{Arg: ast.NodeValue{Value: types.Long(42)}, Value: types.NewPattern(types.Wildcard())}}}}, }, { "opAnd", diff --git a/internal/eval/evalers.go b/internal/eval/evalers.go index 92061414..afd9dc33 100644 --- a/internal/eval/evalers.go +++ b/internal/eval/evalers.go @@ -881,7 +881,7 @@ func (l *likeEval) Eval(ctx *Context) (types.Value, error) { if err != nil { return zeroValue(), err } - return types.Boolean(l.pattern.Match(string(v))), nil + return types.Boolean(l.pattern.Match(v)), nil } type variableName func(ctx *Context) types.Value diff --git a/internal/eval/evalers_test.go b/internal/eval/evalers_test.go index dfb9b53f..3baa3063 100644 --- a/internal/eval/evalers_test.go +++ b/internal/eval/evalers_test.go @@ -7,6 +7,7 @@ import ( "testing" "github.com/cedar-policy/cedar-go/internal/entities" + "github.com/cedar-policy/cedar-go/internal/parser" "github.com/cedar-policy/cedar-go/internal/testutil" "github.com/cedar-policy/cedar-go/types" ) @@ -1424,7 +1425,7 @@ func TestLikeNode(t *testing.T) { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() - pat, err := types.ParsePattern(tt.pattern[1 : len(tt.pattern)-1]) + pat, err := parser.ParsePattern(tt.pattern[1 : len(tt.pattern)-1]) testutil.OK(t, err) n := newLikeEval(tt.str, pat) v, err := n.Eval(&Context{}) diff --git a/internal/json/json_test.go b/internal/json/json_test.go index 8bab6311..cb1138a3 100644 --- a/internal/json/json_test.go +++ b/internal/json/json_test.go @@ -399,7 +399,7 @@ func TestUnmarshalJSON(t *testing.T) { "like single wildcard", `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, "conditions":[{"kind":"when","body":{"like":{"left":{"Value":"text"},"pattern":["Wildcard"]}}}]}`, - ast.Permit().When(ast.String("text").Like(types.NewPattern(types.Wildcard))), + ast.Permit().When(ast.String("text").Like(types.NewPattern(types.Wildcard()))), testutil.OK, }, { @@ -413,14 +413,14 @@ func TestUnmarshalJSON(t *testing.T) { "like wildcard then literal", `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, "conditions":[{"kind":"when","body":{"like":{"left":{"Value":"text"},"pattern":["Wildcard", {"Literal":"foo"}]}}}]}`, - ast.Permit().When(ast.String("text").Like(types.NewPattern(types.Wildcard, types.String("foo")))), + ast.Permit().When(ast.String("text").Like(types.NewPattern(types.Wildcard(), types.String("foo")))), testutil.OK, }, { "like literal then wildcard", `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, "conditions":[{"kind":"when","body":{"like":{"left":{"Value":"text"},"pattern":[{"Literal":"foo"}, "Wildcard"]}}}]}`, - ast.Permit().When(ast.String("text").Like(types.NewPattern(types.String("foo"), types.Wildcard))), + ast.Permit().When(ast.String("text").Like(types.NewPattern(types.String("foo"), types.Wildcard()))), testutil.OK, }, { diff --git a/internal/parser/cedar_unmarshal.go b/internal/parser/cedar_unmarshal.go index ec284fe3..eeb005fa 100644 --- a/internal/parser/cedar_unmarshal.go +++ b/internal/parser/cedar_unmarshal.go @@ -557,7 +557,7 @@ func (p *parser) like(lhs ast.Node) (ast.Node, error) { patternRaw := t.Text patternRaw = strings.TrimPrefix(patternRaw, "\"") patternRaw = strings.TrimSuffix(patternRaw, "\"") - pattern, err := types.ParsePattern(patternRaw) + pattern, err := ParsePattern(patternRaw) if err != nil { return ast.Node{}, err } diff --git a/internal/parser/cedar_unmarshal_test.go b/internal/parser/cedar_unmarshal_test.go index 98e975ea..937e1cd9 100644 --- a/internal/parser/cedar_unmarshal_test.go +++ b/internal/parser/cedar_unmarshal_test.go @@ -300,7 +300,7 @@ when { principal.firstName like "joh\*nny" };`, "like wildcard", `permit ( principal, action, resource ) when { principal.firstName like "*" };`, - ast.Permit().When(ast.Principal().Access("firstName").Like(types.NewPattern(types.Wildcard))), + ast.Permit().When(ast.Principal().Access("firstName").Like(types.NewPattern(types.Wildcard()))), }, { "is", diff --git a/internal/parser/pattern.go b/internal/parser/pattern.go new file mode 100644 index 00000000..0450bf77 --- /dev/null +++ b/internal/parser/pattern.go @@ -0,0 +1,26 @@ +package parser + +import ( + "github.com/cedar-policy/cedar-go/internal/rust" + "github.com/cedar-policy/cedar-go/types" +) + +// ParsePattern will parse an unquoted rust-style string with \*'s in it. +func ParsePattern(v string) (types.Pattern, error) { + b := []byte(v) + var comps []types.PatternComponent + for len(b) > 0 { + for len(b) > 0 && b[0] == '*' { + b = b[1:] + comps = append(comps, types.Wildcard()) + } + var err error + var literal string + literal, b, err = rust.Unquote(b, true) + if err != nil { + return types.Pattern{}, err + } + comps = append(comps, types.String(literal)) + } + return types.NewPattern(comps...), nil +} diff --git a/internal/parser/pattern_test.go b/internal/parser/pattern_test.go new file mode 100644 index 00000000..e686e322 --- /dev/null +++ b/internal/parser/pattern_test.go @@ -0,0 +1,91 @@ +package parser + +import ( + "testing" + + "github.com/cedar-policy/cedar-go/internal/testutil" + "github.com/cedar-policy/cedar-go/types" +) + +func TestParsePattern(t *testing.T) { + t.Parallel() + a := types.String("a") + tests := []struct { + input string + wantOk bool + want types.Pattern + wantErr string + }{ + {"", true, types.NewPattern(), ""}, + {"a", true, types.NewPattern(a), ""}, + {"*", true, types.NewPattern(types.Wildcard()), ""}, + {"*a", true, types.NewPattern(types.Wildcard(), a), ""}, + {"a*", true, types.NewPattern(a, types.Wildcard()), ""}, + {"**", true, types.NewPattern(types.Wildcard()), ""}, + {"**a", true, types.NewPattern(types.Wildcard(), a), ""}, + {"a**", true, types.NewPattern(a, types.Wildcard()), ""}, + {"*a*", true, types.NewPattern(types.Wildcard(), a, types.Wildcard()), ""}, + {"**a**", true, types.NewPattern(types.Wildcard(), a, types.Wildcard()), ""}, + {"abra*ca", true, types.NewPattern(types.String("abra"), types.Wildcard(), types.String("ca")), ""}, + {"abra**ca", true, types.NewPattern(types.String("abra"), types.Wildcard(), types.String("ca")), ""}, + {"*abra*ca", true, types.NewPattern(types.Wildcard(), types.String("abra"), types.Wildcard(), types.String("ca")), ""}, + {"abra*ca*", true, types.NewPattern(types.String("abra"), types.Wildcard(), types.String("ca"), types.Wildcard()), ""}, + {"*abra*ca*", true, types.NewPattern(types.Wildcard(), types.String("abra"), types.Wildcard(), types.String("ca"), types.Wildcard()), ""}, + {"*abra*ca*dabra", true, types.NewPattern(types.Wildcard(), types.String("abra"), types.Wildcard(), types.String("ca"), types.Wildcard(), types.String("dabra")), ""}, + {`*abra*c\**da\*bra`, true, types.NewPattern(types.Wildcard(), types.String("abra"), types.Wildcard(), types.String("c*"), types.Wildcard(), types.String("da*bra")), ""}, + {`\u`, false, types.Pattern{}, "bad unicode rune"}, + } + for _, tt := range tests { + tt := tt + t.Run(tt.input, func(t *testing.T) { + t.Parallel() + got, err := ParsePattern(tt.input) + if err != nil { + testutil.Equals(t, tt.wantOk, false) + testutil.Equals(t, err.Error(), tt.wantErr) + } else { + testutil.Equals(t, tt.wantOk, true) + testutil.Equals(t, got, tt.want) + } + }) + } +} + +func TestMatch(t *testing.T) { + t.Parallel() + tests := []struct { + pattern string + target string + want bool + }{ + {`""`, "", true}, + {`""`, "hello", false}, + {`"*"`, "hello", true}, + {`"e"`, "hello", false}, + {`"*e"`, "hello", false}, + {`"*e*"`, "hello", true}, + {`"hello"`, "hello", true}, + {`"hello*"`, "hello", true}, + {`"*h*llo*"`, "hello", true}, + {`"h*e*o"`, "hello", true}, + {`"h*e**o"`, "hello", true}, + {`"h*z*o"`, "hello", false}, + + {`"\u{210d}*"`, "ℍello", true}, + {`"\u{210d}*"`, "Hello", false}, + + {`"\*\**\*\*"`, "**foo**", true}, + {`"\*\**\*\*"`, "**bar**", true}, + {`"\*\**\*\*"`, "*bar*", false}, + } + for _, tt := range tests { + tt := tt + t.Run(tt.pattern+":"+tt.target, func(t *testing.T) { + t.Parallel() + pat, err := ParsePattern(tt.pattern[1 : len(tt.pattern)-1]) + testutil.OK(t, err) + got := pat.Match(types.String(tt.target)) + testutil.Equals(t, got, tt.want) + }) + } +} diff --git a/types/pattern.go b/types/pattern.go index ad64bab5..ba5288e7 100644 --- a/types/pattern.go +++ b/types/pattern.go @@ -6,8 +6,6 @@ import ( "fmt" "strconv" "strings" - - "github.com/cedar-policy/cedar-go/internal/rust" ) var errJSONInvalidPatternComponent = fmt.Errorf("invalid pattern component") @@ -139,26 +137,6 @@ func matchChunk(chunk, s string) (rest string, ok bool) { return s, true } -// ParsePattern will parse an unquoted rust-style string with \*'s in it. -func ParsePattern(v string) (Pattern, error) { - b := []byte(v) - var comps []PatternComponent - for len(b) > 0 { - for len(b) > 0 && b[0] == '*' { - b = b[1:] - comps = append(comps, Wildcard()) - } - var err error - var literal string - literal, b, err = rust.Unquote(b, true) - if err != nil { - return Pattern{}, err - } - comps = append(comps, String(literal)) - } - return NewPattern(comps...), nil -} - func (p Pattern) MarshalJSON() ([]byte, error) { var buf bytes.Buffer buf.WriteRune('[') diff --git a/types/patttern_test.go b/types/patttern_test.go index 27837f1f..19381555 100644 --- a/types/patttern_test.go +++ b/types/patttern_test.go @@ -19,89 +19,6 @@ func TestPatternFromBuilder(t *testing.T) { }) } -func TestParsePattern(t *testing.T) { - t.Parallel() - a := String("a") - tests := []struct { - input string - wantOk bool - want Pattern - wantErr string - }{ - {"", true, NewPattern(), ""}, - {"a", true, NewPattern(a), ""}, - {"*", true, NewPattern(Wildcard()), ""}, - {"*a", true, NewPattern(Wildcard(), a), ""}, - {"a*", true, NewPattern(a, Wildcard()), ""}, - {"**", true, NewPattern(Wildcard()), ""}, - {"**a", true, NewPattern(Wildcard(), a), ""}, - {"a**", true, NewPattern(a, Wildcard()), ""}, - {"*a*", true, NewPattern(Wildcard(), a, Wildcard()), ""}, - {"**a**", true, NewPattern(Wildcard(), a, Wildcard()), ""}, - {"abra*ca", true, NewPattern(String("abra"), Wildcard(), String("ca")), ""}, - {"abra**ca", true, NewPattern(String("abra"), Wildcard(), String("ca")), ""}, - {"*abra*ca", true, NewPattern(Wildcard(), String("abra"), Wildcard(), String("ca")), ""}, - {"abra*ca*", true, NewPattern(String("abra"), Wildcard(), String("ca"), Wildcard()), ""}, - {"*abra*ca*", true, NewPattern(Wildcard(), String("abra"), Wildcard(), String("ca"), Wildcard()), ""}, - {"*abra*ca*dabra", true, NewPattern(Wildcard(), String("abra"), Wildcard(), String("ca"), Wildcard(), String("dabra")), ""}, - {`*abra*c\**da\*bra`, true, NewPattern(Wildcard(), String("abra"), Wildcard(), String("c*"), Wildcard(), String("da*bra")), ""}, - {`\u`, false, Pattern{}, "bad unicode rune"}, - } - for _, tt := range tests { - tt := tt - t.Run(tt.input, func(t *testing.T) { - t.Parallel() - got, err := ParsePattern(tt.input) - if err != nil { - testutil.Equals(t, tt.wantOk, false) - testutil.Equals(t, err.Error(), tt.wantErr) - } else { - testutil.Equals(t, tt.wantOk, true) - testutil.Equals(t, got, tt.want) - } - }) - } -} - -func TestMatch(t *testing.T) { - t.Parallel() - tests := []struct { - pattern string - target string - want bool - }{ - {`""`, "", true}, - {`""`, "hello", false}, - {`"*"`, "hello", true}, - {`"e"`, "hello", false}, - {`"*e"`, "hello", false}, - {`"*e*"`, "hello", true}, - {`"hello"`, "hello", true}, - {`"hello*"`, "hello", true}, - {`"*h*llo*"`, "hello", true}, - {`"h*e*o"`, "hello", true}, - {`"h*e**o"`, "hello", true}, - {`"h*z*o"`, "hello", false}, - - {`"\u{210d}*"`, "ℍello", true}, - {`"\u{210d}*"`, "Hello", false}, - - {`"\*\**\*\*"`, "**foo**", true}, - {`"\*\**\*\*"`, "**bar**", true}, - {`"\*\**\*\*"`, "*bar*", false}, - } - for _, tt := range tests { - tt := tt - t.Run(tt.pattern+":"+tt.target, func(t *testing.T) { - t.Parallel() - pat, err := ParsePattern(tt.pattern[1 : len(tt.pattern)-1]) - testutil.OK(t, err) - got := pat.Match(String(tt.target)) - testutil.Equals(t, got, tt.want) - }) - } -} - func TestJSON(t *testing.T) { t.Parallel() tests := []struct { From db6dc4f5b9c5c0d68cf520c176c3c0d882f2f916 Mon Sep 17 00:00:00 2001 From: philhassey Date: Tue, 20 Aug 2024 15:29:31 -0600 Subject: [PATCH 191/216] types: move entity and entities into types Addresses IDX-142 Signed-off-by: philhassey --- authorize.go | 3 +- authorize_test.go | 133 +++++++++--------- corpus_test.go | 5 +- internal/eval/evalers.go | 5 +- internal/eval/evalers_test.go | 23 ++- {internal/entities => types}/entities.go | 11 +- {internal/entities => types}/entities_test.go | 16 +-- 7 files changed, 95 insertions(+), 101 deletions(-) rename {internal/entities => types}/entities.go (80%) rename {internal/entities => types}/entities_test.go (91%) diff --git a/authorize.go b/authorize.go index e23b0ac1..56b85f5d 100644 --- a/authorize.go +++ b/authorize.go @@ -3,7 +3,6 @@ package cedar import ( "fmt" - "github.com/cedar-policy/cedar-go/internal/entities" "github.com/cedar-policy/cedar-go/internal/eval" "github.com/cedar-policy/cedar-go/types" ) @@ -67,7 +66,7 @@ type Request struct { // IsAuthorized uses the combination of the PolicySet and Entities to determine // if the given Request to determine Decision and Diagnostic. -func (p PolicySet) IsAuthorized(entityMap entities.Entities, req Request) (Decision, Diagnostic) { +func (p PolicySet) IsAuthorized(entityMap types.Entities, req Request) (Decision, Diagnostic) { c := &eval.Context{ Entities: entityMap, Principal: req.Principal, diff --git a/authorize_test.go b/authorize_test.go index a3610e9b..bab9fe43 100644 --- a/authorize_test.go +++ b/authorize_test.go @@ -4,7 +4,6 @@ import ( "encoding/json" "testing" - "github.com/cedar-policy/cedar-go/internal/entities" "github.com/cedar-policy/cedar-go/internal/testutil" "github.com/cedar-policy/cedar-go/types" ) @@ -17,7 +16,7 @@ func TestIsAuthorized(t *testing.T) { tests := []struct { Name string Policy string - Entities entities.Entities + Entities types.Entities Principal, Action, Resource types.EntityUID Context types.Record Want Decision @@ -27,7 +26,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "simple-permit", Policy: `permit(principal,action,resource);`, - Entities: entities.Entities{}, + Entities: types.Entities{}, Principal: cuzco, Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), @@ -38,7 +37,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "simple-forbid", Policy: `forbid(principal,action,resource);`, - Entities: entities.Entities{}, + Entities: types.Entities{}, Principal: cuzco, Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), @@ -49,7 +48,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "no-permit", Policy: `permit(principal,action,resource in asdf::"1234");`, - Entities: entities.Entities{}, + Entities: types.Entities{}, Principal: cuzco, Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), @@ -60,7 +59,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "error-in-policy", Policy: `permit(principal,action,resource) when { resource in "foo" };`, - Entities: entities.Entities{}, + Entities: types.Entities{}, Principal: cuzco, Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), @@ -73,7 +72,7 @@ func TestIsAuthorized(t *testing.T) { Policy: `permit(principal,action,resource) when { resource in "foo" }; permit(principal,action,resource); `, - Entities: entities.Entities{}, + Entities: types.Entities{}, Principal: cuzco, Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), @@ -84,7 +83,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-requires-context-success", Policy: `permit(principal,action,resource) when { context.x == 42 };`, - Entities: entities.Entities{}, + Entities: types.Entities{}, Principal: cuzco, Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), @@ -95,7 +94,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-requires-context-fail", Policy: `permit(principal,action,resource) when { context.x == 42 };`, - Entities: entities.Entities{}, + Entities: types.Entities{}, Principal: cuzco, Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), @@ -106,8 +105,8 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-requires-entities-success", Policy: `permit(principal,action,resource) when { principal.x == 42 };`, - Entities: entities.Entities{ - cuzco: entities.Entity{ + Entities: types.Entities{ + cuzco: types.Entity{ UID: cuzco, Attributes: types.Record{"x": types.Long(42)}, }, @@ -122,8 +121,8 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-requires-entities-fail", Policy: `permit(principal,action,resource) when { principal.x == 42 };`, - Entities: entities.Entities{ - cuzco: entities.Entity{ + Entities: types.Entities{ + cuzco: types.Entity{ UID: cuzco, Attributes: types.Record{"x": types.Long(43)}, }, @@ -138,8 +137,8 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-requires-entities-parent-success", Policy: `permit(principal,action,resource) when { principal in parent::"bob" };`, - Entities: entities.Entities{ - cuzco: entities.Entity{ + Entities: types.Entities{ + cuzco: types.Entity{ UID: cuzco, Parents: []types.EntityUID{types.NewEntityUID("parent", "bob")}, }, @@ -154,7 +153,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-principal-equals", Policy: `permit(principal == coder::"cuzco",action,resource);`, - Entities: entities.Entities{}, + Entities: types.Entities{}, Principal: cuzco, Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), @@ -165,8 +164,8 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-principal-in", Policy: `permit(principal in team::"osiris",action,resource);`, - Entities: entities.Entities{ - cuzco: entities.Entity{ + Entities: types.Entities{ + cuzco: types.Entity{ UID: cuzco, Parents: []types.EntityUID{types.NewEntityUID("team", "osiris")}, }, @@ -181,7 +180,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-action-equals", Policy: `permit(principal,action == table::"drop",resource);`, - Entities: entities.Entities{}, + Entities: types.Entities{}, Principal: cuzco, Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), @@ -192,8 +191,8 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-action-in", Policy: `permit(principal,action in scary::"stuff",resource);`, - Entities: entities.Entities{ - dropTable: entities.Entity{ + Entities: types.Entities{ + dropTable: types.Entity{ UID: dropTable, Parents: []types.EntityUID{types.NewEntityUID("scary", "stuff")}, }, @@ -208,8 +207,8 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-action-in-set", Policy: `permit(principal,action in [scary::"stuff"],resource);`, - Entities: entities.Entities{ - dropTable: entities.Entity{ + Entities: types.Entities{ + dropTable: types.Entity{ UID: dropTable, Parents: []types.EntityUID{types.NewEntityUID("scary", "stuff")}, }, @@ -224,7 +223,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-resource-equals", Policy: `permit(principal,action,resource == table::"whatever");`, - Entities: entities.Entities{}, + Entities: types.Entities{}, Principal: cuzco, Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), @@ -235,7 +234,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-unless", Policy: `permit(principal,action,resource) unless { false };`, - Entities: entities.Entities{}, + Entities: types.Entities{}, Principal: cuzco, Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), @@ -246,7 +245,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-if", Policy: `permit(principal,action,resource) when { (if true then true else true) };`, - Entities: entities.Entities{}, + Entities: types.Entities{}, Principal: cuzco, Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), @@ -257,7 +256,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-or", Policy: `permit(principal,action,resource) when { (true || false) };`, - Entities: entities.Entities{}, + Entities: types.Entities{}, Principal: cuzco, Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), @@ -268,7 +267,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-and", Policy: `permit(principal,action,resource) when { (true && true) };`, - Entities: entities.Entities{}, + Entities: types.Entities{}, Principal: cuzco, Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), @@ -279,7 +278,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-relations", Policy: `permit(principal,action,resource) when { (1<2) && (1<=1) && (2>1) && (1>=1) && (1!=2) && (1==1)};`, - Entities: entities.Entities{}, + Entities: types.Entities{}, Principal: cuzco, Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), @@ -290,7 +289,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-relations-in", Policy: `permit(principal,action,resource) when { principal in principal };`, - Entities: entities.Entities{}, + Entities: types.Entities{}, Principal: cuzco, Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), @@ -301,8 +300,8 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-relations-has", Policy: `permit(principal,action,resource) when { principal has name };`, - Entities: entities.Entities{ - cuzco: entities.Entity{ + Entities: types.Entities{ + cuzco: types.Entity{ UID: cuzco, Attributes: types.Record{"name": types.String("bob")}, }, @@ -317,7 +316,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-add-sub", Policy: `permit(principal,action,resource) when { 40+3-1==42 };`, - Entities: entities.Entities{}, + Entities: types.Entities{}, Principal: cuzco, Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), @@ -328,7 +327,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-mul", Policy: `permit(principal,action,resource) when { 6*7==42 };`, - Entities: entities.Entities{}, + Entities: types.Entities{}, Principal: cuzco, Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), @@ -339,7 +338,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-negate", Policy: `permit(principal,action,resource) when { -42==-42 };`, - Entities: entities.Entities{}, + Entities: types.Entities{}, Principal: cuzco, Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), @@ -350,7 +349,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-not", Policy: `permit(principal,action,resource) when { !(1+1==42) };`, - Entities: entities.Entities{}, + Entities: types.Entities{}, Principal: cuzco, Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), @@ -361,7 +360,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-set", Policy: `permit(principal,action,resource) when { [1,2,3].contains(2) };`, - Entities: entities.Entities{}, + Entities: types.Entities{}, Principal: cuzco, Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), @@ -372,7 +371,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-record", Policy: `permit(principal,action,resource) when { {name:"bob"} has name };`, - Entities: entities.Entities{}, + Entities: types.Entities{}, Principal: cuzco, Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), @@ -383,7 +382,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-action", Policy: `permit(principal,action,resource) when { action in action };`, - Entities: entities.Entities{}, + Entities: types.Entities{}, Principal: cuzco, Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), @@ -394,7 +393,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-set-contains-ok", Policy: `permit(principal,action,resource) when { [1,2,3].contains(2) };`, - Entities: entities.Entities{}, + Entities: types.Entities{}, Principal: cuzco, Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), @@ -405,7 +404,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-set-contains-error", Policy: `permit(principal,action,resource) when { [1,2,3].contains(2,3) };`, - Entities: entities.Entities{}, + Entities: types.Entities{}, Principal: cuzco, Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), @@ -417,7 +416,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-set-containsAll-ok", Policy: `permit(principal,action,resource) when { [1,2,3].containsAll([2,3]) };`, - Entities: entities.Entities{}, + Entities: types.Entities{}, Principal: cuzco, Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), @@ -428,7 +427,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-set-containsAll-error", Policy: `permit(principal,action,resource) when { [1,2,3].containsAll(2,3) };`, - Entities: entities.Entities{}, + Entities: types.Entities{}, Principal: cuzco, Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), @@ -440,7 +439,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-set-containsAny-ok", Policy: `permit(principal,action,resource) when { [1,2,3].containsAny([2,5]) };`, - Entities: entities.Entities{}, + Entities: types.Entities{}, Principal: cuzco, Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), @@ -451,7 +450,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-set-containsAny-error", Policy: `permit(principal,action,resource) when { [1,2,3].containsAny(2,5) };`, - Entities: entities.Entities{}, + Entities: types.Entities{}, Principal: cuzco, Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), @@ -463,7 +462,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-record-attr", Policy: `permit(principal,action,resource) when { {name:"bob"}["name"] == "bob" };`, - Entities: entities.Entities{}, + Entities: types.Entities{}, Principal: cuzco, Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), @@ -474,7 +473,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-unknown-method", Policy: `permit(principal,action,resource) when { [1,2,3].shuffle() };`, - Entities: entities.Entities{}, + Entities: types.Entities{}, Principal: cuzco, Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), @@ -486,7 +485,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-like", Policy: `permit(principal,action,resource) when { "bananas" like "*nan*" };`, - Entities: entities.Entities{}, + Entities: types.Entities{}, Principal: cuzco, Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), @@ -497,7 +496,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-unknown-ext-fun", Policy: `permit(principal,action,resource) when { fooBar("10") };`, - Entities: entities.Entities{}, + Entities: types.Entities{}, Principal: cuzco, Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), @@ -513,7 +512,7 @@ func TestIsAuthorized(t *testing.T) { decimal("10.0").lessThanOrEqual(decimal("11.0")) && decimal("10.0").greaterThan(decimal("9.0")) && decimal("10.0").greaterThanOrEqual(decimal("9.0")) };`, - Entities: entities.Entities{}, + Entities: types.Entities{}, Principal: cuzco, Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), @@ -524,7 +523,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-decimal-fun-wrong-arity", Policy: `permit(principal,action,resource) when { decimal(1, 2) };`, - Entities: entities.Entities{}, + Entities: types.Entities{}, Principal: cuzco, Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), @@ -540,7 +539,7 @@ func TestIsAuthorized(t *testing.T) { ip("::1").isLoopback() && ip("224.1.2.3").isMulticast() && ip("127.0.0.1").isInRange(ip("127.0.0.0/16"))};`, - Entities: entities.Entities{}, + Entities: types.Entities{}, Principal: cuzco, Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), @@ -551,7 +550,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-ip-fun-wrong-arity", Policy: `permit(principal,action,resource) when { ip() };`, - Entities: entities.Entities{}, + Entities: types.Entities{}, Principal: cuzco, Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), @@ -562,7 +561,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-isIpv4-wrong-arity", Policy: `permit(principal,action,resource) when { ip("1.2.3.4").isIpv4(true) };`, - Entities: entities.Entities{}, + Entities: types.Entities{}, Principal: cuzco, Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), @@ -573,7 +572,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-isIpv6-wrong-arity", Policy: `permit(principal,action,resource) when { ip("1.2.3.4").isIpv6(true) };`, - Entities: entities.Entities{}, + Entities: types.Entities{}, Principal: cuzco, Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), @@ -584,7 +583,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-isLoopback-wrong-arity", Policy: `permit(principal,action,resource) when { ip("1.2.3.4").isLoopback(true) };`, - Entities: entities.Entities{}, + Entities: types.Entities{}, Principal: cuzco, Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), @@ -595,7 +594,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-isMulticast-wrong-arity", Policy: `permit(principal,action,resource) when { ip("1.2.3.4").isMulticast(true) };`, - Entities: entities.Entities{}, + Entities: types.Entities{}, Principal: cuzco, Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), @@ -606,7 +605,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-isInRange-wrong-arity", Policy: `permit(principal,action,resource) when { ip("1.2.3.4").isInRange() };`, - Entities: entities.Entities{}, + Entities: types.Entities{}, Principal: cuzco, Action: dropTable, Resource: types.NewEntityUID("table", "whatever"), @@ -617,7 +616,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "negative-unary-op", Policy: `permit(principal,action,resource) when { -context.value > 0 };`, - Entities: entities.Entities{}, + Entities: types.Entities{}, Context: types.Record{"value": types.Long(-42)}, Want: true, DiagErr: 0, @@ -625,7 +624,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "principal-is", Policy: `permit(principal is Actor,action,resource);`, - Entities: entities.Entities{}, + Entities: types.Entities{}, Principal: types.NewEntityUID("Actor", "cuzco"), Action: types.NewEntityUID("Action", "drop"), Resource: types.NewEntityUID("Resource", "table"), @@ -636,7 +635,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "principal-is-in", Policy: `permit(principal is Actor in Actor::"cuzco",action,resource);`, - Entities: entities.Entities{}, + Entities: types.Entities{}, Principal: types.NewEntityUID("Actor", "cuzco"), Action: types.NewEntityUID("Action", "drop"), Resource: types.NewEntityUID("Resource", "table"), @@ -647,7 +646,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "resource-is", Policy: `permit(principal,action,resource is Resource);`, - Entities: entities.Entities{}, + Entities: types.Entities{}, Principal: types.NewEntityUID("Actor", "cuzco"), Action: types.NewEntityUID("Action", "drop"), Resource: types.NewEntityUID("Resource", "table"), @@ -658,7 +657,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "resource-is-in", Policy: `permit(principal,action,resource is Resource in Resource::"table");`, - Entities: entities.Entities{}, + Entities: types.Entities{}, Principal: types.NewEntityUID("Actor", "cuzco"), Action: types.NewEntityUID("Action", "drop"), Resource: types.NewEntityUID("Resource", "table"), @@ -669,7 +668,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "when-is", Policy: `permit(principal,action,resource) when { resource is Resource };`, - Entities: entities.Entities{}, + Entities: types.Entities{}, Principal: types.NewEntityUID("Actor", "cuzco"), Action: types.NewEntityUID("Action", "drop"), Resource: types.NewEntityUID("Resource", "table"), @@ -680,7 +679,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "when-is-in", Policy: `permit(principal,action,resource) when { resource is Resource in Resource::"table" };`, - Entities: entities.Entities{}, + Entities: types.Entities{}, Principal: types.NewEntityUID("Actor", "cuzco"), Action: types.NewEntityUID("Action", "drop"), Resource: types.NewEntityUID("Resource", "table"), @@ -691,8 +690,8 @@ func TestIsAuthorized(t *testing.T) { { Name: "when-is-in", Policy: `permit(principal,action,resource) when { resource is Resource in Parent::"id" };`, - Entities: entities.Entities{ - types.NewEntityUID("Resource", "table"): entities.Entity{ + Entities: types.Entities{ + types.NewEntityUID("Resource", "table"): types.Entity{ UID: types.NewEntityUID("Resource", "table"), Parents: []types.EntityUID{types.NewEntityUID("Parent", "id")}, }, diff --git a/corpus_test.go b/corpus_test.go index 5e8d72e2..e3ae4dcd 100644 --- a/corpus_test.go +++ b/corpus_test.go @@ -12,7 +12,6 @@ import ( "strings" "testing" - entities2 "github.com/cedar-policy/cedar-go/internal/entities" "github.com/cedar-policy/cedar-go/internal/testutil" "github.com/cedar-policy/cedar-go/types" ) @@ -140,7 +139,7 @@ func TestCorpus(t *testing.T) { t.Fatal("error reading entities content", err) } - var entities entities2.Entities + var entities types.Entities if err := json.Unmarshal(entitiesContent, &entities); err != nil { t.Fatal("error unmarshalling test", err) } @@ -338,7 +337,7 @@ func TestCorpusRelated(t *testing.T) { t.Parallel() policy, err := NewPolicySetFromBytes("", []byte(tt.policy)) testutil.OK(t, err) - ok, diag := policy.IsAuthorized(entities2.Entities{}, tt.request) + ok, diag := policy.IsAuthorized(types.Entities{}, tt.request) testutil.Equals(t, ok, tt.decision) var reasons []PolicyID for _, n := range diag.Reasons { diff --git a/internal/eval/evalers.go b/internal/eval/evalers.go index afd9dc33..fa5f6a3f 100644 --- a/internal/eval/evalers.go +++ b/internal/eval/evalers.go @@ -3,7 +3,6 @@ package eval import ( "fmt" - "github.com/cedar-policy/cedar-go/internal/entities" "github.com/cedar-policy/cedar-go/types" ) @@ -19,7 +18,7 @@ func zeroValue() types.Value { } type Context struct { - Entities entities.Entities + Entities types.Entities Principal, Action, Resource types.Value Context types.Value } @@ -913,7 +912,7 @@ func newInEval(lhs, rhs Evaler) *inEval { return &inEval{lhs: lhs, rhs: rhs} } -func entityIn(entity types.EntityUID, query map[types.EntityUID]struct{}, entityMap entities.Entities) bool { +func entityIn(entity types.EntityUID, query map[types.EntityUID]struct{}, entityMap types.Entities) bool { checked := map[types.EntityUID]struct{}{} toCheck := []types.EntityUID{entity} for len(toCheck) > 0 { diff --git a/internal/eval/evalers_test.go b/internal/eval/evalers_test.go index 3baa3063..31cccc67 100644 --- a/internal/eval/evalers_test.go +++ b/internal/eval/evalers_test.go @@ -6,7 +6,6 @@ import ( "strings" "testing" - "github.com/cedar-policy/cedar-go/internal/entities" "github.com/cedar-policy/cedar-go/internal/parser" "github.com/cedar-policy/cedar-go/internal/testutil" "github.com/cedar-policy/cedar-go/types" @@ -1309,12 +1308,12 @@ func TestAttributeAccessNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newAttributeAccessEval(tt.object, tt.attribute) - entity := entities.Entity{ + entity := types.Entity{ UID: types.NewEntityUID("knownType", "knownID"), Attributes: types.Record{"knownAttr": types.Long(42)}, } v, err := n.Eval(&Context{ - Entities: entities.Entities{ + Entities: types.Entities{ entity.UID: entity, }, }) @@ -1366,12 +1365,12 @@ func TestHasNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newHasEval(tt.record, tt.attribute) - entity := entities.Entity{ + entity := types.Entity{ UID: types.NewEntityUID("knownType", "knownID"), Attributes: types.Record{"knownAttr": types.Long(42)}, } v, err := n.Eval(&Context{ - Entities: entities.Entities{ + Entities: types.Entities{ entity.UID: entity, }, }) @@ -1563,14 +1562,14 @@ func TestEntityIn(t *testing.T) { for _, v := range tt.rhs { rhs[strEnt(v)] = struct{}{} } - entityMap := entities.Entities{} + entityMap := types.Entities{} for k, p := range tt.parents { var ps []types.EntityUID for _, pp := range p { ps = append(ps, strEnt(pp)) } uid := strEnt(k) - entityMap[uid] = entities.Entity{ + entityMap[uid] = types.Entity{ UID: uid, Parents: ps, } @@ -1584,19 +1583,19 @@ func TestEntityIn(t *testing.T) { // This test will run for a very long time (O(2^100)) if there isn't caching. ) - entityMap := entities.Entities{} + entityMap := types.Entities{} for i := 0; i < 100; i++ { p := []types.EntityUID{ types.NewEntityUID(types.Path(fmt.Sprint(i+1)), "1"), types.NewEntityUID(types.Path(fmt.Sprint(i+1)), "2"), } uid1 := types.NewEntityUID(types.Path(fmt.Sprint(i)), "1") - entityMap[uid1] = entities.Entity{ + entityMap[uid1] = types.Entity{ UID: uid1, Parents: p, } uid2 := types.NewEntityUID(types.Path(fmt.Sprint(i)), "2") - entityMap[uid2] = entities.Entity{ + entityMap[uid2] = types.Entity{ UID: uid2, Parents: p, } @@ -1729,14 +1728,14 @@ func TestInNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newInEval(tt.lhs, tt.rhs) - entityMap := entities.Entities{} + entityMap := types.Entities{} for k, p := range tt.parents { var ps []types.EntityUID for _, pp := range p { ps = append(ps, strEnt(pp)) } uid := strEnt(k) - entityMap[uid] = entities.Entity{ + entityMap[uid] = types.Entity{ UID: uid, Parents: ps, } diff --git a/internal/entities/entities.go b/types/entities.go similarity index 80% rename from internal/entities/entities.go rename to types/entities.go index 45d7757c..55bd5e14 100644 --- a/internal/entities/entities.go +++ b/types/entities.go @@ -1,24 +1,23 @@ -package entities +package types import ( "encoding/json" "slices" "strings" - "github.com/cedar-policy/cedar-go/types" "golang.org/x/exp/maps" ) // An Entities is a collection of all the Entities that are needed to evaluate // authorization requests. The key is an EntityUID which uniquely identifies // the Entity (it must be the same as the UID within the Entity itself.) -type Entities map[types.EntityUID]Entity +type Entities map[EntityUID]Entity // An Entity defines the parents and attributes for an EntityUID. type Entity struct { - UID types.EntityUID `json:"uid"` - Parents []types.EntityUID `json:"parents,omitempty"` - Attributes types.Record `json:"attrs"` + UID EntityUID `json:"uid"` + Parents []EntityUID `json:"parents,omitempty"` + Attributes Record `json:"attrs"` } func (e Entities) MarshalJSON() ([]byte, error) { diff --git a/internal/entities/entities_test.go b/types/entities_test.go similarity index 91% rename from internal/entities/entities_test.go rename to types/entities_test.go index 1ccaceaa..ed27e64d 100644 --- a/internal/entities/entities_test.go +++ b/types/entities_test.go @@ -1,4 +1,4 @@ -package entities +package types_test import ( "encoding/json" @@ -12,7 +12,7 @@ func TestEntities(t *testing.T) { t.Parallel() t.Run("Clone", func(t *testing.T) { t.Parallel() - e := Entities{ + e := types.Entities{ types.EntityUID{Type: "A", ID: "A"}: {}, types.EntityUID{Type: "A", ID: "B"}: {}, types.EntityUID{Type: "B", ID: "A"}: {}, @@ -28,8 +28,8 @@ func TestEntitiesJSON(t *testing.T) { t.Parallel() t.Run("Marshal", func(t *testing.T) { t.Parallel() - e := Entities{} - ent := Entity{ + e := types.Entities{} + ent := types.Entity{ UID: types.NewEntityUID("Type", "id"), Parents: []types.EntityUID{}, Attributes: types.Record{"key": types.Long(42)}, @@ -43,11 +43,11 @@ func TestEntitiesJSON(t *testing.T) { t.Run("Unmarshal", func(t *testing.T) { t.Parallel() b := []byte(`[{"uid":{"type":"Type","id":"id"},"parents":[],"attrs":{"key":42}}]`) - var e Entities + var e types.Entities err := json.Unmarshal(b, &e) testutil.OK(t, err) - want := Entities{} - ent := Entity{ + want := types.Entities{} + ent := types.Entity{ UID: types.NewEntityUID("Type", "id"), Parents: []types.EntityUID{}, Attributes: types.Record{"key": types.Long(42)}, @@ -58,7 +58,7 @@ func TestEntitiesJSON(t *testing.T) { t.Run("UnmarshalErr", func(t *testing.T) { t.Parallel() - var e Entities + var e types.Entities err := e.UnmarshalJSON([]byte(`!@#$`)) testutil.Error(t, err) }) From 0eea30ad95ce182d587b4c3480426e9741f03f5f Mon Sep 17 00:00:00 2001 From: philhassey Date: Tue, 20 Aug 2024 15:32:37 -0600 Subject: [PATCH 192/216] internal/parser: remove test related TODO Addresses IDX-142 Signed-off-by: philhassey --- internal/parser/cedar_parse_test.go | 23 +++++++++-------------- 1 file changed, 9 insertions(+), 14 deletions(-) diff --git a/internal/parser/cedar_parse_test.go b/internal/parser/cedar_parse_test.go index d3a2581f..28593964 100644 --- a/internal/parser/cedar_parse_test.go +++ b/internal/parser/cedar_parse_test.go @@ -302,26 +302,21 @@ func TestParse(t *testing.T) { return } testutil.OK(t, err) - if len(policies) != 1 { - // TODO: handle 0, > 1 - return - } // N.B. Until we support the re-rendering of comments, we have to ignore the position for the purposes of // these tests (see test "ex1") - policies[0].Position = ast.Position{Offset: 0, Line: 1, Column: 1} + for _, pp := range policies { + pp.Position = ast.Position{Offset: 0, Line: 1, Column: 1} - var buf bytes.Buffer - pp := policies[0] - pp.MarshalCedar(&buf) + var buf bytes.Buffer + pp.MarshalCedar(&buf) - var p2 parser.PolicySlice - err = p2.UnmarshalCedar(buf.Bytes()) - testutil.OK(t, err) - - // TODO: support 0, > 1 - testutil.Equals(t, p2[0], policies[0]) + var p2 parser.PolicySlice + err = p2.UnmarshalCedar(buf.Bytes()) + testutil.OK(t, err) + testutil.Equals(t, p2[0], pp) + } }) } } From 922766fad6221d9ad51b34f76e68cd0f1d5af7a4 Mon Sep 17 00:00:00 2001 From: philhassey Date: Tue, 20 Aug 2024 15:50:50 -0600 Subject: [PATCH 193/216] cedar: tweak external shape of main cedar package Addresses IDX-142 Signed-off-by: philhassey --- policy.go | 11 ++-- policy_set.go | 43 +++++++++------- policy_set_test.go | 116 +++++++++++++++++++++---------------------- policy_slice.go | 18 +++---- policy_slice_test.go | 2 +- 5 files changed, 99 insertions(+), 91 deletions(-) diff --git a/policy.go b/policy.go index 4edddad0..ceab5535 100644 --- a/policy.go +++ b/policy.go @@ -8,6 +8,7 @@ import ( "github.com/cedar-policy/cedar-go/internal/eval" "github.com/cedar-policy/cedar-go/internal/json" "github.com/cedar-policy/cedar-go/internal/parser" + "github.com/cedar-policy/cedar-go/types" ) // A Policy is the parsed form of a single Cedar language policy statement. @@ -60,19 +61,19 @@ func (p *Policy) UnmarshalCedar(b []byte) error { return nil } -func NewPolicyFromAST(astIn *ast.Policy) *Policy { +func NewPolicyFromAST(astIn *ast.Policy) Policy { p := newPolicy((*internalast.Policy)(astIn)) - return &p + return p } // An Annotations is a map of key, value pairs found in the policy. Annotations // have no impact on policy evaluation. -type Annotations map[string]string +type Annotations map[types.Ident]types.String func (p Policy) Annotations() Annotations { - res := make(map[string]string, len(p.ast.Annotations)) + res := make(Annotations, len(p.ast.Annotations)) for _, e := range p.ast.Annotations { - res[string(e.Key)] = string(e.Value) + res[e.Key] = e.Value } return res } diff --git a/policy_set.go b/policy_set.go index 86deae23..9359dd0c 100644 --- a/policy_set.go +++ b/policy_set.go @@ -9,14 +9,16 @@ import ( type PolicyID string +type policyMap map[PolicyID]Policy + // PolicySet is a set of named policies against which a request can be authorized. type PolicySet struct { - policies map[PolicyID]*Policy + policies policyMap } // NewPolicySet creates a new, empty PolicySet func NewPolicySet() PolicySet { - return PolicySet{policies: map[PolicyID]*Policy{}} + return PolicySet{policies: policyMap{}} } // NewPolicySetFromBytes will create a PolicySet from the given text document with the given file name used in Position @@ -25,11 +27,11 @@ func NewPolicySet() PolicySet { // NewPolicySetFromBytes assigns default PolicyIDs to the policies contained in fileName in the format "policy" where // is incremented for each new policy found in the file. func NewPolicySetFromBytes(fileName string, document []byte) (PolicySet, error) { - policySlice, err := NewPolicySliceFromBytes(fileName, document) + policySlice, err := NewPoliciesFromBytes(fileName, document) if err != nil { return PolicySet{}, err } - policyMap := make(map[PolicyID]*Policy, len(policySlice)) + policyMap := make(policyMap, len(policySlice)) for i, p := range policySlice { policyID := PolicyID(fmt.Sprintf("policy%d", i)) policyMap[policyID] = p @@ -37,29 +39,34 @@ func NewPolicySetFromBytes(fileName string, document []byte) (PolicySet, error) return PolicySet{policies: policyMap}, nil } -// GetPolicy returns a pointer to the Policy with the given ID. If a policy with the given ID does not exist, nil is -// returned. -func (p PolicySet) GetPolicy(policyID PolicyID) *Policy { +// Get returns a pointer to the Policy with the given ID. If a policy with the given ID does not exist, an empty policy is returned. +func (p PolicySet) Get(policyID PolicyID) Policy { return p.policies[policyID] } -// UpsertPolicy inserts or updates a policy with the given ID. -func (p *PolicySet) UpsertPolicy(policyID PolicyID, policy *Policy) { +// Has indicates if the policy exists. +func (p PolicySet) Has(policyID PolicyID) bool { + _, ok := p.policies[policyID] + return ok +} + +// Upsert inserts or updates a policy with the given ID. +func (p *PolicySet) Upsert(policyID PolicyID, policy Policy) { p.policies[policyID] = policy } -// DeletePolicy removes a policy from the PolicySet. Deleting a non-existent policy is a no-op. -func (p *PolicySet) DeletePolicy(policyID PolicyID) { +// Delete removes a policy from the PolicySet. Deleting a non-existent policy is a no-op. +func (p *PolicySet) Delete(policyID PolicyID) { delete(p.policies, policyID) } -// UpsertPolicySet inserts or updates all the policies from src into this PolicySet. Policies in this PolicySet with -// identical IDs in src are clobbered by the policies from src. -func (p *PolicySet) UpsertPolicySet(src PolicySet) { - for id, policy := range src.policies { - p.policies[id] = policy - } -} +// // UpsertPolicySet inserts or updates all the policies from src into this PolicySet. Policies in this PolicySet with +// // identical IDs in src are clobbered by the policies from src. +// func (p *PolicySet) UpsertPolicySet(src PolicySet) { +// for id, policy := range src.policies { +// p.policies[id] = policy +// } +// } // MarshalCedar emits a concatenated Cedar representation of a PolicySet. The policy names are stripped, but policies // are emitted in lexicographical order by ID. diff --git a/policy_set_test.go b/policy_set_test.go index 5817dcfd..85413076 100644 --- a/policy_set_test.go +++ b/policy_set_test.go @@ -25,7 +25,7 @@ func TestNewPolicySetFromFile(t *testing.T) { t.Parallel() ps, err := cedar.NewPolicySetFromBytes("policy.cedar", []byte(`@key("value") permit (principal, action, resource);`)) testutil.OK(t, err) - testutil.Equals(t, ps.GetPolicy("policy0").Annotations(), cedar.Annotations{"key": "value"}) + testutil.Equals(t, ps.Get("policy0").Annotations(), cedar.Annotations{"key": "value"}) }) } @@ -42,12 +42,12 @@ func TestUpsertPolicy(t *testing.T) { )) ps := cedar.NewPolicySet() - ps.UpsertPolicy("policy0", policy0) - ps.UpsertPolicy("policy1", &policy1) + ps.Upsert("policy0", policy0) + ps.Upsert("policy1", policy1) - testutil.Equals(t, ps.GetPolicy("policy0"), policy0) - testutil.Equals(t, ps.GetPolicy("policy1"), &policy1) - testutil.Equals(t, ps.GetPolicy("policy2"), nil) + testutil.Equals(t, ps.Get("policy0"), policy0) + testutil.Equals(t, ps.Get("policy1"), policy1) + testutil.Equals(t, ps.Get("policy2"), cedar.Policy{}) }) t.Run("upsert", func(t *testing.T) { t.Parallel() @@ -55,68 +55,68 @@ func TestUpsertPolicy(t *testing.T) { ps := cedar.NewPolicySet() p1 := cedar.NewPolicyFromAST(ast.Forbid()) - ps.UpsertPolicy("a wavering policy", p1) + ps.Upsert("a wavering policy", p1) p2 := cedar.NewPolicyFromAST(ast.Permit()) - ps.UpsertPolicy("a wavering policy", p2) + ps.Upsert("a wavering policy", p2) - testutil.Equals(t, ps.GetPolicy("a wavering policy"), p2) + testutil.Equals(t, ps.Get("a wavering policy"), p2) }) } -func TestUpsertPolicySet(t *testing.T) { - t.Parallel() - t.Run("empty dst", func(t *testing.T) { - t.Parallel() +// func TestUpsertPolicySet(t *testing.T) { +// t.Parallel() +// t.Run("empty dst", func(t *testing.T) { +// t.Parallel() - policy0 := cedar.NewPolicyFromAST(ast.Forbid()) +// policy0 := cedar.NewPolicyFromAST(ast.Forbid()) - var policy1 cedar.Policy - testutil.OK(t, policy1.UnmarshalJSON( - []byte(`{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}}`), - )) +// var policy1 cedar.Policy +// testutil.OK(t, policy1.UnmarshalJSON( +// []byte(`{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}}`), +// )) - ps1 := cedar.NewPolicySet() - ps1.UpsertPolicy("policy0", policy0) - ps1.UpsertPolicy("policy1", &policy1) +// ps1 := cedar.NewPolicySet() +// ps1.Upsert("policy0", policy0) +// ps1.Upsert("policy1", policy1) - ps2 := cedar.NewPolicySet() - ps2.UpsertPolicySet(ps1) +// ps2 := cedar.NewPolicySet() +// ps2.UpsertPolicySet(ps1) - testutil.Equals(t, ps2.GetPolicy("policy0"), policy0) - testutil.Equals(t, ps2.GetPolicy("policy1"), &policy1) - testutil.Equals(t, ps2.GetPolicy("policy2"), nil) - }) - t.Run("upsert", func(t *testing.T) { - t.Parallel() +// testutil.Equals(t, ps2.Get("policy0"), policy0) +// testutil.Equals(t, ps2.Get("policy1"), &policy1) +// testutil.Equals(t, ps2.Get("policy2"), nil) +// }) +// t.Run("upsert", func(t *testing.T) { +// t.Parallel() - policyA := cedar.NewPolicyFromAST(ast.Forbid()) +// policyA := cedar.NewPolicyFromAST(ast.Forbid()) - var policyB cedar.Policy - testutil.OK(t, policyB.UnmarshalJSON( - []byte(`{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}}`), - )) +// var policyB cedar.Policy +// testutil.OK(t, policyB.UnmarshalJSON( +// []byte(`{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}}`), +// )) - policyC := cedar.NewPolicyFromAST(ast.Permit()) +// policyC := cedar.NewPolicyFromAST(ast.Permit()) - // ps1 maps 0 -> A and 1 -> B - ps1 := cedar.NewPolicySet() - ps1.UpsertPolicy("policy0", policyA) - ps1.UpsertPolicy("policy1", &policyB) +// // ps1 maps 0 -> A and 1 -> B +// ps1 := cedar.NewPolicySet() +// ps1.Upsert("policy0", policyA) +// ps1.Upsert("policy1", &policyB) - // ps1 maps 0 -> b and 2 -> C - ps2 := cedar.NewPolicySet() - ps2.UpsertPolicy("policy0", &policyB) - ps2.UpsertPolicy("policy2", policyC) +// // ps1 maps 0 -> b and 2 -> C +// ps2 := cedar.NewPolicySet() +// ps2.Upsert("policy0", &policyB) +// ps2.Upsert("policy2", policyC) - // Upsert should clobber ps2's policy0, insert policy1, and leave policy2 untouched - ps2.UpsertPolicySet(ps1) +// // Upsert should clobber ps2's policy0, insert policy1, and leave policy2 untouched +// ps2.UpsertPolicySet(ps1) - testutil.Equals(t, ps2.GetPolicy("policy0"), policyA) - testutil.Equals(t, ps2.GetPolicy("policy1"), &policyB) - testutil.Equals(t, ps2.GetPolicy("policy2"), policyC) - }) -} +// testutil.Equals(t, ps2.Get("policy0"), policyA) +// testutil.Equals(t, ps2.Get("policy1"), &policyB) +// testutil.Equals(t, ps2.Get("policy2"), policyC) +// }) +// } func TestDeletePolicy(t *testing.T) { t.Parallel() @@ -126,7 +126,7 @@ func TestDeletePolicy(t *testing.T) { ps := cedar.NewPolicySet() // Just verify that this doesn't crash - ps.DeletePolicy("not a policy") + ps.Delete("not a policy") }) t.Run("delete existing", func(t *testing.T) { t.Parallel() @@ -134,10 +134,10 @@ func TestDeletePolicy(t *testing.T) { ps := cedar.NewPolicySet() p1 := cedar.NewPolicyFromAST(ast.Forbid()) - ps.UpsertPolicy("a policy", p1) - ps.DeletePolicy("a policy") + ps.Upsert("a policy", p1) + ps.Delete("a policy") - testutil.Equals(t, ps.GetPolicy("a policy"), nil) + testutil.Equals(t, ps.Get("a policy"), cedar.Policy{}) }) } @@ -157,17 +157,17 @@ forbid ( resource );` - var policies cedar.PolicySlice + var policies cedar.Policies testutil.OK(t, policies.UnmarshalCedar([]byte(policiesStr))) ps := cedar.NewPolicySet() for i, p := range policies { p.SetFilename("example.cedar") - ps.UpsertPolicy(cedar.PolicyID(fmt.Sprintf("policy%d", i)), p) + ps.Upsert(cedar.PolicyID(fmt.Sprintf("policy%d", i)), p) } - testutil.Equals(t, ps.GetPolicy("policy0").Effect(), cedar.Permit) - testutil.Equals(t, ps.GetPolicy("policy1").Effect(), cedar.Forbid) + testutil.Equals(t, ps.Get("policy0").Effect(), cedar.Permit) + testutil.Equals(t, ps.Get("policy1").Effect(), cedar.Forbid) testutil.Equals(t, string(ps.MarshalCedar()), policiesStr) } diff --git a/policy_slice.go b/policy_slice.go index 5df57350..d2bee17d 100644 --- a/policy_slice.go +++ b/policy_slice.go @@ -8,14 +8,14 @@ import ( "github.com/cedar-policy/cedar-go/internal/parser" ) -// PolicySlice represents a set of un-named Policy's. Cedar documents, unlike the JSON format, don't have a means of +// Policies represents a set of un-named Policy's. Cedar documents, unlike the JSON format, don't have a means of // naming individual policies. -type PolicySlice []*Policy +type Policies []Policy -// NewPolicySliceFromBytes will create a PolicySet from the given text document with the given file name used in Position +// NewPoliciesFromBytes will create a PolicySet from the given text document with the given file name used in Position // data. If there is an error parsing the document, it will be returned. -func NewPolicySliceFromBytes(fileName string, document []byte) (PolicySlice, error) { - var policySlice PolicySlice +func NewPoliciesFromBytes(fileName string, document []byte) (Policies, error) { + var policySlice Policies if err := policySlice.UnmarshalCedar(document); err != nil { return nil, err } @@ -27,22 +27,22 @@ func NewPolicySliceFromBytes(fileName string, document []byte) (PolicySlice, err // UnmarshalCedar parses a concatenation of un-named Cedar policy statements. Names can be assigned to these policies // when adding them to a PolicySet. -func (p *PolicySlice) UnmarshalCedar(b []byte) error { +func (p *Policies) UnmarshalCedar(b []byte) error { var res parser.PolicySlice if err := res.UnmarshalCedar(b); err != nil { return fmt.Errorf("parser error: %w", err) } - policySlice := make([]*Policy, 0, len(res)) + policySlice := make([]Policy, 0, len(res)) for _, p := range res { newPolicy := newPolicy((*internalast.Policy)(p)) - policySlice = append(policySlice, &newPolicy) + policySlice = append(policySlice, newPolicy) } *p = policySlice return nil } // MarshalCedar emits a concatenated Cedar representation of a PolicySlice -func (p PolicySlice) MarshalCedar() []byte { +func (p Policies) MarshalCedar() []byte { var buf bytes.Buffer for i, policy := range p { buf.Write(policy.MarshalCedar()) diff --git a/policy_slice_test.go b/policy_slice_test.go index fc728085..152e47a2 100644 --- a/policy_slice_test.go +++ b/policy_slice_test.go @@ -23,7 +23,7 @@ forbid ( resource );` - var policies cedar.PolicySlice + var policies cedar.Policies testutil.OK(t, policies.UnmarshalCedar([]byte(policiesStr))) testutil.Equals(t, string(policies.MarshalCedar()), policiesStr) From 41d0f8a8e5d5310c7349d5410fe376f8de230f05 Mon Sep 17 00:00:00 2001 From: philhassey Date: Tue, 20 Aug 2024 16:02:01 -0600 Subject: [PATCH 194/216] types: remove useless util function Addresses IDX-142 Signed-off-by: philhassey --- types/boolean_test.go | 2 +- types/entity_uid_test.go | 4 ++-- types/long_test.go | 2 +- types/record_test.go | 10 +++++----- types/set_test.go | 6 +++--- types/string_test.go | 4 ++-- types/testutil_test.go | 13 ------------- 7 files changed, 14 insertions(+), 27 deletions(-) delete mode 100644 types/testutil_test.go diff --git a/types/boolean_test.go b/types/boolean_test.go index a95dbecb..745cf5ca 100644 --- a/types/boolean_test.go +++ b/types/boolean_test.go @@ -25,7 +25,7 @@ func TestBool(t *testing.T) { t.Run("string", func(t *testing.T) { t.Parallel() - assertValueString(t, types.Boolean(true), "true") + testutil.Equals(t, types.Boolean(true).String(), "true") }) } diff --git a/types/entity_uid_test.go b/types/entity_uid_test.go index 60a26a05..e09ed5cb 100644 --- a/types/entity_uid_test.go +++ b/types/entity_uid_test.go @@ -22,8 +22,8 @@ func TestEntity(t *testing.T) { t.Run("string", func(t *testing.T) { t.Parallel() - assertValueString(t, types.EntityUID{Type: "type", ID: "id"}, `type::"id"`) - assertValueString(t, types.EntityUID{Type: "namespace::type", ID: "id"}, `namespace::type::"id"`) + testutil.Equals(t, types.EntityUID{Type: "type", ID: "id"}.String(), `type::"id"`) + testutil.Equals(t, types.EntityUID{Type: "namespace::type", ID: "id"}.String(), `namespace::type::"id"`) }) } diff --git a/types/long_test.go b/types/long_test.go index ec00570c..33c98b84 100644 --- a/types/long_test.go +++ b/types/long_test.go @@ -25,7 +25,7 @@ func TestLong(t *testing.T) { t.Run("string", func(t *testing.T) { t.Parallel() - assertValueString(t, types.Long(1), "1") + testutil.Equals(t, types.Long(1).String(), "1") }) } diff --git a/types/record_test.go b/types/record_test.go index a7ce1fe1..3673598f 100644 --- a/types/record_test.go +++ b/types/record_test.go @@ -57,17 +57,17 @@ func TestRecord(t *testing.T) { t.Run("string", func(t *testing.T) { t.Parallel() - assertValueString(t, types.Record{}, "{}") - assertValueString( + testutil.Equals(t, types.Record{}.String(), "{}") + testutil.Equals( t, - types.Record{"foo": types.Boolean(true)}, + types.Record{"foo": types.Boolean(true)}.String(), `{"foo":true}`) - assertValueString( + testutil.Equals( t, types.Record{ "foo": types.Boolean(true), "bar": types.String("blah"), - }, + }.String(), `{"bar":"blah", "foo":true}`) }) diff --git a/types/set_test.go b/types/set_test.go index 770def8d..33b8369e 100644 --- a/types/set_test.go +++ b/types/set_test.go @@ -45,10 +45,10 @@ func TestSet(t *testing.T) { t.Run("string", func(t *testing.T) { t.Parallel() - assertValueString(t, types.Set{}, "[]") - assertValueString( + testutil.Equals(t, types.Set{}.String(), "[]") + testutil.Equals( t, - types.Set{types.Boolean(true), types.Long(1)}, + types.Set{types.Boolean(true), types.Long(1)}.String(), "[true, 1]") }) diff --git a/types/string_test.go b/types/string_test.go index 882f96d2..b6103d6d 100644 --- a/types/string_test.go +++ b/types/string_test.go @@ -22,8 +22,8 @@ func TestString(t *testing.T) { t.Run("string", func(t *testing.T) { t.Parallel() - assertValueString(t, types.String("hello"), `hello`) - assertValueString(t, types.String("hello\ngoodbye"), "hello\ngoodbye") + testutil.Equals(t, types.String("hello").String(), `hello`) + testutil.Equals(t, types.String("hello\ngoodbye").String(), "hello\ngoodbye") }) } diff --git a/types/testutil_test.go b/types/testutil_test.go deleted file mode 100644 index f91b13d5..00000000 --- a/types/testutil_test.go +++ /dev/null @@ -1,13 +0,0 @@ -package types_test - -import ( - "testing" - - "github.com/cedar-policy/cedar-go/internal/testutil" - "github.com/cedar-policy/cedar-go/types" -) - -func assertValueString(t *testing.T, v types.Value, want string) { - t.Helper() - testutil.Equals(t, v.String(), want) -} From b97339fd817ddb6f3004d32780a56ff26c91b191 Mon Sep 17 00:00:00 2001 From: philhassey Date: Tue, 20 Aug 2024 16:03:21 -0600 Subject: [PATCH 195/216] types: remove dead code Addresses IDX-142 Signed-off-by: philhassey --- types/boolean.go | 1 - types/json_test.go | 1 - 2 files changed, 2 deletions(-) diff --git a/types/boolean.go b/types/boolean.go index 732742a5..0526e971 100644 --- a/types/boolean.go +++ b/types/boolean.go @@ -16,7 +16,6 @@ func (a Boolean) Equal(bi Value) bool { b, ok := bi.(Boolean) return ok && a == b } -func (v Boolean) TypeName() string { return "bool" } // String produces a string representation of the Boolean, e.g. `true`. func (v Boolean) String() string { return string(v.MarshalCedar()) } diff --git a/types/json_test.go b/types/json_test.go index c649ad37..cb8efc38 100644 --- a/types/json_test.go +++ b/types/json_test.go @@ -316,7 +316,6 @@ func (j *jsonErr) String() string { return "" } func (j *jsonErr) MarshalCedar() []byte { return nil } func (j *jsonErr) Equal(Value) bool { return false } func (j *jsonErr) ExplicitMarshalJSON() ([]byte, error) { return nil, fmt.Errorf("jsonErr") } -func (j *jsonErr) TypeName() string { return "jsonErr" } func (j *jsonErr) deepClone() Value { return nil } func TestJSONSet(t *testing.T) { From ba0fe9836f99a218b233a0618811cdd9f9c0a177 Mon Sep 17 00:00:00 2001 From: philhassey Date: Tue, 20 Aug 2024 16:18:18 -0600 Subject: [PATCH 196/216] types: obtain coverage of patterns Addresses IDX-142 Signed-off-by: philhassey --- internal/parser/pattern_test.go | 2 +- types/boolean_test.go | 1 + types/decimal_test.go | 5 +++ types/entities_test.go | 8 ++++- types/entity_uid_test.go | 5 +++ types/patttern_test.go | 64 +++++++++++++++++++++++++++++++-- 6 files changed, 81 insertions(+), 4 deletions(-) diff --git a/internal/parser/pattern_test.go b/internal/parser/pattern_test.go index e686e322..852d5069 100644 --- a/internal/parser/pattern_test.go +++ b/internal/parser/pattern_test.go @@ -51,7 +51,7 @@ func TestParsePattern(t *testing.T) { } } -func TestMatch(t *testing.T) { +func TestPatternParseAndMatch(t *testing.T) { t.Parallel() tests := []struct { pattern string diff --git a/types/boolean_test.go b/types/boolean_test.go index 745cf5ca..1af921ef 100644 --- a/types/boolean_test.go +++ b/types/boolean_test.go @@ -26,6 +26,7 @@ func TestBool(t *testing.T) { t.Run("string", func(t *testing.T) { t.Parallel() testutil.Equals(t, types.Boolean(true).String(), "true") + testutil.Equals(t, types.Boolean(false).String(), "false") }) } diff --git a/types/decimal_test.go b/types/decimal_test.go index 2640af2f..b73b46e4 100644 --- a/types/decimal_test.go +++ b/types/decimal_test.go @@ -120,4 +120,9 @@ func TestDecimal(t *testing.T) { testutil.FatalIf(t, zero.Equal(f), "%v Equal to %v", zero, f) }) + t.Run("MarshalCedar", func(t *testing.T) { + t.Parallel() + testutil.Equals(t, string(types.UnsafeDecimal(42).MarshalCedar()), `decimal("42.0")`) + }) + } diff --git a/types/entities_test.go b/types/entities_test.go index ed27e64d..73374a66 100644 --- a/types/entities_test.go +++ b/types/entities_test.go @@ -34,10 +34,16 @@ func TestEntitiesJSON(t *testing.T) { Parents: []types.EntityUID{}, Attributes: types.Record{"key": types.Long(42)}, } + ent2 := types.Entity{ + UID: types.NewEntityUID("Type", "id2"), + Parents: []types.EntityUID{}, + Attributes: types.Record{"key": types.Long(42)}, + } e[ent.UID] = ent + e[ent2.UID] = ent2 b, err := e.MarshalJSON() testutil.OK(t, err) - testutil.Equals(t, string(b), `[{"uid":{"type":"Type","id":"id"},"attrs":{"key":42}}]`) + testutil.Equals(t, string(b), `[{"uid":{"type":"Type","id":"id"},"attrs":{"key":42}},{"uid":{"type":"Type","id":"id2"},"attrs":{"key":42}}]`) }) t.Run("Unmarshal", func(t *testing.T) { diff --git a/types/entity_uid_test.go b/types/entity_uid_test.go index e09ed5cb..e28bbbe7 100644 --- a/types/entity_uid_test.go +++ b/types/entity_uid_test.go @@ -26,4 +26,9 @@ func TestEntity(t *testing.T) { testutil.Equals(t, types.EntityUID{Type: "namespace::type", ID: "id"}.String(), `namespace::type::"id"`) }) + t.Run("MarshalCedar", func(t *testing.T) { + t.Parallel() + testutil.Equals(t, string(types.EntityUID{"type", "id"}.MarshalCedar()), `type::"id"`) + }) + } diff --git a/types/patttern_test.go b/types/patttern_test.go index 19381555..f92ef962 100644 --- a/types/patttern_test.go +++ b/types/patttern_test.go @@ -6,20 +6,73 @@ import ( "github.com/cedar-policy/cedar-go/internal/testutil" ) -func TestPatternFromBuilder(t *testing.T) { +func TestPattern(t *testing.T) { + t.Parallel() t.Run("saturate two wildcards", func(t *testing.T) { + t.Parallel() pattern1 := NewPattern(Wildcard(), Wildcard()) pattern2 := NewPattern(Wildcard()) testutil.Equals(t, pattern1, pattern2) }) t.Run("saturate two literals", func(t *testing.T) { + t.Parallel() pattern1 := NewPattern(String("foo"), String("bar")) pattern2 := NewPattern(String("foobar")) testutil.Equals(t, pattern1, pattern2) }) + t.Run("panicOnNil", func(t *testing.T) { + t.Parallel() + testutil.AssertPanic(t, func() { + NewPattern(nil) + }) + }) + t.Run("MarshalCedar", func(t *testing.T) { + t.Parallel() + testutil.Equals(t, string(NewPattern(String("*foo"), Wildcard()).MarshalCedar()), `"\*foo*"`) + }) + + t.Run("isPatternComponent", func(t *testing.T) { + t.Parallel() + String("").isPatternComponent() + Wildcard().isPatternComponent() + }) } -func TestJSON(t *testing.T) { +func TestPatternMatch(t *testing.T) { + t.Parallel() + tests := []struct { + pattern Pattern + target string + want bool + }{ + {NewPattern(), "", true}, + {NewPattern(), "hello", false}, + {NewPattern(Wildcard()), "hello", true}, + {NewPattern(String("e")), "hello", false}, + {NewPattern(Wildcard(), String("e")), "hello", false}, + {NewPattern(Wildcard(), String("e"), Wildcard()), "hello", true}, + {NewPattern(String("hello")), "hello", true}, + {NewPattern(String("hello"), Wildcard()), "hello", true}, + {NewPattern(Wildcard(), String("h"), Wildcard(), String("llo"), Wildcard()), "hello", true}, + {NewPattern(String("h"), Wildcard(), String("e"), Wildcard(), String("o")), "hello", true}, + {NewPattern(String("h"), Wildcard(), String("e"), Wildcard(), Wildcard(), String("o")), "hello", true}, + {NewPattern(String("h"), Wildcard(), String("z"), Wildcard(), String("o")), "hello", false}, + + {NewPattern(String("**"), Wildcard(), String("**")), "**foo**", true}, + {NewPattern(String("**"), Wildcard(), String("**")), "**bar**", true}, + {NewPattern(String("**"), Wildcard(), String("**")), "*bar*", false}, + } + for _, tt := range tests { + tt := tt + t.Run(string(tt.pattern.MarshalCedar())+":"+tt.target, func(t *testing.T) { + t.Parallel() + got := tt.pattern.Match(String(tt.target)) + testutil.Equals(t, got, tt.want) + }) + } +} + +func TestPatternJSON(t *testing.T) { t.Parallel() tests := []struct { name string @@ -112,6 +165,13 @@ func TestJSON(t *testing.T) { Pattern{}, false, }, + { + "other type", + `[null]`, + testutil.Error, + Pattern{}, + false, + }, { "lowercase literal", `[{"literal": "foo"}]`, From 3990248c3854b59f7ad42db13827839a38d2362a Mon Sep 17 00:00:00 2001 From: philhassey Date: Tue, 20 Aug 2024 16:35:24 -0600 Subject: [PATCH 197/216] internal/ast: obtain coverage of package Addresses IDX-142 Signed-off-by: philhassey --- internal/ast/ast_test.go | 19 +++++++++++++++++++ internal/ast/internal_test.go | 8 ++++++++ internal/ast/value.go | 10 ---------- 3 files changed, 27 insertions(+), 10 deletions(-) diff --git a/internal/ast/ast_test.go b/internal/ast/ast_test.go index 55a10f63..68d0d689 100644 --- a/internal/ast/ast_test.go +++ b/internal/ast/ast_test.go @@ -480,6 +480,25 @@ func TestASTByTable(t *testing.T) { ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeExtensionCall{Name: "isInRange", Args: []ast.IsNode{ast.NodeValue{Value: types.Long(42)}, ast.NodeValue{Value: types.Long(43)}}}}}}, }, + + { + "duplicateAnnotations", + ast.Permit().Annotate("key", "value").Annotate("key", "value2"), + ast.Policy{Annotations: []ast.AnnotationType{{Key: "key", Value: "value2"}}, Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}}, + }, + + { + "valueRecordElements", + ast.Permit().When(ast.Record(ast.Pairs{{Key: "key", Value: ast.Long(42)}})), + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeRecord{Elements: []ast.RecordElementNode{{Key: "key", Value: ast.NodeValue{Value: types.Long(42)}}}}}}}, + }, + { + "duplicateValueRecordElements", + ast.Permit().When(ast.Record(ast.Pairs{{Key: "key", Value: ast.Long(42)}, {Key: "key", Value: ast.Long(43)}})), + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeRecord{Elements: []ast.RecordElementNode{{Key: "key", Value: ast.NodeValue{Value: types.Long(43)}}}}}}}, + }, } for _, tt := range tests { diff --git a/internal/ast/internal_test.go b/internal/ast/internal_test.go index f5c9ac75..8f61a828 100644 --- a/internal/ast/internal_test.go +++ b/internal/ast/internal_test.go @@ -31,3 +31,11 @@ func TestAsNode(t *testing.T) { v := n.AsIsNode() testutil.Equals(t, v, (IsNode)(NodeValue{Value: types.Long(42)})) } + +func TestIsScope(t *testing.T) { + t.Parallel() + ScopeNode{}.isScope() + PrincipalScopeNode{}.isPrincipalScope() + ActionScopeNode{}.isActionScope() + ResourceScopeNode{}.isResourceScope() +} diff --git a/internal/ast/value.go b/internal/ast/value.go index 6e835d88..6a926ee7 100644 --- a/internal/ast/value.go +++ b/internal/ast/value.go @@ -26,16 +26,6 @@ func Long[T int | int64 | types.Long](l T) Node { return Value(types.Long(l)) } -// SetDeprecated is a convenience function that wraps concrete instances of a Cedar SetDeprecated type -// types in AST value nodes and passes them along to SetNodes. -func SetDeprecated(s types.Set) Node { - var nodes []IsNode - for _, v := range s { - nodes = append(nodes, Value(v).v) - } - return NewNode(NodeTypeSet{Elements: nodes}) -} - // Set allows for a complex set definition with values potentially // being Cedar expressions of their own. For example, this Cedar text: // From 4d5f2c3666c05d3467dab35aa62b9797b453020d Mon Sep 17 00:00:00 2001 From: philhassey Date: Tue, 20 Aug 2024 16:57:58 -0600 Subject: [PATCH 198/216] internal/testutil: add tests for testing tests Addresses IDX-142 Signed-off-by: philhassey --- internal/consts/consts.go | 4 + internal/consts/consts_test.go | 15 +++ internal/eval/compile_test.go | 2 +- internal/eval/convert_test.go | 6 +- internal/eval/evalers_test.go | 64 +++++------ internal/eval/util_test.go | 16 +-- internal/extensions/extensions.go | 4 + internal/extensions/extensions_test.go | 12 ++ internal/json/json_test.go | 10 +- internal/rust/rust_test.go | 4 +- internal/testutil/mocks_test.go | 117 +++++++++++++++++++ internal/testutil/testutil.go | 32 +++--- internal/testutil/testutil_test.go | 149 +++++++++++++++++++++++++ types/decimal_test.go | 2 +- types/json_test.go | 4 +- types/patttern_test.go | 4 +- 16 files changed, 374 insertions(+), 71 deletions(-) create mode 100644 internal/consts/consts_test.go create mode 100644 internal/extensions/extensions_test.go create mode 100644 internal/testutil/mocks_test.go create mode 100644 internal/testutil/testutil_test.go diff --git a/internal/consts/consts.go b/internal/consts/consts.go index 559d2a8f..f46e2818 100644 --- a/internal/consts/consts.go +++ b/internal/consts/consts.go @@ -6,3 +6,7 @@ const ( Resource = "resource" Context = "context" ) + +func init() { + _ = 42 +} diff --git a/internal/consts/consts_test.go b/internal/consts/consts_test.go new file mode 100644 index 00000000..9b4dd3d7 --- /dev/null +++ b/internal/consts/consts_test.go @@ -0,0 +1,15 @@ +package consts + +import ( + "testing" + + "github.com/cedar-policy/cedar-go/internal/testutil" +) + +func TestConsts(t *testing.T) { + t.Parallel() + testutil.Equals(t, Principal, "principal") + testutil.Equals(t, Action, "action") + testutil.Equals(t, Resource, "resource") + testutil.Equals(t, Context, "context") +} diff --git a/internal/eval/compile_test.go b/internal/eval/compile_test.go index 4ff8a1a2..ba45a7ed 100644 --- a/internal/eval/compile_test.go +++ b/internal/eval/compile_test.go @@ -120,7 +120,7 @@ func TestScopeToNode(t *testing.T) { func TestScopeToNodePanic(t *testing.T) { t.Parallel() - testutil.AssertPanic(t, func() { + testutil.Panic(t, func() { _ = scopeToNode(ast.NewPrincipalNode(), ast.ScopeNode{}) }) } diff --git a/internal/eval/convert_test.go b/internal/eval/convert_test.go index c2eefc5b..91abef8f 100644 --- a/internal/eval/convert_test.go +++ b/internal/eval/convert_test.go @@ -15,7 +15,7 @@ func TestToEval(t *testing.T) { name string in ast.Node out types.Value - err func(testing.TB, error) + err func(testutil.TB, error) }{ { "access", @@ -297,14 +297,14 @@ func TestToEval(t *testing.T) { func TestToEvalPanic(t *testing.T) { t.Parallel() - testutil.AssertPanic(t, func() { + testutil.Panic(t, func() { _ = toEval(ast.Node{}.AsIsNode()) }) } func TestToEvalVariablePanic(t *testing.T) { t.Parallel() - testutil.AssertPanic(t, func() { + testutil.Panic(t, func() { _ = toEval(ast.NodeTypeVariable{Name: "bananas"}) }) } diff --git a/internal/eval/evalers_test.go b/internal/eval/evalers_test.go index 31cccc67..28bed1e3 100644 --- a/internal/eval/evalers_test.go +++ b/internal/eval/evalers_test.go @@ -92,7 +92,7 @@ func TestOrNode(t *testing.T) { t.Parallel() n := newOrNode(tt.lhs, tt.rhs) _, err := n.Eval(&Context{}) - testutil.AssertError(t, err, tt.err) + testutil.ErrorIs(t, err, tt.err) }) } } @@ -147,7 +147,7 @@ func TestAndNode(t *testing.T) { t.Parallel() n := newAndEval(tt.lhs, tt.rhs) _, err := n.Eval(&Context{}) - testutil.AssertError(t, err, tt.err) + testutil.ErrorIs(t, err, tt.err) }) } } @@ -189,7 +189,7 @@ func TestNotNode(t *testing.T) { t.Parallel() n := newNotEval(tt.arg) _, err := n.Eval(&Context{}) - testutil.AssertError(t, err, tt.err) + testutil.ErrorIs(t, err, tt.err) }) } } @@ -397,7 +397,7 @@ func TestAddNode(t *testing.T) { t.Parallel() n := newAddEval(tt.lhs, tt.rhs) _, err := n.Eval(&Context{}) - testutil.AssertError(t, err, tt.err) + testutil.ErrorIs(t, err, tt.err) }) } } @@ -436,7 +436,7 @@ func TestSubtractNode(t *testing.T) { t.Parallel() n := newSubtractEval(tt.lhs, tt.rhs) _, err := n.Eval(&Context{}) - testutil.AssertError(t, err, tt.err) + testutil.ErrorIs(t, err, tt.err) }) } } @@ -475,7 +475,7 @@ func TestMultiplyNode(t *testing.T) { t.Parallel() n := newMultiplyEval(tt.lhs, tt.rhs) _, err := n.Eval(&Context{}) - testutil.AssertError(t, err, tt.err) + testutil.ErrorIs(t, err, tt.err) }) } } @@ -505,7 +505,7 @@ func TestNegateNode(t *testing.T) { t.Parallel() n := newNegateEval(tt.arg) _, err := n.Eval(&Context{}) - testutil.AssertError(t, err, tt.err) + testutil.ErrorIs(t, err, tt.err) }) } } @@ -556,7 +556,7 @@ func TestLongLessThanNode(t *testing.T) { t.Parallel() n := newLongLessThanEval(tt.lhs, tt.rhs) _, err := n.Eval(&Context{}) - testutil.AssertError(t, err, tt.err) + testutil.ErrorIs(t, err, tt.err) }) } } @@ -608,7 +608,7 @@ func TestLongLessThanOrEqualNode(t *testing.T) { t.Parallel() n := newLongLessThanOrEqualEval(tt.lhs, tt.rhs) _, err := n.Eval(&Context{}) - testutil.AssertError(t, err, tt.err) + testutil.ErrorIs(t, err, tt.err) }) } } @@ -660,7 +660,7 @@ func TestLongGreaterThanNode(t *testing.T) { t.Parallel() n := newLongGreaterThanEval(tt.lhs, tt.rhs) _, err := n.Eval(&Context{}) - testutil.AssertError(t, err, tt.err) + testutil.ErrorIs(t, err, tt.err) }) } } @@ -712,7 +712,7 @@ func TestLongGreaterThanOrEqualNode(t *testing.T) { t.Parallel() n := newLongGreaterThanOrEqualEval(tt.lhs, tt.rhs) _, err := n.Eval(&Context{}) - testutil.AssertError(t, err, tt.err) + testutil.ErrorIs(t, err, tt.err) }) } } @@ -769,7 +769,7 @@ func TestDecimalLessThanNode(t *testing.T) { t.Parallel() n := newDecimalLessThanEval(tt.lhs, tt.rhs) _, err := n.Eval(&Context{}) - testutil.AssertError(t, err, tt.err) + testutil.ErrorIs(t, err, tt.err) }) } } @@ -826,7 +826,7 @@ func TestDecimalLessThanOrEqualNode(t *testing.T) { t.Parallel() n := newDecimalLessThanOrEqualEval(tt.lhs, tt.rhs) _, err := n.Eval(&Context{}) - testutil.AssertError(t, err, tt.err) + testutil.ErrorIs(t, err, tt.err) }) } } @@ -883,7 +883,7 @@ func TestDecimalGreaterThanNode(t *testing.T) { t.Parallel() n := newDecimalGreaterThanEval(tt.lhs, tt.rhs) _, err := n.Eval(&Context{}) - testutil.AssertError(t, err, tt.err) + testutil.ErrorIs(t, err, tt.err) }) } } @@ -940,7 +940,7 @@ func TestDecimalGreaterThanOrEqualNode(t *testing.T) { t.Parallel() n := newDecimalGreaterThanOrEqualEval(tt.lhs, tt.rhs) _, err := n.Eval(&Context{}) - testutil.AssertError(t, err, tt.err) + testutil.ErrorIs(t, err, tt.err) }) } } @@ -971,7 +971,7 @@ func TestIfThenElseNode(t *testing.T) { t.Parallel() n := newIfThenElseEval(tt.if_, tt.then, tt.else_) v, err := n.Eval(&Context{}) - testutil.AssertError(t, err, tt.err) + testutil.ErrorIs(t, err, tt.err) testutil.Equals(t, v, tt.result) }) } @@ -997,7 +997,7 @@ func TestEqualNode(t *testing.T) { t.Parallel() n := newEqualEval(tt.lhs, tt.rhs) v, err := n.Eval(&Context{}) - testutil.AssertError(t, err, tt.err) + testutil.ErrorIs(t, err, tt.err) AssertValue(t, v, tt.result) }) } @@ -1023,7 +1023,7 @@ func TestNotEqualNode(t *testing.T) { t.Parallel() n := newNotEqualEval(tt.lhs, tt.rhs) v, err := n.Eval(&Context{}) - testutil.AssertError(t, err, tt.err) + testutil.ErrorIs(t, err, tt.err) AssertValue(t, v, tt.result) }) } @@ -1064,7 +1064,7 @@ func TestSetLiteralNode(t *testing.T) { t.Parallel() n := newSetLiteralEval(tt.elems) v, err := n.Eval(&Context{}) - testutil.AssertError(t, err, tt.err) + testutil.ErrorIs(t, err, tt.err) AssertValue(t, v, tt.result) }) } @@ -1088,7 +1088,7 @@ func TestContainsNode(t *testing.T) { t.Parallel() n := newContainsEval(tt.lhs, tt.rhs) v, err := n.Eval(&Context{}) - testutil.AssertError(t, err, tt.err) + testutil.ErrorIs(t, err, tt.err) AssertZeroValue(t, v) }) } @@ -1143,7 +1143,7 @@ func TestContainsAllNode(t *testing.T) { t.Parallel() n := newContainsAllEval(tt.lhs, tt.rhs) v, err := n.Eval(&Context{}) - testutil.AssertError(t, err, tt.err) + testutil.ErrorIs(t, err, tt.err) AssertZeroValue(t, v) }) } @@ -1197,7 +1197,7 @@ func TestContainsAnyNode(t *testing.T) { t.Parallel() n := newContainsAnyEval(tt.lhs, tt.rhs) v, err := n.Eval(&Context{}) - testutil.AssertError(t, err, tt.err) + testutil.ErrorIs(t, err, tt.err) AssertZeroValue(t, v) }) } @@ -1260,7 +1260,7 @@ func TestRecordLiteralNode(t *testing.T) { t.Parallel() n := newRecordLiteralEval(tt.elems) v, err := n.Eval(&Context{}) - testutil.AssertError(t, err, tt.err) + testutil.ErrorIs(t, err, tt.err) AssertValue(t, v, tt.result) }) } @@ -1317,7 +1317,7 @@ func TestAttributeAccessNode(t *testing.T) { entity.UID: entity, }, }) - testutil.AssertError(t, err, tt.err) + testutil.ErrorIs(t, err, tt.err) AssertValue(t, v, tt.result) }) } @@ -1374,7 +1374,7 @@ func TestHasNode(t *testing.T) { entity.UID: entity, }, }) - testutil.AssertError(t, err, tt.err) + testutil.ErrorIs(t, err, tt.err) AssertValue(t, v, tt.result) }) } @@ -1428,7 +1428,7 @@ func TestLikeNode(t *testing.T) { testutil.OK(t, err) n := newLikeEval(tt.str, pat) v, err := n.Eval(&Context{}) - testutil.AssertError(t, err, tt.err) + testutil.ErrorIs(t, err, tt.err) AssertValue(t, v, tt.result) }) } @@ -1625,7 +1625,7 @@ func TestIsNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() got, err := newIsEval(tt.lhs, tt.rhs).Eval(&Context{}) - testutil.AssertError(t, err, tt.err) + testutil.ErrorIs(t, err, tt.err) AssertValue(t, got, tt.result) }) } @@ -1742,7 +1742,7 @@ func TestInNode(t *testing.T) { } ec := Context{Entities: entityMap} v, err := n.Eval(&ec) - testutil.AssertError(t, err, tt.err) + testutil.ErrorIs(t, err, tt.err) AssertValue(t, v, tt.result) }) } @@ -1767,7 +1767,7 @@ func TestDecimalLiteralNode(t *testing.T) { t.Parallel() n := newDecimalLiteralEval(tt.arg) v, err := n.Eval(&Context{}) - testutil.AssertError(t, err, tt.err) + testutil.ErrorIs(t, err, tt.err) AssertValue(t, v, tt.result) }) } @@ -1794,7 +1794,7 @@ func TestIPLiteralNode(t *testing.T) { t.Parallel() n := newIPLiteralEval(tt.arg) v, err := n.Eval(&Context{}) - testutil.AssertError(t, err, tt.err) + testutil.ErrorIs(t, err, tt.err) AssertValue(t, v, tt.result) }) } @@ -1832,7 +1832,7 @@ func TestIPTestNode(t *testing.T) { t.Parallel() n := newIPTestEval(tt.lhs, tt.rhs) v, err := n.Eval(&Context{}) - testutil.AssertError(t, err, tt.err) + testutil.ErrorIs(t, err, tt.err) AssertValue(t, v, tt.result) }) } @@ -1870,7 +1870,7 @@ func TestIPIsInRangeNode(t *testing.T) { t.Parallel() n := newIPIsInRangeEval(tt.lhs, tt.rhs) v, err := n.Eval(&Context{}) - testutil.AssertError(t, err, tt.err) + testutil.ErrorIs(t, err, tt.err) AssertValue(t, v, tt.result) }) } diff --git a/internal/eval/util_test.go b/internal/eval/util_test.go index 75615b66..2545ed02 100644 --- a/internal/eval/util_test.go +++ b/internal/eval/util_test.go @@ -21,7 +21,7 @@ func TestUtil(t *testing.T) { t.Run("toBoolOnNonBool", func(t *testing.T) { t.Parallel() v, err := ValueToBool(types.Long(0)) - testutil.AssertError(t, err, ErrType) + testutil.ErrorIs(t, err, ErrType) testutil.Equals(t, v, false) }) }) @@ -38,7 +38,7 @@ func TestUtil(t *testing.T) { t.Run("toLongOnNonLong", func(t *testing.T) { t.Parallel() v, err := ValueToLong(types.Boolean(true)) - testutil.AssertError(t, err, ErrType) + testutil.ErrorIs(t, err, ErrType) testutil.Equals(t, v, 0) }) }) @@ -55,7 +55,7 @@ func TestUtil(t *testing.T) { t.Run("toStringOnNonString", func(t *testing.T) { t.Parallel() v, err := ValueToString(types.Boolean(true)) - testutil.AssertError(t, err, ErrType) + testutil.ErrorIs(t, err, ErrType) testutil.Equals(t, v, "") }) }) @@ -74,7 +74,7 @@ func TestUtil(t *testing.T) { t.Run("ToSetOnNonSet", func(t *testing.T) { t.Parallel() v, err := ValueToSet(types.Boolean(true)) - testutil.AssertError(t, err, ErrType) + testutil.ErrorIs(t, err, ErrType) testutil.Equals(t, v, nil) }) }) @@ -96,7 +96,7 @@ func TestUtil(t *testing.T) { t.Run("toRecordOnNonRecord", func(t *testing.T) { t.Parallel() v, err := ValueToRecord(types.String("hello")) - testutil.AssertError(t, err, ErrType) + testutil.ErrorIs(t, err, ErrType) testutil.Equals(t, v, nil) }) }) @@ -113,7 +113,7 @@ func TestUtil(t *testing.T) { t.Run("ToEntityOnNonEntity", func(t *testing.T) { t.Parallel() v, err := ValueToEntity(types.String("hello")) - testutil.AssertError(t, err, ErrType) + testutil.ErrorIs(t, err, ErrType) testutil.Equals(t, v, types.EntityUID{}) }) @@ -133,7 +133,7 @@ func TestUtil(t *testing.T) { t.Run("toDecimalOnNonDecimal", func(t *testing.T) { t.Parallel() v, err := ValueToDecimal(types.Boolean(true)) - testutil.AssertError(t, err, ErrType) + testutil.ErrorIs(t, err, ErrType) testutil.Equals(t, v, types.Decimal{}) }) @@ -145,7 +145,7 @@ func TestUtil(t *testing.T) { t.Run("toIPOnNonIP", func(t *testing.T) { t.Parallel() v, err := ValueToIP(types.Boolean(true)) - testutil.AssertError(t, err, ErrType) + testutil.ErrorIs(t, err, ErrType) testutil.Equals(t, v, types.IPAddr{}) }) }) diff --git a/internal/extensions/extensions.go b/internal/extensions/extensions.go index 3fa7f021..aa610e6a 100644 --- a/internal/extensions/extensions.go +++ b/internal/extensions/extensions.go @@ -22,3 +22,7 @@ var ExtMap = map[types.Path]extInfo{ "isMulticast": {Args: 1, IsMethod: true}, "isInRange": {Args: 2, IsMethod: true}, } + +func init() { + _ = 42 +} diff --git a/internal/extensions/extensions_test.go b/internal/extensions/extensions_test.go new file mode 100644 index 00000000..3946590b --- /dev/null +++ b/internal/extensions/extensions_test.go @@ -0,0 +1,12 @@ +package extensions + +import ( + "testing" + + "github.com/cedar-policy/cedar-go/internal/testutil" +) + +func TestExtensions(t *testing.T) { + t.Parallel() + testutil.Equals(t, len(ExtMap), 11) +} diff --git a/internal/json/json_test.go b/internal/json/json_test.go index cb1138a3..65ce4526 100644 --- a/internal/json/json_test.go +++ b/internal/json/json_test.go @@ -16,7 +16,7 @@ func TestUnmarshalJSON(t *testing.T) { name string input string want *ast.Policy - errFunc func(testing.TB, error) + errFunc func(testutil.TB, error) }{ /* @key("value") @@ -485,7 +485,7 @@ func TestMarshalJSON(t *testing.T) { name string input *ast.Policy want string - errFunc func(testing.TB, error) + errFunc func(testutil.TB, error) }{ { "decimal", @@ -529,7 +529,7 @@ func TestMarshalJSON(t *testing.T) { } } -func testNormalizeJSON(t testing.TB, in string) string { +func testNormalizeJSON(t testutil.TB, in string) string { var x any err := json.Unmarshal([]byte(in), &x) testutil.OK(t, err) @@ -551,7 +551,7 @@ func TestMarshalPanics(t *testing.T) { t.Parallel() t.Run("nilScope", func(t *testing.T) { t.Parallel() - testutil.AssertPanic(t, func() { + testutil.Panic(t, func() { s := scopeJSON{} var v ast.IsScopeNode s.FromNode(v) @@ -559,7 +559,7 @@ func TestMarshalPanics(t *testing.T) { }) t.Run("nilNode", func(t *testing.T) { t.Parallel() - testutil.AssertPanic(t, func() { + testutil.Panic(t, func() { s := nodeJSON{} var v ast.IsNode s.FromNode(v) diff --git a/internal/rust/rust_test.go b/internal/rust/rust_test.go index 99b60cb1..af7d1450 100644 --- a/internal/rust/rust_test.go +++ b/internal/rust/rust_test.go @@ -13,7 +13,7 @@ func TestParseUnicodeEscape(t *testing.T) { in []byte out rune outN int - err func(t testing.TB, err error) + err func(t testutil.TB, err error) }{ {"happy", []byte{'{', '4', '2', '}'}, 0x42, 4, testutil.OK}, {"badRune", []byte{'{', 0x80, 0x81}, 0, 1, testutil.Error}, @@ -37,7 +37,7 @@ func TestUnquote(t *testing.T) { name string in string out string - err func(t testing.TB, err error) + err func(t testutil.TB, err error) }{ {"happy", `"test"`, `test`, testutil.OK}, } diff --git a/internal/testutil/mocks_test.go b/internal/testutil/mocks_test.go new file mode 100644 index 00000000..38be4945 --- /dev/null +++ b/internal/testutil/mocks_test.go @@ -0,0 +1,117 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package testutil + +import ( + "sync" +) + +// Ensure, that TBMock does implement TB. +// If this is not the case, regenerate this file with moq. +var _ TB = &TBMock{} + +// TBMock is a mock implementation of TB. +// +// func TestSomethingThatUsesTB(t *testing.T) { +// +// // make and configure a mocked TB +// mockedTB := &TBMock{ +// FatalfFunc: func(format string, args ...any) { +// panic("mock out the Fatalf method") +// }, +// HelperFunc: func() { +// panic("mock out the Helper method") +// }, +// } +// +// // use mockedTB in code that requires TB +// // and then make assertions. +// +// } +type TBMock struct { + // FatalfFunc mocks the Fatalf method. + FatalfFunc func(format string, args ...any) + + // HelperFunc mocks the Helper method. + HelperFunc func() + + // calls tracks calls to the methods. + calls struct { + // Fatalf holds details about calls to the Fatalf method. + Fatalf []struct { + // Format is the format argument value. + Format string + // Args is the args argument value. + Args []any + } + // Helper holds details about calls to the Helper method. + Helper []struct { + } + } + lockFatalf sync.RWMutex + lockHelper sync.RWMutex +} + +// Fatalf calls FatalfFunc. +func (mock *TBMock) Fatalf(format string, args ...any) { + if mock.FatalfFunc == nil { + panic("TBMock.FatalfFunc: method is nil but TB.Fatalf was just called") + } + callInfo := struct { + Format string + Args []any + }{ + Format: format, + Args: args, + } + mock.lockFatalf.Lock() + mock.calls.Fatalf = append(mock.calls.Fatalf, callInfo) + mock.lockFatalf.Unlock() + mock.FatalfFunc(format, args...) +} + +// FatalfCalls gets all the calls that were made to Fatalf. +// Check the length with: +// +// len(mockedTB.FatalfCalls()) +func (mock *TBMock) FatalfCalls() []struct { + Format string + Args []any +} { + var calls []struct { + Format string + Args []any + } + mock.lockFatalf.RLock() + calls = mock.calls.Fatalf + mock.lockFatalf.RUnlock() + return calls +} + +// Helper calls HelperFunc. +func (mock *TBMock) Helper() { + if mock.HelperFunc == nil { + panic("TBMock.HelperFunc: method is nil but TB.Helper was just called") + } + callInfo := struct { + }{} + mock.lockHelper.Lock() + mock.calls.Helper = append(mock.calls.Helper, callInfo) + mock.lockHelper.Unlock() + mock.HelperFunc() +} + +// HelperCalls gets all the calls that were made to Helper. +// Check the length with: +// +// len(mockedTB.HelperCalls()) +func (mock *TBMock) HelperCalls() []struct { +} { + var calls []struct { + } + mock.lockHelper.RLock() + calls = mock.calls.Helper + mock.lockHelper.RUnlock() + return calls +} diff --git a/internal/testutil/testutil.go b/internal/testutil/testutil.go index 73ae253f..148372de 100644 --- a/internal/testutil/testutil.go +++ b/internal/testutil/testutil.go @@ -3,10 +3,16 @@ package testutil import ( "errors" "reflect" - "testing" ) -func Equals[T any](t testing.TB, a, b T) { +type TB interface { + Helper() + Fatalf(format string, args ...any) +} + +//go:generate moq -pkg testutil -fmt goimports -out mocks_test.go . TB + +func Equals[T any](t TB, a, b T) { t.Helper() if reflect.DeepEqual(a, b) { return @@ -14,7 +20,7 @@ func Equals[T any](t testing.TB, a, b T) { t.Fatalf("got %+v want %+v", a, b) } -func FatalIf(t testing.TB, c bool, f string, args ...any) { +func FatalIf(t TB, c bool, f string, args ...any) { t.Helper() if !c { return @@ -22,7 +28,7 @@ func FatalIf(t testing.TB, c bool, f string, args ...any) { t.Fatalf(f, args...) } -func OK(t testing.TB, err error) { +func OK(t TB, err error) { t.Helper() if err == nil { return @@ -30,7 +36,7 @@ func OK(t testing.TB, err error) { t.Fatalf("got %v want nil", err) } -func Error(t testing.TB, err error) { +func Error(t TB, err error) { t.Helper() if err != nil { return @@ -38,22 +44,18 @@ func Error(t testing.TB, err error) { t.Fatalf("got nil want error") } -func AssertError(t *testing.T, got, want error) { +func ErrorIs(t TB, got, want error) { t.Helper() - FatalIf(t, !errors.Is(got, want), "err got %v want %v", got, want) -} - -func Must[T any](obj T, err error) T { - if err != nil { - panic(err) + if !errors.Is(got, want) { + t.Fatalf("err got %v want %v", got, want) } - return obj } -func AssertPanic(t *testing.T, f func()) { +func Panic(t TB, f func()) { + t.Helper() defer func() { if e := recover(); e == nil { - t.Fatal("expected panic, got nil") + t.Fatalf("got nil want panic") } }() f() diff --git a/internal/testutil/testutil_test.go b/internal/testutil/testutil_test.go new file mode 100644 index 00000000..9331ceef --- /dev/null +++ b/internal/testutil/testutil_test.go @@ -0,0 +1,149 @@ +package testutil + +import ( + "fmt" + "testing" +) + +func newTB() *TBMock { + return &TBMock{ + HelperFunc: func() {}, + FatalfFunc: func(format string, args ...any) {}, + } +} + +func TestEquals(t *testing.T) { + t.Parallel() + + t.Run("Pass", func(t *testing.T) { + tb := newTB() + Equals(tb, 42, 42) + + // assertions + Equals(t, len(tb.HelperCalls()), 1) + Equals(t, len(tb.FatalfCalls()), 0) + }) + t.Run("Fail", func(t *testing.T) { + tb := newTB() + Equals(tb, 42, 43) + + // assertions + Equals(t, len(tb.HelperCalls()), 1) + Equals(t, len(tb.FatalfCalls()), 1) + }) +} + +func TestFatalIf(t *testing.T) { + t.Parallel() + + t.Run("Pass", func(t *testing.T) { + tb := newTB() + FatalIf(tb, false, "unused") + + // assertions + Equals(t, len(tb.HelperCalls()), 1) + Equals(t, len(tb.FatalfCalls()), 0) + }) + t.Run("Fail", func(t *testing.T) { + tb := newTB() + FatalIf(tb, true, "used") + + // assertions + Equals(t, len(tb.HelperCalls()), 1) + Equals(t, len(tb.FatalfCalls()), 1) + }) +} + +func TestOK(t *testing.T) { + t.Parallel() + + t.Run("Pass", func(t *testing.T) { + tb := newTB() + var err error + OK(tb, err) + + // assertions + Equals(t, len(tb.HelperCalls()), 1) + Equals(t, len(tb.FatalfCalls()), 0) + }) + t.Run("Fail", func(t *testing.T) { + tb := newTB() + err := fmt.Errorf("error") + OK(tb, err) + + // assertions + Equals(t, len(tb.HelperCalls()), 1) + Equals(t, len(tb.FatalfCalls()), 1) + }) +} + +func TestError(t *testing.T) { + t.Parallel() + + t.Run("Pass", func(t *testing.T) { + tb := newTB() + err := fmt.Errorf("error") + Error(tb, err) + + // assertions + Equals(t, len(tb.HelperCalls()), 1) + Equals(t, len(tb.FatalfCalls()), 0) + }) + t.Run("Fail", func(t *testing.T) { + tb := newTB() + var err error + Error(tb, err) + + // assertions + Equals(t, len(tb.HelperCalls()), 1) + Equals(t, len(tb.FatalfCalls()), 1) + }) +} + +func TestErrorIs(t *testing.T) { + t.Parallel() + + t.Run("Pass", func(t *testing.T) { + tb := newTB() + err := fmt.Errorf("error") + ErrorIs(tb, err, err) + + // assertions + Equals(t, len(tb.HelperCalls()), 1) + Equals(t, len(tb.FatalfCalls()), 0) + }) + t.Run("Fail", func(t *testing.T) { + tb := newTB() + err := fmt.Errorf("error") + err2 := fmt.Errorf("error2") + ErrorIs(tb, err, err2) + + // assertions + Equals(t, len(tb.HelperCalls()), 1) + Equals(t, len(tb.FatalfCalls()), 1) + }) +} + +func TestPanic(t *testing.T) { + t.Parallel() + + t.Run("Pass", func(t *testing.T) { + tb := newTB() + Panic(tb, func() { + panic("panic") + }) + + // assertions + Equals(t, len(tb.HelperCalls()), 1) + Equals(t, len(tb.FatalfCalls()), 0) + }) + t.Run("Fail", func(t *testing.T) { + tb := newTB() + Panic(tb, func() { + }) + + // assertions + Equals(t, len(tb.HelperCalls()), 1) + Equals(t, len(tb.FatalfCalls()), 1) + }) +} diff --git a/types/decimal_test.go b/types/decimal_test.go index b73b46e4..899ae3b9 100644 --- a/types/decimal_test.go +++ b/types/decimal_test.go @@ -101,7 +101,7 @@ func TestDecimal(t *testing.T) { t.Run(fmt.Sprintf("%s->%s", tt.in, tt.errStr), func(t *testing.T) { t.Parallel() _, err := types.ParseDecimal(tt.in) - testutil.AssertError(t, err, types.ErrDecimal) + testutil.ErrorIs(t, err, types.ErrDecimal) testutil.Equals(t, err.Error(), tt.errStr) }) } diff --git a/types/json_test.go b/types/json_test.go index cb8efc38..db7752b9 100644 --- a/types/json_test.go +++ b/types/json_test.go @@ -63,7 +63,7 @@ func TestJSON_Value(t *testing.T) { var got Value ptr := &got err := UnmarshalJSON([]byte(tt.in), ptr) - testutil.AssertError(t, err, tt.err) + testutil.ErrorIs(t, err, tt.err) AssertValue(t, got, tt.want) if tt.err != nil { return @@ -264,7 +264,7 @@ func TestTypedJSONUnmarshal(t *testing.T) { t.Parallel() gotValue, gotErr := tt.f([]byte(tt.in)) testutil.Equals(t, gotValue, tt.wantValue) - testutil.AssertError(t, gotErr, tt.wantErr) + testutil.ErrorIs(t, gotErr, tt.wantErr) }) } } diff --git a/types/patttern_test.go b/types/patttern_test.go index f92ef962..3f5a39ed 100644 --- a/types/patttern_test.go +++ b/types/patttern_test.go @@ -22,7 +22,7 @@ func TestPattern(t *testing.T) { }) t.Run("panicOnNil", func(t *testing.T) { t.Parallel() - testutil.AssertPanic(t, func() { + testutil.Panic(t, func() { NewPattern(nil) }) }) @@ -77,7 +77,7 @@ func TestPatternJSON(t *testing.T) { tests := []struct { name string pattern string - errFunc func(testing.TB, error) + errFunc func(testutil.TB, error) target Pattern shouldRoundTrip bool }{ From 2f33ef0e5bac07b141f5f48b8967b4df72ca1736 Mon Sep 17 00:00:00 2001 From: philhassey Date: Tue, 20 Aug 2024 17:26:57 -0600 Subject: [PATCH 199/216] internal/parser: add coverage for parser Addresses IDX-142 Signed-off-by: philhassey --- internal/parser/cedar_unmarshal_test.go | 59 +++++++++++++++++++++++++ internal/parser/internal_test.go | 33 ++++++++++++++ 2 files changed, 92 insertions(+) create mode 100644 internal/parser/internal_test.go diff --git a/internal/parser/cedar_unmarshal_test.go b/internal/parser/cedar_unmarshal_test.go index 937e1cd9..7d1317ba 100644 --- a/internal/parser/cedar_unmarshal_test.go +++ b/internal/parser/cedar_unmarshal_test.go @@ -2,6 +2,7 @@ package parser_test import ( "bytes" + "strings" "testing" "github.com/cedar-policy/cedar-go/internal/ast" @@ -525,3 +526,61 @@ func TestParsePolicySet(t *testing.T) { testutil.Equals(t, policies[1], (*parser.Policy)(expectedPolicy1)) }) } + +func TestParseApproximateErrors(t *testing.T) { + t.Parallel() + tests := []struct { + name string + in string + outErrSubstring string + }{ + {"unexpectedEffect", "!", "unexpected effect"}, + {"nul", "\x00", "invalid character"}, + {"notTerminated", `"`, "literal not terminated"}, + {"principalBadIsIn", `permit (principal is T in error);`, "got ) want ::"}, + {"principalBadIn", `permit (principal in error);`, "got ) want ::"}, + {"resourceBadEq", `permit (principal, action, resource == error);`, "got ) want ::"}, + {"resourceBadIsIn1", `permit (principal, action, resource is "error");`, "expected ident"}, + {"resourceBadIsIn1", `permit (principal, action, resource is T in error);`, "got ) want ::"}, + {"resourceBadIn", `permit (principal, action, resource in error);`, "got ) want ::"}, + {"unlessCondition", `permit (principal, action, resource) unless {`, "invalid primary"}, + {"or", `permit (principal, action, resource) when { true ||`, "invalid primary"}, + {"and", `permit (principal, action, resource) when { true &&`, "invalid primary"}, + {"isPath", `permit (principal, action, resource) when { context is`, "expected ident"}, + {"isIn", `permit (principal, action, resource) when { context is T in`, "invalid primary"}, + {"mult", `permit (principal, action, resource) when { 42 *`, "invalid primary"}, + {"parens", `permit (principal, action, resource) when { (42}`, "got } want )"}, + {"func", `permit (principal, action, resource) when { ip(}`, "invalid primary"}, + {"args", `permit (principal, action, resource) when { ip(42 42)`, "got 42 want ,"}, + {"dupeKey", `permit (principal, action, resource) when { {k:42,k:43}`, "duplicate key"}, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + var pol parser.Policy + err := pol.UnmarshalCedar([]byte(tt.in)) + testutil.FatalIf(t, !strings.Contains(err.Error(), tt.outErrSubstring), "got %v want %v", err.Error(), tt.outErrSubstring) + }) + } +} + +func TestPolicySliceErrors(t *testing.T) { + t.Parallel() + tests := []struct { + name string + in string + outErrSubstring string + }{ + {"notTerminated", `"`, "literal not terminated"}, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + var pol parser.PolicySlice + err := pol.UnmarshalCedar([]byte(tt.in)) + testutil.FatalIf(t, !strings.Contains(err.Error(), tt.outErrSubstring), "got %v want %v", err.Error(), tt.outErrSubstring) + }) + } +} diff --git a/internal/parser/internal_test.go b/internal/parser/internal_test.go new file mode 100644 index 00000000..1079d1a5 --- /dev/null +++ b/internal/parser/internal_test.go @@ -0,0 +1,33 @@ +package parser + +import ( + "testing" + + "github.com/cedar-policy/cedar-go/internal/ast" + "github.com/cedar-policy/cedar-go/internal/testutil" +) + +func TestScopeToNode(t *testing.T) { + t.Parallel() + t.Run("all", func(t *testing.T) { + t.Parallel() + x := scopeToNode(ast.NodeTypeVariable{Name: "principal"}, ast.ScopeTypeAll{}) + testutil.Equals(t, x, ast.True()) + }) + t.Run("panic", func(t *testing.T) { + t.Parallel() + testutil.Panic(t, func() { + scopeToNode(ast.NodeTypeVariable{Name: "principal"}, nil) + }) + }) +} + +func TestAstNodeToMarshalNode(t *testing.T) { + t.Parallel() + t.Run("panic", func(t *testing.T) { + t.Parallel() + testutil.Panic(t, func() { + astNodeToMarshalNode(nil) + }) + }) +} From fe9036bf7deb771aea3195913918209314507d4d Mon Sep 17 00:00:00 2001 From: philhassey Date: Tue, 20 Aug 2024 17:33:38 -0600 Subject: [PATCH 200/216] internal/testutil: appease linter Addresses IDX-142 Signed-off-by: philhassey --- internal/testutil/testutil_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/testutil/testutil_test.go b/internal/testutil/testutil_test.go index 9331ceef..5ca2a417 100644 --- a/internal/testutil/testutil_test.go +++ b/internal/testutil/testutil_test.go @@ -8,7 +8,7 @@ import ( func newTB() *TBMock { return &TBMock{ HelperFunc: func() {}, - FatalfFunc: func(format string, args ...any) {}, + FatalfFunc: func(string, ...any) {}, } } From b7d3b7150c2b79888939c9de62e5eb2d7d969e71 Mon Sep 17 00:00:00 2001 From: philhassey Date: Tue, 20 Aug 2024 17:42:03 -0600 Subject: [PATCH 201/216] cedar: add coverage to main policy package Addresses IDX-142 Signed-off-by: philhassey --- authorize_test.go | 19 +++++++++++++++++++ policy_set.go | 6 ------ policy_test.go | 14 ++++++++++++++ 3 files changed, 33 insertions(+), 6 deletions(-) diff --git a/authorize_test.go b/authorize_test.go index bab9fe43..f9ba8d11 100644 --- a/authorize_test.go +++ b/authorize_test.go @@ -4,6 +4,8 @@ import ( "encoding/json" "testing" + "github.com/cedar-policy/cedar-go/ast" + "github.com/cedar-policy/cedar-go/internal/eval" "github.com/cedar-policy/cedar-go/internal/testutil" "github.com/cedar-policy/cedar-go/types" ) @@ -728,6 +730,23 @@ func TestError(t *testing.T) { testutil.Equals(t, e.String(), "while evaluating policy `policy42`: bad error") } +type badEvaler struct{} + +func (e *badEvaler) Eval(*eval.Context) (types.Value, error) { + return types.Long(42), nil +} + +func TestBadEval(t *testing.T) { + t.Parallel() + ps := NewPolicySet() + pol := NewPolicyFromAST(ast.Permit()) + pol.eval = &badEvaler{} + ps.Upsert("pol", pol) + dec, diag := ps.IsAuthorized(nil, Request{}) + testutil.Equals(t, dec, Deny) + testutil.Equals(t, len(diag.Errors), 1) +} + func TestJSONDecision(t *testing.T) { t.Parallel() t.Run("MarshalAllow", func(t *testing.T) { diff --git a/policy_set.go b/policy_set.go index 9359dd0c..a90e8d9c 100644 --- a/policy_set.go +++ b/policy_set.go @@ -44,12 +44,6 @@ func (p PolicySet) Get(policyID PolicyID) Policy { return p.policies[policyID] } -// Has indicates if the policy exists. -func (p PolicySet) Has(policyID PolicyID) bool { - _, ok := p.policies[policyID] - return ok -} - // Upsert inserts or updates a policy with the given ID. func (p *PolicySet) Upsert(policyID PolicyID, policy Policy) { p.policies[policyID] = policy diff --git a/policy_test.go b/policy_test.go index 4f1e9b21..7778ba92 100644 --- a/policy_test.go +++ b/policy_test.go @@ -94,3 +94,17 @@ func TestPolicyAST(t *testing.T) { _ = cedar.NewPolicyFromAST(astExample) } + +func TestUnmarshalJSONPolicyErr(t *testing.T) { + t.Parallel() + var p cedar.Policy + err := p.UnmarshalJSON([]byte("!@#$")) + testutil.Error(t, err) +} + +func TestUnmarshalCedarPolicyErr(t *testing.T) { + t.Parallel() + var p cedar.Policy + err := p.UnmarshalCedar([]byte("!@#$")) + testutil.Error(t, err) +} From bfd35bf96831aaf248fd3e99262b9f9b291064f7 Mon Sep 17 00:00:00 2001 From: philhassey Date: Wed, 21 Aug 2024 09:28:02 -0600 Subject: [PATCH 202/216] cedar: update doc strings Addresses IDX-142 Signed-off-by: philhassey --- policy.go | 6 +++++- policy_set.go => policy_map.go | 1 + policy_set_test.go => policy_map_test.go | 4 ++-- policy_slice.go | 16 ++++++++-------- policy_slice_test.go | 5 ++--- 5 files changed, 18 insertions(+), 14 deletions(-) rename policy_set.go => policy_map.go (97%) rename policy_set_test.go => policy_map_test.go (98%) diff --git a/policy.go b/policy.go index ceab5535..cc1e77ef 100644 --- a/policy.go +++ b/policy.go @@ -56,11 +56,11 @@ func (p *Policy) UnmarshalCedar(b []byte) error { if err := cedarPolicy.UnmarshalCedar(b); err != nil { return err } - *p = newPolicy((*internalast.Policy)(&cedarPolicy)) return nil } +// NewPolicyFromAST lets you create a new policy statement from a programatically created AST. func NewPolicyFromAST(astIn *ast.Policy) Policy { p := newPolicy((*internalast.Policy)(astIn)) return p @@ -70,6 +70,7 @@ func NewPolicyFromAST(astIn *ast.Policy) Policy { // have no impact on policy evaluation. type Annotations map[types.Ident]types.String +// Annotations retrieves the annotations associated with this policy. func (p Policy) Annotations() Annotations { res := make(Annotations, len(p.ast.Annotations)) for _, e := range p.ast.Annotations { @@ -88,6 +89,7 @@ const ( Forbid = Effect(false) ) +// Effect retrieves the effect of this policy. func (p Policy) Effect() Effect { return Effect(p.ast.Effect) } @@ -100,10 +102,12 @@ type Position struct { Column int // column number, starting at 1 (character count per line) } +// Position retrieves the position of this policy. func (p Policy) Position() Position { return Position(p.ast.Position) } +// SetFilename sets the filename of this policy. func (p *Policy) SetFilename(fileName string) { p.ast.Position.Filename = fileName } diff --git a/policy_set.go b/policy_map.go similarity index 97% rename from policy_set.go rename to policy_map.go index a90e8d9c..01a01299 100644 --- a/policy_set.go +++ b/policy_map.go @@ -7,6 +7,7 @@ import ( "slices" ) +// PolicyID is a string identifier for the policy within the PolicySet type PolicyID string type policyMap map[PolicyID]Policy diff --git a/policy_set_test.go b/policy_map_test.go similarity index 98% rename from policy_set_test.go rename to policy_map_test.go index 85413076..255a5768 100644 --- a/policy_set_test.go +++ b/policy_map_test.go @@ -157,8 +157,8 @@ forbid ( resource );` - var policies cedar.Policies - testutil.OK(t, policies.UnmarshalCedar([]byte(policiesStr))) + policies, err := cedar.NewPoliciesFromBytes("", []byte(policiesStr)) + testutil.OK(t, err) ps := cedar.NewPolicySet() for i, p := range policies { diff --git a/policy_slice.go b/policy_slice.go index d2bee17d..5f7470d2 100644 --- a/policy_slice.go +++ b/policy_slice.go @@ -8,14 +8,14 @@ import ( "github.com/cedar-policy/cedar-go/internal/parser" ) -// Policies represents a set of un-named Policy's. Cedar documents, unlike the JSON format, don't have a means of +// PolicyList represents a list of un-named Policy's. Cedar documents, unlike the PolicySet form, don't have a means of // naming individual policies. -type Policies []Policy +type PolicyList []Policy -// NewPoliciesFromBytes will create a PolicySet from the given text document with the given file name used in Position +// NewPoliciesFromBytes will create a Policies from the given text document with the given file name used in Position // data. If there is an error parsing the document, it will be returned. -func NewPoliciesFromBytes(fileName string, document []byte) (Policies, error) { - var policySlice Policies +func NewPoliciesFromBytes(fileName string, document []byte) (PolicyList, error) { + var policySlice PolicyList if err := policySlice.UnmarshalCedar(document); err != nil { return nil, err } @@ -27,7 +27,7 @@ func NewPoliciesFromBytes(fileName string, document []byte) (Policies, error) { // UnmarshalCedar parses a concatenation of un-named Cedar policy statements. Names can be assigned to these policies // when adding them to a PolicySet. -func (p *Policies) UnmarshalCedar(b []byte) error { +func (p *PolicyList) UnmarshalCedar(b []byte) error { var res parser.PolicySlice if err := res.UnmarshalCedar(b); err != nil { return fmt.Errorf("parser error: %w", err) @@ -41,8 +41,8 @@ func (p *Policies) UnmarshalCedar(b []byte) error { return nil } -// MarshalCedar emits a concatenated Cedar representation of a PolicySlice -func (p Policies) MarshalCedar() []byte { +// MarshalCedar emits a concatenated Cedar representation of the policies. +func (p PolicyList) MarshalCedar() []byte { var buf bytes.Buffer for i, policy := range p { buf.Write(policy.MarshalCedar()) diff --git a/policy_slice_test.go b/policy_slice_test.go index 152e47a2..28479ef7 100644 --- a/policy_slice_test.go +++ b/policy_slice_test.go @@ -23,8 +23,7 @@ forbid ( resource );` - var policies cedar.Policies - testutil.OK(t, policies.UnmarshalCedar([]byte(policiesStr))) - + policies, err := cedar.NewPoliciesFromBytes("", []byte(policiesStr)) + testutil.OK(t, err) testutil.Equals(t, string(policies.MarshalCedar()), policiesStr) } From 065925cb3c174322103dccec5561e26171518e27 Mon Sep 17 00:00:00 2001 From: philhassey Date: Wed, 21 Aug 2024 11:00:30 -0600 Subject: [PATCH 203/216] cedar: improve docs Addresses IDX-142 Signed-off-by: philhassey --- policy.go | 1 + policy_map.go | 2 +- policy_map_test.go | 2 +- policy_slice.go | 4 ++-- policy_slice_test.go | 2 +- 5 files changed, 6 insertions(+), 5 deletions(-) diff --git a/policy.go b/policy.go index cc1e77ef..ba344f83 100644 --- a/policy.go +++ b/policy.go @@ -61,6 +61,7 @@ func (p *Policy) UnmarshalCedar(b []byte) error { } // NewPolicyFromAST lets you create a new policy statement from a programatically created AST. +// Do not modify the *ast.Policy after passing it into NewPolicyFromAST. func NewPolicyFromAST(astIn *ast.Policy) Policy { p := newPolicy((*internalast.Policy)(astIn)) return p diff --git a/policy_map.go b/policy_map.go index 01a01299..2f0a4244 100644 --- a/policy_map.go +++ b/policy_map.go @@ -28,7 +28,7 @@ func NewPolicySet() PolicySet { // NewPolicySetFromBytes assigns default PolicyIDs to the policies contained in fileName in the format "policy" where // is incremented for each new policy found in the file. func NewPolicySetFromBytes(fileName string, document []byte) (PolicySet, error) { - policySlice, err := NewPoliciesFromBytes(fileName, document) + policySlice, err := NewPolicyListFromBytes(fileName, document) if err != nil { return PolicySet{}, err } diff --git a/policy_map_test.go b/policy_map_test.go index 255a5768..a4e59742 100644 --- a/policy_map_test.go +++ b/policy_map_test.go @@ -157,7 +157,7 @@ forbid ( resource );` - policies, err := cedar.NewPoliciesFromBytes("", []byte(policiesStr)) + policies, err := cedar.NewPolicyListFromBytes("", []byte(policiesStr)) testutil.OK(t, err) ps := cedar.NewPolicySet() diff --git a/policy_slice.go b/policy_slice.go index 5f7470d2..116d4892 100644 --- a/policy_slice.go +++ b/policy_slice.go @@ -12,9 +12,9 @@ import ( // naming individual policies. type PolicyList []Policy -// NewPoliciesFromBytes will create a Policies from the given text document with the given file name used in Position +// NewPolicyListFromBytes will create a Policies from the given text document with the given file name used in Position // data. If there is an error parsing the document, it will be returned. -func NewPoliciesFromBytes(fileName string, document []byte) (PolicyList, error) { +func NewPolicyListFromBytes(fileName string, document []byte) (PolicyList, error) { var policySlice PolicyList if err := policySlice.UnmarshalCedar(document); err != nil { return nil, err diff --git a/policy_slice_test.go b/policy_slice_test.go index 28479ef7..c052206c 100644 --- a/policy_slice_test.go +++ b/policy_slice_test.go @@ -23,7 +23,7 @@ forbid ( resource );` - policies, err := cedar.NewPoliciesFromBytes("", []byte(policiesStr)) + policies, err := cedar.NewPolicyListFromBytes("", []byte(policiesStr)) testutil.OK(t, err) testutil.Equals(t, string(policies.MarshalCedar()), policiesStr) } From 3b3808723af1427f919ea0076a3c8483c18a344a Mon Sep 17 00:00:00 2001 From: philhassey Date: Wed, 21 Aug 2024 11:30:17 -0600 Subject: [PATCH 204/216] types: improve sugar of NewPattern Addresses IDX-142 Signed-off-by: philhassey --- ast/ast_test.go | 4 +- internal/ast/ast_test.go | 4 +- internal/json/json_test.go | 6 +-- internal/parser/cedar_unmarshal_test.go | 2 +- internal/parser/pattern.go | 4 +- internal/parser/pattern_test.go | 30 +++++++-------- types/pattern.go | 28 +++++--------- types/patttern_test.go | 50 +++++++++++-------------- 8 files changed, 56 insertions(+), 72 deletions(-) diff --git a/ast/ast_test.go b/ast/ast_test.go index afa76c52..886c9275 100644 --- a/ast/ast_test.go +++ b/ast/ast_test.go @@ -295,8 +295,8 @@ func TestASTByTable(t *testing.T) { }, { "opLike", - ast.Permit().When(ast.Long(42).Like(types.NewPattern(types.Wildcard()))), - internalast.Permit().When(internalast.Long(42).Like(types.NewPattern(types.Wildcard()))), + ast.Permit().When(ast.Long(42).Like(types.NewPattern(types.Wildcard{}))), + internalast.Permit().When(internalast.Long(42).Like(types.NewPattern(types.Wildcard{}))), }, { "opAnd", diff --git a/internal/ast/ast_test.go b/internal/ast/ast_test.go index 68d0d689..acc6ad2f 100644 --- a/internal/ast/ast_test.go +++ b/internal/ast/ast_test.go @@ -350,9 +350,9 @@ func TestASTByTable(t *testing.T) { }, { "opLike", - ast.Permit().When(ast.Long(42).Like(types.NewPattern(types.Wildcard()))), + ast.Permit().When(ast.Long(42).Like(types.NewPattern(types.Wildcard{}))), ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, - Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeLike{Arg: ast.NodeValue{Value: types.Long(42)}, Value: types.NewPattern(types.Wildcard())}}}}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeLike{Arg: ast.NodeValue{Value: types.Long(42)}, Value: types.NewPattern(types.Wildcard{})}}}}, }, { "opAnd", diff --git a/internal/json/json_test.go b/internal/json/json_test.go index 65ce4526..47e6a653 100644 --- a/internal/json/json_test.go +++ b/internal/json/json_test.go @@ -399,7 +399,7 @@ func TestUnmarshalJSON(t *testing.T) { "like single wildcard", `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, "conditions":[{"kind":"when","body":{"like":{"left":{"Value":"text"},"pattern":["Wildcard"]}}}]}`, - ast.Permit().When(ast.String("text").Like(types.NewPattern(types.Wildcard()))), + ast.Permit().When(ast.String("text").Like(types.NewPattern(types.Wildcard{}))), testutil.OK, }, { @@ -413,14 +413,14 @@ func TestUnmarshalJSON(t *testing.T) { "like wildcard then literal", `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, "conditions":[{"kind":"when","body":{"like":{"left":{"Value":"text"},"pattern":["Wildcard", {"Literal":"foo"}]}}}]}`, - ast.Permit().When(ast.String("text").Like(types.NewPattern(types.Wildcard(), types.String("foo")))), + ast.Permit().When(ast.String("text").Like(types.NewPattern(types.Wildcard{}, types.String("foo")))), testutil.OK, }, { "like literal then wildcard", `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}, "conditions":[{"kind":"when","body":{"like":{"left":{"Value":"text"},"pattern":[{"Literal":"foo"}, "Wildcard"]}}}]}`, - ast.Permit().When(ast.String("text").Like(types.NewPattern(types.String("foo"), types.Wildcard()))), + ast.Permit().When(ast.String("text").Like(types.NewPattern(types.String("foo"), types.Wildcard{}))), testutil.OK, }, { diff --git a/internal/parser/cedar_unmarshal_test.go b/internal/parser/cedar_unmarshal_test.go index 7d1317ba..743c93b7 100644 --- a/internal/parser/cedar_unmarshal_test.go +++ b/internal/parser/cedar_unmarshal_test.go @@ -301,7 +301,7 @@ when { principal.firstName like "joh\*nny" };`, "like wildcard", `permit ( principal, action, resource ) when { principal.firstName like "*" };`, - ast.Permit().When(ast.Principal().Access("firstName").Like(types.NewPattern(types.Wildcard()))), + ast.Permit().When(ast.Principal().Access("firstName").Like(types.NewPattern(types.Wildcard{}))), }, { "is", diff --git a/internal/parser/pattern.go b/internal/parser/pattern.go index 0450bf77..a0e01340 100644 --- a/internal/parser/pattern.go +++ b/internal/parser/pattern.go @@ -8,11 +8,11 @@ import ( // ParsePattern will parse an unquoted rust-style string with \*'s in it. func ParsePattern(v string) (types.Pattern, error) { b := []byte(v) - var comps []types.PatternComponent + var comps []any for len(b) > 0 { for len(b) > 0 && b[0] == '*' { b = b[1:] - comps = append(comps, types.Wildcard()) + comps = append(comps, types.Wildcard{}) } var err error var literal string diff --git a/internal/parser/pattern_test.go b/internal/parser/pattern_test.go index 852d5069..8c7151c1 100644 --- a/internal/parser/pattern_test.go +++ b/internal/parser/pattern_test.go @@ -18,21 +18,21 @@ func TestParsePattern(t *testing.T) { }{ {"", true, types.NewPattern(), ""}, {"a", true, types.NewPattern(a), ""}, - {"*", true, types.NewPattern(types.Wildcard()), ""}, - {"*a", true, types.NewPattern(types.Wildcard(), a), ""}, - {"a*", true, types.NewPattern(a, types.Wildcard()), ""}, - {"**", true, types.NewPattern(types.Wildcard()), ""}, - {"**a", true, types.NewPattern(types.Wildcard(), a), ""}, - {"a**", true, types.NewPattern(a, types.Wildcard()), ""}, - {"*a*", true, types.NewPattern(types.Wildcard(), a, types.Wildcard()), ""}, - {"**a**", true, types.NewPattern(types.Wildcard(), a, types.Wildcard()), ""}, - {"abra*ca", true, types.NewPattern(types.String("abra"), types.Wildcard(), types.String("ca")), ""}, - {"abra**ca", true, types.NewPattern(types.String("abra"), types.Wildcard(), types.String("ca")), ""}, - {"*abra*ca", true, types.NewPattern(types.Wildcard(), types.String("abra"), types.Wildcard(), types.String("ca")), ""}, - {"abra*ca*", true, types.NewPattern(types.String("abra"), types.Wildcard(), types.String("ca"), types.Wildcard()), ""}, - {"*abra*ca*", true, types.NewPattern(types.Wildcard(), types.String("abra"), types.Wildcard(), types.String("ca"), types.Wildcard()), ""}, - {"*abra*ca*dabra", true, types.NewPattern(types.Wildcard(), types.String("abra"), types.Wildcard(), types.String("ca"), types.Wildcard(), types.String("dabra")), ""}, - {`*abra*c\**da\*bra`, true, types.NewPattern(types.Wildcard(), types.String("abra"), types.Wildcard(), types.String("c*"), types.Wildcard(), types.String("da*bra")), ""}, + {"*", true, types.NewPattern(types.Wildcard{}), ""}, + {"*a", true, types.NewPattern(types.Wildcard{}, a), ""}, + {"a*", true, types.NewPattern(a, types.Wildcard{}), ""}, + {"**", true, types.NewPattern(types.Wildcard{}), ""}, + {"**a", true, types.NewPattern(types.Wildcard{}, a), ""}, + {"a**", true, types.NewPattern(a, types.Wildcard{}), ""}, + {"*a*", true, types.NewPattern(types.Wildcard{}, a, types.Wildcard{}), ""}, + {"**a**", true, types.NewPattern(types.Wildcard{}, a, types.Wildcard{}), ""}, + {"abra*ca", true, types.NewPattern(types.String("abra"), types.Wildcard{}, types.String("ca")), ""}, + {"abra**ca", true, types.NewPattern(types.String("abra"), types.Wildcard{}, types.String("ca")), ""}, + {"*abra*ca", true, types.NewPattern(types.Wildcard{}, types.String("abra"), types.Wildcard{}, types.String("ca")), ""}, + {"abra*ca*", true, types.NewPattern(types.String("abra"), types.Wildcard{}, types.String("ca"), types.Wildcard{}), ""}, + {"*abra*ca*", true, types.NewPattern(types.Wildcard{}, types.String("abra"), types.Wildcard{}, types.String("ca"), types.Wildcard{}), ""}, + {"*abra*ca*dabra", true, types.NewPattern(types.Wildcard{}, types.String("abra"), types.Wildcard{}, types.String("ca"), types.Wildcard{}, types.String("dabra")), ""}, + {`*abra*c\**da\*bra`, true, types.NewPattern(types.Wildcard{}, types.String("abra"), types.Wildcard{}, types.String("c*"), types.Wildcard{}, types.String("da*bra")), ""}, {`\u`, false, types.Pattern{}, "bad unicode rune"}, } for _, tt := range tests { diff --git a/types/pattern.go b/types/pattern.go index ba5288e7..48c44fb6 100644 --- a/types/pattern.go +++ b/types/pattern.go @@ -21,27 +21,17 @@ type Pattern struct { comps []patternComponent } -// A PatternComponent is either a wildcard (represented as "*" in Cedar text) or a literal string. Note that * -// characters in literal strings are treated as literal asterisks rather than wildcards. -type PatternComponent interface { - isPatternComponent() -} - -type wildcardComponent struct{} - -func (wildcardComponent) isPatternComponent() {} - -func (String) isPatternComponent() {} - -// Wildcard is a constant which can be used to conveniently construct an instance of WildcardPatternComponent -func Wildcard() PatternComponent { return wildcardComponent{} } +// Wildcard is a type which is used as a component to NewPattern. +type Wildcard struct{} -// NewPattern permits for the programmatic construction of a Pattern out of a set of PatternComponents. -func NewPattern(components ...PatternComponent) Pattern { +// NewPattern permits for the programmatic construction of a Pattern out of a slice of pattern components. +// The pattern components may be one of string, types.String, or types.Wildcard. Any other types will +// cause a panic. +func NewPattern(components ...any) Pattern { var comps []patternComponent for _, c := range components { switch v := c.(type) { - case wildcardComponent: + case Wildcard: if len(comps) == 0 || comps[len(comps)-1].Literal != "" { comps = append(comps, patternComponent{Wildcard: true, Literal: ""}) } @@ -173,14 +163,14 @@ func (p *Pattern) UnmarshalJSON(b []byte) error { return fmt.Errorf(`%w: must provide at least one pattern component`, errJSONInvalidPatternComponent) } - var comps []PatternComponent + var comps []any for _, comp := range objs { switch v := comp.(type) { case string: if v != "Wildcard" { return fmt.Errorf(`%w: invalid component string "%v"`, errJSONInvalidPatternComponent, v) } - comps = append(comps, Wildcard()) + comps = append(comps, Wildcard{}) case map[string]any: if len(v) != 1 { return fmt.Errorf(`%w: too many keys in literal object`, errJSONInvalidPatternComponent) diff --git a/types/patttern_test.go b/types/patttern_test.go index 3f5a39ed..b47ef791 100644 --- a/types/patttern_test.go +++ b/types/patttern_test.go @@ -10,8 +10,8 @@ func TestPattern(t *testing.T) { t.Parallel() t.Run("saturate two wildcards", func(t *testing.T) { t.Parallel() - pattern1 := NewPattern(Wildcard(), Wildcard()) - pattern2 := NewPattern(Wildcard()) + pattern1 := NewPattern(Wildcard{}, Wildcard{}) + pattern2 := NewPattern(Wildcard{}) testutil.Equals(t, pattern1, pattern2) }) t.Run("saturate two literals", func(t *testing.T) { @@ -28,13 +28,7 @@ func TestPattern(t *testing.T) { }) t.Run("MarshalCedar", func(t *testing.T) { t.Parallel() - testutil.Equals(t, string(NewPattern(String("*foo"), Wildcard()).MarshalCedar()), `"\*foo*"`) - }) - - t.Run("isPatternComponent", func(t *testing.T) { - t.Parallel() - String("").isPatternComponent() - Wildcard().isPatternComponent() + testutil.Equals(t, string(NewPattern(String("*foo"), Wildcard{}).MarshalCedar()), `"\*foo*"`) }) } @@ -47,20 +41,20 @@ func TestPatternMatch(t *testing.T) { }{ {NewPattern(), "", true}, {NewPattern(), "hello", false}, - {NewPattern(Wildcard()), "hello", true}, + {NewPattern(Wildcard{}), "hello", true}, {NewPattern(String("e")), "hello", false}, - {NewPattern(Wildcard(), String("e")), "hello", false}, - {NewPattern(Wildcard(), String("e"), Wildcard()), "hello", true}, + {NewPattern(Wildcard{}, String("e")), "hello", false}, + {NewPattern(Wildcard{}, String("e"), Wildcard{}), "hello", true}, {NewPattern(String("hello")), "hello", true}, - {NewPattern(String("hello"), Wildcard()), "hello", true}, - {NewPattern(Wildcard(), String("h"), Wildcard(), String("llo"), Wildcard()), "hello", true}, - {NewPattern(String("h"), Wildcard(), String("e"), Wildcard(), String("o")), "hello", true}, - {NewPattern(String("h"), Wildcard(), String("e"), Wildcard(), Wildcard(), String("o")), "hello", true}, - {NewPattern(String("h"), Wildcard(), String("z"), Wildcard(), String("o")), "hello", false}, + {NewPattern(String("hello"), Wildcard{}), "hello", true}, + {NewPattern(Wildcard{}, String("h"), Wildcard{}, String("llo"), Wildcard{}), "hello", true}, + {NewPattern(String("h"), Wildcard{}, String("e"), Wildcard{}, String("o")), "hello", true}, + {NewPattern(String("h"), Wildcard{}, String("e"), Wildcard{}, Wildcard{}, String("o")), "hello", true}, + {NewPattern(String("h"), Wildcard{}, String("z"), Wildcard{}, String("o")), "hello", false}, - {NewPattern(String("**"), Wildcard(), String("**")), "**foo**", true}, - {NewPattern(String("**"), Wildcard(), String("**")), "**bar**", true}, - {NewPattern(String("**"), Wildcard(), String("**")), "*bar*", false}, + {NewPattern(String("**"), Wildcard{}, String("**")), "**foo**", true}, + {NewPattern(String("**"), Wildcard{}, String("**")), "**bar**", true}, + {NewPattern(String("**"), Wildcard{}, String("**")), "*bar*", false}, } for _, tt := range tests { tt := tt @@ -85,7 +79,7 @@ func TestPatternJSON(t *testing.T) { "like single wildcard", `["Wildcard"]`, testutil.OK, - NewPattern(Wildcard()), + NewPattern(Wildcard{}), true, }, { @@ -99,49 +93,49 @@ func TestPatternJSON(t *testing.T) { "like wildcard then literal", `["Wildcard", {"Literal":"foo"}]`, testutil.OK, - NewPattern(Wildcard(), String("foo")), + NewPattern(Wildcard{}, String("foo")), true, }, { "like literal then wildcard", `[{"Literal":"foo"}, "Wildcard"]`, testutil.OK, - NewPattern(String("foo"), Wildcard()), + NewPattern(String("foo"), Wildcard{}), true, }, { "like literal with asterisk then wildcard", `[{"Literal":"f*oo"}, "Wildcard"]`, testutil.OK, - NewPattern(String("f*oo"), Wildcard()), + NewPattern(String("f*oo"), Wildcard{}), true, }, { "like literal sandwich", `[{"Literal":"foo"}, "Wildcard", {"Literal":"bar"}]`, testutil.OK, - NewPattern(String("foo"), Wildcard(), String("bar")), + NewPattern(String("foo"), Wildcard{}, String("bar")), true, }, { "like wildcard sandwich", `["Wildcard", {"Literal":"foo"}, "Wildcard"]`, testutil.OK, - NewPattern(Wildcard(), String("foo"), Wildcard()), + NewPattern(Wildcard{}, String("foo"), Wildcard{}), true, }, { "double wildcard", `["Wildcard", "Wildcard", {"Literal":"foo"}]`, testutil.OK, - NewPattern(Wildcard(), String("foo")), + NewPattern(Wildcard{}, String("foo")), false, }, { "double literal", `["Wildcard", {"Literal":"foo"}, {"Literal":"bar"}]`, testutil.OK, - NewPattern(Wildcard(), String("foobar")), + NewPattern(Wildcard{}, String("foobar")), false, }, { From bb441b5c2432d81c9bd899a9dd859c28846432fd Mon Sep 17 00:00:00 2001 From: philhassey Date: Wed, 21 Aug 2024 11:32:02 -0600 Subject: [PATCH 205/216] types: add coverage for string in NewPattern Addresses IDX-142 Signed-off-by: philhassey --- types/pattern.go | 11 ++++++++--- types/patttern_test.go | 4 ++++ 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/types/pattern.go b/types/pattern.go index 48c44fb6..99add627 100644 --- a/types/pattern.go +++ b/types/pattern.go @@ -31,15 +31,20 @@ func NewPattern(components ...any) Pattern { var comps []patternComponent for _, c := range components { switch v := c.(type) { - case Wildcard: - if len(comps) == 0 || comps[len(comps)-1].Literal != "" { - comps = append(comps, patternComponent{Wildcard: true, Literal: ""}) + case string: + if len(comps) == 0 { + comps = []patternComponent{{Wildcard: false, Literal: ""}} } + comps[len(comps)-1].Literal += string(v) case String: if len(comps) == 0 { comps = []patternComponent{{Wildcard: false, Literal: ""}} } comps[len(comps)-1].Literal += string(v) + case Wildcard: + if len(comps) == 0 || comps[len(comps)-1].Literal != "" { + comps = append(comps, patternComponent{Wildcard: true, Literal: ""}) + } default: panic(fmt.Sprintf("unexpected component type: %T", v)) } diff --git a/types/patttern_test.go b/types/patttern_test.go index b47ef791..4ec6ea84 100644 --- a/types/patttern_test.go +++ b/types/patttern_test.go @@ -55,6 +55,10 @@ func TestPatternMatch(t *testing.T) { {NewPattern(String("**"), Wildcard{}, String("**")), "**foo**", true}, {NewPattern(String("**"), Wildcard{}, String("**")), "**bar**", true}, {NewPattern(String("**"), Wildcard{}, String("**")), "*bar*", false}, + + // with native strings + {NewPattern(Wildcard{}, "ell", Wildcard{}), "hello", true}, + {NewPattern("he", Wildcard{}), "hello", true}, } for _, tt := range tests { tt := tt From 58767268303af6ff61b771f431874597351be67f Mon Sep 17 00:00:00 2001 From: philhassey Date: Wed, 21 Aug 2024 11:41:38 -0600 Subject: [PATCH 206/216] cedar: add Map method to PolicySet Signed-off-by: philhassey --- authorize_test.go | 2 +- policy_map.go | 28 ++++++++--------- policy_map_test.go | 75 ++++++++++------------------------------------ 3 files changed, 30 insertions(+), 75 deletions(-) diff --git a/authorize_test.go b/authorize_test.go index f9ba8d11..00ccfe88 100644 --- a/authorize_test.go +++ b/authorize_test.go @@ -741,7 +741,7 @@ func TestBadEval(t *testing.T) { ps := NewPolicySet() pol := NewPolicyFromAST(ast.Permit()) pol.eval = &badEvaler{} - ps.Upsert("pol", pol) + ps.Set("pol", pol) dec, diag := ps.IsAuthorized(nil, Request{}) testutil.Equals(t, dec, Deny) testutil.Equals(t, len(diag.Errors), 1) diff --git a/policy_map.go b/policy_map.go index 2f0a4244..8b059f9f 100644 --- a/policy_map.go +++ b/policy_map.go @@ -4,22 +4,25 @@ package cedar import ( "bytes" "fmt" + "maps" "slices" ) // PolicyID is a string identifier for the policy within the PolicySet type PolicyID string -type policyMap map[PolicyID]Policy +// PolicyMap is a map of policy IDs to policy +type PolicyMap map[PolicyID]Policy // PolicySet is a set of named policies against which a request can be authorized. type PolicySet struct { - policies policyMap + // policies are stored internally so we can handle performance, concurrency bookkeeping however we want + policies PolicyMap } // NewPolicySet creates a new, empty PolicySet func NewPolicySet() PolicySet { - return PolicySet{policies: policyMap{}} + return PolicySet{policies: PolicyMap{}} } // NewPolicySetFromBytes will create a PolicySet from the given text document with the given file name used in Position @@ -32,7 +35,7 @@ func NewPolicySetFromBytes(fileName string, document []byte) (PolicySet, error) if err != nil { return PolicySet{}, err } - policyMap := make(policyMap, len(policySlice)) + policyMap := make(PolicyMap, len(policySlice)) for i, p := range policySlice { policyID := PolicyID(fmt.Sprintf("policy%d", i)) policyMap[policyID] = p @@ -40,13 +43,13 @@ func NewPolicySetFromBytes(fileName string, document []byte) (PolicySet, error) return PolicySet{policies: policyMap}, nil } -// Get returns a pointer to the Policy with the given ID. If a policy with the given ID does not exist, an empty policy is returned. +// Get returns the Policy with the given ID. If a policy with the given ID does not exist, an empty policy is returned. func (p PolicySet) Get(policyID PolicyID) Policy { return p.policies[policyID] } -// Upsert inserts or updates a policy with the given ID. -func (p *PolicySet) Upsert(policyID PolicyID, policy Policy) { +// Set inserts or updates a policy with the given ID. +func (p *PolicySet) Set(policyID PolicyID, policy Policy) { p.policies[policyID] = policy } @@ -55,13 +58,10 @@ func (p *PolicySet) Delete(policyID PolicyID) { delete(p.policies, policyID) } -// // UpsertPolicySet inserts or updates all the policies from src into this PolicySet. Policies in this PolicySet with -// // identical IDs in src are clobbered by the policies from src. -// func (p *PolicySet) UpsertPolicySet(src PolicySet) { -// for id, policy := range src.policies { -// p.policies[id] = policy -// } -// } +// Map returns a new PolicyMap instance of the policies in the PolicySet. +func (p *PolicySet) Map() PolicyMap { + return maps.Clone(p.policies) +} // MarshalCedar emits a concatenated Cedar representation of a PolicySet. The policy names are stripped, but policies // are emitted in lexicographical order by ID. diff --git a/policy_map_test.go b/policy_map_test.go index a4e59742..e90ec203 100644 --- a/policy_map_test.go +++ b/policy_map_test.go @@ -42,8 +42,8 @@ func TestUpsertPolicy(t *testing.T) { )) ps := cedar.NewPolicySet() - ps.Upsert("policy0", policy0) - ps.Upsert("policy1", policy1) + ps.Set("policy0", policy0) + ps.Set("policy1", policy1) testutil.Equals(t, ps.Get("policy0"), policy0) testutil.Equals(t, ps.Get("policy1"), policy1) @@ -55,69 +55,15 @@ func TestUpsertPolicy(t *testing.T) { ps := cedar.NewPolicySet() p1 := cedar.NewPolicyFromAST(ast.Forbid()) - ps.Upsert("a wavering policy", p1) + ps.Set("a wavering policy", p1) p2 := cedar.NewPolicyFromAST(ast.Permit()) - ps.Upsert("a wavering policy", p2) + ps.Set("a wavering policy", p2) testutil.Equals(t, ps.Get("a wavering policy"), p2) }) } -// func TestUpsertPolicySet(t *testing.T) { -// t.Parallel() -// t.Run("empty dst", func(t *testing.T) { -// t.Parallel() - -// policy0 := cedar.NewPolicyFromAST(ast.Forbid()) - -// var policy1 cedar.Policy -// testutil.OK(t, policy1.UnmarshalJSON( -// []byte(`{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}}`), -// )) - -// ps1 := cedar.NewPolicySet() -// ps1.Upsert("policy0", policy0) -// ps1.Upsert("policy1", policy1) - -// ps2 := cedar.NewPolicySet() -// ps2.UpsertPolicySet(ps1) - -// testutil.Equals(t, ps2.Get("policy0"), policy0) -// testutil.Equals(t, ps2.Get("policy1"), &policy1) -// testutil.Equals(t, ps2.Get("policy2"), nil) -// }) -// t.Run("upsert", func(t *testing.T) { -// t.Parallel() - -// policyA := cedar.NewPolicyFromAST(ast.Forbid()) - -// var policyB cedar.Policy -// testutil.OK(t, policyB.UnmarshalJSON( -// []byte(`{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}}`), -// )) - -// policyC := cedar.NewPolicyFromAST(ast.Permit()) - -// // ps1 maps 0 -> A and 1 -> B -// ps1 := cedar.NewPolicySet() -// ps1.Upsert("policy0", policyA) -// ps1.Upsert("policy1", &policyB) - -// // ps1 maps 0 -> b and 2 -> C -// ps2 := cedar.NewPolicySet() -// ps2.Upsert("policy0", &policyB) -// ps2.Upsert("policy2", policyC) - -// // Upsert should clobber ps2's policy0, insert policy1, and leave policy2 untouched -// ps2.UpsertPolicySet(ps1) - -// testutil.Equals(t, ps2.Get("policy0"), policyA) -// testutil.Equals(t, ps2.Get("policy1"), &policyB) -// testutil.Equals(t, ps2.Get("policy2"), policyC) -// }) -// } - func TestDeletePolicy(t *testing.T) { t.Parallel() t.Run("delete non-existent", func(t *testing.T) { @@ -134,7 +80,7 @@ func TestDeletePolicy(t *testing.T) { ps := cedar.NewPolicySet() p1 := cedar.NewPolicyFromAST(ast.Forbid()) - ps.Upsert("a policy", p1) + ps.Set("a policy", p1) ps.Delete("a policy") testutil.Equals(t, ps.Get("a policy"), cedar.Policy{}) @@ -163,11 +109,20 @@ forbid ( ps := cedar.NewPolicySet() for i, p := range policies { p.SetFilename("example.cedar") - ps.Upsert(cedar.PolicyID(fmt.Sprintf("policy%d", i)), p) + ps.Set(cedar.PolicyID(fmt.Sprintf("policy%d", i)), p) } testutil.Equals(t, ps.Get("policy0").Effect(), cedar.Permit) testutil.Equals(t, ps.Get("policy1").Effect(), cedar.Forbid) testutil.Equals(t, string(ps.MarshalCedar()), policiesStr) + +} + +func TestPolicyMap(t *testing.T) { + t.Parallel() + ps, err := cedar.NewPolicySetFromBytes("", []byte(`permit (principal, action, resource);`)) + testutil.OK(t, err) + m := ps.Map() + testutil.Equals(t, len(m), 1) } From 1365afd56d08983b57ca426232411453839d9a2a Mon Sep 17 00:00:00 2001 From: philhassey Date: Wed, 21 Aug 2024 12:09:31 -0600 Subject: [PATCH 207/216] cedar: add policy set JSON methods Addresses IDX-142 Signed-off-by: philhassey --- internal/json/policy_set.go | 7 ++++++ policy_map.go | 46 ++++++++++++++++++++++++++++++++----- policy_map_test.go | 26 +++++++++++++++++++++ 3 files changed, 73 insertions(+), 6 deletions(-) create mode 100644 internal/json/policy_set.go diff --git a/internal/json/policy_set.go b/internal/json/policy_set.go new file mode 100644 index 00000000..39ec3c5b --- /dev/null +++ b/internal/json/policy_set.go @@ -0,0 +1,7 @@ +package json + +type PolicySet map[string]*Policy + +type PolicySetJSON struct { + StaticPolicies PolicySet `json:"staticPolicies"` +} diff --git a/policy_map.go b/policy_map.go index 8b059f9f..58a7fd1f 100644 --- a/policy_map.go +++ b/policy_map.go @@ -3,9 +3,13 @@ package cedar import ( "bytes" + "encoding/json" "fmt" "maps" "slices" + + internalast "github.com/cedar-policy/cedar-go/internal/ast" + internaljson "github.com/cedar-policy/cedar-go/internal/json" ) // PolicyID is a string identifier for the policy within the PolicySet @@ -21,8 +25,8 @@ type PolicySet struct { } // NewPolicySet creates a new, empty PolicySet -func NewPolicySet() PolicySet { - return PolicySet{policies: PolicyMap{}} +func NewPolicySet() *PolicySet { + return &PolicySet{policies: PolicyMap{}} } // NewPolicySetFromBytes will create a PolicySet from the given text document with the given file name used in Position @@ -30,17 +34,17 @@ func NewPolicySet() PolicySet { // // NewPolicySetFromBytes assigns default PolicyIDs to the policies contained in fileName in the format "policy" where // is incremented for each new policy found in the file. -func NewPolicySetFromBytes(fileName string, document []byte) (PolicySet, error) { +func NewPolicySetFromBytes(fileName string, document []byte) (*PolicySet, error) { policySlice, err := NewPolicyListFromBytes(fileName, document) if err != nil { - return PolicySet{}, err + return &PolicySet{}, err } policyMap := make(PolicyMap, len(policySlice)) for i, p := range policySlice { policyID := PolicyID(fmt.Sprintf("policy%d", i)) policyMap[policyID] = p } - return PolicySet{policies: policyMap}, nil + return &PolicySet{policies: policyMap}, nil } // Get returns the Policy with the given ID. If a policy with the given ID does not exist, an empty policy is returned. @@ -65,7 +69,7 @@ func (p *PolicySet) Map() PolicyMap { // MarshalCedar emits a concatenated Cedar representation of a PolicySet. The policy names are stripped, but policies // are emitted in lexicographical order by ID. -func (p PolicySet) MarshalCedar() []byte { +func (p *PolicySet) MarshalCedar() []byte { ids := make([]PolicyID, 0, len(p.policies)) for k := range p.policies { ids = append(ids, k) @@ -85,3 +89,33 @@ func (p PolicySet) MarshalCedar() []byte { } return buf.Bytes() } + +// MarshalJSON encodes a PolicySet in the JSON format specified by the [Cedar documentation]. +// +// [Cedar documentation]: https://docs.cedarpolicy.com/policies/json-format.html +func (p *PolicySet) MarshalJSON() ([]byte, error) { + jsonPolicySet := internaljson.PolicySetJSON{ + StaticPolicies: make(internaljson.PolicySet, len(p.policies)), + } + for k, v := range p.policies { + jsonPolicySet.StaticPolicies[string(k)] = (*internaljson.Policy)(v.ast) + } + return json.Marshal(jsonPolicySet) +} + +// UnmarshalJSON parses and compiles a PolicySet in the JSON format specified by the [Cedar documentation]. +// +// [Cedar documentation]: https://docs.cedarpolicy.com/policies/json-format.html +func (p *PolicySet) UnmarshalJSON(b []byte) error { + var jsonPolicySet internaljson.PolicySetJSON + if err := json.Unmarshal(b, &jsonPolicySet); err != nil { + return err + } + *p = PolicySet{ + policies: make(PolicyMap, len(jsonPolicySet.StaticPolicies)), + } + for k, v := range jsonPolicySet.StaticPolicies { + p.policies[PolicyID(k)] = newPolicy((*internalast.Policy)(v)) + } + return nil +} diff --git a/policy_map_test.go b/policy_map_test.go index e90ec203..c0d8652e 100644 --- a/policy_map_test.go +++ b/policy_map_test.go @@ -126,3 +126,29 @@ func TestPolicyMap(t *testing.T) { m := ps.Map() testutil.Equals(t, len(m), 1) } + +func TestPolicySetJSON(t *testing.T) { + t.Parallel() + t.Run("UnmarshalError", func(t *testing.T) { + t.Parallel() + var ps cedar.PolicySet + err := ps.UnmarshalJSON([]byte(`!@#$`)) + testutil.Error(t, err) + }) + t.Run("UnmarshalOK", func(t *testing.T) { + t.Parallel() + var ps cedar.PolicySet + err := ps.UnmarshalJSON([]byte(`{"staticPolicies":{"policy0":{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}}}}`)) + testutil.OK(t, err) + testutil.Equals(t, len(ps.Map()), 1) + }) + + t.Run("MarshalOK", func(t *testing.T) { + t.Parallel() + ps, err := cedar.NewPolicySetFromBytes("", []byte(`permit (principal, action, resource);`)) + testutil.OK(t, err) + out, err := ps.MarshalJSON() + testutil.OK(t, err) + testutil.Equals(t, string(out), `{"staticPolicies":{"policy0":{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"All"}}}}`) + }) +} From 16d08e552ae6bc63360aed48195bddf8eab0b564 Mon Sep 17 00:00:00 2001 From: philhassey Date: Wed, 21 Aug 2024 12:19:52 -0600 Subject: [PATCH 208/216] cedar: minor shape/doc improvements Addresses IDX-142 Signed-off-by: philhassey --- policy.go | 16 ++++++++-------- policy_map.go | 8 ++++---- policy_map_test.go | 8 ++++---- policy_slice.go | 4 ++-- 4 files changed, 18 insertions(+), 18 deletions(-) diff --git a/policy.go b/policy.go index ba344f83..8e8d74a9 100644 --- a/policy.go +++ b/policy.go @@ -17,8 +17,8 @@ type Policy struct { ast *internalast.Policy } -func newPolicy(astIn *internalast.Policy) Policy { - return Policy{eval: eval.Compile(astIn), ast: astIn} +func newPolicy(astIn *internalast.Policy) *Policy { + return &Policy{eval: eval.Compile(astIn), ast: astIn} } // MarshalJSON encodes a single Policy statement in the JSON format specified by the [Cedar documentation]. @@ -38,7 +38,7 @@ func (p *Policy) UnmarshalJSON(b []byte) error { return err } - *p = newPolicy((*internalast.Policy)(&jsonPolicy)) + *p = *newPolicy((*internalast.Policy)(&jsonPolicy)) return nil } @@ -56,13 +56,13 @@ func (p *Policy) UnmarshalCedar(b []byte) error { if err := cedarPolicy.UnmarshalCedar(b); err != nil { return err } - *p = newPolicy((*internalast.Policy)(&cedarPolicy)) + *p = *newPolicy((*internalast.Policy)(&cedarPolicy)) return nil } // NewPolicyFromAST lets you create a new policy statement from a programatically created AST. // Do not modify the *ast.Policy after passing it into NewPolicyFromAST. -func NewPolicyFromAST(astIn *ast.Policy) Policy { +func NewPolicyFromAST(astIn *ast.Policy) *Policy { p := newPolicy((*internalast.Policy)(astIn)) return p } @@ -72,7 +72,7 @@ func NewPolicyFromAST(astIn *ast.Policy) Policy { type Annotations map[types.Ident]types.String // Annotations retrieves the annotations associated with this policy. -func (p Policy) Annotations() Annotations { +func (p *Policy) Annotations() Annotations { res := make(Annotations, len(p.ast.Annotations)) for _, e := range p.ast.Annotations { res[e.Key] = e.Value @@ -91,7 +91,7 @@ const ( ) // Effect retrieves the effect of this policy. -func (p Policy) Effect() Effect { +func (p *Policy) Effect() Effect { return Effect(p.ast.Effect) } @@ -104,7 +104,7 @@ type Position struct { } // Position retrieves the position of this policy. -func (p Policy) Position() Position { +func (p *Policy) Position() Position { return Position(p.ast.Position) } diff --git a/policy_map.go b/policy_map.go index 58a7fd1f..2dfc3289 100644 --- a/policy_map.go +++ b/policy_map.go @@ -16,7 +16,7 @@ import ( type PolicyID string // PolicyMap is a map of policy IDs to policy -type PolicyMap map[PolicyID]Policy +type PolicyMap map[PolicyID]*Policy // PolicySet is a set of named policies against which a request can be authorized. type PolicySet struct { @@ -47,13 +47,13 @@ func NewPolicySetFromBytes(fileName string, document []byte) (*PolicySet, error) return &PolicySet{policies: policyMap}, nil } -// Get returns the Policy with the given ID. If a policy with the given ID does not exist, an empty policy is returned. -func (p PolicySet) Get(policyID PolicyID) Policy { +// Get returns the Policy with the given ID. If a policy with the given ID does not exist, nil is returned. +func (p PolicySet) Get(policyID PolicyID) *Policy { return p.policies[policyID] } // Set inserts or updates a policy with the given ID. -func (p *PolicySet) Set(policyID PolicyID, policy Policy) { +func (p *PolicySet) Set(policyID PolicyID, policy *Policy) { p.policies[policyID] = policy } diff --git a/policy_map_test.go b/policy_map_test.go index c0d8652e..0ca6842b 100644 --- a/policy_map_test.go +++ b/policy_map_test.go @@ -43,11 +43,11 @@ func TestUpsertPolicy(t *testing.T) { ps := cedar.NewPolicySet() ps.Set("policy0", policy0) - ps.Set("policy1", policy1) + ps.Set("policy1", &policy1) testutil.Equals(t, ps.Get("policy0"), policy0) - testutil.Equals(t, ps.Get("policy1"), policy1) - testutil.Equals(t, ps.Get("policy2"), cedar.Policy{}) + testutil.Equals(t, ps.Get("policy1"), &policy1) + testutil.Equals(t, ps.Get("policy2"), nil) }) t.Run("upsert", func(t *testing.T) { t.Parallel() @@ -83,7 +83,7 @@ func TestDeletePolicy(t *testing.T) { ps.Set("a policy", p1) ps.Delete("a policy") - testutil.Equals(t, ps.Get("a policy"), cedar.Policy{}) + testutil.Equals(t, ps.Get("a policy"), nil) }) } diff --git a/policy_slice.go b/policy_slice.go index 116d4892..225dec1b 100644 --- a/policy_slice.go +++ b/policy_slice.go @@ -10,7 +10,7 @@ import ( // PolicyList represents a list of un-named Policy's. Cedar documents, unlike the PolicySet form, don't have a means of // naming individual policies. -type PolicyList []Policy +type PolicyList []*Policy // NewPolicyListFromBytes will create a Policies from the given text document with the given file name used in Position // data. If there is an error parsing the document, it will be returned. @@ -32,7 +32,7 @@ func (p *PolicyList) UnmarshalCedar(b []byte) error { if err := res.UnmarshalCedar(b); err != nil { return fmt.Errorf("parser error: %w", err) } - policySlice := make([]Policy, 0, len(res)) + policySlice := make([]*Policy, 0, len(res)) for _, p := range res { newPolicy := newPolicy((*internalast.Policy)(p)) policySlice = append(policySlice, newPolicy) From 971546b902ac37a92d89693f45439f8e180c842e Mon Sep 17 00:00:00 2001 From: philhassey Date: Wed, 21 Aug 2024 12:24:47 -0600 Subject: [PATCH 209/216] types: improve docs in types Addresses IDX-142 Signed-off-by: philhassey --- types/ident.go | 1 + types/path.go | 6 ------ types/path_test.go | 18 ------------------ types/value.go | 1 + 4 files changed, 2 insertions(+), 24 deletions(-) delete mode 100644 types/path_test.go diff --git a/types/ident.go b/types/ident.go index 45d4e334..23a65160 100644 --- a/types/ident.go +++ b/types/ident.go @@ -1,3 +1,4 @@ package types +// Ident is the type for a single unquoted identifier in cedar, e.g. in `context.key`, `key` is an ident. type Ident string diff --git a/types/path.go b/types/path.go index 8c53c402..522d7b96 100644 --- a/types/path.go +++ b/types/path.go @@ -1,10 +1,4 @@ package types -import "strings" - // Path is the type portion of an EntityUID type Path string - -func PathFromSlice(v []string) Path { - return Path(strings.Join(v, "::")) -} diff --git a/types/path_test.go b/types/path_test.go deleted file mode 100644 index efc3f4f2..00000000 --- a/types/path_test.go +++ /dev/null @@ -1,18 +0,0 @@ -package types_test - -import ( - "testing" - - "github.com/cedar-policy/cedar-go/internal/testutil" - "github.com/cedar-policy/cedar-go/types" -) - -func TestEntityType(t *testing.T) { - t.Parallel() - t.Run("pathFromSlice", func(t *testing.T) { - t.Parallel() - a := types.PathFromSlice([]string{"X", "Y"}) - testutil.Equals(t, a, types.Path("X::Y")) - }) - -} diff --git a/types/value.go b/types/value.go index ee9373eb..46ebb5d5 100644 --- a/types/value.go +++ b/types/value.go @@ -7,6 +7,7 @@ import ( var ErrDecimal = fmt.Errorf("error parsing decimal value") var ErrIP = fmt.Errorf("error parsing ip value") +// Value defines the interface for all Cedar values (String, Long, Set, Record, Boolean, etc ...) type Value interface { // String produces a string representation of the Value. String() string From 0a1e70ab6bce1aad7dc728050d6b2bf04b057482 Mon Sep 17 00:00:00 2001 From: philhassey Date: Wed, 21 Aug 2024 13:59:38 -0600 Subject: [PATCH 210/216] ast: improve AST docs Addresses IDX-142 Signed-off-by: philhassey --- ast/annotation.go | 4 +++- ast/node.go | 3 +++ ast/policy.go | 4 ++++ ast/scope.go | 22 +++++++++++----------- ast/value.go | 11 ++++++++++- 5 files changed, 31 insertions(+), 13 deletions(-) diff --git a/ast/annotation.go b/ast/annotation.go index edb87daf..eaa67138 100644 --- a/ast/annotation.go +++ b/ast/annotation.go @@ -26,15 +26,17 @@ func Annotation(key types.Ident, value types.String) *Annotations { return wrapAnnotations(ast.Annotation(key, value)) } -// If a previous annotation exists with the same key, this builder will replace it. +// Annotation adds an annotation. If a previous annotation exists with the same key, this builder will replace it. func (a *Annotations) Annotation(key types.Ident, value types.String) *Annotations { return wrapAnnotations(a.unwrap().Annotation(key, value)) } +// Permit begins a permit policy from the given annotations. func (a *Annotations) Permit() *Policy { return wrapPolicy(a.unwrap().Permit()) } +// Forbid begins a forbid policy from the given annotations. func (a *Annotations) Forbid() *Policy { return wrapPolicy(a.unwrap().Forbid()) } diff --git a/ast/node.go b/ast/node.go index 5f7a74da..8a9e804e 100644 --- a/ast/node.go +++ b/ast/node.go @@ -2,6 +2,9 @@ package ast import "github.com/cedar-policy/cedar-go/internal/ast" +// Node is a wrapper type for all the Cedar language operators. See the [Cedar operators documentation] for details. +// +// [Cedar operators documentation]: https://docs.cedarpolicy.com/policies/syntax-operators.html type Node struct { ast.Node } diff --git a/ast/policy.go b/ast/policy.go index 62184bf0..95bcf699 100644 --- a/ast/policy.go +++ b/ast/policy.go @@ -12,18 +12,22 @@ func (p *Policy) unwrap() *ast.Policy { return (*ast.Policy)(p) } +// Permit creates a new Permit policy. func Permit() *Policy { return wrapPolicy(ast.Permit()) } +// Forbid creates a new Forbid policy. func Forbid() *Policy { return wrapPolicy(ast.Forbid()) } +// When adds a conditional clause. func (p *Policy) When(node Node) *Policy { return wrapPolicy(p.unwrap().When(node.Node)) } +// Unless adds a conditional clause. func (p *Policy) Unless(node Node) *Policy { return wrapPolicy(p.unwrap().Unless(node.Node)) } diff --git a/ast/scope.go b/ast/scope.go index 7141105d..1ed76a48 100644 --- a/ast/scope.go +++ b/ast/scope.go @@ -4,57 +4,57 @@ import ( "github.com/cedar-policy/cedar-go/types" ) -// This builder will replace the previous principal scope condition. +// PrincipalEq replaces the principal scope condition. func (p *Policy) PrincipalEq(entity types.EntityUID) *Policy { return wrapPolicy(p.unwrap().PrincipalEq(entity)) } -// This builder will replace the previous principal scope condition. +// PrincipalIn replaces the principal scope condition. func (p *Policy) PrincipalIn(entity types.EntityUID) *Policy { return wrapPolicy(p.unwrap().PrincipalIn(entity)) } -// This builder will replace the previous principal scope condition. +// PrincipalIs replaces the principal scope condition. func (p *Policy) PrincipalIs(entityType types.Path) *Policy { return wrapPolicy(p.unwrap().PrincipalIs(entityType)) } -// This builder will replace the previous principal scope condition. +// PrincipalIsIn replaces the principal scope condition. func (p *Policy) PrincipalIsIn(entityType types.Path, entity types.EntityUID) *Policy { return wrapPolicy(p.unwrap().PrincipalIsIn(entityType, entity)) } -// This builder will replace the previous action scope condition. +// ActionEq replaces the action scope condition. func (p *Policy) ActionEq(entity types.EntityUID) *Policy { return wrapPolicy(p.unwrap().ActionEq(entity)) } -// This builder will replace the previous action scope condition. +// ActionIn replaces the action scope condition. func (p *Policy) ActionIn(entity types.EntityUID) *Policy { return wrapPolicy(p.unwrap().ActionIn(entity)) } -// This builder will replace the previous action scope condition. +// ActionInSet replaces the action scope condition. func (p *Policy) ActionInSet(entities ...types.EntityUID) *Policy { return wrapPolicy(p.unwrap().ActionInSet(entities...)) } -// This builder will replace the previous resource scope condition. +// ResourceEq replaces the resource scope condition. func (p *Policy) ResourceEq(entity types.EntityUID) *Policy { return wrapPolicy(p.unwrap().ResourceEq(entity)) } -// This builder will replace the previous resource scope condition. +// ResourceIn replaces the resource scope condition. func (p *Policy) ResourceIn(entity types.EntityUID) *Policy { return wrapPolicy(p.unwrap().ResourceIn(entity)) } -// This builder will replace the previous resource scope condition. +// ResourceIs replaces the resource scope condition. func (p *Policy) ResourceIs(entityType types.Path) *Policy { return wrapPolicy(p.unwrap().ResourceIs(entityType)) } -// This builder will replace the previous resource scope condition. +// ResourceIsIn replaces the resource scope condition. func (p *Policy) ResourceIsIn(entityType types.Path, entity types.EntityUID) *Policy { return wrapPolicy(p.unwrap().ResourceIsIn(entityType, entity)) } diff --git a/ast/value.go b/ast/value.go index ab27ae5d..3399cc0c 100644 --- a/ast/value.go +++ b/ast/value.go @@ -7,22 +7,27 @@ import ( "github.com/cedar-policy/cedar-go/types" ) +// Boolean creates a value node containing a Boolean. func Boolean[T bool | types.Boolean](b T) Node { return wrapNode(ast.Boolean(types.Boolean(b))) } +// True creates a value node containing True. func True() Node { return Boolean(true) } +// False creates a value node containing False. func False() Node { return Boolean(false) } +// String creates a value node containing a String. func String[T string | types.String](s T) Node { return wrapNode(ast.String(types.String(s))) } +// Long creates a value node containing a Long. func Long[T int | int64 | types.Long](l T) Node { return wrapNode(ast.Long(types.Long(l))) } @@ -47,6 +52,7 @@ func Set(nodes ...Node) Node { return wrapNode(ast.Set(astNodes...)) } +// Pair is map of Key string to Value node. type Pair struct { Key types.String Value Node @@ -54,7 +60,7 @@ type Pair struct { type Pairs []Pair -// In the case where duplicate keys exist, the latter value will be preserved. +// Record creates a record node. In the case where duplicate keys exist, the latter value will be preserved. func Record(elements Pairs) Node { var astNodes []ast.Pair for _, v := range elements { @@ -63,14 +69,17 @@ func Record(elements Pairs) Node { return wrapNode(ast.Record(astNodes)) } +// EntityUID creates a value node containing an EntityUID. func EntityUID(typ types.Ident, id types.String) Node { return wrapNode(ast.EntityUID(typ, id)) } +// IPAddr creates an value node containing an IPAddr. func IPAddr[T netip.Prefix | types.IPAddr](i T) Node { return wrapNode(ast.IPAddr(types.IPAddr(i))) } +// Value creates a value node from any value. func Value(v types.Value) Node { return wrapNode(ast.Value(v)) } From ed6ba43470ada36d53d49efc6a3ffa1d21e4073f Mon Sep 17 00:00:00 2001 From: philhassey Date: Thu, 22 Aug 2024 15:53:59 -0600 Subject: [PATCH 211/216] cedar: add tmp to git ignore Addresses IDX-142 Signed-off-by: philhassey --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 9f11b755..f498092c 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ .idea/ +tmp/ \ No newline at end of file From 0167cffe7365cbb94b76b17f675d5377a3b60edb Mon Sep 17 00:00:00 2001 From: philhassey Date: Thu, 22 Aug 2024 16:03:01 -0600 Subject: [PATCH 212/216] cedar: update README Addresses IDX-146 Signed-off-by: philhassey --- README.md | 27 ++++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index efd8992b..3bf0d74f 100644 --- a/README.md +++ b/README.md @@ -57,6 +57,7 @@ import ( "log" cedar "github.com/cedar-policy/cedar-go" + "github.com/cedar-policy/cedar-go/types" ) const policyCedar = `permit ( @@ -84,19 +85,19 @@ func main() { if err := policy.UnmarshalCedar([]byte(policyCedar)); err != nil { log.Fatal(err) } - + ps := cedar.NewPolicySet() - ps.UpsertPolicy("policy0", &policy) + ps.Set("policy0", &policy) - var entities cedar.Entities + var entities types.Entities if err := json.Unmarshal([]byte(entitiesJSON), &entities); err != nil { log.Fatal(err) } req := cedar.Request{ - Principal: cedar.EntityUID{Type: "User", ID: "alice"}, - Action: cedar.EntityUID{Type: "Action", ID: "view"}, - Resource: cedar.EntityUID{Type: "Photo", ID: "VacationPhoto94.jpg"}, - Context: cedar.Record{}, + Principal: types.EntityUID{Type: "User", ID: "alice"}, + Action: types.EntityUID{Type: "Action", ID: "view"}, + Resource: types.EntityUID{Type: "Photo", ID: "VacationPhoto94.jpg"}, + Context: types.Record{}, } ok, _ := ps.IsAuthorized(entities, req) @@ -128,6 +129,18 @@ If you're looking to integrate Cedar into a production system, please be sure th x/exp - code in this subrepository is not subject to the Go 1 compatibility promise. +While in development (0.x.y), each tagged release may contain breaking changes. + +### Upgrading from 0.1.x to 0.2.x + +- The Cedar value types have moved from the `cedar` package to the `types` package. +- The PolicyIDs are now `strings`, previously they were numeric. Combining multiple parsed Cedar files +now involves coming up with IDs for each statement in those files. It's best to +create an empty `NewPolicySet` then parse individual files using `NewPolicyListFromBytes` and subsequently +use `PolicySet.Add` to add each of the policy statements. +- The Cedar `Entity` and `Entities` types have moved from the `cedar` package to the `types` package. +- Stronger typing is being used in many more places. + ## Security See [CONTRIBUTING](CONTRIBUTING.md#security-issue-notifications) for more information. From 9b75ff9a8f9079bac6e8fc22a692fb0a4450abf8 Mon Sep 17 00:00:00 2001 From: philhassey Date: Thu, 22 Aug 2024 17:01:01 -0600 Subject: [PATCH 213/216] cedar: improve README Addresses IDX-142 Signed-off-by: philhassey --- README.md | 24 ++++++++++++++++++------ authorize_test.go | 2 +- policy_map.go | 4 ++-- policy_map_test.go | 12 ++++++------ 4 files changed, 27 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index 3bf0d74f..7e69583b 100644 --- a/README.md +++ b/README.md @@ -87,7 +87,7 @@ func main() { } ps := cedar.NewPolicySet() - ps.Set("policy0", &policy) + ps.Store("policy0", &policy) var entities types.Entities if err := json.Unmarshal([]byte(entitiesJSON), &entities); err != nil { @@ -131,15 +131,27 @@ compatibility promise. While in development (0.x.y), each tagged release may contain breaking changes. +## Change log + +### New features in 0.2.x + +- A programmatic AST is now available in the `ast` package. +- Policy sets can be marshaled and unmarshaled from JSON. +- Policies can also be marshaled to Cedar text. + ### Upgrading from 0.1.x to 0.2.x - The Cedar value types have moved from the `cedar` package to the `types` package. -- The PolicyIDs are now `strings`, previously they were numeric. Combining multiple parsed Cedar files -now involves coming up with IDs for each statement in those files. It's best to -create an empty `NewPolicySet` then parse individual files using `NewPolicyListFromBytes` and subsequently -use `PolicySet.Add` to add each of the policy statements. +- The PolicyIDs are now `strings`, previously they were numeric. +- Errors and reasons use the new `PolicyID` form. +- Combining multiple parsed Cedar files now involves coming up with IDs for each +statement in those files. It's best to create an empty `NewPolicySet` then +parse individual files using `NewPolicyListFromBytes` and subsequently use +`PolicySet.Store` to add each of the policy statements. - The Cedar `Entity` and `Entities` types have moved from the `cedar` package to the `types` package. -- Stronger typing is being used in many more places. +- Stronger typing is being used in many places. +- The `Value` method `Cedar() string` was changed to `MarshalCedar() []byte` + ## Security diff --git a/authorize_test.go b/authorize_test.go index 00ccfe88..da461e59 100644 --- a/authorize_test.go +++ b/authorize_test.go @@ -741,7 +741,7 @@ func TestBadEval(t *testing.T) { ps := NewPolicySet() pol := NewPolicyFromAST(ast.Permit()) pol.eval = &badEvaler{} - ps.Set("pol", pol) + ps.Store("pol", pol) dec, diag := ps.IsAuthorized(nil, Request{}) testutil.Equals(t, dec, Deny) testutil.Equals(t, len(diag.Errors), 1) diff --git a/policy_map.go b/policy_map.go index 2dfc3289..132d6c09 100644 --- a/policy_map.go +++ b/policy_map.go @@ -52,8 +52,8 @@ func (p PolicySet) Get(policyID PolicyID) *Policy { return p.policies[policyID] } -// Set inserts or updates a policy with the given ID. -func (p *PolicySet) Set(policyID PolicyID, policy *Policy) { +// Store inserts or updates a policy with the given ID. +func (p *PolicySet) Store(policyID PolicyID, policy *Policy) { p.policies[policyID] = policy } diff --git a/policy_map_test.go b/policy_map_test.go index 0ca6842b..31fc0d02 100644 --- a/policy_map_test.go +++ b/policy_map_test.go @@ -42,8 +42,8 @@ func TestUpsertPolicy(t *testing.T) { )) ps := cedar.NewPolicySet() - ps.Set("policy0", policy0) - ps.Set("policy1", &policy1) + ps.Store("policy0", policy0) + ps.Store("policy1", &policy1) testutil.Equals(t, ps.Get("policy0"), policy0) testutil.Equals(t, ps.Get("policy1"), &policy1) @@ -55,10 +55,10 @@ func TestUpsertPolicy(t *testing.T) { ps := cedar.NewPolicySet() p1 := cedar.NewPolicyFromAST(ast.Forbid()) - ps.Set("a wavering policy", p1) + ps.Store("a wavering policy", p1) p2 := cedar.NewPolicyFromAST(ast.Permit()) - ps.Set("a wavering policy", p2) + ps.Store("a wavering policy", p2) testutil.Equals(t, ps.Get("a wavering policy"), p2) }) @@ -80,7 +80,7 @@ func TestDeletePolicy(t *testing.T) { ps := cedar.NewPolicySet() p1 := cedar.NewPolicyFromAST(ast.Forbid()) - ps.Set("a policy", p1) + ps.Store("a policy", p1) ps.Delete("a policy") testutil.Equals(t, ps.Get("a policy"), nil) @@ -109,7 +109,7 @@ forbid ( ps := cedar.NewPolicySet() for i, p := range policies { p.SetFilename("example.cedar") - ps.Set(cedar.PolicyID(fmt.Sprintf("policy%d", i)), p) + ps.Store(cedar.PolicyID(fmt.Sprintf("policy%d", i)), p) } testutil.Equals(t, ps.Get("policy0").Effect(), cedar.Permit) From f3275cb92b16883228c9bdf56de6bd8e5a190744 Mon Sep 17 00:00:00 2001 From: philhassey Date: Thu, 22 Aug 2024 17:12:14 -0600 Subject: [PATCH 214/216] types: make record a map of types.String to Value Addresses: IDX-142 Signed-off-by: philhassey --- ast/operator.go | 4 ++-- internal/ast/operator.go | 8 ++++---- internal/eval/convert.go | 9 +++++---- internal/eval/evalers.go | 12 ++++++------ internal/eval/evalers_test.go | 12 ++++++------ internal/json/json_unmarshal.go | 4 ++-- internal/parser/cedar_unmarshal.go | 8 ++++---- types/json_test.go | 2 +- types/record.go | 6 +++--- 9 files changed, 33 insertions(+), 32 deletions(-) diff --git a/ast/operator.go b/ast/operator.go index ab7af3f3..d007cb83 100644 --- a/ast/operator.go +++ b/ast/operator.go @@ -132,11 +132,11 @@ func (lhs Node) ContainsAny(rhs Node) Node { return wrapNode(lhs.Node.ContainsAny(rhs.Node)) } -func (lhs Node) Access(attr string) Node { +func (lhs Node) Access(attr types.String) Node { return wrapNode(lhs.Node.Access(attr)) } -func (lhs Node) Has(attr string) Node { +func (lhs Node) Has(attr types.String) Node { return wrapNode(lhs.Node.Has(attr)) } diff --git a/internal/ast/operator.go b/internal/ast/operator.go index 81a39018..59594878 100644 --- a/internal/ast/operator.go +++ b/internal/ast/operator.go @@ -129,12 +129,12 @@ func (lhs Node) ContainsAny(rhs Node) Node { return NewNode(NodeTypeContainsAny{BinaryNode: BinaryNode{Left: lhs.v, Right: rhs.v}}) } -func (lhs Node) Access(attr string) Node { - return NewNode(NodeTypeAccess{StrOpNode: StrOpNode{Arg: lhs.v, Value: types.String(attr)}}) +func (lhs Node) Access(attr types.String) Node { + return NewNode(NodeTypeAccess{StrOpNode: StrOpNode{Arg: lhs.v, Value: attr}}) } -func (lhs Node) Has(attr string) Node { - return NewNode(NodeTypeHas{StrOpNode: StrOpNode{Arg: lhs.v, Value: types.String(attr)}}) +func (lhs Node) Has(attr types.String) Node { + return NewNode(NodeTypeHas{StrOpNode: StrOpNode{Arg: lhs.v, Value: attr}}) } // ___ ____ _ _ _ diff --git a/internal/eval/convert.go b/internal/eval/convert.go index 4ea615a0..a96b21fb 100644 --- a/internal/eval/convert.go +++ b/internal/eval/convert.go @@ -6,14 +6,15 @@ import ( "github.com/cedar-policy/cedar-go/internal/ast" "github.com/cedar-policy/cedar-go/internal/consts" "github.com/cedar-policy/cedar-go/internal/extensions" + "github.com/cedar-policy/cedar-go/types" ) func toEval(n ast.IsNode) Evaler { switch v := n.(type) { case ast.NodeTypeAccess: - return newAttributeAccessEval(toEval(v.Arg), string(v.Value)) + return newAttributeAccessEval(toEval(v.Arg), v.Value) case ast.NodeTypeHas: - return newHasEval(toEval(v.Arg), string(v.Value)) + return newHasEval(toEval(v.Arg), v.Value) case ast.NodeTypeLike: return newLikeEval(toEval(v.Arg), v.Value) case ast.NodeTypeIfThenElse: @@ -61,9 +62,9 @@ func toEval(n ast.IsNode) Evaler { case ast.NodeValue: return newLiteralEval(v.Value) case ast.NodeTypeRecord: - m := make(map[string]Evaler, len(v.Elements)) + m := make(map[types.String]Evaler, len(v.Elements)) for _, e := range v.Elements { - m[string(e.Key)] = toEval(e.Value) + m[e.Key] = toEval(e.Value) } return newRecordLiteralEval(m) case ast.NodeTypeSet: diff --git a/internal/eval/evalers.go b/internal/eval/evalers.go index fa5f6a3f..65a55681 100644 --- a/internal/eval/evalers.go +++ b/internal/eval/evalers.go @@ -771,10 +771,10 @@ func (n *containsAnyEval) Eval(ctx *Context) (types.Value, error) { // recordLiteralEval type recordLiteralEval struct { - elements map[string]Evaler + elements map[types.String]Evaler } -func newRecordLiteralEval(elements map[string]Evaler) *recordLiteralEval { +func newRecordLiteralEval(elements map[types.String]Evaler) *recordLiteralEval { return &recordLiteralEval{elements: elements} } @@ -793,10 +793,10 @@ func (n *recordLiteralEval) Eval(ctx *Context) (types.Value, error) { // attributeAccessEval type attributeAccessEval struct { object Evaler - attribute string + attribute types.String } -func newAttributeAccessEval(record Evaler, attribute string) *attributeAccessEval { +func newAttributeAccessEval(record Evaler, attribute types.String) *attributeAccessEval { return &attributeAccessEval{object: record, attribute: attribute} } @@ -835,10 +835,10 @@ func (n *attributeAccessEval) Eval(ctx *Context) (types.Value, error) { // hasEval type hasEval struct { object Evaler - attribute string + attribute types.String } -func newHasEval(record Evaler, attribute string) *hasEval { +func newHasEval(record Evaler, attribute types.String) *hasEval { return &hasEval{object: record, attribute: attribute} } diff --git a/internal/eval/evalers_test.go b/internal/eval/evalers_test.go index 28bed1e3..3505c79e 100644 --- a/internal/eval/evalers_test.go +++ b/internal/eval/evalers_test.go @@ -1239,14 +1239,14 @@ func TestRecordLiteralNode(t *testing.T) { t.Parallel() tests := []struct { name string - elems map[string]Evaler + elems map[types.String]Evaler result types.Value err error }{ - {"empty", map[string]Evaler{}, types.Record{}, nil}, - {"errorNode", map[string]Evaler{"foo": newErrorEval(errTest)}, zeroValue(), errTest}, + {"empty", map[types.String]Evaler{}, types.Record{}, nil}, + {"errorNode", map[types.String]Evaler{"foo": newErrorEval(errTest)}, zeroValue(), errTest}, {"ok", - map[string]Evaler{ + map[types.String]Evaler{ "foo": newLiteralEval(types.True), "bar": newLiteralEval(types.String("baz")), }, types.Record{ @@ -1271,7 +1271,7 @@ func TestAttributeAccessNode(t *testing.T) { tests := []struct { name string object Evaler - attribute string + attribute types.String result types.Value err error }{ @@ -1328,7 +1328,7 @@ func TestHasNode(t *testing.T) { tests := []struct { name string record Evaler - attribute string + attribute types.String result types.Value err error }{ diff --git a/internal/json/json_unmarshal.go b/internal/json/json_unmarshal.go index 329322de..518697f3 100644 --- a/internal/json/json_unmarshal.go +++ b/internal/json/json_unmarshal.go @@ -76,12 +76,12 @@ func (j unaryJSON) ToNode(f func(a ast.Node) ast.Node) (ast.Node, error) { } return f(arg), nil } -func (j strJSON) ToNode(f func(a ast.Node, k string) ast.Node) (ast.Node, error) { +func (j strJSON) ToNode(f func(a ast.Node, k types.String) ast.Node) (ast.Node, error) { left, err := j.Left.ToNode() if err != nil { return ast.Node{}, fmt.Errorf("error in left: %w", err) } - return f(left, j.Attr), nil + return f(left, types.String(j.Attr)), nil } func (j likeJSON) ToNode(f func(a ast.Node, k types.Pattern) ast.Node) (ast.Node, error) { left, err := j.Left.ToNode() diff --git a/internal/parser/cedar_unmarshal.go b/internal/parser/cedar_unmarshal.go index eeb005fa..10a10fe5 100644 --- a/internal/parser/cedar_unmarshal.go +++ b/internal/parser/cedar_unmarshal.go @@ -538,13 +538,13 @@ func (p *parser) relation() (ast.Node, error) { func (p *parser) has(lhs ast.Node) (ast.Node, error) { t := p.advance() if t.isIdent() { - return lhs.Has(t.Text), nil + return lhs.Has(types.String(t.Text)), nil } else if t.isString() { str, err := t.stringValue() if err != nil { return ast.Node{}, err } - return lhs.Has(str), nil + return lhs.Has(types.String(str)), nil } return ast.Node{}, p.errorf("expected ident or string") } @@ -899,7 +899,7 @@ func (p *parser) access(lhs ast.Node) (ast.Node, bool, error) { } return knownMethod(lhs, exprs[0]), true, nil } else { - return lhs.Access(t.Text), true, nil + return lhs.Access(types.String(t.Text)), true, nil } case "[": p.advance() @@ -914,7 +914,7 @@ func (p *parser) access(lhs ast.Node) (ast.Node, bool, error) { if err := p.exact("]"); err != nil { return ast.Node{}, false, err } - return lhs.Access(name), true, nil + return lhs.Access(types.String(name)), true, nil default: return lhs, false, nil } diff --git a/types/json_test.go b/types/json_test.go index db7752b9..b3afeda6 100644 --- a/types/json_test.go +++ b/types/json_test.go @@ -346,7 +346,7 @@ func TestJSONRecord(t *testing.T) { t.Parallel() r := Record{} k := []byte{0xde, 0x01} - r[string(k)] = Boolean(false) + r[String(k)] = Boolean(false) v, err := json.Marshal(r) // this demonstrates that invalid keys will still result in json testutil.Equals(t, string(v), `{"\ufffd\u0001":false}`) diff --git a/types/record.go b/types/record.go index 153ae09b..b5d0c1e8 100644 --- a/types/record.go +++ b/types/record.go @@ -11,7 +11,7 @@ import ( // A Record is a collection of attributes. Each attribute consists of a name and // an associated value. Names are simple strings. Values can be of any type. -type Record map[string]Value +type Record map[String]Value // Equals returns true if the records are Equal. func (a Record) Equal(bi Value) bool { @@ -36,7 +36,7 @@ func (v *Record) UnmarshalJSON(b []byte) error { } *v = Record{} for kk, vv := range res { - (*v)[kk] = vv.Value + (*v)[String(kk)] = vv.Value } return nil } @@ -86,7 +86,7 @@ func (r Record) MarshalCedar() []byte { sb.WriteString(", ") } first = false - sb.WriteString(strconv.Quote(k)) + sb.WriteString(strconv.Quote(string(k))) sb.WriteString(":") sb.Write(v.MarshalCedar()) } From 2c4e03be8191cb0ddecde0ebeabe9cbd43efa332 Mon Sep 17 00:00:00 2001 From: philhassey Date: Thu, 22 Aug 2024 17:20:53 -0600 Subject: [PATCH 215/216] types: separate Path from EntityType Addresses IDX-142 Signed-off-by: philhassey --- ast/ast_test.go | 8 ++++---- ast/operator.go | 4 ++-- ast/scope.go | 8 ++++---- internal/ast/ast_test.go | 12 ++++++------ internal/ast/node.go | 2 +- internal/ast/operator.go | 4 ++-- internal/ast/scope.go | 16 ++++++++-------- internal/ast/value.go | 2 +- internal/eval/evalers.go | 4 ++-- internal/eval/evalers_test.go | 20 ++++++++++---------- internal/json/json_test.go | 8 ++++---- internal/json/json_unmarshal.go | 8 ++++---- internal/parser/cedar_unmarshal.go | 14 +++++++------- types/entity_uid.go | 14 ++++++++++---- types/path.go | 4 ---- 15 files changed, 65 insertions(+), 63 deletions(-) delete mode 100644 types/path.go diff --git a/ast/ast_test.go b/ast/ast_test.go index 886c9275..176e0d6d 100644 --- a/ast/ast_test.go +++ b/ast/ast_test.go @@ -345,13 +345,13 @@ func TestASTByTable(t *testing.T) { }, { "opIs", - ast.Permit().When(ast.Long(42).Is(types.Path("T"))), - internalast.Permit().When(internalast.Long(42).Is(types.Path("T"))), + ast.Permit().When(ast.Long(42).Is(types.EntityType("T"))), + internalast.Permit().When(internalast.Long(42).Is(types.EntityType("T"))), }, { "opIsIn", - ast.Permit().When(ast.Long(42).IsIn(types.Path("T"), ast.Long(43))), - internalast.Permit().When(internalast.Long(42).IsIn(types.Path("T"), internalast.Long(43))), + ast.Permit().When(ast.Long(42).IsIn(types.EntityType("T"), ast.Long(43))), + internalast.Permit().When(internalast.Long(42).IsIn(types.EntityType("T"), internalast.Long(43))), }, { "opContains", diff --git a/ast/operator.go b/ast/operator.go index d007cb83..1985be1c 100644 --- a/ast/operator.go +++ b/ast/operator.go @@ -112,11 +112,11 @@ func (lhs Node) In(rhs Node) Node { return wrapNode(lhs.Node.In(rhs.Node)) } -func (lhs Node) Is(entityType types.Path) Node { +func (lhs Node) Is(entityType types.EntityType) Node { return wrapNode(lhs.Node.Is(entityType)) } -func (lhs Node) IsIn(entityType types.Path, rhs Node) Node { +func (lhs Node) IsIn(entityType types.EntityType, rhs Node) Node { return wrapNode(lhs.Node.IsIn(entityType, rhs.Node)) } diff --git a/ast/scope.go b/ast/scope.go index 1ed76a48..f054af52 100644 --- a/ast/scope.go +++ b/ast/scope.go @@ -15,12 +15,12 @@ func (p *Policy) PrincipalIn(entity types.EntityUID) *Policy { } // PrincipalIs replaces the principal scope condition. -func (p *Policy) PrincipalIs(entityType types.Path) *Policy { +func (p *Policy) PrincipalIs(entityType types.EntityType) *Policy { return wrapPolicy(p.unwrap().PrincipalIs(entityType)) } // PrincipalIsIn replaces the principal scope condition. -func (p *Policy) PrincipalIsIn(entityType types.Path, entity types.EntityUID) *Policy { +func (p *Policy) PrincipalIsIn(entityType types.EntityType, entity types.EntityUID) *Policy { return wrapPolicy(p.unwrap().PrincipalIsIn(entityType, entity)) } @@ -50,11 +50,11 @@ func (p *Policy) ResourceIn(entity types.EntityUID) *Policy { } // ResourceIs replaces the resource scope condition. -func (p *Policy) ResourceIs(entityType types.Path) *Policy { +func (p *Policy) ResourceIs(entityType types.EntityType) *Policy { return wrapPolicy(p.unwrap().ResourceIs(entityType)) } // ResourceIsIn replaces the resource scope condition. -func (p *Policy) ResourceIsIn(entityType types.Path, entity types.EntityUID) *Policy { +func (p *Policy) ResourceIsIn(entityType types.EntityType, entity types.EntityUID) *Policy { return wrapPolicy(p.unwrap().ResourceIsIn(entityType, entity)) } diff --git a/internal/ast/ast_test.go b/internal/ast/ast_test.go index acc6ad2f..e47302ec 100644 --- a/internal/ast/ast_test.go +++ b/internal/ast/ast_test.go @@ -135,12 +135,12 @@ func TestASTByTable(t *testing.T) { { "scopePrincipalIs", ast.Permit().PrincipalIs("T"), - ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeIs{Type: types.Path("T")}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}}, + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeIs{Type: types.EntityType("T")}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}}, }, { "scopePrincipalIsIn", ast.Permit().PrincipalIsIn("T", types.NewEntityUID("T", "42")), - ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeIsIn{Type: types.Path("T"), Entity: types.NewEntityUID("T", "42")}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}}, + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeIsIn{Type: types.EntityType("T"), Entity: types.NewEntityUID("T", "42")}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}}, }, { "scopeActionEq", @@ -170,12 +170,12 @@ func TestASTByTable(t *testing.T) { { "scopeResourceIs", ast.Permit().ResourceIs("T"), - ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeIs{Type: types.Path("T")}}, + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeIs{Type: types.EntityType("T")}}, }, { "scopeResourceIsIn", ast.Permit().ResourceIsIn("T", types.NewEntityUID("T", "42")), - ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeIsIn{Type: types.Path("T"), Entity: types.NewEntityUID("T", "42")}}, + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeIsIn{Type: types.EntityType("T"), Entity: types.NewEntityUID("T", "42")}}, }, { "variablePrincipal", @@ -412,13 +412,13 @@ func TestASTByTable(t *testing.T) { "opIs", ast.Permit().When(ast.Long(42).Is("T")), ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, - Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeIs{Left: ast.NodeValue{Value: types.Long(42)}, EntityType: types.Path("T")}}}}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeIs{Left: ast.NodeValue{Value: types.Long(42)}, EntityType: types.EntityType("T")}}}}, }, { "opIsIn", ast.Permit().When(ast.Long(42).IsIn("T", ast.Long(43))), ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, - Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeIsIn{NodeTypeIs: ast.NodeTypeIs{Left: ast.NodeValue{Value: types.Long(42)}, EntityType: types.Path("T")}, Entity: ast.NodeValue{Value: types.Long(43)}}}}}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeIsIn{NodeTypeIs: ast.NodeTypeIs{Left: ast.NodeValue{Value: types.Long(42)}, EntityType: types.EntityType("T")}, Entity: ast.NodeValue{Value: types.Long(43)}}}}}, }, { "opContains", diff --git a/internal/ast/node.go b/internal/ast/node.go index 30d7c835..53363fd2 100644 --- a/internal/ast/node.go +++ b/internal/ast/node.go @@ -76,7 +76,7 @@ func (n NodeTypeLike) isNode() {} type NodeTypeIs struct { Left IsNode - EntityType types.Path + EntityType types.EntityType } func (n NodeTypeIs) isNode() {} diff --git a/internal/ast/operator.go b/internal/ast/operator.go index 59594878..1361d1ea 100644 --- a/internal/ast/operator.go +++ b/internal/ast/operator.go @@ -109,11 +109,11 @@ func (lhs Node) In(rhs Node) Node { return NewNode(NodeTypeIn{BinaryNode: BinaryNode{Left: lhs.v, Right: rhs.v}}) } -func (lhs Node) Is(entityType types.Path) Node { +func (lhs Node) Is(entityType types.EntityType) Node { return NewNode(NodeTypeIs{Left: lhs.v, EntityType: entityType}) } -func (lhs Node) IsIn(entityType types.Path, rhs Node) Node { +func (lhs Node) IsIn(entityType types.EntityType, rhs Node) Node { return NewNode(NodeTypeIsIn{NodeTypeIs: NodeTypeIs{Left: lhs.v, EntityType: entityType}, Entity: rhs.v}) } diff --git a/internal/ast/scope.go b/internal/ast/scope.go index 2cb1a9ef..fa685447 100644 --- a/internal/ast/scope.go +++ b/internal/ast/scope.go @@ -22,11 +22,11 @@ func (s Scope) InSet(entities []types.EntityUID) ScopeTypeInSet { return ScopeTypeInSet{Entities: entities} } -func (s Scope) Is(entityType types.Path) ScopeTypeIs { +func (s Scope) Is(entityType types.EntityType) ScopeTypeIs { return ScopeTypeIs{Type: entityType} } -func (s Scope) IsIn(entityType types.Path, entity types.EntityUID) ScopeTypeIsIn { +func (s Scope) IsIn(entityType types.EntityType, entity types.EntityUID) ScopeTypeIsIn { return ScopeTypeIsIn{Type: entityType, Entity: entity} } @@ -40,12 +40,12 @@ func (p *Policy) PrincipalIn(entity types.EntityUID) *Policy { return p } -func (p *Policy) PrincipalIs(entityType types.Path) *Policy { +func (p *Policy) PrincipalIs(entityType types.EntityType) *Policy { p.Principal = Scope{}.Is(entityType) return p } -func (p *Policy) PrincipalIsIn(entityType types.Path, entity types.EntityUID) *Policy { +func (p *Policy) PrincipalIsIn(entityType types.EntityType, entity types.EntityUID) *Policy { p.Principal = Scope{}.IsIn(entityType, entity) return p } @@ -75,12 +75,12 @@ func (p *Policy) ResourceIn(entity types.EntityUID) *Policy { return p } -func (p *Policy) ResourceIs(entityType types.Path) *Policy { +func (p *Policy) ResourceIs(entityType types.EntityType) *Policy { p.Resource = Scope{}.Is(entityType) return p } -func (p *Policy) ResourceIsIn(entityType types.Path, entity types.EntityUID) *Policy { +func (p *Policy) ResourceIsIn(entityType types.EntityType, entity types.EntityUID) *Policy { p.Resource = Scope{}.IsIn(entityType, entity) return p } @@ -153,13 +153,13 @@ type ScopeTypeIs struct { ScopeNode PrincipalScopeNode ResourceScopeNode - Type types.Path + Type types.EntityType } type ScopeTypeIsIn struct { ScopeNode PrincipalScopeNode ResourceScopeNode - Type types.Path + Type types.EntityType Entity types.EntityUID } diff --git a/internal/ast/value.go b/internal/ast/value.go index 6a926ee7..14673547 100644 --- a/internal/ast/value.go +++ b/internal/ast/value.go @@ -65,7 +65,7 @@ func Record(elements Pairs) Node { } func EntityUID(typ types.Ident, id types.String) Node { - e := types.NewEntityUID(types.Path(typ), types.String(id)) + e := types.NewEntityUID(types.EntityType(typ), types.String(id)) return Value(e) } diff --git a/internal/eval/evalers.go b/internal/eval/evalers.go index 65a55681..1d92afe0 100644 --- a/internal/eval/evalers.go +++ b/internal/eval/evalers.go @@ -963,10 +963,10 @@ func (n *inEval) Eval(ctx *Context) (types.Value, error) { // isEval type isEval struct { lhs Evaler - rhs types.Path + rhs types.EntityType } -func newIsEval(lhs Evaler, rhs types.Path) *isEval { +func newIsEval(lhs Evaler, rhs types.EntityType) *isEval { return &isEval{lhs: lhs, rhs: rhs} } diff --git a/internal/eval/evalers_test.go b/internal/eval/evalers_test.go index 3505c79e..32ed334e 100644 --- a/internal/eval/evalers_test.go +++ b/internal/eval/evalers_test.go @@ -16,7 +16,7 @@ var errTest = fmt.Errorf("test error") // not a real parser func strEnt(v string) types.EntityUID { p := strings.Split(v, "::\"") - return types.EntityUID{Type: types.Path(p[0]), ID: types.String(p[1][:len(p[1])-1])} + return types.EntityUID{Type: types.EntityType(p[0]), ID: types.String(p[1][:len(p[1])-1])} } func AssertValue(t *testing.T, got, want types.Value) { @@ -1586,15 +1586,15 @@ func TestEntityIn(t *testing.T) { entityMap := types.Entities{} for i := 0; i < 100; i++ { p := []types.EntityUID{ - types.NewEntityUID(types.Path(fmt.Sprint(i+1)), "1"), - types.NewEntityUID(types.Path(fmt.Sprint(i+1)), "2"), + types.NewEntityUID(types.EntityType(fmt.Sprint(i+1)), "1"), + types.NewEntityUID(types.EntityType(fmt.Sprint(i+1)), "2"), } - uid1 := types.NewEntityUID(types.Path(fmt.Sprint(i)), "1") + uid1 := types.NewEntityUID(types.EntityType(fmt.Sprint(i)), "1") entityMap[uid1] = types.Entity{ UID: uid1, Parents: p, } - uid2 := types.NewEntityUID(types.Path(fmt.Sprint(i)), "2") + uid2 := types.NewEntityUID(types.EntityType(fmt.Sprint(i)), "2") entityMap[uid2] = types.Entity{ UID: uid2, Parents: p, @@ -1611,14 +1611,14 @@ func TestIsNode(t *testing.T) { tests := []struct { name string lhs Evaler - rhs types.Path + rhs types.EntityType result types.Value err error }{ - {"happyEq", newLiteralEval(types.NewEntityUID("X", "z")), types.Path("X"), types.True, nil}, - {"happyNeq", newLiteralEval(types.NewEntityUID("X", "z")), types.Path("Y"), types.False, nil}, - {"badLhs", newLiteralEval(types.Long(42)), types.Path("X"), zeroValue(), ErrType}, - {"errLhs", newErrorEval(errTest), types.Path("X"), zeroValue(), errTest}, + {"happyEq", newLiteralEval(types.NewEntityUID("X", "z")), types.EntityType("X"), types.True, nil}, + {"happyNeq", newLiteralEval(types.NewEntityUID("X", "z")), types.EntityType("Y"), types.False, nil}, + {"badLhs", newLiteralEval(types.Long(42)), types.EntityType("X"), zeroValue(), ErrType}, + {"errLhs", newErrorEval(errTest), types.EntityType("X"), zeroValue(), errTest}, } for _, tt := range tests { tt := tt diff --git a/internal/json/json_test.go b/internal/json/json_test.go index 47e6a653..0df540cf 100644 --- a/internal/json/json_test.go +++ b/internal/json/json_test.go @@ -119,13 +119,13 @@ func TestUnmarshalJSON(t *testing.T) { { "principalIs", `{"effect":"permit","principal":{"op":"is","entity_type":"T"},"action":{"op":"All"},"resource":{"op":"All"}}`, - ast.Permit().PrincipalIs(types.Path("T")), + ast.Permit().PrincipalIs(types.EntityType("T")), testutil.OK, }, { "principalIsIn", `{"effect":"permit","principal":{"op":"is","entity_type":"T","in":{"entity":{"type":"P","id":"42"}}},"action":{"op":"All"},"resource":{"op":"All"}}`, - ast.Permit().PrincipalIsIn(types.Path("T"), types.NewEntityUID("P", "42")), + ast.Permit().PrincipalIsIn(types.EntityType("T"), types.NewEntityUID("P", "42")), testutil.OK, }, { @@ -161,13 +161,13 @@ func TestUnmarshalJSON(t *testing.T) { { "resourceIs", `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"is","entity_type":"T"}}`, - ast.Permit().ResourceIs(types.Path("T")), + ast.Permit().ResourceIs(types.EntityType("T")), testutil.OK, }, { "resourceIsIn", `{"effect":"permit","principal":{"op":"All"},"action":{"op":"All"},"resource":{"op":"is","entity_type":"T","in":{"entity":{"type":"P","id":"42"}}}}`, - ast.Permit().ResourceIsIn(types.Path("T"), types.NewEntityUID("P", "42")), + ast.Permit().ResourceIsIn(types.EntityType("T"), types.NewEntityUID("P", "42")), testutil.OK, }, { diff --git a/internal/json/json_unmarshal.go b/internal/json/json_unmarshal.go index 518697f3..b20e320c 100644 --- a/internal/json/json_unmarshal.go +++ b/internal/json/json_unmarshal.go @@ -33,9 +33,9 @@ func (s *scopeJSON) ToPrincipalResourceNode() (isPrincipalResourceScopeNode, err return ast.Scope{}.In(*s.Entity), nil case "is": if s.In == nil { - return ast.Scope{}.Is(types.Path(s.EntityType)), nil + return ast.Scope{}.Is(types.EntityType(s.EntityType)), nil } - return ast.Scope{}.IsIn(types.Path(s.EntityType), s.In.Entity), nil + return ast.Scope{}.IsIn(types.EntityType(s.EntityType), s.In.Entity), nil } return nil, fmt.Errorf("unknown op: %v", s.Op) } @@ -101,9 +101,9 @@ func (j isJSON) ToNode() (ast.Node, error) { if err != nil { return ast.Node{}, fmt.Errorf("error in entity: %w", err) } - return left.IsIn(types.Path(j.EntityType), right), nil + return left.IsIn(types.EntityType(j.EntityType), right), nil } - return left.Is(types.Path(j.EntityType)), nil + return left.Is(types.EntityType(j.EntityType)), nil } func (j ifThenElseJSON) ToNode() (ast.Node, error) { if_, err := j.If.ToNode() diff --git a/internal/parser/cedar_unmarshal.go b/internal/parser/cedar_unmarshal.go index 10a10fe5..bc0fec08 100644 --- a/internal/parser/cedar_unmarshal.go +++ b/internal/parser/cedar_unmarshal.go @@ -230,10 +230,10 @@ func (p *parser) entity() (types.EntityUID, error) { if !t.isIdent() { return res, p.errorf("expected ident") } - return p.entityFirstPathPreread(types.Path(t.Text)) + return p.entityFirstPathPreread(types.EntityType(t.Text)) } -func (p *parser) entityFirstPathPreread(firstPath types.Path) (types.EntityUID, error) { +func (p *parser) entityFirstPathPreread(firstPath types.EntityType) (types.EntityUID, error) { var res types.EntityUID res.Type = firstPath for { @@ -243,7 +243,7 @@ func (p *parser) entityFirstPathPreread(firstPath types.Path) (types.EntityUID, t := p.advance() switch { case t.isIdent(): - res.Type = types.Path(res.Type) + "::" + types.Path(t.Text) + res.Type = types.EntityType(res.Type) + "::" + types.EntityType(t.Text) case t.isString(): id, err := t.stringValue() if err != nil { @@ -257,8 +257,8 @@ func (p *parser) entityFirstPathPreread(firstPath types.Path) (types.EntityUID, } } -func (p *parser) pathFirstPathPreread(firstPath string) (types.Path, error) { - res := types.Path(firstPath) +func (p *parser) pathFirstPathPreread(firstPath string) (types.EntityType, error) { + res := types.EntityType(firstPath) for { if p.peek().Text != "::" { return res, nil @@ -267,14 +267,14 @@ func (p *parser) pathFirstPathPreread(firstPath string) (types.Path, error) { t := p.advance() switch { case t.isIdent(): - res = types.Path(fmt.Sprintf("%v::%v", res, t.Text)) + res = types.EntityType(fmt.Sprintf("%v::%v", res, t.Text)) default: return res, p.errorf("unexpected token") } } } -func (p *parser) path() (types.Path, error) { +func (p *parser) path() (types.EntityType, error) { t := p.advance() if !t.isIdent() { return "", p.errorf("expected ident") diff --git a/types/entity_uid.go b/types/entity_uid.go index c1b84868..a0fd89bb 100644 --- a/types/entity_uid.go +++ b/types/entity_uid.go @@ -5,13 +5,19 @@ import ( "strconv" ) +// Path is a series of idents separated by :: +type Path string + +// EntityType is the type portion of an EntityUID +type EntityType Path + // An EntityUID is the identifier for a principal, action, or resource. type EntityUID struct { - Type Path + Type EntityType ID String } -func NewEntityUID(typ Path, id String) EntityUID { +func NewEntityUID(typ EntityType, id String) EntityUID { return EntityUID{ Type: typ, ID: id, @@ -43,11 +49,11 @@ func (v *EntityUID) UnmarshalJSON(b []byte) error { return err } if res.Entity != nil { - v.Type = Path(res.Entity.Type) + v.Type = EntityType(res.Entity.Type) v.ID = String(res.Entity.ID) return nil } else if res.Type != nil && res.ID != nil { // require both Type and ID to parse "implicit" JSON - v.Type = Path(*res.Type) + v.Type = EntityType(*res.Type) v.ID = String(*res.ID) return nil } diff --git a/types/path.go b/types/path.go deleted file mode 100644 index 522d7b96..00000000 --- a/types/path.go +++ /dev/null @@ -1,4 +0,0 @@ -package types - -// Path is the type portion of an EntityUID -type Path string From 1da3cb2aaef0715c9131a34f587bfbbf6af5ed66 Mon Sep 17 00:00:00 2001 From: philhassey Date: Thu, 22 Aug 2024 17:27:17 -0600 Subject: [PATCH 216/216] types: change Entity to pointer inside of Entities Addresses IDX-142 Signed-off-by: philhassey --- authorize_test.go | 16 ++++++++-------- internal/eval/evalers.go | 5 ++++- internal/eval/evalers_test.go | 12 ++++++------ types/entities.go | 10 +++++----- types/entities_test.go | 6 +++--- 5 files changed, 26 insertions(+), 23 deletions(-) diff --git a/authorize_test.go b/authorize_test.go index da461e59..9070f963 100644 --- a/authorize_test.go +++ b/authorize_test.go @@ -108,7 +108,7 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-requires-entities-success", Policy: `permit(principal,action,resource) when { principal.x == 42 };`, Entities: types.Entities{ - cuzco: types.Entity{ + cuzco: &types.Entity{ UID: cuzco, Attributes: types.Record{"x": types.Long(42)}, }, @@ -124,7 +124,7 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-requires-entities-fail", Policy: `permit(principal,action,resource) when { principal.x == 42 };`, Entities: types.Entities{ - cuzco: types.Entity{ + cuzco: &types.Entity{ UID: cuzco, Attributes: types.Record{"x": types.Long(43)}, }, @@ -140,7 +140,7 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-requires-entities-parent-success", Policy: `permit(principal,action,resource) when { principal in parent::"bob" };`, Entities: types.Entities{ - cuzco: types.Entity{ + cuzco: &types.Entity{ UID: cuzco, Parents: []types.EntityUID{types.NewEntityUID("parent", "bob")}, }, @@ -167,7 +167,7 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-principal-in", Policy: `permit(principal in team::"osiris",action,resource);`, Entities: types.Entities{ - cuzco: types.Entity{ + cuzco: &types.Entity{ UID: cuzco, Parents: []types.EntityUID{types.NewEntityUID("team", "osiris")}, }, @@ -194,7 +194,7 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-action-in", Policy: `permit(principal,action in scary::"stuff",resource);`, Entities: types.Entities{ - dropTable: types.Entity{ + dropTable: &types.Entity{ UID: dropTable, Parents: []types.EntityUID{types.NewEntityUID("scary", "stuff")}, }, @@ -210,7 +210,7 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-action-in-set", Policy: `permit(principal,action in [scary::"stuff"],resource);`, Entities: types.Entities{ - dropTable: types.Entity{ + dropTable: &types.Entity{ UID: dropTable, Parents: []types.EntityUID{types.NewEntityUID("scary", "stuff")}, }, @@ -303,7 +303,7 @@ func TestIsAuthorized(t *testing.T) { Name: "permit-when-relations-has", Policy: `permit(principal,action,resource) when { principal has name };`, Entities: types.Entities{ - cuzco: types.Entity{ + cuzco: &types.Entity{ UID: cuzco, Attributes: types.Record{"name": types.String("bob")}, }, @@ -693,7 +693,7 @@ func TestIsAuthorized(t *testing.T) { Name: "when-is-in", Policy: `permit(principal,action,resource) when { resource is Resource in Parent::"id" };`, Entities: types.Entities{ - types.NewEntityUID("Resource", "table"): types.Entity{ + types.NewEntityUID("Resource", "table"): &types.Entity{ UID: types.NewEntityUID("Resource", "table"), Parents: []types.EntityUID{types.NewEntityUID("Parent", "id")}, }, diff --git a/internal/eval/evalers.go b/internal/eval/evalers.go index 1d92afe0..333d9c71 100644 --- a/internal/eval/evalers.go +++ b/internal/eval/evalers.go @@ -924,7 +924,10 @@ func entityIn(entity types.EntityUID, query map[types.EntityUID]struct{}, entity if _, ok := query[candidate]; ok { return true } - toCheck = append(toCheck, entityMap[candidate].Parents...) + next, ok := entityMap[candidate] + if ok { + toCheck = append(toCheck, next.Parents...) + } checked[candidate] = struct{}{} } return false diff --git a/internal/eval/evalers_test.go b/internal/eval/evalers_test.go index 32ed334e..2f5a3e5b 100644 --- a/internal/eval/evalers_test.go +++ b/internal/eval/evalers_test.go @@ -1308,7 +1308,7 @@ func TestAttributeAccessNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newAttributeAccessEval(tt.object, tt.attribute) - entity := types.Entity{ + entity := &types.Entity{ UID: types.NewEntityUID("knownType", "knownID"), Attributes: types.Record{"knownAttr": types.Long(42)}, } @@ -1365,7 +1365,7 @@ func TestHasNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() n := newHasEval(tt.record, tt.attribute) - entity := types.Entity{ + entity := &types.Entity{ UID: types.NewEntityUID("knownType", "knownID"), Attributes: types.Record{"knownAttr": types.Long(42)}, } @@ -1569,7 +1569,7 @@ func TestEntityIn(t *testing.T) { ps = append(ps, strEnt(pp)) } uid := strEnt(k) - entityMap[uid] = types.Entity{ + entityMap[uid] = &types.Entity{ UID: uid, Parents: ps, } @@ -1590,12 +1590,12 @@ func TestEntityIn(t *testing.T) { types.NewEntityUID(types.EntityType(fmt.Sprint(i+1)), "2"), } uid1 := types.NewEntityUID(types.EntityType(fmt.Sprint(i)), "1") - entityMap[uid1] = types.Entity{ + entityMap[uid1] = &types.Entity{ UID: uid1, Parents: p, } uid2 := types.NewEntityUID(types.EntityType(fmt.Sprint(i)), "2") - entityMap[uid2] = types.Entity{ + entityMap[uid2] = &types.Entity{ UID: uid2, Parents: p, } @@ -1735,7 +1735,7 @@ func TestInNode(t *testing.T) { ps = append(ps, strEnt(pp)) } uid := strEnt(k) - entityMap[uid] = types.Entity{ + entityMap[uid] = &types.Entity{ UID: uid, Parents: ps, } diff --git a/types/entities.go b/types/entities.go index 55bd5e14..8da8249b 100644 --- a/types/entities.go +++ b/types/entities.go @@ -11,7 +11,7 @@ import ( // An Entities is a collection of all the Entities that are needed to evaluate // authorization requests. The key is an EntityUID which uniquely identifies // the Entity (it must be the same as the UID within the Entity itself.) -type Entities map[EntityUID]Entity +type Entities map[EntityUID]*Entity // An Entity defines the parents and attributes for an EntityUID. type Entity struct { @@ -26,7 +26,7 @@ func (e Entities) MarshalJSON() ([]byte, error) { } func (e *Entities) UnmarshalJSON(b []byte) error { - var s []Entity + var s []*Entity if err := json.Unmarshal(b, &s); err != nil { return err } @@ -34,7 +34,7 @@ func (e *Entities) UnmarshalJSON(b []byte) error { return nil } -func entitiesFromSlice(s []Entity) Entities { +func entitiesFromSlice(s []*Entity) Entities { var res = Entities{} for _, e := range s { res[e.UID] = e @@ -42,9 +42,9 @@ func entitiesFromSlice(s []Entity) Entities { return res } -func (e Entities) toSlice() []Entity { +func (e Entities) toSlice() []*Entity { s := maps.Values(e) - slices.SortFunc(s, func(a, b Entity) int { + slices.SortFunc(s, func(a, b *Entity) int { return strings.Compare(a.UID.String(), b.UID.String()) }) return s diff --git a/types/entities_test.go b/types/entities_test.go index 73374a66..a02efa0a 100644 --- a/types/entities_test.go +++ b/types/entities_test.go @@ -29,12 +29,12 @@ func TestEntitiesJSON(t *testing.T) { t.Run("Marshal", func(t *testing.T) { t.Parallel() e := types.Entities{} - ent := types.Entity{ + ent := &types.Entity{ UID: types.NewEntityUID("Type", "id"), Parents: []types.EntityUID{}, Attributes: types.Record{"key": types.Long(42)}, } - ent2 := types.Entity{ + ent2 := &types.Entity{ UID: types.NewEntityUID("Type", "id2"), Parents: []types.EntityUID{}, Attributes: types.Record{"key": types.Long(42)}, @@ -53,7 +53,7 @@ func TestEntitiesJSON(t *testing.T) { err := json.Unmarshal(b, &e) testutil.OK(t, err) want := types.Entities{} - ent := types.Entity{ + ent := &types.Entity{ UID: types.NewEntityUID("Type", "id"), Parents: []types.EntityUID{}, Attributes: types.Record{"key": types.Long(42)},