diff --git a/core/src/execution/datafusion/expressions/cast.rs b/core/src/execution/datafusion/expressions/cast.rs index 63e714b39..7b1229aa6 100644 --- a/core/src/execution/datafusion/expressions/cast.rs +++ b/core/src/execution/datafusion/expressions/cast.rs @@ -68,6 +68,25 @@ pub struct Cast { pub timezone: String, } +macro_rules! spark_cast_utf8_to_integral { + ($string_array:expr, $eval_mode:expr, $array_type:ty, $cast_method:ident) => {{ + let mut cast_array = PrimitiveArray::<$array_type>::builder($string_array.len()); + for i in 0..$string_array.len() { + if $string_array.is_null(i) { + cast_array.append_null() + } else if let Some(cast_value) = + $cast_method($string_array.value(i).trim(), $eval_mode)? + { + cast_array.append_value(cast_value); + } else { + cast_array.append_null() + } + } + let result: CometResult = Ok(Arc::new(cast_array.finish()) as ArrayRef); + result + }}; +} + impl Cast { pub fn new( child: Arc, @@ -110,57 +129,19 @@ impl Cast { ( DataType::Utf8, DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64, - ) => { - let string_array = array - .as_any() - .downcast_ref::>() - .expect("spark_cast_utf8_to_i8 expected a string array"); - - match to_type { - DataType::Int8 => { - Self::spark_cast_utf8_to_i8::(string_array, self.eval_mode)? - } - DataType::Int16 => { - Self::spark_cast_utf8_to_i16::(string_array, self.eval_mode)? - } - DataType::Int32 => { - Self::spark_cast_utf8_to_i32::(string_array, self.eval_mode)? - } - DataType::Int64 => { - Self::spark_cast_utf8_to_i64::(string_array, self.eval_mode)? - } - _ => unreachable!("invalid integral type in cast from string"), - } - } + ) => Self::spark_cast_string_to_integral(to_type, &array, self.eval_mode)?, ( DataType::Dictionary(key_type, value_type), DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64, ) if key_type.as_ref() == &DataType::Int32 && value_type.as_ref() == &DataType::Utf8 => { - // TODO file follow on issue for optimizing this to avoid unpacking first + // Note that we are unpacking a dictionary-encoded array and then performing + // the cast. We could potentially improve performance here by casting the + // dictionary values directly without unpacking the array first, although this + // would add more complexity to the code let unpacked_array = Self::unpack_dict_string_array::(&array)?; - let string_array = unpacked_array - .as_any() - .downcast_ref::>() - .expect("spark_cast_utf8_to_i8 expected a string array"); - match to_type { - DataType::Int8 => { - Self::spark_cast_utf8_to_i8::(&string_array, self.eval_mode)? - } - DataType::Int16 => { - Self::spark_cast_utf8_to_i16::(&string_array, self.eval_mode)? - } - DataType::Int32 => { - Self::spark_cast_utf8_to_i32::(&string_array, self.eval_mode)? - } - DataType::Int64 => { - Self::spark_cast_utf8_to_i64::(&string_array, self.eval_mode)? - } - _ => { - unreachable!("invalid integral type in cast from dictionary-encoded string") - } - } + Self::spark_cast_string_to_integral(to_type, &unpacked_array, self.eval_mode)? } _ => { // when we have no Spark-specific casting we delegate to DataFusion @@ -170,6 +151,43 @@ impl Cast { Ok(spark_cast(cast_result, from_type, to_type)) } + fn spark_cast_string_to_integral( + to_type: &DataType, + array: &ArrayRef, + eval_mode: EvalMode, + ) -> CometResult { + let string_array = array + .as_any() + .downcast_ref::>() + .expect("spark_cast_string_to_integral expected a string array"); + + let cast_array: ArrayRef = match to_type { + DataType::Int8 => { + spark_cast_utf8_to_integral!(string_array, eval_mode, Int8Type, cast_string_to_i8)? + } + DataType::Int16 => spark_cast_utf8_to_integral!( + string_array, + eval_mode, + Int16Type, + cast_string_to_i16 + )?, + DataType::Int32 => spark_cast_utf8_to_integral!( + string_array, + eval_mode, + Int32Type, + cast_string_to_i32 + )?, + DataType::Int64 => spark_cast_utf8_to_integral!( + string_array, + eval_mode, + Int64Type, + cast_string_to_i64 + )?, + _ => unreachable!("invalid integral type in cast from string"), + }; + Ok(cast_array) + } + fn unpack_dict_string_array( array: &ArrayRef, ) -> DataFusionResult { @@ -213,100 +231,6 @@ impl Cast { Ok(Arc::new(output_array)) } - - // TODO reduce code duplication - - fn spark_cast_utf8_to_i8( - string_array: &GenericStringArray, - eval_mode: EvalMode, - ) -> CometResult - where - OffsetSize: OffsetSizeTrait, - { - // cast the dictionary values from string to int8 - let mut cast_array = PrimitiveArray::::builder(string_array.len()); - for i in 0..string_array.len() { - if string_array.is_null(i) { - cast_array.append_null() - } else if let Some(cast_value) = - cast_string_to_i8(string_array.value(i).trim(), eval_mode)? - { - cast_array.append_value(cast_value); - } else { - cast_array.append_null() - } - } - Ok(Arc::new(cast_array.finish())) - } - - fn spark_cast_utf8_to_i16( - string_array: &GenericStringArray, - eval_mode: EvalMode, - ) -> CometResult - where - OffsetSize: OffsetSizeTrait, - { - // cast the dictionary values from string to int8 - let mut cast_array = PrimitiveArray::::builder(string_array.len()); - for i in 0..string_array.len() { - if string_array.is_null(i) { - cast_array.append_null() - } else if let Some(cast_value) = - cast_string_to_i16(string_array.value(i).trim(), eval_mode)? - { - cast_array.append_value(cast_value); - } else { - cast_array.append_null() - } - } - Ok(Arc::new(cast_array.finish())) - } - - fn spark_cast_utf8_to_i32( - string_array: &GenericStringArray, - eval_mode: EvalMode, - ) -> CometResult - where - OffsetSize: OffsetSizeTrait, - { - // cast the dictionary values from string to int8 - let mut cast_array = PrimitiveArray::::builder(string_array.len()); - for i in 0..string_array.len() { - if string_array.is_null(i) { - cast_array.append_null() - } else if let Some(cast_value) = - cast_string_to_i32(string_array.value(i).trim(), eval_mode)? - { - cast_array.append_value(cast_value); - } else { - cast_array.append_null() - } - } - Ok(Arc::new(cast_array.finish())) - } - - fn spark_cast_utf8_to_i64( - string_array: &GenericStringArray, - eval_mode: EvalMode, - ) -> CometResult - where - OffsetSize: OffsetSizeTrait, - { - // cast the dictionary values from string to int8 - let mut cast_array = PrimitiveArray::::builder(string_array.len()); - for i in 0..string_array.len() { - if string_array.is_null(i) { - cast_array.append_null() - } else if let Some(cast_value) = - cast_string_to_i64(string_array.value(i).trim(), eval_mode)? - { - cast_array.append_value(cast_value); - } else { - cast_array.append_null() - } - } - Ok(Arc::new(cast_array.finish())) - } } fn cast_string_to_i8(str: &str, eval_mode: EvalMode) -> CometResult> {