Skip to content

Commit

Permalink
feat: Support ANSI mode in CAST from String to Bool (apache#290)
Browse files Browse the repository at this point in the history
  • Loading branch information
andygrove authored and Steve Vaughan Jr committed Apr 23, 2024
1 parent adb3682 commit 370d641
Show file tree
Hide file tree
Showing 9 changed files with 226 additions and 51 deletions.
8 changes: 8 additions & 0 deletions common/src/main/scala/org/apache/comet/CometConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,14 @@ object CometConf {
.toSequence
.createWithDefault(Seq("Range,InMemoryTableScan"))

val COMET_ANSI_MODE_ENABLED: ConfigEntry[Boolean] = conf("spark.comet.ansi.enabled")
.doc(
"Comet does not respect ANSI mode in most cases and by default will not accelerate " +
"queries when ansi mode is enabled. Enable this setting to test Comet's experimental " +
"support for ANSI mode. This should not be used in production.")
.booleanConf
.createWithDefault(false)

}

object ConfigHelpers {
Expand Down
16 changes: 16 additions & 0 deletions core/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,18 @@ pub enum CometError {
#[error("Comet Internal Error: {0}")]
Internal(String),

// Note that this message format is based on Spark 3.4 and is more detailed than the message
// returned by Spark 3.2 or 3.3
#[error("[CAST_INVALID_INPUT] The value '{value}' of the type \"{from_type}\" cannot be cast to \"{to_type}\" \
because it is malformed. Correct the value as per the syntax, or change its target type. \
Use `try_cast` to tolerate malformed input and return NULL instead. If necessary \
set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")]
CastInvalidValue {
value: String,
from_type: String,
to_type: String,
},

#[error(transparent)]
Arrow {
#[from]
Expand Down Expand Up @@ -183,6 +195,10 @@ impl jni::errors::ToException for CometError {
class: "java/lang/NullPointerException".to_string(),
msg: self.to_string(),
},
CometError::CastInvalidValue { .. } => Exception {
class: "org/apache/spark/SparkException".to_string(),
msg: self.to_string(),
},
CometError::NumberIntFormat { source: s } => Exception {
class: "java/lang/NumberFormatException".to_string(),
msg: s.to_string(),
Expand Down
74 changes: 55 additions & 19 deletions core/src/execution/datafusion/expressions/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use std::{
sync::Arc,
};

use crate::errors::{CometError, CometResult};
use arrow::{
compute::{cast_with_options, CastOptions},
record_batch::RecordBatch,
Expand All @@ -30,7 +31,7 @@ use arrow::{
use arrow_array::{Array, ArrayRef, BooleanArray, GenericStringArray, OffsetSizeTrait};
use arrow_schema::{DataType, Schema};
use datafusion::logical_expr::ColumnarValue;
use datafusion_common::{Result as DataFusionResult, ScalarValue};
use datafusion_common::{internal_err, Result as DataFusionResult, ScalarValue};
use datafusion_physical_expr::PhysicalExpr;

use crate::execution::datafusion::expressions::utils::{
Expand All @@ -45,30 +46,49 @@ static CAST_OPTIONS: CastOptions = CastOptions {
.with_timestamp_format(TIMESTAMP_FORMAT),
};

#[derive(Debug, Hash, PartialEq, Clone, Copy)]
pub enum EvalMode {
Legacy,
Ansi,
Try,
}

#[derive(Debug, Hash)]
pub struct Cast {
pub child: Arc<dyn PhysicalExpr>,
pub data_type: DataType,
pub eval_mode: EvalMode,

/// When cast from/to timezone related types, we need timezone, which will be resolved with
/// session local timezone by an analyzer in Spark.
pub timezone: String,
}

impl Cast {
pub fn new(child: Arc<dyn PhysicalExpr>, data_type: DataType, timezone: String) -> Self {
pub fn new(
child: Arc<dyn PhysicalExpr>,
data_type: DataType,
eval_mode: EvalMode,
timezone: String,
) -> Self {
Self {
child,
data_type,
timezone,
eval_mode,
}
}

pub fn new_without_timezone(child: Arc<dyn PhysicalExpr>, data_type: DataType) -> Self {
pub fn new_without_timezone(
child: Arc<dyn PhysicalExpr>,
data_type: DataType,
eval_mode: EvalMode,
) -> Self {
Self {
child,
data_type,
timezone: "".to_string(),
eval_mode,
}
}

Expand All @@ -77,17 +97,22 @@ impl Cast {
let array = array_with_timezone(array, self.timezone.clone(), Some(to_type));
let from_type = array.data_type();
let cast_result = match (from_type, to_type) {
(DataType::Utf8, DataType::Boolean) => Self::spark_cast_utf8_to_boolean::<i32>(&array),
(DataType::Utf8, DataType::Boolean) => {
Self::spark_cast_utf8_to_boolean::<i32>(&array, self.eval_mode)?
}
(DataType::LargeUtf8, DataType::Boolean) => {
Self::spark_cast_utf8_to_boolean::<i64>(&array)
Self::spark_cast_utf8_to_boolean::<i64>(&array, self.eval_mode)?
}
_ => cast_with_options(&array, to_type, &CAST_OPTIONS)?,
};
let result = spark_cast(cast_result, from_type, to_type);
Ok(result)
}

fn spark_cast_utf8_to_boolean<OffsetSize>(from: &dyn Array) -> ArrayRef
fn spark_cast_utf8_to_boolean<OffsetSize>(
from: &dyn Array,
eval_mode: EvalMode,
) -> CometResult<ArrayRef>
where
OffsetSize: OffsetSizeTrait,
{
Expand All @@ -100,24 +125,29 @@ impl Cast {
.iter()
.map(|value| match value {
Some(value) => match value.to_ascii_lowercase().trim() {
"t" | "true" | "y" | "yes" | "1" => Some(true),
"f" | "false" | "n" | "no" | "0" => Some(false),
_ => None,
"t" | "true" | "y" | "yes" | "1" => Ok(Some(true)),
"f" | "false" | "n" | "no" | "0" => Ok(Some(false)),
_ if eval_mode == EvalMode::Ansi => Err(CometError::CastInvalidValue {
value: value.to_string(),
from_type: "STRING".to_string(),
to_type: "BOOLEAN".to_string(),
}),
_ => Ok(None),
},
_ => None,
_ => Ok(None),
})
.collect::<BooleanArray>();
.collect::<Result<BooleanArray, _>>()?;

Arc::new(output_array)
Ok(Arc::new(output_array))
}
}

impl Display for Cast {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Cast [data_type: {}, timezone: {}, child: {}]",
self.data_type, self.timezone, self.child
"Cast [data_type: {}, timezone: {}, child: {}, eval_mode: {:?}]",
self.data_type, self.timezone, self.child, &self.eval_mode
)
}
}
Expand All @@ -130,6 +160,7 @@ impl PartialEq<dyn Any> for Cast {
self.child.eq(&x.child)
&& self.timezone.eq(&x.timezone)
&& self.data_type.eq(&x.data_type)
&& self.eval_mode.eq(&x.eval_mode)
})
.unwrap_or(false)
}
Expand Down Expand Up @@ -171,18 +202,23 @@ impl PhysicalExpr for Cast {
self: Arc<Self>,
children: Vec<Arc<dyn PhysicalExpr>>,
) -> datafusion_common::Result<Arc<dyn PhysicalExpr>> {
Ok(Arc::new(Cast::new(
children[0].clone(),
self.data_type.clone(),
self.timezone.clone(),
)))
match children.len() {
1 => Ok(Arc::new(Cast::new(
children[0].clone(),
self.data_type.clone(),
self.eval_mode,
self.timezone.clone(),
))),
_ => internal_err!("Cast should have exactly one child"),
}
}

fn dyn_hash(&self, state: &mut dyn Hasher) {
let mut s = state;
self.child.hash(&mut s);
self.data_type.hash(&mut s);
self.timezone.hash(&mut s);
self.eval_mode.hash(&mut s);
self.hash(&mut s);
}
}
22 changes: 19 additions & 3 deletions core/src/execution/datafusion/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ use crate::{
avg_decimal::AvgDecimal,
bitwise_not::BitwiseNotExpr,
bloom_filter_might_contain::BloomFilterMightContain,
cast::Cast,
cast::{Cast, EvalMode},
checkoverflow::CheckOverflow,
covariance::Covariance,
if_expr::IfExpr,
Expand Down Expand Up @@ -345,7 +345,17 @@ impl PhysicalPlanner {
let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema)?;
let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap());
let timezone = expr.timezone.clone();
Ok(Arc::new(Cast::new(child, datatype, timezone)))
let eval_mode = match expr.eval_mode.as_str() {
"ANSI" => EvalMode::Ansi,
"TRY" => EvalMode::Try,
"LEGACY" => EvalMode::Legacy,
other => {
return Err(ExecutionError::GeneralError(format!(
"Invalid Cast EvalMode: \"{other}\""
)))
}
};
Ok(Arc::new(Cast::new(child, datatype, eval_mode, timezone)))
}
ExprStruct::Hour(expr) => {
let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema)?;
Expand Down Expand Up @@ -640,13 +650,19 @@ impl PhysicalPlanner {
let left = Arc::new(Cast::new_without_timezone(
left,
DataType::Decimal256(p1, s1),
EvalMode::Legacy,
));
let right = Arc::new(Cast::new_without_timezone(
right,
DataType::Decimal256(p2, s2),
EvalMode::Legacy,
));
let child = Arc::new(BinaryExpr::new(left, op, right));
Ok(Arc::new(Cast::new_without_timezone(child, data_type)))
Ok(Arc::new(Cast::new_without_timezone(
child,
data_type,
EvalMode::Legacy,
)))
}
(
DataFusionOperator::Divide,
Expand Down
2 changes: 2 additions & 0 deletions core/src/execution/proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,8 @@ message Cast {
Expr child = 1;
DataType datatype = 2;
string timezone = 3;
// LEGACY, ANSI, or TRY
string eval_mode = 4;
}

message Equal {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -572,8 +572,12 @@ class CometSparkSessionExtensions
// DataFusion doesn't have ANSI mode. For now we just disable CometExec if ANSI mode is
// enabled.
if (isANSIEnabled(conf)) {
logInfo("Comet extension disabled for ANSI mode")
return plan
if (COMET_ANSI_MODE_ENABLED.get()) {
logWarning("Using Comet's experimental support for ANSI mode.")
} else {
logInfo("Comet extension disabled for ANSI mode")
return plan
}
}

// We shouldn't transform Spark query plan if Comet is disabled.
Expand Down
18 changes: 14 additions & 4 deletions spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
Original file line number Diff line number Diff line change
Expand Up @@ -451,13 +451,15 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde {
def castToProto(
timeZoneId: Option[String],
dt: DataType,
childExpr: Option[Expr]): Option[Expr] = {
childExpr: Option[Expr],
evalMode: String): Option[Expr] = {
val dataType = serializeDataType(dt)

if (childExpr.isDefined && dataType.isDefined) {
val castBuilder = ExprOuterClass.Cast.newBuilder()
castBuilder.setChild(childExpr.get)
castBuilder.setDatatype(dataType.get)
castBuilder.setEvalMode(evalMode)

val timeZone = timeZoneId.getOrElse("UTC")
castBuilder.setTimezone(timeZone)
Expand All @@ -483,9 +485,16 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde {
val value = cast.eval()
exprToProtoInternal(Literal(value, dataType), inputs)

case Cast(child, dt, timeZoneId, _) =>
case Cast(child, dt, timeZoneId, evalMode) =>
val childExpr = exprToProtoInternal(child, inputs)
castToProto(timeZoneId, dt, childExpr)
val evalModeStr = if (evalMode.isInstanceOf[Boolean]) {
// Spark 3.2 & 3.3 has ansiEnabled boolean
if (evalMode.asInstanceOf[Boolean]) "ANSI" else "LEGACY"
} else {
// Spark 3.4+ has EvalMode enum with values LEGACY, ANSI, and TRY
evalMode.toString
}
castToProto(timeZoneId, dt, childExpr, evalModeStr)

case add @ Add(left, right, _) if supportedDataType(left.dataType) =>
val leftExpr = exprToProtoInternal(left, inputs)
Expand Down Expand Up @@ -1028,6 +1037,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde {
.newBuilder()
.setChild(e)
.setDatatype(serializeDataType(IntegerType).get)
.setEvalMode("LEGACY") // year is not affected by ANSI mode
.build())
.build()
})
Expand Down Expand Up @@ -1602,7 +1612,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde {
val childExpr = scalarExprToProto("coalesce", exprChildren: _*)
// TODO: Remove this once we have new DataFusion release which includes
// the fix: https://github.com/apache/arrow-datafusion/pull/9459
castToProto(None, a.dataType, childExpr)
castToProto(None, a.dataType, childExpr, "LEGACY")

// With Spark 3.4, CharVarcharCodegenUtils.readSidePadding gets called to pad spaces for
// char types. Use rpad to achieve the behavior.
Expand Down
Loading

0 comments on commit 370d641

Please sign in to comment.