diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala index ae49fc68a..27eeb0324 100644 --- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala @@ -39,7 +39,13 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { // but this is likely a reasonable starting point for now private val whitespaceChars = " \t\r\n" - private val numericPattern = "0123456789e+-." + whitespaceChars + /** + * We use these characters to construct strings that potentially represent valid numbers + * such as `-12.34d` or `4e7`. Invalid numeric strings will also be generated, such as + * `+e.-d`. + */ + private val numericPattern = "0123456789def+-." + whitespaceChars + private val datePattern = "0123456789/" + whitespaceChars private val timestampPattern = "0123456789/:T" + whitespaceChars @@ -66,23 +72,54 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { castTest(testValues, DataTypes.BooleanType) } + val castStringToIntegralInputs = Seq( + "", ".", "+", "-", "+.", "=.", + "-0", "+1", "-1", ".2", "-.2", + "1e1", "1.1d", "1.1f", + Byte.MinValue.toString, + (Byte.MinValue.toShort - 1).toString, + Byte.MaxValue.toString, + (Byte.MaxValue.toShort + 1).toString, + Short.MinValue.toString, + (Short.MinValue.toInt - 1).toString, + Short.MaxValue.toString, + (Short.MaxValue.toInt + 1).toString, + Int.MinValue.toString, + (Int.MinValue.toLong - 1).toString, + Int.MaxValue.toString, + (Int.MaxValue.toLong + 1).toString, + Long.MinValue.toString, + Long.MaxValue.toString, + "-9223372036854775809", // Long.MinValue -1 + "9223372036854775808" // Long.MaxValue + 1 + ) + test("cast string to byte") { - val testValues = - Seq("", ".", "0", "-0", "+1", "-1", ".2", "-.2", "1e1", "127", "128", "-128", "-129") ++ - generateStrings(numericPattern, 8) - castTest(testValues.toDF("a"), DataTypes.ByteType) + // test with hand-picked values + castTest(castStringToIntegralInputs.toDF("a"), DataTypes.ByteType) + // fuzz test + castTest(generateStrings(numericPattern, 5).toDF("a"), DataTypes.ByteType) } test("cast string to short") { - castTest(generateStrings(numericPattern, 8).toDF("a"), DataTypes.ShortType) + // test with hand-picked values + castTest(castStringToIntegralInputs.toDF("a"), DataTypes.ByteType) + // fuzz test + castTest(generateStrings(numericPattern, 5).toDF("a"), DataTypes.ByteType) } test("cast string to int") { - castTest(generateStrings(numericPattern, 8).toDF("a"), DataTypes.IntegerType) + // test with hand-picked values + castTest(castStringToIntegralInputs.toDF("a"), DataTypes.ByteType) + // fuzz test + castTest(generateStrings(numericPattern, 5).toDF("a"), DataTypes.ByteType) } test("cast string to long") { - castTest(generateStrings(numericPattern, 8).toDF("a"), DataTypes.LongType) + // test with hand-picked values + castTest(castStringToIntegralInputs.toDF("a"), DataTypes.ByteType) + // fuzz test + castTest(generateStrings(numericPattern, 5).toDF("a"), DataTypes.ByteType) } ignore("cast string to float") {