From 23082da3506f7ada8bb11882622e2cde72316f15 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 14 May 2024 09:17:49 -0600 Subject: [PATCH] move more data gen methods --- .../org/apache/comet/CometCastSuite.scala | 70 ++----------------- .../org/apache/comet/DataGenerator.scala | 49 +++++++++++++ 2 files changed, 55 insertions(+), 64 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala index bcc171a87..51fa0626b 100644 --- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala @@ -745,35 +745,11 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { } private def generateFloats(): DataFrame = { - val r = new Random(0) - val values = Seq( - Float.MaxValue, - Float.MinPositiveValue, - Float.MinValue, - Float.NaN, - Float.PositiveInfinity, - Float.NegativeInfinity, - 1.0f, - -1.0f, - Short.MinValue.toFloat, - Short.MaxValue.toFloat, - 0.0f) ++ - Range(0, dataSize).map(_ => r.nextFloat()) - withNulls(values).toDF("a") + withNulls(gen.generateFloats(dataSize)).toDF("a") } private def generateDoubles(): DataFrame = { - val r = new Random(0) - val values = Seq( - Double.MaxValue, - Double.MinPositiveValue, - Double.MinValue, - Double.NaN, - Double.PositiveInfinity, - Double.NegativeInfinity, - 0.0d) ++ - Range(0, dataSize).map(_ => r.nextDouble()) - withNulls(values).toDF("a") + withNulls(gen.generateDoubles(dataSize)).toDF("a") } private def generateBools(): DataFrame = { @@ -781,31 +757,19 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { } private def generateBytes(): DataFrame = { - val r = new Random(0) - val values = Seq(Byte.MinValue, Byte.MaxValue) ++ - Range(0, dataSize).map(_ => r.nextInt().toByte) - withNulls(values).toDF("a") + withNulls(gen.generateBytes(dataSize)).toDF("a") } private def generateShorts(): DataFrame = { - val r = new Random(0) - val values = Seq(Short.MinValue, Short.MaxValue) ++ - Range(0, dataSize).map(_ => r.nextInt().toShort) - withNulls(values).toDF("a") + withNulls(gen.generateShorts(dataSize)).toDF("a") } private def generateInts(): DataFrame = { - val r = new Random(0) - val values = Seq(Int.MinValue, Int.MaxValue) ++ - Range(0, dataSize).map(_ => r.nextInt()) - withNulls(values).toDF("a") + withNulls(gen.generateInts(dataSize)).toDF("a") } private def generateLongs(): DataFrame = { - val r = new Random(0) - val values = Seq(Long.MinValue, Long.MaxValue) ++ - Range(0, dataSize).map(_ => r.nextLong()) - withNulls(values).toDF("a") + withNulls(gen.generateLongs(dataSize)).toDF("a") } private def generateDecimalsPrecision10Scale2(): DataFrame = { @@ -902,28 +866,6 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } - // TODO Commented out to work around scalafix since this is currently unused. - // private def castFallbackTestTimezone( - // input: DataFrame, - // toType: DataType, - // expectedMessage: String): Unit = { - // withTempPath { dir => - // val data = roundtripParquet(input, dir).coalesce(1) - // data.createOrReplaceTempView("t") - // - // withSQLConf( - // (SQLConf.ANSI_ENABLED.key, "false"), - // (CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key, "true"), - // (SQLConf.SESSION_LOCAL_TIMEZONE.key, "America/Los_Angeles")) { - // val df = data.withColumn("converted", col("a").cast(toType)) - // df.collect() - // val str = - // new ExtendedExplainInfo().generateExtendedInfo(df.queryExecution.executedPlan) - // assert(str.contains(expectedMessage)) - // } - // } - // } - private def castTimestampTest(input: DataFrame, toType: DataType) = { withTempPath { dir => val data = roundtripParquet(input, dir).coalesce(1) diff --git a/spark/src/test/scala/org/apache/comet/DataGenerator.scala b/spark/src/test/scala/org/apache/comet/DataGenerator.scala index 8e361e803..19ef9eb9d 100644 --- a/spark/src/test/scala/org/apache/comet/DataGenerator.scala +++ b/spark/src/test/scala/org/apache/comet/DataGenerator.scala @@ -39,4 +39,53 @@ class DataGenerator(r: Random) { Range(0, n).map(_ => generateString(chars, maxLen)) } + def generateFloats(n: Int): Seq[Float] = { + Seq( + Float.MaxValue, + Float.MinPositiveValue, + Float.MinValue, + Float.NaN, + Float.PositiveInfinity, + Float.NegativeInfinity, + 1.0f, + -1.0f, + Short.MinValue.toFloat, + Short.MaxValue.toFloat, + 0.0f) ++ + Range(0, n).map(_ => r.nextFloat()) + } + + def generateDoubles(n: Int): Seq[Double] = { + Seq( + Double.MaxValue, + Double.MinPositiveValue, + Double.MinValue, + Double.NaN, + Double.PositiveInfinity, + Double.NegativeInfinity, + 0.0d) ++ + Range(0, n).map(_ => r.nextDouble()) + } + + def generateBytes(n: Int): Seq[Byte] = { + Seq(Byte.MinValue, Byte.MaxValue) ++ + Range(0, n).map(_ => r.nextInt().toByte) + } + + def generateShorts(n: Int): Seq[Short] = { + val r = new Random(0) + Seq(Short.MinValue, Short.MaxValue) ++ + Range(0, n).map(_ => r.nextInt().toShort) + } + + def generateInts(n: Int): Seq[Int] = { + Seq(Int.MinValue, Int.MaxValue) ++ + Range(0, n).map(_ => r.nextInt()) + } + + def generateLongs(n: Int): Seq[Long] = { + Seq(Long.MinValue, Long.MaxValue) ++ + Range(0, n).map(_ => r.nextLong()) + } + }