From 074fddd31e690f7ee4995e68cc0b72248cd70910 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 25 Apr 2024 08:51:07 -0600 Subject: [PATCH] use checked math operations --- .../execution/datafusion/expressions/cast.rs | 48 ++++++++++++++----- .../org/apache/comet/CometCastSuite.scala | 24 +++++++--- 2 files changed, 52 insertions(+), 20 deletions(-) diff --git a/core/src/execution/datafusion/expressions/cast.rs b/core/src/execution/datafusion/expressions/cast.rs index 2ad5a7eb7..07ea41139 100644 --- a/core/src/execution/datafusion/expressions/cast.rs +++ b/core/src/execution/datafusion/expressions/cast.rs @@ -135,7 +135,7 @@ impl Cast { ) if key_type.as_ref() == &DataType::Int32 && value_type.as_ref() == &DataType::Utf8 => { - // Note that we are unpacking a dictionary-encoded array and then performing + // TODO: we are unpacking a dictionary-encoded array and then performing // the cast. We could potentially improve performance here by casting the // dictionary values directly without unpacking the array first, although this // would add more complexity to the code @@ -328,10 +328,18 @@ impl CastStringToInt for CastStringToInt32 { self.reset(); return none_or_err(eval_mode, type_name, str); } - self.result = Some(self.result.unwrap_or(0) * self.radix - digit as i32); - if self.result.unwrap() > 0 { + let v = self.result.unwrap_or(0) * self.radix; + if let Some(x) = v.checked_sub(digit as i32) { + if x > 0 { + self.reset(); + return none_or_err(eval_mode, type_name, str); + } else { + self.result = Some(x); + } + } else { self.reset(); return none_or_err(eval_mode, type_name, str); + } Ok(()) } @@ -346,10 +354,13 @@ impl CastStringToInt for CastStringToInt32 { str: &str, negative: bool, ) -> CometResult<()> { - if self.result.is_some() && !negative { - self.result = Some(-self.result.unwrap()); - if self.result.unwrap() < 0 { - return none_or_err(eval_mode, type_name, str); + if !negative { + if let Some(r) = self.result { + let negated = r.checked_neg().unwrap_or(-1); + if negated < 0 { + return none_or_err(eval_mode, type_name, str); + } + self.result = Some(negated); } } Ok(()) @@ -384,10 +395,18 @@ impl CastStringToInt for CastStringToInt64 { self.reset(); return none_or_err(eval_mode, type_name, str); } - self.result = Some(self.result.unwrap_or(0) * self.radix - digit as i64); - if self.result.unwrap() > 0 { + let v = self.result.unwrap_or(0) * self.radix; + if let Some(x) = v.checked_sub(digit as i64) { + if x > 0 { + self.reset(); + return none_or_err(eval_mode, type_name, str); + } else { + self.result = Some(x); + } + } else { self.reset(); return none_or_err(eval_mode, type_name, str); + } Ok(()) } @@ -403,10 +422,13 @@ impl CastStringToInt for CastStringToInt64 { str: &str, negative: bool, ) -> CometResult<()> { - if self.result.is_some() && !negative { - self.result = Some(-self.result.unwrap()); - if self.result.unwrap() < 0 { - return none_or_err(eval_mode, type_name, str); + if !negative { + if let Some(r) = self.result { + let negated = r.checked_neg().unwrap_or(-1); + if negated < 0 { + return none_or_err(eval_mode, type_name, str); + } + self.result = Some(negated); } } Ok(()) diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala index 27eeb0324..59941b0c6 100644 --- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala @@ -40,9 +40,8 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { private val whitespaceChars = " \t\r\n" /** - * We use these characters to construct strings that potentially represent valid numbers - * such as `-12.34d` or `4e7`. Invalid numeric strings will also be generated, such as - * `+e.-d`. + * We use these characters to construct strings that potentially represent valid numbers such as + * `-12.34d` or `4e7`. Invalid numeric strings will also be generated, such as `+e.-d`. */ private val numericPattern = "0123456789def+-." + whitespaceChars @@ -73,9 +72,20 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { } val castStringToIntegralInputs = Seq( - "", ".", "+", "-", "+.", "=.", - "-0", "+1", "-1", ".2", "-.2", - "1e1", "1.1d", "1.1f", + "", + ".", + "+", + "-", + "+.", + "-.", + "-0", + "+1", + "-1", + ".2", + "-.2", + "1e1", + "1.1d", + "1.1f", Byte.MinValue.toString, (Byte.MinValue.toShort - 1).toString, Byte.MaxValue.toString, @@ -197,7 +207,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { val cometMessage = actual.getMessage .replace("Execution error: ", "") - assert(expected.getMessage == cometMessage) + assert(cometMessage == expected.getMessage) } else { // Spark 3.2 and 3.3 have a different error message format so we can't do a direct // comparison between Spark and Comet.