Skip to content

Commit

Permalink
Fix backend tests
Browse files Browse the repository at this point in the history
  • Loading branch information
briantu committed Nov 21, 2024
1 parent 3a7a1a3 commit ed0bd17
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,40 +26,48 @@ def __init__(self, include_sources: bool):
def visitStart(self, ctx: AssetSelectionParser.StartContext):
return self.visit(ctx.expr())

def visitParenthesizedExpression(
self, ctx: AssetSelectionParser.ParenthesizedExpressionContext
def visitTraversalAllowedExpression(
self, ctx: AssetSelectionParser.TraversalAllowedExpressionContext
):
return self.visit(ctx.expr())
return self.visit(ctx.traversalAllowedExpr())

def visitUpAndDownTraversalExpression(
self, ctx: AssetSelectionParser.UpAndDownTraversalExpressionContext
):
selection: AssetSelection = self.visit(ctx.traversalAllowedExpr())
up_depth = self.visit(ctx.traversal(0))
down_depth = self.visit(ctx.traversal(1))
return selection.upstream(depth=up_depth) | selection.downstream(depth=down_depth)

def visitUpTraversalExpression(self, ctx: AssetSelectionParser.UpTraversalExpressionContext):
selection: AssetSelection = self.visit(ctx.expr())
selection: AssetSelection = self.visit(ctx.traversalAllowedExpr())
traversal_depth = self.visit(ctx.traversal())
return selection.upstream(depth=traversal_depth)

def visitAndExpression(self, ctx: AssetSelectionParser.AndExpressionContext):
left: AssetSelection = self.visit(ctx.expr(0))
right: AssetSelection = self.visit(ctx.expr(1))
return left & right

def visitAllExpression(self, ctx: AssetSelectionParser.AllExpressionContext):
return AssetSelection.all(include_sources=self.include_sources)

def visitNotExpression(self, ctx: AssetSelectionParser.NotExpressionContext):
selection: AssetSelection = self.visit(ctx.expr())
return AssetSelection.all(include_sources=self.include_sources) - selection

def visitDownTraversalExpression(
self, ctx: AssetSelectionParser.DownTraversalExpressionContext
):
selection: AssetSelection = self.visit(ctx.expr())
selection: AssetSelection = self.visit(ctx.traversalAllowedExpr())
traversal_depth = self.visit(ctx.traversal())
return selection.downstream(depth=traversal_depth)

def visitNotExpression(self, ctx: AssetSelectionParser.NotExpressionContext):
selection: AssetSelection = self.visit(ctx.expr())
return AssetSelection.all(include_sources=self.include_sources) - selection

def visitAndExpression(self, ctx: AssetSelectionParser.AndExpressionContext):
left: AssetSelection = self.visit(ctx.expr(0))
right: AssetSelection = self.visit(ctx.expr(1))
return left & right

def visitOrExpression(self, ctx: AssetSelectionParser.OrExpressionContext):
left: AssetSelection = self.visit(ctx.expr(0))
right: AssetSelection = self.visit(ctx.expr(1))
return left | right

def visitAllExpression(self, ctx: AssetSelectionParser.AllExpressionContext):
return AssetSelection.all(include_sources=self.include_sources)

def visitAttributeExpression(self, ctx: AssetSelectionParser.AttributeExpressionContext):
return self.visit(ctx.attributeExpr())

Expand All @@ -71,13 +79,10 @@ def visitFunctionCallExpression(self, ctx: AssetSelectionParser.FunctionCallExpr
elif function == "roots":
return selection.roots()

def visitUpAndDownTraversalExpression(
self, ctx: AssetSelectionParser.UpAndDownTraversalExpressionContext
def visitParenthesizedExpression(
self, ctx: AssetSelectionParser.ParenthesizedExpressionContext
):
selection: AssetSelection = self.visit(ctx.expr())
up_depth = self.visit(ctx.traversal(0))
down_depth = self.visit(ctx.traversal(1))
return selection.upstream(depth=up_depth) | selection.downstream(depth=down_depth)
return self.visit(ctx.expr())

def visitTraversal(self, ctx: AssetSelectionParser.TraversalContext):
# Get traversal depth from a traversal context
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,58 +11,62 @@
"selection_str, expected_tree_str",
[
("*", "(start (expr *) <EOF>)"),
("key:a", "(start (expr (traversalAllowedExpr (attributeExpr key : (value a)))) <EOF>)"),
(
"***+++",
"(start (expr (expr (traversal *) (expr (traversal *) (expr *))) (traversal + + +)) <EOF>)",
"key_substring:a",
"(start (expr (traversalAllowedExpr (attributeExpr key_substring : (value a)))) <EOF>)",
),
(
'key:"*/a+"',
'(start (expr (traversalAllowedExpr (attributeExpr key : (value "*/a+")))) <EOF>)',
),
("key:a", "(start (expr (attributeExpr key : (value a))) <EOF>)"),
("key_substring:a", "(start (expr (attributeExpr key_substring : (value a))) <EOF>)"),
('key:"*/a+"', '(start (expr (attributeExpr key : (value "*/a+"))) <EOF>)'),
(
'key_substring:"*/a+"',
'(start (expr (attributeExpr key_substring : (value "*/a+"))) <EOF>)',
'(start (expr (traversalAllowedExpr (attributeExpr key_substring : (value "*/a+")))) <EOF>)',
),
(
"sinks(key:a)",
"(start (expr (functionName sinks) ( (expr (attributeExpr key : (value a))) )) <EOF>)",
"(start (expr (traversalAllowedExpr (functionName sinks) ( (expr (traversalAllowedExpr (attributeExpr key : (value a)))) ))) <EOF>)",
),
(
"roots(key:a)",
"(start (expr (functionName roots) ( (expr (attributeExpr key : (value a))) )) <EOF>)",
"(start (expr (traversalAllowedExpr (functionName roots) ( (expr (traversalAllowedExpr (attributeExpr key : (value a)))) ))) <EOF>)",
),
(
"tag:foo=bar",
"(start (expr (traversalAllowedExpr (attributeExpr tag : (value foo) = (value bar)))) <EOF>)",
),
("tag:foo=bar", "(start (expr (attributeExpr tag : (value foo) = (value bar))) <EOF>)"),
(
'owner:"[email protected]"',
'(start (expr (attributeExpr owner : (value "[email protected]"))) <EOF>)',
'(start (expr (traversalAllowedExpr (attributeExpr owner : (value "[email protected]")))) <EOF>)',
),
(
'group:"my_group"',
'(start (expr (attributeExpr group : (value "my_group"))) <EOF>)',
'(start (expr (traversalAllowedExpr (attributeExpr group : (value "my_group")))) <EOF>)',
),
(
"kind:my_kind",
"(start (expr (traversalAllowedExpr (attributeExpr kind : (value my_kind)))) <EOF>)",
),
("kind:my_kind", "(start (expr (attributeExpr kind : (value my_kind))) <EOF>)"),
(
"code_location:my_location",
"(start (expr (attributeExpr code_location : (value my_location))) <EOF>)",
"(start (expr (traversalAllowedExpr (attributeExpr code_location : (value my_location)))) <EOF>)",
),
(
"(((key:a)))",
"(start (expr ( (expr ( (expr ( (expr (attributeExpr key : (value a))) )) )) )) <EOF>)",
"(start (expr (traversalAllowedExpr ( (expr (traversalAllowedExpr ( (expr (traversalAllowedExpr ( (expr (traversalAllowedExpr (attributeExpr key : (value a)))) ))) ))) ))) <EOF>)",
),
(
'not not key:"not"',
'(start (expr not (expr not (expr (attributeExpr key : (value "not"))))) <EOF>)',
"not key:a",
"(start (expr not (expr (traversalAllowedExpr (attributeExpr key : (value a))))) <EOF>)",
),
(
"(roots(key:a) and owner:billing)*",
"(start (expr (expr ( (expr (expr (functionName roots) ( (expr (attributeExpr key : (value a))) )) and (expr (attributeExpr owner : (value billing)))) )) (traversal *)) <EOF>)",
"key:a and key:b",
"(start (expr (expr (traversalAllowedExpr (attributeExpr key : (value a)))) and (expr (traversalAllowedExpr (attributeExpr key : (value b))))) <EOF>)",
),
(
"++(key:a+)",
"(start (expr (traversal + +) (expr ( (expr (expr (attributeExpr key : (value a))) (traversal +)) ))) <EOF>)",
),
(
"key:a* and *key:b",
"(start (expr (expr (expr (attributeExpr key : (value a))) (traversal *)) and (expr (traversal *) (expr (attributeExpr key : (value b))))) <EOF>)",
"key:a or key:b",
"(start (expr (expr (traversalAllowedExpr (attributeExpr key : (value a)))) or (expr (traversalAllowedExpr (attributeExpr key : (value b))))) <EOF>)",
),
],
)
Expand All @@ -74,6 +78,9 @@ def test_antlr_tree(selection_str, expected_tree_str):
@pytest.mark.parametrize(
"selection_str",
[
"+",
"*+",
"**key:a",
"not",
"key:a key:b",
"key:a and and",
Expand Down Expand Up @@ -103,10 +110,26 @@ def test_antlr_tree_invalid(selection_str):
("++key:a", AssetSelection.assets("a").upstream(2)),
("key:a+", AssetSelection.assets("a").downstream(1)),
("key:a++", AssetSelection.assets("a").downstream(2)),
(
"+key:a+",
AssetSelection.assets("a").upstream(1) | AssetSelection.assets("a").downstream(1),
),
("*key:a", AssetSelection.assets("a").upstream()),
("key:a*", AssetSelection.assets("a").downstream()),
(
"*key:a*",
AssetSelection.assets("a").downstream() | AssetSelection.assets("a").upstream(),
),
(
"key:a* and *key:b",
AssetSelection.assets("a").downstream() & AssetSelection.assets("b").upstream(),
),
(
"*key:a and key:b* and *key:c*",
AssetSelection.assets("a").upstream()
& AssetSelection.assets("b").downstream()
& (AssetSelection.assets("c").upstream() | AssetSelection.assets("c").downstream()),
),
("sinks(key:a)", AssetSelection.assets("a").sinks()),
("roots(key:c)", AssetSelection.assets("c").roots()),
("tag:foo", AssetSelection.tag("foo", "")),
Expand Down

0 comments on commit ed0bd17

Please sign in to comment.