Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implement Spark-compatible CAST between integer types #340

Merged
merged 16 commits into from
May 3, 2024
Merged
9 changes: 9 additions & 0 deletions core/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,15 @@ pub enum CometError {
to_type: String,
},

#[error("[CAST_OVERFLOW] The value {value} of the type \"{from_type}\" cannot be cast to \"{to_type}\" \
due to an overflow. Use `try_cast` to tolerate overflow and return NULL instead. If necessary \
set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")]
CastOverFlow {
value: String,
from_type: String,
to_type: String,
},

#[error(transparent)]
Arrow {
#[from]
Expand Down
98 changes: 98 additions & 0 deletions core/src/execution/datafusion/expressions/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,62 @@ macro_rules! cast_utf8_to_timestamp {
}};
}

macro_rules! cast_int_to_int_macro {
(
$array: expr,
$eval_mode:expr,
$from_arrow_primitive_type: ty,
$to_arrow_primitive_type: ty,
$from_data_type: expr,
$to_native_type: ty,
$spark_from_data_type_name: expr,
$spark_to_data_type_name: expr
) => {{
let cast_array = $array
.as_any()
.downcast_ref::<PrimitiveArray<$from_arrow_primitive_type>>()
.unwrap();
let spark_int_literal_suffix = match $from_data_type {
&DataType::Int64 => "L",
&DataType::Int16 => "S",
&DataType::Int8 => "T",
_ => "",
};

let output_array = match $eval_mode {
EvalMode::Legacy => cast_array
.iter()
.map(|value| match value {
Some(value) => {
Ok::<Option<$to_native_type>, CometError>(Some(value as $to_native_type))
}
_ => Ok(None),
})
.collect::<Result<PrimitiveArray<$to_arrow_primitive_type>, _>>(),
_ => cast_array
.iter()
.map(|value| match value {
Some(value) => {
let res = <$to_native_type>::try_from(value);
if res.is_err() {
Err(CometError::CastOverFlow {
value: value.to_string() + spark_int_literal_suffix,
from_type: $spark_from_data_type_name.to_string(),
to_type: $spark_to_data_type_name.to_string(),
})
} else {
Ok::<Option<$to_native_type>, CometError>(Some(res.unwrap()))
}
}
_ => Ok(None),
})
.collect::<Result<PrimitiveArray<$to_arrow_primitive_type>, _>>(),
}?;
let result: CometResult<ArrayRef> = Ok(Arc::new(output_array) as ArrayRef);
result
}};
}

impl Cast {
pub fn new(
child: Arc<dyn PhysicalExpr>,
Expand Down Expand Up @@ -149,6 +205,16 @@ impl Cast {
(DataType::Utf8, DataType::Timestamp(_, _)) => {
Self::cast_string_to_timestamp(&array, to_type, self.eval_mode)?
}
(DataType::Int64, DataType::Int32)
| (DataType::Int64, DataType::Int16)
| (DataType::Int64, DataType::Int8)
| (DataType::Int32, DataType::Int16)
| (DataType::Int32, DataType::Int8)
| (DataType::Int16, DataType::Int8)
if self.eval_mode != EvalMode::Try =>
{
Self::spark_cast_int_to_int(&array, self.eval_mode, from_type, to_type)?
}
(
DataType::Utf8,
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64,
Expand Down Expand Up @@ -248,6 +314,38 @@ impl Cast {
Ok(cast_array)
}

fn spark_cast_int_to_int(
array: &dyn Array,
eval_mode: EvalMode,
from_type: &DataType,
to_type: &DataType,
) -> CometResult<ArrayRef> {
match (from_type, to_type) {
(DataType::Int64, DataType::Int32) => cast_int_to_int_macro!(
array, eval_mode, Int64Type, Int32Type, from_type, i32, "BIGINT", "INT"
),
(DataType::Int64, DataType::Int16) => cast_int_to_int_macro!(
array, eval_mode, Int64Type, Int16Type, from_type, i16, "BIGINT", "SMALLINT"
),
(DataType::Int64, DataType::Int8) => cast_int_to_int_macro!(
array, eval_mode, Int64Type, Int8Type, from_type, i8, "BIGINT", "TINYINT"
),
(DataType::Int32, DataType::Int16) => cast_int_to_int_macro!(
array, eval_mode, Int32Type, Int16Type, from_type, i16, "INT", "SMALLINT"
),
(DataType::Int32, DataType::Int8) => cast_int_to_int_macro!(
array, eval_mode, Int32Type, Int8Type, from_type, i8, "INT", "TINYINT"
),
(DataType::Int16, DataType::Int8) => cast_int_to_int_macro!(
array, eval_mode, Int16Type, Int8Type, from_type, i8, "SMALLINT", "TINYINT"
),
_ => unreachable!(
"{}",
format!("invalid integer type {to_type} in cast from {from_type}")
),
}
}

fn spark_cast_utf8_to_boolean<OffsetSize>(
from: &dyn Array,
eval_mode: EvalMode,
Expand Down
38 changes: 36 additions & 2 deletions spark/src/test/scala/org/apache/comet/CometCastSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -665,6 +665,30 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
castTest(generateTimestamps(), DataTypes.DateType)
}

test("cast short to byte") {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These methods already exist in main but have different naming, so I think you need to upmerge/rebase against main.

Example:

ignore("cast ShortType to ByteType")

castTest(generateShorts, DataTypes.ByteType)
}

test("cast int to byte") {
castTest(generateInts, DataTypes.ByteType)
}

test("cast int to short") {
castTest(generateInts, DataTypes.ShortType)
}

test("cast long to byte") {
castTest(generateLongs, DataTypes.ByteType)
}

test("cast long to short") {
castTest(generateLongs, DataTypes.ShortType)
}

test("cast long to int") {
castTest(generateLongs, DataTypes.IntegerType)
}

private def generateFloats(): DataFrame = {
val r = new Random(0)
val values = Seq(
Expand Down Expand Up @@ -868,11 +892,21 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
} else {
// Spark 3.2 and 3.3 have a different error message format so we can't do a direct
// comparison between Spark and Comet.
// In the case of CAST_INVALID_INPUT
// Spark message is in format `invalid input syntax for type TYPE: VALUE`
// Comet message is in format `The value 'VALUE' of the type FROM_TYPE cannot be cast to TO_TYPE`
// We just check that the comet message contains the same invalid value as the Spark message
val sparkInvalidValue = sparkMessage.substring(sparkMessage.indexOf(':') + 2)
assert(cometMessage.contains(sparkInvalidValue))
// In the case of CAST_OVERFLOW
// Spark message is in format `Casting VALUE to TO_TYPE causes overflow`
// Comet message is in format `The value 'VALUE' of the type FROM_TYPE cannot be cast to TO_TYPE
// due to an overflow`
// We check if the comet message contains 'overflow'.
if (sparkMessage.indexOf(':') == -1) {
assert(cometMessage.contains("overflow"))
} else {
assert(
cometMessage.contains(sparkMessage.substring(sparkMessage.indexOf(':') + 2)))
}
}
}

Expand Down
Loading