diff --git a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala index 77ee4fc41..7ce74c87d 100644 --- a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala +++ b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala @@ -29,13 +29,10 @@ sealed trait SupportLevel object Compatible extends SupportLevel /** We support this feature but results can be different from Spark */ -object Incompatible extends SupportLevel +case class Incompatible(reason: Option[String] = None) extends SupportLevel /** We do not support this feature */ -object Unsupported extends SupportLevel - -/** We do not support this feature and we explain why */ -case class UnsupportedWithReason(reason: String) extends SupportLevel +case class Unsupported(reason: Option[String] = None) extends SupportLevel object CometCast { @@ -54,20 +51,18 @@ object CometCast { case (dt: DataType, _) if dt.typeName == "timestamp_ntz" => toType match { case DataTypes.TimestampType | DataTypes.DateType | DataTypes.StringType => - Incompatible + Incompatible() case _ => - Unsupported + Unsupported() } - case (DataTypes.DoubleType, _: DecimalType) => - Incompatible case (DataTypes.TimestampType, DataTypes.LongType) => - Incompatible + Incompatible() case (DataTypes.BinaryType, DataTypes.StringType) => - Incompatible + Incompatible() case (_: DecimalType, _: DecimalType) => // TODO we need to file an issue for adding specific tests for casting // between decimal types with different precision and scale - Incompatible + Incompatible() case (DataTypes.StringType, _) => canCastFromString(toType, timeZoneId, evalMode) case (_, DataTypes.StringType) => @@ -86,7 +81,7 @@ object CometCast { canCastFromFloat(toType) case (DataTypes.DoubleType, _) => canCastFromDouble(toType) - case _ => Unsupported + case _ => Unsupported() } } @@ -104,22 +99,22 @@ object CometCast { Compatible case DataTypes.FloatType | DataTypes.DoubleType => // https://github.com/apache/datafusion-comet/issues/326 - Unsupported + Unsupported() case _: DecimalType => // https://github.com/apache/datafusion-comet/issues/325 - Unsupported + Unsupported() case DataTypes.DateType => // https://github.com/apache/datafusion-comet/issues/327 - Unsupported + Unsupported() case DataTypes.TimestampType if !timeZoneId.contains("UTC") => - UnsupportedWithReason(s"Unsupported timezone $timeZoneId") + Unsupported(Some(s"Unsupported timezone $timeZoneId")) case DataTypes.TimestampType if evalMode == "ANSI" => - UnsupportedWithReason(s"ANSI mode not supported") + Unsupported(Some(s"ANSI mode not supported")) case DataTypes.TimestampType => // https://github.com/apache/datafusion-comet/issues/328 Compatible case _ => - Unsupported + Unsupported() } } @@ -133,8 +128,8 @@ object CometCast { case DataTypes.TimestampType => Compatible case DataTypes.FloatType | DataTypes.DoubleType => // https://github.com/apache/datafusion-comet/issues/326 - Incompatible - case _ => Unsupported + Incompatible() + case _ => Unsupported() } } @@ -144,13 +139,13 @@ object CometCast { DataTypes.IntegerType => // https://github.com/apache/datafusion-comet/issues/352 // this seems like an edge case that isn't important for us to support - Unsupported + Unsupported() case DataTypes.LongType => // https://github.com/apache/datafusion-comet/issues/352 Compatible case DataTypes.StringType => Compatible case DataTypes.DateType => Compatible - case _ => Unsupported + case _ => Unsupported() } } @@ -158,7 +153,7 @@ object CometCast { case DataTypes.ByteType | DataTypes.ShortType | DataTypes.IntegerType | DataTypes.LongType | DataTypes.FloatType | DataTypes.DoubleType => Compatible - case _ => Unsupported + case _ => Unsupported() } private def canCastFromInt(toType: DataType): SupportLevel = toType match { @@ -166,22 +161,23 @@ object CometCast { DataTypes.IntegerType | DataTypes.LongType | DataTypes.FloatType | DataTypes.DoubleType | _: DecimalType => Compatible - case _ => Unsupported + case _ => Unsupported() } private def canCastFromFloat(toType: DataType): SupportLevel = toType match { case DataTypes.BooleanType | DataTypes.DoubleType => Compatible - case _ => Unsupported + case _ => Unsupported() } private def canCastFromDouble(toType: DataType): SupportLevel = toType match { case DataTypes.BooleanType | DataTypes.FloatType => Compatible - case _ => Unsupported + case _: DecimalType => Incompatible() + case _ => Unsupported() } private def canCastFromDecimal(toType: DataType): SupportLevel = toType match { case DataTypes.FloatType | DataTypes.DoubleType => Compatible - case _ => Unsupported + case _ => Unsupported() } } diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index e3bbd6de7..12ed31ff4 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -43,7 +43,7 @@ import org.apache.spark.unsafe.types.UTF8String import org.apache.comet.CometConf import org.apache.comet.CometSparkSessionExtensions.{isCometOperatorEnabled, isCometScan, isSpark32, isSpark34Plus, withInfo} -import org.apache.comet.expressions.{CometCast, Compatible, Incompatible, Unsupported, UnsupportedWithReason} +import org.apache.comet.expressions.{CometCast, Compatible, Incompatible, Unsupported} import org.apache.comet.serde.ExprOuterClass.{AggExpr, DataType => ProtoDataType, Expr, ScalarFunc} import org.apache.comet.serde.ExprOuterClass.DataType.{DataTypeInfo, DecimalInfo, ListInfo, MapInfo, StructInfo} import org.apache.comet.serde.OperatorOuterClass.{AggregateMode => CometAggregateMode, JoinType, Operator} @@ -591,20 +591,17 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { castSupport match { case Compatible => castToProto(timeZoneId, dt, childExpr, evalModeStr) - case Incompatible if CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.get() => - logWarning(s"Calling incompatible CAST expression: $cast") + case Incompatible(reason) if CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.get() => + logWarning(s"Calling incompatible CAST expression: $cast" + + reason.map(str => s" ($str)").getOrElse("")) castToProto(timeZoneId, dt, childExpr, evalModeStr) - case UnsupportedWithReason(reason) => + case Unsupported(reason) => withInfo( expr, s"Unsupported cast from ${child.dataType} to $dt " + - s"with timezone $timeZoneId and evalMode $evalModeStr: $reason") - None - case Unsupported => - withInfo( - expr, - s"Unsupported cast from ${child.dataType} to $dt " + - s"with timezone $timeZoneId and evalMode $evalModeStr") + s"with timezone $timeZoneId and evalMode $evalModeStr" + + reason.map(str => s" ($str)").getOrElse("") + ) None } } else {