diff --git a/ast/ast_test.go b/ast/ast_test.go index d1cbe912..9f12822e 100644 --- a/ast/ast_test.go +++ b/ast/ast_test.go @@ -3,6 +3,7 @@ package ast_test import ( "net/netip" "testing" + "time" "github.com/cedar-policy/cedar-go/ast" internalast "github.com/cedar-policy/cedar-go/internal/ast" @@ -411,7 +412,51 @@ func TestASTByTable(t *testing.T) { internalast.Permit().When(internalast.Long(42).IsInRange(internalast.Long(43))), }, { - "decimalExtension", + "opOffset", + ast.Permit().When(ast.Datetime(time.Time{}).Offset(ast.Duration(time.Duration(100)))), + internalast.Permit().When(internalast.Datetime(time.Time{}).Offset(internalast.Duration(time.Duration(100)))), + }, + { + "opDurationSince", + ast.Permit().When(ast.Datetime(time.Time{}).DurationSince(ast.Datetime(time.Time{}))), + internalast.Permit().When(internalast.Datetime(time.Time{}).DurationSince(internalast.Datetime(time.Time{}))), + }, + { + "opToDate", + ast.Permit().When(ast.Datetime(time.Time{}).ToDate()), + internalast.Permit().When(internalast.Datetime(time.Time{}).ToDate()), + }, + { + "opToTime", + ast.Permit().When(ast.Datetime(time.Time{}).ToTime()), + internalast.Permit().When(internalast.Datetime(time.Time{}).ToTime()), + }, + { + "opToDays", + ast.Permit().When(ast.Duration(time.Duration(100)).ToDays()), + internalast.Permit().When(internalast.Duration(100).ToDays()), + }, + { + "opToHours", + ast.Permit().When(ast.Duration(time.Duration(100)).ToHours()), + internalast.Permit().When(internalast.Duration(100).ToHours()), + }, + { + "opToMinutes", + ast.Permit().When(ast.Duration(time.Duration(100)).ToMinutes()), + internalast.Permit().When(internalast.Duration(100).ToMinutes()), + }, + { + "opToSeconds", + ast.Permit().When(ast.Duration(time.Duration(100)).ToSeconds()), + internalast.Permit().When(internalast.Duration(100).ToSeconds()), + }, + { + "opToMilliseconds", + ast.Permit().When(ast.Duration(time.Duration(100)).ToMilliseconds()), + internalast.Permit().When(internalast.Duration(100).ToMilliseconds()), + }, + "decimalExtension", ast.Permit().When(ast.DecimalExtensionCall(ast.Value(types.String("3.14")))), internalast.Permit().When(internalast.ExtensionCall("decimal", internalast.String("3.14"))), }, @@ -429,7 +474,7 @@ func TestASTByTable(t *testing.T) { "duration", ast.Permit().When(ast.DurationExtensionCall(ast.Value(types.String("1d2h3m4s5ms")))), internalast.Permit().When(internalast.ExtensionCall("duration", internalast.String("1d2h3m4s5ms"))), - }, + }, } for _, tt := range tests { diff --git a/ast/operator.go b/ast/operator.go index 1985be1c..833f92a2 100644 --- a/ast/operator.go +++ b/ast/operator.go @@ -165,3 +165,33 @@ func (lhs Node) IsLoopback() Node { func (lhs Node) IsInRange(rhs Node) Node { return wrapNode(lhs.Node.IsInRange(rhs.Node)) } + +// ____ _ _ _ +// | _ \ __ _| |_ ___| |_(_)_ __ ___ ___ +// | | | |/ _` | __/ _ \ __| | '_ ` _ \ / _ \ +// | |_| | (_| | || __/ |_| | | | | | | __/ +// |____/ \__,_|\__\___|\__|_|_| |_| |_|\___| + +func (lhs Node) Offset(rhs Node) Node { return wrapNode(lhs.Node.Offset(rhs.Node)) } + +func (lhs Node) DurationSince(rhs Node) Node { return wrapNode(lhs.Node.DurationSince(rhs.Node)) } + +func (lhs Node) ToDate() Node { return wrapNode(lhs.Node.ToDate()) } + +func (lhs Node) ToTime() Node { return wrapNode(lhs.Node.ToTime()) } + +// ____ _ _ +// | _ \ _ _ _ __ __ _| |_(_) ___ _ __ +// | | | | | | | '__/ _` | __| |/ _ \| '_ \ +// | |_| | |_| | | | (_| | |_| | (_) | | | | +// |____/ \__,_|_| \__,_|\__|_|\___/|_| |_| + +func (lhs Node) ToDays() Node { return wrapNode(lhs.Node.ToDays()) } + +func (lhs Node) ToHours() Node { return wrapNode(lhs.Node.ToHours()) } + +func (lhs Node) ToMinutes() Node { return wrapNode(lhs.Node.ToMinutes()) } + +func (lhs Node) ToSeconds() Node { return wrapNode(lhs.Node.ToSeconds()) } + +func (lhs Node) ToMilliseconds() Node { return wrapNode(lhs.Node.ToMilliseconds()) } diff --git a/ast/value.go b/ast/value.go index 6f48799a..bd76be59 100644 --- a/ast/value.go +++ b/ast/value.go @@ -2,6 +2,7 @@ package ast import ( "net/netip" + "time" "github.com/cedar-policy/cedar-go/internal/ast" "github.com/cedar-policy/cedar-go/types" @@ -79,6 +80,14 @@ func IPAddr[T netip.Prefix | types.IPAddr](i T) Node { return wrapNode(ast.IPAddr(types.IPAddr(i))) } +func Datetime(t time.Time) Node { + return Value(types.FromStdTime(t)) +} + +func Duration(d time.Duration) Node { + return Value(types.FromStdDuration(d)) +} + // Value creates a value node from any value. func Value(v types.Value) Node { return wrapNode(ast.Value(v)) diff --git a/authorize_test.go b/authorize_test.go index 3101c162..ee4f393e 100644 --- a/authorize_test.go +++ b/authorize_test.go @@ -15,7 +15,7 @@ func TestIsAuthorized(t *testing.T) { tests := []struct { Name string Policy string - Entities cedar.Entities + Entities cedar.EntityMap Principal, Action, Resource cedar.EntityUID Context cedar.Record Want cedar.Decision @@ -25,7 +25,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "simple-permit", Policy: `permit(principal,action,resource);`, - Entities: cedar.Entities{}, + Entities: cedar.EntityMap{}, Principal: cuzco, Action: dropTable, Resource: cedar.NewEntityUID("table", "whatever"), @@ -36,7 +36,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "simple-forbid", Policy: `forbid(principal,action,resource);`, - Entities: cedar.Entities{}, + Entities: cedar.EntityMap{}, Principal: cuzco, Action: dropTable, Resource: cedar.NewEntityUID("table", "whatever"), @@ -47,7 +47,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "no-permit", Policy: `permit(principal,action,resource in asdf::"1234");`, - Entities: cedar.Entities{}, + Entities: cedar.EntityMap{}, Principal: cuzco, Action: dropTable, Resource: cedar.NewEntityUID("table", "whatever"), @@ -58,7 +58,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "error-in-policy", Policy: `permit(principal,action,resource) when { resource in "foo" };`, - Entities: cedar.Entities{}, + Entities: cedar.EntityMap{}, Principal: cuzco, Action: dropTable, Resource: cedar.NewEntityUID("table", "whatever"), @@ -71,7 +71,7 @@ func TestIsAuthorized(t *testing.T) { Policy: `permit(principal,action,resource) when { resource in "foo" }; permit(principal,action,resource); `, - Entities: cedar.Entities{}, + Entities: cedar.EntityMap{}, Principal: cuzco, Action: dropTable, Resource: cedar.NewEntityUID("table", "whatever"), @@ -82,7 +82,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-requires-context-success", Policy: `permit(principal,action,resource) when { context.x == 42 };`, - Entities: cedar.Entities{}, + Entities: cedar.EntityMap{}, Principal: cuzco, Action: dropTable, Resource: cedar.NewEntityUID("table", "whatever"), @@ -93,7 +93,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-requires-context-fail", Policy: `permit(principal,action,resource) when { context.x == 42 };`, - Entities: cedar.Entities{}, + Entities: cedar.EntityMap{}, Principal: cuzco, Action: dropTable, Resource: cedar.NewEntityUID("table", "whatever"), @@ -104,7 +104,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-requires-entities-success", Policy: `permit(principal,action,resource) when { principal.x == 42 };`, - Entities: cedar.Entities{ + Entities: cedar.EntityMap{ cuzco: cedar.Entity{ UID: cuzco, Attributes: cedar.NewRecord(cedar.RecordMap{"x": cedar.Long(42)}), @@ -120,7 +120,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-requires-entities-fail", Policy: `permit(principal,action,resource) when { principal.x == 42 };`, - Entities: cedar.Entities{ + Entities: cedar.EntityMap{ cuzco: cedar.Entity{ UID: cuzco, Attributes: cedar.NewRecord(cedar.RecordMap{"x": cedar.Long(43)}), @@ -136,7 +136,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-requires-entities-parent-success", Policy: `permit(principal,action,resource) when { principal in parent::"bob" };`, - Entities: cedar.Entities{ + Entities: cedar.EntityMap{ cuzco: cedar.Entity{ UID: cuzco, Parents: cedar.NewEntityUIDSet(cedar.NewEntityUID("parent", "bob")), @@ -152,7 +152,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-principal-equals", Policy: `permit(principal == coder::"cuzco",action,resource);`, - Entities: cedar.Entities{}, + Entities: cedar.EntityMap{}, Principal: cuzco, Action: dropTable, Resource: cedar.NewEntityUID("table", "whatever"), @@ -163,7 +163,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-principal-in", Policy: `permit(principal in team::"osiris",action,resource);`, - Entities: cedar.Entities{ + Entities: cedar.EntityMap{ cuzco: cedar.Entity{ UID: cuzco, Parents: cedar.NewEntityUIDSet(cedar.NewEntityUID("team", "osiris")), @@ -179,7 +179,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-action-equals", Policy: `permit(principal,action == table::"drop",resource);`, - Entities: cedar.Entities{}, + Entities: cedar.EntityMap{}, Principal: cuzco, Action: dropTable, Resource: cedar.NewEntityUID("table", "whatever"), @@ -190,7 +190,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-action-in", Policy: `permit(principal,action in scary::"stuff",resource);`, - Entities: cedar.Entities{ + Entities: cedar.EntityMap{ dropTable: cedar.Entity{ UID: dropTable, Parents: cedar.NewEntityUIDSet(cedar.NewEntityUID("scary", "stuff")), @@ -206,7 +206,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-action-in-set", Policy: `permit(principal,action in [scary::"stuff"],resource);`, - Entities: cedar.Entities{ + Entities: cedar.EntityMap{ dropTable: cedar.Entity{ UID: dropTable, Parents: cedar.NewEntityUIDSet(cedar.NewEntityUID("scary", "stuff")), @@ -222,7 +222,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-resource-equals", Policy: `permit(principal,action,resource == table::"whatever");`, - Entities: cedar.Entities{}, + Entities: cedar.EntityMap{}, Principal: cuzco, Action: dropTable, Resource: cedar.NewEntityUID("table", "whatever"), @@ -233,7 +233,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-unless", Policy: `permit(principal,action,resource) unless { false };`, - Entities: cedar.Entities{}, + Entities: cedar.EntityMap{}, Principal: cuzco, Action: dropTable, Resource: cedar.NewEntityUID("table", "whatever"), @@ -244,7 +244,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-if", Policy: `permit(principal,action,resource) when { (if true then true else true) };`, - Entities: cedar.Entities{}, + Entities: cedar.EntityMap{}, Principal: cuzco, Action: dropTable, Resource: cedar.NewEntityUID("table", "whatever"), @@ -255,7 +255,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-or", Policy: `permit(principal,action,resource) when { (true || false) };`, - Entities: cedar.Entities{}, + Entities: cedar.EntityMap{}, Principal: cuzco, Action: dropTable, Resource: cedar.NewEntityUID("table", "whatever"), @@ -266,7 +266,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-and", Policy: `permit(principal,action,resource) when { (true && true) };`, - Entities: cedar.Entities{}, + Entities: cedar.EntityMap{}, Principal: cuzco, Action: dropTable, Resource: cedar.NewEntityUID("table", "whatever"), @@ -277,7 +277,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-relations", Policy: `permit(principal,action,resource) when { (1<2) && (1<=1) && (2>1) && (1>=1) && (1!=2) && (1==1)};`, - Entities: cedar.Entities{}, + Entities: cedar.EntityMap{}, Principal: cuzco, Action: dropTable, Resource: cedar.NewEntityUID("table", "whatever"), @@ -288,7 +288,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-relations-in", Policy: `permit(principal,action,resource) when { principal in principal };`, - Entities: cedar.Entities{}, + Entities: cedar.EntityMap{}, Principal: cuzco, Action: dropTable, Resource: cedar.NewEntityUID("table", "whatever"), @@ -299,7 +299,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-relations-has", Policy: `permit(principal,action,resource) when { principal has name };`, - Entities: cedar.Entities{ + Entities: cedar.EntityMap{ cuzco: cedar.Entity{ UID: cuzco, Attributes: cedar.NewRecord(cedar.RecordMap{"name": cedar.String("bob")}), @@ -315,7 +315,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-add-sub", Policy: `permit(principal,action,resource) when { 40+3-1==42 };`, - Entities: cedar.Entities{}, + Entities: cedar.EntityMap{}, Principal: cuzco, Action: dropTable, Resource: cedar.NewEntityUID("table", "whatever"), @@ -326,7 +326,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-mul", Policy: `permit(principal,action,resource) when { 6*7==42 };`, - Entities: cedar.Entities{}, + Entities: cedar.EntityMap{}, Principal: cuzco, Action: dropTable, Resource: cedar.NewEntityUID("table", "whatever"), @@ -337,7 +337,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-negate", Policy: `permit(principal,action,resource) when { -42==-42 };`, - Entities: cedar.Entities{}, + Entities: cedar.EntityMap{}, Principal: cuzco, Action: dropTable, Resource: cedar.NewEntityUID("table", "whatever"), @@ -348,7 +348,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-not", Policy: `permit(principal,action,resource) when { !(1+1==42) };`, - Entities: cedar.Entities{}, + Entities: cedar.EntityMap{}, Principal: cuzco, Action: dropTable, Resource: cedar.NewEntityUID("table", "whatever"), @@ -359,7 +359,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-set", Policy: `permit(principal,action,resource) when { [1,2,3].contains(2) };`, - Entities: cedar.Entities{}, + Entities: cedar.EntityMap{}, Principal: cuzco, Action: dropTable, Resource: cedar.NewEntityUID("table", "whatever"), @@ -370,7 +370,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-record", Policy: `permit(principal,action,resource) when { {name:"bob"} has name };`, - Entities: cedar.Entities{}, + Entities: cedar.EntityMap{}, Principal: cuzco, Action: dropTable, Resource: cedar.NewEntityUID("table", "whatever"), @@ -381,7 +381,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-action", Policy: `permit(principal,action,resource) when { action in action };`, - Entities: cedar.Entities{}, + Entities: cedar.EntityMap{}, Principal: cuzco, Action: dropTable, Resource: cedar.NewEntityUID("table", "whatever"), @@ -392,7 +392,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-set-contains-ok", Policy: `permit(principal,action,resource) when { [1,2,3].contains(2) };`, - Entities: cedar.Entities{}, + Entities: cedar.EntityMap{}, Principal: cuzco, Action: dropTable, Resource: cedar.NewEntityUID("table", "whatever"), @@ -403,7 +403,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-set-contains-error", Policy: `permit(principal,action,resource) when { [1,2,3].contains(2,3) };`, - Entities: cedar.Entities{}, + Entities: cedar.EntityMap{}, Principal: cuzco, Action: dropTable, Resource: cedar.NewEntityUID("table", "whatever"), @@ -415,7 +415,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-set-containsAll-ok", Policy: `permit(principal,action,resource) when { [1,2,3].containsAll([2,3]) };`, - Entities: cedar.Entities{}, + Entities: cedar.EntityMap{}, Principal: cuzco, Action: dropTable, Resource: cedar.NewEntityUID("table", "whatever"), @@ -426,7 +426,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-set-containsAll-error", Policy: `permit(principal,action,resource) when { [1,2,3].containsAll(2,3) };`, - Entities: cedar.Entities{}, + Entities: cedar.EntityMap{}, Principal: cuzco, Action: dropTable, Resource: cedar.NewEntityUID("table", "whatever"), @@ -438,7 +438,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-set-containsAny-ok", Policy: `permit(principal,action,resource) when { [1,2,3].containsAny([2,5]) };`, - Entities: cedar.Entities{}, + Entities: cedar.EntityMap{}, Principal: cuzco, Action: dropTable, Resource: cedar.NewEntityUID("table", "whatever"), @@ -449,7 +449,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-set-containsAny-error", Policy: `permit(principal,action,resource) when { [1,2,3].containsAny(2,5) };`, - Entities: cedar.Entities{}, + Entities: cedar.EntityMap{}, Principal: cuzco, Action: dropTable, Resource: cedar.NewEntityUID("table", "whatever"), @@ -461,7 +461,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-record-attr", Policy: `permit(principal,action,resource) when { {name:"bob"}["name"] == "bob" };`, - Entities: cedar.Entities{}, + Entities: cedar.EntityMap{}, Principal: cuzco, Action: dropTable, Resource: cedar.NewEntityUID("table", "whatever"), @@ -472,7 +472,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-unknown-method", Policy: `permit(principal,action,resource) when { [1,2,3].shuffle() };`, - Entities: cedar.Entities{}, + Entities: cedar.EntityMap{}, Principal: cuzco, Action: dropTable, Resource: cedar.NewEntityUID("table", "whatever"), @@ -484,7 +484,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-like", Policy: `permit(principal,action,resource) when { "bananas" like "*nan*" };`, - Entities: cedar.Entities{}, + Entities: cedar.EntityMap{}, Principal: cuzco, Action: dropTable, Resource: cedar.NewEntityUID("table", "whatever"), @@ -495,7 +495,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-unknown-ext-fun", Policy: `permit(principal,action,resource) when { fooBar("10") };`, - Entities: cedar.Entities{}, + Entities: cedar.EntityMap{}, Principal: cuzco, Action: dropTable, Resource: cedar.NewEntityUID("table", "whatever"), @@ -511,7 +511,7 @@ func TestIsAuthorized(t *testing.T) { decimal("10.0").lessThanOrEqual(decimal("11.0")) && decimal("10.0").greaterThan(decimal("9.0")) && decimal("10.0").greaterThanOrEqual(decimal("9.0")) };`, - Entities: cedar.Entities{}, + Entities: cedar.EntityMap{}, Principal: cuzco, Action: dropTable, Resource: cedar.NewEntityUID("table", "whatever"), @@ -522,7 +522,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-decimal-fun-wrong-arity", Policy: `permit(principal,action,resource) when { decimal(1, 2) };`, - Entities: cedar.Entities{}, + Entities: cedar.EntityMap{}, Principal: cuzco, Action: dropTable, Resource: cedar.NewEntityUID("table", "whatever"), @@ -538,7 +538,7 @@ func TestIsAuthorized(t *testing.T) { datetime("1970-01-01T09:08:07Z") > (datetime("1970-01-01")) && datetime("1970-01-01T09:08:07Z") >= (datetime("1970-01-01")) && datetime("1970-01-01T09:08:07Z").toDate() == datetime("1970-01-01")};`, - Entities: cedar.Entities{}, + Entities: cedar.EntityMap{}, Principal: cuzco, Action: dropTable, Resource: cedar.NewEntityUID("table", "whatever"), @@ -549,7 +549,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-datetime-fun-wrong-arity", Policy: `permit(principal,action,resource) when { datetime("1970-01-01", "UTC") };`, - Entities: cedar.Entities{}, + Entities: cedar.EntityMap{}, Principal: cuzco, Action: dropTable, Resource: cedar.NewEntityUID("table", "whatever"), @@ -573,7 +573,7 @@ func TestIsAuthorized(t *testing.T) { datetime("1970-01-01").offset(duration("1ms")).toTime() == duration("1ms") && datetime("1970-01-01T00:00:00.001Z").durationSince(datetime("1970-01-01")) == duration("1ms")};`, - Entities: cedar.Entities{}, + Entities: cedar.EntityMap{}, Principal: cuzco, Action: dropTable, Resource: cedar.NewEntityUID("table", "whatever"), @@ -584,7 +584,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-duration-fun-wrong-arity", Policy: `permit(principal,action,resource) when { duration("1h", "huh?") };`, - Entities: cedar.Entities{}, + Entities: cedar.EntityMap{}, Principal: cuzco, Action: dropTable, Resource: cedar.NewEntityUID("table", "whatever"), @@ -600,7 +600,7 @@ func TestIsAuthorized(t *testing.T) { ip("::1").isLoopback() && ip("224.1.2.3").isMulticast() && ip("127.0.0.1").isInRange(ip("127.0.0.0/16"))};`, - Entities: cedar.Entities{}, + Entities: cedar.EntityMap{}, Principal: cuzco, Action: dropTable, Resource: cedar.NewEntityUID("table", "whatever"), @@ -611,7 +611,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-ip-fun-wrong-arity", Policy: `permit(principal,action,resource) when { ip() };`, - Entities: cedar.Entities{}, + Entities: cedar.EntityMap{}, Principal: cuzco, Action: dropTable, Resource: cedar.NewEntityUID("table", "whatever"), @@ -622,7 +622,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-isIpv4-wrong-arity", Policy: `permit(principal,action,resource) when { ip("1.2.3.4").isIpv4(true) };`, - Entities: cedar.Entities{}, + Entities: cedar.EntityMap{}, Principal: cuzco, Action: dropTable, Resource: cedar.NewEntityUID("table", "whatever"), @@ -633,7 +633,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-isIpv6-wrong-arity", Policy: `permit(principal,action,resource) when { ip("1.2.3.4").isIpv6(true) };`, - Entities: cedar.Entities{}, + Entities: cedar.EntityMap{}, Principal: cuzco, Action: dropTable, Resource: cedar.NewEntityUID("table", "whatever"), @@ -644,7 +644,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-isLoopback-wrong-arity", Policy: `permit(principal,action,resource) when { ip("1.2.3.4").isLoopback(true) };`, - Entities: cedar.Entities{}, + Entities: cedar.EntityMap{}, Principal: cuzco, Action: dropTable, Resource: cedar.NewEntityUID("table", "whatever"), @@ -655,7 +655,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-isMulticast-wrong-arity", Policy: `permit(principal,action,resource) when { ip("1.2.3.4").isMulticast(true) };`, - Entities: cedar.Entities{}, + Entities: cedar.EntityMap{}, Principal: cuzco, Action: dropTable, Resource: cedar.NewEntityUID("table", "whatever"), @@ -666,7 +666,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "permit-when-isInRange-wrong-arity", Policy: `permit(principal,action,resource) when { ip("1.2.3.4").isInRange() };`, - Entities: cedar.Entities{}, + Entities: cedar.EntityMap{}, Principal: cuzco, Action: dropTable, Resource: cedar.NewEntityUID("table", "whatever"), @@ -677,7 +677,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "negative-unary-op", Policy: `permit(principal,action,resource) when { -context.value > 0 };`, - Entities: cedar.Entities{}, + Entities: cedar.EntityMap{}, Context: cedar.NewRecord(cedar.RecordMap{"value": cedar.Long(-42)}), Want: true, DiagErr: 0, @@ -685,7 +685,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "principal-is", Policy: `permit(principal is Actor,action,resource);`, - Entities: cedar.Entities{}, + Entities: cedar.EntityMap{}, Principal: cedar.NewEntityUID("Actor", "cuzco"), Action: cedar.NewEntityUID("Action", "drop"), Resource: cedar.NewEntityUID("Resource", "table"), @@ -696,7 +696,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "principal-is-in", Policy: `permit(principal is Actor in Actor::"cuzco",action,resource);`, - Entities: cedar.Entities{}, + Entities: cedar.EntityMap{}, Principal: cedar.NewEntityUID("Actor", "cuzco"), Action: cedar.NewEntityUID("Action", "drop"), Resource: cedar.NewEntityUID("Resource", "table"), @@ -707,7 +707,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "resource-is", Policy: `permit(principal,action,resource is Resource);`, - Entities: cedar.Entities{}, + Entities: cedar.EntityMap{}, Principal: cedar.NewEntityUID("Actor", "cuzco"), Action: cedar.NewEntityUID("Action", "drop"), Resource: cedar.NewEntityUID("Resource", "table"), @@ -718,7 +718,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "resource-is-in", Policy: `permit(principal,action,resource is Resource in Resource::"table");`, - Entities: cedar.Entities{}, + Entities: cedar.EntityMap{}, Principal: cedar.NewEntityUID("Actor", "cuzco"), Action: cedar.NewEntityUID("Action", "drop"), Resource: cedar.NewEntityUID("Resource", "table"), @@ -729,7 +729,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "when-is", Policy: `permit(principal,action,resource) when { resource is Resource };`, - Entities: cedar.Entities{}, + Entities: cedar.EntityMap{}, Principal: cedar.NewEntityUID("Actor", "cuzco"), Action: cedar.NewEntityUID("Action", "drop"), Resource: cedar.NewEntityUID("Resource", "table"), @@ -740,7 +740,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "when-is-in", Policy: `permit(principal,action,resource) when { resource is Resource in Resource::"table" };`, - Entities: cedar.Entities{}, + Entities: cedar.EntityMap{}, Principal: cedar.NewEntityUID("Actor", "cuzco"), Action: cedar.NewEntityUID("Action", "drop"), Resource: cedar.NewEntityUID("Resource", "table"), @@ -751,7 +751,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "when-is-in", Policy: `permit(principal,action,resource) when { resource is Resource in Parent::"id" };`, - Entities: cedar.Entities{ + Entities: cedar.EntityMap{ cedar.NewEntityUID("Resource", "table"): cedar.Entity{ UID: cedar.NewEntityUID("Resource", "table"), Parents: cedar.NewEntityUIDSet(cedar.NewEntityUID("Parent", "id")), @@ -767,7 +767,7 @@ func TestIsAuthorized(t *testing.T) { { Name: "rfc-57", // https://github.com/cedar-policy/rfcs/blob/main/text/0057-general-multiplication.md Policy: `permit(principal, action, resource) when { context.foo * principal.bar >= 100 };`, - Entities: cedar.Entities{ + Entities: cedar.EntityMap{ cedar.NewEntityUID("Principal", "1"): cedar.Entity{ UID: cedar.NewEntityUID("Principal", "1"), Attributes: cedar.NewRecord(cedar.RecordMap{"bar": cedar.Long(42)}), diff --git a/corpus_test.go b/corpus_test.go index b29c4774..66ae991d 100644 --- a/corpus_test.go +++ b/corpus_test.go @@ -141,7 +141,7 @@ func TestCorpus(t *testing.T) { t.Fatal("error reading entities content", err) } - var entities cedar.Entities + var entities cedar.EntityMap if err := json.Unmarshal(entitiesContent, &entities); err != nil { t.Fatal("error unmarshalling test", err) } @@ -384,7 +384,7 @@ func TestCorpusRelated(t *testing.T) { t.Parallel() policy, err := cedar.NewPolicySetFromBytes("", []byte(tt.policy)) testutil.OK(t, err) - ok, diag := policy.IsAuthorized(cedar.Entities{}, tt.request) + ok, diag := policy.IsAuthorized(cedar.EntityMap{}, tt.request) testutil.Equals(t, ok, tt.decision) var reasons []cedar.PolicyID for _, n := range diag.Reasons { diff --git a/internal/ast/ast_test.go b/internal/ast/ast_test.go index a9b61488..6bdedfa0 100644 --- a/internal/ast/ast_test.go +++ b/internal/ast/ast_test.go @@ -3,6 +3,7 @@ package ast_test import ( "net/netip" "testing" + "time" "github.com/cedar-policy/cedar-go/internal/ast" "github.com/cedar-policy/cedar-go/internal/testutil" @@ -480,6 +481,59 @@ func TestASTByTable(t *testing.T) { ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeExtensionCall{Name: "isInRange", Args: []ast.IsNode{ast.NodeValue{Value: types.Long(42)}, ast.NodeValue{Value: types.Long(43)}}}}}}, }, + { + "opOffset", + ast.Permit().When(ast.Datetime(time.Time{}).Offset(ast.Duration(time.Duration(100)))), + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeExtensionCall{Name: "offset", Args: []ast.IsNode{ast.NodeValue{Value: types.FromStdTime(time.Time{})}, ast.NodeValue{Value: types.FromStdDuration(time.Duration(100))}}}}}}, + }, + { + "opDurationSince", + ast.Permit().When(ast.Datetime(time.Time{}).DurationSince(ast.Datetime(time.Time{}))), + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeExtensionCall{Name: "durationSince", Args: []ast.IsNode{ast.NodeValue{Value: types.FromStdTime(time.Time{})}, ast.NodeValue{Value: types.FromStdTime(time.Time{})}}}}}}, + }, + { + "opToDate", + ast.Permit().When(ast.Datetime(time.Time{}).ToDate()), + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeExtensionCall{Name: "toDate", Args: []ast.IsNode{ast.NodeValue{Value: types.FromStdTime(time.Time{})}}}}}}, + }, + { + "opToTime", + ast.Permit().When(ast.Datetime(time.Time{}).ToTime()), + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeExtensionCall{Name: "toTime", Args: []ast.IsNode{ast.NodeValue{Value: types.FromStdTime(time.Time{})}}}}}}, + }, + { + "opToDays", + ast.Permit().When(ast.Duration(time.Duration(100)).ToDays()), + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeExtensionCall{Name: "toDays", Args: []ast.IsNode{ast.NodeValue{Value: types.FromStdDuration(time.Duration(100))}}}}}}, + }, + { + "opToHours", + ast.Permit().When(ast.Duration(time.Duration(100)).ToHours()), + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeExtensionCall{Name: "toHours", Args: []ast.IsNode{ast.NodeValue{Value: types.FromStdDuration(time.Duration(100))}}}}}}, + }, + {"opToMinutes", + ast.Permit().When(ast.Duration(time.Duration(100)).ToMinutes()), + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeExtensionCall{Name: "toMinutes", Args: []ast.IsNode{ast.NodeValue{Value: types.FromStdDuration(time.Duration(100))}}}}}}, + }, + { + "opToSeconds", + ast.Permit().When(ast.Duration(time.Duration(100)).ToSeconds()), + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeExtensionCall{Name: "toSeconds", Args: []ast.IsNode{ast.NodeValue{Value: types.FromStdDuration(time.Duration(100))}}}}}}, + }, + { + "opToMilliseconds", + ast.Permit().When(ast.Duration(time.Duration(100)).ToMilliseconds()), + ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{}, + Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeTypeExtensionCall{Name: "toMilliseconds", Args: []ast.IsNode{ast.NodeValue{Value: types.FromStdDuration(time.Duration(100))}}}}}}, + }, { "duplicateAnnotations", diff --git a/internal/ast/operator.go b/internal/ast/operator.go index 1361d1ea..9d61dea2 100644 --- a/internal/ast/operator.go +++ b/internal/ast/operator.go @@ -162,3 +162,33 @@ func (lhs Node) IsLoopback() Node { func (lhs Node) IsInRange(rhs Node) Node { return NewMethodCall(lhs, "isInRange", rhs) } + +// ____ _ _ _ +// | _ \ __ _| |_ ___| |_(_)_ __ ___ ___ +// | | | |/ _` | __/ _ \ __| | '_ ` _ \ / _ \ +// | |_| | (_| | || __/ |_| | | | | | | __/ +// |____/ \__,_|\__\___|\__|_|_| |_| |_|\___| + +func (lhs Node) Offset(rhs Node) Node { return NewMethodCall(lhs, "offset", rhs) } + +func (lhs Node) DurationSince(rhs Node) Node { return NewMethodCall(lhs, "durationSince", rhs) } + +func (lhs Node) ToDate() Node { return NewMethodCall(lhs, "toDate") } + +func (lhs Node) ToTime() Node { return NewMethodCall(lhs, "toTime") } + +// ____ _ _ +// | _ \ _ _ _ __ __ _| |_(_) ___ _ __ +// | | | | | | | '__/ _` | __| |/ _ \| '_ \ +// | |_| | |_| | | | (_| | |_| | (_) | | | | +// |____/ \__,_|_| \__,_|\__|_|\___/|_| |_| + +func (lhs Node) ToDays() Node { return NewMethodCall(lhs, "toDays") } + +func (lhs Node) ToHours() Node { return NewMethodCall(lhs, "toHours") } + +func (lhs Node) ToMinutes() Node { return NewMethodCall(lhs, "toMinutes") } + +func (lhs Node) ToSeconds() Node { return NewMethodCall(lhs, "toSeconds") } + +func (lhs Node) ToMilliseconds() Node { return NewMethodCall(lhs, "toMilliseconds") } diff --git a/internal/ast/value.go b/internal/ast/value.go index 14673547..d8a4376f 100644 --- a/internal/ast/value.go +++ b/internal/ast/value.go @@ -2,6 +2,7 @@ package ast import ( "net/netip" + "time" "github.com/cedar-policy/cedar-go/types" ) @@ -73,6 +74,14 @@ func IPAddr[T netip.Prefix | types.IPAddr](i T) Node { return Value(types.IPAddr(i)) } +func Datetime(t time.Time) Node { + return Value(types.FromStdTime(t)) +} + +func Duration(d time.Duration) Node { + return Value(types.FromStdDuration(d)) +} + func ExtensionCall(name types.Path, args ...Node) Node { return NewExtensionCall(name, args...) } diff --git a/internal/error.go b/internal/error.go new file mode 100644 index 00000000..893bfa6d --- /dev/null +++ b/internal/error.go @@ -0,0 +1,13 @@ +package internal + +import "fmt" + +// These errors are declared here in order to allow the tests outside of the +// types package to assert on the error type returned. One day, we could +// consider making them public. + +var ErrDatetime = fmt.Errorf("error parsing datetime value") +var ErrDecimal = fmt.Errorf("error parsing decimal value") +var ErrDuration = fmt.Errorf("error parsing duration value") +var ErrIP = fmt.Errorf("error parsing ip value") +var ErrNotComparable = fmt.Errorf("incompatible types in comparison") diff --git a/internal/eval/evalers.go b/internal/eval/evalers.go index e8aba635..54869771 100644 --- a/internal/eval/evalers.go +++ b/internal/eval/evalers.go @@ -854,7 +854,7 @@ func (n *attributeAccessEval) Eval(env Env) (types.Value, error) { if vv == unspecified { return zeroValue(), fmt.Errorf("cannot access attribute `%s` of %w", n.attribute, errUnspecifiedEntity) } - rec, ok := env.Entities.Load(vv) + rec, ok := env.Entities[vv] if !ok { return zeroValue(), fmt.Errorf("entity `%v` %w", vv.String(), errEntityNotExist) } @@ -892,7 +892,7 @@ func (n *hasEval) Eval(env Env) (types.Value, error) { var record types.Record switch vv := v.(type) { case types.EntityUID: - if rec, ok := env.Entities.Load(vv); ok { + if rec, ok := env.Entities[vv]; ok { record = rec.Attributes } case types.Record: @@ -961,12 +961,12 @@ func entityInOne(env Env, entity types.EntityUID, parent types.EntityUID) bool { var todo []types.EntityUID var candidate = entity for { - if fe, ok := env.Entities.Load(candidate); ok { + if fe, ok := env.Entities[candidate]; ok { if fe.Parents.Contains(parent) { return true } fe.Parents.Iterate(func(k types.EntityUID) bool { - p, ok := env.Entities.Load(k) + p, ok := env.Entities[k] if !ok || p.Parents.Len() == 0 || k == entity || known.Contains(k) { return true } @@ -990,12 +990,12 @@ func entityInSet(env Env, entity types.EntityUID, parents mapset.Container[types var todo []types.EntityUID var candidate = entity for { - if fe, ok := env.Entities.Load(candidate); ok { + if fe, ok := env.Entities[candidate]; ok { if fe.Parents.Intersects(parents) { return true } fe.Parents.Iterate(func(k types.EntityUID) bool { - p, ok := env.Entities.Load(k) + p, ok := env.Entities[k] if !ok || p.Parents.Len() == 0 || k == entity || known.Contains(k) { return true } diff --git a/internal/eval/evalers_test.go b/internal/eval/evalers_test.go index 060b502a..00ed87ca 100644 --- a/internal/eval/evalers_test.go +++ b/internal/eval/evalers_test.go @@ -7,6 +7,7 @@ import ( "testing" "time" + "github.com/cedar-policy/cedar-go/internal" "github.com/cedar-policy/cedar-go/internal/consts" "github.com/cedar-policy/cedar-go/internal/parser" "github.com/cedar-policy/cedar-go/internal/testutil" @@ -2177,7 +2178,7 @@ func TestDecimalLiteralNode(t *testing.T) { }{ {"Error", newErrorEval(errTest), zeroValue(), errTest}, {"TypeError", newLiteralEval(types.Long(1)), zeroValue(), ErrType}, - {"DecimalError", newLiteralEval(types.String("frob")), zeroValue(), types.ErrDecimal}, + {"DecimalError", newLiteralEval(types.String("frob")), zeroValue(), internal.ErrDecimal}, {"Success", newLiteralEval(types.String("1.0")), testutil.Must(types.NewDecimalFromInt(1)), nil}, } for _, tt := range tests { @@ -2204,7 +2205,7 @@ func TestIPLiteralNode(t *testing.T) { }{ {"Error", newErrorEval(errTest), zeroValue(), errTest}, {"TypeError", newLiteralEval(types.Long(1)), zeroValue(), ErrType}, - {"IPError", newLiteralEval(types.String("not-an-IP-address")), zeroValue(), types.ErrIP}, + {"IPError", newLiteralEval(types.String("not-an-IP-address")), zeroValue(), internal.ErrIP}, {"Success", newLiteralEval(types.String("::1/128")), ipv6Loopback, nil}, } for _, tt := range tests { @@ -2335,7 +2336,7 @@ func TestDatetimeLiteralNode(t *testing.T) { }{ {"Error", newErrorEval(errTest), zeroValue(), errTest}, {"TypeError", newLiteralEval(types.Long(1)), zeroValue(), ErrType}, - {"DatetimeError", newLiteralEval(types.String("frob")), zeroValue(), types.ErrDatetime}, + {"DatetimeError", newLiteralEval(types.String("frob")), zeroValue(), internal.ErrDatetime}, {"Success", newLiteralEval(types.String("1970-01-01")), types.FromStdTime(time.UnixMilli(0)), nil}, } for _, tt := range tests { @@ -2480,7 +2481,7 @@ func TestDurationLiteralNode(t *testing.T) { }{ {"Error", newErrorEval(errTest), zeroValue(), errTest}, {"TypeError", newLiteralEval(types.Long(1)), zeroValue(), ErrType}, - {"DurationError", newLiteralEval(types.String("frob")), zeroValue(), types.ErrDuration}, + {"DurationError", newLiteralEval(types.String("frob")), zeroValue(), internal.ErrDuration}, {"Success", newLiteralEval(types.String("1h")), types.FromStdDuration(1 * time.Hour), nil}, } for _, tt := range tests { diff --git a/internal/eval/partial.go b/internal/eval/partial.go index cfd9998c..4fc87b12 100644 --- a/internal/eval/partial.go +++ b/internal/eval/partial.go @@ -504,7 +504,7 @@ func (n *partialHasEval) Eval(env Env) (types.Value, error) { var record types.Record switch vv := v.(type) { case types.EntityUID: - if rec, ok := env.Entities.Load(vv); ok { + if rec, ok := env.Entities[vv]; ok { record = rec.Attributes } case types.Record: diff --git a/internal/extensions/extensions.go b/internal/extensions/extensions.go index 86a760c6..1ee84090 100644 --- a/internal/extensions/extensions.go +++ b/internal/extensions/extensions.go @@ -24,16 +24,16 @@ var ExtMap = map[types.Path]extInfo{ "isMulticast": {Args: 1, IsMethod: true}, "isInRange": {Args: 2, IsMethod: true}, - "toDate": {Args: 1, IsMethod: true}, - "toTime": {Args: 1, IsMethod: true}, + "toDate": {Args: 1, IsMethod: true}, + "toTime": {Args: 1, IsMethod: true}, + "offset": {Args: 2, IsMethod: true}, + "durationSince": {Args: 2, IsMethod: true}, + "toDays": {Args: 1, IsMethod: true}, "toHours": {Args: 1, IsMethod: true}, "toMinutes": {Args: 1, IsMethod: true}, "toSeconds": {Args: 1, IsMethod: true}, "toMilliseconds": {Args: 1, IsMethod: true}, - - "offset": {Args: 2, IsMethod: true}, - "durationSince": {Args: 2, IsMethod: true}, } func init() { diff --git a/types.go b/types.go index af2cdbbb..7e7d1996 100644 --- a/types.go +++ b/types.go @@ -31,7 +31,7 @@ type String = types.String // Other Cedar types -type Entities = types.EntityMap +type EntityMap = types.EntityMap type Entity = types.Entity type EntityType = types.EntityType type EntityUIDSet = types.EntityUIDSet diff --git a/types/datetime.go b/types/datetime.go index 66ae1ee2..4e40eb7d 100644 --- a/types/datetime.go +++ b/types/datetime.go @@ -8,8 +8,12 @@ import ( "strconv" "time" "unicode" + + "github.com/cedar-policy/cedar-go/internal" ) +var errDatetime = internal.ErrDatetime + // Datetime represents a Cedar datetime value type Datetime struct { // value is a timestamp in milliseconds @@ -44,7 +48,7 @@ func ParseDatetime(s string) (Datetime, error) { length := len(s) if length < 10 { - return Datetime{}, fmt.Errorf("%w: string too short", ErrDatetime) + return Datetime{}, fmt.Errorf("%w: string too short", errDatetime) } // Date: YYYY-MM-DD @@ -57,7 +61,7 @@ func ParseDatetime(s string) (Datetime, error) { unicode.IsDigit(rune(s[1])) && unicode.IsDigit(rune(s[2])) && unicode.IsDigit(rune(s[3]))) { - return Datetime{}, fmt.Errorf("%w: invalid year", ErrDatetime) + return Datetime{}, fmt.Errorf("%w: invalid year", errDatetime) } year = 1000*int(rune(s[0])-'0') + 100*int(rune(s[1])-'0') + @@ -65,24 +69,24 @@ func ParseDatetime(s string) (Datetime, error) { int(rune(s[3])-'0') if s[4] != '-' { - return Datetime{}, fmt.Errorf("%w: unexpected character %s", ErrDatetime, strconv.QuoteRune(rune(s[4]))) + return Datetime{}, fmt.Errorf("%w: unexpected character %s", errDatetime, strconv.QuoteRune(rune(s[4]))) } // MM if !(unicode.IsDigit(rune(s[5])) && unicode.IsDigit(rune(s[6]))) { - return Datetime{}, fmt.Errorf("%w: invalid month", ErrDatetime) + return Datetime{}, fmt.Errorf("%w: invalid month", errDatetime) } month = 10*int(rune(s[5])-'0') + int(rune(s[6])-'0') if s[7] != '-' { - return Datetime{}, fmt.Errorf("%w: unexpected character %s", ErrDatetime, strconv.QuoteRune(rune(s[7]))) + return Datetime{}, fmt.Errorf("%w: unexpected character %s", errDatetime, strconv.QuoteRune(rune(s[7]))) } // DD if !(unicode.IsDigit(rune(s[8])) && unicode.IsDigit(rune(s[9]))) { - return Datetime{}, fmt.Errorf("%w: invalid day", ErrDatetime) + return Datetime{}, fmt.Errorf("%w: invalid day", errDatetime) } day = 10*int(rune(s[8])-'0') + int(rune(s[9])-'0') @@ -94,7 +98,7 @@ func ParseDatetime(s string) (Datetime, error) { // If the length is less than 20, we can't have a valid time. if length < 20 { - return Datetime{}, fmt.Errorf("%w: invalid time", ErrDatetime) + return Datetime{}, fmt.Errorf("%w: invalid time", errDatetime) } // Time: Thh:mm:ss? @@ -106,32 +110,32 @@ func ParseDatetime(s string) (Datetime, error) { // ? is at 19, and... we'll skip to get back to that. if s[10] != 'T' { - return Datetime{}, fmt.Errorf("%w: unexpected character %s", ErrDatetime, strconv.QuoteRune(rune(s[10]))) + return Datetime{}, fmt.Errorf("%w: unexpected character %s", errDatetime, strconv.QuoteRune(rune(s[10]))) } if !(unicode.IsDigit(rune(s[11])) && unicode.IsDigit(rune(s[12]))) { - return Datetime{}, fmt.Errorf("%w: invalid hour", ErrDatetime) + return Datetime{}, fmt.Errorf("%w: invalid hour", errDatetime) } hour = 10*int(rune(s[11])-'0') + int(rune(s[12])-'0') if s[13] != ':' { - return Datetime{}, fmt.Errorf("%w: unexpected character %s", ErrDatetime, strconv.QuoteRune(rune(s[13]))) + return Datetime{}, fmt.Errorf("%w: unexpected character %s", errDatetime, strconv.QuoteRune(rune(s[13]))) } if !(unicode.IsDigit(rune(s[14])) && unicode.IsDigit(rune(s[15]))) { - return Datetime{}, fmt.Errorf("%w: invalid minute", ErrDatetime) + return Datetime{}, fmt.Errorf("%w: invalid minute", errDatetime) } minute = 10*int(rune(s[14])-'0') + int(rune(s[15])-'0') if s[16] != ':' { - return Datetime{}, fmt.Errorf("%w: unexpected character %s", ErrDatetime, strconv.QuoteRune(rune(s[16]))) + return Datetime{}, fmt.Errorf("%w: unexpected character %s", errDatetime, strconv.QuoteRune(rune(s[16]))) } if !(unicode.IsDigit(rune(s[17])) && unicode.IsDigit(rune(s[18]))) { - return Datetime{}, fmt.Errorf("%w: invalid second", ErrDatetime) + return Datetime{}, fmt.Errorf("%w: invalid second", errDatetime) } second = 10*int(rune(s[17])-'0') + int(rune(s[18])-'0') @@ -142,13 +146,13 @@ func ParseDatetime(s string) (Datetime, error) { trailerOffset := 19 if s[19] == '.' { if length < 23 { - return Datetime{}, fmt.Errorf("%w: invalid millisecond", ErrDatetime) + return Datetime{}, fmt.Errorf("%w: invalid millisecond", errDatetime) } if !(unicode.IsDigit(rune(s[20])) && unicode.IsDigit(rune(s[21])) && unicode.IsDigit(rune(s[22]))) { - return Datetime{}, fmt.Errorf("%w: invalid millisecond", ErrDatetime) + return Datetime{}, fmt.Errorf("%w: invalid millisecond", errDatetime) } milli = 100*int(rune(s[20])-'0') + 10*int(rune(s[21])-'0') + int(rune(s[22])-'0') @@ -156,7 +160,7 @@ func ParseDatetime(s string) (Datetime, error) { } if length == trailerOffset { - return Datetime{}, fmt.Errorf("%w: expected time zone designator", ErrDatetime) + return Datetime{}, fmt.Errorf("%w: expected time zone designator", errDatetime) } // At this point, we can only have 2 possible lengths. Anything else is an error. @@ -164,7 +168,7 @@ func ParseDatetime(s string) (Datetime, error) { case 'Z': if length > trailerOffset+1 { // If something comes after the Z, it's an error - return Datetime{}, fmt.Errorf("%w: unexpected trailer after time zone designator", ErrDatetime) + return Datetime{}, fmt.Errorf("%w: unexpected trailer after time zone designator", errDatetime) } case '+', '-': sign := 1 @@ -173,9 +177,9 @@ func ParseDatetime(s string) (Datetime, error) { } if length > trailerOffset+5 { - return Datetime{}, fmt.Errorf("%w: unexpected trailer after time zone designator", ErrDatetime) + return Datetime{}, fmt.Errorf("%w: unexpected trailer after time zone designator", errDatetime) } else if length != trailerOffset+5 { - return Datetime{}, fmt.Errorf("%w: invalid time zone offset", ErrDatetime) + return Datetime{}, fmt.Errorf("%w: invalid time zone offset", errDatetime) } // get the time zone offset hhmm. @@ -183,7 +187,7 @@ func ParseDatetime(s string) (Datetime, error) { unicode.IsDigit(rune(s[trailerOffset+2])) && unicode.IsDigit(rune(s[trailerOffset+3])) && unicode.IsDigit(rune(s[trailerOffset+4]))) { - return Datetime{}, fmt.Errorf("%w: invalid time zone offset", ErrDatetime) + return Datetime{}, fmt.Errorf("%w: invalid time zone offset", errDatetime) } hh := time.Duration(10*int64(rune(s[trailerOffset+1])-'0')+int64(rune(s[trailerOffset+2])-'0')) * time.Hour @@ -191,7 +195,7 @@ func ParseDatetime(s string) (Datetime, error) { offset = time.Duration(sign) * (hh + mm) default: - return Datetime{}, fmt.Errorf("%w: invalid time zone designator", ErrDatetime) + return Datetime{}, fmt.Errorf("%w: invalid time zone designator", errDatetime) } t := time.Date(year, time.Month(month), day, @@ -213,7 +217,7 @@ func (a Datetime) Equal(bi Value) bool { func (a Datetime) LessThan(bi Value) (bool, error) { b, ok := bi.(Datetime) if !ok { - return false, ErrNotComparable + return false, internal.ErrNotComparable } return a.value < b.value, nil } @@ -224,7 +228,7 @@ func (a Datetime) LessThan(bi Value) (bool, error) { func (a Datetime) LessThanOrEqual(bi Value) (bool, error) { b, ok := bi.(Datetime) if !ok { - return false, ErrNotComparable + return false, internal.ErrNotComparable } return a.value <= b.value, nil } diff --git a/types/datetime_test.go b/types/datetime_test.go index 9f5658b6..aa75805d 100644 --- a/types/datetime_test.go +++ b/types/datetime_test.go @@ -5,6 +5,7 @@ import ( "testing" "time" + "github.com/cedar-policy/cedar-go/internal" "github.com/cedar-policy/cedar-go/internal/testutil" "github.com/cedar-policy/cedar-go/types" ) @@ -104,7 +105,7 @@ func TestDatetime(t *testing.T) { t.Run(fmt.Sprintf("%d_%s->%s", ti, tt.in, tt.errStr), func(t *testing.T) { t.Parallel() _, err := types.ParseDatetime(tt.in) - testutil.ErrorIs(t, err, types.ErrDatetime) + testutil.ErrorIs(t, err, internal.ErrDatetime) testutil.Equals(t, err.Error(), tt.errStr) }) } @@ -145,7 +146,7 @@ func TestDatetime(t *testing.T) { {one, zero, false, nil}, {zero, one, true, nil}, {zero, zero, false, nil}, - {zero, f, false, types.ErrNotComparable}, + {zero, f, false, internal.ErrNotComparable}, } for ti, tt := range tests { @@ -175,7 +176,7 @@ func TestDatetime(t *testing.T) { {one, zero, false, nil}, {zero, one, true, nil}, {zero, zero, true, nil}, - {zero, f, false, types.ErrNotComparable}, + {zero, f, false, internal.ErrNotComparable}, } for ti, tt := range tests { diff --git a/types/decimal.go b/types/decimal.go index 8478d09e..a217ab01 100644 --- a/types/decimal.go +++ b/types/decimal.go @@ -9,9 +9,12 @@ import ( "strconv" "strings" + "github.com/cedar-policy/cedar-go/internal" "golang.org/x/exp/constraints" ) +var errDecimal = internal.ErrDecimal + // decimalPrecision is the precision of a Decimal. const decimalPrecision = 10000 @@ -26,9 +29,9 @@ type Decimal struct { // sign of intPart and tenThousandths should match. func newDecimal(intPart int64, tenThousandths int16) (Decimal, error) { if intPart > 922337203685477 || (intPart == 922337203685477 && tenThousandths > 5807) { - return Decimal{}, fmt.Errorf("%w: value would overflow", ErrDecimal) + return Decimal{}, fmt.Errorf("%w: value would overflow", errDecimal) } else if intPart < -922337203685477 || (intPart == -922337203685477 && tenThousandths < -5808) { - return Decimal{}, fmt.Errorf("%w: value would underflow", ErrDecimal) + return Decimal{}, fmt.Errorf("%w: value would underflow", errDecimal) } return Decimal{value: intPart*decimalPrecision + int64(tenThousandths)}, nil @@ -37,7 +40,7 @@ func newDecimal(intPart int64, tenThousandths int16) (Decimal, error) { // NewDecimal returns a Decimal value of i * 10^exponent. func NewDecimal(i int64, exponent int) (Decimal, error) { if exponent < -4 || exponent > 14 { - return Decimal{}, fmt.Errorf("%w: exponent value of %v exceeds maximum range of Decimal", ErrDecimal, exponent) + return Decimal{}, fmt.Errorf("%w: exponent value of %v exceeds maximum range of Decimal", errDecimal, exponent) } var intPart int64 @@ -48,9 +51,9 @@ func NewDecimal(i int64, exponent int) (Decimal, error) { } else { intPart = i * int64(math.Pow10(exponent)) if i > 0 && intPart < i { - return Decimal{}, fmt.Errorf("%w: value %ve%v would overflow", ErrDecimal, i, exponent) + return Decimal{}, fmt.Errorf("%w: value %ve%v would overflow", errDecimal, i, exponent) } else if i < 0 && intPart > i { - return Decimal{}, fmt.Errorf("%w: value %ve%v would underflow", ErrDecimal, i, exponent) + return Decimal{}, fmt.Errorf("%w: value %ve%v would underflow", errDecimal, i, exponent) } } @@ -73,9 +76,9 @@ func NewDecimalFromInt[T constraints.Signed](i T) (Decimal, error) { func NewDecimalFromFloat[T constraints.Float](f T) (Decimal, error) { f = f * decimalPrecision if f > math.MaxInt64 { - return Decimal{}, fmt.Errorf("%w: value %v would overflow", ErrDecimal, f) + return Decimal{}, fmt.Errorf("%w: value %v would overflow", errDecimal, f) } else if f < math.MinInt64 { - return Decimal{}, fmt.Errorf("%w: value %v would underflow", ErrDecimal, f) + return Decimal{}, fmt.Errorf("%w: value %v would underflow", errDecimal, f) } return Decimal{int64(f)}, nil @@ -94,29 +97,29 @@ func (d Decimal) Compare(other Decimal) int { func ParseDecimal(s string) (Decimal, error) { decimalIndex := strings.Index(s, ".") if decimalIndex < 0 { - return Decimal{}, fmt.Errorf("%w: missing decimal point", ErrDecimal) + return Decimal{}, fmt.Errorf("%w: missing decimal point", errDecimal) } intPart, err := strconv.ParseInt(s[0:decimalIndex], 10, 64) if err != nil { if errors.Is(err, strconv.ErrRange) { - return Decimal{}, fmt.Errorf("%w: value would overflow", ErrDecimal) + return Decimal{}, fmt.Errorf("%w: value would overflow", errDecimal) } - return Decimal{}, fmt.Errorf("%w: %w", ErrDecimal, err) + return Decimal{}, fmt.Errorf("%w: %w", errDecimal, err) } fracPartStr := s[decimalIndex+1:] fracPart, err := strconv.ParseUint(fracPartStr, 10, 16) if err != nil { if errors.Is(err, strconv.ErrRange) { - return Decimal{}, fmt.Errorf("%w: fractional part exceeds Decimal precision", ErrDecimal) + return Decimal{}, fmt.Errorf("%w: fractional part exceeds Decimal precision", errDecimal) } - return Decimal{}, fmt.Errorf("%w: %w", ErrDecimal, err) + return Decimal{}, fmt.Errorf("%w: %w", errDecimal, err) } decimalPlaces := len(fracPartStr) if decimalPlaces > 4 { - return Decimal{}, fmt.Errorf("%w: fractional part exceeds Decimal precision", ErrDecimal) + return Decimal{}, fmt.Errorf("%w: fractional part exceeds Decimal precision", errDecimal) } tenThousandths := int16(fracPart) * int16(math.Pow10(4-decimalPlaces)) diff --git a/types/decimal_test.go b/types/decimal_test.go index 8f56e1ab..2d94d6ce 100644 --- a/types/decimal_test.go +++ b/types/decimal_test.go @@ -4,6 +4,7 @@ import ( "fmt" "testing" + "github.com/cedar-policy/cedar-go/internal" "github.com/cedar-policy/cedar-go/internal/testutil" "github.com/cedar-policy/cedar-go/types" ) @@ -103,7 +104,7 @@ func TestDecimal(t *testing.T) { t.Run(fmt.Sprintf("%s->%s", tt.in, tt.errStr), func(t *testing.T) { t.Parallel() _, err := types.ParseDecimal(tt.in) - testutil.ErrorIs(t, err, types.ErrDecimal) + testutil.ErrorIs(t, err, internal.ErrDecimal) testutil.Equals(t, err.Error(), tt.errStr) }) } @@ -208,7 +209,7 @@ func TestDecimal(t *testing.T) { t.Run(fmt.Sprintf("%ve%v", tt.in, tt.exp), func(t *testing.T) { t.Parallel() _, err := types.NewDecimal(tt.in, tt.exp) - testutil.ErrorIs(t, err, types.ErrDecimal) + testutil.ErrorIs(t, err, internal.ErrDecimal) }) } }) @@ -232,7 +233,7 @@ func TestDecimal(t *testing.T) { t.Run(fmt.Sprintf("%ve%v", tt.in, tt.exp), func(t *testing.T) { t.Parallel() _, err := types.NewDecimal(tt.in, tt.exp) - testutil.ErrorIs(t, err, types.ErrDecimal) + testutil.ErrorIs(t, err, internal.ErrDecimal) }) } }) @@ -274,7 +275,7 @@ func TestDecimal(t *testing.T) { for _, tt := range tests { t.Run(fmt.Sprintf("%v", tt.in), func(t *testing.T) { _, err := types.NewDecimalFromFloat(tt.in) - testutil.ErrorIs(t, err, types.ErrDecimal) + testutil.ErrorIs(t, err, internal.ErrDecimal) }) } }) diff --git a/types/duration.go b/types/duration.go index 2ed1e5cd..d103ab1c 100644 --- a/types/duration.go +++ b/types/duration.go @@ -10,9 +10,12 @@ import ( "time" "unicode" + "github.com/cedar-policy/cedar-go/internal" "github.com/cedar-policy/cedar-go/internal/consts" ) +var errDuration = internal.ErrDuration + var unitToMillis = map[string]int64{ "d": consts.MillisPerDay, "h": consts.MillisPerHour, @@ -52,7 +55,7 @@ func DurationFromMillis(ms int64) Duration { func ParseDuration(s string) (Duration, error) { // Check for empty string. if len(s) <= 1 { - return Duration{}, fmt.Errorf("%w: string too short", ErrDuration) + return Duration{}, fmt.Errorf("%w: string too short", errDuration) } i := 0 @@ -78,13 +81,13 @@ func ParseDuration(s string) (Duration, error) { // check overflow if value > math.MaxInt32 { - return Duration{}, fmt.Errorf("%w: overflow", ErrDuration) + return Duration{}, fmt.Errorf("%w: overflow", errDuration) } hasValue = true i++ } else if s[i] == 'd' || s[i] == 'h' || s[i] == 'm' || s[i] == 's' { if !hasValue { - return Duration{}, fmt.Errorf("%w: unit found without quantity", ErrDuration) + return Duration{}, fmt.Errorf("%w: unit found without quantity", errDuration) } // is it ms? @@ -104,7 +107,7 @@ func ParseDuration(s string) (Duration, error) { } if !unitOK { - return Duration{}, fmt.Errorf("%w: unexpected unit '%s'", ErrDuration, unit) + return Duration{}, fmt.Errorf("%w: unexpected unit '%s'", errDuration, unit) } total = total + value*unitToMillis[unit] @@ -112,18 +115,18 @@ func ParseDuration(s string) (Duration, error) { hasValue = false value = 0 } else { - return Duration{}, fmt.Errorf("%w: unexpected character %s", ErrDuration, strconv.QuoteRune(rune(s[i]))) + return Duration{}, fmt.Errorf("%w: unexpected character %s", errDuration, strconv.QuoteRune(rune(s[i]))) } } // We didn't have a trailing unit if hasValue { - return Duration{}, fmt.Errorf("%w: expected unit", ErrDuration) + return Duration{}, fmt.Errorf("%w: expected unit", errDuration) } // We still have characters left, but no more units to assign. if i < len(s) { - return Duration{}, fmt.Errorf("%w: invalid duration", ErrDuration) + return Duration{}, fmt.Errorf("%w: invalid duration", errDuration) } return Duration{value: negative * total}, nil @@ -141,7 +144,7 @@ func (a Duration) Equal(bi Value) bool { func (a Duration) LessThan(bi Value) (bool, error) { b, ok := bi.(Duration) if !ok { - return false, ErrNotComparable + return false, internal.ErrNotComparable } return a.value < b.value, nil } @@ -152,7 +155,7 @@ func (a Duration) LessThan(bi Value) (bool, error) { func (a Duration) LessThanOrEqual(bi Value) (bool, error) { b, ok := bi.(Duration) if !ok { - return false, ErrNotComparable + return false, internal.ErrNotComparable } return a.value <= b.value, nil } diff --git a/types/duration_test.go b/types/duration_test.go index 43879958..fcfac607 100644 --- a/types/duration_test.go +++ b/types/duration_test.go @@ -5,6 +5,7 @@ import ( "testing" "time" + "github.com/cedar-policy/cedar-go/internal" "github.com/cedar-policy/cedar-go/internal/testutil" "github.com/cedar-policy/cedar-go/types" ) @@ -72,7 +73,7 @@ func TestDuration(t *testing.T) { t.Run(fmt.Sprintf("%d_%s->%s", ti, tt.in, tt.errStr), func(t *testing.T) { t.Parallel() _, err := types.ParseDuration(tt.in) - testutil.ErrorIs(t, err, types.ErrDuration) + testutil.ErrorIs(t, err, internal.ErrDuration) testutil.Equals(t, err.Error(), tt.errStr) }) } @@ -113,7 +114,7 @@ func TestDuration(t *testing.T) { {one, zero, false, nil}, {zero, one, true, nil}, {zero, zero, false, nil}, - {zero, f, false, types.ErrNotComparable}, + {zero, f, false, internal.ErrNotComparable}, } for ti, tt := range tests { @@ -143,7 +144,7 @@ func TestDuration(t *testing.T) { {one, zero, false, nil}, {zero, one, true, nil}, {zero, zero, true, nil}, - {zero, f, false, types.ErrNotComparable}, + {zero, f, false, internal.ErrNotComparable}, } for ti, tt := range tests { diff --git a/types/entity_map.go b/types/entity_map.go index cbb6c2a4..70b90344 100644 --- a/types/entity_map.go +++ b/types/entity_map.go @@ -13,11 +13,6 @@ import ( // the Entity (it must be the same as the UID within the Entity itself.) type EntityMap map[EntityUID]Entity -func (e EntityMap) Load(k EntityUID) (Entity, bool) { - v, ok := e[k] - return v, ok -} - func (e EntityMap) MarshalJSON() ([]byte, error) { s := maps.Values(e) slices.SortFunc(s, func(a, b Entity) int { diff --git a/types/ipaddr.go b/types/ipaddr.go index 03132657..90c54229 100644 --- a/types/ipaddr.go +++ b/types/ipaddr.go @@ -7,8 +7,12 @@ import ( "hash/fnv" "net/netip" "strings" + + "github.com/cedar-policy/cedar-go/internal" ) +var errIP = internal.ErrIP + // An IPAddr is value that represents an IP address. It can be either IPv4 or IPv6. // The value can represent an individual address or a range of addresses. type IPAddr netip.Prefix @@ -17,13 +21,13 @@ type IPAddr netip.Prefix func ParseIPAddr(s string) (IPAddr, error) { // We disallow IPv4-mapped IPv6 addresses in dotted notation because Cedar does. if strings.Count(s, ":") >= 2 && strings.Count(s, ".") >= 2 { - return IPAddr{}, fmt.Errorf("%w: cannot parse IPv4 addresses embedded in IPv6 addresses", ErrIP) + return IPAddr{}, fmt.Errorf("%w: cannot parse IPv4 addresses embedded in IPv6 addresses", errIP) } else if net, err := netip.ParsePrefix(s); err == nil { return IPAddr(net), nil } else if addr, err := netip.ParseAddr(s); err == nil { return IPAddr(netip.PrefixFrom(addr, addr.BitLen())), nil } else { - return IPAddr{}, fmt.Errorf("%w: error parsing IP address %s", ErrIP, s) + return IPAddr{}, fmt.Errorf("%w: error parsing IP address %s", errIP, s) } } diff --git a/types/json_test.go b/types/json_test.go index ccb474a4..3ae0d892 100644 --- a/types/json_test.go +++ b/types/json_test.go @@ -62,10 +62,10 @@ func TestJSON_Value(t *testing.T) { {"explicitDatetime", `{ "__extn": { "fn": "datetime", "arg": "1970-01-01T00:00:01Z" } }`, mustDatetimeValue("1970-01-01T00:00:01Z"), nil}, {"explicitDuration", `{ "__extn": { "fn": "duration", "arg": "1d12h30m30s500ms" } }`, mustDurationValue("1d12h30m30s500ms"), nil}, {"invalidExtension", `{ "__extn": { "fn": "asdf", "arg": "blah" } }`, zeroValue(), errJSONInvalidExtn}, - {"badIP", `{ "__extn": { "fn": "ip", "arg": "bad" } }`, zeroValue(), ErrIP}, - {"badDecimal", `{ "__extn": { "fn": "decimal", "arg": "bad" } }`, zeroValue(), ErrDecimal}, - {"badDatetime", `{ "__extn": { "fn": "datetime", "arg": "bad" } }`, zeroValue(), ErrDatetime}, - {"badDuration", `{ "__extn": { "fn": "duration", "arg": "bad" } }`, zeroValue(), ErrDuration}, + {"badIP", `{ "__extn": { "fn": "ip", "arg": "bad" } }`, zeroValue(), errIP}, + {"badDecimal", `{ "__extn": { "fn": "decimal", "arg": "bad" } }`, zeroValue(), errDecimal}, + {"badDatetime", `{ "__extn": { "fn": "datetime", "arg": "bad" } }`, zeroValue(), errDatetime}, + {"badDuration", `{ "__extn": { "fn": "duration", "arg": "bad" } }`, zeroValue(), errDuration}, {"set", `[42]`, NewSet(Long(42)), nil}, {"record", `{"a":"b"}`, NewRecord(RecordMap{"a": String("b")}), nil}, {"bool", `false`, Boolean(false), nil}, @@ -158,7 +158,7 @@ func TestTypedJSONUnmarshal(t *testing.T) { }, in: `{ "__extn": { "fn": "ip", "arg": "bad" } }`, wantValue: IPAddr{}, - wantErr: ErrIP, + wantErr: errIP, }, { name: "ip/badJSON", @@ -236,7 +236,7 @@ func TestTypedJSONUnmarshal(t *testing.T) { }, in: `{ "__extn": { "fn": "decimal", "arg": "bad" } }`, wantValue: Decimal{}, - wantErr: ErrDecimal, + wantErr: errDecimal, }, { name: "decimal/badJSON", @@ -336,7 +336,7 @@ func TestTypedJSONUnmarshal(t *testing.T) { }, in: `{ "__extn": { "fn": "datetime", "arg": "bad" } }`, wantValue: Datetime{}, - wantErr: ErrDatetime, + wantErr: errDatetime, }, { name: "datetime/badJSON", @@ -460,7 +460,7 @@ func TestTypedJSONUnmarshal(t *testing.T) { }, in: `{ "__extn": { "fn": "duration", "arg": "bad" } }`, wantValue: Duration{}, - wantErr: ErrDuration, + wantErr: errDuration, }, { name: "duration/badJSON", diff --git a/types/long.go b/types/long.go index 2dec16e3..69c399aa 100644 --- a/types/long.go +++ b/types/long.go @@ -2,6 +2,8 @@ package types import ( "fmt" + + "github.com/cedar-policy/cedar-go/internal" ) // A Long is a whole number without decimals that can range from -9223372036854775808 to 9223372036854775807. @@ -15,7 +17,7 @@ func (a Long) Equal(bi Value) bool { func (a Long) LessThan(bi Value) (bool, error) { b, ok := bi.(Long) if !ok { - return false, ErrNotComparable + return false, internal.ErrNotComparable } return a < b, nil } @@ -23,7 +25,7 @@ func (a Long) LessThan(bi Value) (bool, error) { func (a Long) LessThanOrEqual(bi Value) (bool, error) { b, ok := bi.(Long) if !ok { - return false, ErrNotComparable + return false, internal.ErrNotComparable } return a <= b, nil } diff --git a/types/long_test.go b/types/long_test.go index d5e53e9e..dc185d3d 100644 --- a/types/long_test.go +++ b/types/long_test.go @@ -4,6 +4,7 @@ import ( "fmt" "testing" + "github.com/cedar-policy/cedar-go/internal" "github.com/cedar-policy/cedar-go/internal/testutil" "github.com/cedar-policy/cedar-go/types" ) @@ -39,7 +40,7 @@ func TestLong(t *testing.T) { {one, zero, false, nil}, {zero, one, true, nil}, {zero, zero, false, nil}, - {zero, f, false, types.ErrNotComparable}, + {zero, f, false, internal.ErrNotComparable}, } for ti, tt := range tests { @@ -69,7 +70,7 @@ func TestLong(t *testing.T) { {one, zero, false, nil}, {zero, one, true, nil}, {zero, zero, true, nil}, - {zero, f, false, types.ErrNotComparable}, + {zero, f, false, internal.ErrNotComparable}, } for ti, tt := range tests { diff --git a/types/value.go b/types/value.go index 777a576e..b65b0429 100644 --- a/types/value.go +++ b/types/value.go @@ -4,12 +4,6 @@ import ( "fmt" ) -var ErrDatetime = fmt.Errorf("error parsing datetime value") -var ErrDecimal = fmt.Errorf("error parsing decimal value") -var ErrDuration = fmt.Errorf("error parsing duration value") -var ErrIP = fmt.Errorf("error parsing ip value") -var ErrNotComparable = fmt.Errorf("incompatible types in comparison") - // Value defines the interface for all Cedar values (String, Long, Set, Record, Boolean, etc ...) // // Implementations of Value _must_ be able to be safely copied shallowly, which means they must either be immutable