From 992ad20e4fb35eadda7098f30e1884bd0a1f27da Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 4 Jun 2024 15:43:05 -0600 Subject: [PATCH] simplify PR --- .../apache/comet/serde/QueryPlanSerde.scala | 95 ++++++++----------- 1 file changed, 42 insertions(+), 53 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 432a43a41..e528b5c41 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -727,6 +727,10 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim withInfo(expr, ansiNotSupported) None + case expr: Round if evalMode(expr) == CometEvalMode.ANSI && !cometAnsiEnabled => + withInfo(expr, ansiNotSupported) + None + case add @ Add(left, right, _) if supportedDataType(left.dataType) => val leftExpr = exprToProtoInternal(left, inputs) val rightExpr = exprToProtoInternal(right, inputs) @@ -1612,60 +1616,45 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim val optExpr = scalarExprToProto("pow", leftExpr, rightExpr) optExprWithInfo(optExpr, expr, left, right) - case r: Round if isSpark32 => - // round function for Spark 3.2 does not allow negative round target scale. In addition, - // it has different result precision/scale for decimals. Supporting only 3.3 and above. - withInfo(r, "Round not supported prior to Spark 3.3") - None + // round function for Spark 3.2 does not allow negative round target scale. In addition, + // it has different result precision/scale for decimals. Supporting only 3.3 and above. + case r: Round if !isSpark32 => + // _scale s a constant, copied from Spark's RoundBase because it is a protected val + val scaleV: Any = r.scale.eval(EmptyRow) + val _scale: Int = scaleV.asInstanceOf[Int] - case r: Round => - val cometEvalMode = evalMode(r) - if (cometEvalMode == CometEvalMode.ANSI && !cometAnsiEnabled) { - // https://github.com/apache/datafusion-comet/issues/466 - withInfo( - r, - "Round does not support ANSI mode. " + - s"Set ${CometConf.COMET_ANSI_MODE_ENABLED.key}=true to enable it anyway") - None - } else { - // _scale s a constant, copied from Spark's RoundBase because it is a protected val - val scaleV: Any = r.scale.eval(EmptyRow) - val _scale: Int = scaleV.asInstanceOf[Int] - - lazy val childExpr = exprToProtoInternal(r.child, inputs) - r.child.dataType match { - case t: DecimalType if t.scale < 0 => // Spark disallows negative scale SPARK-30252 - withInfo(r, "Decimal type has negative scale") - None - case _ if scaleV == null => - exprToProtoInternal(Literal(null), inputs) - case _: ByteType | ShortType | IntegerType | LongType if _scale >= 0 => - childExpr // _scale(I.e. decimal place) >= 0 is a no-op for integer types in Spark - case _: FloatType | DoubleType => - // We cannot properly match with the Spark behavior for floating-point numbers. - // Spark uses BigDecimal for rounding float/double, and BigDecimal fist converts a - // double to string internally in order to create its own internal representation. - // The problem is BigDecimal uses java.lang.Double.toString() and it has - // rounding algorithm. E.g. -5.81855622136895E8 is actually - // -581855622.13689494132995605468750. Note the 5th fractional digit is 4 instead - // of 5. Java(Scala)'s toString() rounds it up to -581855622.136895. This makes a - // difference when rounding at 5th digit, I.e. round(-5.81855622136895E8, 5) should - // be -5.818556221369E8, instead of -5.8185562213689E8. There is also an example - // that toString() does NOT round up. 6.1317116247283497E18 is 6131711624728349696. - // It can be rounded up to 6.13171162472835E18 that still represents the same - // double number. I.e. 6.13171162472835E18 == 6.1317116247283497E18. However, - // toString() does not. - // That results in round(6.1317116247283497E18, -5) == 6.1317116247282995E18 - // instead 6.1317116247283999E18. - withInfo(r, "Comet does not support Spark's BigDecimal rounding") - None - case _ => - // `scale` must be Int64 type in DataFusion - val scaleExpr = exprToProtoInternal(Literal(_scale.toLong, LongType), inputs) - val optExpr = - scalarExprToProtoWithReturnType("round", r.dataType, childExpr, scaleExpr) - optExprWithInfo(optExpr, expr, r.child) - } + lazy val childExpr = exprToProtoInternal(r.child, inputs) + r.child.dataType match { + case t: DecimalType if t.scale < 0 => // Spark disallows negative scale SPARK-30252 + withInfo(r, "Decimal type has negative scale") + None + case _ if scaleV == null => + exprToProtoInternal(Literal(null), inputs) + case _: ByteType | ShortType | IntegerType | LongType if _scale >= 0 => + childExpr // _scale(I.e. decimal place) >= 0 is a no-op for integer types in Spark + case _: FloatType | DoubleType => + // We cannot properly match with the Spark behavior for floating-point numbers. + // Spark uses BigDecimal for rounding float/double, and BigDecimal fist converts a + // double to string internally in order to create its own internal representation. + // The problem is BigDecimal uses java.lang.Double.toString() and it has complicated + // rounding algorithm. E.g. -5.81855622136895E8 is actually + // -581855622.13689494132995605468750. Note the 5th fractional digit is 4 instead of + // 5. Java(Scala)'s toString() rounds it up to -581855622.136895. This makes a + // difference when rounding at 5th digit, I.e. round(-5.81855622136895E8, 5) should be + // -5.818556221369E8, instead of -5.8185562213689E8. There is also an example that + // toString() does NOT round up. 6.1317116247283497E18 is 6131711624728349696. It can + // be rounded up to 6.13171162472835E18 that still represents the same double number. + // I.e. 6.13171162472835E18 == 6.1317116247283497E18. However, toString() does not. + // That results in round(6.1317116247283497E18, -5) == 6.1317116247282995E18 instead + // of 6.1317116247283999E18. + withInfo(r, "Comet does not support Spark's BigDecimal rounding") + None + case _ => + // `scale` must be Int64 type in DataFusion + val scaleExpr = exprToProtoInternal(Literal(_scale.toLong, LongType), inputs) + val optExpr = + scalarExprToProtoWithReturnType("round", r.dataType, childExpr, scaleExpr) + optExprWithInfo(optExpr, expr, r.child) } case Signum(child) =>