From c40bc7c26f13773248921646ac4f755025e4823f Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Mon, 6 May 2024 17:20:45 -0700 Subject: [PATCH 1/3] feat: Supports Stddev (#348) * feat: Supports Stddev * fix fmt * update q39a.sql.out * address comments * disable q93a and q93b for now * address comments --------- Co-authored-by: Huaxin Gao --- .../execution/datafusion/expressions/mod.rs | 1 + .../datafusion/expressions/stddev.rs | 179 ++++++++++++++++++ .../datafusion/expressions/variance.rs | 2 - core/src/execution/datafusion/planner.rs | 25 +++ core/src/execution/proto/expr.proto | 8 + docs/source/user-guide/expressions.md | 2 + .../apache/comet/serde/QueryPlanSerde.scala | 42 +++- .../comet/exec/CometAggregateSuite.scala | 43 +++++ .../spark/sql/CometTPCDSQuerySuite.scala | 9 +- 9 files changed, 306 insertions(+), 5 deletions(-) create mode 100644 core/src/execution/datafusion/expressions/stddev.rs diff --git a/core/src/execution/datafusion/expressions/mod.rs b/core/src/execution/datafusion/expressions/mod.rs index 78763fc2a..10cac1696 100644 --- a/core/src/execution/datafusion/expressions/mod.rs +++ b/core/src/execution/datafusion/expressions/mod.rs @@ -29,6 +29,7 @@ pub mod avg_decimal; pub mod bloom_filter_might_contain; pub mod covariance; pub mod stats; +pub mod stddev; pub mod strings; pub mod subquery; pub mod sum_decimal; diff --git a/core/src/execution/datafusion/expressions/stddev.rs b/core/src/execution/datafusion/expressions/stddev.rs new file mode 100644 index 000000000..bbddf9aa4 --- /dev/null +++ b/core/src/execution/datafusion/expressions/stddev.rs @@ -0,0 +1,179 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{any::Any, sync::Arc}; + +use crate::execution::datafusion::expressions::{ + stats::StatsType, utils::down_cast_any_ref, variance::VarianceAccumulator, +}; +use arrow::{ + array::ArrayRef, + datatypes::{DataType, Field}, +}; +use datafusion::logical_expr::Accumulator; +use datafusion_common::{internal_err, Result, ScalarValue}; +use datafusion_physical_expr::{expressions::format_state_name, AggregateExpr, PhysicalExpr}; + +/// STDDEV and STDDEV_SAMP (standard deviation) aggregate expression +/// The implementation mostly is the same as the DataFusion's implementation. The reason +/// we have our own implementation is that DataFusion has UInt64 for state_field `count`, +/// while Spark has Double for count. Also we have added `null_on_divide_by_zero` +/// to be consistent with Spark's implementation. +#[derive(Debug)] +pub struct Stddev { + name: String, + expr: Arc, + stats_type: StatsType, + null_on_divide_by_zero: bool, +} + +impl Stddev { + /// Create a new STDDEV aggregate function + pub fn new( + expr: Arc, + name: impl Into, + data_type: DataType, + stats_type: StatsType, + null_on_divide_by_zero: bool, + ) -> Self { + // the result of stddev just support FLOAT64. + assert!(matches!(data_type, DataType::Float64)); + Self { + name: name.into(), + expr, + stats_type, + null_on_divide_by_zero, + } + } +} + +impl AggregateExpr for Stddev { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn field(&self) -> Result { + Ok(Field::new(&self.name, DataType::Float64, true)) + } + + fn create_accumulator(&self) -> Result> { + Ok(Box::new(StddevAccumulator::try_new( + self.stats_type, + self.null_on_divide_by_zero, + )?)) + } + + fn create_sliding_accumulator(&self) -> Result> { + Ok(Box::new(StddevAccumulator::try_new( + self.stats_type, + self.null_on_divide_by_zero, + )?)) + } + + fn state_fields(&self) -> Result> { + Ok(vec![ + Field::new( + format_state_name(&self.name, "count"), + DataType::Float64, + true, + ), + Field::new( + format_state_name(&self.name, "mean"), + DataType::Float64, + true, + ), + Field::new(format_state_name(&self.name, "m2"), DataType::Float64, true), + ]) + } + + fn expressions(&self) -> Vec> { + vec![self.expr.clone()] + } + + fn name(&self) -> &str { + &self.name + } +} + +impl PartialEq for Stddev { + fn eq(&self, other: &dyn Any) -> bool { + down_cast_any_ref(other) + .downcast_ref::() + .map(|x| { + self.name == x.name + && self.expr.eq(&x.expr) + && self.null_on_divide_by_zero == x.null_on_divide_by_zero + && self.stats_type == x.stats_type + }) + .unwrap_or(false) + } +} + +/// An accumulator to compute the standard deviation +#[derive(Debug)] +pub struct StddevAccumulator { + variance: VarianceAccumulator, +} + +impl StddevAccumulator { + /// Creates a new `StddevAccumulator` + pub fn try_new(s_type: StatsType, null_on_divide_by_zero: bool) -> Result { + Ok(Self { + variance: VarianceAccumulator::try_new(s_type, null_on_divide_by_zero)?, + }) + } + + pub fn get_m2(&self) -> f64 { + self.variance.get_m2() + } +} + +impl Accumulator for StddevAccumulator { + fn state(&mut self) -> Result> { + Ok(vec![ + ScalarValue::from(self.variance.get_count()), + ScalarValue::from(self.variance.get_mean()), + ScalarValue::from(self.variance.get_m2()), + ]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + self.variance.update_batch(values) + } + + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + self.variance.retract_batch(values) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + self.variance.merge_batch(states) + } + + fn evaluate(&mut self) -> Result { + let variance = self.variance.evaluate()?; + match variance { + ScalarValue::Float64(Some(e)) => Ok(ScalarValue::Float64(Some(e.sqrt()))), + ScalarValue::Float64(None) => Ok(ScalarValue::Float64(None)), + _ => internal_err!("Variance should be f64"), + } + } + + fn size(&self) -> usize { + std::mem::align_of_val(self) - std::mem::align_of_val(&self.variance) + self.variance.size() + } +} diff --git a/core/src/execution/datafusion/expressions/variance.rs b/core/src/execution/datafusion/expressions/variance.rs index 6aae01ed8..f996c13d8 100644 --- a/core/src/execution/datafusion/expressions/variance.rs +++ b/core/src/execution/datafusion/expressions/variance.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -//! Defines physical expressions that can evaluated at runtime during query execution - use std::{any::Any, sync::Arc}; use crate::execution::datafusion::expressions::{stats::StatsType, utils::down_cast_any_ref}; diff --git a/core/src/execution/datafusion/planner.rs b/core/src/execution/datafusion/planner.rs index 72174790b..6a050eb8b 100644 --- a/core/src/execution/datafusion/planner.rs +++ b/core/src/execution/datafusion/planner.rs @@ -71,6 +71,7 @@ use crate::{ if_expr::IfExpr, scalar_funcs::create_comet_physical_fun, stats::StatsType, + stddev::Stddev, strings::{Contains, EndsWith, Like, StartsWith, StringSpaceExec, SubstringExec}, subquery::Subquery, sum_decimal::SumDecimal, @@ -1260,6 +1261,30 @@ impl PhysicalPlanner { ))), } } + AggExprStruct::Stddev(expr) => { + let child = self.create_expr(expr.child.as_ref().unwrap(), schema.clone())?; + let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap()); + match expr.stats_type { + 0 => Ok(Arc::new(Stddev::new( + child, + "stddev", + datatype, + StatsType::Sample, + expr.null_on_divide_by_zero, + ))), + 1 => Ok(Arc::new(Stddev::new( + child, + "stddev_pop", + datatype, + StatsType::Population, + expr.null_on_divide_by_zero, + ))), + stats_type => Err(ExecutionError::GeneralError(format!( + "Unknown StatisticsType {:?} for stddev", + stats_type + ))), + } + } } } diff --git a/core/src/execution/proto/expr.proto b/core/src/execution/proto/expr.proto index 042a981f4..ee3de865a 100644 --- a/core/src/execution/proto/expr.proto +++ b/core/src/execution/proto/expr.proto @@ -95,6 +95,7 @@ message AggExpr { CovSample covSample = 12; CovPopulation covPopulation = 13; Variance variance = 14; + Stddev stddev = 15; } } @@ -178,6 +179,13 @@ message Variance { StatisticsType stats_type = 4; } +message Stddev { + Expr child = 1; + bool null_on_divide_by_zero = 2; + DataType datatype = 3; + StatisticsType stats_type = 4; +} + message Literal { oneof value { bool bool_val = 1; diff --git a/docs/source/user-guide/expressions.md b/docs/source/user-guide/expressions.md index f67a4eada..38c86c727 100644 --- a/docs/source/user-guide/expressions.md +++ b/docs/source/user-guide/expressions.md @@ -107,3 +107,5 @@ The following Spark expressions are currently available: - CovSample - VariancePop - VarianceSamp + - StddevPop + - StddevSamp diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index e77adc9bb..1e8877c8d 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -23,7 +23,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, BitAndAgg, BitOrAgg, BitXorAgg, Count, CovPopulation, CovSample, Final, First, Last, Max, Min, Partial, Sum, VariancePop, VarianceSamp} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, BitAndAgg, BitOrAgg, BitXorAgg, Count, CovPopulation, CovSample, Final, First, Last, Max, Min, Partial, StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp} import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.catalyst.optimizer.{BuildRight, NormalizeNaNAndZero} import org.apache.spark.sql.catalyst.plans._ @@ -506,6 +506,46 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { withInfo(aggExpr, child) None } + case std @ StddevSamp(child, nullOnDivideByZero) => + val childExpr = exprToProto(child, inputs, binding) + val dataType = serializeDataType(std.dataType) + + if (childExpr.isDefined && dataType.isDefined) { + val stdBuilder = ExprOuterClass.Stddev.newBuilder() + stdBuilder.setChild(childExpr.get) + stdBuilder.setNullOnDivideByZero(nullOnDivideByZero) + stdBuilder.setDatatype(dataType.get) + stdBuilder.setStatsTypeValue(0) + + Some( + ExprOuterClass.AggExpr + .newBuilder() + .setStddev(stdBuilder) + .build()) + } else { + withInfo(aggExpr, child) + None + } + case std @ StddevPop(child, nullOnDivideByZero) => + val childExpr = exprToProto(child, inputs, binding) + val dataType = serializeDataType(std.dataType) + + if (childExpr.isDefined && dataType.isDefined) { + val stdBuilder = ExprOuterClass.Stddev.newBuilder() + stdBuilder.setChild(childExpr.get) + stdBuilder.setNullOnDivideByZero(nullOnDivideByZero) + stdBuilder.setDatatype(dataType.get) + stdBuilder.setStatsTypeValue(1) + + Some( + ExprOuterClass.AggExpr + .newBuilder() + .setStddev(stdBuilder) + .build()) + } else { + withInfo(aggExpr, child) + None + } case fn => val msg = s"unsupported Spark aggregate function: ${fn.prettyName}" emitWarning(msg) diff --git a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala index 64c031eed..310a24ee3 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala @@ -1169,6 +1169,49 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("stddev_pop and stddev_samp") { + withSQLConf(CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true") { + Seq(true, false).foreach { cometColumnShuffleEnabled => + withSQLConf( + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> cometColumnShuffleEnabled.toString) { + Seq(true, false).foreach { dictionary => + withSQLConf("parquet.enable.dictionary" -> dictionary.toString) { + Seq(true, false).foreach { nullOnDivideByZero => + withSQLConf( + "spark.sql.legacy.statisticalAggregate" -> nullOnDivideByZero.toString) { + val table = "test" + withTable(table) { + sql(s"create table $table(col1 int, col2 int, col3 int, col4 float, " + + "col5 double, col6 int) using parquet") + sql(s"insert into $table values(1, null, null, 1.1, 2.2, 1), " + + "(2, null, null, 3.4, 5.6, 1), (3, null, 4, 7.9, 2.4, 2)") + val expectedNumOfCometAggregates = 2 + checkSparkAnswerWithTolAndNumOfAggregates( + "SELECT stddev_samp(col1), stddev_samp(col2), stddev_samp(col3), " + + "stddev_samp(col4), stddev_samp(col5) FROM test", + expectedNumOfCometAggregates) + checkSparkAnswerWithTolAndNumOfAggregates( + "SELECT stddev_pop(col1), stddev_pop(col2), stddev_pop(col3), " + + "stddev_pop(col4), stddev_pop(col5) FROM test", + expectedNumOfCometAggregates) + checkSparkAnswerAndNumOfAggregates( + "SELECT stddev_samp(col1), stddev_samp(col2), stddev_samp(col3), " + + "stddev_samp(col4), stddev_samp(col5) FROM test GROUP BY col6", + expectedNumOfCometAggregates) + checkSparkAnswerWithTolAndNumOfAggregates( + "SELECT stddev_pop(col1), stddev_pop(col2), stddev_pop(col3), " + + "stddev_pop(col4), stddev_pop(col5) FROM test GROUP BY col6", + expectedNumOfCometAggregates) + } + } + } + } + } + } + } + } + } + protected def checkSparkAnswerAndNumOfAggregates(query: String, numAggregates: Int): Unit = { val df = sql(query) checkSparkAnswer(df) diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTPCDSQuerySuite.scala b/spark/src/test/scala/org/apache/spark/sql/CometTPCDSQuerySuite.scala index cdbd7194d..3342d750c 100644 --- a/spark/src/test/scala/org/apache/spark/sql/CometTPCDSQuerySuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/CometTPCDSQuerySuite.scala @@ -73,8 +73,13 @@ class CometTPCDSQuerySuite "q36", "q37", "q38", - "q39a", - "q39b", + // TODO: https://github.com/apache/datafusion-comet/issues/392 + // comment out 39a and 39b for now because the expected result for stddev failed: + // expected: 1.5242630430075292, actual: 1.524263043007529. + // Will change the comparison logic to detect floating-point numbers and compare + // with epsilon + // "q39a", + // "q39b", "q40", "q41", "q42", From 8e73f7cab5489d5918512b4ae206e39b96242320 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 6 May 2024 22:41:45 -0600 Subject: [PATCH 2/3] feat: Improve cast compatibility tests and docs (#379) --- .../user-guide/compatibility-template.md | 18 +- docs/source/user-guide/compatibility.md | 203 ++++++++---------- .../scala/org/apache/comet/GenerateDocs.scala | 37 ++-- .../apache/comet/expressions/CometCast.scala | 107 ++++++--- .../apache/comet/serde/QueryPlanSerde.scala | 2 +- .../org/apache/comet/CometCastSuite.scala | 59 +++-- .../apache/comet/exec/CometExecSuite.scala | 3 +- 7 files changed, 251 insertions(+), 178 deletions(-) diff --git a/docs/source/user-guide/compatibility-template.md b/docs/source/user-guide/compatibility-template.md index deaca2d24..64f871354 100644 --- a/docs/source/user-guide/compatibility-template.md +++ b/docs/source/user-guide/compatibility-template.md @@ -44,7 +44,19 @@ Cast operations in Comet fall into three levels of support: - **Unsupported**: Comet does not provide a native version of this cast expression and the query stage will fall back to Spark. -The following table shows the current cast operations supported by Comet. Any cast that does not appear in this -table (such as those involving complex types and timestamp_ntz, for example) are not supported by Comet. +### Compatible Casts - +The following cast operations are generally compatible with Spark except for the differences noted here. + + + +### Incompatible Casts + +The following cast operations are not compatible with Spark for all inputs and are disabled by default. + + + +### Unsupported Casts + +Any cast not listed in the previous tables is currently unsupported. We are working on adding more. See the +[tracking issue](https://github.com/apache/datafusion-comet/issues/286) for more details. diff --git a/docs/source/user-guide/compatibility.md b/docs/source/user-guide/compatibility.md index 9a2478d37..57a4271f4 100644 --- a/docs/source/user-guide/compatibility.md +++ b/docs/source/user-guide/compatibility.md @@ -38,122 +38,89 @@ Cast operations in Comet fall into three levels of support: - **Compatible**: The results match Apache Spark - **Incompatible**: The results may match Apache Spark for some inputs, but there are known issues where some inputs - will result in incorrect results or exceptions. The query stage will fall back to Spark by default. Setting - `spark.comet.cast.allowIncompatible=true` will allow all incompatible casts to run natively in Comet, but this is not - recommended for production use. +will result in incorrect results or exceptions. The query stage will fall back to Spark by default. Setting +`spark.comet.cast.allowIncompatible=true` will allow all incompatible casts to run natively in Comet, but this is not +recommended for production use. - **Unsupported**: Comet does not provide a native version of this cast expression and the query stage will fall back to - Spark. - -The following table shows the current cast operations supported by Comet. Any cast that does not appear in this -table (such as those involving complex types and timestamp_ntz, for example) are not supported by Comet. - -| From Type | To Type | Compatible? | Notes | -| --------- | --------- | ------------ | ----------------------------------- | -| boolean | byte | Compatible | | -| boolean | short | Compatible | | -| boolean | integer | Compatible | | -| boolean | long | Compatible | | -| boolean | float | Compatible | | -| boolean | double | Compatible | | -| boolean | decimal | Unsupported | | -| boolean | string | Compatible | | -| boolean | timestamp | Unsupported | | -| byte | boolean | Compatible | | -| byte | short | Compatible | | -| byte | integer | Compatible | | -| byte | long | Compatible | | -| byte | float | Compatible | | -| byte | double | Compatible | | -| byte | decimal | Compatible | | -| byte | string | Compatible | | -| byte | binary | Unsupported | | -| byte | timestamp | Unsupported | | -| short | boolean | Compatible | | -| short | byte | Compatible | | -| short | integer | Compatible | | -| short | long | Compatible | | -| short | float | Compatible | | -| short | double | Compatible | | -| short | decimal | Compatible | | -| short | string | Compatible | | -| short | binary | Unsupported | | -| short | timestamp | Unsupported | | -| integer | boolean | Compatible | | -| integer | byte | Compatible | | -| integer | short | Compatible | | -| integer | long | Compatible | | -| integer | float | Compatible | | -| integer | double | Compatible | | -| integer | decimal | Compatible | | -| integer | string | Compatible | | -| integer | binary | Unsupported | | -| integer | timestamp | Unsupported | | -| long | boolean | Compatible | | -| long | byte | Compatible | | -| long | short | Compatible | | -| long | integer | Compatible | | -| long | float | Compatible | | -| long | double | Compatible | | -| long | decimal | Compatible | | -| long | string | Compatible | | -| long | binary | Unsupported | | -| long | timestamp | Unsupported | | -| float | boolean | Compatible | | -| float | byte | Unsupported | | -| float | short | Unsupported | | -| float | integer | Unsupported | | -| float | long | Unsupported | | -| float | double | Compatible | | -| float | decimal | Unsupported | | -| float | string | Incompatible | | -| float | timestamp | Unsupported | | -| double | boolean | Compatible | | -| double | byte | Unsupported | | -| double | short | Unsupported | | -| double | integer | Unsupported | | -| double | long | Unsupported | | -| double | float | Compatible | | -| double | decimal | Incompatible | | -| double | string | Incompatible | | -| double | timestamp | Unsupported | | -| decimal | boolean | Unsupported | | -| decimal | byte | Unsupported | | -| decimal | short | Unsupported | | -| decimal | integer | Unsupported | | -| decimal | long | Unsupported | | -| decimal | float | Compatible | | -| decimal | double | Compatible | | -| decimal | string | Unsupported | | -| decimal | timestamp | Unsupported | | -| string | boolean | Compatible | | -| string | byte | Compatible | | -| string | short | Compatible | | -| string | integer | Compatible | | -| string | long | Compatible | | -| string | float | Unsupported | | -| string | double | Unsupported | | -| string | decimal | Unsupported | | -| string | binary | Compatible | | -| string | date | Unsupported | | -| string | timestamp | Incompatible | Not all valid formats are supported | -| binary | string | Incompatible | | -| date | boolean | Unsupported | | -| date | byte | Unsupported | | -| date | short | Unsupported | | -| date | integer | Unsupported | | -| date | long | Unsupported | | -| date | float | Unsupported | | -| date | double | Unsupported | | -| date | decimal | Unsupported | | -| date | string | Compatible | | -| date | timestamp | Unsupported | | -| timestamp | boolean | Unsupported | | -| timestamp | byte | Unsupported | | -| timestamp | short | Unsupported | | -| timestamp | integer | Unsupported | | -| timestamp | long | Compatible | | -| timestamp | float | Unsupported | | -| timestamp | double | Unsupported | | -| timestamp | decimal | Unsupported | | -| timestamp | string | Compatible | | -| timestamp | date | Compatible | | +Spark. + +### Compatible Casts + +The following cast operations are generally compatible with Spark except for the differences noted here. + +| From Type | To Type | Notes | +|-|-|-| +| boolean | byte | | +| boolean | short | | +| boolean | integer | | +| boolean | long | | +| boolean | float | | +| boolean | double | | +| boolean | string | | +| byte | boolean | | +| byte | short | | +| byte | integer | | +| byte | long | | +| byte | float | | +| byte | double | | +| byte | decimal | | +| byte | string | | +| short | boolean | | +| short | byte | | +| short | integer | | +| short | long | | +| short | float | | +| short | double | | +| short | decimal | | +| short | string | | +| integer | boolean | | +| integer | byte | | +| integer | short | | +| integer | long | | +| integer | float | | +| integer | double | | +| integer | string | | +| long | boolean | | +| long | byte | | +| long | short | | +| long | integer | | +| long | float | | +| long | double | | +| long | string | | +| float | boolean | | +| float | double | | +| float | string | There can be differences in precision. For example, the input "1.4E-45" will produce 1.0E-45 instead of 1.4E-45 | +| double | boolean | | +| double | float | | +| double | string | There can be differences in precision. For example, the input "1.4E-45" will produce 1.0E-45 instead of 1.4E-45 | +| decimal | float | | +| decimal | double | | +| string | boolean | | +| string | byte | | +| string | short | | +| string | integer | | +| string | long | | +| string | binary | | +| date | string | | +| timestamp | long | | +| timestamp | decimal | | +| timestamp | string | | +| timestamp | date | | + +### Incompatible Casts + +The following cast operations are not compatible with Spark for all inputs and are disabled by default. + +| From Type | To Type | Notes | +|-|-|-| +| integer | decimal | No overflow check | +| long | decimal | No overflow check | +| float | decimal | No overflow check | +| double | decimal | No overflow check | +| string | timestamp | Not all valid formats are supported | +| binary | string | Only works for binary data representing valid UTF-8 strings | + +### Unsupported Casts + +Any cast not listed in the previous tables is currently unsupported. We are working on adding more. See the +[tracking issue](https://github.com/apache/datafusion-comet/issues/286) for more details. diff --git a/spark/src/main/scala/org/apache/comet/GenerateDocs.scala b/spark/src/main/scala/org/apache/comet/GenerateDocs.scala index 8c414c7fe..1e28efd52 100644 --- a/spark/src/main/scala/org/apache/comet/GenerateDocs.scala +++ b/spark/src/main/scala/org/apache/comet/GenerateDocs.scala @@ -25,7 +25,7 @@ import scala.io.Source import org.apache.spark.sql.catalyst.expressions.Cast -import org.apache.comet.expressions.{CometCast, Compatible, Incompatible, Unsupported} +import org.apache.comet.expressions.{CometCast, Compatible, Incompatible} /** * Utility for generating markdown documentation from the configs. @@ -64,23 +64,36 @@ object GenerateDocs { val outputFilename = "docs/source/user-guide/compatibility.md" val w = new BufferedOutputStream(new FileOutputStream(outputFilename)) for (line <- Source.fromFile(templateFilename).getLines()) { - if (line.trim == "") { - w.write("| From Type | To Type | Compatible? | Notes |\n".getBytes) - w.write("|-|-|-|-|\n".getBytes) + if (line.trim == "") { + w.write("| From Type | To Type | Notes |\n".getBytes) + w.write("|-|-|-|\n".getBytes) for (fromType <- CometCast.supportedTypes) { for (toType <- CometCast.supportedTypes) { if (Cast.canCast(fromType, toType) && fromType != toType) { val fromTypeName = fromType.typeName.replace("(10,2)", "") val toTypeName = toType.typeName.replace("(10,2)", "") CometCast.isSupported(fromType, toType, None, "LEGACY") match { - case Compatible => - w.write(s"| $fromTypeName | $toTypeName | Compatible | |\n".getBytes) - case Incompatible(Some(reason)) => - w.write(s"| $fromTypeName | $toTypeName | Incompatible | $reason |\n".getBytes) - case Incompatible(None) => - w.write(s"| $fromTypeName | $toTypeName | Incompatible | |\n".getBytes) - case Unsupported => - w.write(s"| $fromTypeName | $toTypeName | Unsupported | |\n".getBytes) + case Compatible(notes) => + val notesStr = notes.getOrElse("").trim + w.write(s"| $fromTypeName | $toTypeName | $notesStr |\n".getBytes) + case _ => + } + } + } + } + } else if (line.trim == "") { + w.write("| From Type | To Type | Notes |\n".getBytes) + w.write("|-|-|-|\n".getBytes) + for (fromType <- CometCast.supportedTypes) { + for (toType <- CometCast.supportedTypes) { + if (Cast.canCast(fromType, toType) && fromType != toType) { + val fromTypeName = fromType.typeName.replace("(10,2)", "") + val toTypeName = toType.typeName.replace("(10,2)", "") + CometCast.isSupported(fromType, toType, None, "LEGACY") match { + case Incompatible(notes) => + val notesStr = notes.getOrElse("").trim + w.write(s"| $fromTypeName | $toTypeName | $notesStr |\n".getBytes) + case _ => } } } diff --git a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala index 5641c94a8..57e07b8cd 100644 --- a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala +++ b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala @@ -24,10 +24,10 @@ import org.apache.spark.sql.types.{DataType, DataTypes, DecimalType} sealed trait SupportLevel /** We support this feature with full compatibility with Spark */ -object Compatible extends SupportLevel +case class Compatible(notes: Option[String] = None) extends SupportLevel /** We support this feature but results can be different from Spark */ -case class Incompatible(reason: Option[String] = None) extends SupportLevel +case class Incompatible(notes: Option[String] = None) extends SupportLevel /** We do not support this feature */ object Unsupported extends SupportLevel @@ -58,7 +58,7 @@ object CometCast { evalMode: String): SupportLevel = { if (fromType == toType) { - return Compatible + return Compatible() } (fromType, toType) match { @@ -83,10 +83,14 @@ object CometCast { canCastFromDecimal(toType) case (DataTypes.BooleanType, _) => canCastFromBoolean(toType) - case ( - DataTypes.ByteType | DataTypes.ShortType | DataTypes.IntegerType | DataTypes.LongType, - _) => + case (DataTypes.ByteType, _) => + canCastFromByte(toType) + case (DataTypes.ShortType, _) => + canCastFromShort(toType) + case (DataTypes.IntegerType, _) => canCastFromInt(toType) + case (DataTypes.LongType, _) => + canCastFromLong(toType) case (DataTypes.FloatType, _) => canCastFromFloat(toType) case (DataTypes.DoubleType, _) => @@ -101,12 +105,12 @@ object CometCast { evalMode: String): SupportLevel = { toType match { case DataTypes.BooleanType => - Compatible + Compatible() case DataTypes.ByteType | DataTypes.ShortType | DataTypes.IntegerType | DataTypes.LongType => - Compatible + Compatible() case DataTypes.BinaryType => - Compatible + Compatible() case DataTypes.FloatType | DataTypes.DoubleType => // https://github.com/apache/datafusion-comet/issues/326 Unsupported @@ -130,18 +134,21 @@ object CometCast { private def canCastToString(fromType: DataType): SupportLevel = { fromType match { - case DataTypes.BooleanType => Compatible + case DataTypes.BooleanType => Compatible() case DataTypes.ByteType | DataTypes.ShortType | DataTypes.IntegerType | DataTypes.LongType => - Compatible - case DataTypes.DateType => Compatible - case DataTypes.TimestampType => Compatible + Compatible() + case DataTypes.DateType => Compatible() + case DataTypes.TimestampType => Compatible() case DataTypes.FloatType | DataTypes.DoubleType => - // https://github.com/apache/datafusion-comet/issues/326 - Incompatible() + Compatible( + Some( + "There can be differences in precision. " + + "For example, the input \"1.4E-45\" will produce 1.0E-45 " + + "instead of 1.4E-45")) case DataTypes.BinaryType => // https://github.com/apache/datafusion-comet/issues/377 - Incompatible() + Incompatible(Some("Only works for binary data representing valid UTF-8 strings")) case _ => Unsupported } } @@ -155,9 +162,10 @@ object CometCast { Unsupported case DataTypes.LongType => // https://github.com/apache/datafusion-comet/issues/352 - Compatible - case DataTypes.StringType => Compatible - case DataTypes.DateType => Compatible + Compatible() + case DataTypes.StringType => Compatible() + case DataTypes.DateType => Compatible() + case _: DecimalType => Compatible() case _ => Unsupported } } @@ -165,31 +173,72 @@ object CometCast { private def canCastFromBoolean(toType: DataType): SupportLevel = toType match { case DataTypes.ByteType | DataTypes.ShortType | DataTypes.IntegerType | DataTypes.LongType | DataTypes.FloatType | DataTypes.DoubleType => - Compatible + Compatible() case _ => Unsupported } + private def canCastFromByte(toType: DataType): SupportLevel = toType match { + case DataTypes.BooleanType => + Compatible() + case DataTypes.ShortType | DataTypes.IntegerType | DataTypes.LongType => + Compatible() + case DataTypes.FloatType | DataTypes.DoubleType | _: DecimalType => + Compatible() + case _ => + Unsupported + } + + private def canCastFromShort(toType: DataType): SupportLevel = toType match { + case DataTypes.BooleanType => + Compatible() + case DataTypes.ByteType | DataTypes.IntegerType | DataTypes.LongType => + Compatible() + case DataTypes.FloatType | DataTypes.DoubleType | _: DecimalType => + Compatible() + case _ => + Unsupported + } + private def canCastFromInt(toType: DataType): SupportLevel = toType match { - case DataTypes.BooleanType | DataTypes.ByteType | DataTypes.ShortType | - DataTypes.IntegerType | DataTypes.LongType | DataTypes.FloatType | DataTypes.DoubleType | - _: DecimalType => - Compatible - case _ => Unsupported + case DataTypes.BooleanType => + Compatible() + case DataTypes.ByteType | DataTypes.ShortType | DataTypes.LongType => + Compatible() + case DataTypes.FloatType | DataTypes.DoubleType => + Compatible() + case _: DecimalType => + Incompatible(Some("No overflow check")) + case _ => + Unsupported + } + + private def canCastFromLong(toType: DataType): SupportLevel = toType match { + case DataTypes.BooleanType => + Compatible() + case DataTypes.ByteType | DataTypes.ShortType | DataTypes.IntegerType => + Compatible() + case DataTypes.FloatType | DataTypes.DoubleType => + Compatible() + case _: DecimalType => + Incompatible(Some("No overflow check")) + case _ => + Unsupported } private def canCastFromFloat(toType: DataType): SupportLevel = toType match { - case DataTypes.BooleanType | DataTypes.DoubleType => Compatible + case DataTypes.BooleanType | DataTypes.DoubleType => Compatible() + case _: DecimalType => Incompatible(Some("No overflow check")) case _ => Unsupported } private def canCastFromDouble(toType: DataType): SupportLevel = toType match { - case DataTypes.BooleanType | DataTypes.FloatType => Compatible - case _: DecimalType => Incompatible() + case DataTypes.BooleanType | DataTypes.FloatType => Compatible() + case _: DecimalType => Incompatible(Some("No overflow check")) case _ => Unsupported } private def canCastFromDecimal(toType: DataType): SupportLevel = toType match { - case DataTypes.FloatType | DataTypes.DoubleType => Compatible + case DataTypes.FloatType | DataTypes.DoubleType => Compatible() case _ => Unsupported } diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 1e8877c8d..86e9f10b9 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -636,7 +636,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { reason.map(str => s" ($str)").getOrElse("") castSupport match { - case Compatible => + case Compatible(_) => castToProto(timeZoneId, dt, childExpr, evalModeStr) case Incompatible(reason) => if (CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.get()) { diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala index 483301e02..1d698a49a 100644 --- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.functions.col import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, DataTypes} -import org.apache.comet.expressions.CometCast +import org.apache.comet.expressions.{CometCast, Compatible} class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { import testImplicits._ @@ -66,6 +66,23 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } else if (!testExists) { fail(s"Missing test: $expectedTestName") + } else { + val testIgnored = + tags.get(expectedTestName).exists(s => s.contains("org.scalatest.Ignore")) + CometCast.isSupported(fromType, toType, None, "LEGACY") match { + case Compatible(_) => + if (testIgnored) { + fail( + s"Cast from $fromType to $toType is reported as compatible " + + "with Spark but the test is ignored") + } + case _ => + if (!testIgnored) { + fail( + s"We claim that cast from $fromType to $toType is not compatible " + + "with Spark but the test is not ignored") + } + } } } else if (testExists) { fail(s"Found test for cast that Spark does not support: $expectedTestName") @@ -347,7 +364,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { Short.MaxValue.toFloat, 0.0f) ++ Range(0, dataSize).map(_ => r.nextFloat()) - withNulls(values).toDF("a") + castTest(withNulls(values).toDF("a"), DataTypes.StringType) } ignore("cast FloatType to TimestampType") { @@ -401,7 +418,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { Double.NegativeInfinity, 0.0d) ++ Range(0, dataSize).map(_ => r.nextDouble()) - withNulls(values).toDF("a") + castTest(withNulls(values).toDF("a"), DataTypes.StringType) } ignore("cast DoubleType to TimestampType") { @@ -559,6 +576,14 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + ignore("cast StringType to TimestampType") { + // https://github.com/apache/datafusion-comet/issues/328 + withSQLConf((CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key, "true")) { + val values = Seq("2020-01-01T12:34:56.123456", "T2") ++ generateStrings(timestampPattern, 8) + castTest(values.toDF("a"), DataTypes.TimestampType) + } + } + test("cast StringType to TimestampType disabled for non-UTC timezone") { withSQLConf((SQLConf.SESSION_LOCAL_TIMEZONE.key, "America/Denver")) { val values = Seq("2020-01-01T12:34:56.123456", "T2").toDF("a") @@ -569,15 +594,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } - ignore("cast StringType to TimestampType (fuzz test)") { - // https://github.com/apache/datafusion-comet/issues/328 - withSQLConf((CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key, "true")) { - val values = Seq("2020-01-01T12:34:56.123456", "T2") ++ generateStrings(timestampPattern, 8) - castTest(values.toDF("a"), DataTypes.TimestampType) - } - } - - test("cast StringType to TimestampType") { + test("cast StringType to TimestampType - subset of supported values") { withSQLConf( SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC", CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key -> "true") { @@ -606,8 +623,12 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { // CAST from BinaryType ignore("cast BinaryType to StringType") { - // TODO implement this // https://github.com/apache/datafusion-comet/issues/377 + castTest(generateBinary(), DataTypes.StringType) + } + + test("cast BinaryType to StringType - valid UTF-8 inputs") { + castTest(generateStrings(numericPattern, 8).toDF("a"), DataTypes.StringType) } // CAST from DateType @@ -795,7 +816,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { Seq( "2024-01-01T12:34:56.123456", "2024-01-01T01:00:00Z", - "2024-12-31T01:00:00-02:00", + "9999-12-31T01:00:00-02:00", "2024-12-31T01:00:00+02:00") withNulls(values) .toDF("str") @@ -814,6 +835,16 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { Range(0, dataSize).map(_ => generateString(r, chars, maxLen)) } + private def generateBinary(): DataFrame = { + val r = new Random(0) + val bytes = new Array[Byte](8) + val values: Seq[Array[Byte]] = Range(0, dataSize).map(_ => { + r.nextBytes(bytes) + bytes.clone() + }) + values.toDF("a") + } + private def withNulls[T](values: Seq[T]): Seq[Option[T]] = { values.map(v => Some(v)) ++ Seq(None) } diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index 47c2c696a..8f022988f 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -253,7 +253,8 @@ class CometExecSuite extends CometTestBase { dataTypes.map { subqueryType => withSQLConf( CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", - CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true") { + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key -> "true") { withParquetTable((0 until 5).map(i => (i, i + 1)), "tbl") { var column1 = s"CAST(max(_1) AS $subqueryType)" if (subqueryType == "BINARY") { From 8bba58e4157cc4af80916b3a001a03a1f08c34e0 Mon Sep 17 00:00:00 2001 From: Xin Hao Date: Tue, 7 May 2024 12:58:20 +0800 Subject: [PATCH 3/3] docs: fix the docs url of installation instructions (#393) --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 5cbd32c65..fb17535aa 100644 --- a/README.md +++ b/README.md @@ -69,8 +69,8 @@ Linux, Apple OSX (Intel and M1) ## Getting started -See the [DataFusion Comet User Guide](https://datafusion.apache.org/comet/user-guide/) for installation instructions. +See the [DataFusion Comet User Guide](https://datafusion.apache.org/comet/user-guide/installation.html) for installation instructions. ## Contributing See the [DataFusion Comet Contribution Guide](https://datafusion.apache.org/comet/contributor-guide/contributing.html) -for information on how to get started contributing to the project. \ No newline at end of file +for information on how to get started contributing to the project.