From d2b3d1c7538b9fb7ab9cfc0c4c6a238b0dcd91e6 Mon Sep 17 00:00:00 2001 From: Edmondo Porcu Date: Mon, 1 Jan 2024 14:09:41 -0500 Subject: [PATCH 01/18] Rename `expr::window_function::WindowFunction` to `WindowFunctionDefinition`, make structure consistent with ScalarFunction (#8382) * Refactoring WindowFunction into coherent structure with AggregateFunction * One more cargo fmt --------- Co-authored-by: Andrew Lamb --- datafusion/core/src/dataframe/mod.rs | 6 +- .../core/src/physical_optimizer/test_utils.rs | 4 +- datafusion/core/tests/dataframe/mod.rs | 4 +- .../core/tests/fuzz_cases/window_fuzz.rs | 46 +- .../expr/src/built_in_window_function.rs | 207 ++++++++ datafusion/expr/src/expr.rs | 291 ++++++++++- datafusion/expr/src/lib.rs | 6 +- datafusion/expr/src/udwf.rs | 2 +- datafusion/expr/src/utils.rs | 22 +- datafusion/expr/src/window_function.rs | 483 ------------------ .../src/analyzer/count_wildcard_rule.rs | 10 +- .../optimizer/src/analyzer/type_coercion.rs | 8 +- .../optimizer/src/push_down_projection.rs | 6 +- datafusion/physical-plan/src/windows/mod.rs | 28 +- .../proto/src/logical_plan/from_proto.rs | 8 +- datafusion/proto/src/logical_plan/to_proto.rs | 10 +- .../proto/src/physical_plan/from_proto.rs | 10 +- .../tests/cases/roundtrip_logical_plan.rs | 20 +- datafusion/sql/src/expr/function.rs | 19 +- .../substrait/src/logical_plan/consumer.rs | 4 +- 20 files changed, 613 insertions(+), 581 deletions(-) create mode 100644 datafusion/expr/src/built_in_window_function.rs delete mode 100644 datafusion/expr/src/window_function.rs diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 3c3bcd497b7f..5a8c706e32cd 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -1360,7 +1360,7 @@ mod tests { use datafusion_expr::{ avg, cast, count, count_distinct, create_udf, expr, lit, max, min, sum, BuiltInWindowFunction, ScalarFunctionImplementation, Volatility, WindowFrame, - WindowFunction, + WindowFunctionDefinition, }; use datafusion_physical_expr::expressions::Column; use datafusion_physical_plan::get_plan_string; @@ -1525,7 +1525,9 @@ mod tests { // build plan using Table API let t = test_table().await?; let first_row = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::FirstValue), + WindowFunctionDefinition::BuiltInWindowFunction( + BuiltInWindowFunction::FirstValue, + ), vec![col("aggregate_test_100.c1")], vec![col("aggregate_test_100.c2")], vec![], diff --git a/datafusion/core/src/physical_optimizer/test_utils.rs b/datafusion/core/src/physical_optimizer/test_utils.rs index 6e14cca21fed..debafefe39ab 100644 --- a/datafusion/core/src/physical_optimizer/test_utils.rs +++ b/datafusion/core/src/physical_optimizer/test_utils.rs @@ -41,7 +41,7 @@ use crate::prelude::{CsvReadOptions, SessionContext}; use arrow_schema::{Schema, SchemaRef, SortOptions}; use datafusion_common::{JoinType, Statistics}; use datafusion_execution::object_store::ObjectStoreUrl; -use datafusion_expr::{AggregateFunction, WindowFrame, WindowFunction}; +use datafusion_expr::{AggregateFunction, WindowFrame, WindowFunctionDefinition}; use datafusion_physical_expr::expressions::col; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; @@ -234,7 +234,7 @@ pub fn bounded_window_exec( Arc::new( crate::physical_plan::windows::BoundedWindowAggExec::try_new( vec![create_window_expr( - &WindowFunction::AggregateFunction(AggregateFunction::Count), + &WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count), "count".to_owned(), &[col(col_name, &schema).unwrap()], &[], diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index ba661aa2445c..cca23ac6847c 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -45,7 +45,7 @@ use datafusion_expr::expr::{GroupingSet, Sort}; use datafusion_expr::{ array_agg, avg, col, count, exists, expr, in_subquery, lit, max, out_ref_col, scalar_subquery, sum, wildcard, AggregateFunction, Expr, ExprSchemable, WindowFrame, - WindowFrameBound, WindowFrameUnits, WindowFunction, + WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; use datafusion_physical_expr::var_provider::{VarProvider, VarType}; @@ -170,7 +170,7 @@ async fn test_count_wildcard_on_window() -> Result<()> { .table("t1") .await? .select(vec![Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Count), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count), vec![wildcard()], vec![], vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))], diff --git a/datafusion/core/tests/fuzz_cases/window_fuzz.rs b/datafusion/core/tests/fuzz_cases/window_fuzz.rs index 44ff71d02392..3037b4857a3b 100644 --- a/datafusion/core/tests/fuzz_cases/window_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/window_fuzz.rs @@ -33,7 +33,7 @@ use datafusion_common::{Result, ScalarValue}; use datafusion_expr::type_coercion::aggregates::coerce_types; use datafusion_expr::{ AggregateFunction, BuiltInWindowFunction, WindowFrame, WindowFrameBound, - WindowFrameUnits, WindowFunction, + WindowFrameUnits, WindowFunctionDefinition, }; use datafusion_physical_expr::expressions::{cast, col, lit}; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; @@ -143,7 +143,7 @@ fn get_random_function( schema: &SchemaRef, rng: &mut StdRng, is_linear: bool, -) -> (WindowFunction, Vec>, String) { +) -> (WindowFunctionDefinition, Vec>, String) { let mut args = if is_linear { // In linear test for the test version with WindowAggExec we use insert SortExecs to the plan to be able to generate // same result with BoundedWindowAggExec which doesn't use any SortExec. To make result @@ -159,28 +159,28 @@ fn get_random_function( window_fn_map.insert( "sum", ( - WindowFunction::AggregateFunction(AggregateFunction::Sum), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Sum), vec![], ), ); window_fn_map.insert( "count", ( - WindowFunction::AggregateFunction(AggregateFunction::Count), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count), vec![], ), ); window_fn_map.insert( "min", ( - WindowFunction::AggregateFunction(AggregateFunction::Min), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Min), vec![], ), ); window_fn_map.insert( "max", ( - WindowFunction::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), vec![], ), ); @@ -191,28 +191,36 @@ fn get_random_function( window_fn_map.insert( "row_number", ( - WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::RowNumber), + WindowFunctionDefinition::BuiltInWindowFunction( + BuiltInWindowFunction::RowNumber, + ), vec![], ), ); window_fn_map.insert( "rank", ( - WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::Rank), + WindowFunctionDefinition::BuiltInWindowFunction( + BuiltInWindowFunction::Rank, + ), vec![], ), ); window_fn_map.insert( "dense_rank", ( - WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::DenseRank), + WindowFunctionDefinition::BuiltInWindowFunction( + BuiltInWindowFunction::DenseRank, + ), vec![], ), ); window_fn_map.insert( "lead", ( - WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::Lead), + WindowFunctionDefinition::BuiltInWindowFunction( + BuiltInWindowFunction::Lead, + ), vec![ lit(ScalarValue::Int64(Some(rng.gen_range(1..10)))), lit(ScalarValue::Int64(Some(rng.gen_range(1..1000)))), @@ -222,7 +230,9 @@ fn get_random_function( window_fn_map.insert( "lag", ( - WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::Lag), + WindowFunctionDefinition::BuiltInWindowFunction( + BuiltInWindowFunction::Lag, + ), vec![ lit(ScalarValue::Int64(Some(rng.gen_range(1..10)))), lit(ScalarValue::Int64(Some(rng.gen_range(1..1000)))), @@ -233,21 +243,27 @@ fn get_random_function( window_fn_map.insert( "first_value", ( - WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::FirstValue), + WindowFunctionDefinition::BuiltInWindowFunction( + BuiltInWindowFunction::FirstValue, + ), vec![], ), ); window_fn_map.insert( "last_value", ( - WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::LastValue), + WindowFunctionDefinition::BuiltInWindowFunction( + BuiltInWindowFunction::LastValue, + ), vec![], ), ); window_fn_map.insert( "nth_value", ( - WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::NthValue), + WindowFunctionDefinition::BuiltInWindowFunction( + BuiltInWindowFunction::NthValue, + ), vec![lit(ScalarValue::Int64(Some(rng.gen_range(1..10))))], ), ); @@ -255,7 +271,7 @@ fn get_random_function( let rand_fn_idx = rng.gen_range(0..window_fn_map.len()); let fn_name = window_fn_map.keys().collect::>()[rand_fn_idx]; let (window_fn, new_args) = window_fn_map.values().collect::>()[rand_fn_idx]; - if let WindowFunction::AggregateFunction(f) = window_fn { + if let WindowFunctionDefinition::AggregateFunction(f) = window_fn { let a = args[0].clone(); let dt = a.data_type(schema.as_ref()).unwrap(); let sig = f.signature(); diff --git a/datafusion/expr/src/built_in_window_function.rs b/datafusion/expr/src/built_in_window_function.rs new file mode 100644 index 000000000000..a03e3d2d24a9 --- /dev/null +++ b/datafusion/expr/src/built_in_window_function.rs @@ -0,0 +1,207 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Built-in functions module contains all the built-in functions definitions. + +use std::fmt; +use std::str::FromStr; + +use crate::type_coercion::functions::data_types; +use crate::utils; +use crate::{Signature, TypeSignature, Volatility}; +use datafusion_common::{plan_datafusion_err, plan_err, DataFusionError, Result}; + +use arrow::datatypes::DataType; + +use strum_macros::EnumIter; + +impl fmt::Display for BuiltInWindowFunction { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.name()) + } +} + +/// A [window function] built in to DataFusion +/// +/// [window function]: https://en.wikipedia.org/wiki/Window_function_(SQL) +#[derive(Debug, Clone, PartialEq, Eq, Hash, EnumIter)] +pub enum BuiltInWindowFunction { + /// number of the current row within its partition, counting from 1 + RowNumber, + /// rank of the current row with gaps; same as row_number of its first peer + Rank, + /// rank of the current row without gaps; this function counts peer groups + DenseRank, + /// relative rank of the current row: (rank - 1) / (total rows - 1) + PercentRank, + /// relative rank of the current row: (number of rows preceding or peer with current row) / (total rows) + CumeDist, + /// integer ranging from 1 to the argument value, dividing the partition as equally as possible + Ntile, + /// returns value evaluated at the row that is offset rows before the current row within the partition; + /// if there is no such row, instead return default (which must be of the same type as value). + /// Both offset and default are evaluated with respect to the current row. + /// If omitted, offset defaults to 1 and default to null + Lag, + /// returns value evaluated at the row that is offset rows after the current row within the partition; + /// if there is no such row, instead return default (which must be of the same type as value). + /// Both offset and default are evaluated with respect to the current row. + /// If omitted, offset defaults to 1 and default to null + Lead, + /// returns value evaluated at the row that is the first row of the window frame + FirstValue, + /// returns value evaluated at the row that is the last row of the window frame + LastValue, + /// returns value evaluated at the row that is the nth row of the window frame (counting from 1); null if no such row + NthValue, +} + +impl BuiltInWindowFunction { + fn name(&self) -> &str { + use BuiltInWindowFunction::*; + match self { + RowNumber => "ROW_NUMBER", + Rank => "RANK", + DenseRank => "DENSE_RANK", + PercentRank => "PERCENT_RANK", + CumeDist => "CUME_DIST", + Ntile => "NTILE", + Lag => "LAG", + Lead => "LEAD", + FirstValue => "FIRST_VALUE", + LastValue => "LAST_VALUE", + NthValue => "NTH_VALUE", + } + } +} + +impl FromStr for BuiltInWindowFunction { + type Err = DataFusionError; + fn from_str(name: &str) -> Result { + Ok(match name.to_uppercase().as_str() { + "ROW_NUMBER" => BuiltInWindowFunction::RowNumber, + "RANK" => BuiltInWindowFunction::Rank, + "DENSE_RANK" => BuiltInWindowFunction::DenseRank, + "PERCENT_RANK" => BuiltInWindowFunction::PercentRank, + "CUME_DIST" => BuiltInWindowFunction::CumeDist, + "NTILE" => BuiltInWindowFunction::Ntile, + "LAG" => BuiltInWindowFunction::Lag, + "LEAD" => BuiltInWindowFunction::Lead, + "FIRST_VALUE" => BuiltInWindowFunction::FirstValue, + "LAST_VALUE" => BuiltInWindowFunction::LastValue, + "NTH_VALUE" => BuiltInWindowFunction::NthValue, + _ => return plan_err!("There is no built-in window function named {name}"), + }) + } +} + +/// Returns the datatype of the built-in window function +impl BuiltInWindowFunction { + pub fn return_type(&self, input_expr_types: &[DataType]) -> Result { + // Note that this function *must* return the same type that the respective physical expression returns + // or the execution panics. + + // verify that this is a valid set of data types for this function + data_types(input_expr_types, &self.signature()) + // original errors are all related to wrong function signature + // aggregate them for better error message + .map_err(|_| { + plan_datafusion_err!( + "{}", + utils::generate_signature_error_msg( + &format!("{self}"), + self.signature(), + input_expr_types, + ) + ) + })?; + + match self { + BuiltInWindowFunction::RowNumber + | BuiltInWindowFunction::Rank + | BuiltInWindowFunction::DenseRank => Ok(DataType::UInt64), + BuiltInWindowFunction::PercentRank | BuiltInWindowFunction::CumeDist => { + Ok(DataType::Float64) + } + BuiltInWindowFunction::Ntile => Ok(DataType::UInt64), + BuiltInWindowFunction::Lag + | BuiltInWindowFunction::Lead + | BuiltInWindowFunction::FirstValue + | BuiltInWindowFunction::LastValue + | BuiltInWindowFunction::NthValue => Ok(input_expr_types[0].clone()), + } + } + + /// the signatures supported by the built-in window function `fun`. + pub fn signature(&self) -> Signature { + // note: the physical expression must accept the type returned by this function or the execution panics. + match self { + BuiltInWindowFunction::RowNumber + | BuiltInWindowFunction::Rank + | BuiltInWindowFunction::DenseRank + | BuiltInWindowFunction::PercentRank + | BuiltInWindowFunction::CumeDist => Signature::any(0, Volatility::Immutable), + BuiltInWindowFunction::Lag | BuiltInWindowFunction::Lead => { + Signature::one_of( + vec![ + TypeSignature::Any(1), + TypeSignature::Any(2), + TypeSignature::Any(3), + ], + Volatility::Immutable, + ) + } + BuiltInWindowFunction::FirstValue | BuiltInWindowFunction::LastValue => { + Signature::any(1, Volatility::Immutable) + } + BuiltInWindowFunction::Ntile => Signature::uniform( + 1, + vec![ + DataType::UInt64, + DataType::UInt32, + DataType::UInt16, + DataType::UInt8, + DataType::Int64, + DataType::Int32, + DataType::Int16, + DataType::Int8, + ], + Volatility::Immutable, + ), + BuiltInWindowFunction::NthValue => Signature::any(2, Volatility::Immutable), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use strum::IntoEnumIterator; + #[test] + // Test for BuiltInWindowFunction's Display and from_str() implementations. + // For each variant in BuiltInWindowFunction, it converts the variant to a string + // and then back to a variant. The test asserts that the original variant and + // the reconstructed variant are the same. This assertion is also necessary for + // function suggestion. See https://github.com/apache/arrow-datafusion/issues/8082 + fn test_display_and_from_str() { + for func_original in BuiltInWindowFunction::iter() { + let func_name = func_original.to_string(); + let func_from_str = BuiltInWindowFunction::from_str(&func_name).unwrap(); + assert_eq!(func_from_str, func_original); + } + } +} diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 0ec19bcadbf6..ebf4d3143c12 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -19,13 +19,13 @@ use crate::expr_fn::binary_expr; use crate::logical_plan::Subquery; -use crate::udaf; use crate::utils::{expr_to_columns, find_out_reference_exprs}; use crate::window_frame; -use crate::window_function; + use crate::Operator; use crate::{aggregate_function, ExprSchemable}; use crate::{built_in_function, BuiltinScalarFunction}; +use crate::{built_in_window_function, udaf}; use arrow::datatypes::DataType; use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{internal_err, DFSchema, OwnedTableReference}; @@ -34,8 +34,11 @@ use std::collections::HashSet; use std::fmt; use std::fmt::{Display, Formatter, Write}; use std::hash::{BuildHasher, Hash, Hasher}; +use std::str::FromStr; use std::sync::Arc; +use crate::Signature; + /// `Expr` is a central struct of DataFusion's query API, and /// represent logical expressions such as `A + 1`, or `CAST(c1 AS /// int)`. @@ -566,11 +569,64 @@ impl AggregateFunction { } } +/// WindowFunction +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +/// Defines which implementation of an aggregate function DataFusion should call. +pub enum WindowFunctionDefinition { + /// A built in aggregate function that leverages an aggregate function + AggregateFunction(aggregate_function::AggregateFunction), + /// A a built-in window function + BuiltInWindowFunction(built_in_window_function::BuiltInWindowFunction), + /// A user defined aggregate function + AggregateUDF(Arc), + /// A user defined aggregate function + WindowUDF(Arc), +} + +impl WindowFunctionDefinition { + /// Returns the datatype of the window function + pub fn return_type(&self, input_expr_types: &[DataType]) -> Result { + match self { + WindowFunctionDefinition::AggregateFunction(fun) => { + fun.return_type(input_expr_types) + } + WindowFunctionDefinition::BuiltInWindowFunction(fun) => { + fun.return_type(input_expr_types) + } + WindowFunctionDefinition::AggregateUDF(fun) => { + fun.return_type(input_expr_types) + } + WindowFunctionDefinition::WindowUDF(fun) => fun.return_type(input_expr_types), + } + } + + /// the signatures supported by the function `fun`. + pub fn signature(&self) -> Signature { + match self { + WindowFunctionDefinition::AggregateFunction(fun) => fun.signature(), + WindowFunctionDefinition::BuiltInWindowFunction(fun) => fun.signature(), + WindowFunctionDefinition::AggregateUDF(fun) => fun.signature().clone(), + WindowFunctionDefinition::WindowUDF(fun) => fun.signature().clone(), + } + } +} + +impl fmt::Display for WindowFunctionDefinition { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + WindowFunctionDefinition::AggregateFunction(fun) => fun.fmt(f), + WindowFunctionDefinition::BuiltInWindowFunction(fun) => fun.fmt(f), + WindowFunctionDefinition::AggregateUDF(fun) => std::fmt::Debug::fmt(fun, f), + WindowFunctionDefinition::WindowUDF(fun) => fun.fmt(f), + } + } +} + /// Window function #[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct WindowFunction { /// Name of the function - pub fun: window_function::WindowFunction, + pub fun: WindowFunctionDefinition, /// List of expressions to feed to the functions as arguments pub args: Vec, /// List of partition by expressions @@ -584,7 +640,7 @@ pub struct WindowFunction { impl WindowFunction { /// Create a new Window expression pub fn new( - fun: window_function::WindowFunction, + fun: WindowFunctionDefinition, args: Vec, partition_by: Vec, order_by: Vec, @@ -600,6 +656,50 @@ impl WindowFunction { } } +/// Find DataFusion's built-in window function by name. +pub fn find_df_window_func(name: &str) -> Option { + let name = name.to_lowercase(); + // Code paths for window functions leveraging ordinary aggregators and + // built-in window functions are quite different, and the same function + // may have different implementations for these cases. If the sought + // function is not found among built-in window functions, we search for + // it among aggregate functions. + if let Ok(built_in_function) = + built_in_window_function::BuiltInWindowFunction::from_str(name.as_str()) + { + Some(WindowFunctionDefinition::BuiltInWindowFunction( + built_in_function, + )) + } else if let Ok(aggregate) = + aggregate_function::AggregateFunction::from_str(name.as_str()) + { + Some(WindowFunctionDefinition::AggregateFunction(aggregate)) + } else { + None + } +} + +/// Returns the datatype of the window function +#[deprecated( + since = "27.0.0", + note = "please use `WindowFunction::return_type` instead" +)] +pub fn return_type( + fun: &WindowFunctionDefinition, + input_expr_types: &[DataType], +) -> Result { + fun.return_type(input_expr_types) +} + +/// the signatures supported by the function `fun`. +#[deprecated( + since = "27.0.0", + note = "please use `WindowFunction::signature` instead" +)] +pub fn signature(fun: &WindowFunctionDefinition) -> Signature { + fun.signature() +} + // Exists expression. #[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct Exists { @@ -1890,4 +1990,187 @@ mod test { .is_volatile() .expect_err("Shouldn't determine volatility of unresolved function"); } + + use super::*; + + #[test] + fn test_count_return_type() -> Result<()> { + let fun = find_df_window_func("count").unwrap(); + let observed = fun.return_type(&[DataType::Utf8])?; + assert_eq!(DataType::Int64, observed); + + let observed = fun.return_type(&[DataType::UInt64])?; + assert_eq!(DataType::Int64, observed); + + Ok(()) + } + + #[test] + fn test_first_value_return_type() -> Result<()> { + let fun = find_df_window_func("first_value").unwrap(); + let observed = fun.return_type(&[DataType::Utf8])?; + assert_eq!(DataType::Utf8, observed); + + let observed = fun.return_type(&[DataType::UInt64])?; + assert_eq!(DataType::UInt64, observed); + + Ok(()) + } + + #[test] + fn test_last_value_return_type() -> Result<()> { + let fun = find_df_window_func("last_value").unwrap(); + let observed = fun.return_type(&[DataType::Utf8])?; + assert_eq!(DataType::Utf8, observed); + + let observed = fun.return_type(&[DataType::Float64])?; + assert_eq!(DataType::Float64, observed); + + Ok(()) + } + + #[test] + fn test_lead_return_type() -> Result<()> { + let fun = find_df_window_func("lead").unwrap(); + let observed = fun.return_type(&[DataType::Utf8])?; + assert_eq!(DataType::Utf8, observed); + + let observed = fun.return_type(&[DataType::Float64])?; + assert_eq!(DataType::Float64, observed); + + Ok(()) + } + + #[test] + fn test_lag_return_type() -> Result<()> { + let fun = find_df_window_func("lag").unwrap(); + let observed = fun.return_type(&[DataType::Utf8])?; + assert_eq!(DataType::Utf8, observed); + + let observed = fun.return_type(&[DataType::Float64])?; + assert_eq!(DataType::Float64, observed); + + Ok(()) + } + + #[test] + fn test_nth_value_return_type() -> Result<()> { + let fun = find_df_window_func("nth_value").unwrap(); + let observed = fun.return_type(&[DataType::Utf8, DataType::UInt64])?; + assert_eq!(DataType::Utf8, observed); + + let observed = fun.return_type(&[DataType::Float64, DataType::UInt64])?; + assert_eq!(DataType::Float64, observed); + + Ok(()) + } + + #[test] + fn test_percent_rank_return_type() -> Result<()> { + let fun = find_df_window_func("percent_rank").unwrap(); + let observed = fun.return_type(&[])?; + assert_eq!(DataType::Float64, observed); + + Ok(()) + } + + #[test] + fn test_cume_dist_return_type() -> Result<()> { + let fun = find_df_window_func("cume_dist").unwrap(); + let observed = fun.return_type(&[])?; + assert_eq!(DataType::Float64, observed); + + Ok(()) + } + + #[test] + fn test_ntile_return_type() -> Result<()> { + let fun = find_df_window_func("ntile").unwrap(); + let observed = fun.return_type(&[DataType::Int16])?; + assert_eq!(DataType::UInt64, observed); + + Ok(()) + } + + #[test] + fn test_window_function_case_insensitive() -> Result<()> { + let names = vec![ + "row_number", + "rank", + "dense_rank", + "percent_rank", + "cume_dist", + "ntile", + "lag", + "lead", + "first_value", + "last_value", + "nth_value", + "min", + "max", + "count", + "avg", + "sum", + ]; + for name in names { + let fun = find_df_window_func(name).unwrap(); + let fun2 = find_df_window_func(name.to_uppercase().as_str()).unwrap(); + assert_eq!(fun, fun2); + assert_eq!(fun.to_string(), name.to_uppercase()); + } + Ok(()) + } + + #[test] + fn test_find_df_window_function() { + assert_eq!( + find_df_window_func("max"), + Some(WindowFunctionDefinition::AggregateFunction( + aggregate_function::AggregateFunction::Max + )) + ); + assert_eq!( + find_df_window_func("min"), + Some(WindowFunctionDefinition::AggregateFunction( + aggregate_function::AggregateFunction::Min + )) + ); + assert_eq!( + find_df_window_func("avg"), + Some(WindowFunctionDefinition::AggregateFunction( + aggregate_function::AggregateFunction::Avg + )) + ); + assert_eq!( + find_df_window_func("cume_dist"), + Some(WindowFunctionDefinition::BuiltInWindowFunction( + built_in_window_function::BuiltInWindowFunction::CumeDist + )) + ); + assert_eq!( + find_df_window_func("first_value"), + Some(WindowFunctionDefinition::BuiltInWindowFunction( + built_in_window_function::BuiltInWindowFunction::FirstValue + )) + ); + assert_eq!( + find_df_window_func("LAST_value"), + Some(WindowFunctionDefinition::BuiltInWindowFunction( + built_in_window_function::BuiltInWindowFunction::LastValue + )) + ); + assert_eq!( + find_df_window_func("LAG"), + Some(WindowFunctionDefinition::BuiltInWindowFunction( + built_in_window_function::BuiltInWindowFunction::Lag + )) + ); + assert_eq!( + find_df_window_func("LEAD"), + Some(WindowFunctionDefinition::BuiltInWindowFunction( + built_in_window_function::BuiltInWindowFunction::Lead + )) + ); + assert_eq!(find_df_window_func("not_exist"), None) + } } diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index bf8e9e2954f4..ab213a19a352 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -27,6 +27,7 @@ mod accumulator; mod built_in_function; +mod built_in_window_function; mod columnar_value; mod literal; mod nullif; @@ -53,16 +54,16 @@ pub mod tree_node; pub mod type_coercion; pub mod utils; pub mod window_frame; -pub mod window_function; pub mod window_state; pub use accumulator::Accumulator; pub use aggregate_function::AggregateFunction; pub use built_in_function::BuiltinScalarFunction; +pub use built_in_window_function::BuiltInWindowFunction; pub use columnar_value::ColumnarValue; pub use expr::{ Between, BinaryExpr, Case, Cast, Expr, GetFieldAccess, GetIndexedField, GroupingSet, - Like, ScalarFunctionDefinition, TryCast, + Like, ScalarFunctionDefinition, TryCast, WindowFunctionDefinition, }; pub use expr_fn::*; pub use expr_schema::ExprSchemable; @@ -83,7 +84,6 @@ pub use udaf::AggregateUDF; pub use udf::{ScalarUDF, ScalarUDFImpl}; pub use udwf::WindowUDF; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; -pub use window_function::{BuiltInWindowFunction, WindowFunction}; #[cfg(test)] #[ctor::ctor] diff --git a/datafusion/expr/src/udwf.rs b/datafusion/expr/src/udwf.rs index c233ee84b32d..a97a68341f5c 100644 --- a/datafusion/expr/src/udwf.rs +++ b/datafusion/expr/src/udwf.rs @@ -107,7 +107,7 @@ impl WindowUDF { order_by: Vec, window_frame: WindowFrame, ) -> Expr { - let fun = crate::WindowFunction::WindowUDF(Arc::new(self.clone())); + let fun = crate::WindowFunctionDefinition::WindowUDF(Arc::new(self.clone())); Expr::WindowFunction(crate::expr::WindowFunction { fun, diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 09f4842c9e64..e3ecdf154e61 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -1234,7 +1234,7 @@ mod tests { use super::*; use crate::{ col, cube, expr, expr_vec_fmt, grouping_set, lit, rollup, AggregateFunction, - WindowFrame, WindowFunction, + WindowFrame, WindowFunctionDefinition, }; #[test] @@ -1248,28 +1248,28 @@ mod tests { #[test] fn test_group_window_expr_by_sort_keys_empty_window() -> Result<()> { let max1 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), vec![col("name")], vec![], vec![], WindowFrame::new(false), )); let max2 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), vec![col("name")], vec![], vec![], WindowFrame::new(false), )); let min3 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Min), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Min), vec![col("name")], vec![], vec![], WindowFrame::new(false), )); let sum4 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Sum), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Sum), vec![col("age")], vec![], vec![], @@ -1291,28 +1291,28 @@ mod tests { let created_at_desc = Expr::Sort(expr::Sort::new(Box::new(col("created_at")), false, true)); let max1 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), vec![col("name")], vec![], vec![age_asc.clone(), name_desc.clone()], WindowFrame::new(true), )); let max2 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), vec![col("name")], vec![], vec![], WindowFrame::new(false), )); let min3 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Min), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Min), vec![col("name")], vec![], vec![age_asc.clone(), name_desc.clone()], WindowFrame::new(true), )); let sum4 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Sum), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Sum), vec![col("age")], vec![], vec![name_desc.clone(), age_asc.clone(), created_at_desc.clone()], @@ -1343,7 +1343,7 @@ mod tests { fn test_find_sort_exprs() -> Result<()> { let exprs = &[ Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), vec![col("name")], vec![], vec![ @@ -1353,7 +1353,7 @@ mod tests { WindowFrame::new(true), )), Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Sum), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Sum), vec![col("age")], vec![], vec![ diff --git a/datafusion/expr/src/window_function.rs b/datafusion/expr/src/window_function.rs deleted file mode 100644 index 610f1ecaeae9..000000000000 --- a/datafusion/expr/src/window_function.rs +++ /dev/null @@ -1,483 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Window functions provide the ability to perform calculations across -//! sets of rows that are related to the current query row. -//! -//! see also - -use crate::aggregate_function::AggregateFunction; -use crate::type_coercion::functions::data_types; -use crate::utils; -use crate::{AggregateUDF, Signature, TypeSignature, Volatility, WindowUDF}; -use arrow::datatypes::DataType; -use datafusion_common::{plan_datafusion_err, plan_err, DataFusionError, Result}; -use std::sync::Arc; -use std::{fmt, str::FromStr}; -use strum_macros::EnumIter; - -/// WindowFunction -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub enum WindowFunction { - /// A built in aggregate function that leverages an aggregate function - AggregateFunction(AggregateFunction), - /// A a built-in window function - BuiltInWindowFunction(BuiltInWindowFunction), - /// A user defined aggregate function - AggregateUDF(Arc), - /// A user defined aggregate function - WindowUDF(Arc), -} - -/// Find DataFusion's built-in window function by name. -pub fn find_df_window_func(name: &str) -> Option { - let name = name.to_lowercase(); - // Code paths for window functions leveraging ordinary aggregators and - // built-in window functions are quite different, and the same function - // may have different implementations for these cases. If the sought - // function is not found among built-in window functions, we search for - // it among aggregate functions. - if let Ok(built_in_function) = BuiltInWindowFunction::from_str(name.as_str()) { - Some(WindowFunction::BuiltInWindowFunction(built_in_function)) - } else if let Ok(aggregate) = AggregateFunction::from_str(name.as_str()) { - Some(WindowFunction::AggregateFunction(aggregate)) - } else { - None - } -} - -impl fmt::Display for BuiltInWindowFunction { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", self.name()) - } -} - -impl fmt::Display for WindowFunction { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - WindowFunction::AggregateFunction(fun) => fun.fmt(f), - WindowFunction::BuiltInWindowFunction(fun) => fun.fmt(f), - WindowFunction::AggregateUDF(fun) => std::fmt::Debug::fmt(fun, f), - WindowFunction::WindowUDF(fun) => fun.fmt(f), - } - } -} - -/// A [window function] built in to DataFusion -/// -/// [window function]: https://en.wikipedia.org/wiki/Window_function_(SQL) -#[derive(Debug, Clone, PartialEq, Eq, Hash, EnumIter)] -pub enum BuiltInWindowFunction { - /// number of the current row within its partition, counting from 1 - RowNumber, - /// rank of the current row with gaps; same as row_number of its first peer - Rank, - /// rank of the current row without gaps; this function counts peer groups - DenseRank, - /// relative rank of the current row: (rank - 1) / (total rows - 1) - PercentRank, - /// relative rank of the current row: (number of rows preceding or peer with current row) / (total rows) - CumeDist, - /// integer ranging from 1 to the argument value, dividing the partition as equally as possible - Ntile, - /// returns value evaluated at the row that is offset rows before the current row within the partition; - /// if there is no such row, instead return default (which must be of the same type as value). - /// Both offset and default are evaluated with respect to the current row. - /// If omitted, offset defaults to 1 and default to null - Lag, - /// returns value evaluated at the row that is offset rows after the current row within the partition; - /// if there is no such row, instead return default (which must be of the same type as value). - /// Both offset and default are evaluated with respect to the current row. - /// If omitted, offset defaults to 1 and default to null - Lead, - /// returns value evaluated at the row that is the first row of the window frame - FirstValue, - /// returns value evaluated at the row that is the last row of the window frame - LastValue, - /// returns value evaluated at the row that is the nth row of the window frame (counting from 1); null if no such row - NthValue, -} - -impl BuiltInWindowFunction { - fn name(&self) -> &str { - use BuiltInWindowFunction::*; - match self { - RowNumber => "ROW_NUMBER", - Rank => "RANK", - DenseRank => "DENSE_RANK", - PercentRank => "PERCENT_RANK", - CumeDist => "CUME_DIST", - Ntile => "NTILE", - Lag => "LAG", - Lead => "LEAD", - FirstValue => "FIRST_VALUE", - LastValue => "LAST_VALUE", - NthValue => "NTH_VALUE", - } - } -} - -impl FromStr for BuiltInWindowFunction { - type Err = DataFusionError; - fn from_str(name: &str) -> Result { - Ok(match name.to_uppercase().as_str() { - "ROW_NUMBER" => BuiltInWindowFunction::RowNumber, - "RANK" => BuiltInWindowFunction::Rank, - "DENSE_RANK" => BuiltInWindowFunction::DenseRank, - "PERCENT_RANK" => BuiltInWindowFunction::PercentRank, - "CUME_DIST" => BuiltInWindowFunction::CumeDist, - "NTILE" => BuiltInWindowFunction::Ntile, - "LAG" => BuiltInWindowFunction::Lag, - "LEAD" => BuiltInWindowFunction::Lead, - "FIRST_VALUE" => BuiltInWindowFunction::FirstValue, - "LAST_VALUE" => BuiltInWindowFunction::LastValue, - "NTH_VALUE" => BuiltInWindowFunction::NthValue, - _ => return plan_err!("There is no built-in window function named {name}"), - }) - } -} - -/// Returns the datatype of the window function -#[deprecated( - since = "27.0.0", - note = "please use `WindowFunction::return_type` instead" -)] -pub fn return_type( - fun: &WindowFunction, - input_expr_types: &[DataType], -) -> Result { - fun.return_type(input_expr_types) -} - -impl WindowFunction { - /// Returns the datatype of the window function - pub fn return_type(&self, input_expr_types: &[DataType]) -> Result { - match self { - WindowFunction::AggregateFunction(fun) => fun.return_type(input_expr_types), - WindowFunction::BuiltInWindowFunction(fun) => { - fun.return_type(input_expr_types) - } - WindowFunction::AggregateUDF(fun) => fun.return_type(input_expr_types), - WindowFunction::WindowUDF(fun) => fun.return_type(input_expr_types), - } - } -} - -/// Returns the datatype of the built-in window function -impl BuiltInWindowFunction { - pub fn return_type(&self, input_expr_types: &[DataType]) -> Result { - // Note that this function *must* return the same type that the respective physical expression returns - // or the execution panics. - - // verify that this is a valid set of data types for this function - data_types(input_expr_types, &self.signature()) - // original errors are all related to wrong function signature - // aggregate them for better error message - .map_err(|_| { - plan_datafusion_err!( - "{}", - utils::generate_signature_error_msg( - &format!("{self}"), - self.signature(), - input_expr_types, - ) - ) - })?; - - match self { - BuiltInWindowFunction::RowNumber - | BuiltInWindowFunction::Rank - | BuiltInWindowFunction::DenseRank => Ok(DataType::UInt64), - BuiltInWindowFunction::PercentRank | BuiltInWindowFunction::CumeDist => { - Ok(DataType::Float64) - } - BuiltInWindowFunction::Ntile => Ok(DataType::UInt64), - BuiltInWindowFunction::Lag - | BuiltInWindowFunction::Lead - | BuiltInWindowFunction::FirstValue - | BuiltInWindowFunction::LastValue - | BuiltInWindowFunction::NthValue => Ok(input_expr_types[0].clone()), - } - } -} - -/// the signatures supported by the function `fun`. -#[deprecated( - since = "27.0.0", - note = "please use `WindowFunction::signature` instead" -)] -pub fn signature(fun: &WindowFunction) -> Signature { - fun.signature() -} - -impl WindowFunction { - /// the signatures supported by the function `fun`. - pub fn signature(&self) -> Signature { - match self { - WindowFunction::AggregateFunction(fun) => fun.signature(), - WindowFunction::BuiltInWindowFunction(fun) => fun.signature(), - WindowFunction::AggregateUDF(fun) => fun.signature().clone(), - WindowFunction::WindowUDF(fun) => fun.signature().clone(), - } - } -} - -/// the signatures supported by the built-in window function `fun`. -#[deprecated( - since = "27.0.0", - note = "please use `BuiltInWindowFunction::signature` instead" -)] -pub fn signature_for_built_in(fun: &BuiltInWindowFunction) -> Signature { - fun.signature() -} - -impl BuiltInWindowFunction { - /// the signatures supported by the built-in window function `fun`. - pub fn signature(&self) -> Signature { - // note: the physical expression must accept the type returned by this function or the execution panics. - match self { - BuiltInWindowFunction::RowNumber - | BuiltInWindowFunction::Rank - | BuiltInWindowFunction::DenseRank - | BuiltInWindowFunction::PercentRank - | BuiltInWindowFunction::CumeDist => Signature::any(0, Volatility::Immutable), - BuiltInWindowFunction::Lag | BuiltInWindowFunction::Lead => { - Signature::one_of( - vec![ - TypeSignature::Any(1), - TypeSignature::Any(2), - TypeSignature::Any(3), - ], - Volatility::Immutable, - ) - } - BuiltInWindowFunction::FirstValue | BuiltInWindowFunction::LastValue => { - Signature::any(1, Volatility::Immutable) - } - BuiltInWindowFunction::Ntile => Signature::uniform( - 1, - vec![ - DataType::UInt64, - DataType::UInt32, - DataType::UInt16, - DataType::UInt8, - DataType::Int64, - DataType::Int32, - DataType::Int16, - DataType::Int8, - ], - Volatility::Immutable, - ), - BuiltInWindowFunction::NthValue => Signature::any(2, Volatility::Immutable), - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use strum::IntoEnumIterator; - - #[test] - fn test_count_return_type() -> Result<()> { - let fun = find_df_window_func("count").unwrap(); - let observed = fun.return_type(&[DataType::Utf8])?; - assert_eq!(DataType::Int64, observed); - - let observed = fun.return_type(&[DataType::UInt64])?; - assert_eq!(DataType::Int64, observed); - - Ok(()) - } - - #[test] - fn test_first_value_return_type() -> Result<()> { - let fun = find_df_window_func("first_value").unwrap(); - let observed = fun.return_type(&[DataType::Utf8])?; - assert_eq!(DataType::Utf8, observed); - - let observed = fun.return_type(&[DataType::UInt64])?; - assert_eq!(DataType::UInt64, observed); - - Ok(()) - } - - #[test] - fn test_last_value_return_type() -> Result<()> { - let fun = find_df_window_func("last_value").unwrap(); - let observed = fun.return_type(&[DataType::Utf8])?; - assert_eq!(DataType::Utf8, observed); - - let observed = fun.return_type(&[DataType::Float64])?; - assert_eq!(DataType::Float64, observed); - - Ok(()) - } - - #[test] - fn test_lead_return_type() -> Result<()> { - let fun = find_df_window_func("lead").unwrap(); - let observed = fun.return_type(&[DataType::Utf8])?; - assert_eq!(DataType::Utf8, observed); - - let observed = fun.return_type(&[DataType::Float64])?; - assert_eq!(DataType::Float64, observed); - - Ok(()) - } - - #[test] - fn test_lag_return_type() -> Result<()> { - let fun = find_df_window_func("lag").unwrap(); - let observed = fun.return_type(&[DataType::Utf8])?; - assert_eq!(DataType::Utf8, observed); - - let observed = fun.return_type(&[DataType::Float64])?; - assert_eq!(DataType::Float64, observed); - - Ok(()) - } - - #[test] - fn test_nth_value_return_type() -> Result<()> { - let fun = find_df_window_func("nth_value").unwrap(); - let observed = fun.return_type(&[DataType::Utf8, DataType::UInt64])?; - assert_eq!(DataType::Utf8, observed); - - let observed = fun.return_type(&[DataType::Float64, DataType::UInt64])?; - assert_eq!(DataType::Float64, observed); - - Ok(()) - } - - #[test] - fn test_percent_rank_return_type() -> Result<()> { - let fun = find_df_window_func("percent_rank").unwrap(); - let observed = fun.return_type(&[])?; - assert_eq!(DataType::Float64, observed); - - Ok(()) - } - - #[test] - fn test_cume_dist_return_type() -> Result<()> { - let fun = find_df_window_func("cume_dist").unwrap(); - let observed = fun.return_type(&[])?; - assert_eq!(DataType::Float64, observed); - - Ok(()) - } - - #[test] - fn test_ntile_return_type() -> Result<()> { - let fun = find_df_window_func("ntile").unwrap(); - let observed = fun.return_type(&[DataType::Int16])?; - assert_eq!(DataType::UInt64, observed); - - Ok(()) - } - - #[test] - fn test_window_function_case_insensitive() -> Result<()> { - let names = vec![ - "row_number", - "rank", - "dense_rank", - "percent_rank", - "cume_dist", - "ntile", - "lag", - "lead", - "first_value", - "last_value", - "nth_value", - "min", - "max", - "count", - "avg", - "sum", - ]; - for name in names { - let fun = find_df_window_func(name).unwrap(); - let fun2 = find_df_window_func(name.to_uppercase().as_str()).unwrap(); - assert_eq!(fun, fun2); - assert_eq!(fun.to_string(), name.to_uppercase()); - } - Ok(()) - } - - #[test] - fn test_find_df_window_function() { - assert_eq!( - find_df_window_func("max"), - Some(WindowFunction::AggregateFunction(AggregateFunction::Max)) - ); - assert_eq!( - find_df_window_func("min"), - Some(WindowFunction::AggregateFunction(AggregateFunction::Min)) - ); - assert_eq!( - find_df_window_func("avg"), - Some(WindowFunction::AggregateFunction(AggregateFunction::Avg)) - ); - assert_eq!( - find_df_window_func("cume_dist"), - Some(WindowFunction::BuiltInWindowFunction( - BuiltInWindowFunction::CumeDist - )) - ); - assert_eq!( - find_df_window_func("first_value"), - Some(WindowFunction::BuiltInWindowFunction( - BuiltInWindowFunction::FirstValue - )) - ); - assert_eq!( - find_df_window_func("LAST_value"), - Some(WindowFunction::BuiltInWindowFunction( - BuiltInWindowFunction::LastValue - )) - ); - assert_eq!( - find_df_window_func("LAG"), - Some(WindowFunction::BuiltInWindowFunction( - BuiltInWindowFunction::Lag - )) - ); - assert_eq!( - find_df_window_func("LEAD"), - Some(WindowFunction::BuiltInWindowFunction( - BuiltInWindowFunction::Lead - )) - ); - assert_eq!(find_df_window_func("not_exist"), None) - } - - #[test] - // Test for BuiltInWindowFunction's Display and from_str() implementations. - // For each variant in BuiltInWindowFunction, it converts the variant to a string - // and then back to a variant. The test asserts that the original variant and - // the reconstructed variant are the same. This assertion is also necessary for - // function suggestion. See https://github.com/apache/arrow-datafusion/issues/8082 - fn test_display_and_from_str() { - for func_original in BuiltInWindowFunction::iter() { - let func_name = func_original.to_string(); - let func_from_str = BuiltInWindowFunction::from_str(&func_name).unwrap(); - assert_eq!(func_from_str, func_original); - } - } -} diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index fd84bb80160b..953716713e41 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -24,7 +24,7 @@ use datafusion_expr::expr_rewriter::rewrite_preserving_name; use datafusion_expr::utils::COUNT_STAR_EXPANSION; use datafusion_expr::Expr::ScalarSubquery; use datafusion_expr::{ - aggregate_function, expr, lit, window_function, Aggregate, Expr, Filter, LogicalPlan, + aggregate_function, expr, lit, Aggregate, Expr, Filter, LogicalPlan, LogicalPlanBuilder, Projection, Sort, Subquery, }; use std::sync::Arc; @@ -121,7 +121,7 @@ impl TreeNodeRewriter for CountWildcardRewriter { let new_expr = match old_expr.clone() { Expr::WindowFunction(expr::WindowFunction { fun: - window_function::WindowFunction::AggregateFunction( + expr::WindowFunctionDefinition::AggregateFunction( aggregate_function::AggregateFunction::Count, ), args, @@ -131,7 +131,7 @@ impl TreeNodeRewriter for CountWildcardRewriter { }) if args.len() == 1 => match args[0] { Expr::Wildcard { qualifier: None } => { Expr::WindowFunction(expr::WindowFunction { - fun: window_function::WindowFunction::AggregateFunction( + fun: expr::WindowFunctionDefinition::AggregateFunction( aggregate_function::AggregateFunction::Count, ), args: vec![lit(COUNT_STAR_EXPANSION)], @@ -229,7 +229,7 @@ mod tests { use datafusion_expr::{ col, count, exists, expr, in_subquery, lit, logical_plan::LogicalPlanBuilder, max, out_ref_col, scalar_subquery, wildcard, AggregateFunction, Expr, - WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunction, + WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; fn assert_plan_eq(plan: &LogicalPlan, expected: &str) -> Result<()> { @@ -342,7 +342,7 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .window(vec![Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Count), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count), vec![wildcard()], vec![], vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))], diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index b6298f5b552f..4d54dad99670 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -45,9 +45,9 @@ use datafusion_expr::type_coercion::{is_datetime, is_utf8_or_large_utf8}; use datafusion_expr::utils::merge_schema; use datafusion_expr::{ is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown, - type_coercion, window_function, AggregateFunction, BuiltinScalarFunction, Expr, - ExprSchemable, LogicalPlan, Operator, Projection, ScalarFunctionDefinition, - Signature, WindowFrame, WindowFrameBound, WindowFrameUnits, + type_coercion, AggregateFunction, BuiltinScalarFunction, Expr, ExprSchemable, + LogicalPlan, Operator, Projection, ScalarFunctionDefinition, Signature, WindowFrame, + WindowFrameBound, WindowFrameUnits, }; use crate::analyzer::AnalyzerRule; @@ -390,7 +390,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter { coerce_window_frame(window_frame, &self.schema, &order_by)?; let args = match &fun { - window_function::WindowFunction::AggregateFunction(fun) => { + expr::WindowFunctionDefinition::AggregateFunction(fun) => { coerce_agg_exprs_for_signature( fun, &args, diff --git a/datafusion/optimizer/src/push_down_projection.rs b/datafusion/optimizer/src/push_down_projection.rs index 10cc1879aeeb..4ee4f7e417a6 100644 --- a/datafusion/optimizer/src/push_down_projection.rs +++ b/datafusion/optimizer/src/push_down_projection.rs @@ -37,7 +37,7 @@ mod tests { }; use datafusion_expr::{ col, count, lit, max, min, AggregateFunction, Expr, LogicalPlan, Projection, - WindowFrame, WindowFunction, + WindowFrame, WindowFunctionDefinition, }; #[test] @@ -582,7 +582,7 @@ mod tests { let table_scan = test_table_scan()?; let max1 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), vec![col("test.a")], vec![col("test.b")], vec![], @@ -590,7 +590,7 @@ mod tests { )); let max2 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), vec![col("test.b")], vec![], vec![], diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index 3187e6b0fbd3..fec168fabf48 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -34,8 +34,8 @@ use arrow::datatypes::Schema; use arrow_schema::{DataType, Field, SchemaRef}; use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue}; use datafusion_expr::{ - window_function::{BuiltInWindowFunction, WindowFunction}, - PartitionEvaluator, WindowFrame, WindowUDF, + BuiltInWindowFunction, PartitionEvaluator, WindowFrame, WindowFunctionDefinition, + WindowUDF, }; use datafusion_physical_expr::equivalence::collapse_lex_req; use datafusion_physical_expr::{ @@ -56,7 +56,7 @@ pub use datafusion_physical_expr::window::{ /// Create a physical expression for window function pub fn create_window_expr( - fun: &WindowFunction, + fun: &WindowFunctionDefinition, name: String, args: &[Arc], partition_by: &[Arc], @@ -65,7 +65,7 @@ pub fn create_window_expr( input_schema: &Schema, ) -> Result> { Ok(match fun { - WindowFunction::AggregateFunction(fun) => { + WindowFunctionDefinition::AggregateFunction(fun) => { let aggregate = aggregates::create_aggregate_expr( fun, false, @@ -81,13 +81,15 @@ pub fn create_window_expr( aggregate, ) } - WindowFunction::BuiltInWindowFunction(fun) => Arc::new(BuiltInWindowExpr::new( - create_built_in_window_expr(fun, args, input_schema, name)?, - partition_by, - order_by, - window_frame, - )), - WindowFunction::AggregateUDF(fun) => { + WindowFunctionDefinition::BuiltInWindowFunction(fun) => { + Arc::new(BuiltInWindowExpr::new( + create_built_in_window_expr(fun, args, input_schema, name)?, + partition_by, + order_by, + window_frame, + )) + } + WindowFunctionDefinition::AggregateUDF(fun) => { let aggregate = udaf::create_aggregate_expr(fun.as_ref(), args, input_schema, name)?; window_expr_from_aggregate_expr( @@ -97,7 +99,7 @@ pub fn create_window_expr( aggregate, ) } - WindowFunction::WindowUDF(fun) => Arc::new(BuiltInWindowExpr::new( + WindowFunctionDefinition::WindowUDF(fun) => Arc::new(BuiltInWindowExpr::new( create_udwf_window_expr(fun, args, input_schema, name)?, partition_by, order_by, @@ -647,7 +649,7 @@ mod tests { let refs = blocking_exec.refs(); let window_agg_exec = Arc::new(WindowAggExec::try_new( vec![create_window_expr( - &WindowFunction::AggregateFunction(AggregateFunction::Count), + &WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count), "count".to_owned(), &[col("a", &schema)?], &[], diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index c582e92dc11c..36c5b44f00b9 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -1112,7 +1112,7 @@ pub fn parse_expr( let aggr_function = parse_i32_to_aggregate_function(i)?; Ok(Expr::WindowFunction(WindowFunction::new( - datafusion_expr::window_function::WindowFunction::AggregateFunction( + datafusion_expr::expr::WindowFunctionDefinition::AggregateFunction( aggr_function, ), vec![parse_required_expr(expr.expr.as_deref(), registry, "expr")?], @@ -1131,7 +1131,7 @@ pub fn parse_expr( .unwrap_or_else(Vec::new); Ok(Expr::WindowFunction(WindowFunction::new( - datafusion_expr::window_function::WindowFunction::BuiltInWindowFunction( + datafusion_expr::expr::WindowFunctionDefinition::BuiltInWindowFunction( built_in_function, ), args, @@ -1146,7 +1146,7 @@ pub fn parse_expr( .map(|e| vec![e]) .unwrap_or_else(Vec::new); Ok(Expr::WindowFunction(WindowFunction::new( - datafusion_expr::window_function::WindowFunction::AggregateUDF( + datafusion_expr::expr::WindowFunctionDefinition::AggregateUDF( udaf_function, ), args, @@ -1161,7 +1161,7 @@ pub fn parse_expr( .map(|e| vec![e]) .unwrap_or_else(Vec::new); Ok(Expr::WindowFunction(WindowFunction::new( - datafusion_expr::window_function::WindowFunction::WindowUDF( + datafusion_expr::expr::WindowFunctionDefinition::WindowUDF( udwf_function, ), args, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index b9987ff6c727..a162b2389cd1 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -51,7 +51,7 @@ use datafusion_expr::expr::{ use datafusion_expr::{ logical_plan::PlanType, logical_plan::StringifiedPlan, AggregateFunction, BuiltInWindowFunction, BuiltinScalarFunction, Expr, JoinConstraint, JoinType, - TryCast, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunction, + TryCast, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; #[derive(Debug)] @@ -605,22 +605,22 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { ref window_frame, }) => { let window_function = match fun { - WindowFunction::AggregateFunction(fun) => { + WindowFunctionDefinition::AggregateFunction(fun) => { protobuf::window_expr_node::WindowFunction::AggrFunction( protobuf::AggregateFunction::from(fun).into(), ) } - WindowFunction::BuiltInWindowFunction(fun) => { + WindowFunctionDefinition::BuiltInWindowFunction(fun) => { protobuf::window_expr_node::WindowFunction::BuiltInFunction( protobuf::BuiltInWindowFunction::from(fun).into(), ) } - WindowFunction::AggregateUDF(aggr_udf) => { + WindowFunctionDefinition::AggregateUDF(aggr_udf) => { protobuf::window_expr_node::WindowFunction::Udaf( aggr_udf.name().to_string(), ) } - WindowFunction::WindowUDF(window_udf) => { + WindowFunctionDefinition::WindowUDF(window_udf) => { protobuf::window_expr_node::WindowFunction::Udwf( window_udf.name().to_string(), ) diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index 8ad6d679df4d..23ab813ca739 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -31,7 +31,7 @@ use datafusion::datasource::object_store::ObjectStoreUrl; use datafusion::datasource::physical_plan::{FileScanConfig, FileSinkConfig}; use datafusion::execution::context::ExecutionProps; use datafusion::execution::FunctionRegistry; -use datafusion::logical_expr::window_function::WindowFunction; +use datafusion::logical_expr::WindowFunctionDefinition; use datafusion::physical_expr::{PhysicalSortExpr, ScalarFunctionExpr}; use datafusion::physical_plan::expressions::{ in_list, BinaryExpr, CaseExpr, CastExpr, Column, IsNotNullExpr, IsNullExpr, LikeExpr, @@ -414,7 +414,9 @@ fn parse_required_physical_expr( }) } -impl TryFrom<&protobuf::physical_window_expr_node::WindowFunction> for WindowFunction { +impl TryFrom<&protobuf::physical_window_expr_node::WindowFunction> + for WindowFunctionDefinition +{ type Error = DataFusionError; fn try_from( @@ -428,7 +430,7 @@ impl TryFrom<&protobuf::physical_window_expr_node::WindowFunction> for WindowFun )) })?; - Ok(WindowFunction::AggregateFunction(f.into())) + Ok(WindowFunctionDefinition::AggregateFunction(f.into())) } protobuf::physical_window_expr_node::WindowFunction::BuiltInFunction(n) => { let f = protobuf::BuiltInWindowFunction::try_from(*n).map_err(|_| { @@ -437,7 +439,7 @@ impl TryFrom<&protobuf::physical_window_expr_node::WindowFunction> for WindowFun )) })?; - Ok(WindowFunction::BuiltInWindowFunction(f.into())) + Ok(WindowFunctionDefinition::BuiltInWindowFunction(f.into())) } } } diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 2d7d85abda96..dea99f91e392 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -53,7 +53,7 @@ use datafusion_expr::{ col, create_udaf, lit, Accumulator, AggregateFunction, BuiltinScalarFunction::{Sqrt, Substr}, Expr, LogicalPlan, Operator, PartitionEvaluator, Signature, TryCast, Volatility, - WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunction, WindowUDF, + WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, WindowUDF, }; use datafusion_proto::bytes::{ logical_plan_from_bytes, logical_plan_from_bytes_with_extension_codec, @@ -1663,8 +1663,8 @@ fn roundtrip_window() { // 1. without window_frame let test_expr1 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::BuiltInWindowFunction( - datafusion_expr::window_function::BuiltInWindowFunction::Rank, + WindowFunctionDefinition::BuiltInWindowFunction( + datafusion_expr::BuiltInWindowFunction::Rank, ), vec![], vec![col("col1")], @@ -1674,8 +1674,8 @@ fn roundtrip_window() { // 2. with default window_frame let test_expr2 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::BuiltInWindowFunction( - datafusion_expr::window_function::BuiltInWindowFunction::Rank, + WindowFunctionDefinition::BuiltInWindowFunction( + datafusion_expr::BuiltInWindowFunction::Rank, ), vec![], vec![col("col1")], @@ -1691,8 +1691,8 @@ fn roundtrip_window() { }; let test_expr3 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::BuiltInWindowFunction( - datafusion_expr::window_function::BuiltInWindowFunction::Rank, + WindowFunctionDefinition::BuiltInWindowFunction( + datafusion_expr::BuiltInWindowFunction::Rank, ), vec![], vec![col("col1")], @@ -1708,7 +1708,7 @@ fn roundtrip_window() { }; let test_expr4 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), vec![col("col1")], vec![col("col1")], vec![col("col2")], @@ -1759,7 +1759,7 @@ fn roundtrip_window() { ); let test_expr5 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateUDF(Arc::new(dummy_agg.clone())), + WindowFunctionDefinition::AggregateUDF(Arc::new(dummy_agg.clone())), vec![col("col1")], vec![col("col1")], vec![col("col2")], @@ -1808,7 +1808,7 @@ fn roundtrip_window() { ); let test_expr6 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::WindowUDF(Arc::new(dummy_window_udf.clone())), + WindowFunctionDefinition::WindowUDF(Arc::new(dummy_window_udf.clone())), vec![col("col1")], vec![col("col1")], vec![col("col2")], diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 3934d6701c63..395f10b6f783 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -23,8 +23,8 @@ use datafusion_expr::expr::ScalarFunction; use datafusion_expr::function::suggest_valid_function; use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_by}; use datafusion_expr::{ - expr, window_function, AggregateFunction, BuiltinScalarFunction, Expr, WindowFrame, - WindowFunction, + expr, AggregateFunction, BuiltinScalarFunction, Expr, WindowFrame, + WindowFunctionDefinition, }; use sqlparser::ast::{ Expr as SQLExpr, Function as SQLFunction, FunctionArg, FunctionArgExpr, WindowType, @@ -121,12 +121,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { if let Ok(fun) = self.find_window_func(&name) { let expr = match fun { - WindowFunction::AggregateFunction(aggregate_fun) => { + WindowFunctionDefinition::AggregateFunction(aggregate_fun) => { let args = self.function_args_to_expr(args, schema, planner_context)?; Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(aggregate_fun), + WindowFunctionDefinition::AggregateFunction(aggregate_fun), args, partition_by, order_by, @@ -191,19 +191,22 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Ok(Expr::ScalarFunction(ScalarFunction::new(fun, args))) } - pub(super) fn find_window_func(&self, name: &str) -> Result { - window_function::find_df_window_func(name) + pub(super) fn find_window_func( + &self, + name: &str, + ) -> Result { + expr::find_df_window_func(name) // next check user defined aggregates .or_else(|| { self.context_provider .get_aggregate_meta(name) - .map(WindowFunction::AggregateUDF) + .map(WindowFunctionDefinition::AggregateUDF) }) // next check user defined window functions .or_else(|| { self.context_provider .get_window_meta(name) - .map(WindowFunction::WindowUDF) + .map(WindowFunctionDefinition::WindowUDF) }) .ok_or_else(|| { plan_datafusion_err!("There is no window function named {name}") diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 9931dd15aec8..a4ec3e7722a2 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -23,8 +23,8 @@ use datafusion::common::{ use datafusion::execution::FunctionRegistry; use datafusion::logical_expr::{ - aggregate_function, window_function::find_df_window_func, BinaryExpr, - BuiltinScalarFunction, Case, Expr, LogicalPlan, Operator, + aggregate_function, expr::find_df_window_func, BinaryExpr, BuiltinScalarFunction, + Case, Expr, LogicalPlan, Operator, }; use datafusion::logical_expr::{ expr, Cast, Extension, GroupingSet, Like, LogicalPlanBuilder, Partitioning, From bf0a39a791e7cd0e965abb8c87950cc4101149f7 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 2 Jan 2024 00:28:36 -0800 Subject: [PATCH 02/18] Deprecate duplicate function `LogicalPlan::with_new_inputs` (#8707) * Remove duplicate function with_new_inputs * Make it as deprecated function --- datafusion/expr/src/logical_plan/builder.rs | 2 +- datafusion/expr/src/logical_plan/plan.rs | 47 ++----------------- datafusion/expr/src/tree_node/plan.rs | 2 +- .../optimizer/src/eliminate_outer_join.rs | 3 +- .../optimizer/src/optimize_projections.rs | 3 +- datafusion/optimizer/src/optimizer.rs | 2 +- datafusion/optimizer/src/push_down_filter.rs | 28 +++++++---- datafusion/optimizer/src/push_down_limit.rs | 23 +++++---- datafusion/optimizer/src/utils.rs | 2 +- 9 files changed, 45 insertions(+), 67 deletions(-) diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 549c25f89bae..cfc052cfc14c 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -445,7 +445,7 @@ impl LogicalPlanBuilder { ) }) .collect::>>()?; - curr_plan.with_new_inputs(&new_inputs) + curr_plan.with_new_exprs(curr_plan.expressions(), &new_inputs) } } } diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 9b0f441ef902..c0c520c4e211 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -541,35 +541,9 @@ impl LogicalPlan { } /// Returns a copy of this `LogicalPlan` with the new inputs + #[deprecated(since = "35.0.0", note = "please use `with_new_exprs` instead")] pub fn with_new_inputs(&self, inputs: &[LogicalPlan]) -> Result { - // with_new_inputs use original expression, - // so we don't need to recompute Schema. - match &self { - LogicalPlan::Projection(projection) => { - // Schema of the projection may change - // when its input changes. Hence we should use - // `try_new` method instead of `try_new_with_schema`. - Projection::try_new(projection.expr.to_vec(), Arc::new(inputs[0].clone())) - .map(LogicalPlan::Projection) - } - LogicalPlan::Window(Window { window_expr, .. }) => Ok(LogicalPlan::Window( - Window::try_new(window_expr.to_vec(), Arc::new(inputs[0].clone()))?, - )), - LogicalPlan::Aggregate(Aggregate { - group_expr, - aggr_expr, - .. - }) => Aggregate::try_new( - // Schema of the aggregate may change - // when its input changes. Hence we should use - // `try_new` method instead of `try_new_with_schema`. - Arc::new(inputs[0].clone()), - group_expr.to_vec(), - aggr_expr.to_vec(), - ) - .map(LogicalPlan::Aggregate), - _ => self.with_new_exprs(self.expressions(), inputs), - } + self.with_new_exprs(self.expressions(), inputs) } /// Returns a new `LogicalPlan` based on `self` with inputs and @@ -591,10 +565,6 @@ impl LogicalPlan { /// // create new plan using rewritten_exprs in same position /// let new_plan = plan.new_with_exprs(rewritten_exprs, new_inputs); /// ``` - /// - /// Note: sometimes [`Self::with_new_exprs`] will use schema of - /// original plan, it will not change the scheam. Such as - /// `Projection/Aggregate/Window` pub fn with_new_exprs( &self, mut expr: Vec, @@ -706,17 +676,10 @@ impl LogicalPlan { })) } }, - LogicalPlan::Window(Window { - window_expr, - schema, - .. - }) => { + LogicalPlan::Window(Window { window_expr, .. }) => { assert_eq!(window_expr.len(), expr.len()); - Ok(LogicalPlan::Window(Window { - input: Arc::new(inputs[0].clone()), - window_expr: expr, - schema: schema.clone(), - })) + Window::try_new(expr, Arc::new(inputs[0].clone())) + .map(LogicalPlan::Window) } LogicalPlan::Aggregate(Aggregate { group_expr, .. }) => { // group exprs are the first expressions diff --git a/datafusion/expr/src/tree_node/plan.rs b/datafusion/expr/src/tree_node/plan.rs index 217116530d4a..208a8b57d7b0 100644 --- a/datafusion/expr/src/tree_node/plan.rs +++ b/datafusion/expr/src/tree_node/plan.rs @@ -113,7 +113,7 @@ impl TreeNode for LogicalPlan { .zip(new_children.iter()) .any(|(c1, c2)| c1 != &c2) { - self.with_new_inputs(new_children.as_slice()) + self.with_new_exprs(self.expressions(), new_children.as_slice()) } else { Ok(self) } diff --git a/datafusion/optimizer/src/eliminate_outer_join.rs b/datafusion/optimizer/src/eliminate_outer_join.rs index e4d57f0209a4..53c4b3702b1e 100644 --- a/datafusion/optimizer/src/eliminate_outer_join.rs +++ b/datafusion/optimizer/src/eliminate_outer_join.rs @@ -106,7 +106,8 @@ impl OptimizerRule for EliminateOuterJoin { schema: join.schema.clone(), null_equals_null: join.null_equals_null, }); - let new_plan = plan.with_new_inputs(&[new_join])?; + let new_plan = + plan.with_new_exprs(plan.expressions(), &[new_join])?; Ok(Some(new_plan)) } _ => Ok(None), diff --git a/datafusion/optimizer/src/optimize_projections.rs b/datafusion/optimizer/src/optimize_projections.rs index 7ae9f7edf5e5..891a909a3378 100644 --- a/datafusion/optimizer/src/optimize_projections.rs +++ b/datafusion/optimizer/src/optimize_projections.rs @@ -373,7 +373,8 @@ fn optimize_projections( // `old_child` during construction: .map(|(new_input, old_child)| new_input.unwrap_or_else(|| old_child.clone())) .collect::>(); - plan.with_new_inputs(&new_inputs).map(Some) + plan.with_new_exprs(plan.expressions(), &new_inputs) + .map(Some) } } diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index 0dc34cb809eb..2cb59d511ccf 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -382,7 +382,7 @@ impl Optimizer { }) .collect::>(); - Ok(Some(plan.with_new_inputs(&new_inputs)?)) + Ok(Some(plan.with_new_exprs(plan.expressions(), &new_inputs)?)) } /// Use a rule to optimize the whole plan. diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 9d277d18d2f7..4eb925ac0629 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -691,9 +691,11 @@ impl OptimizerRule for PushDownFilter { | LogicalPlan::Distinct(_) | LogicalPlan::Sort(_) => { // commutable - let new_filter = - plan.with_new_inputs(&[child_plan.inputs()[0].clone()])?; - child_plan.with_new_inputs(&[new_filter])? + let new_filter = plan.with_new_exprs( + plan.expressions(), + &[child_plan.inputs()[0].clone()], + )?; + child_plan.with_new_exprs(child_plan.expressions(), &[new_filter])? } LogicalPlan::SubqueryAlias(subquery_alias) => { let mut replace_map = HashMap::new(); @@ -716,7 +718,7 @@ impl OptimizerRule for PushDownFilter { new_predicate, subquery_alias.input.clone(), )?); - child_plan.with_new_inputs(&[new_filter])? + child_plan.with_new_exprs(child_plan.expressions(), &[new_filter])? } LogicalPlan::Projection(projection) => { // A projection is filter-commutable if it do not contain volatile predicates or contain volatile @@ -760,10 +762,15 @@ impl OptimizerRule for PushDownFilter { )?); match conjunction(keep_predicates) { - None => child_plan.with_new_inputs(&[new_filter])?, + None => child_plan.with_new_exprs( + child_plan.expressions(), + &[new_filter], + )?, Some(keep_predicate) => { - let child_plan = - child_plan.with_new_inputs(&[new_filter])?; + let child_plan = child_plan.with_new_exprs( + child_plan.expressions(), + &[new_filter], + )?; LogicalPlan::Filter(Filter::try_new( keep_predicate, Arc::new(child_plan), @@ -837,7 +844,9 @@ impl OptimizerRule for PushDownFilter { )?), None => (*agg.input).clone(), }; - let new_agg = filter.input.with_new_inputs(&vec![child])?; + let new_agg = filter + .input + .with_new_exprs(filter.input.expressions(), &vec![child])?; match conjunction(keep_predicates) { Some(predicate) => LogicalPlan::Filter(Filter::try_new( predicate, @@ -942,7 +951,8 @@ impl OptimizerRule for PushDownFilter { None => extension_plan.node.inputs().into_iter().cloned().collect(), }; // extension with new inputs. - let new_extension = child_plan.with_new_inputs(&new_children)?; + let new_extension = + child_plan.with_new_exprs(child_plan.expressions(), &new_children)?; match conjunction(keep_predicates) { Some(predicate) => LogicalPlan::Filter(Filter::try_new( diff --git a/datafusion/optimizer/src/push_down_limit.rs b/datafusion/optimizer/src/push_down_limit.rs index 6703a1d787a7..c2f35a790616 100644 --- a/datafusion/optimizer/src/push_down_limit.rs +++ b/datafusion/optimizer/src/push_down_limit.rs @@ -126,7 +126,7 @@ impl OptimizerRule for PushDownLimit { fetch: scan.fetch.map(|x| min(x, limit)).or(Some(limit)), projected_schema: scan.projected_schema.clone(), }); - Some(plan.with_new_inputs(&[new_input])?) + Some(plan.with_new_exprs(plan.expressions(), &[new_input])?) } } LogicalPlan::Union(union) => { @@ -145,7 +145,7 @@ impl OptimizerRule for PushDownLimit { inputs: new_inputs, schema: union.schema.clone(), }); - Some(plan.with_new_inputs(&[union])?) + Some(plan.with_new_exprs(plan.expressions(), &[union])?) } LogicalPlan::CrossJoin(cross_join) => { @@ -166,15 +166,16 @@ impl OptimizerRule for PushDownLimit { right: Arc::new(new_right), schema: plan.schema().clone(), }); - Some(plan.with_new_inputs(&[new_cross_join])?) + Some(plan.with_new_exprs(plan.expressions(), &[new_cross_join])?) } LogicalPlan::Join(join) => { let new_join = push_down_join(join, fetch + skip); match new_join { - Some(new_join) => { - Some(plan.with_new_inputs(&[LogicalPlan::Join(new_join)])?) - } + Some(new_join) => Some(plan.with_new_exprs( + plan.expressions(), + &[LogicalPlan::Join(new_join)], + )?), None => None, } } @@ -192,14 +193,16 @@ impl OptimizerRule for PushDownLimit { input: Arc::new((*sort.input).clone()), fetch: new_fetch, }); - Some(plan.with_new_inputs(&[new_sort])?) + Some(plan.with_new_exprs(plan.expressions(), &[new_sort])?) } } LogicalPlan::Projection(_) | LogicalPlan::SubqueryAlias(_) => { // commute - let new_limit = - plan.with_new_inputs(&[child_plan.inputs()[0].clone()])?; - Some(child_plan.with_new_inputs(&[new_limit])?) + let new_limit = plan.with_new_exprs( + plan.expressions(), + &[child_plan.inputs()[0].clone()], + )?; + Some(child_plan.with_new_exprs(child_plan.expressions(), &[new_limit])?) } _ => None, }; diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 48f72ee7a0f8..44f2404afade 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -46,7 +46,7 @@ pub fn optimize_children( new_inputs.push(new_input.unwrap_or_else(|| input.clone())) } if plan_is_changed { - Ok(Some(plan.with_new_inputs(&new_inputs)?)) + Ok(Some(plan.with_new_exprs(plan.expressions(), &new_inputs)?)) } else { Ok(None) } From f4233a92761e9144b8747e66b95bf0b3f82464b8 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Tue, 2 Jan 2024 10:36:43 -0500 Subject: [PATCH 03/18] Minor: refactor bloom filter tests to reduce duplication (#8435) --- .../physical_plan/parquet/row_groups.rs | 343 ++++++++---------- 1 file changed, 153 insertions(+), 190 deletions(-) diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs index 5d18eac7d9fb..24c65423dd4c 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs @@ -1013,82 +1013,28 @@ mod tests { create_physical_expr(expr, &df_schema, schema, &execution_props).unwrap() } - // Note the values in the `String` column are: - // ❯ select * from './parquet-testing/data/data_index_bloom_encoding_stats.parquet'; - // +-----------+ - // | String | - // +-----------+ - // | Hello | - // | This is | - // | a | - // | test | - // | How | - // | are you | - // | doing | - // | today | - // | the quick | - // | brown fox | - // | jumps | - // | over | - // | the lazy | - // | dog | - // +-----------+ #[tokio::test] async fn test_row_group_bloom_filter_pruning_predicate_simple_expr() { - // load parquet file - let testdata = datafusion_common::test_util::parquet_test_data(); - let file_name = "data_index_bloom_encoding_stats.parquet"; - let path = format!("{testdata}/{file_name}"); - let data = bytes::Bytes::from(std::fs::read(path).unwrap()); - - // generate pruning predicate `(String = "Hello_Not_exists")` - let schema = Schema::new(vec![Field::new("String", DataType::Utf8, false)]); - let expr = col(r#""String""#).eq(lit("Hello_Not_Exists")); - let expr = logical2physical(&expr, &schema); - let pruning_predicate = - PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); - - let row_groups = vec![0]; - let pruned_row_groups = test_row_group_bloom_filter_pruning_predicate( - file_name, - data, - &pruning_predicate, - &row_groups, - ) - .await - .unwrap(); - assert!(pruned_row_groups.is_empty()); + BloomFilterTest::new_data_index_bloom_encoding_stats() + .with_expect_all_pruned() + // generate pruning predicate `(String = "Hello_Not_exists")` + .run(col(r#""String""#).eq(lit("Hello_Not_Exists"))) + .await } #[tokio::test] async fn test_row_group_bloom_filter_pruning_predicate_mutiple_expr() { - // load parquet file - let testdata = datafusion_common::test_util::parquet_test_data(); - let file_name = "data_index_bloom_encoding_stats.parquet"; - let path = format!("{testdata}/{file_name}"); - let data = bytes::Bytes::from(std::fs::read(path).unwrap()); - - // generate pruning predicate `(String = "Hello_Not_exists" OR String = "Hello_Not_exists2")` - let schema = Schema::new(vec![Field::new("String", DataType::Utf8, false)]); - let expr = lit("1").eq(lit("1")).and( - col(r#""String""#) - .eq(lit("Hello_Not_Exists")) - .or(col(r#""String""#).eq(lit("Hello_Not_Exists2"))), - ); - let expr = logical2physical(&expr, &schema); - let pruning_predicate = - PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); - - let row_groups = vec![0]; - let pruned_row_groups = test_row_group_bloom_filter_pruning_predicate( - file_name, - data, - &pruning_predicate, - &row_groups, - ) - .await - .unwrap(); - assert!(pruned_row_groups.is_empty()); + BloomFilterTest::new_data_index_bloom_encoding_stats() + .with_expect_all_pruned() + // generate pruning predicate `(String = "Hello_Not_exists" OR String = "Hello_Not_exists2")` + .run( + lit("1").eq(lit("1")).and( + col(r#""String""#) + .eq(lit("Hello_Not_Exists")) + .or(col(r#""String""#).eq(lit("Hello_Not_Exists2"))), + ), + ) + .await } #[tokio::test] @@ -1129,144 +1075,161 @@ mod tests { #[tokio::test] async fn test_row_group_bloom_filter_pruning_predicate_with_exists_value() { - // load parquet file - let testdata = datafusion_common::test_util::parquet_test_data(); - let file_name = "data_index_bloom_encoding_stats.parquet"; - let path = format!("{testdata}/{file_name}"); - let data = bytes::Bytes::from(std::fs::read(path).unwrap()); - - // generate pruning predicate `(String = "Hello")` - let schema = Schema::new(vec![Field::new("String", DataType::Utf8, false)]); - let expr = col(r#""String""#).eq(lit("Hello")); - let expr = logical2physical(&expr, &schema); - let pruning_predicate = - PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); - - let row_groups = vec![0]; - let pruned_row_groups = test_row_group_bloom_filter_pruning_predicate( - file_name, - data, - &pruning_predicate, - &row_groups, - ) - .await - .unwrap(); - assert_eq!(pruned_row_groups, row_groups); + BloomFilterTest::new_data_index_bloom_encoding_stats() + .with_expect_none_pruned() + // generate pruning predicate `(String = "Hello")` + .run(col(r#""String""#).eq(lit("Hello"))) + .await } #[tokio::test] async fn test_row_group_bloom_filter_pruning_predicate_with_exists_2_values() { - // load parquet file - let testdata = datafusion_common::test_util::parquet_test_data(); - let file_name = "data_index_bloom_encoding_stats.parquet"; - let path = format!("{testdata}/{file_name}"); - let data = bytes::Bytes::from(std::fs::read(path).unwrap()); - - // generate pruning predicate `(String = "Hello") OR (String = "the quick")` - let schema = Schema::new(vec![Field::new("String", DataType::Utf8, false)]); - let expr = col(r#""String""#) - .eq(lit("Hello")) - .or(col(r#""String""#).eq(lit("the quick"))); - let expr = logical2physical(&expr, &schema); - let pruning_predicate = - PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); - - let row_groups = vec![0]; - let pruned_row_groups = test_row_group_bloom_filter_pruning_predicate( - file_name, - data, - &pruning_predicate, - &row_groups, - ) - .await - .unwrap(); - assert_eq!(pruned_row_groups, row_groups); + BloomFilterTest::new_data_index_bloom_encoding_stats() + .with_expect_none_pruned() + // generate pruning predicate `(String = "Hello") OR (String = "the quick")` + .run( + col(r#""String""#) + .eq(lit("Hello")) + .or(col(r#""String""#).eq(lit("the quick"))), + ) + .await } #[tokio::test] async fn test_row_group_bloom_filter_pruning_predicate_with_exists_3_values() { - // load parquet file - let testdata = datafusion_common::test_util::parquet_test_data(); - let file_name = "data_index_bloom_encoding_stats.parquet"; - let path = format!("{testdata}/{file_name}"); - let data = bytes::Bytes::from(std::fs::read(path).unwrap()); - - // generate pruning predicate `(String = "Hello") OR (String = "the quick") OR (String = "are you")` - let schema = Schema::new(vec![Field::new("String", DataType::Utf8, false)]); - let expr = col(r#""String""#) - .eq(lit("Hello")) - .or(col(r#""String""#).eq(lit("the quick"))) - .or(col(r#""String""#).eq(lit("are you"))); - let expr = logical2physical(&expr, &schema); - let pruning_predicate = - PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); - - let row_groups = vec![0]; - let pruned_row_groups = test_row_group_bloom_filter_pruning_predicate( - file_name, - data, - &pruning_predicate, - &row_groups, - ) - .await - .unwrap(); - assert_eq!(pruned_row_groups, row_groups); + BloomFilterTest::new_data_index_bloom_encoding_stats() + .with_expect_none_pruned() + // generate pruning predicate `(String = "Hello") OR (String = "the quick") OR (String = "are you")` + .run( + col(r#""String""#) + .eq(lit("Hello")) + .or(col(r#""String""#).eq(lit("the quick"))) + .or(col(r#""String""#).eq(lit("are you"))), + ) + .await } #[tokio::test] async fn test_row_group_bloom_filter_pruning_predicate_with_or_not_eq() { - // load parquet file - let testdata = datafusion_common::test_util::parquet_test_data(); - let file_name = "data_index_bloom_encoding_stats.parquet"; - let path = format!("{testdata}/{file_name}"); - let data = bytes::Bytes::from(std::fs::read(path).unwrap()); - - // generate pruning predicate `(String = "foo") OR (String != "bar")` - let schema = Schema::new(vec![Field::new("String", DataType::Utf8, false)]); - let expr = col(r#""String""#) - .not_eq(lit("foo")) - .or(col(r#""String""#).not_eq(lit("bar"))); - let expr = logical2physical(&expr, &schema); - let pruning_predicate = - PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); - - let row_groups = vec![0]; - let pruned_row_groups = test_row_group_bloom_filter_pruning_predicate( - file_name, - data, - &pruning_predicate, - &row_groups, - ) - .await - .unwrap(); - assert_eq!(pruned_row_groups, row_groups); + BloomFilterTest::new_data_index_bloom_encoding_stats() + .with_expect_none_pruned() + // generate pruning predicate `(String = "foo") OR (String != "bar")` + .run( + col(r#""String""#) + .not_eq(lit("foo")) + .or(col(r#""String""#).not_eq(lit("bar"))), + ) + .await } #[tokio::test] async fn test_row_group_bloom_filter_pruning_predicate_without_bloom_filter() { - // load parquet file - let testdata = datafusion_common::test_util::parquet_test_data(); - let file_name = "alltypes_plain.parquet"; - let path = format!("{testdata}/{file_name}"); - let data = bytes::Bytes::from(std::fs::read(path).unwrap()); - // generate pruning predicate on a column without a bloom filter - let schema = Schema::new(vec![Field::new("string_col", DataType::Utf8, false)]); - let expr = col(r#""string_col""#).eq(lit("0")); - let expr = logical2physical(&expr, &schema); - let pruning_predicate = - PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); + BloomFilterTest::new_all_types() + .with_expect_none_pruned() + .run(col(r#""string_col""#).eq(lit("0"))) + .await + } - let row_groups = vec![0]; - let pruned_row_groups = test_row_group_bloom_filter_pruning_predicate( - file_name, - data, - &pruning_predicate, - &row_groups, - ) - .await - .unwrap(); - assert_eq!(pruned_row_groups, row_groups); + struct BloomFilterTest { + file_name: String, + schema: Schema, + // which row groups should be attempted to prune + row_groups: Vec, + // which row groups are expected to be left after pruning. Must be set + // otherwise will panic on run() + post_pruning_row_groups: Option>, + } + + impl BloomFilterTest { + /// Return a test for data_index_bloom_encoding_stats.parquet + /// Note the values in the `String` column are: + /// ```sql + /// ❯ select * from './parquet-testing/data/data_index_bloom_encoding_stats.parquet'; + /// +-----------+ + /// | String | + /// +-----------+ + /// | Hello | + /// | This is | + /// | a | + /// | test | + /// | How | + /// | are you | + /// | doing | + /// | today | + /// | the quick | + /// | brown fox | + /// | jumps | + /// | over | + /// | the lazy | + /// | dog | + /// +-----------+ + /// ``` + fn new_data_index_bloom_encoding_stats() -> Self { + Self { + file_name: String::from("data_index_bloom_encoding_stats.parquet"), + schema: Schema::new(vec![Field::new("String", DataType::Utf8, false)]), + row_groups: vec![0], + post_pruning_row_groups: None, + } + } + + // Return a test for alltypes_plain.parquet + fn new_all_types() -> Self { + Self { + file_name: String::from("alltypes_plain.parquet"), + schema: Schema::new(vec![Field::new( + "string_col", + DataType::Utf8, + false, + )]), + row_groups: vec![0], + post_pruning_row_groups: None, + } + } + + /// Expect all row groups to be pruned + pub fn with_expect_all_pruned(mut self) -> Self { + self.post_pruning_row_groups = Some(vec![]); + self + } + + /// Expect all row groups not to be pruned + pub fn with_expect_none_pruned(mut self) -> Self { + self.post_pruning_row_groups = Some(self.row_groups.clone()); + self + } + + /// Prune this file using the specified expression and check that the expected row groups are left + async fn run(self, expr: Expr) { + let Self { + file_name, + schema, + row_groups, + post_pruning_row_groups, + } = self; + + let post_pruning_row_groups = + post_pruning_row_groups.expect("post_pruning_row_groups must be set"); + + let testdata = datafusion_common::test_util::parquet_test_data(); + let path = format!("{testdata}/{file_name}"); + let data = bytes::Bytes::from(std::fs::read(path).unwrap()); + + let expr = logical2physical(&expr, &schema); + let pruning_predicate = + PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); + + let pruned_row_groups = test_row_group_bloom_filter_pruning_predicate( + &file_name, + data, + &pruning_predicate, + &row_groups, + ) + .await + .unwrap(); + assert_eq!(pruned_row_groups, post_pruning_row_groups); + } } async fn test_row_group_bloom_filter_pruning_predicate( From 82656af2c79246f28b8519210be42de6e5a82e54 Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Tue, 2 Jan 2024 23:48:29 +0800 Subject: [PATCH 04/18] clean up code (#8715) --- .../core/src/datasource/file_format/write/demux.rs | 4 ++-- .../datasource/physical_plan/parquet/page_filter.rs | 2 +- datafusion/core/src/execution/context/mod.rs | 10 ++++++---- .../core/src/physical_optimizer/sort_pushdown.rs | 4 +--- datafusion/core/src/physical_planner.rs | 5 +++-- datafusion/substrait/src/logical_plan/producer.rs | 2 +- 6 files changed, 14 insertions(+), 13 deletions(-) diff --git a/datafusion/core/src/datasource/file_format/write/demux.rs b/datafusion/core/src/datasource/file_format/write/demux.rs index fa4ed8437015..dbfeb67eaeb9 100644 --- a/datafusion/core/src/datasource/file_format/write/demux.rs +++ b/datafusion/core/src/datasource/file_format/write/demux.rs @@ -383,7 +383,7 @@ fn compute_take_arrays( fn remove_partition_by_columns( parted_batch: &RecordBatch, - partition_by: &Vec<(String, DataType)>, + partition_by: &[(String, DataType)], ) -> Result { let end_idx = parted_batch.num_columns() - partition_by.len(); let non_part_cols = &parted_batch.columns()[..end_idx]; @@ -405,7 +405,7 @@ fn remove_partition_by_columns( } fn compute_hive_style_file_path( - part_key: &Vec, + part_key: &[String], partition_by: &[(String, DataType)], write_id: &str, file_extension: &str, diff --git a/datafusion/core/src/datasource/physical_plan/parquet/page_filter.rs b/datafusion/core/src/datasource/physical_plan/parquet/page_filter.rs index f6310c49bcd6..a0637f379610 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/page_filter.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/page_filter.rs @@ -372,7 +372,7 @@ fn prune_pages_in_one_row_group( } fn create_row_count_in_each_page( - location: &Vec, + location: &[PageLocation], num_rows: usize, ) -> Vec { let mut vec = Vec::with_capacity(location.len()); diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index c51f2d132aad..d6b7f046f3e3 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -1719,7 +1719,7 @@ impl SessionState { let mut stringified_plans = e.stringified_plans.clone(); // analyze & capture output of each rule - let analyzed_plan = match self.analyzer.execute_and_check( + let analyzer_result = self.analyzer.execute_and_check( e.plan.as_ref(), self.options(), |analyzed_plan, analyzer| { @@ -1727,7 +1727,8 @@ impl SessionState { let plan_type = PlanType::AnalyzedLogicalPlan { analyzer_name }; stringified_plans.push(analyzed_plan.to_stringified(plan_type)); }, - ) { + ); + let analyzed_plan = match analyzer_result { Ok(plan) => plan, Err(DataFusionError::Context(analyzer_name, err)) => { let plan_type = PlanType::AnalyzedLogicalPlan { analyzer_name }; @@ -1750,7 +1751,7 @@ impl SessionState { .push(analyzed_plan.to_stringified(PlanType::FinalAnalyzedLogicalPlan)); // optimize the child plan, capturing the output of each optimizer - let (plan, logical_optimization_succeeded) = match self.optimizer.optimize( + let optimized_plan = self.optimizer.optimize( &analyzed_plan, self, |optimized_plan, optimizer| { @@ -1758,7 +1759,8 @@ impl SessionState { let plan_type = PlanType::OptimizedLogicalPlan { optimizer_name }; stringified_plans.push(optimized_plan.to_stringified(plan_type)); }, - ) { + ); + let (plan, logical_optimization_succeeded) = match optimized_plan { Ok(plan) => (Arc::new(plan), true), Err(DataFusionError::Context(optimizer_name, err)) => { let plan_type = PlanType::OptimizedLogicalPlan { optimizer_name }; diff --git a/datafusion/core/src/physical_optimizer/sort_pushdown.rs b/datafusion/core/src/physical_optimizer/sort_pushdown.rs index 97ca47baf05f..f0a8c8cfd3cb 100644 --- a/datafusion/core/src/physical_optimizer/sort_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/sort_pushdown.rs @@ -405,9 +405,7 @@ fn shift_right_required( let new_right_required: Vec = parent_required .iter() .filter_map(|r| { - let Some(col) = r.expr.as_any().downcast_ref::() else { - return None; - }; + let col = r.expr.as_any().downcast_ref::()?; if col.index() < left_columns_len { return None; diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 31d50be10f70..d696c55a8c13 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -1879,7 +1879,7 @@ impl DefaultPhysicalPlanner { ); } - match self.optimize_internal( + let optimized_plan = self.optimize_internal( input, session_state, |plan, optimizer| { @@ -1891,7 +1891,8 @@ impl DefaultPhysicalPlanner { .to_stringified(e.verbose, plan_type), ); }, - ) { + ); + match optimized_plan { Ok(input) => { // This plan will includes statistics if show_statistics is on stringified_plans.push( diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 926883251a63..ab0e8c860858 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -608,7 +608,7 @@ pub fn parse_flat_grouping_exprs( pub fn to_substrait_groupings( ctx: &SessionContext, - exprs: &Vec, + exprs: &[Expr], schema: &DFSchemaRef, extension_info: &mut ( Vec, From 94aff5555874f023c934cd6c3a52dd956a773342 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Berkay=20=C5=9Eahin?= <124376117+berkaysynnada@users.noreply.github.com> Date: Tue, 2 Jan 2024 18:52:12 +0300 Subject: [PATCH 05/18] Update analyze.rs (#8717) --- datafusion/physical-plan/src/analyze.rs | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/datafusion/physical-plan/src/analyze.rs b/datafusion/physical-plan/src/analyze.rs index ded37983bb21..4f1578e220dd 100644 --- a/datafusion/physical-plan/src/analyze.rs +++ b/datafusion/physical-plan/src/analyze.rs @@ -115,8 +115,12 @@ impl ExecutionPlan for AnalyzeExec { /// Specifies whether this plan generates an infinite stream of records. /// If the plan does not support pipelining, but its input(s) are /// infinite, returns an error to indicate this. - fn unbounded_output(&self, _children: &[bool]) -> Result { - internal_err!("Optimization not supported for ANALYZE") + fn unbounded_output(&self, children: &[bool]) -> Result { + if children[0] { + internal_err!("Streaming execution of AnalyzeExec is not possible") + } else { + Ok(false) + } } /// Get the output partitioning of this plan From d4b96a80c86d216613ecbec24d4908bb31ed4c7e Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Tue, 2 Jan 2024 23:57:26 +0800 Subject: [PATCH 06/18] support LargeList in array_position (#8714) --- .../physical-expr/src/array_expressions.rs | 14 ++-- datafusion/sqllogictest/test_files/array.slt | 71 +++++++++++++++++++ 2 files changed, 81 insertions(+), 4 deletions(-) diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 250250630eff..9b93782237f8 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -1367,8 +1367,14 @@ pub fn array_position(args: &[ArrayRef]) -> Result { if args.len() < 2 || args.len() > 3 { return exec_err!("array_position expects two or three arguments"); } - - let list_array = as_list_array(&args[0])?; + match &args[0].data_type() { + DataType::List(_) => general_position_dispatch::(args), + DataType::LargeList(_) => general_position_dispatch::(args), + array_type => exec_err!("array_position does not support type '{array_type:?}'."), + } +} +fn general_position_dispatch(args: &[ArrayRef]) -> Result { + let list_array = as_generic_list_array::(&args[0])?; let element_array = &args[1]; check_datatypes("array_position", &[list_array.values(), element_array])?; @@ -1395,10 +1401,10 @@ pub fn array_position(args: &[ArrayRef]) -> Result { } } - general_position::(list_array, element_array, arr_from) + generic_position::(list_array, element_array, arr_from) } -fn general_position( +fn generic_position( list_array: &GenericListArray, element_array: &ArrayRef, arr_from: Vec, // 0-indexed diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 6dab3b3084a9..4205f64c19d0 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -363,6 +363,17 @@ AS VALUES (make_array(31, 32, 33, 34, 35, 26, 37, 38, 39, 40), 34, 4, 'ok', [8,9]) ; +statement ok +CREATE TABLE large_arrays_values_without_nulls +AS SELECT + arrow_cast(column1, 'LargeList(Int64)') AS column1, + column2, + column3, + column4, + arrow_cast(column5, 'LargeList(Int64)') AS column5 +FROM arrays_values_without_nulls +; + statement ok CREATE TABLE arrays_range AS VALUES @@ -2054,12 +2065,22 @@ select array_position(['h', 'e', 'l', 'l', 'o'], 'l'), array_position([1, 2, 3, ---- 3 5 1 +query III +select array_position(arrow_cast(['h', 'e', 'l', 'l', 'o'], 'LargeList(Utf8)'), 'l'), array_position(arrow_cast([1, 2, 3, 4, 5], 'LargeList(Int64)'), 5), array_position(arrow_cast([1, 1, 1], 'LargeList(Int64)'), 1); +---- +3 5 1 + # array_position scalar function #2 (with optional argument) query III select array_position(['h', 'e', 'l', 'l', 'o'], 'l', 4), array_position([1, 2, 5, 4, 5], 5, 4), array_position([1, 1, 1], 1, 2); ---- 4 5 2 +query III +select array_position(arrow_cast(['h', 'e', 'l', 'l', 'o'], 'LargeList(Utf8)'), 'l', 4), array_position(arrow_cast([1, 2, 3, 4, 5], 'LargeList(Int64)'), 5, 4), array_position(arrow_cast([1, 1, 1], 'LargeList(Int64)'), 1, 2); +---- +4 5 2 + # array_position scalar function #3 (element is list) query II select array_position(make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7, 8, 9]), [4, 5, 6]), array_position(make_array([1, 3, 2], [2, 3, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]), [2, 3, 4]); @@ -2072,24 +2093,44 @@ select array_position(make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7, ---- 4 3 +query II +select array_position(arrow_cast(make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7, 8, 9]), 'LargeList(List(Int64))'), [4, 5, 6]), array_position(arrow_cast(make_array([1, 3, 2], [2, 3, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]), 'LargeList(List(Int64))'), [2, 3, 4]); +---- +2 2 + # list_position scalar function #5 (function alias `array_position`) query III select list_position(['h', 'e', 'l', 'l', 'o'], 'l'), list_position([1, 2, 3, 4, 5], 5), list_position([1, 1, 1], 1); ---- 3 5 1 +query III +select list_position(arrow_cast(['h', 'e', 'l', 'l', 'o'], 'LargeList(Utf8)'), 'l'), list_position(arrow_cast([1, 2, 3, 4, 5], 'LargeList(Int64)'), 5), list_position(arrow_cast([1, 1, 1], 'LargeList(Int64)'), 1); +---- +3 5 1 + # array_indexof scalar function #6 (function alias `array_position`) query III select array_indexof(['h', 'e', 'l', 'l', 'o'], 'l'), array_indexof([1, 2, 3, 4, 5], 5), array_indexof([1, 1, 1], 1); ---- 3 5 1 +query III +select array_indexof(arrow_cast(['h', 'e', 'l', 'l', 'o'], 'LargeList(Utf8)'), 'l'), array_indexof(arrow_cast([1, 2, 3, 4, 5], 'LargeList(Int64)'), 5), array_indexof(arrow_cast([1, 1, 1], 'LargeList(Int64)'), 1); +---- +3 5 1 + # list_indexof scalar function #7 (function alias `array_position`) query III select list_indexof(['h', 'e', 'l', 'l', 'o'], 'l'), list_indexof([1, 2, 3, 4, 5], 5), list_indexof([1, 1, 1], 1); ---- 3 5 1 +query III +select list_indexof(arrow_cast(['h', 'e', 'l', 'l', 'o'], 'LargeList(Utf8)'), 'l'), list_indexof(arrow_cast([1, 2, 3, 4, 5], 'LargeList(Int64)'), 5), list_indexof(arrow_cast([1, 1, 1], 'LargeList(Int64)'), 1); +---- +3 5 1 + # array_position with columns #1 query II select array_position(column1, column2), array_position(column1, column2, column3) from arrays_values_without_nulls; @@ -2099,6 +2140,14 @@ select array_position(column1, column2), array_position(column1, column2, column 3 3 4 4 +query II +select array_position(column1, column2), array_position(column1, column2, column3) from large_arrays_values_without_nulls; +---- +1 1 +2 2 +3 3 +4 4 + # array_position with columns #2 (element is list) query II select array_position(column1, column2), array_position(column1, column2, column3) from nested_arrays; @@ -2106,6 +2155,13 @@ select array_position(column1, column2), array_position(column1, column2, column 3 3 2 5 +#TODO: add this test when #8305 is fixed +#query II +#select array_position(column1, column2), array_position(column1, column2, column3) from nested_arrays; +#---- +#3 3 +#2 5 + # array_position with columns and scalars #1 query III select array_position(make_array(1, 2, 3, 4, 5), column2), array_position(column1, 3), array_position(column1, 3, 5) from arrays_values_without_nulls; @@ -2115,6 +2171,14 @@ NULL NULL NULL NULL NULL NULL NULL NULL NULL +query III +select array_position(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), column2), array_position(column1, 3), array_position(column1, 3, 5) from large_arrays_values_without_nulls; +---- +1 3 NULL +NULL NULL NULL +NULL NULL NULL +NULL NULL NULL + # array_position with columns and scalars #2 (element is list) query III select array_position(make_array([1, 2, 3], [4, 5, 6], [11, 12, 13]), column2), array_position(column1, make_array(4, 5, 6)), array_position(column1, make_array(1, 2, 3), 2) from nested_arrays; @@ -2122,6 +2186,13 @@ select array_position(make_array([1, 2, 3], [4, 5, 6], [11, 12, 13]), column2), NULL 6 4 NULL 1 NULL +#TODO: add this test when #8305 is fixed +#query III +#select array_position(arrow_cast(make_array([1, 2, 3], [4, 5, 6], [11, 12, 13]), 'LargeList(List(Int64))'), column2), array_position(column1, make_array(4, 5, 6)), array_position(column1, make_array(1, 2, 3), 2) from large_nested_arrays; +#---- +#NULL 6 4 +#NULL 1 NULL + ## array_positions (aliases: `list_positions`) # array_positions scalar function #1 From 96cede202a8a554051001143e8345883992c3f74 Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Tue, 2 Jan 2024 23:58:18 +0800 Subject: [PATCH 07/18] support LargeList in array_ndims (#8716) --- datafusion/common/src/utils.rs | 9 +-- .../physical-expr/src/array_expressions.rs | 24 ++++++-- datafusion/sqllogictest/test_files/array.slt | 57 ++++++++++++++++++- 3 files changed, 77 insertions(+), 13 deletions(-) diff --git a/datafusion/common/src/utils.rs b/datafusion/common/src/utils.rs index cfdef309a4ee..49a00b24d10e 100644 --- a/datafusion/common/src/utils.rs +++ b/datafusion/common/src/utils.rs @@ -469,10 +469,11 @@ pub fn coerced_type_with_base_type_only( /// Compute the number of dimensions in a list data type. pub fn list_ndims(data_type: &DataType) -> u64 { - if let DataType::List(field) = data_type { - 1 + list_ndims(field.data_type()) - } else { - 0 + match data_type { + DataType::List(field) | DataType::LargeList(field) => { + 1 + list_ndims(field.data_type()) + } + _ => 0, } } diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 9b93782237f8..92ba7a4d1dcd 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -2250,11 +2250,13 @@ pub fn array_ndims(args: &[ArrayRef]) -> Result { return exec_err!("array_ndims needs one argument"); } - if let Some(list_array) = args[0].as_list_opt::() { - let ndims = datafusion_common::utils::list_ndims(list_array.data_type()); + fn general_list_ndims( + array: &GenericListArray, + ) -> Result { + let mut data = Vec::new(); + let ndims = datafusion_common::utils::list_ndims(array.data_type()); - let mut data = vec![]; - for arr in list_array.iter() { + for arr in array.iter() { if arr.is_some() { data.push(Some(ndims)) } else { @@ -2263,8 +2265,18 @@ pub fn array_ndims(args: &[ArrayRef]) -> Result { } Ok(Arc::new(UInt64Array::from(data)) as ArrayRef) - } else { - Ok(Arc::new(UInt64Array::from(vec![0; args[0].len()])) as ArrayRef) + } + + match args[0].data_type() { + DataType::List(_) => { + let array = as_list_array(&args[0])?; + general_list_ndims::(array) + } + DataType::LargeList(_) => { + let array = as_large_list_array(&args[0])?; + general_list_ndims::(array) + } + _ => Ok(Arc::new(UInt64Array::from(vec![0; args[0].len()])) as ArrayRef), } } diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 4205f64c19d0..2f8e3c805f73 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -3504,7 +3504,7 @@ NULL [3] [4] # array_ndims scalar function #1 query III -select +select array_ndims(1), array_ndims(null), array_ndims([2, 3]); @@ -3520,8 +3520,17 @@ AS VALUES (3, [6], [[9]], [[[[[10]]]]]) ; +statement ok +CREATE TABLE large_array_ndims_table +AS SELECT + column1, + arrow_cast(column2, 'LargeList(Int64)') as column2, + arrow_cast(column3, 'LargeList(List(Int64))') as column3, + arrow_cast(column4, 'LargeList(List(List(List(List(Int64)))))') as column4 +FROM array_ndims_table; + query IIII -select +select array_ndims(column1), array_ndims(column2), array_ndims(column3), @@ -3533,9 +3542,25 @@ from array_ndims_table; 0 1 2 5 0 1 2 5 +query IIII +select + array_ndims(column1), + array_ndims(column2), + array_ndims(column3), + array_ndims(column4) +from large_array_ndims_table; +---- +0 1 2 5 +0 1 2 5 +0 1 2 5 +0 1 2 5 + statement ok drop table array_ndims_table; +statement ok +drop table large_array_ndims_table + query I select array_ndims(arrow_cast([null], 'List(List(List(Int64)))')); ---- @@ -3553,14 +3578,29 @@ select array_ndims(make_array()), array_ndims(make_array(make_array())) ---- 1 2 +query II +select array_ndims(arrow_cast(make_array(), 'LargeList(Null)')), array_ndims(arrow_cast(make_array(make_array()), 'LargeList(List(Null))')) +---- +1 2 + # list_ndims scalar function #4 (function alias `array_ndims`) query III select list_ndims(make_array(1, 2, 3)), list_ndims(make_array([1, 2], [3, 4])), list_ndims(make_array([[[[1], [2]]]])); ---- 1 2 5 +query III +select list_ndims(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)')), list_ndims(arrow_cast(make_array([1, 2], [3, 4]), 'LargeList(List(Int64))')), list_ndims(arrow_cast(make_array([[[[1], [2]]]]), 'LargeList(List(List(List(List(Int64)))))')); +---- +1 2 5 + query II -select array_ndims(make_array()), array_ndims(make_array(make_array())) +select list_ndims(make_array()), list_ndims(make_array(make_array())) +---- +1 2 + +query II +select list_ndims(arrow_cast(make_array(), 'LargeList(Null)')), list_ndims(arrow_cast(make_array(make_array()), 'LargeList(List(Null))')) ---- 1 2 @@ -3576,6 +3616,17 @@ NULL 1 1 2 NULL 1 2 1 NULL +query III +select array_ndims(column1), array_ndims(column2), array_ndims(column3) from large_arrays; +---- +2 1 1 +2 1 1 +2 1 1 +2 1 1 +NULL 1 1 +2 NULL 1 +2 1 NULL + ## array_has/array_has_all/array_has_any query BBBBBBBBBBBB From c1fe3dd8f95ab75511c3295e87782373ad060877 Mon Sep 17 00:00:00 2001 From: Ashim Sedhain <38435962+asimsedhain@users.noreply.github.com> Date: Tue, 2 Jan 2024 09:59:15 -0600 Subject: [PATCH 08/18] feat: remove filters with null constants (#8700) --- datafusion/optimizer/src/eliminate_filter.rs | 33 +++++++++++++++----- 1 file changed, 25 insertions(+), 8 deletions(-) diff --git a/datafusion/optimizer/src/eliminate_filter.rs b/datafusion/optimizer/src/eliminate_filter.rs index c97906a81adf..fea14342ca77 100644 --- a/datafusion/optimizer/src/eliminate_filter.rs +++ b/datafusion/optimizer/src/eliminate_filter.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! Optimizer rule to replace `where false` on a plan with an empty relation. +//! Optimizer rule to replace `where false or null` on a plan with an empty relation. //! This saves time in planning and executing the query. //! Note that this rule should be applied after simplify expressions optimizer rule. use crate::optimizer::ApplyOrder; @@ -27,7 +27,7 @@ use datafusion_expr::{ use crate::{OptimizerConfig, OptimizerRule}; -/// Optimization rule that eliminate the scalar value (true/false) filter with an [LogicalPlan::EmptyRelation] +/// Optimization rule that eliminate the scalar value (true/false/null) filter with an [LogicalPlan::EmptyRelation] #[derive(Default)] pub struct EliminateFilter; @@ -46,20 +46,22 @@ impl OptimizerRule for EliminateFilter { ) -> Result> { match plan { LogicalPlan::Filter(Filter { - predicate: Expr::Literal(ScalarValue::Boolean(Some(v))), + predicate: Expr::Literal(ScalarValue::Boolean(v)), input, .. }) => { match *v { // input also can be filter, apply again - true => Ok(Some( + Some(true) => Ok(Some( self.try_optimize(input, _config)? .unwrap_or_else(|| input.as_ref().clone()), )), - false => Ok(Some(LogicalPlan::EmptyRelation(EmptyRelation { - produce_one_row: false, - schema: input.schema().clone(), - }))), + Some(false) | None => { + Ok(Some(LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: input.schema().clone(), + }))) + } } } _ => Ok(None), @@ -105,6 +107,21 @@ mod tests { assert_optimized_plan_equal(&plan, expected) } + #[test] + fn filter_null() -> Result<()> { + let filter_expr = Expr::Literal(ScalarValue::Boolean(None)); + + let table_scan = test_table_scan().unwrap(); + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate(vec![col("a")], vec![sum(col("b"))])? + .filter(filter_expr)? + .build()?; + + // No aggregate / scan / limit + let expected = "EmptyRelation"; + assert_optimized_plan_equal(&plan, expected) + } + #[test] fn filter_false_nested() -> Result<()> { let filter_expr = Expr::Literal(ScalarValue::Boolean(Some(false))); From 67baf10249b26b4983d3cc3145817903dad8dcd4 Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Wed, 3 Jan 2024 06:37:20 +0800 Subject: [PATCH 09/18] support `LargeList` in `array_prepend` and `array_append` (#8679) * support largelist * fix cast error * fix cast * add tests * fix conflict * s TODO comment for future tests add TODO comment for future tests --------- Co-authored-by: hwj --- datafusion/common/src/utils.rs | 23 ++- .../expr/src/type_coercion/functions.rs | 24 +-- .../physical-expr/src/array_expressions.rs | 144 +++++++------- datafusion/sqllogictest/test_files/array.slt | 184 +++++++++++++++++- 4 files changed, 284 insertions(+), 91 deletions(-) diff --git a/datafusion/common/src/utils.rs b/datafusion/common/src/utils.rs index 49a00b24d10e..0a61fce15482 100644 --- a/datafusion/common/src/utils.rs +++ b/datafusion/common/src/utils.rs @@ -424,10 +424,11 @@ pub fn arrays_into_list_array( /// assert_eq!(base_type(&data_type), DataType::Int32); /// ``` pub fn base_type(data_type: &DataType) -> DataType { - if let DataType::List(field) = data_type { - base_type(field.data_type()) - } else { - data_type.to_owned() + match data_type { + DataType::List(field) | DataType::LargeList(field) => { + base_type(field.data_type()) + } + _ => data_type.to_owned(), } } @@ -462,6 +463,20 @@ pub fn coerced_type_with_base_type_only( field.is_nullable(), ))) } + DataType::LargeList(field) => { + let data_type = match field.data_type() { + DataType::LargeList(_) => { + coerced_type_with_base_type_only(field.data_type(), base_type) + } + _ => base_type.to_owned(), + }; + + DataType::LargeList(Arc::new(Field::new( + field.name(), + data_type, + field.is_nullable(), + ))) + } _ => base_type.clone(), } diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index fa47c92762bf..63908d539bd0 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -116,18 +116,18 @@ fn get_valid_types( &new_base_type, ); - if let DataType::List(ref field) = array_type { - let elem_type = field.data_type(); - if is_append { - Ok(vec![vec![array_type.clone(), elem_type.to_owned()]]) - } else { - Ok(vec![vec![elem_type.to_owned(), array_type.clone()]]) + match array_type { + DataType::List(ref field) | DataType::LargeList(ref field) => { + let elem_type = field.data_type(); + if is_append { + Ok(vec![vec![array_type.clone(), elem_type.to_owned()]]) + } else { + Ok(vec![vec![elem_type.to_owned(), array_type.clone()]]) + } } - } else { - Ok(vec![vec![]]) + _ => Ok(vec![vec![]]), } } - let valid_types = match signature { TypeSignature::Variadic(valid_types) => valid_types .iter() @@ -311,9 +311,9 @@ fn coerced_from<'a>( Utf8 | LargeUtf8 => Some(type_into.clone()), Null if can_cast_types(type_from, type_into) => Some(type_into.clone()), - // Only accept list with the same number of dimensions unless the type is Null. - // List with different dimensions should be handled in TypeSignature or other places before this. - List(_) + // Only accept list and largelist with the same number of dimensions unless the type is Null. + // List or LargeList with different dimensions should be handled in TypeSignature or other places before this. + List(_) | LargeList(_) if datafusion_common::utils::base_type(type_from).eq(&Null) || list_ndims(type_from) == list_ndims(type_into) => { diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 92ba7a4d1dcd..aad021610fcb 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -52,22 +52,6 @@ macro_rules! downcast_arg { }}; } -/// Downcasts multiple arguments into a single concrete type -/// $ARGS: &[ArrayRef] -/// $ARRAY_TYPE: type to downcast to -/// -/// $returns a Vec<$ARRAY_TYPE> -macro_rules! downcast_vec { - ($ARGS:expr, $ARRAY_TYPE:ident) => {{ - $ARGS - .iter() - .map(|e| match e.as_any().downcast_ref::<$ARRAY_TYPE>() { - Some(array) => Ok(array), - _ => internal_err!("failed to downcast"), - }) - }}; -} - /// Computes a BooleanArray indicating equality or inequality between elements in a list array and a specified element array. /// /// # Arguments @@ -832,17 +816,20 @@ pub fn array_pop_back(args: &[ArrayRef]) -> Result { /// /// # Examples /// -/// general_append_and_prepend( +/// generic_append_and_prepend( /// [1, 2, 3], 4, append => [1, 2, 3, 4] /// 5, [6, 7, 8], prepend => [5, 6, 7, 8] /// ) -fn general_append_and_prepend( - list_array: &ListArray, +fn generic_append_and_prepend( + list_array: &GenericListArray, element_array: &ArrayRef, data_type: &DataType, is_append: bool, -) -> Result { - let mut offsets = vec![0]; +) -> Result +where + i64: TryInto, +{ + let mut offsets = vec![O::usize_as(0)]; let values = list_array.values(); let original_data = values.to_data(); let element_data = element_array.to_data(); @@ -858,8 +845,8 @@ fn general_append_and_prepend( let element_index = 1; for (row_index, offset_window) in list_array.offsets().windows(2).enumerate() { - let start = offset_window[0] as usize; - let end = offset_window[1] as usize; + let start = offset_window[0].to_usize().unwrap(); + let end = offset_window[1].to_usize().unwrap(); if is_append { mutable.extend(values_index, start, end); mutable.extend(element_index, row_index, row_index + 1); @@ -867,12 +854,12 @@ fn general_append_and_prepend( mutable.extend(element_index, row_index, row_index + 1); mutable.extend(values_index, start, end); } - offsets.push(offsets[row_index] + (end - start + 1) as i32); + offsets.push(offsets[row_index] + O::usize_as(end - start + 1)); } let data = mutable.freeze(); - Ok(Arc::new(ListArray::try_new( + Ok(Arc::new(GenericListArray::::try_new( Arc::new(Field::new("item", data_type.to_owned(), true)), OffsetBuffer::new(offsets.into()), arrow_array::make_array(data), @@ -938,36 +925,6 @@ pub fn gen_range(args: &[ArrayRef]) -> Result { Ok(arr) } -/// Array_append SQL function -pub fn array_append(args: &[ArrayRef]) -> Result { - if args.len() != 2 { - return exec_err!("array_append expects two arguments"); - } - - let list_array = as_list_array(&args[0])?; - let element_array = &args[1]; - - let res = match list_array.value_type() { - DataType::List(_) => concat_internal(args)?, - DataType::Null => { - return make_array(&[ - list_array.values().to_owned(), - element_array.to_owned(), - ]); - } - data_type => { - return general_append_and_prepend( - list_array, - element_array, - &data_type, - true, - ); - } - }; - - Ok(res) -} - /// Array_sort SQL function pub fn array_sort(args: &[ArrayRef]) -> Result { if args.is_empty() || args.len() > 3 { @@ -1051,25 +1008,40 @@ fn order_nulls_first(modifier: &str) -> Result { } } -/// Array_prepend SQL function -pub fn array_prepend(args: &[ArrayRef]) -> Result { - if args.len() != 2 { - return exec_err!("array_prepend expects two arguments"); - } - - let list_array = as_list_array(&args[1])?; - let element_array = &args[0]; +fn general_append_and_prepend( + args: &[ArrayRef], + is_append: bool, +) -> Result +where + i64: TryInto, +{ + let (list_array, element_array) = if is_append { + let list_array = as_generic_list_array::(&args[0])?; + let element_array = &args[1]; + check_datatypes("array_append", &[element_array, list_array.values()])?; + (list_array, element_array) + } else { + let list_array = as_generic_list_array::(&args[1])?; + let element_array = &args[0]; + check_datatypes("array_prepend", &[list_array.values(), element_array])?; + (list_array, element_array) + }; - check_datatypes("array_prepend", &[element_array, list_array.values()])?; let res = match list_array.value_type() { - DataType::List(_) => concat_internal(args)?, - DataType::Null => return make_array(&[element_array.to_owned()]), + DataType::List(_) => concat_internal::(args)?, + DataType::LargeList(_) => concat_internal::(args)?, + DataType::Null => { + return make_array(&[ + list_array.values().to_owned(), + element_array.to_owned(), + ]); + } data_type => { - return general_append_and_prepend( + return generic_append_and_prepend::( list_array, element_array, &data_type, - false, + is_append, ); } }; @@ -1077,6 +1049,30 @@ pub fn array_prepend(args: &[ArrayRef]) -> Result { Ok(res) } +/// Array_append SQL function +pub fn array_append(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return exec_err!("array_append expects two arguments"); + } + + match args[0].data_type() { + DataType::LargeList(_) => general_append_and_prepend::(args, true), + _ => general_append_and_prepend::(args, true), + } +} + +/// Array_prepend SQL function +pub fn array_prepend(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return exec_err!("array_prepend expects two arguments"); + } + + match args[1].data_type() { + DataType::LargeList(_) => general_append_and_prepend::(args, false), + _ => general_append_and_prepend::(args, false), + } +} + fn align_array_dimensions(args: Vec) -> Result> { let args_ndim = args .iter() @@ -1114,11 +1110,13 @@ fn align_array_dimensions(args: Vec) -> Result> { } // Concatenate arrays on the same row. -fn concat_internal(args: &[ArrayRef]) -> Result { +fn concat_internal(args: &[ArrayRef]) -> Result { let args = align_array_dimensions(args.to_vec())?; - let list_arrays = - downcast_vec!(args, ListArray).collect::>>()?; + let list_arrays = args + .iter() + .map(|arg| as_generic_list_array::(arg)) + .collect::>>()?; // Assume number of rows is the same for all arrays let row_count = list_arrays[0].len(); @@ -1165,7 +1163,7 @@ fn concat_internal(args: &[ArrayRef]) -> Result { .map(|a| a.as_ref()) .collect::>(); - let list_arr = ListArray::new( + let list_arr = GenericListArray::::new( Arc::new(Field::new("item", data_type, true)), OffsetBuffer::from_lengths(array_lengths), Arc::new(compute::concat(elements.as_slice())?), @@ -1192,7 +1190,7 @@ pub fn array_concat(args: &[ArrayRef]) -> Result { } } - concat_internal(new_args.as_slice()) + concat_internal::(new_args.as_slice()) } /// Array_empty SQL function diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 2f8e3c805f73..a3b2c8cdf1e9 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -107,6 +107,19 @@ AS VALUES (make_array(make_array(4, 5, 6), make_array(10, 11, 12), make_array(4, 9, 8), make_array(7, 8, 9), make_array(10, 11, 12), make_array(1, 8, 7)), make_array(10, 11, 12), 3, make_array([[11, 12, 13], [14, 15, 16]], [[17, 18, 19], [20, 21, 22]]), make_array(121, 131, 141)) ; +# TODO: add this when #8305 is fixed +# statement ok +# CREATE TABLE large_nested_arrays +# AS +# SELECT +# arrow_cast(column1, 'LargeList(LargeList(Int64))') AS column1, +# arrow_cast(column2, 'LargeList(Int64)') AS column2, +# column3, +# arrow_cast(column4, 'LargeList(LargeList(List(Int64)))') AS column4, +# arrow_cast(column5, 'LargeList(Int64)') AS column5 +# FROM nested_arrays +# ; + statement ok CREATE TABLE arrays_values AS VALUES @@ -120,6 +133,17 @@ AS VALUES (make_array(61, 62, 63, 64, 65, 66, 67, 68, 69, 70), 66, 7, NULL) ; +statement ok +CREATE TABLE large_arrays_values +AS SELECT + arrow_cast(column1, 'LargeList(Int64)') AS column1, + column2, + column3, + column4 +FROM arrays_values +; + + statement ok CREATE TABLE arrays_values_v2 AS VALUES @@ -131,6 +155,17 @@ AS VALUES (NULL, NULL, NULL, NULL) ; +# TODO: add this when #8305 is fixed +# statement ok +# CREATE TABLE large_arrays_values_v2 +# AS SELECT +# arrow_cast(column1, 'LargeList(Int64)') AS column1, +# arrow_cast(column2, 'LargeList(Int64)') AS column2, +# column3, +# arrow_cast(column4, 'LargeList(LargeList(Int64))') AS column4 +# FROM arrays_values_v2 +# ; + statement ok CREATE TABLE flatten_table AS VALUES @@ -1532,7 +1567,7 @@ query error select array_append(null, [[4]]); query ???? -select +select array_append(make_array(), 4), array_append(make_array(), null), array_append(make_array(1, null, 3), 4), @@ -1541,6 +1576,17 @@ select ---- [4] [] [1, , 3, 4] [, , 1] +# TODO: add this when #8305 is fixed +# query ???? +# select +# array_append(arrow_cast(make_array(), 'LargeList(Null)'), 4), +# array_append(make_array(), null), +# array_append(make_array(1, null, 3), 4), +# array_append(make_array(null, null), 1) +# ; +# ---- +# [4] [] [1, , 3, 4] [, , 1] + # test invalid (non-null) query error select array_append(1, 2); @@ -1552,42 +1598,76 @@ query error select array_append([1], [2]); query ?? -select +select array_append(make_array(make_array(1, null, 3)), make_array(null)), array_append(make_array(make_array(1, null, 3)), null); ---- [[1, , 3], []] [[1, , 3], ] +# TODO: add this when #8305 is fixed +# query ?? +# select +# array_append(arrow_cast(make_array(make_array(1, null, 3), 'LargeList(LargeList(Int64))')), arrow_cast(make_array(null), 'LargeList(Int64)')), +# array_append(arrow_cast(make_array(make_array(1, null, 3), 'LargeList(LargeList(Int64))')), null); +# ---- +# [[1, , 3], []] [[1, , 3], ] + # array_append scalar function #3 query ??? select array_append(make_array(1, 2, 3), 4), array_append(make_array(1.0, 2.0, 3.0), 4.0), array_append(make_array('h', 'e', 'l', 'l'), 'o'); ---- [1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] +query ??? +select array_append(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'), 4), array_append(arrow_cast(make_array(1.0, 2.0, 3.0), 'LargeList(Float64)'), 4.0), array_append(make_array('h', 'e', 'l', 'l'), 'o'); +---- +[1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] + # array_append scalar function #4 (element is list) query ??? select array_append(make_array([1], [2], [3]), make_array(4)), array_append(make_array([1.0], [2.0], [3.0]), make_array(4.0)), array_append(make_array(['h'], ['e'], ['l'], ['l']), make_array('o')); ---- [[1], [2], [3], [4]] [[1.0], [2.0], [3.0], [4.0]] [[h], [e], [l], [l], [o]] +# TODO: add this when #8305 is fixed +# query ??? +# select array_append(arrow_cast(make_array([1], [2], [3]), 'LargeList(LargeList(Int64))'), arrow_cast(make_array(4), 'LargeList(Int64)')), array_append(arrow_cast(make_array([1.0], [2.0], [3.0]), 'LargeList(LargeList(Float64))'), arrow_cast(make_array(4.0), 'LargeList(Float64)')), array_append(arrow_cast(make_array(['h'], ['e'], ['l'], ['l']), 'LargeList(LargeList(Utf8))'), arrow_cast(make_array('o'), 'LargeList(Utf8)')); +# ---- +# [[1], [2], [3], [4]] [[1.0], [2.0], [3.0], [4.0]] [[h], [e], [l], [l], [o]] + # list_append scalar function #5 (function alias `array_append`) query ??? select list_append(make_array(1, 2, 3), 4), list_append(make_array(1.0, 2.0, 3.0), 4.0), list_append(make_array('h', 'e', 'l', 'l'), 'o'); ---- [1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] +query ??? +select list_append(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'), 4), list_append(arrow_cast(make_array(1.0, 2.0, 3.0), 'LargeList(Float64)'), 4.0), list_append(make_array('h', 'e', 'l', 'l'), 'o'); +---- +[1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] + # array_push_back scalar function #6 (function alias `array_append`) query ??? select array_push_back(make_array(1, 2, 3), 4), array_push_back(make_array(1.0, 2.0, 3.0), 4.0), array_push_back(make_array('h', 'e', 'l', 'l'), 'o'); ---- [1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] +query ??? +select array_push_back(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'), 4), array_push_back(arrow_cast(make_array(1.0, 2.0, 3.0), 'LargeList(Float64)'), 4.0), array_push_back(make_array('h', 'e', 'l', 'l'), 'o'); +---- +[1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] + # list_push_back scalar function #7 (function alias `array_append`) query ??? select list_push_back(make_array(1, 2, 3), 4), list_push_back(make_array(1.0, 2.0, 3.0), 4.0), list_push_back(make_array('h', 'e', 'l', 'l'), 'o'); ---- [1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] +query ??? +select list_push_back(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'), 4), list_push_back(arrow_cast(make_array(1.0, 2.0, 3.0), 'LargeList(Float64)'), 4.0), list_push_back(make_array('h', 'e', 'l', 'l'), 'o'); +---- +[1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] + # array_append with columns #1 query ? select array_append(column1, column2) from arrays_values; @@ -1601,6 +1681,18 @@ select array_append(column1, column2) from arrays_values; [51, 52, , 54, 55, 56, 57, 58, 59, 60, 55] [61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 66] +query ? +select array_append(column1, column2) from large_arrays_values; +---- +[, 2, 3, 4, 5, 6, 7, 8, 9, 10, 1] +[11, 12, 13, 14, 15, 16, 17, 18, , 20, 12] +[21, 22, 23, , 25, 26, 27, 28, 29, 30, 23] +[31, 32, 33, 34, 35, , 37, 38, 39, 40, 34] +[44] +[41, 42, 43, 44, 45, 46, 47, 48, 49, 50, ] +[51, 52, , 54, 55, 56, 57, 58, 59, 60, 55] +[61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 66] + # array_append with columns #2 (element is list) query ? select array_append(column1, column2) from nested_arrays; @@ -1608,6 +1700,13 @@ select array_append(column1, column2) from nested_arrays; [[1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6], [7, 8, 9]] [[4, 5, 6], [10, 11, 12], [4, 9, 8], [7, 8, 9], [10, 11, 12], [1, 8, 7], [10, 11, 12]] +# TODO: add this when #8305 is fixed +# query ? +# select array_append(column1, column2) from large_nested_arrays; +# ---- +# [[1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6], [7, 8, 9]] +# [[4, 5, 6], [10, 11, 12], [4, 9, 8], [7, 8, 9], [10, 11, 12], [1, 8, 7], [10, 11, 12]] + # array_append with columns and scalars #1 query ?? select array_append(column2, 100.1), array_append(column3, '.') from arrays; @@ -1620,6 +1719,17 @@ select array_append(column2, 100.1), array_append(column3, '.') from arrays; [100.1] [,, .] [16.6, 17.7, 18.8, 100.1] [.] +query ?? +select array_append(column2, 100.1), array_append(column3, '.') from large_arrays; +---- +[1.1, 2.2, 3.3, 100.1] [L, o, r, e, m, .] +[, 5.5, 6.6, 100.1] [i, p, , u, m, .] +[7.7, 8.8, 9.9, 100.1] [d, , l, o, r, .] +[10.1, , 12.2, 100.1] [s, i, t, .] +[13.3, 14.4, 15.5, 100.1] [a, m, e, t, .] +[100.1] [,, .] +[16.6, 17.7, 18.8, 100.1] [.] + # array_append with columns and scalars #2 query ?? select array_append(column1, make_array(1, 11, 111)), array_append(make_array(make_array(1, 2, 3), make_array(11, 12, 13)), column2) from nested_arrays; @@ -1627,6 +1737,13 @@ select array_append(column1, make_array(1, 11, 111)), array_append(make_array(ma [[1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6], [1, 11, 111]] [[1, 2, 3], [11, 12, 13], [7, 8, 9]] [[4, 5, 6], [10, 11, 12], [4, 9, 8], [7, 8, 9], [10, 11, 12], [1, 8, 7], [1, 11, 111]] [[1, 2, 3], [11, 12, 13], [10, 11, 12]] +# TODO: add this when #8305 is fixed +# query ?? +# select array_append(column1, arrow_cast(make_array(1, 11, 111), 'LargeList(Int64)')), array_append(arrow_cast(make_array(make_array(1, 2, 3), make_array(11, 12, 13)), 'LargeList(LargeList(Int64))'), column2) from large_nested_arrays; +# ---- +# [[1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6], [1, 11, 111]] [[1, 2, 3], [11, 12, 13], [7, 8, 9]] +# [[4, 5, 6], [10, 11, 12], [4, 9, 8], [7, 8, 9], [10, 11, 12], [1, 8, 7], [1, 11, 111]] [[1, 2, 3], [11, 12, 13], [10, 11, 12]] + ## array_prepend (aliases: `list_prepend`, `array_push_front`, `list_push_front`) # array_prepend with NULLs @@ -1688,30 +1805,56 @@ select array_prepend(1, make_array(2, 3, 4)), array_prepend(1.0, make_array(2.0, ---- [1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] +query ??? +select array_prepend(1, arrow_cast(make_array(2, 3, 4), 'LargeList(Int64)')), array_prepend(1.0, arrow_cast(make_array(2.0, 3.0, 4.0), 'LargeList(Float64)')), array_prepend('h', arrow_cast(make_array('e', 'l', 'l', 'o'), 'LargeList(Utf8)')); +---- +[1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] + # array_prepend scalar function #4 (element is list) query ??? select array_prepend(make_array(1), make_array(make_array(2), make_array(3), make_array(4))), array_prepend(make_array(1.0), make_array([2.0], [3.0], [4.0])), array_prepend(make_array('h'), make_array(['e'], ['l'], ['l'], ['o'])); ---- [[1], [2], [3], [4]] [[1.0], [2.0], [3.0], [4.0]] [[h], [e], [l], [l], [o]] +# TODO: add this when #8305 is fixed +# query ??? +# select array_prepend(arrow_cast(make_array(1), 'LargeList(Int64)'), arrow_cast(make_array(make_array(2), make_array(3), make_array(4)), 'LargeList(LargeList(Int64))')), array_prepend(arrow_cast(make_array(1.0), 'LargeList(Float64)'), arrow_cast(make_array([2.0], [3.0], [4.0]), 'LargeList(LargeList(Float64))')), array_prepend(arrow_cast(make_array('h'), 'LargeList(Utf8)'), arrow_cast(make_array(['e'], ['l'], ['l'], ['o']), 'LargeList(LargeList(Utf8))'')); +# ---- +# [[1], [2], [3], [4]] [[1.0], [2.0], [3.0], [4.0]] [[h], [e], [l], [l], [o]] + # list_prepend scalar function #5 (function alias `array_prepend`) query ??? select list_prepend(1, make_array(2, 3, 4)), list_prepend(1.0, make_array(2.0, 3.0, 4.0)), list_prepend('h', make_array('e', 'l', 'l', 'o')); ---- [1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] +query ??? +select list_prepend(1, arrow_cast(make_array(2, 3, 4), 'LargeList(Int64)')), list_prepend(1.0, arrow_cast(make_array(2.0, 3.0, 4.0), 'LargeList(Float64)')), list_prepend('h', arrow_cast(make_array('e', 'l', 'l', 'o'), 'LargeList(Utf8)')); +---- +[1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] + # array_push_front scalar function #6 (function alias `array_prepend`) query ??? select array_push_front(1, make_array(2, 3, 4)), array_push_front(1.0, make_array(2.0, 3.0, 4.0)), array_push_front('h', make_array('e', 'l', 'l', 'o')); ---- [1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] +query ??? +select array_push_front(1, arrow_cast(make_array(2, 3, 4), 'LargeList(Int64)')), array_push_front(1.0, arrow_cast(make_array(2.0, 3.0, 4.0), 'LargeList(Float64)')), array_push_front('h', arrow_cast(make_array('e', 'l', 'l', 'o'), 'LargeList(Utf8)')); +---- +[1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] + # list_push_front scalar function #7 (function alias `array_prepend`) query ??? select list_push_front(1, make_array(2, 3, 4)), list_push_front(1.0, make_array(2.0, 3.0, 4.0)), list_push_front('h', make_array('e', 'l', 'l', 'o')); ---- [1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] +query ??? +select list_push_front(1, arrow_cast(make_array(2, 3, 4), 'LargeList(Int64)')), list_push_front(1.0, arrow_cast(make_array(2.0, 3.0, 4.0), 'LargeList(Float64)')), list_push_front('h', arrow_cast(make_array('e', 'l', 'l', 'o'), 'LargeList(Utf8)')); +---- +[1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] + # array_prepend with columns #1 query ? select array_prepend(column2, column1) from arrays_values; @@ -1725,6 +1868,18 @@ select array_prepend(column2, column1) from arrays_values; [55, 51, 52, , 54, 55, 56, 57, 58, 59, 60] [66, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70] +query ? +select array_prepend(column2, column1) from large_arrays_values; +---- +[1, , 2, 3, 4, 5, 6, 7, 8, 9, 10] +[12, 11, 12, 13, 14, 15, 16, 17, 18, , 20] +[23, 21, 22, 23, , 25, 26, 27, 28, 29, 30] +[34, 31, 32, 33, 34, 35, , 37, 38, 39, 40] +[44] +[, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50] +[55, 51, 52, , 54, 55, 56, 57, 58, 59, 60] +[66, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70] + # array_prepend with columns #2 (element is list) query ? select array_prepend(column2, column1) from nested_arrays; @@ -1732,6 +1887,13 @@ select array_prepend(column2, column1) from nested_arrays; [[7, 8, 9], [1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6]] [[10, 11, 12], [4, 5, 6], [10, 11, 12], [4, 9, 8], [7, 8, 9], [10, 11, 12], [1, 8, 7]] +# TODO: add this when #8305 is fixed +# query ? +# select array_prepend(column2, column1) from large_nested_arrays; +# ---- +# [[7, 8, 9], [1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6]] +# [[10, 11, 12], [4, 5, 6], [10, 11, 12], [4, 9, 8], [7, 8, 9], [10, 11, 12], [1, 8, 7]] + # array_prepend with columns and scalars #1 query ?? select array_prepend(100.1, column2), array_prepend('.', column3) from arrays; @@ -1744,6 +1906,17 @@ select array_prepend(100.1, column2), array_prepend('.', column3) from arrays; [100.1] [., ,] [100.1, 16.6, 17.7, 18.8] [.] +query ?? +select array_prepend(100.1, column2), array_prepend('.', column3) from large_arrays; +---- +[100.1, 1.1, 2.2, 3.3] [., L, o, r, e, m] +[100.1, , 5.5, 6.6] [., i, p, , u, m] +[100.1, 7.7, 8.8, 9.9] [., d, , l, o, r] +[100.1, 10.1, , 12.2] [., s, i, t] +[100.1, 13.3, 14.4, 15.5] [., a, m, e, t] +[100.1] [., ,] +[100.1, 16.6, 17.7, 18.8] [.] + # array_prepend with columns and scalars #2 (element is list) query ?? select array_prepend(make_array(1, 11, 111), column1), array_prepend(column2, make_array(make_array(1, 2, 3), make_array(11, 12, 13))) from nested_arrays; @@ -1751,6 +1924,13 @@ select array_prepend(make_array(1, 11, 111), column1), array_prepend(column2, ma [[1, 11, 111], [1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6]] [[7, 8, 9], [1, 2, 3], [11, 12, 13]] [[1, 11, 111], [4, 5, 6], [10, 11, 12], [4, 9, 8], [7, 8, 9], [10, 11, 12], [1, 8, 7]] [[10, 11, 12], [1, 2, 3], [11, 12, 13]] +# TODO: add this when #8305 is fixed +# query ?? +# select array_prepend(arrow_cast(make_array(1, 11, 111), 'LargeList(Int64)'), column1), array_prepend(column2, arrow_cast(make_array(make_array(1, 2, 3), make_array(11, 12, 13)), 'LargeList(LargeList(Int64))')) from large_nested_arrays; +# ---- +# [[1, 11, 111], [1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6]] [[7, 8, 9], [1, 2, 3], [11, 12, 13]] +# [[1, 11, 111], [4, 5, 6], [10, 11, 12], [4, 9, 8], [7, 8, 9], [10, 11, 12], [1, 8, 7]] [[10, 11, 12], [1, 2, 3], [11, 12, 13]] + ## array_repeat (aliases: `list_repeat`) # array_repeat scalar function #1 From 9a6cc889a40e4740bfc859557a9ca9c8d043891e Mon Sep 17 00:00:00 2001 From: Jeffrey <22608443+Jefffrey@users.noreply.github.com> Date: Wed, 3 Jan 2024 10:17:26 +1100 Subject: [PATCH 10/18] Support for `extract(epoch from date)` for Date32 and Date64 (#8695) --- datafusion/core/tests/sql/expr.rs | 34 ++++++++++++++ .../physical-expr/src/datetime_expressions.rs | 44 ++++++++++--------- 2 files changed, 58 insertions(+), 20 deletions(-) diff --git a/datafusion/core/tests/sql/expr.rs b/datafusion/core/tests/sql/expr.rs index 7d41ad4a881c..8ac0e3e5ef19 100644 --- a/datafusion/core/tests/sql/expr.rs +++ b/datafusion/core/tests/sql/expr.rs @@ -741,6 +741,7 @@ async fn test_extract_date_part() -> Result<()> { #[tokio::test] async fn test_extract_epoch() -> Result<()> { + // timestamp test_expression!( "extract(epoch from '1870-01-01T07:29:10.256'::timestamp)", "-3155646649.744" @@ -754,6 +755,39 @@ async fn test_extract_epoch() -> Result<()> { "946684800.0" ); test_expression!("extract(epoch from NULL::timestamp)", "NULL"); + // date + test_expression!( + "extract(epoch from arrow_cast('1970-01-01', 'Date32'))", + "0.0" + ); + test_expression!( + "extract(epoch from arrow_cast('1970-01-02', 'Date32'))", + "86400.0" + ); + test_expression!( + "extract(epoch from arrow_cast('1970-01-11', 'Date32'))", + "864000.0" + ); + test_expression!( + "extract(epoch from arrow_cast('1969-12-31', 'Date32'))", + "-86400.0" + ); + test_expression!( + "extract(epoch from arrow_cast('1970-01-01', 'Date64'))", + "0.0" + ); + test_expression!( + "extract(epoch from arrow_cast('1970-01-02', 'Date64'))", + "86400.0" + ); + test_expression!( + "extract(epoch from arrow_cast('1970-01-11', 'Date64'))", + "864000.0" + ); + test_expression!( + "extract(epoch from arrow_cast('1969-12-31', 'Date64'))", + "-86400.0" + ); Ok(()) } diff --git a/datafusion/physical-expr/src/datetime_expressions.rs b/datafusion/physical-expr/src/datetime_expressions.rs index f6373d40d965..589bbc8a952b 100644 --- a/datafusion/physical-expr/src/datetime_expressions.rs +++ b/datafusion/physical-expr/src/datetime_expressions.rs @@ -19,7 +19,6 @@ use crate::datetime_expressions; use crate::expressions::cast_column; -use arrow::array::Float64Builder; use arrow::compute::cast; use arrow::{ array::{Array, ArrayRef, Float64Array, OffsetSizeTrait, PrimitiveArray}, @@ -887,28 +886,33 @@ where T: ArrowTemporalType + ArrowNumericType, i64: From, { - let mut b = Float64Builder::with_capacity(array.len()); - match array.data_type() { + let b = match array.data_type() { DataType::Timestamp(tu, _) => { - for i in 0..array.len() { - if array.is_null(i) { - b.append_null(); - } else { - let scale = match tu { - TimeUnit::Second => 1, - TimeUnit::Millisecond => 1_000, - TimeUnit::Microsecond => 1_000_000, - TimeUnit::Nanosecond => 1_000_000_000, - }; - - let n: i64 = array.value(i).into(); - b.append_value(n as f64 / scale as f64); - } - } + let scale = match tu { + TimeUnit::Second => 1, + TimeUnit::Millisecond => 1_000, + TimeUnit::Microsecond => 1_000_000, + TimeUnit::Nanosecond => 1_000_000_000, + } as f64; + array.unary(|n| { + let n: i64 = n.into(); + n as f64 / scale + }) } + DataType::Date32 => { + let seconds_in_a_day = 86400_f64; + array.unary(|n| { + let n: i64 = n.into(); + n as f64 * seconds_in_a_day + }) + } + DataType::Date64 => array.unary(|n| { + let n: i64 = n.into(); + n as f64 / 1_000_f64 + }), _ => return internal_err!("Can not convert {:?} to epoch", array.data_type()), - } - Ok(b.finish()) + }; + Ok(b) } /// to_timestammp() SQL function implementation From 6b1e9c6a3ae95b7065e902d99d9fde66f0f8e054 Mon Sep 17 00:00:00 2001 From: junxiangMu <63799833+guojidan@users.noreply.github.com> Date: Wed, 3 Jan 2024 20:24:58 +0800 Subject: [PATCH 11/18] Implement trait based API for defining WindowUDF (#8719) * Implement trait based API for defining WindowUDF * add test case & docs * fix docs * rename WindowUDFImpl function --- datafusion-examples/README.md | 1 + datafusion-examples/examples/advanced_udwf.rs | 230 ++++++++++++++++++ .../user_defined_window_functions.rs | 64 +++-- datafusion/expr/src/expr_fn.rs | 67 ++++- datafusion/expr/src/lib.rs | 2 +- datafusion/expr/src/udwf.rs | 116 ++++++++- .../tests/cases/roundtrip_logical_plan.rs | 55 +++-- docs/source/library-user-guide/adding-udfs.md | 7 +- 8 files changed, 498 insertions(+), 44 deletions(-) create mode 100644 datafusion-examples/examples/advanced_udwf.rs diff --git a/datafusion-examples/README.md b/datafusion-examples/README.md index 1296c74ea277..aae451add9e7 100644 --- a/datafusion-examples/README.md +++ b/datafusion-examples/README.md @@ -63,6 +63,7 @@ cargo run --example csv_sql - [`advanced_udf.rs`](examples/advanced_udf.rs): Define and invoke a more complicated User Defined Scalar Function (UDF) - [`simple_udaf.rs`](examples/simple_udaf.rs): Define and invoke a User Defined Aggregate Function (UDAF) - [`simple_udfw.rs`](examples/simple_udwf.rs): Define and invoke a User Defined Window Function (UDWF) +- [`advanced_udwf.rs`](examples/advanced_udwf.rs): Define and invoke a more complicated User Defined Window Function (UDWF) ## Distributed diff --git a/datafusion-examples/examples/advanced_udwf.rs b/datafusion-examples/examples/advanced_udwf.rs new file mode 100644 index 000000000000..91869d80a41a --- /dev/null +++ b/datafusion-examples/examples/advanced_udwf.rs @@ -0,0 +1,230 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility}; +use std::any::Any; + +use arrow::{ + array::{ArrayRef, AsArray, Float64Array}, + datatypes::Float64Type, +}; +use datafusion::error::Result; +use datafusion::prelude::*; +use datafusion_common::ScalarValue; +use datafusion_expr::{ + PartitionEvaluator, Signature, WindowFrame, WindowUDF, WindowUDFImpl, +}; + +/// This example shows how to use the full WindowUDFImpl API to implement a user +/// defined window function. As in the `simple_udwf.rs` example, this struct implements +/// a function `partition_evaluator` that returns the `MyPartitionEvaluator` instance. +/// +/// To do so, we must implement the `WindowUDFImpl` trait. +struct SmoothItUdf { + signature: Signature, +} + +impl SmoothItUdf { + /// Create a new instance of the SmoothItUdf struct + fn new() -> Self { + Self { + signature: Signature::exact( + // this function will always take one arguments of type f64 + vec![DataType::Float64], + // this function is deterministic and will always return the same + // result for the same input + Volatility::Immutable, + ), + } + } +} + +impl WindowUDFImpl for SmoothItUdf { + /// We implement as_any so that we can downcast the WindowUDFImpl trait object + fn as_any(&self) -> &dyn Any { + self + } + + /// Return the name of this function + fn name(&self) -> &str { + "smooth_it" + } + + /// Return the "signature" of this function -- namely that types of arguments it will take + fn signature(&self) -> &Signature { + &self.signature + } + + /// What is the type of value that will be returned by this function. + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Float64) + } + + /// Create a `PartitionEvalutor` to evaluate this function on a new + /// partition. + fn partition_evaluator(&self) -> Result> { + Ok(Box::new(MyPartitionEvaluator::new())) + } +} + +/// This implements the lowest level evaluation for a window function +/// +/// It handles calculating the value of the window function for each +/// distinct values of `PARTITION BY` (each car type in our example) +#[derive(Clone, Debug)] +struct MyPartitionEvaluator {} + +impl MyPartitionEvaluator { + fn new() -> Self { + Self {} + } +} + +/// Different evaluation methods are called depending on the various +/// settings of WindowUDF. This example uses the simplest and most +/// general, `evaluate`. See `PartitionEvaluator` for the other more +/// advanced uses. +impl PartitionEvaluator for MyPartitionEvaluator { + /// Tell DataFusion the window function varies based on the value + /// of the window frame. + fn uses_window_frame(&self) -> bool { + true + } + + /// This function is called once per input row. + /// + /// `range`specifies which indexes of `values` should be + /// considered for the calculation. + /// + /// Note this is the SLOWEST, but simplest, way to evaluate a + /// window function. It is much faster to implement + /// evaluate_all or evaluate_all_with_rank, if possible + fn evaluate( + &mut self, + values: &[ArrayRef], + range: &std::ops::Range, + ) -> Result { + // Again, the input argument is an array of floating + // point numbers to calculate a moving average + let arr: &Float64Array = values[0].as_ref().as_primitive::(); + + let range_len = range.end - range.start; + + // our smoothing function will average all the values in the + let output = if range_len > 0 { + let sum: f64 = arr.values().iter().skip(range.start).take(range_len).sum(); + Some(sum / range_len as f64) + } else { + None + }; + + Ok(ScalarValue::Float64(output)) + } +} + +// create local execution context with `cars.csv` registered as a table named `cars` +async fn create_context() -> Result { + // declare a new context. In spark API, this corresponds to a new spark SQL session + let ctx = SessionContext::new(); + + // declare a table in memory. In spark API, this corresponds to createDataFrame(...). + println!("pwd: {}", std::env::current_dir().unwrap().display()); + let csv_path = "../../datafusion/core/tests/data/cars.csv".to_string(); + let read_options = CsvReadOptions::default().has_header(true); + + ctx.register_csv("cars", &csv_path, read_options).await?; + Ok(ctx) +} + +#[tokio::main] +async fn main() -> Result<()> { + let ctx = create_context().await?; + let smooth_it = WindowUDF::from(SmoothItUdf::new()); + ctx.register_udwf(smooth_it.clone()); + + // Use SQL to run the new window function + let df = ctx.sql("SELECT * from cars").await?; + // print the results + df.show().await?; + + // Use SQL to run the new window function: + // + // `PARTITION BY car`:each distinct value of car (red, and green) + // should be treated as a separate partition (and will result in + // creating a new `PartitionEvaluator`) + // + // `ORDER BY time`: within each partition ('green' or 'red') the + // rows will be be ordered by the value in the `time` column + // + // `evaluate_inside_range` is invoked with a window defined by the + // SQL. In this case: + // + // The first invocation will be passed row 0, the first row in the + // partition. + // + // The second invocation will be passed rows 0 and 1, the first + // two rows in the partition. + // + // etc. + let df = ctx + .sql( + "SELECT \ + car, \ + speed, \ + smooth_it(speed) OVER (PARTITION BY car ORDER BY time) AS smooth_speed,\ + time \ + from cars \ + ORDER BY \ + car", + ) + .await?; + // print the results + df.show().await?; + + // this time, call the new widow function with an explicit + // window so evaluate will be invoked with each window. + // + // `ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING`: each invocation + // sees at most 3 rows: the row before, the current row, and the 1 + // row afterward. + let df = ctx.sql( + "SELECT \ + car, \ + speed, \ + smooth_it(speed) OVER (PARTITION BY car ORDER BY time ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) AS smooth_speed,\ + time \ + from cars \ + ORDER BY \ + car", + ).await?; + // print the results + df.show().await?; + + // Now, run the function using the DataFrame API: + let window_expr = smooth_it.call( + vec![col("speed")], // smooth_it(speed) + vec![col("car")], // PARTITION BY car + vec![col("time").sort(true, true)], // ORDER BY time ASC + WindowFrame::new(false), + ); + let df = ctx.table("cars").await?.window(vec![window_expr])?; + + // print the results + df.show().await?; + + Ok(()) +} diff --git a/datafusion/core/tests/user_defined/user_defined_window_functions.rs b/datafusion/core/tests/user_defined/user_defined_window_functions.rs index 5f9939157217..3040fbafe81a 100644 --- a/datafusion/core/tests/user_defined/user_defined_window_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_window_functions.rs @@ -19,6 +19,7 @@ //! user defined window functions use std::{ + any::Any, ops::Range, sync::{ atomic::{AtomicUsize, Ordering}, @@ -32,8 +33,7 @@ use arrow_schema::DataType; use datafusion::{assert_batches_eq, prelude::SessionContext}; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ - function::PartitionEvaluatorFactory, PartitionEvaluator, ReturnTypeFunction, - Signature, Volatility, WindowUDF, + PartitionEvaluator, Signature, Volatility, WindowUDF, WindowUDFImpl, }; /// A query with a window function evaluated over the entire partition @@ -471,24 +471,48 @@ impl OddCounter { } fn register(ctx: &mut SessionContext, test_state: Arc) { - let name = "odd_counter"; - let volatility = Volatility::Immutable; - - let signature = Signature::exact(vec![DataType::Int64], volatility); - - let return_type = Arc::new(DataType::Int64); - let return_type: ReturnTypeFunction = - Arc::new(move |_| Ok(Arc::clone(&return_type))); - - let partition_evaluator_factory: PartitionEvaluatorFactory = - Arc::new(move || Ok(Box::new(OddCounter::new(Arc::clone(&test_state))))); - - ctx.register_udwf(WindowUDF::new( - name, - &signature, - &return_type, - &partition_evaluator_factory, - )) + struct SimpleWindowUDF { + signature: Signature, + return_type: DataType, + test_state: Arc, + } + + impl SimpleWindowUDF { + fn new(test_state: Arc) -> Self { + let signature = + Signature::exact(vec![DataType::Float64], Volatility::Immutable); + let return_type = DataType::Int64; + Self { + signature, + return_type, + test_state, + } + } + } + + impl WindowUDFImpl for SimpleWindowUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "odd_counter" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(self.return_type.clone()) + } + + fn partition_evaluator(&self) -> Result> { + Ok(Box::new(OddCounter::new(Arc::clone(&self.test_state)))) + } + } + + ctx.register_udwf(WindowUDF::from(SimpleWindowUDF::new(test_state))) } } diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index eed41d97ccba..f76fb17b38bb 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -28,7 +28,7 @@ use crate::{ BuiltinScalarFunction, Expr, LogicalPlan, Operator, ReturnTypeFunction, ScalarFunctionImplementation, ScalarUDF, Signature, StateTypeFunction, Volatility, }; -use crate::{ColumnarValue, ScalarUDFImpl, WindowUDF}; +use crate::{ColumnarValue, ScalarUDFImpl, WindowUDF, WindowUDFImpl}; use arrow::datatypes::DataType; use datafusion_common::{Column, Result}; use std::any::Any; @@ -1059,13 +1059,66 @@ pub fn create_udwf( volatility: Volatility, partition_evaluator_factory: PartitionEvaluatorFactory, ) -> WindowUDF { - let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(return_type.clone())); - WindowUDF::new( + let return_type = Arc::try_unwrap(return_type).unwrap_or_else(|t| t.as_ref().clone()); + WindowUDF::from(SimpleWindowUDF::new( name, - &Signature::exact(vec![input_type], volatility), - &return_type, - &partition_evaluator_factory, - ) + input_type, + return_type, + volatility, + partition_evaluator_factory, + )) +} + +/// Implements [`WindowUDFImpl`] for functions that have a single signature and +/// return type. +pub struct SimpleWindowUDF { + name: String, + signature: Signature, + return_type: DataType, + partition_evaluator_factory: PartitionEvaluatorFactory, +} + +impl SimpleWindowUDF { + /// Create a new `SimpleWindowUDF` from a name, input types, return type and + /// implementation. Implementing [`WindowUDFImpl`] allows more flexibility + pub fn new( + name: impl Into, + input_type: DataType, + return_type: DataType, + volatility: Volatility, + partition_evaluator_factory: PartitionEvaluatorFactory, + ) -> Self { + let name = name.into(); + let signature = Signature::exact([input_type].to_vec(), volatility); + Self { + name, + signature, + return_type, + partition_evaluator_factory, + } + } +} + +impl WindowUDFImpl for SimpleWindowUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(self.return_type.clone()) + } + + fn partition_evaluator(&self) -> Result> { + (self.partition_evaluator_factory)() + } } /// Calls a named built in function diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index ab213a19a352..077681d21725 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -82,7 +82,7 @@ pub use signature::{ pub use table_source::{TableProviderFilterPushDown, TableSource, TableType}; pub use udaf::AggregateUDF; pub use udf::{ScalarUDF, ScalarUDFImpl}; -pub use udwf::WindowUDF; +pub use udwf::{WindowUDF, WindowUDFImpl}; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; #[cfg(test)] diff --git a/datafusion/expr/src/udwf.rs b/datafusion/expr/src/udwf.rs index a97a68341f5c..800386bfc77b 100644 --- a/datafusion/expr/src/udwf.rs +++ b/datafusion/expr/src/udwf.rs @@ -24,6 +24,7 @@ use crate::{ use arrow::datatypes::DataType; use datafusion_common::Result; use std::{ + any::Any, fmt::{self, Debug, Display, Formatter}, sync::Arc, }; @@ -80,7 +81,11 @@ impl std::hash::Hash for WindowUDF { } impl WindowUDF { - /// Create a new WindowUDF + /// Create a new WindowUDF from low level details. + /// + /// See [`WindowUDFImpl`] for a more convenient way to create a + /// `WindowUDF` using trait objects + #[deprecated(since = "34.0.0", note = "please implement ScalarUDFImpl instead")] pub fn new( name: &str, signature: &Signature, @@ -95,6 +100,32 @@ impl WindowUDF { } } + /// Create a new `WindowUDF` from a `[WindowUDFImpl]` trait object + /// + /// Note this is the same as using the `From` impl (`WindowUDF::from`) + pub fn new_from_impl(fun: F) -> WindowUDF + where + F: WindowUDFImpl + Send + Sync + 'static, + { + let arc_fun = Arc::new(fun); + let captured_self = arc_fun.clone(); + let return_type: ReturnTypeFunction = Arc::new(move |arg_types| { + let return_type = captured_self.return_type(arg_types)?; + Ok(Arc::new(return_type)) + }); + + let captured_self = arc_fun.clone(); + let partition_evaluator_factory: PartitionEvaluatorFactory = + Arc::new(move || captured_self.partition_evaluator()); + + Self { + name: arc_fun.name().to_string(), + signature: arc_fun.signature().clone(), + return_type: return_type.clone(), + partition_evaluator_factory, + } + } + /// creates a [`Expr`] that calls the window function given /// the `partition_by`, `order_by`, and `window_frame` definition /// @@ -140,3 +171,86 @@ impl WindowUDF { (self.partition_evaluator_factory)() } } + +impl From for WindowUDF +where + F: WindowUDFImpl + Send + Sync + 'static, +{ + fn from(fun: F) -> Self { + Self::new_from_impl(fun) + } +} + +/// Trait for implementing [`WindowUDF`]. +/// +/// This trait exposes the full API for implementing user defined window functions and +/// can be used to implement any function. +/// +/// See [`advanced_udwf.rs`] for a full example with complete implementation and +/// [`WindowUDF`] for other available options. +/// +/// +/// [`advanced_udwf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/advanced_udwf.rs +/// # Basic Example +/// ``` +/// # use std::any::Any; +/// # use arrow::datatypes::DataType; +/// # use datafusion_common::{DataFusionError, plan_err, Result}; +/// # use datafusion_expr::{col, Signature, Volatility, PartitionEvaluator, WindowFrame}; +/// # use datafusion_expr::{WindowUDFImpl, WindowUDF}; +/// struct SmoothIt { +/// signature: Signature +/// }; +/// +/// impl SmoothIt { +/// fn new() -> Self { +/// Self { +/// signature: Signature::uniform(1, vec![DataType::Int32], Volatility::Immutable) +/// } +/// } +/// } +/// +/// /// Implement the WindowUDFImpl trait for AddOne +/// impl WindowUDFImpl for SmoothIt { +/// fn as_any(&self) -> &dyn Any { self } +/// fn name(&self) -> &str { "smooth_it" } +/// fn signature(&self) -> &Signature { &self.signature } +/// fn return_type(&self, args: &[DataType]) -> Result { +/// if !matches!(args.get(0), Some(&DataType::Int32)) { +/// return plan_err!("smooth_it only accepts Int32 arguments"); +/// } +/// Ok(DataType::Int32) +/// } +/// // The actual implementation would add one to the argument +/// fn partition_evaluator(&self) -> Result> { unimplemented!() } +/// } +/// +/// // Create a new ScalarUDF from the implementation +/// let smooth_it = WindowUDF::from(SmoothIt::new()); +/// +/// // Call the function `add_one(col)` +/// let expr = smooth_it.call( +/// vec![col("speed")], // smooth_it(speed) +/// vec![col("car")], // PARTITION BY car +/// vec![col("time").sort(true, true)], // ORDER BY time ASC +/// WindowFrame::new(false), +/// ); +/// ``` +pub trait WindowUDFImpl { + /// Returns this object as an [`Any`] trait object + fn as_any(&self) -> &dyn Any; + + /// Returns this function's name + fn name(&self) -> &str; + + /// Returns the function's [`Signature`] for information about what input + /// types are accepted and the function's Volatility. + fn signature(&self) -> &Signature; + + /// What [`DataType`] will be returned by this function, given the types of + /// the arguments + fn return_type(&self, arg_types: &[DataType]) -> Result; + + /// Invoke the function, returning the [`PartitionEvaluator`] instance + fn partition_evaluator(&self) -> Result>; +} diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index dea99f91e392..402781e17e6f 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use std::any::Any; use std::collections::HashMap; use std::fmt::{self, Debug, Formatter}; use std::sync::Arc; @@ -54,6 +55,7 @@ use datafusion_expr::{ BuiltinScalarFunction::{Sqrt, Substr}, Expr, LogicalPlan, Operator, PartitionEvaluator, Signature, TryCast, Volatility, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, WindowUDF, + WindowUDFImpl, }; use datafusion_proto::bytes::{ logical_plan_from_bytes, logical_plan_from_bytes_with_extension_codec, @@ -1785,27 +1787,52 @@ fn roundtrip_window() { } } - fn return_type(arg_types: &[DataType]) -> Result> { - if arg_types.len() != 1 { - return plan_err!( - "dummy_udwf expects 1 argument, got {}: {:?}", - arg_types.len(), - arg_types - ); + struct SimpleWindowUDF { + signature: Signature, + } + + impl SimpleWindowUDF { + fn new() -> Self { + let signature = + Signature::exact(vec![DataType::Float64], Volatility::Immutable); + Self { signature } + } + } + + impl WindowUDFImpl for SimpleWindowUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "dummy_udwf" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + if arg_types.len() != 1 { + return plan_err!( + "dummy_udwf expects 1 argument, got {}: {:?}", + arg_types.len(), + arg_types + ); + } + Ok(arg_types[0].clone()) + } + + fn partition_evaluator(&self) -> Result> { + make_partition_evaluator() } - Ok(Arc::new(arg_types[0].clone())) } fn make_partition_evaluator() -> Result> { Ok(Box::new(DummyWindow {})) } - let dummy_window_udf = WindowUDF::new( - "dummy_udwf", - &Signature::exact(vec![DataType::Float64], Volatility::Immutable), - &(Arc::new(return_type) as _), - &(Arc::new(make_partition_evaluator) as _), - ); + let dummy_window_udf = WindowUDF::from(SimpleWindowUDF::new()); let test_expr6 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::WindowUDF(Arc::new(dummy_window_udf.clone())), diff --git a/docs/source/library-user-guide/adding-udfs.md b/docs/source/library-user-guide/adding-udfs.md index c51e4de3236c..1f687f978f30 100644 --- a/docs/source/library-user-guide/adding-udfs.md +++ b/docs/source/library-user-guide/adding-udfs.md @@ -201,7 +201,8 @@ fn make_partition_evaluator() -> Result> { ### Registering a Window UDF -To register a Window UDF, you need to wrap the function implementation in a `WindowUDF` struct and then register it with the `SessionContext`. DataFusion provides the `create_udwf` helper functions to make this easier. +To register a Window UDF, you need to wrap the function implementation in a [`WindowUDF`] struct and then register it with the `SessionContext`. DataFusion provides the [`create_udwf`] helper functions to make this easier. +There is a lower level API with more functionality but is more complex, that is documented in [`advanced_udwf.rs`]. ```rust use datafusion::logical_expr::{Volatility, create_udwf}; @@ -218,6 +219,10 @@ let smooth_it = create_udwf( ); ``` +[`windowudf`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/struct.WindowUDF.html +[`create_udwf`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/fn.create_udwf.html +[`advanced_udwf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/advanced_udwf.rs + The `create_udwf` has five arguments to check: - The first argument is the name of the function. This is the name that will be used in SQL queries. From 1179a76567892b259c88f08243ee01f05c4c3d5c Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Thu, 4 Jan 2024 01:50:46 +0800 Subject: [PATCH 12/18] Minor: Introduce utils::hash for StructArray (#8552) * hash struct Signed-off-by: jayzhan211 * fmt Signed-off-by: jayzhan211 * row-wise hash Signed-off-by: jayzhan211 * fmt Signed-off-by: jayzhan211 * create hashes once Signed-off-by: jayzhan211 * cleanup Signed-off-by: jayzhan211 * add comment Signed-off-by: jayzhan211 --------- Signed-off-by: jayzhan211 --- datafusion/common/src/hash_utils.rs | 92 ++++++++++++++++++++++++++++- 1 file changed, 89 insertions(+), 3 deletions(-) diff --git a/datafusion/common/src/hash_utils.rs b/datafusion/common/src/hash_utils.rs index 9198461e00bf..5c36f41a6e42 100644 --- a/datafusion/common/src/hash_utils.rs +++ b/datafusion/common/src/hash_utils.rs @@ -27,7 +27,8 @@ use arrow::{downcast_dictionary_array, downcast_primitive_array}; use arrow_buffer::i256; use crate::cast::{ - as_boolean_array, as_generic_binary_array, as_primitive_array, as_string_array, + as_boolean_array, as_generic_binary_array, as_large_list_array, as_list_array, + as_primitive_array, as_string_array, as_struct_array, }; use crate::error::{DataFusionError, Result, _internal_err}; @@ -207,6 +208,35 @@ fn hash_dictionary( Ok(()) } +fn hash_struct_array( + array: &StructArray, + random_state: &RandomState, + hashes_buffer: &mut [u64], +) -> Result<()> { + let nulls = array.nulls(); + let num_columns = array.num_columns(); + + // Skip null columns + let valid_indices: Vec = if let Some(nulls) = nulls { + nulls.valid_indices().collect() + } else { + (0..num_columns).collect() + }; + + // Create hashes for each row that combines the hashes over all the column at that row. + // array.len() is the number of rows. + let mut values_hashes = vec![0u64; array.len()]; + create_hashes(array.columns(), random_state, &mut values_hashes)?; + + // Skip the null columns, nulls should get hash value 0. + for i in valid_indices { + let hash = &mut hashes_buffer[i]; + *hash = combine_hashes(*hash, values_hashes[i]); + } + + Ok(()) +} + fn hash_list_array( array: &GenericListArray, random_state: &RandomState, @@ -327,12 +357,16 @@ pub fn create_hashes<'a>( array => hash_dictionary(array, random_state, hashes_buffer, rehash)?, _ => unreachable!() } + DataType::Struct(_) => { + let array = as_struct_array(array)?; + hash_struct_array(array, random_state, hashes_buffer)?; + } DataType::List(_) => { - let array = as_list_array(array); + let array = as_list_array(array)?; hash_list_array(array, random_state, hashes_buffer)?; } DataType::LargeList(_) => { - let array = as_large_list_array(array); + let array = as_large_list_array(array)?; hash_list_array(array, random_state, hashes_buffer)?; } _ => { @@ -515,6 +549,58 @@ mod tests { assert_eq!(hashes[2], hashes[3]); } + #[test] + // Tests actual values of hashes, which are different if forcing collisions + #[cfg(not(feature = "force_hash_collisions"))] + fn create_hashes_for_struct_arrays() { + use arrow_buffer::Buffer; + + let boolarr = Arc::new(BooleanArray::from(vec![ + false, false, true, true, true, true, + ])); + let i32arr = Arc::new(Int32Array::from(vec![10, 10, 20, 20, 30, 31])); + + let struct_array = StructArray::from(( + vec![ + ( + Arc::new(Field::new("bool", DataType::Boolean, false)), + boolarr.clone() as ArrayRef, + ), + ( + Arc::new(Field::new("i32", DataType::Int32, false)), + i32arr.clone() as ArrayRef, + ), + ( + Arc::new(Field::new("i32", DataType::Int32, false)), + i32arr.clone() as ArrayRef, + ), + ( + Arc::new(Field::new("bool", DataType::Boolean, false)), + boolarr.clone() as ArrayRef, + ), + ], + Buffer::from(&[0b001011]), + )); + + assert!(struct_array.is_valid(0)); + assert!(struct_array.is_valid(1)); + assert!(struct_array.is_null(2)); + assert!(struct_array.is_valid(3)); + assert!(struct_array.is_null(4)); + assert!(struct_array.is_null(5)); + + let array = Arc::new(struct_array) as ArrayRef; + + let random_state = RandomState::with_seeds(0, 0, 0, 0); + let mut hashes = vec![0; array.len()]; + create_hashes(&[array], &random_state, &mut hashes).unwrap(); + assert_eq!(hashes[0], hashes[1]); + // same value but the third row ( hashes[2] ) is null + assert_ne!(hashes[2], hashes[3]); + // different values but both are null + assert_eq!(hashes[4], hashes[5]); + } + #[test] // Tests actual values of hashes, which are different if forcing collisions #[cfg(not(feature = "force_hash_collisions"))] From 93da699c0e9a6d60c075c252dcf537112b06996a Mon Sep 17 00:00:00 2001 From: comphead Date: Wed, 3 Jan 2024 13:58:11 -0800 Subject: [PATCH 13/18] [CI] Improve windows machine CI test time (#8730) * Test WIN64 CI * Test WIN64 CI * Test WIN64 CI * Test WIN64 CI * Adding incremental compilation * Adding codegen units * Try without opt-level * set opt level only for win machines * set opt level only for win machines. remove incremental compile * update comments --- .github/workflows/rust.yml | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index a541091e3a2b..622521a6fbc7 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -99,6 +99,14 @@ jobs: rust-version: stable - name: Run tests (excluding doctests) run: cargo test --lib --tests --bins --features avro,json,backtrace + env: + # do not produce debug symbols to keep memory usage down + # hardcoding other profile params to avoid profile override values + # More on Cargo profiles https://doc.rust-lang.org/cargo/reference/profiles.html?profile-settings#profile-settings + RUSTFLAGS: "-C debuginfo=0 -C opt-level=0 -C incremental=false -C codegen-units=256" + RUST_BACKTRACE: "1" + # avoid rust stack overflows on tpc-ds tests + RUST_MINSTACK: "3000000" - name: Verify Working Directory Clean run: git diff --exit-code @@ -290,6 +298,7 @@ jobs: # with a OS-dependent path. - name: Setup Rust toolchain run: | + rustup update stable rustup toolchain install stable rustup default stable rustup component add rustfmt @@ -302,9 +311,13 @@ jobs: cargo test --lib --tests --bins --all-features env: # do not produce debug symbols to keep memory usage down - RUSTFLAGS: "-C debuginfo=0" + # use higher optimization level to overcome Windows rust slowness for tpc-ds + # and speed builds: https://github.com/apache/arrow-datafusion/issues/8696 + # Cargo profile docs https://doc.rust-lang.org/cargo/reference/profiles.html?profile-settings#profile-settings + RUSTFLAGS: "-C debuginfo=0 -C opt-level=1 -C target-feature=+crt-static -C incremental=false -C codegen-units=256" RUST_BACKTRACE: "1" - + # avoid rust stack overflows on tpc-ds tests + RUST_MINSTACK: "3000000" macos: name: cargo test (mac) runs-on: macos-latest @@ -327,6 +340,7 @@ jobs: # with a OS-dependent path. - name: Setup Rust toolchain run: | + rustup update stable rustup toolchain install stable rustup default stable rustup component add rustfmt @@ -338,8 +352,12 @@ jobs: cargo test --lib --tests --bins --all-features env: # do not produce debug symbols to keep memory usage down - RUSTFLAGS: "-C debuginfo=0" + # hardcoding other profile params to avoid profile override values + # More on Cargo profiles https://doc.rust-lang.org/cargo/reference/profiles.html?profile-settings#profile-settings + RUSTFLAGS: "-C debuginfo=0 -C opt-level=0 -C incremental=false -C codegen-units=256" RUST_BACKTRACE: "1" + # avoid rust stack overflows on tpc-ds tests + RUST_MINSTACK: "3000000" test-datafusion-pyarrow: name: cargo test pyarrow (amd64) From ad4b7b7cfd4a2f93bbef3c2bff8a6ce65db24b53 Mon Sep 17 00:00:00 2001 From: yi wang <48236141+my-vegetable-has-exploded@users.noreply.github.com> Date: Thu, 4 Jan 2024 06:01:13 +0800 Subject: [PATCH 14/18] fix guarantees in allways_true of PruningPredicate (#8732) * fix: check guarantees in allways_true * Add test for allways_true * refine comment --------- Co-authored-by: Andrew Lamb --- .../datasource/physical_plan/parquet/mod.rs | 78 ++++++++++++------- .../core/src/physical_optimizer/pruning.rs | 7 +- 2 files changed, 53 insertions(+), 32 deletions(-) diff --git a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs index 76a6cc297b0e..9d81d8d083c2 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs @@ -1768,8 +1768,9 @@ mod tests { ); } - #[tokio::test] - async fn parquet_exec_metrics() { + /// Returns a string array with contents: + /// "[Foo, null, bar, bar, bar, bar, zzz]" + fn string_batch() -> RecordBatch { let c1: ArrayRef = Arc::new(StringArray::from(vec![ Some("Foo"), None, @@ -1781,9 +1782,15 @@ mod tests { ])); // batch1: c1(string) - let batch1 = create_batch(vec![("c1", c1.clone())]); + create_batch(vec![("c1", c1.clone())]) + } + + #[tokio::test] + async fn parquet_exec_metrics() { + // batch1: c1(string) + let batch1 = string_batch(); - // on + // c1 != 'bar' let filter = col("c1").not_eq(lit("bar")); // read/write them files: @@ -1812,20 +1819,10 @@ mod tests { #[tokio::test] async fn parquet_exec_display() { - let c1: ArrayRef = Arc::new(StringArray::from(vec![ - Some("Foo"), - None, - Some("bar"), - Some("bar"), - Some("bar"), - Some("bar"), - Some("zzz"), - ])); - // batch1: c1(string) - let batch1 = create_batch(vec![("c1", c1.clone())]); + let batch1 = string_batch(); - // on + // c1 != 'bar' let filter = col("c1").not_eq(lit("bar")); let rt = RoundTrip::new() @@ -1854,21 +1851,15 @@ mod tests { } #[tokio::test] - async fn parquet_exec_skip_empty_pruning() { - let c1: ArrayRef = Arc::new(StringArray::from(vec![ - Some("Foo"), - None, - Some("bar"), - Some("bar"), - Some("bar"), - Some("bar"), - Some("zzz"), - ])); - + async fn parquet_exec_has_no_pruning_predicate_if_can_not_prune() { // batch1: c1(string) - let batch1 = create_batch(vec![("c1", c1.clone())]); + let batch1 = string_batch(); - // filter is too complicated for pruning + // filter is too complicated for pruning (PruningPredicate code does not + // handle case expressions), so the pruning predicate will always be + // "true" + + // WHEN c1 != bar THEN true ELSE false END let filter = when(col("c1").not_eq(lit("bar")), lit(true)) .otherwise(lit(false)) .unwrap(); @@ -1879,7 +1870,7 @@ mod tests { .round_trip(vec![batch1]) .await; - // Should not contain a pruning predicate + // Should not contain a pruning predicate (since nothing can be pruned) let pruning_predicate = &rt.parquet_exec.pruning_predicate; assert!( pruning_predicate.is_none(), @@ -1892,6 +1883,33 @@ mod tests { assert_eq!(predicate.unwrap().to_string(), filter_phys.to_string()); } + #[tokio::test] + async fn parquet_exec_has_pruning_predicate_for_guarantees() { + // batch1: c1(string) + let batch1 = string_batch(); + + // part of the filter is too complicated for pruning (PruningPredicate code does not + // handle case expressions), but part (c1 = 'foo') can be used for bloom filtering, so + // should still have the pruning predicate. + + // c1 = 'foo' AND (WHEN c1 != bar THEN true ELSE false END) + let filter = col("c1").eq(lit("foo")).and( + when(col("c1").not_eq(lit("bar")), lit(true)) + .otherwise(lit(false)) + .unwrap(), + ); + + let rt = RoundTrip::new() + .with_predicate(filter.clone()) + .with_pushdown_predicate() + .round_trip(vec![batch1]) + .await; + + // Should have a pruning predicate + let pruning_predicate = &rt.parquet_exec.pruning_predicate; + assert!(pruning_predicate.is_some()); + } + /// returns the sum of all the metrics with the specified name /// the returned set. /// diff --git a/datafusion/core/src/physical_optimizer/pruning.rs b/datafusion/core/src/physical_optimizer/pruning.rs index fecbffdbb041..06cfc7282468 100644 --- a/datafusion/core/src/physical_optimizer/pruning.rs +++ b/datafusion/core/src/physical_optimizer/pruning.rs @@ -295,9 +295,12 @@ impl PruningPredicate { &self.predicate_expr } - /// Returns true if this pruning predicate is "always true" (aka will not prune anything) + /// Returns true if this pruning predicate can not prune anything. + /// + /// This happens if the predicate is a literal `true` and + /// literal_guarantees is empty. pub fn allways_true(&self) -> bool { - is_always_true(&self.predicate_expr) + is_always_true(&self.predicate_expr) && self.literal_guarantees.is_empty() } pub(crate) fn required_columns(&self) -> &RequiredColumns { From 881d03f72cddec7e1cd659ef0c748760c6177b1c Mon Sep 17 00:00:00 2001 From: Yang Jiang Date: Thu, 4 Jan 2024 06:02:38 +0800 Subject: [PATCH 15/18] [Minor] Avoid mem copy in generate window exprs (#8718) --- datafusion/expr/src/logical_plan/builder.rs | 4 +-- datafusion/expr/src/utils.rs | 30 ++++++++++----------- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index cfc052cfc14c..a684f3e97485 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -292,7 +292,7 @@ impl LogicalPlanBuilder { window_exprs: Vec, ) -> Result { let mut plan = input; - let mut groups = group_window_expr_by_sort_keys(&window_exprs)?; + let mut groups = group_window_expr_by_sort_keys(window_exprs)?; // To align with the behavior of PostgreSQL, we want the sort_keys sorted as same rule as PostgreSQL that first // we compare the sort key themselves and if one window's sort keys are a prefix of another // put the window with more sort keys first. so more deeply sorted plans gets nested further down as children. @@ -314,7 +314,7 @@ impl LogicalPlanBuilder { key_b.len().cmp(&key_a.len()) }); for (_, exprs) in groups { - let window_exprs = exprs.into_iter().cloned().collect::>(); + let window_exprs = exprs.into_iter().collect::>(); // Partition and sorting is done at physical level, see the EnforceDistribution // and EnforceSorting rules. plan = LogicalPlanBuilder::from(plan) diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index e3ecdf154e61..914b354d2950 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -575,14 +575,14 @@ pub fn compare_sort_expr( /// group a slice of window expression expr by their order by expressions pub fn group_window_expr_by_sort_keys( - window_expr: &[Expr], -) -> Result)>> { + window_expr: Vec, +) -> Result)>> { let mut result = vec![]; - window_expr.iter().try_for_each(|expr| match expr { - Expr::WindowFunction(WindowFunction{ partition_by, order_by, .. }) => { + window_expr.into_iter().try_for_each(|expr| match &expr { + Expr::WindowFunction( WindowFunction{ partition_by, order_by, .. }) => { let sort_key = generate_sort_key(partition_by, order_by)?; if let Some((_, values)) = result.iter_mut().find( - |group: &&mut (WindowSortKey, Vec<&Expr>)| matches!(group, (key, _) if *key == sort_key), + |group: &&mut (WindowSortKey, Vec)| matches!(group, (key, _) if *key == sort_key), ) { values.push(expr); } else { @@ -1239,8 +1239,8 @@ mod tests { #[test] fn test_group_window_expr_by_sort_keys_empty_case() -> Result<()> { - let result = group_window_expr_by_sort_keys(&[])?; - let expected: Vec<(WindowSortKey, Vec<&Expr>)> = vec![]; + let result = group_window_expr_by_sort_keys(vec![])?; + let expected: Vec<(WindowSortKey, Vec)> = vec![]; assert_eq!(expected, result); Ok(()) } @@ -1276,10 +1276,10 @@ mod tests { WindowFrame::new(false), )); let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()]; - let result = group_window_expr_by_sort_keys(exprs)?; + let result = group_window_expr_by_sort_keys(exprs.to_vec())?; let key = vec![]; - let expected: Vec<(WindowSortKey, Vec<&Expr>)> = - vec![(key, vec![&max1, &max2, &min3, &sum4])]; + let expected: Vec<(WindowSortKey, Vec)> = + vec![(key, vec![max1, max2, min3, sum4])]; assert_eq!(expected, result); Ok(()) } @@ -1320,7 +1320,7 @@ mod tests { )); // FIXME use as_ref let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()]; - let result = group_window_expr_by_sort_keys(exprs)?; + let result = group_window_expr_by_sort_keys(exprs.to_vec())?; let key1 = vec![(age_asc.clone(), false), (name_desc.clone(), false)]; let key2 = vec![]; @@ -1330,10 +1330,10 @@ mod tests { (created_at_desc, false), ]; - let expected: Vec<(WindowSortKey, Vec<&Expr>)> = vec![ - (key1, vec![&max1, &min3]), - (key2, vec![&max2]), - (key3, vec![&sum4]), + let expected: Vec<(WindowSortKey, Vec)> = vec![ + (key1, vec![max1, min3]), + (key2, vec![max2]), + (key3, vec![sum4]), ]; assert_eq!(expected, result); Ok(()) From ca260d99f17ef667b7f06d2da4a67255d27c94a9 Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Thu, 4 Jan 2024 06:14:23 +0800 Subject: [PATCH 16/18] support LargeList in array_repeat (#8725) --- .../physical-expr/src/array_expressions.rs | 16 +++++--- datafusion/sqllogictest/test_files/array.slt | 37 +++++++++++++++++++ 2 files changed, 47 insertions(+), 6 deletions(-) diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index aad021610fcb..15330af640ae 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -1233,7 +1233,11 @@ pub fn array_repeat(args: &[ArrayRef]) -> Result { match element.data_type() { DataType::List(_) => { let list_array = as_list_array(element)?; - general_list_repeat(list_array, count_array) + general_list_repeat::(list_array, count_array) + } + DataType::LargeList(_) => { + let list_array = as_large_list_array(element)?; + general_list_repeat::(list_array, count_array) } _ => general_repeat(element, count_array), } @@ -1302,8 +1306,8 @@ fn general_repeat(array: &ArrayRef, count_array: &Int64Array) -> Result [[[1, 2, 3], [1, 2, 3]], [], [[6]]] /// ) /// ``` -fn general_list_repeat( - list_array: &ListArray, +fn general_list_repeat( + list_array: &GenericListArray, count_array: &Int64Array, ) -> Result { let data_type = list_array.data_type(); @@ -1335,9 +1339,9 @@ fn general_list_repeat( let data = mutable.freeze(); let repeated_array = arrow_array::make_array(data); - let list_arr = ListArray::try_new( + let list_arr = GenericListArray::::try_new( Arc::new(Field::new("item", value_type.clone(), true)), - OffsetBuffer::from_lengths(vec![original_data.len(); count]), + OffsetBuffer::::from_lengths(vec![original_data.len(); count]), repeated_array, None, )?; @@ -1354,7 +1358,7 @@ fn general_list_repeat( Ok(Arc::new(ListArray::try_new( Arc::new(Field::new("item", data_type.to_owned(), true)), - OffsetBuffer::from_lengths(lengths), + OffsetBuffer::::from_lengths(lengths), values, None, )?)) diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index a3b2c8cdf1e9..7cee615a5729 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -1957,6 +1957,15 @@ select ---- [[1], [1], [1], [1], [1]] [[1.1, 2.2, 3.3], [1.1, 2.2, 3.3], [1.1, 2.2, 3.3]] [[, ], [, ], [, ]] [[[1, 2], [3, 4]], [[1, 2], [3, 4]]] +query ???? +select + array_repeat(arrow_cast([1], 'LargeList(Int64)'), 5), + array_repeat(arrow_cast([1.1, 2.2, 3.3], 'LargeList(Float64)'), 3), + array_repeat(arrow_cast([null, null], 'LargeList(Null)'), 3), + array_repeat(arrow_cast([[1, 2], [3, 4]], 'LargeList(List(Int64))'), 2); +---- +[[1], [1], [1], [1], [1]] [[1.1, 2.2, 3.3], [1.1, 2.2, 3.3], [1.1, 2.2, 3.3]] [[, ], [, ], [, ]] [[[1, 2], [3, 4]], [[1, 2], [3, 4]]] + # array_repeat with columns #1 statement ok @@ -1967,6 +1976,16 @@ AS VALUES (3, 2, 2.2, 'rust', make_array(7)), (0, 3, 3.3, 'datafusion', make_array(8, 9)); +statement ok +CREATE TABLE large_array_repeat_table +AS SELECT + column1, + column2, + column3, + column4, + arrow_cast(column5, 'LargeList(Int64)') as column5 +FROM array_repeat_table; + query ?????? select array_repeat(column2, column1), @@ -1982,9 +2001,27 @@ from array_repeat_table; [2, 2, 2] [2.2, 2.2, 2.2] [rust, rust, rust] [[7], [7], [7]] [2, 2, 2] [[1], [1], [1]] [] [] [] [] [3, 3, 3] [] +query ?????? +select + array_repeat(column2, column1), + array_repeat(column3, column1), + array_repeat(column4, column1), + array_repeat(column5, column1), + array_repeat(column2, 3), + array_repeat(make_array(1), column1) +from large_array_repeat_table; +---- +[1] [1.1] [a] [[4, 5, 6]] [1, 1, 1] [[1]] +[, ] [, ] [, ] [, ] [, , ] [[1], [1]] +[2, 2, 2] [2.2, 2.2, 2.2] [rust, rust, rust] [[7], [7], [7]] [2, 2, 2] [[1], [1], [1]] +[] [] [] [] [3, 3, 3] [] + statement ok drop table array_repeat_table; +statement ok +drop table large_array_repeat_table; + ## array_concat (aliases: `array_cat`, `list_concat`, `list_cat`) # array_concat error From e6b9f527d3a1823887b32a8d3dfca85ea21b204c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Berkay=20=C5=9Eahin?= <124376117+berkaysynnada@users.noreply.github.com> Date: Thu, 4 Jan 2024 09:18:04 +0300 Subject: [PATCH 17/18] ctrl+c termination (#8739) --- datafusion-cli/Cargo.lock | 10 ++++++++++ datafusion-cli/Cargo.toml | 2 +- datafusion-cli/src/exec.rs | 13 ++++++++++--- 3 files changed, 21 insertions(+), 4 deletions(-) diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index e85e8b1a9edb..252b00ca0adc 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -3060,6 +3060,15 @@ dependencies = [ "digest", ] +[[package]] +name = "signal-hook-registry" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8229b473baa5980ac72ef434c4415e70c4b5e71b423043adb4ba059f89c99a1" +dependencies = [ + "libc", +] + [[package]] name = "siphasher" version = "0.3.11" @@ -3371,6 +3380,7 @@ dependencies = [ "num_cpus", "parking_lot", "pin-project-lite", + "signal-hook-registry", "socket2", "tokio-macros", "windows-sys 0.48.0", diff --git a/datafusion-cli/Cargo.toml b/datafusion-cli/Cargo.toml index e1ddba4cad1a..eab7c8e0d1f8 100644 --- a/datafusion-cli/Cargo.toml +++ b/datafusion-cli/Cargo.toml @@ -45,7 +45,7 @@ parking_lot = { version = "0.12" } parquet = { version = "49.0.0", default-features = false } regex = "1.8" rustyline = "11.0" -tokio = { version = "1.24", features = ["macros", "rt", "rt-multi-thread", "sync", "parking_lot"] } +tokio = { version = "1.24", features = ["macros", "rt", "rt-multi-thread", "sync", "parking_lot", "signal"] } url = "2.2" [dev-dependencies] diff --git a/datafusion-cli/src/exec.rs b/datafusion-cli/src/exec.rs index ba9aa2e69aa6..2320a8c314cf 100644 --- a/datafusion-cli/src/exec.rs +++ b/datafusion-cli/src/exec.rs @@ -45,6 +45,7 @@ use datafusion::sql::{parser::DFParser, sqlparser::dialect::dialect_from_str}; use object_store::ObjectStore; use rustyline::error::ReadlineError; use rustyline::Editor; +use tokio::signal; use url::Url; /// run and execute SQL statements and commands, against a context with the given print options @@ -165,9 +166,15 @@ pub async fn exec_from_repl( } Ok(line) => { rl.add_history_entry(line.trim_end())?; - match exec_and_print(ctx, print_options, line).await { - Ok(_) => {} - Err(err) => eprintln!("{err}"), + tokio::select! { + res = exec_and_print(ctx, print_options, line) => match res { + Ok(_) => {} + Err(err) => eprintln!("{err}"), + }, + _ = signal::ctrl_c() => { + println!("^C"); + continue + }, } // dialect might have changed rl.helper_mut().unwrap().set_dialect( From 819d3577872a082f2aea7a68ae83d68534049662 Mon Sep 17 00:00:00 2001 From: Mustafa Akur <106137913+mustafasrepo@users.noreply.github.com> Date: Thu, 4 Jan 2024 09:39:02 +0300 Subject: [PATCH 18/18] Add support for functional dependency for ROW_NUMBER window function. (#8737) * Add primary key support for row_number window function * Add comments, minor changes * Add new test * Review --------- Co-authored-by: Mehmet Ozan Kabak --- datafusion/expr/src/logical_plan/plan.rs | 59 ++++++++++++++++--- datafusion/sqllogictest/test_files/window.slt | 40 ++++++++++++- 2 files changed, 91 insertions(+), 8 deletions(-) diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index c0c520c4e211..93a38fb40df5 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -25,7 +25,9 @@ use std::sync::Arc; use super::dml::CopyTo; use super::DdlStatement; use crate::dml::CopyOptions; -use crate::expr::{Alias, Exists, InSubquery, Placeholder, Sort as SortExpr}; +use crate::expr::{ + Alias, Exists, InSubquery, Placeholder, Sort as SortExpr, WindowFunction, +}; use crate::expr_rewriter::{create_col_from_scalar_expr, normalize_cols}; use crate::logical_plan::display::{GraphvizVisitor, IndentVisitor}; use crate::logical_plan::extension::UserDefinedLogicalNode; @@ -36,9 +38,9 @@ use crate::utils::{ split_conjunction, }; use crate::{ - build_join_schema, expr_vec_fmt, BinaryExpr, CreateMemoryTable, CreateView, Expr, - ExprSchemable, LogicalPlanBuilder, Operator, TableProviderFilterPushDown, - TableSource, + build_join_schema, expr_vec_fmt, BinaryExpr, BuiltInWindowFunction, + CreateMemoryTable, CreateView, Expr, ExprSchemable, LogicalPlanBuilder, Operator, + TableProviderFilterPushDown, TableSource, WindowFunctionDefinition, }; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; @@ -48,9 +50,10 @@ use datafusion_common::tree_node::{ }; use datafusion_common::{ aggregate_functional_dependencies, internal_err, plan_err, Column, Constraints, - DFField, DFSchema, DFSchemaRef, DataFusionError, Dependency, FunctionalDependencies, - OwnedTableReference, ParamValues, Result, UnnestOptions, + DFField, DFSchema, DFSchemaRef, DataFusionError, Dependency, FunctionalDependence, + FunctionalDependencies, OwnedTableReference, ParamValues, Result, UnnestOptions, }; + // backwards compatibility pub use datafusion_common::display::{PlanType, StringifiedPlan, ToStringifiedPlan}; pub use datafusion_common::{JoinConstraint, JoinType}; @@ -1967,7 +1970,9 @@ pub struct Window { impl Window { /// Create a new window operator. pub fn try_new(window_expr: Vec, input: Arc) -> Result { - let mut window_fields: Vec = input.schema().fields().clone(); + let fields = input.schema().fields(); + let input_len = fields.len(); + let mut window_fields = fields.clone(); window_fields.extend_from_slice(&exprlist_to_fields(window_expr.iter(), &input)?); let metadata = input.schema().metadata().clone(); @@ -1976,6 +1981,46 @@ impl Window { input.schema().functional_dependencies().clone(); window_func_dependencies.extend_target_indices(window_fields.len()); + // Since we know that ROW_NUMBER outputs will be unique (i.e. it consists + // of consecutive numbers per partition), we can represent this fact with + // functional dependencies. + let mut new_dependencies = window_expr + .iter() + .enumerate() + .filter_map(|(idx, expr)| { + if let Expr::WindowFunction(WindowFunction { + // Function is ROW_NUMBER + fun: + WindowFunctionDefinition::BuiltInWindowFunction( + BuiltInWindowFunction::RowNumber, + ), + partition_by, + .. + }) = expr + { + // When there is no PARTITION BY, row number will be unique + // across the entire table. + if partition_by.is_empty() { + return Some(idx + input_len); + } + } + None + }) + .map(|idx| { + FunctionalDependence::new(vec![idx], vec![], false) + .with_mode(Dependency::Single) + }) + .collect::>(); + + if !new_dependencies.is_empty() { + for dependence in new_dependencies.iter_mut() { + dependence.target_indices = (0..window_fields.len()).collect(); + } + // Add the dependency introduced because of ROW_NUMBER window function to the functional dependency + let new_deps = FunctionalDependencies::new(new_dependencies); + window_func_dependencies.extend(new_deps); + } + Ok(Window { input, window_expr, diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index aa083290b4f4..7d6d59201396 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -3832,4 +3832,42 @@ select row_number() over (partition by 1 order by 1) rn, from (select 1 a union all select 2 a) x; ---- 1 1 1 1 1 1 -2 1 1 2 2 1 \ No newline at end of file +2 1 1 2 2 1 + +# when partition by expression is empty row number result will be unique. +query TII +SELECT * +FROM (SELECT c1, c2, ROW_NUMBER() OVER() as rn + FROM aggregate_test_100 + LIMIT 5) +GROUP BY rn +ORDER BY rn; +---- +c 2 1 +d 5 2 +b 1 3 +a 1 4 +b 5 5 + +# when partition by expression is constant row number result will be unique. +query TII +SELECT * +FROM (SELECT c1, c2, ROW_NUMBER() OVER(PARTITION BY 3) as rn + FROM aggregate_test_100 + LIMIT 5) +GROUP BY rn +ORDER BY rn; +---- +c 2 1 +d 5 2 +b 1 3 +a 1 4 +b 5 5 + +statement error DataFusion error: Error during planning: Projection references non-aggregate values: Expression aggregate_test_100.c1 could not be resolved from available columns: rn +SELECT * +FROM (SELECT c1, c2, ROW_NUMBER() OVER(PARTITION BY c1) as rn + FROM aggregate_test_100 + LIMIT 5) +GROUP BY rn +ORDER BY rn;