diff --git a/native/spark-expr/src/cast.rs b/native/spark-expr/src/cast.rs index 9e2fabe42..f4c9e2165 100644 --- a/native/spark-expr/src/cast.rs +++ b/native/spark-expr/src/cast.rs @@ -15,6 +15,9 @@ // specific language governing permissions and limitations // under the License. +use crate::timezone; +use crate::utils::array_with_timezone; +use crate::{EvalMode, SparkError, SparkResult}; use arrow::{ array::{ cast::AsArray, @@ -56,11 +59,6 @@ use std::{ sync::Arc, }; -use crate::timezone; -use crate::utils::array_with_timezone; - -use crate::{EvalMode, SparkError, SparkResult}; - static TIMESTAMP_FORMAT: Option<&str> = Some("%Y-%m-%d %H:%M:%S%.f"); const MICROS_PER_SECOND: i64 = 1000000; @@ -166,6 +164,11 @@ pub fn cast_supported( match (from_type, to_type) { (Boolean, _) => can_cast_from_boolean(from_type, options), + (UInt8 | UInt16 | UInt32 | UInt64, Int8 | Int16 | Int32 | Int64) + if options.allow_cast_unsigned_ints => + { + true + } (Int8, _) => can_cast_from_byte(from_type, options), (Int16, _) => can_cast_from_short(from_type, options), (Int32, _) => can_cast_from_int(from_type, options), @@ -783,6 +786,8 @@ pub struct SparkCastOptions { pub timezone: String, /// Allow casts that are supported but not guaranteed to be 100% compatible pub allow_incompat: bool, + /// Support casting unsigned ints to signed ints (used by Parquet SchemaAdapter) + pub allow_cast_unsigned_ints: bool, } impl SparkCastOptions { @@ -791,6 +796,7 @@ impl SparkCastOptions { eval_mode, timezone: timezone.to_string(), allow_incompat, + allow_cast_unsigned_ints: false, } } @@ -799,6 +805,7 @@ impl SparkCastOptions { eval_mode, timezone: "".to_string(), allow_incompat, + allow_cast_unsigned_ints: false, } } } @@ -834,14 +841,14 @@ fn cast_array( to_type: &DataType, cast_options: &SparkCastOptions, ) -> DataFusionResult { + use DataType::*; let array = array_with_timezone(array, cast_options.timezone.clone(), Some(to_type))?; let from_type = array.data_type().clone(); let array = match &from_type { - DataType::Dictionary(key_type, value_type) - if key_type.as_ref() == &DataType::Int32 - && (value_type.as_ref() == &DataType::Utf8 - || value_type.as_ref() == &DataType::LargeUtf8) => + Dictionary(key_type, value_type) + if key_type.as_ref() == &Int32 + && (value_type.as_ref() == &Utf8 || value_type.as_ref() == &LargeUtf8) => { let dict_array = array .as_any() @@ -854,7 +861,7 @@ fn cast_array( ); let casted_result = match to_type { - DataType::Dictionary(_, _) => Arc::new(casted_dictionary.clone()), + Dictionary(_, _) => Arc::new(casted_dictionary.clone()), _ => take(casted_dictionary.values().as_ref(), dict_array.keys(), None)?, }; return Ok(spark_cast_postprocess(casted_result, &from_type, to_type)); @@ -865,70 +872,66 @@ fn cast_array( let eval_mode = cast_options.eval_mode; let cast_result = match (from_type, to_type) { - (DataType::Utf8, DataType::Boolean) => spark_cast_utf8_to_boolean::(&array, eval_mode), - (DataType::LargeUtf8, DataType::Boolean) => { - spark_cast_utf8_to_boolean::(&array, eval_mode) - } - (DataType::Utf8, DataType::Timestamp(_, _)) => { + (Utf8, Boolean) => spark_cast_utf8_to_boolean::(&array, eval_mode), + (LargeUtf8, Boolean) => spark_cast_utf8_to_boolean::(&array, eval_mode), + (Utf8, Timestamp(_, _)) => { cast_string_to_timestamp(&array, to_type, eval_mode, &cast_options.timezone) } - (DataType::Utf8, DataType::Date32) => cast_string_to_date(&array, to_type, eval_mode), - (DataType::Int64, DataType::Int32) - | (DataType::Int64, DataType::Int16) - | (DataType::Int64, DataType::Int8) - | (DataType::Int32, DataType::Int16) - | (DataType::Int32, DataType::Int8) - | (DataType::Int16, DataType::Int8) + (Utf8, Date32) => cast_string_to_date(&array, to_type, eval_mode), + (Int64, Int32) + | (Int64, Int16) + | (Int64, Int8) + | (Int32, Int16) + | (Int32, Int8) + | (Int16, Int8) if eval_mode != EvalMode::Try => { spark_cast_int_to_int(&array, eval_mode, from_type, to_type) } - (DataType::Utf8, DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64) => { + (Utf8, Int8 | Int16 | Int32 | Int64) => { cast_string_to_int::(to_type, &array, eval_mode) } - ( - DataType::LargeUtf8, - DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64, - ) => cast_string_to_int::(to_type, &array, eval_mode), - (DataType::Float64, DataType::Utf8) => spark_cast_float64_to_utf8::(&array, eval_mode), - (DataType::Float64, DataType::LargeUtf8) => { - spark_cast_float64_to_utf8::(&array, eval_mode) - } - (DataType::Float32, DataType::Utf8) => spark_cast_float32_to_utf8::(&array, eval_mode), - (DataType::Float32, DataType::LargeUtf8) => { - spark_cast_float32_to_utf8::(&array, eval_mode) - } - (DataType::Float32, DataType::Decimal128(precision, scale)) => { + (LargeUtf8, Int8 | Int16 | Int32 | Int64) => { + cast_string_to_int::(to_type, &array, eval_mode) + } + (Float64, Utf8) => spark_cast_float64_to_utf8::(&array, eval_mode), + (Float64, LargeUtf8) => spark_cast_float64_to_utf8::(&array, eval_mode), + (Float32, Utf8) => spark_cast_float32_to_utf8::(&array, eval_mode), + (Float32, LargeUtf8) => spark_cast_float32_to_utf8::(&array, eval_mode), + (Float32, Decimal128(precision, scale)) => { cast_float32_to_decimal128(&array, *precision, *scale, eval_mode) } - (DataType::Float64, DataType::Decimal128(precision, scale)) => { + (Float64, Decimal128(precision, scale)) => { cast_float64_to_decimal128(&array, *precision, *scale, eval_mode) } - (DataType::Float32, DataType::Int8) - | (DataType::Float32, DataType::Int16) - | (DataType::Float32, DataType::Int32) - | (DataType::Float32, DataType::Int64) - | (DataType::Float64, DataType::Int8) - | (DataType::Float64, DataType::Int16) - | (DataType::Float64, DataType::Int32) - | (DataType::Float64, DataType::Int64) - | (DataType::Decimal128(_, _), DataType::Int8) - | (DataType::Decimal128(_, _), DataType::Int16) - | (DataType::Decimal128(_, _), DataType::Int32) - | (DataType::Decimal128(_, _), DataType::Int64) + (Float32, Int8) + | (Float32, Int16) + | (Float32, Int32) + | (Float32, Int64) + | (Float64, Int8) + | (Float64, Int16) + | (Float64, Int32) + | (Float64, Int64) + | (Decimal128(_, _), Int8) + | (Decimal128(_, _), Int16) + | (Decimal128(_, _), Int32) + | (Decimal128(_, _), Int64) if eval_mode != EvalMode::Try => { spark_cast_nonintegral_numeric_to_integral(&array, eval_mode, from_type, to_type) } - (DataType::Struct(_), DataType::Utf8) => { - Ok(casts_struct_to_string(array.as_struct(), cast_options)?) - } - (DataType::Struct(_), DataType::Struct(_)) => Ok(cast_struct_to_struct( + (Struct(_), Utf8) => Ok(casts_struct_to_string(array.as_struct(), cast_options)?), + (Struct(_), Struct(_)) => Ok(cast_struct_to_struct( array.as_struct(), from_type, to_type, cast_options, )?), + (UInt8 | UInt16 | UInt32 | UInt64, Int8 | Int16 | Int32 | Int64) + if cast_options.allow_cast_unsigned_ints => + { + Ok(cast_with_options(&array, to_type, &CAST_OPTIONS)?) + } _ if is_datafusion_spark_compatible(from_type, to_type, cast_options.allow_incompat) => { // use DataFusion cast only when we know that it is compatible with Spark Ok(cast_with_options(&array, to_type, &CAST_OPTIONS)?) diff --git a/native/spark-expr/src/schema_adapter.rs b/native/spark-expr/src/schema_adapter.rs index 77da5fbc5..8f30735f6 100644 --- a/native/spark-expr/src/schema_adapter.rs +++ b/native/spark-expr/src/schema_adapter.rs @@ -291,6 +291,7 @@ mod test { use arrow::array::{Int32Array, StringArray}; use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; + use arrow_array::UInt32Array; use arrow_schema::SchemaRef; use datafusion::datasource::listing::PartitionedFile; use datafusion::datasource::physical_plan::{FileScanConfig, ParquetExec}; @@ -304,7 +305,7 @@ mod test { use std::sync::Arc; #[tokio::test] - async fn parquet_roundtrip() -> Result<(), DataFusionError> { + async fn parquet_roundtrip_int_as_string() -> Result<(), DataFusionError> { let file_schema = Arc::new(Schema::new(vec![ Field::new("id", DataType::Int32, false), Field::new("name", DataType::Utf8, false), @@ -325,6 +326,20 @@ mod test { Ok(()) } + #[tokio::test] + async fn parquet_roundtrip_unsigned_int() -> Result<(), DataFusionError> { + let file_schema = Arc::new(Schema::new(vec![Field::new("id", DataType::UInt32, false)])); + + let ids = Arc::new(UInt32Array::from(vec![1, 2, 3])) as Arc; + let batch = RecordBatch::try_new(Arc::clone(&file_schema), vec![ids])?; + + let required_schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)])); + + let _ = roundtrip(&batch, required_schema).await?; + + Ok(()) + } + /// Create a Parquet file containing a single batch and then read the batch back using /// the specified required_schema. This will cause the SchemaAdapter code to be used. async fn roundtrip( @@ -344,7 +359,9 @@ mod test { filename.to_string(), )?]]); - let spark_cast_options = SparkCastOptions::new(EvalMode::Legacy, "UTC", false); + let mut spark_cast_options = SparkCastOptions::new(EvalMode::Legacy, "UTC", false); + spark_cast_options.allow_cast_unsigned_ints = true; + let parquet_exec = ParquetExec::builder(file_scan_config) .with_schema_adapter_factory(Arc::new(SparkSchemaAdapterFactory::new( spark_cast_options,