diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index f544b899a..2a4bb3a9a 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -25,7 +25,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, BitAndAgg, BitOrAgg, BitXorAgg, Count, CovPopulation, CovSample, Final, First, Last, Max, Min, Partial, StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp} import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke -import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, NormalizeNaNAndZero} +import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, NormalizeNaNAndZero} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, SinglePartition} import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils @@ -2417,10 +2417,12 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim return None } - if (join.buildSide == BuildRight && join.joinType == LeftAnti) { - // DataFusion HashJoin LeftAnti has bugs on null keys. - withInfo(join, "BuildRight with LeftAnti is not supported") - return None + join match { + case b: BroadcastHashJoinExec if b.isNullAwareAntiJoin => + // DataFusion HashJoin LeftAnti has bugs on null keys. + withInfo(join, "DataFusion doesn't support null-aware anti join") + return None + case _ => // no-op } val condition = join.condition.map { cond =>