From e674d034d58be0a29d7724e845b193c7c75cb7cf Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 23 Apr 2024 07:42:35 -0600 Subject: [PATCH 01/46] Implement Spark-compatible cast from string to integral types --- .../execution/datafusion/expressions/cast.rs | 407 +++++++++++++++++- .../org/apache/comet/CometCastSuite.scala | 23 +- .../org/apache/spark/sql/CometTestBase.scala | 1 + 3 files changed, 420 insertions(+), 11 deletions(-) diff --git a/core/src/execution/datafusion/expressions/cast.rs b/core/src/execution/datafusion/expressions/cast.rs index 10079855d..afe87bdad 100644 --- a/core/src/execution/datafusion/expressions/cast.rs +++ b/core/src/execution/datafusion/expressions/cast.rs @@ -24,11 +24,15 @@ use std::{ use crate::errors::{CometError, CometResult}; use arrow::{ - compute::{cast_with_options, CastOptions}, + compute::{cast_with_options, take, CastOptions}, record_batch::RecordBatch, util::display::FormatOptions, }; -use arrow_array::{Array, ArrayRef, BooleanArray, GenericStringArray, OffsetSizeTrait}; +use arrow_array::{ + types::{Int16Type, Int32Type, Int64Type, Int8Type}, + Array, ArrayRef, BooleanArray, DictionaryArray, GenericStringArray, OffsetSizeTrait, + PrimitiveArray, +}; use arrow_schema::{DataType, Schema}; use datafusion::logical_expr::ColumnarValue; use datafusion_common::{internal_err, Result as DataFusionResult, ScalarValue}; @@ -103,12 +107,59 @@ impl Cast { (DataType::LargeUtf8, DataType::Boolean) => { Self::spark_cast_utf8_to_boolean::(&array, self.eval_mode)? } + (DataType::Utf8, DataType::Int8) => { + Self::spark_cast_utf8_to_i8::(&array, self.eval_mode)? + } + (DataType::Dictionary(a, b), DataType::Int8) + if a.as_ref() == &DataType::Int32 && b.as_ref() == &DataType::Utf8 => + { + // TODO file follow on issue for optimizing this to avoid unpacking first + let unpacked_array = Self::unpack_dict_string_array::(&array)?; + Self::spark_cast_utf8_to_i8::(&unpacked_array, self.eval_mode)? + } + (DataType::Utf8, DataType::Int16) => { + Self::spark_cast_utf8_to_i16::(&array, self.eval_mode)? + } + (DataType::Dictionary(a, b), DataType::Int16) + if a.as_ref() == &DataType::Int32 && b.as_ref() == &DataType::Utf8 => + { + let unpacked_array = Self::unpack_dict_string_array::(&array)?; + Self::spark_cast_utf8_to_i16::(&unpacked_array, self.eval_mode)? + } + (DataType::Utf8, DataType::Int32) => { + Self::spark_cast_utf8_to_i32::(&array, self.eval_mode)? + } + (DataType::Dictionary(a, b), DataType::Int32) + if a.as_ref() == &DataType::Int32 && b.as_ref() == &DataType::Utf8 => + { + let unpacked_array = Self::unpack_dict_string_array::(&array)?; + Self::spark_cast_utf8_to_i32::(&unpacked_array, self.eval_mode)? + } + (DataType::Utf8, DataType::Int64) => { + Self::spark_cast_utf8_to_i64::(&array, self.eval_mode)? + } + (DataType::Dictionary(a, b), DataType::Int64) + if a.as_ref() == &DataType::Int32 && b.as_ref() == &DataType::Utf8 => + { + let unpacked_array = Self::unpack_dict_string_array::(&array)?; + Self::spark_cast_utf8_to_i64::(&unpacked_array, self.eval_mode)? + } _ => cast_with_options(&array, to_type, &CAST_OPTIONS)?, }; let result = spark_cast(cast_result, from_type, to_type); Ok(result) } + fn unpack_dict_string_array(array: &ArrayRef) -> DataFusionResult { + let dict_array = array + .as_any() + .downcast_ref::>() + .expect("DictionaryArray"); + + let unpacked_array = take(dict_array.values().as_ref(), dict_array.keys(), None)?; + Ok(unpacked_array) + } + fn spark_cast_utf8_to_boolean( from: &dyn Array, eval_mode: EvalMode, @@ -140,6 +191,327 @@ impl Cast { Ok(Arc::new(output_array)) } + + // TODO reduce code duplication + + fn spark_cast_utf8_to_i8( + from: &dyn Array, + eval_mode: EvalMode, + ) -> CometResult + where + OffsetSize: OffsetSizeTrait, + { + let string_array = from + .as_any() + .downcast_ref::>() + .expect("spark_cast_utf8_to_i8 expected a string array"); + + // cast the dictionary values from string to int8 + let mut cast_array = PrimitiveArray::::builder(string_array.len()); + for i in 0..string_array.len() { + if string_array.is_null(i) { + cast_array.append_null() + } else { + if let Some(cast_value) = + cast_string_to_i8(string_array.value(i).trim(), eval_mode)? + { + cast_array.append_value(cast_value); + } else { + cast_array.append_null() + } + } + } + Ok(Arc::new(cast_array.finish())) + } + + fn spark_cast_utf8_to_i16( + from: &dyn Array, + eval_mode: EvalMode, + ) -> CometResult + where + OffsetSize: OffsetSizeTrait, + { + let string_array = from + .as_any() + .downcast_ref::>() + .expect("spark_cast_utf8_to_i16 expected a string array"); + + // cast the dictionary values from string to int8 + let mut cast_array = PrimitiveArray::::builder(string_array.len()); + for i in 0..string_array.len() { + if string_array.is_null(i) { + cast_array.append_null() + } else { + if let Some(cast_value) = + cast_string_to_i16(string_array.value(i).trim(), eval_mode)? + { + cast_array.append_value(cast_value); + } else { + cast_array.append_null() + } + } + } + Ok(Arc::new(cast_array.finish())) + } + + fn spark_cast_utf8_to_i32( + from: &dyn Array, + eval_mode: EvalMode, + ) -> CometResult + where + OffsetSize: OffsetSizeTrait, + { + let string_array = from + .as_any() + .downcast_ref::>() + .expect("spark_cast_utf8_to_i32 expected a string array"); + + // cast the dictionary values from string to int8 + let mut cast_array = PrimitiveArray::::builder(string_array.len()); + for i in 0..string_array.len() { + if string_array.is_null(i) { + cast_array.append_null() + } else { + if let Some(cast_value) = + cast_string_to_i32(string_array.value(i).trim(), eval_mode)? + { + cast_array.append_value(cast_value); + } else { + cast_array.append_null() + } + } + } + Ok(Arc::new(cast_array.finish())) + } + + fn spark_cast_utf8_to_i64( + from: &dyn Array, + eval_mode: EvalMode, + ) -> CometResult + where + OffsetSize: OffsetSizeTrait, + { + let string_array = from + .as_any() + .downcast_ref::>() + .expect("spark_cast_utf8_to_i64 expected a string array"); + + // cast the dictionary values from string to int8 + let mut cast_array = PrimitiveArray::::builder(string_array.len()); + for i in 0..string_array.len() { + if string_array.is_null(i) { + cast_array.append_null() + } else { + if let Some(cast_value) = + cast_string_to_i64(string_array.value(i).trim(), eval_mode)? + { + cast_array.append_value(cast_value); + } else { + cast_array.append_null() + } + } + } + Ok(Arc::new(cast_array.finish())) + } +} + +fn cast_string_to_i8(str: &str, eval_mode: EvalMode) -> CometResult> { + Ok( + do_cast_string_to_integral(str, eval_mode, "TINYINT", i8::MIN as i32, i8::MAX as i32)? + .map(|v| v as i8), + ) +} + +fn cast_string_to_i16(str: &str, eval_mode: EvalMode) -> CometResult> { + Ok( + do_cast_string_to_integral(str, eval_mode, "SMALLINT", i16::MIN as i32, i16::MAX as i32)? + .map(|v| v as i16), + ) +} + +fn cast_string_to_i32(str: &str, eval_mode: EvalMode) -> CometResult> { + do_cast_string_to_i32(str, eval_mode, "INT") +} + +fn cast_string_to_i64(str: &str, eval_mode: EvalMode) -> CometResult> { + do_cast_string_to_i64(str, eval_mode, "BIGINT") +} + +fn do_cast_string_to_integral( + str: &str, + eval_mode: EvalMode, + type_name: &str, + min: i32, + max: i32, +) -> CometResult> { + match do_cast_string_to_i32(str, eval_mode, type_name)? { + None => Ok(None), + Some(v) if v >= min && v <= max => Ok(Some(v)), + _ if eval_mode == EvalMode::Ansi => Err(invalid_value(str, "STRING", type_name)), + _ => Ok(None), + } +} + +fn do_cast_string_to_i32( + str: &str, + eval_mode: EvalMode, + type_name: &str, +) -> CometResult> { + //TODO avoid trim and parse and skip whitespace chars instead + let str = str.trim(); + if str.is_empty() { + return Ok(None); + } + let chars: Vec = str.chars().collect(); + let mut i = 0; + + // skip + or - + let negative = chars[0] == '-'; + if negative || chars[0] == '+' { + i += 1; + if i == chars.len() { + return Ok(None); + } + } + + let mut result = 0; + let radix = 10; + let stop_value = i32::MIN / radix; + + while i < chars.len() { + let b = chars[i]; + i += 1; + + if b == '.' && eval_mode == EvalMode::Legacy { + // truncate decimal in legacy mode + break; + } + + let digit; + if b >= '0' && b <= '9' { + digit = (b as u32) - ('0' as u32); + } else { + return none_or_err(eval_mode, type_name, str); + } + + if result < stop_value { + return none_or_err(eval_mode, type_name, str); + } + result = result * radix - digit as i32; + if result > 0 { + return none_or_err(eval_mode, type_name, str); + } + } + + // This is the case when we've encountered a decimal separator. The fractional + // part will not change the number, but we will verify that the fractional part + // is well-formed. + while i < chars.len() { + let b = chars[i]; + if b < '0' || b > '9' { + return none_or_err(eval_mode, type_name, str); + } + i += 1; + } + + if !negative { + result = -result; + if result < 0 { + return none_or_err(eval_mode, type_name, str); + } + } + Ok(Some(result)) +} + +fn none_or_err(eval_mode: EvalMode, type_name: &str, str: &str) -> CometResult> { + match eval_mode { + EvalMode::Ansi => Err(invalid_value(str, "STRING", type_name)), + _ => Ok(None), + } +} + +fn do_cast_string_to_i64( + str: &str, + eval_mode: EvalMode, + type_name: &str, +) -> CometResult> { + //TODO avoid trim and parse and skip whitespace chars instead + let str = str.trim(); + if str.is_empty() { + return Ok(None); + } + let chars: Vec = str.chars().collect(); + let mut i = 0; + + // skip + or - + let negative = chars[0] == '-'; + if negative || chars[0] == '+' { + i += 1; + if i == chars.len() { + return Ok(None); + } + } + + let mut result = 0; + let radix = 10; + let stop_value = i64::MIN / radix; + + while i < chars.len() { + let b = chars[i]; + i += 1; + + if b == '.' && eval_mode == EvalMode::Legacy { + // truncate decimal in legacy mode + break; + } + + let digit; + if b >= '0' && b <= '9' { + digit = (b as u32) - ('0' as u32); + } else { + return none_or_err(eval_mode, type_name, str); + } + + if result < stop_value { + return none_or_err(eval_mode, type_name, str); + } + result = result * radix - digit as i64; + if result > 0 { + return none_or_err(eval_mode, type_name, str); + } + } + + // This is the case when we've encountered a decimal separator. The fractional + // part will not change the number, but we will verify that the fractional part + // is well-formed. + while i < chars.len() { + let b = chars[i]; + if b < '0' || b > '9' { + return none_or_err(eval_mode, type_name, str); + } + i += 1; + } + + if !negative { + result = -result; + if result < 0 { + return none_or_err(eval_mode, type_name, str); + } + } + Ok(Some(result)) +} + +fn is_java_whitespace(character: char) -> bool { + // TODO we need to use Java's rules here not Rust's (or maybe they are the same?) + character.is_whitespace() +} + +fn invalid_value(value: &str, from_type: &str, to_type: &str) -> CometError { + CometError::CastInvalidValue { + value: value.to_string(), + from_type: from_type.to_string(), + to_type: to_type.to_string(), + } } impl Display for Cast { @@ -222,3 +594,34 @@ impl PhysicalExpr for Cast { self.hash(&mut s); } } + +#[cfg(test)] +mod test { + use super::{cast_string_to_i8, EvalMode}; + + #[test] + fn test_cast_string_as_i8() { + // basic + assert_eq!( + cast_string_to_i8("127", EvalMode::Legacy).unwrap(), + Some(127_i8) + ); + assert_eq!(cast_string_to_i8("128", EvalMode::Legacy).unwrap(), None); + assert!(cast_string_to_i8("128", EvalMode::Ansi).is_err()); + // decimals + assert_eq!( + cast_string_to_i8("0.2", EvalMode::Legacy).unwrap(), + Some(0_i8) + ); + assert_eq!( + cast_string_to_i8(".", EvalMode::Legacy).unwrap(), + Some(0_i8) + ); + // note that TRY behavior is different to LEGACY in some cases + assert_eq!(cast_string_to_i8("0.2", EvalMode::Try).unwrap(), Some(0_i8)); + assert_eq!(cast_string_to_i8(".", EvalMode::Try).unwrap(), None); + // ANSI mode should throw error on decimal + assert!(cast_string_to_i8("0.2", EvalMode::Ansi).is_err()); + assert!(cast_string_to_i8(".", EvalMode::Ansi).is_err()); + } +} diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala index 8abd24598..ae49fc68a 100644 --- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala @@ -66,19 +66,22 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { castTest(testValues, DataTypes.BooleanType) } - ignore("cast string to byte") { - castTest(generateStrings(numericPattern, 8).toDF("a"), DataTypes.ByteType) + test("cast string to byte") { + val testValues = + Seq("", ".", "0", "-0", "+1", "-1", ".2", "-.2", "1e1", "127", "128", "-128", "-129") ++ + generateStrings(numericPattern, 8) + castTest(testValues.toDF("a"), DataTypes.ByteType) } - ignore("cast string to short") { + test("cast string to short") { castTest(generateStrings(numericPattern, 8).toDF("a"), DataTypes.ShortType) } - ignore("cast string to int") { + test("cast string to int") { castTest(generateStrings(numericPattern, 8).toDF("a"), DataTypes.IntegerType) } - ignore("cast string to long") { + test("cast string to long") { castTest(generateStrings(numericPattern, 8).toDF("a"), DataTypes.LongType) } @@ -133,11 +136,12 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { withSQLConf((SQLConf.ANSI_ENABLED.key, "false")) { // cast() should return null for invalid inputs when ansi mode is disabled - val df = data.withColumn("converted", col("a").cast(toType)) + val df = spark.sql(s"select a, cast(a as ${toType.sql}) from t order by a") checkSparkAnswer(df) // try_cast() should always return null for invalid inputs - val df2 = spark.sql(s"select try_cast(a as ${toType.sql}) from t") + val df2 = + spark.sql(s"select a, try_cast(a as ${toType.sql}) from t order by a") checkSparkAnswer(df2) } @@ -154,7 +158,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { // We have to workaround https://github.com/apache/datafusion-comet/issues/293 here by // removing the "Execution error: " error message prefix that is added by DataFusion val cometMessage = actual.getMessage - .substring("Execution error: ".length) + .replace("Execution error: ", "") assert(expected.getMessage == cometMessage) } else { @@ -169,7 +173,8 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { } // try_cast() should always return null for invalid inputs - val df2 = spark.sql(s"select try_cast(a as ${toType.sql}) from t") + val df2 = + spark.sql(s"select a, try_cast(a as ${toType.sql}) from t order by a") checkSparkAnswer(df2) } } diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala index 85e58824c..72b069469 100644 --- a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala +++ b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala @@ -224,6 +224,7 @@ abstract class CometTestBase } val dfComet = Dataset.ofRows(spark, df.logicalPlan) val actual = Try(dfComet.collect()).failed.get + println(dfComet.queryExecution.executedPlan) (expected.get.getCause, actual.getCause) } From 5628a3e1c695213e1a4dc520b2562a7536f6c876 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 23 Apr 2024 07:45:21 -0600 Subject: [PATCH 02/46] remove debug println --- spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala index 72b069469..85e58824c 100644 --- a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala +++ b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala @@ -224,7 +224,6 @@ abstract class CometTestBase } val dfComet = Dataset.ofRows(spark, df.logicalPlan) val actual = Try(dfComet.collect()).failed.get - println(dfComet.queryExecution.executedPlan) (expected.get.getCause, actual.getCause) } From 79633bcfeccaad5ea3b615048ecab5048c001b59 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 23 Apr 2024 07:46:36 -0600 Subject: [PATCH 03/46] update rust tests --- core/src/execution/datafusion/expressions/cast.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/execution/datafusion/expressions/cast.rs b/core/src/execution/datafusion/expressions/cast.rs index afe87bdad..e11333260 100644 --- a/core/src/execution/datafusion/expressions/cast.rs +++ b/core/src/execution/datafusion/expressions/cast.rs @@ -617,8 +617,8 @@ mod test { cast_string_to_i8(".", EvalMode::Legacy).unwrap(), Some(0_i8) ); - // note that TRY behavior is different to LEGACY in some cases - assert_eq!(cast_string_to_i8("0.2", EvalMode::Try).unwrap(), Some(0_i8)); + // TRY should always return null for decimals + assert_eq!(cast_string_to_i8("0.2", EvalMode::Try).unwrap(), None); assert_eq!(cast_string_to_i8(".", EvalMode::Try).unwrap(), None); // ANSI mode should throw error on decimal assert!(cast_string_to_i8("0.2", EvalMode::Ansi).is_err()); From 004f43194d3c96decc717987e65ded1a29ea6ec6 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 23 Apr 2024 09:39:22 -0600 Subject: [PATCH 04/46] clippy --- .../execution/datafusion/expressions/cast.rs | 29 ++++++++++--------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/core/src/execution/datafusion/expressions/cast.rs b/core/src/execution/datafusion/expressions/cast.rs index e11333260..74d1624c0 100644 --- a/core/src/execution/datafusion/expressions/cast.rs +++ b/core/src/execution/datafusion/expressions/cast.rs @@ -29,7 +29,7 @@ use arrow::{ util::display::FormatOptions, }; use arrow_array::{ - types::{Int16Type, Int32Type, Int64Type, Int8Type}, + types::{ArrowDictionaryKeyType, Int16Type, Int32Type, Int64Type, Int8Type}, Array, ArrayRef, BooleanArray, DictionaryArray, GenericStringArray, OffsetSizeTrait, PrimitiveArray, }; @@ -150,11 +150,13 @@ impl Cast { Ok(result) } - fn unpack_dict_string_array(array: &ArrayRef) -> DataFusionResult { + fn unpack_dict_string_array( + array: &ArrayRef, + ) -> DataFusionResult { let dict_array = array .as_any() - .downcast_ref::>() - .expect("DictionaryArray"); + .downcast_ref::>() + .expect("DictionaryArray"); let unpacked_array = take(dict_array.values().as_ref(), dict_array.keys(), None)?; Ok(unpacked_array) @@ -387,12 +389,11 @@ fn do_cast_string_to_i32( break; } - let digit; - if b >= '0' && b <= '9' { - digit = (b as u32) - ('0' as u32); + let digit = if ('0'..='9').contains(&b) { + (b as u32) - ('0' as u32) } else { return none_or_err(eval_mode, type_name, str); - } + }; if result < stop_value { return none_or_err(eval_mode, type_name, str); @@ -408,7 +409,7 @@ fn do_cast_string_to_i32( // is well-formed. while i < chars.len() { let b = chars[i]; - if b < '0' || b > '9' { + if !('0'..='9').contains(&b) { return none_or_err(eval_mode, type_name, str); } i += 1; @@ -423,6 +424,7 @@ fn do_cast_string_to_i32( Ok(Some(result)) } +/// Either return Ok(None) or Err(CometError::CastInvalidValue) depending on the evaluation mode fn none_or_err(eval_mode: EvalMode, type_name: &str, str: &str) -> CometResult> { match eval_mode { EvalMode::Ansi => Err(invalid_value(str, "STRING", type_name)), @@ -465,12 +467,11 @@ fn do_cast_string_to_i64( break; } - let digit; - if b >= '0' && b <= '9' { - digit = (b as u32) - ('0' as u32); + let digit = if ('0'..='9').contains(&b) { + (b as u32) - ('0' as u32) } else { return none_or_err(eval_mode, type_name, str); - } + }; if result < stop_value { return none_or_err(eval_mode, type_name, str); @@ -486,7 +487,7 @@ fn do_cast_string_to_i64( // is well-formed. while i < chars.len() { let b = chars[i]; - if b < '0' || b > '9' { + if !('0'..='9').contains(&b) { return none_or_err(eval_mode, type_name, str); } i += 1; From 7cf98169aa878c17f69cb925a50e473c40910740 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 23 Apr 2024 09:59:37 -0600 Subject: [PATCH 05/46] more clippy --- .../execution/datafusion/expressions/cast.rs | 56 ++++++++----------- 1 file changed, 24 insertions(+), 32 deletions(-) diff --git a/core/src/execution/datafusion/expressions/cast.rs b/core/src/execution/datafusion/expressions/cast.rs index 74d1624c0..f258d1afc 100644 --- a/core/src/execution/datafusion/expressions/cast.rs +++ b/core/src/execution/datafusion/expressions/cast.rs @@ -213,14 +213,12 @@ impl Cast { for i in 0..string_array.len() { if string_array.is_null(i) { cast_array.append_null() + } else if let Some(cast_value) = + cast_string_to_i8(string_array.value(i).trim(), eval_mode)? + { + cast_array.append_value(cast_value); } else { - if let Some(cast_value) = - cast_string_to_i8(string_array.value(i).trim(), eval_mode)? - { - cast_array.append_value(cast_value); - } else { - cast_array.append_null() - } + cast_array.append_null() } } Ok(Arc::new(cast_array.finish())) @@ -243,14 +241,12 @@ impl Cast { for i in 0..string_array.len() { if string_array.is_null(i) { cast_array.append_null() + } else if let Some(cast_value) = + cast_string_to_i16(string_array.value(i).trim(), eval_mode)? + { + cast_array.append_value(cast_value); } else { - if let Some(cast_value) = - cast_string_to_i16(string_array.value(i).trim(), eval_mode)? - { - cast_array.append_value(cast_value); - } else { - cast_array.append_null() - } + cast_array.append_null() } } Ok(Arc::new(cast_array.finish())) @@ -273,14 +269,12 @@ impl Cast { for i in 0..string_array.len() { if string_array.is_null(i) { cast_array.append_null() + } else if let Some(cast_value) = + cast_string_to_i32(string_array.value(i).trim(), eval_mode)? + { + cast_array.append_value(cast_value); } else { - if let Some(cast_value) = - cast_string_to_i32(string_array.value(i).trim(), eval_mode)? - { - cast_array.append_value(cast_value); - } else { - cast_array.append_null() - } + cast_array.append_null() } } Ok(Arc::new(cast_array.finish())) @@ -303,14 +297,12 @@ impl Cast { for i in 0..string_array.len() { if string_array.is_null(i) { cast_array.append_null() + } else if let Some(cast_value) = + cast_string_to_i64(string_array.value(i).trim(), eval_mode)? + { + cast_array.append_value(cast_value); } else { - if let Some(cast_value) = - cast_string_to_i64(string_array.value(i).trim(), eval_mode)? - { - cast_array.append_value(cast_value); - } else { - cast_array.append_null() - } + cast_array.append_null() } } Ok(Arc::new(cast_array.finish())) @@ -389,7 +381,7 @@ fn do_cast_string_to_i32( break; } - let digit = if ('0'..='9').contains(&b) { + let digit = if b.is_ascii_digit() { (b as u32) - ('0' as u32) } else { return none_or_err(eval_mode, type_name, str); @@ -409,7 +401,7 @@ fn do_cast_string_to_i32( // is well-formed. while i < chars.len() { let b = chars[i]; - if !('0'..='9').contains(&b) { + if !b.is_ascii_digit() { return none_or_err(eval_mode, type_name, str); } i += 1; @@ -467,7 +459,7 @@ fn do_cast_string_to_i64( break; } - let digit = if ('0'..='9').contains(&b) { + let digit = if b.is_ascii_digit() { (b as u32) - ('0' as u32) } else { return none_or_err(eval_mode, type_name, str); @@ -487,7 +479,7 @@ fn do_cast_string_to_i64( // is well-formed. while i < chars.len() { let b = chars[i]; - if !('0'..='9').contains(&b) { + if !b.is_ascii_digit() { return none_or_err(eval_mode, type_name, str); } i += 1; From 73cbe8f233a8d1684ed7dcbb12b29139233c0ae1 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 23 Apr 2024 10:42:32 -0600 Subject: [PATCH 06/46] minor refactort to reduce code duplication --- .../execution/datafusion/expressions/cast.rs | 78 +++++++++---------- 1 file changed, 37 insertions(+), 41 deletions(-) diff --git a/core/src/execution/datafusion/expressions/cast.rs b/core/src/execution/datafusion/expressions/cast.rs index f258d1afc..e7e126b92 100644 --- a/core/src/execution/datafusion/expressions/cast.rs +++ b/core/src/execution/datafusion/expressions/cast.rs @@ -107,47 +107,48 @@ impl Cast { (DataType::LargeUtf8, DataType::Boolean) => { Self::spark_cast_utf8_to_boolean::(&array, self.eval_mode)? } - (DataType::Utf8, DataType::Int8) => { - Self::spark_cast_utf8_to_i8::(&array, self.eval_mode)? - } - (DataType::Dictionary(a, b), DataType::Int8) - if a.as_ref() == &DataType::Int32 && b.as_ref() == &DataType::Utf8 => + ( + DataType::Utf8, + DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64, + ) => match to_type { + DataType::Int8 => Self::spark_cast_utf8_to_i8::(&array, self.eval_mode)?, + DataType::Int16 => Self::spark_cast_utf8_to_i16::(&array, self.eval_mode)?, + DataType::Int32 => Self::spark_cast_utf8_to_i32::(&array, self.eval_mode)?, + DataType::Int64 => Self::spark_cast_utf8_to_i64::(&array, self.eval_mode)?, + _ => unreachable!("invalid integral type in cast from string"), + }, + ( + DataType::Dictionary(key_type, value_type), + DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64, + ) if key_type.as_ref() == &DataType::Int32 + && value_type.as_ref() == &DataType::Utf8 => { // TODO file follow on issue for optimizing this to avoid unpacking first let unpacked_array = Self::unpack_dict_string_array::(&array)?; - Self::spark_cast_utf8_to_i8::(&unpacked_array, self.eval_mode)? - } - (DataType::Utf8, DataType::Int16) => { - Self::spark_cast_utf8_to_i16::(&array, self.eval_mode)? - } - (DataType::Dictionary(a, b), DataType::Int16) - if a.as_ref() == &DataType::Int32 && b.as_ref() == &DataType::Utf8 => - { - let unpacked_array = Self::unpack_dict_string_array::(&array)?; - Self::spark_cast_utf8_to_i16::(&unpacked_array, self.eval_mode)? - } - (DataType::Utf8, DataType::Int32) => { - Self::spark_cast_utf8_to_i32::(&array, self.eval_mode)? - } - (DataType::Dictionary(a, b), DataType::Int32) - if a.as_ref() == &DataType::Int32 && b.as_ref() == &DataType::Utf8 => - { - let unpacked_array = Self::unpack_dict_string_array::(&array)?; - Self::spark_cast_utf8_to_i32::(&unpacked_array, self.eval_mode)? + match to_type { + DataType::Int8 => { + Self::spark_cast_utf8_to_i8::(&unpacked_array, self.eval_mode)? + } + DataType::Int16 => { + Self::spark_cast_utf8_to_i16::(&unpacked_array, self.eval_mode)? + } + DataType::Int32 => { + Self::spark_cast_utf8_to_i32::(&unpacked_array, self.eval_mode)? + } + DataType::Int64 => { + Self::spark_cast_utf8_to_i64::(&unpacked_array, self.eval_mode)? + } + _ => { + unreachable!("invalid integral type in cast from dictionary-encoded string") + } + } } - (DataType::Utf8, DataType::Int64) => { - Self::spark_cast_utf8_to_i64::(&array, self.eval_mode)? - } - (DataType::Dictionary(a, b), DataType::Int64) - if a.as_ref() == &DataType::Int32 && b.as_ref() == &DataType::Utf8 => - { - let unpacked_array = Self::unpack_dict_string_array::(&array)?; - Self::spark_cast_utf8_to_i64::(&unpacked_array, self.eval_mode)? - } - _ => cast_with_options(&array, to_type, &CAST_OPTIONS)?, + _ => { + // when we have no Spark-specific casting we delegate to DataFusion + cast_with_options(&array, to_type, &CAST_OPTIONS)? + }, }; - let result = spark_cast(cast_result, from_type, to_type); - Ok(result) + Ok(spark_cast(cast_result, from_type, to_type)) } fn unpack_dict_string_array( @@ -494,11 +495,6 @@ fn do_cast_string_to_i64( Ok(Some(result)) } -fn is_java_whitespace(character: char) -> bool { - // TODO we need to use Java's rules here not Rust's (or maybe they are the same?) - character.is_whitespace() -} - fn invalid_value(value: &str, from_type: &str, to_type: &str) -> CometError { CometError::CastInvalidValue { value: value.to_string(), From 32632c11361406e59f632fe0d5962f9cd6285175 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 23 Apr 2024 12:02:36 -0600 Subject: [PATCH 07/46] introduce accumulator --- .../execution/datafusion/expressions/cast.rs | 249 +++++++++++------- 1 file changed, 158 insertions(+), 91 deletions(-) diff --git a/core/src/execution/datafusion/expressions/cast.rs b/core/src/execution/datafusion/expressions/cast.rs index e7e126b92..938d427ca 100644 --- a/core/src/execution/datafusion/expressions/cast.rs +++ b/core/src/execution/datafusion/expressions/cast.rs @@ -146,7 +146,7 @@ impl Cast { _ => { // when we have no Spark-specific casting we delegate to DataFusion cast_with_options(&array, to_type, &CAST_OPTIONS)? - }, + } }; Ok(spark_cast(cast_result, from_type, to_type)) } @@ -311,35 +311,49 @@ impl Cast { } fn cast_string_to_i8(str: &str, eval_mode: EvalMode) -> CometResult> { - Ok( - do_cast_string_to_integral(str, eval_mode, "TINYINT", i8::MIN as i32, i8::MAX as i32)? - .map(|v| v as i8), - ) + Ok(cast_string_to_integral_with_range_check( + str, + eval_mode, + "TINYINT", + i8::MIN as i32, + i8::MAX as i32, + )? + .map(|v| v as i8)) } fn cast_string_to_i16(str: &str, eval_mode: EvalMode) -> CometResult> { - Ok( - do_cast_string_to_integral(str, eval_mode, "SMALLINT", i16::MIN as i32, i16::MAX as i32)? - .map(|v| v as i16), - ) + Ok(cast_string_to_integral_with_range_check( + str, + eval_mode, + "SMALLINT", + i16::MIN as i32, + i16::MAX as i32, + )? + .map(|v| v as i16)) } fn cast_string_to_i32(str: &str, eval_mode: EvalMode) -> CometResult> { - do_cast_string_to_i32(str, eval_mode, "INT") + let mut accum = CastStringToIntegral32::default(); + do_cast_string_to_int(&mut accum, str, eval_mode, "INT")?; + Ok(accum.result) } fn cast_string_to_i64(str: &str, eval_mode: EvalMode) -> CometResult> { - do_cast_string_to_i64(str, eval_mode, "BIGINT") + let mut accum = CastStringToIntegral64::default(); + do_cast_string_to_int(&mut accum, str, eval_mode, "BIGINT")?; + Ok(accum.result) } -fn do_cast_string_to_integral( +fn cast_string_to_integral_with_range_check( str: &str, eval_mode: EvalMode, type_name: &str, min: i32, max: i32, ) -> CometResult> { - match do_cast_string_to_i32(str, eval_mode, type_name)? { + let mut accum = CastStringToIntegral32::default(); + do_cast_string_to_int(&mut accum, str, eval_mode, type_name)?; + match accum.result { None => Ok(None), Some(v) if v >= min && v <= max => Ok(Some(v)), _ if eval_mode == EvalMode::Ansi => Err(invalid_value(str, "STRING", type_name)), @@ -347,93 +361,149 @@ fn do_cast_string_to_integral( } } -fn do_cast_string_to_i32( - str: &str, - eval_mode: EvalMode, - type_name: &str, -) -> CometResult> { - //TODO avoid trim and parse and skip whitespace chars instead - let str = str.trim(); - if str.is_empty() { - return Ok(None); - } - let chars: Vec = str.chars().collect(); - let mut i = 0; - - // skip + or - - let negative = chars[0] == '-'; - if negative || chars[0] == '+' { - i += 1; - if i == chars.len() { - return Ok(None); - } - } +trait CastStringToIntegral { + fn accumulate( + &mut self, + eval_mode: EvalMode, + type_name: &str, + str: &str, + digit: u32, + ) -> CometResult<()>; - let mut result = 0; - let radix = 10; - let stop_value = i32::MIN / radix; + fn reset(&mut self); - while i < chars.len() { - let b = chars[i]; - i += 1; + fn finish( + &mut self, + eval_mode: EvalMode, + type_name: &str, + str: &str, + negative: bool, + ) -> CometResult<()>; +} +struct CastStringToIntegral32 { + negative: bool, + result: Option, + radix: i32, +} - if b == '.' && eval_mode == EvalMode::Legacy { - // truncate decimal in legacy mode - break; +impl Default for CastStringToIntegral32 { + fn default() -> Self { + Self { + negative: false, + result: Some(0), + radix: 10, } + } +} - let digit = if b.is_ascii_digit() { - (b as u32) - ('0' as u32) - } else { - return none_or_err(eval_mode, type_name, str); - }; - - if result < stop_value { +impl CastStringToIntegral for CastStringToIntegral32 { + fn accumulate( + &mut self, + eval_mode: EvalMode, + type_name: &str, + str: &str, + digit: u32, + ) -> CometResult<()> { + if self.result.is_some() && self.result.unwrap() < i32::MIN / self.radix { + self.reset(); return none_or_err(eval_mode, type_name, str); } - result = result * radix - digit as i32; - if result > 0 { + self.result = Some(self.result.unwrap_or(0) * self.radix - digit as i32); + if self.result.unwrap() > 0 { + self.reset(); return none_or_err(eval_mode, type_name, str); } + Ok(()) + } + fn reset(&mut self) { + self.result = None; } - // This is the case when we've encountered a decimal separator. The fractional - // part will not change the number, but we will verify that the fractional part - // is well-formed. - while i < chars.len() { - let b = chars[i]; - if !b.is_ascii_digit() { - return none_or_err(eval_mode, type_name, str); + fn finish( + &mut self, + eval_mode: EvalMode, + type_name: &str, + str: &str, + negative: bool, + ) -> CometResult<()> { + if self.result.is_some() && !negative { + self.result = Some(-self.result.unwrap()); + if self.result.unwrap() < 0 { + return none_or_err(eval_mode, type_name, str); + } } - i += 1; + Ok(()) } +} - if !negative { - result = -result; - if result < 0 { - return none_or_err(eval_mode, type_name, str); +struct CastStringToIntegral64 { + negative: bool, + result: Option, + radix: i64, +} + +impl Default for CastStringToIntegral64 { + fn default() -> Self { + Self { + negative: false, + result: Some(0), + radix: 10, } } - Ok(Some(result)) } -/// Either return Ok(None) or Err(CometError::CastInvalidValue) depending on the evaluation mode -fn none_or_err(eval_mode: EvalMode, type_name: &str, str: &str) -> CometResult> { - match eval_mode { - EvalMode::Ansi => Err(invalid_value(str, "STRING", type_name)), - _ => Ok(None), +impl CastStringToIntegral for CastStringToIntegral64 { + fn accumulate( + &mut self, + eval_mode: EvalMode, + type_name: &str, + str: &str, + digit: u32, + ) -> CometResult<()> { + if self.result.unwrap_or(0) < i64::MIN / self.radix { + self.reset(); + return none_or_err(eval_mode, type_name, str); + } + self.result = Some(self.result.unwrap_or(0) * self.radix - digit as i64); + if self.result.unwrap() > 0 { + self.reset(); + return none_or_err(eval_mode, type_name, str); + } + Ok(()) + } + + fn reset(&mut self) { + self.result = None; + } + + fn finish( + &mut self, + eval_mode: EvalMode, + type_name: &str, + str: &str, + negative: bool, + ) -> CometResult<()> { + if self.result.is_some() && !negative { + self.result = Some(-self.result.unwrap()); + if self.result.unwrap() < 0 { + return none_or_err(eval_mode, type_name, str); + } + } + Ok(()) } } -fn do_cast_string_to_i64( +fn do_cast_string_to_int( + accumulator: &mut dyn CastStringToIntegral, str: &str, eval_mode: EvalMode, type_name: &str, -) -> CometResult> { +) -> CometResult<()> { //TODO avoid trim and parse and skip whitespace chars instead let str = str.trim(); if str.is_empty() { - return Ok(None); + accumulator.reset(); + return Ok(()); } let chars: Vec = str.chars().collect(); let mut i = 0; @@ -443,14 +513,11 @@ fn do_cast_string_to_i64( if negative || chars[0] == '+' { i += 1; if i == chars.len() { - return Ok(None); + accumulator.reset(); + return Ok(()); } } - let mut result = 0; - let radix = 10; - let stop_value = i64::MIN / radix; - while i < chars.len() { let b = chars[i]; i += 1; @@ -463,16 +530,11 @@ fn do_cast_string_to_i64( let digit = if b.is_ascii_digit() { (b as u32) - ('0' as u32) } else { + accumulator.reset(); return none_or_err(eval_mode, type_name, str); }; - if result < stop_value { - return none_or_err(eval_mode, type_name, str); - } - result = result * radix - digit as i64; - if result > 0 { - return none_or_err(eval_mode, type_name, str); - } + accumulator.accumulate(eval_mode, type_name, str, digit)?; } // This is the case when we've encountered a decimal separator. The fractional @@ -481,18 +543,23 @@ fn do_cast_string_to_i64( while i < chars.len() { let b = chars[i]; if !b.is_ascii_digit() { + accumulator.reset(); return none_or_err(eval_mode, type_name, str); } i += 1; } - if !negative { - result = -result; - if result < 0 { - return none_or_err(eval_mode, type_name, str); - } + accumulator.finish(eval_mode, type_name, str, negative)?; + + Ok(()) +} + +/// Either return Ok(None) or Err(CometError::CastInvalidValue) depending on the evaluation mode +fn none_or_err(eval_mode: EvalMode, type_name: &str, str: &str) -> CometResult<()> { + match eval_mode { + EvalMode::Ansi => Err(invalid_value(str, "STRING", type_name)), + _ => Ok(()), } - Ok(Some(result)) } fn invalid_value(value: &str, from_type: &str, to_type: &str) -> CometError { From 2ea444236f40cf9a25211264ba311752c84c39a1 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 23 Apr 2024 12:18:24 -0600 Subject: [PATCH 08/46] small refactor --- .../execution/datafusion/expressions/cast.rs | 69 +++++++++---------- 1 file changed, 34 insertions(+), 35 deletions(-) diff --git a/core/src/execution/datafusion/expressions/cast.rs b/core/src/execution/datafusion/expressions/cast.rs index 938d427ca..63e714b39 100644 --- a/core/src/execution/datafusion/expressions/cast.rs +++ b/core/src/execution/datafusion/expressions/cast.rs @@ -110,13 +110,28 @@ impl Cast { ( DataType::Utf8, DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64, - ) => match to_type { - DataType::Int8 => Self::spark_cast_utf8_to_i8::(&array, self.eval_mode)?, - DataType::Int16 => Self::spark_cast_utf8_to_i16::(&array, self.eval_mode)?, - DataType::Int32 => Self::spark_cast_utf8_to_i32::(&array, self.eval_mode)?, - DataType::Int64 => Self::spark_cast_utf8_to_i64::(&array, self.eval_mode)?, - _ => unreachable!("invalid integral type in cast from string"), - }, + ) => { + let string_array = array + .as_any() + .downcast_ref::>() + .expect("spark_cast_utf8_to_i8 expected a string array"); + + match to_type { + DataType::Int8 => { + Self::spark_cast_utf8_to_i8::(string_array, self.eval_mode)? + } + DataType::Int16 => { + Self::spark_cast_utf8_to_i16::(string_array, self.eval_mode)? + } + DataType::Int32 => { + Self::spark_cast_utf8_to_i32::(string_array, self.eval_mode)? + } + DataType::Int64 => { + Self::spark_cast_utf8_to_i64::(string_array, self.eval_mode)? + } + _ => unreachable!("invalid integral type in cast from string"), + } + } ( DataType::Dictionary(key_type, value_type), DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64, @@ -125,18 +140,22 @@ impl Cast { { // TODO file follow on issue for optimizing this to avoid unpacking first let unpacked_array = Self::unpack_dict_string_array::(&array)?; + let string_array = unpacked_array + .as_any() + .downcast_ref::>() + .expect("spark_cast_utf8_to_i8 expected a string array"); match to_type { DataType::Int8 => { - Self::spark_cast_utf8_to_i8::(&unpacked_array, self.eval_mode)? + Self::spark_cast_utf8_to_i8::(&string_array, self.eval_mode)? } DataType::Int16 => { - Self::spark_cast_utf8_to_i16::(&unpacked_array, self.eval_mode)? + Self::spark_cast_utf8_to_i16::(&string_array, self.eval_mode)? } DataType::Int32 => { - Self::spark_cast_utf8_to_i32::(&unpacked_array, self.eval_mode)? + Self::spark_cast_utf8_to_i32::(&string_array, self.eval_mode)? } DataType::Int64 => { - Self::spark_cast_utf8_to_i64::(&unpacked_array, self.eval_mode)? + Self::spark_cast_utf8_to_i64::(&string_array, self.eval_mode)? } _ => { unreachable!("invalid integral type in cast from dictionary-encoded string") @@ -198,17 +217,12 @@ impl Cast { // TODO reduce code duplication fn spark_cast_utf8_to_i8( - from: &dyn Array, + string_array: &GenericStringArray, eval_mode: EvalMode, ) -> CometResult where OffsetSize: OffsetSizeTrait, { - let string_array = from - .as_any() - .downcast_ref::>() - .expect("spark_cast_utf8_to_i8 expected a string array"); - // cast the dictionary values from string to int8 let mut cast_array = PrimitiveArray::::builder(string_array.len()); for i in 0..string_array.len() { @@ -226,17 +240,12 @@ impl Cast { } fn spark_cast_utf8_to_i16( - from: &dyn Array, + string_array: &GenericStringArray, eval_mode: EvalMode, ) -> CometResult where OffsetSize: OffsetSizeTrait, { - let string_array = from - .as_any() - .downcast_ref::>() - .expect("spark_cast_utf8_to_i16 expected a string array"); - // cast the dictionary values from string to int8 let mut cast_array = PrimitiveArray::::builder(string_array.len()); for i in 0..string_array.len() { @@ -254,17 +263,12 @@ impl Cast { } fn spark_cast_utf8_to_i32( - from: &dyn Array, + string_array: &GenericStringArray, eval_mode: EvalMode, ) -> CometResult where OffsetSize: OffsetSizeTrait, { - let string_array = from - .as_any() - .downcast_ref::>() - .expect("spark_cast_utf8_to_i32 expected a string array"); - // cast the dictionary values from string to int8 let mut cast_array = PrimitiveArray::::builder(string_array.len()); for i in 0..string_array.len() { @@ -282,17 +286,12 @@ impl Cast { } fn spark_cast_utf8_to_i64( - from: &dyn Array, + string_array: &GenericStringArray, eval_mode: EvalMode, ) -> CometResult where OffsetSize: OffsetSizeTrait, { - let string_array = from - .as_any() - .downcast_ref::>() - .expect("spark_cast_utf8_to_i64 expected a string array"); - // cast the dictionary values from string to int8 let mut cast_array = PrimitiveArray::::builder(string_array.len()); for i in 0..string_array.len() { From d4fd8ff5509943f177a94cfcbb8f12cc93a76342 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 23 Apr 2024 12:56:51 -0600 Subject: [PATCH 09/46] introduce a macro --- .../execution/datafusion/expressions/cast.rs | 200 ++++++------------ 1 file changed, 62 insertions(+), 138 deletions(-) diff --git a/core/src/execution/datafusion/expressions/cast.rs b/core/src/execution/datafusion/expressions/cast.rs index 63e714b39..7b1229aa6 100644 --- a/core/src/execution/datafusion/expressions/cast.rs +++ b/core/src/execution/datafusion/expressions/cast.rs @@ -68,6 +68,25 @@ pub struct Cast { pub timezone: String, } +macro_rules! spark_cast_utf8_to_integral { + ($string_array:expr, $eval_mode:expr, $array_type:ty, $cast_method:ident) => {{ + let mut cast_array = PrimitiveArray::<$array_type>::builder($string_array.len()); + for i in 0..$string_array.len() { + if $string_array.is_null(i) { + cast_array.append_null() + } else if let Some(cast_value) = + $cast_method($string_array.value(i).trim(), $eval_mode)? + { + cast_array.append_value(cast_value); + } else { + cast_array.append_null() + } + } + let result: CometResult = Ok(Arc::new(cast_array.finish()) as ArrayRef); + result + }}; +} + impl Cast { pub fn new( child: Arc, @@ -110,57 +129,19 @@ impl Cast { ( DataType::Utf8, DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64, - ) => { - let string_array = array - .as_any() - .downcast_ref::>() - .expect("spark_cast_utf8_to_i8 expected a string array"); - - match to_type { - DataType::Int8 => { - Self::spark_cast_utf8_to_i8::(string_array, self.eval_mode)? - } - DataType::Int16 => { - Self::spark_cast_utf8_to_i16::(string_array, self.eval_mode)? - } - DataType::Int32 => { - Self::spark_cast_utf8_to_i32::(string_array, self.eval_mode)? - } - DataType::Int64 => { - Self::spark_cast_utf8_to_i64::(string_array, self.eval_mode)? - } - _ => unreachable!("invalid integral type in cast from string"), - } - } + ) => Self::spark_cast_string_to_integral(to_type, &array, self.eval_mode)?, ( DataType::Dictionary(key_type, value_type), DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64, ) if key_type.as_ref() == &DataType::Int32 && value_type.as_ref() == &DataType::Utf8 => { - // TODO file follow on issue for optimizing this to avoid unpacking first + // Note that we are unpacking a dictionary-encoded array and then performing + // the cast. We could potentially improve performance here by casting the + // dictionary values directly without unpacking the array first, although this + // would add more complexity to the code let unpacked_array = Self::unpack_dict_string_array::(&array)?; - let string_array = unpacked_array - .as_any() - .downcast_ref::>() - .expect("spark_cast_utf8_to_i8 expected a string array"); - match to_type { - DataType::Int8 => { - Self::spark_cast_utf8_to_i8::(&string_array, self.eval_mode)? - } - DataType::Int16 => { - Self::spark_cast_utf8_to_i16::(&string_array, self.eval_mode)? - } - DataType::Int32 => { - Self::spark_cast_utf8_to_i32::(&string_array, self.eval_mode)? - } - DataType::Int64 => { - Self::spark_cast_utf8_to_i64::(&string_array, self.eval_mode)? - } - _ => { - unreachable!("invalid integral type in cast from dictionary-encoded string") - } - } + Self::spark_cast_string_to_integral(to_type, &unpacked_array, self.eval_mode)? } _ => { // when we have no Spark-specific casting we delegate to DataFusion @@ -170,6 +151,43 @@ impl Cast { Ok(spark_cast(cast_result, from_type, to_type)) } + fn spark_cast_string_to_integral( + to_type: &DataType, + array: &ArrayRef, + eval_mode: EvalMode, + ) -> CometResult { + let string_array = array + .as_any() + .downcast_ref::>() + .expect("spark_cast_string_to_integral expected a string array"); + + let cast_array: ArrayRef = match to_type { + DataType::Int8 => { + spark_cast_utf8_to_integral!(string_array, eval_mode, Int8Type, cast_string_to_i8)? + } + DataType::Int16 => spark_cast_utf8_to_integral!( + string_array, + eval_mode, + Int16Type, + cast_string_to_i16 + )?, + DataType::Int32 => spark_cast_utf8_to_integral!( + string_array, + eval_mode, + Int32Type, + cast_string_to_i32 + )?, + DataType::Int64 => spark_cast_utf8_to_integral!( + string_array, + eval_mode, + Int64Type, + cast_string_to_i64 + )?, + _ => unreachable!("invalid integral type in cast from string"), + }; + Ok(cast_array) + } + fn unpack_dict_string_array( array: &ArrayRef, ) -> DataFusionResult { @@ -213,100 +231,6 @@ impl Cast { Ok(Arc::new(output_array)) } - - // TODO reduce code duplication - - fn spark_cast_utf8_to_i8( - string_array: &GenericStringArray, - eval_mode: EvalMode, - ) -> CometResult - where - OffsetSize: OffsetSizeTrait, - { - // cast the dictionary values from string to int8 - let mut cast_array = PrimitiveArray::::builder(string_array.len()); - for i in 0..string_array.len() { - if string_array.is_null(i) { - cast_array.append_null() - } else if let Some(cast_value) = - cast_string_to_i8(string_array.value(i).trim(), eval_mode)? - { - cast_array.append_value(cast_value); - } else { - cast_array.append_null() - } - } - Ok(Arc::new(cast_array.finish())) - } - - fn spark_cast_utf8_to_i16( - string_array: &GenericStringArray, - eval_mode: EvalMode, - ) -> CometResult - where - OffsetSize: OffsetSizeTrait, - { - // cast the dictionary values from string to int8 - let mut cast_array = PrimitiveArray::::builder(string_array.len()); - for i in 0..string_array.len() { - if string_array.is_null(i) { - cast_array.append_null() - } else if let Some(cast_value) = - cast_string_to_i16(string_array.value(i).trim(), eval_mode)? - { - cast_array.append_value(cast_value); - } else { - cast_array.append_null() - } - } - Ok(Arc::new(cast_array.finish())) - } - - fn spark_cast_utf8_to_i32( - string_array: &GenericStringArray, - eval_mode: EvalMode, - ) -> CometResult - where - OffsetSize: OffsetSizeTrait, - { - // cast the dictionary values from string to int8 - let mut cast_array = PrimitiveArray::::builder(string_array.len()); - for i in 0..string_array.len() { - if string_array.is_null(i) { - cast_array.append_null() - } else if let Some(cast_value) = - cast_string_to_i32(string_array.value(i).trim(), eval_mode)? - { - cast_array.append_value(cast_value); - } else { - cast_array.append_null() - } - } - Ok(Arc::new(cast_array.finish())) - } - - fn spark_cast_utf8_to_i64( - string_array: &GenericStringArray, - eval_mode: EvalMode, - ) -> CometResult - where - OffsetSize: OffsetSizeTrait, - { - // cast the dictionary values from string to int8 - let mut cast_array = PrimitiveArray::::builder(string_array.len()); - for i in 0..string_array.len() { - if string_array.is_null(i) { - cast_array.append_null() - } else if let Some(cast_value) = - cast_string_to_i64(string_array.value(i).trim(), eval_mode)? - { - cast_array.append_value(cast_value); - } else { - cast_array.append_null() - } - } - Ok(Arc::new(cast_array.finish())) - } } fn cast_string_to_i8(str: &str, eval_mode: EvalMode) -> CometResult> { From 92029ba6a050b6f6c654110fa614ea4fcb37a55f Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 23 Apr 2024 13:11:48 -0600 Subject: [PATCH 10/46] remove overhead of trim on each string --- .../execution/datafusion/expressions/cast.rs | 30 ++++++++++++++----- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/core/src/execution/datafusion/expressions/cast.rs b/core/src/execution/datafusion/expressions/cast.rs index 7b1229aa6..170b07740 100644 --- a/core/src/execution/datafusion/expressions/cast.rs +++ b/core/src/execution/datafusion/expressions/cast.rs @@ -422,26 +422,40 @@ fn do_cast_string_to_int( eval_mode: EvalMode, type_name: &str, ) -> CometResult<()> { - //TODO avoid trim and parse and skip whitespace chars instead - let str = str.trim(); - if str.is_empty() { + + // TODO avoid building a vec of chars + let chars: Vec = str.chars().collect(); + + let mut i = 0; + let mut end = chars.len(); + + // skip leading whitespace + while i < end && chars[i].is_whitespace() { + i += 1; + } + + // skip trailing whitespace + while end > i && chars[end-1].is_whitespace() { + end -= 1; + } + + // check for empty string + if i == end { accumulator.reset(); return Ok(()); } - let chars: Vec = str.chars().collect(); - let mut i = 0; // skip + or - let negative = chars[0] == '-'; if negative || chars[0] == '+' { i += 1; - if i == chars.len() { + if i == end { accumulator.reset(); return Ok(()); } } - while i < chars.len() { + while i < end { let b = chars[i]; i += 1; @@ -463,7 +477,7 @@ fn do_cast_string_to_int( // This is the case when we've encountered a decimal separator. The fractional // part will not change the number, but we will verify that the fractional part // is well-formed. - while i < chars.len() { + while i < end { let b = chars[i]; if !b.is_ascii_digit() { accumulator.reset(); From 85c646c2745a9d50b8061588d505e21ce501d8bd Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 23 Apr 2024 13:22:43 -0600 Subject: [PATCH 11/46] comment --- core/src/execution/datafusion/expressions/cast.rs | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/core/src/execution/datafusion/expressions/cast.rs b/core/src/execution/datafusion/expressions/cast.rs index 170b07740..231c94c8e 100644 --- a/core/src/execution/datafusion/expressions/cast.rs +++ b/core/src/execution/datafusion/expressions/cast.rs @@ -284,6 +284,10 @@ fn cast_string_to_integral_with_range_check( } } +/// We support parsing strings to i32 and i64 to match Spark's logic. Support for i8 and i16 is +/// implemented by first parsing as i32 and then downcasting. The CastStringToIntegral trait is +/// introduced so that we can have the parsing logic delegate either to an i32 or i64 accumulator +/// and avoid the need to use macros here. trait CastStringToIntegral { fn accumulate( &mut self, @@ -422,10 +426,7 @@ fn do_cast_string_to_int( eval_mode: EvalMode, type_name: &str, ) -> CometResult<()> { - - // TODO avoid building a vec of chars let chars: Vec = str.chars().collect(); - let mut i = 0; let mut end = chars.len(); @@ -435,7 +436,7 @@ fn do_cast_string_to_int( } // skip trailing whitespace - while end > i && chars[end-1].is_whitespace() { + while end > i && chars[end - 1].is_whitespace() { end -= 1; } From 89e8cca1c88a79f39dc96922f0438c4f57ad9cfb Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 23 Apr 2024 13:41:02 -0600 Subject: [PATCH 12/46] remove spark from some method names --- .../execution/datafusion/expressions/cast.rs | 39 +++++++------------ 1 file changed, 15 insertions(+), 24 deletions(-) diff --git a/core/src/execution/datafusion/expressions/cast.rs b/core/src/execution/datafusion/expressions/cast.rs index 231c94c8e..d4503aac5 100644 --- a/core/src/execution/datafusion/expressions/cast.rs +++ b/core/src/execution/datafusion/expressions/cast.rs @@ -68,7 +68,7 @@ pub struct Cast { pub timezone: String, } -macro_rules! spark_cast_utf8_to_integral { +macro_rules! cast_utf8_to_integral { ($string_array:expr, $eval_mode:expr, $array_type:ty, $cast_method:ident) => {{ let mut cast_array = PrimitiveArray::<$array_type>::builder($string_array.len()); for i in 0..$string_array.len() { @@ -129,7 +129,7 @@ impl Cast { ( DataType::Utf8, DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64, - ) => Self::spark_cast_string_to_integral(to_type, &array, self.eval_mode)?, + ) => Self::cast_string_to_integral(to_type, &array, self.eval_mode)?, ( DataType::Dictionary(key_type, value_type), DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64, @@ -141,7 +141,7 @@ impl Cast { // dictionary values directly without unpacking the array first, although this // would add more complexity to the code let unpacked_array = Self::unpack_dict_string_array::(&array)?; - Self::spark_cast_string_to_integral(to_type, &unpacked_array, self.eval_mode)? + Self::cast_string_to_integral(to_type, &unpacked_array, self.eval_mode)? } _ => { // when we have no Spark-specific casting we delegate to DataFusion @@ -151,7 +151,7 @@ impl Cast { Ok(spark_cast(cast_result, from_type, to_type)) } - fn spark_cast_string_to_integral( + fn cast_string_to_integral( to_type: &DataType, array: &ArrayRef, eval_mode: EvalMode, @@ -159,30 +159,21 @@ impl Cast { let string_array = array .as_any() .downcast_ref::>() - .expect("spark_cast_string_to_integral expected a string array"); + .expect("cast_string_to_integral expected a string array"); let cast_array: ArrayRef = match to_type { DataType::Int8 => { - spark_cast_utf8_to_integral!(string_array, eval_mode, Int8Type, cast_string_to_i8)? + cast_utf8_to_integral!(string_array, eval_mode, Int8Type, cast_string_to_i8)? + } + DataType::Int16 => { + cast_utf8_to_integral!(string_array, eval_mode, Int16Type, cast_string_to_i16)? + } + DataType::Int32 => { + cast_utf8_to_integral!(string_array, eval_mode, Int32Type, cast_string_to_i32)? + } + DataType::Int64 => { + cast_utf8_to_integral!(string_array, eval_mode, Int64Type, cast_string_to_i64)? } - DataType::Int16 => spark_cast_utf8_to_integral!( - string_array, - eval_mode, - Int16Type, - cast_string_to_i16 - )?, - DataType::Int32 => spark_cast_utf8_to_integral!( - string_array, - eval_mode, - Int32Type, - cast_string_to_i32 - )?, - DataType::Int64 => spark_cast_utf8_to_integral!( - string_array, - eval_mode, - Int64Type, - cast_string_to_i64 - )?, _ => unreachable!("invalid integral type in cast from string"), }; Ok(cast_array) From 0f4da8fe8b62e2f87577492b6ca97ed79359267c Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 23 Apr 2024 13:47:25 -0600 Subject: [PATCH 13/46] rename some methods --- .../execution/datafusion/expressions/cast.rs | 48 +++++++++---------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/core/src/execution/datafusion/expressions/cast.rs b/core/src/execution/datafusion/expressions/cast.rs index d4503aac5..7b52a2462 100644 --- a/core/src/execution/datafusion/expressions/cast.rs +++ b/core/src/execution/datafusion/expressions/cast.rs @@ -68,7 +68,7 @@ pub struct Cast { pub timezone: String, } -macro_rules! cast_utf8_to_integral { +macro_rules! cast_utf8_to_int { ($string_array:expr, $eval_mode:expr, $array_type:ty, $cast_method:ident) => {{ let mut cast_array = PrimitiveArray::<$array_type>::builder($string_array.len()); for i in 0..$string_array.len() { @@ -129,7 +129,7 @@ impl Cast { ( DataType::Utf8, DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64, - ) => Self::cast_string_to_integral(to_type, &array, self.eval_mode)?, + ) => Self::cast_string_to_int(to_type, &array, self.eval_mode)?, ( DataType::Dictionary(key_type, value_type), DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64, @@ -141,7 +141,7 @@ impl Cast { // dictionary values directly without unpacking the array first, although this // would add more complexity to the code let unpacked_array = Self::unpack_dict_string_array::(&array)?; - Self::cast_string_to_integral(to_type, &unpacked_array, self.eval_mode)? + Self::cast_string_to_int(to_type, &unpacked_array, self.eval_mode)? } _ => { // when we have no Spark-specific casting we delegate to DataFusion @@ -151,7 +151,7 @@ impl Cast { Ok(spark_cast(cast_result, from_type, to_type)) } - fn cast_string_to_integral( + fn cast_string_to_int( to_type: &DataType, array: &ArrayRef, eval_mode: EvalMode, @@ -159,20 +159,20 @@ impl Cast { let string_array = array .as_any() .downcast_ref::>() - .expect("cast_string_to_integral expected a string array"); + .expect("cast_string_to_int expected a string array"); let cast_array: ArrayRef = match to_type { DataType::Int8 => { - cast_utf8_to_integral!(string_array, eval_mode, Int8Type, cast_string_to_i8)? + cast_utf8_to_int!(string_array, eval_mode, Int8Type, cast_string_to_i8)? } DataType::Int16 => { - cast_utf8_to_integral!(string_array, eval_mode, Int16Type, cast_string_to_i16)? + cast_utf8_to_int!(string_array, eval_mode, Int16Type, cast_string_to_i16)? } DataType::Int32 => { - cast_utf8_to_integral!(string_array, eval_mode, Int32Type, cast_string_to_i32)? + cast_utf8_to_int!(string_array, eval_mode, Int32Type, cast_string_to_i32)? } DataType::Int64 => { - cast_utf8_to_integral!(string_array, eval_mode, Int64Type, cast_string_to_i64)? + cast_utf8_to_int!(string_array, eval_mode, Int64Type, cast_string_to_i64)? } _ => unreachable!("invalid integral type in cast from string"), }; @@ -225,7 +225,7 @@ impl Cast { } fn cast_string_to_i8(str: &str, eval_mode: EvalMode) -> CometResult> { - Ok(cast_string_to_integral_with_range_check( + Ok(cast_string_to_int_with_range_check( str, eval_mode, "TINYINT", @@ -236,7 +236,7 @@ fn cast_string_to_i8(str: &str, eval_mode: EvalMode) -> CometResult> } fn cast_string_to_i16(str: &str, eval_mode: EvalMode) -> CometResult> { - Ok(cast_string_to_integral_with_range_check( + Ok(cast_string_to_int_with_range_check( str, eval_mode, "SMALLINT", @@ -247,25 +247,25 @@ fn cast_string_to_i16(str: &str, eval_mode: EvalMode) -> CometResult } fn cast_string_to_i32(str: &str, eval_mode: EvalMode) -> CometResult> { - let mut accum = CastStringToIntegral32::default(); + let mut accum = CastStringToInt32::default(); do_cast_string_to_int(&mut accum, str, eval_mode, "INT")?; Ok(accum.result) } fn cast_string_to_i64(str: &str, eval_mode: EvalMode) -> CometResult> { - let mut accum = CastStringToIntegral64::default(); + let mut accum = CastStringToInt64::default(); do_cast_string_to_int(&mut accum, str, eval_mode, "BIGINT")?; Ok(accum.result) } -fn cast_string_to_integral_with_range_check( +fn cast_string_to_int_with_range_check( str: &str, eval_mode: EvalMode, type_name: &str, min: i32, max: i32, ) -> CometResult> { - let mut accum = CastStringToIntegral32::default(); + let mut accum = CastStringToInt32::default(); do_cast_string_to_int(&mut accum, str, eval_mode, type_name)?; match accum.result { None => Ok(None), @@ -276,10 +276,10 @@ fn cast_string_to_integral_with_range_check( } /// We support parsing strings to i32 and i64 to match Spark's logic. Support for i8 and i16 is -/// implemented by first parsing as i32 and then downcasting. The CastStringToIntegral trait is +/// implemented by first parsing as i32 and then downcasting. The CastStringToInt trait is /// introduced so that we can have the parsing logic delegate either to an i32 or i64 accumulator /// and avoid the need to use macros here. -trait CastStringToIntegral { +trait CastStringToInt { fn accumulate( &mut self, eval_mode: EvalMode, @@ -298,13 +298,13 @@ trait CastStringToIntegral { negative: bool, ) -> CometResult<()>; } -struct CastStringToIntegral32 { +struct CastStringToInt32 { negative: bool, result: Option, radix: i32, } -impl Default for CastStringToIntegral32 { +impl Default for CastStringToInt32 { fn default() -> Self { Self { negative: false, @@ -314,7 +314,7 @@ impl Default for CastStringToIntegral32 { } } -impl CastStringToIntegral for CastStringToIntegral32 { +impl CastStringToInt for CastStringToInt32 { fn accumulate( &mut self, eval_mode: EvalMode, @@ -354,13 +354,13 @@ impl CastStringToIntegral for CastStringToIntegral32 { } } -struct CastStringToIntegral64 { +struct CastStringToInt64 { negative: bool, result: Option, radix: i64, } -impl Default for CastStringToIntegral64 { +impl Default for CastStringToInt64 { fn default() -> Self { Self { negative: false, @@ -370,7 +370,7 @@ impl Default for CastStringToIntegral64 { } } -impl CastStringToIntegral for CastStringToIntegral64 { +impl CastStringToInt for CastStringToInt64 { fn accumulate( &mut self, eval_mode: EvalMode, @@ -412,7 +412,7 @@ impl CastStringToIntegral for CastStringToIntegral64 { } fn do_cast_string_to_int( - accumulator: &mut dyn CastStringToIntegral, + accumulator: &mut dyn CastStringToInt, str: &str, eval_mode: EvalMode, type_name: &str, From 8a856d78f64118326d301f66330a56bc6777fd26 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 23 Apr 2024 13:49:27 -0600 Subject: [PATCH 14/46] Update core/src/execution/datafusion/expressions/cast.rs Co-authored-by: comphead --- core/src/execution/datafusion/expressions/cast.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/execution/datafusion/expressions/cast.rs b/core/src/execution/datafusion/expressions/cast.rs index d4503aac5..ab2c2eaf7 100644 --- a/core/src/execution/datafusion/expressions/cast.rs +++ b/core/src/execution/datafusion/expressions/cast.rs @@ -174,7 +174,7 @@ impl Cast { DataType::Int64 => { cast_utf8_to_integral!(string_array, eval_mode, Int64Type, cast_string_to_i64)? } - _ => unreachable!("invalid integral type in cast from string"), + dt => unreachable!(format!("invalid integral type {dt} in cast from string")), }; Ok(cast_array) } From 8c870590e2dfc922274b2780fd5c4255004d28f1 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 23 Apr 2024 15:44:34 -0600 Subject: [PATCH 15/46] addressing feedback --- .../execution/datafusion/expressions/cast.rs | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/core/src/execution/datafusion/expressions/cast.rs b/core/src/execution/datafusion/expressions/cast.rs index 4e412efff..7786f1d29 100644 --- a/core/src/execution/datafusion/expressions/cast.rs +++ b/core/src/execution/datafusion/expressions/cast.rs @@ -69,14 +69,13 @@ pub struct Cast { } macro_rules! cast_utf8_to_int { - ($string_array:expr, $eval_mode:expr, $array_type:ty, $cast_method:ident) => {{ - let mut cast_array = PrimitiveArray::<$array_type>::builder($string_array.len()); - for i in 0..$string_array.len() { - if $string_array.is_null(i) { + ($array:expr, $eval_mode:expr, $array_type:ty, $cast_method:ident) => {{ + let len = $array.len(); + let mut cast_array = PrimitiveArray::<$array_type>::builder(len); + for i in 0..len { + if $array.is_null(i) { cast_array.append_null() - } else if let Some(cast_value) = - $cast_method($string_array.value(i).trim(), $eval_mode)? - { + } else if let Some(cast_value) = $cast_method($array.value(i).trim(), $eval_mode)? { cast_array.append_value(cast_value); } else { cast_array.append_null() @@ -174,7 +173,10 @@ impl Cast { DataType::Int64 => { cast_utf8_to_int!(string_array, eval_mode, Int64Type, cast_string_to_i64)? } - dt => unreachable!(format!("invalid integral type {dt} in cast from string")), + dt => unreachable!( + "{}", + format!("invalid integer type {dt} in cast from string") + ), }; Ok(cast_array) } From 1659a4cc9598aa314fb3d036d8a5b5c0e5f81c39 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 25 Apr 2024 08:17:20 -0600 Subject: [PATCH 16/46] bug fix --- core/src/execution/datafusion/expressions/cast.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/execution/datafusion/expressions/cast.rs b/core/src/execution/datafusion/expressions/cast.rs index 7786f1d29..2ad5a7eb7 100644 --- a/core/src/execution/datafusion/expressions/cast.rs +++ b/core/src/execution/datafusion/expressions/cast.rs @@ -440,8 +440,8 @@ fn do_cast_string_to_int( } // skip + or - - let negative = chars[0] == '-'; - if negative || chars[0] == '+' { + let negative = chars[i] == '-'; + if negative || chars[i] == '+' { i += 1; if i == end { accumulator.reset(); From 5d0730fd23e09ed9c88153c1b0f9879e93d5dca4 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 25 Apr 2024 08:27:12 -0600 Subject: [PATCH 17/46] improve tests --- .../org/apache/comet/CometCastSuite.scala | 53 ++++++++++++++++--- 1 file changed, 45 insertions(+), 8 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala index ae49fc68a..27eeb0324 100644 --- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala @@ -39,7 +39,13 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { // but this is likely a reasonable starting point for now private val whitespaceChars = " \t\r\n" - private val numericPattern = "0123456789e+-." + whitespaceChars + /** + * We use these characters to construct strings that potentially represent valid numbers + * such as `-12.34d` or `4e7`. Invalid numeric strings will also be generated, such as + * `+e.-d`. + */ + private val numericPattern = "0123456789def+-." + whitespaceChars + private val datePattern = "0123456789/" + whitespaceChars private val timestampPattern = "0123456789/:T" + whitespaceChars @@ -66,23 +72,54 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { castTest(testValues, DataTypes.BooleanType) } + val castStringToIntegralInputs = Seq( + "", ".", "+", "-", "+.", "=.", + "-0", "+1", "-1", ".2", "-.2", + "1e1", "1.1d", "1.1f", + Byte.MinValue.toString, + (Byte.MinValue.toShort - 1).toString, + Byte.MaxValue.toString, + (Byte.MaxValue.toShort + 1).toString, + Short.MinValue.toString, + (Short.MinValue.toInt - 1).toString, + Short.MaxValue.toString, + (Short.MaxValue.toInt + 1).toString, + Int.MinValue.toString, + (Int.MinValue.toLong - 1).toString, + Int.MaxValue.toString, + (Int.MaxValue.toLong + 1).toString, + Long.MinValue.toString, + Long.MaxValue.toString, + "-9223372036854775809", // Long.MinValue -1 + "9223372036854775808" // Long.MaxValue + 1 + ) + test("cast string to byte") { - val testValues = - Seq("", ".", "0", "-0", "+1", "-1", ".2", "-.2", "1e1", "127", "128", "-128", "-129") ++ - generateStrings(numericPattern, 8) - castTest(testValues.toDF("a"), DataTypes.ByteType) + // test with hand-picked values + castTest(castStringToIntegralInputs.toDF("a"), DataTypes.ByteType) + // fuzz test + castTest(generateStrings(numericPattern, 5).toDF("a"), DataTypes.ByteType) } test("cast string to short") { - castTest(generateStrings(numericPattern, 8).toDF("a"), DataTypes.ShortType) + // test with hand-picked values + castTest(castStringToIntegralInputs.toDF("a"), DataTypes.ByteType) + // fuzz test + castTest(generateStrings(numericPattern, 5).toDF("a"), DataTypes.ByteType) } test("cast string to int") { - castTest(generateStrings(numericPattern, 8).toDF("a"), DataTypes.IntegerType) + // test with hand-picked values + castTest(castStringToIntegralInputs.toDF("a"), DataTypes.ByteType) + // fuzz test + castTest(generateStrings(numericPattern, 5).toDF("a"), DataTypes.ByteType) } test("cast string to long") { - castTest(generateStrings(numericPattern, 8).toDF("a"), DataTypes.LongType) + // test with hand-picked values + castTest(castStringToIntegralInputs.toDF("a"), DataTypes.ByteType) + // fuzz test + castTest(generateStrings(numericPattern, 5).toDF("a"), DataTypes.ByteType) } ignore("cast string to float") { From 074fddd31e690f7ee4995e68cc0b72248cd70910 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 25 Apr 2024 08:51:07 -0600 Subject: [PATCH 18/46] use checked math operations --- .../execution/datafusion/expressions/cast.rs | 48 ++++++++++++++----- .../org/apache/comet/CometCastSuite.scala | 24 +++++++--- 2 files changed, 52 insertions(+), 20 deletions(-) diff --git a/core/src/execution/datafusion/expressions/cast.rs b/core/src/execution/datafusion/expressions/cast.rs index 2ad5a7eb7..07ea41139 100644 --- a/core/src/execution/datafusion/expressions/cast.rs +++ b/core/src/execution/datafusion/expressions/cast.rs @@ -135,7 +135,7 @@ impl Cast { ) if key_type.as_ref() == &DataType::Int32 && value_type.as_ref() == &DataType::Utf8 => { - // Note that we are unpacking a dictionary-encoded array and then performing + // TODO: we are unpacking a dictionary-encoded array and then performing // the cast. We could potentially improve performance here by casting the // dictionary values directly without unpacking the array first, although this // would add more complexity to the code @@ -328,10 +328,18 @@ impl CastStringToInt for CastStringToInt32 { self.reset(); return none_or_err(eval_mode, type_name, str); } - self.result = Some(self.result.unwrap_or(0) * self.radix - digit as i32); - if self.result.unwrap() > 0 { + let v = self.result.unwrap_or(0) * self.radix; + if let Some(x) = v.checked_sub(digit as i32) { + if x > 0 { + self.reset(); + return none_or_err(eval_mode, type_name, str); + } else { + self.result = Some(x); + } + } else { self.reset(); return none_or_err(eval_mode, type_name, str); + } Ok(()) } @@ -346,10 +354,13 @@ impl CastStringToInt for CastStringToInt32 { str: &str, negative: bool, ) -> CometResult<()> { - if self.result.is_some() && !negative { - self.result = Some(-self.result.unwrap()); - if self.result.unwrap() < 0 { - return none_or_err(eval_mode, type_name, str); + if !negative { + if let Some(r) = self.result { + let negated = r.checked_neg().unwrap_or(-1); + if negated < 0 { + return none_or_err(eval_mode, type_name, str); + } + self.result = Some(negated); } } Ok(()) @@ -384,10 +395,18 @@ impl CastStringToInt for CastStringToInt64 { self.reset(); return none_or_err(eval_mode, type_name, str); } - self.result = Some(self.result.unwrap_or(0) * self.radix - digit as i64); - if self.result.unwrap() > 0 { + let v = self.result.unwrap_or(0) * self.radix; + if let Some(x) = v.checked_sub(digit as i64) { + if x > 0 { + self.reset(); + return none_or_err(eval_mode, type_name, str); + } else { + self.result = Some(x); + } + } else { self.reset(); return none_or_err(eval_mode, type_name, str); + } Ok(()) } @@ -403,10 +422,13 @@ impl CastStringToInt for CastStringToInt64 { str: &str, negative: bool, ) -> CometResult<()> { - if self.result.is_some() && !negative { - self.result = Some(-self.result.unwrap()); - if self.result.unwrap() < 0 { - return none_or_err(eval_mode, type_name, str); + if !negative { + if let Some(r) = self.result { + let negated = r.checked_neg().unwrap_or(-1); + if negated < 0 { + return none_or_err(eval_mode, type_name, str); + } + self.result = Some(negated); } } Ok(()) diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala index 27eeb0324..59941b0c6 100644 --- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala @@ -40,9 +40,8 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { private val whitespaceChars = " \t\r\n" /** - * We use these characters to construct strings that potentially represent valid numbers - * such as `-12.34d` or `4e7`. Invalid numeric strings will also be generated, such as - * `+e.-d`. + * We use these characters to construct strings that potentially represent valid numbers such as + * `-12.34d` or `4e7`. Invalid numeric strings will also be generated, such as `+e.-d`. */ private val numericPattern = "0123456789def+-." + whitespaceChars @@ -73,9 +72,20 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { } val castStringToIntegralInputs = Seq( - "", ".", "+", "-", "+.", "=.", - "-0", "+1", "-1", ".2", "-.2", - "1e1", "1.1d", "1.1f", + "", + ".", + "+", + "-", + "+.", + "-.", + "-0", + "+1", + "-1", + ".2", + "-.2", + "1e1", + "1.1d", + "1.1f", Byte.MinValue.toString, (Byte.MinValue.toShort - 1).toString, Byte.MaxValue.toString, @@ -197,7 +207,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { val cometMessage = actual.getMessage .replace("Execution error: ", "") - assert(expected.getMessage == cometMessage) + assert(cometMessage == expected.getMessage) } else { // Spark 3.2 and 3.3 have a different error message format so we can't do a direct // comparison between Spark and Comet. From 818ab9cbc9b04887c757536559421d77eedf8eb8 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 25 Apr 2024 09:03:04 -0600 Subject: [PATCH 19/46] fix regressions --- core/src/execution/datafusion/expressions/cast.rs | 6 ++---- spark/src/test/scala/org/apache/comet/CometCastSuite.scala | 2 +- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/core/src/execution/datafusion/expressions/cast.rs b/core/src/execution/datafusion/expressions/cast.rs index 07ea41139..541b51077 100644 --- a/core/src/execution/datafusion/expressions/cast.rs +++ b/core/src/execution/datafusion/expressions/cast.rs @@ -339,7 +339,6 @@ impl CastStringToInt for CastStringToInt32 { } else { self.reset(); return none_or_err(eval_mode, type_name, str); - } Ok(()) } @@ -406,7 +405,6 @@ impl CastStringToInt for CastStringToInt64 { } else { self.reset(); return none_or_err(eval_mode, type_name, str); - } Ok(()) } @@ -458,7 +456,7 @@ fn do_cast_string_to_int( // check for empty string if i == end { accumulator.reset(); - return Ok(()); + return none_or_err(eval_mode, type_name, str); } // skip + or - @@ -467,7 +465,7 @@ fn do_cast_string_to_int( i += 1; if i == end { accumulator.reset(); - return Ok(()); + return none_or_err(eval_mode, type_name, str); } } diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala index 59941b0c6..2eaf01b50 100644 --- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala @@ -43,7 +43,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { * We use these characters to construct strings that potentially represent valid numbers such as * `-12.34d` or `4e7`. Invalid numeric strings will also be generated, such as `+e.-d`. */ - private val numericPattern = "0123456789def+-." + whitespaceChars + private val numericPattern = "0123456789deEf+-." + whitespaceChars private val datePattern = "0123456789/" + whitespaceChars private val timestampPattern = "0123456789/:T" + whitespaceChars From 4f2539d1541dea531ee27a0e5fbfe31ead150ae5 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 25 Apr 2024 09:10:00 -0600 Subject: [PATCH 20/46] improve tests --- .../org/apache/comet/CometCastSuite.scala | 45 +++++++++++-------- .../org/apache/spark/sql/CometTestBase.scala | 7 +-- 2 files changed, 31 insertions(+), 21 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala index 2eaf01b50..e1b70a2fd 100644 --- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala @@ -199,24 +199,33 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { // cast() should throw exception on invalid inputs when ansi mode is enabled val df = data.withColumn("converted", col("a").cast(toType)) - val (expected, actual) = checkSparkThrows(df) - - if (CometSparkSessionExtensions.isSpark34Plus) { - // We have to workaround https://github.com/apache/datafusion-comet/issues/293 here by - // removing the "Execution error: " error message prefix that is added by DataFusion - val cometMessage = actual.getMessage - .replace("Execution error: ", "") - - assert(cometMessage == expected.getMessage) - } else { - // Spark 3.2 and 3.3 have a different error message format so we can't do a direct - // comparison between Spark and Comet. - // Spark message is in format `invalid input syntax for type TYPE: VALUE` - // Comet message is in format `The value 'VALUE' of the type FROM_TYPE cannot be cast to TO_TYPE` - // We just check that the comet message contains the same invalid value as the Spark message - val sparkInvalidValue = - expected.getMessage.substring(expected.getMessage.indexOf(':') + 2) - assert(actual.getMessage.contains(sparkInvalidValue)) + checkSparkMaybeThrows(df) match { + case (None, None) => + // neither system threw an exception + case (None, Some(e)) => + // Spark succeeded but Comet failed + throw e + case (Some(e), None) => + // Spark failed but Comet succeeded + fail(s"Comet should have failed with ${e.getCause.getMessage}") + case (Some(sparkException), Some(cometException)) => + // both systems threw an exception so we make sure they are the same + val sparkMessage = sparkException.getCause.getMessage + // We have to workaround https://github.com/apache/datafusion-comet/issues/293 here by + // removing the "Execution error: " error message prefix that is added by DataFusion + val cometMessage = cometException.getCause.getMessage + .replace("Execution error: ", "") + if (CometSparkSessionExtensions.isSpark34Plus) { + assert(cometMessage == sparkMessage) + } else { + // Spark 3.2 and 3.3 have a different error message format so we can't do a direct + // comparison between Spark and Comet. + // Spark message is in format `invalid input syntax for type TYPE: VALUE` + // Comet message is in format `The value 'VALUE' of the type FROM_TYPE cannot be cast to TO_TYPE` + // We just check that the comet message contains the same invalid value as the Spark message + val sparkInvalidValue = sparkMessage.substring(sparkMessage.indexOf(':') + 2) + assert(cometMessage.contains(sparkInvalidValue)) + } } // try_cast() should always return null for invalid inputs diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala index 85e58824c..8957f5289 100644 --- a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala +++ b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala @@ -216,15 +216,16 @@ abstract class CometTestBase checkAnswerWithTol(dfComet, expected, absTol: Double) } - protected def checkSparkThrows(df: => DataFrame): (Throwable, Throwable) = { + protected def checkSparkMaybeThrows( + df: => DataFrame): (Option[Throwable], Option[Throwable]) = { var expected: Option[Throwable] = None withSQLConf(CometConf.COMET_ENABLED.key -> "false") { val dfSpark = Dataset.ofRows(spark, df.logicalPlan) expected = Try(dfSpark.collect()).failed.toOption } val dfComet = Dataset.ofRows(spark, df.logicalPlan) - val actual = Try(dfComet.collect()).failed.get - (expected.get.getCause, actual.getCause) + val actual = Try(dfComet.collect()).failed.toOption + (expected, actual) } private var _spark: SparkSession = _ From de87b5d7b002000d6e7c5cc8ae51f721c631ebe0 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 25 Apr 2024 09:29:28 -0600 Subject: [PATCH 21/46] code cleanup --- .../execution/datafusion/expressions/cast.rs | 50 +++++++++++-------- 1 file changed, 30 insertions(+), 20 deletions(-) diff --git a/core/src/execution/datafusion/expressions/cast.rs b/core/src/execution/datafusion/expressions/cast.rs index 541b51077..7b125211f 100644 --- a/core/src/execution/datafusion/expressions/cast.rs +++ b/core/src/execution/datafusion/expressions/cast.rs @@ -324,21 +324,25 @@ impl CastStringToInt for CastStringToInt32 { str: &str, digit: u32, ) -> CometResult<()> { - if self.result.is_some() && self.result.unwrap() < i32::MIN / self.radix { - self.reset(); - return none_or_err(eval_mode, type_name, str); + // We are going to process the new digit and accumulate the result. However, before doing + // this, if the result is already smaller than the stopValue(Integer.MIN_VALUE / radix), + // then result * 10 will definitely be smaller than minValue, and we can stop + if let Some(r) = self.result { + let stop_value = i32::MIN / self.radix; + if r < stop_value { + self.reset(); + return none_or_err(eval_mode, type_name, str); + } } + // Since the previous result is less than or equal to stopValue(Integer.MIN_VALUE / radix), + // we can just use `result > 0` to check overflow. If result overflows, we should stop let v = self.result.unwrap_or(0) * self.radix; - if let Some(x) = v.checked_sub(digit as i32) { - if x > 0 { + match v.checked_sub(digit as i32) { + Some(x) if x <= 0 => self.result = Some(x), + _ => { self.reset(); return none_or_err(eval_mode, type_name, str); - } else { - self.result = Some(x); } - } else { - self.reset(); - return none_or_err(eval_mode, type_name, str); } Ok(()) } @@ -357,6 +361,7 @@ impl CastStringToInt for CastStringToInt32 { if let Some(r) = self.result { let negated = r.checked_neg().unwrap_or(-1); if negated < 0 { + self.reset(); return none_or_err(eval_mode, type_name, str); } self.result = Some(negated); @@ -390,21 +395,25 @@ impl CastStringToInt for CastStringToInt64 { str: &str, digit: u32, ) -> CometResult<()> { - if self.result.unwrap_or(0) < i64::MIN / self.radix { - self.reset(); - return none_or_err(eval_mode, type_name, str); + // We are going to process the new digit and accumulate the result. However, before doing + // this, if the result is already smaller than the stopValue(Integer.MIN_VALUE / radix), + // then result * 10 will definitely be smaller than minValue, and we can stop + if let Some(r) = self.result { + let stop_value = i64::MIN / self.radix; + if r < stop_value { + self.reset(); + return none_or_err(eval_mode, type_name, str); + } } + // Since the previous result is less than or equal to stopValue(Integer.MIN_VALUE / radix), + // we can just use `result > 0` to check overflow. If result overflows, we should stop let v = self.result.unwrap_or(0) * self.radix; - if let Some(x) = v.checked_sub(digit as i64) { - if x > 0 { + match v.checked_sub(digit as i64) { + Some(x) if x <= 0 => self.result = Some(x), + _ => { self.reset(); return none_or_err(eval_mode, type_name, str); - } else { - self.result = Some(x); } - } else { - self.reset(); - return none_or_err(eval_mode, type_name, str); } Ok(()) } @@ -424,6 +433,7 @@ impl CastStringToInt for CastStringToInt64 { if let Some(r) = self.result { let negated = r.checked_neg().unwrap_or(-1); if negated < 0 { + self.reset(); return none_or_err(eval_mode, type_name, str); } self.result = Some(negated); From 6cbc543c9cc5107956f9658bfc53f18a3744eb46 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 25 Apr 2024 09:51:15 -0600 Subject: [PATCH 22/46] lint --- spark/src/test/scala/org/apache/comet/CometCastSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala index e1b70a2fd..f9e8b257d 100644 --- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala @@ -71,7 +71,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { castTest(testValues, DataTypes.BooleanType) } - val castStringToIntegralInputs = Seq( + private val castStringToIntegralInputs: Seq[String] = Seq( "", ".", "+", From 1e5867c36dfb5f7aeb2b0ab00910077e01db8091 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sat, 27 Apr 2024 08:00:08 -0600 Subject: [PATCH 23/46] Update spark/src/test/scala/org/apache/comet/CometCastSuite.scala Co-authored-by: Liang-Chi Hsieh --- spark/src/test/scala/org/apache/comet/CometCastSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala index f9e8b257d..ff77834c3 100644 --- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala @@ -113,7 +113,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { test("cast string to short") { // test with hand-picked values - castTest(castStringToIntegralInputs.toDF("a"), DataTypes.ByteType) + castTest(castStringToIntegralInputs.toDF("a"), DataTypes.ShortType) // fuzz test castTest(generateStrings(numericPattern, 5).toDF("a"), DataTypes.ByteType) } From b28371133a20342da2cf6f837e6b353445e2a31b Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sat, 27 Apr 2024 08:00:34 -0600 Subject: [PATCH 24/46] Update spark/src/test/scala/org/apache/comet/CometCastSuite.scala Co-authored-by: Liang-Chi Hsieh --- spark/src/test/scala/org/apache/comet/CometCastSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala index ff77834c3..a03cb58bc 100644 --- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala @@ -115,7 +115,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { // test with hand-picked values castTest(castStringToIntegralInputs.toDF("a"), DataTypes.ShortType) // fuzz test - castTest(generateStrings(numericPattern, 5).toDF("a"), DataTypes.ByteType) + castTest(generateStrings(numericPattern, 5).toDF("a"), DataTypes.ShortType) } test("cast string to int") { From 9a5123990f164b4a3deb09365de99aafe1986b73 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sat, 27 Apr 2024 08:01:10 -0600 Subject: [PATCH 25/46] Update spark/src/test/scala/org/apache/comet/CometCastSuite.scala Co-authored-by: Liang-Chi Hsieh --- spark/src/test/scala/org/apache/comet/CometCastSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala index a03cb58bc..db56f1ccd 100644 --- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala @@ -120,7 +120,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { test("cast string to int") { // test with hand-picked values - castTest(castStringToIntegralInputs.toDF("a"), DataTypes.ByteType) + castTest(castStringToIntegralInputs.toDF("a"), DataTypes.IntegerType) // fuzz test castTest(generateStrings(numericPattern, 5).toDF("a"), DataTypes.ByteType) } From 0131005c14c8122dc603fcfe9b42293e0e167e68 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sat, 27 Apr 2024 08:01:37 -0600 Subject: [PATCH 26/46] Update spark/src/test/scala/org/apache/comet/CometCastSuite.scala Co-authored-by: Liang-Chi Hsieh --- spark/src/test/scala/org/apache/comet/CometCastSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala index db56f1ccd..6f0829649 100644 --- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala @@ -127,9 +127,9 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { test("cast string to long") { // test with hand-picked values - castTest(castStringToIntegralInputs.toDF("a"), DataTypes.ByteType) + castTest(castStringToIntegralInputs.toDF("a"), DataTypes.LongType) // fuzz test - castTest(generateStrings(numericPattern, 5).toDF("a"), DataTypes.ByteType) + castTest(generateStrings(numericPattern, 5).toDF("a"), DataTypes.LongType) } ignore("cast string to float") { From d9b1fb6066a1bc24442d6eca664e4f59d0228a26 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sat, 27 Apr 2024 08:02:03 -0600 Subject: [PATCH 27/46] Update spark/src/test/scala/org/apache/comet/CometCastSuite.scala Co-authored-by: Liang-Chi Hsieh --- spark/src/test/scala/org/apache/comet/CometCastSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala index 6f0829649..b1fea2f73 100644 --- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala @@ -122,7 +122,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { // test with hand-picked values castTest(castStringToIntegralInputs.toDF("a"), DataTypes.IntegerType) // fuzz test - castTest(generateStrings(numericPattern, 5).toDF("a"), DataTypes.ByteType) + castTest(generateStrings(numericPattern, 5).toDF("a"), DataTypes.IntegerType) } test("cast string to long") { From 317ea08099b345cad9c1399b9f9a3d9bd09dfd2c Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sat, 27 Apr 2024 11:16:01 -0600 Subject: [PATCH 28/46] simplify the code by removing the i32/i64 specializaton via traits and just implement two versions of the function --- .../execution/datafusion/expressions/cast.rs | 261 +++++++----------- 1 file changed, 104 insertions(+), 157 deletions(-) diff --git a/core/src/execution/datafusion/expressions/cast.rs b/core/src/execution/datafusion/expressions/cast.rs index 7b125211f..5e4296c14 100644 --- a/core/src/execution/datafusion/expressions/cast.rs +++ b/core/src/execution/datafusion/expressions/cast.rs @@ -249,15 +249,11 @@ fn cast_string_to_i16(str: &str, eval_mode: EvalMode) -> CometResult } fn cast_string_to_i32(str: &str, eval_mode: EvalMode) -> CometResult> { - let mut accum = CastStringToInt32::default(); - do_cast_string_to_int(&mut accum, str, eval_mode, "INT")?; - Ok(accum.result) + Ok(do_cast_string_to_i32(str, eval_mode, "INT", i32::MIN)?.map(|n| n as i32)) } fn cast_string_to_i64(str: &str, eval_mode: EvalMode) -> CometResult> { - let mut accum = CastStringToInt64::default(); - do_cast_string_to_int(&mut accum, str, eval_mode, "BIGINT")?; - Ok(accum.result) + do_cast_string_to_i64(str, eval_mode, "BIGINT", i64::MIN) } fn cast_string_to_int_with_range_check( @@ -267,188 +263,116 @@ fn cast_string_to_int_with_range_check( min: i32, max: i32, ) -> CometResult> { - let mut accum = CastStringToInt32::default(); - do_cast_string_to_int(&mut accum, str, eval_mode, type_name)?; - match accum.result { + match do_cast_string_to_i32(str, eval_mode, type_name, i32::MIN)? { None => Ok(None), - Some(v) if v >= min && v <= max => Ok(Some(v)), + Some(v) if v >= min && v <= max => Ok(Some(v as i32)), _ if eval_mode == EvalMode::Ansi => Err(invalid_value(str, "STRING", type_name)), _ => Ok(None), } } -/// We support parsing strings to i32 and i64 to match Spark's logic. Support for i8 and i16 is -/// implemented by first parsing as i32 and then downcasting. The CastStringToInt trait is -/// introduced so that we can have the parsing logic delegate either to an i32 or i64 accumulator -/// and avoid the need to use macros here. -trait CastStringToInt { - fn accumulate( - &mut self, - eval_mode: EvalMode, - type_name: &str, - str: &str, - digit: u32, - ) -> CometResult<()>; - - fn reset(&mut self); - - fn finish( - &mut self, - eval_mode: EvalMode, - type_name: &str, - str: &str, - negative: bool, - ) -> CometResult<()>; -} -struct CastStringToInt32 { - negative: bool, - result: Option, - radix: i32, -} +fn do_cast_string_to_i32( + str: &str, + eval_mode: EvalMode, + type_name: &str, + min_value: i32, +) -> CometResult> { + let chars: Vec = str.chars().collect(); + let mut i = 0; + let mut end = chars.len(); -impl Default for CastStringToInt32 { - fn default() -> Self { - Self { - negative: false, - result: Some(0), - radix: 10, - } + // skip leading whitespace + while i < end && chars[i].is_whitespace() { + i += 1; } -} -impl CastStringToInt for CastStringToInt32 { - fn accumulate( - &mut self, - eval_mode: EvalMode, - type_name: &str, - str: &str, - digit: u32, - ) -> CometResult<()> { - // We are going to process the new digit and accumulate the result. However, before doing - // this, if the result is already smaller than the stopValue(Integer.MIN_VALUE / radix), - // then result * 10 will definitely be smaller than minValue, and we can stop - if let Some(r) = self.result { - let stop_value = i32::MIN / self.radix; - if r < stop_value { - self.reset(); - return none_or_err(eval_mode, type_name, str); - } - } - // Since the previous result is less than or equal to stopValue(Integer.MIN_VALUE / radix), - // we can just use `result > 0` to check overflow. If result overflows, we should stop - let v = self.result.unwrap_or(0) * self.radix; - match v.checked_sub(digit as i32) { - Some(x) if x <= 0 => self.result = Some(x), - _ => { - self.reset(); - return none_or_err(eval_mode, type_name, str); - } - } - Ok(()) + // skip trailing whitespace + while end > i && chars[end - 1].is_whitespace() { + end -= 1; } - fn reset(&mut self) { - self.result = None; + + // check for empty string + if i == end { + return none_or_err(eval_mode, type_name, str); } - fn finish( - &mut self, - eval_mode: EvalMode, - type_name: &str, - str: &str, - negative: bool, - ) -> CometResult<()> { - if !negative { - if let Some(r) = self.result { - let negated = r.checked_neg().unwrap_or(-1); - if negated < 0 { - self.reset(); - return none_or_err(eval_mode, type_name, str); - } - self.result = Some(negated); - } + // skip + or - + let negative = chars[i] == '-'; + if negative || chars[i] == '+' { + i += 1; + if i == end { + return none_or_err(eval_mode, type_name, str); } - Ok(()) } -} -struct CastStringToInt64 { - negative: bool, - result: Option, - radix: i64, -} + let mut result = 0; + let radix = 10; + let stop_value = min_value / radix; + while i < end { + let b = chars[i]; + i += 1; -impl Default for CastStringToInt64 { - fn default() -> Self { - Self { - negative: false, - result: Some(0), - radix: 10, + if b == '.' && eval_mode == EvalMode::Legacy { + // truncate decimal in legacy mode + break; } - } -} -impl CastStringToInt for CastStringToInt64 { - fn accumulate( - &mut self, - eval_mode: EvalMode, - type_name: &str, - str: &str, - digit: u32, - ) -> CometResult<()> { + let digit = if b.is_ascii_digit() { + (b as u32) - ('0' as u32) + } else { + return none_or_err(eval_mode, type_name, str); + }; + // We are going to process the new digit and accumulate the result. However, before doing // this, if the result is already smaller than the stopValue(Integer.MIN_VALUE / radix), // then result * 10 will definitely be smaller than minValue, and we can stop - if let Some(r) = self.result { - let stop_value = i64::MIN / self.radix; - if r < stop_value { - self.reset(); - return none_or_err(eval_mode, type_name, str); - } + if result < stop_value { + return none_or_err(eval_mode, type_name, str); } + // Since the previous result is less than or equal to stopValue(Integer.MIN_VALUE / radix), // we can just use `result > 0` to check overflow. If result overflows, we should stop - let v = self.result.unwrap_or(0) * self.radix; - match v.checked_sub(digit as i64) { - Some(x) if x <= 0 => self.result = Some(x), + let v = result * radix; + match v.checked_sub(digit as i32) { + Some(x) if x <= 0 => result = x, _ => { - self.reset(); return none_or_err(eval_mode, type_name, str); } } - Ok(()) } - fn reset(&mut self) { - self.result = None; + // This is the case when we've encountered a decimal separator. The fractional + // part will not change the number, but we will verify that the fractional part + // is well-formed. + while i < end { + let b = chars[i]; + if !b.is_ascii_digit() { + return none_or_err(eval_mode, type_name, str); + } + i += 1; } - fn finish( - &mut self, - eval_mode: EvalMode, - type_name: &str, - str: &str, - negative: bool, - ) -> CometResult<()> { - if !negative { - if let Some(r) = self.result { - let negated = r.checked_neg().unwrap_or(-1); - if negated < 0 { - self.reset(); - return none_or_err(eval_mode, type_name, str); - } - self.result = Some(negated); + if !negative { + if let Some(x) = result.checked_neg() { + if x < 0 { + return none_or_err(eval_mode, type_name, str); } + result = x; + } else { + return none_or_err(eval_mode, type_name, str); } - Ok(()) } + + Ok(Some(result)) } -fn do_cast_string_to_int( - accumulator: &mut dyn CastStringToInt, +/// This is a copy of do_cast_string_to_i32 but with the type changed to i64 +fn do_cast_string_to_i64( str: &str, eval_mode: EvalMode, type_name: &str, -) -> CometResult<()> { + min_value: i64, +) -> CometResult> { let chars: Vec = str.chars().collect(); let mut i = 0; let mut end = chars.len(); @@ -465,7 +389,6 @@ fn do_cast_string_to_int( // check for empty string if i == end { - accumulator.reset(); return none_or_err(eval_mode, type_name, str); } @@ -474,11 +397,13 @@ fn do_cast_string_to_int( if negative || chars[i] == '+' { i += 1; if i == end { - accumulator.reset(); return none_or_err(eval_mode, type_name, str); } } + let mut result = 0; + let radix = 10; + let stop_value = min_value / radix; while i < end { let b = chars[i]; i += 1; @@ -491,11 +416,25 @@ fn do_cast_string_to_int( let digit = if b.is_ascii_digit() { (b as u32) - ('0' as u32) } else { - accumulator.reset(); return none_or_err(eval_mode, type_name, str); }; - accumulator.accumulate(eval_mode, type_name, str, digit)?; + // We are going to process the new digit and accumulate the result. However, before doing + // this, if the result is already smaller than the stopValue(Integer.MIN_VALUE / radix), + // then result * 10 will definitely be smaller than minValue, and we can stop + if result < stop_value { + return none_or_err(eval_mode, type_name, str); + } + + // Since the previous result is less than or equal to stopValue(Integer.MIN_VALUE / radix), + // we can just use `result > 0` to check overflow. If result overflows, we should stop + let v = result * radix; + match v.checked_sub(digit as i64) { + Some(x) if x <= 0 => result = x, + _ => { + return none_or_err(eval_mode, type_name, str); + } + } } // This is the case when we've encountered a decimal separator. The fractional @@ -504,22 +443,30 @@ fn do_cast_string_to_int( while i < end { let b = chars[i]; if !b.is_ascii_digit() { - accumulator.reset(); return none_or_err(eval_mode, type_name, str); } i += 1; } - accumulator.finish(eval_mode, type_name, str, negative)?; + if !negative { + if let Some(x) = result.checked_neg() { + if x < 0 { + return none_or_err(eval_mode, type_name, str); + } + result = x; + } else { + return none_or_err(eval_mode, type_name, str); + } + } - Ok(()) + Ok(Some(result)) } /// Either return Ok(None) or Err(CometError::CastInvalidValue) depending on the evaluation mode -fn none_or_err(eval_mode: EvalMode, type_name: &str, str: &str) -> CometResult<()> { +fn none_or_err(eval_mode: EvalMode, type_name: &str, str: &str) -> CometResult> { match eval_mode { EvalMode::Ansi => Err(invalid_value(str, "STRING", type_name)), - _ => Ok(()), + _ => Ok(None), } } From efa7747880ce9596b29dea9c51328d1bd77338b5 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sat, 27 Apr 2024 11:28:01 -0600 Subject: [PATCH 29/46] reimplement using generics --- .../execution/datafusion/expressions/cast.rs | 122 +++--------------- 1 file changed, 15 insertions(+), 107 deletions(-) diff --git a/core/src/execution/datafusion/expressions/cast.rs b/core/src/execution/datafusion/expressions/cast.rs index 5e4296c14..01c34ae24 100644 --- a/core/src/execution/datafusion/expressions/cast.rs +++ b/core/src/execution/datafusion/expressions/cast.rs @@ -37,6 +37,7 @@ use arrow_schema::{DataType, Schema}; use datafusion::logical_expr::ColumnarValue; use datafusion_common::{internal_err, Result as DataFusionResult, ScalarValue}; use datafusion_physical_expr::PhysicalExpr; +use num::{traits::CheckedNeg, CheckedSub, Integer, Num}; use crate::execution::datafusion::expressions::utils::{ array_with_timezone, down_cast_any_ref, spark_cast, @@ -249,11 +250,11 @@ fn cast_string_to_i16(str: &str, eval_mode: EvalMode) -> CometResult } fn cast_string_to_i32(str: &str, eval_mode: EvalMode) -> CometResult> { - Ok(do_cast_string_to_i32(str, eval_mode, "INT", i32::MIN)?.map(|n| n as i32)) + Ok(do_cast_string_to_int::(str, eval_mode, "INT", i32::MIN)?.map(|n| n as i32)) } fn cast_string_to_i64(str: &str, eval_mode: EvalMode) -> CometResult> { - do_cast_string_to_i64(str, eval_mode, "BIGINT", i64::MIN) + do_cast_string_to_int::(str, eval_mode, "BIGINT", i64::MIN) } fn cast_string_to_int_with_range_check( @@ -263,7 +264,7 @@ fn cast_string_to_int_with_range_check( min: i32, max: i32, ) -> CometResult> { - match do_cast_string_to_i32(str, eval_mode, type_name, i32::MIN)? { + match do_cast_string_to_int(str, eval_mode, type_name, i32::MIN)? { None => Ok(None), Some(v) if v >= min && v <= max => Ok(Some(v as i32)), _ if eval_mode == EvalMode::Ansi => Err(invalid_value(str, "STRING", type_name)), @@ -271,108 +272,14 @@ fn cast_string_to_int_with_range_check( } } -fn do_cast_string_to_i32( +fn do_cast_string_to_int< + T: Num + PartialOrd + Integer + CheckedSub + CheckedNeg + From + Copy, +>( str: &str, eval_mode: EvalMode, type_name: &str, - min_value: i32, -) -> CometResult> { - let chars: Vec = str.chars().collect(); - let mut i = 0; - let mut end = chars.len(); - - // skip leading whitespace - while i < end && chars[i].is_whitespace() { - i += 1; - } - - // skip trailing whitespace - while end > i && chars[end - 1].is_whitespace() { - end -= 1; - } - - // check for empty string - if i == end { - return none_or_err(eval_mode, type_name, str); - } - - // skip + or - - let negative = chars[i] == '-'; - if negative || chars[i] == '+' { - i += 1; - if i == end { - return none_or_err(eval_mode, type_name, str); - } - } - - let mut result = 0; - let radix = 10; - let stop_value = min_value / radix; - while i < end { - let b = chars[i]; - i += 1; - - if b == '.' && eval_mode == EvalMode::Legacy { - // truncate decimal in legacy mode - break; - } - - let digit = if b.is_ascii_digit() { - (b as u32) - ('0' as u32) - } else { - return none_or_err(eval_mode, type_name, str); - }; - - // We are going to process the new digit and accumulate the result. However, before doing - // this, if the result is already smaller than the stopValue(Integer.MIN_VALUE / radix), - // then result * 10 will definitely be smaller than minValue, and we can stop - if result < stop_value { - return none_or_err(eval_mode, type_name, str); - } - - // Since the previous result is less than or equal to stopValue(Integer.MIN_VALUE / radix), - // we can just use `result > 0` to check overflow. If result overflows, we should stop - let v = result * radix; - match v.checked_sub(digit as i32) { - Some(x) if x <= 0 => result = x, - _ => { - return none_or_err(eval_mode, type_name, str); - } - } - } - - // This is the case when we've encountered a decimal separator. The fractional - // part will not change the number, but we will verify that the fractional part - // is well-formed. - while i < end { - let b = chars[i]; - if !b.is_ascii_digit() { - return none_or_err(eval_mode, type_name, str); - } - i += 1; - } - - if !negative { - if let Some(x) = result.checked_neg() { - if x < 0 { - return none_or_err(eval_mode, type_name, str); - } - result = x; - } else { - return none_or_err(eval_mode, type_name, str); - } - } - - Ok(Some(result)) -} - -/// This is a copy of do_cast_string_to_i32 but with the type changed to i64 -fn do_cast_string_to_i64( - str: &str, - eval_mode: EvalMode, - type_name: &str, - min_value: i64, -) -> CometResult> { + min_value: T, +) -> CometResult> { let chars: Vec = str.chars().collect(); let mut i = 0; let mut end = chars.len(); @@ -401,8 +308,8 @@ fn do_cast_string_to_i64( } } - let mut result = 0; - let radix = 10; + let mut result: T = T::zero(); + let radix = T::from(10); let stop_value = min_value / radix; while i < end { let b = chars[i]; @@ -429,8 +336,9 @@ fn do_cast_string_to_i64( // Since the previous result is less than or equal to stopValue(Integer.MIN_VALUE / radix), // we can just use `result > 0` to check overflow. If result overflows, we should stop let v = result * radix; - match v.checked_sub(digit as i64) { - Some(x) if x <= 0 => result = x, + let digit = (digit as i32).into(); + match v.checked_sub(&digit) { + Some(x) if x <= T::zero() => result = x, _ => { return none_or_err(eval_mode, type_name, str); } @@ -450,7 +358,7 @@ fn do_cast_string_to_i64( if !negative { if let Some(x) = result.checked_neg() { - if x < 0 { + if x < T::zero() { return none_or_err(eval_mode, type_name, str); } result = x; From 344c594603986d70844521e5424a665d91810fc1 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sat, 27 Apr 2024 11:38:44 -0600 Subject: [PATCH 30/46] clippy --- core/src/execution/datafusion/expressions/cast.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/execution/datafusion/expressions/cast.rs b/core/src/execution/datafusion/expressions/cast.rs index 01c34ae24..5264d48d9 100644 --- a/core/src/execution/datafusion/expressions/cast.rs +++ b/core/src/execution/datafusion/expressions/cast.rs @@ -250,7 +250,7 @@ fn cast_string_to_i16(str: &str, eval_mode: EvalMode) -> CometResult } fn cast_string_to_i32(str: &str, eval_mode: EvalMode) -> CometResult> { - Ok(do_cast_string_to_int::(str, eval_mode, "INT", i32::MIN)?.map(|n| n as i32)) + Ok(do_cast_string_to_int::(str, eval_mode, "INT", i32::MIN)?) } fn cast_string_to_i64(str: &str, eval_mode: EvalMode) -> CometResult> { @@ -266,7 +266,7 @@ fn cast_string_to_int_with_range_check( ) -> CometResult> { match do_cast_string_to_int(str, eval_mode, type_name, i32::MIN)? { None => Ok(None), - Some(v) if v >= min && v <= max => Ok(Some(v as i32)), + Some(v) if v >= min && v <= max => Ok(Some(v)), _ if eval_mode == EvalMode::Ansi => Err(invalid_value(str, "STRING", type_name)), _ => Ok(None), } From 83ad6a2695787437779538518c0a4ecceb379f95 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sat, 27 Apr 2024 11:46:38 -0600 Subject: [PATCH 31/46] add LargeUtf8 handling --- .../execution/datafusion/expressions/cast.rs | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/core/src/execution/datafusion/expressions/cast.rs b/core/src/execution/datafusion/expressions/cast.rs index 5264d48d9..77279f128 100644 --- a/core/src/execution/datafusion/expressions/cast.rs +++ b/core/src/execution/datafusion/expressions/cast.rs @@ -129,7 +129,11 @@ impl Cast { ( DataType::Utf8, DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64, - ) => Self::cast_string_to_int(to_type, &array, self.eval_mode)?, + ) => Self::cast_string_to_int::(to_type, &array, self.eval_mode)?, + ( + DataType::LargeUtf8, + DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64, + ) => Self::cast_string_to_int::(to_type, &array, self.eval_mode)?, ( DataType::Dictionary(key_type, value_type), DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64, @@ -141,7 +145,7 @@ impl Cast { // dictionary values directly without unpacking the array first, although this // would add more complexity to the code let unpacked_array = Self::unpack_dict_string_array::(&array)?; - Self::cast_string_to_int(to_type, &unpacked_array, self.eval_mode)? + Self::cast_string_to_int::(to_type, &unpacked_array, self.eval_mode)? } _ => { // when we have no Spark-specific casting we delegate to DataFusion @@ -151,14 +155,14 @@ impl Cast { Ok(spark_cast(cast_result, from_type, to_type)) } - fn cast_string_to_int( + fn cast_string_to_int( to_type: &DataType, array: &ArrayRef, eval_mode: EvalMode, ) -> CometResult { let string_array = array .as_any() - .downcast_ref::>() + .downcast_ref::>() .expect("cast_string_to_int expected a string array"); let cast_array: ArrayRef = match to_type { @@ -250,7 +254,12 @@ fn cast_string_to_i16(str: &str, eval_mode: EvalMode) -> CometResult } fn cast_string_to_i32(str: &str, eval_mode: EvalMode) -> CometResult> { - Ok(do_cast_string_to_int::(str, eval_mode, "INT", i32::MIN)?) + Ok(do_cast_string_to_int::( + str, + eval_mode, + "INT", + i32::MIN, + )?) } fn cast_string_to_i64(str: &str, eval_mode: EvalMode) -> CometResult> { From 6cea67ff3fa55b1d42dfa8e6ebd3344744146195 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sat, 27 Apr 2024 11:52:10 -0600 Subject: [PATCH 32/46] add LargeUtf8 handling --- core/src/execution/datafusion/expressions/cast.rs | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/core/src/execution/datafusion/expressions/cast.rs b/core/src/execution/datafusion/expressions/cast.rs index 77279f128..8514f6242 100644 --- a/core/src/execution/datafusion/expressions/cast.rs +++ b/core/src/execution/datafusion/expressions/cast.rs @@ -138,14 +138,23 @@ impl Cast { DataType::Dictionary(key_type, value_type), DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64, ) if key_type.as_ref() == &DataType::Int32 - && value_type.as_ref() == &DataType::Utf8 => + && (value_type.as_ref() == &DataType::Utf8 + || value_type.as_ref() == &DataType::LargeUtf8) => { // TODO: we are unpacking a dictionary-encoded array and then performing // the cast. We could potentially improve performance here by casting the // dictionary values directly without unpacking the array first, although this // would add more complexity to the code let unpacked_array = Self::unpack_dict_string_array::(&array)?; - Self::cast_string_to_int::(to_type, &unpacked_array, self.eval_mode)? + match value_type.as_ref() { + DataType::Utf8 => { + Self::cast_string_to_int::(to_type, &unpacked_array, self.eval_mode)? + } + DataType::LargeUtf8 => { + Self::cast_string_to_int::(to_type, &unpacked_array, self.eval_mode)? + } + _ => unreachable!("invalid value type for dictionary-encoded string array"), + } } _ => { // when we have no Spark-specific casting we delegate to DataFusion From ff7056748609783c37c888c70e5df2eb77dde3aa Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sat, 27 Apr 2024 12:06:17 -0600 Subject: [PATCH 33/46] clippy --- core/src/execution/datafusion/expressions/cast.rs | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/core/src/execution/datafusion/expressions/cast.rs b/core/src/execution/datafusion/expressions/cast.rs index 8514f6242..5f88cefbb 100644 --- a/core/src/execution/datafusion/expressions/cast.rs +++ b/core/src/execution/datafusion/expressions/cast.rs @@ -263,12 +263,7 @@ fn cast_string_to_i16(str: &str, eval_mode: EvalMode) -> CometResult } fn cast_string_to_i32(str: &str, eval_mode: EvalMode) -> CometResult> { - Ok(do_cast_string_to_int::( - str, - eval_mode, - "INT", - i32::MIN, - )?) + do_cast_string_to_int::(str, eval_mode, "INT", i32::MIN) } fn cast_string_to_i64(str: &str, eval_mode: EvalMode) -> CometResult> { From b7b6bfb2e4370d96272949d892d7bbdd8bf5c34c Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sat, 27 Apr 2024 12:12:52 -0600 Subject: [PATCH 34/46] address feedback --- .../execution/datafusion/expressions/cast.rs | 21 ++++--------------- 1 file changed, 4 insertions(+), 17 deletions(-) diff --git a/core/src/execution/datafusion/expressions/cast.rs b/core/src/execution/datafusion/expressions/cast.rs index 5f88cefbb..a1b3686e5 100644 --- a/core/src/execution/datafusion/expressions/cast.rs +++ b/core/src/execution/datafusion/expressions/cast.rs @@ -24,14 +24,13 @@ use std::{ use crate::errors::{CometError, CometResult}; use arrow::{ - compute::{cast_with_options, take, CastOptions}, + compute::{cast_with_options, CastOptions}, record_batch::RecordBatch, util::display::FormatOptions, }; use arrow_array::{ - types::{ArrowDictionaryKeyType, Int16Type, Int32Type, Int64Type, Int8Type}, - Array, ArrayRef, BooleanArray, DictionaryArray, GenericStringArray, OffsetSizeTrait, - PrimitiveArray, + types::{Int16Type, Int32Type, Int64Type, Int8Type}, + Array, ArrayRef, BooleanArray, GenericStringArray, OffsetSizeTrait, PrimitiveArray, }; use arrow_schema::{DataType, Schema}; use datafusion::logical_expr::ColumnarValue; @@ -145,7 +144,7 @@ impl Cast { // the cast. We could potentially improve performance here by casting the // dictionary values directly without unpacking the array first, although this // would add more complexity to the code - let unpacked_array = Self::unpack_dict_string_array::(&array)?; + let unpacked_array = cast_with_options(&array, &DataType::Utf8, &CAST_OPTIONS)?; match value_type.as_ref() { DataType::Utf8 => { Self::cast_string_to_int::(to_type, &unpacked_array, self.eval_mode)? @@ -195,18 +194,6 @@ impl Cast { Ok(cast_array) } - fn unpack_dict_string_array( - array: &ArrayRef, - ) -> DataFusionResult { - let dict_array = array - .as_any() - .downcast_ref::>() - .expect("DictionaryArray"); - - let unpacked_array = take(dict_array.values().as_ref(), dict_array.keys(), None)?; - Ok(unpacked_array) - } - fn spark_cast_utf8_to_boolean( from: &dyn Array, eval_mode: EvalMode, From 2be0f9df9c371fd3908147d826732111aaf8a1a5 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sat, 27 Apr 2024 12:36:43 -0600 Subject: [PATCH 35/46] fix --- core/src/execution/datafusion/expressions/cast.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/core/src/execution/datafusion/expressions/cast.rs b/core/src/execution/datafusion/expressions/cast.rs index a1b3686e5..43f878d98 100644 --- a/core/src/execution/datafusion/expressions/cast.rs +++ b/core/src/execution/datafusion/expressions/cast.rs @@ -144,12 +144,15 @@ impl Cast { // the cast. We could potentially improve performance here by casting the // dictionary values directly without unpacking the array first, although this // would add more complexity to the code - let unpacked_array = cast_with_options(&array, &DataType::Utf8, &CAST_OPTIONS)?; match value_type.as_ref() { DataType::Utf8 => { + let unpacked_array = + cast_with_options(&array, &DataType::Utf8, &CAST_OPTIONS)?; Self::cast_string_to_int::(to_type, &unpacked_array, self.eval_mode)? } DataType::LargeUtf8 => { + let unpacked_array = + cast_with_options(&array, &DataType::LargeUtf8, &CAST_OPTIONS)?; Self::cast_string_to_int::(to_type, &unpacked_array, self.eval_mode)? } _ => unreachable!("invalid value type for dictionary-encoded string array"), From 659b191f1353246a9a8e33d8d693d5f595821b8c Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sun, 28 Apr 2024 16:31:25 -0600 Subject: [PATCH 36/46] Add criterion benchmark --- core/Cargo.toml | 3 + core/benches/cast.rs | 85 ++++++++++++++++++++++++++++ core/src/execution/datafusion/mod.rs | 2 +- 3 files changed, 89 insertions(+), 1 deletion(-) create mode 100644 core/benches/cast.rs diff --git a/core/Cargo.toml b/core/Cargo.toml index 5d1604952..d3416ce29 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -118,3 +118,6 @@ harness = false name = "row_columnar" harness = false +[[bench]] +name = "cast" +harness = false diff --git a/core/benches/cast.rs b/core/benches/cast.rs new file mode 100644 index 000000000..281fe82e2 --- /dev/null +++ b/core/benches/cast.rs @@ -0,0 +1,85 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_array::{builder::StringBuilder, RecordBatch}; +use arrow_schema::{DataType, Field, Schema}; +use comet::execution::datafusion::expressions::cast::{Cast, EvalMode}; +use criterion::{criterion_group, criterion_main, Criterion}; +use datafusion_physical_expr::{expressions::Column, PhysicalExpr}; +use std::sync::Arc; + +fn criterion_benchmark(c: &mut Criterion) { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Utf8, true)])); + let mut b = StringBuilder::new(); + for i in 0..1000 { + if i % 10 == 0 { + b.append_null(); + } else if i % 2 == 0 { + b.append_value(format!("{}", rand::random::())); + } else { + b.append_value(format!("{}", rand::random::())); + } + } + let array = b.finish(); + let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(array)]).unwrap(); + let expr = Arc::new(Column::new("a", 0)); + let timezone = "".to_string(); + let cast_string_to_i8 = Cast::new( + expr.clone(), + DataType::Int8, + EvalMode::Legacy, + timezone.clone(), + ); + let cast_string_to_i16 = Cast::new( + expr.clone(), + DataType::Int16, + EvalMode::Legacy, + timezone.clone(), + ); + let cast_string_to_i32 = Cast::new( + expr.clone(), + DataType::Int32, + EvalMode::Legacy, + timezone.clone(), + ); + let cast_string_to_i64 = Cast::new(expr, DataType::Int64, EvalMode::Legacy, timezone); + + let mut group = c.benchmark_group("cast"); + group.bench_function("cast_string_to_i8", |b| { + b.iter(|| cast_string_to_i8.evaluate(&batch).unwrap()); + }); + group.bench_function("cast_string_to_i16", |b| { + b.iter(|| cast_string_to_i16.evaluate(&batch).unwrap()); + }); + group.bench_function("cast_string_to_i32", |b| { + b.iter(|| cast_string_to_i32.evaluate(&batch).unwrap()); + }); + group.bench_function("cast_string_to_i64", |b| { + b.iter(|| cast_string_to_i64.evaluate(&batch).unwrap()); + }); +} + +fn config() -> Criterion { + Criterion::default() +} + +criterion_group! { + name = benches; + config = config(); + targets = criterion_benchmark +} +criterion_main!(benches); diff --git a/core/src/execution/datafusion/mod.rs b/core/src/execution/datafusion/mod.rs index c464eeed0..76f0b1c76 100644 --- a/core/src/execution/datafusion/mod.rs +++ b/core/src/execution/datafusion/mod.rs @@ -17,7 +17,7 @@ //! Native execution through DataFusion -mod expressions; +pub mod expressions; mod operators; pub mod planner; pub(crate) mod shuffle_writer; From 52de3f12ef63213e58e2f533e4a008ae3a0c7437 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sun, 28 Apr 2024 16:40:18 -0600 Subject: [PATCH 37/46] optimize cast implementation to avoid copying string to chars (~60% faster) --- .../execution/datafusion/expressions/cast.rs | 143 ++++++++++-------- 1 file changed, 83 insertions(+), 60 deletions(-) diff --git a/core/src/execution/datafusion/expressions/cast.rs b/core/src/execution/datafusion/expressions/cast.rs index 43f878d98..861b698a9 100644 --- a/core/src/execution/datafusion/expressions/cast.rs +++ b/core/src/execution/datafusion/expressions/cast.rs @@ -275,6 +275,14 @@ fn cast_string_to_int_with_range_check( } } +#[derive(PartialEq)] +enum State { + SkipLeadingWhiteSpace, + SkipTrailingWhiteSpace, + ParseSignAndDigits, + ParseFractionalDigits, +} + fn do_cast_string_to_int< T: Num + PartialOrd + Integer + CheckedSub + CheckedNeg + From + Copy, >( @@ -283,88 +291,103 @@ fn do_cast_string_to_int< type_name: &str, min_value: T, ) -> CometResult> { - let chars: Vec = str.chars().collect(); - let mut i = 0; - let mut end = chars.len(); - - // skip leading whitespace - while i < end && chars[i].is_whitespace() { - i += 1; - } - - // skip trailing whitespace - while end > i && chars[end - 1].is_whitespace() { - end -= 1; - } - - // check for empty string - if i == end { + let len = str.len(); + if len == 0 { return none_or_err(eval_mode, type_name, str); } - // skip + or - - let negative = chars[i] == '-'; - if negative || chars[i] == '+' { - i += 1; - if i == end { - return none_or_err(eval_mode, type_name, str); - } - } - let mut result: T = T::zero(); + let mut negative = false; let radix = T::from(10); let stop_value = min_value / radix; - while i < end { - let b = chars[i]; - i += 1; - - if b == '.' && eval_mode == EvalMode::Legacy { - // truncate decimal in legacy mode - break; + let mut state = State::SkipLeadingWhiteSpace; + let mut parsed_sign = false; + + for (i, ch) in str.char_indices() { + // skip leading whitespace + if state == State::SkipLeadingWhiteSpace { + if ch.is_whitespace() { + // consume this char + continue; + } + // change state and fall through to next section + state = State::ParseSignAndDigits; } - let digit = if b.is_ascii_digit() { - (b as u32) - ('0' as u32) - } else { - return none_or_err(eval_mode, type_name, str); - }; + if state == State::ParseSignAndDigits { + if !parsed_sign { + negative = ch == '-'; + let positive = ch == '+'; + parsed_sign = true; + if negative || positive { + // consume this char + continue; + } else if i + 1 == len { + return none_or_err(eval_mode, type_name, str); + } + } - // We are going to process the new digit and accumulate the result. However, before doing - // this, if the result is already smaller than the stopValue(Integer.MIN_VALUE / radix), - // then result * 10 will definitely be smaller than minValue, and we can stop - if result < stop_value { - return none_or_err(eval_mode, type_name, str); + if ch == '.' && eval_mode == EvalMode::Legacy { + // truncate decimal in legacy mode + state = State::ParseFractionalDigits; + // consume this char + continue; + } + + let digit = if ch.is_ascii_digit() { + (ch as u32) - ('0' as u32) + } else { + return none_or_err(eval_mode, type_name, str); + }; + + // We are going to process the new digit and accumulate the result. However, before + // doing this, if the result is already smaller than the + // stopValue(Integer.MIN_VALUE / radix), then result * 10 will definitely be + // smaller than minValue, and we can stop + if result < stop_value { + return none_or_err(eval_mode, type_name, str); + } + + // Since the previous result is less than or equal to stopValue(Integer.MIN_VALUE / + // radix), we can just use `result > 0` to check overflow. If result + // overflows, we should stop + let v = result * radix; + let digit = (digit as i32).into(); + match v.checked_sub(&digit) { + Some(x) if x <= T::zero() => result = x, + _ => { + return none_or_err(eval_mode, type_name, str); + } + } } - // Since the previous result is less than or equal to stopValue(Integer.MIN_VALUE / radix), - // we can just use `result > 0` to check overflow. If result overflows, we should stop - let v = result * radix; - let digit = (digit as i32).into(); - match v.checked_sub(&digit) { - Some(x) if x <= T::zero() => result = x, - _ => { + if state == State::ParseFractionalDigits { + // This is the case when we've encountered a decimal separator. The fractional + // part will not change the number, but we will verify that the fractional part + // is well-formed. + if ch.is_whitespace() { + // finished parsing fractional digits, now need to skip trailing whitespace + state = State::SkipTrailingWhiteSpace; + // consume this char + continue; + } + if !ch.is_ascii_digit() { return none_or_err(eval_mode, type_name, str); } } - } - // This is the case when we've encountered a decimal separator. The fractional - // part will not change the number, but we will verify that the fractional part - // is well-formed. - while i < end { - let b = chars[i]; - if !b.is_ascii_digit() { + // skip trailing whitespace + if state == State::SkipTrailingWhiteSpace && !ch.is_whitespace() { return none_or_err(eval_mode, type_name, str); } - i += 1; } if !negative { - if let Some(x) = result.checked_neg() { - if x < T::zero() { + if let Some(neg) = result.checked_neg() { + if neg < T::zero() { return none_or_err(eval_mode, type_name, str); } - result = x; + result = neg; } else { return none_or_err(eval_mode, type_name, str); } From 7196ea1ea8ea92bf25201362eea019e6d5bc36cc Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sun, 28 Apr 2024 16:59:56 -0600 Subject: [PATCH 38/46] fix regression --- .../execution/datafusion/expressions/cast.rs | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/core/src/execution/datafusion/expressions/cast.rs b/core/src/execution/datafusion/expressions/cast.rs index 861b698a9..f6edb1cdd 100644 --- a/core/src/execution/datafusion/expressions/cast.rs +++ b/core/src/execution/datafusion/expressions/cast.rs @@ -320,18 +320,23 @@ fn do_cast_string_to_int< let positive = ch == '+'; parsed_sign = true; if negative || positive { + if i + 1 == len { + // input string is just "+" or "-" + return none_or_err(eval_mode, type_name, str); + } // consume this char continue; - } else if i + 1 == len { - return none_or_err(eval_mode, type_name, str); } } - if ch == '.' && eval_mode == EvalMode::Legacy { - // truncate decimal in legacy mode - state = State::ParseFractionalDigits; - // consume this char - continue; + if ch == '.' { + if eval_mode == EvalMode::Legacy { + // truncate decimal in legacy mode + state = State::ParseFractionalDigits; + continue; + } else { + return none_or_err(eval_mode, type_name, str); + } } let digit = if ch.is_ascii_digit() { From b9ca5e2e270144d0ad5d7ae945aad488a09930dd Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 29 Apr 2024 17:01:06 -0600 Subject: [PATCH 39/46] Update core/src/execution/datafusion/expressions/cast.rs Co-authored-by: comphead --- core/src/execution/datafusion/expressions/cast.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/execution/datafusion/expressions/cast.rs b/core/src/execution/datafusion/expressions/cast.rs index f6edb1cdd..29d8181c5 100644 --- a/core/src/execution/datafusion/expressions/cast.rs +++ b/core/src/execution/datafusion/expressions/cast.rs @@ -292,7 +292,7 @@ fn do_cast_string_to_int< min_value: T, ) -> CometResult> { let len = str.len(); - if len == 0 { + if str.isEmpty() { return none_or_err(eval_mode, type_name, str); } From 4ab7c3713301fb8ea056f8eee6fd899c60ebba11 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 29 Apr 2024 17:04:48 -0600 Subject: [PATCH 40/46] fix error in suggested code change --- core/src/execution/datafusion/expressions/cast.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/execution/datafusion/expressions/cast.rs b/core/src/execution/datafusion/expressions/cast.rs index 29d8181c5..f8b9cdc19 100644 --- a/core/src/execution/datafusion/expressions/cast.rs +++ b/core/src/execution/datafusion/expressions/cast.rs @@ -292,7 +292,7 @@ fn do_cast_string_to_int< min_value: T, ) -> CometResult> { let len = str.len(); - if str.isEmpty() { + if str.is_empty() { return none_or_err(eval_mode, type_name, str); } From 9906c91e124cd4d95ddbf6553b81c1ebd42a2714 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 29 Apr 2024 17:13:19 -0600 Subject: [PATCH 41/46] address feedback --- core/src/execution/datafusion/expressions/cast.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/core/src/execution/datafusion/expressions/cast.rs b/core/src/execution/datafusion/expressions/cast.rs index f8b9cdc19..03e24d655 100644 --- a/core/src/execution/datafusion/expressions/cast.rs +++ b/core/src/execution/datafusion/expressions/cast.rs @@ -353,7 +353,7 @@ fn do_cast_string_to_int< return none_or_err(eval_mode, type_name, str); } - // Since the previous result is less than or equal to stopValue(Integer.MIN_VALUE / + // Since the previous result is greater than or equal to stopValue(Integer.MIN_VALUE / // radix), we can just use `result > 0` to check overflow. If result // overflows, we should stop let v = result * radix; @@ -402,6 +402,7 @@ fn do_cast_string_to_int< } /// Either return Ok(None) or Err(CometError::CastInvalidValue) depending on the evaluation mode +#[inline] fn none_or_err(eval_mode: EvalMode, type_name: &str, str: &str) -> CometResult> { match eval_mode { EvalMode::Ansi => Err(invalid_value(str, "STRING", type_name)), @@ -409,6 +410,7 @@ fn none_or_err(eval_mode: EvalMode, type_name: &str, str: &str) -> CometResul } } +#[inline] fn invalid_value(value: &str, from_type: &str, to_type: &str) -> CometError { CometError::CastInvalidValue { value: value.to_string(), From a8fada1bcdbca70f5dccfe5547522af9da7da65e Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 29 Apr 2024 17:15:11 -0600 Subject: [PATCH 42/46] Update core/src/execution/datafusion/expressions/cast.rs Co-authored-by: comphead --- core/src/execution/datafusion/expressions/cast.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/core/src/execution/datafusion/expressions/cast.rs b/core/src/execution/datafusion/expressions/cast.rs index 03e24d655..2d6035e5a 100644 --- a/core/src/execution/datafusion/expressions/cast.rs +++ b/core/src/execution/datafusion/expressions/cast.rs @@ -155,7 +155,10 @@ impl Cast { cast_with_options(&array, &DataType::LargeUtf8, &CAST_OPTIONS)?; Self::cast_string_to_int::(to_type, &unpacked_array, self.eval_mode)? } - _ => unreachable!("invalid value type for dictionary-encoded string array"), + dt => unreachable!( + "{}", + format!("invalid value type {dt} for dictionary-encoded string array") + ), } } _ => { From 68a5f5da042f854a2017bbf2f95e75231d97aa4d Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 29 Apr 2024 17:15:44 -0600 Subject: [PATCH 43/46] cargo fmt --- core/src/execution/datafusion/expressions/cast.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/core/src/execution/datafusion/expressions/cast.rs b/core/src/execution/datafusion/expressions/cast.rs index 2d6035e5a..29fb60b6a 100644 --- a/core/src/execution/datafusion/expressions/cast.rs +++ b/core/src/execution/datafusion/expressions/cast.rs @@ -155,10 +155,10 @@ impl Cast { cast_with_options(&array, &DataType::LargeUtf8, &CAST_OPTIONS)?; Self::cast_string_to_int::(to_type, &unpacked_array, self.eval_mode)? } - dt => unreachable!( - "{}", - format!("invalid value type {dt} for dictionary-encoded string array") - ), + dt => unreachable!( + "{}", + format!("invalid value type {dt} for dictionary-encoded string array") + ), } } _ => { From e5118db8521629cf66aa90cb2ba41fa7063b440b Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 30 Apr 2024 08:14:56 -0600 Subject: [PATCH 44/46] add comments with references to Spark code that this code is based on --- core/src/execution/datafusion/expressions/cast.rs | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/core/src/execution/datafusion/expressions/cast.rs b/core/src/execution/datafusion/expressions/cast.rs index 29fb60b6a..f5839fd4c 100644 --- a/core/src/execution/datafusion/expressions/cast.rs +++ b/core/src/execution/datafusion/expressions/cast.rs @@ -233,6 +233,7 @@ impl Cast { } } +/// Equivalent to org.apache.spark.unsafe.types.UTF8String.toByte fn cast_string_to_i8(str: &str, eval_mode: EvalMode) -> CometResult> { Ok(cast_string_to_int_with_range_check( str, @@ -244,6 +245,7 @@ fn cast_string_to_i8(str: &str, eval_mode: EvalMode) -> CometResult> .map(|v| v as i8)) } +/// Equivalent to org.apache.spark.unsafe.types.UTF8String.toShort fn cast_string_to_i16(str: &str, eval_mode: EvalMode) -> CometResult> { Ok(cast_string_to_int_with_range_check( str, @@ -255,10 +257,12 @@ fn cast_string_to_i16(str: &str, eval_mode: EvalMode) -> CometResult .map(|v| v as i16)) } +/// Equivalent to org.apache.spark.unsafe.types.UTF8String.toInt(IntWrapper intWrapper) fn cast_string_to_i32(str: &str, eval_mode: EvalMode) -> CometResult> { do_cast_string_to_int::(str, eval_mode, "INT", i32::MIN) } +/// Equivalent to org.apache.spark.unsafe.types.UTF8String.toLong(LongWrapper intWrapper) fn cast_string_to_i64(str: &str, eval_mode: EvalMode) -> CometResult> { do_cast_string_to_int::(str, eval_mode, "BIGINT", i64::MIN) } @@ -286,6 +290,9 @@ enum State { ParseFractionalDigits, } +/// Equivalent to +/// - org.apache.spark.unsafe.types.UTF8String.toInt(IntWrapper intWrapper, boolean allowDecimal) +/// - org.apache.spark.unsafe.types.UTF8String.toLong(LongWrapper longWrapper, boolean allowDecimal) fn do_cast_string_to_int< T: Num + PartialOrd + Integer + CheckedSub + CheckedNeg + From + Copy, >( From 3511d21762d030de97d58ea26fff765553748ee1 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 30 Apr 2024 22:16:03 -0600 Subject: [PATCH 45/46] fix merge conflict --- spark/src/test/scala/org/apache/comet/CometCastSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala index 549f5115f..11aed1b70 100644 --- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala @@ -489,7 +489,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { // test with hand-picked values castTest(castStringToIntegralInputs.toDF("a"), DataTypes.IntegerType) // fuzz test - castTest(generateStrings(numericPattern, 8).toDF("a"), DataTypes.ShortType) + castTest(generateStrings(numericPattern, 8).toDF("a"), DataTypes.IntegerType) } From c42e326aa5732eb96679e624e3c0dfd6c7a2539f Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 30 Apr 2024 22:18:45 -0600 Subject: [PATCH 46/46] fix merge conflict --- .../test/scala/org/apache/comet/CometCastSuite.scala | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala index 11aed1b70..1bddedde9 100644 --- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala @@ -471,29 +471,28 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { "9223372036854775808" // Long.MaxValue + 1 ) - ignore("cast StringType to ByteType") { + test("cast StringType to ByteType") { // test with hand-picked values castTest(castStringToIntegralInputs.toDF("a"), DataTypes.ByteType) // fuzz test - castTest(generateStrings(numericPattern, 8).toDF("a"), DataTypes.ByteType) + castTest(generateStrings(numericPattern, 4).toDF("a"), DataTypes.ByteType) } - ignore("cast StringType to ShortType") { + test("cast StringType to ShortType") { // test with hand-picked values castTest(castStringToIntegralInputs.toDF("a"), DataTypes.ShortType) // fuzz test castTest(generateStrings(numericPattern, 5).toDF("a"), DataTypes.ShortType) } - ignore("cast StringType to IntegerType") { + test("cast StringType to IntegerType") { // test with hand-picked values castTest(castStringToIntegralInputs.toDF("a"), DataTypes.IntegerType) // fuzz test castTest(generateStrings(numericPattern, 8).toDF("a"), DataTypes.IntegerType) } - - ignore("cast StringType to LongType") { + test("cast StringType to LongType") { // test with hand-picked values castTest(castStringToIntegralInputs.toDF("a"), DataTypes.LongType) // fuzz test