Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Enable cast string to int tests and fix compatibility issue #453

Merged
merged 2 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 11 additions & 46 deletions core/src/execution/datafusion/expressions/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ macro_rules! cast_utf8_to_int {
for i in 0..len {
if $array.is_null(i) {
cast_array.append_null()
} else if let Some(cast_value) = $cast_method($array.value(i).trim(), $eval_mode)? {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We originally trimmed the input strings here, so error messages did not have access to the original inputs

} else if let Some(cast_value) = $cast_method($array.value(i), $eval_mode)? {
cast_array.append_value(cast_value);
} else {
cast_array.append_null()
Expand Down Expand Up @@ -1010,14 +1010,6 @@ fn cast_string_to_int_with_range_check(
}
}

#[derive(PartialEq)]
enum State {
SkipLeadingWhiteSpace,
SkipTrailingWhiteSpace,
ParseSignAndDigits,
ParseFractionalDigits,
}

/// Equivalent to
/// - org.apache.spark.unsafe.types.UTF8String.toInt(IntWrapper intWrapper, boolean allowDecimal)
/// - org.apache.spark.unsafe.types.UTF8String.toLong(LongWrapper longWrapper, boolean allowDecimal)
Expand All @@ -1029,34 +1021,22 @@ fn do_cast_string_to_int<
type_name: &str,
min_value: T,
) -> CometResult<Option<T>> {
let len = str.len();
if str.is_empty() {
let trimmed_str = str.trim();
if trimmed_str.is_empty() {
return none_or_err(eval_mode, type_name, str);
}

let len = trimmed_str.len();
let mut result: T = T::zero();
let mut negative = false;
let radix = T::from(10);
let stop_value = min_value / radix;
let mut state = State::SkipLeadingWhiteSpace;
let mut parsed_sign = false;

for (i, ch) in str.char_indices() {
// skip leading whitespace
if state == State::SkipLeadingWhiteSpace {
if ch.is_whitespace() {
// consume this char
continue;
}
// change state and fall through to next section
state = State::ParseSignAndDigits;
}
let mut parse_sign_and_digits = true;

if state == State::ParseSignAndDigits {
if !parsed_sign {
for (i, ch) in trimmed_str.char_indices() {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We process the trimmed string here and we no longer need the logic for skipping leading and trailing whitespace

if parse_sign_and_digits {
if i == 0 {
negative = ch == '-';
let positive = ch == '+';
parsed_sign = true;
if negative || positive {
if i + 1 == len {
// input string is just "+" or "-"
Expand All @@ -1070,7 +1050,7 @@ fn do_cast_string_to_int<
if ch == '.' {
if eval_mode == EvalMode::Legacy {
// truncate decimal in legacy mode
state = State::ParseFractionalDigits;
parse_sign_and_digits = false;
continue;
} else {
return none_or_err(eval_mode, type_name, str);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

error messages refer to the original input string, not the trimmed version

Expand Down Expand Up @@ -1102,27 +1082,12 @@ fn do_cast_string_to_int<
return none_or_err(eval_mode, type_name, str);
}
}
}

if state == State::ParseFractionalDigits {
// This is the case when we've encountered a decimal separator. The fractional
// part will not change the number, but we will verify that the fractional part
// is well-formed.
if ch.is_whitespace() {
// finished parsing fractional digits, now need to skip trailing whitespace
state = State::SkipTrailingWhiteSpace;
// consume this char
continue;
}
} else {
// make sure fractional digits are valid digits but ignore them
if !ch.is_ascii_digit() {
return none_or_err(eval_mode, type_name, str);
}
}

// skip trailing whitespace
if state == State::SkipTrailingWhiteSpace && !ch.is_whitespace() {
return none_or_err(eval_mode, type_name, str);
}
}

if !negative {
Expand Down
8 changes: 4 additions & 4 deletions docs/source/user-guide/compatibility.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,10 @@ The following cast operations are generally compatible with Spark except for the
| decimal | float | |
| decimal | double | |
| string | boolean | |
| string | byte | |
| string | short | |
| string | integer | |
| string | long | |
| string | binary | |
| date | string | |
| timestamp | long | |
Expand All @@ -125,10 +129,6 @@ The following cast operations are not compatible with Spark for all inputs and a
|-|-|-|
| integer | decimal | No overflow check |
| long | decimal | No overflow check |
| string | byte | Not all invalid inputs are detected |
| string | short | Not all invalid inputs are detected |
| string | integer | Not all invalid inputs are detected |
| string | long | Not all invalid inputs are detected |
| string | timestamp | Not all valid formats are supported |
| binary | string | Only works for binary data representing valid UTF-8 strings |

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ object CometCast {
Compatible()
case DataTypes.ByteType | DataTypes.ShortType | DataTypes.IntegerType |
DataTypes.LongType =>
Incompatible(Some("Not all invalid inputs are detected"))
Compatible()
case DataTypes.BinaryType =>
Compatible()
case DataTypes.FloatType | DataTypes.DoubleType =>
Expand Down
8 changes: 4 additions & 4 deletions spark/src/test/scala/org/apache/comet/CometCastSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -519,28 +519,28 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
"9223372036854775808" // Long.MaxValue + 1
)

ignore("cast StringType to ByteType") {
test("cast StringType to ByteType") {
// test with hand-picked values
castTest(castStringToIntegralInputs.toDF("a"), DataTypes.ByteType)
// fuzz test
castTest(gen.generateStrings(dataSize, numericPattern, 4).toDF("a"), DataTypes.ByteType)
}

ignore("cast StringType to ShortType") {
test("cast StringType to ShortType") {
// test with hand-picked values
castTest(castStringToIntegralInputs.toDF("a"), DataTypes.ShortType)
// fuzz test
castTest(gen.generateStrings(dataSize, numericPattern, 5).toDF("a"), DataTypes.ShortType)
}

ignore("cast StringType to IntegerType") {
test("cast StringType to IntegerType") {
// test with hand-picked values
castTest(castStringToIntegralInputs.toDF("a"), DataTypes.IntegerType)
// fuzz test
castTest(gen.generateStrings(dataSize, numericPattern, 8).toDF("a"), DataTypes.IntegerType)
}

ignore("cast StringType to LongType") {
test("cast StringType to LongType") {
// test with hand-picked values
castTest(castStringToIntegralInputs.toDF("a"), DataTypes.LongType)
// fuzz test
Expand Down
Loading