From e2c2cfd0c9f9babca81f41e120eab6839b7a4f09 Mon Sep 17 00:00:00 2001 From: Parth Chandra Date: Mon, 22 Apr 2024 09:22:32 -0700 Subject: [PATCH] address review comments --- .../comet/CometSparkSessionExtensions.scala | 20 ++- .../apache/comet/ExtendedExplainInfo.scala | 3 +- .../apache/comet/serde/QueryPlanSerde.scala | 142 +++++++++--------- .../ShimCometSparkSessionExtensions.scala | 2 + .../spark/sql/ExtendedExplainGenerator.scala | 2 + .../spark/sql/CometTPCQueryListBase.scala | 1 + 6 files changed, 97 insertions(+), 73 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala index f96e87840..b178a6d26 100644 --- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala @@ -175,7 +175,7 @@ class CometSparkSessionExtensions s"Schema $requiredSchema is not supported") val info2 = createMessage( !isSchemaSupported(partitionSchema), - s"Schema $partitionSchema is not supported") + s"Partition schema $partitionSchema is not supported") withInfo(scanExec, Seq(info1, info2).flatten.mkString(",")) scanExec } @@ -1016,6 +1016,24 @@ object CometSparkSessionExtensions extends Logging { node } + /** + * Attaches explain information to a TreeNode, rolling up the corresponding information tags + * from any child nodes + * + * @param node + * The node to attach the explain information to. Typically a SparkPlan + * @param exprs + * Child nodes. Information attached in these nodes will be be included in the information + * attached to @node + * @tparam T + * The type of the TreeNode. Typically SparkPlan, AggregateExpression, or Expression + * @return + * The node with information (if any) attached + */ + def withInfo[T <: TreeNode[_]](node: T, exprs: T*): T = { + withInfo(node, "", exprs: _*) + } + // Helper to reduce boilerplate def createMessage(condition: Boolean, message: => String): Option[String] = { if (condition) { diff --git a/spark/src/main/scala/org/apache/comet/ExtendedExplainInfo.scala b/spark/src/main/scala/org/apache/comet/ExtendedExplainInfo.scala index cbf3e76db..8d27501c8 100644 --- a/spark/src/main/scala/org/apache/comet/ExtendedExplainInfo.scala +++ b/spark/src/main/scala/org/apache/comet/ExtendedExplainInfo.scala @@ -58,7 +58,8 @@ class ExtendedExplainInfo extends ExtendedExplainGenerator { info.filter(!_.contentEquals("\n")) } - // get all plan nodes, breadth first, leaf nodes first + // get all plan nodes, breadth first traversal, then returned the reversed list so + // leaf nodes are first private def sortup(node: TreeNode[_]): mutable.Queue[TreeNode[_]] = { val ordered = new mutable.Queue[TreeNode[_]]() val traversed = mutable.Queue[TreeNode[_]](getActualPlan(node)) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index f07cde50c..26adbde89 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -221,7 +221,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { if (dataType.isEmpty) { withInfo(aggExpr, s"datatype ${s.dataType} is not supported", child) } else { - withInfo(aggExpr, null, child) + withInfo(aggExpr, child) } None } @@ -258,7 +258,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { withInfo(aggExpr, s"datatype ${s.dataType} is not supported", child) None } else { - withInfo(aggExpr, null, child) + withInfo(aggExpr, child) None } case Count(children) => @@ -274,7 +274,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { .setCount(countBuilder) .build()) } else { - withInfo(aggExpr, null, children: _*) + withInfo(aggExpr, children: _*) None } case min @ Min(child) if minMaxDataTypeSupported(min.dataType) => @@ -295,7 +295,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { withInfo(aggExpr, s"datatype ${min.dataType} is not supported", child) None } else { - withInfo(aggExpr, null, child) + withInfo(aggExpr, child) None } case max @ Max(child) if minMaxDataTypeSupported(max.dataType) => @@ -316,7 +316,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { withInfo(aggExpr, s"datatype ${max.dataType} is not supported", child) None } else { - withInfo(aggExpr, null, child) + withInfo(aggExpr, child) None } case first @ First(child, ignoreNulls) @@ -338,7 +338,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { withInfo(aggExpr, s"datatype ${first.dataType} is not supported", child) None } else { - withInfo(aggExpr, null, child) + withInfo(aggExpr, child) None } case last @ Last(child, ignoreNulls) @@ -360,7 +360,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { withInfo(aggExpr, s"datatype ${last.dataType} is not supported", child) None } else { - withInfo(aggExpr, null, child) + withInfo(aggExpr, child) None } case bitAnd @ BitAndAgg(child) if bitwiseAggTypeSupported(bitAnd.dataType) => @@ -381,7 +381,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { withInfo(aggExpr, s"datatype ${bitAnd.dataType} is not supported", child) None } else { - withInfo(aggExpr, null, child) + withInfo(aggExpr, child) None } case bitOr @ BitOrAgg(child) if bitwiseAggTypeSupported(bitOr.dataType) => @@ -402,7 +402,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { withInfo(aggExpr, s"datatype ${bitOr.dataType} is not supported", child) None } else { - withInfo(aggExpr, null, child) + withInfo(aggExpr, child) None } case bitXor @ BitXorAgg(child) if bitwiseAggTypeSupported(bitXor.dataType) => @@ -423,7 +423,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { withInfo(aggExpr, s"datatype ${bitXor.dataType} is not supported", child) None } else { - withInfo(aggExpr, null, child) + withInfo(aggExpr, child) None } case cov @ CovSample(child1, child2, _) => @@ -525,7 +525,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { case a @ Alias(_, _) => val r = exprToProtoInternal(a.child, inputs) if (r.isEmpty) { - withInfo(expr, null, a.child) + withInfo(expr, a.child) } r @@ -536,7 +536,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { case Cast(child, dt, timeZoneId, evalMode) => val childExpr = exprToProtoInternal(child, inputs) -if(childExpr.isDefined) { + if (childExpr.isDefined) { val evalModeStr = if (evalMode.isInstanceOf[Boolean]) { // Spark 3.2 & 3.3 has ansiEnabled boolean if (evalMode.asInstanceOf[Boolean]) "ANSI" else "LEGACY" @@ -545,10 +545,10 @@ if(childExpr.isDefined) { evalMode.toString } castToProto(timeZoneId, dt, childExpr, evalModeStr) -} else { -withInfo(expr, null, child) -None -} + } else { + withInfo(expr, child) + None + } case add @ Add(left, right, _) if supportedDataType(left.dataType) => val leftExpr = exprToProtoInternal(left, inputs) @@ -569,7 +569,7 @@ None .setAdd(addBuilder) .build()) } else { - withInfo(add, null, left, right) + withInfo(add, left, right) None } @@ -596,7 +596,7 @@ None .setSubtract(builder) .build()) } else { - withInfo(sub, null, left, right) + withInfo(sub, left, right) None } @@ -624,7 +624,7 @@ None .setMultiply(builder) .build()) } else { - withInfo(mul, null, left, right) + withInfo(mul, left, right) None } @@ -660,7 +660,7 @@ None .setDivide(builder) .build()) } else { - withInfo(div, null, left, right) + withInfo(div, left, right) None } case div @ Divide(left, right, _) => @@ -692,7 +692,7 @@ None .setRemainder(builder) .build()) } else { - withInfo(rem, null, left, right) + withInfo(rem, left, right) None } case rem @ Remainder(left, _, _) => @@ -719,7 +719,7 @@ None .setEq(builder) .build()) } else { - withInfo(expr, null, left, right) + withInfo(expr, left, right) None } @@ -738,7 +738,7 @@ None .setNeq(builder) .build()) } else { - withInfo(expr, null, left, right) + withInfo(expr, left, right) None } @@ -757,7 +757,7 @@ None .setEqNullSafe(builder) .build()) } else { - withInfo(expr, null, left, right) + withInfo(expr, left, right) None } @@ -776,7 +776,7 @@ None .setNeqNullSafe(builder) .build()) } else { - withInfo(expr, null, left, right) + withInfo(expr, left, right) None } @@ -795,7 +795,7 @@ None .setGt(builder) .build()) } else { - withInfo(expr, null, left, right) + withInfo(expr, left, right) None } @@ -814,7 +814,7 @@ None .setGtEq(builder) .build()) } else { - withInfo(expr, null, left, right) + withInfo(expr, left, right) None } @@ -833,7 +833,7 @@ None .setLt(builder) .build()) } else { - withInfo(expr, null, left, right) + withInfo(expr, left, right) None } @@ -852,7 +852,7 @@ None .setLtEq(builder) .build()) } else { - withInfo(expr, null, left, right) + withInfo(expr, left, right) None } @@ -922,7 +922,7 @@ None .setSubstring(builder) .build()) } else { - withInfo(expr, null, str) + withInfo(expr, str) None } @@ -942,7 +942,7 @@ None .setLike(builder) .build()) } else { - withInfo(expr, null, left, right) + withInfo(expr, left, right) None } @@ -980,7 +980,7 @@ None .setStartsWith(builder) .build()) } else { - withInfo(expr, null, left, right) + withInfo(expr, left, right) None } @@ -999,7 +999,7 @@ None .setEndsWith(builder) .build()) } else { - withInfo(expr, null, left, right) + withInfo(expr, left, right) None } @@ -1018,7 +1018,7 @@ None .setContains(builder) .build()) } else { - withInfo(expr, null, left, right) + withInfo(expr, left, right) None } @@ -1035,7 +1035,7 @@ None .setStringSpace(builder) .build()) } else { - withInfo(expr, null, child) + withInfo(expr, child) None } @@ -1055,7 +1055,7 @@ None .setHour(builder) .build()) } else { - withInfo(expr, null, child) + withInfo(expr, child) None } @@ -1075,7 +1075,7 @@ None .setMinute(builder) .build()) } else { - withInfo(expr, null, child) + withInfo(expr, child) None } @@ -1094,7 +1094,7 @@ None .setTruncDate(builder) .build()) } else { - withInfo(expr, null, child, format) + withInfo(expr, child, format) None } @@ -1116,7 +1116,7 @@ None .setTruncTimestamp(builder) .build()) } else { - withInfo(expr, null, child, format) + withInfo(expr, child, format) None } @@ -1136,7 +1136,7 @@ None .setSecond(builder) .build()) } else { - withInfo(expr, null, child) + withInfo(expr, child) None } @@ -1171,7 +1171,7 @@ None .setIsNull(castBuilder) .build()) } else { - withInfo(expr, null, child) + withInfo(expr, child) None } @@ -1188,7 +1188,7 @@ None .setIsNotNull(castBuilder) .build()) } else { - withInfo(expr, null, child) + withInfo(expr, child) None } @@ -1215,7 +1215,7 @@ None .setSortOrder(sortOrderBuilder) .build()) } else { - withInfo(expr, null, child) + withInfo(expr, child) None } @@ -1234,7 +1234,7 @@ None .setAnd(builder) .build()) } else { - withInfo(expr, null, left, right) + withInfo(expr, left, right) None } @@ -1253,7 +1253,7 @@ None .setOr(builder) .build()) } else { - withInfo(expr, null, left, right) + withInfo(expr, left, right) None } @@ -1280,7 +1280,7 @@ None .setCheckOverflow(builder) .build()) } else { - withInfo(expr, null, child) + withInfo(expr, child) None } @@ -1331,7 +1331,7 @@ None .build() Some(Expr.newBuilder().setAbs(abs).build()) } else { - withInfo(expr, null, child) + withInfo(expr, child) None } @@ -1501,7 +1501,7 @@ None .setIf(builder) .build()) } else { - withInfo(expr, null, predicate, trueValue, falseValue) + withInfo(expr, predicate, trueValue, falseValue) None } @@ -1526,7 +1526,7 @@ None if (elseValueExpr.isDefined) { builder.setElseExpr(elseValueExpr.get) } else { - withInfo(expr, null, elseValue.get) + withInfo(expr, elseValue.get) return None } } @@ -1536,7 +1536,7 @@ None .setCaseWhen(builder) .build()) } else { - withInfo(expr, null, allBranches: _*) + withInfo(expr, allBranches: _*) None } case ConcatWs(children) => @@ -1659,7 +1659,7 @@ None .setBitwiseAnd(builder) .build()) } else { - withInfo(expr, null, left, right) + withInfo(expr, left, right) None } @@ -1676,7 +1676,7 @@ None .setBitwiseNot(builder) .build()) } else { - withInfo(expr, null, child) + withInfo(expr, child) None } @@ -1695,7 +1695,7 @@ None .setBitwiseOr(builder) .build()) } else { - withInfo(expr, null, left, right) + withInfo(expr, left, right) None } @@ -1714,7 +1714,7 @@ None .setBitwiseXor(builder) .build()) } else { - withInfo(expr, null, left, right) + withInfo(expr, left, right) None } @@ -1740,7 +1740,7 @@ None .setBitwiseShiftRight(builder) .build()) } else { - withInfo(expr, null, left, rightExpression) + withInfo(expr, left, rightExpression) None } @@ -1766,7 +1766,7 @@ None .setBitwiseShiftLeft(builder) .build()) } else { - withInfo(expr, null, left, rightExpression) + withInfo(expr, left, rightExpression) None } @@ -1796,7 +1796,7 @@ None .setNot(builder) .build()) } else { - withInfo(expr, null, child) + withInfo(expr, child) None } @@ -1811,7 +1811,7 @@ None .setNegative(builder) .build()) } else { - withInfo(expr, null, child) + withInfo(expr, child) None } @@ -1823,7 +1823,7 @@ None if (childExpr.isDefined) { castToProto(None, a.dataType, childExpr, "LEGACY") } else { - withInfo(expr, null, a.children: _*) + withInfo(expr, a.children: _*) None } @@ -1850,7 +1850,7 @@ None Some(ExprOuterClass.Expr.newBuilder().setScalarFunc(builder).build()) } else { - withInfo(expr, null, arguments: _*) + withInfo(expr, arguments: _*) None } @@ -1910,7 +1910,7 @@ None .setBloomFilterMightContain(builder) .build()) } else { - withInfo(expr, null, bloomFilter, value) + withInfo(expr, bloomFilter, value) None } @@ -1959,7 +1959,7 @@ None .build()) } else { val allExprs = list ++ Seq(value) - withInfo(expr, null, allExprs: _*) + withInfo(expr, allExprs: _*) None } } @@ -2038,7 +2038,7 @@ None .addAllProjectList(exprs.map(_.get).asJava) Some(result.setProjection(projectBuilder).build()) } else { - withInfo(op, null, projectList: _*) + withInfo(op, projectList: _*) None } @@ -2049,7 +2049,7 @@ None val filterBuilder = OperatorOuterClass.Filter.newBuilder().setPredicate(cond.get) Some(result.setFilter(filterBuilder).build()) } else { - withInfo(op, null, condition, child) + withInfo(op, condition, child) None } @@ -2062,7 +2062,7 @@ None .addAllSortOrders(sortOrders.map(_.get).asJava) Some(result.setSort(sortBuilder).build()) } else { - withInfo(op, null, sortOrder: _*) + withInfo(op, sortOrder: _*) None } @@ -2112,7 +2112,7 @@ None .setNumExprPerProject(projections.head.size) Some(result.setExpand(expandBuilder).build()) } else { - withInfo(op, null, allProjExprs: _*) + withInfo(op, allProjExprs: _*) None } @@ -2215,7 +2215,7 @@ None } else { val allChildren: Seq[Expression] = groupingExpressions ++ aggregateExpressions ++ aggregateAttributes - withInfo(op, null, allChildren: _*) + withInfo(op, allChildren: _*) None } } @@ -2241,7 +2241,7 @@ None val condition = join.condition.map { cond => val condProto = exprToProto(cond, join.left.output ++ join.right.output) if (condProto.isEmpty) { - withInfo(join, null, cond) + withInfo(join, cond) return None } condProto.get @@ -2275,7 +2275,7 @@ None Some(result.setHashJoin(joinBuilder).build()) } else { val allExprs: Seq[Expression] = join.leftKeys ++ join.rightKeys - withInfo(join, null, allExprs: _*) + withInfo(join, allExprs: _*) None } @@ -2337,7 +2337,7 @@ None Some(result.setSortMergeJoin(joinBuilder).build()) } else { val allExprs: Seq[Expression] = join.leftKeys ++ join.rightKeys - withInfo(join, null, allExprs: _*) + withInfo(join, allExprs: _*) None } @@ -2497,7 +2497,7 @@ None childExpr: Expression*): Option[Expr] = { optExpr match { case None => - withInfo(expr, null, childExpr: _*) + withInfo(expr, childExpr: _*) None case o => o } diff --git a/spark/src/main/scala/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala index 4291d6605..ffec1bd40 100644 --- a/spark/src/main/scala/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala @@ -50,6 +50,8 @@ object ShimCometSparkSessionExtensions { .map(_.asInstanceOf[Int]) .headOption + // Extended info is available only since Spark 4.0.0 + // (https://issues.apache.org/jira/browse/SPARK-47289) def supportsExtendedExplainInfo(qe: QueryExecution): Boolean = { try { // Look for QueryExecution.extendedExplainInfo(scala.Function1[String, Unit], SparkPlan) diff --git a/spark/src/main/scala/org/apache/spark/sql/ExtendedExplainGenerator.scala b/spark/src/main/scala/org/apache/spark/sql/ExtendedExplainGenerator.scala index e72c894b3..3c5ae95e2 100644 --- a/spark/src/main/scala/org/apache/spark/sql/ExtendedExplainGenerator.scala +++ b/spark/src/main/scala/org/apache/spark/sql/ExtendedExplainGenerator.scala @@ -23,6 +23,8 @@ import org.apache.spark.sql.execution.SparkPlan /** * A trait for a session extension to implement that provides addition explain plan information. + * We copy this from Spark 4.0 since this trait is not available in Spark 3.x. We can remove this + * after dropping Spark 3.x support. */ trait ExtendedExplainGenerator { diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTPCQueryListBase.scala b/spark/src/test/scala/org/apache/spark/sql/CometTPCQueryListBase.scala index a008dbd58..6a5832060 100644 --- a/spark/src/test/scala/org/apache/spark/sql/CometTPCQueryListBase.scala +++ b/spark/src/test/scala/org/apache/spark/sql/CometTPCQueryListBase.scala @@ -88,6 +88,7 @@ trait CometTPCQueryListBase CometConf.COMET_EXEC_ENABLED.key -> "true", CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", CometConf.COMET_EXEC_ALL_OPERATOR_ENABLED.key -> "true", + // Lower bloom filter thresholds to allows us to simulate the plan produced at larger scale. "spark.sql.optimizer.runtime.bloomFilter.creationSideThreshold" -> "1MB", "spark.sql.optimizer.runtime.bloomFilter.applicationSideScanSizeThreshold" -> "1MB") {