diff --git a/core/src/execution/datafusion/expressions/cast.rs b/core/src/execution/datafusion/expressions/cast.rs index 7b125211f..5e4296c14 100644 --- a/core/src/execution/datafusion/expressions/cast.rs +++ b/core/src/execution/datafusion/expressions/cast.rs @@ -249,15 +249,11 @@ fn cast_string_to_i16(str: &str, eval_mode: EvalMode) -> CometResult } fn cast_string_to_i32(str: &str, eval_mode: EvalMode) -> CometResult> { - let mut accum = CastStringToInt32::default(); - do_cast_string_to_int(&mut accum, str, eval_mode, "INT")?; - Ok(accum.result) + Ok(do_cast_string_to_i32(str, eval_mode, "INT", i32::MIN)?.map(|n| n as i32)) } fn cast_string_to_i64(str: &str, eval_mode: EvalMode) -> CometResult> { - let mut accum = CastStringToInt64::default(); - do_cast_string_to_int(&mut accum, str, eval_mode, "BIGINT")?; - Ok(accum.result) + do_cast_string_to_i64(str, eval_mode, "BIGINT", i64::MIN) } fn cast_string_to_int_with_range_check( @@ -267,188 +263,116 @@ fn cast_string_to_int_with_range_check( min: i32, max: i32, ) -> CometResult> { - let mut accum = CastStringToInt32::default(); - do_cast_string_to_int(&mut accum, str, eval_mode, type_name)?; - match accum.result { + match do_cast_string_to_i32(str, eval_mode, type_name, i32::MIN)? { None => Ok(None), - Some(v) if v >= min && v <= max => Ok(Some(v)), + Some(v) if v >= min && v <= max => Ok(Some(v as i32)), _ if eval_mode == EvalMode::Ansi => Err(invalid_value(str, "STRING", type_name)), _ => Ok(None), } } -/// We support parsing strings to i32 and i64 to match Spark's logic. Support for i8 and i16 is -/// implemented by first parsing as i32 and then downcasting. The CastStringToInt trait is -/// introduced so that we can have the parsing logic delegate either to an i32 or i64 accumulator -/// and avoid the need to use macros here. -trait CastStringToInt { - fn accumulate( - &mut self, - eval_mode: EvalMode, - type_name: &str, - str: &str, - digit: u32, - ) -> CometResult<()>; - - fn reset(&mut self); - - fn finish( - &mut self, - eval_mode: EvalMode, - type_name: &str, - str: &str, - negative: bool, - ) -> CometResult<()>; -} -struct CastStringToInt32 { - negative: bool, - result: Option, - radix: i32, -} +fn do_cast_string_to_i32( + str: &str, + eval_mode: EvalMode, + type_name: &str, + min_value: i32, +) -> CometResult> { + let chars: Vec = str.chars().collect(); + let mut i = 0; + let mut end = chars.len(); -impl Default for CastStringToInt32 { - fn default() -> Self { - Self { - negative: false, - result: Some(0), - radix: 10, - } + // skip leading whitespace + while i < end && chars[i].is_whitespace() { + i += 1; } -} -impl CastStringToInt for CastStringToInt32 { - fn accumulate( - &mut self, - eval_mode: EvalMode, - type_name: &str, - str: &str, - digit: u32, - ) -> CometResult<()> { - // We are going to process the new digit and accumulate the result. However, before doing - // this, if the result is already smaller than the stopValue(Integer.MIN_VALUE / radix), - // then result * 10 will definitely be smaller than minValue, and we can stop - if let Some(r) = self.result { - let stop_value = i32::MIN / self.radix; - if r < stop_value { - self.reset(); - return none_or_err(eval_mode, type_name, str); - } - } - // Since the previous result is less than or equal to stopValue(Integer.MIN_VALUE / radix), - // we can just use `result > 0` to check overflow. If result overflows, we should stop - let v = self.result.unwrap_or(0) * self.radix; - match v.checked_sub(digit as i32) { - Some(x) if x <= 0 => self.result = Some(x), - _ => { - self.reset(); - return none_or_err(eval_mode, type_name, str); - } - } - Ok(()) + // skip trailing whitespace + while end > i && chars[end - 1].is_whitespace() { + end -= 1; } - fn reset(&mut self) { - self.result = None; + + // check for empty string + if i == end { + return none_or_err(eval_mode, type_name, str); } - fn finish( - &mut self, - eval_mode: EvalMode, - type_name: &str, - str: &str, - negative: bool, - ) -> CometResult<()> { - if !negative { - if let Some(r) = self.result { - let negated = r.checked_neg().unwrap_or(-1); - if negated < 0 { - self.reset(); - return none_or_err(eval_mode, type_name, str); - } - self.result = Some(negated); - } + // skip + or - + let negative = chars[i] == '-'; + if negative || chars[i] == '+' { + i += 1; + if i == end { + return none_or_err(eval_mode, type_name, str); } - Ok(()) } -} -struct CastStringToInt64 { - negative: bool, - result: Option, - radix: i64, -} + let mut result = 0; + let radix = 10; + let stop_value = min_value / radix; + while i < end { + let b = chars[i]; + i += 1; -impl Default for CastStringToInt64 { - fn default() -> Self { - Self { - negative: false, - result: Some(0), - radix: 10, + if b == '.' && eval_mode == EvalMode::Legacy { + // truncate decimal in legacy mode + break; } - } -} -impl CastStringToInt for CastStringToInt64 { - fn accumulate( - &mut self, - eval_mode: EvalMode, - type_name: &str, - str: &str, - digit: u32, - ) -> CometResult<()> { + let digit = if b.is_ascii_digit() { + (b as u32) - ('0' as u32) + } else { + return none_or_err(eval_mode, type_name, str); + }; + // We are going to process the new digit and accumulate the result. However, before doing // this, if the result is already smaller than the stopValue(Integer.MIN_VALUE / radix), // then result * 10 will definitely be smaller than minValue, and we can stop - if let Some(r) = self.result { - let stop_value = i64::MIN / self.radix; - if r < stop_value { - self.reset(); - return none_or_err(eval_mode, type_name, str); - } + if result < stop_value { + return none_or_err(eval_mode, type_name, str); } + // Since the previous result is less than or equal to stopValue(Integer.MIN_VALUE / radix), // we can just use `result > 0` to check overflow. If result overflows, we should stop - let v = self.result.unwrap_or(0) * self.radix; - match v.checked_sub(digit as i64) { - Some(x) if x <= 0 => self.result = Some(x), + let v = result * radix; + match v.checked_sub(digit as i32) { + Some(x) if x <= 0 => result = x, _ => { - self.reset(); return none_or_err(eval_mode, type_name, str); } } - Ok(()) } - fn reset(&mut self) { - self.result = None; + // This is the case when we've encountered a decimal separator. The fractional + // part will not change the number, but we will verify that the fractional part + // is well-formed. + while i < end { + let b = chars[i]; + if !b.is_ascii_digit() { + return none_or_err(eval_mode, type_name, str); + } + i += 1; } - fn finish( - &mut self, - eval_mode: EvalMode, - type_name: &str, - str: &str, - negative: bool, - ) -> CometResult<()> { - if !negative { - if let Some(r) = self.result { - let negated = r.checked_neg().unwrap_or(-1); - if negated < 0 { - self.reset(); - return none_or_err(eval_mode, type_name, str); - } - self.result = Some(negated); + if !negative { + if let Some(x) = result.checked_neg() { + if x < 0 { + return none_or_err(eval_mode, type_name, str); } + result = x; + } else { + return none_or_err(eval_mode, type_name, str); } - Ok(()) } + + Ok(Some(result)) } -fn do_cast_string_to_int( - accumulator: &mut dyn CastStringToInt, +/// This is a copy of do_cast_string_to_i32 but with the type changed to i64 +fn do_cast_string_to_i64( str: &str, eval_mode: EvalMode, type_name: &str, -) -> CometResult<()> { + min_value: i64, +) -> CometResult> { let chars: Vec = str.chars().collect(); let mut i = 0; let mut end = chars.len(); @@ -465,7 +389,6 @@ fn do_cast_string_to_int( // check for empty string if i == end { - accumulator.reset(); return none_or_err(eval_mode, type_name, str); } @@ -474,11 +397,13 @@ fn do_cast_string_to_int( if negative || chars[i] == '+' { i += 1; if i == end { - accumulator.reset(); return none_or_err(eval_mode, type_name, str); } } + let mut result = 0; + let radix = 10; + let stop_value = min_value / radix; while i < end { let b = chars[i]; i += 1; @@ -491,11 +416,25 @@ fn do_cast_string_to_int( let digit = if b.is_ascii_digit() { (b as u32) - ('0' as u32) } else { - accumulator.reset(); return none_or_err(eval_mode, type_name, str); }; - accumulator.accumulate(eval_mode, type_name, str, digit)?; + // We are going to process the new digit and accumulate the result. However, before doing + // this, if the result is already smaller than the stopValue(Integer.MIN_VALUE / radix), + // then result * 10 will definitely be smaller than minValue, and we can stop + if result < stop_value { + return none_or_err(eval_mode, type_name, str); + } + + // Since the previous result is less than or equal to stopValue(Integer.MIN_VALUE / radix), + // we can just use `result > 0` to check overflow. If result overflows, we should stop + let v = result * radix; + match v.checked_sub(digit as i64) { + Some(x) if x <= 0 => result = x, + _ => { + return none_or_err(eval_mode, type_name, str); + } + } } // This is the case when we've encountered a decimal separator. The fractional @@ -504,22 +443,30 @@ fn do_cast_string_to_int( while i < end { let b = chars[i]; if !b.is_ascii_digit() { - accumulator.reset(); return none_or_err(eval_mode, type_name, str); } i += 1; } - accumulator.finish(eval_mode, type_name, str, negative)?; + if !negative { + if let Some(x) = result.checked_neg() { + if x < 0 { + return none_or_err(eval_mode, type_name, str); + } + result = x; + } else { + return none_or_err(eval_mode, type_name, str); + } + } - Ok(()) + Ok(Some(result)) } /// Either return Ok(None) or Err(CometError::CastInvalidValue) depending on the evaluation mode -fn none_or_err(eval_mode: EvalMode, type_name: &str, str: &str) -> CometResult<()> { +fn none_or_err(eval_mode: EvalMode, type_name: &str, str: &str) -> CometResult> { match eval_mode { EvalMode::Ansi => Err(invalid_value(str, "STRING", type_name)), - _ => Ok(()), + _ => Ok(None), } }