From 55836febed9c415337baef66a011ba0457d5046e Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 19 Apr 2024 10:53:59 -0600 Subject: [PATCH 01/19] Add ANSI support to cast string to boolean --- core/src/errors.rs | 16 +++ .../execution/datafusion/expressions/cast.rs | 52 +++++++--- core/src/execution/datafusion/planner.rs | 50 +++++++--- core/src/execution/jni_api.rs | 8 +- .../org/apache/comet/CometExecIterator.scala | 6 +- .../comet/CometSparkSessionExtensions.scala | 8 +- .../sql/comet/CometCollectLimitExec.scala | 14 ++- .../spark/sql/comet/CometExecUtils.scala | 5 +- .../CometTakeOrderedAndProjectExec.scala | 9 +- .../shuffle/CometShuffleExchangeExec.scala | 7 +- .../apache/spark/sql/comet/operators.scala | 18 +++- .../org/apache/comet/CometCastSuite.scala | 98 ++++++++++++++----- .../org/apache/spark/sql/CometTestBase.scala | 12 +++ 13 files changed, 233 insertions(+), 70 deletions(-) diff --git a/core/src/errors.rs b/core/src/errors.rs index 1d5766cb9..f4b5d738a 100644 --- a/core/src/errors.rs +++ b/core/src/errors.rs @@ -60,6 +60,18 @@ pub enum CometError { #[error("Comet Internal Error: {0}")] Internal(String), + // TODO this error message is likely to change between Spark versions and it would be better + // to have the full error in Scala and just pass the invalid value back here + #[error("[[CAST_INVALID_INPUT] The value '{value}' of the type \"{from_type}\" cannot be cast to \"{to_type}\" \ + because it is malformed. Correct the value as per the syntax, or change its target type. \ + Use `try_cast` to tolerate malformed input and return NULL instead. If necessary \ + set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error")] + CastInvalidValue { + value: String, + from_type: String, + to_type: String, + }, + #[error(transparent)] Arrow { #[from] @@ -183,6 +195,10 @@ impl jni::errors::ToException for CometError { class: "java/lang/NullPointerException".to_string(), msg: self.to_string(), }, + CometError::CastInvalidValue { .. } => Exception { + class: "org/apache/spark/SparkException".to_string(), + msg: self.to_string(), + }, CometError::NumberIntFormat { source: s } => Exception { class: "java/lang/NumberFormatException".to_string(), msg: s.to_string(), diff --git a/core/src/execution/datafusion/expressions/cast.rs b/core/src/execution/datafusion/expressions/cast.rs index 447c27783..7c8caece9 100644 --- a/core/src/execution/datafusion/expressions/cast.rs +++ b/core/src/execution/datafusion/expressions/cast.rs @@ -22,6 +22,7 @@ use std::{ sync::Arc, }; +use crate::errors::{CometError, CometResult}; use arrow::{ compute::{cast_with_options, CastOptions}, record_batch::RecordBatch, @@ -49,6 +50,7 @@ static CAST_OPTIONS: CastOptions = CastOptions { pub struct Cast { pub child: Arc, pub data_type: DataType, + pub ansi_mode: bool, /// When cast from/to timezone related types, we need timezone, which will be resolved with /// session local timezone by an analyzer in Spark. @@ -56,19 +58,30 @@ pub struct Cast { } impl Cast { - pub fn new(child: Arc, data_type: DataType, timezone: String) -> Self { + pub fn new( + child: Arc, + data_type: DataType, + ansi_mode: bool, + timezone: String, + ) -> Self { Self { child, data_type, timezone, + ansi_mode, } } - pub fn new_without_timezone(child: Arc, data_type: DataType) -> Self { + pub fn new_without_timezone( + child: Arc, + data_type: DataType, + ansi_mode: bool, + ) -> Self { Self { child, data_type, timezone: "".to_string(), + ansi_mode, } } @@ -77,17 +90,22 @@ impl Cast { let array = array_with_timezone(array, self.timezone.clone(), Some(to_type)); let from_type = array.data_type(); let cast_result = match (from_type, to_type) { - (DataType::Utf8, DataType::Boolean) => Self::spark_cast_utf8_to_boolean::(&array), + (DataType::Utf8, DataType::Boolean) => { + Self::spark_cast_utf8_to_boolean::(&array, self.ansi_mode) + } (DataType::LargeUtf8, DataType::Boolean) => { - Self::spark_cast_utf8_to_boolean::(&array) + Self::spark_cast_utf8_to_boolean::(&array, self.ansi_mode) } - _ => cast_with_options(&array, to_type, &CAST_OPTIONS)?, + _ => Ok(cast_with_options(&array, to_type, &CAST_OPTIONS)?), }; - let result = spark_cast(cast_result, from_type, to_type); + let result = spark_cast(cast_result?, from_type, to_type); Ok(result) } - fn spark_cast_utf8_to_boolean(from: &dyn Array) -> ArrayRef + fn spark_cast_utf8_to_boolean( + from: &dyn Array, + ansi_mode: bool, + ) -> CometResult where OffsetSize: OffsetSizeTrait, { @@ -100,15 +118,22 @@ impl Cast { .iter() .map(|value| match value { Some(value) => match value.to_ascii_lowercase().trim() { - "t" | "true" | "y" | "yes" | "1" => Some(true), - "f" | "false" | "n" | "no" | "0" => Some(false), - _ => None, + "t" | "true" | "y" | "yes" | "1" => Ok(Some(true)), + "f" | "false" | "n" | "no" | "0" => Ok(Some(false)), + other if ansi_mode => { + return Err(CometError::CastInvalidValue { + value: other.to_string(), + from_type: "STRING".to_string(), + to_type: "BOOLEAN".to_string(), + }) + } + _ => Ok(None), }, - _ => None, + _ => Ok(None), }) - .collect::(); + .collect::>()?; - Arc::new(output_array) + Ok(Arc::new(output_array)) } } @@ -174,6 +199,7 @@ impl PhysicalExpr for Cast { Ok(Arc::new(Cast::new( children[0].clone(), self.data_type.clone(), + self.ansi_mode, self.timezone.clone(), ))) } diff --git a/core/src/execution/datafusion/planner.rs b/core/src/execution/datafusion/planner.rs index 052ecc44d..745976f02 100644 --- a/core/src/execution/datafusion/planner.rs +++ b/core/src/execution/datafusion/planner.rs @@ -112,27 +112,17 @@ pub struct PhysicalPlanner { exec_context_id: i64, execution_props: ExecutionProps, session_ctx: Arc, -} - -impl Default for PhysicalPlanner { - fn default() -> Self { - let session_ctx = Arc::new(SessionContext::new()); - let execution_props = ExecutionProps::new(); - Self { - exec_context_id: TEST_EXEC_CONTEXT_ID, - execution_props, - session_ctx, - } - } + ansi_mode: bool, } impl PhysicalPlanner { - pub fn new(session_ctx: Arc) -> Self { + pub fn new(session_ctx: Arc, ansi_mode: bool) -> Self { let execution_props = ExecutionProps::new(); Self { exec_context_id: TEST_EXEC_CONTEXT_ID, execution_props, session_ctx, + ansi_mode, } } @@ -141,6 +131,7 @@ impl PhysicalPlanner { exec_context_id, execution_props: self.execution_props, session_ctx: self.session_ctx.clone(), + ansi_mode: self.ansi_mode, } } @@ -343,7 +334,12 @@ impl PhysicalPlanner { let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema)?; let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap()); let timezone = expr.timezone.clone(); - Ok(Arc::new(Cast::new(child, datatype, timezone))) + Ok(Arc::new(Cast::new( + child, + datatype, + self.ansi_mode, + timezone, + ))) } ExprStruct::Hour(expr) => { let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema)?; @@ -638,13 +634,19 @@ impl PhysicalPlanner { let left = Arc::new(Cast::new_without_timezone( left, DataType::Decimal256(p1, s1), + self.ansi_mode, )); let right = Arc::new(Cast::new_without_timezone( right, DataType::Decimal256(p2, s2), + self.ansi_mode, )); let child = Arc::new(BinaryExpr::new(left, op, right)); - Ok(Arc::new(Cast::new_without_timezone(child, data_type))) + Ok(Arc::new(Cast::new_without_timezone( + child, + data_type, + self.ansi_mode, + ))) } ( DataFusionOperator::Divide, @@ -1432,7 +1434,9 @@ mod tests { use arrow_array::{DictionaryArray, Int32Array, StringArray}; use arrow_schema::DataType; - use datafusion::{physical_plan::common::collect, prelude::SessionContext}; + use datafusion::{ + execution::context::ExecutionProps, physical_plan::common::collect, prelude::SessionContext, + }; use tokio::sync::mpsc; use crate::execution::{ @@ -1442,9 +1446,23 @@ mod tests { spark_operator, }; + use crate::execution::datafusion::planner::TEST_EXEC_CONTEXT_ID; use spark_expression::expr::ExprStruct::*; use spark_operator::{operator::OpStruct, Operator}; + impl Default for PhysicalPlanner { + fn default() -> Self { + let session_ctx = Arc::new(SessionContext::new()); + let execution_props = ExecutionProps::new(); + Self { + exec_context_id: TEST_EXEC_CONTEXT_ID, + execution_props, + session_ctx, + ansi_mode: false, + } + } + } + #[test] fn test_unpack_dictionary_primitive() { let op_scan = Operator { diff --git a/core/src/execution/jni_api.rs b/core/src/execution/jni_api.rs index 8249097a1..b8218c20a 100644 --- a/core/src/execution/jni_api.rs +++ b/core/src/execution/jni_api.rs @@ -317,11 +317,17 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan( let exec_context_id = exec_context.id; + // TODO there must be a cleaner way to write this + let ansi_mode = match exec_context.conf.get("ansi_mode") { + Some(value) if value == "true" => true, + _ => false, + }; + // Initialize the execution stream. // Because we don't know if input arrays are dictionary-encoded when we create // query plan, we need to defer stream initialization to first time execution. if exec_context.root_op.is_none() { - let planner = PhysicalPlanner::new(exec_context.session_ctx.clone()) + let planner = PhysicalPlanner::new(exec_context.session_ctx.clone(), ansi_mode) .with_exec_id(exec_context_id); let (scans, root_op) = planner.create_plan( &exec_context.spark_plan, diff --git a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala index b3604c9e0..8f8dc17b8 100644 --- a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala +++ b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala @@ -20,7 +20,9 @@ package org.apache.comet import org.apache.spark._ +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.comet.CometMetricNode +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.vectorized._ import org.apache.comet.CometConf.{COMET_BATCH_SIZE, COMET_DEBUG_ENABLED, COMET_EXEC_MEMORY_FRACTION} @@ -44,7 +46,8 @@ class CometExecIterator( val id: Long, inputs: Seq[Iterator[ColumnarBatch]], protobufQueryPlan: Array[Byte], - nativeMetrics: CometMetricNode) + nativeMetrics: CometMetricNode, + ansiEnabled: Boolean) extends Iterator[ColumnarBatch] { private val nativeLib = new Native() @@ -99,6 +102,7 @@ class CometExecIterator( result.put("memory_fraction", String.valueOf(COMET_EXEC_MEMORY_FRACTION.get())) result.put("batch_size", String.valueOf(COMET_BATCH_SIZE.get())) result.put("debug_native", String.valueOf(COMET_DEBUG_ENABLED.get())) + result.put("ansi_mode", String.valueOf(ansiEnabled)) // Strip mandatory prefix spark. which is not required for DataFusion session params conf.getAll.foreach { diff --git a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala index 275b9ebff..81af55044 100644 --- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala @@ -567,10 +567,10 @@ class CometSparkSessionExtensions override def apply(plan: SparkPlan): SparkPlan = { // DataFusion doesn't have ANSI mode. For now we just disable CometExec if ANSI mode is // enabled. - if (isANSIEnabled(conf)) { - logInfo("Comet extension disabled for ANSI mode") - return plan - } +// if (isANSIEnabled(conf)) { +// logInfo("Comet extension disabled for ANSI mode") +// return plan +// } // We shouldn't transform Spark query plan if Comet is disabled. if (!isCometEnabled(conf)) return plan diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometCollectLimitExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometCollectLimitExec.scala index dd4855126..3addb3ba1 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometCollectLimitExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometCollectLimitExec.scala @@ -21,10 +21,12 @@ package org.apache.spark.sql.comet import org.apache.spark.rdd.RDD import org.apache.spark.serializer.Serializer +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.comet.execution.shuffle.{CometShuffledBatchRDD, CometShuffleExchangeExec} import org.apache.spark.sql.execution.{ColumnarToRowExec, SparkPlan, UnaryExecNode, UnsafeRowSerializer} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics, SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.vectorized.ColumnarBatch import com.google.common.base.Objects @@ -73,7 +75,11 @@ case class CometCollectLimitExec( childRDD } else { val localLimitedRDD = if (limit >= 0) { - CometExecUtils.getNativeLimitRDD(childRDD, output, limit) + CometExecUtils.getNativeLimitRDD( + childRDD, + output, + limit, + SparkSession.active.conf.get(SQLConf.ANSI_ENABLED)) } else { childRDD } @@ -88,7 +94,11 @@ case class CometCollectLimitExec( new CometShuffledBatchRDD(dep, readMetrics) } - CometExecUtils.getNativeLimitRDD(singlePartitionRDD, output, limit) + CometExecUtils.getNativeLimitRDD( + singlePartitionRDD, + output, + limit, + SparkSession.active.conf.get(SQLConf.ANSI_ENABLED)) } } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala index 5931920a2..805216d66 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala @@ -50,10 +50,11 @@ object CometExecUtils { def getNativeLimitRDD( childPlan: RDD[ColumnarBatch], outputAttribute: Seq[Attribute], - limit: Int): RDD[ColumnarBatch] = { + limit: Int, + ansiMode: Boolean): RDD[ColumnarBatch] = { childPlan.mapPartitionsInternal { iter => val limitOp = CometExecUtils.getLimitNativePlan(outputAttribute, limit).get - CometExec.getCometIterator(Seq(iter), limitOp) + CometExec.getCometIterator(Seq(iter), limitOp, ansiMode) } } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala index 26ec401ed..7ecdbfd5a 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala @@ -21,12 +21,14 @@ package org.apache.spark.sql.comet import org.apache.spark.rdd.{ParallelCollectionRDD, RDD} import org.apache.spark.serializer.Serializer +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.{Attribute, NamedExpression, SortOrder} import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.comet.execution.shuffle.{CometShuffledBatchRDD, CometShuffleExchangeExec} import org.apache.spark.sql.execution.{SparkPlan, TakeOrderedAndProjectExec, UnaryExecNode, UnsafeRowSerializer} import org.apache.spark.sql.execution.metric.{SQLMetrics, SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter} import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.comet.serde.QueryPlanSerde.exprToProto @@ -73,18 +75,19 @@ case class CometTakeOrderedAndProjectExec( if (childRDD.getNumPartitions == 0) { new ParallelCollectionRDD(sparkContext, Seq.empty[ColumnarBatch], 1, Map.empty) } else { + val ansiEnabled = SparkSession.active.conf.get(SQLConf.ANSI_ENABLED) val singlePartitionRDD = if (childRDD.getNumPartitions == 1) { childRDD } else { val localTopK = if (orderingSatisfies) { - CometExecUtils.getNativeLimitRDD(childRDD, output, limit) + CometExecUtils.getNativeLimitRDD(childRDD, output, limit, ansiEnabled) } else { childRDD.mapPartitionsInternal { iter => val topK = CometExecUtils .getTopKNativePlan(output, sortOrder, child, limit) .get - CometExec.getCometIterator(Seq(iter), topK) + CometExec.getCometIterator(Seq(iter), topK, ansiEnabled) } } @@ -104,7 +107,7 @@ case class CometTakeOrderedAndProjectExec( val topKAndProjection = CometExecUtils .getProjectionNativePlan(projectList, output, sortOrder, child, limit) .get - CometExec.getCometIterator(Seq(iter), topKAndProjection) + CometExec.getCometIterator(Seq(iter), topKAndProjection, ansiEnabled) } } } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala index 232b6bf17..742015a5d 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala @@ -33,6 +33,7 @@ import org.apache.spark.scheduler.MapStatus import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleWriteMetricsReporter, ShuffleWriteProcessor} import org.apache.spark.shuffle.sort.SortShuffleManager +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering @@ -462,7 +463,11 @@ class CometShuffleWriteProcessor( val nativeMetrics = CometMetricNode(nativeSQLMetrics) val rawIter = cometRDD.iterator(partition, context) - val cometIter = CometExec.getCometIterator(Seq(rawIter), nativePlan, nativeMetrics) + val cometIter = CometExec.getCometIterator( + Seq(rawIter), + nativePlan, + nativeMetrics, + SparkSession.active.conf.get(SQLConf.ANSI_ENABLED)) while (cometIter.hasNext) { cometIter.next() diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index 8545eee90..e45b191c2 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -91,19 +91,21 @@ object CometExec { def getCometIterator( inputs: Seq[Iterator[ColumnarBatch]], - nativePlan: Operator): CometExecIterator = { - getCometIterator(inputs, nativePlan, CometMetricNode(Map.empty)) + nativePlan: Operator, + ansiMode: Boolean): CometExecIterator = { + getCometIterator(inputs, nativePlan, CometMetricNode(Map.empty), ansiMode) } def getCometIterator( inputs: Seq[Iterator[ColumnarBatch]], nativePlan: Operator, - nativeMetrics: CometMetricNode): CometExecIterator = { + nativeMetrics: CometMetricNode, + ansiMode: Boolean): CometExecIterator = { val outputStream = new ByteArrayOutputStream() nativePlan.writeTo(outputStream) outputStream.close() val bytes = outputStream.toByteArray - new CometExecIterator(newIterId, inputs, bytes, nativeMetrics) + new CometExecIterator(newIterId, inputs, bytes, nativeMetrics, ansiMode) } /** @@ -199,6 +201,7 @@ abstract class CometNativeExec extends CometExec { // Switch to use Decimal128 regardless of precision, since Arrow native execution // doesn't support Decimal32 and Decimal64 yet. SQLConf.get.setConfString(CometConf.COMET_USE_DECIMAL_128.key, "true") + val ansiEnabled: Boolean = SQLConf.get.getConf[Boolean](SQLConf.ANSI_ENABLED) val serializedPlanCopy = serializedPlan // TODO: support native metrics for all operators. @@ -206,7 +209,12 @@ abstract class CometNativeExec extends CometExec { def createCometExecIter(inputs: Seq[Iterator[ColumnarBatch]]): CometExecIterator = { val it = - new CometExecIterator(CometExec.newIterId, inputs, serializedPlanCopy, nativeMetrics) + new CometExecIterator( + CometExec.newIterId, + inputs, + serializedPlanCopy, + nativeMetrics, + ansiEnabled) setSubqueries(it.id, originalPlan) diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala index e8d340f21..6e711e5a6 100644 --- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala @@ -20,17 +20,31 @@ package org.apache.comet import java.io.File +import java.sql.SQLException import scala.util.Random +import org.apache.spark.SparkException import org.apache.spark.sql.{CometTestBase, DataFrame, SaveMode} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper -import org.apache.spark.sql.functions.col +import org.apache.spark.sql.functions.{col, exp, expr} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, DataTypes} class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { import testImplicits._ + private val dataSize = 1000 + + // we should eventually add more whitespace chars here as documented in + // https://docs.oracle.com/javase/8/docs/api/java/lang/Character.html#isWhitespace-char- + // but this is likely a reasonable starting point for now + private val whitespaceChars = " \t\r\n" + + private val numericPattern = "0123456789e+-." + whitespaceChars + private val datePattern = "0123456789/" + whitespaceChars + private val timestampPattern = "0123456789/:T" + whitespaceChars + ignore("cast long to short") { castTest(generateLongs, DataTypes.ShortType) } @@ -48,59 +62,99 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { } test("cast string to bool") { - castTest( - Seq("TRUE", "True", "true", "FALSE", "False", "false", "1", "0", "").toDF("a"), - DataTypes.BooleanType) - fuzzCastFromString("truefalseTRUEFALSEyesno10 \t\r\n", 8, DataTypes.BooleanType) + val testValues = + (Seq("TRUE", "True", "true", "FALSE", "False", "false", "1", "0", "", null) ++ + generateStrings("truefalseTRUEFALSEyesno10 \t\r\n", 8)).toDF("a") + castTest(testValues, DataTypes.BooleanType) + } + + ignore("cast string to byte") { + castTest(generateStrings(numericPattern, 8).toDF("a"), DataTypes.ByteType) } ignore("cast string to short") { - fuzzCastFromString("0123456789e+- \t\r\n", 8, DataTypes.ShortType) + castTest(generateStrings(numericPattern, 8).toDF("a"), DataTypes.ShortType) + } + + ignore("cast string to int") { + castTest(generateStrings(numericPattern, 8).toDF("a"), DataTypes.IntegerType) + } + + ignore("cast string to long") { + castTest(generateStrings(numericPattern, 8).toDF("a"), DataTypes.LongType) } ignore("cast string to float") { - fuzzCastFromString("0123456789e+- \t\r\n", 8, DataTypes.FloatType) + castTest(generateStrings(numericPattern, 8).toDF("a"), DataTypes.FloatType) } ignore("cast string to double") { - fuzzCastFromString("0123456789e+- \t\r\n", 8, DataTypes.DoubleType) + castTest(generateStrings(numericPattern, 8).toDF("a"), DataTypes.DoubleType) + } + + ignore("cast string to decimal") { + val values = generateStrings(numericPattern, 8).toDF("a") + castTest(values, DataTypes.createDecimalType(10, 2)) + castTest(values, DataTypes.createDecimalType(10, 0)) + castTest(values, DataTypes.createDecimalType(10, -2)) } ignore("cast string to date") { - fuzzCastFromString("0123456789/ \t\r\n", 16, DataTypes.DateType) + castTest(generateStrings(numericPattern, 8).toDF("a"), DataTypes.DoubleType) } ignore("cast string to timestamp") { - castTest(Seq("2020-01-01T12:34:56.123456", "T2").toDF("a"), DataTypes.TimestampType) - fuzzCastFromString("0123456789/:T \t\r\n", 32, DataTypes.TimestampType) + val values = Seq("2020-01-01T12:34:56.123456", "T2") ++ generateStrings(numericPattern, 8) + castTest(values.toDF("a"), DataTypes.DoubleType) } - private def generateFloats = { + private def generateFloats(): DataFrame = { val r = new Random(0) - Range(0, 10000).map(_ => r.nextFloat()).toDF("a") + Range(0, dataSize).map(_ => r.nextFloat()).toDF("a") } - private def generateLongs = { + private def generateLongs(): DataFrame = { val r = new Random(0) - Range(0, 10000).map(_ => r.nextLong()).toDF("a") + Range(0, dataSize).map(_ => r.nextLong()).toDF("a") } - private def genString(r: Random, chars: String, maxLen: Int): String = { + private def generateString(r: Random, chars: String, maxLen: Int): String = { val len = r.nextInt(maxLen) Range(0, len).map(_ => chars.charAt(r.nextInt(chars.length))).mkString } - private def fuzzCastFromString(chars: String, maxLen: Int, toType: DataType): Unit = { + private def generateStrings(chars: String, maxLen: Int): Seq[String] = { val r = new Random(0) - val inputs = Range(0, 10000).map(_ => genString(r, chars, maxLen)) - castTest(inputs.toDF("a"), toType) + Range(0, dataSize).map(_ => generateString(r, chars, maxLen)) } private def castTest(input: DataFrame, toType: DataType): Unit = { withTempPath { dir => - val df = roundtripParquet(input, dir) - .withColumn("converted", col("a").cast(toType)) - checkSparkAnswer(df) + val data = roundtripParquet(input, dir) + data.createOrReplaceTempView("t") + + withSQLConf((SQLConf.ANSI_ENABLED.key, "false")) { + // cast() should return null for invalid inputs when ansi mode is disabled + val df = data.withColumn("converted", col("a").cast(toType)) + checkSparkAnswer(df) + + // try_cast() should always return null for invalid inputs + val df2 = spark.sql(s"select try_cast(a as ${toType.sql}) from t") + checkSparkAnswer(df2) + } + + // with ANSI enabled, we should produce the same exception as Spark + withSQLConf((SQLConf.ANSI_ENABLED.key, "true")) { + + // cast() should throw exception on invalid inputs when ansi mode is enabled + val df = data.withColumn("converted", col("a").cast(toType)) + val (expected, actual) = checkSparkThrows(df) + assert(expected.getMessage == actual.getMessage) + + // try_cast() should always return null for invalid inputs + val df2 = spark.sql(s"select try_cast(a as ${toType.sql}) from t") + checkSparkAnswer(df2) + } } } diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala index de5866580..c6cad08e4 100644 --- a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala +++ b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala @@ -22,6 +22,7 @@ package org.apache.spark.sql import scala.concurrent.duration._ import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag +import scala.util.Try import org.scalatest.BeforeAndAfterEach @@ -215,6 +216,17 @@ abstract class CometTestBase checkAnswerWithTol(dfComet, expected, absTol: Double) } + protected def checkSparkThrows(df: => DataFrame): (Throwable, Throwable) = { + var expected: Option[Throwable] = None + withSQLConf(CometConf.COMET_ENABLED.key -> "false") { + val dfSpark = Dataset.ofRows(spark, df.logicalPlan) + expected = Try(dfSpark.collect()).failed.toOption + } + val dfComet = Dataset.ofRows(spark, df.logicalPlan) + val actual = Try(dfComet.collect()).failed.get + (expected.get.getCause, actual.getCause) + } + private var _spark: SparkSession = _ protected implicit def spark: SparkSession = _spark protected implicit def sqlContext: SQLContext = _spark.sqlContext From 0e2099e3fb4fb7f83a7bae29ef68bdd75eb6b9e8 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 19 Apr 2024 10:58:05 -0600 Subject: [PATCH 02/19] fix error handling --- core/src/execution/datafusion/expressions/cast.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/core/src/execution/datafusion/expressions/cast.rs b/core/src/execution/datafusion/expressions/cast.rs index 7c8caece9..2dc2ce62d 100644 --- a/core/src/execution/datafusion/expressions/cast.rs +++ b/core/src/execution/datafusion/expressions/cast.rs @@ -91,14 +91,14 @@ impl Cast { let from_type = array.data_type(); let cast_result = match (from_type, to_type) { (DataType::Utf8, DataType::Boolean) => { - Self::spark_cast_utf8_to_boolean::(&array, self.ansi_mode) + Self::spark_cast_utf8_to_boolean::(&array, self.ansi_mode)? } (DataType::LargeUtf8, DataType::Boolean) => { - Self::spark_cast_utf8_to_boolean::(&array, self.ansi_mode) + Self::spark_cast_utf8_to_boolean::(&array, self.ansi_mode)? } - _ => Ok(cast_with_options(&array, to_type, &CAST_OPTIONS)?), + _ => cast_with_options(&array, to_type, &CAST_OPTIONS)?, }; - let result = spark_cast(cast_result?, from_type, to_type); + let result = spark_cast(cast_result, from_type, to_type); Ok(result) } From 7c5139ff7b6f9b76686b79d07640216c12ea4825 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 19 Apr 2024 11:05:20 -0600 Subject: [PATCH 03/19] add comet conf to enable experimental ansi mode --- .../src/main/scala/org/apache/comet/CometConf.scala | 8 ++++++++ .../apache/comet/CometSparkSessionExtensions.scala | 12 ++++++++---- .../test/scala/org/apache/comet/CometCastSuite.scala | 8 +++++--- 3 files changed, 21 insertions(+), 7 deletions(-) diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala b/common/src/main/scala/org/apache/comet/CometConf.scala index 341ec98df..d993d6ddd 100644 --- a/common/src/main/scala/org/apache/comet/CometConf.scala +++ b/common/src/main/scala/org/apache/comet/CometConf.scala @@ -357,6 +357,14 @@ object CometConf { .toSequence .createWithDefault(Seq("Range,InMemoryTableScan")) + val COMET_ANSI_MODE_ENABLED: ConfigEntry[Boolean] = conf("spark.comet.ansi.enabled") + .doc( + "Comet does not respect ANSI mode in most cases and by default will not accelerate " + + "queries when ansi mode is enabled. Enable this setting to test Comet's experimental " + + "support for ANSI mode. This should not be used in production.") + .booleanConf + .createWithDefault(false) + } object ConfigHelpers { diff --git a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala index 81af55044..d51874321 100644 --- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala @@ -567,10 +567,14 @@ class CometSparkSessionExtensions override def apply(plan: SparkPlan): SparkPlan = { // DataFusion doesn't have ANSI mode. For now we just disable CometExec if ANSI mode is // enabled. -// if (isANSIEnabled(conf)) { -// logInfo("Comet extension disabled for ANSI mode") -// return plan -// } + if (isANSIEnabled(conf)) { + if (COMET_ANSI_MODE_ENABLED.get()) { + logWarning("Using Comet's experimental support for ANSI mode.") + } else { + logInfo("Comet extension disabled for ANSI mode") + return plan + } + } // We shouldn't transform Spark query plan if Comet is disabled. if (!isCometEnabled(conf)) return plan diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala index 6e711e5a6..0e5367389 100644 --- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala @@ -49,11 +49,11 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { castTest(generateLongs, DataTypes.ShortType) } - test("cast float to bool") { + ignore("cast float to bool") { castTest(generateFloats, DataTypes.BooleanType) } - test("cast float to int") { + ignore("cast float to int") { castTest(generateFloats, DataTypes.IntegerType) } @@ -144,7 +144,9 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { } // with ANSI enabled, we should produce the same exception as Spark - withSQLConf((SQLConf.ANSI_ENABLED.key, "true")) { + withSQLConf( + (SQLConf.ANSI_ENABLED.key, "true"), + (CometConf.COMET_ANSI_MODE_ENABLED.key, "true")) { // cast() should throw exception on invalid inputs when ansi mode is enabled val df = data.withColumn("converted", col("a").cast(toType)) From a96042420c301e02604272325c79980dd6ba25b9 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 19 Apr 2024 11:06:32 -0600 Subject: [PATCH 04/19] fix --- spark/src/test/scala/org/apache/comet/CometCastSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala index 0e5367389..3cc90253a 100644 --- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala @@ -100,11 +100,11 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { } ignore("cast string to date") { - castTest(generateStrings(numericPattern, 8).toDF("a"), DataTypes.DoubleType) + castTest(generateStrings(datePattern, 8).toDF("a"), DataTypes.DoubleType) } ignore("cast string to timestamp") { - val values = Seq("2020-01-01T12:34:56.123456", "T2") ++ generateStrings(numericPattern, 8) + val values = Seq("2020-01-01T12:34:56.123456", "T2") ++ generateStrings(timestampPattern, 8) castTest(values.toDF("a"), DataTypes.DoubleType) } From 26af2fca026a0f1b5be53434511b710d73bc33b0 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 19 Apr 2024 12:46:58 -0600 Subject: [PATCH 05/19] code cleanup --- core/src/execution/jni_api.rs | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/core/src/execution/jni_api.rs b/core/src/execution/jni_api.rs index b8218c20a..4f7bc3df8 100644 --- a/core/src/execution/jni_api.rs +++ b/core/src/execution/jni_api.rs @@ -317,11 +317,8 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan( let exec_context_id = exec_context.id; - // TODO there must be a cleaner way to write this - let ansi_mode = match exec_context.conf.get("ansi_mode") { - Some(value) if value == "true" => true, - _ => false, - }; + let ansi_mode = + matches!(exec_context.conf.get("ansi_mode"), Some(value) if value == "true"); // Initialize the execution stream. // Because we don't know if input arrays are dictionary-encoded when we create From c90cce9804f1bd1c9ede136a96029b800bda197e Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 19 Apr 2024 12:57:35 -0600 Subject: [PATCH 06/19] clippy --- core/src/execution/datafusion/expressions/cast.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/execution/datafusion/expressions/cast.rs b/core/src/execution/datafusion/expressions/cast.rs index 2dc2ce62d..ca9916f74 100644 --- a/core/src/execution/datafusion/expressions/cast.rs +++ b/core/src/execution/datafusion/expressions/cast.rs @@ -121,7 +121,7 @@ impl Cast { "t" | "true" | "y" | "yes" | "1" => Ok(Some(true)), "f" | "false" | "n" | "no" | "0" => Ok(Some(false)), other if ansi_mode => { - return Err(CometError::CastInvalidValue { + Err(CometError::CastInvalidValue { value: other.to_string(), from_type: "STRING".to_string(), to_type: "BOOLEAN".to_string(), From 5023635f6bbf693056d5d3182274ce6ec4ffb97c Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 19 Apr 2024 14:50:56 -0600 Subject: [PATCH 07/19] Add eval_mode to cast proto, remove ansi mode from planner --- core/src/errors.rs | 4 +- .../execution/datafusion/expressions/cast.rs | 29 ++++++---- core/src/execution/datafusion/planner.rs | 53 +++++++++---------- core/src/execution/jni_api.rs | 5 +- core/src/execution/proto/expr.proto | 2 + .../apache/comet/serde/QueryPlanSerde.scala | 11 ++-- .../org/apache/comet/CometCastSuite.scala | 7 ++- 7 files changed, 59 insertions(+), 52 deletions(-) diff --git a/core/src/errors.rs b/core/src/errors.rs index f4b5d738a..676e1fd9a 100644 --- a/core/src/errors.rs +++ b/core/src/errors.rs @@ -62,10 +62,10 @@ pub enum CometError { // TODO this error message is likely to change between Spark versions and it would be better // to have the full error in Scala and just pass the invalid value back here - #[error("[[CAST_INVALID_INPUT] The value '{value}' of the type \"{from_type}\" cannot be cast to \"{to_type}\" \ + #[error("[CAST_INVALID_INPUT] The value '{value}' of the type \"{from_type}\" cannot be cast to \"{to_type}\" \ because it is malformed. Correct the value as per the syntax, or change its target type. \ Use `try_cast` to tolerate malformed input and return NULL instead. If necessary \ - set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error")] + set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] CastInvalidValue { value: String, from_type: String, diff --git a/core/src/execution/datafusion/expressions/cast.rs b/core/src/execution/datafusion/expressions/cast.rs index ca9916f74..91066e0e2 100644 --- a/core/src/execution/datafusion/expressions/cast.rs +++ b/core/src/execution/datafusion/expressions/cast.rs @@ -46,11 +46,18 @@ static CAST_OPTIONS: CastOptions = CastOptions { .with_timestamp_format(TIMESTAMP_FORMAT), }; +#[derive(Debug, Hash, PartialEq, Clone, Copy)] +pub enum EvalMode { + Legacy, + Ansi, + Try, +} + #[derive(Debug, Hash)] pub struct Cast { pub child: Arc, pub data_type: DataType, - pub ansi_mode: bool, + pub eval_mode: EvalMode, /// When cast from/to timezone related types, we need timezone, which will be resolved with /// session local timezone by an analyzer in Spark. @@ -61,27 +68,27 @@ impl Cast { pub fn new( child: Arc, data_type: DataType, - ansi_mode: bool, + eval_mode: EvalMode, timezone: String, ) -> Self { Self { child, data_type, timezone, - ansi_mode, + eval_mode, } } pub fn new_without_timezone( child: Arc, data_type: DataType, - ansi_mode: bool, + eval_mode: EvalMode, ) -> Self { Self { child, data_type, timezone: "".to_string(), - ansi_mode, + eval_mode, } } @@ -91,10 +98,10 @@ impl Cast { let from_type = array.data_type(); let cast_result = match (from_type, to_type) { (DataType::Utf8, DataType::Boolean) => { - Self::spark_cast_utf8_to_boolean::(&array, self.ansi_mode)? + Self::spark_cast_utf8_to_boolean::(&array, self.eval_mode)? } (DataType::LargeUtf8, DataType::Boolean) => { - Self::spark_cast_utf8_to_boolean::(&array, self.ansi_mode)? + Self::spark_cast_utf8_to_boolean::(&array, self.eval_mode)? } _ => cast_with_options(&array, to_type, &CAST_OPTIONS)?, }; @@ -104,7 +111,7 @@ impl Cast { fn spark_cast_utf8_to_boolean( from: &dyn Array, - ansi_mode: bool, + eval_mode: EvalMode, ) -> CometResult where OffsetSize: OffsetSizeTrait, @@ -120,9 +127,9 @@ impl Cast { Some(value) => match value.to_ascii_lowercase().trim() { "t" | "true" | "y" | "yes" | "1" => Ok(Some(true)), "f" | "false" | "n" | "no" | "0" => Ok(Some(false)), - other if ansi_mode => { + _ if eval_mode == EvalMode::Ansi => { Err(CometError::CastInvalidValue { - value: other.to_string(), + value: value.to_string(), from_type: "STRING".to_string(), to_type: "BOOLEAN".to_string(), }) @@ -199,7 +206,7 @@ impl PhysicalExpr for Cast { Ok(Arc::new(Cast::new( children[0].clone(), self.data_type.clone(), - self.ansi_mode, + self.eval_mode, self.timezone.clone(), ))) } diff --git a/core/src/execution/datafusion/planner.rs b/core/src/execution/datafusion/planner.rs index 745976f02..3385ff4f8 100644 --- a/core/src/execution/datafusion/planner.rs +++ b/core/src/execution/datafusion/planner.rs @@ -89,6 +89,7 @@ use crate::{ spark_partitioning::{partitioning::PartitioningStruct, Partitioning as SparkPartitioning}, }, }; +use crate::execution::datafusion::expressions::cast::EvalMode; // For clippy error on type_complexity. type ExecResult = Result; @@ -112,17 +113,27 @@ pub struct PhysicalPlanner { exec_context_id: i64, execution_props: ExecutionProps, session_ctx: Arc, - ansi_mode: bool, +} + +impl Default for PhysicalPlanner { + fn default() -> Self { + let session_ctx = Arc::new(SessionContext::new()); + let execution_props = ExecutionProps::new(); + Self { + exec_context_id: TEST_EXEC_CONTEXT_ID, + execution_props, + session_ctx, + } + } } impl PhysicalPlanner { - pub fn new(session_ctx: Arc, ansi_mode: bool) -> Self { + pub fn new(session_ctx: Arc) -> Self { let execution_props = ExecutionProps::new(); Self { exec_context_id: TEST_EXEC_CONTEXT_ID, execution_props, session_ctx, - ansi_mode, } } @@ -131,7 +142,6 @@ impl PhysicalPlanner { exec_context_id, execution_props: self.execution_props, session_ctx: self.session_ctx.clone(), - ansi_mode: self.ansi_mode, } } @@ -334,10 +344,15 @@ impl PhysicalPlanner { let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema)?; let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap()); let timezone = expr.timezone.clone(); + let eval_mode = match expr.eval_mode.as_str() { + "ANSI" => EvalMode::Ansi, + "TRY" => EvalMode::Try, + _ => EvalMode::Legacy, + }; Ok(Arc::new(Cast::new( child, datatype, - self.ansi_mode, + eval_mode, timezone, ))) } @@ -634,19 +649,15 @@ impl PhysicalPlanner { let left = Arc::new(Cast::new_without_timezone( left, DataType::Decimal256(p1, s1), - self.ansi_mode, + EvalMode::Legacy )); let right = Arc::new(Cast::new_without_timezone( right, DataType::Decimal256(p2, s2), - self.ansi_mode, + EvalMode::Legacy )); let child = Arc::new(BinaryExpr::new(left, op, right)); - Ok(Arc::new(Cast::new_without_timezone( - child, - data_type, - self.ansi_mode, - ))) + Ok(Arc::new(Cast::new_without_timezone(child, data_type, EvalMode::Legacy))) } ( DataFusionOperator::Divide, @@ -1434,9 +1445,7 @@ mod tests { use arrow_array::{DictionaryArray, Int32Array, StringArray}; use arrow_schema::DataType; - use datafusion::{ - execution::context::ExecutionProps, physical_plan::common::collect, prelude::SessionContext, - }; + use datafusion::{physical_plan::common::collect, prelude::SessionContext}; use tokio::sync::mpsc; use crate::execution::{ @@ -1446,23 +1455,9 @@ mod tests { spark_operator, }; - use crate::execution::datafusion::planner::TEST_EXEC_CONTEXT_ID; use spark_expression::expr::ExprStruct::*; use spark_operator::{operator::OpStruct, Operator}; - impl Default for PhysicalPlanner { - fn default() -> Self { - let session_ctx = Arc::new(SessionContext::new()); - let execution_props = ExecutionProps::new(); - Self { - exec_context_id: TEST_EXEC_CONTEXT_ID, - execution_props, - session_ctx, - ansi_mode: false, - } - } - } - #[test] fn test_unpack_dictionary_primitive() { let op_scan = Operator { diff --git a/core/src/execution/jni_api.rs b/core/src/execution/jni_api.rs index 4f7bc3df8..8249097a1 100644 --- a/core/src/execution/jni_api.rs +++ b/core/src/execution/jni_api.rs @@ -317,14 +317,11 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan( let exec_context_id = exec_context.id; - let ansi_mode = - matches!(exec_context.conf.get("ansi_mode"), Some(value) if value == "true"); - // Initialize the execution stream. // Because we don't know if input arrays are dictionary-encoded when we create // query plan, we need to defer stream initialization to first time execution. if exec_context.root_op.is_none() { - let planner = PhysicalPlanner::new(exec_context.session_ctx.clone(), ansi_mode) + let planner = PhysicalPlanner::new(exec_context.session_ctx.clone()) .with_exec_id(exec_context_id); let (scans, root_op) = planner.create_plan( &exec_context.spark_plan, diff --git a/core/src/execution/proto/expr.proto b/core/src/execution/proto/expr.proto index 58f607fc0..8ac428cdb 100644 --- a/core/src/execution/proto/expr.proto +++ b/core/src/execution/proto/expr.proto @@ -208,6 +208,8 @@ message Cast { Expr child = 1; DataType datatype = 2; string timezone = 3; + // LEGACY, ANSI, or TRY + string eval_mode = 4; } message Equal { diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 26fc708ff..79b00e13e 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -414,7 +414,8 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { def castToProto( timeZoneId: Option[String], dt: DataType, - childExpr: Option[Expr]): Option[Expr] = { + childExpr: Option[Expr], + evalMode: EvalMode.Value): Option[Expr] = { val dataType = serializeDataType(dt) if (childExpr.isDefined && dataType.isDefined) { @@ -425,6 +426,8 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { val timeZone = timeZoneId.getOrElse("UTC") castBuilder.setTimezone(timeZone) + castBuilder.setEvalMode(evalMode.toString) + Some( ExprOuterClass.Expr .newBuilder() @@ -446,9 +449,9 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { val value = cast.eval() exprToProtoInternal(Literal(value, dataType), inputs) - case Cast(child, dt, timeZoneId, _) => + case Cast(child, dt, timeZoneId, evalMode) => val childExpr = exprToProtoInternal(child, inputs) - castToProto(timeZoneId, dt, childExpr) + castToProto(timeZoneId, dt, childExpr, evalMode) case add @ Add(left, right, _) if supportedDataType(left.dataType) => val leftExpr = exprToProtoInternal(left, inputs) @@ -1565,7 +1568,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { val childExpr = scalarExprToProto("coalesce", exprChildren: _*) // TODO: Remove this once we have new DataFusion release which includes // the fix: https://github.com/apache/arrow-datafusion/pull/9459 - castToProto(None, a.dataType, childExpr) + castToProto(None, a.dataType, childExpr, EvalMode.LEGACY) // With Spark 3.4, CharVarcharCodegenUtils.readSidePadding gets called to pad spaces for // char types. Use rpad to achieve the behavior. diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala index 3cc90253a..bdf93140d 100644 --- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala @@ -130,7 +130,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { private def castTest(input: DataFrame, toType: DataType): Unit = { withTempPath { dir => - val data = roundtripParquet(input, dir) + val data = roundtripParquet(input, dir).coalesce(1) data.createOrReplaceTempView("t") withSQLConf((SQLConf.ANSI_ENABLED.key, "false")) { @@ -151,7 +151,10 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { // cast() should throw exception on invalid inputs when ansi mode is enabled val df = data.withColumn("converted", col("a").cast(toType)) val (expected, actual) = checkSparkThrows(df) - assert(expected.getMessage == actual.getMessage) + + // TODO we have to strip off a prefix that is added by DataFusion and it would be nice + // to stop this being added + assert(expected.getMessage == actual.getMessage.substring("Execution error: ".length)) // try_cast() should always return null for invalid inputs val df2 = spark.sql(s"select try_cast(a as ${toType.sql}) from t") From 7169bada7a7028b0199d2277537cd98c4172048c Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 19 Apr 2024 14:55:03 -0600 Subject: [PATCH 08/19] revert ansi mode changes --- .../org/apache/comet/CometExecIterator.scala | 6 +----- .../sql/comet/CometCollectLimitExec.scala | 14 ++------------ .../spark/sql/comet/CometExecUtils.scala | 5 ++--- .../comet/CometTakeOrderedAndProjectExec.scala | 9 +++------ .../shuffle/CometShuffleExchangeExec.scala | 7 +------ .../org/apache/spark/sql/comet/operators.scala | 18 +++++------------- 6 files changed, 14 insertions(+), 45 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala index 8f8dc17b8..b3604c9e0 100644 --- a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala +++ b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala @@ -20,9 +20,7 @@ package org.apache.comet import org.apache.spark._ -import org.apache.spark.sql.SparkSession import org.apache.spark.sql.comet.CometMetricNode -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.vectorized._ import org.apache.comet.CometConf.{COMET_BATCH_SIZE, COMET_DEBUG_ENABLED, COMET_EXEC_MEMORY_FRACTION} @@ -46,8 +44,7 @@ class CometExecIterator( val id: Long, inputs: Seq[Iterator[ColumnarBatch]], protobufQueryPlan: Array[Byte], - nativeMetrics: CometMetricNode, - ansiEnabled: Boolean) + nativeMetrics: CometMetricNode) extends Iterator[ColumnarBatch] { private val nativeLib = new Native() @@ -102,7 +99,6 @@ class CometExecIterator( result.put("memory_fraction", String.valueOf(COMET_EXEC_MEMORY_FRACTION.get())) result.put("batch_size", String.valueOf(COMET_BATCH_SIZE.get())) result.put("debug_native", String.valueOf(COMET_DEBUG_ENABLED.get())) - result.put("ansi_mode", String.valueOf(ansiEnabled)) // Strip mandatory prefix spark. which is not required for DataFusion session params conf.getAll.foreach { diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometCollectLimitExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometCollectLimitExec.scala index 3addb3ba1..dd4855126 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometCollectLimitExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometCollectLimitExec.scala @@ -21,12 +21,10 @@ package org.apache.spark.sql.comet import org.apache.spark.rdd.RDD import org.apache.spark.serializer.Serializer -import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.comet.execution.shuffle.{CometShuffledBatchRDD, CometShuffleExchangeExec} import org.apache.spark.sql.execution.{ColumnarToRowExec, SparkPlan, UnaryExecNode, UnsafeRowSerializer} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics, SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter} -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.vectorized.ColumnarBatch import com.google.common.base.Objects @@ -75,11 +73,7 @@ case class CometCollectLimitExec( childRDD } else { val localLimitedRDD = if (limit >= 0) { - CometExecUtils.getNativeLimitRDD( - childRDD, - output, - limit, - SparkSession.active.conf.get(SQLConf.ANSI_ENABLED)) + CometExecUtils.getNativeLimitRDD(childRDD, output, limit) } else { childRDD } @@ -94,11 +88,7 @@ case class CometCollectLimitExec( new CometShuffledBatchRDD(dep, readMetrics) } - CometExecUtils.getNativeLimitRDD( - singlePartitionRDD, - output, - limit, - SparkSession.active.conf.get(SQLConf.ANSI_ENABLED)) + CometExecUtils.getNativeLimitRDD(singlePartitionRDD, output, limit) } } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala index 805216d66..5931920a2 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala @@ -50,11 +50,10 @@ object CometExecUtils { def getNativeLimitRDD( childPlan: RDD[ColumnarBatch], outputAttribute: Seq[Attribute], - limit: Int, - ansiMode: Boolean): RDD[ColumnarBatch] = { + limit: Int): RDD[ColumnarBatch] = { childPlan.mapPartitionsInternal { iter => val limitOp = CometExecUtils.getLimitNativePlan(outputAttribute, limit).get - CometExec.getCometIterator(Seq(iter), limitOp, ansiMode) + CometExec.getCometIterator(Seq(iter), limitOp) } } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala index 7ecdbfd5a..26ec401ed 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala @@ -21,14 +21,12 @@ package org.apache.spark.sql.comet import org.apache.spark.rdd.{ParallelCollectionRDD, RDD} import org.apache.spark.serializer.Serializer -import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.{Attribute, NamedExpression, SortOrder} import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.comet.execution.shuffle.{CometShuffledBatchRDD, CometShuffleExchangeExec} import org.apache.spark.sql.execution.{SparkPlan, TakeOrderedAndProjectExec, UnaryExecNode, UnsafeRowSerializer} import org.apache.spark.sql.execution.metric.{SQLMetrics, SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter} import org.apache.spark.sql.execution.metric.SQLMetric -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.comet.serde.QueryPlanSerde.exprToProto @@ -75,19 +73,18 @@ case class CometTakeOrderedAndProjectExec( if (childRDD.getNumPartitions == 0) { new ParallelCollectionRDD(sparkContext, Seq.empty[ColumnarBatch], 1, Map.empty) } else { - val ansiEnabled = SparkSession.active.conf.get(SQLConf.ANSI_ENABLED) val singlePartitionRDD = if (childRDD.getNumPartitions == 1) { childRDD } else { val localTopK = if (orderingSatisfies) { - CometExecUtils.getNativeLimitRDD(childRDD, output, limit, ansiEnabled) + CometExecUtils.getNativeLimitRDD(childRDD, output, limit) } else { childRDD.mapPartitionsInternal { iter => val topK = CometExecUtils .getTopKNativePlan(output, sortOrder, child, limit) .get - CometExec.getCometIterator(Seq(iter), topK, ansiEnabled) + CometExec.getCometIterator(Seq(iter), topK) } } @@ -107,7 +104,7 @@ case class CometTakeOrderedAndProjectExec( val topKAndProjection = CometExecUtils .getProjectionNativePlan(projectList, output, sortOrder, child, limit) .get - CometExec.getCometIterator(Seq(iter), topKAndProjection, ansiEnabled) + CometExec.getCometIterator(Seq(iter), topKAndProjection) } } } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala index 742015a5d..232b6bf17 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala @@ -33,7 +33,6 @@ import org.apache.spark.scheduler.MapStatus import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleWriteMetricsReporter, ShuffleWriteProcessor} import org.apache.spark.shuffle.sort.SortShuffleManager -import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering @@ -463,11 +462,7 @@ class CometShuffleWriteProcessor( val nativeMetrics = CometMetricNode(nativeSQLMetrics) val rawIter = cometRDD.iterator(partition, context) - val cometIter = CometExec.getCometIterator( - Seq(rawIter), - nativePlan, - nativeMetrics, - SparkSession.active.conf.get(SQLConf.ANSI_ENABLED)) + val cometIter = CometExec.getCometIterator(Seq(rawIter), nativePlan, nativeMetrics) while (cometIter.hasNext) { cometIter.next() diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index e45b191c2..8545eee90 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -91,21 +91,19 @@ object CometExec { def getCometIterator( inputs: Seq[Iterator[ColumnarBatch]], - nativePlan: Operator, - ansiMode: Boolean): CometExecIterator = { - getCometIterator(inputs, nativePlan, CometMetricNode(Map.empty), ansiMode) + nativePlan: Operator): CometExecIterator = { + getCometIterator(inputs, nativePlan, CometMetricNode(Map.empty)) } def getCometIterator( inputs: Seq[Iterator[ColumnarBatch]], nativePlan: Operator, - nativeMetrics: CometMetricNode, - ansiMode: Boolean): CometExecIterator = { + nativeMetrics: CometMetricNode): CometExecIterator = { val outputStream = new ByteArrayOutputStream() nativePlan.writeTo(outputStream) outputStream.close() val bytes = outputStream.toByteArray - new CometExecIterator(newIterId, inputs, bytes, nativeMetrics, ansiMode) + new CometExecIterator(newIterId, inputs, bytes, nativeMetrics) } /** @@ -201,7 +199,6 @@ abstract class CometNativeExec extends CometExec { // Switch to use Decimal128 regardless of precision, since Arrow native execution // doesn't support Decimal32 and Decimal64 yet. SQLConf.get.setConfString(CometConf.COMET_USE_DECIMAL_128.key, "true") - val ansiEnabled: Boolean = SQLConf.get.getConf[Boolean](SQLConf.ANSI_ENABLED) val serializedPlanCopy = serializedPlan // TODO: support native metrics for all operators. @@ -209,12 +206,7 @@ abstract class CometNativeExec extends CometExec { def createCometExecIter(inputs: Seq[Iterator[ColumnarBatch]]): CometExecIterator = { val it = - new CometExecIterator( - CometExec.newIterId, - inputs, - serializedPlanCopy, - nativeMetrics, - ansiEnabled) + new CometExecIterator(CometExec.newIterId, inputs, serializedPlanCopy, nativeMetrics) setSubqueries(it.id, originalPlan) From ef563eb126b8df4657b939cc4104cabda945d9d2 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 19 Apr 2024 14:57:18 -0600 Subject: [PATCH 09/19] Update core/src/execution/datafusion/expressions/cast.rs Co-authored-by: Edmondo Porcu --- core/src/execution/datafusion/expressions/cast.rs | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/core/src/execution/datafusion/expressions/cast.rs b/core/src/execution/datafusion/expressions/cast.rs index 91066e0e2..d4bed07ca 100644 --- a/core/src/execution/datafusion/expressions/cast.rs +++ b/core/src/execution/datafusion/expressions/cast.rs @@ -203,7 +203,14 @@ impl PhysicalExpr for Cast { self: Arc, children: Vec>, ) -> datafusion_common::Result> { - Ok(Arc::new(Cast::new( + match children.len() { + 1 => Ok(Arc::new(Cast::new( + children[0].clone(), + self.data_type.clone(), + self.timezone.clone(), + ))), + _ => internal_err!("Cast should have exactly one child"), + } children[0].clone(), self.data_type.clone(), self.eval_mode, From 24e739babdd3d00dbab4ef1651bdd35e396b1ccc Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 19 Apr 2024 15:01:23 -0600 Subject: [PATCH 10/19] fix merge conflict + cargo fmt --- .../execution/datafusion/expressions/cast.rs | 20 +++++++------------ core/src/execution/datafusion/planner.rs | 20 +++++++++---------- 2 files changed, 16 insertions(+), 24 deletions(-) diff --git a/core/src/execution/datafusion/expressions/cast.rs b/core/src/execution/datafusion/expressions/cast.rs index d4bed07ca..480d13c81 100644 --- a/core/src/execution/datafusion/expressions/cast.rs +++ b/core/src/execution/datafusion/expressions/cast.rs @@ -31,7 +31,7 @@ use arrow::{ use arrow_array::{Array, ArrayRef, BooleanArray, GenericStringArray, OffsetSizeTrait}; use arrow_schema::{DataType, Schema}; use datafusion::logical_expr::ColumnarValue; -use datafusion_common::{Result as DataFusionResult, ScalarValue}; +git use datafusion_common::{internal_err, Result as DataFusionResult, ScalarValue}; use datafusion_physical_expr::PhysicalExpr; use crate::execution::datafusion::expressions::utils::{ @@ -127,13 +127,11 @@ impl Cast { Some(value) => match value.to_ascii_lowercase().trim() { "t" | "true" | "y" | "yes" | "1" => Ok(Some(true)), "f" | "false" | "n" | "no" | "0" => Ok(Some(false)), - _ if eval_mode == EvalMode::Ansi => { - Err(CometError::CastInvalidValue { - value: value.to_string(), - from_type: "STRING".to_string(), - to_type: "BOOLEAN".to_string(), - }) - } + _ if eval_mode == EvalMode::Ansi => Err(CometError::CastInvalidValue { + value: value.to_string(), + from_type: "STRING".to_string(), + to_type: "BOOLEAN".to_string(), + }), _ => Ok(None), }, _ => Ok(None), @@ -207,15 +205,11 @@ impl PhysicalExpr for Cast { 1 => Ok(Arc::new(Cast::new( children[0].clone(), self.data_type.clone(), + self.eval_mode, self.timezone.clone(), ))), _ => internal_err!("Cast should have exactly one child"), } - children[0].clone(), - self.data_type.clone(), - self.eval_mode, - self.timezone.clone(), - ))) } fn dyn_hash(&self, state: &mut dyn Hasher) { diff --git a/core/src/execution/datafusion/planner.rs b/core/src/execution/datafusion/planner.rs index 3385ff4f8..0ffdf70f4 100644 --- a/core/src/execution/datafusion/planner.rs +++ b/core/src/execution/datafusion/planner.rs @@ -65,7 +65,7 @@ use crate::{ avg_decimal::AvgDecimal, bitwise_not::BitwiseNotExpr, bloom_filter_might_contain::BloomFilterMightContain, - cast::Cast, + cast::{Cast, EvalMode}, checkoverflow::CheckOverflow, if_expr::IfExpr, scalar_funcs::create_comet_physical_fun, @@ -89,7 +89,6 @@ use crate::{ spark_partitioning::{partitioning::PartitioningStruct, Partitioning as SparkPartitioning}, }, }; -use crate::execution::datafusion::expressions::cast::EvalMode; // For clippy error on type_complexity. type ExecResult = Result; @@ -349,12 +348,7 @@ impl PhysicalPlanner { "TRY" => EvalMode::Try, _ => EvalMode::Legacy, }; - Ok(Arc::new(Cast::new( - child, - datatype, - eval_mode, - timezone, - ))) + Ok(Arc::new(Cast::new(child, datatype, eval_mode, timezone))) } ExprStruct::Hour(expr) => { let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema)?; @@ -649,15 +643,19 @@ impl PhysicalPlanner { let left = Arc::new(Cast::new_without_timezone( left, DataType::Decimal256(p1, s1), - EvalMode::Legacy + EvalMode::Legacy, )); let right = Arc::new(Cast::new_without_timezone( right, DataType::Decimal256(p2, s2), - EvalMode::Legacy + EvalMode::Legacy, )); let child = Arc::new(BinaryExpr::new(left, op, right)); - Ok(Arc::new(Cast::new_without_timezone(child, data_type, EvalMode::Legacy))) + Ok(Arc::new(Cast::new_without_timezone( + child, + data_type, + EvalMode::Legacy, + ))) } ( DataFusionOperator::Divide, From 588bc3b48f9338400443ead2d7b6e846a4ceef3f Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 19 Apr 2024 15:21:03 -0600 Subject: [PATCH 11/19] add error for invalid eval mode --- core/src/execution/datafusion/expressions/cast.rs | 2 +- core/src/execution/datafusion/planner.rs | 7 ++++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/core/src/execution/datafusion/expressions/cast.rs b/core/src/execution/datafusion/expressions/cast.rs index 480d13c81..a74467956 100644 --- a/core/src/execution/datafusion/expressions/cast.rs +++ b/core/src/execution/datafusion/expressions/cast.rs @@ -31,7 +31,7 @@ use arrow::{ use arrow_array::{Array, ArrayRef, BooleanArray, GenericStringArray, OffsetSizeTrait}; use arrow_schema::{DataType, Schema}; use datafusion::logical_expr::ColumnarValue; -git use datafusion_common::{internal_err, Result as DataFusionResult, ScalarValue}; +use datafusion_common::{internal_err, Result as DataFusionResult, ScalarValue}; use datafusion_physical_expr::PhysicalExpr; use crate::execution::datafusion::expressions::utils::{ diff --git a/core/src/execution/datafusion/planner.rs b/core/src/execution/datafusion/planner.rs index 0ffdf70f4..c3d931396 100644 --- a/core/src/execution/datafusion/planner.rs +++ b/core/src/execution/datafusion/planner.rs @@ -346,7 +346,12 @@ impl PhysicalPlanner { let eval_mode = match expr.eval_mode.as_str() { "ANSI" => EvalMode::Ansi, "TRY" => EvalMode::Try, - _ => EvalMode::Legacy, + "LEGACY" => EvalMode::Legacy, + other => { + return Err(ExecutionError::GeneralError(format!( + "Invalid Cast EvalMode: {other}" + ))) + } }; Ok(Arc::new(Cast::new(child, datatype, eval_mode, timezone))) } From 3e7f273f17d15b885a925faada77acedaa405f82 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 19 Apr 2024 15:44:00 -0600 Subject: [PATCH 12/19] add link to follow-on issue --- spark/src/test/scala/org/apache/comet/CometCastSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala index bdf93140d..ce6a2b761 100644 --- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala @@ -152,8 +152,8 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { val df = data.withColumn("converted", col("a").cast(toType)) val (expected, actual) = checkSparkThrows(df) - // TODO we have to strip off a prefix that is added by DataFusion and it would be nice - // to stop this being added + // We have to workaround https://github.com/apache/datafusion-comet/issues/293 here by + // removing the "Execution error: " error message prefix that is added by DataFusion assert(expected.getMessage == actual.getMessage.substring("Execution error: ".length)) // try_cast() should always return null for invalid inputs From 53719d6a3d2ba1c3e8bae5ea19662b5b7d85160b Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 19 Apr 2024 15:52:13 -0600 Subject: [PATCH 13/19] remove unused imports --- spark/src/test/scala/org/apache/comet/CometCastSuite.scala | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala index ce6a2b761..970c40b68 100644 --- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala @@ -20,14 +20,12 @@ package org.apache.comet import java.io.File -import java.sql.SQLException import scala.util.Random -import org.apache.spark.SparkException import org.apache.spark.sql.{CometTestBase, DataFrame, SaveMode} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper -import org.apache.spark.sql.functions.{col, exp, expr} +import org.apache.spark.sql.functions.col import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, DataTypes} From 40feb175aced669dbcba0686e0bb33d6570717c4 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 19 Apr 2024 15:54:23 -0600 Subject: [PATCH 14/19] minor cleanup --- spark/src/test/scala/org/apache/comet/CometCastSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala index 970c40b68..f04f94d56 100644 --- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala @@ -62,7 +62,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { test("cast string to bool") { val testValues = (Seq("TRUE", "True", "true", "FALSE", "False", "false", "1", "0", "", null) ++ - generateStrings("truefalseTRUEFALSEyesno10 \t\r\n", 8)).toDF("a") + generateStrings("truefalseTRUEFALSEyesno10" + whitespaceChars, 8)).toDF("a") castTest(testValues, DataTypes.BooleanType) } From cc1905adea2454e7cb5bf4f62bcdf7513a0eed7f Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sat, 20 Apr 2024 09:29:07 -0600 Subject: [PATCH 15/19] Fix compilation issue with Spark 3.2 and 3.3 --- .../org/apache/comet/serde/QueryPlanSerde.scala | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 79b00e13e..cae1d6b37 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -415,19 +415,18 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { timeZoneId: Option[String], dt: DataType, childExpr: Option[Expr], - evalMode: EvalMode.Value): Option[Expr] = { + evalMode: String): Option[Expr] = { val dataType = serializeDataType(dt) if (childExpr.isDefined && dataType.isDefined) { val castBuilder = ExprOuterClass.Cast.newBuilder() castBuilder.setChild(childExpr.get) castBuilder.setDatatype(dataType.get) + castBuilder.setEvalMode(evalMode) val timeZone = timeZoneId.getOrElse("UTC") castBuilder.setTimezone(timeZone) - castBuilder.setEvalMode(evalMode.toString) - Some( ExprOuterClass.Expr .newBuilder() @@ -451,7 +450,14 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { case Cast(child, dt, timeZoneId, evalMode) => val childExpr = exprToProtoInternal(child, inputs) - castToProto(timeZoneId, dt, childExpr, evalMode) + val evalModeStr = if (evalMode.isInstanceOf[Boolean]) { + // Spark 3.2 & 3.3 has ansiEnabled boolean + if (evalMode.asInstanceOf[Boolean]) "ANSI" else "LEGACY" + } else { + // Spark 3.4+ has EvalMode enum with values LEGACY, ANSI, and TRY + evalMode.toString + } + castToProto(timeZoneId, dt, childExpr, evalModeStr) case add @ Add(left, right, _) if supportedDataType(left.dataType) => val leftExpr = exprToProtoInternal(left, inputs) @@ -1568,7 +1574,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { val childExpr = scalarExprToProto("coalesce", exprChildren: _*) // TODO: Remove this once we have new DataFusion release which includes // the fix: https://github.com/apache/arrow-datafusion/pull/9459 - castToProto(None, a.dataType, childExpr, EvalMode.LEGACY) + castToProto(None, a.dataType, childExpr, "LEGACY") // With Spark 3.4, CharVarcharCodegenUtils.readSidePadding gets called to pad spaces for // char types. Use rpad to achieve the behavior. From 7bbdf0076eac8e2165c8f82bbad0adb59ab78a37 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sat, 20 Apr 2024 10:04:17 -0600 Subject: [PATCH 16/19] fix test regression --- core/src/execution/datafusion/expressions/cast.rs | 6 ++++-- core/src/execution/datafusion/planner.rs | 2 +- .../main/scala/org/apache/comet/serde/QueryPlanSerde.scala | 1 + 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/core/src/execution/datafusion/expressions/cast.rs b/core/src/execution/datafusion/expressions/cast.rs index a74467956..10079855d 100644 --- a/core/src/execution/datafusion/expressions/cast.rs +++ b/core/src/execution/datafusion/expressions/cast.rs @@ -146,8 +146,8 @@ impl Display for Cast { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!( f, - "Cast [data_type: {}, timezone: {}, child: {}]", - self.data_type, self.timezone, self.child + "Cast [data_type: {}, timezone: {}, child: {}, eval_mode: {:?}]", + self.data_type, self.timezone, self.child, &self.eval_mode ) } } @@ -160,6 +160,7 @@ impl PartialEq for Cast { self.child.eq(&x.child) && self.timezone.eq(&x.timezone) && self.data_type.eq(&x.data_type) + && self.eval_mode.eq(&x.eval_mode) }) .unwrap_or(false) } @@ -217,6 +218,7 @@ impl PhysicalExpr for Cast { self.child.hash(&mut s); self.data_type.hash(&mut s); self.timezone.hash(&mut s); + self.eval_mode.hash(&mut s); self.hash(&mut s); } } diff --git a/core/src/execution/datafusion/planner.rs b/core/src/execution/datafusion/planner.rs index c3d931396..089f0e092 100644 --- a/core/src/execution/datafusion/planner.rs +++ b/core/src/execution/datafusion/planner.rs @@ -349,7 +349,7 @@ impl PhysicalPlanner { "LEGACY" => EvalMode::Legacy, other => { return Err(ExecutionError::GeneralError(format!( - "Invalid Cast EvalMode: {other}" + "Invalid Cast EvalMode: \"{other}\"" ))) } }; diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index cae1d6b37..2e4379251 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -1000,6 +1000,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { .newBuilder() .setChild(e) .setDatatype(serializeDataType(IntegerType).get) + .setEvalMode("LEGACY") // year is not affected by ANSI mode .build()) .build() }) From 6d042addd898b6760e4c8a014cb2cbcec4bcf189 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sun, 21 Apr 2024 09:14:28 -0600 Subject: [PATCH 17/19] Fix test failures --- .../scala/org/apache/comet/CometCastSuite.scala | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala index f04f94d56..1eba86d56 100644 --- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala @@ -152,7 +152,18 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { // We have to workaround https://github.com/apache/datafusion-comet/issues/293 here by // removing the "Execution error: " error message prefix that is added by DataFusion - assert(expected.getMessage == actual.getMessage.substring("Execution error: ".length)) + val actualCometMessage = actual.getMessage + .substring("Execution error: ".length) + + val cometMessage = if (CometSparkSessionExtensions.isSpark34Plus) { + actualCometMessage + } else { + // Comet follows Spark 3.4 behavior and starts the string with [CAST_INVALID_INPUT] but + // this does not appear in Spark 3.2 or 3.3, so we strip it off before comparing + actualCometMessage.replace("[CAST_INVALID_INPUT]", "") + } + + assert(expected.getMessage == cometMessage) // try_cast() should always return null for invalid inputs val df2 = spark.sql(s"select try_cast(a as ${toType.sql}) from t") From 4c161eca0ebce8066e00030802590f4ca01680ed Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sun, 21 Apr 2024 09:53:25 -0600 Subject: [PATCH 18/19] change test approach for Spark 3.2 and 3.3 --- core/src/errors.rs | 4 +-- .../org/apache/comet/CometCastSuite.scala | 25 +++++++++++-------- 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/core/src/errors.rs b/core/src/errors.rs index 676e1fd9a..f02bd1969 100644 --- a/core/src/errors.rs +++ b/core/src/errors.rs @@ -60,8 +60,8 @@ pub enum CometError { #[error("Comet Internal Error: {0}")] Internal(String), - // TODO this error message is likely to change between Spark versions and it would be better - // to have the full error in Scala and just pass the invalid value back here + // Note that this message format is based on Spark 3.4 and is more detailed than the message + // returned by Spark 3.2 or 3.3 #[error("[CAST_INVALID_INPUT] The value '{value}' of the type \"{from_type}\" cannot be cast to \"{to_type}\" \ because it is malformed. Correct the value as per the syntax, or change its target type. \ Use `try_cast` to tolerate malformed input and return NULL instead. If necessary \ diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala index 1eba86d56..074b4c047 100644 --- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala @@ -150,21 +150,24 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { val df = data.withColumn("converted", col("a").cast(toType)) val (expected, actual) = checkSparkThrows(df) - // We have to workaround https://github.com/apache/datafusion-comet/issues/293 here by - // removing the "Execution error: " error message prefix that is added by DataFusion - val actualCometMessage = actual.getMessage - .substring("Execution error: ".length) + if (CometSparkSessionExtensions.isSpark34Plus) { + // We have to workaround https://github.com/apache/datafusion-comet/issues/293 here by + // removing the "Execution error: " error message prefix that is added by DataFusion + val cometMessage = actual.getMessage + .substring("Execution error: ".length) - val cometMessage = if (CometSparkSessionExtensions.isSpark34Plus) { - actualCometMessage + assert(expected.getMessage == cometMessage) } else { - // Comet follows Spark 3.4 behavior and starts the string with [CAST_INVALID_INPUT] but - // this does not appear in Spark 3.2 or 3.3, so we strip it off before comparing - actualCometMessage.replace("[CAST_INVALID_INPUT]", "") + // Spark 3.2 and 3.3 have a different error message format so we can't do a direct + // comparison between Spark and Comet. + // Spark message is in format `invalid input syntax for type TYPE: VALUE` + // Comet message is in format `The value 'VALUE' of the type FROM_TYPE cannot be cast to TO_TYPE` + // We just check that the comet message contains the same invalid value as the Spark message + val sparkInvalidValue = + expected.getMessage.substring(expected.getMessage.indexOf(':') + 1) + assert(actual.getMessage.contains(sparkInvalidValue)) } - assert(expected.getMessage == cometMessage) - // try_cast() should always return null for invalid inputs val df2 = spark.sql(s"select try_cast(a as ${toType.sql}) from t") checkSparkAnswer(df2) From 5ae8479084c41061710378efe1eb682aa497dab0 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sun, 21 Apr 2024 09:54:07 -0600 Subject: [PATCH 19/19] fix --- spark/src/test/scala/org/apache/comet/CometCastSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala index 074b4c047..8abd24598 100644 --- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala @@ -164,7 +164,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { // Comet message is in format `The value 'VALUE' of the type FROM_TYPE cannot be cast to TO_TYPE` // We just check that the comet message contains the same invalid value as the Spark message val sparkInvalidValue = - expected.getMessage.substring(expected.getMessage.indexOf(':') + 1) + expected.getMessage.substring(expected.getMessage.indexOf(':') + 2) assert(actual.getMessage.contains(sparkInvalidValue)) }