diff --git a/python_modules/dagster/dagster/_core/definitions/antlr_asset_selection/antlr_asset_selection.py b/python_modules/dagster/dagster/_core/definitions/antlr_asset_selection/antlr_asset_selection.py index 82fc5520e63bf..0ef314d766fed 100644 --- a/python_modules/dagster/dagster/_core/definitions/antlr_asset_selection/antlr_asset_selection.py +++ b/python_modules/dagster/dagster/_core/definitions/antlr_asset_selection/antlr_asset_selection.py @@ -26,40 +26,48 @@ def __init__(self, include_sources: bool): def visitStart(self, ctx: AssetSelectionParser.StartContext): return self.visit(ctx.expr()) - def visitParenthesizedExpression( - self, ctx: AssetSelectionParser.ParenthesizedExpressionContext + def visitTraversalAllowedExpression( + self, ctx: AssetSelectionParser.TraversalAllowedExpressionContext ): - return self.visit(ctx.expr()) + return self.visit(ctx.traversalAllowedExpr()) + + def visitUpAndDownTraversalExpression( + self, ctx: AssetSelectionParser.UpAndDownTraversalExpressionContext + ): + selection: AssetSelection = self.visit(ctx.traversalAllowedExpr()) + up_depth = self.visit(ctx.traversal(0)) + down_depth = self.visit(ctx.traversal(1)) + return selection.upstream(depth=up_depth) | selection.downstream(depth=down_depth) def visitUpTraversalExpression(self, ctx: AssetSelectionParser.UpTraversalExpressionContext): - selection: AssetSelection = self.visit(ctx.expr()) + selection: AssetSelection = self.visit(ctx.traversalAllowedExpr()) traversal_depth = self.visit(ctx.traversal()) return selection.upstream(depth=traversal_depth) - def visitAndExpression(self, ctx: AssetSelectionParser.AndExpressionContext): - left: AssetSelection = self.visit(ctx.expr(0)) - right: AssetSelection = self.visit(ctx.expr(1)) - return left & right - - def visitAllExpression(self, ctx: AssetSelectionParser.AllExpressionContext): - return AssetSelection.all(include_sources=self.include_sources) - - def visitNotExpression(self, ctx: AssetSelectionParser.NotExpressionContext): - selection: AssetSelection = self.visit(ctx.expr()) - return AssetSelection.all(include_sources=self.include_sources) - selection - def visitDownTraversalExpression( self, ctx: AssetSelectionParser.DownTraversalExpressionContext ): - selection: AssetSelection = self.visit(ctx.expr()) + selection: AssetSelection = self.visit(ctx.traversalAllowedExpr()) traversal_depth = self.visit(ctx.traversal()) return selection.downstream(depth=traversal_depth) + def visitNotExpression(self, ctx: AssetSelectionParser.NotExpressionContext): + selection: AssetSelection = self.visit(ctx.expr()) + return AssetSelection.all(include_sources=self.include_sources) - selection + + def visitAndExpression(self, ctx: AssetSelectionParser.AndExpressionContext): + left: AssetSelection = self.visit(ctx.expr(0)) + right: AssetSelection = self.visit(ctx.expr(1)) + return left & right + def visitOrExpression(self, ctx: AssetSelectionParser.OrExpressionContext): left: AssetSelection = self.visit(ctx.expr(0)) right: AssetSelection = self.visit(ctx.expr(1)) return left | right + def visitAllExpression(self, ctx: AssetSelectionParser.AllExpressionContext): + return AssetSelection.all(include_sources=self.include_sources) + def visitAttributeExpression(self, ctx: AssetSelectionParser.AttributeExpressionContext): return self.visit(ctx.attributeExpr()) @@ -71,13 +79,10 @@ def visitFunctionCallExpression(self, ctx: AssetSelectionParser.FunctionCallExpr elif function == "roots": return selection.roots() - def visitUpAndDownTraversalExpression( - self, ctx: AssetSelectionParser.UpAndDownTraversalExpressionContext + def visitParenthesizedExpression( + self, ctx: AssetSelectionParser.ParenthesizedExpressionContext ): - selection: AssetSelection = self.visit(ctx.expr()) - up_depth = self.visit(ctx.traversal(0)) - down_depth = self.visit(ctx.traversal(1)) - return selection.upstream(depth=up_depth) | selection.downstream(depth=down_depth) + return self.visit(ctx.expr()) def visitTraversal(self, ctx: AssetSelectionParser.TraversalContext): # Get traversal depth from a traversal context diff --git a/python_modules/dagster/dagster_tests/asset_defs_tests/test_antlr_asset_selection.py b/python_modules/dagster/dagster_tests/asset_defs_tests/test_antlr_asset_selection.py index 7648a59971df8..1da01a7f08bde 100644 --- a/python_modules/dagster/dagster_tests/asset_defs_tests/test_antlr_asset_selection.py +++ b/python_modules/dagster/dagster_tests/asset_defs_tests/test_antlr_asset_selection.py @@ -11,58 +11,62 @@ "selection_str, expected_tree_str", [ ("*", "(start (expr *) )"), + ("key:a", "(start (expr (traversalAllowedExpr (attributeExpr key : (value a)))) )"), ( - "***+++", - "(start (expr (expr (traversal *) (expr (traversal *) (expr *))) (traversal + + +)) )", + "key_substring:a", + "(start (expr (traversalAllowedExpr (attributeExpr key_substring : (value a)))) )", + ), + ( + 'key:"*/a+"', + '(start (expr (traversalAllowedExpr (attributeExpr key : (value "*/a+")))) )', ), - ("key:a", "(start (expr (attributeExpr key : (value a))) )"), - ("key_substring:a", "(start (expr (attributeExpr key_substring : (value a))) )"), - ('key:"*/a+"', '(start (expr (attributeExpr key : (value "*/a+"))) )'), ( 'key_substring:"*/a+"', - '(start (expr (attributeExpr key_substring : (value "*/a+"))) )', + '(start (expr (traversalAllowedExpr (attributeExpr key_substring : (value "*/a+")))) )', ), ( "sinks(key:a)", - "(start (expr (functionName sinks) ( (expr (attributeExpr key : (value a))) )) )", + "(start (expr (traversalAllowedExpr (functionName sinks) ( (expr (traversalAllowedExpr (attributeExpr key : (value a)))) ))) )", ), ( "roots(key:a)", - "(start (expr (functionName roots) ( (expr (attributeExpr key : (value a))) )) )", + "(start (expr (traversalAllowedExpr (functionName roots) ( (expr (traversalAllowedExpr (attributeExpr key : (value a)))) ))) )", + ), + ( + "tag:foo=bar", + "(start (expr (traversalAllowedExpr (attributeExpr tag : (value foo) = (value bar)))) )", ), - ("tag:foo=bar", "(start (expr (attributeExpr tag : (value foo) = (value bar))) )"), ( 'owner:"owner@owner.com"', - '(start (expr (attributeExpr owner : (value "owner@owner.com"))) )', + '(start (expr (traversalAllowedExpr (attributeExpr owner : (value "owner@owner.com")))) )', ), ( 'group:"my_group"', - '(start (expr (attributeExpr group : (value "my_group"))) )', + '(start (expr (traversalAllowedExpr (attributeExpr group : (value "my_group")))) )', + ), + ( + "kind:my_kind", + "(start (expr (traversalAllowedExpr (attributeExpr kind : (value my_kind)))) )", ), - ("kind:my_kind", "(start (expr (attributeExpr kind : (value my_kind))) )"), ( "code_location:my_location", - "(start (expr (attributeExpr code_location : (value my_location))) )", + "(start (expr (traversalAllowedExpr (attributeExpr code_location : (value my_location)))) )", ), ( "(((key:a)))", - "(start (expr ( (expr ( (expr ( (expr (attributeExpr key : (value a))) )) )) )) )", + "(start (expr (traversalAllowedExpr ( (expr (traversalAllowedExpr ( (expr (traversalAllowedExpr ( (expr (traversalAllowedExpr (attributeExpr key : (value a)))) ))) ))) ))) )", ), ( - 'not not key:"not"', - '(start (expr not (expr not (expr (attributeExpr key : (value "not"))))) )', + "not key:a", + "(start (expr not (expr (traversalAllowedExpr (attributeExpr key : (value a))))) )", ), ( - "(roots(key:a) and owner:billing)*", - "(start (expr (expr ( (expr (expr (functionName roots) ( (expr (attributeExpr key : (value a))) )) and (expr (attributeExpr owner : (value billing)))) )) (traversal *)) )", + "key:a and key:b", + "(start (expr (expr (traversalAllowedExpr (attributeExpr key : (value a)))) and (expr (traversalAllowedExpr (attributeExpr key : (value b))))) )", ), ( - "++(key:a+)", - "(start (expr (traversal + +) (expr ( (expr (expr (attributeExpr key : (value a))) (traversal +)) ))) )", - ), - ( - "key:a* and *key:b", - "(start (expr (expr (expr (attributeExpr key : (value a))) (traversal *)) and (expr (traversal *) (expr (attributeExpr key : (value b))))) )", + "key:a or key:b", + "(start (expr (expr (traversalAllowedExpr (attributeExpr key : (value a)))) or (expr (traversalAllowedExpr (attributeExpr key : (value b))))) )", ), ], ) @@ -74,6 +78,9 @@ def test_antlr_tree(selection_str, expected_tree_str): @pytest.mark.parametrize( "selection_str", [ + "+", + "*+", + "**key:a", "not", "key:a key:b", "key:a and and", @@ -103,10 +110,26 @@ def test_antlr_tree_invalid(selection_str): ("++key:a", AssetSelection.assets("a").upstream(2)), ("key:a+", AssetSelection.assets("a").downstream(1)), ("key:a++", AssetSelection.assets("a").downstream(2)), + ( + "+key:a+", + AssetSelection.assets("a").upstream(1) | AssetSelection.assets("a").downstream(1), + ), + ("*key:a", AssetSelection.assets("a").upstream()), + ("key:a*", AssetSelection.assets("a").downstream()), + ( + "*key:a*", + AssetSelection.assets("a").downstream() | AssetSelection.assets("a").upstream(), + ), ( "key:a* and *key:b", AssetSelection.assets("a").downstream() & AssetSelection.assets("b").upstream(), ), + ( + "*key:a and key:b* and *key:c*", + AssetSelection.assets("a").upstream() + & AssetSelection.assets("b").downstream() + & (AssetSelection.assets("c").upstream() | AssetSelection.assets("c").downstream()), + ), ("sinks(key:a)", AssetSelection.assets("a").sinks()), ("roots(key:c)", AssetSelection.assets("c").roots()), ("tag:foo", AssetSelection.tag("foo", "")),