Skip to content

Commit

Permalink
code cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
andygrove committed May 2, 2024
1 parent 5255d6c commit 0c8da56
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 39 deletions.
52 changes: 24 additions & 28 deletions spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,10 @@ sealed trait SupportLevel
object Compatible extends SupportLevel

/** We support this feature but results can be different from Spark */
object Incompatible extends SupportLevel
case class Incompatible(reason: Option[String] = None) extends SupportLevel

/** We do not support this feature */
object Unsupported extends SupportLevel

/** We do not support this feature and we explain why */
case class UnsupportedWithReason(reason: String) extends SupportLevel
case class Unsupported(reason: Option[String] = None) extends SupportLevel

object CometCast {

Expand All @@ -54,20 +51,18 @@ object CometCast {
case (dt: DataType, _) if dt.typeName == "timestamp_ntz" =>
toType match {
case DataTypes.TimestampType | DataTypes.DateType | DataTypes.StringType =>
Incompatible
Incompatible()
case _ =>
Unsupported
Unsupported()
}
case (DataTypes.DoubleType, _: DecimalType) =>
Incompatible
case (DataTypes.TimestampType, DataTypes.LongType) =>
Incompatible
Incompatible()
case (DataTypes.BinaryType, DataTypes.StringType) =>
Incompatible
Incompatible()
case (_: DecimalType, _: DecimalType) =>
// TODO we need to file an issue for adding specific tests for casting
// between decimal types with different precision and scale
Incompatible
Incompatible()
case (DataTypes.StringType, _) =>
canCastFromString(toType, timeZoneId, evalMode)
case (_, DataTypes.StringType) =>
Expand All @@ -86,7 +81,7 @@ object CometCast {
canCastFromFloat(toType)
case (DataTypes.DoubleType, _) =>
canCastFromDouble(toType)
case _ => Unsupported
case _ => Unsupported()
}
}

Expand All @@ -104,22 +99,22 @@ object CometCast {
Compatible
case DataTypes.FloatType | DataTypes.DoubleType =>
// https://github.com/apache/datafusion-comet/issues/326
Unsupported
Unsupported()
case _: DecimalType =>
// https://github.com/apache/datafusion-comet/issues/325
Unsupported
Unsupported()
case DataTypes.DateType =>
// https://github.com/apache/datafusion-comet/issues/327
Unsupported
Unsupported()
case DataTypes.TimestampType if !timeZoneId.contains("UTC") =>
UnsupportedWithReason(s"Unsupported timezone $timeZoneId")
Unsupported(Some(s"Unsupported timezone $timeZoneId"))
case DataTypes.TimestampType if evalMode == "ANSI" =>
UnsupportedWithReason(s"ANSI mode not supported")
Unsupported(Some(s"ANSI mode not supported"))
case DataTypes.TimestampType =>
// https://github.com/apache/datafusion-comet/issues/328
Compatible
case _ =>
Unsupported
Unsupported()
}
}

Expand All @@ -133,8 +128,8 @@ object CometCast {
case DataTypes.TimestampType => Compatible
case DataTypes.FloatType | DataTypes.DoubleType =>
// https://github.com/apache/datafusion-comet/issues/326
Incompatible
case _ => Unsupported
Incompatible()
case _ => Unsupported()
}
}

Expand All @@ -144,44 +139,45 @@ object CometCast {
DataTypes.IntegerType =>
// https://github.com/apache/datafusion-comet/issues/352
// this seems like an edge case that isn't important for us to support
Unsupported
Unsupported()
case DataTypes.LongType =>
// https://github.com/apache/datafusion-comet/issues/352
Compatible
case DataTypes.StringType => Compatible
case DataTypes.DateType => Compatible
case _ => Unsupported
case _ => Unsupported()
}
}

private def canCastFromBoolean(toType: DataType): SupportLevel = toType match {
case DataTypes.ByteType | DataTypes.ShortType | DataTypes.IntegerType | DataTypes.LongType |
DataTypes.FloatType | DataTypes.DoubleType =>
Compatible
case _ => Unsupported
case _ => Unsupported()
}

private def canCastFromInt(toType: DataType): SupportLevel = toType match {
case DataTypes.BooleanType | DataTypes.ByteType | DataTypes.ShortType |
DataTypes.IntegerType | DataTypes.LongType | DataTypes.FloatType | DataTypes.DoubleType |
_: DecimalType =>
Compatible
case _ => Unsupported
case _ => Unsupported()
}

private def canCastFromFloat(toType: DataType): SupportLevel = toType match {
case DataTypes.BooleanType | DataTypes.DoubleType => Compatible
case _ => Unsupported
case _ => Unsupported()
}

private def canCastFromDouble(toType: DataType): SupportLevel = toType match {
case DataTypes.BooleanType | DataTypes.FloatType => Compatible
case _ => Unsupported
case _: DecimalType => Incompatible()
case _ => Unsupported()
}

private def canCastFromDecimal(toType: DataType): SupportLevel = toType match {
case DataTypes.FloatType | DataTypes.DoubleType => Compatible
case _ => Unsupported
case _ => Unsupported()
}

}
19 changes: 8 additions & 11 deletions spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ import org.apache.spark.unsafe.types.UTF8String

import org.apache.comet.CometConf
import org.apache.comet.CometSparkSessionExtensions.{isCometOperatorEnabled, isCometScan, isSpark32, isSpark34Plus, withInfo}
import org.apache.comet.expressions.{CometCast, Compatible, Incompatible, Unsupported, UnsupportedWithReason}
import org.apache.comet.expressions.{CometCast, Compatible, Incompatible, Unsupported}
import org.apache.comet.serde.ExprOuterClass.{AggExpr, DataType => ProtoDataType, Expr, ScalarFunc}
import org.apache.comet.serde.ExprOuterClass.DataType.{DataTypeInfo, DecimalInfo, ListInfo, MapInfo, StructInfo}
import org.apache.comet.serde.OperatorOuterClass.{AggregateMode => CometAggregateMode, JoinType, Operator}
Expand Down Expand Up @@ -591,20 +591,17 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde {
castSupport match {
case Compatible =>
castToProto(timeZoneId, dt, childExpr, evalModeStr)
case Incompatible if CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.get() =>
logWarning(s"Calling incompatible CAST expression: $cast")
case Incompatible(reason) if CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.get() =>
logWarning(s"Calling incompatible CAST expression: $cast" +
reason.map(str => s" ($str)").getOrElse(""))
castToProto(timeZoneId, dt, childExpr, evalModeStr)
case UnsupportedWithReason(reason) =>
case Unsupported(reason) =>
withInfo(
expr,
s"Unsupported cast from ${child.dataType} to $dt " +
s"with timezone $timeZoneId and evalMode $evalModeStr: $reason")
None
case Unsupported =>
withInfo(
expr,
s"Unsupported cast from ${child.dataType} to $dt " +
s"with timezone $timeZoneId and evalMode $evalModeStr")
s"with timezone $timeZoneId and evalMode $evalModeStr" +
reason.map(str => s" ($str)").getOrElse("")
)
None
}
} else {
Expand Down

0 comments on commit 0c8da56

Please sign in to comment.