Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use IfExpr to check when input to log2 is <=0 and return null #506

Merged
merged 8 commits into from
Jul 15, 2024
14 changes: 11 additions & 3 deletions spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
Original file line number Diff line number Diff line change
@@ -1703,18 +1703,21 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
optExprWithInfo(optExpr, expr, child)
}

// The expression for `log` functions is defined as null on numbers less than or equal
// to 0. This matches Spark and Hive behavior, where non positive values eval to null
// instead of NaN or -Infinity
case Log(child) =>
val childExpr = exprToProtoInternal(child, inputs)
val childExpr = exprToProtoInternal(nullIfNegative(child), inputs)
val optExpr = scalarExprToProto("ln", childExpr)
optExprWithInfo(optExpr, expr, child)

case Log10(child) =>
val childExpr = exprToProtoInternal(child, inputs)
val childExpr = exprToProtoInternal(nullIfNegative(child), inputs)
val optExpr = scalarExprToProto("log10", childExpr)
optExprWithInfo(optExpr, expr, child)

case Log2(child) =>
val childExpr = exprToProtoInternal(child, inputs)
val childExpr = exprToProtoInternal(nullIfNegative(child), inputs)
val optExpr = scalarExprToProto("log2", childExpr)
optExprWithInfo(optExpr, expr, child)

@@ -2393,6 +2396,11 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
expression
}

def nullIfNegative(expression: Expression): Expression = {
val zero = Literal.default(expression.dataType)
If(LessThanOrEqual(expression, zero), Literal.create(null, expression.dataType), expression)
}

/**
* Returns true if given datatype is supported as a key in DataFusion sort merge join.
*/
Original file line number Diff line number Diff line change
@@ -819,7 +819,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
Seq("true", "false").foreach { dictionary =>
withSQLConf("parquet.enable.dictionary" -> dictionary) {
withParquetTable(
(0 until 5).map(i => (i.toDouble + 0.3, i.toDouble + 0.8)),
(-5 until 5).map(i => (i.toDouble + 0.3, i.toDouble + 0.8)),
"tbl",
withDictionary = dictionary.toBoolean) {
checkSparkAnswerWithTol(