diff --git a/.github/workflows/spark_sql_test_ansi.yml b/.github/workflows/spark_sql_test_ansi.yml new file mode 100644 index 000000000..5c5d28589 --- /dev/null +++ b/.github/workflows/spark_sql_test_ansi.yml @@ -0,0 +1,81 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +name: Spark SQL Tests (ANSI mode) + +concurrency: + group: ${{ github.repository }}-${{ github.head_ref || github.sha }}-${{ github.workflow }} + cancel-in-progress: true + +on: + # enable the following once Ansi support is completed + # push: + # paths-ignore: + # - "doc/**" + # - "**.md" + # pull_request: + # paths-ignore: + # - "doc/**" + # - "**.md" + + # manual trigger ONLY + # https://docs.github.com/en/actions/managing-workflow-runs/manually-running-a-workflow + workflow_dispatch: + +env: + RUST_VERSION: nightly + +jobs: + spark-sql-catalyst: + strategy: + matrix: + os: [ubuntu-latest] + java-version: [11] + spark-version: [{short: '3.4', full: '3.4.2'}] + module: + - {name: "catalyst", args1: "catalyst/test", args2: ""} + - {name: "sql/core-1", args1: "", args2: sql/testOnly * -- -l org.apache.spark.tags.ExtendedSQLTest -l org.apache.spark.tags.SlowSQLTest} + - {name: "sql/core-2", args1: "", args2: "sql/testOnly * -- -n org.apache.spark.tags.ExtendedSQLTest"} + - {name: "sql/core-3", args1: "", args2: "sql/testOnly * -- -n org.apache.spark.tags.SlowSQLTest"} + - {name: "sql/hive-1", args1: "", args2: "hive/testOnly * -- -l org.apache.spark.tags.ExtendedHiveTest -l org.apache.spark.tags.SlowHiveTest"} + - {name: "sql/hive-2", args1: "", args2: "hive/testOnly * -- -n org.apache.spark.tags.ExtendedHiveTest"} + - {name: "sql/hive-3", args1: "", args2: "hive/testOnly * -- -n org.apache.spark.tags.SlowHiveTest"} + fail-fast: false + name: spark-sql-${{ matrix.module.name }}/${{ matrix.os }}/spark-${{ matrix.spark-version.full }}/java-${{ matrix.java-version }} + runs-on: ${{ matrix.os }} + container: + image: amd64/rust + steps: + - uses: actions/checkout@v4 + - name: Setup Rust & Java toolchain + uses: ./.github/actions/setup-builder + with: + rust-version: ${{env.RUST_VERSION}} + jdk-version: ${{ matrix.java-version }} + - name: Setup Spark + uses: ./.github/actions/setup-spark-builder + with: + spark-version: ${{ matrix.spark-version.full }} + spark-short-version: ${{ matrix.spark-version.short }} + comet-version: '0.1.0-SNAPSHOT' # TODO: get this from pom.xml + - name: Run Spark tests + run: | + cd apache-spark + ENABLE_COMET=true ENABLE_COMET_ANSI_MODE=true build/sbt ${{ matrix.module.args1 }} "${{ matrix.module.args2 }}" + env: + LC_ALL: "C.UTF-8" + diff --git a/README.md b/README.md index b793a7904..8329ca3a7 100644 --- a/README.md +++ b/README.md @@ -64,7 +64,7 @@ Linux, Apple OSX (Intel and M1) ## Requirements - Apache Spark 3.2, 3.3, or 3.4 -- JDK 8 and up +- JDK 8, 11 and 17 (JDK 11 recommended because Spark 3.2 doesn't support 17) - GLIBC 2.17 (Centos 7) and up ## Getting started diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala b/common/src/main/scala/org/apache/comet/CometConf.scala index fda906d03..38dd07fc0 100644 --- a/common/src/main/scala/org/apache/comet/CometConf.scala +++ b/common/src/main/scala/org/apache/comet/CometConf.scala @@ -19,11 +19,9 @@ package org.apache.comet -import java.io.{BufferedOutputStream, FileOutputStream} import java.util.concurrent.TimeUnit import scala.collection.mutable.ListBuffer -import scala.io.Source import org.apache.spark.network.util.ByteUnit import org.apache.spark.network.util.JavaUtils @@ -376,12 +374,14 @@ object CometConf { .booleanConf .createWithDefault(false) - val COMET_CAST_STRING_TO_TIMESTAMP: ConfigEntry[Boolean] = conf( - "spark.comet.cast.stringToTimestamp") - .doc( - "Comet is not currently fully compatible with Spark when casting from String to Timestamp.") - .booleanConf - .createWithDefault(false) + val COMET_CAST_ALLOW_INCOMPATIBLE: ConfigEntry[Boolean] = + conf("spark.comet.cast.allowIncompatible") + .doc( + "Comet is not currently fully compatible with Spark for all cast operations. " + + "Set this config to true to allow them anyway. See compatibility guide " + + "for more information.") + .booleanConf + .createWithDefault(false) } @@ -625,36 +625,3 @@ private[comet] case class ConfigBuilder(key: String) { private object ConfigEntry { val UNDEFINED = "" } - -/** - * Utility for generating markdown documentation from the configs. - * - * This is invoked when running `mvn clean package -DskipTests`. - */ -object CometConfGenerateDocs { - def main(args: Array[String]): Unit = { - if (args.length != 2) { - // scalastyle:off println - println("Missing arguments for template file and output file") - // scalastyle:on println - sys.exit(-1) - } - val templateFilename = args.head - val outputFilename = args(1) - val w = new BufferedOutputStream(new FileOutputStream(outputFilename)) - for (line <- Source.fromFile(templateFilename).getLines()) { - if (line.trim == "") { - val publicConfigs = CometConf.allConfs.filter(_.isPublic) - val confs = publicConfigs.sortBy(_.key) - w.write("| Config | Description | Default Value |\n".getBytes) - w.write("|--------|-------------|---------------|\n".getBytes) - for (conf <- confs) { - w.write(s"| ${conf.key} | ${conf.doc.trim} | ${conf.defaultValueString} |\n".getBytes) - } - } else { - w.write(s"${line.trim}\n".getBytes) - } - } - w.close() - } -} diff --git a/core/Cargo.toml b/core/Cargo.toml index b09b0ea7f..cbca7f629 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -118,3 +118,6 @@ harness = false name = "row_columnar" harness = false +[[bench]] +name = "cast" +harness = false diff --git a/core/benches/cast.rs b/core/benches/cast.rs new file mode 100644 index 000000000..281fe82e2 --- /dev/null +++ b/core/benches/cast.rs @@ -0,0 +1,85 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_array::{builder::StringBuilder, RecordBatch}; +use arrow_schema::{DataType, Field, Schema}; +use comet::execution::datafusion::expressions::cast::{Cast, EvalMode}; +use criterion::{criterion_group, criterion_main, Criterion}; +use datafusion_physical_expr::{expressions::Column, PhysicalExpr}; +use std::sync::Arc; + +fn criterion_benchmark(c: &mut Criterion) { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Utf8, true)])); + let mut b = StringBuilder::new(); + for i in 0..1000 { + if i % 10 == 0 { + b.append_null(); + } else if i % 2 == 0 { + b.append_value(format!("{}", rand::random::())); + } else { + b.append_value(format!("{}", rand::random::())); + } + } + let array = b.finish(); + let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(array)]).unwrap(); + let expr = Arc::new(Column::new("a", 0)); + let timezone = "".to_string(); + let cast_string_to_i8 = Cast::new( + expr.clone(), + DataType::Int8, + EvalMode::Legacy, + timezone.clone(), + ); + let cast_string_to_i16 = Cast::new( + expr.clone(), + DataType::Int16, + EvalMode::Legacy, + timezone.clone(), + ); + let cast_string_to_i32 = Cast::new( + expr.clone(), + DataType::Int32, + EvalMode::Legacy, + timezone.clone(), + ); + let cast_string_to_i64 = Cast::new(expr, DataType::Int64, EvalMode::Legacy, timezone); + + let mut group = c.benchmark_group("cast"); + group.bench_function("cast_string_to_i8", |b| { + b.iter(|| cast_string_to_i8.evaluate(&batch).unwrap()); + }); + group.bench_function("cast_string_to_i16", |b| { + b.iter(|| cast_string_to_i16.evaluate(&batch).unwrap()); + }); + group.bench_function("cast_string_to_i32", |b| { + b.iter(|| cast_string_to_i32.evaluate(&batch).unwrap()); + }); + group.bench_function("cast_string_to_i64", |b| { + b.iter(|| cast_string_to_i64.evaluate(&batch).unwrap()); + }); +} + +fn config() -> Criterion { + Criterion::default() +} + +criterion_group! { + name = benches; + config = config(); + targets = criterion_benchmark +} +criterion_main!(benches); diff --git a/core/src/execution/datafusion/expressions/cast.rs b/core/src/execution/datafusion/expressions/cast.rs index 10079855d..45859c5fb 100644 --- a/core/src/execution/datafusion/expressions/cast.rs +++ b/core/src/execution/datafusion/expressions/cast.rs @@ -17,7 +17,7 @@ use std::{ any::Any, - fmt::{Display, Formatter}, + fmt::{Debug, Display, Formatter}, hash::{Hash, Hasher}, sync::Arc, }; @@ -25,14 +25,22 @@ use std::{ use crate::errors::{CometError, CometResult}; use arrow::{ compute::{cast_with_options, CastOptions}, + datatypes::TimestampMicrosecondType, record_batch::RecordBatch, util::display::FormatOptions, }; -use arrow_array::{Array, ArrayRef, BooleanArray, GenericStringArray, OffsetSizeTrait}; +use arrow_array::{ + types::{Int16Type, Int32Type, Int64Type, Int8Type}, + Array, ArrayRef, BooleanArray, Float32Array, Float64Array, GenericStringArray, OffsetSizeTrait, + PrimitiveArray, +}; use arrow_schema::{DataType, Schema}; +use chrono::{TimeZone, Timelike}; use datafusion::logical_expr::ColumnarValue; use datafusion_common::{internal_err, Result as DataFusionResult, ScalarValue}; use datafusion_physical_expr::PhysicalExpr; +use num::{traits::CheckedNeg, CheckedSub, Integer, Num}; +use regex::Regex; use crate::execution::datafusion::expressions::utils::{ array_with_timezone, down_cast_any_ref, spark_cast, @@ -64,6 +72,110 @@ pub struct Cast { pub timezone: String, } +macro_rules! cast_utf8_to_int { + ($array:expr, $eval_mode:expr, $array_type:ty, $cast_method:ident) => {{ + let len = $array.len(); + let mut cast_array = PrimitiveArray::<$array_type>::builder(len); + for i in 0..len { + if $array.is_null(i) { + cast_array.append_null() + } else if let Some(cast_value) = $cast_method($array.value(i).trim(), $eval_mode)? { + cast_array.append_value(cast_value); + } else { + cast_array.append_null() + } + } + let result: CometResult = Ok(Arc::new(cast_array.finish()) as ArrayRef); + result + }}; +} + +macro_rules! cast_utf8_to_timestamp { + ($array:expr, $eval_mode:expr, $array_type:ty, $cast_method:ident) => {{ + let len = $array.len(); + let mut cast_array = PrimitiveArray::<$array_type>::builder(len).with_timezone("UTC"); + for i in 0..len { + if $array.is_null(i) { + cast_array.append_null() + } else if let Ok(Some(cast_value)) = $cast_method($array.value(i).trim(), $eval_mode) { + cast_array.append_value(cast_value); + } else { + cast_array.append_null() + } + } + let result: ArrayRef = Arc::new(cast_array.finish()) as ArrayRef; + result + }}; +} + +macro_rules! cast_float_to_string { + ($from:expr, $eval_mode:expr, $type:ty, $output_type:ty, $offset_type:ty) => {{ + + fn cast( + from: &dyn Array, + _eval_mode: EvalMode, + ) -> CometResult + where + OffsetSize: OffsetSizeTrait, { + let array = from.as_any().downcast_ref::<$output_type>().unwrap(); + + // If the absolute number is less than 10,000,000 and greater or equal than 0.001, the + // result is expressed without scientific notation with at least one digit on either side of + // the decimal point. Otherwise, Spark uses a mantissa followed by E and an + // exponent. The mantissa has an optional leading minus sign followed by one digit to the + // left of the decimal point, and the minimal number of digits greater than zero to the + // right. The exponent has and optional leading minus sign. + // source: https://docs.databricks.com/en/sql/language-manual/functions/cast.html + + const LOWER_SCIENTIFIC_BOUND: $type = 0.001; + const UPPER_SCIENTIFIC_BOUND: $type = 10000000.0; + + let output_array = array + .iter() + .map(|value| match value { + Some(value) if value == <$type>::INFINITY => Ok(Some("Infinity".to_string())), + Some(value) if value == <$type>::NEG_INFINITY => Ok(Some("-Infinity".to_string())), + Some(value) + if (value.abs() < UPPER_SCIENTIFIC_BOUND + && value.abs() >= LOWER_SCIENTIFIC_BOUND) + || value.abs() == 0.0 => + { + let trailing_zero = if value.fract() == 0.0 { ".0" } else { "" }; + + Ok(Some(format!("{value}{trailing_zero}"))) + } + Some(value) + if value.abs() >= UPPER_SCIENTIFIC_BOUND + || value.abs() < LOWER_SCIENTIFIC_BOUND => + { + let formatted = format!("{value:E}"); + + if formatted.contains(".") { + Ok(Some(formatted)) + } else { + // `formatted` is already in scientific notation and can be split up by E + // in order to add the missing trailing 0 which gets removed for numbers with a fraction of 0.0 + let prepare_number: Vec<&str> = formatted.split("E").collect(); + + let coefficient = prepare_number[0]; + + let exponent = prepare_number[1]; + + Ok(Some(format!("{coefficient}.0E{exponent}"))) + } + } + Some(value) => Ok(Some(value.to_string())), + _ => Ok(None), + }) + .collect::, CometError>>()?; + + Ok(Arc::new(output_array)) + } + + cast::<$offset_type>($from, $eval_mode) + }}; +} + impl Cast { pub fn new( child: Arc, @@ -103,10 +215,138 @@ impl Cast { (DataType::LargeUtf8, DataType::Boolean) => { Self::spark_cast_utf8_to_boolean::(&array, self.eval_mode)? } - _ => cast_with_options(&array, to_type, &CAST_OPTIONS)?, + (DataType::Utf8, DataType::Timestamp(_, _)) => { + Self::cast_string_to_timestamp(&array, to_type, self.eval_mode)? + } + ( + DataType::Utf8, + DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64, + ) => Self::cast_string_to_int::(to_type, &array, self.eval_mode)?, + ( + DataType::LargeUtf8, + DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64, + ) => Self::cast_string_to_int::(to_type, &array, self.eval_mode)?, + ( + DataType::Dictionary(key_type, value_type), + DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64, + ) if key_type.as_ref() == &DataType::Int32 + && (value_type.as_ref() == &DataType::Utf8 + || value_type.as_ref() == &DataType::LargeUtf8) => + { + // TODO: we are unpacking a dictionary-encoded array and then performing + // the cast. We could potentially improve performance here by casting the + // dictionary values directly without unpacking the array first, although this + // would add more complexity to the code + match value_type.as_ref() { + DataType::Utf8 => { + let unpacked_array = + cast_with_options(&array, &DataType::Utf8, &CAST_OPTIONS)?; + Self::cast_string_to_int::(to_type, &unpacked_array, self.eval_mode)? + } + DataType::LargeUtf8 => { + let unpacked_array = + cast_with_options(&array, &DataType::LargeUtf8, &CAST_OPTIONS)?; + Self::cast_string_to_int::(to_type, &unpacked_array, self.eval_mode)? + } + dt => unreachable!( + "{}", + format!("invalid value type {dt} for dictionary-encoded string array") + ), + } + } + (DataType::Float64, DataType::Utf8) => { + Self::spark_cast_float64_to_utf8::(&array, self.eval_mode)? + } + (DataType::Float64, DataType::LargeUtf8) => { + Self::spark_cast_float64_to_utf8::(&array, self.eval_mode)? + } + (DataType::Float32, DataType::Utf8) => { + Self::spark_cast_float32_to_utf8::(&array, self.eval_mode)? + } + (DataType::Float32, DataType::LargeUtf8) => { + Self::spark_cast_float32_to_utf8::(&array, self.eval_mode)? + } + _ => { + // when we have no Spark-specific casting we delegate to DataFusion + cast_with_options(&array, to_type, &CAST_OPTIONS)? + } }; - let result = spark_cast(cast_result, from_type, to_type); - Ok(result) + Ok(spark_cast(cast_result, from_type, to_type)) + } + + fn cast_string_to_int( + to_type: &DataType, + array: &ArrayRef, + eval_mode: EvalMode, + ) -> CometResult { + let string_array = array + .as_any() + .downcast_ref::>() + .expect("cast_string_to_int expected a string array"); + + let cast_array: ArrayRef = match to_type { + DataType::Int8 => { + cast_utf8_to_int!(string_array, eval_mode, Int8Type, cast_string_to_i8)? + } + DataType::Int16 => { + cast_utf8_to_int!(string_array, eval_mode, Int16Type, cast_string_to_i16)? + } + DataType::Int32 => { + cast_utf8_to_int!(string_array, eval_mode, Int32Type, cast_string_to_i32)? + } + DataType::Int64 => { + cast_utf8_to_int!(string_array, eval_mode, Int64Type, cast_string_to_i64)? + } + dt => unreachable!( + "{}", + format!("invalid integer type {dt} in cast from string") + ), + }; + Ok(cast_array) + } + + fn cast_string_to_timestamp( + array: &ArrayRef, + to_type: &DataType, + eval_mode: EvalMode, + ) -> CometResult { + let string_array = array + .as_any() + .downcast_ref::>() + .expect("Expected a string array"); + + let cast_array: ArrayRef = match to_type { + DataType::Timestamp(_, _) => { + cast_utf8_to_timestamp!( + string_array, + eval_mode, + TimestampMicrosecondType, + timestamp_parser + ) + } + _ => unreachable!("Invalid data type {:?} in cast from string", to_type), + }; + Ok(cast_array) + } + + fn spark_cast_float64_to_utf8( + from: &dyn Array, + _eval_mode: EvalMode, + ) -> CometResult + where + OffsetSize: OffsetSizeTrait, + { + cast_float_to_string!(from, _eval_mode, f64, Float64Array, OffsetSize) + } + + fn spark_cast_float32_to_utf8( + from: &dyn Array, + _eval_mode: EvalMode, + ) -> CometResult + where + OffsetSize: OffsetSizeTrait, + { + cast_float_to_string!(from, _eval_mode, f32, Float32Array, OffsetSize) } fn spark_cast_utf8_to_boolean( @@ -142,6 +382,202 @@ impl Cast { } } +/// Equivalent to org.apache.spark.unsafe.types.UTF8String.toByte +fn cast_string_to_i8(str: &str, eval_mode: EvalMode) -> CometResult> { + Ok(cast_string_to_int_with_range_check( + str, + eval_mode, + "TINYINT", + i8::MIN as i32, + i8::MAX as i32, + )? + .map(|v| v as i8)) +} + +/// Equivalent to org.apache.spark.unsafe.types.UTF8String.toShort +fn cast_string_to_i16(str: &str, eval_mode: EvalMode) -> CometResult> { + Ok(cast_string_to_int_with_range_check( + str, + eval_mode, + "SMALLINT", + i16::MIN as i32, + i16::MAX as i32, + )? + .map(|v| v as i16)) +} + +/// Equivalent to org.apache.spark.unsafe.types.UTF8String.toInt(IntWrapper intWrapper) +fn cast_string_to_i32(str: &str, eval_mode: EvalMode) -> CometResult> { + do_cast_string_to_int::(str, eval_mode, "INT", i32::MIN) +} + +/// Equivalent to org.apache.spark.unsafe.types.UTF8String.toLong(LongWrapper intWrapper) +fn cast_string_to_i64(str: &str, eval_mode: EvalMode) -> CometResult> { + do_cast_string_to_int::(str, eval_mode, "BIGINT", i64::MIN) +} + +fn cast_string_to_int_with_range_check( + str: &str, + eval_mode: EvalMode, + type_name: &str, + min: i32, + max: i32, +) -> CometResult> { + match do_cast_string_to_int(str, eval_mode, type_name, i32::MIN)? { + None => Ok(None), + Some(v) if v >= min && v <= max => Ok(Some(v)), + _ if eval_mode == EvalMode::Ansi => Err(invalid_value(str, "STRING", type_name)), + _ => Ok(None), + } +} + +#[derive(PartialEq)] +enum State { + SkipLeadingWhiteSpace, + SkipTrailingWhiteSpace, + ParseSignAndDigits, + ParseFractionalDigits, +} + +/// Equivalent to +/// - org.apache.spark.unsafe.types.UTF8String.toInt(IntWrapper intWrapper, boolean allowDecimal) +/// - org.apache.spark.unsafe.types.UTF8String.toLong(LongWrapper longWrapper, boolean allowDecimal) +fn do_cast_string_to_int< + T: Num + PartialOrd + Integer + CheckedSub + CheckedNeg + From + Copy, +>( + str: &str, + eval_mode: EvalMode, + type_name: &str, + min_value: T, +) -> CometResult> { + let len = str.len(); + if str.is_empty() { + return none_or_err(eval_mode, type_name, str); + } + + let mut result: T = T::zero(); + let mut negative = false; + let radix = T::from(10); + let stop_value = min_value / radix; + let mut state = State::SkipLeadingWhiteSpace; + let mut parsed_sign = false; + + for (i, ch) in str.char_indices() { + // skip leading whitespace + if state == State::SkipLeadingWhiteSpace { + if ch.is_whitespace() { + // consume this char + continue; + } + // change state and fall through to next section + state = State::ParseSignAndDigits; + } + + if state == State::ParseSignAndDigits { + if !parsed_sign { + negative = ch == '-'; + let positive = ch == '+'; + parsed_sign = true; + if negative || positive { + if i + 1 == len { + // input string is just "+" or "-" + return none_or_err(eval_mode, type_name, str); + } + // consume this char + continue; + } + } + + if ch == '.' { + if eval_mode == EvalMode::Legacy { + // truncate decimal in legacy mode + state = State::ParseFractionalDigits; + continue; + } else { + return none_or_err(eval_mode, type_name, str); + } + } + + let digit = if ch.is_ascii_digit() { + (ch as u32) - ('0' as u32) + } else { + return none_or_err(eval_mode, type_name, str); + }; + + // We are going to process the new digit and accumulate the result. However, before + // doing this, if the result is already smaller than the + // stopValue(Integer.MIN_VALUE / radix), then result * 10 will definitely be + // smaller than minValue, and we can stop + if result < stop_value { + return none_or_err(eval_mode, type_name, str); + } + + // Since the previous result is greater than or equal to stopValue(Integer.MIN_VALUE / + // radix), we can just use `result > 0` to check overflow. If result + // overflows, we should stop + let v = result * radix; + let digit = (digit as i32).into(); + match v.checked_sub(&digit) { + Some(x) if x <= T::zero() => result = x, + _ => { + return none_or_err(eval_mode, type_name, str); + } + } + } + + if state == State::ParseFractionalDigits { + // This is the case when we've encountered a decimal separator. The fractional + // part will not change the number, but we will verify that the fractional part + // is well-formed. + if ch.is_whitespace() { + // finished parsing fractional digits, now need to skip trailing whitespace + state = State::SkipTrailingWhiteSpace; + // consume this char + continue; + } + if !ch.is_ascii_digit() { + return none_or_err(eval_mode, type_name, str); + } + } + + // skip trailing whitespace + if state == State::SkipTrailingWhiteSpace && !ch.is_whitespace() { + return none_or_err(eval_mode, type_name, str); + } + } + + if !negative { + if let Some(neg) = result.checked_neg() { + if neg < T::zero() { + return none_or_err(eval_mode, type_name, str); + } + result = neg; + } else { + return none_or_err(eval_mode, type_name, str); + } + } + + Ok(Some(result)) +} + +/// Either return Ok(None) or Err(CometError::CastInvalidValue) depending on the evaluation mode +#[inline] +fn none_or_err(eval_mode: EvalMode, type_name: &str, str: &str) -> CometResult> { + match eval_mode { + EvalMode::Ansi => Err(invalid_value(str, "STRING", type_name)), + _ => Ok(None), + } +} + +#[inline] +fn invalid_value(value: &str, from_type: &str, to_type: &str) -> CometError { + CometError::CastInvalidValue { + value: value.to_string(), + from_type: from_type.to_string(), + to_type: to_type.to_string(), + } +} + impl Display for Cast { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!( @@ -222,3 +658,298 @@ impl PhysicalExpr for Cast { self.hash(&mut s); } } + +fn timestamp_parser(value: &str, eval_mode: EvalMode) -> CometResult> { + let value = value.trim(); + if value.is_empty() { + return Ok(None); + } + // Define regex patterns and corresponding parsing functions + let patterns = &[ + ( + Regex::new(r"^\d{4}$").unwrap(), + parse_str_to_year_timestamp as fn(&str) -> CometResult>, + ), + ( + Regex::new(r"^\d{4}-\d{2}$").unwrap(), + parse_str_to_month_timestamp, + ), + ( + Regex::new(r"^\d{4}-\d{2}-\d{2}$").unwrap(), + parse_str_to_day_timestamp, + ), + ( + Regex::new(r"^\d{4}-\d{2}-\d{2}T\d{1,2}$").unwrap(), + parse_str_to_hour_timestamp, + ), + ( + Regex::new(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}$").unwrap(), + parse_str_to_minute_timestamp, + ), + ( + Regex::new(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}$").unwrap(), + parse_str_to_second_timestamp, + ), + ( + Regex::new(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{1,6}$").unwrap(), + parse_str_to_microsecond_timestamp, + ), + ( + Regex::new(r"^T\d{1,2}$").unwrap(), + parse_str_to_time_only_timestamp, + ), + ]; + + let mut timestamp = None; + + // Iterate through patterns and try matching + for (pattern, parse_func) in patterns { + if pattern.is_match(value) { + timestamp = parse_func(value)?; + break; + } + } + + if timestamp.is_none() { + if eval_mode == EvalMode::Ansi { + return Err(CometError::CastInvalidValue { + value: value.to_string(), + from_type: "STRING".to_string(), + to_type: "TIMESTAMP".to_string(), + }); + } else { + return Ok(None); + } + } + + match timestamp { + Some(ts) => Ok(Some(ts)), + None => Err(CometError::Internal( + "Failed to parse timestamp".to_string(), + )), + } +} + +fn parse_ymd_timestamp(year: i32, month: u32, day: u32) -> CometResult> { + let datetime = chrono::Utc.with_ymd_and_hms(year, month, day, 0, 0, 0); + + // Check if datetime is not None + let utc_datetime = match datetime.single() { + Some(dt) => dt.with_timezone(&chrono::Utc), + None => { + return Err(CometError::Internal( + "Failed to parse timestamp".to_string(), + )); + } + }; + + Ok(Some(utc_datetime.timestamp_micros())) +} + +fn parse_hms_timestamp( + year: i32, + month: u32, + day: u32, + hour: u32, + minute: u32, + second: u32, + microsecond: u32, +) -> CometResult> { + let datetime = chrono::Utc.with_ymd_and_hms(year, month, day, hour, minute, second); + + // Check if datetime is not None + let utc_datetime = match datetime.single() { + Some(dt) => dt + .with_timezone(&chrono::Utc) + .with_nanosecond(microsecond * 1000), + None => { + return Err(CometError::Internal( + "Failed to parse timestamp".to_string(), + )); + } + }; + + let result = match utc_datetime { + Some(dt) => dt.timestamp_micros(), + None => { + return Err(CometError::Internal( + "Failed to parse timestamp".to_string(), + )); + } + }; + + Ok(Some(result)) +} + +fn get_timestamp_values(value: &str, timestamp_type: &str) -> CometResult> { + let values: Vec<_> = value + .split(|c| c == 'T' || c == '-' || c == ':' || c == '.') + .collect(); + let year = values[0].parse::().unwrap_or_default(); + let month = values.get(1).map_or(1, |m| m.parse::().unwrap_or(1)); + let day = values.get(2).map_or(1, |d| d.parse::().unwrap_or(1)); + let hour = values.get(3).map_or(0, |h| h.parse::().unwrap_or(0)); + let minute = values.get(4).map_or(0, |m| m.parse::().unwrap_or(0)); + let second = values.get(5).map_or(0, |s| s.parse::().unwrap_or(0)); + let microsecond = values.get(6).map_or(0, |ms| ms.parse::().unwrap_or(0)); + + match timestamp_type { + "year" => parse_ymd_timestamp(year, 1, 1), + "month" => parse_ymd_timestamp(year, month, 1), + "day" => parse_ymd_timestamp(year, month, day), + "hour" => parse_hms_timestamp(year, month, day, hour, 0, 0, 0), + "minute" => parse_hms_timestamp(year, month, day, hour, minute, 0, 0), + "second" => parse_hms_timestamp(year, month, day, hour, minute, second, 0), + "microsecond" => parse_hms_timestamp(year, month, day, hour, minute, second, microsecond), + _ => Err(CometError::CastInvalidValue { + value: value.to_string(), + from_type: "STRING".to_string(), + to_type: "TIMESTAMP".to_string(), + }), + } +} + +fn parse_str_to_year_timestamp(value: &str) -> CometResult> { + get_timestamp_values(value, "year") +} + +fn parse_str_to_month_timestamp(value: &str) -> CometResult> { + get_timestamp_values(value, "month") +} + +fn parse_str_to_day_timestamp(value: &str) -> CometResult> { + get_timestamp_values(value, "day") +} + +fn parse_str_to_hour_timestamp(value: &str) -> CometResult> { + get_timestamp_values(value, "hour") +} + +fn parse_str_to_minute_timestamp(value: &str) -> CometResult> { + get_timestamp_values(value, "minute") +} + +fn parse_str_to_second_timestamp(value: &str) -> CometResult> { + get_timestamp_values(value, "second") +} + +fn parse_str_to_microsecond_timestamp(value: &str) -> CometResult> { + get_timestamp_values(value, "microsecond") +} + +fn parse_str_to_time_only_timestamp(value: &str) -> CometResult> { + let values: Vec<&str> = value.split('T').collect(); + let time_values: Vec = values[1] + .split(':') + .map(|v| v.parse::().unwrap_or(0)) + .collect(); + + let datetime = chrono::Utc::now(); + let timestamp = datetime + .with_hour(time_values.first().copied().unwrap_or_default()) + .and_then(|dt| dt.with_minute(*time_values.get(1).unwrap_or(&0))) + .and_then(|dt| dt.with_second(*time_values.get(2).unwrap_or(&0))) + .and_then(|dt| dt.with_nanosecond(*time_values.get(3).unwrap_or(&0) * 1_000)) + .map(|dt| dt.to_utc().timestamp_micros()) + .unwrap_or_default(); + + Ok(Some(timestamp)) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::datatypes::TimestampMicrosecondType; + use arrow_array::StringArray; + use arrow_schema::TimeUnit; + + #[test] + fn timestamp_parser_test() { + // write for all formats + assert_eq!( + timestamp_parser("2020", EvalMode::Legacy).unwrap(), + Some(1577836800000000) // this is in milliseconds + ); + assert_eq!( + timestamp_parser("2020-01", EvalMode::Legacy).unwrap(), + Some(1577836800000000) + ); + assert_eq!( + timestamp_parser("2020-01-01", EvalMode::Legacy).unwrap(), + Some(1577836800000000) + ); + assert_eq!( + timestamp_parser("2020-01-01T12", EvalMode::Legacy).unwrap(), + Some(1577880000000000) + ); + assert_eq!( + timestamp_parser("2020-01-01T12:34", EvalMode::Legacy).unwrap(), + Some(1577882040000000) + ); + assert_eq!( + timestamp_parser("2020-01-01T12:34:56", EvalMode::Legacy).unwrap(), + Some(1577882096000000) + ); + assert_eq!( + timestamp_parser("2020-01-01T12:34:56.123456", EvalMode::Legacy).unwrap(), + Some(1577882096123456) + ); + // assert_eq!( + // timestamp_parser("T2", EvalMode::Legacy).unwrap(), + // Some(1714356000000000) // this value needs to change everyday. + // ); + } + + #[test] + fn test_cast_string_to_timestamp() { + let array: ArrayRef = Arc::new(StringArray::from(vec![ + Some("2020-01-01T12:34:56.123456"), + Some("T2"), + ])); + + let string_array = array + .as_any() + .downcast_ref::>() + .expect("Expected a string array"); + + let eval_mode = EvalMode::Legacy; + let result = cast_utf8_to_timestamp!( + &string_array, + eval_mode, + TimestampMicrosecondType, + timestamp_parser + ); + + assert_eq!( + result.data_type(), + &DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into())) + ); + assert_eq!(result.len(), 2); + } + + #[test] + fn test_cast_string_as_i8() { + // basic + assert_eq!( + cast_string_to_i8("127", EvalMode::Legacy).unwrap(), + Some(127_i8) + ); + assert_eq!(cast_string_to_i8("128", EvalMode::Legacy).unwrap(), None); + assert!(cast_string_to_i8("128", EvalMode::Ansi).is_err()); + // decimals + assert_eq!( + cast_string_to_i8("0.2", EvalMode::Legacy).unwrap(), + Some(0_i8) + ); + assert_eq!( + cast_string_to_i8(".", EvalMode::Legacy).unwrap(), + Some(0_i8) + ); + // TRY should always return null for decimals + assert_eq!(cast_string_to_i8("0.2", EvalMode::Try).unwrap(), None); + assert_eq!(cast_string_to_i8(".", EvalMode::Try).unwrap(), None); + // ANSI mode should throw error on decimal + assert!(cast_string_to_i8("0.2", EvalMode::Ansi).is_err()); + assert!(cast_string_to_i8(".", EvalMode::Ansi).is_err()); + } +} diff --git a/core/src/execution/datafusion/mod.rs b/core/src/execution/datafusion/mod.rs index c464eeed0..76f0b1c76 100644 --- a/core/src/execution/datafusion/mod.rs +++ b/core/src/execution/datafusion/mod.rs @@ -17,7 +17,7 @@ //! Native execution through DataFusion -mod expressions; +pub mod expressions; mod operators; pub mod planner; pub(crate) mod shuffle_writer; diff --git a/docs/source/contributor-guide/development.md b/docs/source/contributor-guide/development.md index 63146c191..356c81b33 100644 --- a/docs/source/contributor-guide/development.md +++ b/docs/source/contributor-guide/development.md @@ -40,7 +40,7 @@ A few common commands are specified in project's `Makefile`: - `make`: compile the entire project, but don't run tests - `make test-rust`: compile the project and run tests in Rust side -- `make test-java`: compile the project and run tests in Java side +- `make test-jvm`: compile the project and run tests in Java side - `make test`: compile the project and run tests in both Rust and Java side. - `make release`: compile the project and creates a release build. This diff --git a/docs/source/user-guide/compatibility-template.md b/docs/source/user-guide/compatibility-template.md new file mode 100644 index 000000000..deaca2d24 --- /dev/null +++ b/docs/source/user-guide/compatibility-template.md @@ -0,0 +1,50 @@ + + +# Compatibility Guide + +Comet aims to provide consistent results with the version of Apache Spark that is being used. + +This guide offers information about areas of functionality where there are known differences. + +## ANSI mode + +Comet currently ignores ANSI mode in most cases, and therefore can produce different results than Spark. By default, +Comet will fall back to Spark if ANSI mode is enabled. To enable Comet to accelerate queries when ANSI mode is enabled, +specify `spark.comet.ansi.enabled=true` in the Spark configuration. Comet's ANSI support is experimental and should not +be used in production. + +There is an [epic](https://github.com/apache/datafusion-comet/issues/313) where we are tracking the work to fully implement ANSI support. + +## Cast + +Cast operations in Comet fall into three levels of support: + +- **Compatible**: The results match Apache Spark +- **Incompatible**: The results may match Apache Spark for some inputs, but there are known issues where some inputs + will result in incorrect results or exceptions. The query stage will fall back to Spark by default. Setting + `spark.comet.cast.allowIncompatible=true` will allow all incompatible casts to run natively in Comet, but this is not + recommended for production use. +- **Unsupported**: Comet does not provide a native version of this cast expression and the query stage will fall back to + Spark. + +The following table shows the current cast operations supported by Comet. Any cast that does not appear in this +table (such as those involving complex types and timestamp_ntz, for example) are not supported by Comet. + + diff --git a/docs/source/user-guide/compatibility.md b/docs/source/user-guide/compatibility.md index b4b4c92eb..9a2478d37 100644 --- a/docs/source/user-guide/compatibility.md +++ b/docs/source/user-guide/compatibility.md @@ -1,20 +1,20 @@ # Compatibility Guide @@ -34,13 +34,126 @@ There is an [epic](https://github.com/apache/datafusion-comet/issues/313) where ## Cast -Comet currently delegates to Apache DataFusion for most cast operations, and this means that the behavior is not -guaranteed to be consistent with Spark. +Cast operations in Comet fall into three levels of support: -There is an [epic](https://github.com/apache/datafusion-comet/issues/286) where we are tracking the work to implement Spark-compatible cast expressions. +- **Compatible**: The results match Apache Spark +- **Incompatible**: The results may match Apache Spark for some inputs, but there are known issues where some inputs + will result in incorrect results or exceptions. The query stage will fall back to Spark by default. Setting + `spark.comet.cast.allowIncompatible=true` will allow all incompatible casts to run natively in Comet, but this is not + recommended for production use. +- **Unsupported**: Comet does not provide a native version of this cast expression and the query stage will fall back to + Spark. -### Cast from String to Timestamp +The following table shows the current cast operations supported by Comet. Any cast that does not appear in this +table (such as those involving complex types and timestamp_ntz, for example) are not supported by Comet. -Casting from String to Timestamp is disabled by default due to incompatibilities with Spark, including timezone -issues, and can be enabled by setting `spark.comet.castStringToTimestamp=true`. See the -[tracking issue](https://github.com/apache/datafusion-comet/issues/328) for more information. +| From Type | To Type | Compatible? | Notes | +| --------- | --------- | ------------ | ----------------------------------- | +| boolean | byte | Compatible | | +| boolean | short | Compatible | | +| boolean | integer | Compatible | | +| boolean | long | Compatible | | +| boolean | float | Compatible | | +| boolean | double | Compatible | | +| boolean | decimal | Unsupported | | +| boolean | string | Compatible | | +| boolean | timestamp | Unsupported | | +| byte | boolean | Compatible | | +| byte | short | Compatible | | +| byte | integer | Compatible | | +| byte | long | Compatible | | +| byte | float | Compatible | | +| byte | double | Compatible | | +| byte | decimal | Compatible | | +| byte | string | Compatible | | +| byte | binary | Unsupported | | +| byte | timestamp | Unsupported | | +| short | boolean | Compatible | | +| short | byte | Compatible | | +| short | integer | Compatible | | +| short | long | Compatible | | +| short | float | Compatible | | +| short | double | Compatible | | +| short | decimal | Compatible | | +| short | string | Compatible | | +| short | binary | Unsupported | | +| short | timestamp | Unsupported | | +| integer | boolean | Compatible | | +| integer | byte | Compatible | | +| integer | short | Compatible | | +| integer | long | Compatible | | +| integer | float | Compatible | | +| integer | double | Compatible | | +| integer | decimal | Compatible | | +| integer | string | Compatible | | +| integer | binary | Unsupported | | +| integer | timestamp | Unsupported | | +| long | boolean | Compatible | | +| long | byte | Compatible | | +| long | short | Compatible | | +| long | integer | Compatible | | +| long | float | Compatible | | +| long | double | Compatible | | +| long | decimal | Compatible | | +| long | string | Compatible | | +| long | binary | Unsupported | | +| long | timestamp | Unsupported | | +| float | boolean | Compatible | | +| float | byte | Unsupported | | +| float | short | Unsupported | | +| float | integer | Unsupported | | +| float | long | Unsupported | | +| float | double | Compatible | | +| float | decimal | Unsupported | | +| float | string | Incompatible | | +| float | timestamp | Unsupported | | +| double | boolean | Compatible | | +| double | byte | Unsupported | | +| double | short | Unsupported | | +| double | integer | Unsupported | | +| double | long | Unsupported | | +| double | float | Compatible | | +| double | decimal | Incompatible | | +| double | string | Incompatible | | +| double | timestamp | Unsupported | | +| decimal | boolean | Unsupported | | +| decimal | byte | Unsupported | | +| decimal | short | Unsupported | | +| decimal | integer | Unsupported | | +| decimal | long | Unsupported | | +| decimal | float | Compatible | | +| decimal | double | Compatible | | +| decimal | string | Unsupported | | +| decimal | timestamp | Unsupported | | +| string | boolean | Compatible | | +| string | byte | Compatible | | +| string | short | Compatible | | +| string | integer | Compatible | | +| string | long | Compatible | | +| string | float | Unsupported | | +| string | double | Unsupported | | +| string | decimal | Unsupported | | +| string | binary | Compatible | | +| string | date | Unsupported | | +| string | timestamp | Incompatible | Not all valid formats are supported | +| binary | string | Incompatible | | +| date | boolean | Unsupported | | +| date | byte | Unsupported | | +| date | short | Unsupported | | +| date | integer | Unsupported | | +| date | long | Unsupported | | +| date | float | Unsupported | | +| date | double | Unsupported | | +| date | decimal | Unsupported | | +| date | string | Compatible | | +| date | timestamp | Unsupported | | +| timestamp | boolean | Unsupported | | +| timestamp | byte | Unsupported | | +| timestamp | short | Unsupported | | +| timestamp | integer | Unsupported | | +| timestamp | long | Compatible | | +| timestamp | float | Unsupported | | +| timestamp | double | Unsupported | | +| timestamp | decimal | Unsupported | | +| timestamp | string | Compatible | | +| timestamp | date | Compatible | | diff --git a/docs/source/user-guide/configs.md b/docs/source/user-guide/configs.md index 2a7003327..22a7a0982 100644 --- a/docs/source/user-guide/configs.md +++ b/docs/source/user-guide/configs.md @@ -25,7 +25,7 @@ Comet provides the following configuration settings. |--------|-------------|---------------| | spark.comet.ansi.enabled | Comet does not respect ANSI mode in most cases and by default will not accelerate queries when ansi mode is enabled. Enable this setting to test Comet's experimental support for ANSI mode. This should not be used in production. | false | | spark.comet.batchSize | The columnar batch size, i.e., the maximum number of rows that a batch can contain. | 8192 | -| spark.comet.cast.stringToTimestamp | Comet is not currently fully compatible with Spark when casting from String to Timestamp. | false | +| spark.comet.cast.allowIncompatible | Comet is not currently fully compatible with Spark for all cast operations. Set this config to true to allow them anyway. See compatibility guide for more information. | false | | spark.comet.columnar.shuffle.async.enabled | Whether to enable asynchronous shuffle for Arrow-based shuffle. By default, this config is false. | false | | spark.comet.columnar.shuffle.async.max.thread.num | Maximum number of threads on an executor used for Comet async columnar shuffle. By default, this config is 100. This is the upper bound of total number of shuffle threads per executor. In other words, if the number of cores * the number of shuffle threads per task `spark.comet.columnar.shuffle.async.thread.num` is larger than this config. Comet will use this config as the number of shuffle threads per executor instead. | 100 | | spark.comet.columnar.shuffle.async.thread.num | Number of threads used for Comet async columnar shuffle per shuffle task. By default, this config is 3. Note that more threads means more memory requirement to buffer shuffle data before flushing to disk. Also, more threads may not always improve performance, and should be set based on the number of cores available. | 3 | diff --git a/pom.xml b/pom.xml index 6d28c8168..d47953fac 100644 --- a/pom.xml +++ b/pom.xml @@ -886,7 +886,7 @@ under the License. rust-toolchain Makefile dev/Dockerfile* - dev/diff/** + dev/diffs/** dev/deploy-file **/test/resources/** **/benchmarks/*.txt diff --git a/spark/pom.xml b/spark/pom.xml index 9392b7fe9..7d3d3d758 100644 --- a/spark/pom.xml +++ b/spark/pom.xml @@ -58,6 +58,11 @@ under the License. org.scala-lang scala-library + + org.scala-lang + scala-reflect + provided + com.google.protobuf protobuf-java @@ -270,17 +275,13 @@ under the License. 3.2.0 - generate-config-docs + generate-user-guide-reference-docs package java - org.apache.comet.CometConfGenerateDocs - - docs/source/user-guide/configs-template.md - docs/source/user-guide/configs.md - + org.apache.comet.GenerateDocs compile diff --git a/spark/src/main/scala/org/apache/comet/GenerateDocs.scala b/spark/src/main/scala/org/apache/comet/GenerateDocs.scala new file mode 100644 index 000000000..8c414c7fe --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/GenerateDocs.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet + +import java.io.{BufferedOutputStream, FileOutputStream} + +import scala.io.Source + +import org.apache.spark.sql.catalyst.expressions.Cast + +import org.apache.comet.expressions.{CometCast, Compatible, Incompatible, Unsupported} + +/** + * Utility for generating markdown documentation from the configs. + * + * This is invoked when running `mvn clean package -DskipTests`. + */ +object GenerateDocs { + + def main(args: Array[String]): Unit = { + generateConfigReference() + generateCompatibilityGuide() + } + + private def generateConfigReference(): Unit = { + val templateFilename = "docs/source/user-guide/configs-template.md" + val outputFilename = "docs/source/user-guide/configs.md" + val w = new BufferedOutputStream(new FileOutputStream(outputFilename)) + for (line <- Source.fromFile(templateFilename).getLines()) { + if (line.trim == "") { + val publicConfigs = CometConf.allConfs.filter(_.isPublic) + val confs = publicConfigs.sortBy(_.key) + w.write("| Config | Description | Default Value |\n".getBytes) + w.write("|--------|-------------|---------------|\n".getBytes) + for (conf <- confs) { + w.write(s"| ${conf.key} | ${conf.doc.trim} | ${conf.defaultValueString} |\n".getBytes) + } + } else { + w.write(s"${line.trim}\n".getBytes) + } + } + w.close() + } + + private def generateCompatibilityGuide(): Unit = { + val templateFilename = "docs/source/user-guide/compatibility-template.md" + val outputFilename = "docs/source/user-guide/compatibility.md" + val w = new BufferedOutputStream(new FileOutputStream(outputFilename)) + for (line <- Source.fromFile(templateFilename).getLines()) { + if (line.trim == "") { + w.write("| From Type | To Type | Compatible? | Notes |\n".getBytes) + w.write("|-|-|-|-|\n".getBytes) + for (fromType <- CometCast.supportedTypes) { + for (toType <- CometCast.supportedTypes) { + if (Cast.canCast(fromType, toType) && fromType != toType) { + val fromTypeName = fromType.typeName.replace("(10,2)", "") + val toTypeName = toType.typeName.replace("(10,2)", "") + CometCast.isSupported(fromType, toType, None, "LEGACY") match { + case Compatible => + w.write(s"| $fromTypeName | $toTypeName | Compatible | |\n".getBytes) + case Incompatible(Some(reason)) => + w.write(s"| $fromTypeName | $toTypeName | Incompatible | $reason |\n".getBytes) + case Incompatible(None) => + w.write(s"| $fromTypeName | $toTypeName | Incompatible | |\n".getBytes) + case Unsupported => + w.write(s"| $fromTypeName | $toTypeName | Unsupported | |\n".getBytes) + } + } + } + } + } else { + w.write(s"${line.trim}\n".getBytes) + } + } + w.close() + } +} diff --git a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala new file mode 100644 index 000000000..5641c94a8 --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala @@ -0,0 +1,196 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.expressions + +import org.apache.spark.sql.types.{DataType, DataTypes, DecimalType} + +sealed trait SupportLevel + +/** We support this feature with full compatibility with Spark */ +object Compatible extends SupportLevel + +/** We support this feature but results can be different from Spark */ +case class Incompatible(reason: Option[String] = None) extends SupportLevel + +/** We do not support this feature */ +object Unsupported extends SupportLevel + +object CometCast { + + def supportedTypes: Seq[DataType] = + Seq( + DataTypes.BooleanType, + DataTypes.ByteType, + DataTypes.ShortType, + DataTypes.IntegerType, + DataTypes.LongType, + DataTypes.FloatType, + DataTypes.DoubleType, + DataTypes.createDecimalType(10, 2), + DataTypes.StringType, + DataTypes.BinaryType, + DataTypes.DateType, + DataTypes.TimestampType) + // TODO add DataTypes.TimestampNTZType for Spark 3.4 and later + // https://github.com/apache/datafusion-comet/issues/378 + + def isSupported( + fromType: DataType, + toType: DataType, + timeZoneId: Option[String], + evalMode: String): SupportLevel = { + + if (fromType == toType) { + return Compatible + } + + (fromType, toType) match { + case (dt: DataType, _) if dt.typeName == "timestamp_ntz" => + // https://github.com/apache/datafusion-comet/issues/378 + toType match { + case DataTypes.TimestampType | DataTypes.DateType | DataTypes.StringType => + Incompatible() + case _ => + Unsupported + } + case (_: DecimalType, _: DecimalType) => + // https://github.com/apache/datafusion-comet/issues/375 + Incompatible() + case (DataTypes.StringType, _) => + canCastFromString(toType, timeZoneId, evalMode) + case (_, DataTypes.StringType) => + canCastToString(fromType) + case (DataTypes.TimestampType, _) => + canCastFromTimestamp(toType) + case (_: DecimalType, _) => + canCastFromDecimal(toType) + case (DataTypes.BooleanType, _) => + canCastFromBoolean(toType) + case ( + DataTypes.ByteType | DataTypes.ShortType | DataTypes.IntegerType | DataTypes.LongType, + _) => + canCastFromInt(toType) + case (DataTypes.FloatType, _) => + canCastFromFloat(toType) + case (DataTypes.DoubleType, _) => + canCastFromDouble(toType) + case _ => Unsupported + } + } + + private def canCastFromString( + toType: DataType, + timeZoneId: Option[String], + evalMode: String): SupportLevel = { + toType match { + case DataTypes.BooleanType => + Compatible + case DataTypes.ByteType | DataTypes.ShortType | DataTypes.IntegerType | + DataTypes.LongType => + Compatible + case DataTypes.BinaryType => + Compatible + case DataTypes.FloatType | DataTypes.DoubleType => + // https://github.com/apache/datafusion-comet/issues/326 + Unsupported + case _: DecimalType => + // https://github.com/apache/datafusion-comet/issues/325 + Unsupported + case DataTypes.DateType => + // https://github.com/apache/datafusion-comet/issues/327 + Unsupported + case DataTypes.TimestampType if timeZoneId.exists(tz => tz != "UTC") => + Incompatible(Some(s"Cast will use UTC instead of $timeZoneId")) + case DataTypes.TimestampType if evalMode == "ANSI" => + Incompatible(Some("ANSI mode not supported")) + case DataTypes.TimestampType => + // https://github.com/apache/datafusion-comet/issues/328 + Incompatible(Some("Not all valid formats are supported")) + case _ => + Unsupported + } + } + + private def canCastToString(fromType: DataType): SupportLevel = { + fromType match { + case DataTypes.BooleanType => Compatible + case DataTypes.ByteType | DataTypes.ShortType | DataTypes.IntegerType | + DataTypes.LongType => + Compatible + case DataTypes.DateType => Compatible + case DataTypes.TimestampType => Compatible + case DataTypes.FloatType | DataTypes.DoubleType => + // https://github.com/apache/datafusion-comet/issues/326 + Incompatible() + case DataTypes.BinaryType => + // https://github.com/apache/datafusion-comet/issues/377 + Incompatible() + case _ => Unsupported + } + } + + private def canCastFromTimestamp(toType: DataType): SupportLevel = { + toType match { + case DataTypes.BooleanType | DataTypes.ByteType | DataTypes.ShortType | + DataTypes.IntegerType => + // https://github.com/apache/datafusion-comet/issues/352 + // this seems like an edge case that isn't important for us to support + Unsupported + case DataTypes.LongType => + // https://github.com/apache/datafusion-comet/issues/352 + Compatible + case DataTypes.StringType => Compatible + case DataTypes.DateType => Compatible + case _ => Unsupported + } + } + + private def canCastFromBoolean(toType: DataType): SupportLevel = toType match { + case DataTypes.ByteType | DataTypes.ShortType | DataTypes.IntegerType | DataTypes.LongType | + DataTypes.FloatType | DataTypes.DoubleType => + Compatible + case _ => Unsupported + } + + private def canCastFromInt(toType: DataType): SupportLevel = toType match { + case DataTypes.BooleanType | DataTypes.ByteType | DataTypes.ShortType | + DataTypes.IntegerType | DataTypes.LongType | DataTypes.FloatType | DataTypes.DoubleType | + _: DecimalType => + Compatible + case _ => Unsupported + } + + private def canCastFromFloat(toType: DataType): SupportLevel = toType match { + case DataTypes.BooleanType | DataTypes.DoubleType => Compatible + case _ => Unsupported + } + + private def canCastFromDouble(toType: DataType): SupportLevel = toType match { + case DataTypes.BooleanType | DataTypes.FloatType => Compatible + case _: DecimalType => Incompatible() + case _ => Unsupported + } + + private def canCastFromDecimal(toType: DataType): SupportLevel = toType match { + case DataTypes.FloatType | DataTypes.DoubleType => Compatible + case _ => Unsupported + } + +} diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index c1d787fbb..63b23ba1e 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -43,6 +43,7 @@ import org.apache.spark.unsafe.types.UTF8String import org.apache.comet.CometConf import org.apache.comet.CometSparkSessionExtensions.{isCometOperatorEnabled, isCometScan, isSpark32, isSpark34Plus, withInfo} +import org.apache.comet.expressions.{CometCast, Compatible, Incompatible, Unsupported} import org.apache.comet.serde.ExprOuterClass.{AggExpr, DataType => ProtoDataType, Expr, ScalarFunc} import org.apache.comet.serde.ExprOuterClass.DataType.{DataTypeInfo, DecimalInfo, ListInfo, MapInfo, StructInfo} import org.apache.comet.serde.OperatorOuterClass.{AggregateMode => CometAggregateMode, JoinType, Operator} @@ -585,20 +586,35 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { // Spark 3.4+ has EvalMode enum with values LEGACY, ANSI, and TRY evalMode.toString } - val supportedCast = (child.dataType, dt) match { - case (DataTypes.StringType, DataTypes.TimestampType) - if !CometConf.COMET_CAST_STRING_TO_TIMESTAMP.get() => - // https://github.com/apache/datafusion-comet/issues/328 - withInfo(expr, s"${CometConf.COMET_CAST_STRING_TO_TIMESTAMP.key} is disabled") - false - case _ => true - } - if (supportedCast) { - castToProto(timeZoneId, dt, childExpr, evalModeStr) - } else { - // no need to call withInfo here since it was called when determining - // the value for `supportedCast` - None + val castSupport = + CometCast.isSupported(child.dataType, dt, timeZoneId, evalModeStr) + + def getIncompatMessage(reason: Option[String]) = + "Comet does not guarantee correct results for cast " + + s"from ${child.dataType} to $dt " + + s"with timezone $timeZoneId and evalMode $evalModeStr" + + reason.map(str => s" ($str)").getOrElse("") + + castSupport match { + case Compatible => + castToProto(timeZoneId, dt, childExpr, evalModeStr) + case Incompatible(reason) => + if (CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.get()) { + logWarning(getIncompatMessage(reason)) + castToProto(timeZoneId, dt, childExpr, evalModeStr) + } else { + withInfo( + expr, + s"${getIncompatMessage(reason)}. To enable all incompatible casts, set " + + s"${CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key}=true") + None + } + case Unsupported => + withInfo( + expr, + s"Unsupported cast from ${child.dataType} to $dt " + + s"with timezone $timeZoneId and evalMode $evalModeStr") + None } } else { withInfo(expr, child) diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala index c6a7c7223..54b136791 100644 --- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala @@ -30,6 +30,8 @@ import org.apache.spark.sql.functions.col import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, DataTypes} +import org.apache.comet.expressions.CometCast + class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { import testImplicits._ @@ -40,7 +42,12 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { // but this is likely a reasonable starting point for now private val whitespaceChars = " \t\r\n" - private val numericPattern = "0123456789e+-." + whitespaceChars + /** + * We use these characters to construct strings that potentially represent valid numbers such as + * `-12.34d` or `4e7`. Invalid numeric strings will also be generated, such as `+e.-d`. + */ + private val numericPattern = "0123456789deEf+-." + whitespaceChars + private val datePattern = "0123456789/" + whitespaceChars private val timestampPattern = "0123456789/:T" + whitespaceChars @@ -67,22 +74,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } - // make sure we have tests for all combinations of our supported types - val supportedTypes = - Seq( - DataTypes.BooleanType, - DataTypes.ByteType, - DataTypes.ShortType, - DataTypes.IntegerType, - DataTypes.LongType, - DataTypes.FloatType, - DataTypes.DoubleType, - DataTypes.createDecimalType(10, 2), - DataTypes.StringType, - DataTypes.DateType, - DataTypes.TimestampType) - // TODO add DataTypes.TimestampNTZType for Spark 3.4 and later - assertTestsExist(supportedTypes, supportedTypes) + assertTestsExist(CometCast.supportedTypes, CometCast.supportedTypes) } // CAST from BooleanType @@ -159,6 +151,10 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { castTest(generateBytes(), DataTypes.StringType) } + ignore("cast ByteType to BinaryType") { + castTest(generateBytes(), DataTypes.BinaryType) + } + ignore("cast ByteType to TimestampType") { // input: -1, expected: 1969-12-31 15:59:59.0, actual: 1969-12-31 15:59:59.999999 castTest(generateBytes(), DataTypes.TimestampType) @@ -199,6 +195,10 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { castTest(generateShorts(), DataTypes.StringType) } + ignore("cast ShortType to BinaryType") { + castTest(generateShorts(), DataTypes.BinaryType) + } + ignore("cast ShortType to TimestampType") { // input: -1003, expected: 1969-12-31 15:43:17.0, actual: 1969-12-31 15:59:59.998997 castTest(generateShorts(), DataTypes.TimestampType) @@ -241,6 +241,10 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { castTest(generateInts(), DataTypes.StringType) } + ignore("cast IntegerType to BinaryType") { + castTest(generateInts(), DataTypes.BinaryType) + } + ignore("cast IntegerType to TimestampType") { // input: -1000479329, expected: 1938-04-19 01:04:31.0, actual: 1969-12-31 15:43:19.520671 castTest(generateInts(), DataTypes.TimestampType) @@ -284,6 +288,10 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { castTest(generateLongs(), DataTypes.StringType) } + ignore("cast LongType to BinaryType") { + castTest(generateLongs(), DataTypes.BinaryType) + } + ignore("cast LongType to TimestampType") { // java.lang.ArithmeticException: long overflow castTest(generateLongs(), DataTypes.TimestampType) @@ -324,9 +332,22 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { castTest(generateFloats(), DataTypes.createDecimalType(10, 2)) } - ignore("cast FloatType to StringType") { + test("cast FloatType to StringType") { // https://github.com/apache/datafusion-comet/issues/312 - castTest(generateFloats(), DataTypes.StringType) + val r = new Random(0) + val values = Seq( + Float.MaxValue, + Float.MinValue, + Float.NaN, + Float.PositiveInfinity, + Float.NegativeInfinity, + 1.0f, + -1.0f, + Short.MinValue.toFloat, + Short.MaxValue.toFloat, + 0.0f) ++ + Range(0, dataSize).map(_ => r.nextFloat()) + withNulls(values).toDF("a") } ignore("cast FloatType to TimestampType") { @@ -369,9 +390,18 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { castTest(generateDoubles(), DataTypes.createDecimalType(10, 2)) } - ignore("cast DoubleType to StringType") { + test("cast DoubleType to StringType") { // https://github.com/apache/datafusion-comet/issues/312 - castTest(generateDoubles(), DataTypes.StringType) + val r = new Random(0) + val values = Seq( + Double.MaxValue, + Double.MinValue, + Double.NaN, + Double.PositiveInfinity, + Double.NegativeInfinity, + 0.0d) ++ + Range(0, dataSize).map(_ => r.nextDouble()) + withNulls(values).toDF("a") } ignore("cast DoubleType to TimestampType") { @@ -433,23 +463,64 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { castTest(testValues, DataTypes.BooleanType) } - ignore("cast StringType to ByteType") { - // https://github.com/apache/datafusion-comet/issues/15 - castTest(generateStrings(numericPattern, 8).toDF("a"), DataTypes.ByteType) - } - - ignore("cast StringType to ShortType") { - // https://github.com/apache/datafusion-comet/issues/15 - castTest(generateStrings(numericPattern, 8).toDF("a"), DataTypes.ShortType) - } - - ignore("cast StringType to IntegerType") { - // https://github.com/apache/datafusion-comet/issues/15 + private val castStringToIntegralInputs: Seq[String] = Seq( + "", + ".", + "+", + "-", + "+.", + "-.", + "-0", + "+1", + "-1", + ".2", + "-.2", + "1e1", + "1.1d", + "1.1f", + Byte.MinValue.toString, + (Byte.MinValue.toShort - 1).toString, + Byte.MaxValue.toString, + (Byte.MaxValue.toShort + 1).toString, + Short.MinValue.toString, + (Short.MinValue.toInt - 1).toString, + Short.MaxValue.toString, + (Short.MaxValue.toInt + 1).toString, + Int.MinValue.toString, + (Int.MinValue.toLong - 1).toString, + Int.MaxValue.toString, + (Int.MaxValue.toLong + 1).toString, + Long.MinValue.toString, + Long.MaxValue.toString, + "-9223372036854775809", // Long.MinValue -1 + "9223372036854775808" // Long.MaxValue + 1 + ) + + test("cast StringType to ByteType") { + // test with hand-picked values + castTest(castStringToIntegralInputs.toDF("a"), DataTypes.ByteType) + // fuzz test + castTest(generateStrings(numericPattern, 4).toDF("a"), DataTypes.ByteType) + } + + test("cast StringType to ShortType") { + // test with hand-picked values + castTest(castStringToIntegralInputs.toDF("a"), DataTypes.ShortType) + // fuzz test + castTest(generateStrings(numericPattern, 5).toDF("a"), DataTypes.ShortType) + } + + test("cast StringType to IntegerType") { + // test with hand-picked values + castTest(castStringToIntegralInputs.toDF("a"), DataTypes.IntegerType) + // fuzz test castTest(generateStrings(numericPattern, 8).toDF("a"), DataTypes.IntegerType) } - ignore("cast StringType to LongType") { - // https://github.com/apache/datafusion-comet/issues/15 + test("cast StringType to LongType") { + // test with hand-picked values + castTest(castStringToIntegralInputs.toDF("a"), DataTypes.LongType) + // fuzz test castTest(generateStrings(numericPattern, 8).toDF("a"), DataTypes.LongType) } @@ -469,27 +540,76 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { castTest(values, DataTypes.createDecimalType(10, 2)) } + test("cast StringType to BinaryType") { + castTest(generateStrings(numericPattern, 8).toDF("a"), DataTypes.BinaryType) + } + ignore("cast StringType to DateType") { // https://github.com/apache/datafusion-comet/issues/327 castTest(generateStrings(datePattern, 8).toDF("a"), DataTypes.DateType) } test("cast StringType to TimestampType disabled by default") { - val values = Seq("2020-01-01T12:34:56.123456", "T2").toDF("a") - castFallbackTest( - values.toDF("a"), - DataTypes.TimestampType, - "spark.comet.cast.stringToTimestamp is disabled") + withSQLConf((SQLConf.SESSION_LOCAL_TIMEZONE.key, "UTC")) { + val values = Seq("2020-01-01T12:34:56.123456", "T2").toDF("a") + castFallbackTest( + values.toDF("a"), + DataTypes.TimestampType, + "Not all valid formats are supported") + } + } + + test("cast StringType to TimestampType disabled for non-UTC timezone") { + withSQLConf((SQLConf.SESSION_LOCAL_TIMEZONE.key, "America/Denver")) { + val values = Seq("2020-01-01T12:34:56.123456", "T2").toDF("a") + castFallbackTest( + values.toDF("a"), + DataTypes.TimestampType, + "Cast will use UTC instead of Some(America/Denver)") + } } - ignore("cast StringType to TimestampType") { + ignore("cast StringType to TimestampType (fuzz test)") { // https://github.com/apache/datafusion-comet/issues/328 - withSQLConf((CometConf.COMET_CAST_STRING_TO_TIMESTAMP.key, "true")) { + withSQLConf((CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key, "true")) { val values = Seq("2020-01-01T12:34:56.123456", "T2") ++ generateStrings(timestampPattern, 8) castTest(values.toDF("a"), DataTypes.TimestampType) } } + test("cast StringType to TimestampType") { + withSQLConf( + SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC", + CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key -> "true") { + val values = Seq( + "2020", + "2020-01", + "2020-01-01", + "2020-01-01T12", + "2020-01-01T12:34", + "2020-01-01T12:34:56", + "2020-01-01T12:34:56.123456", + "T2", + "-9?") + castTimestampTest(values.toDF("a"), DataTypes.TimestampType) + } + + // test for invalid inputs + withSQLConf( + SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC", + CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key -> "true") { + val values = Seq("-9?", "1-", "0.5") + castTimestampTest(values.toDF("a"), DataTypes.TimestampType) + } + } + + // CAST from BinaryType + + ignore("cast BinaryType to StringType") { + // TODO implement this + // https://github.com/apache/datafusion-comet/issues/377 + } + // CAST from DateType ignore("cast DateType to BooleanType") { @@ -566,9 +686,8 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { castTest(generateTimestamps(), DataTypes.IntegerType) } - ignore("cast TimestampType to LongType") { - // https://github.com/apache/datafusion-comet/issues/352 - // input: 2023-12-31 17:00:00.0, expected: 1.70407078E9, actual: 1.70407082E15] + test("cast TimestampType to LongType") { + assume(CometSparkSessionExtensions.isSpark33Plus) castTest(generateTimestamps(), DataTypes.LongType) } @@ -717,7 +836,28 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } - private def castTest(input: DataFrame, toType: DataType): Unit = { + private def castFallbackTestTimezone( + input: DataFrame, + toType: DataType, + expectedMessage: String): Unit = { + withTempPath { dir => + val data = roundtripParquet(input, dir).coalesce(1) + data.createOrReplaceTempView("t") + + withSQLConf( + (SQLConf.ANSI_ENABLED.key, "false"), + (CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key, "true"), + (SQLConf.SESSION_LOCAL_TIMEZONE.key, "America/Los_Angeles")) { + val df = data.withColumn("converted", col("a").cast(toType)) + df.collect() + val str = + new ExtendedExplainInfo().generateExtendedInfo(df.queryExecution.executedPlan) + assert(str.contains(expectedMessage)) + } + } + } + + private def castTimestampTest(input: DataFrame, toType: DataType) = { withTempPath { dir => val data = roundtripParquet(input, dir).coalesce(1) data.createOrReplaceTempView("t") @@ -731,6 +871,31 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { val df2 = spark.sql(s"select try_cast(a as ${toType.sql}) from t") checkSparkAnswer(df2) } + } + } + + private def castTest(input: DataFrame, toType: DataType): Unit = { + + // we do not support the TryCast expression in Spark 3.2 and 3.3 + // https://github.com/apache/datafusion-comet/issues/374 + val testTryCast = CometSparkSessionExtensions.isSpark34Plus + + withTempPath { dir => + val data = roundtripParquet(input, dir).coalesce(1) + data.createOrReplaceTempView("t") + + withSQLConf((SQLConf.ANSI_ENABLED.key, "false")) { + // cast() should return null for invalid inputs when ansi mode is disabled + val df = spark.sql(s"select a, cast(a as ${toType.sql}) from t order by a") + checkSparkAnswerAndOperator(df) + + // try_cast() should always return null for invalid inputs + if (testTryCast) { + val df2 = + spark.sql(s"select a, try_cast(a as ${toType.sql}) from t order by a") + checkSparkAnswerAndOperator(df2) + } + } // with ANSI enabled, we should produce the same exception as Spark withSQLConf( @@ -769,9 +934,11 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { } // try_cast() should always return null for invalid inputs - val df2 = - spark.sql(s"select a, try_cast(a as ${toType.sql}) from t order by a") - checkSparkAnswer(df2) + if (testTryCast) { + val df2 = + spark.sql(s"select a, try_cast(a as ${toType.sql}) from t order by a") + checkSparkAnswerAndOperator(df2) + } } } } diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 3683c8d44..c8c7ffd5c 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -259,7 +259,9 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } test("cast timestamp and timestamp_ntz") { - withSQLConf(SESSION_LOCAL_TIMEZONE.key -> "Asia/Kathmandu") { + withSQLConf( + SESSION_LOCAL_TIMEZONE.key -> "Asia/Kathmandu", + CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key -> "true") { Seq(true, false).foreach { dictionaryEnabled => withTempDir { dir => val path = new Path(dir.toURI.toString, "timestamp_trunc.parquet") @@ -282,7 +284,9 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { // TODO: make the test pass for Spark 3.2 & 3.3 assume(isSpark34Plus) - withSQLConf(SESSION_LOCAL_TIMEZONE.key -> "Asia/Kathmandu") { + withSQLConf( + SESSION_LOCAL_TIMEZONE.key -> "Asia/Kathmandu", + CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key -> "true") { Seq(true, false).foreach { dictionaryEnabled => withTempDir { dir => val path = new Path(dir.toURI.toString, "timestamp_trunc.parquet") @@ -305,7 +309,9 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { // TODO: make the test pass for Spark 3.2 & 3.3 assume(isSpark34Plus) - withSQLConf(SESSION_LOCAL_TIMEZONE.key -> "Asia/Kathmandu") { + withSQLConf( + SESSION_LOCAL_TIMEZONE.key -> "Asia/Kathmandu", + CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key -> "true") { Seq(true, false).foreach { dictionaryEnabled => withTempDir { dir => val path = new Path(dir.toURI.toString, "timestamp_trunc.parquet") @@ -394,32 +400,34 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { test("date_trunc with timestamp_ntz") { assume(!isSpark32, "timestamp functions for timestamp_ntz have incorrect behavior in 3.2") - Seq(true, false).foreach { dictionaryEnabled => - withTempDir { dir => - val path = new Path(dir.toURI.toString, "timestamp_trunc.parquet") - makeRawTimeParquetFile(path, dictionaryEnabled = dictionaryEnabled, 10000) - withParquetTable(path.toString, "timetbl") { - Seq( - "YEAR", - "YYYY", - "YY", - "MON", - "MONTH", - "MM", - "QUARTER", - "WEEK", - "DAY", - "DD", - "HOUR", - "MINUTE", - "SECOND", - "MILLISECOND", - "MICROSECOND").foreach { format => - checkSparkAnswerAndOperator( - "SELECT " + - s"date_trunc('$format', _3), " + - s"date_trunc('$format', _5) " + - " from timetbl") + withSQLConf(CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key -> "true") { + Seq(true, false).foreach { dictionaryEnabled => + withTempDir { dir => + val path = new Path(dir.toURI.toString, "timestamp_trunc.parquet") + makeRawTimeParquetFile(path, dictionaryEnabled = dictionaryEnabled, 10000) + withParquetTable(path.toString, "timetbl") { + Seq( + "YEAR", + "YYYY", + "YY", + "MON", + "MONTH", + "MM", + "QUARTER", + "WEEK", + "DAY", + "DD", + "HOUR", + "MINUTE", + "SECOND", + "MILLISECOND", + "MICROSECOND").foreach { format => + checkSparkAnswerAndOperator( + "SELECT " + + s"date_trunc('$format', _3), " + + s"date_trunc('$format', _5) " + + " from timetbl") + } } } } @@ -428,22 +436,24 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { test("date_trunc with format array") { assume(isSpark33Plus, "TimestampNTZ is supported in Spark 3.3+, See SPARK-36182") - val numRows = 1000 - Seq(true, false).foreach { dictionaryEnabled => - withTempDir { dir => - val path = new Path(dir.toURI.toString, "timestamp_trunc_with_format.parquet") - makeDateTimeWithFormatTable(path, dictionaryEnabled = dictionaryEnabled, numRows) - withParquetTable(path.toString, "timeformattbl") { - checkSparkAnswerAndOperator( - "SELECT " + - "format, _0, _1, _2, _3, _4, _5, " + - "date_trunc(format, _0), " + - "date_trunc(format, _1), " + - "date_trunc(format, _2), " + - "date_trunc(format, _3), " + - "date_trunc(format, _4), " + - "date_trunc(format, _5) " + - " from timeformattbl ") + withSQLConf(CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key -> "true") { + val numRows = 1000 + Seq(true, false).foreach { dictionaryEnabled => + withTempDir { dir => + val path = new Path(dir.toURI.toString, "timestamp_trunc_with_format.parquet") + makeDateTimeWithFormatTable(path, dictionaryEnabled = dictionaryEnabled, numRows) + withParquetTable(path.toString, "timeformattbl") { + checkSparkAnswerAndOperator( + "SELECT " + + "format, _0, _1, _2, _3, _4, _5, " + + "date_trunc(format, _0), " + + "date_trunc(format, _1), " + + "date_trunc(format, _2), " + + "date_trunc(format, _3), " + + "date_trunc(format, _4), " + + "date_trunc(format, _5) " + + " from timeformattbl ") + } } } } @@ -818,7 +828,9 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { test("ceil and floor") { Seq("true", "false").foreach { dictionary => - withSQLConf("parquet.enable.dictionary" -> dictionary) { + withSQLConf( + "parquet.enable.dictionary" -> dictionary, + CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key -> "true") { withParquetTable( (-5 until 5).map(i => (i.toDouble + 0.3, i.toDouble + 0.8)), "tbl", @@ -1406,7 +1418,9 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { test("hash functions") { Seq(true, false).foreach { dictionary => - withSQLConf("parquet.enable.dictionary" -> dictionary.toString) { + withSQLConf( + "parquet.enable.dictionary" -> dictionary.toString, + CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key -> "true") { val table = "test" withTable(table) { sql(s"create table $table(col string, a int, b float) using parquet") diff --git a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala index bd4042ec1..fc6876fd1 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala @@ -863,7 +863,8 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { Seq(true, false).foreach { nativeShuffleEnabled => withSQLConf( CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> nativeShuffleEnabled.toString, - CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "false") { + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "false", + CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key -> "true") { withTempDir { dir => val path = new Path(dir.toURI.toString, "test") makeParquetFile(path, 1000, 20, dictionaryEnabled) diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index f37183563..bfde14033 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -949,7 +949,9 @@ class CometExecSuite extends CometTestBase { } test("SPARK-33474: Support typed literals as partition spec values") { - withSQLConf(SESSION_LOCAL_TIMEZONE.key -> "Asia/Kathmandu") { + withSQLConf( + SESSION_LOCAL_TIMEZONE.key -> "Asia/Kathmandu", + CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key -> "true") { withTable("t1") { val binaryStr = "Spark SQL" val binaryHexStr = Hex.hex(UTF8String.fromString(binaryStr).getBytes).toString diff --git a/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala b/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala index 48969ea41..90ea79473 100644 --- a/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala @@ -261,6 +261,7 @@ trait CometPlanStabilitySuite extends DisableAdaptiveExecutionSuite with TPCDSBa CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", CometConf.COMET_EXEC_ALL_OPERATOR_ENABLED.key -> "true", CometConf.COMET_EXEC_ALL_EXPR_ENABLED.key -> "true", + CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key -> "true", // needed for v1.4/q9, v1.4/q44, v2.7.0/q6, v2.7.0/q64 "spark.sql.readSideCharPadding" -> "false", SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "10MB") { val qe = sql(queryString).queryExecution