Skip to content

Commit

Permalink
fix data type
Browse files Browse the repository at this point in the history
  • Loading branch information
huaxingao committed Sep 18, 2024
1 parent 8cc5772 commit 0d84fd2
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 22 deletions.
54 changes: 34 additions & 20 deletions native/core/src/execution/datafusion/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1692,24 +1692,32 @@ impl PhysicalPlanner {
.and_then(|inner| inner.lower_frame_bound_struct.as_ref())
{
Some(l) => match l {
LowerFrameBoundStruct::UnboundedPreceding(_) => {
match units {
WindowFrameUnits::Rows => WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
WindowFrameUnits::Range | WindowFrameUnits::Groups => WindowFrameBound::Preceding(ScalarValue::Int64(None)),
LowerFrameBoundStruct::UnboundedPreceding(_) => match units {
WindowFrameUnits::Rows => {
WindowFrameBound::Preceding(ScalarValue::UInt64(None))
}
}
WindowFrameUnits::Range | WindowFrameUnits::Groups => {
WindowFrameBound::Preceding(ScalarValue::Int64(None))
}
},
LowerFrameBoundStruct::Preceding(offset) => {
let offset_value = offset.offset.abs() as i64;
let offset_value = offset.offset.abs();
match units {
WindowFrameUnits::Rows => WindowFrameBound::Preceding(ScalarValue::UInt64(Some(offset_value as u64))),
WindowFrameUnits::Range | WindowFrameUnits::Groups => WindowFrameBound::Preceding(ScalarValue::Int64(Some(offset_value))),
WindowFrameUnits::Rows => WindowFrameBound::Preceding(ScalarValue::UInt64(
Some(offset_value as u64),
)),
WindowFrameUnits::Range | WindowFrameUnits::Groups => {
WindowFrameBound::Preceding(ScalarValue::Int64(Some(offset_value)))
}
}
}
LowerFrameBoundStruct::CurrentRow(_) => WindowFrameBound::CurrentRow,
},
None => match units {
WindowFrameUnits::Rows => WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
WindowFrameUnits::Range | WindowFrameUnits::Groups => WindowFrameBound::Preceding(ScalarValue::Int64(None)),
WindowFrameUnits::Range | WindowFrameUnits::Groups => {
WindowFrameBound::Preceding(ScalarValue::Int64(None))
}
},
};

Expand All @@ -1719,23 +1727,29 @@ impl PhysicalPlanner {
.and_then(|inner| inner.upper_frame_bound_struct.as_ref())
{
Some(u) => match u {
UpperFrameBoundStruct::UnboundedFollowing(_) => {
match units {
WindowFrameUnits::Rows => WindowFrameBound::Following(ScalarValue::UInt64(None)),
WindowFrameUnits::Range | WindowFrameUnits::Groups => WindowFrameBound::Following(ScalarValue::Int64(None)),
UpperFrameBoundStruct::UnboundedFollowing(_) => match units {
WindowFrameUnits::Rows => {
WindowFrameBound::Following(ScalarValue::UInt64(None))
}
}
UpperFrameBoundStruct::Following(offset) => {
match units {
WindowFrameUnits::Rows => WindowFrameBound::Following(ScalarValue::UInt64(Some(offset.offset as u64))),
WindowFrameUnits::Range | WindowFrameUnits::Groups => WindowFrameBound::Following(ScalarValue::Int64(Some(offset.offset as i64))),
WindowFrameUnits::Range | WindowFrameUnits::Groups => {
WindowFrameBound::Following(ScalarValue::Int64(None))
}
}
},
UpperFrameBoundStruct::Following(offset) => match units {
WindowFrameUnits::Rows => {
WindowFrameBound::Following(ScalarValue::UInt64(Some(offset.offset as u64)))
}
WindowFrameUnits::Range | WindowFrameUnits::Groups => {
WindowFrameBound::Following(ScalarValue::Int64(Some(offset.offset)))
}
},
UpperFrameBoundStruct::CurrentRow(_) => WindowFrameBound::CurrentRow,
},
None => match units {
WindowFrameUnits::Rows => WindowFrameBound::Following(ScalarValue::UInt64(None)),
WindowFrameUnits::Range | WindowFrameUnits::Groups => WindowFrameBound::Following(ScalarValue::Int64(None)),
WindowFrameUnits::Range | WindowFrameUnits::Groups => {
WindowFrameBound::Following(ScalarValue::Int64(None))
}
},
};

Expand Down
42 changes: 42 additions & 0 deletions spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2787,6 +2787,11 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
return None
}

if (partitionSpec.nonEmpty && orderSpec.nonEmpty &&
!validatePartitionAndSortSpecsForWindowFunc(partitionSpec, orderSpec, op)) {
return None
}

val windowExprProto = winExprs.map(windowExprToProto(_, output, op.conf))
val partitionExprs = partitionSpec.map(exprToProto(_, child.output))

Expand Down Expand Up @@ -3290,4 +3295,41 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
true
}
}

private def validatePartitionAndSortSpecsForWindowFunc(
partitionSpec: Seq[Expression],
orderSpec: Seq[SortOrder],
op: SparkPlan): Boolean = {
if (partitionSpec.length != orderSpec.length) {
withInfo(op, "Partitioning and sorting specifications do not match")
return false
} else {
val partitionColumnNames = partitionSpec.collect { case a: AttributeReference =>
a.name
}

if (partitionColumnNames.length != partitionSpec.length) {
withInfo(op, "Unsupported partitioning specification")
return false
}

val orderColumnNames = orderSpec.collect { case s: SortOrder =>
s.child match {
case a: AttributeReference => a.name
}
}

if (orderColumnNames.length != orderSpec.length) {
withInfo(op, "Unsupported SortOrder")
return false
}

if (partitionColumnNames.toSet != orderColumnNames.toSet) {
withInfo(op, "Partitioning and sorting specifications do not match")
return false
}

true
}
}
}
20 changes: 18 additions & 2 deletions spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,22 @@ class CometExecSuite extends CometTestBase {
}
}

test(
"fall back to Spark when the partition spec and order spec are not the same for window function") {
withTempView("test") {
sql("""
|CREATE OR REPLACE TEMPORARY VIEW test_agg AS SELECT * FROM VALUES
| (1, true), (1, false),
|(2, true), (3, false), (4, true) AS test(k, v)
|""".stripMargin)

val df = sql("""
SELECT k, v, every(v) OVER (PARTITION BY k ORDER BY v) FROM test_agg
|""".stripMargin)
checkSparkAnswer(df)
}
}

test("Native window operator should be CometUnaryExec") {
withTempView("testData") {
sql("""
Expand All @@ -164,11 +180,11 @@ class CometExecSuite extends CometTestBase {
|(3, 1L, 1.0D, date("2017-08-01"), timestamp_seconds(1501545600), null)
|AS testData(val, val_long, val_double, val_date, val_timestamp, cate)
|""".stripMargin)
val df = sql("""
val df1 = sql("""
|SELECT val, cate, count(val) OVER(PARTITION BY cate ORDER BY val ROWS CURRENT ROW)
|FROM testData ORDER BY cate, val
|""".stripMargin)
checkSparkAnswer(df)
checkSparkAnswer(df1)
}
}

Expand Down

0 comments on commit 0d84fd2

Please sign in to comment.