Skip to content

Commit

Permalink
feat: Add fuzz testing for arithmetic expressions (#519)
Browse files Browse the repository at this point in the history
* Add fuzz tests for aritmetic expressions

* add unary math

* add bit-wise expressions

* bug fix
  • Loading branch information
andygrove authored Jun 7, 2024
1 parent c3e27c0 commit c6d387c
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 4 deletions.
1 change: 0 additions & 1 deletion fuzz-testing/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ Planned areas of improvement:

- ANSI mode
- Support for all data types, expressions, and operators supported by Comet
- Unary and binary arithmetic expressions
- IF and CASE WHEN expressions
- Complex (nested) expressions
- Literal scalar values in queries
Expand Down
4 changes: 4 additions & 0 deletions fuzz-testing/src/main/scala/org/apache/comet/fuzz/Meta.scala
Original file line number Diff line number Diff line change
Expand Up @@ -106,4 +106,8 @@ object Meta {
Function("stddev_samp", 1),
Function("corr", 2))

val unaryArithmeticOps: Seq[String] = Seq("+", "-")

val binaryArithmeticOps: Seq[String] = Seq("+", "-", "*", "/", "%", "&", "|", "^")

}
31 changes: 30 additions & 1 deletion fuzz-testing/src/main/scala/org/apache/comet/fuzz/QueryGen.scala
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,13 @@ object QueryGen {
val uniqueQueries = mutable.HashSet[String]()

for (_ <- 0 until numQueries) {
val sql = r.nextInt().abs % 4 match {
val sql = r.nextInt().abs % 6 match {
case 0 => generateJoin(r, spark, numFiles)
case 1 => generateAggregate(r, spark, numFiles)
case 2 => generateScalar(r, spark, numFiles)
case 3 => generateCast(r, spark, numFiles)
case 4 => generateUnaryArithmetic(r, spark, numFiles)
case 5 => generateBinaryArithmetic(r, spark, numFiles)
}
if (!uniqueQueries.contains(sql)) {
uniqueQueries += sql
Expand Down Expand Up @@ -92,6 +94,33 @@ object QueryGen {
s"ORDER BY ${args.mkString(", ")};"
}

private def generateUnaryArithmetic(r: Random, spark: SparkSession, numFiles: Int): String = {
val tableName = s"test${r.nextInt(numFiles)}"
val table = spark.table(tableName)

val op = Utils.randomChoice(Meta.unaryArithmeticOps, r)
val a = Utils.randomChoice(table.columns, r)

// Example SELECT a, -a FROM test0
s"SELECT $a, $op$a " +
s"FROM $tableName " +
s"ORDER BY $a;"
}

private def generateBinaryArithmetic(r: Random, spark: SparkSession, numFiles: Int): String = {
val tableName = s"test${r.nextInt(numFiles)}"
val table = spark.table(tableName)

val op = Utils.randomChoice(Meta.binaryArithmeticOps, r)
val a = Utils.randomChoice(table.columns, r)
val b = Utils.randomChoice(table.columns, r)

// Example SELECT a, b, a+b FROM test0
s"SELECT $a, $b, $a $op $b " +
s"FROM $tableName " +
s"ORDER BY $a, $b;"
}

private def generateCast(r: Random, spark: SparkSession, numFiles: Int): String = {
val tableName = s"test${r.nextInt(numFiles)}"
val table = spark.table(tableName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

package org.apache.comet.fuzz

import java.io.{BufferedWriter, FileWriter, PrintWriter}
import java.io.{BufferedWriter, FileWriter, PrintWriter, StringWriter}

import scala.io.Source

Expand Down Expand Up @@ -111,9 +111,11 @@ object QueryRunner {
showSQL(w, sql)
w.write(s"[ERROR] Query failed in Comet: ${e.getMessage}:\n")
w.write("```\n")
val p = new PrintWriter(w)
val sw = new StringWriter()
val p = new PrintWriter(sw)
e.printStackTrace(p)
p.close()
w.write(s"${sw.toString}\n")
w.write("```\n")
}

Expand Down

0 comments on commit c6d387c

Please sign in to comment.