Skip to content

Commit

Permalink
Add eval_mode to cast proto, remove ansi mode from planner
Browse files Browse the repository at this point in the history
  • Loading branch information
andygrove committed Apr 19, 2024
1 parent c90cce9 commit 5023635
Show file tree
Hide file tree
Showing 7 changed files with 59 additions and 52 deletions.
4 changes: 2 additions & 2 deletions core/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,10 @@ pub enum CometError {

// TODO this error message is likely to change between Spark versions and it would be better
// to have the full error in Scala and just pass the invalid value back here
#[error("[[CAST_INVALID_INPUT] The value '{value}' of the type \"{from_type}\" cannot be cast to \"{to_type}\" \
#[error("[CAST_INVALID_INPUT] The value '{value}' of the type \"{from_type}\" cannot be cast to \"{to_type}\" \
because it is malformed. Correct the value as per the syntax, or change its target type. \
Use `try_cast` to tolerate malformed input and return NULL instead. If necessary \
set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error")]
set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")]
CastInvalidValue {
value: String,
from_type: String,
Expand Down
29 changes: 18 additions & 11 deletions core/src/execution/datafusion/expressions/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,18 @@ static CAST_OPTIONS: CastOptions = CastOptions {
.with_timestamp_format(TIMESTAMP_FORMAT),
};

#[derive(Debug, Hash, PartialEq, Clone, Copy)]
pub enum EvalMode {
Legacy,
Ansi,
Try,
}

#[derive(Debug, Hash)]
pub struct Cast {
pub child: Arc<dyn PhysicalExpr>,
pub data_type: DataType,
pub ansi_mode: bool,
pub eval_mode: EvalMode,

/// When cast from/to timezone related types, we need timezone, which will be resolved with
/// session local timezone by an analyzer in Spark.
Expand All @@ -61,27 +68,27 @@ impl Cast {
pub fn new(
child: Arc<dyn PhysicalExpr>,
data_type: DataType,
ansi_mode: bool,
eval_mode: EvalMode,
timezone: String,
) -> Self {
Self {
child,
data_type,
timezone,
ansi_mode,
eval_mode,
}
}

pub fn new_without_timezone(
child: Arc<dyn PhysicalExpr>,
data_type: DataType,
ansi_mode: bool,
eval_mode: EvalMode,
) -> Self {
Self {
child,
data_type,
timezone: "".to_string(),
ansi_mode,
eval_mode,
}
}

Expand All @@ -91,10 +98,10 @@ impl Cast {
let from_type = array.data_type();
let cast_result = match (from_type, to_type) {
(DataType::Utf8, DataType::Boolean) => {
Self::spark_cast_utf8_to_boolean::<i32>(&array, self.ansi_mode)?
Self::spark_cast_utf8_to_boolean::<i32>(&array, self.eval_mode)?
}
(DataType::LargeUtf8, DataType::Boolean) => {
Self::spark_cast_utf8_to_boolean::<i64>(&array, self.ansi_mode)?
Self::spark_cast_utf8_to_boolean::<i64>(&array, self.eval_mode)?
}
_ => cast_with_options(&array, to_type, &CAST_OPTIONS)?,
};
Expand All @@ -104,7 +111,7 @@ impl Cast {

fn spark_cast_utf8_to_boolean<OffsetSize>(
from: &dyn Array,
ansi_mode: bool,
eval_mode: EvalMode,
) -> CometResult<ArrayRef>
where
OffsetSize: OffsetSizeTrait,
Expand All @@ -120,9 +127,9 @@ impl Cast {
Some(value) => match value.to_ascii_lowercase().trim() {
"t" | "true" | "y" | "yes" | "1" => Ok(Some(true)),
"f" | "false" | "n" | "no" | "0" => Ok(Some(false)),
other if ansi_mode => {
_ if eval_mode == EvalMode::Ansi => {
Err(CometError::CastInvalidValue {
value: other.to_string(),
value: value.to_string(),
from_type: "STRING".to_string(),
to_type: "BOOLEAN".to_string(),
})
Expand Down Expand Up @@ -199,7 +206,7 @@ impl PhysicalExpr for Cast {
Ok(Arc::new(Cast::new(
children[0].clone(),
self.data_type.clone(),
self.ansi_mode,
self.eval_mode,
self.timezone.clone(),
)))
}
Expand Down
53 changes: 24 additions & 29 deletions core/src/execution/datafusion/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ use crate::{
spark_partitioning::{partitioning::PartitioningStruct, Partitioning as SparkPartitioning},
},
};
use crate::execution::datafusion::expressions::cast::EvalMode;

// For clippy error on type_complexity.
type ExecResult<T> = Result<T, ExecutionError>;
Expand All @@ -112,17 +113,27 @@ pub struct PhysicalPlanner {
exec_context_id: i64,
execution_props: ExecutionProps,
session_ctx: Arc<SessionContext>,
ansi_mode: bool,
}

impl Default for PhysicalPlanner {
fn default() -> Self {
let session_ctx = Arc::new(SessionContext::new());
let execution_props = ExecutionProps::new();
Self {
exec_context_id: TEST_EXEC_CONTEXT_ID,
execution_props,
session_ctx,
}
}
}

impl PhysicalPlanner {
pub fn new(session_ctx: Arc<SessionContext>, ansi_mode: bool) -> Self {
pub fn new(session_ctx: Arc<SessionContext>) -> Self {
let execution_props = ExecutionProps::new();
Self {
exec_context_id: TEST_EXEC_CONTEXT_ID,
execution_props,
session_ctx,
ansi_mode,
}
}

Expand All @@ -131,7 +142,6 @@ impl PhysicalPlanner {
exec_context_id,
execution_props: self.execution_props,
session_ctx: self.session_ctx.clone(),
ansi_mode: self.ansi_mode,
}
}

Expand Down Expand Up @@ -334,10 +344,15 @@ impl PhysicalPlanner {
let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema)?;
let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap());
let timezone = expr.timezone.clone();
let eval_mode = match expr.eval_mode.as_str() {
"ANSI" => EvalMode::Ansi,
"TRY" => EvalMode::Try,
_ => EvalMode::Legacy,
};
Ok(Arc::new(Cast::new(
child,
datatype,
self.ansi_mode,
eval_mode,
timezone,
)))
}
Expand Down Expand Up @@ -634,19 +649,15 @@ impl PhysicalPlanner {
let left = Arc::new(Cast::new_without_timezone(
left,
DataType::Decimal256(p1, s1),
self.ansi_mode,
EvalMode::Legacy
));
let right = Arc::new(Cast::new_without_timezone(
right,
DataType::Decimal256(p2, s2),
self.ansi_mode,
EvalMode::Legacy
));
let child = Arc::new(BinaryExpr::new(left, op, right));
Ok(Arc::new(Cast::new_without_timezone(
child,
data_type,
self.ansi_mode,
)))
Ok(Arc::new(Cast::new_without_timezone(child, data_type, EvalMode::Legacy)))
}
(
DataFusionOperator::Divide,
Expand Down Expand Up @@ -1434,9 +1445,7 @@ mod tests {

use arrow_array::{DictionaryArray, Int32Array, StringArray};
use arrow_schema::DataType;
use datafusion::{
execution::context::ExecutionProps, physical_plan::common::collect, prelude::SessionContext,
};
use datafusion::{physical_plan::common::collect, prelude::SessionContext};
use tokio::sync::mpsc;

use crate::execution::{
Expand All @@ -1446,23 +1455,9 @@ mod tests {
spark_operator,
};

use crate::execution::datafusion::planner::TEST_EXEC_CONTEXT_ID;
use spark_expression::expr::ExprStruct::*;
use spark_operator::{operator::OpStruct, Operator};

impl Default for PhysicalPlanner {
fn default() -> Self {
let session_ctx = Arc::new(SessionContext::new());
let execution_props = ExecutionProps::new();
Self {
exec_context_id: TEST_EXEC_CONTEXT_ID,
execution_props,
session_ctx,
ansi_mode: false,
}
}
}

#[test]
fn test_unpack_dictionary_primitive() {
let op_scan = Operator {
Expand Down
5 changes: 1 addition & 4 deletions core/src/execution/jni_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -317,14 +317,11 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan(

let exec_context_id = exec_context.id;

let ansi_mode =
matches!(exec_context.conf.get("ansi_mode"), Some(value) if value == "true");

// Initialize the execution stream.
// Because we don't know if input arrays are dictionary-encoded when we create
// query plan, we need to defer stream initialization to first time execution.
if exec_context.root_op.is_none() {
let planner = PhysicalPlanner::new(exec_context.session_ctx.clone(), ansi_mode)
let planner = PhysicalPlanner::new(exec_context.session_ctx.clone())
.with_exec_id(exec_context_id);
let (scans, root_op) = planner.create_plan(
&exec_context.spark_plan,
Expand Down
2 changes: 2 additions & 0 deletions core/src/execution/proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,8 @@ message Cast {
Expr child = 1;
DataType datatype = 2;
string timezone = 3;
// LEGACY, ANSI, or TRY
string eval_mode = 4;
}

message Equal {
Expand Down
11 changes: 7 additions & 4 deletions spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,8 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde {
def castToProto(
timeZoneId: Option[String],
dt: DataType,
childExpr: Option[Expr]): Option[Expr] = {
childExpr: Option[Expr],
evalMode: EvalMode.Value): Option[Expr] = {
val dataType = serializeDataType(dt)

if (childExpr.isDefined && dataType.isDefined) {
Expand All @@ -425,6 +426,8 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde {
val timeZone = timeZoneId.getOrElse("UTC")
castBuilder.setTimezone(timeZone)

castBuilder.setEvalMode(evalMode.toString)

Some(
ExprOuterClass.Expr
.newBuilder()
Expand All @@ -446,9 +449,9 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde {
val value = cast.eval()
exprToProtoInternal(Literal(value, dataType), inputs)

case Cast(child, dt, timeZoneId, _) =>
case Cast(child, dt, timeZoneId, evalMode) =>
val childExpr = exprToProtoInternal(child, inputs)
castToProto(timeZoneId, dt, childExpr)
castToProto(timeZoneId, dt, childExpr, evalMode)

case add @ Add(left, right, _) if supportedDataType(left.dataType) =>
val leftExpr = exprToProtoInternal(left, inputs)
Expand Down Expand Up @@ -1565,7 +1568,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde {
val childExpr = scalarExprToProto("coalesce", exprChildren: _*)
// TODO: Remove this once we have new DataFusion release which includes
// the fix: https://github.com/apache/arrow-datafusion/pull/9459
castToProto(None, a.dataType, childExpr)
castToProto(None, a.dataType, childExpr, EvalMode.LEGACY)

// With Spark 3.4, CharVarcharCodegenUtils.readSidePadding gets called to pad spaces for
// char types. Use rpad to achieve the behavior.
Expand Down
7 changes: 5 additions & 2 deletions spark/src/test/scala/org/apache/comet/CometCastSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {

private def castTest(input: DataFrame, toType: DataType): Unit = {
withTempPath { dir =>
val data = roundtripParquet(input, dir)
val data = roundtripParquet(input, dir).coalesce(1)
data.createOrReplaceTempView("t")

withSQLConf((SQLConf.ANSI_ENABLED.key, "false")) {
Expand All @@ -151,7 +151,10 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
// cast() should throw exception on invalid inputs when ansi mode is enabled
val df = data.withColumn("converted", col("a").cast(toType))
val (expected, actual) = checkSparkThrows(df)
assert(expected.getMessage == actual.getMessage)

// TODO we have to strip off a prefix that is added by DataFusion and it would be nice
// to stop this being added
assert(expected.getMessage == actual.getMessage.substring("Execution error: ".length))

// try_cast() should always return null for invalid inputs
val df2 = spark.sql(s"select try_cast(a as ${toType.sql}) from t")
Expand Down

0 comments on commit 5023635

Please sign in to comment.