diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 74253e717..4ffe0ffd6 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -19,6 +19,8 @@ package org.apache.comet +import java.time.{Duration, Period} + import org.apache.hadoop.fs.Path import org.apache.spark.sql.{CometTestBase, DataFrame, Row} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper @@ -1583,67 +1585,17 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { withTempDir { dir => // Array values test - val arrayPath = new Path(dir.toURI.toString, "array_test.parquet").toString - Seq(Int.MaxValue, Int.MinValue).toDF("a").write.mode("overwrite").parquet(arrayPath) - val arrayQuery = "select a, -a from t" - runArrayTest(arrayQuery, "integer", arrayPath) - - // long values test - val longArrayPath = new Path(dir.toURI.toString, "long_array_test.parquet").toString - Seq(Long.MaxValue, Long.MinValue) - .toDF("a") - .write - .mode("overwrite") - .parquet(longArrayPath) - val longArrayQuery = "select a, -a from t" - runArrayTest(longArrayQuery, "long", longArrayPath) - - // short values test - val shortArrayPath = new Path(dir.toURI.toString, "short_array_test.parquet").toString - Seq(Short.MaxValue, Short.MinValue) - .toDF("a") - .write - .mode("overwrite") - .parquet(shortArrayPath) - val shortArrayQuery = "select a, -a from t" - runArrayTest(shortArrayQuery, "", shortArrayPath) - - // byte values test - val byteArrayPath = new Path(dir.toURI.toString, "byte_array_test.parquet").toString - Seq(Byte.MaxValue, Byte.MinValue) - .toDF("a") - .write - .mode("overwrite") - .parquet(byteArrayPath) - val byteArrayQuery = "select a, -a from t" - runArrayTest(byteArrayQuery, "", byteArrayPath) - - // interval values test - withTable("t_interval") { - spark.sql("CREATE TABLE t_interval(a STRING) USING PARQUET") - spark.sql("INSERT INTO t_interval VALUES ('INTERVAL 10000000000 YEAR')") - withAnsiMode(enabled = true) { - spark - .sql("SELECT CAST(a AS INTERVAL) AS a FROM t_interval") - .createOrReplaceTempView("t_interval_casted") - checkOverflow("SELECT a, -a FROM t_interval_casted", "interval") - } - } - - withTable("t") { - sql("create table t(a int) using parquet") - sql("insert into t values (-2147483648)") - withAnsiMode(enabled = true) { - checkOverflow("select a, -a from t", "integer") - } - } - - withTable("t_float") { - sql("create table t_float(a float) using parquet") - sql("insert into t_float values (3.4128235E38)") - withAnsiMode(enabled = true) { - checkOverflow("select a, -a from t_float", "float") - } + val dataTypes = Seq( + ("array_test.parquet", Seq(Int.MaxValue, Int.MinValue).toDF("a"), "integer"), + ("long_array_test.parquet", Seq(Long.MaxValue, Long.MinValue).toDF("a"), "long"), + ("short_array_test.parquet", Seq(Short.MaxValue, Short.MinValue).toDF("a"), ""), + ("byte_array_test.parquet", Seq(Byte.MaxValue, Byte.MinValue).toDF("a"), "")) + + dataTypes.foreach { case (fileName, df, dtype) => + val path = new Path(dir.toURI.toString, fileName).toString + df.write.mode("overwrite").parquet(path) + val query = s"select a, -a from t" + runArrayTest(query, dtype, path) } // scalar tests @@ -1669,6 +1621,25 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { for (n <- Seq("3.4028235E38", "-3.4028235E38")) { checkOverflow(s"select -(cast(${n} as float)) FROM tbl", "float") } + // interval test without cast + val longDf = Seq(Long.MaxValue, Long.MaxValue, 2).toDF("v") + val yearMonthDf = Seq(Int.MaxValue, Int.MaxValue, 2) + .map(Period.ofMonths) + .toDF("v") + val dayTimeDf = Seq(106751991L, 106751991L, 2L) + .map(Duration.ofDays) + .toDF("v") + Seq(longDf, yearMonthDf, dayTimeDf).foreach { df => + withSQLConf( + "spark.sql.optimizer.excludedRules" -> "org.apache.spark.sql.catalyst.optimizer.ConstantFolding", + SQLConf.ANSI_ENABLED.key -> "true", + CometConf.COMET_ANSI_MODE_ENABLED.key -> "true", + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true") { + df.createOrReplaceTempView("temp_interval_table") + checkOverflow("select -v from temp_interval_table", "interval") + } + } } } }