diff --git a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala index 0b93f3206..73592d785 100644 --- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala @@ -46,7 +46,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.comet.CometConf._ -import org.apache.comet.CometSparkSessionExtensions.{createMessage, isANSIEnabled, isCometBroadCastForceEnabled, isCometColumnarShuffleEnabled, isCometEnabled, isCometExecEnabled, isCometOperatorEnabled, isCometScan, isCometScanEnabled, isCometShuffleEnabled, isSchemaSupported, shouldApplyRowToColumnar, withInfo} +import org.apache.comet.CometSparkSessionExtensions.{createMessage, isANSIEnabled, isCometBroadCastForceEnabled, isCometColumnarShuffleEnabled, isCometEnabled, isCometExecEnabled, isCometOperatorEnabled, isCometScan, isCometScanEnabled, isCometShuffleEnabled, isSchemaSupported, shouldApplyRowToColumnar, withInfo, withInfos} import org.apache.comet.parquet.{CometParquetScan, SupportsComet} import org.apache.comet.serde.OperatorOuterClass.Operator import org.apache.comet.serde.QueryPlanSerde @@ -120,7 +120,7 @@ class CometSparkSessionExtensions val info3 = createMessage( getPushedAggregate(scanExec.scan.asInstanceOf[ParquetScan]).isDefined, "Comet does not support pushed aggregate") - withInfo(scanExec, Seq(info1, info2, info3).flatten.mkString("\n")) + withInfos(scanExec, Seq(info1, info2, info3).flatten.toSet) scanExec // Other datasource V2 scan @@ -147,7 +147,7 @@ class CometSparkSessionExtensions !isSchemaSupported(scanExec.scan.readSchema()), "Comet extension is not enabled for " + s"${scanExec.scan.getClass.getSimpleName}: Schema not supported") - withInfo(scanExec, Seq(info1, info2).flatten.mkString("\n")) + withInfos(scanExec, Seq(info1, info2).flatten.toSet) // If it is data source V2 other than Parquet or Iceberg, // attach the unsupported reason to the plan. @@ -1045,17 +1045,36 @@ object CometSparkSessionExtensions extends Logging { * The node with information (if any) attached */ def withInfo[T <: TreeNode[_]](node: T, info: String, exprs: T*): T = { - val exprInfo = exprs - .flatMap { e => Seq(e.getTagValue(CometExplainInfo.EXTENSION_INFO)) } - .flatten - .mkString("\n") - if (info != null && info.nonEmpty && exprInfo.nonEmpty) { - node.setTagValue(CometExplainInfo.EXTENSION_INFO, Seq(exprInfo, info).mkString("\n")) - } else if (exprInfo.nonEmpty) { - node.setTagValue(CometExplainInfo.EXTENSION_INFO, exprInfo) - } else if (info != null && info.nonEmpty) { - node.setTagValue(CometExplainInfo.EXTENSION_INFO, info) + // support existing approach of passing in multiple infos in a newline-delimited string + val infoSet = if (info == null || info.isEmpty) { + Set.empty[String] + } else { + info.split("\n").toSet } + withInfos(node, infoSet, exprs: _*) + } + + /** + * Attaches explain information to a TreeNode, rolling up the corresponding information tags + * from any child nodes. For now, we are using this to attach the reasons why certain Spark + * operators or expressions are disabled. + * + * @param node + * The node to attach the explain information to. Typically a SparkPlan + * @param info + * Information text. May contain zero or more strings. If not provided, then only information + * from child nodes will be included. + * @param exprs + * Child nodes. Information attached in these nodes will be be included in the information + * attached to @node + * @tparam T + * The type of the TreeNode. Typically SparkPlan, AggregateExpression, or Expression + * @return + * The node with information (if any) attached + */ + def withInfos[T <: TreeNode[_]](node: T, info: Set[String], exprs: T*): T = { + val exprInfo = exprs.flatMap(_.getTagValue(CometExplainInfo.EXTENSION_INFO)).flatten.toSet + node.setTagValue(CometExplainInfo.EXTENSION_INFO, exprInfo ++ info) node } @@ -1074,7 +1093,7 @@ object CometSparkSessionExtensions extends Logging { * The node with information (if any) attached */ def withInfo[T <: TreeNode[_]](node: T, exprs: T*): T = { - withInfo(node, "", exprs: _*) + withInfos(node, Set.empty, exprs: _*) } // Helper to reduce boilerplate diff --git a/spark/src/main/scala/org/apache/comet/ExtendedExplainInfo.scala b/spark/src/main/scala/org/apache/comet/ExtendedExplainInfo.scala index 8d27501c8..8e5aee8b6 100644 --- a/spark/src/main/scala/org/apache/comet/ExtendedExplainInfo.scala +++ b/spark/src/main/scala/org/apache/comet/ExtendedExplainInfo.scala @@ -32,7 +32,7 @@ class ExtendedExplainInfo extends ExtendedExplainGenerator { override def generateExtendedInfo(plan: SparkPlan): String = { val info = extensionInfo(plan) - info.distinct.mkString("\n").trim + info.toSeq.sorted.mkString("\n").trim } private def getActualPlan(node: TreeNode[_]): TreeNode[_] = { @@ -45,17 +45,17 @@ class ExtendedExplainInfo extends ExtendedExplainGenerator { } } - private def extensionInfo(node: TreeNode[_]): mutable.Seq[String] = { + private def extensionInfo(node: TreeNode[_]): Set[String] = { var info = mutable.Seq[String]() val sorted = sortup(node) sorted.foreach { p => - val all: Array[String] = - getActualPlan(p).getTagValue(CometExplainInfo.EXTENSION_INFO).getOrElse("").split("\n") + val all: Set[String] = + getActualPlan(p).getTagValue(CometExplainInfo.EXTENSION_INFO).getOrElse(Set.empty[String]) for (s <- all) { info = info :+ s } } - info.filter(!_.contentEquals("\n")) + info.toSet } // get all plan nodes, breadth first traversal, then returned the reversed list so @@ -84,5 +84,5 @@ class ExtendedExplainInfo extends ExtendedExplainGenerator { } object CometExplainInfo { - val EXTENSION_INFO = new TreeNodeTag[String]("CometExtensionInfo") + val EXTENSION_INFO = new TreeNodeTag[Set[String]]("CometExtensionInfo") } diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index e3390296a..eb4429dc6 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -1384,29 +1384,30 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { Seq( ( s"SELECT cast(make_interval(c0, c1, c0, c1, c0, c0, c2) as string) as C from $table", - "make_interval is not supported"), + Set("make_interval is not supported")), ( "SELECT " + "date_part('YEAR', make_interval(c0, c1, c0, c1, c0, c0, c2))" + " + " + "date_part('MONTH', make_interval(c0, c1, c0, c1, c0, c0, c2))" + s" as yrs_and_mths from $table", - "extractintervalyears is not supported\n" + - "extractintervalmonths is not supported"), + Set( + "extractintervalyears is not supported", + "extractintervalmonths is not supported")), ( s"SELECT sum(c0), sum(c2) from $table group by c1", - "Native shuffle is not enabled\n" + - "AQEShuffleRead is not supported"), + Set("Native shuffle is not enabled", "AQEShuffleRead is not supported")), ( "SELECT A.c1, A.sum_c0, A.sum_c2, B.casted from " + s"(SELECT c1, sum(c0) as sum_c0, sum(c2) as sum_c2 from $table group by c1) as A, " + s"(SELECT c1, cast(make_interval(c0, c1, c0, c1, c0, c0, c2) as string) as casted from $table) as B " + "where A.c1 = B.c1 ", - "Native shuffle is not enabled\n" + - "AQEShuffleRead is not supported\n" + - "make_interval is not supported\n" + - "BroadcastExchange is not supported\n" + - "BroadcastHashJoin disabled because not all child plans are native")) + Set( + "Native shuffle is not enabled", + "AQEShuffleRead is not supported", + "make_interval is not supported", + "BroadcastExchange is not supported", + "BroadcastHashJoin disabled because not all child plans are native"))) .foreach(test => { val qry = test._1 val expected = test._2 diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala index 1ed447dc3..8ff287dec 100644 --- a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala +++ b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala @@ -245,7 +245,7 @@ abstract class CometTestBase protected def checkSparkAnswerAndCompareExplainPlan( df: DataFrame, - expectedInfo: String): Unit = { + expectedInfo: Set[String]): Unit = { var expected: Array[Row] = Array.empty var dfSpark: Dataset[Row] = null withSQLConf( @@ -264,7 +264,7 @@ abstract class CometTestBase } val extendedInfo = new ExtendedExplainInfo().generateExtendedInfo(dfComet.queryExecution.executedPlan) - assert(extendedInfo.equalsIgnoreCase(expectedInfo)) + assert(extendedInfo.equalsIgnoreCase(expectedInfo.toSeq.sorted.mkString("\n"))) } private var _spark: SparkSession = _