diff --git a/spark/src/test/scala/org/apache/comet/DataGenerator.scala b/spark/src/test/scala/org/apache/comet/DataGenerator.scala index 691a371b5..80e7c2288 100644 --- a/spark/src/test/scala/org/apache/comet/DataGenerator.scala +++ b/spark/src/test/scala/org/apache/comet/DataGenerator.scala @@ -21,14 +21,20 @@ package org.apache.comet import scala.util.Random +import org.apache.spark.sql.{RandomDataGenerator, Row} +import org.apache.spark.sql.types.{StringType, StructType} + object DataGenerator { // note that we use `def` rather than `val` intentionally here so that // each test suite starts with a fresh data generator to help ensure // that tests are deterministic def DEFAULT = new DataGenerator(new Random(42)) + // matches the probability of nulls in Spark's RandomDataGenerator + private val PROBABILITY_OF_NULL: Float = 0.1f } class DataGenerator(r: Random) { + import DataGenerator._ /** Generate a random string using the specified characters */ def generateString(chars: String, maxLen: Int): String = { @@ -95,4 +101,39 @@ class DataGenerator(r: Random) { Range(0, n).map(_ => r.nextLong()) } + // Generate a random row according to the schema, the string filed in the struct could be + // configured to generate strings by passing a stringGen function. Other types are delegated + // to Spark's RandomDataGenerator. + def generateRow(schema: StructType, stringGen: Option[() => String] = None): Row = { + val fields = schema.fields.map { f => + f.dataType match { + case StructType(children) => + generateRow(StructType(children), stringGen) + case StringType if stringGen.isDefined => + val gen = stringGen.get + val data = if (f.nullable && r.nextFloat() <= PROBABILITY_OF_NULL) { + null + } else { + gen() + } + data + case _ => + val gen = RandomDataGenerator.forType(f.dataType, f.nullable, r) match { + case Some(g) => g + case None => + throw new IllegalStateException(s"No RandomDataGenerator for type ${f.dataType}") + } + gen() + } + }.toSeq + Row.fromSeq(fields) + } + + def generateRows( + num: Int, + schema: StructType, + stringGen: Option[() => String] = None): Seq[Row] = { + Range(0, num).map(_ => generateRow(schema, stringGen)) + } + } diff --git a/spark/src/test/scala/org/apache/comet/DataGeneratorSuite.scala b/spark/src/test/scala/org/apache/comet/DataGeneratorSuite.scala new file mode 100644 index 000000000..02dfb9d7b --- /dev/null +++ b/spark/src/test/scala/org/apache/comet/DataGeneratorSuite.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet + +import org.apache.spark.sql.CometTestBase +import org.apache.spark.sql.types.StructType + +class DataGeneratorSuite extends CometTestBase { + + test("test configurable stringGen in row generator") { + val gen = DataGenerator.DEFAULT + val chars = "abcde" + val maxLen = 10 + val stringGen = () => gen.generateString(chars, maxLen) + val numRows = 100 + val schema = new StructType().add("a", "string") + var numNulls = 0 + gen + .generateRows(numRows, schema, Some(stringGen)) + .foreach(row => { + if (row.getString(0) != null) { + assert(row.getString(0).forall(chars.toSeq.contains)) + assert(row.getString(0).length <= maxLen) + } else { + numNulls += 1 + } + }) + // 0.1 null probability + assert(numNulls >= 0.05 * numRows && numNulls <= 0.15 * numRows) + } + +}