Skip to content

Commit

Permalink
feat: Add random row generator in data gen
Browse files Browse the repository at this point in the history
  • Loading branch information
advancedxy committed May 20, 2024
1 parent b4c2dc2 commit f9ddaab
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 0 deletions.
58 changes: 58 additions & 0 deletions spark/src/test/scala/org/apache/comet/DataGenerator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,23 @@

package org.apache.comet

import scala.collection.mutable
import scala.util.Random

import org.apache.spark.sql.{RandomDataGenerator, Row}
import org.apache.spark.sql.types.{ArrayType, StringType, StructType}

object DataGenerator {
// note that we use `def` rather than `val` intentionally here so that
// each test suite starts with a fresh data generator to help ensure
// that tests are deterministic
def DEFAULT = new DataGenerator(new Random(42))
// matches the probability of nulls in Spark's RandomDataGenerator
private val PROBABILITY_OF_NULL: Float = 0.1f
}

class DataGenerator(r: Random) {
import DataGenerator._

/** Generate a random string using the specified characters */
def generateString(chars: String, maxLen: Int): String = {
Expand Down Expand Up @@ -95,4 +102,55 @@ class DataGenerator(r: Random) {
Range(0, n).map(_ => r.nextLong())
}

// Generate a random row according to the schema, the string filed in the struct could be
// configured to generate strings by passing a stringGen function. Other types are delegated
// to Spark's RandomDataGenerator.
def generateRow(schema: StructType, stringGen: Option[() => String] = None): Row = {
val fields = mutable.ArrayBuffer.empty[Any]
schema.fields.foreach { f =>
f.dataType match {
case ArrayType(childType, nullable) =>
val data = if (f.nullable && r.nextFloat() <= PROBABILITY_OF_NULL) {
null
} else {
val arr = mutable.ArrayBuffer.empty[Any]
val n = 1 // rand.nextInt(10)
var i = 0
val generator = RandomDataGenerator.forType(childType, nullable, r)
assert(generator.isDefined, "Unsupported type")
val gen = generator.get
while (i < n) {
arr += gen()
i += 1
}
arr.toSeq
}
fields += data
case StructType(children) =>
fields += generateRow(StructType(children))
case StringType if stringGen.isDefined =>
val gen = stringGen.get
val data = if (f.nullable && r.nextFloat() <= PROBABILITY_OF_NULL) {
null
} else {
gen()
}
fields += data
case _ =>
val generator = RandomDataGenerator.forType(f.dataType, f.nullable, r)
assert(generator.isDefined, "Unsupported type")
val gen = generator.get
fields += gen()
}
}
Row.fromSeq(fields)
}

def generateRows(
num: Int,
schema: StructType,
stringGen: Option[() => String] = None): Seq[Row] = {
Range(0, num).map(_ => generateRow(schema, stringGen))
}

}
49 changes: 49 additions & 0 deletions spark/src/test/scala/org/apache/comet/DataGeneratorSuite.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.comet

import org.apache.spark.sql.CometTestBase
import org.apache.spark.sql.types.StructType

class DataGeneratorSuite extends CometTestBase {

test("test configurable stringGen in row generator") {
val gen = DataGenerator.DEFAULT
val chars = "abcde"
val maxLen = 10
val stringGen = () => gen.generateString(chars, maxLen)
val numRows = 100
val schema = new StructType().add("a", "string")
var numNulls = 0
gen
.generateRows(numRows, schema, Some(stringGen))
.foreach(row => {
if (row.getString(0) != null) {
assert(row.getString(0).forall(chars.toSeq.contains))
assert(row.getString(0).length <= maxLen)
} else {
numNulls += 1
}
})
// 0.1 null probability
assert(numNulls >= 5 && numNulls <= 15)
}

}

0 comments on commit f9ddaab

Please sign in to comment.