Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implement Spark-compatible CAST between integer types #340

Merged
merged 16 commits into from
May 3, 2024
Merged
1 change: 1 addition & 0 deletions common/src/main/scala/org/apache/comet/Constants.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,5 @@ package org.apache.comet
object Constants {
val LOG_CONF_PATH = "comet.log.file.path"
val LOG_CONF_NAME = "log4rs.yaml"
val EMPTY_STRING = ""
}
9 changes: 9 additions & 0 deletions core/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,15 @@ pub enum CometError {
to_type: String,
},

#[error("[CAST_OVERFLOW] The value {value} of the type \"{from_type}\" cannot be cast to \"{to_type}\" \
due to an overflow. Use `try_cast` to tolerate overflow and return NULL instead. If necessary \
set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")]
CastOverFlow {
value: String,
from_type: String,
to_type: String,
},

#[error(transparent)]
Arrow {
#[from]
Expand Down
98 changes: 98 additions & 0 deletions core/src/execution/datafusion/expressions/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,62 @@ macro_rules! cast_utf8_to_int {
}};
}

macro_rules! cast_int_to_int_macro {
(
$array: expr,
$eval_mode:expr,
$from_arrow_primitive_type: ty,
$to_arrow_primitive_type: ty,
$from_data_type: expr,
$to_native_type: ty,
$spark_from_data_type_name: expr,
$spark_to_data_type_name: expr
) => {{
let cast_array = $array
.as_any()
.downcast_ref::<PrimitiveArray<$from_arrow_primitive_type>>()
.unwrap();
let spark_int_literal_suffix = match $from_data_type {
&DataType::Int64 => "L",
&DataType::Int16 => "S",
&DataType::Int8 => "T",
_ => "",
};

let output_array = match $eval_mode {
EvalMode::Legacy => cast_array
.iter()
.map(|value| match value {
Some(value) => {
Ok::<Option<$to_native_type>, CometError>(Some(value as $to_native_type))
}
_ => Ok(None),
})
.collect::<Result<PrimitiveArray<$to_arrow_primitive_type>, _>>(),
_ => cast_array
.iter()
.map(|value| match value {
Some(value) => {
let res = <$to_native_type>::try_from(value);
if res.is_err() {
Err(CometError::CastOverFlow {
value: value.to_string() + spark_int_literal_suffix,
from_type: $spark_from_data_type_name.to_string(),
to_type: $spark_to_data_type_name.to_string(),
})
} else {
Ok::<Option<$to_native_type>, CometError>(Some(res.unwrap()))
}
}
_ => Ok(None),
})
.collect::<Result<PrimitiveArray<$to_arrow_primitive_type>, _>>(),
}?;
let result: CometResult<ArrayRef> = Ok(Arc::new(output_array) as ArrayRef);
result
}};
}

impl Cast {
pub fn new(
child: Arc<dyn PhysicalExpr>,
Expand Down Expand Up @@ -125,6 +181,16 @@ impl Cast {
(DataType::LargeUtf8, DataType::Boolean) => {
Self::spark_cast_utf8_to_boolean::<i64>(&array, self.eval_mode)?
}
(DataType::Int64, DataType::Int32)
| (DataType::Int64, DataType::Int16)
| (DataType::Int64, DataType::Int8)
| (DataType::Int32, DataType::Int16)
| (DataType::Int32, DataType::Int8)
| (DataType::Int16, DataType::Int8)
if self.eval_mode != EvalMode::Try =>
{
Self::spark_cast_int_to_int(&array, self.eval_mode, from_type, to_type)?
}
(
DataType::Utf8,
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64,
Expand Down Expand Up @@ -200,6 +266,38 @@ impl Cast {
Ok(cast_array)
}

fn spark_cast_int_to_int(
array: &dyn Array,
eval_mode: EvalMode,
from_type: &DataType,
to_type: &DataType,
) -> CometResult<ArrayRef> {
match (from_type, to_type) {
(DataType::Int64, DataType::Int32) => cast_int_to_int_macro!(
array, eval_mode, Int64Type, Int32Type, from_type, i32, "BIGINT", "INT"
),
(DataType::Int64, DataType::Int16) => cast_int_to_int_macro!(
array, eval_mode, Int64Type, Int16Type, from_type, i16, "BIGINT", "SMALLINT"
),
(DataType::Int64, DataType::Int8) => cast_int_to_int_macro!(
array, eval_mode, Int64Type, Int8Type, from_type, i8, "BIGINT", "TINYINT"
),
(DataType::Int32, DataType::Int16) => cast_int_to_int_macro!(
array, eval_mode, Int32Type, Int16Type, from_type, i16, "INT", "SMALLINT"
),
(DataType::Int32, DataType::Int8) => cast_int_to_int_macro!(
array, eval_mode, Int32Type, Int8Type, from_type, i8, "INT", "TINYINT"
),
(DataType::Int16, DataType::Int8) => cast_int_to_int_macro!(
array, eval_mode, Int16Type, Int8Type, from_type, i8, "SMALLINT", "TINYINT"
),
_ => unreachable!(
"{}",
format!("invalid integer type {to_type} in cast from {from_type}")
),
}
}

fn spark_cast_utf8_to_boolean<OffsetSize>(
from: &dyn Array,
eval_mode: EvalMode,
Expand Down
41 changes: 39 additions & 2 deletions spark/src/test/scala/org/apache/comet/CometCastSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ import org.apache.spark.sql.functions.col
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataType, DataTypes}

import org.apache.comet.Constants.EMPTY_STRING

class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
import testImplicits._

Expand Down Expand Up @@ -642,6 +644,30 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
castTest(generateTimestamps(), DataTypes.DateType)
}

test("cast short to byte") {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These methods already exist in main but have different naming, so I think you need to upmerge/rebase against main.

Example:

ignore("cast ShortType to ByteType")

castTest(generateShorts, DataTypes.ByteType)
}

test("cast int to byte") {
castTest(generateInts, DataTypes.ByteType)
}

test("cast int to short") {
castTest(generateInts, DataTypes.ShortType)
}

test("cast long to byte") {
castTest(generateLongs, DataTypes.ByteType)
}

test("cast long to short") {
castTest(generateLongs, DataTypes.ShortType)
}

test("cast long to int") {
castTest(generateLongs, DataTypes.IntegerType)
}

private def generateFloats(): DataFrame = {
val r = new Random(0)
val values = Seq(
Expand Down Expand Up @@ -807,11 +833,22 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
} else {
// Spark 3.2 and 3.3 have a different error message format so we can't do a direct
// comparison between Spark and Comet.
// In the case of CAST_INVALID_INPUT
// Spark message is in format `invalid input syntax for type TYPE: VALUE`
// Comet message is in format `The value 'VALUE' of the type FROM_TYPE cannot be cast to TO_TYPE`
// We just check that the comet message contains the same invalid value as the Spark message
val sparkInvalidValue = sparkMessage.substring(sparkMessage.indexOf(':') + 2)
assert(cometMessage.contains(sparkInvalidValue))
// In the case of CAST_OVERFLOW
// Spark message is in format `Casting VALUE to TO_TYPE causes overflow`
// Comet message is in format `The value 'VALUE' of the type FROM_TYPE cannot be cast to TO_TYPE
// due to an overflow`
// We check if the comet message contains 'overflow'.
val sparkInvalidValue = if (sparkMessage.indexOf(':') == -1) {
EMPTY_STRING
} else {
sparkMessage.substring(sparkMessage.indexOf(':') + 2)
}
assert(
cometMessage.contains(sparkInvalidValue) || cometMessage.contains("overflow"))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If sparkInvalidValue is EMPTY_STRING, won't cometMessage.contains(sparkInvalidValue) always be true?

Copy link
Contributor Author

@ganeshkumar269 ganeshkumar269 May 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you are right, my bad 😅 . so incase sparkMessage doesnt have ':' should I assert on just commetMessage.contains("overflow")
something like this,

if sparkMessage.indexOf(':') == -1 then assert(commetMessage.contains("overflow"))
else assert(commetMessage.contains(sparkInvalidValue))

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, something like that. I haven't reviewed the overflow messages to see if they contain : though (in any of the spark versions 3.2, 3.3, and 3.4)

Copy link
Contributor Author

@ganeshkumar269 ganeshkumar269 May 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doesnt look like overflow error message has ':' in it, i ran spark.sql("select cast(9223372036854775807 as int)").show() in my local on various spark versions.

3.4 - [CAST_OVERFLOW] The value 9223372036854775807L of the type "BIGINT" cannot be cast to "INT" due to an overflow. Use try_cast to tolerate overflow and return NULL instead. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error
3.3 - The value 9223372036854775807L of the type "BIGINT" cannot be cast to "INT" due to an overflow. Use try_cast to tolerate overflow and return NULL instead. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error
3.2 - Casting 9223372036854775807 to int causes overflow

}
}

Expand Down
Loading