Skip to content

Commit

Permalink
SNOW-1787235 fix test multithreading with lqb and cte (#2576)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-aalam authored Nov 8, 2024
1 parent e2300ce commit f01e9a7
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 49 deletions.
26 changes: 17 additions & 9 deletions src/snowflake/snowpark/_internal/analyzer/select_statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from collections import UserDict, defaultdict
from copy import copy, deepcopy
from enum import Enum
from functools import cached_property, reduce
from functools import reduce
from typing import (
TYPE_CHECKING,
AbstractSet,
Expand Down Expand Up @@ -241,14 +241,15 @@ def __init__(
] = defaultdict(dict)
self._api_calls = api_calls.copy() if api_calls is not None else None
self._cumulative_node_complexity: Optional[Dict[PlanNodeCategory, int]] = None
self._encoded_node_id_with_query: Optional[str] = None

@property
@abstractmethod
def sql_query(self) -> str:
"""Returns the sql query of this Selectable logical plan."""
pass

@cached_property
@property
def encoded_node_id_with_query(self) -> str:
"""
Returns an encoded node id of this Selectable logical plan.
Expand All @@ -257,7 +258,10 @@ def encoded_node_id_with_query(self) -> str:
two selectable node with same queries. This is currently used by repeated subquery
elimination to detect two nodes with same query, please use it with careful.
"""
return encode_node_id_with_query(self)
with self.analyzer.session._plan_lock:
if self._encoded_node_id_with_query is None:
self._encoded_node_id_with_query = encode_node_id_with_query(self)
return self._encoded_node_id_with_query

@property
@abstractmethod
Expand Down Expand Up @@ -324,12 +328,16 @@ def plan_state(self) -> Dict[PlanState, Any]:

@property
def cumulative_node_complexity(self) -> Dict[PlanNodeCategory, int]:
if self._cumulative_node_complexity is None:
self._cumulative_node_complexity = sum_node_complexities(
self.individual_node_complexity,
*(node.cumulative_node_complexity for node in self.children_plan_nodes),
)
return self._cumulative_node_complexity
with self.analyzer.session._plan_lock:
if self._cumulative_node_complexity is None:
self._cumulative_node_complexity = sum_node_complexities(
self.individual_node_complexity,
*(
node.cumulative_node_complexity
for node in self.children_plan_nodes
),
)
return self._cumulative_node_complexity

@cumulative_node_complexity.setter
def cumulative_node_complexity(self, value: Dict[PlanNodeCategory, int]):
Expand Down
87 changes: 48 additions & 39 deletions src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ def __init__(
self.session._analyzer,
self.df_aliased_col_name_to_real_col_name,
)
self._plan_state: Optional[Dict[PlanState, Any]] = None

@property
def uuid(self) -> str:
Expand Down Expand Up @@ -379,38 +380,45 @@ def output_dict(self) -> Dict[str, Any]:
}
return self._output_dict

@cached_property
@property
def plan_state(self) -> Dict[PlanState, Any]:
from snowflake.snowpark._internal.analyzer.select_statement import (
SelectStatement,
)
with self.session._plan_lock:
if self._plan_state is not None:
# return the cached plan state
return self._plan_state

# calculate plan height and num_selects_with_complexity_merged
height = 0
num_selects_with_complexity_merged = 0
current_level = [self]
while len(current_level) > 0:
next_level = []
for node in current_level:
next_level.extend(node.children_plan_nodes)
if (
isinstance(node, SelectStatement)
and node._merge_projection_complexity_with_subquery
):
num_selects_with_complexity_merged += 1
height += 1
current_level = next_level
# calculate the repeated node status
cte_nodes, duplicated_node_complexity_distribution = find_duplicate_subtrees(
self, propagate_complexity_hist=True
)
from snowflake.snowpark._internal.analyzer.select_statement import (
SelectStatement,
)

return {
PlanState.PLAN_HEIGHT: height,
PlanState.NUM_SELECTS_WITH_COMPLEXITY_MERGED: num_selects_with_complexity_merged,
PlanState.NUM_CTE_NODES: len(cte_nodes),
PlanState.DUPLICATED_NODE_COMPLEXITY_DISTRIBUTION: duplicated_node_complexity_distribution,
}
# calculate plan height and num_selects_with_complexity_merged
height = 0
num_selects_with_complexity_merged = 0
current_level = [self]
while len(current_level) > 0:
next_level = []
for node in current_level:
next_level.extend(node.children_plan_nodes)
if (
isinstance(node, SelectStatement)
and node._merge_projection_complexity_with_subquery
):
num_selects_with_complexity_merged += 1
height += 1
current_level = next_level
# calculate the repeated node status
(
cte_nodes,
duplicated_node_complexity_distribution,
) = find_duplicate_subtrees(self, propagate_complexity_hist=True)

self._plan_state = {
PlanState.PLAN_HEIGHT: height,
PlanState.NUM_SELECTS_WITH_COMPLEXITY_MERGED: num_selects_with_complexity_merged,
PlanState.NUM_CTE_NODES: len(cte_nodes),
PlanState.DUPLICATED_NODE_COMPLEXITY_DISTRIBUTION: duplicated_node_complexity_distribution,
}
return self._plan_state

@property
def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]:
Expand All @@ -420,16 +428,17 @@ def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]:

@property
def cumulative_node_complexity(self) -> Dict[PlanNodeCategory, int]:
if self._cumulative_node_complexity is None:
# if source plan is available, the source plan complexity
# is the snowflake plan complexity.
if self.source_plan:
self._cumulative_node_complexity = (
self.source_plan.cumulative_node_complexity
)
else:
self._cumulative_node_complexity = {}
return self._cumulative_node_complexity
with self.session._plan_lock:
if self._cumulative_node_complexity is None:
# if source plan is available, the source plan complexity
# is the snowflake plan complexity.
if self.source_plan:
self._cumulative_node_complexity = (
self.source_plan.cumulative_node_complexity
)
else:
self._cumulative_node_complexity = {}
return self._cumulative_node_complexity

@cumulative_node_complexity.setter
def cumulative_node_complexity(self, value: Dict[PlanNodeCategory, int]):
Expand Down
4 changes: 4 additions & 0 deletions src/snowflake/snowpark/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,10 @@ def __init__(
# query can be slow and prevent other threads from moving on waiting for _lock.
self._package_lock = create_rlock(self._conn._thread_safe_session_enabled)

# this lock is used to protect race-conditions when evaluating critical lazy properties
# of SnowflakePlan or Selectable objects
self._plan_lock = create_rlock(self._conn._thread_safe_session_enabled)

self._custom_package_usage_config: Dict = {}
self._conf = self.RuntimeConfig(self, options or {})
self._runtime_version_from_requirement: str = None
Expand Down
53 changes: 52 additions & 1 deletion tests/integ/test_multithreading.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@

import pytest

from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import (
PlanNodeCategory,
PlanState,
)
from snowflake.snowpark._internal.compiler.cte_utils import find_duplicate_subtrees
from snowflake.snowpark.session import (
_PYTHON_SNOWPARK_ENABLE_THREAD_SAFE_SESSION,
Session,
Expand Down Expand Up @@ -678,7 +683,6 @@ def change_config_value(session_):
reason="large query breakdown is not supported in local testing mode",
run=False,
)
@pytest.mark.skip("SNOW-1787235: Investigate and fix flaky test")
def test_large_query_breakdown_with_cte(threadsafe_session):
bounds = (300, 600) if threadsafe_session.sql_simplifier_enabled else (60, 90)
try:
Expand Down Expand Up @@ -911,3 +915,50 @@ def run_query(session_, thread_id):
# otherwise, we will use the same cursor created by the main thread
# thus creating 0 new cursors.
assert mock_telemetry.call_count == (num_workers if is_enabled else 0)


@pytest.mark.xfail(
"config.getoption('local_testing_mode', default=False)",
reason="local testing does not execute sql queries",
run=False,
)
@patch("snowflake.snowpark._internal.analyzer.snowflake_plan.find_duplicate_subtrees")
def test_critical_lazy_evaluation_for_plan(
mock_find_duplicate_subtrees, threadsafe_session
):
mock_find_duplicate_subtrees.side_effect = find_duplicate_subtrees

df = threadsafe_session.sql("select 1 as a, 2 as b").filter(col("a") == 1)
for i in range(10):
df = df.with_column("a", col("a") + i + col("a"))
df = df.union_all(df)

def call_critical_lazy_methods(df_):
assert df_._plan.cumulative_node_complexity == {
PlanNodeCategory.FILTER: 2,
PlanNodeCategory.LITERAL: 22,
PlanNodeCategory.COLUMN: 64,
PlanNodeCategory.LOW_IMPACT: 42,
PlanNodeCategory.SET_OPERATION: 1,
}
assert df_._plan.plan_state == {
PlanState.PLAN_HEIGHT: 13,
PlanState.NUM_CTE_NODES: 1,
PlanState.NUM_SELECTS_WITH_COMPLEXITY_MERGED: 0,
PlanState.DUPLICATED_NODE_COMPLEXITY_DISTRIBUTION: [2, 0, 0, 0, 0, 0, 0],
}
assert (
df_._select_statement.encoded_node_id_with_query
== "b04d566533_SelectStatement"
)

with ThreadPoolExecutor(max_workers=5) as executor:
futures = [executor.submit(call_critical_lazy_methods, df) for _ in range(10)]

for future in as_completed(futures):
future.result()

# SnowflakePlan.plan_state calls find_duplicate_subtrees. This should be
# called only once and the cached result should be used for the rest of
# the calls.
mock_find_duplicate_subtrees.assert_called_once()
1 change: 1 addition & 0 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,5 +79,6 @@ def mock_session(mock_analyzer) -> Session:
fake_session = mock.create_autospec(Session)
fake_session._cte_optimization_enabled = False
fake_session._analyzer = mock_analyzer
fake_session._plan_lock = mock.MagicMock()
mock_analyzer.session = fake_session
return fake_session

0 comments on commit f01e9a7

Please sign in to comment.