diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index fa6542f73..f2870c9d3 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -822,8 +822,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim } None - case rem @ Remainder(left, right, _) - if supportedDataType(left.dataType) && !decimalBeforeSpark34(left.dataType) => + case rem @ Remainder(left, right, _) if supportedDataType(left.dataType) => val leftExpr = exprToProtoInternal(left, inputs) val rightExpr = exprToProtoInternal(nullIfWhenPrimitive(right), inputs) @@ -905,7 +904,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim (left != negZeroLeft && right != negZeroRight) && (left != leftZero && right != rightZero) && (left != doubleNan && right != doubleNan) && - (left != floatNan && right != floatNan)) { + (left != floatNan && right != floatNan) && isSpark34Plus) { withInfo(expr, left, right) return None } @@ -915,16 +914,18 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim } } - val leftExpr = if (left.dataType == DoubleType || left.dataType == FloatType) { - exprToProtoInternal(If(EqualTo(left, negZeroLeft), leftZero, left), inputs) - } else { - exprToProtoInternal(left, inputs) - } - val rightExpr = if (right.dataType == DoubleType || right.dataType == FloatType) { - exprToProtoInternal(If(EqualTo(right, negZeroRight), rightZero, right), inputs) - } else { - exprToProtoInternal(right, inputs) - } + val leftExpr = + if (left.dataType == DoubleType || left.dataType == FloatType) { + exprToProtoInternal(If(EqualTo(left, negZeroLeft), leftZero, left), inputs) + } else { + exprToProtoInternal(left, inputs) + } + val rightExpr = + if (right.dataType == DoubleType || right.dataType == FloatType) { + exprToProtoInternal(If(EqualTo(right, negZeroRight), rightZero, right), inputs) + } else { + exprToProtoInternal(right, inputs) + } buildEqualExpr(leftExpr, rightExpr)