Skip to content

Commit

Permalink
Add CometEvalMode enum
Browse files Browse the repository at this point in the history
  • Loading branch information
andygrove committed Jun 3, 2024
1 parent 5ab3ee0 commit e377e44
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@

package org.apache.comet.expressions

import java.util.Locale

import org.apache.spark.sql.internal.SQLConf

/**
Expand Down Expand Up @@ -48,15 +46,5 @@ object CometEvalMode extends Enumeration {
LEGACY
}

def fromString(str: String): CometEvalMode.Value = {
str.toUpperCase(Locale.ROOT) match {
case "LEGACY" => CometEvalMode.LEGACY
case "TRY" => CometEvalMode.TRY
case "ANSI" => CometEvalMode.ANSI
case _ =>
throw new IllegalArgumentException(
s"Invalid eval mode '$str' "
) // Assuming we want to catch errors strictly
}
}
def fromString(str: String): CometEvalMode.Value = CometEvalMode.withName(str)
}
97 changes: 56 additions & 41 deletions spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ package org.apache.comet.serde
import scala.collection.JavaConverters._

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions.{EvalMode, _}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, BitAndAgg, BitOrAgg, BitXorAgg, Corr, Count, CovPopulation, CovSample, Final, First, Last, Max, Min, Partial, StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp}
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
import org.apache.spark.sql.catalyst.optimizer.{BuildRight, NormalizeNaNAndZero}
Expand Down Expand Up @@ -1584,53 +1584,68 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
val optExpr = scalarExprToProto("pow", leftExpr, rightExpr)
optExprWithInfo(optExpr, expr, left, right)

case r: Round if r.ansiEnabled && !CometConf.COMET_ANSI_MODE_ENABLED.get =>
case r: Round if r.ansiEnabled =>
// https://github.com/apache/datafusion-comet/issues/466
withInfo(
r,
"Round does not support ANSI mode. " +
s"Set ${CometConf.COMET_ANSI_MODE_ENABLED.key}=true to enable it anyway")
None

// round function for Spark 3.2 does not allow negative round target scale. In addition,
// it has different result precision/scale for decimals. Supporting only 3.3 and above.
case r: Round if !isSpark32 =>
// _scale s a constant, copied from Spark's RoundBase because it is a protected val
val scaleV: Any = r.scale.eval(EmptyRow)
val _scale: Int = scaleV.asInstanceOf[Int]
case r: Round if isSpark32 =>
// round function for Spark 3.2 does not allow negative round target scale. In addition,
// it has different result precision/scale for decimals. Supporting only 3.3 and above.
withInfo(r, "Round not supported prior to Spark 3.3")
None

lazy val childExpr = exprToProtoInternal(r.child, inputs)
r.child.dataType match {
case t: DecimalType if t.scale < 0 => // Spark disallows negative scale SPARK-30252
withInfo(r, "Decimal type has negative scale")
None
case _ if scaleV == null =>
exprToProtoInternal(Literal(null), inputs)
case _: ByteType | ShortType | IntegerType | LongType if _scale >= 0 =>
childExpr // _scale(I.e. decimal place) >= 0 is a no-op for integer types in Spark
case _: FloatType | DoubleType =>
// We cannot properly match with the Spark behavior for floating-point numbers.
// Spark uses BigDecimal for rounding float/double, and BigDecimal fist converts a
// double to string internally in order to create its own internal representation.
// The problem is BigDecimal uses java.lang.Double.toString() and it has complicated
// rounding algorithm. E.g. -5.81855622136895E8 is actually
// -581855622.13689494132995605468750. Note the 5th fractional digit is 4 instead of
// 5. Java(Scala)'s toString() rounds it up to -581855622.136895. This makes a
// difference when rounding at 5th digit, I.e. round(-5.81855622136895E8, 5) should be
// -5.818556221369E8, instead of -5.8185562213689E8. There is also an example that
// toString() does NOT round up. 6.1317116247283497E18 is 6131711624728349696. It can
// be rounded up to 6.13171162472835E18 that still represents the same double number.
// I.e. 6.13171162472835E18 == 6.1317116247283497E18. However, toString() does not.
// That results in round(6.1317116247283497E18, -5) == 6.1317116247282995E18 instead
// of 6.1317116247283999E18.
withInfo(r, "Comet does not support Spark's BigDecimal rounding")
None
case _ =>
// `scale` must be Int64 type in DataFusion
val scaleExpr = exprToProtoInternal(Literal(_scale.toLong, LongType), inputs)
val optExpr =
scalarExprToProtoWithReturnType("round", r.dataType, childExpr, scaleExpr)
optExprWithInfo(optExpr, expr, r.child)
case r @ Round(_, _, ansiEnabledOrEvalMode) =>
val evalMode = getCometEvalMode(ansiEnabledOrEvalMode)
if (evalMode == CometEvalMode.ANSI && !CometConf.COMET_ANSI_MODE_ENABLED.get) {
// https://github.com/apache/datafusion-comet/issues/466
withInfo(
r,
"Round does not support ANSI mode. " +
s"Set ${CometConf.COMET_ANSI_MODE_ENABLED.key}=true to enable it anyway")
None
} else {
// _scale s a constant, copied from Spark's RoundBase because it is a protected val
val scaleV: Any = r.scale.eval(EmptyRow)
val _scale: Int = scaleV.asInstanceOf[Int]

lazy val childExpr = exprToProtoInternal(r.child, inputs)
r.child.dataType match {
case t: DecimalType if t.scale < 0 => // Spark disallows negative scale SPARK-30252
withInfo(r, "Decimal type has negative scale")
None
case _ if scaleV == null =>
exprToProtoInternal(Literal(null), inputs)
case _: ByteType | ShortType | IntegerType | LongType if _scale >= 0 =>
childExpr // _scale(I.e. decimal place) >= 0 is a no-op for integer types in Spark
case _: FloatType | DoubleType =>
// We cannot properly match with the Spark behavior for floating-point numbers.
// Spark uses BigDecimal for rounding float/double, and BigDecimal fist converts a
// double to string internally in order to create its own internal representation.
// The problem is BigDecimal uses java.lang.Double.toString() and it has
// rounding algorithm. E.g. -5.81855622136895E8 is actually
// -581855622.13689494132995605468750. Note the 5th fractional digit is 4 instead
// of 5. Java(Scala)'s toString() rounds it up to -581855622.136895. This makes a
// difference when rounding at 5th digit, I.e. round(-5.81855622136895E8, 5) should
// be -5.818556221369E8, instead of -5.8185562213689E8. There is also an example
// that toString() does NOT round up. 6.1317116247283497E18 is 6131711624728349696.
// It can be rounded up to 6.13171162472835E18 that still represents the same
// double number. I.e. 6.13171162472835E18 == 6.1317116247283497E18. However,
// toString() does not.
// That results in round(6.1317116247283497E18, -5) == 6.1317116247282995E18
// instead 6.1317116247283999E18.
withInfo(r, "Comet does not support Spark's BigDecimal rounding")
None
case _ =>
// `scale` must be Int64 type in DataFusion
val scaleExpr = exprToProtoInternal(Literal(_scale.toLong, LongType), inputs)
val optExpr =
scalarExprToProtoWithReturnType("round", r.dataType, childExpr, scaleExpr)
optExprWithInfo(optExpr, expr, r.child)
}
}

case Signum(child) =>
Expand Down Expand Up @@ -2201,7 +2216,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
exprToProtoInternal(newExpr, input)
}

private def getCometEvalMode(evalMode: EvalMode.Value): CometEvalMode.Value = {
private def getCometEvalMode(evalMode: Any): CometEvalMode.Value = {
if (evalMode.isInstanceOf[Boolean]) {
// Spark 3.2 & 3.3 has ansiEnabled boolean
CometEvalMode.fromBoolean(evalMode.asInstanceOf[Boolean])
Expand Down

0 comments on commit e377e44

Please sign in to comment.