Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1865904 fix query gen when nested cte node is partitioned #2816

Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -565,19 +565,14 @@ def _replace_child_and_update_ancestors(
temp_table_selectable.post_actions = [drop_table_query]

parents = self._parent_map[child]
updated_nodes = set()
for parent in parents:
replace_child(parent, child, temp_table_selectable, self._query_generator)

nodes_to_reset = list(parents)
while nodes_to_reset:
node = nodes_to_reset.pop()
if node in updated_nodes:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you give some more detailed description about what is the problem if we skip the node update here ? the cte optimizaiton is doing a level to level update, so it can skip the node update if it is already updated, but for large query breakdown, i recall you were doing a dfs, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider a tree below

     node1
        |
     node2
   /       \
node3      node4

suppose we start by re-resolving node3 -> node2 -> node1. In this process, node2 and node1 are marked in updated_nodes. Now, when we go updating the ancestors of node4, re-resolution of node2 and node1 would be skipped. This is not ideal if node4 update can also trigger a re-update of node2 and node1.

For example, this could be problematic is when node3 and node4 before the update had referenced_cte map. After the first update node3 -> node2 -> node1, node2 will be resolved with an older version of node4. If after a re-resolved node4 there are no referenced_ctes, then node2 will not be updated with this information.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so it is because you are doing a dfs, if you are doing updated in the order of node3, node4 then node2, last node1 you shouldn't run into problems.

# Skip if the node is already updated.
continue

update_resolvable_node(node, self._query_generator)
updated_nodes.add(node)

parents = self._parent_map[node]
nodes_to_reset.extend(parents)
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,6 @@ def compile(self) -> Dict[PlanQueryType, List[Query]]:
error_type=type(e).__name__,
error_message=str(e),
)
pass
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we have a bug before?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no. I just noticed an unnecessary pass so I cleaned it up.


return self.replace_temp_obj_placeholders(queries)

Expand Down
10 changes: 5 additions & 5 deletions src/snowflake/snowpark/_internal/compiler/query_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,11 +218,11 @@ def do_resolve_with_resolved_children(

elif isinstance(logical_plan, WithQueryBlock):
resolved_child = resolved_children[logical_plan.children[0]]
# record the CTE definition of the current block
if logical_plan.name not in self.resolved_with_query_block:
self.resolved_with_query_block[
logical_plan.name
] = resolved_child.queries[-1]
# record the CTE definition of the current block and update the query when
sfc-gh-aalam marked this conversation as resolved.
Show resolved Hide resolved
# the child is re-resolved during optimization stage.
self.resolved_with_query_block[logical_plan.name] = resolved_child.queries[
-1
]

resolved_plan = self.plan_builder.with_query_block(
logical_plan,
Expand Down
43 changes: 43 additions & 0 deletions tests/integ/test_large_query_breakdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import logging
import os
import re
import tempfile
from unittest.mock import patch

Expand Down Expand Up @@ -734,6 +735,48 @@ def test_optimization_skipped_with_exceptions(
assert kwargs["error_type"] == error_type.__name__


def test_large_query_breakdown_with_nested_cte(session):
session.cte_optimization_enabled = True
set_bounds(session, 15, 20)

temp_table = Utils.random_table_name()
session.create_dataframe([(1, 2), (3, 4)], ["A", "B"]).write.save_as_table(
temp_table, table_type="temp"
)
base_select = session.table(temp_table)
for i in range(2):
base_select = base_select.with_column("A", col("A") + lit(i))

base_df = base_select.union_all(base_select)

df1 = base_df.with_column("A", col("A") + 1)
df2 = base_df.with_column("B", col("B") + 1)
for i in range(2):
df1 = df1.with_column("A", col("A") + i)

df1 = df1.group_by("A").agg(sum_distinct(col("B")).alias("B"))
df2 = df2.group_by("B").agg(sum_distinct(col("A")).alias("A"))
mid_final_df = df1.union_all(df2)

mid1 = mid_final_df.filter(col("A") > 10)
mid2 = mid_final_df.filter(col("B") > 3)
final_df = mid1.union_all(mid2)

with SqlCounter(query_count=1, describe_count=0):
queries = final_df.queries
# TODO: update when to_selectable memoization is merged
sfc-gh-aalam marked this conversation as resolved.
Show resolved Hide resolved
assert len(queries["queries"]) == 3
assert len(queries["post_actions"]) == 2
match = re.search(r"SNOWPARK_TEMP_CTE_[\w]+", queries["queries"][0])
assert match is not None
cte_name_for_first_partition = match.group()
# assert that query for upper cte node is re-written and does not
# contain the cte name for the first partition
assert cte_name_for_first_partition not in queries["queries"][2]

check_result_with_and_without_breakdown(session, final_df)


def test_complexity_bounds_affect_num_partitions(session, large_query_df):
"""Test complexity bounds affect number of partitions.
Also test that when partitions are added, drop table queries are added.
Expand Down
Loading