Skip to content

Commit

Permalink
SNOW-1844094: Skip optimization if error is raised during optimizatio…
Browse files Browse the repository at this point in the history
…n stage (#2710)
  • Loading branch information
sfc-gh-aalam authored Dec 5, 2024
1 parent 41fb8d5 commit 7a994fa
Show file tree
Hide file tree
Showing 4 changed files with 153 additions and 88 deletions.
193 changes: 105 additions & 88 deletions src/snowflake/snowpark/_internal/compiler/plan_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#

import copy
import logging
import time
from typing import Any, Dict, List

Expand Down Expand Up @@ -32,6 +33,8 @@
from snowflake.snowpark._internal.utils import random_name_for_temp_object
from snowflake.snowpark.mock._connection import MockServerConnection

_logger = logging.getLogger(__name__)


class PlanCompiler:
"""
Expand Down Expand Up @@ -77,98 +80,112 @@ def should_start_query_compilation(self) -> bool:
)

def compile(self) -> Dict[PlanQueryType, List[Query]]:
# initialize the queries with the original queries without optimization
final_plan = self._plan
queries = {
PlanQueryType.QUERIES: final_plan.queries,
PlanQueryType.POST_ACTIONS: final_plan.post_actions,
}

if self.should_start_query_compilation():
session = self._plan.session
# preparation for compilation
# 1. make a copy of the original plan
start_time = time.time()
complexity_score_before_compilation = get_complexity_score(self._plan)
logical_plans: List[LogicalPlan] = [copy.deepcopy(self._plan)]
plot_plan_if_enabled(self._plan, "original_plan")
plot_plan_if_enabled(logical_plans[0], "deep_copied_plan")
deep_copy_end_time = time.time()

# 2. create a code generator with the original plan
query_generator = create_query_generator(self._plan)

extra_optimization_status: Dict[str, Any] = {}
# 3. apply each optimizations if needed
# CTE optimization
cte_start_time = time.time()
if session.cte_optimization_enabled:
repeated_subquery_eliminator = RepeatedSubqueryElimination(
logical_plans, query_generator
try:
# preparation for compilation
# 1. make a copy of the original plan
start_time = time.time()
complexity_score_before_compilation = get_complexity_score(self._plan)
logical_plans: List[LogicalPlan] = [copy.deepcopy(self._plan)]
plot_plan_if_enabled(self._plan, "original_plan")
plot_plan_if_enabled(logical_plans[0], "deep_copied_plan")
deep_copy_end_time = time.time()

# 2. create a code generator with the original plan
query_generator = create_query_generator(self._plan)

extra_optimization_status: Dict[str, Any] = {}
# 3. apply each optimizations if needed
# CTE optimization
cte_start_time = time.time()
if session.cte_optimization_enabled:
repeated_subquery_eliminator = RepeatedSubqueryElimination(
logical_plans, query_generator
)
elimination_result = repeated_subquery_eliminator.apply()
logical_plans = elimination_result.logical_plans
# add the extra repeated subquery elimination status
extra_optimization_status[
CompilationStageTelemetryField.CTE_NODE_CREATED.value
] = elimination_result.total_num_of_ctes

cte_end_time = time.time()
complexity_scores_after_cte = [
get_complexity_score(logical_plan) for logical_plan in logical_plans
]
for i, plan in enumerate(logical_plans):
plot_plan_if_enabled(plan, f"cte_optimized_plan_{i}")

# Large query breakdown
breakdown_failure_summary, skipped_summary = {}, {}
if session.large_query_breakdown_enabled:
large_query_breakdown = LargeQueryBreakdown(
session,
query_generator,
logical_plans,
session.large_query_breakdown_complexity_bounds,
)
breakdown_result = large_query_breakdown.apply()
logical_plans = breakdown_result.logical_plans
breakdown_failure_summary = breakdown_result.breakdown_summary
skipped_summary = breakdown_result.skipped_summary

large_query_breakdown_end_time = time.time()
complexity_scores_after_large_query_breakdown = [
get_complexity_score(logical_plan) for logical_plan in logical_plans
]
for i, plan in enumerate(logical_plans):
plot_plan_if_enabled(plan, f"large_query_breakdown_plan_{i}")

# 4. do a final pass of code generation
queries = query_generator.generate_queries(logical_plans)

# log telemetry data
deep_copy_time = deep_copy_end_time - start_time
cte_time = cte_end_time - cte_start_time
large_query_breakdown_time = (
large_query_breakdown_end_time - cte_end_time
)
elimination_result = repeated_subquery_eliminator.apply()
logical_plans = elimination_result.logical_plans
# add the extra repeated subquery elimination status
extra_optimization_status[
CompilationStageTelemetryField.CTE_NODE_CREATED.value
] = elimination_result.total_num_of_ctes

cte_end_time = time.time()
complexity_scores_after_cte = [
get_complexity_score(logical_plan) for logical_plan in logical_plans
]
for i, plan in enumerate(logical_plans):
plot_plan_if_enabled(plan, f"cte_optimized_plan_{i}")

# Large query breakdown
breakdown_failure_summary, skipped_summary = {}, {}
if session.large_query_breakdown_enabled:
large_query_breakdown = LargeQueryBreakdown(
session,
query_generator,
logical_plans,
session.large_query_breakdown_complexity_bounds,
total_time = time.time() - start_time
summary_value = {
TelemetryField.CTE_OPTIMIZATION_ENABLED.value: session.cte_optimization_enabled,
TelemetryField.LARGE_QUERY_BREAKDOWN_ENABLED.value: session.large_query_breakdown_enabled,
CompilationStageTelemetryField.COMPLEXITY_SCORE_BOUNDS.value: session.large_query_breakdown_complexity_bounds,
CompilationStageTelemetryField.TIME_TAKEN_FOR_COMPILATION.value: total_time,
CompilationStageTelemetryField.TIME_TAKEN_FOR_DEEP_COPY_PLAN.value: deep_copy_time,
CompilationStageTelemetryField.TIME_TAKEN_FOR_CTE_OPTIMIZATION.value: cte_time,
CompilationStageTelemetryField.TIME_TAKEN_FOR_LARGE_QUERY_BREAKDOWN.value: large_query_breakdown_time,
CompilationStageTelemetryField.COMPLEXITY_SCORE_BEFORE_COMPILATION.value: complexity_score_before_compilation,
CompilationStageTelemetryField.COMPLEXITY_SCORE_AFTER_CTE_OPTIMIZATION.value: complexity_scores_after_cte,
CompilationStageTelemetryField.COMPLEXITY_SCORE_AFTER_LARGE_QUERY_BREAKDOWN.value: complexity_scores_after_large_query_breakdown,
CompilationStageTelemetryField.BREAKDOWN_FAILURE_SUMMARY.value: breakdown_failure_summary,
CompilationStageTelemetryField.TYPE_LARGE_QUERY_BREAKDOWN_OPTIMIZATION_SKIPPED.value: skipped_summary,
}
# add the extra optimization status
summary_value.update(extra_optimization_status)
session._conn._telemetry_client.send_query_compilation_summary_telemetry(
session_id=session.session_id,
plan_uuid=self._plan.uuid,
compilation_stage_summary=summary_value,
)
breakdown_result = large_query_breakdown.apply()
logical_plans = breakdown_result.logical_plans
breakdown_failure_summary = breakdown_result.breakdown_summary
skipped_summary = breakdown_result.skipped_summary

large_query_breakdown_end_time = time.time()
complexity_scores_after_large_query_breakdown = [
get_complexity_score(logical_plan) for logical_plan in logical_plans
]
for i, plan in enumerate(logical_plans):
plot_plan_if_enabled(plan, f"large_query_breakdown_plan_{i}")

# 4. do a final pass of code generation
queries = query_generator.generate_queries(logical_plans)

# log telemetry data
deep_copy_time = deep_copy_end_time - start_time
cte_time = cte_end_time - cte_start_time
large_query_breakdown_time = large_query_breakdown_end_time - cte_end_time
total_time = time.time() - start_time
summary_value = {
TelemetryField.CTE_OPTIMIZATION_ENABLED.value: session.cte_optimization_enabled,
TelemetryField.LARGE_QUERY_BREAKDOWN_ENABLED.value: session.large_query_breakdown_enabled,
CompilationStageTelemetryField.COMPLEXITY_SCORE_BOUNDS.value: session.large_query_breakdown_complexity_bounds,
CompilationStageTelemetryField.TIME_TAKEN_FOR_COMPILATION.value: total_time,
CompilationStageTelemetryField.TIME_TAKEN_FOR_DEEP_COPY_PLAN.value: deep_copy_time,
CompilationStageTelemetryField.TIME_TAKEN_FOR_CTE_OPTIMIZATION.value: cte_time,
CompilationStageTelemetryField.TIME_TAKEN_FOR_LARGE_QUERY_BREAKDOWN.value: large_query_breakdown_time,
CompilationStageTelemetryField.COMPLEXITY_SCORE_BEFORE_COMPILATION.value: complexity_score_before_compilation,
CompilationStageTelemetryField.COMPLEXITY_SCORE_AFTER_CTE_OPTIMIZATION.value: complexity_scores_after_cte,
CompilationStageTelemetryField.COMPLEXITY_SCORE_AFTER_LARGE_QUERY_BREAKDOWN.value: complexity_scores_after_large_query_breakdown,
CompilationStageTelemetryField.BREAKDOWN_FAILURE_SUMMARY.value: breakdown_failure_summary,
CompilationStageTelemetryField.TYPE_LARGE_QUERY_BREAKDOWN_OPTIMIZATION_SKIPPED.value: skipped_summary,
}
# add the extra optimization status
summary_value.update(extra_optimization_status)
session._conn._telemetry_client.send_query_compilation_summary_telemetry(
session_id=session.session_id,
plan_uuid=self._plan.uuid,
compilation_stage_summary=summary_value,
)
else:
final_plan = self._plan
queries = {
PlanQueryType.QUERIES: final_plan.queries,
PlanQueryType.POST_ACTIONS: final_plan.post_actions,
}
except Exception as e:
# if any error occurs during the compilation, we should fall back to the original plan
_logger.debug(f"Skipping optimization due to error: {e}")
session._conn._telemetry_client.send_query_compilation_stage_failed_telemetry(
session_id=session.session_id,
plan_uuid=self._plan.uuid,
error_type=type(e).__name__,
error_message=str(e),
)
pass

return self.replace_temp_obj_placeholders(queries)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,16 @@ class CompilationStageTelemetryField(Enum):
"snowpark_large_query_breakdown_optimization_skipped"
)
TYPE_COMPILATION_STAGE_STATISTICS = "snowpark_compilation_stage_statistics"
TYPE_COMPILATION_STAGE_FAILED = "snowpark_compilation_stage_failed"
TYPE_LARGE_QUERY_BREAKDOWN_UPDATE_COMPLEXITY_BOUNDS = (
"snowpark_large_query_breakdown_update_complexity_bounds"
)

# keys
KEY_REASON = "reason"
PLAN_UUID = "plan_uuid"
ERROR_TYPE = "error_type"
ERROR_MESSAGE = "error_message"
TIME_TAKEN_FOR_COMPILATION = "time_taken_for_compilation_sec"
TIME_TAKEN_FOR_DEEP_COPY_PLAN = "time_taken_for_deep_copy_plan_sec"
TIME_TAKEN_FOR_CTE_OPTIMIZATION = "time_taken_for_cte_optimization_sec"
Expand Down
16 changes: 16 additions & 0 deletions src/snowflake/snowpark/_internal/telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,22 @@ def send_query_compilation_summary_telemetry(
}
self.send(message)

def send_query_compilation_stage_failed_telemetry(
self, session_id: int, plan_uuid: str, error_type: str, error_message: str
) -> None:
message = {
**self._create_basic_telemetry_data(
CompilationStageTelemetryField.TYPE_COMPILATION_STAGE_FAILED.value
),
TelemetryField.KEY_DATA.value: {
TelemetryField.SESSION_ID.value: session_id,
CompilationStageTelemetryField.PLAN_UUID.value: plan_uuid,
CompilationStageTelemetryField.ERROR_TYPE.value: error_type,
CompilationStageTelemetryField.ERROR_MESSAGE.value: error_message,
},
}
self.send(message)

def send_temp_table_cleanup_telemetry(
self,
session_id: str,
Expand Down
29 changes: 29 additions & 0 deletions tests/integ/test_large_query_breakdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,6 +705,35 @@ def test_add_parent_plan_uuid_to_statement_params(session, large_query_df):
assert call.kwargs["_statement_params"]["_PLAN_UUID"] == plan.uuid


@pytest.mark.skipif(
IS_IN_STORED_PROC, reason="SNOW-609328: support caplog in SP regression test"
)
@pytest.mark.parametrize("error_type", [AssertionError, ValueError, RuntimeError])
@patch("snowflake.snowpark._internal.compiler.plan_compiler.LargeQueryBreakdown.apply")
def test_optimization_skipped_with_exceptions(
mock_lqb_apply, session, large_query_df, caplog, error_type
):
"""Test large query breakdown is skipped when there are exceptions"""
caplog.clear()
mock_lqb_apply.side_effect = error_type("test exception")
with caplog.at_level(logging.DEBUG):
with patch.object(
session._conn._telemetry_client,
"send_query_compilation_stage_failed_telemetry",
) as patch_send:
queries = large_query_df.queries

assert "Skipping optimization due to error:" in caplog.text
assert len(queries["queries"]) == 1
assert len(queries["post_actions"]) == 0

patch_send.assert_called_once()
_, kwargs = patch_send.call_args
print(kwargs)
assert kwargs["error_message"] == "test exception"
assert kwargs["error_type"] == error_type.__name__


def test_complexity_bounds_affect_num_partitions(session, large_query_df):
"""Test complexity bounds affect number of partitions.
Also test that when partitions are added, drop table queries are added.
Expand Down

0 comments on commit 7a994fa

Please sign in to comment.