Skip to content

Commit

Permalink
fix: Enable cast string to int tests and fix compatibility issue (apa…
Browse files Browse the repository at this point in the history
…che#453)

* simplify cast string to int logic and use untrimmed string in error messages

* remove state enum

(cherry picked from commit de8fe45)
  • Loading branch information
andygrove authored and Huaxin Gao committed May 23, 2024
1 parent 7a9bb0b commit ab4f297
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 55 deletions.
57 changes: 11 additions & 46 deletions core/src/execution/datafusion/expressions/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ macro_rules! cast_utf8_to_int {
for i in 0..len {
if $array.is_null(i) {
cast_array.append_null()
} else if let Some(cast_value) = $cast_method($array.value(i).trim(), $eval_mode)? {
} else if let Some(cast_value) = $cast_method($array.value(i), $eval_mode)? {
cast_array.append_value(cast_value);
} else {
cast_array.append_null()
Expand Down Expand Up @@ -1010,14 +1010,6 @@ fn cast_string_to_int_with_range_check(
}
}

#[derive(PartialEq)]
enum State {
SkipLeadingWhiteSpace,
SkipTrailingWhiteSpace,
ParseSignAndDigits,
ParseFractionalDigits,
}

/// Equivalent to
/// - org.apache.spark.unsafe.types.UTF8String.toInt(IntWrapper intWrapper, boolean allowDecimal)
/// - org.apache.spark.unsafe.types.UTF8String.toLong(LongWrapper longWrapper, boolean allowDecimal)
Expand All @@ -1029,34 +1021,22 @@ fn do_cast_string_to_int<
type_name: &str,
min_value: T,
) -> CometResult<Option<T>> {
let len = str.len();
if str.is_empty() {
let trimmed_str = str.trim();
if trimmed_str.is_empty() {
return none_or_err(eval_mode, type_name, str);
}

let len = trimmed_str.len();
let mut result: T = T::zero();
let mut negative = false;
let radix = T::from(10);
let stop_value = min_value / radix;
let mut state = State::SkipLeadingWhiteSpace;
let mut parsed_sign = false;

for (i, ch) in str.char_indices() {
// skip leading whitespace
if state == State::SkipLeadingWhiteSpace {
if ch.is_whitespace() {
// consume this char
continue;
}
// change state and fall through to next section
state = State::ParseSignAndDigits;
}
let mut parse_sign_and_digits = true;

if state == State::ParseSignAndDigits {
if !parsed_sign {
for (i, ch) in trimmed_str.char_indices() {
if parse_sign_and_digits {
if i == 0 {
negative = ch == '-';
let positive = ch == '+';
parsed_sign = true;
if negative || positive {
if i + 1 == len {
// input string is just "+" or "-"
Expand All @@ -1070,7 +1050,7 @@ fn do_cast_string_to_int<
if ch == '.' {
if eval_mode == EvalMode::Legacy {
// truncate decimal in legacy mode
state = State::ParseFractionalDigits;
parse_sign_and_digits = false;
continue;
} else {
return none_or_err(eval_mode, type_name, str);
Expand Down Expand Up @@ -1102,27 +1082,12 @@ fn do_cast_string_to_int<
return none_or_err(eval_mode, type_name, str);
}
}
}

if state == State::ParseFractionalDigits {
// This is the case when we've encountered a decimal separator. The fractional
// part will not change the number, but we will verify that the fractional part
// is well-formed.
if ch.is_whitespace() {
// finished parsing fractional digits, now need to skip trailing whitespace
state = State::SkipTrailingWhiteSpace;
// consume this char
continue;
}
} else {
// make sure fractional digits are valid digits but ignore them
if !ch.is_ascii_digit() {
return none_or_err(eval_mode, type_name, str);
}
}

// skip trailing whitespace
if state == State::SkipTrailingWhiteSpace && !ch.is_whitespace() {
return none_or_err(eval_mode, type_name, str);
}
}

if !negative {
Expand Down
8 changes: 4 additions & 4 deletions docs/source/user-guide/compatibility.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,10 @@ The following cast operations are generally compatible with Spark except for the
| decimal | float | |
| decimal | double | |
| string | boolean | |
| string | byte | |
| string | short | |
| string | integer | |
| string | long | |
| string | binary | |
| date | string | |
| timestamp | long | |
Expand All @@ -125,10 +129,6 @@ The following cast operations are not compatible with Spark for all inputs and a
|-|-|-|
| integer | decimal | No overflow check |
| long | decimal | No overflow check |
| string | byte | Not all invalid inputs are detected |
| string | short | Not all invalid inputs are detected |
| string | integer | Not all invalid inputs are detected |
| string | long | Not all invalid inputs are detected |
| string | timestamp | Not all valid formats are supported |
| binary | string | Only works for binary data representing valid UTF-8 strings |

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ object CometCast {
Compatible()
case DataTypes.ByteType | DataTypes.ShortType | DataTypes.IntegerType |
DataTypes.LongType =>
Incompatible(Some("Not all invalid inputs are detected"))
Compatible()
case DataTypes.BinaryType =>
Compatible()
case DataTypes.FloatType | DataTypes.DoubleType =>
Expand Down
8 changes: 4 additions & 4 deletions spark/src/test/scala/org/apache/comet/CometCastSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -519,28 +519,28 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
"9223372036854775808" // Long.MaxValue + 1
)

ignore("cast StringType to ByteType") {
test("cast StringType to ByteType") {
// test with hand-picked values
castTest(castStringToIntegralInputs.toDF("a"), DataTypes.ByteType)
// fuzz test
castTest(gen.generateStrings(dataSize, numericPattern, 4).toDF("a"), DataTypes.ByteType)
}

ignore("cast StringType to ShortType") {
test("cast StringType to ShortType") {
// test with hand-picked values
castTest(castStringToIntegralInputs.toDF("a"), DataTypes.ShortType)
// fuzz test
castTest(gen.generateStrings(dataSize, numericPattern, 5).toDF("a"), DataTypes.ShortType)
}

ignore("cast StringType to IntegerType") {
test("cast StringType to IntegerType") {
// test with hand-picked values
castTest(castStringToIntegralInputs.toDF("a"), DataTypes.IntegerType)
// fuzz test
castTest(gen.generateStrings(dataSize, numericPattern, 8).toDF("a"), DataTypes.IntegerType)
}

ignore("cast StringType to LongType") {
test("cast StringType to LongType") {
// test with hand-picked values
castTest(castStringToIntegralInputs.toDF("a"), DataTypes.LongType)
// fuzz test
Expand Down

0 comments on commit ab4f297

Please sign in to comment.