Skip to content

Commit

Permalink
More
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Jul 5, 2024
1 parent ee8c25e commit ad2aa77
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 25 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.spark.sql

import scala.util.control.NonFatal

import org.apache.spark.{SparkException, SparkThrowable}
import org.apache.spark.ErrorMessageFormat.MINIMAL
import org.apache.spark.SparkThrowableHelper.getMessage
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.HiveResult.hiveResultString
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.command.{DescribeColumnCommand, DescribeCommandBase}
import org.apache.spark.sql.types.StructType

trait CometSQLQueryTestHelper {

private val notIncludedMsg = "[not included in comparison]"
private val clsName = this.getClass.getCanonicalName
protected val emptySchema = StructType(Seq.empty).catalogString

protected def replaceNotIncludedMsg(line: String): String = {
line
.replaceAll("#\\d+", "#x")
.replaceAll("plan_id=\\d+", "plan_id=x")
.replaceAll(s"Location.*$clsName/", s"Location $notIncludedMsg/{warehouse_dir}/")
.replaceAll(s"file:[^\\s,]*$clsName", s"file:$notIncludedMsg/{warehouse_dir}")
.replaceAll("Created By.*", s"Created By $notIncludedMsg")
.replaceAll("Created Time.*", s"Created Time $notIncludedMsg")
.replaceAll("Last Access.*", s"Last Access $notIncludedMsg")
.replaceAll("Partition Statistics\t\\d+", s"Partition Statistics\t$notIncludedMsg")
.replaceAll("\\*\\(\\d+\\) ", "*") // remove the WholeStageCodegen codegenStageIds
}

/** Executes a query and returns the result as (schema of the output, normalized output). */
protected def getNormalizedResult(session: SparkSession, sql: String): (String, Seq[String]) = {
// Returns true if the plan is supposed to be sorted.
def isSorted(plan: LogicalPlan): Boolean = plan match {
case _: Join | _: Aggregate | _: Generate | _: Sample | _: Distinct => false
case _: DescribeCommandBase | _: DescribeColumnCommand | _: DescribeRelation |
_: DescribeColumn =>
true
case PhysicalOperation(_, _, Sort(_, true, _)) => true
case _ => plan.children.iterator.exists(isSorted)
}

val df = session.sql(sql)
df.explain()
df.show()
val schema = df.schema.catalogString
// Get answer, but also get rid of the #1234 expression ids that show up in explain plans
val answer = SQLExecution.withNewExecutionId(df.queryExecution, Some(sql)) {
hiveResultString(df.queryExecution.executedPlan).map(replaceNotIncludedMsg)
}

// If the output is not pre-sorted, sort it.
if (isSorted(df.queryExecution.analyzed)) (schema, answer) else (schema, answer.sorted)
}

/**
* This method handles exceptions occurred during query execution as they may need special care
* to become comparable to the expected output.
*
* @param result
* a function that returns a pair of schema and output
*/
protected def handleExceptions(result: => (String, Seq[String])): (String, Seq[String]) = {
val format = MINIMAL
try {
result
} catch {
case e: SparkThrowable with Throwable if e.getErrorClass != null =>
(emptySchema, Seq(e.getClass.getName, getMessage(e, format)))
case a: AnalysisException =>
// Do not output the logical plan tree which contains expression IDs.
// Also implement a crude way of masking expression IDs in the error message
// with a generic pattern "###".
val msg = if (a.plan.nonEmpty) a.getSimpleMessage else a.getMessage
(emptySchema, Seq(a.getClass.getName, msg.replaceAll("#\\d+", "#x")))
case s: SparkException if s.getCause != null =>
// For a runtime exception, it is hard to match because its message contains
// information of stage, task ID, etc.
// To make result matching simpler, here we match the cause of the exception if it exists.
s.getCause match {
case e: SparkThrowable with Throwable if e.getErrorClass != null =>
(emptySchema, Seq(e.getClass.getName, getMessage(e, format)))
case cause =>
(emptySchema, Seq(cause.getClass.getName, cause.getMessage))
}
case NonFatal(e) =>
// If there is an exception, put the exception class followed by the message.
(emptySchema, Seq(e.getClass.getName, e.getMessage))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ import org.apache.spark.sql.test.TestSparkSession
* Because we need to modify some methods of Spark `TPCDSQueryTestSuite` but they are private, we
* copy Spark `TPCDSQueryTestSuite`.
*/
class CometTPCDSQueryTestSuite extends QueryTest with TPCDSBase with SQLQueryTestHelper {
class CometTPCDSQueryTestSuite extends QueryTest with TPCDSBase with CometSQLQueryTestHelper {

private val tpcdsDataPath = sys.env.get("SPARK_TPCDS_DATA")

Expand Down Expand Up @@ -89,29 +89,6 @@ class CometTPCDSQueryTestSuite extends QueryTest with TPCDSBase with SQLQueryTes
""".stripMargin)
}

/** Executes a query and returns the result as (schema of the output, normalized output). */
protected def normalizedResult(session: SparkSession, sql: String): (String, Seq[String]) = {
// Returns true if the plan is supposed to be sorted.
def isSorted(plan: LogicalPlan): Boolean = plan match {
case _: Join | _: Aggregate | _: Generate | _: Sample | _: Distinct => false
case _: DescribeCommandBase | _: DescribeColumnCommand | _: DescribeRelation |
_: DescribeColumn =>
true
case PhysicalOperation(_, _, Sort(_, true, _)) => true
case _ => plan.children.iterator.exists(isSorted)
}

val df = session.sql(sql)
val schema = df.schema.catalogString
// Get answer, but also get rid of the #1234 expression ids that show up in explain plans
val answer = SQLExecution.withNewExecutionId(df.queryExecution, Some(sql)) {
hiveResultString(df.queryExecution.executedPlan).map(replaceNotIncludedMsg)
}

// If the output is not pre-sorted, sort it.
if (isSorted(df.queryExecution.analyzed)) (schema, answer) else (schema, answer.sorted)
}

private def runQuery(query: String, goldenFile: File, conf: Map[String, String]): Unit = {
// This is `sortMergeJoinConf != conf` in Spark, i.e., it sorts results for other joins
// than sort merge join. But in some queries DataFusion sort returns correct results
Expand All @@ -120,7 +97,7 @@ class CometTPCDSQueryTestSuite extends QueryTest with TPCDSBase with SQLQueryTes
val shouldSortResults = true
withSQLConf(conf.toSeq: _*) {
try {
val (schema, output) = handleExceptions(normalizedResult(spark, query))
val (schema, output) = handleExceptions(getNormalizedResult(spark, query))
val queryString = query.trim
val outputString = output.mkString("\n").replaceAll("\\s+$", "")
if (regenerateGoldenFiles) {
Expand Down

0 comments on commit ad2aa77

Please sign in to comment.