Skip to content

Commit

Permalink
feat: Add support for TryCast expression in Spark 3.2 and 3.3 (apache…
Browse files Browse the repository at this point in the history
…#416)

* working on trycast

* code refactor

* compilation fix

* bug fixes and supporting try_Cast

* removing trycast var and comment

* removing issue comment

* adding comments

(cherry picked from commit ddf6a6f)
  • Loading branch information
vaibhawvipul authored and Huaxin Gao committed May 23, 2024
1 parent 12a85ce commit c51fe13
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 55 deletions.
97 changes: 56 additions & 41 deletions spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
Original file line number Diff line number Diff line change
Expand Up @@ -604,6 +604,52 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim

def exprToProtoInternal(expr: Expression, inputs: Seq[Attribute]): Option[Expr] = {
SQLConf.get

def handleCast(
child: Expression,
inputs: Seq[Attribute],
dt: DataType,
timeZoneId: Option[String],
actualEvalModeStr: String): Option[Expr] = {

val childExpr = exprToProtoInternal(child, inputs)
if (childExpr.isDefined) {
val castSupport =
CometCast.isSupported(child.dataType, dt, timeZoneId, actualEvalModeStr)

def getIncompatMessage(reason: Option[String]): String =
"Comet does not guarantee correct results for cast " +
s"from ${child.dataType} to $dt " +
s"with timezone $timeZoneId and evalMode $actualEvalModeStr" +
reason.map(str => s" ($str)").getOrElse("")

castSupport match {
case Compatible(_) =>
castToProto(timeZoneId, dt, childExpr, actualEvalModeStr)
case Incompatible(reason) =>
if (CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.get()) {
logWarning(getIncompatMessage(reason))
castToProto(timeZoneId, dt, childExpr, actualEvalModeStr)
} else {
withInfo(
expr,
s"${getIncompatMessage(reason)}. To enable all incompatible casts, set " +
s"${CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key}=true")
None
}
case Unsupported =>
withInfo(
expr,
s"Unsupported cast from ${child.dataType} to $dt " +
s"with timezone $timeZoneId and evalMode $actualEvalModeStr")
None
}
} else {
withInfo(expr, child)
None
}
}

expr match {
case a @ Alias(_, _) =>
val r = exprToProtoInternal(a.child, inputs)
Expand All @@ -617,50 +663,19 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
val value = cast.eval()
exprToProtoInternal(Literal(value, dataType), inputs)

case UnaryExpression(child) if expr.prettyName == "trycast" =>
val timeZoneId = SQLConf.get.sessionLocalTimeZone
handleCast(child, inputs, expr.dataType, Some(timeZoneId), "TRY")

case Cast(child, dt, timeZoneId, evalMode) =>
val childExpr = exprToProtoInternal(child, inputs)
if (childExpr.isDefined) {
val evalModeStr = if (evalMode.isInstanceOf[Boolean]) {
// Spark 3.2 & 3.3 has ansiEnabled boolean
if (evalMode.asInstanceOf[Boolean]) "ANSI" else "LEGACY"
} else {
// Spark 3.4+ has EvalMode enum with values LEGACY, ANSI, and TRY
evalMode.toString
}
val castSupport =
CometCast.isSupported(child.dataType, dt, timeZoneId, evalModeStr)

def getIncompatMessage(reason: Option[String]) =
"Comet does not guarantee correct results for cast " +
s"from ${child.dataType} to $dt " +
s"with timezone $timeZoneId and evalMode $evalModeStr" +
reason.map(str => s" ($str)").getOrElse("")

castSupport match {
case Compatible(_) =>
castToProto(timeZoneId, dt, childExpr, evalModeStr)
case Incompatible(reason) =>
if (CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.get()) {
logWarning(getIncompatMessage(reason))
castToProto(timeZoneId, dt, childExpr, evalModeStr)
} else {
withInfo(
expr,
s"${getIncompatMessage(reason)}. To enable all incompatible casts, set " +
s"${CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key}=true")
None
}
case Unsupported =>
withInfo(
expr,
s"Unsupported cast from ${child.dataType} to $dt " +
s"with timezone $timeZoneId and evalMode $evalModeStr")
None
}
val evalModeStr = if (evalMode.isInstanceOf[Boolean]) {
// Spark 3.2 & 3.3 has ansiEnabled boolean
if (evalMode.asInstanceOf[Boolean]) "ANSI" else "LEGACY"
} else {
withInfo(expr, child)
None
// Spark 3.4+ has EvalMode enum with values LEGACY, ANSI, and TRY
evalMode.toString
}
handleCast(child, inputs, dt, timeZoneId, evalModeStr)

case add @ Add(left, right, _) if supportedDataType(left.dataType) =>
val leftExpr = exprToProtoInternal(left, inputs)
Expand Down
22 changes: 8 additions & 14 deletions spark/src/test/scala/org/apache/comet/CometCastSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -886,10 +886,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {

private def castTest(input: DataFrame, toType: DataType): Unit = {

// we do not support the TryCast expression in Spark 3.2 and 3.3
// https://github.com/apache/datafusion-comet/issues/374
val testTryCast = CometSparkSessionExtensions.isSpark34Plus

// we now support the TryCast expression in Spark 3.2 and 3.3
withTempPath { dir =>
val data = roundtripParquet(input, dir).coalesce(1)
data.createOrReplaceTempView("t")
Expand All @@ -900,11 +897,9 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
checkSparkAnswerAndOperator(df)

// try_cast() should always return null for invalid inputs
if (testTryCast) {
val df2 =
spark.sql(s"select a, try_cast(a as ${toType.sql}) from t order by a")
checkSparkAnswerAndOperator(df2)
}
val df2 =
spark.sql(s"select a, try_cast(a as ${toType.sql}) from t order by a")
checkSparkAnswerAndOperator(df2)
}

// with ANSI enabled, we should produce the same exception as Spark
Expand Down Expand Up @@ -963,11 +958,10 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}

// try_cast() should always return null for invalid inputs
if (testTryCast) {
val df2 =
spark.sql(s"select a, try_cast(a as ${toType.sql}) from t order by a")
checkSparkAnswerAndOperator(df2)
}
val df2 =
spark.sql(s"select a, try_cast(a as ${toType.sql}) from t order by a")
checkSparkAnswerAndOperator(df2)

}
}
}
Expand Down

0 comments on commit c51fe13

Please sign in to comment.