From 4c2eecd7524db806d0984b2d4c94a0c5056f8817 Mon Sep 17 00:00:00 2001 From: Eren Avsarogullari Date: Sat, 24 Feb 2024 10:09:12 -0800 Subject: [PATCH] fix: Cast string to boolean not compatible with Spark --- .../execution/datafusion/expressions/cast.rs | 35 +++++++++++-- .../apache/comet/exec/CometExecSuite.scala | 49 ++++++++++++++++++- 2 files changed, 80 insertions(+), 4 deletions(-) diff --git a/core/src/execution/datafusion/expressions/cast.rs b/core/src/execution/datafusion/expressions/cast.rs index d8450686d..da2faaf61 100644 --- a/core/src/execution/datafusion/expressions/cast.rs +++ b/core/src/execution/datafusion/expressions/cast.rs @@ -27,7 +27,7 @@ use arrow::{ record_batch::RecordBatch, util::display::FormatOptions, }; -use arrow_array::ArrayRef; +use arrow_array::{Array, ArrayRef, BooleanArray, GenericStringArray, OffsetSizeTrait}; use arrow_schema::{DataType, Schema}; use datafusion::logical_expr::ColumnarValue; use datafusion_common::{Result as DataFusionResult, ScalarValue}; @@ -75,8 +75,37 @@ impl Cast { fn cast_array(&self, array: ArrayRef) -> DataFusionResult { let array = array_with_timezone(array, self.timezone.clone(), Some(&self.data_type)); let from_type = array.data_type(); - let cast_result = cast_with_options(&array, &self.data_type, &CAST_OPTIONS)?; - Ok(spark_cast(cast_result, from_type, &self.data_type)) + let to_type = &self.data_type; + let cast_result = match (from_type, to_type) { + (DataType::Utf8, DataType::Boolean) => Self::spark_cast_utf8_to_boolean::(&array), + (DataType::LargeUtf8, DataType::Boolean) => Self::spark_cast_utf8_to_boolean::(&array), + _ => cast_with_options(&array, &self.data_type, &CAST_OPTIONS)? + }; + let result = spark_cast(cast_result, from_type, &self.data_type); + Ok(result) + } + + fn spark_cast_utf8_to_boolean(from: &dyn Array) -> ArrayRef + where + OffsetSize: OffsetSizeTrait, + { + let array = from + .as_any() + .downcast_ref::>() + .unwrap(); + + let output_array = array + .iter() + .map(|value| match value { + Some(value) => match value.to_ascii_lowercase().trim() { + "t" | "true" | "y" | "yes" | "1" => Some(true), + "f" | "false" | "n" | "no" | "0" => Some(false), + _ => None + }, + _ => None + }).collect::(); + + Arc::new(output_array) } } diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index 0414671c2..089cb7695 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -20,6 +20,7 @@ package org.apache.comet.exec import scala.collection.JavaConverters._ +import scala.collection.convert.ImplicitConversions.`collection AsScalaIterable` import scala.collection.mutable import scala.util.Random @@ -37,9 +38,10 @@ import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec, UnionExec} import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec import org.apache.spark.sql.execution.joins.{BroadcastNestedLoopJoinExec, CartesianProductExec, SortMergeJoinExec} import org.apache.spark.sql.execution.window.WindowExec -import org.apache.spark.sql.functions.{date_add, expr} +import org.apache.spark.sql.functions.{col, date_add, expr} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.SESSION_LOCAL_TIMEZONE +import org.apache.spark.sql.types.DataTypes import org.apache.spark.unsafe.types.UTF8String import org.apache.comet.CometConf @@ -218,6 +220,51 @@ class CometExecSuite extends CometTestBase { } } + test("test cast utf8 to boolean as compatible with Spark") { + withSQLConf( + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_ALL_OPERATOR_ENABLED.key -> "true") { + withTable("test_table1", "test_table2", "test_table3", "test_table4") { + // Supported boolean values as true by both Arrow and Spark + val inputDF = Seq("t", "true", "y", "yes", "1", "T", "TrUe", "Y", "YES").toDF("c1") + inputDF.write.format("parquet").saveAsTable("test_table1") + val resultDF = this.spark + .table("test_table1") + .withColumn("converted", col("c1").cast(DataTypes.BooleanType)) + val resultArr = resultDF.collectAsList().toList + resultArr.foreach(x => assert(x.get(1) == true)) + + // Supported boolean values as false by both Arrow and Spark + val inputDF2 = Seq("f", "false", "n", "no", "0", "F", "FaLSe", "N", "No").toDF("c1") + inputDF2.write.format("parquet").saveAsTable("test_table2") + val resultDF2 = this.spark + .table("test_table2") + .withColumn("converted", col("c1").cast(DataTypes.BooleanType)) + val resultArr2 = resultDF2.collectAsList().toList + resultArr2.foreach(x => assert(x.get(1) == false)) + + // Supported boolean values by Arrow but not Spark + val inputDF3 = + Seq("TR", "FA", "tr", "tru", "ye", "on", "fa", "fal", "fals", "of", "off").toDF("c1") + inputDF3.write.format("parquet").saveAsTable("test_table3") + val resultDF3 = this.spark + .table("test_table3") + .withColumn("converted", col("c1").cast(DataTypes.BooleanType)) + val resultArr3 = resultDF3.collectAsList().toList + resultArr3.foreach(x => assert(x.get(1) == null)) + + // Invalid boolean casting values for Arrow and Spark + val inputDF4 = Seq("car", "Truck").toDF("c1") + inputDF4.write.format("parquet").saveAsTable("test_table4") + val resultDF4 = this.spark + .table("test_table4") + .withColumn("converted", col("c1").cast(DataTypes.BooleanType)) + val resultArr4 = resultDF4.collectAsList().toList + resultArr4.foreach(x => assert(x.get(1) == null)) + } + } + } + test( "fix: ReusedExchangeExec + CometShuffleExchangeExec under QueryStageExec " + "should be CometRoot") {