From 655c0e3eb05c88942400d8049323c07cdf078430 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Sat, 28 Dec 2024 18:45:08 -0800 Subject: [PATCH 1/4] fix query gen when cte is updated --- .../snowpark/_internal/compiler/large_query_breakdown.py | 5 ----- .../snowpark/_internal/compiler/plan_compiler.py | 1 - .../snowpark/_internal/compiler/query_generator.py | 9 ++++----- 3 files changed, 4 insertions(+), 11 deletions(-) diff --git a/src/snowflake/snowpark/_internal/compiler/large_query_breakdown.py b/src/snowflake/snowpark/_internal/compiler/large_query_breakdown.py index 1f1971b9f99..4ecce744099 100644 --- a/src/snowflake/snowpark/_internal/compiler/large_query_breakdown.py +++ b/src/snowflake/snowpark/_internal/compiler/large_query_breakdown.py @@ -565,19 +565,14 @@ def _replace_child_and_update_ancestors( temp_table_selectable.post_actions = [drop_table_query] parents = self._parent_map[child] - updated_nodes = set() for parent in parents: replace_child(parent, child, temp_table_selectable, self._query_generator) nodes_to_reset = list(parents) while nodes_to_reset: node = nodes_to_reset.pop() - if node in updated_nodes: - # Skip if the node is already updated. - continue update_resolvable_node(node, self._query_generator) - updated_nodes.add(node) parents = self._parent_map[node] nodes_to_reset.extend(parents) diff --git a/src/snowflake/snowpark/_internal/compiler/plan_compiler.py b/src/snowflake/snowpark/_internal/compiler/plan_compiler.py index 5f73e6d63ff..78808f8ec60 100644 --- a/src/snowflake/snowpark/_internal/compiler/plan_compiler.py +++ b/src/snowflake/snowpark/_internal/compiler/plan_compiler.py @@ -185,7 +185,6 @@ def compile(self) -> Dict[PlanQueryType, List[Query]]: error_type=type(e).__name__, error_message=str(e), ) - pass return self.replace_temp_obj_placeholders(queries) diff --git a/src/snowflake/snowpark/_internal/compiler/query_generator.py b/src/snowflake/snowpark/_internal/compiler/query_generator.py index c9e61e6c850..6a48a488265 100644 --- a/src/snowflake/snowpark/_internal/compiler/query_generator.py +++ b/src/snowflake/snowpark/_internal/compiler/query_generator.py @@ -218,11 +218,10 @@ def do_resolve_with_resolved_children( elif isinstance(logical_plan, WithQueryBlock): resolved_child = resolved_children[logical_plan.children[0]] - # record the CTE definition of the current block - if logical_plan.name not in self.resolved_with_query_block: - self.resolved_with_query_block[ - logical_plan.name - ] = resolved_child.queries[-1] + # record/update the CTE definition of the current block + self.resolved_with_query_block[logical_plan.name] = resolved_child.queries[ + -1 + ] resolved_plan = self.plan_builder.with_query_block( logical_plan, From 08a2b8dcb70de9734fc3411a4b28443cda7cc184 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Sun, 29 Dec 2024 17:31:09 -0800 Subject: [PATCH 2/4] add test; update comment --- .../_internal/compiler/query_generator.py | 3 +- tests/integ/test_large_query_breakdown.py | 43 +++++++++++++++++++ 2 files changed, 45 insertions(+), 1 deletion(-) diff --git a/src/snowflake/snowpark/_internal/compiler/query_generator.py b/src/snowflake/snowpark/_internal/compiler/query_generator.py index 6a48a488265..6da4a974590 100644 --- a/src/snowflake/snowpark/_internal/compiler/query_generator.py +++ b/src/snowflake/snowpark/_internal/compiler/query_generator.py @@ -218,7 +218,8 @@ def do_resolve_with_resolved_children( elif isinstance(logical_plan, WithQueryBlock): resolved_child = resolved_children[logical_plan.children[0]] - # record/update the CTE definition of the current block + # record the CTE definition of the current block and update the query when + # the child is re-resolved during optimization stage. self.resolved_with_query_block[logical_plan.name] = resolved_child.queries[ -1 ] diff --git a/tests/integ/test_large_query_breakdown.py b/tests/integ/test_large_query_breakdown.py index 8daa9818a28..79f9d4e89d1 100644 --- a/tests/integ/test_large_query_breakdown.py +++ b/tests/integ/test_large_query_breakdown.py @@ -5,6 +5,7 @@ import logging import os +import re import tempfile from unittest.mock import patch @@ -734,6 +735,48 @@ def test_optimization_skipped_with_exceptions( assert kwargs["error_type"] == error_type.__name__ +def test_large_query_breakdown_with_nested_cte(session): + session.cte_optimization_enabled = True + set_bounds(session, 15, 20) + + temp_table = Utils.random_table_name() + session.create_dataframe([(1, 2), (3, 4)], ["A", "B"]).write.save_as_table( + temp_table, table_type="temp" + ) + base_select = session.table(temp_table) + for i in range(2): + base_select = base_select.with_column("A", col("A") + lit(i)) + + base_df = base_select.union_all(base_select) + + df1 = base_df.with_column("A", col("A") + 1) + df2 = base_df.with_column("B", col("B") + 1) + for i in range(2): + df1 = df1.with_column("A", col("A") + i) + + df1 = df1.group_by("A").agg(sum_distinct(col("B")).alias("B")) + df2 = df2.group_by("B").agg(sum_distinct(col("A")).alias("A")) + mid_final_df = df1.union_all(df2) + + mid1 = mid_final_df.filter(col("A") > 10) + mid2 = mid_final_df.filter(col("B") > 3) + final_df = mid1.union_all(mid2) + + with SqlCounter(query_count=1, describe_count=0): + queries = final_df.queries + # TODO: update when to_selectable memoization is merged + assert len(queries["queries"]) == 3 + assert len(queries["post_actions"]) == 2 + match = re.search(r"SNOWPARK_TEMP_CTE_[\w]+", queries["queries"][0]) + assert match is not None + cte_name_for_first_partition = match.group() + # assert that query for upper cte node is re-written and does not + # contain the cte name for the first partition + assert cte_name_for_first_partition not in queries["queries"][2] + + check_result_with_and_without_breakdown(session, final_df) + + def test_complexity_bounds_affect_num_partitions(session, large_query_df): """Test complexity bounds affect number of partitions. Also test that when partitions are added, drop table queries are added. From 8c8dbbd82a77681375416d1b581194722ee3029c Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Thu, 2 Jan 2025 14:08:37 -0800 Subject: [PATCH 3/4] Update src/snowflake/snowpark/_internal/compiler/query_generator.py Co-authored-by: Hazem Elmeleegy --- src/snowflake/snowpark/_internal/compiler/query_generator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/snowflake/snowpark/_internal/compiler/query_generator.py b/src/snowflake/snowpark/_internal/compiler/query_generator.py index 6da4a974590..f45f8eadfa6 100644 --- a/src/snowflake/snowpark/_internal/compiler/query_generator.py +++ b/src/snowflake/snowpark/_internal/compiler/query_generator.py @@ -218,7 +218,7 @@ def do_resolve_with_resolved_children( elif isinstance(logical_plan, WithQueryBlock): resolved_child = resolved_children[logical_plan.children[0]] - # record the CTE definition of the current block and update the query when + # record the CTE definition of the current block or update the query when # the child is re-resolved during optimization stage. self.resolved_with_query_block[logical_plan.name] = resolved_child.queries[ -1 From bef4e25b351ae060beb8d77c7ec147896001b9f0 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Fri, 3 Jan 2025 10:32:53 -0800 Subject: [PATCH 4/4] remove todo --- tests/integ/test_large_query_breakdown.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/integ/test_large_query_breakdown.py b/tests/integ/test_large_query_breakdown.py index fdfa30505e3..a73e7487de2 100644 --- a/tests/integ/test_large_query_breakdown.py +++ b/tests/integ/test_large_query_breakdown.py @@ -764,15 +764,14 @@ def test_large_query_breakdown_with_nested_cte(session): with SqlCounter(query_count=1, describe_count=0): queries = final_df.queries - # TODO: update when to_selectable memoization is merged - assert len(queries["queries"]) == 3 - assert len(queries["post_actions"]) == 2 + assert len(queries["queries"]) == 2 + assert len(queries["post_actions"]) == 1 match = re.search(r"SNOWPARK_TEMP_CTE_[\w]+", queries["queries"][0]) assert match is not None cte_name_for_first_partition = match.group() # assert that query for upper cte node is re-written and does not # contain the cte name for the first partition - assert cte_name_for_first_partition not in queries["queries"][2] + assert cte_name_for_first_partition not in queries["queries"][1] check_result_with_and_without_breakdown(session, final_df)