diff --git a/internal/mapset/mapset_test.go b/internal/mapset/mapset_test.go index cf1a49c..02c726c 100644 --- a/internal/mapset/mapset_test.go +++ b/internal/mapset/mapset_test.go @@ -150,6 +150,12 @@ func TestMapSet(t *testing.T) { testutil.Equals(t, string(out), "[]") }) + t.Run("marshal error", func(t *testing.T) { + s := FromItems(complex(0, 0)) + _, err := json.Marshal(s) + testutil.Error(t, err) + }) + t.Run("encode json one int", func(t *testing.T) { s := FromItems(1) diff --git a/internal/parser/cedar_unmarshal_test.go b/internal/parser/cedar_unmarshal_test.go index 20e6a72..a654638 100644 --- a/internal/parser/cedar_unmarshal_test.go +++ b/internal/parser/cedar_unmarshal_test.go @@ -564,6 +564,7 @@ func TestParseApproximateErrors(t *testing.T) { {"reservedKeywordAsEntityType", `permit (principal == false::"42", action, resource)`, "expected ident"}, {"reservedKeywordAsAttributeAccess", `permit (principal, action, resource) when { context.false }`, "expected ident"}, {"invalidPrimary", `permit (principal, action, resource) when { foobar }`, "invalid primary"}, + {"unexpectedTokenInEntityOrExtFun", `permit (principal, action, resource) when { A::B 42 }`, "unexpected token"}, } for _, tt := range tests { tt := tt diff --git a/internal/testutil/testutil.go b/internal/testutil/testutil.go index cb8c1f6..8bbf2fd 100644 --- a/internal/testutil/testutil.go +++ b/internal/testutil/testutil.go @@ -73,6 +73,7 @@ func Must[T any](t T, err error) T { // JSONMarshalsTo asserts that obj marshals as JSON to the given string, allowing for formatting differences and // displaying an easy-to-read diff. func JSONMarshalsTo[T any](t TB, obj T, want string) { + t.Helper() b, err := json.MarshalIndent(obj, "", "\t") OK(t, err) diff --git a/internal/testutil/testutil_test.go b/internal/testutil/testutil_test.go index 5ca2a41..76626b8 100644 --- a/internal/testutil/testutil_test.go +++ b/internal/testutil/testutil_test.go @@ -147,3 +147,60 @@ func TestPanic(t *testing.T) { Equals(t, len(tb.FatalfCalls()), 1) }) } + +func TestMust(t *testing.T) { + t.Parallel() + t.Run("Panic", func(t *testing.T) { + tb := newTB() + var x bool + Panic(tb, func() { + x = Must(true, fmt.Errorf("panic")) + }) + // assertions + Equals(t, x, false) + Equals(t, len(tb.HelperCalls()), 1) + Equals(t, len(tb.FatalfCalls()), 0) + }) + t.Run("Okay", func(t *testing.T) { + tb := newTB() + var x bool + Panic(tb, func() { + x = Must(true, nil) + }) + // assertions + Equals(t, x, true) + }) +} + +func TestJSONMarshalsTo(t *testing.T) { + t.Parallel() + t.Run("Okay", func(t *testing.T) { + tb := newTB() + JSONMarshalsTo(tb, "test", `"test"`) + Equals(t, len(tb.HelperCalls()), 4) + Equals(t, len(tb.FatalfCalls()), 0) + }) + + t.Run("ErrNotEqual", func(t *testing.T) { + tb := newTB() + JSONMarshalsTo(tb, "test", `"asdf"`) + Equals(t, len(tb.HelperCalls()), 4) + Equals(t, len(tb.FatalfCalls()), 1) + }) + + t.Run("ErrNotJSON", func(t *testing.T) { + tb := newTB() + JSONMarshalsTo(tb, "test", `asdf`) + Equals(t, len(tb.HelperCalls()), 4) + Equals(t, len(tb.FatalfCalls()), 2) + }) + + t.Run("ErrNotMarshalable", func(t *testing.T) { + tb := newTB() + cx := complex(0, 0) + JSONMarshalsTo(tb, cx, `null`) + Equals(t, len(tb.HelperCalls()), 4) + Equals(t, len(tb.FatalfCalls()), 2) + }) + +} diff --git a/types.go b/types.go index 7ac49d7..3dd4b38 100644 --- a/types.go +++ b/types.go @@ -93,7 +93,7 @@ func NewEntityUIDSet(args ...EntityUID) EntityUIDSet { // The pattern components may be one of string, cedar.String, or cedar.Wildcard. Any other types will // cause a panic. func NewPattern(components ...any) Pattern { - return types.NewPattern(components) + return types.NewPattern(components...) } // NewRecord returns an immutable Record given a Go map of Strings to Values diff --git a/types/ipaddr_test.go b/types/ipaddr_test.go index e7c9924..7d87905 100644 --- a/types/ipaddr_test.go +++ b/types/ipaddr_test.go @@ -269,4 +269,12 @@ func TestIP(t *testing.T) { }) } }) + t.Run("MarshalCedar", func(t *testing.T) { + t.Parallel() + testutil.Equals( + t, + string(testutil.Must(types.ParseIPAddr("10.0.0.42")).MarshalCedar()), + `ip("10.0.0.42")`) + }) + } diff --git a/types/record_test.go b/types/record_test.go index ac1c0b7..c422abb 100644 --- a/types/record_test.go +++ b/types/record_test.go @@ -41,6 +41,12 @@ func TestRecord(t *testing.T) { "two": types.Long(2), "nest": twoElems, }) + sameHash1 := types.NewRecord(types.RecordMap{ + "key": types.Long(0), + }) + sameHash2 := types.NewRecord(types.RecordMap{ + "key": testutil.Must(types.NewDecimalFromInt(0)), + }) testutil.FatalIf(t, !empty.Equal(empty), "%v not Equal to %v", empty, empty) testutil.FatalIf(t, !empty.Equal(empty2), "%v not Equal to %v", empty, empty2) @@ -55,6 +61,9 @@ func TestRecord(t *testing.T) { testutil.FatalIf(t, nested.Equal(twoElems), "%v Equal to %v", nested, twoElems) testutil.FatalIf(t, twoElems.Equal(differentValues), "%v Equal to %v", twoElems, differentValues) testutil.FatalIf(t, twoElems.Equal(differentKeys), "%v Equal to %v", twoElems, differentKeys) + + testutil.FatalIf(t, sameHash1.Equal(sameHash2), "%v Equal to %v", sameHash1, sameHash2) + }) t.Run("string", func(t *testing.T) { diff --git a/types/set_test.go b/types/set_test.go index 4053ff7..1292697 100644 --- a/types/set_test.go +++ b/types/set_test.go @@ -206,4 +206,78 @@ func TestSet(t *testing.T) { testutil.Equals(t, got, types.Long(42)) }) + + t.Run("Contains", func(t *testing.T) { + t.Parallel() + tests := []struct { + name string + set types.Set + value types.Value + want bool + }{ + {"trueLong", types.NewSet(types.Long(42)), types.Long(42), true}, + {"falseLong", types.NewSet(types.Long(42)), types.Long(1234), false}, + {"trueDecimal", + types.NewSet(testutil.Must(types.NewDecimalFromInt(42))), + testutil.Must(types.NewDecimalFromInt(42)), + true, + }, + {"falseDecimal", + types.NewSet(testutil.Must(types.NewDecimalFromInt(42))), + testutil.Must(types.NewDecimalFromInt(1234)), + false, + }, + {"trueDuration", + types.NewSet(types.NewDurationFromMillis(42)), + types.NewDurationFromMillis(42), + true, + }, + {"falseDuration", + types.NewSet(types.NewDurationFromMillis(42)), + types.NewDurationFromMillis(1234), + false, + }, + {"trueDatetime", + types.NewSet(types.NewDatetimeFromMillis(42)), + types.NewDatetimeFromMillis(42), + true, + }, + {"falseDatetime", + types.NewSet(types.NewDatetimeFromMillis(42)), + types.NewDatetimeFromMillis(1234), + false, + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := tt.set.Contains(tt.value) + testutil.Equals(t, got, tt.want) + }) + } + }) + + t.Run("Equals", func(t *testing.T) { + t.Parallel() + tests := []struct { + name string + set types.Set + value types.Value + want bool + }{ + {"true", types.NewSet(types.Long(42)), types.NewSet(types.Long(42)), true}, + {"falseSet", types.NewSet(types.Long(42)), types.NewSet(types.Long(1234)), false}, + {"falseOtherType", types.NewSet(types.Long(42)), types.Long(24), false}, + {"falseSameHash", types.NewSet(types.Long(0)), types.NewSet(testutil.Must(types.NewDecimalFromInt(0))), false}, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := tt.set.Equal(tt.value) + testutil.Equals(t, got, tt.want) + }) + } + }) } diff --git a/types_test.go b/types_test.go new file mode 100644 index 0000000..88ed06e --- /dev/null +++ b/types_test.go @@ -0,0 +1,39 @@ +package cedar_test + +import ( + "testing" + "time" + + "github.com/cedar-policy/cedar-go" + "github.com/cedar-policy/cedar-go/internal/testutil" + "github.com/cedar-policy/cedar-go/types" +) + +func TestTypes(t *testing.T) { + t.Parallel() + testutil.Equals(t, cedar.NewDatetimeFromMillis(42), types.NewDatetimeFromMillis(42)) + testutil.Equals(t, cedar.NewDurationFromMillis(42), types.NewDurationFromMillis(42)) + ts := time.Now() + testutil.Equals(t, cedar.NewDatetime(ts), types.NewDatetime(ts)) + testutil.Equals(t, cedar.NewDuration(time.Second), types.NewDuration(time.Second)) + testutil.Equals(t, + cedar.NewPattern("test", cedar.Wildcard{}), + types.NewPattern("test", types.Wildcard{}), + ) + testutil.Equals(t, + cedar.NewSet(cedar.Long(42), cedar.Long(43)), + types.NewSet(types.Long(42), types.Long(43)), + ) + testutil.Equals(t, + testutil.Must(cedar.NewDecimal(42, 0)), + testutil.Must(types.NewDecimal(42, 0)), + ) + testutil.Equals(t, + testutil.Must(cedar.NewDecimalFromInt(42)), + testutil.Must(types.NewDecimalFromInt(42)), + ) + testutil.Equals(t, + testutil.Must(cedar.NewDecimalFromFloat(42.0)), + testutil.Must(types.NewDecimalFromFloat(42.0)), + ) +}