Skip to content

Commit

Permalink
handled cast for all overflow cases
Browse files Browse the repository at this point in the history
  • Loading branch information
ganesh.maddula committed Apr 28, 2024
1 parent 54a5f59 commit 8b29641
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 57 deletions.
3 changes: 1 addition & 2 deletions core/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,7 @@ pub enum CometError {
from_type: String,
to_type: String,
},
// Note that this message format is based on Spark 3.4 and is more detailed than the message
// returned by Spark 3.2 or 3.3

#[error("[CAST_OVERFLOW] The value {value} of the type \"{from_type}\" cannot be cast to \"{to_type}\" \
due to an overflow. Use `try_cast` to tolerate overflow and return NULL instead. If necessary \
set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")]
Expand Down
131 changes: 89 additions & 42 deletions core/src/execution/datafusion/expressions/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ use arrow::{
record_batch::RecordBatch,
util::display::FormatOptions,
};
use arrow_array::{Array, ArrayRef, BooleanArray, GenericStringArray, Int16Array, Int64Array, OffsetSizeTrait};
use arrow_array::{Array, ArrayRef, BooleanArray, GenericStringArray, OffsetSizeTrait, PrimitiveArray};
use arrow_array::types::{Int16Type, Int32Type, Int64Type, Int8Type};
use arrow_schema::{DataType, Schema};
use datafusion::logical_expr::ColumnarValue;
use datafusion_common::{internal_err, Result as DataFusionResult, ScalarValue};
Expand Down Expand Up @@ -64,6 +65,62 @@ pub struct Cast {
pub timezone: String,
}

macro_rules! cast_int_to_int_macro{
(
$array: expr,
$eval_mode:expr,
$from_arrow_primitive_type: ty,
$to_arrow_primitive_type: ty,
$from_data_type: expr,
$to_native_type: ty,
$spark_from_data_type_name: expr,
$spark_to_data_type_name: expr
) => {{
let cast_array = $array
.as_any()
.downcast_ref::<PrimitiveArray::<$from_arrow_primitive_type>>()
.unwrap();
let spark_int_literal_suffix = match $from_data_type {
&DataType::Int64 => "L",
&DataType::Int16 => "S",
&DataType::Int8 => "T",
_ => ""
};

let output_array = match $eval_mode {
EvalMode::Legacy =>
cast_array.iter()
.map(|value| match value {
Some(value) => Ok::<Option<$to_native_type>, CometError>(Some(value as $to_native_type)),
_ => Ok(None)
})
.collect::<Result<PrimitiveArray::<$to_arrow_primitive_type>, _>>(),
_ => {
cast_array.iter()
.map(|value| match value{
Some(value) => {
let res = <$to_native_type>::try_from(value);
if res.is_err() {
Err(CometError::CastOverFlow{
value: value.to_string() + spark_int_literal_suffix,
from_type: $spark_from_data_type_name.to_string(),
to_type: $spark_to_data_type_name.to_string(),
})
}else{
Ok::<Option<$to_native_type>, CometError>(Some(res.unwrap()))
}

},
_ => Ok(None)
})
.collect::<Result<PrimitiveArray::<$to_arrow_primitive_type>, _>>()
}
}?;
let result: CometResult<ArrayRef> = Ok(Arc::new(output_array) as ArrayRef);
result
}};
}

impl Cast {
pub fn new(
child: Arc<dyn PhysicalExpr>,
Expand Down Expand Up @@ -103,56 +160,46 @@ impl Cast {
(DataType::LargeUtf8, DataType::Boolean) => {
Self::spark_cast_utf8_to_boolean::<i64>(&array, self.eval_mode)?
}
(DataType::Int64, DataType::Int16) if self.eval_mode != EvalMode::Try => {
// (DataType::Int64, DataType::Int16) => {
Self::spark_cast_int64_to_int16(&array, self.eval_mode)?
(DataType::Int64, DataType::Int32)
| (DataType::Int64, DataType::Int16)
| (DataType::Int64, DataType::Int8)
| (DataType::Int32, DataType::Int16)
| (DataType::Int32, DataType::Int8)
| (DataType::Int16, DataType::Int8)
if self.eval_mode != EvalMode::Try => {
Self::spark_cast_int_to_int(&array, self.eval_mode, from_type, to_type)?
}
_ => cast_with_options(&array, to_type, &CAST_OPTIONS)?,
};
let result = spark_cast(cast_result, from_type, to_type);
Ok(result)
}
fn spark_cast_int64_to_int16(
from: &dyn Array,

fn spark_cast_int_to_int(
array: &dyn Array,
eval_mode: EvalMode,
from_type: &DataType,
to_type: &DataType,
) -> CometResult<ArrayRef>
{
let array = from
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();

let output_array = match eval_mode {
EvalMode::Legacy => {
array.iter()
.map(|value| match value{
Some(value) => Ok::<Option<i16>, CometError>(Some(value as i16)),
_ => Ok(None)
})
.collect::<Result<Int16Array, _>>()?
},
_ => {
array.iter()
.map(|value| match value{
Some(value) => {
let res = i16::try_from(value);
if res.is_err() {
Err(CometError::CastOverFlow{
value: value.to_string() + "L",
from_type: "BIGINT".to_string(),
to_type: "SMALLINT".to_string(),
})
}else{
Ok::<Option<i16>, CometError>(Some(i16::try_from(value).unwrap()))
}

},
_ => Ok(None)
})
.collect::<Result<Int16Array, _>>()?
}
};
Ok(Arc::new(output_array))
match (from_type, to_type) {
(DataType::Int64, DataType::Int32) =>
cast_int_to_int_macro!(array, eval_mode, Int64Type, Int32Type, from_type, i32, "BIGINT", "INT"),
(DataType::Int64, DataType::Int16) =>
cast_int_to_int_macro!(array, eval_mode, Int64Type, Int16Type, from_type, i16, "BIGINT", "SMALLINT"),
(DataType::Int64, DataType::Int8) =>
cast_int_to_int_macro!(array, eval_mode, Int64Type, Int8Type, from_type, i8, "BIGINT", "TINYINT"),
(DataType::Int32, DataType::Int16) =>
cast_int_to_int_macro!(array, eval_mode, Int32Type, Int16Type, from_type, i16, "INT", "SMALLINT"),
(DataType::Int32, DataType::Int8) =>
cast_int_to_int_macro!(array, eval_mode, Int32Type, Int8Type, from_type, i8, "INT", "TINYINT"),
(DataType::Int16, DataType::Int8) =>
cast_int_to_int_macro!(array, eval_mode, Int16Type, Int8Type, from_type, i8, "SMALLINT", "TINYINT"),
_ => unreachable!(
"{}",
format!("invalid integer type {to_type} in cast from {from_type}")
),
}
}

fn spark_cast_utf8_to_boolean<OffsetSize>(
Expand Down
37 changes: 24 additions & 13 deletions spark/src/test/scala/org/apache/comet/CometCastSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,6 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
private val datePattern = "0123456789/" + whitespaceChars
private val timestampPattern = "0123456789/:T" + whitespaceChars

// ignore("cast long to short") {
// castTest(generateLongs, DataTypes.ShortType)
// }
//
ignore("cast float to bool") {
castTest(generateFloats, DataTypes.BooleanType)
}
Expand Down Expand Up @@ -106,26 +102,29 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
castTest(values.toDF("a"), DataTypes.DoubleType)
}

// spotless:off
test("cast short to int"){

test("cast short to byte") {
castTest(generateShorts, DataTypes.ByteType)
}
test("cast short to long"){

test("cast int to byte") {
castTest(generateInts, DataTypes.ByteType)
}
test("cast int to short"){

test("cast int to short") {
castTest(generateInts, DataTypes.ShortType)
}
test("cast int to long"){

test("cast long to byte") {
castTest(generateLongs, DataTypes.ByteType)
}
test("cast long to short"){

test("cast long to short") {
castTest(generateLongs, DataTypes.ShortType)
}
test("cast long to int"){

test("cast long to int") {
castTest(generateLongs, DataTypes.IntegerType)
}
// spotless:on

private def generateFloats(): DataFrame = {
val r = new Random(0)
Expand All @@ -137,6 +136,18 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
(Range(0, dataSize).map(_ => r.nextLong()) ++ Seq(Long.MaxValue, Long.MinValue)).toDF("a")
}

private def generateInts(): DataFrame = {
val r = new Random(0)
(Range(0, dataSize).map(_ => r.nextInt()) ++ Seq(Int.MaxValue, Int.MinValue)).toDF("a")
}

private def generateShorts(): DataFrame = {
val r = new Random(0)
(Range(0, dataSize).map(_ => r.nextInt(Short.MaxValue).toShort) ++ Seq(
Short.MaxValue,
Short.MinValue)).toDF("a")
}

private def generateString(r: Random, chars: String, maxLen: Int): String = {
val len = r.nextInt(maxLen)
Range(0, len).map(_ => chars.charAt(r.nextInt(chars.length))).mkString
Expand Down

0 comments on commit 8b29641

Please sign in to comment.