Skip to content

Commit

Permalink
use checked math operations
Browse files Browse the repository at this point in the history
  • Loading branch information
andygrove committed Apr 25, 2024
1 parent 5d0730f commit 074fddd
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 20 deletions.
48 changes: 35 additions & 13 deletions core/src/execution/datafusion/expressions/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ impl Cast {
) if key_type.as_ref() == &DataType::Int32
&& value_type.as_ref() == &DataType::Utf8 =>
{
// Note that we are unpacking a dictionary-encoded array and then performing
// TODO: we are unpacking a dictionary-encoded array and then performing
// the cast. We could potentially improve performance here by casting the
// dictionary values directly without unpacking the array first, although this
// would add more complexity to the code
Expand Down Expand Up @@ -328,10 +328,18 @@ impl CastStringToInt for CastStringToInt32 {
self.reset();
return none_or_err(eval_mode, type_name, str);
}
self.result = Some(self.result.unwrap_or(0) * self.radix - digit as i32);
if self.result.unwrap() > 0 {
let v = self.result.unwrap_or(0) * self.radix;
if let Some(x) = v.checked_sub(digit as i32) {
if x > 0 {
self.reset();
return none_or_err(eval_mode, type_name, str);
} else {
self.result = Some(x);
}
} else {
self.reset();
return none_or_err(eval_mode, type_name, str);

}
Ok(())
}
Expand All @@ -346,10 +354,13 @@ impl CastStringToInt for CastStringToInt32 {
str: &str,
negative: bool,
) -> CometResult<()> {
if self.result.is_some() && !negative {
self.result = Some(-self.result.unwrap());
if self.result.unwrap() < 0 {
return none_or_err(eval_mode, type_name, str);
if !negative {
if let Some(r) = self.result {
let negated = r.checked_neg().unwrap_or(-1);
if negated < 0 {
return none_or_err(eval_mode, type_name, str);
}
self.result = Some(negated);
}
}
Ok(())
Expand Down Expand Up @@ -384,10 +395,18 @@ impl CastStringToInt for CastStringToInt64 {
self.reset();
return none_or_err(eval_mode, type_name, str);
}
self.result = Some(self.result.unwrap_or(0) * self.radix - digit as i64);
if self.result.unwrap() > 0 {
let v = self.result.unwrap_or(0) * self.radix;
if let Some(x) = v.checked_sub(digit as i64) {
if x > 0 {
self.reset();
return none_or_err(eval_mode, type_name, str);
} else {
self.result = Some(x);
}
} else {
self.reset();
return none_or_err(eval_mode, type_name, str);

}
Ok(())
}
Expand All @@ -403,10 +422,13 @@ impl CastStringToInt for CastStringToInt64 {
str: &str,
negative: bool,
) -> CometResult<()> {
if self.result.is_some() && !negative {
self.result = Some(-self.result.unwrap());
if self.result.unwrap() < 0 {
return none_or_err(eval_mode, type_name, str);
if !negative {
if let Some(r) = self.result {
let negated = r.checked_neg().unwrap_or(-1);
if negated < 0 {
return none_or_err(eval_mode, type_name, str);
}
self.result = Some(negated);
}
}
Ok(())
Expand Down
24 changes: 17 additions & 7 deletions spark/src/test/scala/org/apache/comet/CometCastSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,8 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
private val whitespaceChars = " \t\r\n"

/**
* We use these characters to construct strings that potentially represent valid numbers
* such as `-12.34d` or `4e7`. Invalid numeric strings will also be generated, such as
* `+e.-d`.
* We use these characters to construct strings that potentially represent valid numbers such as
* `-12.34d` or `4e7`. Invalid numeric strings will also be generated, such as `+e.-d`.
*/
private val numericPattern = "0123456789def+-." + whitespaceChars

Expand Down Expand Up @@ -73,9 +72,20 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}

val castStringToIntegralInputs = Seq(
"", ".", "+", "-", "+.", "=.",
"-0", "+1", "-1", ".2", "-.2",
"1e1", "1.1d", "1.1f",
"",
".",
"+",
"-",
"+.",
"-.",
"-0",
"+1",
"-1",
".2",
"-.2",
"1e1",
"1.1d",
"1.1f",
Byte.MinValue.toString,
(Byte.MinValue.toShort - 1).toString,
Byte.MaxValue.toString,
Expand Down Expand Up @@ -197,7 +207,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
val cometMessage = actual.getMessage
.replace("Execution error: ", "")

assert(expected.getMessage == cometMessage)
assert(cometMessage == expected.getMessage)
} else {
// Spark 3.2 and 3.3 have a different error message format so we can't do a direct
// comparison between Spark and Comet.
Expand Down

0 comments on commit 074fddd

Please sign in to comment.