From 5023635f6bbf693056d5d3182274ce6ec4ffb97c Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 19 Apr 2024 14:50:56 -0600 Subject: [PATCH] Add eval_mode to cast proto, remove ansi mode from planner --- core/src/errors.rs | 4 +- .../execution/datafusion/expressions/cast.rs | 29 ++++++---- core/src/execution/datafusion/planner.rs | 53 +++++++++---------- core/src/execution/jni_api.rs | 5 +- core/src/execution/proto/expr.proto | 2 + .../apache/comet/serde/QueryPlanSerde.scala | 11 ++-- .../org/apache/comet/CometCastSuite.scala | 7 ++- 7 files changed, 59 insertions(+), 52 deletions(-) diff --git a/core/src/errors.rs b/core/src/errors.rs index f4b5d738a..676e1fd9a 100644 --- a/core/src/errors.rs +++ b/core/src/errors.rs @@ -62,10 +62,10 @@ pub enum CometError { // TODO this error message is likely to change between Spark versions and it would be better // to have the full error in Scala and just pass the invalid value back here - #[error("[[CAST_INVALID_INPUT] The value '{value}' of the type \"{from_type}\" cannot be cast to \"{to_type}\" \ + #[error("[CAST_INVALID_INPUT] The value '{value}' of the type \"{from_type}\" cannot be cast to \"{to_type}\" \ because it is malformed. Correct the value as per the syntax, or change its target type. \ Use `try_cast` to tolerate malformed input and return NULL instead. If necessary \ - set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error")] + set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] CastInvalidValue { value: String, from_type: String, diff --git a/core/src/execution/datafusion/expressions/cast.rs b/core/src/execution/datafusion/expressions/cast.rs index ca9916f74..91066e0e2 100644 --- a/core/src/execution/datafusion/expressions/cast.rs +++ b/core/src/execution/datafusion/expressions/cast.rs @@ -46,11 +46,18 @@ static CAST_OPTIONS: CastOptions = CastOptions { .with_timestamp_format(TIMESTAMP_FORMAT), }; +#[derive(Debug, Hash, PartialEq, Clone, Copy)] +pub enum EvalMode { + Legacy, + Ansi, + Try, +} + #[derive(Debug, Hash)] pub struct Cast { pub child: Arc, pub data_type: DataType, - pub ansi_mode: bool, + pub eval_mode: EvalMode, /// When cast from/to timezone related types, we need timezone, which will be resolved with /// session local timezone by an analyzer in Spark. @@ -61,27 +68,27 @@ impl Cast { pub fn new( child: Arc, data_type: DataType, - ansi_mode: bool, + eval_mode: EvalMode, timezone: String, ) -> Self { Self { child, data_type, timezone, - ansi_mode, + eval_mode, } } pub fn new_without_timezone( child: Arc, data_type: DataType, - ansi_mode: bool, + eval_mode: EvalMode, ) -> Self { Self { child, data_type, timezone: "".to_string(), - ansi_mode, + eval_mode, } } @@ -91,10 +98,10 @@ impl Cast { let from_type = array.data_type(); let cast_result = match (from_type, to_type) { (DataType::Utf8, DataType::Boolean) => { - Self::spark_cast_utf8_to_boolean::(&array, self.ansi_mode)? + Self::spark_cast_utf8_to_boolean::(&array, self.eval_mode)? } (DataType::LargeUtf8, DataType::Boolean) => { - Self::spark_cast_utf8_to_boolean::(&array, self.ansi_mode)? + Self::spark_cast_utf8_to_boolean::(&array, self.eval_mode)? } _ => cast_with_options(&array, to_type, &CAST_OPTIONS)?, }; @@ -104,7 +111,7 @@ impl Cast { fn spark_cast_utf8_to_boolean( from: &dyn Array, - ansi_mode: bool, + eval_mode: EvalMode, ) -> CometResult where OffsetSize: OffsetSizeTrait, @@ -120,9 +127,9 @@ impl Cast { Some(value) => match value.to_ascii_lowercase().trim() { "t" | "true" | "y" | "yes" | "1" => Ok(Some(true)), "f" | "false" | "n" | "no" | "0" => Ok(Some(false)), - other if ansi_mode => { + _ if eval_mode == EvalMode::Ansi => { Err(CometError::CastInvalidValue { - value: other.to_string(), + value: value.to_string(), from_type: "STRING".to_string(), to_type: "BOOLEAN".to_string(), }) @@ -199,7 +206,7 @@ impl PhysicalExpr for Cast { Ok(Arc::new(Cast::new( children[0].clone(), self.data_type.clone(), - self.ansi_mode, + self.eval_mode, self.timezone.clone(), ))) } diff --git a/core/src/execution/datafusion/planner.rs b/core/src/execution/datafusion/planner.rs index 745976f02..3385ff4f8 100644 --- a/core/src/execution/datafusion/planner.rs +++ b/core/src/execution/datafusion/planner.rs @@ -89,6 +89,7 @@ use crate::{ spark_partitioning::{partitioning::PartitioningStruct, Partitioning as SparkPartitioning}, }, }; +use crate::execution::datafusion::expressions::cast::EvalMode; // For clippy error on type_complexity. type ExecResult = Result; @@ -112,17 +113,27 @@ pub struct PhysicalPlanner { exec_context_id: i64, execution_props: ExecutionProps, session_ctx: Arc, - ansi_mode: bool, +} + +impl Default for PhysicalPlanner { + fn default() -> Self { + let session_ctx = Arc::new(SessionContext::new()); + let execution_props = ExecutionProps::new(); + Self { + exec_context_id: TEST_EXEC_CONTEXT_ID, + execution_props, + session_ctx, + } + } } impl PhysicalPlanner { - pub fn new(session_ctx: Arc, ansi_mode: bool) -> Self { + pub fn new(session_ctx: Arc) -> Self { let execution_props = ExecutionProps::new(); Self { exec_context_id: TEST_EXEC_CONTEXT_ID, execution_props, session_ctx, - ansi_mode, } } @@ -131,7 +142,6 @@ impl PhysicalPlanner { exec_context_id, execution_props: self.execution_props, session_ctx: self.session_ctx.clone(), - ansi_mode: self.ansi_mode, } } @@ -334,10 +344,15 @@ impl PhysicalPlanner { let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema)?; let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap()); let timezone = expr.timezone.clone(); + let eval_mode = match expr.eval_mode.as_str() { + "ANSI" => EvalMode::Ansi, + "TRY" => EvalMode::Try, + _ => EvalMode::Legacy, + }; Ok(Arc::new(Cast::new( child, datatype, - self.ansi_mode, + eval_mode, timezone, ))) } @@ -634,19 +649,15 @@ impl PhysicalPlanner { let left = Arc::new(Cast::new_without_timezone( left, DataType::Decimal256(p1, s1), - self.ansi_mode, + EvalMode::Legacy )); let right = Arc::new(Cast::new_without_timezone( right, DataType::Decimal256(p2, s2), - self.ansi_mode, + EvalMode::Legacy )); let child = Arc::new(BinaryExpr::new(left, op, right)); - Ok(Arc::new(Cast::new_without_timezone( - child, - data_type, - self.ansi_mode, - ))) + Ok(Arc::new(Cast::new_without_timezone(child, data_type, EvalMode::Legacy))) } ( DataFusionOperator::Divide, @@ -1434,9 +1445,7 @@ mod tests { use arrow_array::{DictionaryArray, Int32Array, StringArray}; use arrow_schema::DataType; - use datafusion::{ - execution::context::ExecutionProps, physical_plan::common::collect, prelude::SessionContext, - }; + use datafusion::{physical_plan::common::collect, prelude::SessionContext}; use tokio::sync::mpsc; use crate::execution::{ @@ -1446,23 +1455,9 @@ mod tests { spark_operator, }; - use crate::execution::datafusion::planner::TEST_EXEC_CONTEXT_ID; use spark_expression::expr::ExprStruct::*; use spark_operator::{operator::OpStruct, Operator}; - impl Default for PhysicalPlanner { - fn default() -> Self { - let session_ctx = Arc::new(SessionContext::new()); - let execution_props = ExecutionProps::new(); - Self { - exec_context_id: TEST_EXEC_CONTEXT_ID, - execution_props, - session_ctx, - ansi_mode: false, - } - } - } - #[test] fn test_unpack_dictionary_primitive() { let op_scan = Operator { diff --git a/core/src/execution/jni_api.rs b/core/src/execution/jni_api.rs index 4f7bc3df8..8249097a1 100644 --- a/core/src/execution/jni_api.rs +++ b/core/src/execution/jni_api.rs @@ -317,14 +317,11 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan( let exec_context_id = exec_context.id; - let ansi_mode = - matches!(exec_context.conf.get("ansi_mode"), Some(value) if value == "true"); - // Initialize the execution stream. // Because we don't know if input arrays are dictionary-encoded when we create // query plan, we need to defer stream initialization to first time execution. if exec_context.root_op.is_none() { - let planner = PhysicalPlanner::new(exec_context.session_ctx.clone(), ansi_mode) + let planner = PhysicalPlanner::new(exec_context.session_ctx.clone()) .with_exec_id(exec_context_id); let (scans, root_op) = planner.create_plan( &exec_context.spark_plan, diff --git a/core/src/execution/proto/expr.proto b/core/src/execution/proto/expr.proto index 58f607fc0..8ac428cdb 100644 --- a/core/src/execution/proto/expr.proto +++ b/core/src/execution/proto/expr.proto @@ -208,6 +208,8 @@ message Cast { Expr child = 1; DataType datatype = 2; string timezone = 3; + // LEGACY, ANSI, or TRY + string eval_mode = 4; } message Equal { diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 26fc708ff..79b00e13e 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -414,7 +414,8 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { def castToProto( timeZoneId: Option[String], dt: DataType, - childExpr: Option[Expr]): Option[Expr] = { + childExpr: Option[Expr], + evalMode: EvalMode.Value): Option[Expr] = { val dataType = serializeDataType(dt) if (childExpr.isDefined && dataType.isDefined) { @@ -425,6 +426,8 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { val timeZone = timeZoneId.getOrElse("UTC") castBuilder.setTimezone(timeZone) + castBuilder.setEvalMode(evalMode.toString) + Some( ExprOuterClass.Expr .newBuilder() @@ -446,9 +449,9 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { val value = cast.eval() exprToProtoInternal(Literal(value, dataType), inputs) - case Cast(child, dt, timeZoneId, _) => + case Cast(child, dt, timeZoneId, evalMode) => val childExpr = exprToProtoInternal(child, inputs) - castToProto(timeZoneId, dt, childExpr) + castToProto(timeZoneId, dt, childExpr, evalMode) case add @ Add(left, right, _) if supportedDataType(left.dataType) => val leftExpr = exprToProtoInternal(left, inputs) @@ -1565,7 +1568,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { val childExpr = scalarExprToProto("coalesce", exprChildren: _*) // TODO: Remove this once we have new DataFusion release which includes // the fix: https://github.com/apache/arrow-datafusion/pull/9459 - castToProto(None, a.dataType, childExpr) + castToProto(None, a.dataType, childExpr, EvalMode.LEGACY) // With Spark 3.4, CharVarcharCodegenUtils.readSidePadding gets called to pad spaces for // char types. Use rpad to achieve the behavior. diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala index 3cc90253a..bdf93140d 100644 --- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala @@ -130,7 +130,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { private def castTest(input: DataFrame, toType: DataType): Unit = { withTempPath { dir => - val data = roundtripParquet(input, dir) + val data = roundtripParquet(input, dir).coalesce(1) data.createOrReplaceTempView("t") withSQLConf((SQLConf.ANSI_ENABLED.key, "false")) { @@ -151,7 +151,10 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { // cast() should throw exception on invalid inputs when ansi mode is enabled val df = data.withColumn("converted", col("a").cast(toType)) val (expected, actual) = checkSparkThrows(df) - assert(expected.getMessage == actual.getMessage) + + // TODO we have to strip off a prefix that is added by DataFusion and it would be nice + // to stop this being added + assert(expected.getMessage == actual.getMessage.substring("Execution error: ".length)) // try_cast() should always return null for invalid inputs val df2 = spark.sql(s"select try_cast(a as ${toType.sql}) from t")