Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implement Spark-compatible CAST between integer types #340

Merged
merged 16 commits into from
May 3, 2024
Merged
9 changes: 9 additions & 0 deletions core/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,15 @@ pub enum CometError {
to_type: String,
},

#[error("[CAST_OVERFLOW] The value {value} of the type \"{from_type}\" cannot be cast to \"{to_type}\" \
due to an overflow. Use `try_cast` to tolerate overflow and return NULL instead. If necessary \
set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")]
CastOverFlow {
value: String,
from_type: String,
to_type: String,
},

#[error(transparent)]
Arrow {
#[from]
Expand Down
103 changes: 102 additions & 1 deletion core/src/execution/datafusion/expressions/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@ use arrow::{
record_batch::RecordBatch,
util::display::FormatOptions,
};
use arrow_array::{Array, ArrayRef, BooleanArray, GenericStringArray, OffsetSizeTrait};
use arrow_array::{
types::{Int16Type, Int32Type, Int64Type, Int8Type},
Array, ArrayRef, BooleanArray, GenericStringArray, OffsetSizeTrait, PrimitiveArray,
};
use arrow_schema::{DataType, Schema};
use datafusion::logical_expr::ColumnarValue;
use datafusion_common::{internal_err, Result as DataFusionResult, ScalarValue};
Expand Down Expand Up @@ -64,6 +67,62 @@ pub struct Cast {
pub timezone: String,
}

macro_rules! cast_int_to_int_macro {
(
$array: expr,
$eval_mode:expr,
$from_arrow_primitive_type: ty,
$to_arrow_primitive_type: ty,
$from_data_type: expr,
$to_native_type: ty,
$spark_from_data_type_name: expr,
$spark_to_data_type_name: expr
) => {{
let cast_array = $array
.as_any()
.downcast_ref::<PrimitiveArray<$from_arrow_primitive_type>>()
.unwrap();
let spark_int_literal_suffix = match $from_data_type {
&DataType::Int64 => "L",
&DataType::Int16 => "S",
&DataType::Int8 => "T",
_ => "",
};

let output_array = match $eval_mode {
EvalMode::Legacy => cast_array
.iter()
.map(|value| match value {
Some(value) => {
Ok::<Option<$to_native_type>, CometError>(Some(value as $to_native_type))
}
_ => Ok(None),
})
.collect::<Result<PrimitiveArray<$to_arrow_primitive_type>, _>>(),
_ => cast_array
.iter()
.map(|value| match value {
Some(value) => {
let res = <$to_native_type>::try_from(value);
if res.is_err() {
Err(CometError::CastOverFlow {
value: value.to_string() + spark_int_literal_suffix,
from_type: $spark_from_data_type_name.to_string(),
to_type: $spark_to_data_type_name.to_string(),
})
} else {
Ok::<Option<$to_native_type>, CometError>(Some(res.unwrap()))
}
}
_ => Ok(None),
})
.collect::<Result<PrimitiveArray<$to_arrow_primitive_type>, _>>(),
}?;
let result: CometResult<ArrayRef> = Ok(Arc::new(output_array) as ArrayRef);
result
}};
}

impl Cast {
pub fn new(
child: Arc<dyn PhysicalExpr>,
Expand Down Expand Up @@ -103,12 +162,54 @@ impl Cast {
(DataType::LargeUtf8, DataType::Boolean) => {
Self::spark_cast_utf8_to_boolean::<i64>(&array, self.eval_mode)?
}
(DataType::Int64, DataType::Int32)
| (DataType::Int64, DataType::Int16)
| (DataType::Int64, DataType::Int8)
| (DataType::Int32, DataType::Int16)
| (DataType::Int32, DataType::Int8)
| (DataType::Int16, DataType::Int8)
if self.eval_mode != EvalMode::Try =>
{
Self::spark_cast_int_to_int(&array, self.eval_mode, from_type, to_type)?
}
_ => cast_with_options(&array, to_type, &CAST_OPTIONS)?,
};
let result = spark_cast(cast_result, from_type, to_type);
Ok(result)
}

fn spark_cast_int_to_int(
array: &dyn Array,
eval_mode: EvalMode,
from_type: &DataType,
to_type: &DataType,
) -> CometResult<ArrayRef> {
match (from_type, to_type) {
(DataType::Int64, DataType::Int32) => cast_int_to_int_macro!(
array, eval_mode, Int64Type, Int32Type, from_type, i32, "BIGINT", "INT"
),
(DataType::Int64, DataType::Int16) => cast_int_to_int_macro!(
array, eval_mode, Int64Type, Int16Type, from_type, i16, "BIGINT", "SMALLINT"
),
(DataType::Int64, DataType::Int8) => cast_int_to_int_macro!(
array, eval_mode, Int64Type, Int8Type, from_type, i8, "BIGINT", "TINYINT"
),
(DataType::Int32, DataType::Int16) => cast_int_to_int_macro!(
array, eval_mode, Int32Type, Int16Type, from_type, i16, "INT", "SMALLINT"
),
(DataType::Int32, DataType::Int8) => cast_int_to_int_macro!(
array, eval_mode, Int32Type, Int8Type, from_type, i8, "INT", "TINYINT"
),
(DataType::Int16, DataType::Int8) => cast_int_to_int_macro!(
array, eval_mode, Int16Type, Int8Type, from_type, i8, "SMALLINT", "TINYINT"
),
_ => unreachable!(
"{}",
format!("invalid integer type {to_type} in cast from {from_type}")
),
}
}

fn spark_cast_utf8_to_boolean<OffsetSize>(
from: &dyn Array,
eval_mode: EvalMode,
Expand Down
27 changes: 26 additions & 1 deletion spark/src/test/scala/org/apache/comet/CometCastSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,30 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
// TODO: implement
}

test("cast short to byte") {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These methods already exist in main but have different naming, so I think you need to upmerge/rebase against main.

Example:

ignore("cast ShortType to ByteType")

castTest(generateShorts, DataTypes.ByteType)
}

test("cast int to byte") {
castTest(generateInts, DataTypes.ByteType)
}

test("cast int to short") {
castTest(generateInts, DataTypes.ShortType)
}

test("cast long to byte") {
castTest(generateLongs, DataTypes.ByteType)
}

test("cast long to short") {
castTest(generateLongs, DataTypes.ShortType)
}

test("cast long to int") {
castTest(generateLongs, DataTypes.IntegerType)
}

private def generateFloats(): DataFrame = {
val r = new Random(0)
val values = Seq(
Expand Down Expand Up @@ -722,7 +746,8 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
// Comet message is in format `The value 'VALUE' of the type FROM_TYPE cannot be cast to TO_TYPE`
// We just check that the comet message contains the same invalid value as the Spark message
val sparkInvalidValue = sparkMessage.substring(sparkMessage.indexOf(':') + 2)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that we are handling multiple types of error here, we should probably check if sparkMessage.indexOf(':') returns a non-zero value before trying to use it.

assert(cometMessage.contains(sparkInvalidValue))
assert(
cometMessage.contains(sparkInvalidValue) || cometMessage.contains("overflow"))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If sparkInvalidValue is EMPTY_STRING, won't cometMessage.contains(sparkInvalidValue) always be true?

Copy link
Contributor Author

@ganeshkumar269 ganeshkumar269 May 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you are right, my bad 😅 . so incase sparkMessage doesnt have ':' should I assert on just commetMessage.contains("overflow")
something like this,

if sparkMessage.indexOf(':') == -1 then assert(commetMessage.contains("overflow"))
else assert(commetMessage.contains(sparkInvalidValue))

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, something like that. I haven't reviewed the overflow messages to see if they contain : though (in any of the spark versions 3.2, 3.3, and 3.4)

Copy link
Contributor Author

@ganeshkumar269 ganeshkumar269 May 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doesnt look like overflow error message has ':' in it, i ran spark.sql("select cast(9223372036854775807 as int)").show() in my local on various spark versions.

3.4 - [CAST_OVERFLOW] The value 9223372036854775807L of the type "BIGINT" cannot be cast to "INT" due to an overflow. Use try_cast to tolerate overflow and return NULL instead. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error
3.3 - The value 9223372036854775807L of the type "BIGINT" cannot be cast to "INT" due to an overflow. Use try_cast to tolerate overflow and return NULL instead. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error
3.2 - Casting 9223372036854775807 to int causes overflow

}
}

Expand Down
Loading