diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 3f924952d..47572f7c9 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -61,8 +61,9 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim def supportedDataType(dt: DataType): Boolean = dt match { case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType | _: DoubleType | _: StringType | _: BinaryType | _: TimestampType | _: DecimalType | - _: DateType | _: BooleanType | _: NullType | _: TimestampNTZType => + _: DateType | _: BooleanType | _: NullType => true + case dt if dt.typeName == "timestamp_ntz" => true case dt => emitWarning(s"unsupported Spark data type: $dt") false diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala index c7c80826a..3f4d7bfd3 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala @@ -581,3 +581,18 @@ class CometShuffleWriteProcessor( } } } + +/** + * Copied from Spark `PartitionIdPassthrough` as it is private in Spark 3.2. + */ +private[spark] class PartitionIdPassthrough(override val numPartitions: Int) extends Partitioner { + override def getPartition(key: Any): Int = key.asInstanceOf[Int] +} + +/** + * Copied from Spark `ConstantPartitioner` as it doesn't exist in Spark 3.2. + */ +private[spark] class ConstantPartitioner extends Partitioner { + override def numPartitions: Int = 1 + override def getPartition(key: Any): Int = 0 +}