diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 728727235..26eefef2e 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -857,61 +857,52 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim case EqualTo(left, right) => val zero = Literal.default(left.dataType) val negZero = UnaryMinus(zero) - if (left.dataType == DoubleType || left.dataType == FloatType) { - if (right == negZero) { - return Some( - ExprOuterClass.Expr - .newBuilder() - .setEq( - ExprOuterClass.Equal - .newBuilder() - .setLeft(exprToProtoInternal(left, inputs).get) - .setRight(exprToProtoInternal(Abs(right).child, inputs).get)) - .build()) - } else if (left == negZero) { - return Some( - ExprOuterClass.Expr - .newBuilder() - .setEq( - ExprOuterClass.Equal - .newBuilder() - .setLeft(exprToProtoInternal(Abs(left).child, inputs).get) - .setRight(exprToProtoInternal(right, inputs).get)) - .build()) - } else { - Some( - ExprOuterClass.Expr + + def buildEqualExpr(leftExpr: Expr, rightExpr: Expr): ExprOuterClass.Expr = { + ExprOuterClass.Expr + .newBuilder() + .setEq( + ExprOuterClass.Equal .newBuilder() - .setEq( - ExprOuterClass.Equal - .newBuilder() - .setLeft(exprToProtoInternal(left, inputs).get) - .setRight(exprToProtoInternal(right, inputs).get)) - .build()) - } + .setLeft(leftExpr) + .setRight(rightExpr)) + .build() } - var leftExpr, rightExpr: Option[Expr] = None if (left.dataType == DoubleType || left.dataType == FloatType) { - leftExpr = exprToProtoInternal(If(EqualTo(left, negZero), zero, left), inputs) - rightExpr = exprToProtoInternal(If(EqualTo(right, negZero), zero, right), inputs) - } else { - leftExpr = exprToProtoInternal(left, inputs) - rightExpr = exprToProtoInternal(right, inputs) + (left, right) match { + case (`negZero`, _) => + return Some( + buildEqualExpr( + exprToProtoInternal(Abs(left).child, inputs).get, + exprToProtoInternal(right, inputs).get)) + case (_, `negZero`) => + return Some( + buildEqualExpr( + exprToProtoInternal(left, inputs).get, + exprToProtoInternal(Abs(right).child, inputs).get)) + case _ => + Some( + buildEqualExpr( + exprToProtoInternal(left, inputs).get, + exprToProtoInternal(right, inputs).get)) + } } - if (leftExpr.isDefined && rightExpr.isDefined) { - val builder = ExprOuterClass.Equal.newBuilder() - builder.setLeft(leftExpr.get) - builder.setRight(rightExpr.get) - Some( - ExprOuterClass.Expr - .newBuilder() - .setEq(builder) - .build()) - } else { - withInfo(expr, left, right) - None + val (leftExpr, rightExpr) = + if (left.dataType == DoubleType || left.dataType == FloatType) { + ( + exprToProtoInternal(If(EqualTo(left, negZero), zero, left), inputs), + exprToProtoInternal(If(EqualTo(right, negZero), zero, right), inputs)) + } else { + (exprToProtoInternal(left, inputs), exprToProtoInternal(right, inputs)) + } + + (leftExpr, rightExpr) match { + case (Some(l), Some(r)) => Some(buildEqualExpr(l, r)) + case _ => + withInfo(expr, left, right) + None } case Not(EqualTo(left, right)) => diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 6d2243d2d..b6b6a393e 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -856,6 +856,30 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("zero equality") { + withParquetTable( + Seq( + (-0.0, 0.0), + (0.0, -0.0), + (-0.0, -0.0), + (0.0, 0.0), + (1.0, 2.0), + (1.0, 1.0), + (1.0, 0.0), + (0.0, 1.0), + (-0.0, 1.0), + (1.0, -0.0), + (1.0, -1.0), + (-1.0, 1.0), + (-1.0, -0.0), + (-1.0, -1.0), + (-1.0, 0.0), + (0.0, -1.0)), + "t") { + checkSparkAnswerAndOperator("SELECT _1 == _2 FROM t") + } + } + test("remainder") { val query = "SELECT _1, _2, _1 % _2 FROM t" withParquetTable(Seq((21840, -0.0), (21840, 5.0)), "t") {