Skip to content

Commit

Permalink
Merge pull request #59 from micahhausler/extension-ast-calls
Browse files Browse the repository at this point in the history
Expose extension function calls in AST package
  • Loading branch information
patjakdev authored Nov 10, 2024
2 parents ef88a40 + b9ff4a5 commit 1a0d55f
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 4 deletions.
35 changes: 31 additions & 4 deletions ast/ast_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,8 @@ import (
"github.com/cedar-policy/cedar-go/types"
)

// These tests serve mostly as examples of how to translate from Cedar text into programmatic AST construction. They
// don't verify anything.
func TestAstExamples(t *testing.T) {
t.Parallel()
// This example shows how you can construct polcies with the ast package
func Example() {

johnny := types.NewEntityUID("User", "johnny")
sow := types.NewEntityUID("Action", "sow")
Expand Down Expand Up @@ -70,6 +68,15 @@ func TestAstExamples(t *testing.T) {
ast.Context().Access("fooCount"),
).Contains(ast.Long(1)),
)

// forbid (principal, action, resource)
// when { resource.angleRadians.greaterThan(decimal("3.1415")) }
_ = ast.Forbid().
When(
ast.Resource().Access("angleRadians").DecimalGreaterThan(
ast.DecimalExtensionCall(ast.String("3.1415")),
),
)
}

func TestASTByTable(t *testing.T) {
Expand Down Expand Up @@ -449,6 +456,26 @@ func TestASTByTable(t *testing.T) {
ast.Permit().When(ast.Duration(time.Duration(100)).ToMilliseconds()),
internalast.Permit().When(internalast.Duration(100).ToMilliseconds()),
},
{
"decimalExtension",
ast.Permit().When(ast.DecimalExtensionCall(ast.Value(types.String("3.14")))),
internalast.Permit().When(internalast.ExtensionCall("decimal", internalast.String("3.14"))),
},
{
"ipExtension",
ast.Permit().When(ast.IPExtensionCall(ast.Value(types.String("127.0.0.1")))),
internalast.Permit().When(internalast.ExtensionCall("ip", internalast.String("127.0.0.1"))),
},
{
"datetime",
ast.Permit().When(ast.DatetimeExtensionCall(ast.Value(types.String("2006-01-02T15:04:05Z07:00")))),
internalast.Permit().When(internalast.ExtensionCall("datetime", internalast.String("2006-01-02T15:04:05Z07:00"))),
},
{
"duration",
ast.Permit().When(ast.DurationExtensionCall(ast.Value(types.String("1d2h3m4s5ms")))),
internalast.Permit().When(internalast.ExtensionCall("duration", internalast.String("1d2h3m4s5ms"))),
},
}

for _, tt := range tests {
Expand Down
20 changes: 20 additions & 0 deletions ast/value.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,23 @@ func Duration(d time.Duration) Node {
func Value(v types.Value) Node {
return wrapNode(ast.Value(v))
}

// DecimalExtensionCall wraps a node with the cedar `decimal()` extension call
func DecimalExtensionCall(rhs Node) Node {
return wrapNode(ast.ExtensionCall("decimal", rhs.Node))
}

// IPExtensionCall wraps a node with the cedar `ip()` extension call
func IPExtensionCall(rhs Node) Node {
return wrapNode(ast.ExtensionCall("ip", rhs.Node))
}

// DatetimeExtensionCall wraps a node with the cedar `datetime()` extension call
func DatetimeExtensionCall(rhs Node) Node {
return wrapNode(ast.ExtensionCall("datetime", rhs.Node))
}

// DurationExtensionCall wraps a node with the cedar `duration()` extension call
func DurationExtensionCall(rhs Node) Node {
return wrapNode(ast.ExtensionCall("duration", rhs.Node))
}

0 comments on commit 1a0d55f

Please sign in to comment.