diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 622521a6fbc7..ae6c1ee56129 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -106,7 +106,7 @@ jobs: RUSTFLAGS: "-C debuginfo=0 -C opt-level=0 -C incremental=false -C codegen-units=256" RUST_BACKTRACE: "1" # avoid rust stack overflows on tpc-ds tests - RUST_MINSTACK: "3000000" + RUST_MIN_STACK: "3000000" - name: Verify Working Directory Clean run: git diff --exit-code @@ -310,14 +310,13 @@ jobs: cd datafusion-cli cargo test --lib --tests --bins --all-features env: - # do not produce debug symbols to keep memory usage down - # use higher optimization level to overcome Windows rust slowness for tpc-ds - # and speed builds: https://github.com/apache/arrow-datafusion/issues/8696 - # Cargo profile docs https://doc.rust-lang.org/cargo/reference/profiles.html?profile-settings#profile-settings - RUSTFLAGS: "-C debuginfo=0 -C opt-level=1 -C target-feature=+crt-static -C incremental=false -C codegen-units=256" + # Minimize producing debug symbols to keep memory usage down + # Set debuginfo=line-tables-only as debuginfo=0 causes immensely slow build + # See for more details: https://github.com/rust-lang/rust/issues/119560 + RUSTFLAGS: "-C debuginfo=line-tables-only" RUST_BACKTRACE: "1" # avoid rust stack overflows on tpc-ds tests - RUST_MINSTACK: "3000000" + RUST_MIN_STACK: "3000000" macos: name: cargo test (mac) runs-on: macos-latest @@ -357,7 +356,7 @@ jobs: RUSTFLAGS: "-C debuginfo=0 -C opt-level=0 -C incremental=false -C codegen-units=256" RUST_BACKTRACE: "1" # avoid rust stack overflows on tpc-ds tests - RUST_MINSTACK: "3000000" + RUST_MIN_STACK: "3000000" test-datafusion-pyarrow: name: cargo test pyarrow (amd64) diff --git a/datafusion-cli/src/print_options.rs b/datafusion-cli/src/print_options.rs index b8594352b585..b382eb34f62c 100644 --- a/datafusion-cli/src/print_options.rs +++ b/datafusion-cli/src/print_options.rs @@ -141,7 +141,8 @@ impl PrintOptions { let mut row_count = 0_usize; let mut with_header = true; - while let Some(Ok(batch)) = stream.next().await { + while let Some(maybe_batch) = stream.next().await { + let batch = maybe_batch?; row_count += batch.num_rows(); self.format.print_batches( &mut writer, diff --git a/datafusion-examples/examples/advanced_udf.rs b/datafusion-examples/examples/advanced_udf.rs index 6ebf88a0b671..d530b9abe030 100644 --- a/datafusion-examples/examples/advanced_udf.rs +++ b/datafusion-examples/examples/advanced_udf.rs @@ -40,6 +40,7 @@ use std::sync::Arc; /// the power of the second argument `a^b`. /// /// To do so, we must implement the `ScalarUDFImpl` trait. +#[derive(Debug, Clone)] struct PowUdf { signature: Signature, aliases: Vec, diff --git a/datafusion-examples/examples/advanced_udwf.rs b/datafusion-examples/examples/advanced_udwf.rs index 91869d80a41a..f46031434fc9 100644 --- a/datafusion-examples/examples/advanced_udwf.rs +++ b/datafusion-examples/examples/advanced_udwf.rs @@ -34,6 +34,7 @@ use datafusion_expr::{ /// a function `partition_evaluator` that returns the `MyPartitionEvaluator` instance. /// /// To do so, we must implement the `WindowUDFImpl` trait. +#[derive(Debug, Clone)] struct SmoothItUdf { signature: Signature, } diff --git a/datafusion/common/src/column.rs b/datafusion/common/src/column.rs index 2e729c128e73..f0edc7175948 100644 --- a/datafusion/common/src/column.rs +++ b/datafusion/common/src/column.rs @@ -17,6 +17,7 @@ //! Column +use crate::error::_schema_err; use crate::utils::{parse_identifiers_normalized, quote_identifier}; use crate::{DFSchema, DataFusionError, OwnedTableReference, Result, SchemaError}; use std::collections::HashSet; @@ -211,13 +212,13 @@ impl Column { } } - Err(DataFusionError::SchemaError(SchemaError::FieldNotFound { + _schema_err!(SchemaError::FieldNotFound { field: Box::new(Column::new(self.relation.clone(), self.name)), valid_fields: schemas .iter() .flat_map(|s| s.fields().iter().map(|f| f.qualified_column())) .collect(), - })) + }) } /// Qualify column if not done yet. @@ -299,23 +300,21 @@ impl Column { } // If not due to USING columns then due to ambiguous column name - return Err(DataFusionError::SchemaError( - SchemaError::AmbiguousReference { - field: Column::new_unqualified(self.name), - }, - )); + return _schema_err!(SchemaError::AmbiguousReference { + field: Column::new_unqualified(self.name), + }); } } } - Err(DataFusionError::SchemaError(SchemaError::FieldNotFound { + _schema_err!(SchemaError::FieldNotFound { field: Box::new(self), valid_fields: schemas .iter() .flat_map(|s| s.iter()) .flat_map(|s| s.fields().iter().map(|f| f.qualified_column())) .collect(), - })) + }) } } diff --git a/datafusion/common/src/dfschema.rs b/datafusion/common/src/dfschema.rs index d6e4490cec4c..85b97aac037d 100644 --- a/datafusion/common/src/dfschema.rs +++ b/datafusion/common/src/dfschema.rs @@ -26,6 +26,7 @@ use std::sync::Arc; use crate::error::{ unqualified_field_not_found, DataFusionError, Result, SchemaError, _plan_err, + _schema_err, }; use crate::{ field_not_found, Column, FunctionalDependencies, OwnedTableReference, TableReference, @@ -141,11 +142,9 @@ impl DFSchema { if let Some(qualifier) = field.qualifier() { qualified_names.insert((qualifier, field.name())); } else if !unqualified_names.insert(field.name()) { - return Err(DataFusionError::SchemaError( - SchemaError::DuplicateUnqualifiedField { - name: field.name().to_string(), - }, - )); + return _schema_err!(SchemaError::DuplicateUnqualifiedField { + name: field.name().to_string(), + }); } } @@ -159,14 +158,12 @@ impl DFSchema { qualified_names.sort(); for (qualifier, name) in &qualified_names { if unqualified_names.contains(name) { - return Err(DataFusionError::SchemaError( - SchemaError::AmbiguousReference { - field: Column { - relation: Some((*qualifier).clone()), - name: name.to_string(), - }, - }, - )); + return _schema_err!(SchemaError::AmbiguousReference { + field: Column { + relation: Some((*qualifier).clone()), + name: name.to_string(), + } + }); } } Ok(Self { @@ -230,9 +227,9 @@ impl DFSchema { for field in other_schema.fields() { // skip duplicate columns let duplicated_field = match field.qualifier() { - Some(q) => self.field_with_name(Some(q), field.name()).is_ok(), + Some(q) => self.has_column_with_qualified_name(q, field.name()), // for unqualified columns, check as unqualified name - None => self.field_with_unqualified_name(field.name()).is_ok(), + None => self.has_column_with_unqualified_name(field.name()), }; if !duplicated_field { self.fields.push(field.clone()); @@ -392,14 +389,12 @@ impl DFSchema { if fields_without_qualifier.len() == 1 { Ok(fields_without_qualifier[0]) } else { - Err(DataFusionError::SchemaError( - SchemaError::AmbiguousReference { - field: Column { - relation: None, - name: name.to_string(), - }, + _schema_err!(SchemaError::AmbiguousReference { + field: Column { + relation: None, + name: name.to_string(), }, - )) + }) } } } diff --git a/datafusion/common/src/error.rs b/datafusion/common/src/error.rs index e58faaa15096..978938809c1b 100644 --- a/datafusion/common/src/error.rs +++ b/datafusion/common/src/error.rs @@ -82,7 +82,9 @@ pub enum DataFusionError { Configuration(String), /// This error happens with schema-related errors, such as schema inference not possible /// and non-unique column names. - SchemaError(SchemaError), + /// 2nd argument is for optional backtrace + /// Boxing the optional backtrace to prevent + SchemaError(SchemaError, Box>), /// Error returned during execution of the query. /// Examples include files not found, errors in parsing certain types. Execution(String), @@ -125,34 +127,6 @@ pub enum SchemaError { }, } -/// Create a "field not found" DataFusion::SchemaError -pub fn field_not_found>( - qualifier: Option, - name: &str, - schema: &DFSchema, -) -> DataFusionError { - DataFusionError::SchemaError(SchemaError::FieldNotFound { - field: Box::new(Column::new(qualifier, name)), - valid_fields: schema - .fields() - .iter() - .map(|f| f.qualified_column()) - .collect(), - }) -} - -/// Convenience wrapper over [`field_not_found`] for when there is no qualifier -pub fn unqualified_field_not_found(name: &str, schema: &DFSchema) -> DataFusionError { - DataFusionError::SchemaError(SchemaError::FieldNotFound { - field: Box::new(Column::new_unqualified(name)), - valid_fields: schema - .fields() - .iter() - .map(|f| f.qualified_column()) - .collect(), - }) -} - impl Display for SchemaError { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { @@ -298,7 +272,7 @@ impl Display for DataFusionError { write!(f, "IO error: {desc}") } DataFusionError::SQL(ref desc, ref backtrace) => { - let backtrace = backtrace.clone().unwrap_or("".to_owned()); + let backtrace: String = backtrace.clone().unwrap_or("".to_owned()); write!(f, "SQL error: {desc:?}{backtrace}") } DataFusionError::Configuration(ref desc) => { @@ -314,8 +288,10 @@ impl Display for DataFusionError { DataFusionError::Plan(ref desc) => { write!(f, "Error during planning: {desc}") } - DataFusionError::SchemaError(ref desc) => { - write!(f, "Schema error: {desc}") + DataFusionError::SchemaError(ref desc, ref backtrace) => { + let backtrace: &str = + &backtrace.as_ref().clone().unwrap_or("".to_owned()); + write!(f, "Schema error: {desc}{backtrace}") } DataFusionError::Execution(ref desc) => { write!(f, "Execution error: {desc}") @@ -356,7 +332,7 @@ impl Error for DataFusionError { DataFusionError::Internal(_) => None, DataFusionError::Configuration(_) => None, DataFusionError::Plan(_) => None, - DataFusionError::SchemaError(e) => Some(e), + DataFusionError::SchemaError(e, _) => Some(e), DataFusionError::Execution(_) => None, DataFusionError::ResourcesExhausted(_) => None, DataFusionError::External(e) => Some(e.as_ref()), @@ -556,12 +532,63 @@ macro_rules! arrow_err { }; } +// Exposes a macro to create `DataFusionError::SchemaError` with optional backtrace +#[macro_export] +macro_rules! schema_datafusion_err { + ($ERR:expr) => { + DataFusionError::SchemaError( + $ERR, + Box::new(Some(DataFusionError::get_back_trace())), + ) + }; +} + +// Exposes a macro to create `Err(DataFusionError::SchemaError)` with optional backtrace +#[macro_export] +macro_rules! schema_err { + ($ERR:expr) => { + Err(DataFusionError::SchemaError( + $ERR, + Box::new(Some(DataFusionError::get_back_trace())), + )) + }; +} + // To avoid compiler error when using macro in the same crate: // macros from the current crate cannot be referred to by absolute paths pub use internal_datafusion_err as _internal_datafusion_err; pub use internal_err as _internal_err; pub use not_impl_err as _not_impl_err; pub use plan_err as _plan_err; +pub use schema_err as _schema_err; + +/// Create a "field not found" DataFusion::SchemaError +pub fn field_not_found>( + qualifier: Option, + name: &str, + schema: &DFSchema, +) -> DataFusionError { + schema_datafusion_err!(SchemaError::FieldNotFound { + field: Box::new(Column::new(qualifier, name)), + valid_fields: schema + .fields() + .iter() + .map(|f| f.qualified_column()) + .collect(), + }) +} + +/// Convenience wrapper over [`field_not_found`] for when there is no qualifier +pub fn unqualified_field_not_found(name: &str, schema: &DFSchema) -> DataFusionError { + schema_datafusion_err!(SchemaError::FieldNotFound { + field: Box::new(Column::new_unqualified(name)), + valid_fields: schema + .fields() + .iter() + .map(|f| f.qualified_column()) + .collect(), + }) +} #[cfg(test)] mod test { diff --git a/datafusion/common/src/hash_utils.rs b/datafusion/common/src/hash_utils.rs index 5c36f41a6e42..8dcc00ca1c29 100644 --- a/datafusion/common/src/hash_utils.rs +++ b/datafusion/common/src/hash_utils.rs @@ -214,22 +214,19 @@ fn hash_struct_array( hashes_buffer: &mut [u64], ) -> Result<()> { let nulls = array.nulls(); - let num_columns = array.num_columns(); + let row_len = array.len(); - // Skip null columns - let valid_indices: Vec = if let Some(nulls) = nulls { + let valid_row_indices: Vec = if let Some(nulls) = nulls { nulls.valid_indices().collect() } else { - (0..num_columns).collect() + (0..row_len).collect() }; // Create hashes for each row that combines the hashes over all the column at that row. - // array.len() is the number of rows. - let mut values_hashes = vec![0u64; array.len()]; + let mut values_hashes = vec![0u64; row_len]; create_hashes(array.columns(), random_state, &mut values_hashes)?; - // Skip the null columns, nulls should get hash value 0. - for i in valid_indices { + for i in valid_row_indices { let hash = &mut hashes_buffer[i]; *hash = combine_hashes(*hash, values_hashes[i]); } @@ -601,6 +598,39 @@ mod tests { assert_eq!(hashes[4], hashes[5]); } + #[test] + // Tests actual values of hashes, which are different if forcing collisions + #[cfg(not(feature = "force_hash_collisions"))] + fn create_hashes_for_struct_arrays_more_column_than_row() { + let struct_array = StructArray::from(vec![ + ( + Arc::new(Field::new("bool", DataType::Boolean, false)), + Arc::new(BooleanArray::from(vec![false, false])) as ArrayRef, + ), + ( + Arc::new(Field::new("i32-1", DataType::Int32, false)), + Arc::new(Int32Array::from(vec![10, 10])) as ArrayRef, + ), + ( + Arc::new(Field::new("i32-2", DataType::Int32, false)), + Arc::new(Int32Array::from(vec![10, 10])) as ArrayRef, + ), + ( + Arc::new(Field::new("i32-3", DataType::Int32, false)), + Arc::new(Int32Array::from(vec![10, 10])) as ArrayRef, + ), + ]); + + assert!(struct_array.is_valid(0)); + assert!(struct_array.is_valid(1)); + + let array = Arc::new(struct_array) as ArrayRef; + let random_state = RandomState::with_seeds(0, 0, 0, 0); + let mut hashes = vec![0; array.len()]; + create_hashes(&[array], &random_state, &mut hashes).unwrap(); + assert_eq!(hashes[0], hashes[1]); + } + #[test] // Tests actual values of hashes, which are different if forcing collisions #[cfg(not(feature = "force_hash_collisions"))] diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 5a8c706e32cd..f15f1e9ba6fb 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -802,6 +802,7 @@ impl DataFrame { /// Executes this DataFrame and returns a stream over a single partition /// + /// # Example /// ``` /// # use datafusion::prelude::*; /// # use datafusion::error::Result; @@ -813,6 +814,11 @@ impl DataFrame { /// # Ok(()) /// # } /// ``` + /// + /// # Aborting Execution + /// + /// Dropping the stream will abort the execution of the query, and free up + /// any allocated resources pub async fn execute_stream(self) -> Result { let task_ctx = Arc::new(self.task_ctx()); let plan = self.create_physical_plan().await?; @@ -841,6 +847,7 @@ impl DataFrame { /// Executes this DataFrame and returns one stream per partition. /// + /// # Example /// ``` /// # use datafusion::prelude::*; /// # use datafusion::error::Result; @@ -852,6 +859,10 @@ impl DataFrame { /// # Ok(()) /// # } /// ``` + /// # Aborting Execution + /// + /// Dropping the stream will abort the execution of the query, and free up + /// any allocated resources pub async fn execute_stream_partitioned( self, ) -> Result> { @@ -1175,7 +1186,7 @@ impl DataFrame { let field_to_rename = match self.plan.schema().field_from_column(&old_column) { Ok(field) => field, // no-op if field not found - Err(DataFusionError::SchemaError(SchemaError::FieldNotFound { .. })) => { + Err(DataFusionError::SchemaError(SchemaError::FieldNotFound { .. }, _)) => { return Ok(self) } Err(err) => return Err(err), diff --git a/datafusion/core/src/lib.rs b/datafusion/core/src/lib.rs index b3ebbc6e3637..8fc724a22443 100644 --- a/datafusion/core/src/lib.rs +++ b/datafusion/core/src/lib.rs @@ -285,37 +285,38 @@ //! //! ### Logical Plans //! Logical planning yields [`LogicalPlan`] nodes and [`Expr`] -//! expressions which are [`Schema`] aware and represent statements +//! representing expressions which are [`Schema`] aware and represent statements //! independent of how they are physically executed. //! A [`LogicalPlan`] is a Directed Acyclic Graph (DAG) of other //! [`LogicalPlan`]s, each potentially containing embedded [`Expr`]s. //! -//! Examples of working with and executing `Expr`s can be found in the +//! [`Expr`]s can be rewritten using the [`TreeNode`] API and simplified using +//! [`ExprSimplifier`]. Examples of working with and executing `Expr`s can be found in the //! [`expr_api`.rs] example //! +//! [`TreeNode`]: datafusion_common::tree_node::TreeNode +//! [`ExprSimplifier`]: crate::optimizer::simplify_expressions::ExprSimplifier //! [`expr_api`.rs]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/expr_api.rs //! //! ### Physical Plans //! //! An [`ExecutionPlan`] (sometimes referred to as a "physical plan") //! is a plan that can be executed against data. It a DAG of other -//! [`ExecutionPlan`]s each potentially containing expressions of the -//! following types: +//! [`ExecutionPlan`]s each potentially containing expressions that implement the +//! [`PhysicalExpr`] trait. //! -//! 1. [`PhysicalExpr`]: Scalar functions -//! -//! 2. [`AggregateExpr`]: Aggregate functions -//! -//! 2. [`WindowExpr`]: Window functions -//! -//! Compared to a [`LogicalPlan`], an [`ExecutionPlan`] has concrete +//! Compared to a [`LogicalPlan`], an [`ExecutionPlan`] has additional concrete //! information about how to perform calculations (e.g. hash vs merge //! join), and how data flows during execution (e.g. partitioning and //! sortedness). //! +//! [cp_solver] performs range propagation analysis on [`PhysicalExpr`]s and +//! [`PruningPredicate`] can prove certain boolean [`PhysicalExpr`]s used for +//! filtering can never be `true` using additional statistical information. +//! +//! [cp_solver]: crate::physical_expr::intervals::cp_solver +//! [`PruningPredicate`]: crate::physical_optimizer::pruning::PruningPredicate //! [`PhysicalExpr`]: crate::physical_plan::PhysicalExpr -//! [`AggregateExpr`]: crate::physical_plan::AggregateExpr -//! [`WindowExpr`]: crate::physical_plan::WindowExpr //! //! ## Execution //! diff --git a/datafusion/core/src/physical_optimizer/enforce_distribution.rs b/datafusion/core/src/physical_optimizer/enforce_distribution.rs index e6dd7b21ed31..351801884485 100644 --- a/datafusion/core/src/physical_optimizer/enforce_distribution.rs +++ b/datafusion/core/src/physical_optimizer/enforce_distribution.rs @@ -926,9 +926,8 @@ fn add_hash_on_top( hash_exprs: Vec>, n_target: usize, ) -> Result { - let partition_count = input.plan.output_partitioning().partition_count(); // Early return if hash repartition is unnecessary - if n_target == partition_count && n_target == 1 { + if n_target == 1 { return Ok(input); } @@ -1191,13 +1190,13 @@ fn ensure_distribution( ) .map( |(mut child, requirement, required_input_ordering, would_benefit, maintains)| { - // Don't need to apply when the returned row count is not greater than 1: + // Don't need to apply when the returned row count is not greater than batch size let num_rows = child.plan.statistics()?.num_rows; let repartition_beneficial_stats = if num_rows.is_exact().unwrap_or(false) { num_rows .get_value() .map(|value| value > &batch_size) - .unwrap_or(true) + .unwrap() // safe to unwrap since is_exact() is true } else { true }; @@ -1208,21 +1207,22 @@ fn ensure_distribution( // Unless partitioning doesn't increase the partition count, it is not beneficial: && child.plan.output_partitioning().partition_count() < target_partitions; + // When `repartition_file_scans` is set, attempt to increase + // parallelism at the source. + if repartition_file_scans && repartition_beneficial_stats { + if let Some(new_child) = + child.plan.repartitioned(target_partitions, config)? + { + child.plan = new_child; + } + } + if enable_round_robin // Operator benefits from partitioning (e.g. filter): && (would_benefit && repartition_beneficial_stats) // Unless partitioning doesn't increase the partition count, it is not beneficial: && child.plan.output_partitioning().partition_count() < target_partitions { - // When `repartition_file_scans` is set, attempt to increase - // parallelism at the source. - if repartition_file_scans { - if let Some(new_child) = - child.plan.repartitioned(target_partitions, config)? - { - child.plan = new_child; - } - } // Increase parallelism by adding round-robin repartitioning // on top of the operator. Note that we only do this if the // partition count is not already equal to the desired partition @@ -1367,17 +1367,10 @@ impl DistributionContext { fn update_children(mut self) -> Result { for child_context in self.children_nodes.iter_mut() { - child_context.distribution_connection = match child_context.plan.as_any() { - plan_any if plan_any.is::() => matches!( - plan_any - .downcast_ref::() - .unwrap() - .partitioning(), - Partitioning::RoundRobinBatch(_) | Partitioning::Hash(_, _) - ), - plan_any - if plan_any.is::() - || plan_any.is::() => + child_context.distribution_connection = match &child_context.plan { + plan if is_repartition(plan) + || is_coalesce_partitions(plan) + || is_sort_preserving_merge(plan) => { true } @@ -3876,14 +3869,14 @@ pub(crate) mod tests { "RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2", "AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[]", // Plan already has two partitions - "ParquetExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e]", + "ParquetExec: file_groups={2 groups: [[x:0..100], [y:0..100]]}, projection=[a, b, c, d, e]", ]; let expected_csv = [ "AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[]", "RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2", "AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[]", // Plan already has two partitions - "CsvExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], has_header=false", + "CsvExec: file_groups={2 groups: [[x:0..100], [y:0..100]]}, projection=[a, b, c, d, e], has_header=false", ]; assert_optimized!(expected_parquet, plan_parquet, true, false, 2, true, 10); diff --git a/datafusion/core/src/physical_optimizer/pruning.rs b/datafusion/core/src/physical_optimizer/pruning.rs index 06cfc7282468..0cbbaf2bf6cd 100644 --- a/datafusion/core/src/physical_optimizer/pruning.rs +++ b/datafusion/core/src/physical_optimizer/pruning.rs @@ -45,12 +45,23 @@ use datafusion_physical_expr::utils::{collect_columns, Guarantee, LiteralGuarant use datafusion_physical_expr::{expressions as phys_expr, PhysicalExprRef}; use log::trace; -/// Interface to pass statistics (min/max/nulls) information to [`PruningPredicate`]. +/// A source of runtime statistical information to [`PruningPredicate`]s. /// -/// Returns statistics for containers / files as Arrow [`ArrayRef`], so the -/// evaluation happens once on a single `RecordBatch`, amortizing the overhead -/// of evaluating of the predicate. This is important when pruning 1000s of -/// containers which often happens in analytic systems. +/// # Supported Information +/// +/// 1. Minimum and maximum values for columns +/// +/// 2. Null counts for columns +/// +/// 3. Whether the values in a column are contained in a set of literals +/// +/// # Vectorized Interface +/// +/// Information for containers / files are returned as Arrow [`ArrayRef`], so +/// the evaluation happens once on a single `RecordBatch`, which amortizes the +/// overhead of evaluating the predicate. This is important when pruning 1000s +/// of containers which often happens in analytic systems that have 1000s of +/// potential files to consider. /// /// For example, for the following three files with a single column `a`: /// ```text @@ -83,8 +94,11 @@ pub trait PruningStatistics { /// Note: the returned array must contain [`Self::num_containers`] rows fn max_values(&self, column: &Column) -> Option; - /// Return the number of containers (e.g. row groups) being - /// pruned with these statistics (the number of rows in each returned array) + /// Return the number of containers (e.g. Row Groups) being pruned with + /// these statistics. + /// + /// This value corresponds to the size of the [`ArrayRef`] returned by + /// [`Self::min_values`], [`Self::max_values`], and [`Self::null_counts`]. fn num_containers(&self) -> usize; /// Return the number of null values for the named column as an @@ -95,13 +109,11 @@ pub trait PruningStatistics { /// Note: the returned array must contain [`Self::num_containers`] rows fn null_counts(&self, column: &Column) -> Option; - /// Returns an array where each row represents information known about - /// the `values` contained in a column. + /// Returns [`BooleanArray`] where each row represents information known + /// about specific literal `values` in a column. /// - /// This API is designed to be used along with [`LiteralGuarantee`] to prove - /// that predicates can not possibly evaluate to `true` and thus prune - /// containers. For example, Parquet Bloom Filters can prove that values are - /// not present. + /// For example, Parquet Bloom Filters implement this API to communicate + /// that `values` are known not to be present in a Row Group. /// /// The returned array has one row for each container, with the following /// meanings: @@ -120,28 +132,34 @@ pub trait PruningStatistics { ) -> Option; } -/// Evaluates filter expressions on statistics such as min/max values and null -/// counts, attempting to prove a "container" (e.g. Parquet Row Group) can be -/// skipped without reading the actual data, potentially leading to significant -/// performance improvements. +/// Used to prove that arbitrary predicates (boolean expression) can not +/// possibly evaluate to `true` given information about a column provided by +/// [`PruningStatistics`]. +/// +/// `PruningPredicate` analyzes filter expressions using statistics such as +/// min/max values and null counts, attempting to prove a "container" (e.g. +/// Parquet Row Group) can be skipped without reading the actual data, +/// potentially leading to significant performance improvements. /// -/// For example, [`PruningPredicate`]s are used to prune Parquet Row Groups -/// based on the min/max values found in the Parquet metadata. If the -/// `PruningPredicate` can guarantee that no rows in the Row Group match the -/// filter, the entire Row Group is skipped during query execution. +/// For example, `PruningPredicate`s are used to prune Parquet Row Groups based +/// on the min/max values found in the Parquet metadata. If the +/// `PruningPredicate` can prove that the filter can never evaluate to `true` +/// for any row in the Row Group, the entire Row Group is skipped during query +/// execution. /// -/// The `PruningPredicate` API is general, allowing it to be used for pruning -/// other types of containers (e.g. files) based on statistics that may be -/// known from external catalogs (e.g. Delta Lake) or other sources. Thus it -/// supports: +/// The `PruningPredicate` API is designed to be general, so it can used for +/// pruning other types of containers (e.g. files) based on statistics that may +/// be known from external catalogs (e.g. Delta Lake) or other sources. /// -/// 1. Arbitrary expressions expressions (including user defined functions) +/// It currently supports: +/// +/// 1. Arbitrary expressions (including user defined functions) /// /// 2. Vectorized evaluation (provide more than one set of statistics at a time) /// so it is suitable for pruning 1000s of containers. /// -/// 3. Anything that implements the [`PruningStatistics`] trait, not just -/// Parquet metadata. +/// 3. Any source of information that implements the [`PruningStatistics`] trait +/// (not just Parquet metadata). /// /// # Example /// @@ -154,7 +172,8 @@ pub trait PruningStatistics { /// C: {x_min = 5, x_max = 8} /// ``` /// -/// Applying the `PruningPredicate` will concludes that `A` can be pruned: +/// `PruningPredicate` will conclude that the rows in container `A` can never +/// be true (as the maximum value is only `4`), so it can be pruned: /// /// ```text /// A: false (no rows could possibly match x = 5) @@ -2017,54 +2036,52 @@ mod tests { DataType::Decimal128(9, 2), true, )])); - // s1 > 5 - let expr = col("s1").gt(lit(ScalarValue::Decimal128(Some(500), 9, 2))); - let expr = logical2physical(&expr, &schema); - // If the data is written by spark, the physical data type is INT32 in the parquet - // So we use the INT32 type of statistic. - let statistics = TestStatistics::new().with( - "s1", - ContainerStats::new_i32( - vec![Some(0), Some(4), None, Some(3)], // min - vec![Some(5), Some(6), Some(4), None], // max + + prune_with_expr( + // s1 > 5 + col("s1").gt(lit(ScalarValue::Decimal128(Some(500), 9, 2))), + &schema, + // If the data is written by spark, the physical data type is INT32 in the parquet + // So we use the INT32 type of statistic. + &TestStatistics::new().with( + "s1", + ContainerStats::new_i32( + vec![Some(0), Some(4), None, Some(3)], // min + vec![Some(5), Some(6), Some(4), None], // max + ), ), + &[false, true, false, true], ); - let p = PruningPredicate::try_new(expr, schema.clone()).unwrap(); - let result = p.prune(&statistics).unwrap(); - let expected = vec![false, true, false, true]; - assert_eq!(result, expected); - // with cast column to other type - let expr = cast(col("s1"), DataType::Decimal128(14, 3)) - .gt(lit(ScalarValue::Decimal128(Some(5000), 14, 3))); - let expr = logical2physical(&expr, &schema); - let statistics = TestStatistics::new().with( - "s1", - ContainerStats::new_i32( - vec![Some(0), Some(4), None, Some(3)], // min - vec![Some(5), Some(6), Some(4), None], // max + prune_with_expr( + // with cast column to other type + cast(col("s1"), DataType::Decimal128(14, 3)) + .gt(lit(ScalarValue::Decimal128(Some(5000), 14, 3))), + &schema, + &TestStatistics::new().with( + "s1", + ContainerStats::new_i32( + vec![Some(0), Some(4), None, Some(3)], // min + vec![Some(5), Some(6), Some(4), None], // max + ), ), + &[false, true, false, true], ); - let p = PruningPredicate::try_new(expr, schema.clone()).unwrap(); - let result = p.prune(&statistics).unwrap(); - let expected = vec![false, true, false, true]; - assert_eq!(result, expected); - // with try cast column to other type - let expr = try_cast(col("s1"), DataType::Decimal128(14, 3)) - .gt(lit(ScalarValue::Decimal128(Some(5000), 14, 3))); - let expr = logical2physical(&expr, &schema); - let statistics = TestStatistics::new().with( - "s1", - ContainerStats::new_i32( - vec![Some(0), Some(4), None, Some(3)], // min - vec![Some(5), Some(6), Some(4), None], // max + prune_with_expr( + // with try cast column to other type + try_cast(col("s1"), DataType::Decimal128(14, 3)) + .gt(lit(ScalarValue::Decimal128(Some(5000), 14, 3))), + &schema, + &TestStatistics::new().with( + "s1", + ContainerStats::new_i32( + vec![Some(0), Some(4), None, Some(3)], // min + vec![Some(5), Some(6), Some(4), None], // max + ), ), + &[false, true, false, true], ); - let p = PruningPredicate::try_new(expr, schema).unwrap(); - let result = p.prune(&statistics).unwrap(); - let expected = vec![false, true, false, true]; - assert_eq!(result, expected); // decimal(18,2) let schema = Arc::new(Schema::new(vec![Field::new( @@ -2072,22 +2089,21 @@ mod tests { DataType::Decimal128(18, 2), true, )])); - // s1 > 5 - let expr = col("s1").gt(lit(ScalarValue::Decimal128(Some(500), 18, 2))); - let expr = logical2physical(&expr, &schema); - // If the data is written by spark, the physical data type is INT64 in the parquet - // So we use the INT32 type of statistic. - let statistics = TestStatistics::new().with( - "s1", - ContainerStats::new_i64( - vec![Some(0), Some(4), None, Some(3)], // min - vec![Some(5), Some(6), Some(4), None], // max + prune_with_expr( + // s1 > 5 + col("s1").gt(lit(ScalarValue::Decimal128(Some(500), 18, 2))), + &schema, + // If the data is written by spark, the physical data type is INT64 in the parquet + // So we use the INT32 type of statistic. + &TestStatistics::new().with( + "s1", + ContainerStats::new_i64( + vec![Some(0), Some(4), None, Some(3)], // min + vec![Some(5), Some(6), Some(4), None], // max + ), ), + &[false, true, false, true], ); - let p = PruningPredicate::try_new(expr, schema).unwrap(); - let result = p.prune(&statistics).unwrap(); - let expected = vec![false, true, false, true]; - assert_eq!(result, expected); // decimal(23,2) let schema = Arc::new(Schema::new(vec![Field::new( @@ -2095,22 +2111,22 @@ mod tests { DataType::Decimal128(23, 2), true, )])); - // s1 > 5 - let expr = col("s1").gt(lit(ScalarValue::Decimal128(Some(500), 23, 2))); - let expr = logical2physical(&expr, &schema); - let statistics = TestStatistics::new().with( - "s1", - ContainerStats::new_decimal128( - vec![Some(0), Some(400), None, Some(300)], // min - vec![Some(500), Some(600), Some(400), None], // max - 23, - 2, + + prune_with_expr( + // s1 > 5 + col("s1").gt(lit(ScalarValue::Decimal128(Some(500), 23, 2))), + &schema, + &TestStatistics::new().with( + "s1", + ContainerStats::new_decimal128( + vec![Some(0), Some(400), None, Some(300)], // min + vec![Some(500), Some(600), Some(400), None], // max + 23, + 2, + ), ), + &[false, true, false, true], ); - let p = PruningPredicate::try_new(expr, schema).unwrap(); - let result = p.prune(&statistics).unwrap(); - let expected = vec![false, true, false, true]; - assert_eq!(result, expected); } #[test] @@ -2120,10 +2136,6 @@ mod tests { Field::new("s2", DataType::Int32, true), ])); - // Prune using s2 > 5 - let expr = col("s2").gt(lit(5)); - let expr = logical2physical(&expr, &schema); - let statistics = TestStatistics::new().with( "s2", ContainerStats::new_i32( @@ -2131,53 +2143,50 @@ mod tests { vec![Some(5), Some(6), None, None], // max ), ); + prune_with_expr( + // Prune using s2 > 5 + col("s2").gt(lit(5)), + &schema, + &statistics, + // s2 [0, 5] ==> no rows should pass + // s2 [4, 6] ==> some rows could pass + // No stats for s2 ==> some rows could pass + // s2 [3, None] (null max) ==> some rows could pass + &[false, true, true, true], + ); - // s2 [0, 5] ==> no rows should pass - // s2 [4, 6] ==> some rows could pass - // No stats for s2 ==> some rows could pass - // s2 [3, None] (null max) ==> some rows could pass - - let p = PruningPredicate::try_new(expr, schema.clone()).unwrap(); - let result = p.prune(&statistics).unwrap(); - let expected = vec![false, true, true, true]; - assert_eq!(result, expected); - - // filter with cast - let expr = cast(col("s2"), DataType::Int64).gt(lit(ScalarValue::Int64(Some(5)))); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema).unwrap(); - let result = p.prune(&statistics).unwrap(); - let expected = vec![false, true, true, true]; - assert_eq!(result, expected); + prune_with_expr( + // filter with cast + cast(col("s2"), DataType::Int64).gt(lit(ScalarValue::Int64(Some(5)))), + &schema, + &statistics, + &[false, true, true, true], + ); } #[test] fn prune_not_eq_data() { let schema = Arc::new(Schema::new(vec![Field::new("s1", DataType::Utf8, true)])); - // Prune using s2 != 'M' - let expr = col("s1").not_eq(lit("M")); - let expr = logical2physical(&expr, &schema); - - let statistics = TestStatistics::new().with( - "s1", - ContainerStats::new_utf8( - vec![Some("A"), Some("A"), Some("N"), Some("M"), None, Some("A")], // min - vec![Some("Z"), Some("L"), Some("Z"), Some("M"), None, None], // max + prune_with_expr( + // Prune using s2 != 'M' + col("s1").not_eq(lit("M")), + &schema, + &TestStatistics::new().with( + "s1", + ContainerStats::new_utf8( + vec![Some("A"), Some("A"), Some("N"), Some("M"), None, Some("A")], // min + vec![Some("Z"), Some("L"), Some("Z"), Some("M"), None, None], // max + ), ), + // s1 [A, Z] ==> might have values that pass predicate + // s1 [A, L] ==> all rows pass the predicate + // s1 [N, Z] ==> all rows pass the predicate + // s1 [M, M] ==> all rows do not pass the predicate + // No stats for s2 ==> some rows could pass + // s2 [3, None] (null max) ==> some rows could pass + &[true, true, true, false, true, true], ); - - // s1 [A, Z] ==> might have values that pass predicate - // s1 [A, L] ==> all rows pass the predicate - // s1 [N, Z] ==> all rows pass the predicate - // s1 [M, M] ==> all rows do not pass the predicate - // No stats for s2 ==> some rows could pass - // s2 [3, None] (null max) ==> some rows could pass - - let p = PruningPredicate::try_new(expr, schema).unwrap(); - let result = p.prune(&statistics).unwrap(); - let expected = vec![true, true, true, false, true, true]; - assert_eq!(result, expected); } /// Creates setup for boolean chunk pruning @@ -2216,69 +2225,75 @@ mod tests { fn prune_bool_const_expr() { let (schema, statistics, _, _) = bool_setup(); - // true - let expr = lit(true); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema.clone()).unwrap(); - let result = p.prune(&statistics).unwrap(); - assert_eq!(result, vec![true, true, true, true, true]); + prune_with_expr( + // true + lit(true), + &schema, + &statistics, + &[true, true, true, true, true], + ); - // false - // constant literals that do NOT refer to any columns are currently not evaluated at all, hence the result is - // "all true" - let expr = lit(false); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema).unwrap(); - let result = p.prune(&statistics).unwrap(); - assert_eq!(result, vec![true, true, true, true, true]); + prune_with_expr( + // false + // constant literals that do NOT refer to any columns are currently not evaluated at all, hence the result is + // "all true" + lit(false), + &schema, + &statistics, + &[true, true, true, true, true], + ); } #[test] fn prune_bool_column() { let (schema, statistics, expected_true, _) = bool_setup(); - // b1 - let expr = col("b1"); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema).unwrap(); - let result = p.prune(&statistics).unwrap(); - assert_eq!(result, expected_true); + prune_with_expr( + // b1 + col("b1"), + &schema, + &statistics, + &expected_true, + ); } #[test] fn prune_bool_not_column() { let (schema, statistics, _, expected_false) = bool_setup(); - // !b1 - let expr = col("b1").not(); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema).unwrap(); - let result = p.prune(&statistics).unwrap(); - assert_eq!(result, expected_false); + prune_with_expr( + // !b1 + col("b1").not(), + &schema, + &statistics, + &expected_false, + ); } #[test] fn prune_bool_column_eq_true() { let (schema, statistics, expected_true, _) = bool_setup(); - // b1 = true - let expr = col("b1").eq(lit(true)); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema).unwrap(); - let result = p.prune(&statistics).unwrap(); - assert_eq!(result, expected_true); + prune_with_expr( + // b1 = true + col("b1").eq(lit(true)), + &schema, + &statistics, + &expected_true, + ); } #[test] fn prune_bool_not_column_eq_true() { let (schema, statistics, _, expected_false) = bool_setup(); - // !b1 = true - let expr = col("b1").not().eq(lit(true)); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema).unwrap(); - let result = p.prune(&statistics).unwrap(); - assert_eq!(result, expected_false); + prune_with_expr( + // !b1 = true + col("b1").not().eq(lit(true)), + &schema, + &statistics, + &expected_false, + ); } /// Creates a setup for chunk pruning, modeling a int32 column "i" @@ -2313,21 +2328,18 @@ mod tests { // i [-11, -1] ==> no rows can pass (not keep) // i [NULL, NULL] ==> unknown (must keep) // i [1, NULL] ==> unknown (must keep) - let expected_ret = vec![true, true, false, true, true]; + let expected_ret = &[true, true, false, true, true]; // i > 0 - let expr = col("i").gt(lit(0)); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema.clone()).unwrap(); - let result = p.prune(&statistics).unwrap(); - assert_eq!(result, expected_ret); + prune_with_expr(col("i").gt(lit(0)), &schema, &statistics, expected_ret); // -i < 0 - let expr = Expr::Negative(Box::new(col("i"))).lt(lit(0)); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema).unwrap(); - let result = p.prune(&statistics).unwrap(); - assert_eq!(result, expected_ret); + prune_with_expr( + Expr::Negative(Box::new(col("i"))).lt(lit(0)), + &schema, + &statistics, + expected_ret, + ); } #[test] @@ -2340,21 +2352,23 @@ mod tests { // i [-11, -1] ==> all rows must pass (must keep) // i [NULL, NULL] ==> unknown (must keep) // i [1, NULL] ==> no rows can pass (not keep) - let expected_ret = vec![true, false, true, true, false]; + let expected_ret = &[true, false, true, true, false]; - // i <= 0 - let expr = col("i").lt_eq(lit(0)); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema.clone()).unwrap(); - let result = p.prune(&statistics).unwrap(); - assert_eq!(result, expected_ret); + prune_with_expr( + // i <= 0 + col("i").lt_eq(lit(0)), + &schema, + &statistics, + expected_ret, + ); - // -i >= 0 - let expr = Expr::Negative(Box::new(col("i"))).gt_eq(lit(0)); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema).unwrap(); - let result = p.prune(&statistics).unwrap(); - assert_eq!(result, expected_ret); + prune_with_expr( + // -i >= 0 + Expr::Negative(Box::new(col("i"))).gt_eq(lit(0)), + &schema, + &statistics, + expected_ret, + ); } #[test] @@ -2367,37 +2381,39 @@ mod tests { // i [-11, -1] ==> no rows could pass in theory (conservatively keep) // i [NULL, NULL] ==> unknown (must keep) // i [1, NULL] ==> no rows can pass (conservatively keep) - let expected_ret = vec![true, true, true, true, true]; + let expected_ret = &[true, true, true, true, true]; - // cast(i as utf8) <= 0 - let expr = cast(col("i"), DataType::Utf8).lt_eq(lit("0")); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema.clone()).unwrap(); - let result = p.prune(&statistics).unwrap(); - assert_eq!(result, expected_ret); + prune_with_expr( + // cast(i as utf8) <= 0 + cast(col("i"), DataType::Utf8).lt_eq(lit("0")), + &schema, + &statistics, + expected_ret, + ); - // try_cast(i as utf8) <= 0 - let expr = try_cast(col("i"), DataType::Utf8).lt_eq(lit("0")); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema.clone()).unwrap(); - let result = p.prune(&statistics).unwrap(); - assert_eq!(result, expected_ret); + prune_with_expr( + // try_cast(i as utf8) <= 0 + try_cast(col("i"), DataType::Utf8).lt_eq(lit("0")), + &schema, + &statistics, + expected_ret, + ); - // cast(-i as utf8) >= 0 - let expr = - cast(Expr::Negative(Box::new(col("i"))), DataType::Utf8).gt_eq(lit("0")); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema.clone()).unwrap(); - let result = p.prune(&statistics).unwrap(); - assert_eq!(result, expected_ret); + prune_with_expr( + // cast(-i as utf8) >= 0 + cast(Expr::Negative(Box::new(col("i"))), DataType::Utf8).gt_eq(lit("0")), + &schema, + &statistics, + expected_ret, + ); - // try_cast(-i as utf8) >= 0 - let expr = - try_cast(Expr::Negative(Box::new(col("i"))), DataType::Utf8).gt_eq(lit("0")); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema).unwrap(); - let result = p.prune(&statistics).unwrap(); - assert_eq!(result, expected_ret); + prune_with_expr( + // try_cast(-i as utf8) >= 0 + try_cast(Expr::Negative(Box::new(col("i"))), DataType::Utf8).gt_eq(lit("0")), + &schema, + &statistics, + expected_ret, + ); } #[test] @@ -2410,14 +2426,15 @@ mod tests { // i [-11, -1] ==> no rows can pass (not keep) // i [NULL, NULL] ==> unknown (must keep) // i [1, NULL] ==> no rows can pass (not keep) - let expected_ret = vec![true, false, false, true, false]; + let expected_ret = &[true, false, false, true, false]; - // i = 0 - let expr = col("i").eq(lit(0)); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema).unwrap(); - let result = p.prune(&statistics).unwrap(); - assert_eq!(result, expected_ret); + prune_with_expr( + // i = 0 + col("i").eq(lit(0)), + &schema, + &statistics, + expected_ret, + ); } #[test] @@ -2430,19 +2447,21 @@ mod tests { // i [-11, -1] ==> no rows can pass (not keep) // i [NULL, NULL] ==> unknown (must keep) // i [1, NULL] ==> no rows can pass (not keep) - let expected_ret = vec![true, false, false, true, false]; + let expected_ret = &[true, false, false, true, false]; - let expr = cast(col("i"), DataType::Int64).eq(lit(0i64)); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema.clone()).unwrap(); - let result = p.prune(&statistics).unwrap(); - assert_eq!(result, expected_ret); + prune_with_expr( + cast(col("i"), DataType::Int64).eq(lit(0i64)), + &schema, + &statistics, + expected_ret, + ); - let expr = try_cast(col("i"), DataType::Int64).eq(lit(0i64)); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema).unwrap(); - let result = p.prune(&statistics).unwrap(); - assert_eq!(result, expected_ret); + prune_with_expr( + try_cast(col("i"), DataType::Int64).eq(lit(0i64)), + &schema, + &statistics, + expected_ret, + ); } #[test] @@ -2458,13 +2477,14 @@ mod tests { // i [-11, -1] ==> no rows can pass (could keep) // i [NULL, NULL] ==> unknown (keep) // i [1, NULL] ==> no rows can pass (could keep) - let expected_ret = vec![true, true, true, true, true]; + let expected_ret = &[true, true, true, true, true]; - let expr = cast(col("i"), DataType::Utf8).eq(lit("0")); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema).unwrap(); - let result = p.prune(&statistics).unwrap(); - assert_eq!(result, expected_ret); + prune_with_expr( + cast(col("i"), DataType::Utf8).eq(lit("0")), + &schema, + &statistics, + expected_ret, + ); } #[test] @@ -2477,21 +2497,23 @@ mod tests { // i [-11, -1] ==> no rows can pass (not keep) // i [NULL, NULL] ==> unknown (must keep) // i [1, NULL] ==> all rows must pass (must keep) - let expected_ret = vec![true, true, false, true, true]; + let expected_ret = &[true, true, false, true, true]; - // i > -1 - let expr = col("i").gt(lit(-1)); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema.clone()).unwrap(); - let result = p.prune(&statistics).unwrap(); - assert_eq!(result, expected_ret); + prune_with_expr( + // i > -1 + col("i").gt(lit(-1)), + &schema, + &statistics, + expected_ret, + ); - // -i < 1 - let expr = Expr::Negative(Box::new(col("i"))).lt(lit(1)); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema).unwrap(); - let result = p.prune(&statistics).unwrap(); - assert_eq!(result, expected_ret); + prune_with_expr( + // -i < 1 + Expr::Negative(Box::new(col("i"))).lt(lit(1)), + &schema, + &statistics, + expected_ret, + ); } #[test] @@ -2500,14 +2522,15 @@ mod tests { // Expression "i IS NULL" when there are no null statistics, // should all be kept - let expected_ret = vec![true, true, true, true, true]; + let expected_ret = &[true, true, true, true, true]; - // i IS NULL, no null statistics - let expr = col("i").is_null(); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema.clone()).unwrap(); - let result = p.prune(&statistics).unwrap(); - assert_eq!(result, expected_ret); + prune_with_expr( + // i IS NULL, no null statistics + col("i").is_null(), + &schema, + &statistics, + expected_ret, + ); // provide null counts for each column let statistics = statistics.with_null_counts( @@ -2521,51 +2544,55 @@ mod tests { ], ); - let expected_ret = vec![false, true, true, true, false]; + let expected_ret = &[false, true, true, true, false]; - // i IS NULL, with actual null statistcs - let expr = col("i").is_null(); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema).unwrap(); - let result = p.prune(&statistics).unwrap(); - assert_eq!(result, expected_ret); + prune_with_expr( + // i IS NULL, with actual null statistcs + col("i").is_null(), + &schema, + &statistics, + expected_ret, + ); } #[test] fn prune_cast_column_scalar() { // The data type of column i is INT32 let (schema, statistics) = int32_setup(); - let expected_ret = vec![true, true, false, true, true]; + let expected_ret = &[true, true, false, true, true]; - // i > int64(0) - let expr = col("i").gt(cast(lit(ScalarValue::Int64(Some(0))), DataType::Int32)); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema.clone()).unwrap(); - let result = p.prune(&statistics).unwrap(); - assert_eq!(result, expected_ret); + prune_with_expr( + // i > int64(0) + col("i").gt(cast(lit(ScalarValue::Int64(Some(0))), DataType::Int32)), + &schema, + &statistics, + expected_ret, + ); - // cast(i as int64) > int64(0) - let expr = cast(col("i"), DataType::Int64).gt(lit(ScalarValue::Int64(Some(0)))); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema.clone()).unwrap(); - let result = p.prune(&statistics).unwrap(); - assert_eq!(result, expected_ret); + prune_with_expr( + // cast(i as int64) > int64(0) + cast(col("i"), DataType::Int64).gt(lit(ScalarValue::Int64(Some(0)))), + &schema, + &statistics, + expected_ret, + ); - // try_cast(i as int64) > int64(0) - let expr = - try_cast(col("i"), DataType::Int64).gt(lit(ScalarValue::Int64(Some(0)))); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema.clone()).unwrap(); - let result = p.prune(&statistics).unwrap(); - assert_eq!(result, expected_ret); + prune_with_expr( + // try_cast(i as int64) > int64(0) + try_cast(col("i"), DataType::Int64).gt(lit(ScalarValue::Int64(Some(0)))), + &schema, + &statistics, + expected_ret, + ); - // `-cast(i as int64) < 0` convert to `cast(i as int64) > -0` - let expr = Expr::Negative(Box::new(cast(col("i"), DataType::Int64))) - .lt(lit(ScalarValue::Int64(Some(0)))); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema).unwrap(); - let result = p.prune(&statistics).unwrap(); - assert_eq!(result, expected_ret); + prune_with_expr( + // `-cast(i as int64) < 0` convert to `cast(i as int64) > -0` + Expr::Negative(Box::new(cast(col("i"), DataType::Int64))) + .lt(lit(ScalarValue::Int64(Some(0)))), + &schema, + &statistics, + expected_ret, + ); } #[test] @@ -2721,7 +2748,7 @@ mod tests { &schema, &statistics, // rule out containers ('false) where we know foo is not present - vec![true, false, true, true, false, true, true, false, true], + &[true, false, true, true, false, true, true, false, true], ); // s1 = 'bar' @@ -2730,7 +2757,7 @@ mod tests { &schema, &statistics, // rule out containers where we know bar is not present - vec![true, true, true, false, false, false, true, true, true], + &[true, true, true, false, false, false, true, true, true], ); // s1 = 'baz' (unknown value) @@ -2739,7 +2766,7 @@ mod tests { &schema, &statistics, // can't rule out anything - vec![true, true, true, true, true, true, true, true, true], + &[true, true, true, true, true, true, true, true, true], ); // s1 = 'foo' AND s1 = 'bar' @@ -2750,7 +2777,7 @@ mod tests { // logically this predicate can't possibly be true (the column can't // take on both values) but we could rule it out if the stats tell // us that both values are not present - vec![true, true, true, true, true, true, true, true, true], + &[true, true, true, true, true, true, true, true, true], ); // s1 = 'foo' OR s1 = 'bar' @@ -2759,7 +2786,7 @@ mod tests { &schema, &statistics, // can rule out containers that we know contain neither foo nor bar - vec![true, true, true, true, true, true, false, false, false], + &[true, true, true, true, true, true, false, false, false], ); // s1 = 'foo' OR s1 = 'baz' @@ -2768,7 +2795,7 @@ mod tests { &schema, &statistics, // can't rule out anything container - vec![true, true, true, true, true, true, true, true, true], + &[true, true, true, true, true, true, true, true, true], ); // s1 = 'foo' OR s1 = 'bar' OR s1 = 'baz' @@ -2781,7 +2808,7 @@ mod tests { &statistics, // can rule out any containers based on knowledge of s1 and `foo`, // `bar` and (`foo`, `bar`) - vec![true, true, true, true, true, true, true, true, true], + &[true, true, true, true, true, true, true, true, true], ); // s1 != foo @@ -2790,7 +2817,7 @@ mod tests { &schema, &statistics, // rule out containers we know for sure only contain foo - vec![false, true, true, false, true, true, false, true, true], + &[false, true, true, false, true, true, false, true, true], ); // s1 != bar @@ -2799,7 +2826,7 @@ mod tests { &schema, &statistics, // rule out when we know for sure s1 has the value bar - vec![false, false, false, true, true, true, true, true, true], + &[false, false, false, true, true, true, true, true, true], ); // s1 != foo AND s1 != bar @@ -2810,7 +2837,7 @@ mod tests { &schema, &statistics, // can rule out any container where we know s1 does not have either 'foo' or 'bar' - vec![true, true, true, false, false, false, true, true, true], + &[true, true, true, false, false, false, true, true, true], ); // s1 != foo AND s1 != bar AND s1 != baz @@ -2822,7 +2849,7 @@ mod tests { &schema, &statistics, // can't rule out any container based on knowledge of s1,s2 - vec![true, true, true, true, true, true, true, true, true], + &[true, true, true, true, true, true, true, true, true], ); // s1 != foo OR s1 != bar @@ -2833,7 +2860,7 @@ mod tests { &schema, &statistics, // cant' rule out anything based on contains information - vec![true, true, true, true, true, true, true, true, true], + &[true, true, true, true, true, true, true, true, true], ); // s1 != foo OR s1 != bar OR s1 != baz @@ -2845,7 +2872,7 @@ mod tests { &schema, &statistics, // cant' rule out anything based on contains information - vec![true, true, true, true, true, true, true, true, true], + &[true, true, true, true, true, true, true, true, true], ); } @@ -2907,7 +2934,7 @@ mod tests { &schema, &statistics, // rule out containers where we know s1 is not present - vec![true, false, true, true, false, true, true, false, true], + &[true, false, true, true, false, true, true, false, true], ); // s1 = 'foo' OR s2 = 'bar' @@ -2917,7 +2944,7 @@ mod tests { &schema, &statistics, // can't rule out any container (would need to prove that s1 != foo AND s2 != bar) - vec![true, true, true, true, true, true, true, true, true], + &[true, true, true, true, true, true, true, true, true], ); // s1 = 'foo' AND s2 != 'bar' @@ -2928,7 +2955,7 @@ mod tests { // can only rule out container where we know either: // 1. s1 doesn't have the value 'foo` or // 2. s2 has only the value of 'bar' - vec![false, false, false, true, false, true, true, false, true], + &[false, false, false, true, false, true, true, false, true], ); // s1 != 'foo' AND s2 != 'bar' @@ -2941,7 +2968,7 @@ mod tests { // Can rule out any container where we know either // 1. s1 has only the value 'foo' // 2. s2 has only the value 'bar' - vec![false, false, false, false, true, true, false, true, true], + &[false, false, false, false, true, true, false, true, true], ); // s1 != 'foo' AND (s2 = 'bar' OR s2 = 'baz') @@ -2953,7 +2980,7 @@ mod tests { &statistics, // Can rule out any container where we know s1 has only the value // 'foo'. Can't use knowledge of s2 and bar to rule out anything - vec![false, true, true, false, true, true, false, true, true], + &[false, true, true, false, true, true, false, true, true], ); // s1 like '%foo%bar%' @@ -2962,7 +2989,7 @@ mod tests { &schema, &statistics, // cant rule out anything with information we know - vec![true, true, true, true, true, true, true, true, true], + &[true, true, true, true, true, true, true, true, true], ); // s1 like '%foo%bar%' AND s2 = 'bar' @@ -2973,7 +3000,7 @@ mod tests { &schema, &statistics, // can rule out any container where we know s2 does not have the value 'bar' - vec![true, true, true, false, false, false, true, true, true], + &[true, true, true, false, false, false, true, true, true], ); // s1 like '%foo%bar%' OR s2 = 'bar' @@ -2983,7 +3010,7 @@ mod tests { &statistics, // can't rule out anything (we would have to prove that both the // like and the equality must be false) - vec![true, true, true, true, true, true, true, true, true], + &[true, true, true, true, true, true, true, true, true], ); } @@ -3055,7 +3082,7 @@ mod tests { // 1. 0 is outside the min/max range of i // 1. s does not contain foo // (range is false, and contained is false) - vec![true, false, true, false, false, false, true, false, true], + &[true, false, true, false, false, false, true, false, true], ); // i = 0 and s != 'foo' @@ -3066,7 +3093,7 @@ mod tests { // Can rule out containers where either: // 1. 0 is outside the min/max range of i // 2. s only contains foo - vec![false, false, false, true, false, true, true, false, true], + &[false, false, false, true, false, true, true, false, true], ); // i = 0 OR s = 'foo' @@ -3076,7 +3103,7 @@ mod tests { &statistics, // in theory could rule out containers if we had min/max values for // s as well. But in this case we don't so we can't rule out anything - vec![true, true, true, true, true, true, true, true, true], + &[true, true, true, true, true, true, true, true, true], ); } @@ -3091,7 +3118,7 @@ mod tests { expr: Expr, schema: &SchemaRef, statistics: &TestStatistics, - expected: Vec, + expected: &[bool], ) { println!("Pruning with expr: {}", expr); let expr = logical2physical(&expr, schema); diff --git a/datafusion/core/tests/path_partition.rs b/datafusion/core/tests/path_partition.rs index abe6ab283aff..dd8eb52f67c7 100644 --- a/datafusion/core/tests/path_partition.rs +++ b/datafusion/core/tests/path_partition.rs @@ -168,9 +168,9 @@ async fn parquet_distinct_partition_col() -> Result<()> { assert_eq!(min_limit, resulting_limit); let s = ScalarValue::try_from_array(results[0].column(1), 0)?; - let month = match s { - ScalarValue::Utf8(Some(month)) => month, - s => panic!("Expected month as Utf8 found {s:?}"), + let month = match extract_as_utf(&s) { + Some(month) => month, + s => panic!("Expected month as Dict(_, Utf8) found {s:?}"), }; let sql_on_partition_boundary = format!( @@ -191,6 +191,15 @@ async fn parquet_distinct_partition_col() -> Result<()> { Ok(()) } +fn extract_as_utf(v: &ScalarValue) -> Option { + if let ScalarValue::Dictionary(_, v) = v { + if let ScalarValue::Utf8(v) = v.as_ref() { + return v.clone(); + } + } + None +} + #[tokio::test] async fn csv_filter_with_file_col() -> Result<()> { let ctx = SessionContext::new(); diff --git a/datafusion/core/tests/sql/expr.rs b/datafusion/core/tests/sql/expr.rs index 8ac0e3e5ef19..e8a3d27c089a 100644 --- a/datafusion/core/tests/sql/expr.rs +++ b/datafusion/core/tests/sql/expr.rs @@ -19,55 +19,6 @@ use datafusion::datasource::empty::EmptyTable; use super::*; -#[tokio::test] -async fn test_boolean_expressions() -> Result<()> { - test_expression!("true", "true"); - test_expression!("false", "false"); - test_expression!("false = false", "true"); - test_expression!("true = false", "false"); - Ok(()) -} - -#[tokio::test] -async fn test_mathematical_expressions_with_null() -> Result<()> { - test_expression!("sqrt(NULL)", "NULL"); - test_expression!("cbrt(NULL)", "NULL"); - test_expression!("sin(NULL)", "NULL"); - test_expression!("cos(NULL)", "NULL"); - test_expression!("tan(NULL)", "NULL"); - test_expression!("asin(NULL)", "NULL"); - test_expression!("acos(NULL)", "NULL"); - test_expression!("atan(NULL)", "NULL"); - test_expression!("sinh(NULL)", "NULL"); - test_expression!("cosh(NULL)", "NULL"); - test_expression!("tanh(NULL)", "NULL"); - test_expression!("asinh(NULL)", "NULL"); - test_expression!("acosh(NULL)", "NULL"); - test_expression!("atanh(NULL)", "NULL"); - test_expression!("floor(NULL)", "NULL"); - test_expression!("ceil(NULL)", "NULL"); - test_expression!("round(NULL)", "NULL"); - test_expression!("trunc(NULL)", "NULL"); - test_expression!("abs(NULL)", "NULL"); - test_expression!("signum(NULL)", "NULL"); - test_expression!("exp(NULL)", "NULL"); - test_expression!("ln(NULL)", "NULL"); - test_expression!("log2(NULL)", "NULL"); - test_expression!("log10(NULL)", "NULL"); - test_expression!("power(NULL, 2)", "NULL"); - test_expression!("power(NULL, NULL)", "NULL"); - test_expression!("power(2, NULL)", "NULL"); - test_expression!("atan2(NULL, NULL)", "NULL"); - test_expression!("atan2(1, NULL)", "NULL"); - test_expression!("atan2(NULL, 1)", "NULL"); - test_expression!("nanvl(NULL, NULL)", "NULL"); - test_expression!("nanvl(1, NULL)", "NULL"); - test_expression!("nanvl(NULL, 1)", "NULL"); - test_expression!("isnan(NULL)", "NULL"); - test_expression!("iszero(NULL)", "NULL"); - Ok(()) -} - #[tokio::test] #[cfg_attr(not(feature = "crypto_expressions"), ignore)] async fn test_encoding_expressions() -> Result<()> { @@ -128,14 +79,6 @@ async fn test_encoding_expressions() -> Result<()> { Ok(()) } -#[should_panic(expected = "Invalid timezone \\\"Foo\\\": 'Foo' is not a valid timezone")] -#[tokio::test] -async fn test_array_cast_invalid_timezone_will_panic() { - let ctx = SessionContext::new(); - let sql = "SELECT arrow_cast('2021-01-02T03:04:00', 'Timestamp(Nanosecond, Some(\"Foo\"))')"; - execute(&ctx, sql).await; -} - #[tokio::test] #[cfg_attr(not(feature = "crypto_expressions"), ignore)] async fn test_crypto_expressions() -> Result<()> { @@ -212,242 +155,6 @@ async fn test_crypto_expressions() -> Result<()> { Ok(()) } -#[tokio::test] -async fn test_array_index() -> Result<()> { - // By default PostgreSQL uses a one-based numbering convention for arrays, that is, an array of n elements starts with array[1] and ends with array[n] - test_expression!("([5,4,3,2,1])[1]", "5"); - test_expression!("([5,4,3,2,1])[2]", "4"); - test_expression!("([5,4,3,2,1])[5]", "1"); - test_expression!("([[1, 2], [2, 3], [3,4]])[1]", "[1, 2]"); - test_expression!("([[1, 2], [2, 3], [3,4]])[3]", "[3, 4]"); - test_expression!("([[1, 2], [2, 3], [3,4]])[1][1]", "1"); - test_expression!("([[1, 2], [2, 3], [3,4]])[2][2]", "3"); - test_expression!("([[1, 2], [2, 3], [3,4]])[3][2]", "4"); - // out of bounds - test_expression!("([5,4,3,2,1])[0]", "NULL"); - test_expression!("([5,4,3,2,1])[6]", "NULL"); - // test_expression!("([5,4,3,2,1])[-1]", "NULL"); - test_expression!("([5,4,3,2,1])[100]", "NULL"); - - Ok(()) -} - -#[tokio::test] -async fn test_array_literals() -> Result<()> { - // Named, just another syntax - test_expression!("ARRAY[1,2,3,4,5]", "[1, 2, 3, 4, 5]"); - // Unnamed variant - test_expression!("[1,2,3,4,5]", "[1, 2, 3, 4, 5]"); - test_expression!("[true, false]", "[true, false]"); - test_expression!("['str1', 'str2']", "[str1, str2]"); - test_expression!("[[1,2], [3,4]]", "[[1, 2], [3, 4]]"); - - // TODO: Not supported in parser, uncomment when it will be available - // test_expression!( - // "[]", - // "[]" - // ); - - Ok(()) -} - -#[tokio::test] -async fn test_struct_literals() -> Result<()> { - test_expression!("STRUCT(1,2,3,4,5)", "{c0: 1, c1: 2, c2: 3, c3: 4, c4: 5}"); - test_expression!("STRUCT(Null)", "{c0: }"); - test_expression!("STRUCT(2)", "{c0: 2}"); - test_expression!("STRUCT('1',Null)", "{c0: 1, c1: }"); - test_expression!("STRUCT(true, false)", "{c0: true, c1: false}"); - test_expression!("STRUCT('str1', 'str2')", "{c0: str1, c1: str2}"); - - Ok(()) -} - -#[tokio::test] -async fn binary_bitwise_shift() -> Result<()> { - test_expression!("2 << 10", "2048"); - test_expression!("2048 >> 10", "2"); - test_expression!("2048 << NULL", "NULL"); - test_expression!("2048 >> NULL", "NULL"); - - Ok(()) -} - -#[tokio::test] -async fn test_interval_expressions() -> Result<()> { - // day nano intervals - test_expression!( - "interval '1'", - "0 years 0 mons 0 days 0 hours 0 mins 1.000000000 secs" - ); - test_expression!( - "interval '1 second'", - "0 years 0 mons 0 days 0 hours 0 mins 1.000000000 secs" - ); - test_expression!( - "interval '500 milliseconds'", - "0 years 0 mons 0 days 0 hours 0 mins 0.500000000 secs" - ); - test_expression!( - "interval '5 second'", - "0 years 0 mons 0 days 0 hours 0 mins 5.000000000 secs" - ); - test_expression!( - "interval '0.5 minute'", - "0 years 0 mons 0 days 0 hours 0 mins 30.000000000 secs" - ); - // https://github.com/apache/arrow-rs/issues/4424 - // test_expression!( - // "interval '.5 minute'", - // "0 years 0 mons 0 days 0 hours 0 mins 30.000000000 secs" - // ); - test_expression!( - "interval '5 minute'", - "0 years 0 mons 0 days 0 hours 5 mins 0.000000000 secs" - ); - test_expression!( - "interval '5 minute 1 second'", - "0 years 0 mons 0 days 0 hours 5 mins 1.000000000 secs" - ); - test_expression!( - "interval '1 hour'", - "0 years 0 mons 0 days 1 hours 0 mins 0.000000000 secs" - ); - test_expression!( - "interval '5 hour'", - "0 years 0 mons 0 days 5 hours 0 mins 0.000000000 secs" - ); - test_expression!( - "interval '1 day'", - "0 years 0 mons 1 days 0 hours 0 mins 0.000000000 secs" - ); - test_expression!( - "interval '1 week'", - "0 years 0 mons 7 days 0 hours 0 mins 0.000000000 secs" - ); - test_expression!( - "interval '2 weeks'", - "0 years 0 mons 14 days 0 hours 0 mins 0.000000000 secs" - ); - test_expression!( - "interval '1 day 1'", - "0 years 0 mons 1 days 0 hours 0 mins 1.000000000 secs" - ); - test_expression!( - "interval '0.5'", - "0 years 0 mons 0 days 0 hours 0 mins 0.500000000 secs" - ); - test_expression!( - "interval '0.5 day 1'", - "0 years 0 mons 0 days 12 hours 0 mins 1.000000000 secs" - ); - test_expression!( - "interval '0.49 day'", - "0 years 0 mons 0 days 11 hours 45 mins 36.000000000 secs" - ); - test_expression!( - "interval '0.499 day'", - "0 years 0 mons 0 days 11 hours 58 mins 33.600000000 secs" - ); - test_expression!( - "interval '0.4999 day'", - "0 years 0 mons 0 days 11 hours 59 mins 51.360000000 secs" - ); - test_expression!( - "interval '0.49999 day'", - "0 years 0 mons 0 days 11 hours 59 mins 59.136000000 secs" - ); - test_expression!( - "interval '0.49999999999 day'", - "0 years 0 mons 0 days 11 hours 59 mins 59.999999136 secs" - ); - test_expression!( - "interval '5 day'", - "0 years 0 mons 5 days 0 hours 0 mins 0.000000000 secs" - ); - // Hour is ignored, this matches PostgreSQL - test_expression!( - "interval '5 day' hour", - "0 years 0 mons 5 days 0 hours 0 mins 0.000000000 secs" - ); - test_expression!( - "interval '5 day 4 hours 3 minutes 2 seconds 100 milliseconds'", - "0 years 0 mons 5 days 4 hours 3 mins 2.100000000 secs" - ); - // month intervals - test_expression!( - "interval '0.5 month'", - "0 years 0 mons 15 days 0 hours 0 mins 0.000000000 secs" - ); - test_expression!( - "interval '0.5' month", - "0 years 0 mons 15 days 0 hours 0 mins 0.000000000 secs" - ); - test_expression!( - "interval '1 month'", - "0 years 1 mons 0 days 0 hours 0 mins 0.000000000 secs" - ); - test_expression!( - "interval '1' MONTH", - "0 years 1 mons 0 days 0 hours 0 mins 0.000000000 secs" - ); - test_expression!( - "interval '5 month'", - "0 years 5 mons 0 days 0 hours 0 mins 0.000000000 secs" - ); - test_expression!( - "interval '13 month'", - "0 years 13 mons 0 days 0 hours 0 mins 0.000000000 secs" - ); - test_expression!( - "interval '0.5 year'", - "0 years 6 mons 0 days 0 hours 0 mins 0.000000000 secs" - ); - test_expression!( - "interval '1 year'", - "0 years 12 mons 0 days 0 hours 0 mins 0.000000000 secs" - ); - test_expression!( - "interval '1 decade'", - "0 years 120 mons 0 days 0 hours 0 mins 0.000000000 secs" - ); - test_expression!( - "interval '2 decades'", - "0 years 240 mons 0 days 0 hours 0 mins 0.000000000 secs" - ); - test_expression!( - "interval '1 century'", - "0 years 1200 mons 0 days 0 hours 0 mins 0.000000000 secs" - ); - test_expression!( - "interval '2 year'", - "0 years 24 mons 0 days 0 hours 0 mins 0.000000000 secs" - ); - test_expression!( - "interval '2' year", - "0 years 24 mons 0 days 0 hours 0 mins 0.000000000 secs" - ); - // complex - test_expression!( - "interval '1 year 1 day'", - "0 years 12 mons 1 days 0 hours 0 mins 0.000000000 secs" - ); - test_expression!( - "interval '1 year 1 day 1 hour'", - "0 years 12 mons 1 days 1 hours 0 mins 0.000000000 secs" - ); - test_expression!( - "interval '1 year 1 day 1 hour 1 minute'", - "0 years 12 mons 1 days 1 hours 1 mins 0.000000000 secs" - ); - test_expression!( - "interval '1 year 1 day 1 hour 1 minute 1 second'", - "0 years 12 mons 1 days 1 hours 1 mins 1.000000000 secs" - ); - - Ok(()) -} - #[cfg(feature = "unicode_expressions")] #[tokio::test] async fn test_substring_expr() -> Result<()> { @@ -458,108 +165,6 @@ async fn test_substring_expr() -> Result<()> { Ok(()) } -/// Test string expressions test split into two batches -/// to prevent stack overflow error -#[tokio::test] -async fn test_string_expressions_batch1() -> Result<()> { - test_expression!("ascii('')", "0"); - test_expression!("ascii('x')", "120"); - test_expression!("ascii(NULL)", "NULL"); - test_expression!("bit_length('')", "0"); - test_expression!("bit_length('chars')", "40"); - test_expression!("bit_length('josé')", "40"); - test_expression!("bit_length(NULL)", "NULL"); - test_expression!("btrim(' xyxtrimyyx ', NULL)", "NULL"); - test_expression!("btrim(' xyxtrimyyx ')", "xyxtrimyyx"); - test_expression!("btrim('\n xyxtrimyyx \n')", "\n xyxtrimyyx \n"); - test_expression!("btrim('xyxtrimyyx', 'xyz')", "trim"); - test_expression!("btrim('\nxyxtrimyyx\n', 'xyz\n')", "trim"); - test_expression!("btrim(NULL, 'xyz')", "NULL"); - test_expression!("chr(CAST(120 AS int))", "x"); - test_expression!("chr(CAST(128175 AS int))", "💯"); - test_expression!("chr(CAST(NULL AS int))", "NULL"); - test_expression!("concat('a','b','c')", "abc"); - test_expression!("concat('abcde', 2, NULL, 22)", "abcde222"); - test_expression!("concat(NULL)", ""); - test_expression!("concat_ws(',', 'abcde', 2, NULL, 22)", "abcde,2,22"); - test_expression!("concat_ws('|','a','b','c')", "a|b|c"); - test_expression!("concat_ws('|',NULL)", ""); - test_expression!("concat_ws(NULL,'a',NULL,'b','c')", "NULL"); - test_expression!("concat_ws('|','a',NULL)", "a"); - test_expression!("concat_ws('|','a',NULL,NULL)", "a"); - test_expression!("initcap('')", ""); - test_expression!("initcap('hi THOMAS')", "Hi Thomas"); - test_expression!("initcap(NULL)", "NULL"); - test_expression!("lower('')", ""); - test_expression!("lower('TOM')", "tom"); - test_expression!("lower(NULL)", "NULL"); - test_expression!("ltrim(' zzzytest ', NULL)", "NULL"); - test_expression!("ltrim(' zzzytest ')", "zzzytest "); - test_expression!("ltrim('zzzytest', 'xyz')", "test"); - test_expression!("ltrim(NULL, 'xyz')", "NULL"); - test_expression!("octet_length('')", "0"); - test_expression!("octet_length('chars')", "5"); - test_expression!("octet_length('josé')", "5"); - test_expression!("octet_length(NULL)", "NULL"); - test_expression!("repeat('Pg', 4)", "PgPgPgPg"); - test_expression!("repeat('Pg', CAST(NULL AS INT))", "NULL"); - test_expression!("repeat(NULL, 4)", "NULL"); - test_expression!("replace('abcdefabcdef', 'cd', 'XX')", "abXXefabXXef"); - test_expression!("replace('abcdefabcdef', 'cd', NULL)", "NULL"); - test_expression!("replace('abcdefabcdef', 'notmatch', 'XX')", "abcdefabcdef"); - test_expression!("replace('abcdefabcdef', NULL, 'XX')", "NULL"); - test_expression!("replace(NULL, 'cd', 'XX')", "NULL"); - test_expression!("rtrim(' testxxzx ')", " testxxzx"); - test_expression!("rtrim(' zzzytest ', NULL)", "NULL"); - test_expression!("rtrim('testxxzx', 'xyz')", "test"); - test_expression!("rtrim(NULL, 'xyz')", "NULL"); - Ok(()) -} - -/// Test string expressions test split into two batches -/// to prevent stack overflow error -#[tokio::test] -async fn test_string_expressions_batch2() -> Result<()> { - test_expression!("split_part('abc~@~def~@~ghi', '~@~', 2)", "def"); - test_expression!("split_part('abc~@~def~@~ghi', '~@~', 20)", ""); - test_expression!("split_part(NULL, '~@~', 20)", "NULL"); - test_expression!("split_part('abc~@~def~@~ghi', NULL, 20)", "NULL"); - test_expression!( - "split_part('abc~@~def~@~ghi', '~@~', CAST(NULL AS INT))", - "NULL" - ); - test_expression!("starts_with('alphabet', 'alph')", "true"); - test_expression!("starts_with('alphabet', 'blph')", "false"); - test_expression!("starts_with(NULL, 'blph')", "NULL"); - test_expression!("starts_with('alphabet', NULL)", "NULL"); - test_expression!("to_hex(2147483647)", "7fffffff"); - test_expression!("to_hex(9223372036854775807)", "7fffffffffffffff"); - test_expression!("to_hex(CAST(NULL AS int))", "NULL"); - test_expression!("trim(' tom ')", "tom"); - test_expression!("trim(LEADING ' tom ')", "tom "); - test_expression!("trim(TRAILING ' tom ')", " tom"); - test_expression!("trim(BOTH ' tom ')", "tom"); - test_expression!("trim(LEADING ' ' FROM ' tom ')", "tom "); - test_expression!("trim(TRAILING ' ' FROM ' tom ')", " tom"); - test_expression!("trim(BOTH ' ' FROM ' tom ')", "tom"); - test_expression!("trim(' ' FROM ' tom ')", "tom"); - test_expression!("trim(LEADING 'x' FROM 'xxxtomxxx')", "tomxxx"); - test_expression!("trim(TRAILING 'x' FROM 'xxxtomxxx')", "xxxtom"); - test_expression!("trim(BOTH 'x' FROM 'xxxtomxx')", "tom"); - test_expression!("trim('x' FROM 'xxxtomxx')", "tom"); - test_expression!("trim(LEADING 'xy' FROM 'xyxabcxyzdefxyx')", "abcxyzdefxyx"); - test_expression!("trim(TRAILING 'xy' FROM 'xyxabcxyzdefxyx')", "xyxabcxyzdef"); - test_expression!("trim(BOTH 'xy' FROM 'xyxabcxyzdefxyx')", "abcxyzdef"); - test_expression!("trim('xy' FROM 'xyxabcxyzdefxyx')", "abcxyzdef"); - test_expression!("trim(' tom')", "tom"); - test_expression!("trim('')", ""); - test_expression!("trim('tom ')", "tom"); - test_expression!("upper('')", ""); - test_expression!("upper('tom')", "TOM"); - test_expression!("upper(NULL)", "NULL"); - Ok(()) -} - #[tokio::test] #[cfg_attr(not(feature = "regex_expressions"), ignore)] async fn test_regex_expressions() -> Result<()> { @@ -593,329 +198,6 @@ async fn test_regex_expressions() -> Result<()> { Ok(()) } -#[tokio::test] -async fn test_cast_expressions() -> Result<()> { - test_expression!("CAST('0' AS INT)", "0"); - test_expression!("CAST(NULL AS INT)", "NULL"); - test_expression!("TRY_CAST('0' AS INT)", "0"); - test_expression!("TRY_CAST('x' AS INT)", "NULL"); - Ok(()) -} - -#[tokio::test] -#[ignore] -// issue: https://github.com/apache/arrow-datafusion/issues/6596 -async fn test_array_cast_expressions() -> Result<()> { - test_expression!("CAST([1,2,3,4] AS INT[])", "[1, 2, 3, 4]"); - test_expression!( - "CAST([1,2,3,4] AS NUMERIC(10,4)[])", - "[1.0000, 2.0000, 3.0000, 4.0000]" - ); - Ok(()) -} - -#[tokio::test] -async fn test_random_expression() -> Result<()> { - let ctx = SessionContext::new(); - let sql = "SELECT random() r1"; - let actual = execute(&ctx, sql).await; - let r1 = actual[0][0].parse::().unwrap(); - assert!(0.0 <= r1); - assert!(r1 < 1.0); - Ok(()) -} - -#[tokio::test] -async fn test_uuid_expression() -> Result<()> { - let ctx = SessionContext::new(); - let sql = "SELECT uuid()"; - let actual = execute(&ctx, sql).await; - let uuid = actual[0][0].parse::().unwrap(); - assert_eq!(uuid.get_version_num(), 4); - Ok(()) -} - -#[tokio::test] -async fn test_extract_date_part() -> Result<()> { - test_expression!("date_part('YEAR', CAST('2000-01-01' AS DATE))", "2000.0"); - test_expression!( - "EXTRACT(year FROM timestamp '2020-09-08T12:00:00+00:00')", - "2020.0" - ); - test_expression!("date_part('QUARTER', CAST('2000-01-01' AS DATE))", "1.0"); - test_expression!( - "EXTRACT(quarter FROM to_timestamp('2020-09-08T12:00:00+00:00'))", - "3.0" - ); - test_expression!("date_part('MONTH', CAST('2000-01-01' AS DATE))", "1.0"); - test_expression!( - "EXTRACT(month FROM to_timestamp('2020-09-08T12:00:00+00:00'))", - "9.0" - ); - test_expression!("date_part('WEEK', CAST('2003-01-01' AS DATE))", "1.0"); - test_expression!( - "EXTRACT(WEEK FROM to_timestamp('2020-09-08T12:00:00+00:00'))", - "37.0" - ); - test_expression!("date_part('DAY', CAST('2000-01-01' AS DATE))", "1.0"); - test_expression!( - "EXTRACT(day FROM to_timestamp('2020-09-08T12:00:00+00:00'))", - "8.0" - ); - test_expression!("date_part('DOY', CAST('2000-01-01' AS DATE))", "1.0"); - test_expression!( - "EXTRACT(doy FROM to_timestamp('2020-09-08T12:00:00+00:00'))", - "252.0" - ); - test_expression!("date_part('DOW', CAST('2000-01-01' AS DATE))", "6.0"); - test_expression!( - "EXTRACT(dow FROM to_timestamp('2020-09-08T12:00:00+00:00'))", - "2.0" - ); - test_expression!("date_part('HOUR', CAST('2000-01-01' AS DATE))", "0.0"); - test_expression!( - "EXTRACT(hour FROM to_timestamp('2020-09-08T12:03:03+00:00'))", - "12.0" - ); - test_expression!( - "EXTRACT(minute FROM to_timestamp('2020-09-08T12:12:00+00:00'))", - "12.0" - ); - test_expression!( - "date_part('minute', to_timestamp('2020-09-08T12:12:00+00:00'))", - "12.0" - ); - test_expression!( - "EXTRACT(second FROM timestamp '2020-09-08T12:00:12.12345678+00:00')", - "12.12345678" - ); - test_expression!( - "EXTRACT(millisecond FROM timestamp '2020-09-08T12:00:12.12345678+00:00')", - "12123.45678" - ); - test_expression!( - "EXTRACT(microsecond FROM timestamp '2020-09-08T12:00:12.12345678+00:00')", - "12123456.78" - ); - test_expression!( - "EXTRACT(nanosecond FROM timestamp '2020-09-08T12:00:12.12345678+00:00')", - "1.212345678e10" - ); - test_expression!( - "date_part('second', timestamp '2020-09-08T12:00:12.12345678+00:00')", - "12.12345678" - ); - test_expression!( - "date_part('millisecond', timestamp '2020-09-08T12:00:12.12345678+00:00')", - "12123.45678" - ); - test_expression!( - "date_part('microsecond', timestamp '2020-09-08T12:00:12.12345678+00:00')", - "12123456.78" - ); - test_expression!( - "date_part('nanosecond', timestamp '2020-09-08T12:00:12.12345678+00:00')", - "1.212345678e10" - ); - - // Keep precision when coercing Utf8 to Timestamp - test_expression!( - "date_part('second', '2020-09-08T12:00:12.12345678+00:00')", - "12.12345678" - ); - test_expression!( - "date_part('millisecond', '2020-09-08T12:00:12.12345678+00:00')", - "12123.45678" - ); - test_expression!( - "date_part('microsecond', '2020-09-08T12:00:12.12345678+00:00')", - "12123456.78" - ); - test_expression!( - "date_part('nanosecond', '2020-09-08T12:00:12.12345678+00:00')", - "1.212345678e10" - ); - - Ok(()) -} - -#[tokio::test] -async fn test_extract_epoch() -> Result<()> { - // timestamp - test_expression!( - "extract(epoch from '1870-01-01T07:29:10.256'::timestamp)", - "-3155646649.744" - ); - test_expression!( - "extract(epoch from '2000-01-01T00:00:00.000'::timestamp)", - "946684800.0" - ); - test_expression!( - "extract(epoch from to_timestamp('2000-01-01T00:00:00+00:00'))", - "946684800.0" - ); - test_expression!("extract(epoch from NULL::timestamp)", "NULL"); - // date - test_expression!( - "extract(epoch from arrow_cast('1970-01-01', 'Date32'))", - "0.0" - ); - test_expression!( - "extract(epoch from arrow_cast('1970-01-02', 'Date32'))", - "86400.0" - ); - test_expression!( - "extract(epoch from arrow_cast('1970-01-11', 'Date32'))", - "864000.0" - ); - test_expression!( - "extract(epoch from arrow_cast('1969-12-31', 'Date32'))", - "-86400.0" - ); - test_expression!( - "extract(epoch from arrow_cast('1970-01-01', 'Date64'))", - "0.0" - ); - test_expression!( - "extract(epoch from arrow_cast('1970-01-02', 'Date64'))", - "86400.0" - ); - test_expression!( - "extract(epoch from arrow_cast('1970-01-11', 'Date64'))", - "864000.0" - ); - test_expression!( - "extract(epoch from arrow_cast('1969-12-31', 'Date64'))", - "-86400.0" - ); - Ok(()) -} - -#[tokio::test] -async fn test_extract_date_part_func() -> Result<()> { - test_expression!( - format!( - "(date_part('{0}', now()) = EXTRACT({0} FROM now()))", - "year" - ), - "true" - ); - test_expression!( - format!( - "(date_part('{0}', now()) = EXTRACT({0} FROM now()))", - "quarter" - ), - "true" - ); - test_expression!( - format!( - "(date_part('{0}', now()) = EXTRACT({0} FROM now()))", - "month" - ), - "true" - ); - test_expression!( - format!( - "(date_part('{0}', now()) = EXTRACT({0} FROM now()))", - "week" - ), - "true" - ); - test_expression!( - format!("(date_part('{0}', now()) = EXTRACT({0} FROM now()))", "day"), - "true" - ); - test_expression!( - format!( - "(date_part('{0}', now()) = EXTRACT({0} FROM now()))", - "hour" - ), - "true" - ); - test_expression!( - format!( - "(date_part('{0}', now()) = EXTRACT({0} FROM now()))", - "minute" - ), - "true" - ); - test_expression!( - format!( - "(date_part('{0}', now()) = EXTRACT({0} FROM now()))", - "second" - ), - "true" - ); - test_expression!( - format!( - "(date_part('{0}', now()) = EXTRACT({0} FROM now()))", - "millisecond" - ), - "true" - ); - test_expression!( - format!( - "(date_part('{0}', now()) = EXTRACT({0} FROM now()))", - "microsecond" - ), - "true" - ); - test_expression!( - format!( - "(date_part('{0}', now()) = EXTRACT({0} FROM now()))", - "nanosecond" - ), - "true" - ); - - Ok(()) -} - -#[tokio::test] -async fn test_in_list_scalar() -> Result<()> { - test_expression!("'a' IN ('a','b')", "true"); - test_expression!("'c' IN ('a','b')", "false"); - test_expression!("'c' NOT IN ('a','b')", "true"); - test_expression!("'a' NOT IN ('a','b')", "false"); - test_expression!("NULL IN ('a','b')", "NULL"); - test_expression!("NULL NOT IN ('a','b')", "NULL"); - test_expression!("'a' IN ('a','b',NULL)", "true"); - test_expression!("'c' IN ('a','b',NULL)", "NULL"); - test_expression!("'a' NOT IN ('a','b',NULL)", "false"); - test_expression!("'c' NOT IN ('a','b',NULL)", "NULL"); - test_expression!("0 IN (0,1,2)", "true"); - test_expression!("3 IN (0,1,2)", "false"); - test_expression!("3 NOT IN (0,1,2)", "true"); - test_expression!("0 NOT IN (0,1,2)", "false"); - test_expression!("NULL IN (0,1,2)", "NULL"); - test_expression!("NULL NOT IN (0,1,2)", "NULL"); - test_expression!("0 IN (0,1,2,NULL)", "true"); - test_expression!("3 IN (0,1,2,NULL)", "NULL"); - test_expression!("0 NOT IN (0,1,2,NULL)", "false"); - test_expression!("3 NOT IN (0,1,2,NULL)", "NULL"); - test_expression!("0.0 IN (0.0,0.1,0.2)", "true"); - test_expression!("0.3 IN (0.0,0.1,0.2)", "false"); - test_expression!("0.3 NOT IN (0.0,0.1,0.2)", "true"); - test_expression!("0.0 NOT IN (0.0,0.1,0.2)", "false"); - test_expression!("NULL IN (0.0,0.1,0.2)", "NULL"); - test_expression!("NULL NOT IN (0.0,0.1,0.2)", "NULL"); - test_expression!("0.0 IN (0.0,0.1,0.2,NULL)", "true"); - test_expression!("0.3 IN (0.0,0.1,0.2,NULL)", "NULL"); - test_expression!("0.0 NOT IN (0.0,0.1,0.2,NULL)", "false"); - test_expression!("0.3 NOT IN (0.0,0.1,0.2,NULL)", "NULL"); - test_expression!("'1' IN ('a','b',1)", "true"); - test_expression!("'2' IN ('a','b',1)", "false"); - test_expression!("'2' NOT IN ('a','b',1)", "true"); - test_expression!("'1' NOT IN ('a','b',1)", "false"); - test_expression!("NULL IN ('a','b',1)", "NULL"); - test_expression!("NULL NOT IN ('a','b',1)", "NULL"); - test_expression!("'1' IN ('a','b',NULL,1)", "true"); - test_expression!("'2' IN ('a','b',NULL,1)", "NULL"); - test_expression!("'1' NOT IN ('a','b',NULL,1)", "false"); - test_expression!("'2' NOT IN ('a','b',NULL,1)", "NULL"); - Ok(()) -} - #[tokio::test] async fn csv_query_nullif_divide_by_0() -> Result<()> { let ctx = SessionContext::new(); diff --git a/datafusion/core/tests/user_defined/user_defined_window_functions.rs b/datafusion/core/tests/user_defined/user_defined_window_functions.rs index 3040fbafe81a..54eab4315a97 100644 --- a/datafusion/core/tests/user_defined/user_defined_window_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_window_functions.rs @@ -471,6 +471,7 @@ impl OddCounter { } fn register(ctx: &mut SessionContext, test_state: Arc) { + #[derive(Debug, Clone)] struct SimpleWindowUDF { signature: Signature, return_type: DataType, diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index ebf4d3143c12..5617d217eb9f 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -1948,6 +1948,7 @@ mod test { ); // UDF + #[derive(Debug)] struct TestScalarUDF { signature: Signature, } diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index f76fb17b38bb..0491750d18a9 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -32,6 +32,7 @@ use crate::{ColumnarValue, ScalarUDFImpl, WindowUDF, WindowUDFImpl}; use arrow::datatypes::DataType; use datafusion_common::{Column, Result}; use std::any::Any; +use std::fmt::Debug; use std::ops::Not; use std::sync::Arc; @@ -983,6 +984,16 @@ pub struct SimpleScalarUDF { fun: ScalarFunctionImplementation, } +impl Debug for SimpleScalarUDF { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("ScalarUDF") + .field("name", &self.name) + .field("signature", &self.signature) + .field("fun", &"") + .finish() + } +} + impl SimpleScalarUDF { /// Create a new `SimpleScalarUDF` from a name, input types, return type and /// implementation. Implementing [`ScalarUDFImpl`] allows more flexibility @@ -1078,6 +1089,17 @@ pub struct SimpleWindowUDF { partition_evaluator_factory: PartitionEvaluatorFactory, } +impl Debug for SimpleWindowUDF { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("WindowUDF") + .field("name", &self.name) + .field("signature", &self.signature) + .field("return_type", &"") + .field("partition_evaluator_factory", &"") + .finish() + } +} + impl SimpleWindowUDF { /// Create a new `SimpleWindowUDF` from a name, input types, return type and /// implementation. Implementing [`WindowUDFImpl`] allows more flexibility diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index a684f3e97485..847fbbbf61c7 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -1845,13 +1845,16 @@ mod tests { .project(vec![col("id"), col("first_name").alias("id")]); match plan { - Err(DataFusionError::SchemaError(SchemaError::AmbiguousReference { - field: - Column { - relation: Some(OwnedTableReference::Bare { table }), - name, - }, - })) => { + Err(DataFusionError::SchemaError( + SchemaError::AmbiguousReference { + field: + Column { + relation: Some(OwnedTableReference::Bare { table }), + name, + }, + }, + _, + )) => { assert_eq!("employee_csv", table); assert_eq!("id", &name); Ok(()) @@ -1872,13 +1875,16 @@ mod tests { .aggregate(vec![col("state")], vec![sum(col("salary")).alias("state")]); match plan { - Err(DataFusionError::SchemaError(SchemaError::AmbiguousReference { - field: - Column { - relation: Some(OwnedTableReference::Bare { table }), - name, - }, - })) => { + Err(DataFusionError::SchemaError( + SchemaError::AmbiguousReference { + field: + Column { + relation: Some(OwnedTableReference::Bare { table }), + name, + }, + }, + _, + )) => { assert_eq!("employee_csv", table); assert_eq!("state", &name); Ok(()) diff --git a/datafusion/expr/src/type_coercion/binary.rs b/datafusion/expr/src/type_coercion/binary.rs index 1b62c1bc05c1..6bacc1870079 100644 --- a/datafusion/expr/src/type_coercion/binary.rs +++ b/datafusion/expr/src/type_coercion/binary.rs @@ -667,8 +667,6 @@ fn string_concat_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { string_concat_internal_coercion(from_type, &LargeUtf8) } - // TODO: cast between array elements (#6558) - (List(_), from_type) | (from_type, List(_)) => Some(from_type.to_owned()), _ => None, }) } diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 2ec80a4a9ea1..8b35d5834c61 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -35,48 +35,26 @@ use std::sync::Arc; /// functions you supply such name, type signature, return type, and actual /// implementation. /// -/// /// 1. For simple (less performant) use cases, use [`create_udf`] and [`simple_udf.rs`]. /// /// 2. For advanced use cases, use [`ScalarUDFImpl`] and [`advanced_udf.rs`]. /// +/// # API Note +/// +/// This is a separate struct from `ScalarUDFImpl` to maintain backwards +/// compatibility with the older API. +/// /// [`create_udf`]: crate::expr_fn::create_udf /// [`simple_udf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/simple_udf.rs /// [`advanced_udf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/advanced_udf.rs -#[derive(Clone)] +#[derive(Debug, Clone)] pub struct ScalarUDF { - /// The name of the function - name: String, - /// The signature (the types of arguments that are supported) - signature: Signature, - /// Function that returns the return type given the argument types - return_type: ReturnTypeFunction, - /// actual implementation - /// - /// The fn param is the wrapped function but be aware that the function will - /// be passed with the slice / vec of columnar values (either scalar or array) - /// with the exception of zero param function, where a singular element vec - /// will be passed. In that case the single element is a null array to indicate - /// the batch's row count (so that the generative zero-argument function can know - /// the result array size). - fun: ScalarFunctionImplementation, - /// Optional aliases for the function. This list should NOT include the value of `name` as well - aliases: Vec, -} - -impl Debug for ScalarUDF { - fn fmt(&self, f: &mut Formatter) -> fmt::Result { - f.debug_struct("ScalarUDF") - .field("name", &self.name) - .field("signature", &self.signature) - .field("fun", &"") - .finish() - } + inner: Arc, } impl PartialEq for ScalarUDF { fn eq(&self, other: &Self) -> bool { - self.name == other.name && self.signature == other.signature + self.name() == other.name() && self.signature() == other.signature() } } @@ -84,8 +62,8 @@ impl Eq for ScalarUDF {} impl std::hash::Hash for ScalarUDF { fn hash(&self, state: &mut H) { - self.name.hash(state); - self.signature.hash(state); + self.name().hash(state); + self.signature().hash(state); } } @@ -101,13 +79,12 @@ impl ScalarUDF { return_type: &ReturnTypeFunction, fun: &ScalarFunctionImplementation, ) -> Self { - Self { + Self::new_from_impl(ScalarUdfLegacyWrapper { name: name.to_owned(), signature: signature.clone(), return_type: return_type.clone(), fun: fun.clone(), - aliases: vec![], - } + }) } /// Create a new `ScalarUDF` from a `[ScalarUDFImpl]` trait object @@ -115,37 +92,24 @@ impl ScalarUDF { /// Note this is the same as using the `From` impl (`ScalarUDF::from`) pub fn new_from_impl(fun: F) -> ScalarUDF where - F: ScalarUDFImpl + Send + Sync + 'static, + F: ScalarUDFImpl + 'static, { - // TODO change the internal implementation to use the trait object - let arc_fun = Arc::new(fun); - let captured_self = arc_fun.clone(); - let return_type: ReturnTypeFunction = Arc::new(move |arg_types| { - let return_type = captured_self.return_type(arg_types)?; - Ok(Arc::new(return_type)) - }); - - let captured_self = arc_fun.clone(); - let func: ScalarFunctionImplementation = - Arc::new(move |args| captured_self.invoke(args)); - Self { - name: arc_fun.name().to_string(), - signature: arc_fun.signature().clone(), - return_type: return_type.clone(), - fun: func, - aliases: arc_fun.aliases().to_vec(), + inner: Arc::new(fun), } } - /// Adds additional names that can be used to invoke this function, in addition to `name` - pub fn with_aliases( - mut self, - aliases: impl IntoIterator, - ) -> Self { - self.aliases - .extend(aliases.into_iter().map(|s| s.to_string())); - self + /// Return the underlying [`ScalarUDFImpl`] trait object for this function + pub fn inner(&self) -> Arc { + self.inner.clone() + } + + /// Adds additional names that can be used to invoke this function, in + /// addition to `name` + /// + /// If you implement [`ScalarUDFImpl`] directly you should return aliases directly. + pub fn with_aliases(self, aliases: impl IntoIterator) -> Self { + Self::new_from_impl(AliasedScalarUDFImpl::new(self, aliases)) } /// Returns a [`Expr`] logical expression to call this UDF with specified @@ -159,31 +123,46 @@ impl ScalarUDF { )) } - /// Returns this function's name + /// Returns this function's name. + /// + /// See [`ScalarUDFImpl::name`] for more details. pub fn name(&self) -> &str { - &self.name + self.inner.name() } - /// Returns the aliases for this function. See [`ScalarUDF::with_aliases`] for more details + /// Returns the aliases for this function. + /// + /// See [`ScalarUDF::with_aliases`] for more details pub fn aliases(&self) -> &[String] { - &self.aliases + self.inner.aliases() } - /// Returns this function's [`Signature`] (what input types are accepted) + /// Returns this function's [`Signature`] (what input types are accepted). + /// + /// See [`ScalarUDFImpl::signature`] for more details. pub fn signature(&self) -> &Signature { - &self.signature + self.inner.signature() } - /// The datatype this function returns given the input argument input types + /// The datatype this function returns given the input argument input types. + /// + /// See [`ScalarUDFImpl::return_type`] for more details. pub fn return_type(&self, args: &[DataType]) -> Result { - // Old API returns an Arc of the datatype for some reason - let res = (self.return_type)(args)?; - Ok(res.as_ref().clone()) + self.inner.return_type(args) + } + + /// Invoke the function on `args`, returning the appropriate result. + /// + /// See [`ScalarUDFImpl::invoke`] for more details. + pub fn invoke(&self, args: &[ColumnarValue]) -> Result { + self.inner.invoke(args) } - /// Return an [`Arc`] to the function implementation + /// Returns a `ScalarFunctionImplementation` that can invoke the function + /// during execution pub fn fun(&self) -> ScalarFunctionImplementation { - self.fun.clone() + let captured = self.inner.clone(); + Arc::new(move |args| captured.invoke(args)) } } @@ -213,6 +192,7 @@ where /// # use datafusion_common::{DataFusionError, plan_err, Result}; /// # use datafusion_expr::{col, ColumnarValue, Signature, Volatility}; /// # use datafusion_expr::{ScalarUDFImpl, ScalarUDF}; +/// #[derive(Debug)] /// struct AddOne { /// signature: Signature /// }; @@ -246,7 +226,7 @@ where /// // Call the function `add_one(col)` /// let expr = add_one.call(vec![col("a")]); /// ``` -pub trait ScalarUDFImpl { +pub trait ScalarUDFImpl: Debug + Send + Sync { /// Returns this object as an [`Any`] trait object fn as_any(&self) -> &dyn Any; @@ -292,3 +272,106 @@ pub trait ScalarUDFImpl { &[] } } + +/// ScalarUDF that adds an alias to the underlying function. It is better to +/// implement [`ScalarUDFImpl`], which supports aliases, directly if possible. +#[derive(Debug)] +struct AliasedScalarUDFImpl { + inner: ScalarUDF, + aliases: Vec, +} + +impl AliasedScalarUDFImpl { + pub fn new( + inner: ScalarUDF, + new_aliases: impl IntoIterator, + ) -> Self { + let mut aliases = inner.aliases().to_vec(); + aliases.extend(new_aliases.into_iter().map(|s| s.to_string())); + + Self { inner, aliases } + } +} + +impl ScalarUDFImpl for AliasedScalarUDFImpl { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + self.inner.name() + } + + fn signature(&self) -> &Signature { + self.inner.signature() + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + self.inner.return_type(arg_types) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + self.inner.invoke(args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +/// Implementation of [`ScalarUDFImpl`] that wraps the function style pointers +/// of the older API (see +/// for more details) +struct ScalarUdfLegacyWrapper { + /// The name of the function + name: String, + /// The signature (the types of arguments that are supported) + signature: Signature, + /// Function that returns the return type given the argument types + return_type: ReturnTypeFunction, + /// actual implementation + /// + /// The fn param is the wrapped function but be aware that the function will + /// be passed with the slice / vec of columnar values (either scalar or array) + /// with the exception of zero param function, where a singular element vec + /// will be passed. In that case the single element is a null array to indicate + /// the batch's row count (so that the generative zero-argument function can know + /// the result array size). + fun: ScalarFunctionImplementation, +} + +impl Debug for ScalarUdfLegacyWrapper { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + f.debug_struct("ScalarUDF") + .field("name", &self.name) + .field("signature", &self.signature) + .field("fun", &"") + .finish() + } +} + +impl ScalarUDFImpl for ScalarUdfLegacyWrapper { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + // Old API returns an Arc of the datatype for some reason + let res = (self.return_type)(arg_types)?; + Ok(res.as_ref().clone()) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + (self.fun)(args) + } + + fn aliases(&self) -> &[String] { + &[] + } +} diff --git a/datafusion/expr/src/udwf.rs b/datafusion/expr/src/udwf.rs index 800386bfc77b..239a5e24cbf2 100644 --- a/datafusion/expr/src/udwf.rs +++ b/datafusion/expr/src/udwf.rs @@ -34,40 +34,33 @@ use std::{ /// /// See the documetnation on [`PartitionEvaluator`] for more details /// +/// 1. For simple (less performant) use cases, use [`create_udwf`] and [`simple_udwf.rs`]. +/// +/// 2. For advanced use cases, use [`WindowUDFImpl`] and [`advanced_udf.rs`]. +/// +/// # API Note +/// This is a separate struct from `WindowUDFImpl` to maintain backwards +/// compatibility with the older API. +/// /// [`PartitionEvaluator`]: crate::PartitionEvaluator -#[derive(Clone)] +/// [`create_udwf`]: crate::expr_fn::create_udwf +/// [`simple_udwf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/simple_udwf.rs +/// [`advanced_udwf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/advanced_udwf.rs +#[derive(Debug, Clone)] pub struct WindowUDF { - /// name - name: String, - /// signature - signature: Signature, - /// Return type - return_type: ReturnTypeFunction, - /// Return the partition evaluator - partition_evaluator_factory: PartitionEvaluatorFactory, -} - -impl Debug for WindowUDF { - fn fmt(&self, f: &mut Formatter) -> fmt::Result { - f.debug_struct("WindowUDF") - .field("name", &self.name) - .field("signature", &self.signature) - .field("return_type", &"") - .field("partition_evaluator_factory", &"") - .finish_non_exhaustive() - } + inner: Arc, } /// Defines how the WindowUDF is shown to users impl Display for WindowUDF { fn fmt(&self, f: &mut Formatter) -> fmt::Result { - write!(f, "{}", self.name) + write!(f, "{}", self.name()) } } impl PartialEq for WindowUDF { fn eq(&self, other: &Self) -> bool { - self.name == other.name && self.signature == other.signature + self.name() == other.name() && self.signature() == other.signature() } } @@ -75,8 +68,8 @@ impl Eq for WindowUDF {} impl std::hash::Hash for WindowUDF { fn hash(&self, state: &mut H) { - self.name.hash(state); - self.signature.hash(state); + self.name().hash(state); + self.signature().hash(state); } } @@ -92,12 +85,12 @@ impl WindowUDF { return_type: &ReturnTypeFunction, partition_evaluator_factory: &PartitionEvaluatorFactory, ) -> Self { - Self { - name: name.to_string(), + Self::new_from_impl(WindowUDFLegacyWrapper { + name: name.to_owned(), signature: signature.clone(), return_type: return_type.clone(), partition_evaluator_factory: partition_evaluator_factory.clone(), - } + }) } /// Create a new `WindowUDF` from a `[WindowUDFImpl]` trait object @@ -105,27 +98,18 @@ impl WindowUDF { /// Note this is the same as using the `From` impl (`WindowUDF::from`) pub fn new_from_impl(fun: F) -> WindowUDF where - F: WindowUDFImpl + Send + Sync + 'static, + F: WindowUDFImpl + 'static, { - let arc_fun = Arc::new(fun); - let captured_self = arc_fun.clone(); - let return_type: ReturnTypeFunction = Arc::new(move |arg_types| { - let return_type = captured_self.return_type(arg_types)?; - Ok(Arc::new(return_type)) - }); - - let captured_self = arc_fun.clone(); - let partition_evaluator_factory: PartitionEvaluatorFactory = - Arc::new(move || captured_self.partition_evaluator()); - Self { - name: arc_fun.name().to_string(), - signature: arc_fun.signature().clone(), - return_type: return_type.clone(), - partition_evaluator_factory, + inner: Arc::new(fun), } } + /// Return the underlying [`WindowUDFImpl`] trait object for this function + pub fn inner(&self) -> Arc { + self.inner.clone() + } + /// creates a [`Expr`] that calls the window function given /// the `partition_by`, `order_by`, and `window_frame` definition /// @@ -150,25 +134,29 @@ impl WindowUDF { } /// Returns this function's name + /// + /// See [`WindowUDFImpl::name`] for more details. pub fn name(&self) -> &str { - &self.name + self.inner.name() } /// Returns this function's signature (what input types are accepted) + /// + /// See [`WindowUDFImpl::signature`] for more details. pub fn signature(&self) -> &Signature { - &self.signature + self.inner.signature() } /// Return the type of the function given its input types + /// + /// See [`WindowUDFImpl::return_type`] for more details. pub fn return_type(&self, args: &[DataType]) -> Result { - // Old API returns an Arc of the datatype for some reason - let res = (self.return_type)(args)?; - Ok(res.as_ref().clone()) + self.inner.return_type(args) } /// Return a `PartitionEvaluator` for evaluating this window function pub fn partition_evaluator_factory(&self) -> Result> { - (self.partition_evaluator_factory)() + self.inner.partition_evaluator() } } @@ -198,6 +186,7 @@ where /// # use datafusion_common::{DataFusionError, plan_err, Result}; /// # use datafusion_expr::{col, Signature, Volatility, PartitionEvaluator, WindowFrame}; /// # use datafusion_expr::{WindowUDFImpl, WindowUDF}; +/// #[derive(Debug, Clone)] /// struct SmoothIt { /// signature: Signature /// }; @@ -236,7 +225,7 @@ where /// WindowFrame::new(false), /// ); /// ``` -pub trait WindowUDFImpl { +pub trait WindowUDFImpl: Debug + Send + Sync { /// Returns this object as an [`Any`] trait object fn as_any(&self) -> &dyn Any; @@ -254,3 +243,52 @@ pub trait WindowUDFImpl { /// Invoke the function, returning the [`PartitionEvaluator`] instance fn partition_evaluator(&self) -> Result>; } + +/// Implementation of [`WindowUDFImpl`] that wraps the function style pointers +/// of the older API (see +/// for more details) +pub struct WindowUDFLegacyWrapper { + /// name + name: String, + /// signature + signature: Signature, + /// Return type + return_type: ReturnTypeFunction, + /// Return the partition evaluator + partition_evaluator_factory: PartitionEvaluatorFactory, +} + +impl Debug for WindowUDFLegacyWrapper { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + f.debug_struct("WindowUDF") + .field("name", &self.name) + .field("signature", &self.signature) + .field("return_type", &"") + .field("partition_evaluator_factory", &"") + .finish_non_exhaustive() + } +} + +impl WindowUDFImpl for WindowUDFLegacyWrapper { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + // Old API returns an Arc of the datatype for some reason + let res = (self.return_type)(arg_types)?; + Ok(res.as_ref().clone()) + } + + fn partition_evaluator(&self) -> Result> { + (self.partition_evaluator_factory)() + } +} diff --git a/datafusion/optimizer/src/analyzer/mod.rs b/datafusion/optimizer/src/analyzer/mod.rs index 14d5ddf47378..9d47299a5616 100644 --- a/datafusion/optimizer/src/analyzer/mod.rs +++ b/datafusion/optimizer/src/analyzer/mod.rs @@ -17,6 +17,7 @@ pub mod count_wildcard_rule; pub mod inline_table_scan; +pub mod rewrite_expr; pub mod subquery; pub mod type_coercion; @@ -37,6 +38,8 @@ use log::debug; use std::sync::Arc; use std::time::Instant; +use self::rewrite_expr::OperatorToFunction; + /// [`AnalyzerRule`]s transform [`LogicalPlan`]s in some way to make /// the plan valid prior to the rest of the DataFusion optimization process. /// @@ -72,6 +75,9 @@ impl Analyzer { pub fn new() -> Self { let rules: Vec> = vec![ Arc::new(InlineTableScan::new()), + // OperatorToFunction should be run before TypeCoercion, since it rewrite based on the argument types (List or Scalar), + // and TypeCoercion may cast the argument types from Scalar to List. + Arc::new(OperatorToFunction::new()), Arc::new(TypeCoercion::new()), Arc::new(CountWildcardRule::new()), ]; diff --git a/datafusion/optimizer/src/analyzer/rewrite_expr.rs b/datafusion/optimizer/src/analyzer/rewrite_expr.rs new file mode 100644 index 000000000000..8f1c844ed062 --- /dev/null +++ b/datafusion/optimizer/src/analyzer/rewrite_expr.rs @@ -0,0 +1,321 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Analyzer rule for to replace operators with function calls (e.g `||` to array_concat`) + +use std::sync::Arc; + +use datafusion_common::config::ConfigOptions; +use datafusion_common::tree_node::TreeNodeRewriter; +use datafusion_common::utils::list_ndims; +use datafusion_common::DFSchema; +use datafusion_common::DFSchemaRef; +use datafusion_common::Result; +use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::expr_rewriter::rewrite_preserving_name; +use datafusion_expr::utils::merge_schema; +use datafusion_expr::BuiltinScalarFunction; +use datafusion_expr::Operator; +use datafusion_expr::ScalarFunctionDefinition; +use datafusion_expr::{BinaryExpr, Expr, LogicalPlan}; + +use super::AnalyzerRule; + +#[derive(Default)] +pub struct OperatorToFunction {} + +impl OperatorToFunction { + pub fn new() -> Self { + Self {} + } +} + +impl AnalyzerRule for OperatorToFunction { + fn name(&self) -> &str { + "operator_to_function" + } + + fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result { + analyze_internal(&plan) + } +} + +fn analyze_internal(plan: &LogicalPlan) -> Result { + // optimize child plans first + let new_inputs = plan + .inputs() + .iter() + .map(|p| analyze_internal(p)) + .collect::>>()?; + + // get schema representing all available input fields. This is used for data type + // resolution only, so order does not matter here + let mut schema = merge_schema(new_inputs.iter().collect()); + + if let LogicalPlan::TableScan(ts) = plan { + let source_schema = + DFSchema::try_from_qualified_schema(&ts.table_name, &ts.source.schema())?; + schema.merge(&source_schema); + } + + let mut expr_rewrite = OperatorToFunctionRewriter { + schema: Arc::new(schema), + }; + + let new_expr = plan + .expressions() + .into_iter() + .map(|expr| { + // ensure names don't change: + // https://github.com/apache/arrow-datafusion/issues/3555 + rewrite_preserving_name(expr, &mut expr_rewrite) + }) + .collect::>>()?; + + plan.with_new_exprs(new_expr, &new_inputs) +} + +pub(crate) struct OperatorToFunctionRewriter { + pub(crate) schema: DFSchemaRef, +} + +impl TreeNodeRewriter for OperatorToFunctionRewriter { + type N = Expr; + + fn mutate(&mut self, expr: Expr) -> Result { + match expr { + Expr::BinaryExpr(BinaryExpr { + ref left, + op, + ref right, + }) => { + if let Some(fun) = rewrite_array_concat_operator_to_func_for_column( + left.as_ref(), + op, + right.as_ref(), + self.schema.as_ref(), + )? + .or_else(|| { + rewrite_array_concat_operator_to_func( + left.as_ref(), + op, + right.as_ref(), + ) + }) { + // Convert &Box -> Expr + let left = (**left).clone(); + let right = (**right).clone(); + return Ok(Expr::ScalarFunction(ScalarFunction { + func_def: ScalarFunctionDefinition::BuiltIn(fun), + args: vec![left, right], + })); + } + + Ok(expr) + } + _ => Ok(expr), + } + } +} + +/// Summary of the logic below: +/// +/// 1) array || array -> array concat +/// +/// 2) array || scalar -> array append +/// +/// 3) scalar || array -> array prepend +/// +/// 4) (arry concat, array append, array prepend) || array -> array concat +/// +/// 5) (arry concat, array append, array prepend) || scalar -> array append +fn rewrite_array_concat_operator_to_func( + left: &Expr, + op: Operator, + right: &Expr, +) -> Option { + // Convert `Array StringConcat Array` to ScalarFunction::ArrayConcat + + if op != Operator::StringConcat { + return None; + } + + match (left, right) { + // Chain concat operator (a || b) || array, + // (arry concat, array append, array prepend) || array -> array concat + ( + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayConcat), + args: _left_args, + }), + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::MakeArray), + args: _right_args, + }), + ) + | ( + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayAppend), + args: _left_args, + }), + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::MakeArray), + args: _right_args, + }), + ) + | ( + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayPrepend), + args: _left_args, + }), + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::MakeArray), + args: _right_args, + }), + ) => Some(BuiltinScalarFunction::ArrayConcat), + // Chain concat operator (a || b) || scalar, + // (arry concat, array append, array prepend) || scalar -> array append + ( + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayConcat), + args: _left_args, + }), + _scalar, + ) + | ( + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayAppend), + args: _left_args, + }), + _scalar, + ) + | ( + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayPrepend), + args: _left_args, + }), + _scalar, + ) => Some(BuiltinScalarFunction::ArrayAppend), + // array || array -> array concat + ( + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::MakeArray), + args: _left_args, + }), + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::MakeArray), + args: _right_args, + }), + ) => Some(BuiltinScalarFunction::ArrayConcat), + // array || scalar -> array append + ( + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::MakeArray), + args: _left_args, + }), + _right_scalar, + ) => Some(BuiltinScalarFunction::ArrayAppend), + // scalar || array -> array prepend + ( + _left_scalar, + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::MakeArray), + args: _right_args, + }), + ) => Some(BuiltinScalarFunction::ArrayPrepend), + + _ => None, + } +} + +/// Summary of the logic below: +/// +/// 1) (arry concat, array append, array prepend) || column -> (array append, array concat) +/// +/// 2) column1 || column2 -> (array prepend, array append, array concat) +fn rewrite_array_concat_operator_to_func_for_column( + left: &Expr, + op: Operator, + right: &Expr, + schema: &DFSchema, +) -> Result> { + if op != Operator::StringConcat { + return Ok(None); + } + + match (left, right) { + // Column cases: + // 1) array_prepend/append/concat || column + ( + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayPrepend), + args: _left_args, + }), + Expr::Column(c), + ) + | ( + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayAppend), + args: _left_args, + }), + Expr::Column(c), + ) + | ( + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayConcat), + args: _left_args, + }), + Expr::Column(c), + ) => { + let d = schema.field_from_column(c)?.data_type(); + let ndim = list_ndims(d); + match ndim { + 0 => Ok(Some(BuiltinScalarFunction::ArrayAppend)), + _ => Ok(Some(BuiltinScalarFunction::ArrayConcat)), + } + } + // 2) select column1 || column2 + (Expr::Column(c1), Expr::Column(c2)) => { + let d1 = schema.field_from_column(c1)?.data_type(); + let d2 = schema.field_from_column(c2)?.data_type(); + let ndim1 = list_ndims(d1); + let ndim2 = list_ndims(d2); + match (ndim1, ndim2) { + (0, _) => Ok(Some(BuiltinScalarFunction::ArrayPrepend)), + (_, 0) => Ok(Some(BuiltinScalarFunction::ArrayAppend)), + _ => Ok(Some(BuiltinScalarFunction::ArrayConcat)), + } + } + _ => Ok(None), + } +} diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 4d54dad99670..6f1da5f4e6d9 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -811,6 +811,7 @@ mod test { static TEST_SIGNATURE: OnceLock = OnceLock::new(); + #[derive(Debug, Clone, Default)] struct TestScalarUDF {} impl ScalarUDFImpl for TestScalarUDF { fn as_any(&self) -> &dyn Any { diff --git a/datafusion/optimizer/src/optimize_projections.rs b/datafusion/optimizer/src/optimize_projections.rs index 891a909a3378..1d4eda0bd23e 100644 --- a/datafusion/optimizer/src/optimize_projections.rs +++ b/datafusion/optimizer/src/optimize_projections.rs @@ -583,11 +583,11 @@ fn rewrite_expr(expr: &Expr, input: &Projection) -> Result> { /// /// # Returns /// -/// If the function can safely infer all outer-referenced columns, returns a -/// `Some(HashSet)` containing these columns. Otherwise, returns `None`. -fn outer_columns(expr: &Expr) -> Option> { +/// returns a `HashSet` containing all outer-referenced columns. +fn outer_columns(expr: &Expr) -> HashSet { let mut columns = HashSet::new(); - outer_columns_helper(expr, &mut columns).then_some(columns) + outer_columns_helper(expr, &mut columns); + columns } /// A recursive subroutine that accumulates outer-referenced columns by the @@ -598,87 +598,104 @@ fn outer_columns(expr: &Expr) -> Option> { /// * `expr` - The expression to analyze for outer-referenced columns. /// * `columns` - A mutable reference to a `HashSet` where detected /// columns are collected. -/// -/// Returns `true` if it can safely collect all outer-referenced columns. -/// Otherwise, returns `false`. -fn outer_columns_helper(expr: &Expr, columns: &mut HashSet) -> bool { +fn outer_columns_helper(expr: &Expr, columns: &mut HashSet) { match expr { Expr::OuterReferenceColumn(_, col) => { columns.insert(col.clone()); - true } Expr::BinaryExpr(binary_expr) => { - outer_columns_helper(&binary_expr.left, columns) - && outer_columns_helper(&binary_expr.right, columns) + outer_columns_helper(&binary_expr.left, columns); + outer_columns_helper(&binary_expr.right, columns); } Expr::ScalarSubquery(subquery) => { let exprs = subquery.outer_ref_columns.iter(); - outer_columns_helper_multi(exprs, columns) + outer_columns_helper_multi(exprs, columns); } Expr::Exists(exists) => { let exprs = exists.subquery.outer_ref_columns.iter(); - outer_columns_helper_multi(exprs, columns) + outer_columns_helper_multi(exprs, columns); } Expr::Alias(alias) => outer_columns_helper(&alias.expr, columns), Expr::InSubquery(insubquery) => { let exprs = insubquery.subquery.outer_ref_columns.iter(); - outer_columns_helper_multi(exprs, columns) + outer_columns_helper_multi(exprs, columns); } - Expr::IsNotNull(expr) | Expr::IsNull(expr) => outer_columns_helper(expr, columns), Expr::Cast(cast) => outer_columns_helper(&cast.expr, columns), Expr::Sort(sort) => outer_columns_helper(&sort.expr, columns), Expr::AggregateFunction(aggregate_fn) => { - outer_columns_helper_multi(aggregate_fn.args.iter(), columns) - && aggregate_fn - .order_by - .as_ref() - .map_or(true, |obs| outer_columns_helper_multi(obs.iter(), columns)) - && aggregate_fn - .filter - .as_ref() - .map_or(true, |filter| outer_columns_helper(filter, columns)) + outer_columns_helper_multi(aggregate_fn.args.iter(), columns); + if let Some(filter) = aggregate_fn.filter.as_ref() { + outer_columns_helper(filter, columns); + } + if let Some(obs) = aggregate_fn.order_by.as_ref() { + outer_columns_helper_multi(obs.iter(), columns); + } } Expr::WindowFunction(window_fn) => { - outer_columns_helper_multi(window_fn.args.iter(), columns) - && outer_columns_helper_multi(window_fn.order_by.iter(), columns) - && outer_columns_helper_multi(window_fn.partition_by.iter(), columns) + outer_columns_helper_multi(window_fn.args.iter(), columns); + outer_columns_helper_multi(window_fn.order_by.iter(), columns); + outer_columns_helper_multi(window_fn.partition_by.iter(), columns); } Expr::GroupingSet(groupingset) => match groupingset { - GroupingSet::GroupingSets(multi_exprs) => multi_exprs - .iter() - .all(|e| outer_columns_helper_multi(e.iter(), columns)), + GroupingSet::GroupingSets(multi_exprs) => { + multi_exprs + .iter() + .for_each(|e| outer_columns_helper_multi(e.iter(), columns)); + } GroupingSet::Cube(exprs) | GroupingSet::Rollup(exprs) => { - outer_columns_helper_multi(exprs.iter(), columns) + outer_columns_helper_multi(exprs.iter(), columns); } }, Expr::ScalarFunction(scalar_fn) => { - outer_columns_helper_multi(scalar_fn.args.iter(), columns) + outer_columns_helper_multi(scalar_fn.args.iter(), columns); } Expr::Like(like) => { - outer_columns_helper(&like.expr, columns) - && outer_columns_helper(&like.pattern, columns) + outer_columns_helper(&like.expr, columns); + outer_columns_helper(&like.pattern, columns); } Expr::InList(in_list) => { - outer_columns_helper(&in_list.expr, columns) - && outer_columns_helper_multi(in_list.list.iter(), columns) + outer_columns_helper(&in_list.expr, columns); + outer_columns_helper_multi(in_list.list.iter(), columns); } Expr::Case(case) => { let when_then_exprs = case .when_then_expr .iter() .flat_map(|(first, second)| [first.as_ref(), second.as_ref()]); - outer_columns_helper_multi(when_then_exprs, columns) - && case - .expr - .as_ref() - .map_or(true, |expr| outer_columns_helper(expr, columns)) - && case - .else_expr - .as_ref() - .map_or(true, |expr| outer_columns_helper(expr, columns)) + outer_columns_helper_multi(when_then_exprs, columns); + if let Some(expr) = case.expr.as_ref() { + outer_columns_helper(expr, columns); + } + if let Some(expr) = case.else_expr.as_ref() { + outer_columns_helper(expr, columns); + } + } + Expr::SimilarTo(similar_to) => { + outer_columns_helper(&similar_to.expr, columns); + outer_columns_helper(&similar_to.pattern, columns); + } + Expr::TryCast(try_cast) => outer_columns_helper(&try_cast.expr, columns), + Expr::GetIndexedField(index) => outer_columns_helper(&index.expr, columns), + Expr::Between(between) => { + outer_columns_helper(&between.expr, columns); + outer_columns_helper(&between.low, columns); + outer_columns_helper(&between.high, columns); } - Expr::Column(_) | Expr::Literal(_) | Expr::Wildcard { .. } => true, - _ => false, + Expr::Not(expr) + | Expr::IsNotFalse(expr) + | Expr::IsFalse(expr) + | Expr::IsTrue(expr) + | Expr::IsNotTrue(expr) + | Expr::IsUnknown(expr) + | Expr::IsNotUnknown(expr) + | Expr::IsNotNull(expr) + | Expr::IsNull(expr) + | Expr::Negative(expr) => outer_columns_helper(expr, columns), + Expr::Column(_) + | Expr::Literal(_) + | Expr::Wildcard { .. } + | Expr::ScalarVariable { .. } + | Expr::Placeholder(_) => (), } } @@ -690,14 +707,11 @@ fn outer_columns_helper(expr: &Expr, columns: &mut HashSet) -> bool { /// * `exprs` - The expressions to analyze for outer-referenced columns. /// * `columns` - A mutable reference to a `HashSet` where detected /// columns are collected. -/// -/// Returns `true` if it can safely collect all outer-referenced columns. -/// Otherwise, returns `false`. fn outer_columns_helper_multi<'a>( - mut exprs: impl Iterator, + exprs: impl Iterator, columns: &mut HashSet, -) -> bool { - exprs.all(|e| outer_columns_helper(e, columns)) +) { + exprs.for_each(|e| outer_columns_helper(e, columns)); } /// Generates the required expressions (columns) that reside at `indices` of @@ -766,13 +780,7 @@ fn indices_referred_by_expr( ) -> Result> { let mut cols = expr.to_columns()?; // Get outer-referenced columns: - if let Some(outer_cols) = outer_columns(expr) { - cols.extend(outer_cols); - } else { - // Expression is not known to contain outer columns or not. Hence, do - // not assume anything and require all the schema indices at the input: - return Ok((0..input_schema.fields().len()).collect()); - } + cols.extend(outer_columns(expr)); Ok(cols .iter() .flat_map(|col| input_schema.index_of_column(col)) @@ -978,8 +986,8 @@ mod tests { use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::{Result, TableReference}; use datafusion_expr::{ - binary_expr, col, count, lit, logical_plan::builder::LogicalPlanBuilder, - table_scan, Expr, LogicalPlan, Operator, + binary_expr, col, count, lit, logical_plan::builder::LogicalPlanBuilder, not, + table_scan, try_cast, Expr, Like, LogicalPlan, Operator, }; fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { @@ -1060,4 +1068,187 @@ mod tests { \n TableScan: ?table? projection=[]"; assert_optimized_plan_equal(&plan, expected) } + + #[test] + fn test_struct_field_push_down() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, false), + Field::new_struct( + "s", + vec![ + Field::new("x", DataType::Int64, false), + Field::new("y", DataType::Int64, false), + ], + false, + ), + ])); + + let table_scan = table_scan(TableReference::none(), &schema, None)?.build()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("s").field("x")])? + .build()?; + let expected = "Projection: (?table?.s)[x]\ + \n TableScan: ?table? projection=[s]"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn test_neg_push_down() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![-col("a")])? + .build()?; + + let expected = "Projection: (- test.a)\ + \n TableScan: test projection=[a]"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn test_is_null() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a").is_null()])? + .build()?; + + let expected = "Projection: test.a IS NULL\ + \n TableScan: test projection=[a]"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn test_is_not_null() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a").is_not_null()])? + .build()?; + + let expected = "Projection: test.a IS NOT NULL\ + \n TableScan: test projection=[a]"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn test_is_true() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a").is_true()])? + .build()?; + + let expected = "Projection: test.a IS TRUE\ + \n TableScan: test projection=[a]"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn test_is_not_true() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a").is_not_true()])? + .build()?; + + let expected = "Projection: test.a IS NOT TRUE\ + \n TableScan: test projection=[a]"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn test_is_false() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a").is_false()])? + .build()?; + + let expected = "Projection: test.a IS FALSE\ + \n TableScan: test projection=[a]"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn test_is_not_false() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a").is_not_false()])? + .build()?; + + let expected = "Projection: test.a IS NOT FALSE\ + \n TableScan: test projection=[a]"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn test_is_unknown() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a").is_unknown()])? + .build()?; + + let expected = "Projection: test.a IS UNKNOWN\ + \n TableScan: test projection=[a]"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn test_is_not_unknown() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a").is_not_unknown()])? + .build()?; + + let expected = "Projection: test.a IS NOT UNKNOWN\ + \n TableScan: test projection=[a]"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn test_not() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![not(col("a"))])? + .build()?; + + let expected = "Projection: NOT test.a\ + \n TableScan: test projection=[a]"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn test_try_cast() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![try_cast(col("a"), DataType::Float64)])? + .build()?; + + let expected = "Projection: TRY_CAST(test.a AS Float64)\ + \n TableScan: test projection=[a]"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn test_similar_to() -> Result<()> { + let table_scan = test_table_scan()?; + let expr = Box::new(col("a")); + let pattern = Box::new(lit("[0-9]")); + let similar_to_expr = + Expr::SimilarTo(Like::new(false, expr, pattern, None, false)); + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![similar_to_expr])? + .build()?; + + let expected = "Projection: test.a SIMILAR TO Utf8(\"[0-9]\")\ + \n TableScan: test projection=[a]"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn test_between() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a").between(lit(1), lit(3))])? + .build()?; + + let expected = "Projection: test.a BETWEEN Int32(1) AND Int32(3)\ + \n TableScan: test projection=[a]"; + assert_optimized_plan_equal(&plan, expected) + } } diff --git a/datafusion/physical-expr/src/aggregate/count_distinct.rs b/datafusion/physical-expr/src/aggregate/count_distinct.rs index c2fd32a96c4f..f7c13948b2dc 100644 --- a/datafusion/physical-expr/src/aggregate/count_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/count_distinct.rs @@ -15,21 +15,32 @@ // specific language governing permissions and limitations // under the License. -use arrow::datatypes::{DataType, Field}; +use arrow::datatypes::{DataType, Field, TimeUnit}; +use arrow_array::types::{ + ArrowPrimitiveType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, + Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, + Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType, + TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, + TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, +}; +use arrow_array::PrimitiveArray; use std::any::Any; +use std::cmp::Eq; use std::fmt::Debug; +use std::hash::Hash; use std::sync::Arc; use ahash::RandomState; use arrow::array::{Array, ArrayRef}; use std::collections::HashSet; -use crate::aggregate::utils::down_cast_any_ref; +use crate::aggregate::utils::{down_cast_any_ref, Hashable}; use crate::expressions::format_state_name; use crate::{AggregateExpr, PhysicalExpr}; -use datafusion_common::Result; -use datafusion_common::ScalarValue; +use datafusion_common::cast::{as_list_array, as_primitive_array}; +use datafusion_common::utils::array_into_list_array; +use datafusion_common::{Result, ScalarValue}; use datafusion_expr::Accumulator; type DistinctScalarValues = ScalarValue; @@ -60,6 +71,18 @@ impl DistinctCount { } } +macro_rules! native_distinct_count_accumulator { + ($TYPE:ident) => {{ + Ok(Box::new(NativeDistinctCountAccumulator::<$TYPE>::new())) + }}; +} + +macro_rules! float_distinct_count_accumulator { + ($TYPE:ident) => {{ + Ok(Box::new(FloatDistinctCountAccumulator::<$TYPE>::new())) + }}; +} + impl AggregateExpr for DistinctCount { /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { @@ -83,10 +106,57 @@ impl AggregateExpr for DistinctCount { } fn create_accumulator(&self) -> Result> { - Ok(Box::new(DistinctCountAccumulator { - values: HashSet::default(), - state_data_type: self.state_data_type.clone(), - })) + use DataType::*; + use TimeUnit::*; + + match &self.state_data_type { + Int8 => native_distinct_count_accumulator!(Int8Type), + Int16 => native_distinct_count_accumulator!(Int16Type), + Int32 => native_distinct_count_accumulator!(Int32Type), + Int64 => native_distinct_count_accumulator!(Int64Type), + UInt8 => native_distinct_count_accumulator!(UInt8Type), + UInt16 => native_distinct_count_accumulator!(UInt16Type), + UInt32 => native_distinct_count_accumulator!(UInt32Type), + UInt64 => native_distinct_count_accumulator!(UInt64Type), + Decimal128(_, _) => native_distinct_count_accumulator!(Decimal128Type), + Decimal256(_, _) => native_distinct_count_accumulator!(Decimal256Type), + + Date32 => native_distinct_count_accumulator!(Date32Type), + Date64 => native_distinct_count_accumulator!(Date64Type), + Time32(Millisecond) => { + native_distinct_count_accumulator!(Time32MillisecondType) + } + Time32(Second) => { + native_distinct_count_accumulator!(Time32SecondType) + } + Time64(Microsecond) => { + native_distinct_count_accumulator!(Time64MicrosecondType) + } + Time64(Nanosecond) => { + native_distinct_count_accumulator!(Time64NanosecondType) + } + Timestamp(Microsecond, _) => { + native_distinct_count_accumulator!(TimestampMicrosecondType) + } + Timestamp(Millisecond, _) => { + native_distinct_count_accumulator!(TimestampMillisecondType) + } + Timestamp(Nanosecond, _) => { + native_distinct_count_accumulator!(TimestampNanosecondType) + } + Timestamp(Second, _) => { + native_distinct_count_accumulator!(TimestampSecondType) + } + + Float16 => float_distinct_count_accumulator!(Float16Type), + Float32 => float_distinct_count_accumulator!(Float32Type), + Float64 => float_distinct_count_accumulator!(Float64Type), + + _ => Ok(Box::new(DistinctCountAccumulator { + values: HashSet::default(), + state_data_type: self.state_data_type.clone(), + })), + } } fn name(&self) -> &str { @@ -192,6 +262,182 @@ impl Accumulator for DistinctCountAccumulator { } } +#[derive(Debug)] +struct NativeDistinctCountAccumulator +where + T: ArrowPrimitiveType + Send, + T::Native: Eq + Hash, +{ + values: HashSet, +} + +impl NativeDistinctCountAccumulator +where + T: ArrowPrimitiveType + Send, + T::Native: Eq + Hash, +{ + fn new() -> Self { + Self { + values: HashSet::default(), + } + } +} + +impl Accumulator for NativeDistinctCountAccumulator +where + T: ArrowPrimitiveType + Send + Debug, + T::Native: Eq + Hash, +{ + fn state(&self) -> Result> { + let arr = Arc::new(PrimitiveArray::::from_iter_values( + self.values.iter().cloned(), + )) as ArrayRef; + let list = Arc::new(array_into_list_array(arr)) as ArrayRef; + Ok(vec![ScalarValue::List(list)]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if values.is_empty() { + return Ok(()); + } + + let arr = as_primitive_array::(&values[0])?; + arr.iter().for_each(|value| { + if let Some(value) = value { + self.values.insert(value); + } + }); + + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + if states.is_empty() { + return Ok(()); + } + assert_eq!( + states.len(), + 1, + "count_distinct states must be single array" + ); + + let arr = as_list_array(&states[0])?; + arr.iter().try_for_each(|maybe_list| { + if let Some(list) = maybe_list { + let list = as_primitive_array::(&list)?; + self.values.extend(list.values()) + }; + Ok(()) + }) + } + + fn evaluate(&self) -> Result { + Ok(ScalarValue::Int64(Some(self.values.len() as i64))) + } + + fn size(&self) -> usize { + let estimated_buckets = (self.values.len().checked_mul(8).unwrap_or(usize::MAX) + / 7) + .next_power_of_two(); + + // Size of accumulator + // + size of entry * number of buckets + // + 1 byte for each bucket + // + fixed size of HashSet + std::mem::size_of_val(self) + + std::mem::size_of::() * estimated_buckets + + estimated_buckets + + std::mem::size_of_val(&self.values) + } +} + +#[derive(Debug)] +struct FloatDistinctCountAccumulator +where + T: ArrowPrimitiveType + Send, +{ + values: HashSet, RandomState>, +} + +impl FloatDistinctCountAccumulator +where + T: ArrowPrimitiveType + Send, +{ + fn new() -> Self { + Self { + values: HashSet::default(), + } + } +} + +impl Accumulator for FloatDistinctCountAccumulator +where + T: ArrowPrimitiveType + Send + Debug, +{ + fn state(&self) -> Result> { + let arr = Arc::new(PrimitiveArray::::from_iter_values( + self.values.iter().map(|v| v.0), + )) as ArrayRef; + let list = Arc::new(array_into_list_array(arr)) as ArrayRef; + Ok(vec![ScalarValue::List(list)]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if values.is_empty() { + return Ok(()); + } + + let arr = as_primitive_array::(&values[0])?; + arr.iter().for_each(|value| { + if let Some(value) = value { + self.values.insert(Hashable(value)); + } + }); + + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + if states.is_empty() { + return Ok(()); + } + assert_eq!( + states.len(), + 1, + "count_distinct states must be single array" + ); + + let arr = as_list_array(&states[0])?; + arr.iter().try_for_each(|maybe_list| { + if let Some(list) = maybe_list { + let list = as_primitive_array::(&list)?; + self.values + .extend(list.values().iter().map(|v| Hashable(*v))); + }; + Ok(()) + }) + } + + fn evaluate(&self) -> Result { + Ok(ScalarValue::Int64(Some(self.values.len() as i64))) + } + + fn size(&self) -> usize { + let estimated_buckets = (self.values.len().checked_mul(8).unwrap_or(usize::MAX) + / 7) + .next_power_of_two(); + + // Size of accumulator + // + size of entry * number of buckets + // + 1 byte for each bucket + // + fixed size of HashSet + std::mem::size_of_val(self) + + std::mem::size_of::() * estimated_buckets + + estimated_buckets + + std::mem::size_of_val(&self.values) + } +} + #[cfg(test)] mod tests { use crate::expressions::NoOp; @@ -206,6 +452,8 @@ mod tests { Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type, }; + use arrow_array::Decimal256Array; + use arrow_buffer::i256; use datafusion_common::cast::{as_boolean_array, as_list_array, as_primitive_array}; use datafusion_common::internal_err; use datafusion_common::DataFusionError; @@ -367,6 +615,35 @@ mod tests { }}; } + macro_rules! test_count_distinct_update_batch_bigint { + ($ARRAY_TYPE:ident, $DATA_TYPE:ident, $PRIM_TYPE:ty) => {{ + let values: Vec> = vec![ + Some(i256::from(1)), + Some(i256::from(1)), + None, + Some(i256::from(3)), + Some(i256::from(2)), + None, + Some(i256::from(2)), + Some(i256::from(3)), + Some(i256::from(1)), + ]; + + let arrays = vec![Arc::new($ARRAY_TYPE::from(values)) as ArrayRef]; + + let (states, result) = run_update_batch(&arrays)?; + + let mut state_vec = state_to_vec_primitive!(&states[0], $DATA_TYPE); + state_vec.sort(); + + assert_eq!(states.len(), 1); + assert_eq!(state_vec, vec![i256::from(1), i256::from(2), i256::from(3)]); + assert_eq!(result, ScalarValue::Int64(Some(3))); + + Ok(()) + }}; + } + #[test] fn count_distinct_update_batch_i8() -> Result<()> { test_count_distinct_update_batch_numeric!(Int8Array, Int8Type, i8) @@ -417,6 +694,11 @@ mod tests { test_count_distinct_update_batch_floating_point!(Float64Array, Float64Type, f64) } + #[test] + fn count_distinct_update_batch_i256() -> Result<()> { + test_count_distinct_update_batch_bigint!(Decimal256Array, Decimal256Type, i256) + } + #[test] fn count_distinct_update_batch_boolean() -> Result<()> { let get_count = |data: BooleanArray| -> Result<(Vec, i64)> { diff --git a/datafusion/physical-expr/src/aggregate/sum_distinct.rs b/datafusion/physical-expr/src/aggregate/sum_distinct.rs index 0cf4a90ab8cc..6dbb39224629 100644 --- a/datafusion/physical-expr/src/aggregate/sum_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/sum_distinct.rs @@ -25,11 +25,11 @@ use arrow::array::{Array, ArrayRef}; use arrow_array::cast::AsArray; use arrow_array::types::*; use arrow_array::{ArrowNativeTypeOp, ArrowPrimitiveType}; -use arrow_buffer::{ArrowNativeType, ToByteSlice}; +use arrow_buffer::ArrowNativeType; use std::collections::HashSet; use crate::aggregate::sum::downcast_sum; -use crate::aggregate::utils::down_cast_any_ref; +use crate::aggregate::utils::{down_cast_any_ref, Hashable}; use crate::{AggregateExpr, PhysicalExpr}; use datafusion_common::{not_impl_err, DataFusionError, Result, ScalarValue}; use datafusion_expr::type_coercion::aggregates::sum_return_type; @@ -119,24 +119,6 @@ impl PartialEq for DistinctSum { } } -/// A wrapper around a type to provide hash for floats -#[derive(Copy, Clone)] -struct Hashable(T); - -impl std::hash::Hash for Hashable { - fn hash(&self, state: &mut H) { - self.0.to_byte_slice().hash(state) - } -} - -impl PartialEq for Hashable { - fn eq(&self, other: &Self) -> bool { - self.0.is_eq(other.0) - } -} - -impl Eq for Hashable {} - struct DistinctSumAccumulator { values: HashSet, RandomState>, data_type: DataType, diff --git a/datafusion/physical-expr/src/aggregate/utils.rs b/datafusion/physical-expr/src/aggregate/utils.rs index 9777158da133..d73c46a0f687 100644 --- a/datafusion/physical-expr/src/aggregate/utils.rs +++ b/datafusion/physical-expr/src/aggregate/utils.rs @@ -28,7 +28,7 @@ use arrow_array::types::{ Decimal128Type, DecimalType, TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, }; -use arrow_buffer::ArrowNativeType; +use arrow_buffer::{ArrowNativeType, ToByteSlice}; use arrow_schema::{DataType, Field, SortOptions}; use datafusion_common::{exec_err, DataFusionError, Result}; use datafusion_expr::Accumulator; @@ -211,3 +211,21 @@ pub(crate) fn ordering_fields( pub fn get_sort_options(ordering_req: &[PhysicalSortExpr]) -> Vec { ordering_req.iter().map(|item| item.options).collect() } + +/// A wrapper around a type to provide hash for floats +#[derive(Copy, Clone, Debug)] +pub(crate) struct Hashable(pub T); + +impl std::hash::Hash for Hashable { + fn hash(&self, state: &mut H) { + self.0.to_byte_slice().hash(state) + } +} + +impl PartialEq for Hashable { + fn eq(&self, other: &Self) -> bool { + self.0.is_eq(other.0) + } +} + +impl Eq for Hashable {} diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 15330af640ae..9665116b04ab 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -529,7 +529,7 @@ fn general_except( pub fn array_except(args: &[ArrayRef]) -> Result { if args.len() != 2 { - return internal_err!("array_except needs two arguments"); + return exec_err!("array_except needs two arguments"); } let array1 = &args[0]; @@ -894,7 +894,7 @@ pub fn gen_range(args: &[ArrayRef]) -> Result { as_int64_array(&args[1])?, Some(as_int64_array(&args[2])?), ), - _ => return internal_err!("gen_range expects 1 to 3 arguments"), + _ => return exec_err!("gen_range expects 1 to 3 arguments"), }; let mut values = vec![]; @@ -948,7 +948,7 @@ pub fn array_sort(args: &[ArrayRef]) -> Result { nulls_first: order_nulls_first(nulls_first)?, }) } - _ => return internal_err!("array_sort expects 1 to 3 arguments"), + _ => return exec_err!("array_sort expects 1 to 3 arguments"), }; let list_array = as_list_array(&args[0])?; @@ -994,7 +994,7 @@ fn order_desc(modifier: &str) -> Result { match modifier.to_uppercase().as_str() { "DESC" => Ok(true), "ASC" => Ok(false), - _ => internal_err!("the second parameter of array_sort expects DESC or ASC"), + _ => exec_err!("the second parameter of array_sort expects DESC or ASC"), } } @@ -1002,7 +1002,7 @@ fn order_nulls_first(modifier: &str) -> Result { match modifier.to_uppercase().as_str() { "NULLS FIRST" => Ok(true), "NULLS LAST" => Ok(false), - _ => internal_err!( + _ => exec_err!( "the third parameter of array_sort expects NULLS FIRST or NULLS LAST" ), } @@ -1208,7 +1208,7 @@ pub fn array_empty(args: &[ArrayRef]) -> Result { match array_type { DataType::List(_) => array_empty_dispatch::(&args[0]), DataType::LargeList(_) => array_empty_dispatch::(&args[0]), - _ => internal_err!("array_empty does not support type '{array_type:?}'."), + _ => exec_err!("array_empty does not support type '{array_type:?}'."), } } @@ -1598,7 +1598,9 @@ fn array_remove_internal( let list_array = array.as_list::(); general_remove::(list_array, element_array, arr_n) } - _ => internal_err!("array_remove_all expects a list array"), + array_type => { + exec_err!("array_remove_all does not support type '{array_type:?}'.") + } } } @@ -2022,8 +2024,21 @@ pub fn array_to_string(args: &[ArrayRef]) -> Result { ) -> Result<&mut String> { match arr.data_type() { DataType::List(..) => { - let list_array = downcast_arg!(arr, ListArray); + let list_array = as_list_array(&arr)?; + for i in 0..list_array.len() { + compute_array_to_string( + arg, + list_array.value(i), + delimiter.clone(), + null_string.clone(), + with_null_string, + )?; + } + Ok(arg) + } + DataType::LargeList(..) => { + let list_array = as_large_list_array(&arr)?; for i in 0..list_array.len() { compute_array_to_string( arg, @@ -2055,35 +2070,61 @@ pub fn array_to_string(args: &[ArrayRef]) -> Result { } } - let mut arg = String::from(""); - let mut res: Vec> = Vec::new(); - - match arr.data_type() { - DataType::List(_) | DataType::LargeList(_) | DataType::FixedSizeList(_, _) => { - let list_array = arr.as_list::(); - for (arr, &delimiter) in list_array.iter().zip(delimiters.iter()) { - if let (Some(arr), Some(delimiter)) = (arr, delimiter) { - arg = String::from(""); - let s = compute_array_to_string( - &mut arg, - arr, - delimiter.to_string(), - null_string.clone(), - with_null_string, - )? - .clone(); - - if let Some(s) = s.strip_suffix(delimiter) { - res.push(Some(s.to_string())); - } else { - res.push(Some(s)); - } + fn generate_string_array( + list_arr: &GenericListArray, + delimiters: Vec>, + null_string: String, + with_null_string: bool, + ) -> Result { + let mut res: Vec> = Vec::new(); + for (arr, &delimiter) in list_arr.iter().zip(delimiters.iter()) { + if let (Some(arr), Some(delimiter)) = (arr, delimiter) { + let mut arg = String::from(""); + let s = compute_array_to_string( + &mut arg, + arr, + delimiter.to_string(), + null_string.clone(), + with_null_string, + )? + .clone(); + + if let Some(s) = s.strip_suffix(delimiter) { + res.push(Some(s.to_string())); } else { - res.push(None); + res.push(Some(s)); } + } else { + res.push(None); } } + + Ok(StringArray::from(res)) + } + + let arr_type = arr.data_type(); + let string_arr = match arr_type { + DataType::List(_) | DataType::FixedSizeList(_, _) => { + let list_array = as_list_array(&arr)?; + generate_string_array::( + list_array, + delimiters, + null_string, + with_null_string, + )? + } + DataType::LargeList(_) => { + let list_array = as_large_list_array(&arr)?; + generate_string_array::( + list_array, + delimiters, + null_string, + with_null_string, + )? + } _ => { + let mut arg = String::from(""); + let mut res: Vec> = Vec::new(); // delimiter length is 1 assert_eq!(delimiters.len(), 1); let delimiter = delimiters[0].unwrap(); @@ -2102,10 +2143,11 @@ pub fn array_to_string(args: &[ArrayRef]) -> Result { } else { res.push(Some(s)); } + StringArray::from(res) } - } + }; - Ok(Arc::new(StringArray::from(res))) + Ok(Arc::new(string_arr)) } /// Cardinality SQL function @@ -2114,16 +2156,31 @@ pub fn cardinality(args: &[ArrayRef]) -> Result { return exec_err!("cardinality expects one argument"); } - let list_array = as_list_array(&args[0])?.clone(); + match &args[0].data_type() { + DataType::List(_) => { + let list_array = as_list_array(&args[0])?; + generic_list_cardinality::(list_array) + } + DataType::LargeList(_) => { + let list_array = as_large_list_array(&args[0])?; + generic_list_cardinality::(list_array) + } + other => { + exec_err!("cardinality does not support type '{:?}'", other) + } + } +} - let result = list_array +fn generic_list_cardinality( + array: &GenericListArray, +) -> Result { + let result = array .iter() .map(|arr| match compute_array_dims(arr)? { Some(vector) => Ok(Some(vector.iter().map(|x| x.unwrap()).product::())), None => Ok(None), }) .collect::>()?; - Ok(Arc::new(result) as ArrayRef) } @@ -2205,10 +2262,7 @@ pub fn array_length(args: &[ArrayRef]) -> Result { match &args[0].data_type() { DataType::List(_) => array_length_dispatch::(args), DataType::LargeList(_) => array_length_dispatch::(args), - _ => internal_err!( - "array_length does not support type '{:?}'", - args[0].data_type() - ), + array_type => exec_err!("array_length does not support type '{array_type:?}'"), } } @@ -2233,11 +2287,8 @@ pub fn array_dims(args: &[ArrayRef]) -> Result { .map(compute_array_dims) .collect::>>()? } - _ => { - return exec_err!( - "array_dims does not support type '{:?}'", - args[0].data_type() - ); + array_type => { + return exec_err!("array_dims does not support type '{array_type:?}'"); } }; @@ -2386,7 +2437,7 @@ pub fn array_has_any(args: &[ArrayRef]) -> Result { DataType::LargeList(_) => { general_array_has_dispatch::(&args[0], &args[1], ComparisonType::Any) } - _ => internal_err!("array_has_any does not support type '{array_type:?}'."), + _ => exec_err!("array_has_any does not support type '{array_type:?}'."), } } @@ -2405,7 +2456,7 @@ pub fn array_has_all(args: &[ArrayRef]) -> Result { DataType::LargeList(_) => { general_array_has_dispatch::(&args[0], &args[1], ComparisonType::All) } - _ => internal_err!("array_has_all does not support type '{array_type:?}'."), + _ => exec_err!("array_has_all does not support type '{array_type:?}'."), } } @@ -2488,7 +2539,7 @@ pub fn string_to_array(args: &[ArrayRef]) -> Result { - return internal_err!( + return exec_err!( "Expect string_to_array function to take two or three parameters" ) } @@ -2556,7 +2607,7 @@ pub fn array_distinct(args: &[ArrayRef]) -> Result { let array = as_large_list_array(&args[0])?; general_array_distinct(array, field) } - _ => internal_err!("array_distinct only support list array"), + array_type => exec_err!("array_distinct does not support type '{array_type:?}'"), } } diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index c17081398cb8..8c4078dbce8c 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -20,9 +20,7 @@ mod kernels; use std::hash::{Hash, Hasher}; use std::{any::Any, sync::Arc}; -use crate::array_expressions::{ - array_append, array_concat, array_has_all, array_prepend, -}; +use crate::array_expressions::array_has_all; use crate::expressions::datum::{apply, apply_cmp}; use crate::intervals::cp_solver::{propagate_arithmetic, propagate_comparison}; use crate::physical_expr::down_cast_any_ref; @@ -598,12 +596,7 @@ impl BinaryExpr { BitwiseXor => bitwise_xor_dyn(left, right), BitwiseShiftRight => bitwise_shift_right_dyn(left, right), BitwiseShiftLeft => bitwise_shift_left_dyn(left, right), - StringConcat => match (left_data_type, right_data_type) { - (DataType::List(_), DataType::List(_)) => array_concat(&[left, right]), - (DataType::List(_), _) => array_append(&[left, right]), - (_, DataType::List(_)) => array_prepend(&[left, right]), - _ => binary_string_array_op!(left, right, concat_elements), - }, + StringConcat => binary_string_array_op!(left, right, concat_elements), AtArrow => array_has_all(&[left, right]), ArrowAt => array_has_all(&[right, left]), } diff --git a/datafusion/physical-expr/src/udf.rs b/datafusion/physical-expr/src/udf.rs index 0ec1cf3f256b..9daa9eb173dd 100644 --- a/datafusion/physical-expr/src/udf.rs +++ b/datafusion/physical-expr/src/udf.rs @@ -36,7 +36,7 @@ pub fn create_physical_expr( Ok(Arc::new(ScalarFunctionExpr::new( fun.name(), - fun.fun().clone(), + fun.fun(), input_phy_exprs.to_vec(), fun.return_type(&input_exprs_types)?, None, diff --git a/datafusion/physical-plan/src/aggregates/group_values/row.rs b/datafusion/physical-plan/src/aggregates/group_values/row.rs index e7c7a42cf902..10ff9edb8912 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/row.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/row.rs @@ -17,18 +17,22 @@ use crate::aggregates::group_values::GroupValues; use ahash::RandomState; +use arrow::compute::cast; use arrow::record_batch::RecordBatch; use arrow::row::{RowConverter, Rows, SortField}; -use arrow_array::ArrayRef; -use arrow_schema::SchemaRef; +use arrow_array::{Array, ArrayRef}; +use arrow_schema::{DataType, SchemaRef}; use datafusion_common::hash_utils::create_hashes; -use datafusion_common::Result; +use datafusion_common::{DataFusionError, Result}; use datafusion_execution::memory_pool::proxy::{RawTableAllocExt, VecAllocExt}; use datafusion_physical_expr::EmitTo; use hashbrown::raw::RawTable; /// A [`GroupValues`] making use of [`Rows`] pub struct GroupValuesRows { + /// The output schema + schema: SchemaRef, + /// Converter for the group values row_converter: RowConverter, @@ -75,6 +79,7 @@ impl GroupValuesRows { let map = RawTable::with_capacity(0); Ok(Self { + schema, row_converter, map, map_size: 0, @@ -165,7 +170,7 @@ impl GroupValues for GroupValuesRows { .take() .expect("Can not emit from empty rows"); - let output = match emit_to { + let mut output = match emit_to { EmitTo::All => { let output = self.row_converter.convert_rows(&group_values)?; group_values.clear(); @@ -198,6 +203,20 @@ impl GroupValues for GroupValuesRows { } }; + // TODO: Materialize dictionaries in group keys (#7647) + for (field, array) in self.schema.fields.iter().zip(&mut output) { + let expected = field.data_type(); + if let DataType::Dictionary(_, v) = expected { + let actual = array.data_type(); + if v.as_ref() != actual { + return Err(DataFusionError::Internal(format!( + "Converted group rows expected dictionary of {v} got {actual}" + ))); + } + *array = cast(array.as_ref(), expected)?; + } + } + self.group_values = Some(group_values); Ok(output) } diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index a38044de02e3..0b94dd01cfd4 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -36,7 +36,6 @@ use crate::{ use arrow::array::ArrayRef; use arrow::datatypes::{Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; -use arrow_schema::DataType; use datafusion_common::stats::Precision; use datafusion_common::{not_impl_err, plan_err, DataFusionError, Result}; use datafusion_execution::TaskContext; @@ -254,9 +253,6 @@ pub struct AggregateExec { limit: Option, /// Input plan, could be a partial aggregate or the input to the aggregate pub input: Arc, - /// Original aggregation schema, could be different from `schema` before dictionary group - /// keys get materialized - original_schema: SchemaRef, /// Schema after the aggregate is applied schema: SchemaRef, /// Input schema before any aggregation is applied. For partial aggregate this will be the @@ -287,7 +283,7 @@ impl AggregateExec { input: Arc, input_schema: SchemaRef, ) -> Result { - let original_schema = create_schema( + let schema = create_schema( &input.schema(), &group_by.expr, &aggr_expr, @@ -295,11 +291,7 @@ impl AggregateExec { mode, )?; - let schema = Arc::new(materialize_dict_group_keys( - &original_schema, - group_by.expr.len(), - )); - let original_schema = Arc::new(original_schema); + let schema = Arc::new(schema); AggregateExec::try_new_with_schema( mode, group_by, @@ -308,7 +300,6 @@ impl AggregateExec { input, input_schema, schema, - original_schema, ) } @@ -329,7 +320,6 @@ impl AggregateExec { input: Arc, input_schema: SchemaRef, schema: SchemaRef, - original_schema: SchemaRef, ) -> Result { let input_eq_properties = input.equivalence_properties(); // Get GROUP BY expressions: @@ -382,7 +372,6 @@ impl AggregateExec { aggr_expr, filter_expr, input, - original_schema, schema, input_schema, projection_mapping, @@ -693,7 +682,7 @@ impl ExecutionPlan for AggregateExec { children[0].clone(), self.input_schema.clone(), self.schema.clone(), - self.original_schema.clone(), + //self.original_schema.clone(), )?; me.limit = self.limit; Ok(Arc::new(me)) @@ -800,24 +789,6 @@ fn create_schema( Ok(Schema::new(fields)) } -/// returns schema with dictionary group keys materialized as their value types -/// The actual convertion happens in `RowConverter` and we don't do unnecessary -/// conversion back into dictionaries -fn materialize_dict_group_keys(schema: &Schema, group_count: usize) -> Schema { - let fields = schema - .fields - .iter() - .enumerate() - .map(|(i, field)| match field.data_type() { - DataType::Dictionary(_, value_data_type) if i < group_count => { - Field::new(field.name(), *value_data_type.clone(), field.is_nullable()) - } - _ => Field::clone(field), - }) - .collect::>(); - Schema::new(fields) -} - fn group_schema(schema: &Schema, group_count: usize) -> SchemaRef { let group_fields = schema.fields()[0..group_count].to_vec(); Arc::new(Schema::new(group_fields)) diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index 89614fd3020c..6a0c02f5caf3 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -324,9 +324,7 @@ impl GroupedHashAggregateStream { .map(create_group_accumulator) .collect::>()?; - // we need to use original schema so RowConverter in group_values below - // will do the proper coversion of dictionaries into value types - let group_schema = group_schema(&agg.original_schema, agg_group_by.expr.len()); + let group_schema = group_schema(&agg_schema, agg_group_by.expr.len()); let spill_expr = group_schema .fields .into_iter() diff --git a/datafusion/physical-plan/src/lib.rs b/datafusion/physical-plan/src/lib.rs index 6c9e97e03cb7..1dd1392b9d86 100644 --- a/datafusion/physical-plan/src/lib.rs +++ b/datafusion/physical-plan/src/lib.rs @@ -288,6 +288,24 @@ pub trait ExecutionPlan: Debug + DisplayAs + Send + Sync { /// [`TryStreamExt`]: futures::stream::TryStreamExt /// [`RecordBatchStreamAdapter`]: crate::stream::RecordBatchStreamAdapter /// + /// # Cancellation / Aborting Execution + /// + /// The [`Stream`] that is returned must ensure that any allocated resources + /// are freed when the stream itself is dropped. This is particularly + /// important for [`spawn`]ed tasks or threads. Unless care is taken to + /// "abort" such tasks, they may continue to consume resources even after + /// the plan is dropped, generating intermediate results that are never + /// used. + /// + /// See [`AbortOnDropSingle`], [`AbortOnDropMany`] and + /// [`RecordBatchReceiverStreamBuilder`] for structures to help ensure all + /// background tasks are cancelled. + /// + /// [`spawn`]: tokio::task::spawn + /// [`AbortOnDropSingle`]: crate::common::AbortOnDropSingle + /// [`AbortOnDropMany`]: crate::common::AbortOnDropMany + /// [`RecordBatchReceiverStreamBuilder`]: crate::stream::RecordBatchReceiverStreamBuilder + /// /// # Implementation Examples /// /// While `async` `Stream`s have a non trivial learning curve, the @@ -491,7 +509,12 @@ pub async fn collect( common::collect(stream).await } -/// Execute the [ExecutionPlan] and return a single stream of results +/// Execute the [ExecutionPlan] and return a single stream of results. +/// +/// # Aborting Execution +/// +/// Dropping the stream will abort the execution of the query, and free up +/// any allocated resources pub fn execute_stream( plan: Arc, context: Arc, @@ -549,7 +572,13 @@ pub async fn collect_partitioned( Ok(batches) } -/// Execute the [ExecutionPlan] and return a vec with one stream per output partition +/// Execute the [ExecutionPlan] and return a vec with one stream per output +/// partition +/// +/// # Aborting Execution +/// +/// Dropping the stream will abort the execution of the query, and free up +/// any allocated resources pub fn execute_stream_partitioned( plan: Arc, context: Arc, diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 402781e17e6f..03daf535f201 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -1787,6 +1787,7 @@ fn roundtrip_window() { } } + #[derive(Debug, Clone)] struct SimpleWindowUDF { signature: Signature, } diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index 27351e10eb34..9fded63af3fc 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -98,11 +98,13 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { StackEntry::Operator(op) => { let right = eval_stack.pop().unwrap(); let left = eval_stack.pop().unwrap(); + let expr = Expr::BinaryExpr(BinaryExpr::new( Box::new(left), op, Box::new(right), )); + eval_stack.push(expr); } } diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index c5c30e3a2253..a04df5589b85 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -250,7 +250,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // Default expressions are restricted, column references are not allowed let empty_schema = DFSchema::empty(); let error_desc = |e: DataFusionError| match e { - DataFusionError::SchemaError(SchemaError::FieldNotFound { .. }) => { + DataFusionError::SchemaError(SchemaError::FieldNotFound { .. }, _) => { plan_datafusion_err!( "Column reference is not allowed in the DEFAULT expression : {}", e diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index b96553ffbf86..b9fb4c65dc2c 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -31,9 +31,10 @@ use arrow_schema::DataType; use datafusion_common::file_options::StatementOptions; use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::{ - not_impl_err, plan_datafusion_err, plan_err, unqualified_field_not_found, Column, - Constraints, DFField, DFSchema, DFSchemaRef, DataFusionError, OwnedTableReference, - Result, ScalarValue, SchemaReference, TableReference, ToDFSchema, + not_impl_err, plan_datafusion_err, plan_err, schema_err, unqualified_field_not_found, + Column, Constraints, DFField, DFSchema, DFSchemaRef, DataFusionError, + OwnedTableReference, Result, ScalarValue, SchemaError, SchemaReference, + TableReference, ToDFSchema, }; use datafusion_expr::dml::{CopyOptions, CopyTo}; use datafusion_expr::expr_rewriter::normalize_col_with_schemas_and_ambiguity_check; @@ -1138,11 +1139,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .index_of_column_by_name(None, &c)? .ok_or_else(|| unqualified_field_not_found(&c, &table_schema))?; if value_indices[column_index].is_some() { - return Err(DataFusionError::SchemaError( - datafusion_common::SchemaError::DuplicateUnqualifiedField { - name: c, - }, - )); + return schema_err!(SchemaError::DuplicateUnqualifiedField { + name: c, + }); } else { value_indices[column_index] = Some(i); } diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 48ba50145308..4de08a7124cf 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -756,9 +756,11 @@ fn join_with_ambiguous_column() { #[test] fn where_selection_with_ambiguous_column() { let sql = "SELECT * FROM person a, person b WHERE id = id + 1"; - let err = logical_plan(sql).expect_err("query should have failed"); + let err = logical_plan(sql) + .expect_err("query should have failed") + .strip_backtrace(); assert_eq!( - "SchemaError(AmbiguousReference { field: Column { relation: None, name: \"id\" } })", + "\"Schema error: Ambiguous reference to unqualified field id\"", format!("{err:?}") ); } diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 78575c9dffc5..aa512f6e2600 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -2469,11 +2469,11 @@ select max(x_dict) from value_dict group by x_dict % 2 order by max(x_dict); query T select arrow_typeof(x_dict) from value_dict group by x_dict; ---- -Int32 -Int32 -Int32 -Int32 -Int32 +Dictionary(Int64, Int32) +Dictionary(Int64, Int32) +Dictionary(Int64, Int32) +Dictionary(Int64, Int32) +Dictionary(Int64, Int32) statement ok drop table value diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 7cee615a5729..d864091a8588 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -3238,30 +3238,55 @@ select list_to_string(['h', 'e', 'l', 'l', 'o'], ','), list_to_string([1, 2, 3, ---- h,e,l,l,o 1-2-3-4-5 1|2|3 +query TTT +select list_to_string(arrow_cast(['h', 'e', 'l', 'l', 'o'], 'LargeList(Utf8)'), ','), list_to_string(arrow_cast([1, 2, 3, 4, 5], 'LargeList(Int64)'), '-'), list_to_string(arrow_cast([1.0, 2.0, 3.0], 'LargeList(Float64)'), '|'); +---- +h,e,l,l,o 1-2-3-4-5 1|2|3 + # array_join scalar function #5 (function alias `array_to_string`) query TTT select array_join(['h', 'e', 'l', 'l', 'o'], ','), array_join([1, 2, 3, 4, 5], '-'), array_join([1.0, 2.0, 3.0], '|'); ---- h,e,l,l,o 1-2-3-4-5 1|2|3 +query TTT +select array_join(arrow_cast(['h', 'e', 'l', 'l', 'o'], 'LargeList(Utf8)'), ','), array_join(arrow_cast([1, 2, 3, 4, 5], 'LargeList(Int64)'), '-'), array_join(arrow_cast([1.0, 2.0, 3.0], 'LargeList(Float64)'), '|'); +---- +h,e,l,l,o 1-2-3-4-5 1|2|3 + # list_join scalar function #6 (function alias `list_join`) query TTT select list_join(['h', 'e', 'l', 'l', 'o'], ','), list_join([1, 2, 3, 4, 5], '-'), list_join([1.0, 2.0, 3.0], '|'); ---- h,e,l,l,o 1-2-3-4-5 1|2|3 +query TTT +select list_join(arrow_cast(['h', 'e', 'l', 'l', 'o'], 'LargeList(Utf8)'), ','), list_join(arrow_cast([1, 2, 3, 4, 5], 'LargeList(Int64)'), '-'), list_join(arrow_cast([1.0, 2.0, 3.0], 'LargeList(Float64)'), '|'); +---- +h,e,l,l,o 1-2-3-4-5 1|2|3 + # array_to_string scalar function with nulls #1 query TTT select array_to_string(make_array('h', NULL, 'l', NULL, 'o'), ','), array_to_string(make_array(1, NULL, 3, NULL, 5), '-'), array_to_string(make_array(NULL, 2.0, 3.0), '|'); ---- h,l,o 1-3-5 2|3 +query TTT +select array_to_string(arrow_cast(['h', 'e', 'l', 'l', 'o'], 'LargeList(Utf8)'), ','), array_to_string(arrow_cast([1, 2, 3, 4, 5], 'LargeList(Int64)'), '-'), array_to_string(arrow_cast([1.0, 2.0, 3.0], 'LargeList(Float64)'), '|'); +---- +h,e,l,l,o 1-2-3-4-5 1|2|3 + # array_to_string scalar function with nulls #2 query TTT select array_to_string(make_array('h', NULL, NULL, NULL, 'o'), ',', '-'), array_to_string(make_array(NULL, 2, NULL, 4, 5), '-', 'nil'), array_to_string(make_array(1.0, NULL, 3.0), '|', '0'); ---- h,-,-,-,o nil-2-nil-4-5 1|0|3 +query TTT +select array_to_string(arrow_cast(make_array('h', NULL, NULL, NULL, 'o'), 'LargeList(Utf8)'), ',', '-'), array_to_string(arrow_cast(make_array(NULL, 2, NULL, 4, 5), 'LargeList(Int64)'), '-', 'nil'), array_to_string(arrow_cast(make_array(1.0, NULL, 3.0), 'LargeList(Float64)'), '|', '0'); +---- +h,-,-,-,o nil-2-nil-4-5 1|0|3 + # array_to_string with columns #1 # For reference @@ -3288,6 +3313,18 @@ NULL 51^52^54^55^56^57^58^59^60 NULL +query T +select array_to_string(column1, column4) from large_arrays_values; +---- +2,3,4,5,6,7,8,9,10 +11.12.13.14.15.16.17.18.20 +21-22-23-25-26-27-28-29-30 +31ok32ok33ok34ok35ok37ok38ok39ok40 +NULL +41$42$43$44$45$46$47$48$49$50 +51^52^54^55^56^57^58^59^60 +NULL + query TT select array_to_string(column1, '_'), array_to_string(make_array(1,2,3), '/') from arrays_values; ---- @@ -3300,6 +3337,18 @@ NULL 1/2/3 51_52_54_55_56_57_58_59_60 1/2/3 61_62_63_64_65_66_67_68_69_70 1/2/3 +query TT +select array_to_string(column1, '_'), array_to_string(make_array(1,2,3), '/') from large_arrays_values; +---- +2_3_4_5_6_7_8_9_10 1/2/3 +11_12_13_14_15_16_17_18_20 1/2/3 +21_22_23_25_26_27_28_29_30 1/2/3 +31_32_33_34_35_37_38_39_40 1/2/3 +NULL 1/2/3 +41_42_43_44_45_46_47_48_49_50 1/2/3 +51_52_54_55_56_57_58_59_60 1/2/3 +61_62_63_64_65_66_67_68_69_70 1/2/3 + query TT select array_to_string(column1, '_', '*'), array_to_string(make_array(make_array(1,2,3)), '.') from arrays_values; ---- @@ -3312,6 +3361,18 @@ NULL 1.2.3 51_52_*_54_55_56_57_58_59_60 1.2.3 61_62_63_64_65_66_67_68_69_70 1.2.3 +query TT +select array_to_string(column1, '_', '*'), array_to_string(make_array(make_array(1,2,3)), '.') from large_arrays_values; +---- +*_2_3_4_5_6_7_8_9_10 1.2.3 +11_12_13_14_15_16_17_18_*_20 1.2.3 +21_22_23_*_25_26_27_28_29_30 1.2.3 +31_32_33_34_35_*_37_38_39_40 1.2.3 +NULL 1.2.3 +41_42_43_44_45_46_47_48_49_50 1.2.3 +51_52_*_54_55_56_57_58_59_60 1.2.3 +61_62_63_64_65_66_67_68_69_70 1.2.3 + ## cardinality # cardinality scalar function @@ -3320,18 +3381,33 @@ select cardinality(make_array(1, 2, 3, 4, 5)), cardinality([1, 3, 5]), cardinali ---- 5 3 5 +query III +select cardinality(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)')), cardinality(arrow_cast([1, 3, 5], 'LargeList(Int64)')), cardinality(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)')); +---- +5 3 5 + # cardinality scalar function #2 query II select cardinality(make_array([1, 2], [3, 4], [5, 6])), cardinality(array_repeat(array_repeat(array_repeat(3, 3), 2), 3)); ---- 6 18 +query I +select cardinality(arrow_cast(make_array([1, 2], [3, 4], [5, 6]), 'LargeList(List(Int64))')); +---- +6 + # cardinality scalar function #3 query II select cardinality(make_array()), cardinality(make_array(make_array())) ---- NULL 0 +query II +select cardinality(arrow_cast(make_array(), 'LargeList(Null)')), cardinality(arrow_cast(make_array(make_array()), 'LargeList(List(Null))')) +---- +NULL 0 + # cardinality with columns query III select cardinality(column1), cardinality(column2), cardinality(column3) from arrays; @@ -3344,6 +3420,17 @@ NULL 3 4 4 NULL 1 4 3 NULL +query III +select cardinality(column1), cardinality(column2), cardinality(column3) from large_arrays; +---- +4 3 5 +4 3 5 +4 3 5 +4 3 3 +NULL 3 4 +4 NULL 1 +4 3 NULL + ## array_remove (aliases: `list_remove`) # array_remove scalar function #1 @@ -4530,6 +4617,45 @@ select 1 || make_array(2, 3, 4), 1.0 || make_array(2.0, 3.0, 4.0), 'h' || make_a ---- [1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] +# array concatenate operator with scalars #4 (mixed) +query ? +select 0 || [1,2,3] || 4 || [5] || [6,7]; +---- +[0, 1, 2, 3, 4, 5, 6, 7] + +# array concatenate operator with nd-list #5 (mixed) +query ? +select 0 || [1,2,3] || [[4,5]] || [[6,7,8]] || [9,10]; +---- +[[0, 1, 2, 3], [4, 5], [6, 7, 8], [9, 10]] + +# array concatenate operator non-valid cases +## concat 2D with scalar is not valid +query error +select 0 || [1,2,3] || [[4,5]] || [[6,7,8]] || [9,10] || 11; + +## concat scalar with 2D is not valid +query error +select 0 || [[1,2,3]]; + +# array concatenate operator with column + +statement ok +CREATE TABLE array_concat_operator_table +AS VALUES + (0, [1, 2, 2, 3], 4, [5, 6, 5]), + (-1, [4, 5, 6], 7, [8, 1, 1]) +; + +query ? +select column1 || column2 || column3 || column4 from array_concat_operator_table; +---- +[0, 1, 2, 2, 3, 4, 5, 6, 5] +[-1, 4, 5, 6, 7, 8, 1, 1] + +statement ok +drop table array_concat_operator_table; + ## array containment operator # array containment operator with scalars #1 (at arrow) diff --git a/datafusion/sqllogictest/test_files/dictionary.slt b/datafusion/sqllogictest/test_files/dictionary.slt new file mode 100644 index 000000000000..002aade2528e --- /dev/null +++ b/datafusion/sqllogictest/test_files/dictionary.slt @@ -0,0 +1,282 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Tests for querying on dictionary encoded data + +# Note: These tables model data as is common for timeseries, such as in InfluxDB IOx +# There are three types of columns: +# 1. tag columns, which are string dictionaries, often with low cardinality +# 2. field columns, which are typed, +# 3. a `time` columns, which is a nanosecond timestamp + +# It is common to group and filter on the "tag" columns (and thus on dictionary +# encoded values) + +# Table m1 with a tag column `tag_id` 4 fields `f1` - `f4`, and `time` + +statement ok +CREATE VIEW m1 AS +SELECT + arrow_cast(column1, 'Dictionary(Int32, Utf8)') as tag_id, + arrow_cast(column2, 'Float64') as f1, + arrow_cast(column3, 'Utf8') as f2, + arrow_cast(column4, 'Utf8') as f3, + arrow_cast(column5, 'Float64') as f4, + arrow_cast(column6, 'Timestamp(Nanosecond, None)') as time +FROM ( + VALUES + -- equivalent to the following line protocol data + -- m1,tag_id=1000 f1=32,f2="foo",f3="True",f4=1.0 1703030400000000000 + -- m1,tag_id=1000 f1=32,f2="foo",f3="True",f4=2.0 1703031000000000000 + -- m1,tag_id=1000 f1=32,f2="foo",f3="True",f4=3.0 1703031600000000000 + -- m1,tag_id=1000 f1=32,f2="foo",f3="True",f4=4.0 1703032200000000000 + -- m1,tag_id=1000 f1=32,f2="foo",f3="True",f4=5.0 1703032800000000000 + -- m1,tag_id=1000 f1=32,f2="foo",f3="True",f4=6.0 1703033400000000000 + -- m1,tag_id=1000 f1=32,f2="foo",f3="True",f4=7.0 1703034000000000000 + -- m1,tag_id=1000 f1=32,f2="foo",f3="True",f4=8.0 1703034600000000000 + -- m1,tag_id=1000 f1=32,f2="foo",f3="True",f4=9.0 1703035200000000000 + -- m1,tag_id=1000 f1=32,f2="foo",f3="True",f4=10.0 1703035800000000000 + ('1000', 32, 'foo', 'True', 1.0, 1703030400000000000), + ('1000', 32, 'foo', 'True', 2.0, 1703031000000000000), + ('1000', 32, 'foo', 'True', 3.0, 1703031600000000000), + ('1000', 32, 'foo', 'True', 4.0, 1703032200000000000), + ('1000', 32, 'foo', 'True', 5.0, 1703032800000000000), + ('1000', 32, 'foo', 'True', 6.0, 1703033400000000000), + ('1000', 32, 'foo', 'True', 7.0, 1703034000000000000), + ('1000', 32, 'foo', 'True', 8.0, 1703034600000000000), + ('1000', 32, 'foo', 'True', 9.0, 1703035200000000000), + ('1000', 32, 'foo', 'True', 10.0, 1703035800000000000) +); + +query ?RTTRP +SELECT * FROM m1; +---- +1000 32 foo True 1 2023-12-20T00:00:00 +1000 32 foo True 2 2023-12-20T00:10:00 +1000 32 foo True 3 2023-12-20T00:20:00 +1000 32 foo True 4 2023-12-20T00:30:00 +1000 32 foo True 5 2023-12-20T00:40:00 +1000 32 foo True 6 2023-12-20T00:50:00 +1000 32 foo True 7 2023-12-20T01:00:00 +1000 32 foo True 8 2023-12-20T01:10:00 +1000 32 foo True 9 2023-12-20T01:20:00 +1000 32 foo True 10 2023-12-20T01:30:00 + +# Note that te type of the tag column is `Dictionary(Int32, Utf8)` +query TTT +DESCRIBE m1; +---- +tag_id Dictionary(Int32, Utf8) YES +f1 Float64 YES +f2 Utf8 YES +f3 Utf8 YES +f4 Float64 YES +time Timestamp(Nanosecond, None) YES + + +# Table m2 with a tag columns `tag_id` and `type`, a field column `f5`, and `time` +statement ok +CREATE VIEW m2 AS +SELECT + arrow_cast(column1, 'Dictionary(Int32, Utf8)') as type, + arrow_cast(column2, 'Dictionary(Int32, Utf8)') as tag_id, + arrow_cast(column3, 'Float64') as f5, + arrow_cast(column4, 'Timestamp(Nanosecond, None)') as time +FROM ( + VALUES + -- equivalent to the following line protocol data + -- m2,type=active,tag_id=1000 f5=100 1701648000000000000 + -- m2,type=active,tag_id=1000 f5=200 1701648600000000000 + -- m2,type=active,tag_id=1000 f5=300 1701649200000000000 + -- m2,type=active,tag_id=1000 f5=400 1701649800000000000 + -- m2,type=active,tag_id=1000 f5=500 1701650400000000000 + -- m2,type=active,tag_id=1000 f5=600 1701651000000000000 + -- m2,type=passive,tag_id=2000 f5=700 1701651600000000000 + -- m2,type=passive,tag_id=1000 f5=800 1701652200000000000 + -- m2,type=passive,tag_id=1000 f5=900 1701652800000000000 + -- m2,type=passive,tag_id=1000 f5=1000 1701653400000000000 + ('active', '1000', 100, 1701648000000000000), + ('active', '1000', 200, 1701648600000000000), + ('active', '1000', 300, 1701649200000000000), + ('active', '1000', 400, 1701649800000000000), + ('active', '1000', 500, 1701650400000000000), + ('active', '1000', 600, 1701651000000000000), + ('passive', '1000', 700, 1701651600000000000), + ('passive', '1000', 800, 1701652200000000000), + ('passive', '1000', 900, 1701652800000000000), + ('passive', '1000', 1000, 1701653400000000000) +); + +query ??RP +SELECT * FROM m2; +---- +active 1000 100 2023-12-04T00:00:00 +active 1000 200 2023-12-04T00:10:00 +active 1000 300 2023-12-04T00:20:00 +active 1000 400 2023-12-04T00:30:00 +active 1000 500 2023-12-04T00:40:00 +active 1000 600 2023-12-04T00:50:00 +passive 1000 700 2023-12-04T01:00:00 +passive 1000 800 2023-12-04T01:10:00 +passive 1000 900 2023-12-04T01:20:00 +passive 1000 1000 2023-12-04T01:30:00 + +query TTT +DESCRIBE m2; +---- +type Dictionary(Int32, Utf8) YES +tag_id Dictionary(Int32, Utf8) YES +f5 Float64 YES +time Timestamp(Nanosecond, None) YES + +query I +select count(*) from m1 where tag_id = '1000' and time < '2024-01-03T14:46:35+01:00'; +---- +10 + +query RRR rowsort +select min(f5), max(f5), avg(f5) from m2 where tag_id = '1000' and time < '2024-01-03T14:46:35+01:00' group by type; +---- +100 600 350 +700 1000 850 + +query IRRRP +select count(*), min(f5), max(f5), avg(f5), date_bin('30 minutes', time) as "time" +from m2 where tag_id = '1000' and time < '2024-01-03T14:46:35+01:00' +group by date_bin('30 minutes', time) +order by date_bin('30 minutes', time) DESC +---- +1 1000 1000 1000 2023-12-04T01:30:00 +3 700 900 800 2023-12-04T01:00:00 +3 400 600 500 2023-12-04T00:30:00 +3 100 300 200 2023-12-04T00:00:00 + + + +# Reproducer for https://github.com/apache/arrow-datafusion/issues/8738 +# This query should work correctly +query P?TT rowsort +SELECT + "data"."timestamp" as "time", + "data"."tag_id", + "data"."field", + "data"."value" +FROM ( + ( + SELECT "m2"."time" as "timestamp", "m2"."tag_id", 'active_power' as "field", "m2"."f5" as "value" + FROM "m2" + WHERE "m2"."time" >= '2023-12-05T14:46:35+01:00' AND "m2"."time" < '2024-01-03T14:46:35+01:00' + AND "m2"."f5" IS NOT NULL + AND "m2"."type" IN ('active') + AND "m2"."tag_id" IN ('1000') + ) UNION ( + SELECT "m1"."time" as "timestamp", "m1"."tag_id", 'f1' as "field", "m1"."f1" as "value" + FROM "m1" + WHERE "m1"."time" >= '2023-12-05T14:46:35+01:00' AND "m1"."time" < '2024-01-03T14:46:35+01:00' + AND "m1"."f1" IS NOT NULL + AND "m1"."tag_id" IN ('1000') + ) UNION ( + SELECT "m1"."time" as "timestamp", "m1"."tag_id", 'f2' as "field", "m1"."f2" as "value" + FROM "m1" + WHERE "m1"."time" >= '2023-12-05T14:46:35+01:00' AND "m1"."time" < '2024-01-03T14:46:35+01:00' + AND "m1"."f2" IS NOT NULL + AND "m1"."tag_id" IN ('1000') + ) +) as "data" +ORDER BY + "time", + "data"."tag_id" +; +---- +2023-12-20T00:00:00 1000 f1 32.0 +2023-12-20T00:00:00 1000 f2 foo +2023-12-20T00:10:00 1000 f1 32.0 +2023-12-20T00:10:00 1000 f2 foo +2023-12-20T00:20:00 1000 f1 32.0 +2023-12-20T00:20:00 1000 f2 foo +2023-12-20T00:30:00 1000 f1 32.0 +2023-12-20T00:30:00 1000 f2 foo +2023-12-20T00:40:00 1000 f1 32.0 +2023-12-20T00:40:00 1000 f2 foo +2023-12-20T00:50:00 1000 f1 32.0 +2023-12-20T00:50:00 1000 f2 foo +2023-12-20T01:00:00 1000 f1 32.0 +2023-12-20T01:00:00 1000 f2 foo +2023-12-20T01:10:00 1000 f1 32.0 +2023-12-20T01:10:00 1000 f2 foo +2023-12-20T01:20:00 1000 f1 32.0 +2023-12-20T01:20:00 1000 f2 foo +2023-12-20T01:30:00 1000 f1 32.0 +2023-12-20T01:30:00 1000 f2 foo + + +# deterministic sort (so we can avoid rowsort) +query P?TT +SELECT + "data"."timestamp" as "time", + "data"."tag_id", + "data"."field", + "data"."value" +FROM ( + ( + SELECT "m2"."time" as "timestamp", "m2"."tag_id", 'active_power' as "field", "m2"."f5" as "value" + FROM "m2" + WHERE "m2"."time" >= '2023-12-05T14:46:35+01:00' AND "m2"."time" < '2024-01-03T14:46:35+01:00' + AND "m2"."f5" IS NOT NULL + AND "m2"."type" IN ('active') + AND "m2"."tag_id" IN ('1000') + ) UNION ( + SELECT "m1"."time" as "timestamp", "m1"."tag_id", 'f1' as "field", "m1"."f1" as "value" + FROM "m1" + WHERE "m1"."time" >= '2023-12-05T14:46:35+01:00' AND "m1"."time" < '2024-01-03T14:46:35+01:00' + AND "m1"."f1" IS NOT NULL + AND "m1"."tag_id" IN ('1000') + ) UNION ( + SELECT "m1"."time" as "timestamp", "m1"."tag_id", 'f2' as "field", "m1"."f2" as "value" + FROM "m1" + WHERE "m1"."time" >= '2023-12-05T14:46:35+01:00' AND "m1"."time" < '2024-01-03T14:46:35+01:00' + AND "m1"."f2" IS NOT NULL + AND "m1"."tag_id" IN ('1000') + ) +) as "data" +ORDER BY + "time", + "data"."tag_id", + "data"."field", + "data"."value" +; +---- +2023-12-20T00:00:00 1000 f1 32.0 +2023-12-20T00:00:00 1000 f2 foo +2023-12-20T00:10:00 1000 f1 32.0 +2023-12-20T00:10:00 1000 f2 foo +2023-12-20T00:20:00 1000 f1 32.0 +2023-12-20T00:20:00 1000 f2 foo +2023-12-20T00:30:00 1000 f1 32.0 +2023-12-20T00:30:00 1000 f2 foo +2023-12-20T00:40:00 1000 f1 32.0 +2023-12-20T00:40:00 1000 f2 foo +2023-12-20T00:50:00 1000 f1 32.0 +2023-12-20T00:50:00 1000 f2 foo +2023-12-20T01:00:00 1000 f1 32.0 +2023-12-20T01:00:00 1000 f2 foo +2023-12-20T01:10:00 1000 f1 32.0 +2023-12-20T01:10:00 1000 f2 foo +2023-12-20T01:20:00 1000 f1 32.0 +2023-12-20T01:20:00 1000 f2 foo +2023-12-20T01:30:00 1000 f1 32.0 +2023-12-20T01:30:00 1000 f2 foo diff --git a/datafusion/sqllogictest/test_files/explain.slt b/datafusion/sqllogictest/test_files/explain.slt index 4583ef319b7f..2a39e3138869 100644 --- a/datafusion/sqllogictest/test_files/explain.slt +++ b/datafusion/sqllogictest/test_files/explain.slt @@ -180,6 +180,7 @@ initial_logical_plan Projection: simple_explain_test.a, simple_explain_test.b, simple_explain_test.c --TableScan: simple_explain_test logical_plan after inline_table_scan SAME TEXT AS ABOVE +logical_plan after operator_to_function SAME TEXT AS ABOVE logical_plan after type_coercion SAME TEXT AS ABOVE logical_plan after count_wildcard_rule SAME TEXT AS ABOVE analyzed_logical_plan SAME TEXT AS ABOVE diff --git a/datafusion/sqllogictest/test_files/expr.slt b/datafusion/sqllogictest/test_files/expr.slt new file mode 100644 index 000000000000..a2a8d9c6475c --- /dev/null +++ b/datafusion/sqllogictest/test_files/expr.slt @@ -0,0 +1,1251 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# test_boolean_expressions +query BBBB +SELECT true, false, false = false, true = false +---- +true false true false + +# test_mathematical_expressions_with_null +query RRRRRRRRRRRRRRRRRR?RRRRRRRIRRRRRRBB +SELECT + sqrt(NULL), + cbrt(NULL), + sin(NULL), + cos(NULL), + tan(NULL), + asin(NULL), + acos(NULL), + atan(NULL), + sinh(NULL), + cosh(NULL), + tanh(NULL), + asinh(NULL), + acosh(NULL), + atanh(NULL), + floor(NULL), + ceil(NULL), + round(NULL), + trunc(NULL), + abs(NULL), + signum(NULL), + exp(NULL), + ln(NULL), + log2(NULL), + log10(NULL), + power(NULL, 2), + power(NULL, NULL), + power(2, NULL), + atan2(NULL, NULL), + atan2(1, NULL), + atan2(NULL, 1), + nanvl(NULL, NULL), + nanvl(1, NULL), + nanvl(NULL, 1), + isnan(NULL), + iszero(NULL) +---- +NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL + +# test_array_cast_invalid_timezone_will_panic +statement error Parser error: Invalid timezone "Foo": 'Foo' is not a valid timezone +SELECT arrow_cast('2021-01-02T03:04:00', 'Timestamp(Nanosecond, Some("Foo"))') + +# test_array_index +query III??IIIIII +SELECT + ([5,4,3,2,1])[1], + ([5,4,3,2,1])[2], + ([5,4,3,2,1])[5], + ([[1, 2], [2, 3], [3,4]])[1], + ([[1, 2], [2, 3], [3,4]])[3], + ([[1, 2], [2, 3], [3,4]])[1][1], + ([[1, 2], [2, 3], [3,4]])[2][2], + ([[1, 2], [2, 3], [3,4]])[3][2], + -- out of bounds + ([5,4,3,2,1])[0], + ([5,4,3,2,1])[6], + -- ([5,4,3,2,1])[-1], -- TODO: wrong answer + -- ([5,4,3,2,1])[null], -- TODO: not supported + ([5,4,3,2,1])[100] +---- +5 4 1 [1, 2] [3, 4] 1 3 4 NULL NULL NULL + +# test_array_literals +query ????? +SELECT + [1,2,3,4,5], + [true, false], + ['str1', 'str2'], + [[1,2], [3,4]], + [] +---- +[1, 2, 3, 4, 5] [true, false] [str1, str2] [[1, 2], [3, 4]] [] + +# test_struct_literals +query ?????? +SELECT + STRUCT(1,2,3,4,5), + STRUCT(Null), + STRUCT(2), + STRUCT('1',Null), + STRUCT(true, false), + STRUCT('str1', 'str2') +---- +{c0: 1, c1: 2, c2: 3, c3: 4, c4: 5} {c0: } {c0: 2} {c0: 1, c1: } {c0: true, c1: false} {c0: str1, c1: str2} + +# test binary_bitwise_shift +query IIII +SELECT + 2 << 10, + 2048 >> 10, + 2048 << NULL, + 2048 >> NULL +---- +2048 2 NULL NULL + +query ? +SELECT interval '1' +---- +0 years 0 mons 0 days 0 hours 0 mins 1.000000000 secs + +query ? +SELECT interval '1 second' +---- +0 years 0 mons 0 days 0 hours 0 mins 1.000000000 secs + +query ? +SELECT interval '500 milliseconds' +---- +0 years 0 mons 0 days 0 hours 0 mins 0.500000000 secs + +query ? +SELECT interval '5 second' +---- +0 years 0 mons 0 days 0 hours 0 mins 5.000000000 secs + +query ? +SELECT interval '0.5 minute' +---- +0 years 0 mons 0 days 0 hours 0 mins 30.000000000 secs + +query ? +SELECT interval '.5 minute' +---- +0 years 0 mons 0 days 0 hours 0 mins 30.000000000 secs + +query ? +SELECT interval '5 minute' +---- +0 years 0 mons 0 days 0 hours 5 mins 0.000000000 secs + +query ? +SELECT interval '5 minute 1 second' +---- +0 years 0 mons 0 days 0 hours 5 mins 1.000000000 secs + +query ? +SELECT interval '1 hour' +---- +0 years 0 mons 0 days 1 hours 0 mins 0.000000000 secs + +query ? +SELECT interval '5 hour' +---- +0 years 0 mons 0 days 5 hours 0 mins 0.000000000 secs + +query ? +SELECT interval '1 day' +---- +0 years 0 mons 1 days 0 hours 0 mins 0.000000000 secs + +query ? +SELECT interval '1 week' +---- +0 years 0 mons 7 days 0 hours 0 mins 0.000000000 secs + +query ? +SELECT interval '2 weeks' +---- +0 years 0 mons 14 days 0 hours 0 mins 0.000000000 secs + +query ? +SELECT interval '1 day 1' +---- +0 years 0 mons 1 days 0 hours 0 mins 1.000000000 secs + +query ? +SELECT interval '0.5' +---- +0 years 0 mons 0 days 0 hours 0 mins 0.500000000 secs + +query ? +SELECT interval '0.5 day 1' +---- +0 years 0 mons 0 days 12 hours 0 mins 1.000000000 secs + +query ? +SELECT interval '0.49 day' +---- +0 years 0 mons 0 days 11 hours 45 mins 36.000000000 secs + +query ? +SELECT interval '0.499 day' +---- +0 years 0 mons 0 days 11 hours 58 mins 33.600000000 secs + +query ? +SELECT interval '0.4999 day' +---- +0 years 0 mons 0 days 11 hours 59 mins 51.360000000 secs + +query ? +SELECT interval '0.49999 day' +---- +0 years 0 mons 0 days 11 hours 59 mins 59.136000000 secs + +query ? +SELECT interval '0.49999999999 day' +---- +0 years 0 mons 0 days 11 hours 59 mins 59.999999136 secs + +query ? +SELECT interval '5 day' +---- +0 years 0 mons 5 days 0 hours 0 mins 0.000000000 secs + +# Hour is ignored, this matches PostgreSQL +query ? +SELECT interval '5 day' hour +---- +0 years 0 mons 5 days 0 hours 0 mins 0.000000000 secs + +query ? +SELECT interval '5 day 4 hours 3 minutes 2 seconds 100 milliseconds' +---- +0 years 0 mons 5 days 4 hours 3 mins 2.100000000 secs + +query ? +SELECT interval '0.5 month' +---- +0 years 0 mons 15 days 0 hours 0 mins 0.000000000 secs + +query ? +SELECT interval '0.5' month +---- +0 years 0 mons 15 days 0 hours 0 mins 0.000000000 secs + +query ? +SELECT interval '1 month' +---- +0 years 1 mons 0 days 0 hours 0 mins 0.000000000 secs + +query ? +SELECT interval '1' MONTH +---- +0 years 1 mons 0 days 0 hours 0 mins 0.000000000 secs + +query ? +SELECT interval '5 month' +---- +0 years 5 mons 0 days 0 hours 0 mins 0.000000000 secs + +query ? +SELECT interval '13 month' +---- +0 years 13 mons 0 days 0 hours 0 mins 0.000000000 secs + +query ? +SELECT interval '0.5 year' +---- +0 years 6 mons 0 days 0 hours 0 mins 0.000000000 secs + +query ? +SELECT interval '1 year' +---- +0 years 12 mons 0 days 0 hours 0 mins 0.000000000 secs + +query ? +SELECT interval '1 decade' +---- +0 years 120 mons 0 days 0 hours 0 mins 0.000000000 secs + +query ? +SELECT interval '2 decades' +---- +0 years 240 mons 0 days 0 hours 0 mins 0.000000000 secs + +query ? +SELECT interval '1 century' +---- +0 years 1200 mons 0 days 0 hours 0 mins 0.000000000 secs + +query ? +SELECT interval '2 year' +---- +0 years 24 mons 0 days 0 hours 0 mins 0.000000000 secs + +query ? +SELECT interval '1 year 1 day' +---- +0 years 12 mons 1 days 0 hours 0 mins 0.000000000 secs + +query ? +SELECT interval '1 year 1 day 1 hour' +---- +0 years 12 mons 1 days 1 hours 0 mins 0.000000000 secs + +query ? +SELECT interval '1 year 1 day 1 hour 1 minute' +---- +0 years 12 mons 1 days 1 hours 1 mins 0.000000000 secs + +query ? +SELECT interval '1 year 1 day 1 hour 1 minute 1 second' +---- +0 years 12 mons 1 days 1 hours 1 mins 1.000000000 secs + +query I +SELECT ascii('') +---- +0 + +query I +SELECT ascii('x') +---- +120 + +query I +SELECT ascii(NULL) +---- +NULL + +query I +SELECT bit_length('') +---- +0 + +query I +SELECT bit_length('chars') +---- +40 + +query I +SELECT bit_length('josé') +---- +40 + +query ? +SELECT bit_length(NULL) +---- +NULL + +query T +SELECT btrim(' xyxtrimyyx ', NULL) +---- +NULL + +query T +SELECT btrim(' xyxtrimyyx ') +---- +xyxtrimyyx + +query T +SELECT btrim('\n xyxtrimyyx \n') +---- +\n xyxtrimyyx \n + +query T +SELECT btrim('xyxtrimyyx', 'xyz') +---- +trim + +query T +SELECT btrim('\nxyxtrimyyx\n', 'xyz\n') +---- +trim + +query ? +SELECT btrim(NULL, 'xyz') +---- +NULL + +query T +SELECT chr(CAST(120 AS int)) +---- +x + +query T +SELECT chr(CAST(128175 AS int)) +---- +💯 + +query T +SELECT chr(CAST(NULL AS int)) +---- +NULL + +query T +SELECT concat('a','b','c') +---- +abc + +query T +SELECT concat('abcde', 2, NULL, 22) +---- +abcde222 + +query T +SELECT concat(NULL) +---- +(empty) + +query T +SELECT concat_ws(',', 'abcde', 2, NULL, 22) +---- +abcde,2,22 + +query T +SELECT concat_ws('|','a','b','c') +---- +a|b|c + +query T +SELECT concat_ws('|',NULL) +---- +(empty) + +query T +SELECT concat_ws(NULL,'a',NULL,'b','c') +---- +NULL + +query T +SELECT concat_ws('|','a',NULL) +---- +a + +query T +SELECT concat_ws('|','a',NULL,NULL) +---- +a + +query T +SELECT initcap('') +---- +(empty) + +query T +SELECT initcap('hi THOMAS') +---- +Hi Thomas + +query ? +SELECT initcap(NULL) +---- +NULL + +query T +SELECT lower('') +---- +(empty) + +query T +SELECT lower('TOM') +---- +tom + +query ? +SELECT lower(NULL) +---- +NULL + +query T +SELECT ltrim(' zzzytest ', NULL) +---- +NULL + +query T +SELECT ltrim(' zzzytest ') +---- +zzzytest + +query T +SELECT ltrim('zzzytest', 'xyz') +---- +test + +query ? +SELECT ltrim(NULL, 'xyz') +---- +NULL + +query I +SELECT octet_length('') +---- +0 + +query I +SELECT octet_length('chars') +---- +5 + +query I +SELECT octet_length('josé') +---- +5 + +query ? +SELECT octet_length(NULL) +---- +NULL + +query T +SELECT repeat('Pg', 4) +---- +PgPgPgPg + +query T +SELECT repeat('Pg', CAST(NULL AS INT)) +---- +NULL + +query ? +SELECT repeat(NULL, 4) +---- +NULL + +query T +SELECT replace('abcdefabcdef', 'cd', 'XX') +---- +abXXefabXXef + +query T +SELECT replace('abcdefabcdef', 'cd', NULL) +---- +NULL + +query T +SELECT replace('abcdefabcdef', 'notmatch', 'XX') +---- +abcdefabcdef + +query T +SELECT replace('abcdefabcdef', NULL, 'XX') +---- +NULL + +query ? +SELECT replace(NULL, 'cd', 'XX') +---- +NULL + +query T +SELECT rtrim(' testxxzx ') +---- + testxxzx + +query T +SELECT rtrim(' zzzytest ', NULL) +---- +NULL + +query T +SELECT rtrim('testxxzx', 'xyz') +---- +test + +query ? +SELECT rtrim(NULL, 'xyz') +---- +NULL + +query T +SELECT split_part('abc~@~def~@~ghi', '~@~', 2) +---- +def + +query T +SELECT split_part('abc~@~def~@~ghi', '~@~', 20) +---- +(empty) + +query ? +SELECT split_part(NULL, '~@~', 20) +---- +NULL + +query T +SELECT split_part('abc~@~def~@~ghi', NULL, 20) +---- +NULL + +query T +SELECT split_part('abc~@~def~@~ghi', '~@~', CAST(NULL AS INT)) +---- +NULL + +query B +SELECT starts_with('alphabet', 'alph') +---- +true + +query B +SELECT starts_with('alphabet', 'blph') +---- +false + +query B +SELECT starts_with(NULL, 'blph') +---- +NULL + +query B +SELECT starts_with('alphabet', NULL) +---- +NULL + +query T +SELECT to_hex(2147483647) +---- +7fffffff + +query T +SELECT to_hex(9223372036854775807) +---- +7fffffffffffffff + +query T +SELECT to_hex(CAST(NULL AS int)) +---- +NULL + +query T +SELECT trim(' tom ') +---- +tom + +query T +SELECT trim(LEADING ' tom ') +---- +tom + +query T +SELECT trim(TRAILING ' tom ') +---- + tom + +query T +SELECT trim(BOTH ' tom ') +---- +tom + +query T +SELECT trim(LEADING ' ' FROM ' tom ') +---- +tom + +query T +SELECT trim(TRAILING ' ' FROM ' tom ') +---- + tom + +query T +SELECT trim(BOTH ' ' FROM ' tom ') +---- +tom + +query T +SELECT trim(' ' FROM ' tom ') +---- +tom + +query T +SELECT trim(LEADING 'x' FROM 'xxxtomxxx') +---- +tomxxx + +query T +SELECT trim(TRAILING 'x' FROM 'xxxtomxxx') +---- +xxxtom + +query T +SELECT trim(BOTH 'x' FROM 'xxxtomxx') +---- +tom + +query T +SELECT trim('x' FROM 'xxxtomxx') +---- +tom + + +query T +SELECT trim(LEADING 'xy' FROM 'xyxabcxyzdefxyx') +---- +abcxyzdefxyx + +query T +SELECT trim(TRAILING 'xy' FROM 'xyxabcxyzdefxyx') +---- +xyxabcxyzdef + +query T +SELECT trim(BOTH 'xy' FROM 'xyxabcxyzdefxyx') +---- +abcxyzdef + +query T +SELECT trim('xy' FROM 'xyxabcxyzdefxyx') +---- +abcxyzdef + +query T +SELECT trim(' tom') +---- +tom + +query T +SELECT trim('') +---- +(empty) + +query T +SELECT trim('tom ') +---- +tom + +query T +SELECT upper('') +---- +(empty) + +query T +SELECT upper('tom') +---- +TOM + +query ? +SELECT upper(NULL) +---- +NULL + +# TODO issue: https://github.com/apache/arrow-datafusion/issues/6596 +# query ?? +#SELECT +# CAST([1,2,3,4] AS INT[]) as a, +# CAST([1,2,3,4] AS NUMERIC(10,4)[]) as b +#---- +#[1, 2, 3, 4] [1.0000, 2.0000, 3.0000, 4.0000] + +# test_random_expression +query BB +SELECT + random() BETWEEN 0.0 AND 1.0, + random() = random() +---- +true false + +# test_uuid_expression +query II +SELECT octet_length(uuid()), length(uuid()) +---- +36 36 + +# test_cast_expressions +query IIII +SELECT + CAST('0' AS INT) as a, + CAST(NULL AS INT) as b, + TRY_CAST('0' AS INT) as c, + TRY_CAST('x' AS INT) as d +---- +0 NULL 0 NULL + +# test_extract_date_part + +query R +SELECT date_part('YEAR', CAST('2000-01-01' AS DATE)) +---- +2000 + +query R +SELECT EXTRACT(year FROM timestamp '2020-09-08T12:00:00+00:00') +---- +2020 + +query R +SELECT date_part('QUARTER', CAST('2000-01-01' AS DATE)) +---- +1 + +query R +SELECT EXTRACT(quarter FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +3 + +query R +SELECT date_part('MONTH', CAST('2000-01-01' AS DATE)) +---- +1 + +query R +SELECT EXTRACT(month FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +9 + +query R +SELECT date_part('WEEK', CAST('2003-01-01' AS DATE)) +---- +1 + +query R +SELECT EXTRACT(WEEK FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +37 + +query R +SELECT date_part('DAY', CAST('2000-01-01' AS DATE)) +---- +1 + +query R +SELECT EXTRACT(day FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +8 + +query R +SELECT date_part('DOY', CAST('2000-01-01' AS DATE)) +---- +1 + +query R +SELECT EXTRACT(doy FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +252 + +query R +SELECT date_part('DOW', CAST('2000-01-01' AS DATE)) +---- +6 + +query R +SELECT EXTRACT(dow FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +2 + +query R +SELECT date_part('HOUR', CAST('2000-01-01' AS DATE)) +---- +0 + +query R +SELECT EXTRACT(hour FROM to_timestamp('2020-09-08T12:03:03+00:00')) +---- +12 + +query R +SELECT EXTRACT(minute FROM to_timestamp('2020-09-08T12:12:00+00:00')) +---- +12 + +query R +SELECT date_part('minute', to_timestamp('2020-09-08T12:12:00+00:00')) +---- +12 + +query R +SELECT EXTRACT(second FROM timestamp '2020-09-08T12:00:12.12345678+00:00') +---- +12.12345678 + +query R +SELECT EXTRACT(millisecond FROM timestamp '2020-09-08T12:00:12.12345678+00:00') +---- +12123.45678 + +query R +SELECT EXTRACT(microsecond FROM timestamp '2020-09-08T12:00:12.12345678+00:00') +---- +12123456.78 + +query R +SELECT EXTRACT(nanosecond FROM timestamp '2020-09-08T12:00:12.12345678+00:00') +---- +12123456780 + +# Keep precision when coercing Utf8 to Timestamp +query R +SELECT date_part('second', timestamp '2020-09-08T12:00:12.12345678+00:00') +---- +12.12345678 + +query R +SELECT date_part('millisecond', timestamp '2020-09-08T12:00:12.12345678+00:00') +---- +12123.45678 + +query R +SELECT date_part('microsecond', timestamp '2020-09-08T12:00:12.12345678+00:00') +---- +12123456.78 + +query R +SELECT date_part('nanosecond', timestamp '2020-09-08T12:00:12.12345678+00:00') +---- +12123456780 + +query R +SELECT date_part('second', '2020-09-08T12:00:12.12345678+00:00') +---- +12.12345678 + +query R +SELECT date_part('millisecond', '2020-09-08T12:00:12.12345678+00:00') +---- +12123.45678 + +query R +SELECT date_part('microsecond', '2020-09-08T12:00:12.12345678+00:00') +---- +12123456.78 + +query R +SELECT date_part('nanosecond', '2020-09-08T12:00:12.12345678+00:00') +---- +12123456780 + +# test_extract_epoch + +query R +SELECT extract(epoch from '1870-01-01T07:29:10.256'::timestamp) +---- +-3155646649.744 + +query R +SELECT extract(epoch from '2000-01-01T00:00:00.000'::timestamp) +---- +946684800 + +query R +SELECT extract(epoch from to_timestamp('2000-01-01T00:00:00+00:00')) +---- +946684800 + +query R +SELECT extract(epoch from NULL::timestamp) +---- +NULL + +query R +SELECT extract(epoch from arrow_cast('1970-01-01', 'Date32')) +---- +0 + +query R +SELECT extract(epoch from arrow_cast('1970-01-02', 'Date32')) +---- +86400 + +query R +SELECT extract(epoch from arrow_cast('1970-01-11', 'Date32')) +---- +864000 + +query R +SELECT extract(epoch from arrow_cast('1969-12-31', 'Date32')) +---- +-86400 + +query R +SELECT extract(epoch from arrow_cast('1970-01-01', 'Date64')) +---- +0 + +query R +SELECT extract(epoch from arrow_cast('1970-01-02', 'Date64')) +---- +86400 + +query R +SELECT extract(epoch from arrow_cast('1970-01-11', 'Date64')) +---- +864000 + +query R +SELECT extract(epoch from arrow_cast('1969-12-31', 'Date64')) +---- +-86400 + +# test_extract_date_part_func + +query B +SELECT (date_part('year', now()) = EXTRACT(year FROM now())) +---- +true + +query B +SELECT (date_part('quarter', now()) = EXTRACT(quarter FROM now())) +---- +true + +query B +SELECT (date_part('month', now()) = EXTRACT(month FROM now())) +---- +true + +query B +SELECT (date_part('week', now()) = EXTRACT(week FROM now())) +---- +true + +query B +SELECT (date_part('day', now()) = EXTRACT(day FROM now())) +---- +true + +query B +SELECT (date_part('hour', now()) = EXTRACT(hour FROM now())) +---- +true + +query B +SELECT (date_part('minute', now()) = EXTRACT(minute FROM now())) +---- +true + +query B +SELECT (date_part('second', now()) = EXTRACT(second FROM now())) +---- +true + +query B +SELECT (date_part('millisecond', now()) = EXTRACT(millisecond FROM now())) +---- +true + +query B +SELECT (date_part('microsecond', now()) = EXTRACT(microsecond FROM now())) +---- +true + +query B +SELECT (date_part('nanosecond', now()) = EXTRACT(nanosecond FROM now())) +---- +true + +query B +SELECT 'a' IN ('a','b') +---- +true + +query B +SELECT 'c' IN ('a','b') +---- +false + +query B +SELECT 'c' NOT IN ('a','b') +---- +true + +query B +SELECT 'a' NOT IN ('a','b') +---- +false + +query B +SELECT NULL IN ('a','b') +---- +NULL + +query B +SELECT NULL NOT IN ('a','b') +---- +NULL + +query B +SELECT 'a' IN ('a','b',NULL) +---- +true + +query B +SELECT 'c' IN ('a','b',NULL) +---- +NULL + +query B +SELECT 'a' NOT IN ('a','b',NULL) +---- +false + +query B +SELECT 'c' NOT IN ('a','b',NULL) +---- +NULL + +query B +SELECT 0 IN (0,1,2) +---- +true + +query B +SELECT 3 IN (0,1,2) +---- +false + +query B +SELECT 3 NOT IN (0,1,2) +---- +true + +query B +SELECT 0 NOT IN (0,1,2) +---- +false + +query B +SELECT NULL IN (0,1,2) +---- +NULL + +query B +SELECT NULL NOT IN (0,1,2) +---- +NULL + +query B +SELECT 0 IN (0,1,2,NULL) +---- +true + +query B +SELECT 3 IN (0,1,2,NULL) +---- +NULL + +query B +SELECT 0 NOT IN (0,1,2,NULL) +---- +false + +query B +SELECT 3 NOT IN (0,1,2,NULL) +---- +NULL + +query B +SELECT 0.0 IN (0.0,0.1,0.2) +---- +true + +query B +SELECT 0.3 IN (0.0,0.1,0.2) +---- +false + +query B +SELECT 0.3 NOT IN (0.0,0.1,0.2) +---- +true + +query B +SELECT 0.0 NOT IN (0.0,0.1,0.2) +---- +false + +query B +SELECT NULL IN (0.0,0.1,0.2) +---- +NULL + +query B +SELECT NULL NOT IN (0.0,0.1,0.2) +---- +NULL + +query B +SELECT 0.0 IN (0.0,0.1,0.2,NULL) +---- +true + +query B +SELECT 0.3 IN (0.0,0.1,0.2,NULL) +---- +NULL + +query B +SELECT 0.0 NOT IN (0.0,0.1,0.2,NULL) +---- +false + +query B +SELECT 0.3 NOT IN (0.0,0.1,0.2,NULL) +---- +NULL + +query B +SELECT '1' IN ('a','b',1) +---- +true + +query B +SELECT '2' IN ('a','b',1) +---- +false + +query B +SELECT '2' NOT IN ('a','b',1) +---- +true + +query B +SELECT '1' NOT IN ('a','b',1) +---- +false + +query B +SELECT NULL IN ('a','b',1) +---- +NULL + +query B +SELECT NULL NOT IN ('a','b',1) +---- +NULL + +query B +SELECT '1' IN ('a','b',NULL,1) +---- +true + +query B +SELECT '2' IN ('a','b',NULL,1) +---- +NULL + +query B +SELECT '1' NOT IN ('a','b',NULL,1) +---- +false + +query B +SELECT '2' NOT IN ('a','b',NULL,1) +---- +NULL diff --git a/datafusion/sqllogictest/test_files/join.slt b/datafusion/sqllogictest/test_files/join.slt index c9dd7ca604ad..ca9b918ff3ee 100644 --- a/datafusion/sqllogictest/test_files/join.slt +++ b/datafusion/sqllogictest/test_files/join.slt @@ -626,6 +626,38 @@ Alice 100 Alice 1 Alice 50 Alice 2 Alice 100 Alice 2 +statement ok +set datafusion.execution.target_partitions = 1; + +statement ok +set datafusion.optimizer.repartition_joins = true; + +# make sure when target partition is 1, hash repartition is not added +# to the final plan. +query TT +EXPLAIN SELECT * +FROM t1, +t1 as t2 +WHERE t1.a=t2.a; +---- +logical_plan +Inner Join: t1.a = t2.a +--TableScan: t1 projection=[a, b] +--SubqueryAlias: t2 +----TableScan: t1 projection=[a, b] +physical_plan +CoalesceBatchesExec: target_batch_size=8192 +--HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(a@0, a@0)] +----MemoryExec: partitions=1, partition_sizes=[1] +----MemoryExec: partitions=1, partition_sizes=[1] + +# Reset the configs to old values +statement ok +set datafusion.execution.target_partitions = 4; + +statement ok +set datafusion.optimizer.repartition_joins = false; + statement ok DROP TABLE t1; diff --git a/datafusion/sqllogictest/test_files/repartition_scan.slt b/datafusion/sqllogictest/test_files/repartition_scan.slt index 02eccd7c5d06..9d4951c7ecac 100644 --- a/datafusion/sqllogictest/test_files/repartition_scan.slt +++ b/datafusion/sqllogictest/test_files/repartition_scan.slt @@ -63,6 +63,26 @@ CoalesceBatchesExec: target_batch_size=8192 --FilterExec: column1@0 != 42 ----ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..101], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:101..202], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:202..303], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:303..403]]}, projection=[column1], predicate=column1@0 != 42, pruning_predicate=column1_min@0 != 42 OR 42 != column1_max@1 +# disable round robin repartitioning +statement ok +set datafusion.optimizer.enable_round_robin_repartition = false; + +## Expect to see the scan read the file as "4" groups with even sizes (offsets) again +query TT +EXPLAIN SELECT column1 FROM parquet_table WHERE column1 <> 42; +---- +logical_plan +Filter: parquet_table.column1 != Int32(42) +--TableScan: parquet_table projection=[column1], partial_filters=[parquet_table.column1 != Int32(42)] +physical_plan +CoalesceBatchesExec: target_batch_size=8192 +--FilterExec: column1@0 != 42 +----ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..101], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:101..202], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:202..303], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:303..403]]}, projection=[column1], predicate=column1@0 != 42, pruning_predicate=column1_min@0 != 42 OR 42 != column1_max@1 + +# enable round robin repartitioning again +statement ok +set datafusion.optimizer.enable_round_robin_repartition = true; + # create a second parquet file statement ok COPY (VALUES (100), (200)) TO 'test_files/scratch/repartition_scan/parquet_table/1.parquet' @@ -147,7 +167,7 @@ WITH HEADER ROW LOCATION 'test_files/scratch/repartition_scan/csv_table/'; query I -select * from csv_table; +select * from csv_table ORDER BY column1; ---- 1 2 @@ -190,7 +210,7 @@ STORED AS json LOCATION 'test_files/scratch/repartition_scan/json_table/'; query I -select * from "json_table"; +select * from "json_table" ORDER BY column1; ---- 1 2 diff --git a/datafusion/sqllogictest/test_files/tpch/q2.slt.part b/datafusion/sqllogictest/test_files/tpch/q2.slt.part index ed439348d22d..ed950db190bb 100644 --- a/datafusion/sqllogictest/test_files/tpch/q2.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/q2.slt.part @@ -238,7 +238,7 @@ order by p_partkey limit 10; ---- -9828.21 Supplier#000000647 UNITED KINGDOM 13120 Manufacturer#5 x5U7MBZmwfG9 33-258-202-4782 s the slyly even ideas poach fluffily +9828.21 Supplier#000000647 UNITED KINGDOM 13120 Manufacturer#5 x5U7MBZmwfG9 33-258-202-4782 s the slyly even ideas poach fluffily 9508.37 Supplier#000000070 FRANCE 3563 Manufacturer#1 INWNH2w,OOWgNDq0BRCcBwOMQc6PdFDc4 16-821-608-1166 ests sleep quickly express ideas. ironic ideas haggle about the final T 9508.37 Supplier#000000070 FRANCE 17268 Manufacturer#4 INWNH2w,OOWgNDq0BRCcBwOMQc6PdFDc4 16-821-608-1166 ests sleep quickly express ideas. ironic ideas haggle about the final T 9453.01 Supplier#000000802 ROMANIA 10021 Manufacturer#5 ,6HYXb4uaHITmtMBj4Ak57Pd 29-342-882-6463 gular frets. permanently special multipliers believe blithely alongs diff --git a/docs/logos/DataFUSION-Logo-Dark.svg b/docs/logos/DataFUSION-Logo-Dark.svg new file mode 100644 index 000000000000..e16f244430e6 --- /dev/null +++ b/docs/logos/DataFUSION-Logo-Dark.svg @@ -0,0 +1 @@ +DataFUSION-Logo-Dark \ No newline at end of file diff --git a/docs/logos/DataFUSION-Logo-Dark@2x.png b/docs/logos/DataFUSION-Logo-Dark@2x.png new file mode 100644 index 000000000000..cc60f12a0e4f Binary files /dev/null and b/docs/logos/DataFUSION-Logo-Dark@2x.png differ diff --git a/docs/logos/DataFUSION-Logo-Dark@4x.png b/docs/logos/DataFUSION-Logo-Dark@4x.png new file mode 100644 index 000000000000..0503c216ac84 Binary files /dev/null and b/docs/logos/DataFUSION-Logo-Dark@4x.png differ diff --git a/docs/logos/DataFUSION-Logo-Light.svg b/docs/logos/DataFUSION-Logo-Light.svg new file mode 100644 index 000000000000..b3bef2193dde --- /dev/null +++ b/docs/logos/DataFUSION-Logo-Light.svg @@ -0,0 +1 @@ +DataFUSION-Logo-Light \ No newline at end of file diff --git a/docs/logos/DataFUSION-Logo-Light@2x.png b/docs/logos/DataFUSION-Logo-Light@2x.png new file mode 100644 index 000000000000..8992213b0e60 Binary files /dev/null and b/docs/logos/DataFUSION-Logo-Light@2x.png differ diff --git a/docs/logos/DataFUSION-Logo-Light@4x.png b/docs/logos/DataFUSION-Logo-Light@4x.png new file mode 100644 index 000000000000..bd329ca21956 Binary files /dev/null and b/docs/logos/DataFUSION-Logo-Light@4x.png differ diff --git a/docs/logos/DataFusion-LogoAndColorPaletteExploration_v01.pdf b/docs/logos/DataFusion-LogoAndColorPaletteExploration_v01.pdf new file mode 100644 index 000000000000..4594c50f9044 Binary files /dev/null and b/docs/logos/DataFusion-LogoAndColorPaletteExploration_v01.pdf differ diff --git a/docs/source/_templates/layout.html b/docs/source/_templates/layout.html index a9d0f30bcf8e..9f7880049856 100644 --- a/docs/source/_templates/layout.html +++ b/docs/source/_templates/layout.html @@ -3,3 +3,24 @@ {# Silence the navbar #} {% block docs_navbar %} {% endblock %} + + +{% block footer %} + + + +{% endblock %} diff --git a/docs/source/conf.py b/docs/source/conf.py index 3fa6c6091d6f..becece330d1a 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -33,9 +33,9 @@ # -- Project information ----------------------------------------------------- -project = 'Arrow DataFusion' -copyright = '2023, Apache Software Foundation' -author = 'Arrow DataFusion Authors' +project = 'Apache Arrow DataFusion' +copyright = '2019-2024, Apache Software Foundation' +author = 'Apache Software Foundation' # -- General configuration ---------------------------------------------------