Skip to content

Commit

Permalink
Add support to select using graph-operators when using `LoadMode.CUST…
Browse files Browse the repository at this point in the history
…OM` or `LoadMode.DBT_MANIFEST` (astronomer#728)

Add support for the following when using `LoadMode.CUSTOM` or
`LoadMode.DBT_MANIFEST`:

* Support selection of model by name
* Support the selection of models by name & their children (with or
without degrees)
* Support the selection of models by name & their parents (with or
without degrees)
* Support intersections and unions involving graph selectors (with or
without other supported selectors, eg. tags)

Examples of select/exclusion statements that now work regardless of the
`LoadMode` being used:

```
model_a
+model_b
model_c+
+model_d+
2+model_e
model_f+3
model_f+,tag:nightly
```

Related dbt documentation:
https://docs.getdbt.com/reference/node-selection/graph-operators
https://docs.getdbt.com/reference/node-selection/set-operators

Limitations:
* The at operator is not supported yet (`@`)
* If users opt to use graph selector, it will increase the DAG parsing
time and the task execution time when using `LoadMode.CUSTOM` or
`LoadMode.DBT_MANIFEST`


This PR improves and extends the original implementation proposed by
@tseruga in astronomer#429. Some of the changes that were introduced on top of the
original PR:
* Add support to descendants (before only precursors were supported)
* Add support to different depths/degrees of precursors/descendants
* Add support to the union between graph operators and graph/non-graph
operators
* Add support to the intersection between graph operators and
graph/non-graph operators

Closes: astronomer#684

Co-authored-by: Tyler Seruga <[email protected]>
  • Loading branch information
2 people authored and arojasb3 committed Jul 14, 2024
1 parent 1373cf7 commit f3c21b5
Show file tree
Hide file tree
Showing 3 changed files with 424 additions and 45 deletions.
215 changes: 197 additions & 18 deletions cosmos/dbt/selector.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from __future__ import annotations
from pathlib import Path
import copy

import re
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Any

from cosmos.constants import DbtResourceType
Expand All @@ -16,11 +18,154 @@
PATH_SELECTOR = "path:"
TAG_SELECTOR = "tag:"
CONFIG_SELECTOR = "config."

PLUS_SELECTOR = "+"
GRAPH_SELECTOR_REGEX = r"^([0-9]*\+)?([^\+]+)(\+[0-9]*)?$|"

logger = get_logger(__name__)


@dataclass
class GraphSelector:
"""
Implements dbt graph operator selectors:
model_a
+model_b
model_c+
+model_d+
2+model_e
model_f+3
https://docs.getdbt.com/reference/node-selection/graph-operators
"""

node_name: str
precursors: str | None
descendants: str | None

@property
def precursors_depth(self) -> int:
"""
Calculates the depth/degrees/generations of precursors (parents).
Return:
-1: if it should return all the generations of precursors
0: if it shouldn't return any precursors
>0: upperbound number of parent generations
"""
if not self.precursors:
return 0
if self.precursors == "+":
return -1
else:
return int(self.precursors[:-1])

@property
def descendants_depth(self) -> int:
"""
Calculates the depth/degrees/generations of descendants (children).
Return:
-1: if it should return all the generations of children
0: if it shouldn't return any children
>0: upperbound of children generations
"""
if not self.descendants:
return 0
if self.descendants == "+":
return -1
else:
return int(self.descendants[1:])

@staticmethod
def parse(text: str) -> GraphSelector | None:
"""
Parse a string and identify if there are graph selectors, including the desired node name, descendants and
precursors. Return a GraphSelector instance if the pattern matches.
"""
regex_match = re.search(GRAPH_SELECTOR_REGEX, text)
if regex_match:
precursors, node_name, descendants = regex_match.groups()
return GraphSelector(node_name, precursors, descendants)
return None

def select_node_precursors(self, nodes: dict[str, DbtNode], root_id: str, selected_nodes: set[str]) -> None:
"""
Parse original nodes and add the precursor nodes related to this config to the selected_nodes set.
:param nodes: Original dbt nodes list
:param root_id: Unique identifier of self.node_name
:param selected_nodes: Set where precursor nodes will be added to.
"""
if self.precursors:
depth = self.precursors_depth
previous_generation = {root_id}
processed_nodes = set()
while depth and previous_generation:
new_generation: set[str] = set()
for node_id in previous_generation:
if node_id not in processed_nodes:
new_generation.update(set(nodes[node_id].depends_on))
processed_nodes.add(node_id)
selected_nodes.update(new_generation)
previous_generation = new_generation
depth -= 1

def select_node_descendants(self, nodes: dict[str, DbtNode], root_id: str, selected_nodes: set[str]) -> None:
"""
Parse original nodes and add the descendant nodes related to this config to the selected_nodes set.
:param nodes: Original dbt nodes list
:param root_id: Unique identifier of self.node_name
:param selected_nodes: Set where descendant nodes will be added to.
"""
if self.descendants:
children_by_node = defaultdict(set)
# Index nodes by parent id
# We could optimize by doing this only once for the dbt project and giving it
# as a parameter to the GraphSelector
for node_id, node in nodes.items():
for parent_id in node.depends_on:
children_by_node[parent_id].add(node_id)

depth = self.descendants_depth
previous_generation = {root_id}
processed_nodes = set()
while depth and previous_generation:
new_generation: set[str] = set()
for node_id in previous_generation:
if node_id not in processed_nodes:
new_generation.update(children_by_node[node_id])
processed_nodes.add(node_id)
selected_nodes.update(new_generation)
previous_generation = new_generation
depth -= 1

def filter_nodes(self, nodes: dict[str, DbtNode]) -> set[str]:
"""
Given a dictionary with the original dbt project nodes, applies the current graph selector to
identify the subset of nodes that matches the selection criteria.
:param nodes: dbt project nodes
:return: set of node ids that matches current graph selector
"""
selected_nodes: set[str] = set()

# Index nodes by name, we can improve performance by doing this once
# for multiple GraphSelectors
node_by_name = {}
for node_id, node in nodes.items():
node_by_name[node.name] = node_id

if self.node_name in node_by_name:
root_id = node_by_name[self.node_name]
else:
logger.warn(f"Selector {self.node_name} not found.")
return selected_nodes

selected_nodes.add(root_id)
self.select_node_precursors(nodes, root_id, selected_nodes)
self.select_node_descendants(nodes, root_id, selected_nodes)
return selected_nodes


class SelectorConfig:
"""
Represents a select/exclude statement.
Expand All @@ -43,11 +188,12 @@ def __init__(self, project_dir: Path | None, statement: str):
self.tags: list[str] = []
self.config: dict[str, str] = {}
self.other: list[str] = []
self.graph_selectors: list[GraphSelector] = []
self.load_from_statement(statement)

@property
def is_empty(self) -> bool:
return not (self.paths or self.tags or self.config or self.other)
return not (self.paths or self.tags or self.config or self.graph_selectors or self.other)

def load_from_statement(self, statement: str) -> None:
"""
Expand All @@ -61,6 +207,7 @@ def load_from_statement(self, statement: str) -> None:
https://docs.getdbt.com/reference/node-selection/yaml-selectors
"""
items = statement.split(",")

for item in items:
if item.startswith(PATH_SELECTOR):
index = len(PATH_SELECTOR)
Expand All @@ -77,11 +224,16 @@ def load_from_statement(self, statement: str) -> None:
if key in SUPPORTED_CONFIG:
self.config[key] = value
else:
self.other.append(item)
logger.warning("Unsupported select statement: %s", item)
if item:
graph_selector = GraphSelector.parse(item)
if graph_selector is not None:
self.graph_selectors.append(graph_selector)
else:
self.other.append(item)
logger.warning("Unsupported select statement: %s", item)

def __repr__(self) -> str:
return f"SelectorConfig(paths={self.paths}, tags={self.tags}, config={self.config}, other={self.other})"
return f"SelectorConfig(paths={self.paths}, tags={self.tags}, config={self.config}, other={self.other}, graph_selectors={self.graph_selectors})"


class NodeSelector:
Expand All @@ -95,7 +247,9 @@ class NodeSelector:
def __init__(self, nodes: dict[str, DbtNode], config: SelectorConfig) -> None:
self.nodes = nodes
self.config = config
self.selected_nodes: set[str] = set()

@property
def select_nodes_ids_by_intersection(self) -> set[str]:
"""
Return a list of node ids which matches the configuration defined in config.
Expand All @@ -107,14 +261,19 @@ def select_nodes_ids_by_intersection(self) -> set[str]:
if self.config.is_empty:
return set(self.nodes.keys())

self.selected_nodes: set[str] = set()
selected_nodes: set[str] = set()
self.visited_nodes: set[str] = set()

for node_id, node in self.nodes.items():
if self._should_include_node(node_id, node):
self.selected_nodes.add(node_id)
selected_nodes.add(node_id)

if self.config.graph_selectors:
nodes_by_graph_selector = self.select_by_graph_operator()
selected_nodes = selected_nodes.intersection(nodes_by_graph_selector)

return self.selected_nodes
self.selected_nodes = selected_nodes
return selected_nodes

def _should_include_node(self, node_id: str, node: DbtNode) -> bool:
"Checks if a single node should be included. Only runs once per node with caching."
Expand Down Expand Up @@ -175,6 +334,22 @@ def _is_path_matching(self, node: DbtNode) -> bool:
return self._should_include_node(node.depends_on[0], model_node)
return False

def select_by_graph_operator(self) -> set[str]:
"""
Return a list of node ids which match the configuration defined in the config.
Return all nodes that are parents (or parents from parents) of the root defined in the configuration.
References:
https://docs.getdbt.com/reference/node-selection/syntax
https://docs.getdbt.com/reference/node-selection/yaml-selectors
"""
selected_nodes_by_selector: list[set[str]] = []

for graph_selector in self.config.graph_selectors:
selected_nodes_by_selector.append(graph_selector.filter_nodes(self.nodes))
return set.intersection(*selected_nodes_by_selector)


def retrieve_by_label(statement_list: list[str], label: str) -> set[str]:
"""
Expand All @@ -189,7 +364,7 @@ def retrieve_by_label(statement_list: list[str], label: str) -> set[str]:
for statement in statement_list:
config = SelectorConfig(Path(), statement)
item_values = getattr(config, label)
label_values = label_values.union(item_values)
label_values.update(item_values)

return label_values

Expand Down Expand Up @@ -217,20 +392,24 @@ def select_nodes(
filters = [["select", select], ["exclude", exclude]]
for filter_type, filter in filters:
for filter_parameter in filter:
if filter_parameter.startswith(PATH_SELECTOR) or filter_parameter.startswith(TAG_SELECTOR):
if (
filter_parameter.startswith(PATH_SELECTOR)
or filter_parameter.startswith(TAG_SELECTOR)
or PLUS_SELECTOR in filter_parameter
or any([filter_parameter.startswith(CONFIG_SELECTOR + config + ":") for config in SUPPORTED_CONFIG])
):
continue
elif any([filter_parameter.startswith(CONFIG_SELECTOR + config + ":") for config in SUPPORTED_CONFIG]):
continue
else:
elif ":" in filter_parameter:
raise CosmosValueError(f"Invalid {filter_type} filter: {filter_parameter}")

subset_ids: set[str] = set()

for statement in select:
config = SelectorConfig(project_dir, statement)
node_selector = NodeSelector(nodes, config)
select_ids = node_selector.select_nodes_ids_by_intersection()
subset_ids = subset_ids.union(set(select_ids))

select_ids = node_selector.select_nodes_ids_by_intersection
subset_ids.update(set(select_ids))

if select:
nodes = {id_: nodes[id_] for id_ in subset_ids}
Expand All @@ -241,7 +420,7 @@ def select_nodes(
for statement in exclude:
config = SelectorConfig(project_dir, statement)
node_selector = NodeSelector(nodes, config)
exclude_ids = exclude_ids.union(set(node_selector.select_nodes_ids_by_intersection()))
exclude_ids.update(set(node_selector.select_nodes_ids_by_intersection))
subset_ids = set(nodes_ids) - set(exclude_ids)

return {id_: nodes[id_] for id_ in subset_ids}
35 changes: 34 additions & 1 deletion docs/configuration/selecting-excluding.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ The ``select`` and ``exclude`` parameters are lists, with values like the follow
- ``tag:my_tag``: include/exclude models with the tag ``my_tag``
- ``config.materialized:table``: include/exclude models with the config ``materialized: table``
- ``path:analytics/tables``: include/exclude models in the ``analytics/tables`` directory

- ``+node_name+1`` (graph operators): include/exclude the node with name ``node_name``, all its parents, and its first generation of children (`dbt graph selector docs <https://docs.getdbt.com/reference/node-selection/graph-operators>`_)
- ``tag:my_tag,+node_name`` (intersection): include/exclude ``node_name`` and its parents if they have the tag ``my_tag`` (`dbt set operator docs <https://docs.getdbt.com/reference/node-selection/set-operators>`_)
- ``['tag:first_tag', 'tag:second_tag']`` (union): include/exclude nodes that have either ``tag:first_tag`` or ``tag:second_tag``

.. note::

Expand Down Expand Up @@ -51,3 +53,34 @@ Examples:
select=["path:analytics/tables"],
)
)
.. code-block:: python
from cosmos import DbtDag, RenderConfig
jaffle_shop = DbtDag(
render_config=RenderConfig(
select=["tag:include_tag1", "tag:include_tag2"], # union
)
)
.. code-block:: python
from cosmos import DbtDag, RenderConfig
jaffle_shop = DbtDag(
render_config=RenderConfig(
select=["tag:include_tag1,tag:include_tag2"], # intersection
)
)
.. code-block:: python
from cosmos import DbtDag, RenderConfig
jaffle_shop = DbtDag(
render_config=RenderConfig(
exclude=["node_name+"], # node_name and its children
)
)
Loading

0 comments on commit f3c21b5

Please sign in to comment.