From ddf6a6ffe34eb26f065397034618b9249131121b Mon Sep 17 00:00:00 2001 From: Vipul Vaibhaw Date: Thu, 16 May 2024 05:25:40 +0530 Subject: [PATCH] feat: Add support for TryCast expression in Spark 3.2 and 3.3 (#416) * working on trycast * code refactor * compilation fix * bug fixes and supporting try_Cast * removing trycast var and comment * removing issue comment * adding comments --- .../apache/comet/serde/QueryPlanSerde.scala | 97 +++++++++++-------- .../org/apache/comet/CometCastSuite.scala | 22 ++--- 2 files changed, 64 insertions(+), 55 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 7238990ad..cf7c86a9f 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -604,6 +604,52 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim def exprToProtoInternal(expr: Expression, inputs: Seq[Attribute]): Option[Expr] = { SQLConf.get + + def handleCast( + child: Expression, + inputs: Seq[Attribute], + dt: DataType, + timeZoneId: Option[String], + actualEvalModeStr: String): Option[Expr] = { + + val childExpr = exprToProtoInternal(child, inputs) + if (childExpr.isDefined) { + val castSupport = + CometCast.isSupported(child.dataType, dt, timeZoneId, actualEvalModeStr) + + def getIncompatMessage(reason: Option[String]): String = + "Comet does not guarantee correct results for cast " + + s"from ${child.dataType} to $dt " + + s"with timezone $timeZoneId and evalMode $actualEvalModeStr" + + reason.map(str => s" ($str)").getOrElse("") + + castSupport match { + case Compatible(_) => + castToProto(timeZoneId, dt, childExpr, actualEvalModeStr) + case Incompatible(reason) => + if (CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.get()) { + logWarning(getIncompatMessage(reason)) + castToProto(timeZoneId, dt, childExpr, actualEvalModeStr) + } else { + withInfo( + expr, + s"${getIncompatMessage(reason)}. To enable all incompatible casts, set " + + s"${CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key}=true") + None + } + case Unsupported => + withInfo( + expr, + s"Unsupported cast from ${child.dataType} to $dt " + + s"with timezone $timeZoneId and evalMode $actualEvalModeStr") + None + } + } else { + withInfo(expr, child) + None + } + } + expr match { case a @ Alias(_, _) => val r = exprToProtoInternal(a.child, inputs) @@ -617,50 +663,19 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim val value = cast.eval() exprToProtoInternal(Literal(value, dataType), inputs) + case UnaryExpression(child) if expr.prettyName == "trycast" => + val timeZoneId = SQLConf.get.sessionLocalTimeZone + handleCast(child, inputs, expr.dataType, Some(timeZoneId), "TRY") + case Cast(child, dt, timeZoneId, evalMode) => - val childExpr = exprToProtoInternal(child, inputs) - if (childExpr.isDefined) { - val evalModeStr = if (evalMode.isInstanceOf[Boolean]) { - // Spark 3.2 & 3.3 has ansiEnabled boolean - if (evalMode.asInstanceOf[Boolean]) "ANSI" else "LEGACY" - } else { - // Spark 3.4+ has EvalMode enum with values LEGACY, ANSI, and TRY - evalMode.toString - } - val castSupport = - CometCast.isSupported(child.dataType, dt, timeZoneId, evalModeStr) - - def getIncompatMessage(reason: Option[String]) = - "Comet does not guarantee correct results for cast " + - s"from ${child.dataType} to $dt " + - s"with timezone $timeZoneId and evalMode $evalModeStr" + - reason.map(str => s" ($str)").getOrElse("") - - castSupport match { - case Compatible(_) => - castToProto(timeZoneId, dt, childExpr, evalModeStr) - case Incompatible(reason) => - if (CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.get()) { - logWarning(getIncompatMessage(reason)) - castToProto(timeZoneId, dt, childExpr, evalModeStr) - } else { - withInfo( - expr, - s"${getIncompatMessage(reason)}. To enable all incompatible casts, set " + - s"${CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key}=true") - None - } - case Unsupported => - withInfo( - expr, - s"Unsupported cast from ${child.dataType} to $dt " + - s"with timezone $timeZoneId and evalMode $evalModeStr") - None - } + val evalModeStr = if (evalMode.isInstanceOf[Boolean]) { + // Spark 3.2 & 3.3 has ansiEnabled boolean + if (evalMode.asInstanceOf[Boolean]) "ANSI" else "LEGACY" } else { - withInfo(expr, child) - None + // Spark 3.4+ has EvalMode enum with values LEGACY, ANSI, and TRY + evalMode.toString } + handleCast(child, inputs, dt, timeZoneId, evalModeStr) case add @ Add(left, right, _) if supportedDataType(left.dataType) => val leftExpr = exprToProtoInternal(left, inputs) diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala index 50311a9b5..ea3355d05 100644 --- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala @@ -886,10 +886,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { private def castTest(input: DataFrame, toType: DataType): Unit = { - // we do not support the TryCast expression in Spark 3.2 and 3.3 - // https://github.com/apache/datafusion-comet/issues/374 - val testTryCast = CometSparkSessionExtensions.isSpark34Plus - + // we now support the TryCast expression in Spark 3.2 and 3.3 withTempPath { dir => val data = roundtripParquet(input, dir).coalesce(1) data.createOrReplaceTempView("t") @@ -900,11 +897,9 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { checkSparkAnswerAndOperator(df) // try_cast() should always return null for invalid inputs - if (testTryCast) { - val df2 = - spark.sql(s"select a, try_cast(a as ${toType.sql}) from t order by a") - checkSparkAnswerAndOperator(df2) - } + val df2 = + spark.sql(s"select a, try_cast(a as ${toType.sql}) from t order by a") + checkSparkAnswerAndOperator(df2) } // with ANSI enabled, we should produce the same exception as Spark @@ -963,11 +958,10 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { } // try_cast() should always return null for invalid inputs - if (testTryCast) { - val df2 = - spark.sql(s"select a, try_cast(a as ${toType.sql}) from t order by a") - checkSparkAnswerAndOperator(df2) - } + val df2 = + spark.sql(s"select a, try_cast(a as ${toType.sql}) from t order by a") + checkSparkAnswerAndOperator(df2) + } } }