diff --git a/dev/scalastyle-config.xml b/dev/scalastyle-config.xml index 92ba690ad..9de6df51e 100644 --- a/dev/scalastyle-config.xml +++ b/dev/scalastyle-config.xml @@ -242,7 +242,7 @@ This file is divided into 3 sections: java,scala,org,apache,3rdParty,comet javax?\..* scala\..* - org\.(?!apache\.comet).* + org\.(?!apache).* org\.apache\.(?!comet).* (?!(javax?\.|scala\.|org\.apache\.comet\.)).* org\.apache\.comet\..* diff --git a/fuzz-testing/.gitignore b/fuzz-testing/.gitignore new file mode 100644 index 000000000..570ff02a7 --- /dev/null +++ b/fuzz-testing/.gitignore @@ -0,0 +1,6 @@ +.idea +target +spark-warehouse +queries.sql +results*.md +test*.parquet \ No newline at end of file diff --git a/fuzz-testing/README.md b/fuzz-testing/README.md new file mode 100644 index 000000000..56af359f2 --- /dev/null +++ b/fuzz-testing/README.md @@ -0,0 +1,100 @@ + + +# Comet Fuzz + +Comet Fuzz is a standalone project for generating random data and queries and executing queries against Spark +with Comet disabled and enabled and checking for incompatibilities. + +Although it is a simple tool it has already been useful in finding many bugs. + +Comet Fuzz is inspired by the [SparkFuzz](https://ir.cwi.nl/pub/30222) paper from Databricks and CWI. + +## Roadmap + +Planned areas of improvement: + +- Support for all data types, expressions, and operators supported by Comet +- Explicit casts +- Unary and binary arithmetic expressions +- IF and CASE WHEN expressions +- Complex (nested) expressions +- Literal scalar values in queries +- Add option to avoid grouping and sorting on floating-point columns +- Improve join query support: + - Support joins without join keys + - Support composite join keys + - Support multiple join keys + - Support join conditions that use expressions + +## Usage + +Build the jar file first. + +```shell +mvn package +``` + +Set appropriate values for `SPARK_HOME`, `SPARK_MASTER`, and `COMET_JAR` environment variables and then use +`spark-submit` to run CometFuzz against a Spark cluster. + +### Generating Data Files + +```shell +$SPARK_HOME/bin/spark-submit \ + --master $SPARK_MASTER \ + --class org.apache.comet.fuzz.Main \ + target/comet-fuzz-spark3.4_2.12-0.1.0-SNAPSHOT-jar-with-dependencies.jar \ + data --num-files=2 --num-rows=200 --num-columns=100 +``` + +### Generating Queries + +Generate random queries that are based on the available test files. + +```shell +$SPARK_HOME/bin/spark-submit \ + --master $SPARK_MASTER \ + --class org.apache.comet.fuzz.Main \ + target/comet-fuzz-spark3.4_2.12-0.1.0-SNAPSHOT-jar-with-dependencies.jar \ + queries --num-files=2 --num-queries=500 +``` + +Note that the output filename is currently hard-coded as `queries.sql` + +### Execute Queries + +```shell +$SPARK_HOME/bin/spark-submit \ + --master $SPARK_MASTER \ + --conf spark.sql.extensions=org.apache.comet.CometSparkSessionExtensions \ + --conf spark.comet.enabled=true \ + --conf spark.comet.exec.enabled=true \ + --conf spark.comet.exec.all.enabled=true \ + --conf spark.shuffle.manager=org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager \ + --conf spark.comet.exec.shuffle.enabled=true \ + --conf spark.comet.exec.shuffle.mode=auto \ + --jars $COMET_JAR \ + --driver-class-path $COMET_JAR \ + --class org.apache.comet.fuzz.Main \ + target/comet-fuzz-spark3.4_2.12-0.1.0-SNAPSHOT-jar-with-dependencies.jar \ + run --num-files=2 --filename=queries.sql +``` + +Note that the output filename is currently hard-coded as `results-${System.currentTimeMillis()}.md` diff --git a/fuzz-testing/pom.xml b/fuzz-testing/pom.xml new file mode 100644 index 000000000..f69d959f9 --- /dev/null +++ b/fuzz-testing/pom.xml @@ -0,0 +1,105 @@ + + + + 4.0.0 + + + org.apache.comet + comet-parent-spark${spark.version.short}_${scala.binary.version} + 0.1.0-SNAPSHOT + ../pom.xml + + + comet-fuzz-spark${spark.version.short}_${scala.binary.version} + comet-fuzz + http://maven.apache.org + jar + + + + false + + + + + org.scala-lang + scala-library + ${scala.version} + provided + + + org.apache.spark + spark-sql_${scala.binary.version} + provided + + + org.rogach + scallop_${scala.binary.version} + + + + + src/main/scala + + + org.apache.maven.plugins + maven-compiler-plugin + 3.1 + + ${java.version} + ${java.version} + + + + net.alchim31.maven + scala-maven-plugin + 4.7.2 + + + + compile + testCompile + + + + + + maven-assembly-plugin + 3.3.0 + + + jar-with-dependencies + + + + + make-assembly + package + + single + + + + + + + diff --git a/fuzz-testing/src/main/scala/org/apache/comet/fuzz/DataGen.scala b/fuzz-testing/src/main/scala/org/apache/comet/fuzz/DataGen.scala new file mode 100644 index 000000000..47a6bd879 --- /dev/null +++ b/fuzz-testing/src/main/scala/org/apache/comet/fuzz/DataGen.scala @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.fuzz + +import java.math.{BigDecimal, RoundingMode} +import java.nio.charset.Charset +import java.sql.Timestamp + +import scala.util.Random + +import org.apache.spark.sql.{Row, SaveMode, SparkSession} +import org.apache.spark.sql.types.{DataType, DataTypes, DecimalType, StructField, StructType} + +object DataGen { + + def generateRandomFiles( + r: Random, + spark: SparkSession, + numFiles: Int, + numRows: Int, + numColumns: Int): Unit = { + for (i <- 0 until numFiles) { + generateRandomParquetFile(r, spark, s"test$i.parquet", numRows, numColumns) + } + } + + def generateRandomParquetFile( + r: Random, + spark: SparkSession, + filename: String, + numRows: Int, + numColumns: Int): Unit = { + + // generate schema using random data types + val fields = Range(0, numColumns) + .map(i => StructField(s"c$i", Utils.randomWeightedChoice(Meta.dataTypes), nullable = true)) + val schema = StructType(fields) + + // generate columnar data + val cols: Seq[Seq[Any]] = fields.map(f => generateColumn(r, f.dataType, numRows)) + + // convert to rows + val rows = Range(0, numRows).map(rowIndex => { + Row.fromSeq(cols.map(_(rowIndex))) + }) + + val df = spark.createDataFrame(spark.sparkContext.parallelize(rows), schema) + df.write.mode(SaveMode.Overwrite).parquet(filename) + } + + def generateColumn(r: Random, dataType: DataType, numRows: Int): Seq[Any] = { + dataType match { + case DataTypes.BooleanType => + generateColumn(r, DataTypes.LongType, numRows) + .map(_.asInstanceOf[Long].toShort) + .map(s => s % 2 == 0) + case DataTypes.ByteType => + generateColumn(r, DataTypes.LongType, numRows).map(_.asInstanceOf[Long].toByte) + case DataTypes.ShortType => + generateColumn(r, DataTypes.LongType, numRows).map(_.asInstanceOf[Long].toShort) + case DataTypes.IntegerType => + generateColumn(r, DataTypes.LongType, numRows).map(_.asInstanceOf[Long].toInt) + case DataTypes.LongType => + Range(0, numRows).map(_ => { + r.nextInt(50) match { + case 0 => null + case 1 => 0L + case 2 => Byte.MinValue.toLong + case 3 => Byte.MaxValue.toLong + case 4 => Short.MinValue.toLong + case 5 => Short.MaxValue.toLong + case 6 => Int.MinValue.toLong + case 7 => Int.MaxValue.toLong + case 8 => Long.MinValue + case 9 => Long.MaxValue + case _ => r.nextLong() + } + }) + case DataTypes.FloatType => + Range(0, numRows).map(_ => { + r.nextInt(20) match { + case 0 => null + case 1 => Float.NegativeInfinity + case 2 => Float.PositiveInfinity + case 3 => Float.MinValue + case 4 => Float.MaxValue + case 5 => 0.0f + case 6 => -0.0f + case _ => r.nextFloat() + } + }) + case DataTypes.DoubleType => + Range(0, numRows).map(_ => { + r.nextInt(20) match { + case 0 => null + case 1 => Double.NegativeInfinity + case 2 => Double.PositiveInfinity + case 3 => Double.MinValue + case 4 => Double.MaxValue + case 5 => 0.0 + case 6 => -0.0 + case _ => r.nextDouble() + } + }) + case dt: DecimalType => + Range(0, numRows).map(_ => + new BigDecimal(r.nextDouble()).setScale(dt.scale, RoundingMode.HALF_UP)) + case DataTypes.StringType => + Range(0, numRows).map(_ => { + r.nextInt(10) match { + case 0 => null + case 1 => r.nextInt().toByte.toString + case 2 => r.nextLong().toString + case 3 => r.nextDouble().toString + case _ => r.nextString(8) + } + }) + case DataTypes.BinaryType => + generateColumn(r, DataTypes.StringType, numRows) + .map { + case x: String => + x.getBytes(Charset.defaultCharset()) + case _ => + null + } + case DataTypes.DateType => + Range(0, numRows).map(_ => new java.sql.Date(1716645600011L + r.nextInt())) + case DataTypes.TimestampType => + Range(0, numRows).map(_ => new Timestamp(1716645600011L + r.nextInt())) + case _ => throw new IllegalStateException(s"Cannot generate data for $dataType yet") + } + } + +} diff --git a/fuzz-testing/src/main/scala/org/apache/comet/fuzz/Main.scala b/fuzz-testing/src/main/scala/org/apache/comet/fuzz/Main.scala new file mode 100644 index 000000000..799885d65 --- /dev/null +++ b/fuzz-testing/src/main/scala/org/apache/comet/fuzz/Main.scala @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.fuzz + +import scala.util.Random + +import org.rogach.scallop.{ScallopConf, Subcommand} +import org.rogach.scallop.ScallopOption + +import org.apache.spark.sql.SparkSession + +class Conf(arguments: Seq[String]) extends ScallopConf(arguments) { + object generateData extends Subcommand("data") { + val numFiles: ScallopOption[Int] = opt[Int](required = true) + val numRows: ScallopOption[Int] = opt[Int](required = true) + val numColumns: ScallopOption[Int] = opt[Int](required = true) + } + addSubcommand(generateData) + object generateQueries extends Subcommand("queries") { + val numFiles: ScallopOption[Int] = opt[Int](required = false) + val numQueries: ScallopOption[Int] = opt[Int](required = true) + } + addSubcommand(generateQueries) + object runQueries extends Subcommand("run") { + val filename: ScallopOption[String] = opt[String](required = true) + val numFiles: ScallopOption[Int] = opt[Int](required = false) + val showMatchingResults: ScallopOption[Boolean] = opt[Boolean](required = false) + } + addSubcommand(runQueries) + verify() +} + +object Main { + + lazy val spark: SparkSession = SparkSession + .builder() + .getOrCreate() + + def main(args: Array[String]): Unit = { + val r = new Random(42) + + val conf = new Conf(args.toIndexedSeq) + conf.subcommand match { + case Some(conf.generateData) => + DataGen.generateRandomFiles( + r, + spark, + numFiles = conf.generateData.numFiles(), + numRows = conf.generateData.numRows(), + numColumns = conf.generateData.numColumns()) + case Some(conf.generateQueries) => + QueryGen.generateRandomQueries( + r, + spark, + numFiles = conf.generateQueries.numFiles(), + conf.generateQueries.numQueries()) + case Some(conf.runQueries) => + QueryRunner.runQueries( + spark, + conf.runQueries.numFiles(), + conf.runQueries.filename(), + conf.runQueries.showMatchingResults()) + case _ => + // scalastyle:off println + println("Invalid subcommand") + // scalastyle:on println + sys.exit(-1) + } + } +} diff --git a/fuzz-testing/src/main/scala/org/apache/comet/fuzz/Meta.scala b/fuzz-testing/src/main/scala/org/apache/comet/fuzz/Meta.scala new file mode 100644 index 000000000..13ebbf9ed --- /dev/null +++ b/fuzz-testing/src/main/scala/org/apache/comet/fuzz/Meta.scala @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.fuzz + +import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.DataTypes + +object Meta { + + val dataTypes: Seq[(DataType, Double)] = Seq( + (DataTypes.BooleanType, 0.1), + (DataTypes.ByteType, 0.2), + (DataTypes.ShortType, 0.2), + (DataTypes.IntegerType, 0.2), + (DataTypes.LongType, 0.2), + (DataTypes.FloatType, 0.2), + (DataTypes.DoubleType, 0.2), + (DataTypes.createDecimalType(10, 2), 0.2), + (DataTypes.DateType, 0.2), + (DataTypes.TimestampType, 0.2), + // TimestampNTZType only in Spark 3.4+ + // (DataTypes.TimestampNTZType, 0.2), + (DataTypes.StringType, 0.2), + (DataTypes.BinaryType, 0.1)) + + val stringScalarFunc: Seq[Function] = Seq( + Function("substring", 3), + Function("coalesce", 1), + Function("starts_with", 2), + Function("ends_with", 2), + Function("contains", 2), + Function("ascii", 1), + Function("bit_length", 1), + Function("octet_length", 1), + Function("upper", 1), + Function("lower", 1), + Function("chr", 1), + Function("init_cap", 1), + Function("trim", 1), + Function("ltrim", 1), + Function("rtrim", 1), + Function("btrim", 1), + Function("concat_ws", 2), + Function("repeat", 2), + Function("length", 1), + Function("reverse", 1), + Function("in_str", 2), + Function("replace", 2), + Function("translate", 2)) + + val dateScalarFunc: Seq[Function] = + Seq(Function("year", 1), Function("hour", 1), Function("minute", 1), Function("second", 1)) + + val mathScalarFunc: Seq[Function] = Seq( + Function("abs", 1), + Function("acos", 1), + Function("asin", 1), + Function("atan", 1), + Function("Atan2", 1), + Function("Cos", 1), + Function("Exp", 2), + Function("Ln", 1), + Function("Log10", 1), + Function("Log2", 1), + Function("Pow", 2), + Function("Round", 1), + Function("Signum", 1), + Function("Sin", 1), + Function("Sqrt", 1), + Function("Tan", 1), + Function("Ceil", 1), + Function("Floor", 1)) + + val scalarFunc: Seq[Function] = stringScalarFunc ++ dateScalarFunc ++ mathScalarFunc + + val aggFunc: Seq[Function] = Seq( + Function("min", 1), + Function("max", 1), + Function("count", 1), + Function("avg", 1), + Function("sum", 1), + Function("first", 1), + Function("last", 1), + Function("var_pop", 1), + Function("var_samp", 1), + Function("covar_pop", 1), + Function("covar_samp", 1), + Function("stddev_pop", 1), + Function("stddev_samp", 1), + Function("corr", 2)) + +} diff --git a/fuzz-testing/src/main/scala/org/apache/comet/fuzz/QueryGen.scala b/fuzz-testing/src/main/scala/org/apache/comet/fuzz/QueryGen.scala new file mode 100644 index 000000000..1daa26200 --- /dev/null +++ b/fuzz-testing/src/main/scala/org/apache/comet/fuzz/QueryGen.scala @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.fuzz + +import java.io.{BufferedWriter, FileWriter} + +import scala.collection.mutable +import scala.util.Random + +import org.apache.spark.sql.SparkSession + +object QueryGen { + + def generateRandomQueries( + r: Random, + spark: SparkSession, + numFiles: Int, + numQueries: Int): Unit = { + for (i <- 0 until numFiles) { + spark.read.parquet(s"test$i.parquet").createTempView(s"test$i") + } + + val w = new BufferedWriter(new FileWriter("queries.sql")) + + val uniqueQueries = mutable.HashSet[String]() + + for (_ <- 0 until numQueries) { + val sql = r.nextInt().abs % 3 match { + case 0 => generateJoin(r, spark, numFiles) + case 1 => generateAggregate(r, spark, numFiles) + case 2 => generateScalar(r, spark, numFiles) + } + if (!uniqueQueries.contains(sql)) { + uniqueQueries += sql + w.write(sql + "\n") + } + } + w.close() + } + + private def generateAggregate(r: Random, spark: SparkSession, numFiles: Int): String = { + val tableName = s"test${r.nextInt(numFiles)}" + val table = spark.table(tableName) + + val func = Utils.randomChoice(Meta.aggFunc, r) + val args = Range(0, func.num_args) + .map(_ => Utils.randomChoice(table.columns, r)) + + val groupingCols = Range(0, 2).map(_ => Utils.randomChoice(table.columns, r)) + + if (groupingCols.isEmpty) { + s"SELECT ${args.mkString(", ")}, ${func.name}(${args.mkString(", ")}) AS x " + + s"FROM $tableName " + + s"ORDER BY ${args.mkString(", ")};" + } else { + s"SELECT ${groupingCols.mkString(", ")}, ${func.name}(${args.mkString(", ")}) " + + s"FROM $tableName " + + s"GROUP BY ${groupingCols.mkString(",")} " + + s"ORDER BY ${groupingCols.mkString(", ")};" + } + } + + private def generateScalar(r: Random, spark: SparkSession, numFiles: Int): String = { + val tableName = s"test${r.nextInt(numFiles)}" + val table = spark.table(tableName) + + val func = Utils.randomChoice(Meta.scalarFunc, r) + val args = Range(0, func.num_args) + .map(_ => Utils.randomChoice(table.columns, r)) + + // Example SELECT c0, log(c0) as x FROM test0 + s"SELECT ${args.mkString(", ")}, ${func.name}(${args.mkString(", ")}) AS x " + + s"FROM $tableName " + + s"ORDER BY ${args.mkString(", ")};" + } + + private def generateJoin(r: Random, spark: SparkSession, numFiles: Int): String = { + val leftTableName = s"test${r.nextInt(numFiles)}" + val rightTableName = s"test${r.nextInt(numFiles)}" + val leftTable = spark.table(leftTableName) + val rightTable = spark.table(rightTableName) + + val leftCol = Utils.randomChoice(leftTable.columns, r) + val rightCol = Utils.randomChoice(rightTable.columns, r) + + val joinTypes = Seq(("INNER", 0.4), ("LEFT", 0.3), ("RIGHT", 0.3)) + val joinType = Utils.randomWeightedChoice(joinTypes) + + val leftColProjection = leftTable.columns.map(c => s"l.$c").mkString(", ") + val rightColProjection = rightTable.columns.map(c => s"r.$c").mkString(", ") + "SELECT " + + s"$leftColProjection, " + + s"$rightColProjection " + + s"FROM $leftTableName l " + + s"$joinType JOIN $rightTableName r " + + s"ON l.$leftCol = r.$rightCol " + + "ORDER BY " + + s"$leftColProjection, " + + s"$rightColProjection;" + } + +} + +case class Function(name: String, num_args: Int) diff --git a/fuzz-testing/src/main/scala/org/apache/comet/fuzz/QueryRunner.scala b/fuzz-testing/src/main/scala/org/apache/comet/fuzz/QueryRunner.scala new file mode 100644 index 000000000..49f9fc3bd --- /dev/null +++ b/fuzz-testing/src/main/scala/org/apache/comet/fuzz/QueryRunner.scala @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.fuzz + +import java.io.{BufferedWriter, FileWriter} + +import scala.io.Source + +import org.apache.spark.sql.{Row, SparkSession} + +object QueryRunner { + + def runQueries( + spark: SparkSession, + numFiles: Int, + filename: String, + showMatchingResults: Boolean, + showFailedSparkQueries: Boolean = false): Unit = { + + val outputFilename = s"results-${System.currentTimeMillis()}.md" + // scalastyle:off println + println(s"Writing results to $outputFilename") + // scalastyle:on println + + val w = new BufferedWriter(new FileWriter(outputFilename)) + + // register input files + for (i <- 0 until numFiles) { + val table = spark.read.parquet(s"test$i.parquet") + val tableName = s"test$i" + table.createTempView(tableName) + w.write( + s"Created table $tableName with schema:\n\t" + + s"${table.schema.fields.map(f => s"${f.name}: ${f.dataType}").mkString("\n\t")}\n\n") + } + + val querySource = Source.fromFile(filename) + try { + querySource + .getLines() + .foreach(sql => { + + try { + // execute with Spark + spark.conf.set("spark.comet.enabled", "false") + val df = spark.sql(sql) + val sparkRows = df.collect() + val sparkPlan = df.queryExecution.executedPlan.toString + + try { + spark.conf.set("spark.comet.enabled", "true") + val df = spark.sql(sql) + val cometRows = df.collect() + val cometPlan = df.queryExecution.executedPlan.toString + + if (sparkRows.length == cometRows.length) { + var i = 0 + while (i < sparkRows.length) { + val l = sparkRows(i) + val r = cometRows(i) + assert(l.length == r.length) + for (j <- 0 until l.length) { + val same = (l(j), r(j)) match { + case (a: Float, b: Float) if a.isInfinity => b.isInfinity + case (a: Float, b: Float) if a.isNaN => b.isNaN + case (a: Float, b: Float) => (a - b).abs <= 0.000001f + case (a: Double, b: Double) if a.isInfinity => b.isInfinity + case (a: Double, b: Double) if a.isNaN => b.isNaN + case (a: Double, b: Double) => (a - b).abs <= 0.000001 + case (a: Array[Byte], b: Array[Byte]) => a.sameElements(b) + case (a, b) => a == b + } + if (!same) { + showSQL(w, sql) + showPlans(w, sparkPlan, cometPlan) + w.write(s"First difference at row $i:\n") + w.write("Spark: `" + formatRow(l) + "`\n") + w.write("Comet: `" + formatRow(r) + "`\n") + i = sparkRows.length + } + } + i += 1 + } + } else { + showSQL(w, sql) + showPlans(w, sparkPlan, cometPlan) + w.write( + s"[ERROR] Spark produced ${sparkRows.length} rows and " + + s"Comet produced ${cometRows.length} rows.\n") + } + } catch { + case e: Exception => + // the query worked in Spark but failed in Comet, so this is likely a bug in Comet + showSQL(w, sql) + w.write(s"[ERROR] Query failed in Comet: ${e.getMessage}\n") + } + + // flush after every query so that results are saved in the event of the driver crashing + w.flush() + + } catch { + case e: Exception => + // we expect many generated queries to be invalid + if (showFailedSparkQueries) { + showSQL(w, sql) + w.write(s"Query failed in Spark: ${e.getMessage}\n") + } + } + }) + + } finally { + w.close() + querySource.close() + } + } + + private def formatRow(row: Row): String = { + row.toSeq + .map { + case v: Array[Byte] => v.mkString + case other => other.toString + } + .mkString(",") + } + + private def showSQL(w: BufferedWriter, sql: String, maxLength: Int = 120): Unit = { + w.write("## SQL\n") + w.write("```\n") + val words = sql.split(" ") + val currentLine = new StringBuilder + for (word <- words) { + if (currentLine.length + word.length + 1 > maxLength) { + w.write(currentLine.toString.trim) + w.write("\n") + currentLine.setLength(0) + } + currentLine.append(word).append(" ") + } + if (currentLine.nonEmpty) { + w.write(currentLine.toString.trim) + w.write("\n") + } + w.write("```\n") + } + + private def showPlans(w: BufferedWriter, sparkPlan: String, cometPlan: String): Unit = { + w.write("### Spark Plan\n") + w.write(s"```\n$sparkPlan\n```\n") + w.write("### Comet Plan\n") + w.write(s"```\n$cometPlan\n```\n") + } + +} diff --git a/fuzz-testing/src/main/scala/org/apache/comet/fuzz/Utils.scala b/fuzz-testing/src/main/scala/org/apache/comet/fuzz/Utils.scala new file mode 100644 index 000000000..19f9695a9 --- /dev/null +++ b/fuzz-testing/src/main/scala/org/apache/comet/fuzz/Utils.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.fuzz + +import scala.util.Random + +object Utils { + + def randomChoice[T](list: Seq[T], r: Random): T = { + list(r.nextInt(list.length)) + } + + def randomWeightedChoice[T](valuesWithWeights: Seq[(T, Double)]): T = { + val totalWeight = valuesWithWeights.map(_._2).sum + val randomValue = Random.nextDouble() * totalWeight + var cumulativeWeight = 0.0 + + for ((value, weight) <- valuesWithWeights) { + cumulativeWeight += weight + if (cumulativeWeight >= randomValue) { + return value + } + } + + // If for some reason the loop doesn't return, return the last value + valuesWithWeights.last._1 + } + +} diff --git a/pom.xml b/pom.xml index 0ec834982..8c322bae0 100644 --- a/pom.xml +++ b/pom.xml @@ -33,6 +33,7 @@ under the License. common spark spark-integration + fuzz-testing @@ -409,6 +410,12 @@ under the License. test + + org.rogach + scallop_${scala.binary.version} + 5.1.0 + +