Skip to content

Commit

Permalink
tests: Move random data generation methods from CometCastSuite to new…
Browse files Browse the repository at this point in the history
… DataGenerator class (apache#426)

* add new DataGenerator class

* move more data gen methods

* ignore some failing tests, update compat docs

* address feedback

* fix regression
  • Loading branch information
andygrove authored May 15, 2024
1 parent 1a04805 commit fcf7d5b
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 98 deletions.
8 changes: 4 additions & 4 deletions docs/source/user-guide/compatibility.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,6 @@ The following cast operations are generally compatible with Spark except for the
| decimal | float | |
| decimal | double | |
| string | boolean | |
| string | byte | |
| string | short | |
| string | integer | |
| string | long | |
| string | binary | |
| date | string | |
| timestamp | long | |
Expand All @@ -129,6 +125,10 @@ The following cast operations are not compatible with Spark for all inputs and a
|-|-|-|
| integer | decimal | No overflow check |
| long | decimal | No overflow check |
| string | byte | Not all invalid inputs are detected |
| string | short | Not all invalid inputs are detected |
| string | integer | Not all invalid inputs are detected |
| string | long | Not all invalid inputs are detected |
| string | timestamp | Not all valid formats are supported |
| binary | string | Only works for binary data representing valid UTF-8 strings |

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ object CometCast {
Compatible()
case DataTypes.ByteType | DataTypes.ShortType | DataTypes.IntegerType |
DataTypes.LongType =>
Compatible()
Incompatible(Some("Not all invalid inputs are detected"))
case DataTypes.BinaryType =>
Compatible()
case DataTypes.FloatType | DataTypes.DoubleType =>
Expand Down
122 changes: 30 additions & 92 deletions spark/src/test/scala/org/apache/comet/CometCastSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@ import org.apache.comet.expressions.{CometCast, Compatible}
class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
import testImplicits._

private val dataSize = 1000
/** Create a data generator using a fixed seed so that tests are reproducible */
private val gen = DataGenerator.DEFAULT

/** Number of random data items to generate in each test */
private val dataSize = 10000

// we should eventually add more whitespace chars here as documented in
// https://docs.oracle.com/javase/8/docs/api/java/lang/Character.html#isWhitespace-char-
Expand Down Expand Up @@ -478,7 +482,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
test("cast StringType to BooleanType") {
val testValues =
(Seq("TRUE", "True", "true", "FALSE", "False", "false", "1", "0", "", null) ++
generateStrings("truefalseTRUEFALSEyesno10" + whitespaceChars, 8)).toDF("a")
gen.generateStrings(dataSize, "truefalseTRUEFALSEyesno10" + whitespaceChars, 8)).toDF("a")
castTest(testValues, DataTypes.BooleanType)
}

Expand Down Expand Up @@ -515,57 +519,57 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
"9223372036854775808" // Long.MaxValue + 1
)

test("cast StringType to ByteType") {
ignore("cast StringType to ByteType") {
// test with hand-picked values
castTest(castStringToIntegralInputs.toDF("a"), DataTypes.ByteType)
// fuzz test
castTest(generateStrings(numericPattern, 4).toDF("a"), DataTypes.ByteType)
castTest(gen.generateStrings(dataSize, numericPattern, 4).toDF("a"), DataTypes.ByteType)
}

test("cast StringType to ShortType") {
ignore("cast StringType to ShortType") {
// test with hand-picked values
castTest(castStringToIntegralInputs.toDF("a"), DataTypes.ShortType)
// fuzz test
castTest(generateStrings(numericPattern, 5).toDF("a"), DataTypes.ShortType)
castTest(gen.generateStrings(dataSize, numericPattern, 5).toDF("a"), DataTypes.ShortType)
}

test("cast StringType to IntegerType") {
ignore("cast StringType to IntegerType") {
// test with hand-picked values
castTest(castStringToIntegralInputs.toDF("a"), DataTypes.IntegerType)
// fuzz test
castTest(generateStrings(numericPattern, 8).toDF("a"), DataTypes.IntegerType)
castTest(gen.generateStrings(dataSize, numericPattern, 8).toDF("a"), DataTypes.IntegerType)
}

test("cast StringType to LongType") {
ignore("cast StringType to LongType") {
// test with hand-picked values
castTest(castStringToIntegralInputs.toDF("a"), DataTypes.LongType)
// fuzz test
castTest(generateStrings(numericPattern, 8).toDF("a"), DataTypes.LongType)
castTest(gen.generateStrings(dataSize, numericPattern, 8).toDF("a"), DataTypes.LongType)
}

ignore("cast StringType to FloatType") {
// https://github.com/apache/datafusion-comet/issues/326
castTest(generateStrings(numericPattern, 8).toDF("a"), DataTypes.FloatType)
castTest(gen.generateStrings(dataSize, numericPattern, 8).toDF("a"), DataTypes.FloatType)
}

ignore("cast StringType to DoubleType") {
// https://github.com/apache/datafusion-comet/issues/326
castTest(generateStrings(numericPattern, 8).toDF("a"), DataTypes.DoubleType)
castTest(gen.generateStrings(dataSize, numericPattern, 8).toDF("a"), DataTypes.DoubleType)
}

ignore("cast StringType to DecimalType(10,2)") {
// https://github.com/apache/datafusion-comet/issues/325
val values = generateStrings(numericPattern, 8).toDF("a")
val values = gen.generateStrings(dataSize, numericPattern, 8).toDF("a")
castTest(values, DataTypes.createDecimalType(10, 2))
}

test("cast StringType to BinaryType") {
castTest(generateStrings(numericPattern, 8).toDF("a"), DataTypes.BinaryType)
castTest(gen.generateStrings(dataSize, numericPattern, 8).toDF("a"), DataTypes.BinaryType)
}

ignore("cast StringType to DateType") {
// https://github.com/apache/datafusion-comet/issues/327
castTest(generateStrings(datePattern, 8).toDF("a"), DataTypes.DateType)
castTest(gen.generateStrings(dataSize, datePattern, 8).toDF("a"), DataTypes.DateType)
}

test("cast StringType to TimestampType disabled by default") {
Expand All @@ -581,7 +585,10 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
ignore("cast StringType to TimestampType") {
// https://github.com/apache/datafusion-comet/issues/328
withSQLConf((CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key, "true")) {
val values = Seq("2020-01-01T12:34:56.123456", "T2") ++ generateStrings(timestampPattern, 8)
val values = Seq("2020-01-01T12:34:56.123456", "T2") ++ gen.generateStrings(
dataSize,
timestampPattern,
8)
castTest(values.toDF("a"), DataTypes.TimestampType)
}
}
Expand Down Expand Up @@ -630,7 +637,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}

test("cast BinaryType to StringType - valid UTF-8 inputs") {
castTest(generateStrings(numericPattern, 8).toDF("a"), DataTypes.StringType)
castTest(gen.generateStrings(dataSize, numericPattern, 8).toDF("a"), DataTypes.StringType)
}

// CAST from DateType
Expand Down Expand Up @@ -739,67 +746,31 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}

private def generateFloats(): DataFrame = {
val r = new Random(0)
val values = Seq(
Float.MaxValue,
Float.MinPositiveValue,
Float.MinValue,
Float.NaN,
Float.PositiveInfinity,
Float.NegativeInfinity,
1.0f,
-1.0f,
Short.MinValue.toFloat,
Short.MaxValue.toFloat,
0.0f) ++
Range(0, dataSize).map(_ => r.nextFloat())
withNulls(values).toDF("a")
withNulls(gen.generateFloats(dataSize)).toDF("a")
}

private def generateDoubles(): DataFrame = {
val r = new Random(0)
val values = Seq(
Double.MaxValue,
Double.MinPositiveValue,
Double.MinValue,
Double.NaN,
Double.PositiveInfinity,
Double.NegativeInfinity,
0.0d) ++
Range(0, dataSize).map(_ => r.nextDouble())
withNulls(values).toDF("a")
withNulls(gen.generateDoubles(dataSize)).toDF("a")
}

private def generateBools(): DataFrame = {
withNulls(Seq(true, false)).toDF("a")
}

private def generateBytes(): DataFrame = {
val r = new Random(0)
val values = Seq(Byte.MinValue, Byte.MaxValue) ++
Range(0, dataSize).map(_ => r.nextInt().toByte)
withNulls(values).toDF("a")
withNulls(gen.generateBytes(dataSize)).toDF("a")
}

private def generateShorts(): DataFrame = {
val r = new Random(0)
val values = Seq(Short.MinValue, Short.MaxValue) ++
Range(0, dataSize).map(_ => r.nextInt().toShort)
withNulls(values).toDF("a")
withNulls(gen.generateShorts(dataSize)).toDF("a")
}

private def generateInts(): DataFrame = {
val r = new Random(0)
val values = Seq(Int.MinValue, Int.MaxValue) ++
Range(0, dataSize).map(_ => r.nextInt())
withNulls(values).toDF("a")
withNulls(gen.generateInts(dataSize)).toDF("a")
}

private def generateLongs(): DataFrame = {
val r = new Random(0)
val values = Seq(Long.MinValue, Long.MaxValue) ++
Range(0, dataSize).map(_ => r.nextLong())
withNulls(values).toDF("a")
withNulls(gen.generateLongs(dataSize)).toDF("a")
}

private def generateDecimalsPrecision10Scale2(): DataFrame = {
Expand Down Expand Up @@ -864,17 +835,6 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
.drop("str")
}

private def generateString(r: Random, chars: String, maxLen: Int): String = {
val len = r.nextInt(maxLen)
Range(0, len).map(_ => chars.charAt(r.nextInt(chars.length))).mkString
}

// TODO return DataFrame for consistency with other generators and include null values
private def generateStrings(chars: String, maxLen: Int): Seq[String] = {
val r = new Random(0)
Range(0, dataSize).map(_ => generateString(r, chars, maxLen))
}

private def generateBinary(): DataFrame = {
val r = new Random(0)
val bytes = new Array[Byte](8)
Expand Down Expand Up @@ -907,28 +867,6 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}
}

// TODO Commented out to work around scalafix since this is currently unused.
// private def castFallbackTestTimezone(
// input: DataFrame,
// toType: DataType,
// expectedMessage: String): Unit = {
// withTempPath { dir =>
// val data = roundtripParquet(input, dir).coalesce(1)
// data.createOrReplaceTempView("t")
//
// withSQLConf(
// (SQLConf.ANSI_ENABLED.key, "false"),
// (CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key, "true"),
// (SQLConf.SESSION_LOCAL_TIMEZONE.key, "America/Los_Angeles")) {
// val df = data.withColumn("converted", col("a").cast(toType))
// df.collect()
// val str =
// new ExtendedExplainInfo().generateExtendedInfo(df.queryExecution.executedPlan)
// assert(str.contains(expectedMessage))
// }
// }
// }

private def castTimestampTest(input: DataFrame, toType: DataType) = {
withTempPath { dir =>
val data = roundtripParquet(input, dir).coalesce(1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -942,7 +942,9 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {

test("Chr") {
Seq(false, true).foreach { dictionary =>
withSQLConf("parquet.enable.dictionary" -> dictionary.toString) {
withSQLConf(
"parquet.enable.dictionary" -> dictionary.toString,
CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key -> "true") {
val table = "test"
withTable(table) {
sql(s"create table $table(col varchar(20)) using parquet")
Expand Down
98 changes: 98 additions & 0 deletions spark/src/test/scala/org/apache/comet/DataGenerator.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.comet

import scala.util.Random

object DataGenerator {
// note that we use `def` rather than `val` intentionally here so that
// each test suite starts with a fresh data generator to help ensure
// that tests are deterministic
def DEFAULT = new DataGenerator(new Random(42))
}

class DataGenerator(r: Random) {

/** Generate a random string using the specified characters */
def generateString(chars: String, maxLen: Int): String = {
val len = r.nextInt(maxLen)
Range(0, len).map(_ => chars.charAt(r.nextInt(chars.length))).mkString
}

/** Generate random strings */
def generateStrings(n: Int, maxLen: Int): Seq[String] = {
Range(0, n).map(_ => r.nextString(maxLen))
}

/** Generate random strings using the specified characters */
def generateStrings(n: Int, chars: String, maxLen: Int): Seq[String] = {
Range(0, n).map(_ => generateString(chars, maxLen))
}

def generateFloats(n: Int): Seq[Float] = {
Seq(
Float.MaxValue,
Float.MinPositiveValue,
Float.MinValue,
Float.NaN,
Float.PositiveInfinity,
Float.NegativeInfinity,
1.0f,
-1.0f,
Short.MinValue.toFloat,
Short.MaxValue.toFloat,
0.0f) ++
Range(0, n).map(_ => r.nextFloat())
}

def generateDoubles(n: Int): Seq[Double] = {
Seq(
Double.MaxValue,
Double.MinPositiveValue,
Double.MinValue,
Double.NaN,
Double.PositiveInfinity,
Double.NegativeInfinity,
0.0d) ++
Range(0, n).map(_ => r.nextDouble())
}

def generateBytes(n: Int): Seq[Byte] = {
Seq(Byte.MinValue, Byte.MaxValue) ++
Range(0, n).map(_ => r.nextInt().toByte)
}

def generateShorts(n: Int): Seq[Short] = {
val r = new Random(0)
Seq(Short.MinValue, Short.MaxValue) ++
Range(0, n).map(_ => r.nextInt().toShort)
}

def generateInts(n: Int): Seq[Int] = {
Seq(Int.MinValue, Int.MaxValue) ++
Range(0, n).map(_ => r.nextInt())
}

def generateLongs(n: Int): Seq[Long] = {
Seq(Long.MinValue, Long.MaxValue) ++
Range(0, n).map(_ => r.nextLong())
}

}

0 comments on commit fcf7d5b

Please sign in to comment.