diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 46eb1b0035..a481d8422d 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -1376,7 +1376,13 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { case ShiftRight(left, right) => val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(right, inputs) + val rightExpr = if (left.dataType == LongType) { + // DataFusion bitwise shift right expression requires + // same data type between left and right side + exprToProtoInternal(Cast(right, LongType), inputs) + } else { + exprToProtoInternal(right, inputs) + } if (leftExpr.isDefined && rightExpr.isDefined) { val builder = ExprOuterClass.BitwiseShiftRight.newBuilder() @@ -1394,7 +1400,13 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { case ShiftLeft(left, right) => val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(right, inputs) + val rightExpr = if (left.dataType == LongType) { + // DataFusion bitwise shift right expression requires + // same data type between left and right side + exprToProtoInternal(Cast(right, LongType), inputs) + } else { + exprToProtoInternal(right, inputs) + } if (leftExpr.isDefined && rightExpr.isDefined) { val builder = ExprOuterClass.BitwiseShiftLeft.newBuilder() diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 3f29e950ec..2609bd3efe 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -34,6 +34,26 @@ import org.apache.comet.CometSparkSessionExtensions.{isSpark32, isSpark34Plus} class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { import testImplicits._ + test("bitwise shift with different left/right types") { + Seq(false, true).foreach { dictionary => + withSQLConf("parquet.enable.dictionary" -> dictionary.toString) { + val table = "test" + withTable(table) { + sql(s"create table $table(col1 long, col2 int) using parquet") + sql(s"insert into $table values(1111, 2)") + sql(s"insert into $table values(1111, 2)") + sql(s"insert into $table values(3333, 4)") + sql(s"insert into $table values(5555, 6)") + + checkSparkAnswerAndOperator( + s"SELECT shiftright(col1, 2), shiftright(col1, col2) FROM $table") + checkSparkAnswerAndOperator( + s"SELECT shiftleft(col1, 2), shiftleft(col1, col2) FROM $table") + } + } + } + } + test("basic data type support") { Seq(true, false).foreach { dictionaryEnabled => withTempDir { dir =>