Skip to content

Commit

Permalink
spotless
Browse files Browse the repository at this point in the history
  • Loading branch information
andygrove committed May 26, 2024
1 parent 4409377 commit 34b7624
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 104 deletions.
30 changes: 18 additions & 12 deletions fuzz-testing/src/main/scala/org/apache/comet/fuzz/DataGen.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,33 +26,40 @@ import scala.util.Random
import org.apache.spark.sql.{Row, SaveMode, SparkSession}
import org.apache.spark.sql.types.{DataType, DataTypes, StructField, StructType}


object DataGen {

def generateRandomFiles(r: Random, spark: SparkSession, numFiles: Int, numRows: Int,
def generateRandomFiles(
r: Random,
spark: SparkSession,
numFiles: Int,
numRows: Int,
numColumns: Int): Unit = {
for (i <- 0 until numFiles) {
generateRandomParquetFile(r, spark, s"test$i.parquet", numRows, numColumns)
}
}

def generateRandomParquetFile(r: Random, spark: SparkSession, filename: String,
numRows: Int, numColumns: Int): Unit = {
def generateRandomParquetFile(
r: Random,
spark: SparkSession,
filename: String,
numRows: Int,
numColumns: Int): Unit = {

// TODO add examples of all supported types, including complex types
val dataTypes = Seq(
(DataTypes.ByteType, 0.2),
(DataTypes.ShortType, 0.2),
(DataTypes.IntegerType, 0.2),
(DataTypes.LongType, 0.2),
(DataTypes.FloatType, 0.2),
(DataTypes.DoubleType, 0.2),
(DataTypes.ShortType, 0.2),
(DataTypes.IntegerType, 0.2),
(DataTypes.LongType, 0.2),
(DataTypes.FloatType, 0.2),
(DataTypes.DoubleType, 0.2),
// TODO add support for all Comet supported types
// (DataTypes.createDecimalType(10,2), 0.2),
// (DataTypes.createDecimalType(10,0), 0.2),
// (DataTypes.createDecimalType(4,0), 0.2),
(DataTypes.DateType, 0.2),
(DataTypes.TimestampType, 0.2),
(DataTypes.DateType, 0.2),
(DataTypes.TimestampType, 0.2),
// (DataTypes.TimestampNTZType, 0.2),
(DataTypes.StringType, 0.2))

Expand Down Expand Up @@ -143,5 +150,4 @@ object DataGen {
}
}


}
21 changes: 15 additions & 6 deletions fuzz-testing/src/main/scala/org/apache/comet/fuzz/Main.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,24 @@ package org.apache.comet.fuzz

import scala.util.Random

import org.apache.spark.sql.SparkSession
import org.rogach.scallop.{ScallopConf, Subcommand}

import org.apache.spark.sql.SparkSession

class Conf(arguments: Seq[String]) extends ScallopConf(arguments) {
val generateData = new Subcommand("data") {
val generateData: generateData = new generateData
class generateData extends Subcommand("data") {
val numFiles = opt[Int](required = true)
val numRows = opt[Int](required = true)
val numColumns = opt[Int](required = true)
}
val generateQueries = new Subcommand("queries") {
val generateQueries: generateQueries = new generateQueries
class generateQueries extends Subcommand("queries") {
val numFiles = opt[Int](required = false)
val numQueries = opt[Int](required = true)
}
val runQueries = new Subcommand("run") {
val runQueries: runQueries = new runQueries
class runQueries extends Subcommand("run") {
val filename = opt[String](required = true)
val numFiles = opt[Int](required = false)
val showMatchingResults = opt[Boolean](required = false)
Expand All @@ -47,7 +51,8 @@ class Conf(arguments: Seq[String]) extends ScallopConf(arguments) {

object Main {

lazy val spark = SparkSession.builder()
lazy val spark: SparkSession = SparkSession
.builder()
.master("local[*]")
.getOrCreate()

Expand All @@ -57,7 +62,11 @@ object Main {
val conf = new Conf(args)
conf.subcommand match {
case Some(opt @ conf.generateData) =>
DataGen.generateRandomFiles(r, spark, numFiles = opt.numFiles(), numRows = opt.numRows(),
DataGen.generateRandomFiles(
r,
spark,
numFiles = opt.numFiles(),
numRows = opt.numRows(),
numColumns = opt.numColumns())
case Some(opt @ conf.generateQueries) =>
QueryGen.generateRandomQueries(r, spark, numFiles = opt.numFiles(), opt.numQueries())
Expand Down
22 changes: 10 additions & 12 deletions fuzz-testing/src/main/scala/org/apache/comet/fuzz/QueryGen.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,11 @@ import org.apache.spark.sql.SparkSession

object QueryGen {

def generateRandomQueries(r: Random, spark: SparkSession, numFiles: Int,
numQueries: Int): Unit = {
def generateRandomQueries(
r: Random,
spark: SparkSession,
numFiles: Int,
numQueries: Int): Unit = {
for (i <- 0 until numFiles) {
spark.read.parquet(s"test$i.parquet").createTempView(s"test$i")
}
Expand All @@ -55,7 +58,7 @@ object QueryGen {
w.close()
}

val scalarFunc = Seq(
val scalarFunc: Seq[Function] = Seq(
// string
Function("substring", 3),
Function("coalesce", 1),
Expand Down Expand Up @@ -103,10 +106,9 @@ object QueryGen {
Function("Sqrt", 1),
Function("Tan", 1),
Function("Ceil", 1),
Function("Floor", 1)
)
Function("Floor", 1))

val aggFunc = Seq(
val aggFunc: Seq[Function] = Seq(
Function("min", 1),
Function("max", 1),
Function("count", 1),
Expand Down Expand Up @@ -163,14 +165,10 @@ object QueryGen {
val leftCol = Utils.randomChoice(leftTable.columns, r)
val rightCol = Utils.randomChoice(rightTable.columns, r)

val joinTypes = Seq(
("INNER", 0.4),
("LEFT", 0.3),
("RIGHT", 0.3)
)
val joinTypes = Seq(("INNER", 0.4), ("LEFT", 0.3), ("RIGHT", 0.3))
val joinType = Utils.randomWeightedChoice(joinTypes)

s"SELECT * " +
"SELECT * " +
s"FROM ${leftTableName} " +
s"${joinType} JOIN ${rightTableName} " +
s"ON ${leftTableName}.${leftCol} = ${rightTableName}.${rightCol};"
Expand Down
152 changes: 78 additions & 74 deletions fuzz-testing/src/main/scala/org/apache/comet/fuzz/QueryRunner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ import org.apache.spark.sql.SparkSession
object QueryRunner {

def runQueries(
spark: SparkSession,
numFiles: Int,
filename: String,
showMatchingResults: Boolean,
showFailedSparkQueries: Boolean = false): Unit = {
spark: SparkSession,
numFiles: Int,
filename: String,
showMatchingResults: Boolean,
showFailedSparkQueries: Boolean = false): Unit = {

val outputFilename = s"results-${System.currentTimeMillis()}.md"
// scalastyle:off println
Expand All @@ -46,98 +46,102 @@ object QueryRunner {
val table = spark.read.parquet(s"test$i.parquet")
val tableName = s"test$i"
table.createTempView(tableName)
w.write(s"Created table $tableName with schema:\n\t" +
s"${table.schema.fields.map(f => s"${f.name}: ${f.dataType}").mkString("\n\t")}\n\n")
w.write(
s"Created table $tableName with schema:\n\t" +
s"${table.schema.fields.map(f => s"${f.name}: ${f.dataType}").mkString("\n\t")}\n\n")
}

val querySource = Source.fromFile(filename)
try {
querySource.getLines().foreach(sql => {

try {
// execute with Spark
spark.conf.set("spark.comet.enabled", "false")
val df = spark.sql(sql)
val sparkRows = df.collect()

// TODO for now we sort the output to make this deterministic, but this means
// that we are never testing Comet's sort for correctness
val sparkRowsAsStrings = sparkRows.map(_.toString()).sorted
val sparkResult = sparkRowsAsStrings.mkString("\n")

val sparkPlan = df.queryExecution.executedPlan.toString

w.write(s"## $sql\n\n")
querySource
.getLines()
.foreach(sql => {

try {
spark.conf.set("spark.comet.enabled", "true")
// execute with Spark
spark.conf.set("spark.comet.enabled", "false")
val df = spark.sql(sql)
val cometRows = df.collect()
val sparkRows = df.collect()

// TODO for now we sort the output to make this deterministic, but this means
// that we are never testing Comet's sort for correctness
val cometRowsAsStrings = cometRows.map(_.toString()).sorted
val cometResult = cometRowsAsStrings.mkString("\n")
val cometPlan = df.queryExecution.executedPlan.toString
val sparkRowsAsStrings = sparkRows.map(_.toString()).sorted
val sparkResult = sparkRowsAsStrings.mkString("\n")

val sparkPlan = df.queryExecution.executedPlan.toString

w.write(s"## $sql\n\n")

try {
spark.conf.set("spark.comet.enabled", "true")
val df = spark.sql(sql)
val cometRows = df.collect()
// TODO for now we sort the output to make this deterministic, but this means
// that we are never testing Comet's sort for correctness
val cometRowsAsStrings = cometRows.map(_.toString()).sorted
val cometResult = cometRowsAsStrings.mkString("\n")
val cometPlan = df.queryExecution.executedPlan.toString

if (sparkResult == cometResult) {
w.write(s"Spark and Comet produce the same results (${cometRows.length} rows).\n")
if (showMatchingResults) {
w.write("### Spark Plan\n")
w.write(s"```\n$sparkPlan\n```\n")

w.write("### Comet Plan\n")
w.write(s"```\n$cometPlan\n```\n")

w.write("### Query Result\n")
w.write("```\n")
w.write(s"$cometResult\n")
w.write("```\n\n")
}
} else {
w.write("[ERROR] Spark and Comet produced different results.\n")

if (sparkResult == cometResult) {
w.write(s"Spark and Comet produce the same results (${cometRows.length} rows).\n")
if (showMatchingResults) {
w.write("### Spark Plan\n")
w.write(s"```\n$sparkPlan\n```\n")

w.write("### Comet Plan\n")
w.write(s"```\n$cometPlan\n```\n")

w.write("### Query Result\n")
w.write("```\n")
w.write(s"$cometResult\n")
w.write("```\n\n")
}
} else {
w.write("[ERROR] Spark and Comet produced different results.\n")

w.write("### Spark Plan\n")
w.write(s"```\n$sparkPlan\n```\n")

w.write("### Comet Plan\n")
w.write(s"```\n$cometPlan\n```\n")

w.write("### Results \n")

w.write(s"Spark produced ${sparkRows.length} rows and " +
s"Comet produced ${cometRows.length} rows.\n")

if (sparkRows.length == cometRows.length) {
var i = 0
while (i < sparkRows.length) {
if (sparkRowsAsStrings(i) != cometRowsAsStrings(i)) {
w.write(s"First difference at row $i:\n")
w.write("Spark: `" + sparkRowsAsStrings(i) + "`\n")
w.write("Comet: `" + cometRowsAsStrings(i) + "`\n")
i = sparkRows.length
w.write("### Results \n")

w.write(
s"Spark produced ${sparkRows.length} rows and " +
s"Comet produced ${cometRows.length} rows.\n")

if (sparkRows.length == cometRows.length) {
var i = 0
while (i < sparkRows.length) {
if (sparkRowsAsStrings(i) != cometRowsAsStrings(i)) {
w.write(s"First difference at row $i:\n")
w.write("Spark: `" + sparkRowsAsStrings(i) + "`\n")
w.write("Comet: `" + cometRowsAsStrings(i) + "`\n")
i = sparkRows.length
}
i += 1
}
i += 1
}
}
} catch {
case e: Exception =>
// the query worked in Spark but failed in Comet, so this is likely a bug in Comet
w.write(s"Query failed in Comet: ${e.getMessage}\n")
}

// flush after every query so that results are saved in the event of the driver crashing
w.flush()

} catch {
case e: Exception =>
// the query worked in Spark but failed in Comet, so this is likely a bug in Comet
w.write(s"Query failed in Comet: ${e.getMessage}\n")
// we expect many generated queries to be invalid
if (showFailedSparkQueries) {
w.write(s"## $sql\n\n")
w.write(s"Query failed in Spark: ${e.getMessage}\n")
}
}

// flush after every query so that results are saved in the event of the driver crashing
w.flush()

} catch {
case e: Exception =>
// we expect many generated queries to be invalid
if (showFailedSparkQueries) {
w.write(s"## $sql\n\n")
w.write(s"Query failed in Spark: ${e.getMessage}\n")
}
}
})
})

} finally {
w.close()
Expand Down

0 comments on commit 34b7624

Please sign in to comment.