Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Build DataflowPlan for custom offset window #1584

Open
wants to merge 4 commits into
base: court/custom-offset5
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20241218-133743.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: Support for custom offset windows.
time: 2024-12-18T13:37:43.23915-08:00
custom:
Author: courtneyholcomb
Issue: "1584"
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def with_aggregation_state(self, aggregation_state: AggregationState) -> TimeDim
window_function=self.window_function,
)

def with_window_function(self, window_function: SqlWindowFunction) -> TimeDimensionSpec: # noqa: D102
def with_window_function(self, window_function: Optional[SqlWindowFunction]) -> TimeDimensionSpec: # noqa: D102
return TimeDimensionSpec(
element_name=self.element_name,
entity_links=self.entity_links,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,7 @@ def test_min_queryable_time_granularity_for_different_agg_time_grains( # noqa:
def test_custom_offset_window_for_metric(
simple_semantic_manifest_lookup: SemanticManifestLookup,
) -> None:
"""Test offset window with custom grain supplied.

TODO: As of now, the functionality of an offset window with a custom grain is not supported in MF.
This test is added to show that at least the parsing is successful using a custom grain offset window.
Once support for that is added in MF + relevant tests, this test can be removed.
"""
"""Test offset window with custom grain supplied."""
metric = simple_semantic_manifest_lookup.metric_lookup.get_metric(MetricReference("bookings_offset_martian_day"))

assert len(metric.input_metrics) == 1
Expand Down
126 changes: 97 additions & 29 deletions metricflow/dataflow/builder/dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from metricflow_semantics.specs.where_filter.where_filter_spec import WhereFilterSpec
from metricflow_semantics.specs.where_filter.where_filter_spec_set import WhereFilterSpecSet
from metricflow_semantics.specs.where_filter.where_filter_transform import WhereSpecFactory
from metricflow_semantics.sql.sql_exprs import SqlWindowFunction
from metricflow_semantics.sql.sql_join_type import SqlJoinType
from metricflow_semantics.sql.sql_table import SqlTable
from metricflow_semantics.time.dateutil_adjuster import DateutilTimePeriodAdjuster
Expand Down Expand Up @@ -84,6 +85,7 @@
from metricflow.dataflow.nodes.combine_aggregated_outputs import CombineAggregatedOutputsNode
from metricflow.dataflow.nodes.compute_metrics import ComputeMetricsNode
from metricflow.dataflow.nodes.constrain_time import ConstrainTimeRangeNode
from metricflow.dataflow.nodes.custom_granularity_bounds import CustomGranularityBoundsNode
from metricflow.dataflow.nodes.filter_elements import FilterElementsNode
from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode
from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode
Expand All @@ -92,6 +94,7 @@
from metricflow.dataflow.nodes.join_to_time_spine import JoinToTimeSpineNode
from metricflow.dataflow.nodes.metric_time_transform import MetricTimeDimensionTransformNode
from metricflow.dataflow.nodes.min_max import MinMaxNode
from metricflow.dataflow.nodes.offset_by_custom_granularity import OffsetByCustomGranularityNode
from metricflow.dataflow.nodes.order_by_limit import OrderByLimitNode
from metricflow.dataflow.nodes.read_sql_source import ReadSqlSourceNode
from metricflow.dataflow.nodes.semi_additive_join import SemiAdditiveJoinNode
Expand Down Expand Up @@ -658,13 +661,18 @@ def _build_derived_metric_output_node(
)
if metric_spec.has_time_offset and queried_agg_time_dimension_specs:
# TODO: move this to a helper method
time_spine_node = self._build_time_spine_node(queried_agg_time_dimension_specs)
time_spine_node = self._build_time_spine_node(
queried_time_spine_specs=queried_agg_time_dimension_specs,
offset_window=metric_spec.offset_window,
)
output_node = JoinToTimeSpineNode.create(
metric_source_node=output_node,
time_spine_node=time_spine_node,
requested_agg_time_dimension_specs=queried_agg_time_dimension_specs,
join_on_time_dimension_spec=self._sort_by_base_granularity(queried_agg_time_dimension_specs)[0],
offset_window=metric_spec.offset_window,
offset_window=(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there cases where you can use a JoinToTimeSpineNode with a custom offset window? If not, an assertion in the class would help.

metric_spec.offset_window if not self._offset_window_is_custom(metric_spec.offset_window) else None
),
offset_to_grain=metric_spec.offset_to_grain,
join_type=SqlJoinType.INNER,
)
Expand Down Expand Up @@ -1651,13 +1659,20 @@ def _build_aggregated_measure_from_measure_source_node(
required_time_spine_specs = base_queried_agg_time_dimension_specs
if join_on_time_dimension_spec not in required_time_spine_specs:
required_time_spine_specs = (join_on_time_dimension_spec,) + required_time_spine_specs
time_spine_node = self._build_time_spine_node(required_time_spine_specs)
time_spine_node = self._build_time_spine_node(
queried_time_spine_specs=required_time_spine_specs,
offset_window=before_aggregation_time_spine_join_description.offset_window,
)
unaggregated_measure_node = JoinToTimeSpineNode.create(
metric_source_node=unaggregated_measure_node,
time_spine_node=time_spine_node,
requested_agg_time_dimension_specs=base_queried_agg_time_dimension_specs,
join_on_time_dimension_spec=join_on_time_dimension_spec,
offset_window=before_aggregation_time_spine_join_description.offset_window,
offset_window=(
before_aggregation_time_spine_join_description.offset_window
if not self._offset_window_is_custom(before_aggregation_time_spine_join_description.offset_window)
else None
),
offset_to_grain=before_aggregation_time_spine_join_description.offset_to_grain,
join_type=before_aggregation_time_spine_join_description.join_type,
)
Expand Down Expand Up @@ -1864,6 +1879,7 @@ def _build_time_spine_node(
queried_time_spine_specs: Sequence[TimeDimensionSpec],
where_filter_specs: Sequence[WhereFilterSpec] = (),
time_range_constraint: Optional[TimeRangeConstraint] = None,
offset_window: Optional[MetricTimeWindow] = None,
) -> DataflowPlanNode:
"""Return the time spine node needed to satisfy the specs."""
required_time_spine_spec_set = self.__get_required_linkable_specs(
Expand All @@ -1872,39 +1888,85 @@ def _build_time_spine_node(
)
required_time_spine_specs = required_time_spine_spec_set.time_dimension_specs

# TODO: support multiple time spines here. Build node on the one with the smallest base grain.
# Then, pass custom_granularity_specs into _build_pre_aggregation_plan if they aren't satisfied by smallest time spine.
time_spine_source = self._choose_time_spine_source(required_time_spine_specs)
read_node = self._choose_time_spine_read_node(time_spine_source)
time_spine_data_set = self._node_data_set_resolver.get_output_data_set(read_node)

# Change the column aliases to match the specs that were requested in the query.
time_spine_node = AliasSpecsNode.create(
parent_node=read_node,
change_specs=tuple(
SpecToAlias(
input_spec=time_spine_data_set.instance_from_time_dimension_grain_and_date_part(
time_granularity_name=required_spec.time_granularity.name, date_part=required_spec.date_part
).spec,
output_spec=required_spec,
)
for required_spec in required_time_spine_specs
),
)

# If the base grain of the time spine isn't selected, it will have duplicate rows that need deduping.
should_dedupe = ExpandedTimeGranularity.from_time_granularity(time_spine_source.base_granularity) not in {
spec.time_granularity for spec in queried_time_spine_specs
}
should_dedupe = False
filter_to_specs = tuple(queried_time_spine_specs)
if offset_window and self._offset_window_is_custom(offset_window):
time_spine_node = self._build_custom_offset_time_spine_node(
offset_window=offset_window, required_time_spine_specs=required_time_spine_specs
)
filter_to_specs = self._node_data_set_resolver.get_output_data_set(
time_spine_node
).instance_set.spec_set.time_dimension_specs
else:
# For simpler time spine queries, choose the appropriate time spine node and apply requested aliases.
time_spine_source = self._choose_time_spine_source(required_time_spine_specs)
# TODO: support multiple time spines here. Build node on the one with the smallest base grain.
# Then, pass custom_granularity_specs into _build_pre_aggregation_plan if they aren't satisfied by smallest time spine.
read_node = self._choose_time_spine_read_node(time_spine_source)
time_spine_data_set = self._node_data_set_resolver.get_output_data_set(read_node)
# Change the column aliases to match the specs that were requested in the query.
time_spine_node = AliasSpecsNode.create(
parent_node=read_node,
change_specs=tuple(
SpecToAlias(
input_spec=time_spine_data_set.instance_from_time_dimension_grain_and_date_part(
time_granularity_name=required_spec.time_granularity.name, date_part=required_spec.date_part
).spec,
output_spec=required_spec,
)
for required_spec in required_time_spine_specs
),
)
# If the base grain of the time spine isn't selected, it will have duplicate rows that need deduping.
should_dedupe = ExpandedTimeGranularity.from_time_granularity(time_spine_source.base_granularity) not in {
spec.time_granularity for spec in queried_time_spine_specs
}

return self._build_pre_aggregation_plan(
source_node=time_spine_node,
filter_to_specs=InstanceSpecSet(time_dimension_specs=tuple(queried_time_spine_specs)),
filter_to_specs=InstanceSpecSet(time_dimension_specs=filter_to_specs),
time_range_constraint=time_range_constraint,
where_filter_specs=where_filter_specs,
distinct=should_dedupe,
)

def _build_custom_offset_time_spine_node(
self, offset_window: MetricTimeWindow, required_time_spine_specs: Tuple[TimeDimensionSpec, ...]
) -> DataflowPlanNode:
# Build time spine node that offsets agg time dimensions by a custom grain.
custom_grain = self._semantic_model_lookup._custom_granularities[offset_window.granularity]
time_spine_source = self._choose_time_spine_source((DataSet.metric_time_dimension_spec(custom_grain),))
time_spine_read_node = self._choose_time_spine_read_node(time_spine_source)
if {spec.time_granularity for spec in required_time_spine_specs} == {custom_grain}:
# TODO: If querying with only the same grain as is used in the offset_window, can use a simpler plan.
pass
# For custom offset windows queried with other granularities, first, build CustomGranularityBoundsNode.
# This will be used twice in the output node, and ideally will be turned into a CTE.
bounds_node = CustomGranularityBoundsNode.create(
parent_node=time_spine_read_node, custom_granularity_name=custom_grain.name
)
# Build a FilterElementsNode from bounds node to get required unique rows.
bounds_data_set = self._node_data_set_resolver.get_output_data_set(bounds_node)
bounds_specs = tuple(
bounds_data_set.instance_from_window_function(window_func).spec
for window_func in (SqlWindowFunction.FIRST_VALUE, SqlWindowFunction.LAST_VALUE)
)
custom_grain_spec = bounds_data_set.instance_from_time_dimension_grain_and_date_part(
time_granularity_name=custom_grain.name, date_part=None
).spec
filter_elements_node = FilterElementsNode.create(
parent_node=bounds_node,
include_specs=InstanceSpecSet(time_dimension_specs=(custom_grain_spec,) + bounds_specs),
distinct=True,
)
# Pass both the CustomGranularityBoundsNode and the FilterElementsNode into the OffsetByCustomGranularityNode.
return OffsetByCustomGranularityNode.create(
custom_granularity_bounds_node=bounds_node,
filter_elements_node=filter_elements_node,
offset_window=offset_window,
required_time_spine_specs=required_time_spine_specs,
)

def _sort_by_base_granularity(self, time_dimension_specs: Sequence[TimeDimensionSpec]) -> List[TimeDimensionSpec]:
"""Sort the time dimensions by their base granularity.

Expand Down Expand Up @@ -1935,3 +1997,9 @@ def _determine_time_spine_join_spec(
time_granularity=join_spec_grain, date_part=None
)
return join_on_time_dimension_spec

def _offset_window_is_custom(self, offset_window: Optional[MetricTimeWindow]) -> bool:
return (
offset_window is not None
and offset_window.granularity in self._semantic_model_lookup.custom_granularity_names
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens when the granularity is not in custom_granularity_names? Is this handled through a validation?

)
58 changes: 42 additions & 16 deletions metricflow/plan_conversion/dataflow_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -1447,47 +1447,73 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet
time_spine_alias = self._next_unique_table_alias()

required_agg_time_dimension_specs = tuple(node.requested_agg_time_dimension_specs)
if node.join_on_time_dimension_spec not in node.requested_agg_time_dimension_specs:
join_spec_was_requested = node.join_on_time_dimension_spec in node.requested_agg_time_dimension_specs
if not join_spec_was_requested:
required_agg_time_dimension_specs += (node.join_on_time_dimension_spec,)

# Build join expression.
join_column_name = self._column_association_resolver.resolve_spec(node.join_on_time_dimension_spec).column_name
parent_join_column_name = self._column_association_resolver.resolve_spec(
node.join_on_time_dimension_spec
).column_name
time_spine_jon_column_name = time_spine_data_set.instance_from_time_dimension_grain_and_date_part(
time_granularity_name=node.join_on_time_dimension_spec.time_granularity.name, date_part=None
).associated_column.column_name
join_description = SqlQueryPlanJoinBuilder.make_join_to_time_spine_join_description(
node=node,
time_spine_alias=time_spine_alias,
agg_time_dimension_column_name=join_column_name,
time_spine_column_name=time_spine_jon_column_name,
parent_column_name=parent_join_column_name,
parent_sql_select_node=parent_data_set.checked_sql_select_node,
parent_alias=parent_alias,
)

# Build combined instance set.
# Build new instances and columns.
time_spine_required_spec_set = InstanceSpecSet(time_dimension_specs=required_agg_time_dimension_specs)
parent_instance_set = parent_data_set.instance_set.transform(
output_parent_instance_set = parent_data_set.instance_set.transform(
FilterElements(exclude_specs=time_spine_required_spec_set)
)
time_spine_instance_set = time_spine_data_set.instance_set.transform(
FilterElements(include_specs=time_spine_required_spec_set)
)
output_instance_set = InstanceSet.merge([parent_instance_set, time_spine_instance_set])
output_time_spine_instances: Tuple[TimeDimensionInstance, ...] = ()
output_time_spine_columns: Tuple[SqlSelectColumn, ...] = ()
for old_instance in time_spine_data_set.instance_set.time_dimension_instances:
new_spec = old_instance.spec.with_window_function(None)
if new_spec not in required_agg_time_dimension_specs:
continue
if old_instance.spec.window_function:
new_instance = old_instance.with_new_spec(
new_spec=new_spec, column_association_resolver=self._column_association_resolver
)
column = SqlSelectColumn(
expr=SqlColumnReferenceExpression.from_table_and_column_names(
table_alias=time_spine_alias, column_name=old_instance.associated_column.column_name
),
column_alias=new_instance.associated_column.column_name,
)
else:
new_instance = old_instance
column = SqlSelectColumn.from_table_and_column_names(
table_alias=time_spine_alias, column_name=old_instance.associated_column.column_name
)
output_time_spine_instances += (new_instance,)
output_time_spine_columns += (column,)

# Build new simple select columns.
select_columns = create_simple_select_columns_for_instance_sets(
output_instance_set = InstanceSet.merge(
[output_parent_instance_set, InstanceSet(time_dimension_instances=output_time_spine_instances)]
)
select_columns = output_time_spine_columns + create_simple_select_columns_for_instance_sets(
self._column_association_resolver,
OrderedDict({parent_alias: parent_instance_set, time_spine_alias: time_spine_instance_set}),
OrderedDict({parent_alias: output_parent_instance_set}),
)

# If offset_to_grain is used, will need to filter down to rows that match selected granularities.
# Does not apply if one of the granularities selected matches the time spine column granularity.
where_filter: Optional[SqlExpressionNode] = None
need_where_filter = (
node.offset_to_grain and node.join_on_time_dimension_spec not in node.requested_agg_time_dimension_specs
)
need_where_filter = node.offset_to_grain and not join_spec_was_requested

# Filter down to one row per granularity period requested in the group by. Any other granularities
# included here will be filtered out before aggregation and so should not be included in where filter.
if need_where_filter:
join_column_expr = SqlColumnReferenceExpression.from_table_and_column_names(
table_alias=time_spine_alias, column_name=join_column_name
table_alias=time_spine_alias, column_name=parent_join_column_name
)
for requested_spec in node.requested_agg_time_dimension_specs:
column_name = self._column_association_resolver.resolve_spec(requested_spec).column_name
Expand Down
Loading
Loading