diff --git a/native/core/src/execution/datafusion/planner.rs b/native/core/src/execution/datafusion/planner.rs index 485df8f8e..89719859f 100644 --- a/native/core/src/execution/datafusion/planner.rs +++ b/native/core/src/execution/datafusion/planner.rs @@ -1692,24 +1692,32 @@ impl PhysicalPlanner { .and_then(|inner| inner.lower_frame_bound_struct.as_ref()) { Some(l) => match l { - LowerFrameBoundStruct::UnboundedPreceding(_) => { - match units { - WindowFrameUnits::Rows => WindowFrameBound::Preceding(ScalarValue::UInt64(None)), - WindowFrameUnits::Range | WindowFrameUnits::Groups => WindowFrameBound::Preceding(ScalarValue::Int64(None)), + LowerFrameBoundStruct::UnboundedPreceding(_) => match units { + WindowFrameUnits::Rows => { + WindowFrameBound::Preceding(ScalarValue::UInt64(None)) } - } + WindowFrameUnits::Range | WindowFrameUnits::Groups => { + WindowFrameBound::Preceding(ScalarValue::Int64(None)) + } + }, LowerFrameBoundStruct::Preceding(offset) => { - let offset_value = offset.offset.abs() as i64; + let offset_value = offset.offset.abs(); match units { - WindowFrameUnits::Rows => WindowFrameBound::Preceding(ScalarValue::UInt64(Some(offset_value as u64))), - WindowFrameUnits::Range | WindowFrameUnits::Groups => WindowFrameBound::Preceding(ScalarValue::Int64(Some(offset_value))), + WindowFrameUnits::Rows => WindowFrameBound::Preceding(ScalarValue::UInt64( + Some(offset_value as u64), + )), + WindowFrameUnits::Range | WindowFrameUnits::Groups => { + WindowFrameBound::Preceding(ScalarValue::Int64(Some(offset_value))) + } } } LowerFrameBoundStruct::CurrentRow(_) => WindowFrameBound::CurrentRow, }, None => match units { WindowFrameUnits::Rows => WindowFrameBound::Preceding(ScalarValue::UInt64(None)), - WindowFrameUnits::Range | WindowFrameUnits::Groups => WindowFrameBound::Preceding(ScalarValue::Int64(None)), + WindowFrameUnits::Range | WindowFrameUnits::Groups => { + WindowFrameBound::Preceding(ScalarValue::Int64(None)) + } }, }; @@ -1719,23 +1727,29 @@ impl PhysicalPlanner { .and_then(|inner| inner.upper_frame_bound_struct.as_ref()) { Some(u) => match u { - UpperFrameBoundStruct::UnboundedFollowing(_) => { - match units { - WindowFrameUnits::Rows => WindowFrameBound::Following(ScalarValue::UInt64(None)), - WindowFrameUnits::Range | WindowFrameUnits::Groups => WindowFrameBound::Following(ScalarValue::Int64(None)), + UpperFrameBoundStruct::UnboundedFollowing(_) => match units { + WindowFrameUnits::Rows => { + WindowFrameBound::Following(ScalarValue::UInt64(None)) } - } - UpperFrameBoundStruct::Following(offset) => { - match units { - WindowFrameUnits::Rows => WindowFrameBound::Following(ScalarValue::UInt64(Some(offset.offset as u64))), - WindowFrameUnits::Range | WindowFrameUnits::Groups => WindowFrameBound::Following(ScalarValue::Int64(Some(offset.offset as i64))), + WindowFrameUnits::Range | WindowFrameUnits::Groups => { + WindowFrameBound::Following(ScalarValue::Int64(None)) } - } + }, + UpperFrameBoundStruct::Following(offset) => match units { + WindowFrameUnits::Rows => { + WindowFrameBound::Following(ScalarValue::UInt64(Some(offset.offset as u64))) + } + WindowFrameUnits::Range | WindowFrameUnits::Groups => { + WindowFrameBound::Following(ScalarValue::Int64(Some(offset.offset))) + } + }, UpperFrameBoundStruct::CurrentRow(_) => WindowFrameBound::CurrentRow, }, None => match units { WindowFrameUnits::Rows => WindowFrameBound::Following(ScalarValue::UInt64(None)), - WindowFrameUnits::Range | WindowFrameUnits::Groups => WindowFrameBound::Following(ScalarValue::Int64(None)), + WindowFrameUnits::Range | WindowFrameUnits::Groups => { + WindowFrameBound::Following(ScalarValue::Int64(None)) + } }, }; diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 749bc9130..460c80924 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -2787,6 +2787,11 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim return None } + if (partitionSpec.nonEmpty && orderSpec.nonEmpty && + !validatePartitionAndSortSpecsForWindowFunc(partitionSpec, orderSpec, op)) { + return None + } + val windowExprProto = winExprs.map(windowExprToProto(_, output, op.conf)) val partitionExprs = partitionSpec.map(exprToProto(_, child.output)) @@ -3290,4 +3295,41 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim true } } + + private def validatePartitionAndSortSpecsForWindowFunc( + partitionSpec: Seq[Expression], + orderSpec: Seq[SortOrder], + op: SparkPlan): Boolean = { + if (partitionSpec.length != orderSpec.length) { + withInfo(op, "Partitioning and sorting specifications do not match") + return false + } else { + val partitionColumnNames = partitionSpec.collect { case a: AttributeReference => + a.name + } + + if (partitionColumnNames.length != partitionSpec.length) { + withInfo(op, "Unsupported partitioning specification") + return false + } + + val orderColumnNames = orderSpec.collect { case s: SortOrder => + s.child match { + case a: AttributeReference => a.name + } + } + + if (orderColumnNames.length != orderSpec.length) { + withInfo(op, "Unsupported SortOrder") + return false + } + + if (partitionColumnNames.toSet != orderColumnNames.toSet) { + withInfo(op, "Partitioning and sorting specifications do not match") + return false + } + + true + } + } } diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index 5a7c88b5a..e1dde458c 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -149,6 +149,22 @@ class CometExecSuite extends CometTestBase { } } + test( + "fall back to Spark when the partition spec and order spec are not the same for window function") { + withTempView("test") { + sql(""" + |CREATE OR REPLACE TEMPORARY VIEW test_agg AS SELECT * FROM VALUES + | (1, true), (1, false), + |(2, true), (3, false), (4, true) AS test(k, v) + |""".stripMargin) + + val df = sql(""" + SELECT k, v, every(v) OVER (PARTITION BY k ORDER BY v) FROM test_agg + |""".stripMargin) + checkSparkAnswer(df) + } + } + test("Native window operator should be CometUnaryExec") { withTempView("testData") { sql(""" @@ -164,11 +180,11 @@ class CometExecSuite extends CometTestBase { |(3, 1L, 1.0D, date("2017-08-01"), timestamp_seconds(1501545600), null) |AS testData(val, val_long, val_double, val_date, val_timestamp, cate) |""".stripMargin) - val df = sql(""" + val df1 = sql(""" |SELECT val, cate, count(val) OVER(PARTITION BY cate ORDER BY val ROWS CURRENT ROW) |FROM testData ORDER BY cate, val |""".stripMargin) - checkSparkAnswer(df) + checkSparkAnswer(df1) } }