Skip to content

Commit

Permalink
add new DataGenerator class
Browse files Browse the repository at this point in the history
  • Loading branch information
andygrove committed May 14, 2024
1 parent 3808306 commit 892d66d
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 23 deletions.
41 changes: 18 additions & 23 deletions spark/src/test/scala/org/apache/comet/CometCastSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ import org.apache.comet.expressions.{CometCast, Compatible}
class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
import testImplicits._

/** Create a data generator using a fixed seed so that tests are reproducible */
private val gen = new DataGenerator(new Random(42))

private val dataSize = 1000

// we should eventually add more whitespace chars here as documented in
Expand Down Expand Up @@ -478,7 +481,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
test("cast StringType to BooleanType") {
val testValues =
(Seq("TRUE", "True", "true", "FALSE", "False", "false", "1", "0", "", null) ++
generateStrings("truefalseTRUEFALSEyesno10" + whitespaceChars, 8)).toDF("a")
gen.generateStrings(dataSize, "truefalseTRUEFALSEyesno10" + whitespaceChars, 8)).toDF("a")
castTest(testValues, DataTypes.BooleanType)
}

Expand Down Expand Up @@ -519,53 +522,53 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
// test with hand-picked values
castTest(castStringToIntegralInputs.toDF("a"), DataTypes.ByteType)
// fuzz test
castTest(generateStrings(numericPattern, 4).toDF("a"), DataTypes.ByteType)
castTest(gen.generateStrings(dataSize, numericPattern, 4).toDF("a"), DataTypes.ByteType)
}

test("cast StringType to ShortType") {
// test with hand-picked values
castTest(castStringToIntegralInputs.toDF("a"), DataTypes.ShortType)
// fuzz test
castTest(generateStrings(numericPattern, 5).toDF("a"), DataTypes.ShortType)
castTest(gen.generateStrings(dataSize, numericPattern, 5).toDF("a"), DataTypes.ShortType)
}

test("cast StringType to IntegerType") {
// test with hand-picked values
castTest(castStringToIntegralInputs.toDF("a"), DataTypes.IntegerType)
// fuzz test
castTest(generateStrings(numericPattern, 8).toDF("a"), DataTypes.IntegerType)
castTest(gen.generateStrings(dataSize, numericPattern, 8).toDF("a"), DataTypes.IntegerType)
}

test("cast StringType to LongType") {
// test with hand-picked values
castTest(castStringToIntegralInputs.toDF("a"), DataTypes.LongType)
// fuzz test
castTest(generateStrings(numericPattern, 8).toDF("a"), DataTypes.LongType)
castTest(gen.generateStrings(dataSize, numericPattern, 8).toDF("a"), DataTypes.LongType)
}

ignore("cast StringType to FloatType") {
// https://github.com/apache/datafusion-comet/issues/326
castTest(generateStrings(numericPattern, 8).toDF("a"), DataTypes.FloatType)
castTest(gen.generateStrings(dataSize, numericPattern, 8).toDF("a"), DataTypes.FloatType)
}

ignore("cast StringType to DoubleType") {
// https://github.com/apache/datafusion-comet/issues/326
castTest(generateStrings(numericPattern, 8).toDF("a"), DataTypes.DoubleType)
castTest(gen.generateStrings(dataSize, numericPattern, 8).toDF("a"), DataTypes.DoubleType)
}

ignore("cast StringType to DecimalType(10,2)") {
// https://github.com/apache/datafusion-comet/issues/325
val values = generateStrings(numericPattern, 8).toDF("a")
val values = gen.generateStrings(dataSize, numericPattern, 8).toDF("a")
castTest(values, DataTypes.createDecimalType(10, 2))
}

test("cast StringType to BinaryType") {
castTest(generateStrings(numericPattern, 8).toDF("a"), DataTypes.BinaryType)
castTest(gen.generateStrings(dataSize, numericPattern, 8).toDF("a"), DataTypes.BinaryType)
}

ignore("cast StringType to DateType") {
// https://github.com/apache/datafusion-comet/issues/327
castTest(generateStrings(datePattern, 8).toDF("a"), DataTypes.DateType)
castTest(gen.generateStrings(dataSize, datePattern, 8).toDF("a"), DataTypes.DateType)
}

test("cast StringType to TimestampType disabled by default") {
Expand All @@ -581,7 +584,10 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
ignore("cast StringType to TimestampType") {
// https://github.com/apache/datafusion-comet/issues/328
withSQLConf((CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key, "true")) {
val values = Seq("2020-01-01T12:34:56.123456", "T2") ++ generateStrings(timestampPattern, 8)
val values = Seq("2020-01-01T12:34:56.123456", "T2") ++ gen.generateStrings(
dataSize,
timestampPattern,
8)
castTest(values.toDF("a"), DataTypes.TimestampType)
}
}
Expand Down Expand Up @@ -630,7 +636,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}

test("cast BinaryType to StringType - valid UTF-8 inputs") {
castTest(generateStrings(numericPattern, 8).toDF("a"), DataTypes.StringType)
castTest(gen.generateStrings(dataSize, numericPattern, 8).toDF("a"), DataTypes.StringType)
}

// CAST from DateType
Expand Down Expand Up @@ -864,17 +870,6 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
.drop("str")
}

private def generateString(r: Random, chars: String, maxLen: Int): String = {
val len = r.nextInt(maxLen)
Range(0, len).map(_ => chars.charAt(r.nextInt(chars.length))).mkString
}

// TODO return DataFrame for consistency with other generators and include null values
private def generateStrings(chars: String, maxLen: Int): Seq[String] = {
val r = new Random(0)
Range(0, dataSize).map(_ => generateString(r, chars, maxLen))
}

private def generateBinary(): DataFrame = {
val r = new Random(0)
val bytes = new Array[Byte](8)
Expand Down
42 changes: 42 additions & 0 deletions spark/src/test/scala/org/apache/comet/DataGenerator.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.comet

import scala.util.Random

class DataGenerator(r: Random) {

/** Generate a random string using the specified characters */
def generateString(chars: String, maxLen: Int): String = {
val len = r.nextInt(maxLen)
Range(0, len).map(_ => chars.charAt(r.nextInt(chars.length))).mkString
}

/** Generate random strings */
def generateStrings(n: Int, maxLen: Int): Seq[String] = {
Range(0, n).map(_ => r.nextString(maxLen))
}

/** Generate random strings using the specified characters */
def generateStrings(n: Int, chars: String, maxLen: Int): Seq[String] = {
Range(0, n).map(_ => generateString(chars, maxLen))
}

}

0 comments on commit 892d66d

Please sign in to comment.