diff --git a/dev/scalastyle-config.xml b/dev/scalastyle-config.xml
index 92ba690ad..9de6df51e 100644
--- a/dev/scalastyle-config.xml
+++ b/dev/scalastyle-config.xml
@@ -242,7 +242,7 @@ This file is divided into 3 sections:
java,scala,org,apache,3rdParty,comet
javax?\..*
scala\..*
- org\.(?!apache\.comet).*
+ org\.(?!apache).*
org\.apache\.(?!comet).*
(?!(javax?\.|scala\.|org\.apache\.comet\.)).*
org\.apache\.comet\..*
diff --git a/fuzz-testing/.gitignore b/fuzz-testing/.gitignore
new file mode 100644
index 000000000..570ff02a7
--- /dev/null
+++ b/fuzz-testing/.gitignore
@@ -0,0 +1,6 @@
+.idea
+target
+spark-warehouse
+queries.sql
+results*.md
+test*.parquet
\ No newline at end of file
diff --git a/fuzz-testing/README.md b/fuzz-testing/README.md
new file mode 100644
index 000000000..56af359f2
--- /dev/null
+++ b/fuzz-testing/README.md
@@ -0,0 +1,100 @@
+
+
+# Comet Fuzz
+
+Comet Fuzz is a standalone project for generating random data and queries and executing queries against Spark
+with Comet disabled and enabled and checking for incompatibilities.
+
+Although it is a simple tool it has already been useful in finding many bugs.
+
+Comet Fuzz is inspired by the [SparkFuzz](https://ir.cwi.nl/pub/30222) paper from Databricks and CWI.
+
+## Roadmap
+
+Planned areas of improvement:
+
+- Support for all data types, expressions, and operators supported by Comet
+- Explicit casts
+- Unary and binary arithmetic expressions
+- IF and CASE WHEN expressions
+- Complex (nested) expressions
+- Literal scalar values in queries
+- Add option to avoid grouping and sorting on floating-point columns
+- Improve join query support:
+ - Support joins without join keys
+ - Support composite join keys
+ - Support multiple join keys
+ - Support join conditions that use expressions
+
+## Usage
+
+Build the jar file first.
+
+```shell
+mvn package
+```
+
+Set appropriate values for `SPARK_HOME`, `SPARK_MASTER`, and `COMET_JAR` environment variables and then use
+`spark-submit` to run CometFuzz against a Spark cluster.
+
+### Generating Data Files
+
+```shell
+$SPARK_HOME/bin/spark-submit \
+ --master $SPARK_MASTER \
+ --class org.apache.comet.fuzz.Main \
+ target/comet-fuzz-spark3.4_2.12-0.1.0-SNAPSHOT-jar-with-dependencies.jar \
+ data --num-files=2 --num-rows=200 --num-columns=100
+```
+
+### Generating Queries
+
+Generate random queries that are based on the available test files.
+
+```shell
+$SPARK_HOME/bin/spark-submit \
+ --master $SPARK_MASTER \
+ --class org.apache.comet.fuzz.Main \
+ target/comet-fuzz-spark3.4_2.12-0.1.0-SNAPSHOT-jar-with-dependencies.jar \
+ queries --num-files=2 --num-queries=500
+```
+
+Note that the output filename is currently hard-coded as `queries.sql`
+
+### Execute Queries
+
+```shell
+$SPARK_HOME/bin/spark-submit \
+ --master $SPARK_MASTER \
+ --conf spark.sql.extensions=org.apache.comet.CometSparkSessionExtensions \
+ --conf spark.comet.enabled=true \
+ --conf spark.comet.exec.enabled=true \
+ --conf spark.comet.exec.all.enabled=true \
+ --conf spark.shuffle.manager=org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager \
+ --conf spark.comet.exec.shuffle.enabled=true \
+ --conf spark.comet.exec.shuffle.mode=auto \
+ --jars $COMET_JAR \
+ --driver-class-path $COMET_JAR \
+ --class org.apache.comet.fuzz.Main \
+ target/comet-fuzz-spark3.4_2.12-0.1.0-SNAPSHOT-jar-with-dependencies.jar \
+ run --num-files=2 --filename=queries.sql
+```
+
+Note that the output filename is currently hard-coded as `results-${System.currentTimeMillis()}.md`
diff --git a/fuzz-testing/pom.xml b/fuzz-testing/pom.xml
new file mode 100644
index 000000000..f69d959f9
--- /dev/null
+++ b/fuzz-testing/pom.xml
@@ -0,0 +1,105 @@
+
+
+
+ 4.0.0
+
+
+ org.apache.comet
+ comet-parent-spark${spark.version.short}_${scala.binary.version}
+ 0.1.0-SNAPSHOT
+ ../pom.xml
+
+
+ comet-fuzz-spark${spark.version.short}_${scala.binary.version}
+ comet-fuzz
+ http://maven.apache.org
+ jar
+
+
+
+ false
+
+
+
+
+ org.scala-lang
+ scala-library
+ ${scala.version}
+ provided
+
+
+ org.apache.spark
+ spark-sql_${scala.binary.version}
+ provided
+
+
+ org.rogach
+ scallop_${scala.binary.version}
+
+
+
+
+ src/main/scala
+
+
+ org.apache.maven.plugins
+ maven-compiler-plugin
+ 3.1
+
+
+ ${java.version}
+
+
+
+ net.alchim31.maven
+ scala-maven-plugin
+ 4.7.2
+
+
+
+ compile
+ testCompile
+
+
+
+
+
+ maven-assembly-plugin
+ 3.3.0
+
+
+ jar-with-dependencies
+
+
+
+
+ make-assembly
+ package
+
+ single
+
+
+
+
+
+
+
diff --git a/fuzz-testing/src/main/scala/org/apache/comet/fuzz/DataGen.scala b/fuzz-testing/src/main/scala/org/apache/comet/fuzz/DataGen.scala
new file mode 100644
index 000000000..47a6bd879
--- /dev/null
+++ b/fuzz-testing/src/main/scala/org/apache/comet/fuzz/DataGen.scala
@@ -0,0 +1,151 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.fuzz
+
+import java.math.{BigDecimal, RoundingMode}
+import java.nio.charset.Charset
+import java.sql.Timestamp
+
+import scala.util.Random
+
+import org.apache.spark.sql.{Row, SaveMode, SparkSession}
+import org.apache.spark.sql.types.{DataType, DataTypes, DecimalType, StructField, StructType}
+
+object DataGen {
+
+ def generateRandomFiles(
+ r: Random,
+ spark: SparkSession,
+ numFiles: Int,
+ numRows: Int,
+ numColumns: Int): Unit = {
+ for (i <- 0 until numFiles) {
+ generateRandomParquetFile(r, spark, s"test$i.parquet", numRows, numColumns)
+ }
+ }
+
+ def generateRandomParquetFile(
+ r: Random,
+ spark: SparkSession,
+ filename: String,
+ numRows: Int,
+ numColumns: Int): Unit = {
+
+ // generate schema using random data types
+ val fields = Range(0, numColumns)
+ .map(i => StructField(s"c$i", Utils.randomWeightedChoice(Meta.dataTypes), nullable = true))
+ val schema = StructType(fields)
+
+ // generate columnar data
+ val cols: Seq[Seq[Any]] = fields.map(f => generateColumn(r, f.dataType, numRows))
+
+ // convert to rows
+ val rows = Range(0, numRows).map(rowIndex => {
+ Row.fromSeq(cols.map(_(rowIndex)))
+ })
+
+ val df = spark.createDataFrame(spark.sparkContext.parallelize(rows), schema)
+ df.write.mode(SaveMode.Overwrite).parquet(filename)
+ }
+
+ def generateColumn(r: Random, dataType: DataType, numRows: Int): Seq[Any] = {
+ dataType match {
+ case DataTypes.BooleanType =>
+ generateColumn(r, DataTypes.LongType, numRows)
+ .map(_.asInstanceOf[Long].toShort)
+ .map(s => s % 2 == 0)
+ case DataTypes.ByteType =>
+ generateColumn(r, DataTypes.LongType, numRows).map(_.asInstanceOf[Long].toByte)
+ case DataTypes.ShortType =>
+ generateColumn(r, DataTypes.LongType, numRows).map(_.asInstanceOf[Long].toShort)
+ case DataTypes.IntegerType =>
+ generateColumn(r, DataTypes.LongType, numRows).map(_.asInstanceOf[Long].toInt)
+ case DataTypes.LongType =>
+ Range(0, numRows).map(_ => {
+ r.nextInt(50) match {
+ case 0 => null
+ case 1 => 0L
+ case 2 => Byte.MinValue.toLong
+ case 3 => Byte.MaxValue.toLong
+ case 4 => Short.MinValue.toLong
+ case 5 => Short.MaxValue.toLong
+ case 6 => Int.MinValue.toLong
+ case 7 => Int.MaxValue.toLong
+ case 8 => Long.MinValue
+ case 9 => Long.MaxValue
+ case _ => r.nextLong()
+ }
+ })
+ case DataTypes.FloatType =>
+ Range(0, numRows).map(_ => {
+ r.nextInt(20) match {
+ case 0 => null
+ case 1 => Float.NegativeInfinity
+ case 2 => Float.PositiveInfinity
+ case 3 => Float.MinValue
+ case 4 => Float.MaxValue
+ case 5 => 0.0f
+ case 6 => -0.0f
+ case _ => r.nextFloat()
+ }
+ })
+ case DataTypes.DoubleType =>
+ Range(0, numRows).map(_ => {
+ r.nextInt(20) match {
+ case 0 => null
+ case 1 => Double.NegativeInfinity
+ case 2 => Double.PositiveInfinity
+ case 3 => Double.MinValue
+ case 4 => Double.MaxValue
+ case 5 => 0.0
+ case 6 => -0.0
+ case _ => r.nextDouble()
+ }
+ })
+ case dt: DecimalType =>
+ Range(0, numRows).map(_ =>
+ new BigDecimal(r.nextDouble()).setScale(dt.scale, RoundingMode.HALF_UP))
+ case DataTypes.StringType =>
+ Range(0, numRows).map(_ => {
+ r.nextInt(10) match {
+ case 0 => null
+ case 1 => r.nextInt().toByte.toString
+ case 2 => r.nextLong().toString
+ case 3 => r.nextDouble().toString
+ case _ => r.nextString(8)
+ }
+ })
+ case DataTypes.BinaryType =>
+ generateColumn(r, DataTypes.StringType, numRows)
+ .map {
+ case x: String =>
+ x.getBytes(Charset.defaultCharset())
+ case _ =>
+ null
+ }
+ case DataTypes.DateType =>
+ Range(0, numRows).map(_ => new java.sql.Date(1716645600011L + r.nextInt()))
+ case DataTypes.TimestampType =>
+ Range(0, numRows).map(_ => new Timestamp(1716645600011L + r.nextInt()))
+ case _ => throw new IllegalStateException(s"Cannot generate data for $dataType yet")
+ }
+ }
+
+}
diff --git a/fuzz-testing/src/main/scala/org/apache/comet/fuzz/Main.scala b/fuzz-testing/src/main/scala/org/apache/comet/fuzz/Main.scala
new file mode 100644
index 000000000..799885d65
--- /dev/null
+++ b/fuzz-testing/src/main/scala/org/apache/comet/fuzz/Main.scala
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.fuzz
+
+import scala.util.Random
+
+import org.rogach.scallop.{ScallopConf, Subcommand}
+import org.rogach.scallop.ScallopOption
+
+import org.apache.spark.sql.SparkSession
+
+class Conf(arguments: Seq[String]) extends ScallopConf(arguments) {
+ object generateData extends Subcommand("data") {
+ val numFiles: ScallopOption[Int] = opt[Int](required = true)
+ val numRows: ScallopOption[Int] = opt[Int](required = true)
+ val numColumns: ScallopOption[Int] = opt[Int](required = true)
+ }
+ addSubcommand(generateData)
+ object generateQueries extends Subcommand("queries") {
+ val numFiles: ScallopOption[Int] = opt[Int](required = false)
+ val numQueries: ScallopOption[Int] = opt[Int](required = true)
+ }
+ addSubcommand(generateQueries)
+ object runQueries extends Subcommand("run") {
+ val filename: ScallopOption[String] = opt[String](required = true)
+ val numFiles: ScallopOption[Int] = opt[Int](required = false)
+ val showMatchingResults: ScallopOption[Boolean] = opt[Boolean](required = false)
+ }
+ addSubcommand(runQueries)
+ verify()
+}
+
+object Main {
+
+ lazy val spark: SparkSession = SparkSession
+ .builder()
+ .getOrCreate()
+
+ def main(args: Array[String]): Unit = {
+ val r = new Random(42)
+
+ val conf = new Conf(args.toIndexedSeq)
+ conf.subcommand match {
+ case Some(conf.generateData) =>
+ DataGen.generateRandomFiles(
+ r,
+ spark,
+ numFiles = conf.generateData.numFiles(),
+ numRows = conf.generateData.numRows(),
+ numColumns = conf.generateData.numColumns())
+ case Some(conf.generateQueries) =>
+ QueryGen.generateRandomQueries(
+ r,
+ spark,
+ numFiles = conf.generateQueries.numFiles(),
+ conf.generateQueries.numQueries())
+ case Some(conf.runQueries) =>
+ QueryRunner.runQueries(
+ spark,
+ conf.runQueries.numFiles(),
+ conf.runQueries.filename(),
+ conf.runQueries.showMatchingResults())
+ case _ =>
+ // scalastyle:off println
+ println("Invalid subcommand")
+ // scalastyle:on println
+ sys.exit(-1)
+ }
+ }
+}
diff --git a/fuzz-testing/src/main/scala/org/apache/comet/fuzz/Meta.scala b/fuzz-testing/src/main/scala/org/apache/comet/fuzz/Meta.scala
new file mode 100644
index 000000000..13ebbf9ed
--- /dev/null
+++ b/fuzz-testing/src/main/scala/org/apache/comet/fuzz/Meta.scala
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.fuzz
+
+import org.apache.spark.sql.types.DataType
+import org.apache.spark.sql.types.DataTypes
+
+object Meta {
+
+ val dataTypes: Seq[(DataType, Double)] = Seq(
+ (DataTypes.BooleanType, 0.1),
+ (DataTypes.ByteType, 0.2),
+ (DataTypes.ShortType, 0.2),
+ (DataTypes.IntegerType, 0.2),
+ (DataTypes.LongType, 0.2),
+ (DataTypes.FloatType, 0.2),
+ (DataTypes.DoubleType, 0.2),
+ (DataTypes.createDecimalType(10, 2), 0.2),
+ (DataTypes.DateType, 0.2),
+ (DataTypes.TimestampType, 0.2),
+ // TimestampNTZType only in Spark 3.4+
+ // (DataTypes.TimestampNTZType, 0.2),
+ (DataTypes.StringType, 0.2),
+ (DataTypes.BinaryType, 0.1))
+
+ val stringScalarFunc: Seq[Function] = Seq(
+ Function("substring", 3),
+ Function("coalesce", 1),
+ Function("starts_with", 2),
+ Function("ends_with", 2),
+ Function("contains", 2),
+ Function("ascii", 1),
+ Function("bit_length", 1),
+ Function("octet_length", 1),
+ Function("upper", 1),
+ Function("lower", 1),
+ Function("chr", 1),
+ Function("init_cap", 1),
+ Function("trim", 1),
+ Function("ltrim", 1),
+ Function("rtrim", 1),
+ Function("btrim", 1),
+ Function("concat_ws", 2),
+ Function("repeat", 2),
+ Function("length", 1),
+ Function("reverse", 1),
+ Function("in_str", 2),
+ Function("replace", 2),
+ Function("translate", 2))
+
+ val dateScalarFunc: Seq[Function] =
+ Seq(Function("year", 1), Function("hour", 1), Function("minute", 1), Function("second", 1))
+
+ val mathScalarFunc: Seq[Function] = Seq(
+ Function("abs", 1),
+ Function("acos", 1),
+ Function("asin", 1),
+ Function("atan", 1),
+ Function("Atan2", 1),
+ Function("Cos", 1),
+ Function("Exp", 2),
+ Function("Ln", 1),
+ Function("Log10", 1),
+ Function("Log2", 1),
+ Function("Pow", 2),
+ Function("Round", 1),
+ Function("Signum", 1),
+ Function("Sin", 1),
+ Function("Sqrt", 1),
+ Function("Tan", 1),
+ Function("Ceil", 1),
+ Function("Floor", 1))
+
+ val scalarFunc: Seq[Function] = stringScalarFunc ++ dateScalarFunc ++ mathScalarFunc
+
+ val aggFunc: Seq[Function] = Seq(
+ Function("min", 1),
+ Function("max", 1),
+ Function("count", 1),
+ Function("avg", 1),
+ Function("sum", 1),
+ Function("first", 1),
+ Function("last", 1),
+ Function("var_pop", 1),
+ Function("var_samp", 1),
+ Function("covar_pop", 1),
+ Function("covar_samp", 1),
+ Function("stddev_pop", 1),
+ Function("stddev_samp", 1),
+ Function("corr", 2))
+
+}
diff --git a/fuzz-testing/src/main/scala/org/apache/comet/fuzz/QueryGen.scala b/fuzz-testing/src/main/scala/org/apache/comet/fuzz/QueryGen.scala
new file mode 100644
index 000000000..1daa26200
--- /dev/null
+++ b/fuzz-testing/src/main/scala/org/apache/comet/fuzz/QueryGen.scala
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.fuzz
+
+import java.io.{BufferedWriter, FileWriter}
+
+import scala.collection.mutable
+import scala.util.Random
+
+import org.apache.spark.sql.SparkSession
+
+object QueryGen {
+
+ def generateRandomQueries(
+ r: Random,
+ spark: SparkSession,
+ numFiles: Int,
+ numQueries: Int): Unit = {
+ for (i <- 0 until numFiles) {
+ spark.read.parquet(s"test$i.parquet").createTempView(s"test$i")
+ }
+
+ val w = new BufferedWriter(new FileWriter("queries.sql"))
+
+ val uniqueQueries = mutable.HashSet[String]()
+
+ for (_ <- 0 until numQueries) {
+ val sql = r.nextInt().abs % 3 match {
+ case 0 => generateJoin(r, spark, numFiles)
+ case 1 => generateAggregate(r, spark, numFiles)
+ case 2 => generateScalar(r, spark, numFiles)
+ }
+ if (!uniqueQueries.contains(sql)) {
+ uniqueQueries += sql
+ w.write(sql + "\n")
+ }
+ }
+ w.close()
+ }
+
+ private def generateAggregate(r: Random, spark: SparkSession, numFiles: Int): String = {
+ val tableName = s"test${r.nextInt(numFiles)}"
+ val table = spark.table(tableName)
+
+ val func = Utils.randomChoice(Meta.aggFunc, r)
+ val args = Range(0, func.num_args)
+ .map(_ => Utils.randomChoice(table.columns, r))
+
+ val groupingCols = Range(0, 2).map(_ => Utils.randomChoice(table.columns, r))
+
+ if (groupingCols.isEmpty) {
+ s"SELECT ${args.mkString(", ")}, ${func.name}(${args.mkString(", ")}) AS x " +
+ s"FROM $tableName " +
+ s"ORDER BY ${args.mkString(", ")};"
+ } else {
+ s"SELECT ${groupingCols.mkString(", ")}, ${func.name}(${args.mkString(", ")}) " +
+ s"FROM $tableName " +
+ s"GROUP BY ${groupingCols.mkString(",")} " +
+ s"ORDER BY ${groupingCols.mkString(", ")};"
+ }
+ }
+
+ private def generateScalar(r: Random, spark: SparkSession, numFiles: Int): String = {
+ val tableName = s"test${r.nextInt(numFiles)}"
+ val table = spark.table(tableName)
+
+ val func = Utils.randomChoice(Meta.scalarFunc, r)
+ val args = Range(0, func.num_args)
+ .map(_ => Utils.randomChoice(table.columns, r))
+
+ // Example SELECT c0, log(c0) as x FROM test0
+ s"SELECT ${args.mkString(", ")}, ${func.name}(${args.mkString(", ")}) AS x " +
+ s"FROM $tableName " +
+ s"ORDER BY ${args.mkString(", ")};"
+ }
+
+ private def generateJoin(r: Random, spark: SparkSession, numFiles: Int): String = {
+ val leftTableName = s"test${r.nextInt(numFiles)}"
+ val rightTableName = s"test${r.nextInt(numFiles)}"
+ val leftTable = spark.table(leftTableName)
+ val rightTable = spark.table(rightTableName)
+
+ val leftCol = Utils.randomChoice(leftTable.columns, r)
+ val rightCol = Utils.randomChoice(rightTable.columns, r)
+
+ val joinTypes = Seq(("INNER", 0.4), ("LEFT", 0.3), ("RIGHT", 0.3))
+ val joinType = Utils.randomWeightedChoice(joinTypes)
+
+ val leftColProjection = leftTable.columns.map(c => s"l.$c").mkString(", ")
+ val rightColProjection = rightTable.columns.map(c => s"r.$c").mkString(", ")
+ "SELECT " +
+ s"$leftColProjection, " +
+ s"$rightColProjection " +
+ s"FROM $leftTableName l " +
+ s"$joinType JOIN $rightTableName r " +
+ s"ON l.$leftCol = r.$rightCol " +
+ "ORDER BY " +
+ s"$leftColProjection, " +
+ s"$rightColProjection;"
+ }
+
+}
+
+case class Function(name: String, num_args: Int)
diff --git a/fuzz-testing/src/main/scala/org/apache/comet/fuzz/QueryRunner.scala b/fuzz-testing/src/main/scala/org/apache/comet/fuzz/QueryRunner.scala
new file mode 100644
index 000000000..49f9fc3bd
--- /dev/null
+++ b/fuzz-testing/src/main/scala/org/apache/comet/fuzz/QueryRunner.scala
@@ -0,0 +1,170 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.fuzz
+
+import java.io.{BufferedWriter, FileWriter}
+
+import scala.io.Source
+
+import org.apache.spark.sql.{Row, SparkSession}
+
+object QueryRunner {
+
+ def runQueries(
+ spark: SparkSession,
+ numFiles: Int,
+ filename: String,
+ showMatchingResults: Boolean,
+ showFailedSparkQueries: Boolean = false): Unit = {
+
+ val outputFilename = s"results-${System.currentTimeMillis()}.md"
+ // scalastyle:off println
+ println(s"Writing results to $outputFilename")
+ // scalastyle:on println
+
+ val w = new BufferedWriter(new FileWriter(outputFilename))
+
+ // register input files
+ for (i <- 0 until numFiles) {
+ val table = spark.read.parquet(s"test$i.parquet")
+ val tableName = s"test$i"
+ table.createTempView(tableName)
+ w.write(
+ s"Created table $tableName with schema:\n\t" +
+ s"${table.schema.fields.map(f => s"${f.name}: ${f.dataType}").mkString("\n\t")}\n\n")
+ }
+
+ val querySource = Source.fromFile(filename)
+ try {
+ querySource
+ .getLines()
+ .foreach(sql => {
+
+ try {
+ // execute with Spark
+ spark.conf.set("spark.comet.enabled", "false")
+ val df = spark.sql(sql)
+ val sparkRows = df.collect()
+ val sparkPlan = df.queryExecution.executedPlan.toString
+
+ try {
+ spark.conf.set("spark.comet.enabled", "true")
+ val df = spark.sql(sql)
+ val cometRows = df.collect()
+ val cometPlan = df.queryExecution.executedPlan.toString
+
+ if (sparkRows.length == cometRows.length) {
+ var i = 0
+ while (i < sparkRows.length) {
+ val l = sparkRows(i)
+ val r = cometRows(i)
+ assert(l.length == r.length)
+ for (j <- 0 until l.length) {
+ val same = (l(j), r(j)) match {
+ case (a: Float, b: Float) if a.isInfinity => b.isInfinity
+ case (a: Float, b: Float) if a.isNaN => b.isNaN
+ case (a: Float, b: Float) => (a - b).abs <= 0.000001f
+ case (a: Double, b: Double) if a.isInfinity => b.isInfinity
+ case (a: Double, b: Double) if a.isNaN => b.isNaN
+ case (a: Double, b: Double) => (a - b).abs <= 0.000001
+ case (a: Array[Byte], b: Array[Byte]) => a.sameElements(b)
+ case (a, b) => a == b
+ }
+ if (!same) {
+ showSQL(w, sql)
+ showPlans(w, sparkPlan, cometPlan)
+ w.write(s"First difference at row $i:\n")
+ w.write("Spark: `" + formatRow(l) + "`\n")
+ w.write("Comet: `" + formatRow(r) + "`\n")
+ i = sparkRows.length
+ }
+ }
+ i += 1
+ }
+ } else {
+ showSQL(w, sql)
+ showPlans(w, sparkPlan, cometPlan)
+ w.write(
+ s"[ERROR] Spark produced ${sparkRows.length} rows and " +
+ s"Comet produced ${cometRows.length} rows.\n")
+ }
+ } catch {
+ case e: Exception =>
+ // the query worked in Spark but failed in Comet, so this is likely a bug in Comet
+ showSQL(w, sql)
+ w.write(s"[ERROR] Query failed in Comet: ${e.getMessage}\n")
+ }
+
+ // flush after every query so that results are saved in the event of the driver crashing
+ w.flush()
+
+ } catch {
+ case e: Exception =>
+ // we expect many generated queries to be invalid
+ if (showFailedSparkQueries) {
+ showSQL(w, sql)
+ w.write(s"Query failed in Spark: ${e.getMessage}\n")
+ }
+ }
+ })
+
+ } finally {
+ w.close()
+ querySource.close()
+ }
+ }
+
+ private def formatRow(row: Row): String = {
+ row.toSeq
+ .map {
+ case v: Array[Byte] => v.mkString
+ case other => other.toString
+ }
+ .mkString(",")
+ }
+
+ private def showSQL(w: BufferedWriter, sql: String, maxLength: Int = 120): Unit = {
+ w.write("## SQL\n")
+ w.write("```\n")
+ val words = sql.split(" ")
+ val currentLine = new StringBuilder
+ for (word <- words) {
+ if (currentLine.length + word.length + 1 > maxLength) {
+ w.write(currentLine.toString.trim)
+ w.write("\n")
+ currentLine.setLength(0)
+ }
+ currentLine.append(word).append(" ")
+ }
+ if (currentLine.nonEmpty) {
+ w.write(currentLine.toString.trim)
+ w.write("\n")
+ }
+ w.write("```\n")
+ }
+
+ private def showPlans(w: BufferedWriter, sparkPlan: String, cometPlan: String): Unit = {
+ w.write("### Spark Plan\n")
+ w.write(s"```\n$sparkPlan\n```\n")
+ w.write("### Comet Plan\n")
+ w.write(s"```\n$cometPlan\n```\n")
+ }
+
+}
diff --git a/fuzz-testing/src/main/scala/org/apache/comet/fuzz/Utils.scala b/fuzz-testing/src/main/scala/org/apache/comet/fuzz/Utils.scala
new file mode 100644
index 000000000..19f9695a9
--- /dev/null
+++ b/fuzz-testing/src/main/scala/org/apache/comet/fuzz/Utils.scala
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.fuzz
+
+import scala.util.Random
+
+object Utils {
+
+ def randomChoice[T](list: Seq[T], r: Random): T = {
+ list(r.nextInt(list.length))
+ }
+
+ def randomWeightedChoice[T](valuesWithWeights: Seq[(T, Double)]): T = {
+ val totalWeight = valuesWithWeights.map(_._2).sum
+ val randomValue = Random.nextDouble() * totalWeight
+ var cumulativeWeight = 0.0
+
+ for ((value, weight) <- valuesWithWeights) {
+ cumulativeWeight += weight
+ if (cumulativeWeight >= randomValue) {
+ return value
+ }
+ }
+
+ // If for some reason the loop doesn't return, return the last value
+ valuesWithWeights.last._1
+ }
+
+}
diff --git a/pom.xml b/pom.xml
index 0ec834982..8c322bae0 100644
--- a/pom.xml
+++ b/pom.xml
@@ -33,6 +33,7 @@ under the License.
common
spark
spark-integration
+ fuzz-testing
@@ -409,6 +410,12 @@ under the License.
test
+
+ org.rogach
+ scallop_${scala.binary.version}
+ 5.1.0
+
+