diff --git a/cosmos/dbt/selector.py b/cosmos/dbt/selector.py index 257d60721..1e7b42667 100644 --- a/cosmos/dbt/selector.py +++ b/cosmos/dbt/selector.py @@ -35,6 +35,10 @@ class GraphSelector: +model_d+ 2+model_e model_f+3 + +/path/to/model_g+ + path:/path/to/model_h+ + +tag:nightly + +config.materialized:view https://docs.getdbt.com/reference/node-selection/graph-operators """ @@ -84,6 +88,8 @@ def parse(text: str) -> GraphSelector | None: regex_match = re.search(GRAPH_SELECTOR_REGEX, text) if regex_match: precursors, node_name, descendants = regex_match.groups() + if "/" in node_name and not node_name.startswith(PATH_SELECTOR): + node_name = f"{PATH_SELECTOR}{node_name}" return GraphSelector(node_name, precursors, descendants) return None @@ -148,22 +154,63 @@ def filter_nodes(self, nodes: dict[str, DbtNode]) -> set[str]: :return: set of node ids that matches current graph selector """ selected_nodes: set[str] = set() + root_nodes: set[str] = set() # Index nodes by name, we can improve performance by doing this once # for multiple GraphSelectors - node_by_name = {} - for node_id, node in nodes.items(): - node_by_name[node.name] = node_id + if PATH_SELECTOR in self.node_name: + path_selection = self.node_name[len(PATH_SELECTOR) :] + root_nodes.update({node_id for node_id, node in nodes.items() if path_selection in str(node.file_path)}) + + elif TAG_SELECTOR in self.node_name: + tag_selection = self.node_name[len(TAG_SELECTOR) :] + root_nodes.update({node_id for node_id, node in nodes.items() if tag_selection in node.tags}) + + elif CONFIG_SELECTOR in self.node_name: + config_selection_key, config_selection_value = self.node_name[len(CONFIG_SELECTOR) :].split(":") + if config_selection_key not in SUPPORTED_CONFIG: + logger.warning("Unsupported config key selector: %s", config_selection_key) + + # currently tags, materialized, and schema are the only supported config keys + # logic is separated into two conditions because the config 'tags' contains a + # list of tags, but the config 'materialized', and 'schema' contain strings + elif config_selection_key == "tags": + root_nodes.update( + { + node_id + for node_id, node in nodes.items() + if config_selection_value in node.config.get(config_selection_key, []) + } + ) + elif config_selection_key in ( + "materialized", + "schema", + ): + root_nodes.update( + { + node_id + for node_id, node in nodes.items() + if config_selection_value == node.config.get(config_selection_key, "") + } + ) - if self.node_name in node_by_name: - root_id = node_by_name[self.node_name] else: - logger.warn(f"Selector {self.node_name} not found.") - return selected_nodes + node_by_name = {} + for node_id, node in nodes.items(): + node_by_name[node.name] = node_id + + if self.node_name in node_by_name: + root_id = node_by_name[self.node_name] + root_nodes.add(root_id) + else: + logger.warn(f"Selector {self.node_name} not found.") + return selected_nodes + + selected_nodes.update(root_nodes) - selected_nodes.add(root_id) - self.select_node_precursors(nodes, root_id, selected_nodes) - self.select_node_descendants(nodes, root_id, selected_nodes) + for root_id in root_nodes: + self.select_node_precursors(nodes, root_id, selected_nodes) + self.select_node_descendants(nodes, root_id, selected_nodes) return selected_nodes @@ -210,14 +257,23 @@ def load_from_statement(self, statement: str) -> None: items = statement.split(",") for item in items: - if item.startswith(PATH_SELECTOR): - self._parse_path_selector(item) - elif item.startswith(TAG_SELECTOR): - self._parse_tag_selector(item) - elif item.startswith(CONFIG_SELECTOR): - self._parse_config_selector(item) - else: - self._parse_unknown_selector(item) + regex_match = re.search(GRAPH_SELECTOR_REGEX, item) + if regex_match: + precursors, node_name, descendants = regex_match.groups() + if node_name is None: + ... + elif precursors or descendants: + self._parse_unknown_selector(item) + elif node_name.startswith(PATH_SELECTOR): + self._parse_path_selector(item) + elif "/" in node_name: + self._parse_path_selector(f"{PATH_SELECTOR}{node_name}") + elif node_name.startswith(TAG_SELECTOR): + self._parse_tag_selector(item) + elif node_name.startswith(CONFIG_SELECTOR): + self._parse_config_selector(item) + else: + self._parse_unknown_selector(item) def _parse_unknown_selector(self, item: str) -> None: if item: diff --git a/tests/dbt/test_selector.py b/tests/dbt/test_selector.py index ece32ac95..56f65dad0 100644 --- a/tests/dbt/test_selector.py +++ b/tests/dbt/test_selector.py @@ -191,6 +191,14 @@ def test_select_nodes_by_select_path(): assert selected == expected +def test_select_nodes_with_slash_but_no_path_selector(): + selected = select_nodes(project_dir=SAMPLE_PROJ_PATH, nodes=sample_nodes, select=["gen2/models"]) + expected = { + parent_node.unique_id: parent_node, + } + assert selected == expected + + def test_select_nodes_by_select_union(): selected = select_nodes(project_dir=SAMPLE_PROJ_PATH, nodes=sample_nodes, select=["tag:has_child", "tag:nightly"]) expected = { @@ -432,3 +440,71 @@ def test_should_include_node_without_depends_on(selector_config): selector = NodeSelector({}, selector_config) selector.visited_nodes = set() selector._should_include_node(node.unique_id, node) + + +@pytest.mark.parametrize( + "select_statement, expected", + [ + ( + ["+path:gen2/models"], + [ + "model.dbt-proj.another_grandparent_node", + "model.dbt-proj.grandparent", + "model.dbt-proj.parent", + ], + ), + ( + ["path:gen2/models+"], + [ + "model.dbt-proj.child", + "model.dbt-proj.parent", + "model.dbt-proj.sibling1", + "model.dbt-proj.sibling2", + ], + ), + ( + ["gen2/models+"], + [ + "model.dbt-proj.child", + "model.dbt-proj.parent", + "model.dbt-proj.sibling1", + "model.dbt-proj.sibling2", + ], + ), + ( + ["+gen2/models"], + [ + "model.dbt-proj.another_grandparent_node", + "model.dbt-proj.grandparent", + "model.dbt-proj.parent", + ], + ), + ( + ["1+tag:deprecated"], + [ + "model.dbt-proj.parent", + "model.dbt-proj.sibling1", + "model.dbt-proj.sibling2", + ], + ), + ( + ["1+config.tags:deprecated"], + [ + "model.dbt-proj.parent", + "model.dbt-proj.sibling1", + "model.dbt-proj.sibling2", + ], + ), + ( + ["config.materialized:table+"], + [ + "model.dbt-proj.child", + "model.dbt-proj.sibling1", + "model.dbt-proj.sibling2", + ], + ), + ], +) +def test_select_using_graph_operators(select_statement, expected): + selected = select_nodes(project_dir=SAMPLE_PROJ_PATH, nodes=sample_nodes, select=select_statement) + assert sorted(selected.keys()) == expected