Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use CTEs for metrics in multi-metric cases #1526

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20241112-215817.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: Use CTEs instead of sub-queries in generated SQL.
time: 2024-11-12T21:58:17.127471-08:00
custom:
Author: plypaul
Issue: "1040"
49 changes: 49 additions & 0 deletions metricflow/dataflow/dataflow_plan_analyzer.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from __future__ import annotations

from collections import defaultdict
from dataclasses import dataclass
from typing import Dict, FrozenSet, Mapping, Sequence, Set

from metricflow_semantics.collection_helpers.merger import Mergeable
from typing_extensions import override

from metricflow.dataflow.dataflow_plan import DataflowPlan, DataflowPlanNode
from metricflow.dataflow.dataflow_plan_visitor import DataflowPlanNodeVisitorWithDefaultHandler
from metricflow.dataflow.nodes.compute_metrics import ComputeMetricsNode


class DataflowPlanAnalyzer:
Expand Down Expand Up @@ -36,6 +39,12 @@ def find_common_branches(dataflow_plan: DataflowPlan) -> Sequence[DataflowPlanNo

return tuple(sorted(dataflow_plan.sink_node.accept(common_branches_visitor)))

@staticmethod
def group_nodes_by_type(dataflow_plan: DataflowPlan) -> DataflowPlanNodeSet:
"""Groups dataflow plan nodes by type."""
grouping_visitor = _GroupNodesByTypeVisitor()
return dataflow_plan.sink_node.accept(grouping_visitor)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just a meta-question that's not specific to this PR - do we have any concerns about the performance of traversing the dataflow plan DAG so many times with all these new visitors? If so, I wonder if we could combine the functionality of some of these visitors that collect metadata about the DAG for optimization purposes into one visitor so that we only need to traverse the DAG once.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be fine since the number of nodes in a dataflow plan are generally small, e.g. < 100 nodes.



class _CountDataflowNodeVisitor(DataflowPlanNodeVisitorWithDefaultHandler[None]):
"""Helper visitor to build a dict from a node in the plan to the number of times it appears in the plan."""
Expand Down Expand Up @@ -77,3 +86,43 @@ def _default_handler(self, node: DataflowPlanNode) -> FrozenSet[DataflowPlanNode
common_branch_leaf_nodes.update(parent_node.accept(self))

return frozenset(common_branch_leaf_nodes)


@dataclass(frozen=True)
class DataflowPlanNodeSet(Mergeable):
"""Contains a set of dataflow plan nodes with fields for different types.

`ComputeMetricsNode` is the only node of interest for current use cases, but fields for other types can be added
later.
"""

compute_metric_nodes: FrozenSet[ComputeMetricsNode]

@override
def merge(self, other: DataflowPlanNodeSet) -> DataflowPlanNodeSet:
return DataflowPlanNodeSet(
compute_metric_nodes=self.compute_metric_nodes.union(other.compute_metric_nodes),
)

@classmethod
@override
def empty_instance(cls) -> DataflowPlanNodeSet:
return DataflowPlanNodeSet(
compute_metric_nodes=frozenset(),
)


class _GroupNodesByTypeVisitor(DataflowPlanNodeVisitorWithDefaultHandler[DataflowPlanNodeSet]):
"""Groups dataflow nodes by type."""

@override
def _default_handler(self, node: DataflowPlanNode) -> DataflowPlanNodeSet:
node_sets = []
for parent_node in node.parent_nodes:
node_sets.append(parent_node.accept(self))

return DataflowPlanNodeSet.merge_iterable(node_sets)

@override
def visit_compute_metrics_node(self, node: ComputeMetricsNode) -> DataflowPlanNodeSet:
return self._default_handler(node).merge(DataflowPlanNodeSet(frozenset({node})))
32 changes: 16 additions & 16 deletions metricflow/engine/metricflow_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,21 +100,21 @@ class MetricFlowQueryRequest:
"""

request_id: MetricFlowRequestId
saved_query_name: Optional[str] = None
metric_names: Optional[Sequence[str]] = None
metrics: Optional[Sequence[MetricQueryParameter]] = None
group_by_names: Optional[Sequence[str]] = None
group_by: Optional[Tuple[GroupByParameter, ...]] = None
limit: Optional[int] = None
time_constraint_start: Optional[datetime.datetime] = None
time_constraint_end: Optional[datetime.datetime] = None
where_constraints: Optional[Sequence[str]] = None
order_by_names: Optional[Sequence[str]] = None
order_by: Optional[Sequence[OrderByQueryParameter]] = None
min_max_only: bool = False
sql_optimization_level: SqlQueryOptimizationLevel = SqlQueryOptimizationLevel.O4
dataflow_plan_optimizations: FrozenSet[DataflowPlanOptimization] = DataflowPlanOptimization.enabled_optimizations()
query_type: MetricFlowQueryType = MetricFlowQueryType.METRIC
saved_query_name: Optional[str]
metric_names: Optional[Sequence[str]]
metrics: Optional[Sequence[MetricQueryParameter]]
group_by_names: Optional[Sequence[str]]
group_by: Optional[Tuple[GroupByParameter, ...]]
limit: Optional[int]
time_constraint_start: Optional[datetime.datetime]
time_constraint_end: Optional[datetime.datetime]
where_constraints: Optional[Sequence[str]]
order_by_names: Optional[Sequence[str]]
order_by: Optional[Sequence[OrderByQueryParameter]]
min_max_only: bool
sql_optimization_level: SqlQueryOptimizationLevel
dataflow_plan_optimizations: FrozenSet[DataflowPlanOptimization]
query_type: MetricFlowQueryType

@staticmethod
def create_with_random_request_id( # noqa: D102
Expand All @@ -129,7 +129,7 @@ def create_with_random_request_id( # noqa: D102
where_constraints: Optional[Sequence[str]] = None,
order_by_names: Optional[Sequence[str]] = None,
order_by: Optional[Sequence[OrderByQueryParameter]] = None,
sql_optimization_level: SqlQueryOptimizationLevel = SqlQueryOptimizationLevel.O4,
sql_optimization_level: SqlQueryOptimizationLevel = SqlQueryOptimizationLevel.default_level(),
dataflow_plan_optimizations: FrozenSet[
DataflowPlanOptimization
] = DataflowPlanOptimization.enabled_optimizations(),
Expand Down
7 changes: 5 additions & 2 deletions metricflow/plan_conversion/dataflow_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def convert_to_sql_query_plan(
self,
sql_engine_type: SqlEngine,
dataflow_plan_node: DataflowPlanNode,
optimization_level: SqlQueryOptimizationLevel = SqlQueryOptimizationLevel.O4,
optimization_level: SqlQueryOptimizationLevel = SqlQueryOptimizationLevel.default_level(),
sql_query_plan_id: Optional[DagId] = None,
) -> ConvertToSqlPlanResult:
"""Create an SQL query plan that represents the computation up to the given dataflow plan node."""
Expand Down Expand Up @@ -273,7 +273,10 @@ def _get_nodes_to_convert_to_cte(
"""Handles logic for selecting which nodes to convert to CTEs based on the request."""
dataflow_plan = dataflow_plan_node.as_plan()
nodes_to_convert_to_cte: Set[DataflowPlanNode] = set(DataflowPlanAnalyzer.find_common_branches(dataflow_plan))
# Additional nodes will be added later.

compute_metric_nodes = DataflowPlanAnalyzer.group_nodes_by_type(dataflow_plan).compute_metric_nodes
if len(compute_metric_nodes) > 1:
nodes_to_convert_to_cte.update(compute_metric_nodes)

return frozenset(nodes_to_convert_to_cte)

Expand Down
2 changes: 1 addition & 1 deletion metricflow/sql/optimizer/optimization_levels.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class SqlQueryOptimizationLevel(Enum):

@staticmethod
def default_level() -> SqlQueryOptimizationLevel: # noqa: D102
return SqlQueryOptimizationLevel.O4
return SqlQueryOptimizationLevel.O5


@dataclass(frozen=True)
Expand Down
3 changes: 0 additions & 3 deletions tests_metricflow/examples/test_node_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from metricflow.dataset.convert_semantic_model import SemanticModelToDataSetConverter
from metricflow.plan_conversion.dataflow_to_sql import DataflowToSqlQueryPlanConverter
from metricflow.protocols.sql_client import SqlClient
from metricflow.sql.optimizer.optimization_levels import SqlQueryOptimizationLevel
from metricflow.sql.render.sql_plan_renderer import SqlQueryPlanRenderer

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -57,7 +56,6 @@ def test_view_sql_generated_at_a_node(
conversion_result = to_sql_plan_converter.convert_to_sql_query_plan(
sql_engine_type=sql_client.sql_engine_type,
dataflow_plan_node=read_source_node,
optimization_level=SqlQueryOptimizationLevel.O4,
)
sql_plan_at_read_node = conversion_result.sql_plan
sql_at_read_node = sql_renderer.render_sql_query_plan(sql_plan_at_read_node).sql
Expand Down Expand Up @@ -86,7 +84,6 @@ def test_view_sql_generated_at_a_node(
conversion_result = to_sql_plan_converter.convert_to_sql_query_plan(
sql_engine_type=sql_client.sql_engine_type,
dataflow_plan_node=filter_elements_node,
optimization_level=SqlQueryOptimizationLevel.O4,
)
sql_plan_at_filter_elements_node = conversion_result.sql_plan
sql_at_filter_elements_node = sql_renderer.render_sql_query_plan(sql_plan_at_filter_elements_node).sql
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def make_execution_plan_converter( # noqa: D103
),
sql_plan_renderer=DefaultSqlQueryPlanRenderer(),
sql_client=sql_client,
sql_optimization_level=SqlQueryOptimizationLevel.O4,
sql_optimization_level=SqlQueryOptimizationLevel.default_level(),
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ def convert_and_check(
sql_engine_type=sql_client.sql_engine_type,
sql_query_plan_id=DagId.from_str("plan0_optimized"),
dataflow_plan_node=node,
optimization_level=SqlQueryOptimizationLevel.O4,
)
sql_query_plan = conversion_result.sql_plan
display_graph_if_requested(
Expand Down
1 change: 0 additions & 1 deletion tests_metricflow/query_rendering/compare_rendered_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ def render_and_check(
conversion_result = dataflow_to_sql_converter.convert_to_sql_query_plan(
sql_engine_type=sql_client.sql_engine_type,
dataflow_plan_node=optimized_plan.sink_node,
optimization_level=SqlQueryOptimizationLevel.O4,
sql_query_plan_id=DagId.from_str("plan0_optimized"),
)
sql_query_plan = conversion_result.sql_plan
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,25 @@ sql_engine: BigQuery
---
-- Combine Aggregated Outputs
-- Compute Metrics via Expressions
SELECT
COALESCE(MAX(subq_28.buys), 0) AS visit_buy_conversions
FROM (
WITH sma_28019_cte AS (
-- Read Elements From Semantic Model 'visits_source'
-- Metric Time Dimension 'ds'
SELECT
DATETIME_TRUNC(ds, day) AS metric_time__day
, user_id AS user
, 1 AS visits
FROM ***************************.fct_visits visits_source_src_28000
)

SELECT
COALESCE(MAX(subq_27.buys), 0) AS visit_buy_conversions
FROM (
-- Read From CTE For node_id=sma_28019
-- Pass Only Elements: ['visits',]
-- Aggregate Measures
SELECT
SUM(1) AS visits
FROM ***************************.fct_visits visits_source_src_28000
SUM(visits) AS visits
FROM sma_28019_cte sma_28019_cte
) subq_18
CROSS JOIN (
-- Find conversions for user within the range of 7 day
Expand All @@ -26,42 +35,33 @@ CROSS JOIN (
FROM (
-- Dedupe the fanout with mf_internal_uuid in the conversion data set
SELECT DISTINCT
FIRST_VALUE(subq_21.visits) OVER (
FIRST_VALUE(sma_28019_cte.visits) OVER (
PARTITION BY
subq_24.user
, subq_24.metric_time__day
, subq_24.mf_internal_uuid
ORDER BY subq_21.metric_time__day DESC
subq_23.user
, subq_23.metric_time__day
, subq_23.mf_internal_uuid
ORDER BY sma_28019_cte.metric_time__day DESC
ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
) AS visits
, FIRST_VALUE(subq_21.metric_time__day) OVER (
, FIRST_VALUE(sma_28019_cte.metric_time__day) OVER (
PARTITION BY
subq_24.user
, subq_24.metric_time__day
, subq_24.mf_internal_uuid
ORDER BY subq_21.metric_time__day DESC
subq_23.user
, subq_23.metric_time__day
, subq_23.mf_internal_uuid
ORDER BY sma_28019_cte.metric_time__day DESC
ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
) AS metric_time__day
, FIRST_VALUE(subq_21.user) OVER (
, FIRST_VALUE(sma_28019_cte.user) OVER (
PARTITION BY
subq_24.user
, subq_24.metric_time__day
, subq_24.mf_internal_uuid
ORDER BY subq_21.metric_time__day DESC
subq_23.user
, subq_23.metric_time__day
, subq_23.mf_internal_uuid
ORDER BY sma_28019_cte.metric_time__day DESC
ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
) AS user
, subq_24.mf_internal_uuid AS mf_internal_uuid
, subq_24.buys AS buys
FROM (
-- Read Elements From Semantic Model 'visits_source'
-- Metric Time Dimension 'ds'
-- Pass Only Elements: ['visits', 'metric_time__day', 'user']
SELECT
DATETIME_TRUNC(ds, day) AS metric_time__day
, user_id AS user
, 1 AS visits
FROM ***************************.fct_visits visits_source_src_28000
) subq_21
, subq_23.mf_internal_uuid AS mf_internal_uuid
, subq_23.buys AS buys
FROM sma_28019_cte sma_28019_cte
INNER JOIN (
-- Read Elements From Semantic Model 'buys_source'
-- Metric Time Dimension 'ds'
Expand All @@ -72,16 +72,16 @@ CROSS JOIN (
, 1 AS buys
, GENERATE_UUID() AS mf_internal_uuid
FROM ***************************.fct_buys buys_source_src_28000
) subq_24
) subq_23
ON
(
subq_21.user = subq_24.user
sma_28019_cte.user = subq_23.user
) AND (
(
subq_21.metric_time__day <= subq_24.metric_time__day
sma_28019_cte.metric_time__day <= subq_23.metric_time__day
) AND (
subq_21.metric_time__day > DATE_SUB(CAST(subq_24.metric_time__day AS DATETIME), INTERVAL 7 day)
sma_28019_cte.metric_time__day > DATE_SUB(CAST(subq_23.metric_time__day AS DATETIME), INTERVAL 7 day)
)
)
) subq_25
) subq_28
) subq_24
) subq_27
Loading
Loading