diff --git a/core/src/execution/datafusion/expressions/cast.rs b/core/src/execution/datafusion/expressions/cast.rs index 35ab23a76..f68732fb1 100644 --- a/core/src/execution/datafusion/expressions/cast.rs +++ b/core/src/execution/datafusion/expressions/cast.rs @@ -82,7 +82,7 @@ macro_rules! cast_utf8_to_int { for i in 0..len { if $array.is_null(i) { cast_array.append_null() - } else if let Some(cast_value) = $cast_method($array.value(i).trim(), $eval_mode)? { + } else if let Some(cast_value) = $cast_method($array.value(i), $eval_mode)? { cast_array.append_value(cast_value); } else { cast_array.append_null() @@ -1010,14 +1010,6 @@ fn cast_string_to_int_with_range_check( } } -#[derive(PartialEq)] -enum State { - SkipLeadingWhiteSpace, - SkipTrailingWhiteSpace, - ParseSignAndDigits, - ParseFractionalDigits, -} - /// Equivalent to /// - org.apache.spark.unsafe.types.UTF8String.toInt(IntWrapper intWrapper, boolean allowDecimal) /// - org.apache.spark.unsafe.types.UTF8String.toLong(LongWrapper longWrapper, boolean allowDecimal) @@ -1029,34 +1021,22 @@ fn do_cast_string_to_int< type_name: &str, min_value: T, ) -> CometResult> { - let len = str.len(); - if str.is_empty() { + let trimmed_str = str.trim(); + if trimmed_str.is_empty() { return none_or_err(eval_mode, type_name, str); } - + let len = trimmed_str.len(); let mut result: T = T::zero(); let mut negative = false; let radix = T::from(10); let stop_value = min_value / radix; - let mut state = State::SkipLeadingWhiteSpace; - let mut parsed_sign = false; - - for (i, ch) in str.char_indices() { - // skip leading whitespace - if state == State::SkipLeadingWhiteSpace { - if ch.is_whitespace() { - // consume this char - continue; - } - // change state and fall through to next section - state = State::ParseSignAndDigits; - } + let mut parse_sign_and_digits = true; - if state == State::ParseSignAndDigits { - if !parsed_sign { + for (i, ch) in trimmed_str.char_indices() { + if parse_sign_and_digits { + if i == 0 { negative = ch == '-'; let positive = ch == '+'; - parsed_sign = true; if negative || positive { if i + 1 == len { // input string is just "+" or "-" @@ -1070,7 +1050,7 @@ fn do_cast_string_to_int< if ch == '.' { if eval_mode == EvalMode::Legacy { // truncate decimal in legacy mode - state = State::ParseFractionalDigits; + parse_sign_and_digits = false; continue; } else { return none_or_err(eval_mode, type_name, str); @@ -1102,27 +1082,12 @@ fn do_cast_string_to_int< return none_or_err(eval_mode, type_name, str); } } - } - - if state == State::ParseFractionalDigits { - // This is the case when we've encountered a decimal separator. The fractional - // part will not change the number, but we will verify that the fractional part - // is well-formed. - if ch.is_whitespace() { - // finished parsing fractional digits, now need to skip trailing whitespace - state = State::SkipTrailingWhiteSpace; - // consume this char - continue; - } + } else { + // make sure fractional digits are valid digits but ignore them if !ch.is_ascii_digit() { return none_or_err(eval_mode, type_name, str); } } - - // skip trailing whitespace - if state == State::SkipTrailingWhiteSpace && !ch.is_whitespace() { - return none_or_err(eval_mode, type_name, str); - } } if !negative { diff --git a/docs/source/user-guide/compatibility.md b/docs/source/user-guide/compatibility.md index 278edb848..a4ed9289f 100644 --- a/docs/source/user-guide/compatibility.md +++ b/docs/source/user-guide/compatibility.md @@ -110,6 +110,10 @@ The following cast operations are generally compatible with Spark except for the | decimal | float | | | decimal | double | | | string | boolean | | +| string | byte | | +| string | short | | +| string | integer | | +| string | long | | | string | binary | | | date | string | | | timestamp | long | | @@ -125,10 +129,6 @@ The following cast operations are not compatible with Spark for all inputs and a |-|-|-| | integer | decimal | No overflow check | | long | decimal | No overflow check | -| string | byte | Not all invalid inputs are detected | -| string | short | Not all invalid inputs are detected | -| string | integer | Not all invalid inputs are detected | -| string | long | Not all invalid inputs are detected | | string | timestamp | Not all valid formats are supported | | binary | string | Only works for binary data representing valid UTF-8 strings | diff --git a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala index 9c3695ba5..795bdb428 100644 --- a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala +++ b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala @@ -108,7 +108,7 @@ object CometCast { Compatible() case DataTypes.ByteType | DataTypes.ShortType | DataTypes.IntegerType | DataTypes.LongType => - Incompatible(Some("Not all invalid inputs are detected")) + Compatible() case DataTypes.BinaryType => Compatible() case DataTypes.FloatType | DataTypes.DoubleType => diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala index ea3355d05..8caba14c6 100644 --- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala @@ -519,28 +519,28 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { "9223372036854775808" // Long.MaxValue + 1 ) - ignore("cast StringType to ByteType") { + test("cast StringType to ByteType") { // test with hand-picked values castTest(castStringToIntegralInputs.toDF("a"), DataTypes.ByteType) // fuzz test castTest(gen.generateStrings(dataSize, numericPattern, 4).toDF("a"), DataTypes.ByteType) } - ignore("cast StringType to ShortType") { + test("cast StringType to ShortType") { // test with hand-picked values castTest(castStringToIntegralInputs.toDF("a"), DataTypes.ShortType) // fuzz test castTest(gen.generateStrings(dataSize, numericPattern, 5).toDF("a"), DataTypes.ShortType) } - ignore("cast StringType to IntegerType") { + test("cast StringType to IntegerType") { // test with hand-picked values castTest(castStringToIntegralInputs.toDF("a"), DataTypes.IntegerType) // fuzz test castTest(gen.generateStrings(dataSize, numericPattern, 8).toDF("a"), DataTypes.IntegerType) } - ignore("cast StringType to LongType") { + test("cast StringType to LongType") { // test with hand-picked values castTest(castStringToIntegralInputs.toDF("a"), DataTypes.LongType) // fuzz test