Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement prefix + DBT selector graph operator for model selection #429

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 58 additions & 4 deletions cosmos/dbt/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
PATH_SELECTOR = "path:"
TAG_SELECTOR = "tag:"
CONFIG_SELECTOR = "config."

MODEL_UPSTREAM_SELECTOR = "+"

logger = get_logger(__name__)

Expand All @@ -42,6 +42,7 @@
self.paths: list[Path] = []
self.tags: list[str] = []
self.config: dict[str, str] = {}
self.model_upstream: str = ""
self.other: list[str] = []
self.load_from_statement(statement)

Expand Down Expand Up @@ -73,6 +74,9 @@
key, value = item[index:].split(":")
if key in SUPPORTED_CONFIG:
self.config[key] = value
elif item.startswith(MODEL_UPSTREAM_SELECTOR):
index = len(MODEL_UPSTREAM_SELECTOR)
self.model_upstream = item[index:]
else:
self.other.append(item)
logger.warning("Unsupported select statement: %s", item)
Expand Down Expand Up @@ -146,6 +150,48 @@
return selected_nodes


def select_nodes_ids_by_dfs(nodes: dict[str, DbtNode], config: SelectorConfig) -> set[str]:
"""
Return a list of node ids which match the configuration defined in the config.

Specifically, this method depth-first searches from the model referenced in the config,
and returns all nodes that are parents of the root.

:param nodes: Dictionary mapping dbt nodes (node.unique_id to node)
:param config: User-defined select statements

References:
https://docs.getdbt.com/reference/node-selection/syntax
https://docs.getdbt.com/reference/node-selection/yaml-selectors
"""

# Model selection is done via node name, not unique id, so we have to linearly search for it
root_id, root = None, None
for node_id, node in nodes.items():
if node.name == config.model_upstream:
root_id, root = node_id, node
break

if not root:
return set()

selected_nodes = set()

def dfs(node_id: str | None, node: DbtNode, nodes: dict[str, DbtNode]) -> None:
if node_id is None:
return

Check warning on line 182 in cosmos/dbt/selector.py

View check run for this annotation

Codecov / codecov/patch

cosmos/dbt/selector.py#L182

Added line #L182 was not covered by tests

nonlocal selected_nodes

selected_nodes.add(node_id)
for child in node.depends_on:
if child in nodes:
dfs(child, nodes[child], nodes)

dfs(root_id, root, nodes)
return selected_nodes


def retrieve_by_label(statement_list: list[str], label: str) -> set[str]:
"""
Return a set of values associated with a label.
Expand Down Expand Up @@ -184,7 +230,11 @@
filters = [["select", select], ["exclude", exclude]]
for filter_type, filter in filters:
for filter_parameter in filter:
if filter_parameter.startswith(PATH_SELECTOR) or filter_parameter.startswith(TAG_SELECTOR):
if (
filter_parameter.startswith(PATH_SELECTOR)
or filter_parameter.startswith(TAG_SELECTOR)
or filter_parameter.startswith(MODEL_UPSTREAM_SELECTOR)
):
continue
elif any([filter_parameter.startswith(CONFIG_SELECTOR + config + ":") for config in SUPPORTED_CONFIG]):
continue
Expand All @@ -195,8 +245,12 @@

for statement in select:
config = SelectorConfig(project_dir, statement)
select_ids = select_nodes_ids_by_intersection(nodes, config)
subset_ids = subset_ids.union(set(select_ids))
if config.model_upstream:
select_ids = select_nodes_ids_by_dfs(nodes, config)
subset_ids = subset_ids.union(set(select_ids))
else:
select_ids = select_nodes_ids_by_intersection(nodes, config)
subset_ids = subset_ids.union(set(select_ids))

if select:
nodes = {id_: nodes[id_] for id_ in subset_ids}
Expand Down
71 changes: 70 additions & 1 deletion tests/dbt/test_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,20 @@ def test_is_empty_config(selector_config, paths, tags, config, other, expected):
tags=["has_child"],
config={"materialized": "view"},
)
grandparent_sibling_node = DbtNode(
name="grandparent_sibling",
unique_id="grandparent_sibling",
resource_type=DbtResourceType.MODEL,
depends_on=[],
file_path=SAMPLE_PROJ_PATH / "gen1/models/grandparent_sibling.sql",
tags=[],
config={},
)
parent_node = DbtNode(
name="parent",
unique_id="parent",
resource_type=DbtResourceType.MODEL,
depends_on=["grandparent"],
depends_on=["grandparent", "grandparent_sibling"],
file_path=SAMPLE_PROJ_PATH / "gen2/models/parent.sql",
tags=["has_child"],
config={"materialized": "view"},
Expand Down Expand Up @@ -85,13 +94,24 @@ def test_is_empty_config(selector_config, paths, tags, config, other, expected):
tags=["nightly"],
config={"materialized": "table", "tags": ["deprecated", "test2"]},
)
orphaned_node = DbtNode(
name="orphaned",
unique_id="orphaned",
resource_type=DbtResourceType.MODEL,
depends_on=[],
file_path=SAMPLE_PROJ_PATH / "gen3/models/orphaned.sql",
tags=[],
config={},
)

sample_nodes = {
grandparent_node.unique_id: grandparent_node,
grandparent_sibling_node.unique_id: grandparent_sibling_node,
parent_node.unique_id: parent_node,
child_node.unique_id: child_node,
grandchild_1_test_node.unique_id: grandchild_1_test_node,
grandchild_2_test_node.unique_id: grandchild_2_test_node,
orphaned_node.unique_id: orphaned_node,
}


Expand Down Expand Up @@ -184,6 +204,8 @@ def test_select_nodes_by_exclude_tag():
child_node.unique_id: child_node,
grandchild_1_test_node.unique_id: grandchild_1_test_node,
grandchild_2_test_node.unique_id: grandchild_2_test_node,
grandparent_sibling_node.unique_id: grandparent_sibling_node,
orphaned_node.unique_id: orphaned_node,
}
assert selected == expected

Expand All @@ -208,6 +230,19 @@ def test_select_nodes_by_exclude_union_config_test_tags():
)
expected = {
grandparent_node.unique_id: grandparent_node,
grandparent_sibling_node.unique_id: grandparent_sibling_node,
parent_node.unique_id: parent_node,
child_node.unique_id: child_node,
orphaned_node.unique_id: orphaned_node,
}
assert selected == expected


def test_select_nodes_by_dfs():
selected = select_nodes(project_dir=SAMPLE_PROJ_PATH, nodes=sample_nodes, select=["+child"])
expected = {
grandparent_node.unique_id: grandparent_node,
grandparent_sibling_node.unique_id: grandparent_sibling_node,
parent_node.unique_id: parent_node,
child_node.unique_id: child_node,
}
Expand All @@ -220,6 +255,18 @@ def test_select_nodes_by_path_dir():
child_node.unique_id: child_node,
grandchild_1_test_node.unique_id: grandchild_1_test_node,
grandchild_2_test_node.unique_id: grandchild_2_test_node,
orphaned_node.unique_id: orphaned_node,
}
assert selected == expected


def test_select_nodes_by_dfs_exclude_tags():
selected = select_nodes(
project_dir=SAMPLE_PROJ_PATH, nodes=sample_nodes, select=["+child"], exclude=["tag:has_child"]
)
expected = {
grandparent_sibling_node.unique_id: grandparent_sibling_node,
child_node.unique_id: child_node,
}
assert selected == expected

Expand All @@ -230,3 +277,25 @@ def test_select_nodes_by_path_file():
parent_node.unique_id: parent_node,
}
assert selected == expected


def test_select_node_by_dfs_partial_tree():
selected = select_nodes(project_dir=SAMPLE_PROJ_PATH, nodes=sample_nodes, select=["+parent"])
expected = {
grandparent_node.unique_id: grandparent_node,
grandparent_sibling_node.unique_id: grandparent_sibling_node,
parent_node.unique_id: parent_node,
}
assert selected == expected


def test_select_node_by_dfs_leaf():
selected = select_nodes(project_dir=SAMPLE_PROJ_PATH, nodes=sample_nodes, select=["+orphaned"])
expected = {orphaned_node.unique_id: orphaned_node}
assert selected == expected


def test_select_node_by_dfs_no_node():
selected = select_nodes(project_dir=SAMPLE_PROJ_PATH, nodes=sample_nodes, select=["+modelDoesntExist"])
expected = {}
assert selected == expected
Loading