diff --git a/cosmos/dbt/selector.py b/cosmos/dbt/selector.py index e10f5b9b2..3d8f6ed66 100644 --- a/cosmos/dbt/selector.py +++ b/cosmos/dbt/selector.py @@ -16,7 +16,7 @@ PATH_SELECTOR = "path:" TAG_SELECTOR = "tag:" CONFIG_SELECTOR = "config." - +MODEL_UPSTREAM_SELECTOR = "+" logger = get_logger(__name__) @@ -42,6 +42,7 @@ def __init__(self, project_dir: Path, statement: str): self.paths: list[Path] = [] self.tags: list[str] = [] self.config: dict[str, str] = {} + self.model_upstream: str = "" self.other: list[str] = [] self.load_from_statement(statement) @@ -73,6 +74,9 @@ def load_from_statement(self, statement: str) -> None: key, value = item[index:].split(":") if key in SUPPORTED_CONFIG: self.config[key] = value + elif item.startswith(MODEL_UPSTREAM_SELECTOR): + index = len(MODEL_UPSTREAM_SELECTOR) + self.model_upstream = item[index:] else: self.other.append(item) logger.warning("Unsupported select statement: %s", item) @@ -146,6 +150,48 @@ def should_include_node(node_id: str, node: DbtNode) -> bool: return selected_nodes +def select_nodes_ids_by_dfs(nodes: dict[str, DbtNode], config: SelectorConfig) -> set[str]: + """ + Return a list of node ids which match the configuration defined in the config. + + Specifically, this method depth-first searches from the model referenced in the config, + and returns all nodes that are parents of the root. + + :param nodes: Dictionary mapping dbt nodes (node.unique_id to node) + :param config: User-defined select statements + + References: + https://docs.getdbt.com/reference/node-selection/syntax + https://docs.getdbt.com/reference/node-selection/yaml-selectors + """ + + # Model selection is done via node name, not unique id, so we have to linearly search for it + root_id, root = None, None + for node_id, node in nodes.items(): + if node.name == config.model_upstream: + root_id, root = node_id, node + break + + if not root: + return set() + + selected_nodes = set() + + def dfs(node_id: str | None, node: DbtNode, nodes: dict[str, DbtNode]) -> None: + if node_id is None: + return + + nonlocal selected_nodes + + selected_nodes.add(node_id) + for child in node.depends_on: + if child in nodes: + dfs(child, nodes[child], nodes) + + dfs(root_id, root, nodes) + return selected_nodes + + def retrieve_by_label(statement_list: list[str], label: str) -> set[str]: """ Return a set of values associated with a label. @@ -184,7 +230,11 @@ def select_nodes( filters = [["select", select], ["exclude", exclude]] for filter_type, filter in filters: for filter_parameter in filter: - if filter_parameter.startswith(PATH_SELECTOR) or filter_parameter.startswith(TAG_SELECTOR): + if ( + filter_parameter.startswith(PATH_SELECTOR) + or filter_parameter.startswith(TAG_SELECTOR) + or filter_parameter.startswith(MODEL_UPSTREAM_SELECTOR) + ): continue elif any([filter_parameter.startswith(CONFIG_SELECTOR + config + ":") for config in SUPPORTED_CONFIG]): continue @@ -195,8 +245,12 @@ def select_nodes( for statement in select: config = SelectorConfig(project_dir, statement) - select_ids = select_nodes_ids_by_intersection(nodes, config) - subset_ids = subset_ids.union(set(select_ids)) + if config.model_upstream: + select_ids = select_nodes_ids_by_dfs(nodes, config) + subset_ids = subset_ids.union(set(select_ids)) + else: + select_ids = select_nodes_ids_by_intersection(nodes, config) + subset_ids = subset_ids.union(set(select_ids)) if select: nodes = {id_: nodes[id_] for id_ in subset_ids} diff --git a/tests/dbt/test_selector.py b/tests/dbt/test_selector.py index 8e3fc8c61..09b7cabb9 100644 --- a/tests/dbt/test_selector.py +++ b/tests/dbt/test_selector.py @@ -47,11 +47,20 @@ def test_is_empty_config(selector_config, paths, tags, config, other, expected): tags=["has_child"], config={"materialized": "view"}, ) +grandparent_sibling_node = DbtNode( + name="grandparent_sibling", + unique_id="grandparent_sibling", + resource_type=DbtResourceType.MODEL, + depends_on=[], + file_path=SAMPLE_PROJ_PATH / "gen1/models/grandparent_sibling.sql", + tags=[], + config={}, +) parent_node = DbtNode( name="parent", unique_id="parent", resource_type=DbtResourceType.MODEL, - depends_on=["grandparent"], + depends_on=["grandparent", "grandparent_sibling"], file_path=SAMPLE_PROJ_PATH / "gen2/models/parent.sql", tags=["has_child"], config={"materialized": "view"}, @@ -85,13 +94,24 @@ def test_is_empty_config(selector_config, paths, tags, config, other, expected): tags=["nightly"], config={"materialized": "table", "tags": ["deprecated", "test2"]}, ) +orphaned_node = DbtNode( + name="orphaned", + unique_id="orphaned", + resource_type=DbtResourceType.MODEL, + depends_on=[], + file_path=SAMPLE_PROJ_PATH / "gen3/models/orphaned.sql", + tags=[], + config={}, +) sample_nodes = { grandparent_node.unique_id: grandparent_node, + grandparent_sibling_node.unique_id: grandparent_sibling_node, parent_node.unique_id: parent_node, child_node.unique_id: child_node, grandchild_1_test_node.unique_id: grandchild_1_test_node, grandchild_2_test_node.unique_id: grandchild_2_test_node, + orphaned_node.unique_id: orphaned_node, } @@ -184,6 +204,8 @@ def test_select_nodes_by_exclude_tag(): child_node.unique_id: child_node, grandchild_1_test_node.unique_id: grandchild_1_test_node, grandchild_2_test_node.unique_id: grandchild_2_test_node, + grandparent_sibling_node.unique_id: grandparent_sibling_node, + orphaned_node.unique_id: orphaned_node, } assert selected == expected @@ -208,6 +230,19 @@ def test_select_nodes_by_exclude_union_config_test_tags(): ) expected = { grandparent_node.unique_id: grandparent_node, + grandparent_sibling_node.unique_id: grandparent_sibling_node, + parent_node.unique_id: parent_node, + child_node.unique_id: child_node, + orphaned_node.unique_id: orphaned_node, + } + assert selected == expected + + +def test_select_nodes_by_dfs(): + selected = select_nodes(project_dir=SAMPLE_PROJ_PATH, nodes=sample_nodes, select=["+child"]) + expected = { + grandparent_node.unique_id: grandparent_node, + grandparent_sibling_node.unique_id: grandparent_sibling_node, parent_node.unique_id: parent_node, child_node.unique_id: child_node, } @@ -220,6 +255,18 @@ def test_select_nodes_by_path_dir(): child_node.unique_id: child_node, grandchild_1_test_node.unique_id: grandchild_1_test_node, grandchild_2_test_node.unique_id: grandchild_2_test_node, + orphaned_node.unique_id: orphaned_node, + } + assert selected == expected + + +def test_select_nodes_by_dfs_exclude_tags(): + selected = select_nodes( + project_dir=SAMPLE_PROJ_PATH, nodes=sample_nodes, select=["+child"], exclude=["tag:has_child"] + ) + expected = { + grandparent_sibling_node.unique_id: grandparent_sibling_node, + child_node.unique_id: child_node, } assert selected == expected @@ -230,3 +277,25 @@ def test_select_nodes_by_path_file(): parent_node.unique_id: parent_node, } assert selected == expected + + +def test_select_node_by_dfs_partial_tree(): + selected = select_nodes(project_dir=SAMPLE_PROJ_PATH, nodes=sample_nodes, select=["+parent"]) + expected = { + grandparent_node.unique_id: grandparent_node, + grandparent_sibling_node.unique_id: grandparent_sibling_node, + parent_node.unique_id: parent_node, + } + assert selected == expected + + +def test_select_node_by_dfs_leaf(): + selected = select_nodes(project_dir=SAMPLE_PROJ_PATH, nodes=sample_nodes, select=["+orphaned"]) + expected = {orphaned_node.unique_id: orphaned_node} + assert selected == expected + + +def test_select_node_by_dfs_no_node(): + selected = select_nodes(project_dir=SAMPLE_PROJ_PATH, nodes=sample_nodes, select=["+modelDoesntExist"]) + expected = {} + assert selected == expected