diff --git a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala index 490e0b0f26..648dc5dfbc 100644 --- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala @@ -175,7 +175,7 @@ class CometSparkSessionExtensions s"Schema $requiredSchema is not supported") val info2 = createMessage( !isSchemaSupported(partitionSchema), - s"Schema $partitionSchema is not supported") + s"Partition schema $partitionSchema is not supported") withInfo(scanExec, Seq(info1, info2).flatten.mkString(",")) scanExec } @@ -996,6 +996,24 @@ object CometSparkSessionExtensions extends Logging { node } + /** + * Attaches explain information to a TreeNode, rolling up the corresponding information tags + * from any child nodes + * + * @param node + * The node to attach the explain information to. Typically a SparkPlan + * @param exprs + * Child nodes. Information attached in these nodes will be be included in the information + * attached to @node + * @tparam T + * The type of the TreeNode. Typically SparkPlan, AggregateExpression, or Expression + * @return + * The node with information (if any) attached + */ + def withInfo[T <: TreeNode[_]](node: T, exprs: T*): T = { + withInfo(node, "", exprs: _*) + } + // Helper to reduce boilerplate def createMessage(condition: Boolean, message: => String): Option[String] = { if (condition) { diff --git a/spark/src/main/scala/org/apache/comet/ExtendedExplainInfo.scala b/spark/src/main/scala/org/apache/comet/ExtendedExplainInfo.scala index cbf3e76dbd..8d27501c8a 100644 --- a/spark/src/main/scala/org/apache/comet/ExtendedExplainInfo.scala +++ b/spark/src/main/scala/org/apache/comet/ExtendedExplainInfo.scala @@ -58,7 +58,8 @@ class ExtendedExplainInfo extends ExtendedExplainGenerator { info.filter(!_.contentEquals("\n")) } - // get all plan nodes, breadth first, leaf nodes first + // get all plan nodes, breadth first traversal, then returned the reversed list so + // leaf nodes are first private def sortup(node: TreeNode[_]): mutable.Queue[TreeNode[_]] = { val ordered = new mutable.Queue[TreeNode[_]]() val traversed = mutable.Queue[TreeNode[_]](getActualPlan(node)) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 92ad94cd33..36aab13813 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -221,7 +221,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { if (dataType.isEmpty) { withInfo(aggExpr, s"datatype ${s.dataType} is not supported", child) } else { - withInfo(aggExpr, null, child) + withInfo(aggExpr, child) } None } @@ -258,7 +258,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { withInfo(aggExpr, s"datatype ${s.dataType} is not supported", child) None } else { - withInfo(aggExpr, null, child) + withInfo(aggExpr, child) None } case Count(children) => @@ -274,7 +274,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { .setCount(countBuilder) .build()) } else { - withInfo(aggExpr, null, children: _*) + withInfo(aggExpr, children: _*) None } case min @ Min(child) if minMaxDataTypeSupported(min.dataType) => @@ -295,7 +295,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { withInfo(aggExpr, s"datatype ${min.dataType} is not supported", child) None } else { - withInfo(aggExpr, null, child) + withInfo(aggExpr, child) None } case max @ Max(child) if minMaxDataTypeSupported(max.dataType) => @@ -316,7 +316,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { withInfo(aggExpr, s"datatype ${max.dataType} is not supported", child) None } else { - withInfo(aggExpr, null, child) + withInfo(aggExpr, child) None } case first @ First(child, ignoreNulls) @@ -338,7 +338,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { withInfo(aggExpr, s"datatype ${first.dataType} is not supported", child) None } else { - withInfo(aggExpr, null, child) + withInfo(aggExpr, child) None } case last @ Last(child, ignoreNulls) @@ -360,7 +360,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { withInfo(aggExpr, s"datatype ${last.dataType} is not supported", child) None } else { - withInfo(aggExpr, null, child) + withInfo(aggExpr, child) None } case bitAnd @ BitAndAgg(child) if bitwiseAggTypeSupported(bitAnd.dataType) => @@ -381,7 +381,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { withInfo(aggExpr, s"datatype ${bitAnd.dataType} is not supported", child) None } else { - withInfo(aggExpr, null, child) + withInfo(aggExpr, child) None } case bitOr @ BitOrAgg(child) if bitwiseAggTypeSupported(bitOr.dataType) => @@ -402,7 +402,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { withInfo(aggExpr, s"datatype ${bitOr.dataType} is not supported", child) None } else { - withInfo(aggExpr, null, child) + withInfo(aggExpr, child) None } case bitXor @ BitXorAgg(child) if bitwiseAggTypeSupported(bitXor.dataType) => @@ -423,7 +423,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { withInfo(aggExpr, s"datatype ${bitXor.dataType} is not supported", child) None } else { - withInfo(aggExpr, null, child) + withInfo(aggExpr, child) None } case cov @ CovSample(child1, child2, _) => @@ -523,7 +523,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { case a @ Alias(_, _) => val r = exprToProtoInternal(a.child, inputs) if (r.isEmpty) { - withInfo(expr, null, a.child) + withInfo(expr, a.child) } r @@ -537,7 +537,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { if (childExpr.isDefined) { castToProto(timeZoneId, dt, childExpr) } else { - withInfo(expr, null, child) + withInfo(expr, child) None } @@ -560,7 +560,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { .setAdd(addBuilder) .build()) } else { - withInfo(add, null, left, right) + withInfo(add, left, right) None } @@ -587,7 +587,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { .setSubtract(builder) .build()) } else { - withInfo(sub, null, left, right) + withInfo(sub, left, right) None } @@ -615,7 +615,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { .setMultiply(builder) .build()) } else { - withInfo(mul, null, left, right) + withInfo(mul, left, right) None } @@ -651,7 +651,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { .setDivide(builder) .build()) } else { - withInfo(div, null, left, right) + withInfo(div, left, right) None } case div @ Divide(left, right, _) => @@ -683,7 +683,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { .setRemainder(builder) .build()) } else { - withInfo(rem, null, left, right) + withInfo(rem, left, right) None } case rem @ Remainder(left, _, _) => @@ -710,7 +710,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { .setEq(builder) .build()) } else { - withInfo(expr, null, left, right) + withInfo(expr, left, right) None } @@ -729,7 +729,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { .setNeq(builder) .build()) } else { - withInfo(expr, null, left, right) + withInfo(expr, left, right) None } @@ -748,7 +748,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { .setEqNullSafe(builder) .build()) } else { - withInfo(expr, null, left, right) + withInfo(expr, left, right) None } @@ -767,7 +767,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { .setNeqNullSafe(builder) .build()) } else { - withInfo(expr, null, left, right) + withInfo(expr, left, right) None } @@ -786,7 +786,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { .setGt(builder) .build()) } else { - withInfo(expr, null, left, right) + withInfo(expr, left, right) None } @@ -805,7 +805,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { .setGtEq(builder) .build()) } else { - withInfo(expr, null, left, right) + withInfo(expr, left, right) None } @@ -824,7 +824,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { .setLt(builder) .build()) } else { - withInfo(expr, null, left, right) + withInfo(expr, left, right) None } @@ -843,7 +843,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { .setLtEq(builder) .build()) } else { - withInfo(expr, null, left, right) + withInfo(expr, left, right) None } @@ -913,7 +913,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { .setSubstring(builder) .build()) } else { - withInfo(expr, null, str) + withInfo(expr, str) None } @@ -933,7 +933,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { .setLike(builder) .build()) } else { - withInfo(expr, null, left, right) + withInfo(expr, left, right) None } @@ -971,7 +971,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { .setStartsWith(builder) .build()) } else { - withInfo(expr, null, left, right) + withInfo(expr, left, right) None } @@ -990,7 +990,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { .setEndsWith(builder) .build()) } else { - withInfo(expr, null, left, right) + withInfo(expr, left, right) None } @@ -1009,7 +1009,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { .setContains(builder) .build()) } else { - withInfo(expr, null, left, right) + withInfo(expr, left, right) None } @@ -1026,7 +1026,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { .setStringSpace(builder) .build()) } else { - withInfo(expr, null, child) + withInfo(expr, child) None } @@ -1046,7 +1046,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { .setHour(builder) .build()) } else { - withInfo(expr, null, child) + withInfo(expr, child) None } @@ -1066,7 +1066,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { .setMinute(builder) .build()) } else { - withInfo(expr, null, child) + withInfo(expr, child) None } @@ -1085,7 +1085,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { .setTruncDate(builder) .build()) } else { - withInfo(expr, null, child, format) + withInfo(expr, child, format) None } @@ -1107,7 +1107,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { .setTruncTimestamp(builder) .build()) } else { - withInfo(expr, null, child, format) + withInfo(expr, child, format) None } @@ -1127,7 +1127,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { .setSecond(builder) .build()) } else { - withInfo(expr, null, child) + withInfo(expr, child) None } @@ -1161,7 +1161,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { .setIsNull(castBuilder) .build()) } else { - withInfo(expr, null, child) + withInfo(expr, child) None } @@ -1178,7 +1178,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { .setIsNotNull(castBuilder) .build()) } else { - withInfo(expr, null, child) + withInfo(expr, child) None } @@ -1205,7 +1205,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { .setSortOrder(sortOrderBuilder) .build()) } else { - withInfo(expr, null, child) + withInfo(expr, child) None } @@ -1224,7 +1224,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { .setAnd(builder) .build()) } else { - withInfo(expr, null, left, right) + withInfo(expr, left, right) None } @@ -1243,7 +1243,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { .setOr(builder) .build()) } else { - withInfo(expr, null, left, right) + withInfo(expr, left, right) None } @@ -1270,7 +1270,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { .setCheckOverflow(builder) .build()) } else { - withInfo(expr, null, child) + withInfo(expr, child) None } @@ -1321,7 +1321,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { .build() Some(Expr.newBuilder().setAbs(abs).build()) } else { - withInfo(expr, null, child) + withInfo(expr, child) None } @@ -1491,7 +1491,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { .setIf(builder) .build()) } else { - withInfo(expr, null, predicate, trueValue, falseValue) + withInfo(expr, predicate, trueValue, falseValue) None } @@ -1516,7 +1516,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { if (elseValueExpr.isDefined) { builder.setElseExpr(elseValueExpr.get) } else { - withInfo(expr, null, elseValue.get) + withInfo(expr, elseValue.get) return None } } @@ -1526,7 +1526,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { .setCaseWhen(builder) .build()) } else { - withInfo(expr, null, allBranches: _*) + withInfo(expr, allBranches: _*) None } case ConcatWs(children) => @@ -1649,7 +1649,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { .setBitwiseAnd(builder) .build()) } else { - withInfo(expr, null, left, right) + withInfo(expr, left, right) None } @@ -1666,7 +1666,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { .setBitwiseNot(builder) .build()) } else { - withInfo(expr, null, child) + withInfo(expr, child) None } @@ -1685,7 +1685,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { .setBitwiseOr(builder) .build()) } else { - withInfo(expr, null, left, right) + withInfo(expr, left, right) None } @@ -1704,7 +1704,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { .setBitwiseXor(builder) .build()) } else { - withInfo(expr, null, left, right) + withInfo(expr, left, right) None } @@ -1730,7 +1730,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { .setBitwiseShiftRight(builder) .build()) } else { - withInfo(expr, null, left, rightExpression) + withInfo(expr, left, rightExpression) None } @@ -1756,7 +1756,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { .setBitwiseShiftLeft(builder) .build()) } else { - withInfo(expr, null, left, rightExpression) + withInfo(expr, left, rightExpression) None } @@ -1786,7 +1786,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { .setNot(builder) .build()) } else { - withInfo(expr, null, child) + withInfo(expr, child) None } @@ -1801,7 +1801,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { .setNegative(builder) .build()) } else { - withInfo(expr, null, child) + withInfo(expr, child) None } @@ -1813,7 +1813,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { if (childExpr.isDefined) { castToProto(None, a.dataType, childExpr) } else { - withInfo(expr, null, a.children: _*) + withInfo(expr, a.children: _*) None } @@ -1840,7 +1840,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { Some(ExprOuterClass.Expr.newBuilder().setScalarFunc(builder).build()) } else { - withInfo(expr, null, arguments: _*) + withInfo(expr, arguments: _*) None } @@ -1900,7 +1900,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { .setBloomFilterMightContain(builder) .build()) } else { - withInfo(expr, null, bloomFilter, value) + withInfo(expr, bloomFilter, value) None } @@ -1949,7 +1949,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { .build()) } else { val allExprs = list ++ Seq(value) - withInfo(expr, null, allExprs: _*) + withInfo(expr, allExprs: _*) None } } @@ -2028,7 +2028,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { .addAllProjectList(exprs.map(_.get).asJava) Some(result.setProjection(projectBuilder).build()) } else { - withInfo(op, null, projectList: _*) + withInfo(op, projectList: _*) None } @@ -2039,7 +2039,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { val filterBuilder = OperatorOuterClass.Filter.newBuilder().setPredicate(cond.get) Some(result.setFilter(filterBuilder).build()) } else { - withInfo(op, null, condition, child) + withInfo(op, condition, child) None } @@ -2052,7 +2052,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { .addAllSortOrders(sortOrders.map(_.get).asJava) Some(result.setSort(sortBuilder).build()) } else { - withInfo(op, null, sortOrder: _*) + withInfo(op, sortOrder: _*) None } @@ -2102,7 +2102,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { .setNumExprPerProject(projections.head.size) Some(result.setExpand(expandBuilder).build()) } else { - withInfo(op, null, allProjExprs: _*) + withInfo(op, allProjExprs: _*) None } @@ -2205,7 +2205,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { } else { val allChildren: Seq[Expression] = groupingExpressions ++ aggregateExpressions ++ aggregateAttributes - withInfo(op, null, allChildren: _*) + withInfo(op, allChildren: _*) None } } @@ -2231,7 +2231,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { val condition = join.condition.map { cond => val condProto = exprToProto(cond, join.left.output ++ join.right.output) if (condProto.isEmpty) { - withInfo(join, null, cond) + withInfo(join, cond) return None } condProto.get @@ -2265,7 +2265,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { Some(result.setHashJoin(joinBuilder).build()) } else { val allExprs: Seq[Expression] = join.leftKeys ++ join.rightKeys - withInfo(join, null, allExprs: _*) + withInfo(join, allExprs: _*) None } @@ -2327,7 +2327,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { Some(result.setSortMergeJoin(joinBuilder).build()) } else { val allExprs: Seq[Expression] = join.leftKeys ++ join.rightKeys - withInfo(join, null, allExprs: _*) + withInfo(join, allExprs: _*) None } @@ -2487,7 +2487,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { childExpr: Expression*): Option[Expr] = { optExpr match { case None => - withInfo(expr, null, childExpr: _*) + withInfo(expr, childExpr: _*) None case o => o } diff --git a/spark/src/main/scala/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala index 4291d66054..ffec1bd402 100644 --- a/spark/src/main/scala/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala @@ -50,6 +50,8 @@ object ShimCometSparkSessionExtensions { .map(_.asInstanceOf[Int]) .headOption + // Extended info is available only since Spark 4.0.0 + // (https://issues.apache.org/jira/browse/SPARK-47289) def supportsExtendedExplainInfo(qe: QueryExecution): Boolean = { try { // Look for QueryExecution.extendedExplainInfo(scala.Function1[String, Unit], SparkPlan) diff --git a/spark/src/main/scala/org/apache/spark/sql/ExtendedExplainGenerator.scala b/spark/src/main/scala/org/apache/spark/sql/ExtendedExplainGenerator.scala index e72c894b3d..3c5ae95e29 100644 --- a/spark/src/main/scala/org/apache/spark/sql/ExtendedExplainGenerator.scala +++ b/spark/src/main/scala/org/apache/spark/sql/ExtendedExplainGenerator.scala @@ -23,6 +23,8 @@ import org.apache.spark.sql.execution.SparkPlan /** * A trait for a session extension to implement that provides addition explain plan information. + * We copy this from Spark 4.0 since this trait is not available in Spark 3.x. We can remove this + * after dropping Spark 3.x support. */ trait ExtendedExplainGenerator { diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTPCQueryListBase.scala b/spark/src/test/scala/org/apache/spark/sql/CometTPCQueryListBase.scala index a008dbd58e..6a58320606 100644 --- a/spark/src/test/scala/org/apache/spark/sql/CometTPCQueryListBase.scala +++ b/spark/src/test/scala/org/apache/spark/sql/CometTPCQueryListBase.scala @@ -88,6 +88,7 @@ trait CometTPCQueryListBase CometConf.COMET_EXEC_ENABLED.key -> "true", CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", CometConf.COMET_EXEC_ALL_OPERATOR_ENABLED.key -> "true", + // Lower bloom filter thresholds to allows us to simulate the plan produced at larger scale. "spark.sql.optimizer.runtime.bloomFilter.creationSideThreshold" -> "1MB", "spark.sql.optimizer.runtime.bloomFilter.applicationSideScanSizeThreshold" -> "1MB") {