diff --git a/spark/src/main/scala/org/apache/comet/expressions/CometEvalMode.scala b/spark/src/main/scala/org/apache/comet/expressions/CometEvalMode.scala index c89891fe1..cc5dd47e1 100644 --- a/spark/src/main/scala/org/apache/comet/expressions/CometEvalMode.scala +++ b/spark/src/main/scala/org/apache/comet/expressions/CometEvalMode.scala @@ -19,8 +19,6 @@ package org.apache.comet.expressions -import java.util.Locale - import org.apache.spark.sql.internal.SQLConf /** @@ -48,15 +46,5 @@ object CometEvalMode extends Enumeration { LEGACY } - def fromString(str: String): CometEvalMode.Value = { - str.toUpperCase(Locale.ROOT) match { - case "LEGACY" => CometEvalMode.LEGACY - case "TRY" => CometEvalMode.TRY - case "ANSI" => CometEvalMode.ANSI - case _ => - throw new IllegalArgumentException( - s"Invalid eval mode '$str' " - ) // Assuming we want to catch errors strictly - } - } + def fromString(str: String): CometEvalMode.Value = CometEvalMode.withName(str) } diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index a7aa3bb19..5f9c91061 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -22,7 +22,7 @@ package org.apache.comet.serde import scala.collection.JavaConverters._ import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.expressions.{EvalMode, _} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, BitAndAgg, BitOrAgg, BitXorAgg, Corr, Count, CovPopulation, CovSample, Final, First, Last, Max, Min, Partial, StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp} import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.catalyst.optimizer.{BuildRight, NormalizeNaNAndZero} @@ -1584,7 +1584,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim val optExpr = scalarExprToProto("pow", leftExpr, rightExpr) optExprWithInfo(optExpr, expr, left, right) - case r: Round if r.ansiEnabled && !CometConf.COMET_ANSI_MODE_ENABLED.get => + case r: Round if r.ansiEnabled => // https://github.com/apache/datafusion-comet/issues/466 withInfo( r, @@ -1592,45 +1592,60 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim s"Set ${CometConf.COMET_ANSI_MODE_ENABLED.key}=true to enable it anyway") None - // round function for Spark 3.2 does not allow negative round target scale. In addition, - // it has different result precision/scale for decimals. Supporting only 3.3 and above. - case r: Round if !isSpark32 => - // _scale s a constant, copied from Spark's RoundBase because it is a protected val - val scaleV: Any = r.scale.eval(EmptyRow) - val _scale: Int = scaleV.asInstanceOf[Int] + case r: Round if isSpark32 => + // round function for Spark 3.2 does not allow negative round target scale. In addition, + // it has different result precision/scale for decimals. Supporting only 3.3 and above. + withInfo(r, "Round not supported prior to Spark 3.3") + None - lazy val childExpr = exprToProtoInternal(r.child, inputs) - r.child.dataType match { - case t: DecimalType if t.scale < 0 => // Spark disallows negative scale SPARK-30252 - withInfo(r, "Decimal type has negative scale") - None - case _ if scaleV == null => - exprToProtoInternal(Literal(null), inputs) - case _: ByteType | ShortType | IntegerType | LongType if _scale >= 0 => - childExpr // _scale(I.e. decimal place) >= 0 is a no-op for integer types in Spark - case _: FloatType | DoubleType => - // We cannot properly match with the Spark behavior for floating-point numbers. - // Spark uses BigDecimal for rounding float/double, and BigDecimal fist converts a - // double to string internally in order to create its own internal representation. - // The problem is BigDecimal uses java.lang.Double.toString() and it has complicated - // rounding algorithm. E.g. -5.81855622136895E8 is actually - // -581855622.13689494132995605468750. Note the 5th fractional digit is 4 instead of - // 5. Java(Scala)'s toString() rounds it up to -581855622.136895. This makes a - // difference when rounding at 5th digit, I.e. round(-5.81855622136895E8, 5) should be - // -5.818556221369E8, instead of -5.8185562213689E8. There is also an example that - // toString() does NOT round up. 6.1317116247283497E18 is 6131711624728349696. It can - // be rounded up to 6.13171162472835E18 that still represents the same double number. - // I.e. 6.13171162472835E18 == 6.1317116247283497E18. However, toString() does not. - // That results in round(6.1317116247283497E18, -5) == 6.1317116247282995E18 instead - // of 6.1317116247283999E18. - withInfo(r, "Comet does not support Spark's BigDecimal rounding") - None - case _ => - // `scale` must be Int64 type in DataFusion - val scaleExpr = exprToProtoInternal(Literal(_scale.toLong, LongType), inputs) - val optExpr = - scalarExprToProtoWithReturnType("round", r.dataType, childExpr, scaleExpr) - optExprWithInfo(optExpr, expr, r.child) + case r @ Round(_, _, ansiEnabledOrEvalMode) => + val evalMode = getCometEvalMode(ansiEnabledOrEvalMode) + if (evalMode == CometEvalMode.ANSI && !CometConf.COMET_ANSI_MODE_ENABLED.get) { + // https://github.com/apache/datafusion-comet/issues/466 + withInfo( + r, + "Round does not support ANSI mode. " + + s"Set ${CometConf.COMET_ANSI_MODE_ENABLED.key}=true to enable it anyway") + None + } else { + // _scale s a constant, copied from Spark's RoundBase because it is a protected val + val scaleV: Any = r.scale.eval(EmptyRow) + val _scale: Int = scaleV.asInstanceOf[Int] + + lazy val childExpr = exprToProtoInternal(r.child, inputs) + r.child.dataType match { + case t: DecimalType if t.scale < 0 => // Spark disallows negative scale SPARK-30252 + withInfo(r, "Decimal type has negative scale") + None + case _ if scaleV == null => + exprToProtoInternal(Literal(null), inputs) + case _: ByteType | ShortType | IntegerType | LongType if _scale >= 0 => + childExpr // _scale(I.e. decimal place) >= 0 is a no-op for integer types in Spark + case _: FloatType | DoubleType => + // We cannot properly match with the Spark behavior for floating-point numbers. + // Spark uses BigDecimal for rounding float/double, and BigDecimal fist converts a + // double to string internally in order to create its own internal representation. + // The problem is BigDecimal uses java.lang.Double.toString() and it has + // rounding algorithm. E.g. -5.81855622136895E8 is actually + // -581855622.13689494132995605468750. Note the 5th fractional digit is 4 instead + // of 5. Java(Scala)'s toString() rounds it up to -581855622.136895. This makes a + // difference when rounding at 5th digit, I.e. round(-5.81855622136895E8, 5) should + // be -5.818556221369E8, instead of -5.8185562213689E8. There is also an example + // that toString() does NOT round up. 6.1317116247283497E18 is 6131711624728349696. + // It can be rounded up to 6.13171162472835E18 that still represents the same + // double number. I.e. 6.13171162472835E18 == 6.1317116247283497E18. However, + // toString() does not. + // That results in round(6.1317116247283497E18, -5) == 6.1317116247282995E18 + // instead 6.1317116247283999E18. + withInfo(r, "Comet does not support Spark's BigDecimal rounding") + None + case _ => + // `scale` must be Int64 type in DataFusion + val scaleExpr = exprToProtoInternal(Literal(_scale.toLong, LongType), inputs) + val optExpr = + scalarExprToProtoWithReturnType("round", r.dataType, childExpr, scaleExpr) + optExprWithInfo(optExpr, expr, r.child) + } } case Signum(child) => @@ -2201,7 +2216,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim exprToProtoInternal(newExpr, input) } - private def getCometEvalMode(evalMode: EvalMode.Value): CometEvalMode.Value = { + private def getCometEvalMode(evalMode: Any): CometEvalMode.Value = { if (evalMode.isInstanceOf[Boolean]) { // Spark 3.2 & 3.3 has ansiEnabled boolean CometEvalMode.fromBoolean(evalMode.asInstanceOf[Boolean])