From 5cf3264aa0964181fa3b4def668526885cd3d4d4 Mon Sep 17 00:00:00 2001 From: Emil Ejbyfeldt Date: Sun, 30 Jun 2024 21:06:09 +0200 Subject: [PATCH] Add support for CreateNamedStruct --- .../expressions/create_named_struct.rs | 127 ++++++++++++++++++ .../execution/datafusion/expressions/mod.rs | 1 + core/src/execution/datafusion/planner.rs | 11 +- core/src/execution/proto/expr.proto | 6 + .../apache/comet/serde/QueryPlanSerde.scala | 19 +++ .../apache/comet/CometExpressionSuite.scala | 14 ++ 6 files changed, 177 insertions(+), 1 deletion(-) create mode 100644 core/src/execution/datafusion/expressions/create_named_struct.rs diff --git a/core/src/execution/datafusion/expressions/create_named_struct.rs b/core/src/execution/datafusion/expressions/create_named_struct.rs new file mode 100644 index 0000000000..17d763cca9 --- /dev/null +++ b/core/src/execution/datafusion/expressions/create_named_struct.rs @@ -0,0 +1,127 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{ + any::Any, + fmt::{Display, Formatter}, + hash::{Hash, Hasher}, + sync::Arc, +}; + +use arrow::record_batch::RecordBatch; +use arrow_array::StructArray; +use arrow_schema::{DataType, Schema}; +use datafusion::logical_expr::ColumnarValue; +use datafusion_common::{DataFusionError, Result as DataFusionResult}; +use datafusion_physical_expr::PhysicalExpr; + +use crate::execution::datafusion::expressions::utils::down_cast_any_ref; + +#[derive(Debug, Hash)] +pub struct CreateNamedStruct { + values: Vec>, + data_type: DataType, +} + +impl CreateNamedStruct { + pub fn new(values: Vec>, data_type: DataType) -> Self { + Self { values, data_type } + } +} + +impl PhysicalExpr for CreateNamedStruct { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, _input_schema: &Schema) -> DataFusionResult { + Ok(self.data_type.clone()) + } + + fn nullable(&self, _input_schema: &Schema) -> DataFusionResult { + Ok(false) + } + + fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult { + let values = self + .values + .iter() + .map(|expr| expr.evaluate(batch)) + .collect::>>()?; + let arrays = ColumnarValue::values_to_arrays(&values)?; + let fields = match &self.data_type { + DataType::Struct(fields) => fields, + _ => { + return Err(DataFusionError::Internal(format!( + "Expected struct data type, got {:?}", + self.data_type + ))) + } + }; + Ok(ColumnarValue::Array(Arc::new(StructArray::new( + fields.clone(), + arrays, + None, + )))) + } + + fn children(&self) -> Vec<&Arc> { + self.values.iter().collect() + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> datafusion_common::Result> { + Ok(Arc::new(CreateNamedStruct::new( + children.clone(), + self.data_type.clone(), + ))) + } + + fn dyn_hash(&self, state: &mut dyn Hasher) { + let mut s = state; + self.values.hash(&mut s); + self.data_type.hash(&mut s); + self.hash(&mut s); + } +} + +impl Display for CreateNamedStruct { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "CreateNamedStruct [values: {:?}, data_type: {:?}]", + self.values, self.data_type + ) + } +} + +impl PartialEq for CreateNamedStruct { + fn eq(&self, other: &dyn Any) -> bool { + down_cast_any_ref(other) + .downcast_ref::() + .map(|x| { + self.values + .iter() + .zip(x.values.iter()) + .all(|(a, b)| a.eq(b)) + && self.data_type.eq(&x.data_type) + }) + .unwrap_or(false) + } +} diff --git a/core/src/execution/datafusion/expressions/mod.rs b/core/src/execution/datafusion/expressions/mod.rs index d91e25980c..385857f1e4 100644 --- a/core/src/execution/datafusion/expressions/mod.rs +++ b/core/src/execution/datafusion/expressions/mod.rs @@ -33,6 +33,7 @@ pub mod avg_decimal; pub mod bloom_filter_might_contain; pub mod correlation; pub mod covariance; +pub mod create_named_struct; pub mod negative; pub mod stats; pub mod stddev; diff --git a/core/src/execution/datafusion/planner.rs b/core/src/execution/datafusion/planner.rs index afdebce328..acad261c51 100644 --- a/core/src/execution/datafusion/planner.rs +++ b/core/src/execution/datafusion/planner.rs @@ -97,7 +97,7 @@ use crate::{ }, }; -use super::expressions::{abs::CometAbsFunc, EvalMode}; +use super::expressions::{abs::CometAbsFunc, create_named_struct::CreateNamedStruct, EvalMode}; // For clippy error on type_complexity. type ExecResult = Result; @@ -584,6 +584,15 @@ impl PhysicalPlanner { value_expr, )?)) } + ExprStruct::CreateNamedStruct(expr) => { + let values = expr + .values + .iter() + .map(|expr| self.create_expr(expr, input_schema.clone())) + .collect::, _>>()?; + let data_type = to_arrow_datatype(expr.datatype.as_ref().unwrap()); + Ok(Arc::new(CreateNamedStruct::new(values, data_type))) + } expr => Err(ExecutionError::GeneralError(format!( "Not implemented: {:?}", expr diff --git a/core/src/execution/proto/expr.proto b/core/src/execution/proto/expr.proto index 6b66a307ad..56518d9eed 100644 --- a/core/src/execution/proto/expr.proto +++ b/core/src/execution/proto/expr.proto @@ -77,6 +77,7 @@ message Expr { Subquery subquery = 50; UnboundReference unbound = 51; BloomFilterMightContain bloom_filter_might_contain = 52; + CreateNamedStruct create_named_struct = 53; } } @@ -486,6 +487,11 @@ message BloomFilterMightContain { Expr value = 2; } +message CreateNamedStruct { + repeated Expr values = 1; + DataType datatype = 2; +} + enum SortDirection { Ascending = 0; Descending = 1; diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 1e61ef75e7..76e31ba48b 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -2141,6 +2141,25 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim scalarExprToProtoWithReturnType(algorithm, StringType, childExpr) } + case struct @ CreateNamedStruct(_) => + val valExprs = struct.valExprs.map(exprToProto(_, inputs, binding)) + val dataType = serializeDataType(struct.dataType) + + if (valExprs.forall(_.isDefined) && dataType.isDefined) { + val structBuilder = ExprOuterClass.CreateNamedStruct.newBuilder() + structBuilder.addAllValues(valExprs.map(_.get).asJava) + structBuilder.setDatatype(dataType.get) + + Some( + ExprOuterClass.Expr + .newBuilder() + .setCreateNamedStruct(structBuilder) + .build()) + } else { + withInfo(expr, struct.valExprs: _*) + None + } + case _ => withInfo(expr, s"${expr.prettyName} is not supported", expr.children: _*) None diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 8dbfb71b38..a5c19e6a05 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -1719,4 +1719,18 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } } + test("named_struct") { + Seq(true, false).foreach { dictionaryEnabled => + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test.parquet") + makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, 10000) + withParquetTable(path.toString, "tbl") { + checkSparkAnswerAndOperator("SELECT named_struct('a', _1, 'b', _2) FROM tbl") + checkSparkAnswerAndOperator("SELECT named_struct('a', _1, 'b', 2) FROM tbl") + checkSparkAnswerAndOperator( + "SELECT named_struct('a', named_struct('b', _1, 'c', _2)) FROM tbl") + } + } + } + } }