Skip to content

Commit

Permalink
improve comparison logic, make queries deterministic
Browse files Browse the repository at this point in the history
  • Loading branch information
andygrove committed May 26, 2024
1 parent 34b7624 commit a2dce70
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 65 deletions.
2 changes: 1 addition & 1 deletion dev/scalastyle-config.xml
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ This file is divided into 3 sections:
<parameter name="groups">java,scala,org,apache,3rdParty,comet</parameter>
<parameter name="group.java">javax?\..*</parameter>
<parameter name="group.scala">scala\..*</parameter>
<parameter name="group.org">org\.(?!apache\.comet).*</parameter>
<parameter name="group.org">org\.(?!apache).*</parameter>
<parameter name="group.apache">org\.apache\.(?!comet).*</parameter>
<parameter name="group.3rdParty">(?!(javax?\.|scala\.|org\.apache\.comet\.)).*</parameter>
<parameter name="group.comet">org\.apache\.comet\..*</parameter>
Expand Down
2 changes: 1 addition & 1 deletion fuzz-testing/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ Set appropriate values for `SPARK_HOME`, `SPARK_MASTER`, and `COMET_JAR` environ
$SPARK_HOME/bin/spark-submit \
--master $SPARK_MASTER \
--class org.apache.comet.fuzz.Main \
target/cometfuzz-0.1.0-SNAPSHOT-jar-with-dependencies.jar \
target/comet-fuzz-0.1.0-SNAPSHOT-jar-with-dependencies.jar \
data --num-files=2 --num-rows=200 --num-columns=100
```

Expand Down
1 change: 1 addition & 0 deletions fuzz-testing/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ under the License.
<artifactId>comet-fuzz-spark${spark.version.short}_${scala.binary.version}</artifactId>
<name>comet-fuzz</name>
<url>http://maven.apache.org</url>
<packaging>jar</packaging>

<properties>
<!-- Reverse default (skip installation), and then enable only for child modules -->
Expand Down
17 changes: 9 additions & 8 deletions fuzz-testing/src/main/scala/org/apache/comet/fuzz/Main.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,26 +22,27 @@ package org.apache.comet.fuzz
import scala.util.Random

import org.rogach.scallop.{ScallopConf, Subcommand}
import org.rogach.scallop.ScallopOption

import org.apache.spark.sql.SparkSession

class Conf(arguments: Seq[String]) extends ScallopConf(arguments) {
val generateData: generateData = new generateData
class generateData extends Subcommand("data") {
val numFiles = opt[Int](required = true)
val numRows = opt[Int](required = true)
val numColumns = opt[Int](required = true)
val numFiles: ScallopOption[Int] = opt[Int](required = true)
val numRows: ScallopOption[Int] = opt[Int](required = true)
val numColumns: ScallopOption[Int] = opt[Int](required = true)
}
val generateQueries: generateQueries = new generateQueries
class generateQueries extends Subcommand("queries") {
val numFiles = opt[Int](required = false)
val numQueries = opt[Int](required = true)
val numFiles: ScallopOption[Int] = opt[Int](required = false)
val numQueries: ScallopOption[Int] = opt[Int](required = true)
}
val runQueries: runQueries = new runQueries
class runQueries extends Subcommand("run") {
val filename = opt[String](required = true)
val numFiles = opt[Int](required = false)
val showMatchingResults = opt[Boolean](required = false)
val filename: ScallopOption[String] = opt[String](required = true)
val numFiles: ScallopOption[Int] = opt[Int](required = false)
val showMatchingResults: ScallopOption[Boolean] = opt[Boolean](required = false)
}
addSubcommand(generateData)
addSubcommand(generateQueries)
Expand Down
20 changes: 14 additions & 6 deletions fuzz-testing/src/main/scala/org/apache/comet/fuzz/QueryGen.scala
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ object QueryGen {
// TODO add explicit casts
// TODO add unary and binary arithmetic expressions
// TODO add IF and CASE WHEN expressions
// TODO support nested expressions
}
if (!uniqueQueries.contains(sql)) {
uniqueQueries += sql
Expand Down Expand Up @@ -134,10 +135,14 @@ object QueryGen {
.map(_ => Utils.randomChoice(table.columns, r))
val groupingCols = Range(0, 2).map(_ => Utils.randomChoice(table.columns, r))
if (groupingCols.isEmpty) {
s"SELECT ${func.name}(${args.mkString(", ")}) FROM $tableName"
s"SELECT ${args.mkString(", ")}, ${func.name}(${args.mkString(", ")}) AS x " +
s"FROM $tableName" +
s"ORDER BY ${args.mkString(", ")}"
} else {
s"SELECT ${groupingCols.mkString(", ")}, ${func.name}(${args.mkString(", ")}) " +
s"FROM $tableName GROUP BY ${groupingCols.mkString(",")}"
s"FROM $tableName " +
s"GROUP BY ${groupingCols.mkString(",")}" +
s"ORDER BY ${groupingCols.mkString(", ")}"
}
}

Expand All @@ -150,7 +155,9 @@ object QueryGen {
// TODO support using literals as well as columns
.map(_ => Utils.randomChoice(table.columns, r))

s"SELECT ${func.name}(${args.mkString(", ")}) FROM $tableName"
s"SELECT ${args.mkString(", ")}, ${func.name}(${args.mkString(", ")}) AS x " +
s"FROM $tableName " +
s"ORDER BY ${args.mkString(", ")}"
}

private def generateJoin(r: Random, spark: SparkSession, numFiles: Int): String = {
Expand All @@ -169,9 +176,10 @@ object QueryGen {
val joinType = Utils.randomWeightedChoice(joinTypes)

"SELECT * " +
s"FROM ${leftTableName} " +
s"${joinType} JOIN ${rightTableName} " +
s"ON ${leftTableName}.${leftCol} = ${rightTableName}.${rightCol};"
s"FROM $leftTableName " +
s"$joinType JOIN $rightTableName " +
s"ON $leftTableName.$leftCol = $rightTableName.$rightCol " +
s"ORDER BY $leftCol;"
}

}
Expand Down
82 changes: 33 additions & 49 deletions fuzz-testing/src/main/scala/org/apache/comet/fuzz/QueryRunner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -62,72 +62,49 @@ object QueryRunner {
spark.conf.set("spark.comet.enabled", "false")
val df = spark.sql(sql)
val sparkRows = df.collect()

// TODO for now we sort the output to make this deterministic, but this means
// that we are never testing Comet's sort for correctness
val sparkRowsAsStrings = sparkRows.map(_.toString()).sorted
val sparkResult = sparkRowsAsStrings.mkString("\n")

val sparkPlan = df.queryExecution.executedPlan.toString

w.write(s"## $sql\n\n")

try {
spark.conf.set("spark.comet.enabled", "true")
val df = spark.sql(sql)
val cometRows = df.collect()
// TODO for now we sort the output to make this deterministic, but this means
// that we are never testing Comet's sort for correctness
val cometRowsAsStrings = cometRows.map(_.toString()).sorted
val cometResult = cometRowsAsStrings.mkString("\n")
val cometPlan = df.queryExecution.executedPlan.toString

if (sparkResult == cometResult) {
w.write(s"Spark and Comet produce the same results (${cometRows.length} rows).\n")
if (showMatchingResults) {
w.write("### Spark Plan\n")
w.write(s"```\n$sparkPlan\n```\n")

w.write("### Comet Plan\n")
w.write(s"```\n$cometPlan\n```\n")

w.write("### Query Result\n")
w.write("```\n")
w.write(s"$cometResult\n")
w.write("```\n\n")
}
} else {
w.write("[ERROR] Spark and Comet produced different results.\n")

w.write("### Spark Plan\n")
w.write(s"```\n$sparkPlan\n```\n")

w.write("### Comet Plan\n")
w.write(s"```\n$cometPlan\n```\n")

w.write("### Results \n")

w.write(
s"Spark produced ${sparkRows.length} rows and " +
s"Comet produced ${cometRows.length} rows.\n")

if (sparkRows.length == cometRows.length) {
var i = 0
while (i < sparkRows.length) {
if (sparkRowsAsStrings(i) != cometRowsAsStrings(i)) {
if (sparkRows.length == cometRows.length) {
var i = 0
while (i < sparkRows.length) {
val l = sparkRows(i)
val r = cometRows(i)
assert(l.length == r.length)
for (j <- 0 until l.length) {
val same = (l(j), r(j)) match {
case (a: Float, b: Float) => (a - b).abs <= 0.000001f
case (a: Double, b: Double) => (a - b).abs <= 0.000001
case (a, b) => a == b
}
if (!same) {
w.write(s"## $sql\n\n")
showPlans(w, sparkPlan, cometPlan)
w.write(s"First difference at row $i:\n")
w.write("Spark: `" + sparkRowsAsStrings(i) + "`\n")
w.write("Comet: `" + cometRowsAsStrings(i) + "`\n")
w.write("Spark: `" + l.mkString(",") + "`\n")
w.write("Comet: `" + r.mkString(", ") + "`\n")
i = sparkRows.length
}
i += 1
}
i += 1
}
} else {
w.write(s"## $sql\n\n")
showPlans(w, sparkPlan, cometPlan)
w.write(
s"[ERROR] Spark produced ${sparkRows.length} rows and " +
s"Comet produced ${cometRows.length} rows.\n")
}
} catch {
case e: Exception =>
// the query worked in Spark but failed in Comet, so this is likely a bug in Comet
w.write(s"Query failed in Comet: ${e.getMessage}\n")
w.write(s"## $sql\n\n")
w.write(s"[ERROR] Query failed in Comet: ${e.getMessage}\n")
}

// flush after every query so that results are saved in the event of the driver crashing
Expand All @@ -149,4 +126,11 @@ object QueryRunner {
}
}

private def showPlans(w: BufferedWriter, sparkPlan: String, cometPlan: String): Unit = {
w.write("### Spark Plan\n")
w.write(s"```\n$sparkPlan\n```\n")
w.write("### Comet Plan\n")
w.write(s"```\n$cometPlan\n```\n")
}

}

0 comments on commit a2dce70

Please sign in to comment.