diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 4d592edc1..fa6542f73 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -865,7 +865,6 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim def buildEqualExpr( leftExpr: Option[Expr], rightExpr: Option[Expr]): Option[ExprOuterClass.Expr] = { - if (leftExpr.isDefined && rightExpr.isDefined) { Some( ExprOuterClass.Expr @@ -881,11 +880,10 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim None } } - - if (left.dataType == DoubleType || - left.dataType == FloatType || - right.dataType == DoubleType || - right.dataType == FloatType) { + if ((left.dataType == DoubleType && + right.dataType == DoubleType) || + (left.dataType == FloatType && + right.dataType == FloatType)) { (left, right) match { case (`negZeroLeft`, `negZeroRight`) => return buildEqualExpr( @@ -900,8 +898,14 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim exprToProtoInternal(left, inputs), exprToProtoInternal(Abs(right).child, inputs)) case _ => + val doubleNan = Literal(Double.NaN, DoubleType) + val floatNan = Literal(Float.NaN, FloatType) + if ((left.nullable && !right.nullable) && - (left != leftZero && right != rightZero)) { + (left != negZeroLeft && right != negZeroRight) && + (left != leftZero && right != rightZero) && + (left != doubleNan && right != doubleNan) && + (left != floatNan && right != floatNan)) { withInfo(expr, left, right) return None }