From 8cc5772224c30de43543a4388b527adbbb9184dd Mon Sep 17 00:00:00 2001 From: huaxingao Date: Mon, 19 Aug 2024 14:22:03 -0700 Subject: [PATCH] fix offset datatype --- .../core/src/execution/datafusion/planner.rs | 32 +++++++++++++++---- .../apache/comet/exec/CometExecSuite.scala | 12 +++---- 2 files changed, 30 insertions(+), 14 deletions(-) diff --git a/native/core/src/execution/datafusion/planner.rs b/native/core/src/execution/datafusion/planner.rs index 4a9a10fb5..485df8f8e 100644 --- a/native/core/src/execution/datafusion/planner.rs +++ b/native/core/src/execution/datafusion/planner.rs @@ -1693,15 +1693,24 @@ impl PhysicalPlanner { { Some(l) => match l { LowerFrameBoundStruct::UnboundedPreceding(_) => { - WindowFrameBound::Preceding(ScalarValue::UInt64(None)) + match units { + WindowFrameUnits::Rows => WindowFrameBound::Preceding(ScalarValue::UInt64(None)), + WindowFrameUnits::Range | WindowFrameUnits::Groups => WindowFrameBound::Preceding(ScalarValue::Int64(None)), + } } LowerFrameBoundStruct::Preceding(offset) => { - let offset_value = offset.offset.unsigned_abs(); - WindowFrameBound::Preceding(ScalarValue::UInt64(Some(offset_value))) + let offset_value = offset.offset.abs() as i64; + match units { + WindowFrameUnits::Rows => WindowFrameBound::Preceding(ScalarValue::UInt64(Some(offset_value as u64))), + WindowFrameUnits::Range | WindowFrameUnits::Groups => WindowFrameBound::Preceding(ScalarValue::Int64(Some(offset_value))), + } } LowerFrameBoundStruct::CurrentRow(_) => WindowFrameBound::CurrentRow, }, - None => WindowFrameBound::Preceding(ScalarValue::UInt64(None)), + None => match units { + WindowFrameUnits::Rows => WindowFrameBound::Preceding(ScalarValue::UInt64(None)), + WindowFrameUnits::Range | WindowFrameUnits::Groups => WindowFrameBound::Preceding(ScalarValue::Int64(None)), + }, }; let upper_bound: WindowFrameBound = match spark_window_frame @@ -1711,14 +1720,23 @@ impl PhysicalPlanner { { Some(u) => match u { UpperFrameBoundStruct::UnboundedFollowing(_) => { - WindowFrameBound::Following(ScalarValue::UInt64(None)) + match units { + WindowFrameUnits::Rows => WindowFrameBound::Following(ScalarValue::UInt64(None)), + WindowFrameUnits::Range | WindowFrameUnits::Groups => WindowFrameBound::Following(ScalarValue::Int64(None)), + } } UpperFrameBoundStruct::Following(offset) => { - WindowFrameBound::Following(ScalarValue::UInt64(Some(offset.offset as u64))) + match units { + WindowFrameUnits::Rows => WindowFrameBound::Following(ScalarValue::UInt64(Some(offset.offset as u64))), + WindowFrameUnits::Range | WindowFrameUnits::Groups => WindowFrameBound::Following(ScalarValue::Int64(Some(offset.offset as i64))), + } } UpperFrameBoundStruct::CurrentRow(_) => WindowFrameBound::CurrentRow, }, - None => WindowFrameBound::Following(ScalarValue::UInt64(None)), + None => match units { + WindowFrameUnits::Rows => WindowFrameBound::Following(ScalarValue::UInt64(None)), + WindowFrameUnits::Range | WindowFrameUnits::Groups => WindowFrameBound::Following(ScalarValue::Int64(None)), + }, }; let window_frame = WindowFrame::new_bounds(units, lower_bound, upper_bound); diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index 6c3d68a82..5a7c88b5a 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -193,23 +193,21 @@ class CometExecSuite extends CometTestBase { } } - test("Window range frame should fall back to Spark") { + test("Window range frame with long boundary should not fail") { val df = Seq((1L, "1"), (1L, "1"), (2147483650L, "1"), (3L, "2"), (2L, "1"), (2147483650L, "2")) .toDF("key", "value") - checkAnswer( + checkSparkAnswer( df.select( $"key", count("key").over( - Window.partitionBy($"value").orderBy($"key").rangeBetween(0, 2147483648L))), - Seq(Row(1, 3), Row(1, 3), Row(2, 2), Row(3, 2), Row(2147483650L, 1), Row(2147483650L, 1))) - checkAnswer( + Window.partitionBy($"value").orderBy($"key").rangeBetween(0, 2147483648L)))) + checkSparkAnswer( df.select( $"key", count("key").over( - Window.partitionBy($"value").orderBy($"key").rangeBetween(-2147483649L, 0))), - Seq(Row(1, 2), Row(1, 2), Row(2, 3), Row(2147483650L, 2), Row(2147483650L, 4), Row(3, 1))) + Window.partitionBy($"value").orderBy($"key").rangeBetween(-2147483649L, 0)))) } test("Unsupported window expression should fall back to Spark") {