Skip to content

Commit

Permalink
feat: Add random row generator in data generator (apache#451)
Browse files Browse the repository at this point in the history
* feat: Add random row generator in data gen

* fix

* remove array type match case, which should already been handled in RandomDataGenerator.forType

* fix style issue

* address comments

(cherry picked from commit 9125e6a)
  • Loading branch information
advancedxy authored and Huaxin Gao committed May 31, 2024
1 parent c78c8c0 commit dce4332
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 0 deletions.
41 changes: 41 additions & 0 deletions spark/src/test/scala/org/apache/comet/DataGenerator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,20 @@ package org.apache.comet

import scala.util.Random

import org.apache.spark.sql.{RandomDataGenerator, Row}
import org.apache.spark.sql.types.{StringType, StructType}

object DataGenerator {
// note that we use `def` rather than `val` intentionally here so that
// each test suite starts with a fresh data generator to help ensure
// that tests are deterministic
def DEFAULT = new DataGenerator(new Random(42))
// matches the probability of nulls in Spark's RandomDataGenerator
private val PROBABILITY_OF_NULL: Float = 0.1f
}

class DataGenerator(r: Random) {
import DataGenerator._

/** Generate a random string using the specified characters */
def generateString(chars: String, maxLen: Int): String = {
Expand Down Expand Up @@ -95,4 +101,39 @@ class DataGenerator(r: Random) {
Range(0, n).map(_ => r.nextLong())
}

// Generate a random row according to the schema, the string filed in the struct could be
// configured to generate strings by passing a stringGen function. Other types are delegated
// to Spark's RandomDataGenerator.
def generateRow(schema: StructType, stringGen: Option[() => String] = None): Row = {
val fields = schema.fields.map { f =>
f.dataType match {
case StructType(children) =>
generateRow(StructType(children), stringGen)
case StringType if stringGen.isDefined =>
val gen = stringGen.get
val data = if (f.nullable && r.nextFloat() <= PROBABILITY_OF_NULL) {
null
} else {
gen()
}
data
case _ =>
val gen = RandomDataGenerator.forType(f.dataType, f.nullable, r) match {
case Some(g) => g
case None =>
throw new IllegalStateException(s"No RandomDataGenerator for type ${f.dataType}")
}
gen()
}
}.toSeq
Row.fromSeq(fields)
}

def generateRows(
num: Int,
schema: StructType,
stringGen: Option[() => String] = None): Seq[Row] = {
Range(0, num).map(_ => generateRow(schema, stringGen))
}

}
49 changes: 49 additions & 0 deletions spark/src/test/scala/org/apache/comet/DataGeneratorSuite.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.comet

import org.apache.spark.sql.CometTestBase
import org.apache.spark.sql.types.StructType

class DataGeneratorSuite extends CometTestBase {

test("test configurable stringGen in row generator") {
val gen = DataGenerator.DEFAULT
val chars = "abcde"
val maxLen = 10
val stringGen = () => gen.generateString(chars, maxLen)
val numRows = 100
val schema = new StructType().add("a", "string")
var numNulls = 0
gen
.generateRows(numRows, schema, Some(stringGen))
.foreach(row => {
if (row.getString(0) != null) {
assert(row.getString(0).forall(chars.toSeq.contains))
assert(row.getString(0).length <= maxLen)
} else {
numNulls += 1
}
})
// 0.1 null probability
assert(numNulls >= 0.05 * numRows && numNulls <= 0.15 * numRows)
}

}

0 comments on commit dce4332

Please sign in to comment.