From 4cde998f1e6dc3f572f2b3ed09994190f5d770ce Mon Sep 17 00:00:00 2001 From: Huaijin Date: Tue, 16 Jan 2024 10:16:30 +0800 Subject: [PATCH 01/39] fix: don't extract common sub expr in `CASE WHEN` clause (#8833) * fix: don't extract common sub expr in CASE WHEN clause * fix ci * fix --- .../optimizer/src/common_subexpr_eliminate.rs | 24 +++++++----- datafusion/optimizer/src/push_down_filter.rs | 39 +++---------------- datafusion/optimizer/src/utils.rs | 16 ++++++++ .../sqllogictest/test_files/functions.slt | 2 +- datafusion/sqllogictest/test_files/select.slt | 19 +++++++++ .../sqllogictest/test_files/tpch/q14.slt.part | 6 +-- 6 files changed, 58 insertions(+), 48 deletions(-) diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 000329d0d078..fc867df23c36 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -20,6 +20,7 @@ use std::collections::{BTreeSet, HashMap}; use std::sync::Arc; +use crate::utils::is_volatile_expression; use crate::{utils, OptimizerConfig, OptimizerRule}; use arrow::datatypes::DataType; @@ -29,7 +30,7 @@ use datafusion_common::tree_node::{ use datafusion_common::{ internal_err, Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, }; -use datafusion_expr::expr::{is_volatile, Alias}; +use datafusion_expr::expr::Alias; use datafusion_expr::logical_plan::{ Aggregate, Filter, LogicalPlan, Projection, Sort, Window, }; @@ -518,7 +519,7 @@ enum ExprMask { } impl ExprMask { - fn ignores(&self, expr: &Expr) -> Result { + fn ignores(&self, expr: &Expr) -> bool { let is_normal_minus_aggregates = matches!( expr, Expr::Literal(..) @@ -529,14 +530,12 @@ impl ExprMask { | Expr::Wildcard { .. } ); - let is_volatile = is_volatile(expr)?; - let is_aggr = matches!(expr, Expr::AggregateFunction(..)); - Ok(match self { - Self::Normal => is_volatile || is_normal_minus_aggregates || is_aggr, - Self::NormalAndAggregates => is_volatile || is_normal_minus_aggregates, - }) + match self { + Self::Normal => is_normal_minus_aggregates || is_aggr, + Self::NormalAndAggregates => is_normal_minus_aggregates, + } } } @@ -614,7 +613,12 @@ impl ExprIdentifierVisitor<'_> { impl TreeNodeVisitor for ExprIdentifierVisitor<'_> { type N = Expr; - fn pre_visit(&mut self, _expr: &Expr) -> Result { + fn pre_visit(&mut self, expr: &Expr) -> Result { + // related to https://github.com/apache/arrow-datafusion/issues/8814 + // If the expr contain volatile expression or is a case expression, skip it. + if matches!(expr, Expr::Case(..)) || is_volatile_expression(expr)? { + return Ok(VisitRecursion::Skip); + } self.visit_stack .push(VisitRecord::EnterMark(self.node_count)); self.node_count += 1; @@ -628,7 +632,7 @@ impl TreeNodeVisitor for ExprIdentifierVisitor<'_> { let (idx, sub_expr_desc) = self.pop_enter_mark(); // skip exprs should not be recognize. - if self.expr_mask.ignores(expr)? { + if self.expr_mask.ignores(expr) { self.id_array[idx].0 = self.series_number; let desc = Self::desc_expr(expr); self.visit_stack.push(VisitRecord::ExprItem(desc)); diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 4eb925ac0629..7086c5cda56f 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -19,6 +19,7 @@ use std::collections::{HashMap, HashSet}; use std::sync::Arc; use crate::optimizer::ApplyOrder; +use crate::utils::is_volatile_expression; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; @@ -34,7 +35,7 @@ use datafusion_expr::logical_plan::{ use datafusion_expr::utils::{conjunction, split_conjunction, split_conjunction_owned}; use datafusion_expr::{ and, build_join_schema, or, BinaryExpr, Expr, Filter, LogicalPlanBuilder, Operator, - ScalarFunctionDefinition, TableProviderFilterPushDown, Volatility, + ScalarFunctionDefinition, TableProviderFilterPushDown, }; use itertools::Itertools; @@ -739,7 +740,9 @@ impl OptimizerRule for PushDownFilter { (field.qualified_name(), expr) }) - .partition(|(_, value)| is_volatile_expression(value)); + .partition(|(_, value)| { + is_volatile_expression(value).unwrap_or(true) + }); let mut push_predicates = vec![]; let mut keep_predicates = vec![]; @@ -1028,38 +1031,6 @@ pub fn replace_cols_by_name( }) } -/// check whether the expression is volatile predicates -fn is_volatile_expression(e: &Expr) -> bool { - let mut is_volatile = false; - e.apply(&mut |expr| { - Ok(match expr { - Expr::ScalarFunction(f) => match &f.func_def { - ScalarFunctionDefinition::BuiltIn(fun) - if fun.volatility() == Volatility::Volatile => - { - is_volatile = true; - VisitRecursion::Stop - } - ScalarFunctionDefinition::UDF(fun) - if fun.signature().volatility == Volatility::Volatile => - { - is_volatile = true; - VisitRecursion::Stop - } - ScalarFunctionDefinition::Name(_) => { - return internal_err!( - "Function `Expr` with name should be resolved." - ); - } - _ => VisitRecursion::Continue, - }, - _ => VisitRecursion::Continue, - }) - }) - .unwrap(); - is_volatile -} - /// check whether the expression uses the columns in `check_map`. fn contain(e: &Expr, check_map: &HashMap) -> bool { let mut is_contain = false; diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 44f2404afade..5671dc6ae94d 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -18,8 +18,10 @@ //! Collection of utility functions that are leveraged by the query optimizer rules use crate::{OptimizerConfig, OptimizerRule}; +use datafusion_common::tree_node::{TreeNode, VisitRecursion}; use datafusion_common::{Column, DFSchemaRef}; use datafusion_common::{DFSchema, Result}; +use datafusion_expr::expr::is_volatile; use datafusion_expr::expr_rewriter::replace_col; use datafusion_expr::utils as expr_utils; use datafusion_expr::{logical_plan::LogicalPlan, Expr, Operator}; @@ -92,6 +94,20 @@ pub fn log_plan(description: &str, plan: &LogicalPlan) { trace!("{description}::\n{}\n", plan.display_indent_schema()); } +/// check whether the expression is volatile predicates +pub(crate) fn is_volatile_expression(e: &Expr) -> Result { + let mut is_volatile_expr = false; + e.apply(&mut |expr| { + Ok(if is_volatile(expr)? { + is_volatile_expr = true; + VisitRecursion::Stop + } else { + VisitRecursion::Continue + }) + })?; + Ok(is_volatile_expr) +} + /// Splits a conjunctive [`Expr`] such as `A AND B AND C` => `[A, B, C]` /// /// See [`split_conjunction_owned`] for more details and an example. diff --git a/datafusion/sqllogictest/test_files/functions.slt b/datafusion/sqllogictest/test_files/functions.slt index 1903088b0748..7bd60a3a154b 100644 --- a/datafusion/sqllogictest/test_files/functions.slt +++ b/datafusion/sqllogictest/test_files/functions.slt @@ -998,6 +998,6 @@ NULL # Verify that multiple calls to volatile functions like `random()` are not combined / optimized away query B -SELECT r FROM (SELECT r1 == r2 r, r1, r2 FROM (SELECT random() r1, random() r2) WHERE r1 > 0 AND r2 > 0) +SELECT r FROM (SELECT r1 == r2 r, r1, r2 FROM (SELECT random()+1 r1, random()+1 r2) WHERE r1 > 0 AND r2 > 0) ---- false diff --git a/datafusion/sqllogictest/test_files/select.slt b/datafusion/sqllogictest/test_files/select.slt index 132bcdd246fe..ca48c07b0914 100644 --- a/datafusion/sqllogictest/test_files/select.slt +++ b/datafusion/sqllogictest/test_files/select.slt @@ -1112,3 +1112,22 @@ SELECT abs(x), abs(x) + abs(y) FROM t; statement ok DROP TABLE t; + +# related to https://github.com/apache/arrow-datafusion/issues/8814 +statement ok +create table t(x int, y int) as values (1,1), (2,2), (3,3), (0,0), (4,0); + +query II +SELECT +CASE WHEN B.x > 0 THEN A.x / B.x ELSE 0 END AS value1, +CASE WHEN B.x > 0 AND B.y > 0 THEN A.x / B.x ELSE 0 END AS value3 +FROM t AS A, (SELECT * FROM t WHERE x = 0) AS B; +---- +0 0 +0 0 +0 0 +0 0 +0 0 + +statement ok +DROP TABLE t; diff --git a/datafusion/sqllogictest/test_files/tpch/q14.slt.part b/datafusion/sqllogictest/test_files/tpch/q14.slt.part index b584972c25bc..7e614ab49e38 100644 --- a/datafusion/sqllogictest/test_files/tpch/q14.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/q14.slt.part @@ -33,8 +33,8 @@ where ---- logical_plan Projection: Float64(100) * CAST(SUM(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * Int64(1) - lineitem.l_discount ELSE Int64(0) END) AS Float64) / CAST(SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS Float64) AS promo_revenue ---Aggregate: groupBy=[[]], aggr=[[SUM(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)Decimal128(Some(1),20,0) - lineitem.l_discountlineitem.l_discountDecimal128(Some(1),20,0)lineitem.l_extendedprice AS lineitem.l_extendedprice * Decimal128(Some(1),20,0) - lineitem.l_discount ELSE Decimal128(Some(0),38,4) END) AS SUM(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * Int64(1) - lineitem.l_discount ELSE Int64(0) END), SUM(lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)Decimal128(Some(1),20,0) - lineitem.l_discountlineitem.l_discountDecimal128(Some(1),20,0)lineitem.l_extendedprice AS lineitem.l_extendedprice * Decimal128(Some(1),20,0) - lineitem.l_discount) AS SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]] -----Projection: lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount) AS lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)Decimal128(Some(1),20,0) - lineitem.l_discountlineitem.l_discountDecimal128(Some(1),20,0)lineitem.l_extendedprice, part.p_type +--Aggregate: groupBy=[[]], aggr=[[SUM(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount) ELSE Decimal128(Some(0),38,4) END) AS SUM(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * Int64(1) - lineitem.l_discount ELSE Int64(0) END), SUM(lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)) AS SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]] +----Projection: lineitem.l_extendedprice, lineitem.l_discount, part.p_type ------Inner Join: lineitem.l_partkey = part.p_partkey --------Projection: lineitem.l_partkey, lineitem.l_extendedprice, lineitem.l_discount ----------Filter: lineitem.l_shipdate >= Date32("9374") AND lineitem.l_shipdate < Date32("9404") @@ -45,7 +45,7 @@ ProjectionExec: expr=[100 * CAST(SUM(CASE WHEN part.p_type LIKE Utf8("PROMO%") --AggregateExec: mode=Final, gby=[], aggr=[SUM(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * Int64(1) - lineitem.l_discount ELSE Int64(0) END), SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] ----CoalescePartitionsExec ------AggregateExec: mode=Partial, gby=[], aggr=[SUM(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * Int64(1) - lineitem.l_discount ELSE Int64(0) END), SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] ---------ProjectionExec: expr=[l_extendedprice@1 * (Some(1),20,0 - l_discount@2) as lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)Decimal128(Some(1),20,0) - lineitem.l_discountlineitem.l_discountDecimal128(Some(1),20,0)lineitem.l_extendedprice, p_type@4 as p_type] +--------ProjectionExec: expr=[l_extendedprice@1 as l_extendedprice, l_discount@2 as l_discount, p_type@4 as p_type] ----------CoalesceBatchesExec: target_batch_size=8192 ------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(l_partkey@0, p_partkey@0)] --------------CoalesceBatchesExec: target_batch_size=8192 From 08de64d3778ef6ad79d0f60f462c2c66ed84b12e Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Tue, 16 Jan 2024 03:51:19 -0500 Subject: [PATCH 02/39] Add "Extended" clickbench queries (#8861) --- benchmarks/bench.sh | 17 ++++++ benchmarks/queries/clickbench/README.md | 33 +++++++++++ benchmarks/queries/clickbench/README.txt | 1 - benchmarks/queries/clickbench/extended.sql | 1 + benchmarks/src/clickbench.rs | 66 ++++++++++++++-------- 5 files changed, 92 insertions(+), 26 deletions(-) create mode 100644 benchmarks/queries/clickbench/README.md delete mode 100644 benchmarks/queries/clickbench/README.txt create mode 100644 benchmarks/queries/clickbench/extended.sql diff --git a/benchmarks/bench.sh b/benchmarks/bench.sh index bdbdc0e51762..ccaf26eb798d 100755 --- a/benchmarks/bench.sh +++ b/benchmarks/bench.sh @@ -74,6 +74,7 @@ parquet: Benchmark of parquet reader's filtering speed sort: Benchmark of sorting speed clickbench_1: ClickBench queries against a single parquet file clickbench_partitioned: ClickBench queries against a partitioned (100 files) parquet +clickbench_extended: ClickBench "inspired" queries against a single parquet (DataFusion specific) ********** * Supported Configuration (Environment Variables) @@ -155,6 +156,9 @@ main() { clickbench_partitioned) data_clickbench_partitioned ;; + clickbench_extended) + data_clickbench_1 + ;; *) echo "Error: unknown benchmark '$BENCHMARK' for data generation" usage @@ -193,6 +197,7 @@ main() { run_sort run_clickbench_1 run_clickbench_partitioned + run_clickbench_extended ;; tpch) run_tpch "1" @@ -218,6 +223,9 @@ main() { clickbench_partitioned) run_clickbench_partitioned ;; + clickbench_extended) + run_clickbench_extended + ;; *) echo "Error: unknown benchmark '$BENCHMARK' for run" usage @@ -401,6 +409,15 @@ run_clickbench_partitioned() { $CARGO_COMMAND --bin dfbench -- clickbench --iterations 5 --path "${DATA_DIR}/hits_partitioned" --queries-path "${SCRIPT_DIR}/queries/clickbench/queries.sql" -o ${RESULTS_FILE} } +# Runs the clickbench "extended" benchmark with a single large parquet file +run_clickbench_extended() { + RESULTS_FILE="${RESULTS_DIR}/clickbench_extended.json" + echo "RESULTS_FILE: ${RESULTS_FILE}" + echo "Running clickbench (1 file) extended benchmark..." + $CARGO_COMMAND --bin dfbench -- clickbench --iterations 5 --path "${DATA_DIR}/hits.parquet" --queries-path "${SCRIPT_DIR}/queries/clickbench/extended.sql" -o ${RESULTS_FILE} +} + + compare_benchmarks() { BASE_RESULTS_DIR="${SCRIPT_DIR}/results" BRANCH1="${ARG2}" diff --git a/benchmarks/queries/clickbench/README.md b/benchmarks/queries/clickbench/README.md new file mode 100644 index 000000000000..d5105afd4832 --- /dev/null +++ b/benchmarks/queries/clickbench/README.md @@ -0,0 +1,33 @@ +# ClickBench queries + +This directory contains queries for the ClickBench benchmark https://benchmark.clickhouse.com/ + +ClickBench is focused on aggregation and filtering performance (though it has no Joins) + +## Files: +* `queries.sql` - Actual ClickBench queries, downloaded from the [ClickBench repository] +* `extended.sql` - "Extended" DataFusion specific queries. + +[ClickBench repository]: https://github.com/ClickHouse/ClickBench/blob/main/datafusion/queries.sql + +## "Extended" Queries +The "extended" queries are not part of the official ClickBench benchmark. +Instead they are used to test other DataFusion features that are not +covered by the standard benchmark + +Each description below is for the corresponding line in `extended.sql` (line 1 +is `Q0`, line 2 is `Q1`, etc.) + +### Q0 +Models initial Data exploration, to understand some statistics of data. +Import Query Properties: multiple `COUNT DISTINCT` on strings + +```sql +SELECT + COUNT(DISTINCT "SearchPhrase"), COUNT(DISTINCT "MobilePhone"), COUNT(DISTINCT "MobilePhoneModel") +FROM hits; +``` + + + + diff --git a/benchmarks/queries/clickbench/README.txt b/benchmarks/queries/clickbench/README.txt deleted file mode 100644 index b46900956e54..000000000000 --- a/benchmarks/queries/clickbench/README.txt +++ /dev/null @@ -1 +0,0 @@ -Downloaded from https://github.com/ClickHouse/ClickBench/blob/main/datafusion/queries.sql diff --git a/benchmarks/queries/clickbench/extended.sql b/benchmarks/queries/clickbench/extended.sql new file mode 100644 index 000000000000..82c0266af61a --- /dev/null +++ b/benchmarks/queries/clickbench/extended.sql @@ -0,0 +1 @@ +SELECT COUNT(DISTINCT "SearchPhrase"), COUNT(DISTINCT "MobilePhone"), COUNT(DISTINCT "MobilePhoneModel") FROM hits; \ No newline at end of file diff --git a/benchmarks/src/clickbench.rs b/benchmarks/src/clickbench.rs index a6d32eb39f31..69a650a106c7 100644 --- a/benchmarks/src/clickbench.rs +++ b/benchmarks/src/clickbench.rs @@ -15,13 +15,14 @@ // specific language governing permissions and limitations // under the License. +use std::path::Path; use std::{path::PathBuf, time::Instant}; use datafusion::{ - common::exec_err, error::{DataFusionError, Result}, prelude::SessionContext, }; +use datafusion_common::exec_datafusion_err; use structopt::StructOpt; use crate::{BenchmarkRun, CommonOpt}; @@ -69,15 +70,49 @@ pub struct RunOpt { output_path: Option, } -const CLICKBENCH_QUERY_START_ID: usize = 0; -const CLICKBENCH_QUERY_END_ID: usize = 42; +struct AllQueries { + queries: Vec, +} + +impl AllQueries { + fn try_new(path: &Path) -> Result { + // ClickBench has all queries in a single file identified by line number + let all_queries = std::fs::read_to_string(path) + .map_err(|e| exec_datafusion_err!("Could not open {path:?}: {e}"))?; + Ok(Self { + queries: all_queries.lines().map(|s| s.to_string()).collect(), + }) + } + + /// Returns the text of query `query_id` + fn get_query(&self, query_id: usize) -> Result<&str> { + self.queries + .get(query_id) + .ok_or_else(|| { + let min_id = self.min_query_id(); + let max_id = self.max_query_id(); + exec_datafusion_err!( + "Invalid query id {query_id}. Must be between {min_id} and {max_id}" + ) + }) + .map(|s| s.as_str()) + } + + fn min_query_id(&self) -> usize { + 0 + } + fn max_query_id(&self) -> usize { + self.queries.len() - 1 + } +} impl RunOpt { pub async fn run(self) -> Result<()> { println!("Running benchmarks with the following options: {self:?}"); + let queries = AllQueries::try_new(self.queries_path.as_path())?; let query_range = match self.query { Some(query_id) => query_id..=query_id, - None => CLICKBENCH_QUERY_START_ID..=CLICKBENCH_QUERY_END_ID, + None => queries.min_query_id()..=queries.max_query_id(), }; let config = self.common.config(); @@ -88,12 +123,12 @@ impl RunOpt { let mut benchmark_run = BenchmarkRun::new(); for query_id in query_range { benchmark_run.start_new_case(&format!("Query {query_id}")); - let sql = self.get_query(query_id)?; + let sql = queries.get_query(query_id)?; println!("Q{query_id}: {sql}"); for i in 0..iterations { let start = Instant::now(); - let results = ctx.sql(&sql).await?.collect().await?; + let results = ctx.sql(sql).await?.collect().await?; let elapsed = start.elapsed(); let ms = elapsed.as_secs_f64() * 1000.0; let row_count: usize = results.iter().map(|b| b.num_rows()).sum(); @@ -120,23 +155,4 @@ impl RunOpt { ) }) } - - /// Returns the text of query `query_id` - fn get_query(&self, query_id: usize) -> Result { - if query_id > CLICKBENCH_QUERY_END_ID { - return exec_err!( - "Invalid query id {query_id}. Must be between {CLICKBENCH_QUERY_START_ID} and {CLICKBENCH_QUERY_END_ID}" - ); - } - - let path = self.queries_path.as_path(); - - // ClickBench has all queries in a single file identified by line number - let all_queries = std::fs::read_to_string(path).map_err(|e| { - DataFusionError::Execution(format!("Could not open {path:?}: {e}")) - })?; - let all_queries: Vec<_> = all_queries.lines().collect(); - - Ok(all_queries.get(query_id).map(|s| s.to_string()).unwrap()) - } } From 5433b52e93ec1031cf08d73d20556421f604a1e0 Mon Sep 17 00:00:00 2001 From: Trent Hauck Date: Tue, 16 Jan 2024 03:53:07 -0800 Subject: [PATCH 03/39] Change cli to propagate error to exit code (#8856) * feat(cli): use error exit code when command or file * fix: remove prints * style: rust fmt * refactor: use CLI specific error for display * refactor: better use statements * Revert "refactor: better use statements" This reverts commit fac8c3a2e9c3072307679b149543ead834ef1035. * Revert "refactor: use CLI specific error for display" This reverts commit e58d331439438566acc4779bca9a873ed77d0818. * refactor: wrap main_inner, use ExitCode --- datafusion-cli/src/command.rs | 3 ++- datafusion-cli/src/exec.rs | 26 ++++++++++++++------------ datafusion-cli/src/main.rs | 20 ++++++++++++++++---- 3 files changed, 32 insertions(+), 17 deletions(-) diff --git a/datafusion-cli/src/command.rs b/datafusion-cli/src/command.rs index f7f36b6f9d51..feef137e6195 100644 --- a/datafusion-cli/src/command.rs +++ b/datafusion-cli/src/command.rs @@ -79,7 +79,8 @@ impl Command { filename, e )) })?; - exec_from_lines(ctx, &mut BufReader::new(file), print_options).await; + exec_from_lines(ctx, &mut BufReader::new(file), print_options) + .await?; Ok(()) } else { exec_err!("Required filename argument is missing") diff --git a/datafusion-cli/src/exec.rs b/datafusion-cli/src/exec.rs index 637fc7e4d9e8..aabf69aac888 100644 --- a/datafusion-cli/src/exec.rs +++ b/datafusion-cli/src/exec.rs @@ -53,13 +53,12 @@ pub async fn exec_from_commands( ctx: &mut SessionContext, commands: Vec, print_options: &PrintOptions, -) { +) -> Result<()> { for sql in commands { - match exec_and_print(ctx, print_options, sql).await { - Ok(_) => {} - Err(err) => println!("{err}"), - } + exec_and_print(ctx, print_options, sql).await?; } + + Ok(()) } /// run and execute SQL statements and commands from a file, against a context with the given print options @@ -67,7 +66,7 @@ pub async fn exec_from_lines( ctx: &mut SessionContext, reader: &mut BufReader, print_options: &PrintOptions, -) { +) -> Result<()> { let mut query = "".to_owned(); for line in reader.lines() { @@ -97,26 +96,28 @@ pub async fn exec_from_lines( // run the left over query if the last statement doesn't contain ‘;’ // ignore if it only consists of '\n' if query.contains(|c| c != '\n') { - match exec_and_print(ctx, print_options, query).await { - Ok(_) => {} - Err(err) => println!("{err}"), - } + exec_and_print(ctx, print_options, query).await?; } + + Ok(()) } pub async fn exec_from_files( ctx: &mut SessionContext, files: Vec, print_options: &PrintOptions, -) { +) -> Result<()> { let files = files .into_iter() .map(|file_path| File::open(file_path).unwrap()) .collect::>(); + for file in files { let mut reader = BufReader::new(file); - exec_from_lines(ctx, &mut reader, print_options).await; + exec_from_lines(ctx, &mut reader, print_options).await?; } + + Ok(()) } /// run and execute SQL statements and commands against a context with the given print options @@ -215,6 +216,7 @@ async fn exec_and_print( MsSQL, ClickHouse, BigQuery, Ansi." ) })?; + let statements = DFParser::parse_sql_with_dialect(&sql, dialect.as_ref())?; for statement in statements { let mut plan = ctx.state().statement_to_plan(statement).await?; diff --git a/datafusion-cli/src/main.rs b/datafusion-cli/src/main.rs index dcfd28df1cb0..a9082f2e5351 100644 --- a/datafusion-cli/src/main.rs +++ b/datafusion-cli/src/main.rs @@ -18,6 +18,7 @@ use std::collections::HashMap; use std::env; use std::path::Path; +use std::process::ExitCode; use std::str::FromStr; use std::sync::{Arc, OnceLock}; @@ -138,7 +139,18 @@ struct Args { } #[tokio::main] -pub async fn main() -> Result<()> { +/// Calls [`main_inner`], then handles printing errors and returning the correct exit code +pub async fn main() -> ExitCode { + if let Err(e) = main_inner().await { + println!("Error: {e}"); + return ExitCode::FAILURE; + } + + ExitCode::SUCCESS +} + +/// Main CLI entrypoint +async fn main_inner() -> Result<()> { env_logger::init(); let args = Args::parse(); @@ -216,7 +228,7 @@ pub async fn main() -> Result<()> { if commands.is_empty() && files.is_empty() { if !rc.is_empty() { - exec::exec_from_files(&mut ctx, rc, &print_options).await + exec::exec_from_files(&mut ctx, rc, &print_options).await?; } // TODO maybe we can have thiserror for cli but for now let's keep it simple return exec::exec_from_repl(&mut ctx, &mut print_options) @@ -225,11 +237,11 @@ pub async fn main() -> Result<()> { } if !files.is_empty() { - exec::exec_from_files(&mut ctx, files, &print_options).await; + exec::exec_from_files(&mut ctx, files, &print_options).await?; } if !commands.is_empty() { - exec::exec_from_commands(&mut ctx, commands, &print_options).await; + exec::exec_from_commands(&mut ctx, commands, &print_options).await?; } Ok(()) From b64ad79f8bb70f44e6744538619eccee661b2874 Mon Sep 17 00:00:00 2001 From: Dejan Simic <10134699+simicd@users.noreply.github.com> Date: Tue, 16 Jan 2024 14:12:28 +0100 Subject: [PATCH 04/39] test: Port tests in `predicates.rs` to sqllogictest (#8879) --- datafusion/core/tests/sql/mod.rs | 202 +---------------- datafusion/core/tests/sql/predicates.rs | 186 ---------------- .../sqllogictest/test_files/predicates.slt | 206 +++++++++++++++++- 3 files changed, 206 insertions(+), 388 deletions(-) delete mode 100644 datafusion/core/tests/sql/predicates.rs diff --git a/datafusion/core/tests/sql/mod.rs b/datafusion/core/tests/sql/mod.rs index 80c6c81ef955..0960f93ae103 100644 --- a/datafusion/core/tests/sql/mod.rs +++ b/datafusion/core/tests/sql/mod.rs @@ -21,10 +21,9 @@ use arrow::{ array::*, datatypes::*, record_batch::RecordBatch, util::display::array_value_to_string, }; -use chrono::prelude::*; use datafusion::datasource::TableProvider; -use datafusion::error::{DataFusionError, Result}; +use datafusion::error::Result; use datafusion::logical_expr::{Aggregate, LogicalPlan, TableScan}; use datafusion::physical_plan::metrics::MetricValue; use datafusion::physical_plan::ExecutionPlan; @@ -34,12 +33,10 @@ use datafusion::test_util; use datafusion::{assert_batches_eq, assert_batches_sorted_eq}; use datafusion::{datasource::MemTable, physical_plan::collect}; use datafusion::{execution::context::SessionContext, physical_plan::displayable}; -use datafusion_common::plan_err; use datafusion_common::{assert_contains, assert_not_contains}; use object_store::path::Path; use std::fs::File; use std::io::Write; -use std::ops::Sub; use std::path::PathBuf; use tempfile::TempDir; @@ -77,7 +74,6 @@ pub mod explain_analyze; pub mod expr; pub mod joins; pub mod partitioned_csv; -pub mod predicates; pub mod references; pub mod repartition; pub mod select; @@ -211,202 +207,6 @@ fn create_left_semi_anti_join_context_with_null_ids( Ok(ctx) } -fn get_tpch_table_schema(table: &str) -> Schema { - match table { - "customer" => Schema::new(vec![ - Field::new("c_custkey", DataType::Int64, false), - Field::new("c_name", DataType::Utf8, false), - Field::new("c_address", DataType::Utf8, false), - Field::new("c_nationkey", DataType::Int64, false), - Field::new("c_phone", DataType::Utf8, false), - Field::new("c_acctbal", DataType::Decimal128(15, 2), false), - Field::new("c_mktsegment", DataType::Utf8, false), - Field::new("c_comment", DataType::Utf8, false), - ]), - - "orders" => Schema::new(vec![ - Field::new("o_orderkey", DataType::Int64, false), - Field::new("o_custkey", DataType::Int64, false), - Field::new("o_orderstatus", DataType::Utf8, false), - Field::new("o_totalprice", DataType::Decimal128(15, 2), false), - Field::new("o_orderdate", DataType::Date32, false), - Field::new("o_orderpriority", DataType::Utf8, false), - Field::new("o_clerk", DataType::Utf8, false), - Field::new("o_shippriority", DataType::Int32, false), - Field::new("o_comment", DataType::Utf8, false), - ]), - - "lineitem" => Schema::new(vec![ - Field::new("l_orderkey", DataType::Int64, false), - Field::new("l_partkey", DataType::Int64, false), - Field::new("l_suppkey", DataType::Int64, false), - Field::new("l_linenumber", DataType::Int32, false), - Field::new("l_quantity", DataType::Decimal128(15, 2), false), - Field::new("l_extendedprice", DataType::Decimal128(15, 2), false), - Field::new("l_discount", DataType::Decimal128(15, 2), false), - Field::new("l_tax", DataType::Decimal128(15, 2), false), - Field::new("l_returnflag", DataType::Utf8, false), - Field::new("l_linestatus", DataType::Utf8, false), - Field::new("l_shipdate", DataType::Date32, false), - Field::new("l_commitdate", DataType::Date32, false), - Field::new("l_receiptdate", DataType::Date32, false), - Field::new("l_shipinstruct", DataType::Utf8, false), - Field::new("l_shipmode", DataType::Utf8, false), - Field::new("l_comment", DataType::Utf8, false), - ]), - - "nation" => Schema::new(vec![ - Field::new("n_nationkey", DataType::Int64, false), - Field::new("n_name", DataType::Utf8, false), - Field::new("n_regionkey", DataType::Int64, false), - Field::new("n_comment", DataType::Utf8, false), - ]), - - "supplier" => Schema::new(vec![ - Field::new("s_suppkey", DataType::Int64, false), - Field::new("s_name", DataType::Utf8, false), - Field::new("s_address", DataType::Utf8, false), - Field::new("s_nationkey", DataType::Int64, false), - Field::new("s_phone", DataType::Utf8, false), - Field::new("s_acctbal", DataType::Decimal128(15, 2), false), - Field::new("s_comment", DataType::Utf8, false), - ]), - - "partsupp" => Schema::new(vec![ - Field::new("ps_partkey", DataType::Int64, false), - Field::new("ps_suppkey", DataType::Int64, false), - Field::new("ps_availqty", DataType::Int32, false), - Field::new("ps_supplycost", DataType::Decimal128(15, 2), false), - Field::new("ps_comment", DataType::Utf8, false), - ]), - - "part" => Schema::new(vec![ - Field::new("p_partkey", DataType::Int64, false), - Field::new("p_name", DataType::Utf8, false), - Field::new("p_mfgr", DataType::Utf8, false), - Field::new("p_brand", DataType::Utf8, false), - Field::new("p_type", DataType::Utf8, false), - Field::new("p_size", DataType::Int32, false), - Field::new("p_container", DataType::Utf8, false), - Field::new("p_retailprice", DataType::Decimal128(15, 2), false), - Field::new("p_comment", DataType::Utf8, false), - ]), - - "region" => Schema::new(vec![ - Field::new("r_regionkey", DataType::Int64, false), - Field::new("r_name", DataType::Utf8, false), - Field::new("r_comment", DataType::Utf8, false), - ]), - - _ => unimplemented!("Table: {}", table), - } -} - -async fn register_tpch_csv(ctx: &SessionContext, table: &str) -> Result<()> { - let schema = get_tpch_table_schema(table); - - ctx.register_csv( - table, - format!("tests/tpch-csv/{table}.csv").as_str(), - CsvReadOptions::new().schema(&schema), - ) - .await?; - Ok(()) -} - -async fn register_tpch_csv_data( - ctx: &SessionContext, - table_name: &str, - data: &str, -) -> Result<()> { - let schema = Arc::new(get_tpch_table_schema(table_name)); - - let mut reader = ::csv::ReaderBuilder::new() - .has_headers(false) - .from_reader(data.as_bytes()); - let records: Vec<_> = reader.records().map(|it| it.unwrap()).collect(); - - let mut cols: Vec> = vec![]; - for field in schema.fields().iter() { - match field.data_type() { - DataType::Utf8 => cols.push(Box::new(StringBuilder::new())), - DataType::Date32 => { - cols.push(Box::new(Date32Builder::with_capacity(records.len()))) - } - DataType::Int32 => { - cols.push(Box::new(Int32Builder::with_capacity(records.len()))) - } - DataType::Int64 => { - cols.push(Box::new(Int64Builder::with_capacity(records.len()))) - } - DataType::Decimal128(_, _) => { - cols.push(Box::new(Decimal128Builder::with_capacity(records.len()))) - } - _ => plan_err!("Not implemented: {}", field.data_type())?, - } - } - - for record in records.iter() { - for (idx, val) in record.iter().enumerate() { - let col = cols.get_mut(idx).unwrap(); - let field = schema.field(idx); - match field.data_type() { - DataType::Utf8 => { - let sb = col.as_any_mut().downcast_mut::().unwrap(); - sb.append_value(val); - } - DataType::Date32 => { - let sb = col.as_any_mut().downcast_mut::().unwrap(); - let dt = NaiveDate::parse_from_str(val.trim(), "%Y-%m-%d").unwrap(); - let dt = dt - .sub(NaiveDate::from_ymd_opt(1970, 1, 1).unwrap()) - .num_days() as i32; - sb.append_value(dt); - } - DataType::Int32 => { - let sb = col.as_any_mut().downcast_mut::().unwrap(); - sb.append_value(val.trim().parse().unwrap()); - } - DataType::Int64 => { - let sb = col.as_any_mut().downcast_mut::().unwrap(); - sb.append_value(val.trim().parse().unwrap()); - } - DataType::Decimal128(_, _) => { - let sb = col - .as_any_mut() - .downcast_mut::() - .unwrap(); - let val = val.trim().replace('.', ""); - let value_i128 = val.parse::().unwrap(); - sb.append_value(value_i128); - } - _ => plan_err!("Not implemented: {}", field.data_type())?, - } - } - } - let cols: Vec = cols - .iter_mut() - .zip(schema.fields()) - .map(|(it, field)| match field.data_type() { - DataType::Decimal128(p, s) => Arc::new( - it.as_any_mut() - .downcast_mut::() - .unwrap() - .finish() - .with_precision_and_scale(*p, *s) - .unwrap(), - ), - _ => it.finish(), - }) - .collect(); - - let batch = RecordBatch::try_new(Arc::clone(&schema), cols)?; - - let _ = ctx.register_batch(table_name, batch).unwrap(); - - Ok(()) -} - async fn register_aggregate_csv_by_sql(ctx: &SessionContext) { let testdata = datafusion::test_util::arrow_test_data(); diff --git a/datafusion/core/tests/sql/predicates.rs b/datafusion/core/tests/sql/predicates.rs deleted file mode 100644 index fe735bf6b828..000000000000 --- a/datafusion/core/tests/sql/predicates.rs +++ /dev/null @@ -1,186 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use super::*; - -#[tokio::test] -async fn string_coercion() -> Result<()> { - let vendor_id_utf8: StringArray = - vec![Some("124"), Some("345")].into_iter().collect(); - - let vendor_id_dict: DictionaryArray = - vec![Some("124"), Some("345")].into_iter().collect(); - - let batch = RecordBatch::try_from_iter(vec![ - ("vendor_id_utf8", Arc::new(vendor_id_utf8) as _), - ("vendor_id_dict", Arc::new(vendor_id_dict) as _), - ]) - .unwrap(); - - let ctx = SessionContext::new(); - ctx.register_batch("t", batch)?; - - let expected = [ - "+----------------+----------------+", - "| vendor_id_utf8 | vendor_id_dict |", - "+----------------+----------------+", - "| 124 | 124 |", - "+----------------+----------------+", - ]; - - // Compare utf8 column with numeric constant - let sql = "SELECT * from t where vendor_id_utf8 = 124"; - let actual = execute_to_batches(&ctx, sql).await; - assert_batches_eq!(expected, &actual); - - // Compare dictionary encoded utf8 column with numeric constant - let sql = "SELECT * from t where vendor_id_dict = 124"; - let actual = execute_to_batches(&ctx, sql).await; - assert_batches_eq!(expected, &actual); - - // Compare dictionary encoded utf8 column with numeric constant with explicit cast - let sql = "SELECT * from t where cast(vendor_id_utf8 as varchar) = 124"; - let actual = execute_to_batches(&ctx, sql).await; - assert_batches_eq!(expected, &actual); - - Ok(()) -} - -#[tokio::test] -// Test issue: https://github.com/apache/arrow-datafusion/issues/3635 -async fn multiple_or_predicates() -> Result<()> { - let ctx = SessionContext::new(); - register_tpch_csv(&ctx, "lineitem").await?; - register_tpch_csv(&ctx, "part").await?; - let sql = "explain select - l_partkey - from - lineitem, - part - where - ( - p_partkey = l_partkey - and p_brand = 'Brand#12' - and l_quantity >= 1 and l_quantity <= 1 + 10 - and p_size between 1 and 5 - ) - or - ( - p_partkey = l_partkey - and p_brand = 'Brand#23' - and l_quantity >= 10 and l_quantity <= 10 + 10 - and p_size between 1 and 10 - ) - or - ( - p_partkey = l_partkey - and p_brand = 'Brand#34' - and l_quantity >= 20 and l_quantity <= 20 + 10 - and p_size between 1 and 15 - )"; - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(sql).await.expect(&msg); - let plan = dataframe.into_optimized_plan().unwrap(); - // Note that we expect `part.p_partkey = lineitem.l_partkey` to have been - // factored out and appear only once in the following plan - let expected = vec![ - "Explain [plan_type:Utf8, plan:Utf8]", - " Projection: lineitem.l_partkey [l_partkey:Int64]", - " Inner Join: lineitem.l_partkey = part.p_partkey Filter: part.p_brand = Utf8(\"Brand#12\") AND lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) AND part.p_size <= Int32(5) OR part.p_brand = Utf8(\"Brand#23\") AND lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) AND part.p_size <= Int32(10) OR part.p_brand = Utf8(\"Brand#34\") AND lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2) AND part.p_size <= Int32(15) [l_partkey:Int64, l_quantity:Decimal128(15, 2), p_partkey:Int64, p_brand:Utf8, p_size:Int32]", - " Filter: lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2) [l_partkey:Int64, l_quantity:Decimal128(15, 2)]", - " TableScan: lineitem projection=[l_partkey, l_quantity], partial_filters=[lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2)] [l_partkey:Int64, l_quantity:Decimal128(15, 2)]", - " Filter: (part.p_brand = Utf8(\"Brand#12\") AND part.p_size <= Int32(5) OR part.p_brand = Utf8(\"Brand#23\") AND part.p_size <= Int32(10) OR part.p_brand = Utf8(\"Brand#34\") AND part.p_size <= Int32(15)) AND part.p_size >= Int32(1) [p_partkey:Int64, p_brand:Utf8, p_size:Int32]", - " TableScan: part projection=[p_partkey, p_brand, p_size], partial_filters=[part.p_size >= Int32(1), part.p_brand = Utf8(\"Brand#12\") AND part.p_size <= Int32(5) OR part.p_brand = Utf8(\"Brand#23\") AND part.p_size <= Int32(10) OR part.p_brand = Utf8(\"Brand#34\") AND part.p_size <= Int32(15)] [p_partkey:Int64, p_brand:Utf8, p_size:Int32]", - ]; - let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - Ok(()) -} - -// Fix for issue#78 join predicates from inside of OR expr also pulled up properly. -#[tokio::test] -async fn tpch_q19_pull_predicates_to_innerjoin_simplified() -> Result<()> { - let ctx = SessionContext::new(); - - register_tpch_csv(&ctx, "part").await?; - register_tpch_csv(&ctx, "lineitem").await?; - - let partsupp = r#"63700,7311,100,993.49,ven ideas. quickly even packages print. pending multipliers must have to are fluff"#; - register_tpch_csv_data(&ctx, "partsupp", partsupp).await?; - - let sql = r#" -select - p_partkey, - sum(l_extendedprice), - avg(l_discount), - count(distinct ps_suppkey) -from - lineitem, - part, - partsupp -where - ( - p_partkey = l_partkey - and p_brand = 'Brand#12' - and p_partkey = ps_partkey - ) - or - ( - p_partkey = l_partkey - and p_brand = 'Brand#23' - and ps_partkey = p_partkey - ) - group by p_partkey - ;"#; - - let dataframe = ctx.sql(sql).await.unwrap(); - let plan = dataframe.into_optimized_plan().unwrap(); - let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - let expected = vec![ - "Aggregate: groupBy=[[part.p_partkey]], aggr=[[SUM(lineitem.l_extendedprice), AVG(lineitem.l_discount), COUNT(DISTINCT partsupp.ps_suppkey)]] [p_partkey:Int64, SUM(lineitem.l_extendedprice):Decimal128(25, 2);N, AVG(lineitem.l_discount):Decimal128(19, 6);N, COUNT(DISTINCT partsupp.ps_suppkey):Int64;N]", - " Projection: lineitem.l_extendedprice, lineitem.l_discount, part.p_partkey, partsupp.ps_suppkey [l_extendedprice:Decimal128(15, 2), l_discount:Decimal128(15, 2), p_partkey:Int64, ps_suppkey:Int64]", - " Inner Join: part.p_partkey = partsupp.ps_partkey [l_extendedprice:Decimal128(15, 2), l_discount:Decimal128(15, 2), p_partkey:Int64, ps_partkey:Int64, ps_suppkey:Int64]", - " Projection: lineitem.l_extendedprice, lineitem.l_discount, part.p_partkey [l_extendedprice:Decimal128(15, 2), l_discount:Decimal128(15, 2), p_partkey:Int64]", - " Inner Join: lineitem.l_partkey = part.p_partkey [l_partkey:Int64, l_extendedprice:Decimal128(15, 2), l_discount:Decimal128(15, 2), p_partkey:Int64]", - " TableScan: lineitem projection=[l_partkey, l_extendedprice, l_discount] [l_partkey:Int64, l_extendedprice:Decimal128(15, 2), l_discount:Decimal128(15, 2)]", - " Projection: part.p_partkey [p_partkey:Int64]", - " Filter: part.p_brand = Utf8(\"Brand#12\") OR part.p_brand = Utf8(\"Brand#23\") [p_partkey:Int64, p_brand:Utf8]", - " TableScan: part projection=[p_partkey, p_brand], partial_filters=[part.p_brand = Utf8(\"Brand#12\") OR part.p_brand = Utf8(\"Brand#23\")] [p_partkey:Int64, p_brand:Utf8]", - " TableScan: partsupp projection=[ps_partkey, ps_suppkey] [ps_partkey:Int64, ps_suppkey:Int64]", - ]; - - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - - // assert data - let results = execute_to_batches(&ctx, sql).await; - let expected = ["+-----------+-------------------------------+--------------------------+-------------------------------------+", - "| p_partkey | SUM(lineitem.l_extendedprice) | AVG(lineitem.l_discount) | COUNT(DISTINCT partsupp.ps_suppkey) |", - "+-----------+-------------------------------+--------------------------+-------------------------------------+", - "| 63700 | 13309.60 | 0.100000 | 1 |", - "+-----------+-------------------------------+--------------------------+-------------------------------------+"]; - assert_batches_eq!(expected, &results); - - Ok(()) -} diff --git a/datafusion/sqllogictest/test_files/predicates.slt b/datafusion/sqllogictest/test_files/predicates.slt index e992a440d0a2..e32e415338a7 100644 --- a/datafusion/sqllogictest/test_files/predicates.slt +++ b/datafusion/sqllogictest/test_files/predicates.slt @@ -16,7 +16,7 @@ # under the License. ########## -## Limit Tests +## Predicates Tests ########## statement ok @@ -521,3 +521,207 @@ set datafusion.execution.parquet.bloom_filter_enabled=false; ######## statement ok DROP TABLE data_index_bloom_encoding_stats; + + +######## +# String coercion +######## + +statement error DataFusion error: SQL error: ParserError\("Expected a data type name, found: ,"\) +CREATE TABLE t(vendor_id_utf8, vendor_id_dict) +AS VALUES +(arrow_cast('124', 'Utf8'), arrow_cast('124', 'Dictionary(Int16, Utf8)')), +(arrow_cast('345', 'Utf8'), arrow_cast('345', 'Dictionary(Int16, Utf8)')); + +query error DataFusion error: Error during planning: table 'datafusion\.public\.t' not found +SELECT * FROM t WHERE vendor_id_utf8 = 124; + +query error DataFusion error: Error during planning: table 'datafusion\.public\.t' not found +SELECT * FROM t WHERE vendor_id_dict = 124 + +query error DataFusion error: Error during planning: table 'datafusion\.public\.t' not found +SELECT * FROM t WHERE cast(vendor_id_utf8 as varchar) = 124 + +######## +# Multiple OR predicates +######## + +statement ok +CREATE EXTERNAL TABLE IF NOT EXISTS lineitem ( + l_orderkey BIGINT, + l_partkey BIGINT, + l_suppkey BIGINT, + l_linenumber INTEGER, + l_quantity DECIMAL(15, 2), + l_extendedprice DECIMAL(15, 2), + l_discount DECIMAL(15, 2), + l_tax DECIMAL(15, 2), + l_returnflag VARCHAR, + l_linestatus VARCHAR, + l_shipdate DATE, + l_commitdate DATE, + l_receiptdate DATE, + l_shipinstruct VARCHAR, + l_shipmode VARCHAR, + l_comment VARCHAR, +) STORED AS CSV DELIMITER ',' WITH HEADER ROW LOCATION '../core/tests/tpch-csv/lineitem.csv'; + +statement ok +CREATE EXTERNAL TABLE IF NOT EXISTS part ( + p_partkey BIGINT, + p_name VARCHAR, + p_mfgr VARCHAR, + p_brand VARCHAR, + p_type VARCHAR, + p_size INT, + p_container VARCHAR, + p_retailprice DECIMAL(15, 2), + p_comment VARCHAR, +) STORED AS CSV DELIMITER ',' WITH HEADER ROW LOCATION '../core/tests/tpch-csv/part.csv'; + +query TT +EXPLAIN SELECT l_partkey FROM +lineitem, part WHERE +( + p_partkey = l_partkey + AND p_brand = 'Brand#12' + AND l_quantity >= 1 AND l_quantity <= 1 + 10 + AND p_size BETWEEN 1 AND 5 +) +OR +( + p_partkey = l_partkey + AND p_brand = 'Brand#23' + AND l_quantity >= 10 AND l_quantity <= 10 + 10 + AND p_size BETWEEN 1 AND 10 +) +OR +( + p_partkey = l_partkey + AND p_brand = 'Brand#34' + AND l_quantity >= 20 AND l_quantity <= 20 + 10 + AND p_size BETWEEN 1 AND 15 +) +---- +logical_plan +Projection: lineitem.l_partkey +--Inner Join: lineitem.l_partkey = part.p_partkey Filter: part.p_brand = Utf8("Brand#12") AND lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) AND part.p_size <= Int32(5) OR part.p_brand = Utf8("Brand#23") AND lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) AND part.p_size <= Int32(10) OR part.p_brand = Utf8("Brand#34") AND lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2) AND part.p_size <= Int32(15) +----Filter: lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2) +------TableScan: lineitem projection=[l_partkey, l_quantity], partial_filters=[lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2)] +----Filter: (part.p_brand = Utf8("Brand#12") AND part.p_size <= Int32(5) OR part.p_brand = Utf8("Brand#23") AND part.p_size <= Int32(10) OR part.p_brand = Utf8("Brand#34") AND part.p_size <= Int32(15)) AND part.p_size >= Int32(1) +------TableScan: part projection=[p_partkey, p_brand, p_size], partial_filters=[part.p_size >= Int32(1), part.p_brand = Utf8("Brand#12") AND part.p_size <= Int32(5) OR part.p_brand = Utf8("Brand#23") AND part.p_size <= Int32(10) OR part.p_brand = Utf8("Brand#34") AND part.p_size <= Int32(15)] +physical_plan +ProjectionExec: expr=[l_partkey@0 as l_partkey] +--CoalesceBatchesExec: target_batch_size=8192 +----HashJoinExec: mode=Partitioned, join_type=Inner, on=[(l_partkey@0, p_partkey@0)], filter=p_brand@1 = Brand#12 AND l_quantity@0 >= Some(100),15,2 AND l_quantity@0 <= Some(1100),15,2 AND p_size@2 <= 5 OR p_brand@1 = Brand#23 AND l_quantity@0 >= Some(1000),15,2 AND l_quantity@0 <= Some(2000),15,2 AND p_size@2 <= 10 OR p_brand@1 = Brand#34 AND l_quantity@0 >= Some(2000),15,2 AND l_quantity@0 <= Some(3000),15,2 AND p_size@2 <= 15 +------CoalesceBatchesExec: target_batch_size=8192 +--------RepartitionExec: partitioning=Hash([l_partkey@0], 4), input_partitions=4 +----------CoalesceBatchesExec: target_batch_size=8192 +------------FilterExec: l_quantity@1 >= Some(100),15,2 AND l_quantity@1 <= Some(1100),15,2 OR l_quantity@1 >= Some(1000),15,2 AND l_quantity@1 <= Some(2000),15,2 OR l_quantity@1 >= Some(2000),15,2 AND l_quantity@1 <= Some(3000),15,2 +--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +----------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/tpch-csv/lineitem.csv]]}, projection=[l_partkey, l_quantity], has_header=true +------CoalesceBatchesExec: target_batch_size=8192 +--------RepartitionExec: partitioning=Hash([p_partkey@0], 4), input_partitions=4 +----------CoalesceBatchesExec: target_batch_size=8192 +------------FilterExec: (p_brand@1 = Brand#12 AND p_size@2 <= 5 OR p_brand@1 = Brand#23 AND p_size@2 <= 10 OR p_brand@1 = Brand#34 AND p_size@2 <= 15) AND p_size@2 >= 1 +--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +----------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/tpch-csv/part.csv]]}, projection=[p_partkey, p_brand, p_size], has_header=true + + +######## +# TPCH Q19 - Pull predicates to inner join (simplified) +######## +statement ok +CREATE TABLE IF NOT EXISTS partsupp ( + ps_partkey BIGINT, + ps_suppkey BIGINT, + ps_availqty INTEGER, + ps_supplycost DECIMAL(15, 2), + ps_comment VARCHAR, +) AS VALUES +(63700, 7311, 100, 993.49, 'ven ideas. quickly even packages print. pending multipliers must have to are fluff'); + +query IRRI +SELECT + p_partkey, + sum(l_extendedprice), + avg(l_discount), + count(distinct ps_suppkey) +FROM + lineitem, + part, + partsupp +WHERE +( + p_partkey = l_partkey + AND p_brand = 'Brand#12' + AND p_partkey = ps_partkey +) +OR +( + p_partkey = l_partkey + AND p_brand = 'Brand#23' + AND ps_partkey = p_partkey +) +GROUP BY p_partkey; +---- +63700 13309.6 0.1 1 + + +query TT +EXPLAIN SELECT + p_partkey, + sum(l_extendedprice), + avg(l_discount), + count(distinct ps_suppkey) +FROM + lineitem, + part, + partsupp +WHERE +( + p_partkey = l_partkey + AND p_brand = 'Brand#12' + AND p_partkey = ps_partkey +) +OR +( + p_partkey = l_partkey + AND p_brand = 'Brand#23' + AND ps_partkey = p_partkey +) +GROUP BY p_partkey; +---- +logical_plan +Aggregate: groupBy=[[part.p_partkey]], aggr=[[SUM(lineitem.l_extendedprice), AVG(lineitem.l_discount), COUNT(DISTINCT partsupp.ps_suppkey)]] +--Projection: lineitem.l_extendedprice, lineitem.l_discount, part.p_partkey, partsupp.ps_suppkey +----Inner Join: part.p_partkey = partsupp.ps_partkey +------Projection: lineitem.l_extendedprice, lineitem.l_discount, part.p_partkey +--------Inner Join: lineitem.l_partkey = part.p_partkey +----------TableScan: lineitem projection=[l_partkey, l_extendedprice, l_discount] +----------Projection: part.p_partkey +------------Filter: part.p_brand = Utf8("Brand#12") OR part.p_brand = Utf8("Brand#23") +--------------TableScan: part projection=[p_partkey, p_brand], partial_filters=[part.p_brand = Utf8("Brand#12") OR part.p_brand = Utf8("Brand#23")] +------TableScan: partsupp projection=[ps_partkey, ps_suppkey] +physical_plan +AggregateExec: mode=SinglePartitioned, gby=[p_partkey@2 as p_partkey], aggr=[SUM(lineitem.l_extendedprice), AVG(lineitem.l_discount), COUNT(DISTINCT partsupp.ps_suppkey)] +--ProjectionExec: expr=[l_extendedprice@0 as l_extendedprice, l_discount@1 as l_discount, p_partkey@2 as p_partkey, ps_suppkey@4 as ps_suppkey] +----CoalesceBatchesExec: target_batch_size=8192 +------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(p_partkey@2, ps_partkey@0)] +--------ProjectionExec: expr=[l_extendedprice@1 as l_extendedprice, l_discount@2 as l_discount, p_partkey@3 as p_partkey] +----------CoalesceBatchesExec: target_batch_size=8192 +------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(l_partkey@0, p_partkey@0)] +--------------CoalesceBatchesExec: target_batch_size=8192 +----------------RepartitionExec: partitioning=Hash([l_partkey@0], 4), input_partitions=4 +------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +--------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/tpch-csv/lineitem.csv]]}, projection=[l_partkey, l_extendedprice, l_discount], has_header=true +--------------CoalesceBatchesExec: target_batch_size=8192 +----------------RepartitionExec: partitioning=Hash([p_partkey@0], 4), input_partitions=4 +------------------ProjectionExec: expr=[p_partkey@0 as p_partkey] +--------------------CoalesceBatchesExec: target_batch_size=8192 +----------------------FilterExec: p_brand@1 = Brand#12 OR p_brand@1 = Brand#23 +------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +--------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/tpch-csv/part.csv]]}, projection=[p_partkey, p_brand], has_header=true +--------CoalesceBatchesExec: target_batch_size=8192 +----------RepartitionExec: partitioning=Hash([ps_partkey@0], 4), input_partitions=1 +------------MemoryExec: partitions=1, partition_sizes=[1] From 08cee667637048284a2ac21ad06f40f426e60057 Mon Sep 17 00:00:00 2001 From: Cancai Cai <77189278+caicancai@users.noreply.github.com> Date: Tue, 16 Jan 2024 21:22:59 +0800 Subject: [PATCH 05/39] docs: Update contributor guide with installation instructions (#8876) --- docs/source/contributor-guide/index.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/docs/source/contributor-guide/index.md b/docs/source/contributor-guide/index.md index cb0fe63abd91..11dcbd935a02 100644 --- a/docs/source/contributor-guide/index.md +++ b/docs/source/contributor-guide/index.md @@ -95,9 +95,16 @@ Compiling DataFusion from sources requires an installed version of the protobuf On most platforms this can be installed from your system's package manager ``` +# Ubuntu $ sudo apt install -y protobuf-compiler + +# Fedora $ dnf install -y protobuf-devel + +# Arch Linux $ pacman -S protobuf + +# macOS $ brew install protobuf ``` From 7b7e80d5090b97deb75ce3252327b956eb9b773b Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Tue, 16 Jan 2024 21:35:19 +0800 Subject: [PATCH 06/39] add tests between nested list and largelist (#8882) --- datafusion/sqllogictest/test_files/arrow_typeof.slt | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/datafusion/sqllogictest/test_files/arrow_typeof.slt b/datafusion/sqllogictest/test_files/arrow_typeof.slt index 6a623e6c92f9..5e9e7ff03d8b 100644 --- a/datafusion/sqllogictest/test_files/arrow_typeof.slt +++ b/datafusion/sqllogictest/test_files/arrow_typeof.slt @@ -358,6 +358,10 @@ select arrow_typeof(arrow_cast(make_array(1, 2, 3), 'List(Int64)')); ---- List(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) +query T +select arrow_typeof(arrow_cast(arrow_cast(make_array([1, 2, 3]), 'LargeList(LargeList(Int64))'), 'List(List(Int64))')); +---- +List(Field { name: "item", data_type: List(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) ## LargeList @@ -376,3 +380,8 @@ query T select arrow_typeof(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)')); ---- LargeList(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) + +query T +select arrow_typeof(arrow_cast(make_array([1, 2, 3]), 'LargeList(LargeList(Int64))')); +---- +LargeList(Field { name: "item", data_type: LargeList(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) \ No newline at end of file From d2ff112de073f63490f049425f197d6066ea1980 Mon Sep 17 00:00:00 2001 From: Devin D'Angelo Date: Tue, 16 Jan 2024 09:05:19 -0500 Subject: [PATCH 07/39] Disable Parallel Parquet Writer by Default, Improve Writing Test Coverage (#8854) * disable parallel writer and add test * more tests * --complete sqllogic tests * make rows distinct and add make_array of struct --- datafusion/common/src/config.rs | 2 +- datafusion/sqllogictest/test_files/copy.slt | 19 +++++++++++++++++++ .../test_files/information_schema.slt | 4 ++-- .../test_files/repartition_scan.slt | 8 ++++---- docs/source/user-guide/configs.md | 2 +- 5 files changed, 27 insertions(+), 8 deletions(-) diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index 5c051a7dee82..996a505dea80 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -403,7 +403,7 @@ config_namespace! { /// parquet files by serializing them in parallel. Each column /// in each row group in each output file are serialized in parallel /// leveraging a maximum possible core count of n_files*n_row_groups*n_columns. - pub allow_single_file_parallelism: bool, default = true + pub allow_single_file_parallelism: bool, default = false /// By default parallel parquet writer is tuned for minimum /// memory usage in a streaming execution plan. You may see diff --git a/datafusion/sqllogictest/test_files/copy.slt b/datafusion/sqllogictest/test_files/copy.slt index 89b23917884c..9f5b7af41577 100644 --- a/datafusion/sqllogictest/test_files/copy.slt +++ b/datafusion/sqllogictest/test_files/copy.slt @@ -64,6 +64,25 @@ select * from validate_parquet; 1 Foo 2 Bar +query ?? +COPY +(values (struct ('foo', (struct ('foo', make_array(struct('a',1), struct('b',2))))), make_array(timestamp '2023-01-01 01:00:01',timestamp '2023-01-01 01:00:01')), +(struct('bar', (struct ('foo', make_array(struct('aa',10), struct('bb',20))))), make_array(timestamp '2024-01-01 01:00:01', timestamp '2024-01-01 01:00:01'))) +to 'test_files/scratch/copy/table_nested' (format parquet, single_file_output false); +---- +2 + +# validate multiple parquet file output +statement ok +CREATE EXTERNAL TABLE validate_parquet_nested STORED AS PARQUET LOCATION 'test_files/scratch/copy/table_nested/'; + +query ?? +select * from validate_parquet_nested; +---- +{c0: foo, c1: {c0: foo, c1: [{c0: a, c1: 1}, {c0: b, c1: 2}]}} [2023-01-01T01:00:01, 2023-01-01T01:00:01] +{c0: bar, c1: {c0: foo, c1: [{c0: aa, c1: 10}, {c0: bb, c1: 20}]}} [2024-01-01T01:00:01, 2024-01-01T01:00:01] + + # Copy parquet with all supported statment overrides query IT COPY source_table diff --git a/datafusion/sqllogictest/test_files/information_schema.slt b/datafusion/sqllogictest/test_files/information_schema.slt index f8893bf7ae5c..44daa5141677 100644 --- a/datafusion/sqllogictest/test_files/information_schema.slt +++ b/datafusion/sqllogictest/test_files/information_schema.slt @@ -154,7 +154,7 @@ datafusion.execution.listing_table_ignore_subdirectory true datafusion.execution.max_buffered_batches_per_output_file 2 datafusion.execution.meta_fetch_concurrency 32 datafusion.execution.minimum_parallel_output_files 4 -datafusion.execution.parquet.allow_single_file_parallelism true +datafusion.execution.parquet.allow_single_file_parallelism false datafusion.execution.parquet.bloom_filter_enabled false datafusion.execution.parquet.bloom_filter_fpp NULL datafusion.execution.parquet.bloom_filter_ndv NULL @@ -229,7 +229,7 @@ datafusion.execution.listing_table_ignore_subdirectory true Should sub directori datafusion.execution.max_buffered_batches_per_output_file 2 This is the maximum number of RecordBatches buffered for each output file being worked. Higher values can potentially give faster write performance at the cost of higher peak memory consumption datafusion.execution.meta_fetch_concurrency 32 Number of files to read in parallel when inferring schema and statistics datafusion.execution.minimum_parallel_output_files 4 Guarantees a minimum level of output files running in parallel. RecordBatches will be distributed in round robin fashion to each parallel writer. Each writer is closed and a new file opened once soft_max_rows_per_output_file is reached. -datafusion.execution.parquet.allow_single_file_parallelism true Controls whether DataFusion will attempt to speed up writing parquet files by serializing them in parallel. Each column in each row group in each output file are serialized in parallel leveraging a maximum possible core count of n_files*n_row_groups*n_columns. +datafusion.execution.parquet.allow_single_file_parallelism false Controls whether DataFusion will attempt to speed up writing parquet files by serializing them in parallel. Each column in each row group in each output file are serialized in parallel leveraging a maximum possible core count of n_files*n_row_groups*n_columns. datafusion.execution.parquet.bloom_filter_enabled false Sets if bloom filter is enabled for any column datafusion.execution.parquet.bloom_filter_fpp NULL Sets bloom filter false positive probability. If NULL, uses default parquet writer setting datafusion.execution.parquet.bloom_filter_ndv NULL Sets bloom filter number of distinct values. If NULL, uses default parquet writer setting diff --git a/datafusion/sqllogictest/test_files/repartition_scan.slt b/datafusion/sqllogictest/test_files/repartition_scan.slt index 4b8c8f2f084e..5ee0da2d33e8 100644 --- a/datafusion/sqllogictest/test_files/repartition_scan.slt +++ b/datafusion/sqllogictest/test_files/repartition_scan.slt @@ -61,7 +61,7 @@ Filter: parquet_table.column1 != Int32(42) physical_plan CoalesceBatchesExec: target_batch_size=8192 --FilterExec: column1@0 != 42 -----ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..104], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:104..208], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:208..312], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:312..414]]}, projection=[column1], predicate=column1@0 != 42, pruning_predicate=column1_min@0 != 42 OR 42 != column1_max@1, required_guarantees=[column1 not in (42)] +----ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..153], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:153..306], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:306..459], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:459..610]]}, projection=[column1], predicate=column1@0 != 42, pruning_predicate=column1_min@0 != 42 OR 42 != column1_max@1, required_guarantees=[column1 not in (42)] # disable round robin repartitioning statement ok @@ -77,7 +77,7 @@ Filter: parquet_table.column1 != Int32(42) physical_plan CoalesceBatchesExec: target_batch_size=8192 --FilterExec: column1@0 != 42 -----ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..104], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:104..208], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:208..312], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:312..414]]}, projection=[column1], predicate=column1@0 != 42, pruning_predicate=column1_min@0 != 42 OR 42 != column1_max@1, required_guarantees=[column1 not in (42)] +----ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..153], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:153..306], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:306..459], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:459..610]]}, projection=[column1], predicate=column1@0 != 42, pruning_predicate=column1_min@0 != 42 OR 42 != column1_max@1, required_guarantees=[column1 not in (42)] # enable round robin repartitioning again statement ok @@ -102,7 +102,7 @@ SortPreservingMergeExec: [column1@0 ASC NULLS LAST] --SortExec: expr=[column1@0 ASC NULLS LAST] ----CoalesceBatchesExec: target_batch_size=8192 ------FilterExec: column1@0 != 42 ---------ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:0..205], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:205..405, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..5], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:5..210], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:210..414]]}, projection=[column1], predicate=column1@0 != 42, pruning_predicate=column1_min@0 != 42 OR 42 != column1_max@1, required_guarantees=[column1 not in (42)] +--------ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:0..303], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:303..601, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..5], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:5..308], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:308..610]]}, projection=[column1], predicate=column1@0 != 42, pruning_predicate=column1_min@0 != 42 OR 42 != column1_max@1, required_guarantees=[column1 not in (42)] ## Read the files as though they are ordered @@ -138,7 +138,7 @@ physical_plan SortPreservingMergeExec: [column1@0 ASC NULLS LAST] --CoalesceBatchesExec: target_batch_size=8192 ----FilterExec: column1@0 != 42 -------ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:0..202], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..207], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:207..414], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:202..405]]}, projection=[column1], output_ordering=[column1@0 ASC NULLS LAST], predicate=column1@0 != 42, pruning_predicate=column1_min@0 != 42 OR 42 != column1_max@1, required_guarantees=[column1 not in (42)] +------ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:0..300], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..305], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:305..610], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:300..601]]}, projection=[column1], output_ordering=[column1@0 ASC NULLS LAST], predicate=column1@0 != 42, pruning_predicate=column1_min@0 != 42 OR 42 != column1_max@1, required_guarantees=[column1 not in (42)] # Cleanup statement ok diff --git a/docs/source/user-guide/configs.md b/docs/source/user-guide/configs.md index 7111ea1d0ab5..5e26e2b205dd 100644 --- a/docs/source/user-guide/configs.md +++ b/docs/source/user-guide/configs.md @@ -71,7 +71,7 @@ Environment variables are read during `SessionConfig` initialisation so they mus | datafusion.execution.parquet.bloom_filter_enabled | false | Sets if bloom filter is enabled for any column | | datafusion.execution.parquet.bloom_filter_fpp | NULL | Sets bloom filter false positive probability. If NULL, uses default parquet writer setting | | datafusion.execution.parquet.bloom_filter_ndv | NULL | Sets bloom filter number of distinct values. If NULL, uses default parquet writer setting | -| datafusion.execution.parquet.allow_single_file_parallelism | true | Controls whether DataFusion will attempt to speed up writing parquet files by serializing them in parallel. Each column in each row group in each output file are serialized in parallel leveraging a maximum possible core count of n_files*n_row_groups*n_columns. | +| datafusion.execution.parquet.allow_single_file_parallelism | false | Controls whether DataFusion will attempt to speed up writing parquet files by serializing them in parallel. Each column in each row group in each output file are serialized in parallel leveraging a maximum possible core count of n_files*n_row_groups*n_columns. | | datafusion.execution.parquet.maximum_parallel_row_group_writers | 1 | By default parallel parquet writer is tuned for minimum memory usage in a streaming execution plan. You may see a performance benefit when writing large parquet files by increasing maximum_parallel_row_group_writers and maximum_buffered_record_batches_per_stream if your system has idle cores and can tolerate additional memory usage. Boosting these values is likely worthwhile when writing out already in-memory data, such as from a cached data frame. | | datafusion.execution.parquet.maximum_buffered_record_batches_per_stream | 2 | By default parallel parquet writer is tuned for minimum memory usage in a streaming execution plan. You may see a performance benefit when writing large parquet files by increasing maximum_parallel_row_group_writers and maximum_buffered_record_batches_per_stream if your system has idle cores and can tolerate additional memory usage. Boosting these values is likely worthwhile when writing out already in-memory data, such as from a cached data frame. | | datafusion.execution.aggregate.scalar_update_factor | 10 | Specifies the threshold for using `ScalarValue`s to update accumulators during high-cardinality aggregations for each input batch. The aggregation is considered high-cardinality if the number of affected groups is greater than or equal to `batch_size / scalar_update_factor`. In such cases, `ScalarValue`s are utilized for updating accumulators, rather than the default batch-slice approach. This can lead to performance improvements. By adjusting the `scalar_update_factor`, you can balance the trade-off between more efficient accumulator updates and the number of groups affected. | From 8cf1abb03c1f6aaad32e59b3c59b202acb55259b Mon Sep 17 00:00:00 2001 From: Mustafa Akur <106137913+mustafasrepo@users.noreply.github.com> Date: Tue, 16 Jan 2024 19:16:48 +0300 Subject: [PATCH 08/39] Support for order sensitive `NTH_VALUE` aggregation, make reverse `ARRAY_AGG` more efficient (#8841) * Initial commit * minor changes * Parse index argument * Move nth_value to array_agg * Initial implementation (with redundant data) * Add new test * Add reverse support * Add new slt tests * Add multi partition support * Minor changes * Minor changes * Add new aggregator to the proto * Remove redundant tests * Keep n entries in the state for nth value * Change implementation * Move nth value to its own file * Minor changes * minor changes * Review * Update comments * Use drain method to remove from the beginning. * Add reverse support, convert buffer to vecdeque * Minor changes * Minor changes * Review Part 2 * Review Part 3 * Add new_list from iter * Convert API to receive vecdeque * Receive mutable argument * Refactor merge implementation * Fix doctest --------- Co-authored-by: Mehmet Ozan Kabak --- datafusion/common/src/scalar.rs | 83 +++- datafusion/expr/src/aggregate_function.rs | 41 +- .../expr/src/type_coercion/aggregates.rs | 13 +- .../src/aggregate/array_agg_ordered.rs | 344 ++++++++------- .../physical-expr/src/aggregate/build_in.rs | 37 +- datafusion/physical-expr/src/aggregate/mod.rs | 4 +- .../physical-expr/src/aggregate/nth_value.rs | 400 ++++++++++++++++++ .../physical-expr/src/aggregate/utils.rs | 4 +- .../physical-expr/src/expressions/mod.rs | 11 +- .../physical-plan/src/aggregates/mod.rs | 67 ++- datafusion/proto/proto/datafusion.proto | 1 + datafusion/proto/src/generated/pbjson.rs | 3 + datafusion/proto/src/generated/prost.rs | 3 + .../proto/src/logical_plan/from_proto.rs | 27 +- datafusion/proto/src/logical_plan/to_proto.rs | 5 + .../sqllogictest/test_files/group_by.slt | 140 ++++++ 16 files changed, 943 insertions(+), 240 deletions(-) create mode 100644 datafusion/physical-expr/src/aggregate/nth_value.rs diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index cc5b70796e88..9cbd9e292ff3 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -19,11 +19,13 @@ use std::borrow::Borrow; use std::cmp::Ordering; -use std::collections::HashSet; -use std::convert::{Infallible, TryInto}; +use std::collections::{HashSet, VecDeque}; +use std::convert::{Infallible, TryFrom, TryInto}; +use std::fmt; use std::hash::Hash; +use std::iter::repeat; use std::str::FromStr; -use std::{convert::TryFrom, fmt, iter::repeat, sync::Arc}; +use std::sync::Arc; use crate::arrow_datafusion_err; use crate::cast::{ @@ -33,23 +35,22 @@ use crate::cast::{ use crate::error::{DataFusionError, Result, _internal_err, _not_impl_err}; use crate::hash_utils::create_hashes; use crate::utils::{array_into_large_list_array, array_into_list_array}; + use arrow::compute::kernels::numeric::*; -use arrow::datatypes::{i256, Fields, SchemaBuilder}; use arrow::util::display::{ArrayFormatter, FormatOptions}; use arrow::{ array::*, compute::kernels::cast::{cast_with_options, CastOptions}, datatypes::{ - ArrowDictionaryKeyType, ArrowNativeType, DataType, Field, Float32Type, Int16Type, - Int32Type, Int64Type, Int8Type, IntervalDayTimeType, IntervalMonthDayNanoType, - IntervalUnit, IntervalYearMonthType, TimeUnit, TimestampMicrosecondType, + i256, ArrowDictionaryKeyType, ArrowNativeType, ArrowTimestampType, DataType, + Field, Fields, Float32Type, Int16Type, Int32Type, Int64Type, Int8Type, + IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, + IntervalYearMonthType, SchemaBuilder, TimeUnit, TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, DECIMAL128_MAX_PRECISION, }, }; use arrow_array::cast::as_list_array; -use arrow_array::types::ArrowTimestampType; -use arrow_array::{ArrowNativeTypeOp, Scalar}; /// A dynamically typed, nullable single value, (the single-valued counter-part /// to arrow's [`Array`]) @@ -1728,6 +1729,43 @@ impl ScalarValue { Arc::new(array_into_list_array(values)) } + /// Converts `IntoIterator` where each element has type corresponding to + /// `data_type`, to a [`ListArray`]. + /// + /// Example + /// ``` + /// use datafusion_common::ScalarValue; + /// use arrow::array::{ListArray, Int32Array}; + /// use arrow::datatypes::{DataType, Int32Type}; + /// use datafusion_common::cast::as_list_array; + /// + /// let scalars = vec![ + /// ScalarValue::Int32(Some(1)), + /// ScalarValue::Int32(None), + /// ScalarValue::Int32(Some(2)) + /// ]; + /// + /// let result = ScalarValue::new_list_from_iter(scalars.into_iter(), &DataType::Int32); + /// + /// let expected = ListArray::from_iter_primitive::( + /// vec![ + /// Some(vec![Some(1), None, Some(2)]) + /// ]); + /// + /// assert_eq!(*result, expected); + /// ``` + pub fn new_list_from_iter( + values: impl IntoIterator + ExactSizeIterator, + data_type: &DataType, + ) -> Arc { + let values = if values.len() == 0 { + new_empty_array(data_type) + } else { + Self::iter_to_array(values).unwrap() + }; + Arc::new(array_into_list_array(values)) + } + /// Converts `Vec` where each element has type corresponding to /// `data_type`, to a [`LargeListArray`]. /// @@ -2626,6 +2664,18 @@ impl ScalarValue { .sum::() } + /// Estimates [size](Self::size) of [`VecDeque`] in bytes. + /// + /// Includes the size of the [`VecDeque`] container itself. + pub fn size_of_vec_deque(vec_deque: &VecDeque) -> usize { + std::mem::size_of_val(vec_deque) + + (std::mem::size_of::() * vec_deque.capacity()) + + vec_deque + .iter() + .map(|sv| sv.size() - std::mem::size_of_val(sv)) + .sum::() + } + /// Estimates [size](Self::size) of [`HashSet`] in bytes. /// /// Includes the size of the [`HashSet`] container itself. @@ -3151,22 +3201,19 @@ impl ScalarType for TimestampNanosecondType { #[cfg(test)] mod tests { - use super::*; - use std::cmp::Ordering; use std::sync::Arc; - use chrono::NaiveDate; - use rand::Rng; + use super::*; + use crate::cast::{as_string_array, as_uint32_array, as_uint64_array}; use arrow::buffer::OffsetBuffer; - use arrow::compute::kernels; - use arrow::compute::{concat, is_null}; - use arrow::datatypes::ArrowPrimitiveType; + use arrow::compute::{concat, is_null, kernels}; + use arrow::datatypes::{ArrowNumericType, ArrowPrimitiveType}; use arrow::util::pretty::pretty_format_columns; - use arrow_array::ArrowNumericType; - use crate::cast::{as_string_array, as_uint32_array, as_uint64_array}; + use chrono::NaiveDate; + use rand::Rng; #[test] fn test_to_array_of_size_for_list() { diff --git a/datafusion/expr/src/aggregate_function.rs b/datafusion/expr/src/aggregate_function.rs index 9db7635d99a0..574de3e7082a 100644 --- a/datafusion/expr/src/aggregate_function.rs +++ b/datafusion/expr/src/aggregate_function.rs @@ -17,12 +17,15 @@ //! Aggregate function module contains all built-in aggregate functions definitions +use std::sync::Arc; +use std::{fmt, str::FromStr}; + use crate::utils; use crate::{type_coercion::aggregates::*, Signature, TypeSignature, Volatility}; + use arrow::datatypes::{DataType, Field}; use datafusion_common::{plan_datafusion_err, plan_err, DataFusionError, Result}; -use std::sync::Arc; -use std::{fmt, str::FromStr}; + use strum_macros::EnumIter; /// Enum of all built-in aggregate functions @@ -30,26 +33,28 @@ use strum_macros::EnumIter; // https://arrow.apache.org/datafusion/contributor-guide/index.html#how-to-add-a-new-aggregate-function #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash, EnumIter)] pub enum AggregateFunction { - /// count + /// Count Count, - /// sum + /// Sum Sum, - /// min + /// Minimum Min, - /// max + /// Maximum Max, - /// avg + /// Average Avg, - /// median + /// Median Median, - /// Approximate aggregate function + /// Approximate distinct function ApproxDistinct, - /// array_agg + /// Aggregation into an array ArrayAgg, - /// first_value + /// First value in a group according to some ordering FirstValue, - /// last_value + /// Last value in a group according to some ordering LastValue, + /// N'th value in a group according to some ordering + NthValue, /// Variance (Sample) Variance, /// Variance (Population) @@ -100,7 +105,7 @@ pub enum AggregateFunction { BoolAnd, /// Bool Or BoolOr, - /// string_agg + /// String aggregation StringAgg, } @@ -118,6 +123,7 @@ impl AggregateFunction { ArrayAgg => "ARRAY_AGG", FirstValue => "FIRST_VALUE", LastValue => "LAST_VALUE", + NthValue => "NTH_VALUE", Variance => "VAR", VariancePop => "VAR_POP", Stddev => "STDDEV", @@ -174,6 +180,7 @@ impl FromStr for AggregateFunction { "array_agg" => AggregateFunction::ArrayAgg, "first_value" => AggregateFunction::FirstValue, "last_value" => AggregateFunction::LastValue, + "nth_value" => AggregateFunction::NthValue, "string_agg" => AggregateFunction::StringAgg, // statistical "corr" => AggregateFunction::Correlation, @@ -300,9 +307,9 @@ impl AggregateFunction { Ok(coerced_data_types[0].clone()) } AggregateFunction::Grouping => Ok(DataType::Int32), - AggregateFunction::FirstValue | AggregateFunction::LastValue => { - Ok(coerced_data_types[0].clone()) - } + AggregateFunction::FirstValue + | AggregateFunction::LastValue + | AggregateFunction::NthValue => Ok(coerced_data_types[0].clone()), AggregateFunction::StringAgg => Ok(DataType::LargeUtf8), } } @@ -371,6 +378,7 @@ impl AggregateFunction { | AggregateFunction::LastValue => { Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable) } + AggregateFunction::NthValue => Signature::any(2, Volatility::Immutable), AggregateFunction::Covariance | AggregateFunction::CovariancePop | AggregateFunction::Correlation @@ -428,6 +436,7 @@ impl AggregateFunction { #[cfg(test)] mod tests { use super::*; + use strum::IntoEnumIterator; #[test] diff --git a/datafusion/expr/src/type_coercion/aggregates.rs b/datafusion/expr/src/type_coercion/aggregates.rs index 56bb5c9b69c4..ab994c143ac2 100644 --- a/datafusion/expr/src/type_coercion/aggregates.rs +++ b/datafusion/expr/src/type_coercion/aggregates.rs @@ -15,17 +15,16 @@ // specific language governing permissions and limitations // under the License. +use std::ops::Deref; + +use super::functions::can_coerce_from; +use crate::{AggregateFunction, Signature, TypeSignature}; + use arrow::datatypes::{ DataType, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, }; - use datafusion_common::{internal_err, plan_err, DataFusionError, Result}; -use std::ops::Deref; - -use crate::{AggregateFunction, Signature, TypeSignature}; - -use super::functions::can_coerce_from; pub static STRINGS: &[DataType] = &[DataType::Utf8, DataType::LargeUtf8]; @@ -297,6 +296,7 @@ pub fn coerce_types( AggregateFunction::Median | AggregateFunction::FirstValue | AggregateFunction::LastValue => Ok(input_types.to_vec()), + AggregateFunction::NthValue => Ok(input_types.to_vec()), AggregateFunction::Grouping => Ok(vec![input_types[0].clone()]), AggregateFunction::StringAgg => { if !is_string_agg_supported_arg_type(&input_types[0]) { @@ -584,6 +584,7 @@ pub fn is_string_agg_supported_arg_type(arg_type: &DataType) -> bool { #[cfg(test)] mod tests { use super::*; + use arrow::datatypes::DataType; #[test] diff --git a/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs b/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs index eb5ae8b0b0c3..34f8d20628dc 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs @@ -20,46 +20,43 @@ use std::any::Any; use std::cmp::Ordering; -use std::collections::BinaryHeap; +use std::collections::{BinaryHeap, VecDeque}; use std::fmt::Debug; use std::sync::Arc; use crate::aggregate::utils::{down_cast_any_ref, ordering_fields}; use crate::expressions::format_state_name; -use crate::{AggregateExpr, LexOrdering, PhysicalExpr, PhysicalSortExpr}; +use crate::{ + reverse_order_bys, AggregateExpr, LexOrdering, PhysicalExpr, PhysicalSortExpr, +}; -use arrow::array::ArrayRef; +use arrow::array::{Array, ArrayRef}; use arrow::datatypes::{DataType, Field}; use arrow_array::cast::AsArray; -use arrow_array::Array; use arrow_schema::{Fields, SortOptions}; use datafusion_common::utils::{compare_rows, get_row_at_idx}; use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue}; use datafusion_expr::Accumulator; -use itertools::izip; - -/// Expression for a ARRAY_AGG(ORDER BY) aggregation. -/// When aggregation works in multiple partitions -/// aggregations are split into multiple partitions, -/// then their results are merged. This aggregator -/// is a version of ARRAY_AGG that can support producing -/// intermediate aggregation (with necessary side information) -/// and that can merge aggregations from multiple partitions. +/// Expression for a `ARRAY_AGG(... ORDER BY ..., ...)` aggregation. In a multi +/// partition setting, partial aggregations are computed for every partition, +/// and then their results are merged. #[derive(Debug)] pub struct OrderSensitiveArrayAgg { /// Column name name: String, - /// The DataType for the input expression + /// The `DataType` for the input expression input_data_type: DataType, /// The input expression expr: Arc, - /// If the input expression can have NULLs + /// If the input expression can have `NULL`s nullable: bool, /// Ordering data types order_by_data_types: Vec, /// Ordering requirement ordering_req: LexOrdering, + /// Whether the aggregation is running in reverse + reverse: bool, } impl OrderSensitiveArrayAgg { @@ -79,6 +76,7 @@ impl OrderSensitiveArrayAgg { nullable, order_by_data_types, ordering_req, + reverse: false, } } } @@ -98,11 +96,13 @@ impl AggregateExpr for OrderSensitiveArrayAgg { } fn create_accumulator(&self) -> Result> { - Ok(Box::new(OrderSensitiveArrayAggAccumulator::try_new( + OrderSensitiveArrayAggAccumulator::try_new( &self.input_data_type, &self.order_by_data_types, self.ordering_req.clone(), - )?)) + self.reverse, + ) + .map(|acc| Box::new(acc) as _) } fn state_fields(&self) -> Result> { @@ -125,16 +125,25 @@ impl AggregateExpr for OrderSensitiveArrayAgg { } fn order_bys(&self) -> Option<&[PhysicalSortExpr]> { - if self.ordering_req.is_empty() { - None - } else { - Some(&self.ordering_req) - } + (!self.ordering_req.is_empty()).then_some(&self.ordering_req) } fn name(&self) -> &str { &self.name } + + fn reverse_expr(&self) -> Option> { + Some(Arc::new(Self { + name: self.name.to_string(), + input_data_type: self.input_data_type.clone(), + expr: self.expr.clone(), + nullable: self.nullable, + order_by_data_types: self.order_by_data_types.clone(), + // Reverse requirement: + ordering_req: reverse_order_bys(&self.ordering_req), + reverse: !self.reverse, + })) + } } impl PartialEq for OrderSensitiveArrayAgg { @@ -153,19 +162,20 @@ impl PartialEq for OrderSensitiveArrayAgg { #[derive(Debug)] pub(crate) struct OrderSensitiveArrayAggAccumulator { - // `values` stores entries in the ARRAY_AGG result. + /// Stores entries in the `ARRAY_AGG` result. values: Vec, - // `ordering_values` stores values of ordering requirement expression - // corresponding to each value in the ARRAY_AGG. - // For each `ScalarValue` inside `values`, there will be a corresponding - // `Vec` inside `ordering_values` which stores it ordering. - // This information is used during merging results of the different partitions. - // For detailed information how merging is done see [`merge_ordered_arrays`] + /// Stores values of ordering requirement expressions corresponding to each + /// entry in `values`. This information is used when merging results from + /// different partitions. For detailed information how merging is done, see + /// [`merge_ordered_arrays`]. ordering_values: Vec>, - // `datatypes` stores, datatype of expression inside ARRAY_AGG and ordering requirement expressions. + /// Stores datatypes of expressions inside values and ordering requirement + /// expressions. datatypes: Vec, - // Stores ordering requirement of the Accumulator + /// Stores the ordering requirement of the `Accumulator`. ordering_req: LexOrdering, + /// Whether the aggregation is running in reverse. + reverse: bool, } impl OrderSensitiveArrayAggAccumulator { @@ -175,6 +185,7 @@ impl OrderSensitiveArrayAggAccumulator { datatype: &DataType, ordering_dtypes: &[DataType], ordering_req: LexOrdering, + reverse: bool, ) -> Result { let mut datatypes = vec![datatype.clone()]; datatypes.extend(ordering_dtypes.iter().cloned()); @@ -183,6 +194,7 @@ impl OrderSensitiveArrayAggAccumulator { ordering_values: vec![], datatypes, ordering_req, + reverse, }) } } @@ -207,63 +219,63 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator { if states.is_empty() { return Ok(()); } - // First entry in the state is the aggregation result. - let array_agg_values = &states[0]; - // 2nd entry stores values received for ordering requirement columns, for each aggregation value inside ARRAY_AGG list. - // For each `StructArray` inside ARRAY_AGG list, we will receive an `Array` that stores - // values received from its ordering requirement expression. (This information is necessary for during merging). - let agg_orderings = &states[1]; - - if let Some(agg_orderings) = agg_orderings.as_list_opt::() { - // Stores ARRAY_AGG results coming from each partition - let mut partition_values = vec![]; - // Stores ordering requirement expression results coming from each partition - let mut partition_ordering_values = vec![]; - - // Existing values should be merged also. - partition_values.push(self.values.clone()); - partition_ordering_values.push(self.ordering_values.clone()); - - let array_agg_res = - ScalarValue::convert_array_to_scalar_vec(array_agg_values)?; - - for v in array_agg_res.into_iter() { - partition_values.push(v); - } + // First entry in the state is the aggregation result. Second entry + // stores values received for ordering requirement columns for each + // aggregation value inside `ARRAY_AGG` list. For each `StructArray` + // inside `ARRAY_AGG` list, we will receive an `Array` that stores values + // received from its ordering requirement expression. (This information + // is necessary for during merging). + let [array_agg_values, agg_orderings, ..] = &states else { + return exec_err!("State should have two elements"); + }; + let Some(agg_orderings) = agg_orderings.as_list_opt::() else { + return exec_err!("Expects to receive a list array"); + }; - let orderings = ScalarValue::convert_array_to_scalar_vec(agg_orderings)?; - - for partition_ordering_rows in orderings.into_iter() { - // Extract value from struct to ordering_rows for each group/partition - let ordering_value = partition_ordering_rows.into_iter().map(|ordering_row| { - if let ScalarValue::Struct(Some(ordering_columns_per_row), _) = ordering_row { - Ok(ordering_columns_per_row) - } else { - exec_err!( - "Expects to receive ScalarValue::Struct(Some(..), _) but got:{:?}", - ordering_row.data_type() - ) - } - }).collect::>>()?; - - partition_ordering_values.push(ordering_value); - } + // Stores ARRAY_AGG results coming from each partition + let mut partition_values = vec![]; + // Stores ordering requirement expression results coming from each partition + let mut partition_ordering_values = vec![]; - let sort_options = self - .ordering_req - .iter() - .map(|sort_expr| sort_expr.options) - .collect::>(); - let (new_values, new_orderings) = merge_ordered_arrays( - &partition_values, - &partition_ordering_values, - &sort_options, - )?; - self.values = new_values; - self.ordering_values = new_orderings; - } else { - return exec_err!("Expects to receive a list array"); + // Existing values should be merged also. + partition_values.push(self.values.clone().into()); + partition_ordering_values.push(self.ordering_values.clone().into()); + + let array_agg_res = ScalarValue::convert_array_to_scalar_vec(array_agg_values)?; + + for v in array_agg_res.into_iter() { + partition_values.push(v.into()); + } + + let orderings = ScalarValue::convert_array_to_scalar_vec(agg_orderings)?; + + let ordering_values = orderings.into_iter().map(|partition_ordering_rows| { + // Extract value from struct to ordering_rows for each group/partition + partition_ordering_rows.into_iter().map(|ordering_row| { + if let ScalarValue::Struct(Some(ordering_columns_per_row), _) = ordering_row { + Ok(ordering_columns_per_row) + } else { + exec_err!( + "Expects to receive ScalarValue::Struct(Some(..), _) but got: {:?}", + ordering_row.data_type() + ) + } + }).collect::>>() + }).collect::>>()?; + for ordering_values in ordering_values.into_iter() { + partition_ordering_values.push(ordering_values); } + + let sort_options = self + .ordering_req + .iter() + .map(|sort_expr| sort_expr.options) + .collect::>(); + (self.values, self.ordering_values) = merge_ordered_arrays( + &mut partition_values, + &mut partition_ordering_values, + &sort_options, + )?; Ok(()) } @@ -274,8 +286,13 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator { } fn evaluate(&self) -> Result { - let arr = ScalarValue::new_list(&self.values, &self.datatypes[0]); - Ok(ScalarValue::List(arr)) + let values = self.values.clone(); + let array = if self.reverse { + ScalarValue::new_list_from_iter(values.into_iter().rev(), &self.datatypes[0]) + } else { + ScalarValue::new_list_from_iter(values.into_iter(), &self.datatypes[0]) + }; + Ok(ScalarValue::List(array)) } fn size(&self) -> usize { @@ -306,7 +323,7 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator { impl OrderSensitiveArrayAggAccumulator { fn evaluate_orderings(&self) -> Result { let fields = ordering_fields(&self.ordering_req, &self.datatypes[1..]); - let struct_field = Fields::from(fields.clone()); + let struct_field = Fields::from(fields); let orderings: Vec = self .ordering_values @@ -315,7 +332,7 @@ impl OrderSensitiveArrayAggAccumulator { ScalarValue::Struct(Some(ordering.clone()), struct_field.clone()) }) .collect(); - let struct_type = DataType::Struct(Fields::from(fields)); + let struct_type = DataType::Struct(struct_field); // Wrap in List, so we have the same data structure ListArray(StructArray..) for group by cases let arr = ScalarValue::new_list(&orderings, &struct_type); @@ -323,20 +340,19 @@ impl OrderSensitiveArrayAggAccumulator { } } -/// This is a wrapper struct to be able to correctly merge ARRAY_AGG -/// data from multiple partitions using `BinaryHeap`. -/// When used inside `BinaryHeap` this struct returns smallest `CustomElement`, -/// where smallest is determined by `ordering` values (`Vec`) -/// according to `sort_options` +/// This is a wrapper struct to be able to correctly merge `ARRAY_AGG` data from +/// multiple partitions using `BinaryHeap`. When used inside `BinaryHeap`, this +/// struct returns smallest `CustomElement`, where smallest is determined by +/// `ordering` values (`Vec`) according to `sort_options`. #[derive(Debug, PartialEq, Eq)] struct CustomElement<'a> { - // Stores from which partition entry is received + /// Stores the partition this entry came from branch_idx: usize, - // values to be merged + /// Values to merge value: ScalarValue, - // according to `ordering` values, comparisons will be done. + // Comparison "key" ordering: Vec, - // `sort_options` defines, desired ordering by the user + /// Options defining the ordering semantics sort_options: &'a [SortOptions], } @@ -411,87 +427,86 @@ impl<'a> PartialOrd for CustomElement<'a> { /// For each ScalarValue in the `values` we have a corresponding `Vec` (like timestamp of it) /// for the example above `sort_options` will have size two, that defines ordering requirement of the merge. /// Inner `Vec`s of the `ordering_values` will be compared according `sort_options` (Their sizes should match) -fn merge_ordered_arrays( +pub(crate) fn merge_ordered_arrays( // We will merge values into single `Vec`. - values: &[Vec], + values: &mut [VecDeque], // `values` will be merged according to `ordering_values`. // Inner `Vec` can be thought as ordering information for the // each `ScalarValue` in the values`. - ordering_values: &[Vec>], + ordering_values: &mut [VecDeque>], // Defines according to which ordering comparisons should be done. sort_options: &[SortOptions], ) -> Result<(Vec, Vec>)> { // Keep track the most recent data of each branch, in binary heap data structure. - let mut heap: BinaryHeap = BinaryHeap::new(); + let mut heap = BinaryHeap::::new(); - if !(values.len() == ordering_values.len() - && values + if values.len() != ordering_values.len() + || values .iter() .zip(ordering_values.iter()) - .all(|(vals, ordering_vals)| vals.len() == ordering_vals.len())) + .any(|(vals, ordering_vals)| vals.len() != ordering_vals.len()) { return exec_err!( "Expects values arguments and/or ordering_values arguments to have same size" ); } let n_branch = values.len(); - // For each branch we keep track of indices of next will be merged entry - let mut indices = vec![0_usize; n_branch]; - // Keep track of sizes of each branch. - let end_indices = (0..n_branch) - .map(|idx| values[idx].len()) - .collect::>(); let mut merged_values = vec![]; let mut merged_orderings = vec![]; // Continue iterating the loop until consuming data of all branches. loop { - let min_elem = if let Some(min_elem) = heap.pop() { - min_elem + let minimum = if let Some(minimum) = heap.pop() { + minimum } else { // Heap is empty, fill it with the next entries from each branch. - for (idx, end_idx, ordering, branch_index) in izip!( - indices.iter(), - end_indices.iter(), - ordering_values.iter(), - 0..n_branch - ) { - // We consumed this branch, skip it - if idx == end_idx { - continue; + for branch_idx in 0..n_branch { + if let Some(orderings) = ordering_values[branch_idx].pop_front() { + // Their size should be same, we can safely .unwrap here. + let value = values[branch_idx].pop_front().unwrap(); + // Push the next element to the heap: + heap.push(CustomElement::new( + branch_idx, + value, + orderings, + sort_options, + )); } - - // Push the next element to the heap. - let elem = CustomElement::new( - branch_index, - values[branch_index][*idx].clone(), - ordering[*idx].to_vec(), - sort_options, - ); - heap.push(elem); + // If None, we consumed this branch, skip it. } - // Now we have filled the heap, get the largest entry (this will be the next element in merge) - if let Some(min_elem) = heap.pop() { - min_elem + + // Now we have filled the heap, get the largest entry (this will be + // the next element in merge). + if let Some(minimum) = heap.pop() { + minimum } else { - // Heap is empty, this means that all indices are same with end_indices. e.g - // We have consumed all of the branches. Merging is completed - // Exit from the loop + // Heap is empty, this means that all indices are same with + // `end_indices`. We have consumed all of the branches, merge + // is completed, exit from the loop: break; } }; - let branch_idx = min_elem.branch_idx; - // Increment the index of merged branch, - indices[branch_idx] += 1; - let row_idx = indices[branch_idx]; - merged_values.push(min_elem.value.clone()); - merged_orderings.push(min_elem.ordering.clone()); - if row_idx < end_indices[branch_idx] { - // Push next entry in the most recently consumed branch to the heap - // If there is an available entry - let value = values[branch_idx][row_idx].clone(); - let ordering_row = ordering_values[branch_idx][row_idx].to_vec(); - let elem = CustomElement::new(branch_idx, value, ordering_row, sort_options); - heap.push(elem); + let CustomElement { + branch_idx, + value, + ordering, + .. + } = minimum; + // Add minimum value in the heap to the result + merged_values.push(value); + merged_orderings.push(ordering); + + // If there is an available entry, push next entry in the most + // recently consumed branch to the heap. + if let Some(orderings) = ordering_values[branch_idx].pop_front() { + // Their size should be same, we can safely .unwrap here. + let value = values[branch_idx].pop_front().unwrap(); + // Push the next element to the heap: + heap.push(CustomElement::new( + branch_idx, + value, + orderings, + sort_options, + )); } } @@ -500,12 +515,15 @@ fn merge_ordered_arrays( #[cfg(test)] mod tests { + use std::collections::VecDeque; + use std::sync::Arc; + use crate::aggregate::array_agg_ordered::merge_ordered_arrays; + use arrow_array::{Array, ArrayRef, Int64Array}; use arrow_schema::SortOptions; use datafusion_common::utils::get_row_at_idx; use datafusion_common::{Result, ScalarValue}; - use std::sync::Arc; #[test] fn test_merge_asc() -> Result<()> { @@ -516,7 +534,7 @@ mod tests { let n_row = lhs_arrays[0].len(); let lhs_orderings = (0..n_row) .map(|idx| get_row_at_idx(&lhs_arrays, idx)) - .collect::>>()?; + .collect::>>()?; let rhs_arrays: Vec = vec![ Arc::new(Int64Array::from(vec![0, 0, 1, 1, 2])), @@ -525,7 +543,7 @@ mod tests { let n_row = rhs_arrays[0].len(); let rhs_orderings = (0..n_row) .map(|idx| get_row_at_idx(&rhs_arrays, idx)) - .collect::>>()?; + .collect::>>()?; let sort_options = vec![ SortOptions { descending: false, @@ -540,12 +558,12 @@ mod tests { let lhs_vals_arr = Arc::new(Int64Array::from(vec![0, 1, 2, 3, 4])) as ArrayRef; let lhs_vals = (0..lhs_vals_arr.len()) .map(|idx| ScalarValue::try_from_array(&lhs_vals_arr, idx)) - .collect::>>()?; + .collect::>>()?; let rhs_vals_arr = Arc::new(Int64Array::from(vec![0, 1, 2, 3, 4])) as ArrayRef; let rhs_vals = (0..rhs_vals_arr.len()) .map(|idx| ScalarValue::try_from_array(&rhs_vals_arr, idx)) - .collect::>>()?; + .collect::>>()?; let expected = Arc::new(Int64Array::from(vec![0, 0, 1, 1, 2, 2, 3, 3, 4, 4])) as ArrayRef; let expected_ts = vec![ @@ -554,8 +572,8 @@ mod tests { ]; let (merged_vals, merged_ts) = merge_ordered_arrays( - &[lhs_vals, rhs_vals], - &[lhs_orderings, rhs_orderings], + &mut [lhs_vals, rhs_vals], + &mut [lhs_orderings, rhs_orderings], &sort_options, )?; let merged_vals = ScalarValue::iter_to_array(merged_vals.into_iter())?; @@ -583,7 +601,7 @@ mod tests { let n_row = lhs_arrays[0].len(); let lhs_orderings = (0..n_row) .map(|idx| get_row_at_idx(&lhs_arrays, idx)) - .collect::>>()?; + .collect::>>()?; let rhs_arrays: Vec = vec![ Arc::new(Int64Array::from(vec![2, 1, 1, 0, 0])), @@ -592,7 +610,7 @@ mod tests { let n_row = rhs_arrays[0].len(); let rhs_orderings = (0..n_row) .map(|idx| get_row_at_idx(&rhs_arrays, idx)) - .collect::>>()?; + .collect::>>()?; let sort_options = vec![ SortOptions { descending: true, @@ -608,12 +626,12 @@ mod tests { let lhs_vals_arr = Arc::new(Int64Array::from(vec![0, 1, 2, 1, 2])) as ArrayRef; let lhs_vals = (0..lhs_vals_arr.len()) .map(|idx| ScalarValue::try_from_array(&lhs_vals_arr, idx)) - .collect::>>()?; + .collect::>>()?; let rhs_vals_arr = Arc::new(Int64Array::from(vec![0, 1, 2, 1, 2])) as ArrayRef; let rhs_vals = (0..rhs_vals_arr.len()) .map(|idx| ScalarValue::try_from_array(&rhs_vals_arr, idx)) - .collect::>>()?; + .collect::>>()?; let expected = Arc::new(Int64Array::from(vec![0, 0, 1, 1, 2, 2, 1, 1, 2, 2])) as ArrayRef; let expected_ts = vec![ @@ -621,8 +639,8 @@ mod tests { Arc::new(Int64Array::from(vec![4, 4, 3, 3, 2, 2, 1, 1, 0, 0])) as ArrayRef, ]; let (merged_vals, merged_ts) = merge_ordered_arrays( - &[lhs_vals, rhs_vals], - &[lhs_orderings, rhs_orderings], + &mut [lhs_vals, rhs_vals], + &mut [lhs_orderings, rhs_orderings], &sort_options, )?; let merged_vals = ScalarValue::iter_to_array(merged_vals.into_iter())?; diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index c40f0db19405..1a3d21fc40bc 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -26,12 +26,15 @@ //! * Signature: see `Signature` //! * Return type: a function `(arg_types) -> return_type`. E.g. for min, ([f32]) -> f32, ([f64]) -> f64. +use std::sync::Arc; + use crate::aggregate::regr::RegrType; -use crate::{expressions, AggregateExpr, PhysicalExpr, PhysicalSortExpr}; +use crate::expressions::{self, Literal}; +use crate::{AggregateExpr, PhysicalExpr, PhysicalSortExpr}; + use arrow::datatypes::Schema; -use datafusion_common::{not_impl_err, DataFusionError, Result}; -pub use datafusion_expr::AggregateFunction; -use std::sync::Arc; +use datafusion_common::{internal_err, not_impl_err, DataFusionError, Result}; +use datafusion_expr::AggregateFunction; /// Create a physical aggregation expression. /// This function errors when `input_phy_exprs`' can't be coerced to a valid argument type of the aggregation function. @@ -369,6 +372,28 @@ pub fn create_aggregate_expr( ordering_req.to_vec(), ordering_types, )), + (AggregateFunction::NthValue, _) => { + let expr = &input_phy_exprs[0]; + let Some(n) = input_phy_exprs[1] + .as_any() + .downcast_ref::() + .map(|literal| literal.value()) + else { + return internal_err!( + "Second argument of NTH_VALUE needs to be a literal" + ); + }; + let nullable = expr.nullable(input_schema)?; + Arc::new(expressions::NthValueAgg::new( + expr.clone(), + n.clone().try_into()?, + name, + input_phy_types[0].clone(), + nullable, + ordering_types, + ordering_req.to_vec(), + )) + } (AggregateFunction::StringAgg, false) => { if !ordering_req.is_empty() { return not_impl_err!( @@ -396,9 +421,9 @@ mod tests { BitAnd, BitOr, BitXor, BoolAnd, BoolOr, Correlation, Count, Covariance, DistinctArrayAgg, DistinctCount, Max, Min, Stddev, Sum, Variance, }; + use arrow::datatypes::{DataType, Field}; - use datafusion_common::plan_err; - use datafusion_common::ScalarValue; + use datafusion_common::{plan_err, ScalarValue}; use datafusion_expr::type_coercion::aggregates::NUMERICS; use datafusion_expr::{type_coercion, Signature}; diff --git a/datafusion/physical-expr/src/aggregate/mod.rs b/datafusion/physical-expr/src/aggregate/mod.rs index 5bd1fca385b1..270a8e6f7705 100644 --- a/datafusion/physical-expr/src/aggregate/mod.rs +++ b/datafusion/physical-expr/src/aggregate/mod.rs @@ -20,7 +20,7 @@ use std::fmt::Debug; use std::sync::Arc; use self::groups_accumulator::GroupsAccumulator; -use crate::expressions::OrderSensitiveArrayAgg; +use crate::expressions::{NthValueAgg, OrderSensitiveArrayAgg}; use crate::{PhysicalExpr, PhysicalSortExpr}; use arrow::datatypes::Field; @@ -47,6 +47,7 @@ pub(crate) mod covariance; pub(crate) mod first_last; pub(crate) mod grouping; pub(crate) mod median; +pub(crate) mod nth_value; pub(crate) mod string_agg; #[macro_use] pub(crate) mod min_max; @@ -140,4 +141,5 @@ pub trait AggregateExpr: Send + Sync + Debug + PartialEq { /// However, an `ARRAY_AGG` with `ORDER BY` depends on the input ordering. pub fn is_order_sensitive(aggr_expr: &Arc) -> bool { aggr_expr.as_any().is::() + || aggr_expr.as_any().is::() } diff --git a/datafusion/physical-expr/src/aggregate/nth_value.rs b/datafusion/physical-expr/src/aggregate/nth_value.rs new file mode 100644 index 000000000000..5a1ca90b7f5e --- /dev/null +++ b/datafusion/physical-expr/src/aggregate/nth_value.rs @@ -0,0 +1,400 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines NTH_VALUE aggregate expression which may specify ordering requirement +//! that can evaluated at runtime during query execution + +use std::any::Any; +use std::collections::VecDeque; +use std::sync::Arc; + +use crate::aggregate::array_agg_ordered::merge_ordered_arrays; +use crate::aggregate::utils::{down_cast_any_ref, ordering_fields}; +use crate::expressions::format_state_name; +use crate::{ + reverse_order_bys, AggregateExpr, LexOrdering, PhysicalExpr, PhysicalSortExpr, +}; + +use arrow_array::cast::AsArray; +use arrow_array::ArrayRef; +use arrow_schema::{DataType, Field, Fields}; +use datafusion_common::utils::get_row_at_idx; +use datafusion_common::{exec_err, internal_err, DataFusionError, Result, ScalarValue}; +use datafusion_expr::Accumulator; + +/// Expression for a `NTH_VALUE(... ORDER BY ..., ...)` aggregation. In a multi +/// partition setting, partial aggregations are computed for every partition, +/// and then their results are merged. +#[derive(Debug)] +pub struct NthValueAgg { + /// Column name + name: String, + /// The `DataType` for the input expression + input_data_type: DataType, + /// The input expression + expr: Arc, + /// The `N` value. + n: i64, + /// If the input expression can have `NULL`s + nullable: bool, + /// Ordering data types + order_by_data_types: Vec, + /// Ordering requirement + ordering_req: LexOrdering, +} + +impl NthValueAgg { + /// Create a new `NthValueAgg` aggregate function + pub fn new( + expr: Arc, + n: i64, + name: impl Into, + input_data_type: DataType, + nullable: bool, + order_by_data_types: Vec, + ordering_req: LexOrdering, + ) -> Self { + Self { + name: name.into(), + input_data_type, + expr, + n, + nullable, + order_by_data_types, + ordering_req, + } + } +} + +impl AggregateExpr for NthValueAgg { + fn as_any(&self) -> &dyn Any { + self + } + + fn field(&self) -> Result { + Ok(Field::new(&self.name, self.input_data_type.clone(), true)) + } + + fn create_accumulator(&self) -> Result> { + Ok(Box::new(NthValueAccumulator::try_new( + self.n, + &self.input_data_type, + &self.order_by_data_types, + self.ordering_req.clone(), + )?)) + } + + fn state_fields(&self) -> Result> { + let mut fields = vec![Field::new_list( + format_state_name(&self.name, "nth_value"), + Field::new("item", self.input_data_type.clone(), true), + self.nullable, // This should be the same as field() + )]; + if !self.ordering_req.is_empty() { + let orderings = + ordering_fields(&self.ordering_req, &self.order_by_data_types); + fields.push(Field::new_list( + format_state_name(&self.name, "nth_value_orderings"), + Field::new("item", DataType::Struct(Fields::from(orderings)), true), + self.nullable, + )); + } + Ok(fields) + } + + fn expressions(&self) -> Vec> { + vec![self.expr.clone()] + } + + fn order_bys(&self) -> Option<&[PhysicalSortExpr]> { + (!self.ordering_req.is_empty()).then_some(&self.ordering_req) + } + + fn name(&self) -> &str { + &self.name + } + + fn reverse_expr(&self) -> Option> { + Some(Arc::new(Self { + name: self.name.to_string(), + input_data_type: self.input_data_type.clone(), + expr: self.expr.clone(), + // index should be from the opposite side + n: -self.n, + nullable: self.nullable, + order_by_data_types: self.order_by_data_types.clone(), + // reverse requirement + ordering_req: reverse_order_bys(&self.ordering_req), + }) as _) + } +} + +impl PartialEq for NthValueAgg { + fn eq(&self, other: &dyn Any) -> bool { + down_cast_any_ref(other) + .downcast_ref::() + .map(|x| { + self.name == x.name + && self.input_data_type == x.input_data_type + && self.order_by_data_types == x.order_by_data_types + && self.expr.eq(&x.expr) + }) + .unwrap_or(false) + } +} + +#[derive(Debug)] +pub(crate) struct NthValueAccumulator { + n: i64, + /// Stores entries in the `NTH_VALUE` result. + values: VecDeque, + /// Stores values of ordering requirement expressions corresponding to each + /// entry in `values`. This information is used when merging results from + /// different partitions. For detailed information how merging is done, see + /// [`merge_ordered_arrays`]. + ordering_values: VecDeque>, + /// Stores datatypes of expressions inside values and ordering requirement + /// expressions. + datatypes: Vec, + /// Stores the ordering requirement of the `Accumulator`. + ordering_req: LexOrdering, +} + +impl NthValueAccumulator { + /// Create a new order-sensitive NTH_VALUE accumulator based on the given + /// item data type. + pub fn try_new( + n: i64, + datatype: &DataType, + ordering_dtypes: &[DataType], + ordering_req: LexOrdering, + ) -> Result { + if n == 0 { + // n cannot be 0 + return internal_err!("Nth value indices are 1 based. 0 is invalid index"); + } + let mut datatypes = vec![datatype.clone()]; + datatypes.extend(ordering_dtypes.iter().cloned()); + Ok(Self { + n, + values: VecDeque::new(), + ordering_values: VecDeque::new(), + datatypes, + ordering_req, + }) + } +} + +impl Accumulator for NthValueAccumulator { + /// Updates its state with the `values`. Assumes data in the `values` satisfies the required + /// ordering for the accumulator (across consecutive batches, not just batch-wise). + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if values.is_empty() { + return Ok(()); + } + + let n_required = self.n.unsigned_abs() as usize; + let from_start = self.n > 0; + if from_start { + // direction is from start + let n_remaining = n_required.saturating_sub(self.values.len()); + self.append_new_data(values, Some(n_remaining))?; + } else { + // direction is from end + self.append_new_data(values, None)?; + let start_offset = self.values.len().saturating_sub(n_required); + if start_offset > 0 { + self.values.drain(0..start_offset); + self.ordering_values.drain(0..start_offset); + } + } + + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + if states.is_empty() { + return Ok(()); + } + // First entry in the state is the aggregation result. + let array_agg_values = &states[0]; + let n_required = self.n.unsigned_abs() as usize; + if self.ordering_req.is_empty() { + let array_agg_res = + ScalarValue::convert_array_to_scalar_vec(array_agg_values)?; + for v in array_agg_res.into_iter() { + self.values.extend(v); + if self.values.len() > n_required { + // There is enough data collected can stop merging + break; + } + } + } else if let Some(agg_orderings) = states[1].as_list_opt::() { + // 2nd entry stores values received for ordering requirement columns, for each aggregation value inside NTH_VALUE list. + // For each `StructArray` inside NTH_VALUE list, we will receive an `Array` that stores + // values received from its ordering requirement expression. (This information is necessary for during merging). + + // Stores NTH_VALUE results coming from each partition + let mut partition_values: Vec> = vec![]; + // Stores ordering requirement expression results coming from each partition + let mut partition_ordering_values: Vec>> = vec![]; + + // Existing values should be merged also. + partition_values.push(self.values.clone()); + + partition_ordering_values.push(self.ordering_values.clone()); + + let array_agg_res = + ScalarValue::convert_array_to_scalar_vec(array_agg_values)?; + + for v in array_agg_res.into_iter() { + partition_values.push(v.into()); + } + + let orderings = ScalarValue::convert_array_to_scalar_vec(agg_orderings)?; + + let ordering_values = orderings.into_iter().map(|partition_ordering_rows| { + // Extract value from struct to ordering_rows for each group/partition + partition_ordering_rows.into_iter().map(|ordering_row| { + if let ScalarValue::Struct(Some(ordering_columns_per_row), _) = ordering_row { + Ok(ordering_columns_per_row) + } else { + exec_err!( + "Expects to receive ScalarValue::Struct(Some(..), _) but got: {:?}", + ordering_row.data_type() + ) + } + }).collect::>>() + }).collect::>>()?; + for ordering_values in ordering_values.into_iter() { + partition_ordering_values.push(ordering_values.into()); + } + + let sort_options = self + .ordering_req + .iter() + .map(|sort_expr| sort_expr.options) + .collect::>(); + let (new_values, new_orderings) = merge_ordered_arrays( + &mut partition_values, + &mut partition_ordering_values, + &sort_options, + )?; + self.values = new_values.into(); + self.ordering_values = new_orderings.into(); + } else { + return exec_err!("Expects to receive a list array"); + } + Ok(()) + } + + fn state(&self) -> Result> { + let mut result = vec![self.evaluate_values()]; + if !self.ordering_req.is_empty() { + result.push(self.evaluate_orderings()); + } + Ok(result) + } + + fn evaluate(&self) -> Result { + let n_required = self.n.unsigned_abs() as usize; + let from_start = self.n > 0; + let nth_value_idx = if from_start { + // index is from start + let forward_idx = n_required - 1; + (forward_idx < self.values.len()).then_some(forward_idx) + } else { + // index is from end + self.values.len().checked_sub(n_required) + }; + if let Some(idx) = nth_value_idx { + Ok(self.values[idx].clone()) + } else { + ScalarValue::try_from(self.datatypes[0].clone()) + } + } + + fn size(&self) -> usize { + let mut total = std::mem::size_of_val(self) + + ScalarValue::size_of_vec_deque(&self.values) + - std::mem::size_of_val(&self.values); + + // Add size of the `self.ordering_values` + total += + std::mem::size_of::>() * self.ordering_values.capacity(); + for row in &self.ordering_values { + total += ScalarValue::size_of_vec(row) - std::mem::size_of_val(row); + } + + // Add size of the `self.datatypes` + total += std::mem::size_of::() * self.datatypes.capacity(); + for dtype in &self.datatypes { + total += dtype.size() - std::mem::size_of_val(dtype); + } + + // Add size of the `self.ordering_req` + total += std::mem::size_of::() * self.ordering_req.capacity(); + // TODO: Calculate size of each `PhysicalSortExpr` more accurately. + total + } +} + +impl NthValueAccumulator { + fn evaluate_orderings(&self) -> ScalarValue { + let fields = ordering_fields(&self.ordering_req, &self.datatypes[1..]); + let struct_field = Fields::from(fields); + + let orderings = self + .ordering_values + .iter() + .map(|ordering| { + ScalarValue::Struct(Some(ordering.clone()), struct_field.clone()) + }) + .collect::>(); + let struct_type = DataType::Struct(struct_field); + + // Wrap in List, so we have the same data structure ListArray(StructArray..) for group by cases + ScalarValue::List(ScalarValue::new_list(&orderings, &struct_type)) + } + + fn evaluate_values(&self) -> ScalarValue { + let mut values_cloned = self.values.clone(); + let values_slice = values_cloned.make_contiguous(); + ScalarValue::List(ScalarValue::new_list(values_slice, &self.datatypes[0])) + } + + /// Updates state, with the `values`. Fetch contains missing number of entries for state to be complete + /// None represents all of the new `values` need to be added to the state. + fn append_new_data( + &mut self, + values: &[ArrayRef], + fetch: Option, + ) -> Result<()> { + let n_row = values[0].len(); + let n_to_add = if let Some(fetch) = fetch { + std::cmp::min(fetch, n_row) + } else { + n_row + }; + for index in 0..n_to_add { + let row = get_row_at_idx(values, index)?; + self.values.push_back(row[0].clone()); + self.ordering_values.push_back(row[1..].to_vec()); + } + Ok(()) + } +} diff --git a/datafusion/physical-expr/src/aggregate/utils.rs b/datafusion/physical-expr/src/aggregate/utils.rs index d73c46a0f687..6dd586bfb8ce 100644 --- a/datafusion/physical-expr/src/aggregate/utils.rs +++ b/datafusion/physical-expr/src/aggregate/utils.rs @@ -196,9 +196,9 @@ pub(crate) fn ordering_fields( ordering_req .iter() .zip(data_types.iter()) - .map(|(expr, dtype)| { + .map(|(sort_expr, dtype)| { Field::new( - expr.to_string().as_str(), + sort_expr.expr.to_string().as_str(), dtype.clone(), // Multi partitions may be empty hence field should be nullable. true, diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index b6d0ad5b9104..bbfba4ad8310 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -60,6 +60,7 @@ pub use crate::aggregate::grouping::Grouping; pub use crate::aggregate::median::Median; pub use crate::aggregate::min_max::{Max, Min}; pub use crate::aggregate::min_max::{MaxAccumulator, MinAccumulator}; +pub use crate::aggregate::nth_value::NthValueAgg; pub use crate::aggregate::regr::{Regr, RegrType}; pub use crate::aggregate::stats::StatsType; pub use crate::aggregate::stddev::{Stddev, StddevPop}; @@ -67,7 +68,6 @@ pub use crate::aggregate::string_agg::StringAgg; pub use crate::aggregate::sum::Sum; pub use crate::aggregate::sum_distinct::DistinctSum; pub use crate::aggregate::variance::{Variance, VariancePop}; - pub use crate::window::cume_dist::cume_dist; pub use crate::window::cume_dist::CumeDist; pub use crate::window::lead_lag::WindowShift; @@ -77,6 +77,7 @@ pub use crate::window::ntile::Ntile; pub use crate::window::rank::{dense_rank, percent_rank, rank}; pub use crate::window::rank::{Rank, RankType}; pub use crate::window::row_number::RowNumber; +pub use crate::PhysicalSortExpr; pub use binary::{binary, BinaryExpr}; pub use case::{case, CaseExpr}; @@ -98,20 +99,20 @@ pub use try_cast::{try_cast, TryCastExpr}; pub fn format_state_name(name: &str, state_name: &str) -> String { format!("{name}[{state_name}]") } -pub use crate::PhysicalSortExpr; #[cfg(test)] pub(crate) mod tests { + use std::sync::Arc; + use crate::expressions::{col, create_aggregate_expr, try_cast}; use crate::{AggregateExpr, EmitTo}; + use arrow::record_batch::RecordBatch; use arrow_array::ArrayRef; use arrow_schema::{Field, Schema}; - use datafusion_common::Result; - use datafusion_common::ScalarValue; + use datafusion_common::{Result, ScalarValue}; use datafusion_expr::type_coercion::aggregates::coerce_types; use datafusion_expr::AggregateFunction; - use std::sync::Arc; /// macro to perform an aggregation using [`datafusion_expr::Accumulator`] and verify the /// result. diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index facd601955b6..d3ae0d5ce01f 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -25,7 +25,6 @@ use crate::aggregates::{ no_grouping::AggregateStream, row_hash::GroupedHashAggregateStream, topk_stream::GroupedTopKAggregateStream, }; - use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet}; use crate::windows::get_ordered_partition_by_indices; use crate::{ @@ -909,6 +908,7 @@ fn get_aggregate_exprs_requirement( let aggr_req = PhysicalSortRequirement::from_sort_exprs(aggr_req); let reverse_aggr_req = PhysicalSortRequirement::from_sort_exprs(&reverse_aggr_req); + if let Some(first_value) = aggr_expr.as_any().downcast_ref::() { let mut first_value = first_value.clone(); if eq_properties.ordering_satisfy_requirement(&concat_slices( @@ -931,7 +931,9 @@ fn get_aggregate_exprs_requirement( first_value = first_value.with_requirement_satisfied(false); *aggr_expr = Arc::new(first_value) as _; } - } else if let Some(last_value) = aggr_expr.as_any().downcast_ref::() { + continue; + } + if let Some(last_value) = aggr_expr.as_any().downcast_ref::() { let mut last_value = last_value.clone(); if eq_properties.ordering_satisfy_requirement(&concat_slices( prefix_requirement, @@ -953,18 +955,63 @@ fn get_aggregate_exprs_requirement( last_value = last_value.with_requirement_satisfied(false); *aggr_expr = Arc::new(last_value) as _; } - } else if let Some(finer_ordering) = + continue; + } + if let Some(finer_ordering) = + finer_ordering(&requirement, aggr_expr, group_by, eq_properties, agg_mode) + { + if eq_properties.ordering_satisfy(&finer_ordering) { + // Requirement is satisfied by existing ordering + requirement = finer_ordering; + continue; + } + } + if let Some(reverse_aggr_expr) = aggr_expr.reverse_expr() { + if let Some(finer_ordering) = finer_ordering( + &requirement, + &reverse_aggr_expr, + group_by, + eq_properties, + agg_mode, + ) { + if eq_properties.ordering_satisfy(&finer_ordering) { + // Reverse requirement is satisfied by exiting ordering. + // Hence reverse the aggregator + requirement = finer_ordering; + *aggr_expr = reverse_aggr_expr; + continue; + } + } + } + if let Some(finer_ordering) = finer_ordering(&requirement, aggr_expr, group_by, eq_properties, agg_mode) { + // There is a requirement that both satisfies existing requirement and current + // aggregate requirement. Use updated requirement requirement = finer_ordering; - } else { - // If neither of the requirements satisfy the other, this means - // requirements are conflicting. Currently, we do not support - // conflicting requirements. - return not_impl_err!( - "Conflicting ordering requirements in aggregate functions is not supported" - ); + continue; + } + if let Some(reverse_aggr_expr) = aggr_expr.reverse_expr() { + if let Some(finer_ordering) = finer_ordering( + &requirement, + &reverse_aggr_expr, + group_by, + eq_properties, + agg_mode, + ) { + // There is a requirement that both satisfies existing requirement and reverse + // aggregate requirement. Use updated requirement + requirement = finer_ordering; + *aggr_expr = reverse_aggr_expr; + continue; + } } + // Neither the existing requirement and current aggregate requirement satisfy the other, this means + // requirements are conflicting. Currently, we do not support + // conflicting requirements. + return not_impl_err!( + "Conflicting ordering requirements in aggregate functions is not supported" + ); } Ok(PhysicalSortRequirement::from_sort_exprs(&requirement)) } diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index c95465b5ae44..8bde0da133eb 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -715,6 +715,7 @@ enum AggregateFunction { REGR_SYY = 33; REGR_SXY = 34; STRING_AGG = 35; + NTH_VALUE_AGG = 36; } message AggregateExprNode { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index d5d86b2179fa..528761136ca3 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -457,6 +457,7 @@ impl serde::Serialize for AggregateFunction { Self::RegrSyy => "REGR_SYY", Self::RegrSxy => "REGR_SXY", Self::StringAgg => "STRING_AGG", + Self::NthValueAgg => "NTH_VALUE_AGG", }; serializer.serialize_str(variant) } @@ -504,6 +505,7 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { "REGR_SYY", "REGR_SXY", "STRING_AGG", + "NTH_VALUE_AGG", ]; struct GeneratedVisitor; @@ -580,6 +582,7 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { "REGR_SYY" => Ok(AggregateFunction::RegrSyy), "REGR_SXY" => Ok(AggregateFunction::RegrSxy), "STRING_AGG" => Ok(AggregateFunction::StringAgg), + "NTH_VALUE_AGG" => Ok(AggregateFunction::NthValueAgg), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 7e262e620fa7..9a0b7ab332a6 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -3079,6 +3079,7 @@ pub enum AggregateFunction { RegrSyy = 33, RegrSxy = 34, StringAgg = 35, + NthValueAgg = 36, } impl AggregateFunction { /// String value of the enum field names used in the ProtoBuf definition. @@ -3125,6 +3126,7 @@ impl AggregateFunction { AggregateFunction::RegrSyy => "REGR_SYY", AggregateFunction::RegrSxy => "REGR_SXY", AggregateFunction::StringAgg => "STRING_AGG", + AggregateFunction::NthValueAgg => "NTH_VALUE_AGG", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -3168,6 +3170,7 @@ impl AggregateFunction { "REGR_SYY" => Some(Self::RegrSyy), "REGR_SXY" => Some(Self::RegrSxy), "STRING_AGG" => Some(Self::StringAgg), + "NTH_VALUE_AGG" => Some(Self::NthValueAgg), _ => None, } } diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 2d9c7be46bc9..9185bdb80429 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +use std::sync::Arc; + use crate::protobuf::{ self, plan_type::PlanTypeEnum::{ @@ -26,6 +28,7 @@ use crate::protobuf::{ AnalyzedLogicalPlanType, CubeNode, GroupingSetNode, OptimizedLogicalPlanType, OptimizedPhysicalPlanType, PlaceholderNode, RollupNode, }; + use arrow::{ array::AsArray, buffer::Buffer, @@ -41,17 +44,19 @@ use datafusion_common::{ Constraints, DFField, DFSchema, DFSchemaRef, DataFusionError, OwnedTableReference, Result, ScalarValue, }; +use datafusion_expr::expr::{Alias, Placeholder}; use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_by}; use datafusion_expr::{ abs, acos, acosh, array, array_append, array_concat, array_dims, array_distinct, - array_element, array_except, array_has, array_has_all, array_has_any, - array_intersect, array_length, array_ndims, array_position, array_positions, - array_prepend, array_remove, array_remove_all, array_remove_n, array_repeat, - array_replace, array_replace_all, array_replace_n, array_resize, array_slice, - array_sort, array_to_string, array_union, arrow_typeof, ascii, asin, asinh, atan, - atan2, atanh, bit_length, btrim, cardinality, cbrt, ceil, character_length, chr, - coalesce, concat_expr, concat_ws_expr, cos, cosh, cot, current_date, current_time, - date_bin, date_part, date_trunc, decode, degrees, digest, encode, exp, + array_element, array_empty, array_except, array_has, array_has_all, array_has_any, + array_intersect, array_length, array_ndims, array_pop_back, array_pop_front, + array_position, array_positions, array_prepend, array_remove, array_remove_all, + array_remove_n, array_repeat, array_replace, array_replace_all, array_replace_n, + array_resize, array_slice, array_sort, array_to_string, array_union, arrow_typeof, + ascii, asin, asinh, atan, atan2, atanh, bit_length, btrim, cardinality, cbrt, ceil, + character_length, chr, coalesce, concat_expr, concat_ws_expr, cos, cosh, cot, + current_date, current_time, date_bin, date_part, date_trunc, decode, degrees, digest, + encode, exp, expr::{self, InList, Sort, WindowFunction}, factorial, find_in_set, flatten, floor, from_unixtime, gcd, gen_range, isnan, iszero, lcm, left, levenshtein, ln, log, log10, log2, @@ -68,11 +73,6 @@ use datafusion_expr::{ JoinConstraint, JoinType, Like, Operator, TryCast, WindowFrame, WindowFrameBound, WindowFrameUnits, }; -use datafusion_expr::{ - array_empty, array_pop_back, array_pop_front, - expr::{Alias, Placeholder}, -}; -use std::sync::Arc; #[derive(Debug)] pub enum Error { @@ -617,6 +617,7 @@ impl From for AggregateFunction { protobuf::AggregateFunction::Median => Self::Median, protobuf::AggregateFunction::FirstValueAgg => Self::FirstValue, protobuf::AggregateFunction::LastValueAgg => Self::LastValue, + protobuf::AggregateFunction::NthValueAgg => Self::NthValue, protobuf::AggregateFunction::StringAgg => Self::StringAgg, } } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index ec9b886c1f22..7eef3da9519f 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -31,6 +31,7 @@ use crate::protobuf::{ AnalyzedLogicalPlanType, CubeNode, EmptyMessage, GroupingSetNode, LogicalExprList, OptimizedLogicalPlanType, OptimizedPhysicalPlanType, PlaceholderNode, RollupNode, }; + use arrow::{ array::ArrayRef, datatypes::{ @@ -409,6 +410,7 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction { AggregateFunction::Median => Self::Median, AggregateFunction::FirstValue => Self::FirstValueAgg, AggregateFunction::LastValue => Self::LastValueAgg, + AggregateFunction::NthValue => Self::NthValueAgg, AggregateFunction::StringAgg => Self::StringAgg, } } @@ -728,6 +730,9 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { AggregateFunction::LastValue => { protobuf::AggregateFunction::LastValueAgg } + AggregateFunction::NthValue => { + protobuf::AggregateFunction::NthValueAgg + } AggregateFunction::StringAgg => { protobuf::AggregateFunction::StringAgg } diff --git a/datafusion/sqllogictest/test_files/group_by.slt b/datafusion/sqllogictest/test_files/group_by.slt index 7c5803d38594..79e6a9357b40 100644 --- a/datafusion/sqllogictest/test_files/group_by.slt +++ b/datafusion/sqllogictest/test_files/group_by.slt @@ -4714,3 +4714,143 @@ statement ok DROP TABLE uint64_dict; ### END Group By with Dictionary Variants ### + +statement ok +set datafusion.execution.target_partitions = 1; + +query III? +SELECT a, b, NTH_VALUE(c, 2), ARRAY_AGG(c) +FROM multiple_ordered_table +GROUP BY a, b +ORDER BY a, b; +---- +0 0 1 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24] +0 1 26 [25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49] +1 2 51 [50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74] +1 3 76 [75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99] + +query III? +SELECT a, b, NTH_VALUE(c, 2 ORDER BY c ASC), ARRAY_AGG(c ORDER BY c ASC) +FROM multiple_ordered_table +GROUP BY a, b +ORDER BY a, b; +---- +0 0 1 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24] +0 1 26 [25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49] +1 2 51 [50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74] +1 3 76 [75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99] + +query II?I +SELECT a, b, ARRAY_AGG(c ORDER BY c ASC), NTH_VALUE(c, 2 ORDER BY c DESC) +FROM multiple_ordered_table +GROUP BY a, b +ORDER BY a, b; +---- +0 0 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24] 23 +0 1 [25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49] 48 +1 2 [50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74] 73 +1 3 [75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99] 98 + +query IIIIII +SELECT a, b, NTH_VALUE(c, 2 ORDER BY c ASC), NTH_VALUE(c, 3 ORDER BY c ASC), NTH_VALUE(c, 2 ORDER BY c DESC), NTH_VALUE(c, 3 ORDER BY c DESC) +FROM multiple_ordered_table +GROUP BY a, b +ORDER BY a, b; +---- +0 0 1 2 23 22 +0 1 26 27 48 47 +1 2 51 52 73 72 +1 3 76 77 98 97 + +# we should be able to reverse array agg requirement, if it helps to remove a SortExec from plan. +query TT +EXPLAIN SELECT a, b, ARRAY_AGG(c ORDER BY c DESC) +FROM multiple_ordered_table +GROUP BY a, b +ORDER BY a, b; +---- +logical_plan +Sort: multiple_ordered_table.a ASC NULLS LAST, multiple_ordered_table.b ASC NULLS LAST +--Aggregate: groupBy=[[multiple_ordered_table.a, multiple_ordered_table.b]], aggr=[[ARRAY_AGG(multiple_ordered_table.c) ORDER BY [multiple_ordered_table.c DESC NULLS FIRST]]] +----TableScan: multiple_ordered_table projection=[a, b, c] +physical_plan +AggregateExec: mode=Single, gby=[a@0 as a, b@1 as b], aggr=[ARRAY_AGG(multiple_ordered_table.c)], ordering_mode=Sorted +--CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c], output_orderings=[[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST], [c@2 ASC NULLS LAST]], has_header=true + +query II? +SELECT a, b, ARRAY_AGG(c ORDER BY c DESC) +FROM multiple_ordered_table +GROUP BY a, b +ORDER BY a, b; +---- +0 0 [24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0] +0 1 [49, 48, 47, 46, 45, 44, 43, 42, 41, 40, 39, 38, 37, 36, 35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25] +1 2 [74, 73, 72, 71, 70, 69, 68, 67, 66, 65, 64, 63, 62, 61, 60, 59, 58, 57, 56, 55, 54, 53, 52, 51, 50] +1 3 [99, 98, 97, 96, 95, 94, 93, 92, 91, 90, 89, 88, 87, 86, 85, 84, 83, 82, 81, 80, 79, 78, 77, 76, 75] + +query II?II +SELECT a, b, ARRAY_AGG(d ORDER BY d DESC), NTH_VALUE(d, 1 ORDER BY d DESC), NTH_VALUE(d, 1 ORDER BY d ASC) +FROM multiple_ordered_table +GROUP BY a, b +ORDER BY a, b; +---- +0 0 [4, 4, 4, 4, 3, 3, 3, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0] 4 0 +0 1 [4, 4, 4, 3, 3, 3, 3, 3, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0] 4 0 +1 2 [4, 4, 4, 4, 4, 4, 3, 3, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0] 4 0 +1 3 [4, 4, 4, 4, 4, 4, 3, 3, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0] 4 0 + +# increase partition to 8 +statement ok +set datafusion.execution.target_partitions = 8; + +# NTH_VALUE(c, 2) and ARRAY_AGG(c)[2] should produce same result +query III +SELECT a, b, NTH_VALUE(c, 2) - ARRAY_AGG(c)[2] +FROM multiple_ordered_table +GROUP BY a, b +ORDER BY a, b; +---- +0 0 0 +0 1 0 +1 2 0 +1 3 0 + +query III? +SELECT a, b, NTH_VALUE(c, 2 ORDER BY c ASC), ARRAY_AGG(c ORDER BY c ASC) +FROM multiple_ordered_table +GROUP BY a, b +ORDER BY a, b; +---- +0 0 1 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24] +0 1 26 [25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49] +1 2 51 [50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74] +1 3 76 [75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99] + +query II?I +SELECT a, b, ARRAY_AGG(c ORDER BY c ASC), NTH_VALUE(c, 2 ORDER BY c DESC) +FROM multiple_ordered_table +GROUP BY a, b +ORDER BY a, b; +---- +0 0 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24] 23 +0 1 [25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49] 48 +1 2 [50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74] 73 +1 3 [75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99] 98 + +query IIIIII +SELECT a, b, NTH_VALUE(c, 2 ORDER BY c ASC), NTH_VALUE(c, 3 ORDER BY c ASC), NTH_VALUE(c, 2 ORDER BY c DESC), NTH_VALUE(c, 3 ORDER BY c DESC) +FROM multiple_ordered_table +GROUP BY a, b +ORDER BY a, b; +---- +0 0 1 2 23 22 +0 1 26 27 48 47 +1 2 51 52 73 72 +1 3 76 77 98 97 + +# nth value cannot work with conflicting requirements +statement error DataFusion error: This feature is not implemented: Conflicting ordering requirements in aggregate functions is not supported +SELECT a, b, NTH_VALUE(c, 2 ORDER BY c ASC), NTH_VALUE(c, 3 ORDER BY d ASC) +FROM multiple_ordered_table +GROUP BY a, b +ORDER BY a, b; \ No newline at end of file From 3f219bc929cfd418b0e3d3501f8eba1d5a2c87ae Mon Sep 17 00:00:00 2001 From: Dejan Simic <10134699+simicd@users.noreply.github.com> Date: Tue, 16 Jan 2024 22:02:05 +0100 Subject: [PATCH 09/39] Remove migrated unit tests (#8885) --- datafusion/core/tests/sql/csv_files.rs | 125 ------------------------- datafusion/core/tests/sql/mod.rs | 1 - 2 files changed, 126 deletions(-) delete mode 100644 datafusion/core/tests/sql/csv_files.rs diff --git a/datafusion/core/tests/sql/csv_files.rs b/datafusion/core/tests/sql/csv_files.rs deleted file mode 100644 index 5ed0068d6135..000000000000 --- a/datafusion/core/tests/sql/csv_files.rs +++ /dev/null @@ -1,125 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use super::*; - -#[tokio::test] -async fn csv_custom_quote() -> Result<()> { - let tmp_dir = TempDir::new()?; - let ctx = SessionContext::new(); - - let schema = Arc::new(Schema::new(vec![ - Field::new("c1", DataType::Utf8, false), - Field::new("c2", DataType::Utf8, false), - ])); - let filename = format!("partition.{}", "csv"); - let file_path = tmp_dir.path().join(filename); - let mut file = File::create(file_path)?; - - // generate some data - for index in 0..10 { - let text1 = format!("id{index:}"); - let text2 = format!("value{index:}"); - let data = format!("~{text1}~,~{text2}~\r\n"); - file.write_all(data.as_bytes())?; - } - ctx.register_csv( - "test", - tmp_dir.path().to_str().unwrap(), - CsvReadOptions::new() - .schema(&schema) - .has_header(false) - .quote(b'~'), - ) - .await?; - - let results = plan_and_collect(&ctx, "SELECT * from test").await?; - - let expected = vec![ - "+-----+--------+", - "| c1 | c2 |", - "+-----+--------+", - "| id0 | value0 |", - "| id1 | value1 |", - "| id2 | value2 |", - "| id3 | value3 |", - "| id4 | value4 |", - "| id5 | value5 |", - "| id6 | value6 |", - "| id7 | value7 |", - "| id8 | value8 |", - "| id9 | value9 |", - "+-----+--------+", - ]; - - assert_batches_sorted_eq!(expected, &results); - Ok(()) -} - -#[tokio::test] -async fn csv_custom_escape() -> Result<()> { - let tmp_dir = TempDir::new()?; - let ctx = SessionContext::new(); - - let schema = Arc::new(Schema::new(vec![ - Field::new("c1", DataType::Utf8, false), - Field::new("c2", DataType::Utf8, false), - ])); - let filename = format!("partition.{}", "csv"); - let file_path = tmp_dir.path().join(filename); - let mut file = File::create(file_path)?; - - // generate some data - for index in 0..10 { - let text1 = format!("id{index:}"); - let text2 = format!("value\\\"{index:}"); - let data = format!("\"{text1}\",\"{text2}\"\r\n"); - file.write_all(data.as_bytes())?; - } - - ctx.register_csv( - "test", - tmp_dir.path().to_str().unwrap(), - CsvReadOptions::new() - .schema(&schema) - .has_header(false) - .escape(b'\\'), - ) - .await?; - - let results = plan_and_collect(&ctx, "SELECT * from test").await?; - - let expected = vec![ - "+-----+---------+", - "| c1 | c2 |", - "+-----+---------+", - "| id0 | value\"0 |", - "| id1 | value\"1 |", - "| id2 | value\"2 |", - "| id3 | value\"3 |", - "| id4 | value\"4 |", - "| id5 | value\"5 |", - "| id6 | value\"6 |", - "| id7 | value\"7 |", - "| id8 | value\"8 |", - "| id9 | value\"9 |", - "+-----+---------+", - ]; - - assert_batches_sorted_eq!(expected, &results); - Ok(()) -} diff --git a/datafusion/core/tests/sql/mod.rs b/datafusion/core/tests/sql/mod.rs index 0960f93ae103..140eeb91d1a9 100644 --- a/datafusion/core/tests/sql/mod.rs +++ b/datafusion/core/tests/sql/mod.rs @@ -69,7 +69,6 @@ macro_rules! test_expression { pub mod aggregates; pub mod create_drop; -pub mod csv_files; pub mod explain_analyze; pub mod expr; pub mod joins; From ffaa67904ed0ca454267ccc5832582bcb669a5c0 Mon Sep 17 00:00:00 2001 From: Dejan Simic <10134699+simicd@users.noreply.github.com> Date: Wed, 17 Jan 2024 07:15:25 +0100 Subject: [PATCH 10/39] test: Port tests in `references.rs` to sqllogictest (#8877) * Migrate references unit tests to sqllogictest * Remove unused import --- datafusion/core/tests/sql/mod.rs | 56 ------- datafusion/core/tests/sql/references.rs | 140 ------------------ .../sqllogictest/test_files/references.slt | 134 +++++++++++++++++ 3 files changed, 134 insertions(+), 196 deletions(-) delete mode 100644 datafusion/core/tests/sql/references.rs create mode 100644 datafusion/sqllogictest/test_files/references.slt diff --git a/datafusion/core/tests/sql/mod.rs b/datafusion/core/tests/sql/mod.rs index 140eeb91d1a9..981bdf34f539 100644 --- a/datafusion/core/tests/sql/mod.rs +++ b/datafusion/core/tests/sql/mod.rs @@ -73,66 +73,10 @@ pub mod explain_analyze; pub mod expr; pub mod joins; pub mod partitioned_csv; -pub mod references; pub mod repartition; pub mod select; mod sql_api; -fn create_join_context( - column_left: &str, - column_right: &str, - repartition_joins: bool, -) -> Result { - let ctx = SessionContext::new_with_config( - SessionConfig::new() - .with_repartition_joins(repartition_joins) - .with_target_partitions(2) - .with_batch_size(4096), - ); - - let t1_schema = Arc::new(Schema::new(vec![ - Field::new(column_left, DataType::UInt32, true), - Field::new("t1_name", DataType::Utf8, true), - Field::new("t1_int", DataType::UInt32, true), - ])); - let t1_data = RecordBatch::try_new( - t1_schema, - vec![ - Arc::new(UInt32Array::from(vec![11, 22, 33, 44])), - Arc::new(StringArray::from(vec![ - Some("a"), - Some("b"), - Some("c"), - Some("d"), - ])), - Arc::new(UInt32Array::from(vec![1, 2, 3, 4])), - ], - )?; - ctx.register_batch("t1", t1_data)?; - - let t2_schema = Arc::new(Schema::new(vec![ - Field::new(column_right, DataType::UInt32, true), - Field::new("t2_name", DataType::Utf8, true), - Field::new("t2_int", DataType::UInt32, true), - ])); - let t2_data = RecordBatch::try_new( - t2_schema, - vec![ - Arc::new(UInt32Array::from(vec![11, 22, 44, 55])), - Arc::new(StringArray::from(vec![ - Some("z"), - Some("y"), - Some("x"), - Some("w"), - ])), - Arc::new(UInt32Array::from(vec![3, 1, 3, 3])), - ], - )?; - ctx.register_batch("t2", t2_data)?; - - Ok(ctx) -} - fn create_left_semi_anti_join_context_with_null_ids( column_left: &str, column_right: &str, diff --git a/datafusion/core/tests/sql/references.rs b/datafusion/core/tests/sql/references.rs deleted file mode 100644 index f465e8a2dacc..000000000000 --- a/datafusion/core/tests/sql/references.rs +++ /dev/null @@ -1,140 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use super::*; - -#[tokio::test] -async fn qualified_table_references() -> Result<()> { - let ctx = SessionContext::new(); - register_aggregate_csv(&ctx).await?; - - for table_ref in &[ - "aggregate_test_100", - "public.aggregate_test_100", - "datafusion.public.aggregate_test_100", - ] { - let sql = format!("SELECT COUNT(*) FROM {table_ref}"); - let actual = execute_to_batches(&ctx, &sql).await; - let expected = [ - "+----------+", - "| COUNT(*) |", - "+----------+", - "| 100 |", - "+----------+", - ]; - assert_batches_eq!(expected, &actual); - } - Ok(()) -} - -#[tokio::test] -async fn qualified_table_references_and_fields() -> Result<()> { - let ctx = SessionContext::new(); - - let c1: StringArray = vec!["foofoo", "foobar", "foobaz"] - .into_iter() - .map(Some) - .collect(); - let c2: Int64Array = vec![1, 2, 3].into_iter().map(Some).collect(); - let c3: Int64Array = vec![10, 20, 30].into_iter().map(Some).collect(); - - let batch = RecordBatch::try_from_iter(vec![ - ("f.c1", Arc::new(c1) as ArrayRef), - // evil -- use the same name as the table - ("test.c2", Arc::new(c2) as ArrayRef), - // more evil still - ("....", Arc::new(c3) as ArrayRef), - ])?; - - ctx.register_batch("test", batch)?; - - // referring to the unquoted column is an error - let sql = r#"SELECT f1.c1 from test"#; - let error = ctx.sql(sql).await.unwrap_err(); - assert_contains!( - error.to_string(), - r#"No field named f1.c1. Valid fields are test."f.c1", test."test.c2""# - ); - - // however, enclosing it in double quotes is ok - let sql = r#"SELECT "f.c1" from test"#; - let actual = execute_to_batches(&ctx, sql).await; - let expected = [ - "+--------+", - "| f.c1 |", - "+--------+", - "| foofoo |", - "| foobar |", - "| foobaz |", - "+--------+", - ]; - assert_batches_eq!(expected, &actual); - // Works fully qualified too - let sql = r#"SELECT test."f.c1" from test"#; - let actual = execute_to_batches(&ctx, sql).await; - assert_batches_eq!(expected, &actual); - - // check that duplicated table name and column name are ok - let sql = r#"SELECT "test.c2" as expr1, test."test.c2" as expr2 from test"#; - let actual = execute_to_batches(&ctx, sql).await; - let expected = [ - "+-------+-------+", - "| expr1 | expr2 |", - "+-------+-------+", - "| 1 | 1 |", - "| 2 | 2 |", - "| 3 | 3 |", - "+-------+-------+", - ]; - assert_batches_eq!(expected, &actual); - - // check that '....' is also an ok column name (in the sense that - // datafusion should run the query, not that someone should write - // this - let sql = r#"SELECT "....", "...." as c3 from test order by "....""#; - let actual = execute_to_batches(&ctx, sql).await; - let expected = [ - "+------+----+", - "| .... | c3 |", - "+------+----+", - "| 10 | 10 |", - "| 20 | 20 |", - "| 30 | 30 |", - "+------+----+", - ]; - assert_batches_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn test_partial_qualified_name() -> Result<()> { - let ctx = create_join_context("t1_id", "t2_id", true)?; - let sql = "SELECT t1.t1_id, t1_name FROM public.t1"; - let expected = [ - "+-------+---------+", - "| t1_id | t1_name |", - "+-------+---------+", - "| 11 | a |", - "| 22 | b |", - "| 33 | c |", - "| 44 | d |", - "+-------+---------+", - ]; - let actual = execute_to_batches(&ctx, sql).await; - assert_batches_eq!(expected, &actual); - Ok(()) -} diff --git a/datafusion/sqllogictest/test_files/references.slt b/datafusion/sqllogictest/test_files/references.slt new file mode 100644 index 000000000000..c09addb3eec2 --- /dev/null +++ b/datafusion/sqllogictest/test_files/references.slt @@ -0,0 +1,134 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +########## +## References Tests +########## + + +# Qualified table references +# Query tables with catalog prefix +statement ok +CREATE EXTERNAL TABLE aggregate_test_100 ( + c1 VARCHAR NOT NULL, + c2 TINYINT NOT NULL, + c3 SMALLINT NOT NULL, + c4 SMALLINT, + c5 INT, + c6 BIGINT NOT NULL, + c7 SMALLINT NOT NULL, + c8 INT NOT NULL, + c9 BIGINT UNSIGNED NOT NULL, + c10 VARCHAR NOT NULL, + c11 FLOAT NOT NULL, + c12 DOUBLE NOT NULL, + c13 VARCHAR NOT NULL +) +STORED AS CSV +WITH HEADER ROW +LOCATION '../../testing/data/csv/aggregate_test_100.csv' + +query I +SELECT COUNT(*) FROM aggregate_test_100; +---- +100 + +query I +SELECT COUNT(*) FROM public.aggregate_test_100; +---- +100 + +query I +SELECT COUNT(*) FROM datafusion.public.aggregate_test_100; +---- +100 + + +# Qualified table references and fields +# Query fields with prefixes +statement ok +CREATE TABLE test("f.c1" TEXT, "test.c2" INT, "...." INT) AS VALUES +('foofoo', 1, 10), +('foobar', 2, 20), +('foobaz', 3, 30); + +query error DataFusion error: Schema error: No field named f1\.c1\. Valid fields are test\."f\.c1", test\."test\.c2", test\."\.\.\.\."\. +SELECT f1.c1 FROM test; + +query T +SELECT "f.c1" FROM test +---- +foofoo +foobar +foobaz + +query T +SELECT test."f.c1" FROM test +---- +foofoo +foobar +foobaz + +query II +SELECT "test.c2" AS expr1, test."test.c2" AS expr2 FROM test +---- +1 1 +2 2 +3 3 + +query II +SELECT "....", "...." AS c3 FROM test ORDER BY "...." +---- +10 10 +20 20 +30 30 + +query TT +EXPLAIN (SELECT "....", "...." AS c3 FROM test ORDER BY "...."); +---- +logical_plan +Sort: test..... ASC NULLS LAST +--Projection: test....., test..... AS c3 +----TableScan: test projection=[....] +physical_plan +SortExec: expr=[....@0 ASC NULLS LAST] +--ProjectionExec: expr=[....@0 as ...., ....@0 as c3] +----MemoryExec: partitions=1, partition_sizes=[1] + + +# Partial qualified name +statement ok +CREATE TABLE t1(t1_id INT, t1_name TEXT, t1_int INT) AS VALUES +(11, 'a', 1), +(22, 'b', 2), +(33, 'c', 3), +(44, 'd', 3); + +statement ok +CREATE TABLE t2(t2_id INT, t2_name TEXT, t2_int INT) AS VALUES +(11, 'z', 3), +(22, 'y', 1), +(44, 'x', 3), +(55, 'w', 3); + +query IT +SELECT t1.t1_id, t1_name FROM public.t1; +---- +11 a +22 b +33 c +44 d From 31094b00e2e5f764a89a2e9806e98acf0576729f Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Thu, 18 Jan 2024 00:04:59 +0800 Subject: [PATCH 11/39] fix bug with `to_timestamp` and `InitCap` logical serialization, add roundtrip test between expression and proto, (#8868) * add roundtrip test between expression and proto --------- Co-authored-by: Andrew Lamb --- datafusion/proto/Cargo.toml | 1 + .../proto/src/logical_plan/from_proto.rs | 17 +++++++-- datafusion/proto/tests/cases/serialize.rs | 37 +++++++++++++++++++ 3 files changed, 51 insertions(+), 4 deletions(-) diff --git a/datafusion/proto/Cargo.toml b/datafusion/proto/Cargo.toml index f9f24b28db81..e42322021630 100644 --- a/datafusion/proto/Cargo.toml +++ b/datafusion/proto/Cargo.toml @@ -54,4 +54,5 @@ serde_json = { workspace = true, optional = true } [dev-dependencies] doc-comment = { workspace = true } +strum = { version = "0.25.0", features = ["derive"] } tokio = "1.18" diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 9185bdb80429..973e366d0bbd 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -58,8 +58,8 @@ use datafusion_expr::{ current_date, current_time, date_bin, date_part, date_trunc, decode, degrees, digest, encode, exp, expr::{self, InList, Sort, WindowFunction}, - factorial, find_in_set, flatten, floor, from_unixtime, gcd, gen_range, isnan, iszero, - lcm, left, levenshtein, ln, log, log10, log2, + factorial, find_in_set, flatten, floor, from_unixtime, gcd, gen_range, initcap, + isnan, iszero, lcm, left, levenshtein, ln, log, log10, log2, logical_plan::{PlanType, StringifiedPlan}, lower, lpad, ltrim, md5, nanvl, now, nullif, octet_length, overlay, pi, power, radians, random, regexp_match, regexp_replace, repeat, replace, reverse, right, @@ -1585,7 +1585,7 @@ pub fn parse_expr( Ok(character_length(parse_expr(&args[0], registry)?)) } ScalarFunction::Chr => Ok(chr(parse_expr(&args[0], registry)?)), - ScalarFunction::InitCap => Ok(ascii(parse_expr(&args[0], registry)?)), + ScalarFunction::InitCap => Ok(initcap(parse_expr(&args[0], registry)?)), ScalarFunction::Gcd => Ok(gcd( parse_expr(&args[0], registry)?, parse_expr(&args[1], registry)?, @@ -1742,7 +1742,16 @@ pub fn parse_expr( Ok(arrow_typeof(parse_expr(&args[0], registry)?)) } ScalarFunction::ToTimestamp => { - Ok(to_timestamp_seconds(parse_expr(&args[0], registry)?)) + let args: Vec<_> = args + .iter() + .map(|expr| parse_expr(expr, registry)) + .collect::>()?; + Ok(Expr::ScalarFunction( + datafusion_expr::expr::ScalarFunction::new( + BuiltinScalarFunction::ToTimestamp, + args, + ), + )) } ScalarFunction::Flatten => Ok(flatten(parse_expr(&args[0], registry)?)), ScalarFunction::StringToArray => Ok(string_to_array( diff --git a/datafusion/proto/tests/cases/serialize.rs b/datafusion/proto/tests/cases/serialize.rs index 5b890accd81f..222d1a3a629c 100644 --- a/datafusion/proto/tests/cases/serialize.rs +++ b/datafusion/proto/tests/cases/serialize.rs @@ -243,3 +243,40 @@ fn context_with_udf() -> SessionContext { ctx } + +#[test] +fn test_expression_serialization_roundtrip() { + use datafusion_common::ScalarValue; + use datafusion_expr::expr::ScalarFunction; + use datafusion_expr::BuiltinScalarFunction; + use datafusion_proto::logical_plan::from_proto::parse_expr; + use datafusion_proto::protobuf::LogicalExprNode; + use strum::IntoEnumIterator; + + let ctx = SessionContext::new(); + let lit = Expr::Literal(ScalarValue::Utf8(None)); + for builtin_fun in BuiltinScalarFunction::iter() { + // default to 4 args (though some exprs like substr have error checking) + let num_args = match builtin_fun { + BuiltinScalarFunction::Substr => 3, + _ => 4, + }; + let args: Vec<_> = std::iter::repeat(&lit).take(num_args).cloned().collect(); + let expr = Expr::ScalarFunction(ScalarFunction::new(builtin_fun, args)); + + let proto = LogicalExprNode::try_from(&expr).unwrap(); + let deserialize = parse_expr(&proto, &ctx).unwrap(); + + let serialize_name = extract_function_name(&expr); + let deserialize_name = extract_function_name(&deserialize); + + assert_eq!(serialize_name, deserialize_name); + } + + /// Extracts the first part of a function name + /// 'foo(bar)' -> 'foo' + fn extract_function_name(expr: &Expr) -> String { + let name = expr.display_name().unwrap(); + name.split('(').next().unwrap().to_string() + } +} From 9004eb4e4851a92e9bba60f0ca027514918ba1ae Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Thu, 18 Jan 2024 04:26:12 +0800 Subject: [PATCH 12/39] uncomment tests (#8881) --- datafusion/common/src/scalar.rs | 14 ++ .../physical-expr/src/array_expressions.rs | 22 ++- datafusion/sqllogictest/test_files/array.slt | 166 +++++++++--------- 3 files changed, 108 insertions(+), 94 deletions(-) diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index 9cbd9e292ff3..20d03c70960a 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -2957,6 +2957,20 @@ impl TryFrom<&DataType> for ScalarValue { .to_owned() .into(), ), + // 'ScalarValue::LargeList' contains single element `LargeListArray + DataType::LargeList(field) => ScalarValue::LargeList( + new_null_array( + &DataType::LargeList(Arc::new(Field::new( + "item", + field.data_type().clone(), + true, + ))), + 1, + ) + .as_list::() + .to_owned() + .into(), + ), DataType::Struct(fields) => ScalarValue::Struct(None, fields.clone()), DataType::Null => ScalarValue::Null, _ => { diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 5b35c4b9d8fb..af6587631df5 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -1074,7 +1074,9 @@ pub fn array_prepend(args: &[ArrayRef]) -> Result { } } -fn align_array_dimensions(args: Vec) -> Result> { +fn align_array_dimensions( + args: Vec, +) -> Result> { let args_ndim = args .iter() .map(|arg| datafusion_common::utils::list_ndims(arg.data_type())) @@ -1091,9 +1093,9 @@ fn align_array_dimensions(args: Vec) -> Result> { for _ in 0..(max_ndim - ndim) { let data_type = aligned_array.data_type().to_owned(); let array_lengths = vec![1; aligned_array.len()]; - let offsets = OffsetBuffer::::from_lengths(array_lengths); + let offsets = OffsetBuffer::::from_lengths(array_lengths); - aligned_array = Arc::new(ListArray::try_new( + aligned_array = Arc::new(GenericListArray::::try_new( Arc::new(Field::new("item", data_type, true)), offsets, aligned_array, @@ -1112,13 +1114,12 @@ fn align_array_dimensions(args: Vec) -> Result> { // Concatenate arrays on the same row. fn concat_internal(args: &[ArrayRef]) -> Result { - let args = align_array_dimensions(args.to_vec())?; + let args = align_array_dimensions::(args.to_vec())?; let list_arrays = args .iter() .map(|arg| as_generic_list_array::(arg)) .collect::>>()?; - // Assume number of rows is the same for all arrays let row_count = list_arrays[0].len(); @@ -2733,9 +2734,11 @@ mod tests { let array2d_1 = Arc::new(array_into_list_array(array1d_1.clone())) as ArrayRef; let array2d_2 = Arc::new(array_into_list_array(array1d_2.clone())) as ArrayRef; - let res = - align_array_dimensions(vec![array1d_1.to_owned(), array2d_2.to_owned()]) - .unwrap(); + let res = align_array_dimensions::(vec![ + array1d_1.to_owned(), + array2d_2.to_owned(), + ]) + .unwrap(); let expected = as_list_array(&array2d_1).unwrap(); let expected_dim = datafusion_common::utils::list_ndims(array2d_1.data_type()); @@ -2748,7 +2751,8 @@ mod tests { let array3d_1 = Arc::new(array_into_list_array(array2d_1)) as ArrayRef; let array3d_2 = array_into_list_array(array2d_2.to_owned()); let res = - align_array_dimensions(vec![array1d_1, Arc::new(array3d_2.clone())]).unwrap(); + align_array_dimensions::(vec![array1d_1, Arc::new(array3d_2.clone())]) + .unwrap(); let expected = as_list_array(&array3d_1).unwrap(); let expected_dim = datafusion_common::utils::list_ndims(array3d_1.data_type()); diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 6b45f204fefc..342fcb5bec3f 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -107,18 +107,17 @@ AS VALUES (make_array(make_array(4, 5, 6), make_array(10, 11, 12), make_array(4, 9, 8), make_array(7, 8, 9), make_array(10, 11, 12), make_array(1, 8, 7)), make_array(10, 11, 12), 3, make_array([[11, 12, 13], [14, 15, 16]], [[17, 18, 19], [20, 21, 22]]), make_array(121, 131, 141)) ; -# TODO: add this when #8305 is fixed -# statement ok -# CREATE TABLE large_nested_arrays -# AS -# SELECT -# arrow_cast(column1, 'LargeList(LargeList(Int64))') AS column1, -# arrow_cast(column2, 'LargeList(Int64)') AS column2, -# column3, -# arrow_cast(column4, 'LargeList(LargeList(List(Int64)))') AS column4, -# arrow_cast(column5, 'LargeList(Int64)') AS column5 -# FROM nested_arrays -# ; +statement ok +CREATE TABLE large_nested_arrays +AS + SELECT + arrow_cast(column1, 'LargeList(LargeList(Int64))') AS column1, + arrow_cast(column2, 'LargeList(Int64)') AS column2, + column3, + arrow_cast(column4, 'LargeList(LargeList(List(Int64)))') AS column4, + arrow_cast(column5, 'LargeList(Int64)') AS column5 + FROM nested_arrays +; statement ok CREATE TABLE arrays_values @@ -155,16 +154,15 @@ AS VALUES (NULL, NULL, NULL, NULL) ; -# TODO: add this when #8305 is fixed -# statement ok -# CREATE TABLE large_arrays_values_v2 -# AS SELECT -# arrow_cast(column1, 'LargeList(Int64)') AS column1, -# arrow_cast(column2, 'LargeList(Int64)') AS column2, -# column3, -# arrow_cast(column4, 'LargeList(LargeList(Int64))') AS column4 -# FROM arrays_values_v2 -# ; +statement ok +CREATE TABLE large_arrays_values_v2 +AS SELECT + arrow_cast(column1, 'LargeList(Int64)') AS column1, + arrow_cast(column2, 'LargeList(Int64)') AS column2, + column3, + arrow_cast(column4, 'LargeList(LargeList(Int64))') AS column4 +FROM arrays_values_v2 +; statement ok CREATE TABLE flatten_table @@ -1576,16 +1574,15 @@ select ---- [4] [] [1, , 3, 4] [, , 1] -# TODO: add this when #8305 is fixed -# query ???? -# select -# array_append(arrow_cast(make_array(), 'LargeList(Null)'), 4), -# array_append(make_array(), null), -# array_append(make_array(1, null, 3), 4), -# array_append(make_array(null, null), 1) -# ; -# ---- -# [4] [] [1, , 3, 4] [, , 1] +query ???? +select + array_append(arrow_cast(make_array(), 'LargeList(Null)'), 4), + array_append(make_array(), null), + array_append(make_array(1, null, 3), 4), + array_append(make_array(null, null), 1) +; +---- +[4] [] [1, , 3, 4] [, , 1] # test invalid (non-null) query error @@ -1604,13 +1601,12 @@ select ---- [[1, , 3], []] [[1, , 3], ] -# TODO: add this when #8305 is fixed -# query ?? -# select -# array_append(arrow_cast(make_array(make_array(1, null, 3), 'LargeList(LargeList(Int64))')), arrow_cast(make_array(null), 'LargeList(Int64)')), -# array_append(arrow_cast(make_array(make_array(1, null, 3), 'LargeList(LargeList(Int64))')), null); -# ---- -# [[1, , 3], []] [[1, , 3], ] +query ?? +select + array_append(arrow_cast(make_array(make_array(1, null, 3)), 'LargeList(LargeList(Int64))'), arrow_cast(make_array(null), 'LargeList(Int64)')), + array_append(arrow_cast(make_array(make_array(1, null, 3)), 'LargeList(LargeList(Int64))'), null); +---- +[[1, , 3], []] [[1, , 3], ] # array_append scalar function #3 query ??? @@ -1629,11 +1625,10 @@ select array_append(make_array([1], [2], [3]), make_array(4)), array_append(make ---- [[1], [2], [3], [4]] [[1.0], [2.0], [3.0], [4.0]] [[h], [e], [l], [l], [o]] -# TODO: add this when #8305 is fixed -# query ??? -# select array_append(arrow_cast(make_array([1], [2], [3]), 'LargeList(LargeList(Int64))'), arrow_cast(make_array(4), 'LargeList(Int64)')), array_append(arrow_cast(make_array([1.0], [2.0], [3.0]), 'LargeList(LargeList(Float64))'), arrow_cast(make_array(4.0), 'LargeList(Float64)')), array_append(arrow_cast(make_array(['h'], ['e'], ['l'], ['l']), 'LargeList(LargeList(Utf8))'), arrow_cast(make_array('o'), 'LargeList(Utf8)')); -# ---- -# [[1], [2], [3], [4]] [[1.0], [2.0], [3.0], [4.0]] [[h], [e], [l], [l], [o]] +query ??? +select array_append(arrow_cast(make_array([1], [2], [3]), 'LargeList(LargeList(Int64))'), arrow_cast(make_array(4), 'LargeList(Int64)')), array_append(arrow_cast(make_array([1.0], [2.0], [3.0]), 'LargeList(LargeList(Float64))'), arrow_cast(make_array(4.0), 'LargeList(Float64)')), array_append(arrow_cast(make_array(['h'], ['e'], ['l'], ['l']), 'LargeList(LargeList(Utf8))'), arrow_cast(make_array('o'), 'LargeList(Utf8)')); +---- +[[1], [2], [3], [4]] [[1.0], [2.0], [3.0], [4.0]] [[h], [e], [l], [l], [o]] # list_append scalar function #5 (function alias `array_append`) query ??? @@ -1700,12 +1695,11 @@ select array_append(column1, column2) from nested_arrays; [[1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6], [7, 8, 9]] [[4, 5, 6], [10, 11, 12], [4, 9, 8], [7, 8, 9], [10, 11, 12], [1, 8, 7], [10, 11, 12]] -# TODO: add this when #8305 is fixed -# query ? -# select array_append(column1, column2) from large_nested_arrays; -# ---- -# [[1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6], [7, 8, 9]] -# [[4, 5, 6], [10, 11, 12], [4, 9, 8], [7, 8, 9], [10, 11, 12], [1, 8, 7], [10, 11, 12]] +query ? +select array_append(column1, column2) from large_nested_arrays; +---- +[[1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6], [7, 8, 9]] +[[4, 5, 6], [10, 11, 12], [4, 9, 8], [7, 8, 9], [10, 11, 12], [1, 8, 7], [10, 11, 12]] # array_append with columns and scalars #1 query ?? @@ -1737,12 +1731,11 @@ select array_append(column1, make_array(1, 11, 111)), array_append(make_array(ma [[1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6], [1, 11, 111]] [[1, 2, 3], [11, 12, 13], [7, 8, 9]] [[4, 5, 6], [10, 11, 12], [4, 9, 8], [7, 8, 9], [10, 11, 12], [1, 8, 7], [1, 11, 111]] [[1, 2, 3], [11, 12, 13], [10, 11, 12]] -# TODO: add this when #8305 is fixed -# query ?? -# select array_append(column1, arrow_cast(make_array(1, 11, 111), 'LargeList(Int64)')), array_append(arrow_cast(make_array(make_array(1, 2, 3), make_array(11, 12, 13)), 'LargeList(LargeList(Int64))'), column2) from large_nested_arrays; -# ---- -# [[1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6], [1, 11, 111]] [[1, 2, 3], [11, 12, 13], [7, 8, 9]] -# [[4, 5, 6], [10, 11, 12], [4, 9, 8], [7, 8, 9], [10, 11, 12], [1, 8, 7], [1, 11, 111]] [[1, 2, 3], [11, 12, 13], [10, 11, 12]] +query ?? +select array_append(column1, arrow_cast(make_array(1, 11, 111), 'LargeList(Int64)')), array_append(arrow_cast(make_array(make_array(1, 2, 3), make_array(11, 12, 13)), 'LargeList(LargeList(Int64))'), column2) from large_nested_arrays; +---- +[[1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6], [1, 11, 111]] [[1, 2, 3], [11, 12, 13], [7, 8, 9]] +[[4, 5, 6], [10, 11, 12], [4, 9, 8], [7, 8, 9], [10, 11, 12], [1, 8, 7], [1, 11, 111]] [[1, 2, 3], [11, 12, 13], [10, 11, 12]] ## array_prepend (aliases: `list_prepend`, `array_push_front`, `list_push_front`) @@ -1816,11 +1809,12 @@ select array_prepend(make_array(1), make_array(make_array(2), make_array(3), mak ---- [[1], [2], [3], [4]] [[1.0], [2.0], [3.0], [4.0]] [[h], [e], [l], [l], [o]] -# TODO: add this when #8305 is fixed -# query ??? -# select array_prepend(arrow_cast(make_array(1), 'LargeList(Int64)'), arrow_cast(make_array(make_array(2), make_array(3), make_array(4)), 'LargeList(LargeList(Int64))')), array_prepend(arrow_cast(make_array(1.0), 'LargeList(Float64)'), arrow_cast(make_array([2.0], [3.0], [4.0]), 'LargeList(LargeList(Float64))')), array_prepend(arrow_cast(make_array('h'), 'LargeList(Utf8)'), arrow_cast(make_array(['e'], ['l'], ['l'], ['o']), 'LargeList(LargeList(Utf8))'')); -# ---- -# [[1], [2], [3], [4]] [[1.0], [2.0], [3.0], [4.0]] [[h], [e], [l], [l], [o]] +query ??? +select array_prepend(arrow_cast(make_array(1), 'LargeList(Int64)'), arrow_cast(make_array(make_array(2), make_array(3), make_array(4)), 'LargeList(LargeList(Int64))')), + array_prepend(arrow_cast(make_array(1.0), 'LargeList(Float64)'), arrow_cast(make_array([2.0], [3.0], [4.0]), 'LargeList(LargeList(Float64))')), + array_prepend(arrow_cast(make_array('h'), 'LargeList(Utf8)'), arrow_cast(make_array(['e'], ['l'], ['l'], ['o']), 'LargeList(LargeList(Utf8))')); +---- +[[1], [2], [3], [4]] [[1.0], [2.0], [3.0], [4.0]] [[h], [e], [l], [l], [o]] # list_prepend scalar function #5 (function alias `array_prepend`) query ??? @@ -1887,12 +1881,11 @@ select array_prepend(column2, column1) from nested_arrays; [[7, 8, 9], [1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6]] [[10, 11, 12], [4, 5, 6], [10, 11, 12], [4, 9, 8], [7, 8, 9], [10, 11, 12], [1, 8, 7]] -# TODO: add this when #8305 is fixed -# query ? -# select array_prepend(column2, column1) from large_nested_arrays; -# ---- -# [[7, 8, 9], [1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6]] -# [[10, 11, 12], [4, 5, 6], [10, 11, 12], [4, 9, 8], [7, 8, 9], [10, 11, 12], [1, 8, 7]] +query ? +select array_prepend(column2, column1) from large_nested_arrays; +---- +[[7, 8, 9], [1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6]] +[[10, 11, 12], [4, 5, 6], [10, 11, 12], [4, 9, 8], [7, 8, 9], [10, 11, 12], [1, 8, 7]] # array_prepend with columns and scalars #1 query ?? @@ -1924,12 +1917,11 @@ select array_prepend(make_array(1, 11, 111), column1), array_prepend(column2, ma [[1, 11, 111], [1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6]] [[7, 8, 9], [1, 2, 3], [11, 12, 13]] [[1, 11, 111], [4, 5, 6], [10, 11, 12], [4, 9, 8], [7, 8, 9], [10, 11, 12], [1, 8, 7]] [[10, 11, 12], [1, 2, 3], [11, 12, 13]] -# TODO: add this when #8305 is fixed -# query ?? -# select array_prepend(arrow_cast(make_array(1, 11, 111), 'LargeList(Int64)'), column1), array_prepend(column2, arrow_cast(make_array(make_array(1, 2, 3), make_array(11, 12, 13)), 'LargeList(LargeList(Int64))')) from large_nested_arrays; -# ---- -# [[1, 11, 111], [1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6]] [[7, 8, 9], [1, 2, 3], [11, 12, 13]] -# [[1, 11, 111], [4, 5, 6], [10, 11, 12], [4, 9, 8], [7, 8, 9], [10, 11, 12], [1, 8, 7]] [[10, 11, 12], [1, 2, 3], [11, 12, 13]] +query ?? +select array_prepend(arrow_cast(make_array(1, 11, 111), 'LargeList(Int64)'), column1), array_prepend(column2, arrow_cast(make_array(make_array(1, 2, 3), make_array(11, 12, 13)), 'LargeList(LargeList(Int64))')) from large_nested_arrays; +---- +[[1, 11, 111], [1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6]] [[7, 8, 9], [1, 2, 3], [11, 12, 13]] +[[1, 11, 111], [4, 5, 6], [10, 11, 12], [4, 9, 8], [7, 8, 9], [10, 11, 12], [1, 8, 7]] [[10, 11, 12], [1, 2, 3], [11, 12, 13]] ## array_repeat (aliases: `list_repeat`) @@ -2372,12 +2364,11 @@ select array_position(column1, column2), array_position(column1, column2, column 3 3 2 5 -#TODO: add this test when #8305 is fixed -#query II -#select array_position(column1, column2), array_position(column1, column2, column3) from nested_arrays; -#---- -#3 3 -#2 5 +query II +select array_position(column1, column2), array_position(column1, column2, column3) from nested_arrays; +---- +3 3 +2 5 # array_position with columns and scalars #1 query III @@ -2403,12 +2394,11 @@ select array_position(make_array([1, 2, 3], [4, 5, 6], [11, 12, 13]), column2), NULL 6 4 NULL 1 NULL -#TODO: add this test when #8305 is fixed -#query III -#select array_position(arrow_cast(make_array([1, 2, 3], [4, 5, 6], [11, 12, 13]), 'LargeList(List(Int64))'), column2), array_position(column1, make_array(4, 5, 6)), array_position(column1, make_array(1, 2, 3), 2) from large_nested_arrays; -#---- -#NULL 6 4 -#NULL 1 NULL +query III +select array_position(arrow_cast(make_array([1, 2, 3], [4, 5, 6], [11, 12, 13]), 'LargeList(LargeList(Int64))'), column2), array_position(column1, arrow_cast(make_array(4, 5, 6), 'LargeList(Int64)')), array_position(column1, arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'), 2) from large_nested_arrays; +---- +NULL 6 4 +NULL 1 NULL ## array_positions (aliases: `list_positions`) @@ -4989,6 +4979,9 @@ drop table values_without_nulls; statement ok drop table nested_arrays; +statement ok +drop table large_nested_arrays; + statement ok drop table arrays; @@ -5007,6 +5000,9 @@ drop table arrays_values; statement ok drop table arrays_values_v2; +statement ok +drop table large_arrays_values_v2; + statement ok drop table array_has_table_1D; From 81d9d8869c16d878375faeefb1b99f8cbd323785 Mon Sep 17 00:00:00 2001 From: SteveLauC Date: Thu, 18 Jan 2024 04:26:38 +0800 Subject: [PATCH 13/39] refactor: rename FileStream.file_reader to file_opener & update doc (#8883) --- .../src/datasource/physical_plan/file_stream.rs | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/datafusion/core/src/datasource/physical_plan/file_stream.rs b/datafusion/core/src/datasource/physical_plan/file_stream.rs index 353662397648..9cb58e7032db 100644 --- a/datafusion/core/src/datasource/physical_plan/file_stream.rs +++ b/datafusion/core/src/datasource/physical_plan/file_stream.rs @@ -83,11 +83,9 @@ pub struct FileStream { projected_schema: SchemaRef, /// The remaining number of records to parse, None if no limit remain: Option, - /// A closure that takes a reader and an optional remaining number of lines - /// (before reaching the limit) and returns a batch iterator. If the file reader - /// is not capable of limiting the number of records in the last batch, the file - /// stream will take care of truncating it. - file_reader: F, + /// A generic [`FileOpener`]. Calling `open()` returns a [`FileOpenFuture`], + /// which can be resolved to a stream of `RecordBatch`. + file_opener: F, /// The partition column projector pc_projector: PartitionColumnProjector, /// The stream state @@ -250,7 +248,7 @@ impl FileStream { pub fn new( config: &FileScanConfig, partition: usize, - file_reader: F, + file_opener: F, metrics: &ExecutionPlanMetricsSet, ) -> Result { let (projected_schema, ..) = config.project(); @@ -269,7 +267,7 @@ impl FileStream { file_iter: files.into(), projected_schema, remain: config.limit, - file_reader, + file_opener, pc_projector, state: FileStreamState::Idle, file_stream_metrics: FileStreamMetrics::new(metrics, partition), @@ -301,7 +299,7 @@ impl FileStream { }; Some( - self.file_reader + self.file_opener .open(file_meta) .map(|future| (future, part_file.partition_values)), ) From da776d9265f32ada57bdd72609c730f2fd155fa4 Mon Sep 17 00:00:00 2001 From: SteveLauC Date: Thu, 18 Jan 2024 05:13:51 +0800 Subject: [PATCH 14/39] docs: fix wrong name in sub-crates' README (#8889) --- datafusion/core/README.md | 2 +- datafusion/execution/README.md | 2 +- datafusion/physical-plan/README.md | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/datafusion/core/README.md b/datafusion/core/README.md index 5a9493d086cd..aa5dc08eaaa5 100644 --- a/datafusion/core/README.md +++ b/datafusion/core/README.md @@ -17,7 +17,7 @@ under the License. --> -# DataFusion Common +# DataFusion Core [DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. diff --git a/datafusion/execution/README.md b/datafusion/execution/README.md index 67aac6be82b3..8a03255ee4ad 100644 --- a/datafusion/execution/README.md +++ b/datafusion/execution/README.md @@ -17,7 +17,7 @@ under the License. --> -# DataFusion Common +# DataFusion Execution [DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. diff --git a/datafusion/physical-plan/README.md b/datafusion/physical-plan/README.md index 366a6b555150..ec604253fd2e 100644 --- a/datafusion/physical-plan/README.md +++ b/datafusion/physical-plan/README.md @@ -17,7 +17,7 @@ under the License. --> -# DataFusion Common +# DataFusion Physical Plan [DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. From 89d22b4b5c62f26e6cc0fc86dfa631ada1f44567 Mon Sep 17 00:00:00 2001 From: Matt Gapp <61894094+matthewgapp@users.noreply.github.com> Date: Wed, 17 Jan 2024 13:24:20 -0800 Subject: [PATCH 15/39] Recursive CTEs: Stage 1 - add config flag (#8828) * add config flag for recursive ctes update docs from script update slt test for doc change * restore testing pin --- datafusion/common/src/config.rs | 5 +++++ datafusion/sql/src/query.rs | 13 ++++++++++++- .../sqllogictest/test_files/information_schema.slt | 2 ++ docs/source/user-guide/configs.md | 1 + 4 files changed, 20 insertions(+), 1 deletion(-) diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index 996a505dea80..e00c17930850 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -290,6 +290,11 @@ config_namespace! { /// Hive. Note that this setting does not affect reading partitioned /// tables (e.g. `/table/year=2021/month=01/data.parquet`). pub listing_table_ignore_subdirectory: bool, default = true + + /// Should DataFusion support recursive CTEs + /// Defaults to false since this feature is a work in progress and may not + /// behave as expected + pub enable_recursive_ctes: bool, default = false } } diff --git a/datafusion/sql/src/query.rs b/datafusion/sql/src/query.rs index dd4cab126261..388377e3ee6b 100644 --- a/datafusion/sql/src/query.rs +++ b/datafusion/sql/src/query.rs @@ -54,7 +54,18 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // Process CTEs from top to bottom // do not allow self-references if with.recursive { - return not_impl_err!("Recursive CTEs are not supported"); + if self + .context_provider + .options() + .execution + .enable_recursive_ctes + { + return plan_err!( + "Recursive CTEs are enabled but are not yet supported" + ); + } else { + return not_impl_err!("Recursive CTEs are not supported"); + } } for cte in with.cte_tables { diff --git a/datafusion/sqllogictest/test_files/information_schema.slt b/datafusion/sqllogictest/test_files/information_schema.slt index 44daa5141677..b37b78ab6d79 100644 --- a/datafusion/sqllogictest/test_files/information_schema.slt +++ b/datafusion/sqllogictest/test_files/information_schema.slt @@ -150,6 +150,7 @@ datafusion.execution.aggregate.scalar_update_factor 10 datafusion.execution.batch_size 8192 datafusion.execution.coalesce_batches true datafusion.execution.collect_statistics false +datafusion.execution.enable_recursive_ctes false datafusion.execution.listing_table_ignore_subdirectory true datafusion.execution.max_buffered_batches_per_output_file 2 datafusion.execution.meta_fetch_concurrency 32 @@ -225,6 +226,7 @@ datafusion.execution.aggregate.scalar_update_factor 10 Specifies the threshold f datafusion.execution.batch_size 8192 Default batch size while creating new batches, it's especially useful for buffer-in-memory batches since creating tiny batches would result in too much metadata memory consumption datafusion.execution.coalesce_batches true When set to true, record batches will be examined between each operator and small batches will be coalesced into larger batches. This is helpful when there are highly selective filters or joins that could produce tiny output batches. The target batch size is determined by the configuration setting datafusion.execution.collect_statistics false Should DataFusion collect statistics after listing files +datafusion.execution.enable_recursive_ctes false Should DataFusion support recursive CTEs Defaults to false since this feature is a work in progress and may not behave as expected datafusion.execution.listing_table_ignore_subdirectory true Should sub directories be ignored when scanning directories for data files. Defaults to true (ignores subdirectories), consistent with Hive. Note that this setting does not affect reading partitioned tables (e.g. `/table/year=2021/month=01/data.parquet`). datafusion.execution.max_buffered_batches_per_output_file 2 This is the maximum number of RecordBatches buffered for each output file being worked. Higher values can potentially give faster write performance at the cost of higher peak memory consumption datafusion.execution.meta_fetch_concurrency 32 Number of files to read in parallel when inferring schema and statistics diff --git a/docs/source/user-guide/configs.md b/docs/source/user-guide/configs.md index 5e26e2b205dd..a812b74284cf 100644 --- a/docs/source/user-guide/configs.md +++ b/docs/source/user-guide/configs.md @@ -83,6 +83,7 @@ Environment variables are read during `SessionConfig` initialisation so they mus | datafusion.execution.soft_max_rows_per_output_file | 50000000 | Target number of rows in output files when writing multiple. This is a soft max, so it can be exceeded slightly. There also will be one file smaller than the limit if the total number of rows written is not roughly divisible by the soft max | | datafusion.execution.max_buffered_batches_per_output_file | 2 | This is the maximum number of RecordBatches buffered for each output file being worked. Higher values can potentially give faster write performance at the cost of higher peak memory consumption | | datafusion.execution.listing_table_ignore_subdirectory | true | Should sub directories be ignored when scanning directories for data files. Defaults to true (ignores subdirectories), consistent with Hive. Note that this setting does not affect reading partitioned tables (e.g. `/table/year=2021/month=01/data.parquet`). | +| datafusion.execution.enable_recursive_ctes | false | Should DataFusion support recursive CTEs Defaults to false since this feature is a work in progress and may not behave as expected | | datafusion.optimizer.enable_distinct_aggregation_soft_limit | true | When set to true, the optimizer will push a limit operation into grouped aggregations which have no aggregate expressions, as a soft limit, emitting groups once the limit is reached, before all rows in the group are read. | | datafusion.optimizer.enable_round_robin_repartition | true | When set to true, the physical plan optimizer will try to add round robin repartitioning to increase parallelism to leverage more CPU cores | | datafusion.optimizer.enable_topk_aggregation | true | When set to true, the optimizer will attempt to perform limit operations during aggregations, if possible | From 57e38fb16a9918fb16c2c50e2188cc22c0afe4ad Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Thu, 18 Jan 2024 11:19:23 +0800 Subject: [PATCH 16/39] Support array literal with scalar function (#8884) * fix struct Signed-off-by: jayzhan211 * add test Signed-off-by: jayzhan211 * support functions Signed-off-by: jayzhan211 --------- Signed-off-by: jayzhan211 --- datafusion/sql/src/expr/value.rs | 40 ++++---------------- datafusion/sqllogictest/test_files/array.slt | 26 ++++++++++++- 2 files changed, 31 insertions(+), 35 deletions(-) diff --git a/datafusion/sql/src/expr/value.rs b/datafusion/sql/src/expr/value.rs index 9f88318ab21a..c0870cc54106 100644 --- a/datafusion/sql/src/expr/value.rs +++ b/datafusion/sql/src/expr/value.rs @@ -24,8 +24,8 @@ use datafusion_common::{ }; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::expr::{BinaryExpr, Placeholder}; +use datafusion_expr::BuiltinScalarFunction; use datafusion_expr::{lit, Expr, Operator}; -use datafusion_expr::{BuiltinScalarFunction, ScalarFunctionDefinition}; use log::debug; use sqlparser::ast::{BinaryOperator, Expr as SQLExpr, Interval, Value}; use sqlparser::parser::ParserError::ParserError; @@ -135,38 +135,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { elements: Vec, schema: &DFSchema, ) -> Result { - let mut values = Vec::with_capacity(elements.len()); - - for element in elements { - let value = self.sql_expr_to_logical_expr( - element, - schema, - &mut PlannerContext::new(), - )?; - - match value { - Expr::Literal(_) => { - values.push(value); - } - Expr::ScalarFunction(ScalarFunction { - func_def: ScalarFunctionDefinition::BuiltIn(fun), - .. - }) => { - if fun == BuiltinScalarFunction::MakeArray { - values.push(value); - } else { - return not_impl_err!( - "ScalarFunctions without MakeArray are not supported: {value}" - ); - } - } - _ => { - return not_impl_err!( - "Arrays with elements other than literal are not supported: {value}" - ); - } - } - } + let values = elements + .into_iter() + .map(|element| { + self.sql_expr_to_logical_expr(element, schema, &mut PlannerContext::new()) + }) + .collect::>>()?; Ok(Expr::ScalarFunction(ScalarFunction::new( BuiltinScalarFunction::MakeArray, diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 342fcb5bec3f..55cd17724565 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -454,11 +454,33 @@ AS FROM nested_arrays_with_repeating_elements ; +# Array literal + +## boolean coercion is not supported query error select [1, true, null] -query error DataFusion error: This feature is not implemented: ScalarFunctions without MakeArray are not supported: now() -SELECT [now()] +## wrapped in array_length to get deterministic results +query I +SELECT array_length([now()]) +---- +1 + +## array literal with functions +query ? +select [abs(-1.2), sin(-1), log(2), ceil(3.141)] +---- +[1.2, -0.8414709848078965, 0.3010299801826477, 4.0] + +## array literal with nested types +query ??? +select + [struct('foo', 1)], + [struct('foo', [1,2,3])], + [struct('foo', [struct(3, 'x')])] +; +---- +[{c0: foo, c1: 1}] [{c0: foo, c1: [1, 2, 3]}] [{c0: foo, c1: [{c0: 3, c1: x}]}] query TTT select arrow_typeof(column1), arrow_typeof(column2), arrow_typeof(column3) from arrays; From 5adfc15df494093294d3ef9568277faefc69a1a1 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 18 Jan 2024 10:43:47 +0100 Subject: [PATCH 17/39] Bump actions/cache from 3 to 4 (#8903) Bumps [actions/cache](https://github.com/actions/cache) from 3 to 4. - [Release notes](https://github.com/actions/cache/releases) - [Changelog](https://github.com/actions/cache/blob/main/RELEASES.md) - [Commits](https://github.com/actions/cache/compare/v3...v4) --- updated-dependencies: - dependency-name: actions/cache dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/rust.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 62992e7acf68..375c9f2c2c5a 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -53,7 +53,7 @@ jobs: rust-version: stable - name: Cache Cargo - uses: actions/cache@v3 + uses: actions/cache@v4 with: path: | ~/.cargo/bin/ @@ -383,7 +383,7 @@ jobs: # rustup default stable # rustup component add rustfmt clippy # - name: Cache Cargo - # uses: actions/cache@v3 + # uses: actions/cache@v4 # with: # path: /home/runner/.cargo # # this key is not equal because the user is different than on a container (runner vs github) From bdf5e9d81091971328fbcf9f9517ae38c07d64c4 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 18 Jan 2024 06:34:04 -0500 Subject: [PATCH 18/39] Fix `datafusion-cli` print output (#8895) * Fix datafusion-cli print output * fmt * Do not print header if only empty batches, test for same --- datafusion-cli/src/print_format.rs | 127 +++++++++++++++++++++++++++-- 1 file changed, 120 insertions(+), 7 deletions(-) diff --git a/datafusion-cli/src/print_format.rs b/datafusion-cli/src/print_format.rs index ea418562495d..0a8c7b4b3e3a 100644 --- a/datafusion-cli/src/print_format.rs +++ b/datafusion-cli/src/print_format.rs @@ -161,23 +161,29 @@ impl PrintFormat { maxrows: MaxRows, with_header: bool, ) -> Result<()> { - if batches.is_empty() || batches[0].num_rows() == 0 { + // filter out any empty batches + let batches: Vec<_> = batches + .iter() + .filter(|b| b.num_rows() > 0) + .cloned() + .collect(); + if batches.is_empty() { return Ok(()); } match self { Self::Csv | Self::Automatic => { - print_batches_with_sep(writer, batches, b',', with_header) + print_batches_with_sep(writer, &batches, b',', with_header) } - Self::Tsv => print_batches_with_sep(writer, batches, b'\t', with_header), + Self::Tsv => print_batches_with_sep(writer, &batches, b'\t', with_header), Self::Table => { if maxrows == MaxRows::Limited(0) { return Ok(()); } - format_batches_with_maxrows(writer, batches, maxrows) + format_batches_with_maxrows(writer, &batches, maxrows) } - Self::Json => batches_to_json!(ArrayWriter, writer, batches), - Self::NdJson => batches_to_json!(LineDelimitedWriter, writer, batches), + Self::Json => batches_to_json!(ArrayWriter, writer, &batches), + Self::NdJson => batches_to_json!(LineDelimitedWriter, writer, &batches), } } } @@ -189,7 +195,7 @@ mod tests { use super::*; - use arrow::array::Int32Array; + use arrow::array::{ArrayRef, Int32Array}; use arrow::datatypes::{DataType, Field, Schema}; use datafusion::error::Result; @@ -351,4 +357,111 @@ mod tests { Ok(()) } + + #[test] + fn test_print_batches_empty_batches() -> Result<()> { + let batch = one_column_batch(); + let empty_batch = RecordBatch::new_empty(batch.schema()); + + #[rustfmt::skip] + let expected =&[ + "+---+", + "| a |", + "+---+", + "| 1 |", + "| 2 |", + "| 3 |", + "+---+\n", + ]; + + PrintBatchesTest::new() + .with_format(PrintFormat::Table) + .with_batches(vec![empty_batch.clone(), batch, empty_batch]) + .with_expected(expected) + .run(); + Ok(()) + } + + #[test] + fn test_print_batches_empty_batches_no_header() -> Result<()> { + let empty_batch = RecordBatch::new_empty(one_column_batch().schema()); + + // empty batches should not print a header + let expected = &[""]; + + PrintBatchesTest::new() + .with_format(PrintFormat::Table) + .with_batches(vec![empty_batch]) + .with_header(true) + .with_expected(expected) + .run(); + Ok(()) + } + + struct PrintBatchesTest { + format: PrintFormat, + batches: Vec, + maxrows: MaxRows, + with_header: bool, + expected: Vec<&'static str>, + } + + impl PrintBatchesTest { + fn new() -> Self { + Self { + format: PrintFormat::Table, + batches: vec![], + maxrows: MaxRows::Unlimited, + with_header: false, + expected: vec![], + } + } + + /// set the format + fn with_format(mut self, format: PrintFormat) -> Self { + self.format = format; + self + } + + /// set the batches to convert + fn with_batches(mut self, batches: Vec) -> Self { + self.batches = batches; + self + } + + /// set whether to include a header + fn with_header(mut self, with_header: bool) -> Self { + self.with_header = with_header; + self + } + + /// set expected output + fn with_expected(mut self, expected: &[&'static str]) -> Self { + self.expected = expected.to_vec(); + self + } + + /// run the test + fn run(self) { + let mut buffer: Vec = vec![]; + self.format + .print_batches(&mut buffer, &self.batches, self.maxrows, self.with_header) + .unwrap(); + let actual = String::from_utf8(buffer).unwrap(); + let expected = self.expected.join("\n"); + assert_eq!( + actual, expected, + "actual:\n\n{actual}expected:\n\n{expected}" + ); + } + } + + /// return a batch with one column and three rows + fn one_column_batch() -> RecordBatch { + RecordBatch::try_from_iter(vec![( + "a", + Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef, + )]) + .unwrap() + } } From eb9bbe817479747f9ec5fc537f4b08a729c5e84b Mon Sep 17 00:00:00 2001 From: SteveLauC Date: Thu, 18 Jan 2024 20:07:25 +0800 Subject: [PATCH 19/39] docs: add an example for RecordBatchReceiverStreamBuilder (#8888) --- datafusion/physical-plan/src/stream.rs | 68 ++++++++++++++++++++++---- 1 file changed, 59 insertions(+), 9 deletions(-) diff --git a/datafusion/physical-plan/src/stream.rs b/datafusion/physical-plan/src/stream.rs index fdf32620ca50..e4ef6c423865 100644 --- a/datafusion/physical-plan/src/stream.rs +++ b/datafusion/physical-plan/src/stream.rs @@ -156,14 +156,62 @@ impl ReceiverStreamBuilder { } } -/// Builder for [`RecordBatchReceiverStream`] that propagates errors +/// Builder for `RecordBatchReceiverStream` that propagates errors /// and panic's correctly. /// -/// [`RecordBatchReceiverStream`] is used to spawn one or more tasks -/// that produce `RecordBatch`es and send them to a single +/// [`RecordBatchReceiverStreamBuilder`] is used to spawn one or more tasks +/// that produce [`RecordBatch`]es and send them to a single /// `Receiver` which can improve parallelism. /// /// This also handles propagating panic`s and canceling the tasks. +/// +/// # Example +/// +/// The following example spawns 2 tasks that will write [`RecordBatch`]es to +/// the `tx` end of the builder, after building the stream, we can receive +/// those batches with calling `.next()` +/// +/// ``` +/// # use std::sync::Arc; +/// # use datafusion_common::arrow::datatypes::{Schema, Field, DataType}; +/// # use datafusion_common::arrow::array::RecordBatch; +/// # use datafusion_physical_plan::stream::RecordBatchReceiverStreamBuilder; +/// # use futures::stream::StreamExt; +/// # use tokio::runtime::Builder; +/// # let rt = Builder::new_current_thread().build().unwrap(); +/// # +/// # rt.block_on(async { +/// let schema = Arc::new(Schema::new(vec![Field::new("foo", DataType::Int8, false)])); +/// let mut builder = RecordBatchReceiverStreamBuilder::new(Arc::clone(&schema), 10); +/// +/// // task 1 +/// let tx_1 = builder.tx(); +/// let schema_1 = Arc::clone(&schema); +/// builder.spawn(async move { +/// // Your task needs to send batches to the tx +/// tx_1.send(Ok(RecordBatch::new_empty(schema_1))).await.unwrap(); +/// +/// Ok(()) +/// }); +/// +/// // task 2 +/// let tx_2 = builder.tx(); +/// let schema_2 = Arc::clone(&schema); +/// builder.spawn(async move { +/// // Your task needs to send batches to the tx +/// tx_2.send(Ok(RecordBatch::new_empty(schema_2))).await.unwrap(); +/// +/// Ok(()) +/// }); +/// +/// let mut stream = builder.build(); +/// while let Some(res_batch) = stream.next().await { +/// // `res_batch` can either from task 1 or 2 +/// +/// // do something with `res_batch` +/// } +/// # }); +/// ``` pub struct RecordBatchReceiverStreamBuilder { schema: SchemaRef, inner: ReceiverStreamBuilder, @@ -186,8 +234,9 @@ impl RecordBatchReceiverStreamBuilder { /// Spawn task that will be aborted if this builder (or the stream /// built from it) are dropped /// - /// this is often used to spawn tasks that write to the sender - /// retrieved from `Self::tx` + /// This is often used to spawn tasks that write to the sender + /// retrieved from [`Self::tx`], for examples, see the document + /// of this type. pub fn spawn(&mut self, task: F) where F: Future>, @@ -199,8 +248,9 @@ impl RecordBatchReceiverStreamBuilder { /// Spawn a blocking task that will be aborted if this builder (or the stream /// built from it) are dropped /// - /// this is often used to spawn tasks that write to the sender - /// retrieved from `Self::tx` + /// This is often used to spawn tasks that write to the sender + /// retrieved from [`Self::tx`], for examples, see the document + /// of this type. pub fn spawn_blocking(&mut self, f: F) where F: FnOnce() -> Result<()>, @@ -209,7 +259,7 @@ impl RecordBatchReceiverStreamBuilder { self.inner.spawn_blocking(f) } - /// runs the input_partition of the `input` ExecutionPlan on the + /// runs the `partition` of the `input` ExecutionPlan on the /// tokio threadpool and writes its outputs to this stream /// /// If the input partition produces an error, the error will be @@ -339,7 +389,7 @@ where } } -/// EmptyRecordBatchStream can be used to create a RecordBatchStream +/// `EmptyRecordBatchStream` can be used to create a [`RecordBatchStream`] /// that will produce no results pub struct EmptyRecordBatchStream { /// Schema wrapped by Arc From d14f766b5d3c117d5027f032b38837ec4e285b8a Mon Sep 17 00:00:00 2001 From: Zhong Xu Date: Thu, 18 Jan 2024 05:57:34 -0800 Subject: [PATCH 20/39] Fix "Projection references non-aggregate values" by updating `rebase_expr` to use `transform_down` (#8890) * transform_down * add test --- datafusion/sql/src/utils.rs | 2 +- datafusion/sqllogictest/test_files/group_by.slt | 7 ++++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs index 616a2fc74932..0dc1258ebabe 100644 --- a/datafusion/sql/src/utils.rs +++ b/datafusion/sql/src/utils.rs @@ -66,7 +66,7 @@ pub(crate) fn rebase_expr( base_exprs: &[Expr], plan: &LogicalPlan, ) -> Result { - expr.clone().transform_up(&|nested_expr| { + expr.clone().transform_down(&|nested_expr| { if base_exprs.contains(&nested_expr) { Ok(Transformed::Yes(expr_as_column_expr(&nested_expr, plan)?)) } else { diff --git a/datafusion/sqllogictest/test_files/group_by.slt b/datafusion/sqllogictest/test_files/group_by.slt index 79e6a9357b40..c75929049c18 100644 --- a/datafusion/sqllogictest/test_files/group_by.slt +++ b/datafusion/sqllogictest/test_files/group_by.slt @@ -4853,4 +4853,9 @@ statement error DataFusion error: This feature is not implemented: Conflicting o SELECT a, b, NTH_VALUE(c, 2 ORDER BY c ASC), NTH_VALUE(c, 3 ORDER BY d ASC) FROM multiple_ordered_table GROUP BY a, b -ORDER BY a, b; \ No newline at end of file +ORDER BY a, b; + +query II +SELECT a + 1 AS d, a + 1 + b AS c FROM (SELECT 1 AS a, 2 AS b) GROUP BY a + 1, a + 1 + b; +---- +2 4 From 9b78efaeccd57ab8b8e20c29174a121af8130376 Mon Sep 17 00:00:00 2001 From: Tushushu Date: Fri, 19 Jan 2024 04:24:48 +0800 Subject: [PATCH 21/39] Add serde support for Arrow FileTypeWriterOptions (#8850) * refactor * generated files * feat * feat * feat * feat * tests * clippy --------- Co-authored-by: Andrew Lamb --- .../common/src/file_options/arrow_writer.rs | 12 +++ datafusion/proto/proto/datafusion.proto | 3 + datafusion/proto/src/generated/pbjson.rs | 85 +++++++++++++++++++ datafusion/proto/src/generated/prost.rs | 7 +- datafusion/proto/src/logical_plan/mod.rs | 19 +++++ .../proto/src/physical_plan/from_proto.rs | 5 ++ .../tests/cases/roundtrip_logical_plan.rs | 40 +++++++++ 7 files changed, 170 insertions(+), 1 deletion(-) diff --git a/datafusion/common/src/file_options/arrow_writer.rs b/datafusion/common/src/file_options/arrow_writer.rs index a30e6d800e20..cb921535aba5 100644 --- a/datafusion/common/src/file_options/arrow_writer.rs +++ b/datafusion/common/src/file_options/arrow_writer.rs @@ -27,6 +27,18 @@ use super::StatementOptions; #[derive(Clone, Debug)] pub struct ArrowWriterOptions {} +impl ArrowWriterOptions { + pub fn new() -> Self { + Self {} + } +} + +impl Default for ArrowWriterOptions { + fn default() -> Self { + Self::new() + } +} + impl TryFrom<(&ConfigOptions, &StatementOptions)> for ArrowWriterOptions { type Error = DataFusionError; diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 8bde0da133eb..d79879e57a7d 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -1213,6 +1213,7 @@ message FileTypeWriterOptions { JsonWriterOptions json_options = 1; ParquetWriterOptions parquet_options = 2; CsvWriterOptions csv_options = 3; + ArrowWriterOptions arrow_options = 4; } } @@ -1243,6 +1244,8 @@ message CsvWriterOptions { string null_value = 8; } +message ArrowWriterOptions {} + message WriterProperties { uint64 data_page_size_limit = 1; uint64 dictionary_page_size_limit = 2; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 528761136ca3..d7ad6fb03c92 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -1929,6 +1929,77 @@ impl<'de> serde::Deserialize<'de> for ArrowType { deserializer.deserialize_struct("datafusion.ArrowType", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for ArrowWriterOptions { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let len = 0; + let struct_ser = serializer.serialize_struct("datafusion.ArrowWriterOptions", len)?; + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for ArrowWriterOptions { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + Err(serde::de::Error::unknown_field(value, FIELDS)) + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = ArrowWriterOptions; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.ArrowWriterOptions") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + while map_.next_key::()?.is_some() { + let _ = map_.next_value::()?; + } + Ok(ArrowWriterOptions { + }) + } + } + deserializer.deserialize_struct("datafusion.ArrowWriterOptions", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for AvroFormat { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -8354,6 +8425,9 @@ impl serde::Serialize for FileTypeWriterOptions { file_type_writer_options::FileType::CsvOptions(v) => { struct_ser.serialize_field("csvOptions", v)?; } + file_type_writer_options::FileType::ArrowOptions(v) => { + struct_ser.serialize_field("arrowOptions", v)?; + } } } struct_ser.end() @@ -8372,6 +8446,8 @@ impl<'de> serde::Deserialize<'de> for FileTypeWriterOptions { "parquetOptions", "csv_options", "csvOptions", + "arrow_options", + "arrowOptions", ]; #[allow(clippy::enum_variant_names)] @@ -8379,6 +8455,7 @@ impl<'de> serde::Deserialize<'de> for FileTypeWriterOptions { JsonOptions, ParquetOptions, CsvOptions, + ArrowOptions, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -8403,6 +8480,7 @@ impl<'de> serde::Deserialize<'de> for FileTypeWriterOptions { "jsonOptions" | "json_options" => Ok(GeneratedField::JsonOptions), "parquetOptions" | "parquet_options" => Ok(GeneratedField::ParquetOptions), "csvOptions" | "csv_options" => Ok(GeneratedField::CsvOptions), + "arrowOptions" | "arrow_options" => Ok(GeneratedField::ArrowOptions), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -8444,6 +8522,13 @@ impl<'de> serde::Deserialize<'de> for FileTypeWriterOptions { return Err(serde::de::Error::duplicate_field("csvOptions")); } file_type__ = map_.next_value::<::std::option::Option<_>>()?.map(file_type_writer_options::FileType::CsvOptions) +; + } + GeneratedField::ArrowOptions => { + if file_type__.is_some() { + return Err(serde::de::Error::duplicate_field("arrowOptions")); + } + file_type__ = map_.next_value::<::std::option::Option<_>>()?.map(file_type_writer_options::FileType::ArrowOptions) ; } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 9a0b7ab332a6..d594da90879c 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1646,7 +1646,7 @@ pub struct PartitionColumn { #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct FileTypeWriterOptions { - #[prost(oneof = "file_type_writer_options::FileType", tags = "1, 2, 3")] + #[prost(oneof = "file_type_writer_options::FileType", tags = "1, 2, 3, 4")] pub file_type: ::core::option::Option, } /// Nested message and enum types in `FileTypeWriterOptions`. @@ -1660,6 +1660,8 @@ pub mod file_type_writer_options { ParquetOptions(super::ParquetWriterOptions), #[prost(message, tag = "3")] CsvOptions(super::CsvWriterOptions), + #[prost(message, tag = "4")] + ArrowOptions(super::ArrowWriterOptions), } } #[allow(clippy::derive_partial_eq_without_eq)] @@ -1704,6 +1706,9 @@ pub struct CsvWriterOptions { } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] +pub struct ArrowWriterOptions {} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] pub struct WriterProperties { #[prost(uint64, tag = "1")] pub data_page_size_limit: u64, diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index 6ca95519a9b1..f10f11c1c093 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -16,6 +16,7 @@ // under the License. use arrow::csv::WriterBuilder; +use datafusion_common::file_options::arrow_writer::ArrowWriterOptions; use std::collections::HashMap; use std::fmt::Debug; use std::str::FromStr; @@ -858,6 +859,13 @@ impl AsLogicalPlan for LogicalPlanNode { Some(copy_to_node::CopyOptions::WriterOptions(opt)) => { match &opt.file_type { Some(ft) => match ft { + file_type_writer_options::FileType::ArrowOptions(_) => { + CopyOptions::WriterOptions(Box::new( + FileTypeWriterOptions::Arrow( + ArrowWriterOptions::new(), + ), + )) + } file_type_writer_options::FileType::CsvOptions( writer_options, ) => { @@ -1659,6 +1667,17 @@ impl AsLogicalPlan for LogicalPlanNode { } CopyOptions::WriterOptions(opt) => { match opt.as_ref() { + FileTypeWriterOptions::Arrow(_) => { + let arrow_writer_options = + file_type_writer_options::FileType::ArrowOptions( + protobuf::ArrowWriterOptions {}, + ); + Some(copy_to_node::CopyOptions::WriterOptions( + protobuf::FileTypeWriterOptions { + file_type: Some(arrow_writer_options), + }, + )) + } FileTypeWriterOptions::CSV(csv_opts) => { let csv_options = &csv_opts.writer_options; let csv_writer_options = csv_writer_options_to_proto( diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index ea28eeee8810..dc827d02bf25 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -42,6 +42,7 @@ use datafusion::physical_plan::windows::create_window_expr; use datafusion::physical_plan::{ functions, ColumnStatistics, Partitioning, PhysicalExpr, Statistics, WindowExpr, }; +use datafusion_common::file_options::arrow_writer::ArrowWriterOptions; use datafusion_common::file_options::csv_writer::CsvWriterOptions; use datafusion_common::file_options::json_writer::JsonWriterOptions; use datafusion_common::file_options::parquet_writer::ParquetWriterOptions; @@ -834,6 +835,10 @@ impl TryFrom<&protobuf::FileTypeWriterOptions> for FileTypeWriterOptions { .ok_or_else(|| proto_error("Missing required file_type field in protobuf"))?; match file_type { + protobuf::file_type_writer_options::FileType::ArrowOptions(_) => { + Ok(Self::Arrow(ArrowWriterOptions::new())) + } + protobuf::file_type_writer_options::FileType::JsonOptions(opts) => { let compression: CompressionTypeVariant = opts.compression().into(); Ok(Self::JSON(JsonWriterOptions::new(compression))) diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index ed21124a9e22..2d38cfd400ad 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -27,6 +27,7 @@ use arrow::datatypes::{ IntervalUnit, Schema, SchemaRef, TimeUnit, UnionFields, UnionMode, }; +use datafusion_common::file_options::arrow_writer::ArrowWriterOptions; use prost::Message; use datafusion::datasource::provider::TableProviderFactory; @@ -394,6 +395,45 @@ async fn roundtrip_logical_plan_copy_to_writer_options() -> Result<()> { Ok(()) } +#[tokio::test] +async fn roundtrip_logical_plan_copy_to_arrow() -> Result<()> { + let ctx = SessionContext::new(); + + let input = create_csv_scan(&ctx).await?; + + let plan = LogicalPlan::Copy(CopyTo { + input: Arc::new(input), + output_url: "test.arrow".to_string(), + file_format: FileType::ARROW, + single_file_output: true, + copy_options: CopyOptions::WriterOptions(Box::new(FileTypeWriterOptions::Arrow( + ArrowWriterOptions::new(), + ))), + }); + + let bytes = logical_plan_to_bytes(&plan)?; + let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; + assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}")); + + match logical_round_trip { + LogicalPlan::Copy(copy_to) => { + assert_eq!("test.arrow", copy_to.output_url); + assert_eq!(FileType::ARROW, copy_to.file_format); + assert!(copy_to.single_file_output); + match ©_to.copy_options { + CopyOptions::WriterOptions(y) => match y.as_ref() { + FileTypeWriterOptions::Arrow(_) => {} + _ => panic!(), + }, + _ => panic!(), + } + } + _ => panic!(), + } + + Ok(()) +} + #[tokio::test] async fn roundtrip_logical_plan_copy_to_csv() -> Result<()> { let ctx = SessionContext::new(); From 3a9e23d138c935ffea68408899016c9323aa0f36 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Fri, 19 Jan 2024 04:56:58 -0500 Subject: [PATCH 22/39] Improve datafusion-cli print format tests (#8896) --- datafusion-cli/src/print_format.rs | 415 ++++++++++++++++++++--------- 1 file changed, 283 insertions(+), 132 deletions(-) diff --git a/datafusion-cli/src/print_format.rs b/datafusion-cli/src/print_format.rs index 0a8c7b4b3e3a..2de52be612bb 100644 --- a/datafusion-cli/src/print_format.rs +++ b/datafusion-cli/src/print_format.rs @@ -190,117 +190,212 @@ impl PrintFormat { #[cfg(test)] mod tests { - use std::io::{Cursor, Read, Write}; - use std::sync::Arc; - use super::*; + use std::sync::Arc; use arrow::array::{ArrayRef, Int32Array}; use arrow::datatypes::{DataType, Field, Schema}; - use datafusion::error::Result; - - fn run_test(batches: &[RecordBatch], test_fn: F) -> Result - where - F: Fn(&mut Cursor>, &[RecordBatch]) -> Result<()>, - { - let mut buffer = Cursor::new(Vec::new()); - test_fn(&mut buffer, batches)?; - buffer.set_position(0); - let mut contents = String::new(); - buffer.read_to_string(&mut contents)?; - Ok(contents) + + #[test] + fn print_empty() { + for format in [ + PrintFormat::Csv, + PrintFormat::Tsv, + PrintFormat::Table, + PrintFormat::Json, + PrintFormat::NdJson, + PrintFormat::Automatic, + ] { + // no output for empty batches, even with header set + PrintBatchesTest::new() + .with_format(format) + .with_batches(vec![]) + .with_expected(&[""]) + .run(); + } } #[test] - fn test_print_batches_with_sep() -> Result<()> { - let contents = run_test(&[], |buffer, batches| { - print_batches_with_sep(buffer, batches, b',', true) - })?; - assert_eq!(contents, ""); + fn print_csv_no_header() { + #[rustfmt::skip] + let expected = &[ + "1,4,7", + "2,5,8", + "3,6,9", + ]; - let schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Int32, false), - Field::new("b", DataType::Int32, false), - Field::new("c", DataType::Int32, false), - ])); - let batch = RecordBatch::try_new( - schema, - vec![ - Arc::new(Int32Array::from(vec![1, 2, 3])), - Arc::new(Int32Array::from(vec![4, 5, 6])), - Arc::new(Int32Array::from(vec![7, 8, 9])), - ], - )?; + PrintBatchesTest::new() + .with_format(PrintFormat::Csv) + .with_batches(split_batch(three_column_batch())) + .with_header(WithHeader::No) + .with_expected(expected) + .run(); + } - let contents = run_test(&[batch], |buffer, batches| { - print_batches_with_sep(buffer, batches, b',', true) - })?; - assert_eq!(contents, "a,b,c\n1,4,7\n2,5,8\n3,6,9\n"); + #[test] + fn print_csv_with_header() { + #[rustfmt::skip] + let expected = &[ + "a,b,c", + "1,4,7", + "2,5,8", + "3,6,9", + ]; - Ok(()) + PrintBatchesTest::new() + .with_format(PrintFormat::Csv) + .with_batches(split_batch(three_column_batch())) + .with_header(WithHeader::Yes) + .with_expected(expected) + .run(); } #[test] - fn test_print_batches_to_json_empty() -> Result<()> { - let contents = run_test(&[], |buffer, batches| { - batches_to_json!(ArrayWriter, buffer, batches) - })?; - assert_eq!(contents, ""); + fn print_tsv_no_header() { + #[rustfmt::skip] + let expected = &[ + "1\t4\t7", + "2\t5\t8", + "3\t6\t9", + ]; - let contents = run_test(&[], |buffer, batches| { - batches_to_json!(LineDelimitedWriter, buffer, batches) - })?; - assert_eq!(contents, ""); + PrintBatchesTest::new() + .with_format(PrintFormat::Tsv) + .with_batches(split_batch(three_column_batch())) + .with_header(WithHeader::No) + .with_expected(expected) + .run(); + } - let schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Int32, false), - Field::new("b", DataType::Int32, false), - Field::new("c", DataType::Int32, false), - ])); - let batch = RecordBatch::try_new( - schema, - vec![ - Arc::new(Int32Array::from(vec![1, 2, 3])), - Arc::new(Int32Array::from(vec![4, 5, 6])), - Arc::new(Int32Array::from(vec![7, 8, 9])), - ], - )?; - let batches = vec![batch]; + #[test] + fn print_tsv_with_header() { + #[rustfmt::skip] + let expected = &[ + "a\tb\tc", + "1\t4\t7", + "2\t5\t8", + "3\t6\t9", + ]; + + PrintBatchesTest::new() + .with_format(PrintFormat::Tsv) + .with_batches(split_batch(three_column_batch())) + .with_header(WithHeader::Yes) + .with_expected(expected) + .run(); + } - let contents = run_test(&batches, |buffer, batches| { - batches_to_json!(ArrayWriter, buffer, batches) - })?; - assert_eq!(contents, "[{\"a\":1,\"b\":4,\"c\":7},{\"a\":2,\"b\":5,\"c\":8},{\"a\":3,\"b\":6,\"c\":9}]\n"); + #[test] + fn print_table() { + let expected = &[ + "+---+---+---+", + "| a | b | c |", + "+---+---+---+", + "| 1 | 4 | 7 |", + "| 2 | 5 | 8 |", + "| 3 | 6 | 9 |", + "+---+---+---+", + ]; - let contents = run_test(&batches, |buffer, batches| { - batches_to_json!(LineDelimitedWriter, buffer, batches) - })?; - assert_eq!(contents, "{\"a\":1,\"b\":4,\"c\":7}\n{\"a\":2,\"b\":5,\"c\":8}\n{\"a\":3,\"b\":6,\"c\":9}\n"); + PrintBatchesTest::new() + .with_format(PrintFormat::Table) + .with_batches(split_batch(three_column_batch())) + .with_header(WithHeader::Ignored) + .with_expected(expected) + .run(); + } + #[test] + fn print_json() { + let expected = + &[r#"[{"a":1,"b":4,"c":7},{"a":2,"b":5,"c":8},{"a":3,"b":6,"c":9}]"#]; - Ok(()) + PrintBatchesTest::new() + .with_format(PrintFormat::Json) + .with_batches(split_batch(three_column_batch())) + .with_header(WithHeader::Ignored) + .with_expected(expected) + .run(); } #[test] - fn test_format_batches_with_maxrows() -> Result<()> { - let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); - let batch = RecordBatch::try_new( - schema, - vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], - )?; + fn print_ndjson() { + let expected = &[ + r#"{"a":1,"b":4,"c":7}"#, + r#"{"a":2,"b":5,"c":8}"#, + r#"{"a":3,"b":6,"c":9}"#, + ]; + + PrintBatchesTest::new() + .with_format(PrintFormat::NdJson) + .with_batches(split_batch(three_column_batch())) + .with_header(WithHeader::Ignored) + .with_expected(expected) + .run(); + } + #[test] + fn print_automatic_no_header() { #[rustfmt::skip] - let all_rows_expected = [ + let expected = &[ + "1,4,7", + "2,5,8", + "3,6,9", + ]; + + PrintBatchesTest::new() + .with_format(PrintFormat::Automatic) + .with_batches(split_batch(three_column_batch())) + .with_header(WithHeader::No) + .with_expected(expected) + .run(); + } + #[test] + fn print_automatic_with_header() { + #[rustfmt::skip] + let expected = &[ + "a,b,c", + "1,4,7", + "2,5,8", + "3,6,9", + ]; + + PrintBatchesTest::new() + .with_format(PrintFormat::Automatic) + .with_batches(split_batch(three_column_batch())) + .with_header(WithHeader::Yes) + .with_expected(expected) + .run(); + } + + #[test] + fn print_maxrows_unlimited() { + #[rustfmt::skip] + let expected = &[ "+---+", "| a |", "+---+", "| 1 |", "| 2 |", "| 3 |", - "+---+\n", - ].join("\n"); + "+---+", + ]; + + // should print out entire output with no truncation if unlimited or + // limit greater than number of batches or equal to the number of batches + for max_rows in [MaxRows::Unlimited, MaxRows::Limited(5), MaxRows::Limited(3)] { + PrintBatchesTest::new() + .with_format(PrintFormat::Table) + .with_batches(vec![one_column_batch()]) + .with_maxrows(max_rows) + .with_expected(expected) + .run(); + } + } + #[test] + fn print_maxrows_limited_one_batch() { #[rustfmt::skip] - let one_row_expected = [ + let expected = &[ "+---+", "| a |", "+---+", @@ -308,11 +403,21 @@ mod tests { "| . |", "| . |", "| . |", - "+---+\n", - ].join("\n"); + "+---+", + ]; + PrintBatchesTest::new() + .with_format(PrintFormat::Table) + .with_batches(vec![one_column_batch()]) + .with_maxrows(MaxRows::Limited(1)) + .with_expected(expected) + .run(); + } + + #[test] + fn print_maxrows_limited_multi_batched() { #[rustfmt::skip] - let multi_batches_expected = [ + let expected = &[ "+---+", "| a |", "+---+", @@ -324,42 +429,23 @@ mod tests { "| . |", "| . |", "| . |", - "+---+\n", - ].join("\n"); - - let no_limit = run_test(&[batch.clone()], |buffer, batches| { - format_batches_with_maxrows(buffer, batches, MaxRows::Unlimited) - })?; - assert_eq!(no_limit, all_rows_expected); - - let maxrows_less_than_actual = run_test(&[batch.clone()], |buffer, batches| { - format_batches_with_maxrows(buffer, batches, MaxRows::Limited(1)) - })?; - assert_eq!(maxrows_less_than_actual, one_row_expected); - - let maxrows_more_than_actual = run_test(&[batch.clone()], |buffer, batches| { - format_batches_with_maxrows(buffer, batches, MaxRows::Limited(5)) - })?; - assert_eq!(maxrows_more_than_actual, all_rows_expected); - - let maxrows_equals_actual = run_test(&[batch.clone()], |buffer, batches| { - format_batches_with_maxrows(buffer, batches, MaxRows::Limited(3)) - })?; - assert_eq!(maxrows_equals_actual, all_rows_expected); - - let multi_batches = run_test( - &[batch.clone(), batch.clone(), batch.clone()], - |buffer, batches| { - format_batches_with_maxrows(buffer, batches, MaxRows::Limited(5)) - }, - )?; - assert_eq!(multi_batches, multi_batches_expected); - - Ok(()) + "+---+", + ]; + + PrintBatchesTest::new() + .with_format(PrintFormat::Table) + .with_batches(vec![ + one_column_batch(), + one_column_batch(), + one_column_batch(), + ]) + .with_maxrows(MaxRows::Limited(5)) + .with_expected(expected) + .run(); } #[test] - fn test_print_batches_empty_batches() -> Result<()> { + fn test_print_batches_empty_batches() { let batch = one_column_batch(); let empty_batch = RecordBatch::new_empty(batch.schema()); @@ -371,7 +457,7 @@ mod tests { "| 1 |", "| 2 |", "| 3 |", - "+---+\n", + "+---+", ]; PrintBatchesTest::new() @@ -379,11 +465,10 @@ mod tests { .with_batches(vec![empty_batch.clone(), batch, empty_batch]) .with_expected(expected) .run(); - Ok(()) } #[test] - fn test_print_batches_empty_batches_no_header() -> Result<()> { + fn test_print_batches_empty_batches_no_header() { let empty_batch = RecordBatch::new_empty(one_column_batch().schema()); // empty batches should not print a header @@ -392,27 +477,36 @@ mod tests { PrintBatchesTest::new() .with_format(PrintFormat::Table) .with_batches(vec![empty_batch]) - .with_header(true) + .with_header(WithHeader::Yes) .with_expected(expected) .run(); - Ok(()) } + #[derive(Debug)] struct PrintBatchesTest { format: PrintFormat, batches: Vec, maxrows: MaxRows, - with_header: bool, + with_header: WithHeader, expected: Vec<&'static str>, } + /// How to test with_header + #[derive(Debug, Clone)] + enum WithHeader { + Yes, + No, + /// output should be the same with or without header + Ignored, + } + impl PrintBatchesTest { fn new() -> Self { Self { format: PrintFormat::Table, batches: vec![], maxrows: MaxRows::Unlimited, - with_header: false, + with_header: WithHeader::Ignored, expected: vec![], } } @@ -429,8 +523,14 @@ mod tests { self } - /// set whether to include a header - fn with_header(mut self, with_header: bool) -> Self { + /// set maxrows + fn with_maxrows(mut self, maxrows: MaxRows) -> Self { + self.maxrows = maxrows; + self + } + + /// set with_header + fn with_header(mut self, with_header: WithHeader) -> Self { self.with_header = with_header; self } @@ -443,17 +543,58 @@ mod tests { /// run the test fn run(self) { - let mut buffer: Vec = vec![]; - self.format - .print_batches(&mut buffer, &self.batches, self.maxrows, self.with_header) - .unwrap(); - let actual = String::from_utf8(buffer).unwrap(); - let expected = self.expected.join("\n"); + let actual = self.output(); + let actual: Vec<_> = actual.trim_end().split('\n').collect(); + let expected = self.expected; assert_eq!( actual, expected, - "actual:\n\n{actual}expected:\n\n{expected}" + "\n\nactual:\n{actual:#?}\n\nexpected:\n{expected:#?}" ); } + + /// formats batches using parameters and returns the resulting output + fn output(&self) -> String { + match self.with_header { + WithHeader::Yes => self.output_with_header(true), + WithHeader::No => self.output_with_header(false), + WithHeader::Ignored => { + let output = self.output_with_header(true); + // ensure the output is the same without header + let output_without_header = self.output_with_header(false); + assert_eq!( + output, output_without_header, + "Expected output to be the same with or without header" + ); + output + } + } + } + + fn output_with_header(&self, with_header: bool) -> String { + let mut buffer: Vec = vec![]; + self.format + .print_batches(&mut buffer, &self.batches, self.maxrows, with_header) + .unwrap(); + String::from_utf8(buffer).unwrap() + } + } + + /// Return a batch with three columns and three rows + fn three_column_batch() -> RecordBatch { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + Field::new("c", DataType::Int32, false), + ])); + RecordBatch::try_new( + schema, + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(Int32Array::from(vec![4, 5, 6])), + Arc::new(Int32Array::from(vec![7, 8, 9])), + ], + ) + .unwrap() } /// return a batch with one column and three rows @@ -464,4 +605,14 @@ mod tests { )]) .unwrap() } + + /// Slice the record batch into 2 batches + fn split_batch(batch: RecordBatch) -> Vec { + assert!(batch.num_rows() > 1); + let split = batch.num_rows() / 2; + vec![ + batch.slice(0, split), + batch.slice(split, batch.num_rows() - split), + ] + } } From a78692127ae4620025e624c94c25f1996af48999 Mon Sep 17 00:00:00 2001 From: Matt Gapp <61894094+matthewgapp@users.noreply.github.com> Date: Fri, 19 Jan 2024 12:07:43 -0800 Subject: [PATCH 23/39] Recursive CTEs: Stage 2 - add support for sql -> logical plan generation (#8839) * add config flag for recursive ctes update docs from script update slt test for doc change * restore testing pin * add sql -> logical plan support * impl cte as work table * move SharedState to continuance * impl WorkTableState wip: readying pr to implement only logical plan fix sql integration test wip: add sql test for logical plan wip: format test assertion * wip: remove uncessary with qualifier method some docs more docs * Add comments to `RecursiveQuery` * Update datfusion-cli Cargo.lock * Fix clippy * better errors and comments * add doc comment with rationale for create_cte_worktable method * wip: tweak --------- Co-authored-by: Andrew Lamb --- datafusion-cli/Cargo.lock | 66 ++++---- .../core/src/datasource/cte_worktable.rs | 97 ++++++++++++ datafusion/core/src/datasource/mod.rs | 1 + datafusion/core/src/execution/context/mod.rs | 13 ++ datafusion/core/src/physical_planner.rs | 7 +- datafusion/expr/src/logical_plan/builder.rs | 24 +++ datafusion/expr/src/logical_plan/mod.rs | 4 +- datafusion/expr/src/logical_plan/plan.rs | 69 ++++++++ .../optimizer/src/common_subexpr_eliminate.rs | 1 + .../optimizer/src/optimize_projections.rs | 1 + datafusion/proto/src/logical_plan/mod.rs | 3 + datafusion/sql/src/planner.rs | 13 ++ datafusion/sql/src/query.rs | 148 +++++++++++++++--- datafusion/sql/tests/sql_integration.rs | 51 +++++- 14 files changed, 431 insertions(+), 67 deletions(-) create mode 100644 datafusion/core/src/datasource/cte_worktable.rs diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 5663e736dbd8..db5913503aaa 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -360,9 +360,9 @@ dependencies = [ [[package]] name = "async-compression" -version = "0.4.5" +version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc2d0cfb2a7388d34f590e76686704c494ed7aaceed62ee1ba35cbf363abc2a5" +checksum = "a116f46a969224200a0a97f29cfd4c50e7534e4b4826bd23ea2c3c533039c82c" dependencies = [ "bzip2", "flate2", @@ -733,9 +733,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bitflags" -version = "2.4.1" +version = "2.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "327762f6e5a765692301e5bb513e0d9fef63be86bbc14528052b1cd3e6f03e07" +checksum = "ed570934406eb16438a4e976b1b4500774099c13b8cb96eec99f620f05090ddf" [[package]] name = "blake2" @@ -1125,7 +1125,7 @@ dependencies = [ "half", "hashbrown 0.14.3", "indexmap 2.1.0", - "itertools 0.12.0", + "itertools", "log", "num-traits", "num_cpus", @@ -1235,7 +1235,7 @@ dependencies = [ "datafusion-expr", "datafusion-physical-expr", "hashbrown 0.14.3", - "itertools 0.12.0", + "itertools", "log", "regex-syntax", ] @@ -1260,7 +1260,7 @@ dependencies = [ "hashbrown 0.14.3", "hex", "indexmap 2.1.0", - "itertools 0.12.0", + "itertools", "log", "md-5", "paste", @@ -1291,7 +1291,7 @@ dependencies = [ "half", "hashbrown 0.14.3", "indexmap 2.1.0", - "itertools 0.12.0", + "itertools", "log", "once_cell", "parking_lot", @@ -1652,9 +1652,9 @@ checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" [[package]] name = "h2" -version = "0.3.23" +version = "0.3.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b553656127a00601c8ae5590fcfdc118e4083a7924b6cf4ffc1ea4b99dc429d7" +checksum = "bb2c4422095b67ee78da96fbb51a4cc413b3b25883c7717ff7ca1ab31022c9c9" dependencies = [ "bytes", "fnv", @@ -1722,9 +1722,9 @@ dependencies = [ [[package]] name = "hermit-abi" -version = "0.3.3" +version = "0.3.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d77f7ec81a6d05a3abb01ab6eb7590f6083d08449fe5a1c8b1e620283546ccb7" +checksum = "5d3d0e0f38255e7fa3cf31335b3a56f05febd18025f4db5ef7a0cfb4f8da651f" [[package]] name = "hex" @@ -1908,15 +1908,6 @@ version = "2.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8f518f335dce6725a761382244631d86cf0ccb2863413590b31338feb467f9c3" -[[package]] -name = "itertools" -version = "0.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1c173a5686ce8bfa551b3563d0c2170bf24ca44da99c7ca4bfdab5418c3fe57" -dependencies = [ - "either", -] - [[package]] name = "itertools" version = "0.12.0" @@ -2072,16 +2063,16 @@ version = "0.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "85c833ca1e66078851dba29046874e38f08b2c883700aa29a03ddd3b23814ee8" dependencies = [ - "bitflags 2.4.1", + "bitflags 2.4.2", "libc", "redox_syscall", ] [[package]] name = "linux-raw-sys" -version = "0.4.12" +version = "0.4.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4cd1a83af159aa67994778be9070f0ae1bd732942279cabb14f86f986a21456" +checksum = "01cda141df6706de531b6c46c3a33ecca755538219bd484262fa09410c13539c" [[package]] name = "lock_api" @@ -2279,7 +2270,7 @@ version = "1.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" dependencies = [ - "hermit-abi 0.3.3", + "hermit-abi 0.3.4", "libc", ] @@ -2305,7 +2296,7 @@ dependencies = [ "futures", "humantime", "hyper", - "itertools 0.12.0", + "itertools", "parking_lot", "percent-encoding", "quick-xml", @@ -2516,9 +2507,9 @@ checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" [[package]] name = "pkg-config" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69d3587f8a9e599cc7ec2c00e331f71c4e69a5f9a4b8a6efd5b07466b9736f9a" +checksum = "2900ede94e305130c13ddd391e0ab7cbaeb783945ae07a279c268cb05109c6cb" [[package]] name = "powerfmt" @@ -2534,14 +2525,13 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" [[package]] name = "predicates" -version = "3.0.4" +version = "3.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6dfc28575c2e3f19cb3c73b93af36460ae898d426eba6fc15b9bd2a5220758a0" +checksum = "68b87bfd4605926cdfefc1c3b5f8fe560e3feca9d5552cf68c466d3d8236c7e8" dependencies = [ "anstyle", "difflib", "float-cmp", - "itertools 0.11.0", "normalize-line-endings", "predicates-core", "regex", @@ -2836,11 +2826,11 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.29" +version = "0.38.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0a1a81a2478639a14e68937903356dbac62cf52171148924f754bb8a8cd7a96c" +checksum = "322394588aaf33c24007e8bb3238ee3e4c5c09c084ab32bc73890b99ff326bca" dependencies = [ - "bitflags 2.4.1", + "bitflags 2.4.2", "errno", "libc", "linux-raw-sys", @@ -3102,9 +3092,9 @@ dependencies = [ [[package]] name = "smallvec" -version = "1.11.2" +version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4dccd0940a2dcdf68d092b8cbab7dc0ad8fa938bf95787e1b916b0e3d0e8e970" +checksum = "2593d31f82ead8df961d8bd23a64c2ccf2eb5dd34b0a34bfb4dd54011c72009e" [[package]] name = "snafu" @@ -3563,9 +3553,9 @@ checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" [[package]] name = "unicode-bidi" -version = "0.3.14" +version = "0.3.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f2528f27a9eb2b21e69c95319b30bd0efd85d09c379741b0f78ea1d86be2416" +checksum = "08f95100a766bf4f8f28f90d77e0a5461bbdb219042e7679bebe79004fed8d75" [[package]] name = "unicode-ident" diff --git a/datafusion/core/src/datasource/cte_worktable.rs b/datafusion/core/src/datasource/cte_worktable.rs new file mode 100644 index 000000000000..de13e73e003b --- /dev/null +++ b/datafusion/core/src/datasource/cte_worktable.rs @@ -0,0 +1,97 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! CteWorkTable implementation used for recursive queries + +use std::any::Any; +use std::sync::Arc; + +use arrow::datatypes::SchemaRef; +use async_trait::async_trait; +use datafusion_common::not_impl_err; + +use crate::{ + error::Result, + logical_expr::{Expr, LogicalPlan, TableProviderFilterPushDown}, + physical_plan::ExecutionPlan, +}; + +use datafusion_common::DataFusionError; + +use crate::datasource::{TableProvider, TableType}; +use crate::execution::context::SessionState; + +/// The temporary working table where the previous iteration of a recursive query is stored +/// Naming is based on PostgreSQL's implementation. +/// See here for more details: www.postgresql.org/docs/11/queries-with.html#id-1.5.6.12.5.4 +pub struct CteWorkTable { + /// The name of the CTE work table + // WIP, see https://github.com/apache/arrow-datafusion/issues/462 + #[allow(dead_code)] + name: String, + /// This schema must be shared across both the static and recursive terms of a recursive query + table_schema: SchemaRef, +} + +impl CteWorkTable { + /// construct a new CteWorkTable with the given name and schema + /// This schema must match the schema of the recursive term of the query + /// Since the scan method will contain an physical plan that assumes this schema + pub fn new(name: &str, table_schema: SchemaRef) -> Self { + Self { + name: name.to_owned(), + table_schema, + } + } +} + +#[async_trait] +impl TableProvider for CteWorkTable { + fn as_any(&self) -> &dyn Any { + self + } + + fn get_logical_plan(&self) -> Option<&LogicalPlan> { + None + } + + fn schema(&self) -> SchemaRef { + self.table_schema.clone() + } + + fn table_type(&self) -> TableType { + TableType::Temporary + } + + async fn scan( + &self, + _state: &SessionState, + _projection: Option<&Vec>, + _filters: &[Expr], + _limit: Option, + ) -> Result> { + not_impl_err!("scan not implemented for CteWorkTable yet") + } + + fn supports_filter_pushdown( + &self, + _filter: &Expr, + ) -> Result { + // TODO: should we support filter pushdown? + Ok(TableProviderFilterPushDown::Unsupported) + } +} diff --git a/datafusion/core/src/datasource/mod.rs b/datafusion/core/src/datasource/mod.rs index 2e516cc36a01..8f20da183a93 100644 --- a/datafusion/core/src/datasource/mod.rs +++ b/datafusion/core/src/datasource/mod.rs @@ -20,6 +20,7 @@ //! [`ListingTable`]: crate::datasource::listing::ListingTable pub mod avro_to_arrow; +pub mod cte_worktable; pub mod default_table_source; pub mod empty; pub mod file_format; diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 1e378541b624..9b623d7a51ec 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -26,6 +26,7 @@ mod parquet; use crate::{ catalog::{CatalogList, MemoryCatalogList}, datasource::{ + cte_worktable::CteWorkTable, function::{TableFunction, TableFunctionImpl}, listing::{ListingOptions, ListingTable}, provider::TableProviderFactory, @@ -1899,6 +1900,18 @@ impl<'a> ContextProvider for SessionContextProvider<'a> { Ok(provider_as_source(provider)) } + /// Create a new CTE work table for a recursive CTE logical plan + /// This table will be used in conjunction with a Worktable physical plan + /// to read and write each iteration of a recursive CTE + fn create_cte_work_table( + &self, + name: &str, + schema: SchemaRef, + ) -> Result> { + let table = Arc::new(CteWorkTable::new(name, schema)); + Ok(provider_as_source(table)) + } + fn get_function_meta(&self, name: &str) -> Option> { self.state.scalar_functions().get(name).cloned() } diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 98390ac271d0..bc448fe06fcf 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -87,8 +87,8 @@ use datafusion_expr::expr::{ use datafusion_expr::expr_rewriter::unnormalize_cols; use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary; use datafusion_expr::{ - DescribeTable, DmlStatement, ScalarFunctionDefinition, StringifiedPlan, WindowFrame, - WindowFrameBound, WriteOp, + DescribeTable, DmlStatement, RecursiveQuery, ScalarFunctionDefinition, + StringifiedPlan, WindowFrame, WindowFrameBound, WriteOp, }; use datafusion_physical_expr::expressions::Literal; use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; @@ -1290,6 +1290,9 @@ impl DefaultPhysicalPlanner { Ok(plan) } } + LogicalPlan::RecursiveQuery(RecursiveQuery { name: _, static_term: _, recursive_term: _, is_distinct: _,.. }) => { + not_impl_err!("Physical counterpart of RecursiveQuery is not implemented yet") + } }; exec_plan }.boxed() diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 847fbbbf61c7..eb5e5bd42634 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -55,6 +55,8 @@ use datafusion_common::{ ScalarValue, TableReference, ToDFSchema, UnnestOptions, }; +use super::plan::RecursiveQuery; + /// Default table name for unnamed table pub const UNNAMED_TABLE: &str = "?table?"; @@ -121,6 +123,28 @@ impl LogicalPlanBuilder { })) } + /// Convert a regular plan into a recursive query. + /// `is_distinct` indicates whether the recursive term should be de-duplicated (`UNION`) after each iteration or not (`UNION ALL`). + pub fn to_recursive_query( + &self, + name: String, + recursive_term: LogicalPlan, + is_distinct: bool, + ) -> Result { + // TODO: we need to do a bunch of validation here. Maybe more. + if is_distinct { + return Err(DataFusionError::NotImplemented( + "Recursive queries with a distinct 'UNION' (in which the previous iteration's results will be de-duplicated) is not supported".to_string(), + )); + } + Ok(Self::from(LogicalPlan::RecursiveQuery(RecursiveQuery { + name, + static_term: Arc::new(self.plan.clone()), + recursive_term: Arc::new(recursive_term), + is_distinct, + }))) + } + /// Create a values list based relation, and the schema is inferred from data, consuming /// `value`. See the [Postgres VALUES](https://www.postgresql.org/docs/current/queries-values.html) /// documentation for more details. diff --git a/datafusion/expr/src/logical_plan/mod.rs b/datafusion/expr/src/logical_plan/mod.rs index bc722dd69ace..f6e6000897a5 100644 --- a/datafusion/expr/src/logical_plan/mod.rs +++ b/datafusion/expr/src/logical_plan/mod.rs @@ -36,8 +36,8 @@ pub use plan::{ projection_schema, Aggregate, Analyze, CrossJoin, DescribeTable, Distinct, DistinctOn, EmptyRelation, Explain, Extension, Filter, Join, JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, PlanType, Prepare, Projection, - Repartition, Sort, StringifiedPlan, Subquery, SubqueryAlias, TableScan, - ToStringifiedPlan, Union, Unnest, Values, Window, + RecursiveQuery, Repartition, Sort, StringifiedPlan, Subquery, SubqueryAlias, + TableScan, ToStringifiedPlan, Union, Unnest, Values, Window, }; pub use statement::{ SetVariable, Statement, TransactionAccessMode, TransactionConclusion, TransactionEnd, diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 93a38fb40df5..5ab8a9c99cd0 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -154,6 +154,8 @@ pub enum LogicalPlan { /// Unnest a column that contains a nested list type such as an /// ARRAY. This is used to implement SQL `UNNEST` Unnest(Unnest), + /// A variadic query (e.g. "Recursive CTEs") + RecursiveQuery(RecursiveQuery), } impl LogicalPlan { @@ -191,6 +193,10 @@ impl LogicalPlan { LogicalPlan::Copy(CopyTo { input, .. }) => input.schema(), LogicalPlan::Ddl(ddl) => ddl.schema(), LogicalPlan::Unnest(Unnest { schema, .. }) => schema, + LogicalPlan::RecursiveQuery(RecursiveQuery { static_term, .. }) => { + // we take the schema of the static term as the schema of the entire recursive query + static_term.schema() + } } } @@ -243,6 +249,10 @@ impl LogicalPlan { | LogicalPlan::TableScan(_) => { vec![self.schema()] } + LogicalPlan::RecursiveQuery(RecursiveQuery { static_term, .. }) => { + // return only the schema of the static term + static_term.all_schemas() + } // return children schemas LogicalPlan::Limit(_) | LogicalPlan::Subquery(_) @@ -384,6 +394,7 @@ impl LogicalPlan { .try_for_each(f), // plans without expressions LogicalPlan::EmptyRelation(_) + | LogicalPlan::RecursiveQuery(_) | LogicalPlan::Subquery(_) | LogicalPlan::SubqueryAlias(_) | LogicalPlan::Limit(_) @@ -430,6 +441,11 @@ impl LogicalPlan { LogicalPlan::Ddl(ddl) => ddl.inputs(), LogicalPlan::Unnest(Unnest { input, .. }) => vec![input], LogicalPlan::Prepare(Prepare { input, .. }) => vec![input], + LogicalPlan::RecursiveQuery(RecursiveQuery { + static_term, + recursive_term, + .. + }) => vec![static_term, recursive_term], // plans without inputs LogicalPlan::TableScan { .. } | LogicalPlan::Statement { .. } @@ -510,6 +526,9 @@ impl LogicalPlan { cross.left.head_output_expr() } } + LogicalPlan::RecursiveQuery(RecursiveQuery { static_term, .. }) => { + static_term.head_output_expr() + } LogicalPlan::Union(union) => Ok(Some(Expr::Column( union.schema.fields()[0].qualified_column(), ))), @@ -835,6 +854,14 @@ impl LogicalPlan { }; Ok(LogicalPlan::Distinct(distinct)) } + LogicalPlan::RecursiveQuery(RecursiveQuery { + name, is_distinct, .. + }) => Ok(LogicalPlan::RecursiveQuery(RecursiveQuery { + name: name.clone(), + static_term: Arc::new(inputs[0].clone()), + recursive_term: Arc::new(inputs[1].clone()), + is_distinct: *is_distinct, + })), LogicalPlan::Analyze(a) => { assert!(expr.is_empty()); assert_eq!(inputs.len(), 1); @@ -1073,6 +1100,7 @@ impl LogicalPlan { }), LogicalPlan::TableScan(TableScan { fetch, .. }) => *fetch, LogicalPlan::EmptyRelation(_) => Some(0), + LogicalPlan::RecursiveQuery(_) => None, LogicalPlan::Subquery(_) => None, LogicalPlan::SubqueryAlias(SubqueryAlias { input, .. }) => input.max_rows(), LogicalPlan::Limit(Limit { fetch, .. }) => *fetch, @@ -1408,6 +1436,11 @@ impl LogicalPlan { fn fmt(&self, f: &mut Formatter) -> fmt::Result { match self.0 { LogicalPlan::EmptyRelation(_) => write!(f, "EmptyRelation"), + LogicalPlan::RecursiveQuery(RecursiveQuery { + is_distinct, .. + }) => { + write!(f, "RecursiveQuery: is_distinct={}", is_distinct) + } LogicalPlan::Values(Values { ref values, .. }) => { let str_values: Vec<_> = values .iter() @@ -1718,6 +1751,42 @@ pub struct EmptyRelation { pub schema: DFSchemaRef, } +/// A variadic query operation, Recursive CTE. +/// +/// # Recursive Query Evaluation +/// +/// From the [Postgres Docs]: +/// +/// 1. Evaluate the non-recursive term. For `UNION` (but not `UNION ALL`), +/// discard duplicate rows. Include all remaining rows in the result of the +/// recursive query, and also place them in a temporary working table. +// +/// 2. So long as the working table is not empty, repeat these steps: +/// +/// * Evaluate the recursive term, substituting the current contents of the +/// working table for the recursive self-reference. For `UNION` (but not `UNION +/// ALL`), discard duplicate rows and rows that duplicate any previous result +/// row. Include all remaining rows in the result of the recursive query, and +/// also place them in a temporary intermediate table. +/// +/// * Replace the contents of the working table with the contents of the +/// intermediate table, then empty the intermediate table. +/// +/// [Postgres Docs]: https://www.postgresql.org/docs/current/queries-with.html#QUERIES-WITH-RECURSIVE +#[derive(Clone, PartialEq, Eq, Hash)] +pub struct RecursiveQuery { + /// Name of the query + pub name: String, + /// The static term (initial contents of the working table) + pub static_term: Arc, + /// The recursive term (evaluated on the contents of the working table until + /// it returns an empty set) + pub recursive_term: Arc, + /// Should the output of the recursive term be deduplicated (`UNION`) or + /// not (`UNION ALL`). + pub is_distinct: bool, +} + /// Values expression. See /// [Postgres VALUES](https://www.postgresql.org/docs/current/queries-values.html) /// documentation for more details. diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index fc867df23c36..f29c7406acc9 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -365,6 +365,7 @@ impl OptimizerRule for CommonSubexprEliminate { | LogicalPlan::Dml(_) | LogicalPlan::Copy(_) | LogicalPlan::Unnest(_) + | LogicalPlan::RecursiveQuery(_) | LogicalPlan::Prepare(_) => { // apply the optimization to all inputs of the plan utils::optimize_children(self, plan, config)? diff --git a/datafusion/optimizer/src/optimize_projections.rs b/datafusion/optimizer/src/optimize_projections.rs index d9c45510972c..ab0cb0a26551 100644 --- a/datafusion/optimizer/src/optimize_projections.rs +++ b/datafusion/optimizer/src/optimize_projections.rs @@ -163,6 +163,7 @@ fn optimize_projections( .collect::>() } LogicalPlan::EmptyRelation(_) + | LogicalPlan::RecursiveQuery(_) | LogicalPlan::Statement(_) | LogicalPlan::Values(_) | LogicalPlan::Extension(_) diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index f10f11c1c093..d95d69780301 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -1734,6 +1734,9 @@ impl AsLogicalPlan for LogicalPlanNode { LogicalPlan::DescribeTable(_) => Err(proto_error( "LogicalPlan serde is not yet implemented for DescribeTable", )), + LogicalPlan::RecursiveQuery(_) => Err(proto_error( + "LogicalPlan serde is not yet implemented for RecursiveQuery", + )), } } } diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index a04df5589b85..d4dd42edcd39 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -61,6 +61,19 @@ pub trait ContextProvider { not_impl_err!("Table Functions are not supported") } + /// This provides a worktable (an intermediate table that is used to store the results of a CTE during execution) + /// We don't directly implement this in the logical plan's ['SqlToRel`] + /// because the sql code needs access to a table that contains execution-related types that can't be a direct dependency + /// of the sql crate (namely, the `CteWorktable`). + /// The [`ContextProvider`] provides a way to "hide" this dependency. + fn create_cte_work_table( + &self, + _name: &str, + _schema: SchemaRef, + ) -> Result> { + not_impl_err!("Recursive CTE is not implemented") + } + /// Getter for a UDF description fn get_function_meta(&self, name: &str) -> Option>; /// Getter for a UDAF description diff --git a/datafusion/sql/src/query.rs b/datafusion/sql/src/query.rs index 388377e3ee6b..af0b91ae6c7e 100644 --- a/datafusion/sql/src/query.rs +++ b/datafusion/sql/src/query.rs @@ -19,6 +19,7 @@ use std::sync::Arc; use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; +use arrow::datatypes::Schema; use datafusion_common::{ not_impl_err, plan_err, sql_err, Constraints, DataFusionError, Result, ScalarValue, }; @@ -26,7 +27,8 @@ use datafusion_expr::{ CreateMemoryTable, DdlStatement, Distinct, Expr, LogicalPlan, LogicalPlanBuilder, }; use sqlparser::ast::{ - Expr as SQLExpr, Offset as SQLOffset, OrderByExpr, Query, SetExpr, Value, + Expr as SQLExpr, Offset as SQLOffset, OrderByExpr, Query, SetExpr, SetOperator, + SetQuantifier, Value, }; use sqlparser::parser::ParserError::ParserError; @@ -52,21 +54,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let set_expr = query.body; if let Some(with) = query.with { // Process CTEs from top to bottom - // do not allow self-references - if with.recursive { - if self - .context_provider - .options() - .execution - .enable_recursive_ctes - { - return plan_err!( - "Recursive CTEs are enabled but are not yet supported" - ); - } else { - return not_impl_err!("Recursive CTEs are not supported"); - } - } + + let is_recursive = with.recursive; for cte in with.cte_tables { // A `WITH` block can't use the same name more than once @@ -76,16 +65,127 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { "WITH query name {cte_name:?} specified more than once" ))); } - // create logical plan & pass backreferencing CTEs - // CTE expr don't need extend outer_query_schema - let logical_plan = - self.query_to_plan(*cte.query, &mut planner_context.clone())?; - // Each `WITH` block can change the column names in the last - // projection (e.g. "WITH table(t1, t2) AS SELECT 1, 2"). - let logical_plan = self.apply_table_alias(logical_plan, cte.alias)?; + if is_recursive { + if !self + .context_provider + .options() + .execution + .enable_recursive_ctes + { + return not_impl_err!("Recursive CTEs are not enabled"); + } + + match *cte.query.body { + SetExpr::SetOperation { + op: SetOperator::Union, + left, + right, + set_quantifier, + } => { + let distinct = set_quantifier != SetQuantifier::All; + + // Each recursive CTE consists from two parts in the logical plan: + // 1. A static term (the left hand side on the SQL, where the + // referencing to the same CTE is not allowed) + // + // 2. A recursive term (the right hand side, and the recursive + // part) + + // Since static term does not have any specific properties, it can + // be compiled as if it was a regular expression. This will + // allow us to infer the schema to be used in the recursive term. + + // ---------- Step 1: Compile the static term ------------------ + let static_plan = self + .set_expr_to_plan(*left, &mut planner_context.clone())?; + + // Since the recursive CTEs include a component that references a + // table with its name, like the example below: + // + // WITH RECURSIVE values(n) AS ( + // SELECT 1 as n -- static term + // UNION ALL + // SELECT n + 1 + // FROM values -- self reference + // WHERE n < 100 + // ) + // + // We need a temporary 'relation' to be referenced and used. PostgreSQL + // calls this a 'working table', but it is entirely an implementation + // detail and a 'real' table with that name might not even exist (as + // in the case of DataFusion). + // + // Since we can't simply register a table during planning stage (it is + // an execution problem), we'll use a relation object that preserves the + // schema of the input perfectly and also knows which recursive CTE it is + // bound to. + + // ---------- Step 2: Create a temporary relation ------------------ + // Step 2.1: Create a table source for the temporary relation + let work_table_source = + self.context_provider.create_cte_work_table( + &cte_name, + Arc::new(Schema::from(static_plan.schema().as_ref())), + )?; - planner_context.insert_cte(cte_name, logical_plan); + // Step 2.2: Create a temporary relation logical plan that will be used + // as the input to the recursive term + let work_table_plan = LogicalPlanBuilder::scan( + cte_name.to_string(), + work_table_source, + None, + )? + .build()?; + + let name = cte_name.clone(); + + // Step 2.3: Register the temporary relation in the planning context + // For all the self references in the variadic term, we'll replace it + // with the temporary relation we created above by temporarily registering + // it as a CTE. This temporary relation in the planning context will be + // replaced by the actual CTE plan once we're done with the planning. + planner_context.insert_cte(cte_name.clone(), work_table_plan); + + // ---------- Step 3: Compile the recursive term ------------------ + // this uses the named_relation we inserted above to resolve the + // relation. This ensures that the recursive term uses the named relation logical plan + // and thus the 'continuance' physical plan as its input and source + let recursive_plan = self + .set_expr_to_plan(*right, &mut planner_context.clone())?; + + // ---------- Step 4: Create the final plan ------------------ + // Step 4.1: Compile the final plan + let logical_plan = LogicalPlanBuilder::from(static_plan) + .to_recursive_query(name, recursive_plan, distinct)? + .build()?; + + let final_plan = + self.apply_table_alias(logical_plan, cte.alias)?; + + // Step 4.2: Remove the temporary relation from the planning context and replace it + // with the final plan. + planner_context.insert_cte(cte_name.clone(), final_plan); + } + _ => { + return Err(DataFusionError::SQL( + ParserError(format!("Unsupported CTE: {cte}")), + None, + )); + } + }; + } else { + // create logical plan & pass backreferencing CTEs + // CTE expr don't need extend outer_query_schema + let logical_plan = + self.query_to_plan(*cte.query, &mut planner_context.clone())?; + + // Each `WITH` block can change the column names in the last + // projection (e.g. "WITH table(t1, t2) AS SELECT 1, 2"). + let logical_plan = self.apply_table_alias(logical_plan, cte.alias)?; + + planner_context.insert_cte(cte_name, logical_plan); + } } } let plan = self.set_expr_to_plan(*(set_expr.clone()), planner_context)?; diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 44da4cd4d836..c88e2d1130ed 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -1394,11 +1394,46 @@ fn recursive_ctes() { select * from numbers;"; let err = logical_plan(sql).expect_err("query should have failed"); assert_eq!( - "This feature is not implemented: Recursive CTEs are not supported", + "This feature is not implemented: Recursive CTEs are not enabled", err.strip_backtrace() ); } +#[test] +fn recursive_ctes_enabled() { + let sql = " + WITH RECURSIVE numbers AS ( + select 1 as n + UNION ALL + select n + 1 FROM numbers WHERE N < 10 + ) + select * from numbers;"; + + // manually setting up test here so that we can enable recursive ctes + let mut context = MockContextProvider::default(); + context.options_mut().execution.enable_recursive_ctes = true; + + let planner = SqlToRel::new_with_options(&context, ParserOptions::default()); + let result = DFParser::parse_sql_with_dialect(sql, &GenericDialect {}); + let mut ast = result.unwrap(); + + let plan = planner + .statement_to_plan(ast.pop_front().unwrap()) + .expect("recursive cte plan creation failed"); + + assert_eq!( + format!("{plan:?}"), + "Projection: numbers.n\ + \n SubqueryAlias: numbers\ + \n RecursiveQuery: is_distinct=false\ + \n Projection: Int64(1) AS n\ + \n EmptyRelation\ + \n Projection: numbers.n + Int64(1)\ + \n Filter: numbers.n < Int64(10)\ + \n TableScan: numbers" + ); +} + #[test] fn select_simple_aggregate_with_groupby_and_column_is_in_aggregate_and_groupby() { quick_test( @@ -2692,6 +2727,12 @@ struct MockContextProvider { udafs: HashMap>, } +impl MockContextProvider { + fn options_mut(&mut self) -> &mut ConfigOptions { + &mut self.options + } +} + impl ContextProvider for MockContextProvider { fn get_table_source(&self, name: TableReference) -> Result> { let schema = match name.table() { @@ -2801,6 +2842,14 @@ impl ContextProvider for MockContextProvider { fn options(&self) -> &ConfigOptions { &self.options } + + fn create_cte_work_table( + &self, + _name: &str, + schema: SchemaRef, + ) -> Result> { + Ok(Arc::new(EmptyTable::new(schema))) + } } #[test] From 72b81f1f0a5095cc642b91559b0906d6aedf1752 Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Sat, 20 Jan 2024 04:08:39 +0800 Subject: [PATCH 24/39] remove null in array-append adn array-prepend (#8901) --- datafusion/physical-expr/src/array_expressions.rs | 6 ------ datafusion/sqllogictest/test_files/array.slt | 6 +++--- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index af6587631df5..b9cdcff20659 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -1031,12 +1031,6 @@ where let res = match list_array.value_type() { DataType::List(_) => concat_internal::(args)?, DataType::LargeList(_) => concat_internal::(args)?, - DataType::Null => { - return make_array(&[ - list_array.values().to_owned(), - element_array.to_owned(), - ]); - } data_type => { return generic_append_and_prepend::( list_array, diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 55cd17724565..ee9168de6482 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -1599,9 +1599,9 @@ select query ???? select array_append(arrow_cast(make_array(), 'LargeList(Null)'), 4), - array_append(make_array(), null), - array_append(make_array(1, null, 3), 4), - array_append(make_array(null, null), 1) + array_append(arrow_cast(make_array(), 'LargeList(Null)'), null), + array_append(arrow_cast(make_array(1, null, 3), 'LargeList(Int64)'), 4), + array_append(arrow_cast(make_array(null, null), 'LargeList(Null)'), 1) ; ---- [4] [] [1, , 3, 4] [, , 1] From ae0f401d89cf9dc8b717e6b95f73ed4f3be9798b Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Sat, 20 Jan 2024 04:14:44 +0800 Subject: [PATCH 25/39] Add support for FixedSizeList type in `arrow_cast`, hashing (#8344) * Add support for parsing FixedSizeList type * fix fmt * support cast fixedsizelist from list * clean comment * support cast between NULL and FixedSizedLisr * add test for FixedSizeList hash * add test for cast fixedsizelist --- datafusion/common/src/hash_utils.rs | 65 ++++++++++++++++++- datafusion/common/src/scalar.rs | 24 +++++-- datafusion/common/src/utils.rs | 17 ++++- datafusion/expr/src/utils.rs | 1 + datafusion/sql/src/expr/arrow_cast.rs | 17 +++++ .../sqllogictest/test_files/arrow_typeof.slt | 39 ++++++++++- 6 files changed, 155 insertions(+), 8 deletions(-) diff --git a/datafusion/common/src/hash_utils.rs b/datafusion/common/src/hash_utils.rs index 8dcc00ca1c29..d5a1b3ee363b 100644 --- a/datafusion/common/src/hash_utils.rs +++ b/datafusion/common/src/hash_utils.rs @@ -27,8 +27,9 @@ use arrow::{downcast_dictionary_array, downcast_primitive_array}; use arrow_buffer::i256; use crate::cast::{ - as_boolean_array, as_generic_binary_array, as_large_list_array, as_list_array, - as_primitive_array, as_string_array, as_struct_array, + as_boolean_array, as_fixed_size_list_array, as_generic_binary_array, + as_large_list_array, as_list_array, as_primitive_array, as_string_array, + as_struct_array, }; use crate::error::{DataFusionError, Result, _internal_err}; @@ -267,6 +268,38 @@ where Ok(()) } +fn hash_fixed_list_array( + array: &FixedSizeListArray, + random_state: &RandomState, + hashes_buffer: &mut [u64], +) -> Result<()> { + let values = array.values().clone(); + let value_len = array.value_length(); + let offset_size = value_len as usize / array.len(); + let nulls = array.nulls(); + let mut values_hashes = vec![0u64; values.len()]; + create_hashes(&[values], random_state, &mut values_hashes)?; + if let Some(nulls) = nulls { + for i in 0..array.len() { + if nulls.is_valid(i) { + let hash = &mut hashes_buffer[i]; + for values_hash in &values_hashes[i * offset_size..(i + 1) * offset_size] + { + *hash = combine_hashes(*hash, *values_hash); + } + } + } + } else { + for i in 0..array.len() { + let hash = &mut hashes_buffer[i]; + for values_hash in &values_hashes[i * offset_size..(i + 1) * offset_size] { + *hash = combine_hashes(*hash, *values_hash); + } + } + } + Ok(()) +} + /// Test version of `create_hashes` that produces the same value for /// all hashes (to test collisions) /// @@ -366,6 +399,10 @@ pub fn create_hashes<'a>( let array = as_large_list_array(array)?; hash_list_array(array, random_state, hashes_buffer)?; } + DataType::FixedSizeList(_,_) => { + let array = as_fixed_size_list_array(array)?; + hash_fixed_list_array(array, random_state, hashes_buffer)?; + } _ => { // This is internal because we should have caught this before. return _internal_err!( @@ -546,6 +583,30 @@ mod tests { assert_eq!(hashes[2], hashes[3]); } + #[test] + // Tests actual values of hashes, which are different if forcing collisions + #[cfg(not(feature = "force_hash_collisions"))] + fn create_hashes_for_fixed_size_list_arrays() { + let data = vec![ + Some(vec![Some(0), Some(1), Some(2)]), + None, + Some(vec![Some(3), None, Some(5)]), + Some(vec![Some(3), None, Some(5)]), + None, + Some(vec![Some(0), Some(1), Some(2)]), + ]; + let list_array = + Arc::new(FixedSizeListArray::from_iter_primitive::( + data, 3, + )) as ArrayRef; + let random_state = RandomState::with_seeds(0, 0, 0, 0); + let mut hashes = vec![0; list_array.len()]; + create_hashes(&[list_array], &random_state, &mut hashes).unwrap(); + assert_eq!(hashes[0], hashes[5]); + assert_eq!(hashes[1], hashes[4]); + assert_eq!(hashes[2], hashes[3]); + } + #[test] // Tests actual values of hashes, which are different if forcing collisions #[cfg(not(feature = "force_hash_collisions"))] diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index 20d03c70960a..99b8cff20de7 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -34,8 +34,9 @@ use crate::cast::{ }; use crate::error::{DataFusionError, Result, _internal_err, _not_impl_err}; use crate::hash_utils::create_hashes; -use crate::utils::{array_into_large_list_array, array_into_list_array}; - +use crate::utils::{ + array_into_fixed_size_list_array, array_into_large_list_array, array_into_list_array, +}; use arrow::compute::kernels::numeric::*; use arrow::util::display::{ArrayFormatter, FormatOptions}; use arrow::{ @@ -2223,9 +2224,11 @@ impl ScalarValue { let list_array = as_fixed_size_list_array(array)?; let nested_array = list_array.value(index); // Produces a single element `ListArray` with the value at `index`. - let arr = Arc::new(array_into_list_array(nested_array)); + let list_size = nested_array.len(); + let arr = + Arc::new(array_into_fixed_size_list_array(nested_array, list_size)); - ScalarValue::List(arr) + ScalarValue::FixedSizeList(arr) } DataType::Date32 => typed_cast!(array, index, Date32Array, Date32)?, DataType::Date64 => typed_cast!(array, index, Date64Array, Date64)?, @@ -2971,6 +2974,19 @@ impl TryFrom<&DataType> for ScalarValue { .to_owned() .into(), ), + // `ScalaValue::FixedSizeList` contains single element `FixedSizeList`. + DataType::FixedSizeList(field, _) => ScalarValue::FixedSizeList( + new_null_array( + &DataType::FixedSizeList( + Arc::new(Field::new("item", field.data_type().clone(), true)), + 1, + ), + 1, + ) + .as_fixed_size_list() + .to_owned() + .into(), + ), DataType::Struct(fields) => ScalarValue::Struct(None, fields.clone()), DataType::Null => ScalarValue::Null, _ => { diff --git a/datafusion/common/src/utils.rs b/datafusion/common/src/utils.rs index 0a61fce15482..d21bd464f850 100644 --- a/datafusion/common/src/utils.rs +++ b/datafusion/common/src/utils.rs @@ -25,7 +25,9 @@ use arrow::compute; use arrow::compute::{partition, SortColumn, SortOptions}; use arrow::datatypes::{Field, SchemaRef, UInt32Type}; use arrow::record_batch::RecordBatch; -use arrow_array::{Array, LargeListArray, ListArray, RecordBatchOptions}; +use arrow_array::{ + Array, FixedSizeListArray, LargeListArray, ListArray, RecordBatchOptions, +}; use arrow_schema::DataType; use sqlparser::ast::Ident; use sqlparser::dialect::GenericDialect; @@ -368,6 +370,19 @@ pub fn array_into_large_list_array(arr: ArrayRef) -> LargeListArray { ) } +pub fn array_into_fixed_size_list_array( + arr: ArrayRef, + list_size: usize, +) -> FixedSizeListArray { + let list_size = list_size as i32; + FixedSizeListArray::new( + Arc::new(Field::new("item", arr.data_type().to_owned(), true)), + list_size, + arr, + None, + ) +} + /// Wrap arrays into a single element `ListArray`. /// /// Example: diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 40c2c4705362..02479c0765bd 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -911,6 +911,7 @@ pub fn can_hash(data_type: &DataType) -> bool { } DataType::List(_) => true, DataType::LargeList(_) => true, + DataType::FixedSizeList(_, _) => true, _ => false, } } diff --git a/datafusion/sql/src/expr/arrow_cast.rs b/datafusion/sql/src/expr/arrow_cast.rs index ade8b96b5cc2..9a0d61f41c01 100644 --- a/datafusion/sql/src/expr/arrow_cast.rs +++ b/datafusion/sql/src/expr/arrow_cast.rs @@ -150,6 +150,7 @@ impl<'a> Parser<'a> { Token::Dictionary => self.parse_dictionary(), Token::List => self.parse_list(), Token::LargeList => self.parse_large_list(), + Token::FixedSizeList => self.parse_fixed_size_list(), tok => Err(make_error( self.val, &format!("finding next type, got unexpected '{tok}'"), @@ -177,6 +178,19 @@ impl<'a> Parser<'a> { )))) } + /// Parses the FixedSizeList type + fn parse_fixed_size_list(&mut self) -> Result { + self.expect_token(Token::LParen)?; + let length = self.parse_i32("FixedSizeList")?; + self.expect_token(Token::Comma)?; + let data_type = self.parse_next_type()?; + self.expect_token(Token::RParen)?; + Ok(DataType::FixedSizeList( + Arc::new(Field::new("item", data_type, true)), + length, + )) + } + /// Parses the next timeunit fn parse_time_unit(&mut self, context: &str) -> Result { match self.next_token()? { @@ -508,6 +522,7 @@ impl<'a> Tokenizer<'a> { "List" => Token::List, "LargeList" => Token::LargeList, + "FixedSizeList" => Token::FixedSizeList, "Second" => Token::TimeUnit(TimeUnit::Second), "Millisecond" => Token::TimeUnit(TimeUnit::Millisecond), @@ -598,6 +613,7 @@ enum Token { DoubleQuotedString(String), List, LargeList, + FixedSizeList, } impl Display for Token { @@ -606,6 +622,7 @@ impl Display for Token { Token::SimpleType(t) => write!(f, "{t}"), Token::List => write!(f, "List"), Token::LargeList => write!(f, "LargeList"), + Token::FixedSizeList => write!(f, "FixedSizeList"), Token::Timestamp => write!(f, "Timestamp"), Token::Time32 => write!(f, "Time32"), Token::Time64 => write!(f, "Time64"), diff --git a/datafusion/sqllogictest/test_files/arrow_typeof.slt b/datafusion/sqllogictest/test_files/arrow_typeof.slt index 5e9e7ff03d8b..afc28ecc39dc 100644 --- a/datafusion/sqllogictest/test_files/arrow_typeof.slt +++ b/datafusion/sqllogictest/test_files/arrow_typeof.slt @@ -384,4 +384,41 @@ LargeList(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, di query T select arrow_typeof(arrow_cast(make_array([1, 2, 3]), 'LargeList(LargeList(Int64))')); ---- -LargeList(Field { name: "item", data_type: LargeList(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) \ No newline at end of file +LargeList(Field { name: "item", data_type: LargeList(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) + +## FixedSizeList + +query ? +select arrow_cast(null, 'FixedSizeList(1, Int64)'); +---- +NULL + +#TODO: arrow-rs doesn't support it yet +#query ? +#select arrow_cast('1', 'FixedSizeList(1, Int64)'); +#---- +#[1] + + +query ? +select arrow_cast([1], 'FixedSizeList(1, Int64)'); +---- +[1] + +query error DataFusion error: Optimizer rule 'simplify_expressions' failed +select arrow_cast(make_array(1, 2, 3), 'FixedSizeList(4, Int64)'); + +query ? +select arrow_cast(make_array(1, 2, 3), 'FixedSizeList(3, Int64)'); +---- +[1, 2, 3] + +query T +select arrow_typeof(arrow_cast(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'), 'FixedSizeList(3, Int64)')); +---- +FixedSizeList(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, 3) + +query ? +select arrow_cast([1, 2, 3], 'FixedSizeList(3, Int64)'); +---- +[1, 2, 3] \ No newline at end of file From d0c84cc3f585906b7bdcfe7c074a95f97da20275 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 20 Jan 2024 00:05:31 -0800 Subject: [PATCH 26/39] aggregate_statistics should only optimize MIN/MAX when relation is not empty (#8914) * Fix aggregate_statistics * Add more test --- .../aggregate_statistics.rs | 62 +++++++++++++------ .../sqllogictest/test_files/aggregate.slt | 34 ++++++++++ 2 files changed, 76 insertions(+), 20 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs index 86a8cdb7b3d4..0a53c775aa89 100644 --- a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs +++ b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs @@ -197,17 +197,28 @@ fn take_optimizable_min( agg_expr: &dyn AggregateExpr, stats: &Statistics, ) -> Option<(ScalarValue, String)> { - let col_stats = &stats.column_statistics; - if let Some(casted_expr) = agg_expr.as_any().downcast_ref::() { - if casted_expr.expressions().len() == 1 { - // TODO optimize with exprs other than Column - if let Some(col_expr) = casted_expr.expressions()[0] - .as_any() - .downcast_ref::() + if let Precision::Exact(num_rows) = &stats.num_rows { + if *num_rows > 0 { + let col_stats = &stats.column_statistics; + if let Some(casted_expr) = + agg_expr.as_any().downcast_ref::() { - if let Precision::Exact(val) = &col_stats[col_expr.index()].min_value { - if !val.is_null() { - return Some((val.clone(), casted_expr.name().to_string())); + if casted_expr.expressions().len() == 1 { + // TODO optimize with exprs other than Column + if let Some(col_expr) = casted_expr.expressions()[0] + .as_any() + .downcast_ref::() + { + if let Precision::Exact(val) = + &col_stats[col_expr.index()].min_value + { + if !val.is_null() { + return Some(( + val.clone(), + casted_expr.name().to_string(), + )); + } + } } } } @@ -221,17 +232,28 @@ fn take_optimizable_max( agg_expr: &dyn AggregateExpr, stats: &Statistics, ) -> Option<(ScalarValue, String)> { - let col_stats = &stats.column_statistics; - if let Some(casted_expr) = agg_expr.as_any().downcast_ref::() { - if casted_expr.expressions().len() == 1 { - // TODO optimize with exprs other than Column - if let Some(col_expr) = casted_expr.expressions()[0] - .as_any() - .downcast_ref::() + if let Precision::Exact(num_rows) = &stats.num_rows { + if *num_rows > 0 { + let col_stats = &stats.column_statistics; + if let Some(casted_expr) = + agg_expr.as_any().downcast_ref::() { - if let Precision::Exact(val) = &col_stats[col_expr.index()].max_value { - if !val.is_null() { - return Some((val.clone(), casted_expr.name().to_string())); + if casted_expr.expressions().len() == 1 { + // TODO optimize with exprs other than Column + if let Some(col_expr) = casted_expr.expressions()[0] + .as_any() + .downcast_ref::() + { + if let Precision::Exact(val) = + &col_stats[col_expr.index()].max_value + { + if !val.is_null() { + return Some(( + val.clone(), + casted_expr.name().to_string(), + )); + } + } } } } diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 50cdebd054a7..a098c8de0d3c 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -3260,3 +3260,37 @@ query I select count(*) from (select count(*) a, count(*) b from (select 1)); ---- 1 + +# rule `aggregate_statistics` should not optimize MIN/MAX to wrong values on empty relation + +statement ok +CREATE TABLE empty(col0 INTEGER); + +query I +SELECT MIN(col0) FROM empty WHERE col0=1; +---- +NULL + +query I +SELECT MAX(col0) FROM empty WHERE col0=1; +---- +NULL + +statement ok +DROP TABLE empty; + +statement ok +CREATE TABLE t(col0 INTEGER) as VALUES(2); + +query I +SELECT MIN(col0) FROM t WHERE col0=1; +---- +NULL + +query I +SELECT MAX(col0) FROM t WHERE col0=1; +---- +NULL + +statement ok +DROP TABLE t; From e7c0482d05a4251db83eb7d4897318020df6f0b2 Mon Sep 17 00:00:00 2001 From: Bruce Ritchie Date: Sat, 20 Jan 2024 05:28:24 -0500 Subject: [PATCH 27/39] support to_timestamp with optional chrono formats (#8886) * Support to_timestamp with chrono formatting #5398 * Updated user guide's to_timestamp to include chrono formatting information #5398 * Minor comment update. * Small documentation updates for to_timestamp functions. * Cargo fmt and clippy improvements. * Switched to assert and unwrap_err based on feedback * Fixed assert, code compiles and runs as expected now. * Fix fmt (again). * Add additional to_timestamp tests covering usage with tables with and without valid formats. * to_timestamp documentation fixes. * - Changed internal_err! -> exec_err! for unsupported data type errors. - Extracted out to_timestamp_impl method to reduce code duplication as per PR feedback. - Extracted out validate_to_timestamp_data_types to reduce code duplication as per PR feedback. - Added additional tests for argument validation and invalid arguments. - Removed unnecessary shim function 'string_to_timestamp_nanos_with_format_shim' * Resolved merge conflict, updated toStringXXX methods to reflect upstream change * prettier * Fix clippy --------- Co-authored-by: Andrew Lamb --- .../examples/dataframe_to_timestamp.rs | 109 +++ datafusion/expr/src/built_in_function.rs | 68 +- datafusion/expr/src/expr_fn.rs | 25 +- .../physical-expr/src/datetime_expressions.rs | 817 ++++++++++++++++-- .../proto/src/logical_plan/from_proto.rs | 69 +- .../sqllogictest/test_files/timestamps.slt | 143 ++- .../source/user-guide/sql/scalar_functions.md | 63 +- 7 files changed, 1115 insertions(+), 179 deletions(-) create mode 100644 datafusion-examples/examples/dataframe_to_timestamp.rs diff --git a/datafusion-examples/examples/dataframe_to_timestamp.rs b/datafusion-examples/examples/dataframe_to_timestamp.rs new file mode 100644 index 000000000000..8caa9245596b --- /dev/null +++ b/datafusion-examples/examples/dataframe_to_timestamp.rs @@ -0,0 +1,109 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use datafusion::arrow::array::StringArray; +use datafusion::arrow::datatypes::{DataType, Field, Schema}; +use datafusion::arrow::record_batch::RecordBatch; +use datafusion::error::Result; +use datafusion::prelude::*; +use datafusion_common::assert_contains; + +/// This example demonstrates how to use the to_timestamp function in the DataFrame API as well as via sql. +#[tokio::main] +async fn main() -> Result<()> { + // define a schema. + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8, false), + ])); + + // define data. + let batch = RecordBatch::try_new( + schema, + vec![ + Arc::new(StringArray::from(vec![ + "2020-09-08T13:42:29Z", + "2020-09-08T13:42:29.190855-05:00", + "2020-08-09 12:13:29", + "2020-01-02", + ])), + Arc::new(StringArray::from(vec![ + "2020-09-08T13:42:29Z", + "2020-09-08T13:42:29.190855-05:00", + "08-09-2020 13/42/29", + "09-27-2020 13:42:29-05:30", + ])), + ], + )?; + + // declare a new context. In spark API, this corresponds to a new spark SQLsession + let ctx = SessionContext::new(); + + // declare a table in memory. In spark API, this corresponds to createDataFrame(...). + ctx.register_batch("t", batch)?; + let df = ctx.table("t").await?; + + // use to_timestamp function to convert col 'a' to timestamp type using the default parsing + let df = df.with_column("a", to_timestamp(vec![col("a")]))?; + // use to_timestamp_seconds function to convert col 'b' to timestamp(Seconds) type using a list of chrono formats to try + let df = df.with_column( + "b", + to_timestamp_seconds(vec![ + col("b"), + lit("%+"), + lit("%d-%m-%Y %H/%M/%S"), + lit("%m-%d-%Y %H:%M:%S%#z"), + ]), + )?; + + let df = df.select_columns(&["a", "b"])?; + + // print the results + df.show().await?; + + // use sql to convert col 'a' to timestamp using the default parsing + let df = ctx.sql("select to_timestamp(a) from t").await?; + + // print the results + df.show().await?; + + // use sql to convert col 'b' to timestamp using a list of chrono formats to try + let df = ctx.sql("select to_timestamp(b, '%+', '%d-%m-%Y %H/%M/%S', '%m-%d-%Y %H:%M:%S%#z') from t").await?; + + // print the results + df.show().await?; + + // use sql to convert a static string to a timestamp using a list of chrono formats to try + let df = ctx.sql("select to_timestamp('01-14-2023 01:01:30+05:30', '%+', '%d-%m-%Y %H/%M/%S', '%m-%d-%Y %H:%M:%S%#z')").await?; + + // print the results + df.show().await?; + + // use sql to convert a static string to a timestamp using a non-matching chrono format to try + let result = ctx + .sql("select to_timestamp('01-14-2023 01/01/30', '%d-%m-%Y %H:%M:%S')") + .await? + .collect() + .await; + + let expected = "Error parsing timestamp from '01-14-2023 01/01/30' using format '%d-%m-%Y %H:%M:%S': input contains invalid characters"; + assert_contains!(result.unwrap_err().to_string(), expected); + + Ok(()) +} diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index 6f64642f60d9..b54cd68164c1 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -1053,67 +1053,13 @@ impl BuiltinScalarFunction { vec![Exact(vec![Utf8, Int64]), Exact(vec![LargeUtf8, Int64])], self.volatility(), ), - BuiltinScalarFunction::ToTimestamp => Signature::uniform( - 1, - vec![ - Int64, - Float64, - Timestamp(Nanosecond, None), - Timestamp(Microsecond, None), - Timestamp(Millisecond, None), - Timestamp(Second, None), - Utf8, - ], - self.volatility(), - ), - BuiltinScalarFunction::ToTimestampMillis => Signature::uniform( - 1, - vec![ - Int64, - Timestamp(Nanosecond, None), - Timestamp(Microsecond, None), - Timestamp(Millisecond, None), - Timestamp(Second, None), - Utf8, - ], - self.volatility(), - ), - BuiltinScalarFunction::ToTimestampMicros => Signature::uniform( - 1, - vec![ - Int64, - Timestamp(Nanosecond, None), - Timestamp(Microsecond, None), - Timestamp(Millisecond, None), - Timestamp(Second, None), - Utf8, - ], - self.volatility(), - ), - BuiltinScalarFunction::ToTimestampNanos => Signature::uniform( - 1, - vec![ - Int64, - Timestamp(Nanosecond, None), - Timestamp(Microsecond, None), - Timestamp(Millisecond, None), - Timestamp(Second, None), - Utf8, - ], - self.volatility(), - ), - BuiltinScalarFunction::ToTimestampSeconds => Signature::uniform( - 1, - vec![ - Int64, - Timestamp(Nanosecond, None), - Timestamp(Microsecond, None), - Timestamp(Millisecond, None), - Timestamp(Second, None), - Utf8, - ], - self.volatility(), - ), + BuiltinScalarFunction::ToTimestamp + | BuiltinScalarFunction::ToTimestampSeconds + | BuiltinScalarFunction::ToTimestampMillis + | BuiltinScalarFunction::ToTimestampMicros + | BuiltinScalarFunction::ToTimestampNanos => { + Signature::variadic_any(self.volatility()) + } BuiltinScalarFunction::FromUnixtime => { Signature::uniform(1, vec![Int64], self.volatility()) } diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 834420e413b0..ae534f4bb44b 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -885,29 +885,30 @@ nary_scalar_expr!( scalar_expr!(DatePart, date_part, part date, "extracts a subfield from the date"); scalar_expr!(DateTrunc, date_trunc, part date, "truncates the date to a specified level of precision"); scalar_expr!(DateBin, date_bin, stride source origin, "coerces an arbitrary timestamp to the start of the nearest specified interval"); -scalar_expr!( +nary_scalar_expr!( + ToTimestamp, + to_timestamp, + "converts a string and optional formats to a `Timestamp(Nanoseconds, None)`" +); +nary_scalar_expr!( ToTimestampMillis, to_timestamp_millis, - date, - "converts a string to a `Timestamp(Milliseconds, None)`" + "converts a string and optional formats to a `Timestamp(Milliseconds, None)`" ); -scalar_expr!( +nary_scalar_expr!( ToTimestampMicros, to_timestamp_micros, - date, - "converts a string to a `Timestamp(Microseconds, None)`" + "converts a string and optional formats to a `Timestamp(Microseconds, None)`" ); -scalar_expr!( +nary_scalar_expr!( ToTimestampNanos, to_timestamp_nanos, - date, - "converts a string to a `Timestamp(Nanoseconds, None)`" + "converts a string and optional formats to a `Timestamp(Nanoseconds, None)`" ); -scalar_expr!( +nary_scalar_expr!( ToTimestampSeconds, to_timestamp_seconds, - date, - "converts a string to a `Timestamp(Seconds, None)`" + "converts a string and optional formats to a `Timestamp(Seconds, None)`" ); scalar_expr!( FromUnixtime, diff --git a/datafusion/physical-expr/src/datetime_expressions.rs b/datafusion/physical-expr/src/datetime_expressions.rs index 589bbc8a952b..d21d89c19d2e 100644 --- a/datafusion/physical-expr/src/datetime_expressions.rs +++ b/datafusion/physical-expr/src/datetime_expressions.rs @@ -17,7 +17,6 @@ //! DateTime expressions -use crate::datetime_expressions; use crate::expressions::cast_column; use arrow::compute::cast; use arrow::{ @@ -37,7 +36,9 @@ use arrow::{ use arrow_array::temporal_conversions::NANOSECONDS; use arrow_array::timezone::Tz; use arrow_array::types::ArrowTimestampType; +use arrow_array::GenericStringArray; use chrono::prelude::*; +use chrono::LocalResult::Single; use chrono::{Duration, Months, NaiveDate}; use datafusion_common::cast::{ as_date32_array, as_date64_array, as_generic_string_array, as_primitive_array, @@ -49,9 +50,96 @@ use datafusion_common::{ ScalarValue, }; use datafusion_expr::ColumnarValue; +use itertools::Either; use std::str::FromStr; use std::sync::Arc; +/// Error message if nanosecond conversion request beyond supported interval +const ERR_NANOSECONDS_NOT_SUPPORTED: &str = "The dates that can be represented as nanoseconds have to be between 1677-09-21T00:12:44.0 and 2262-04-11T23:47:16.854775804"; + +/// Accepts a string with a `chrono` format and converts it to a +/// nanosecond precision timestamp. +/// +/// See [`chrono::format::strftime`] for the full set of supported formats. +/// +/// Implements the `to_timestamp` function to convert a string to a +/// timestamp, following the model of spark SQL’s to_`timestamp`. +/// +/// Internally, this function uses the `chrono` library for the +/// datetime parsing +/// +/// ## Timestamp Precision +/// +/// Function uses the maximum precision timestamps supported by +/// Arrow (nanoseconds stored as a 64-bit integer) timestamps. This +/// means the range of dates that timestamps can represent is ~1677 AD +/// to 2262 AM +/// +/// ## Timezone / Offset Handling +/// +/// Numerical values of timestamps are stored compared to offset UTC. +/// +/// Any timestamp in the formatting string is handled according to the rules +/// defined by `chrono`. +/// +/// [`chrono::format::strftime`]: https://docs.rs/chrono/latest/chrono/format/strftime/index.html +/// +#[inline] +pub(crate) fn string_to_timestamp_nanos_formatted( + s: &str, + format: &str, +) -> Result { + string_to_datetime_formatted(&Utc, s, format)? + .naive_utc() + .timestamp_nanos_opt() + .ok_or_else(|| { + DataFusionError::Execution(ERR_NANOSECONDS_NOT_SUPPORTED.to_string()) + }) +} + +/// Accepts a string and parses it using the [`chrono::format::strftime`] specifiers +/// relative to the provided `timezone` +/// +/// [IANA timezones] are only supported if the `arrow-array/chrono-tz` feature is enabled +/// +/// * `2023-01-01 040506 America/Los_Angeles` +/// +/// If a timestamp is ambiguous, for example as a result of daylight-savings time, an error +/// will be returned +/// +/// [`chrono::format::strftime`]: https://docs.rs/chrono/latest/chrono/format/strftime/index.html +/// [IANA timezones]: https://www.iana.org/time-zones +pub(crate) fn string_to_datetime_formatted( + timezone: &T, + s: &str, + format: &str, +) -> Result, DataFusionError> { + let err = |err_ctx: &str| { + DataFusionError::Execution(format!( + "Error parsing timestamp from '{s}' using format '{format}': {err_ctx}" + )) + }; + + // attempt to parse the string assuming it has a timezone + let dt = DateTime::parse_from_str(s, format); + + if let Err(e) = &dt { + // no timezone or other failure, try without a timezone + let ndt = NaiveDateTime::parse_from_str(s, format); + if let Err(e) = &ndt { + return Err(err(&e.to_string())); + } + + if let Single(e) = &timezone.from_local_datetime(&ndt.unwrap()) { + Ok(e.to_owned()) + } else { + Err(err(&e.to_string())) + } + } else { + Ok(dt.unwrap().with_timezone(timezone)) + } +} + /// given a function `op` that maps a `&str` to a Result of an arrow native type, /// returns a `PrimitiveArray` after the application /// of the function to `args[0]`. @@ -84,7 +172,96 @@ where array.iter().map(|x| x.map(&op).transpose()).collect() } -// given an function that maps a `&str` to a arrow native type, +/// given a function `op` that maps `&str`, `&str` to the first successful Result +/// of an arrow native type, returns a `PrimitiveArray` after the application of the +/// function to `args` and the subsequence application of the `op2` function to any +/// successful result. This function calls the `op` function with the first and second +/// argument and if not successful continues with first and third, first and fourth, +/// etc until the result was successful or no more arguments are present. +/// # Errors +/// This function errors iff: +/// * the number of arguments is not > 1 or +/// * the array arguments are not castable to a `GenericStringArray` or +/// * the function `op` errors for all input +pub(crate) fn strings_to_primitive_function<'a, T, O, F, F2>( + args: &'a [ColumnarValue], + op: F, + op2: F2, + name: &str, +) -> Result> +where + O: ArrowPrimitiveType, + T: OffsetSizeTrait, + F: Fn(&'a str, &'a str) -> Result, + F2: Fn(O::Native) -> O::Native, +{ + if args.len() < 2 { + return internal_err!( + "{:?} args were supplied but {} takes 2 or more arguments", + args.len(), + name + ); + } + + // this will throw the error if any of the array args are not castable to GenericStringArray + let data = args + .iter() + .map(|a| match a { + ColumnarValue::Array(a) => { + Ok(Either::Left(as_generic_string_array::(a.as_ref())?)) + } + ColumnarValue::Scalar(s) => match s { + ScalarValue::Utf8(a) | ScalarValue::LargeUtf8(a) => Ok(Either::Right(a)), + other => internal_err!( + "Unexpected scalar type encountered '{other}' for function '{name}'" + ), + }, + }) + .collect::, &Option>>>>()?; + + let first_arg = &data.first().unwrap().left().unwrap(); + + first_arg + .iter() + .enumerate() + .map(|(pos, x)| { + let mut val = None; + + if let Some(x) = x { + let param_args = data.iter().skip(1); + + // go through the args and find the first successful result. Only the last + // failure will be returned if no successful result was received. + for param_arg in param_args { + // param_arg is an array, use the corresponding index into the array as the arg + // we're currently parsing + let p = *param_arg; + let r = if p.is_left() { + let p = p.left().unwrap(); + op(x, p.value(pos)) + } + // args is a scalar, use it directly + else if let Some(p) = p.right().unwrap() { + op(x, p.as_str()) + } else { + continue; + }; + + if r.is_ok() { + val = Some(Ok(op2(r.unwrap()))); + break; + } else { + val = Some(r); + } + } + }; + + val.transpose() + }) + .collect() +} + +// given an function that maps a `&str` to an arrow native type, // returns a `ColumnarValue` where the function is applied to either a `ArrayRef` or `ScalarValue` // depending on the `args`'s variant. fn handle<'a, O, F, S>( @@ -99,24 +276,112 @@ where { match &args[0] { ColumnarValue::Array(a) => match a.data_type() { - DataType::Utf8 => Ok(ColumnarValue::Array(Arc::new( + DataType::Utf8 | DataType::LargeUtf8 => Ok(ColumnarValue::Array(Arc::new( unary_string_to_primitive_function::(&[a.as_ref()], op, name)?, ))), - DataType::LargeUtf8 => Ok(ColumnarValue::Array(Arc::new( - unary_string_to_primitive_function::(&[a.as_ref()], op, name)?, - ))), - other => internal_err!("Unsupported data type {other:?} for function {name}"), + other => exec_err!("Unsupported data type {other:?} for function {name}"), }, ColumnarValue::Scalar(scalar) => match scalar { - ScalarValue::Utf8(a) => { + ScalarValue::Utf8(a) | ScalarValue::LargeUtf8(a) => { let result = a.as_ref().map(|x| (op)(x)).transpose()?; Ok(ColumnarValue::Scalar(S::scalar(result))) } - ScalarValue::LargeUtf8(a) => { - let result = a.as_ref().map(|x| (op)(x)).transpose()?; - Ok(ColumnarValue::Scalar(S::scalar(result))) + other => exec_err!("Unsupported data type {other:?} for function {name}"), + }, + } +} + +// given an function that maps a `&str`, `&str` to an arrow native type, +// returns a `ColumnarValue` where the function is applied to either a `ArrayRef` or `ScalarValue` +// depending on the `args`'s variant. +fn handle_multiple<'a, O, F, S, M>( + args: &'a [ColumnarValue], + op: F, + op2: M, + name: &str, +) -> Result +where + O: ArrowPrimitiveType, + S: ScalarType, + F: Fn(&'a str, &'a str) -> Result, + M: Fn(O::Native) -> O::Native, +{ + match &args[0] { + ColumnarValue::Array(a) => match a.data_type() { + DataType::Utf8 | DataType::LargeUtf8 => { + // validate the column types + for (pos, arg) in args.iter().enumerate() { + match arg { + ColumnarValue::Array(arg) => match arg.data_type() { + DataType::Utf8 | DataType::LargeUtf8 => { + // all good + }, + other => return exec_err!("Unsupported data type {other:?} for function {name}, arg # {pos}"), + }, + ColumnarValue::Scalar(arg) => { match arg.data_type() { + DataType::Utf8 | DataType::LargeUtf8 => { + // all good + }, + other => return exec_err!("Unsupported data type {other:?} for function {name}, arg # {pos}"), + }} + } + } + + Ok(ColumnarValue::Array(Arc::new( + strings_to_primitive_function::(args, op, op2, name)?, + ))) + } + other => { + exec_err!("Unsupported data type {other:?} for function {name}") + } + }, + // if the first argument is a scalar utf8 all arguments are expected to be scalar utf8 + ColumnarValue::Scalar(scalar) => match scalar { + ScalarValue::Utf8(a) | ScalarValue::LargeUtf8(a) => { + let mut val: Option> = None; + let mut err: Option = None; + + match a { + Some(a) => { + // enumerate all the values finding the first one that returns an Ok result + for (pos, v) in args.iter().enumerate().skip(1) { + if let ColumnarValue::Scalar(s) = v { + if let ScalarValue::Utf8(x) | ScalarValue::LargeUtf8(x) = + s + { + if let Some(s) = x { + match op(a.as_str(), s.as_str()) { + Ok(r) => { + val = Some(Ok(ColumnarValue::Scalar( + S::scalar(Some(op2(r))), + ))); + break; + } + Err(e) => { + err = Some(e); + } + } + } + } else { + return exec_err!("Unsupported data type {s:?} for function {name}, arg # {pos}"); + } + } else { + return exec_err!("Unsupported data type {v:?} for function {name}, arg # {pos}"); + } + } + } + None => (), + } + + if let Some(v) = val { + v + } else { + Err(err.unwrap()) + } + } + other => { + exec_err!("Unsupported data type {other:?} for function {name}") } - other => internal_err!("Unsupported data type {other:?} for function {name}"), }, } } @@ -126,53 +391,61 @@ fn string_to_timestamp_nanos_shim(s: &str) -> Result { string_to_timestamp_nanos(s).map_err(|e| e.into()) } +fn to_timestamp_impl>( + args: &[ColumnarValue], + name: &str, +) -> Result { + let factor = match T::UNIT { + TimeUnit::Second => 1_000_000_000, + TimeUnit::Millisecond => 1_000_000, + TimeUnit::Microsecond => 1_000, + TimeUnit::Nanosecond => 1, + }; + + match args.len() { + 1 => handle::( + args, + |s| string_to_timestamp_nanos_shim(s).map(|n| n / factor), + name, + ), + n if n >= 2 => handle_multiple::( + args, + string_to_timestamp_nanos_formatted, + |n| n / factor, + name, + ), + _ => internal_err!("Unsupported 0 argument count for function {name}"), + } +} + /// to_timestamp SQL function /// -/// Note: `to_timestamp` returns `Timestamp(Nanosecond)` though its arguments are interpreted as **seconds**. The supported range for integer input is between `-9223372037` and `9223372036`. +/// Note: `to_timestamp` returns `Timestamp(Nanosecond)` though its arguments are interpreted as **seconds**. +/// The supported range for integer input is between `-9223372037` and `9223372036`. /// Supported range for string input is between `1677-09-21T00:12:44.0` and `2262-04-11T23:47:16.0`. /// Please use `to_timestamp_seconds` for the input outside of supported bounds. pub fn to_timestamp(args: &[ColumnarValue]) -> Result { - handle::( - args, - string_to_timestamp_nanos_shim, - "to_timestamp", - ) + to_timestamp_impl::(args, "to_timestamp") } /// to_timestamp_millis SQL function pub fn to_timestamp_millis(args: &[ColumnarValue]) -> Result { - handle::( - args, - |s| string_to_timestamp_nanos_shim(s).map(|n| n / 1_000_000), - "to_timestamp_millis", - ) + to_timestamp_impl::(args, "to_timestamp_millis") } /// to_timestamp_micros SQL function pub fn to_timestamp_micros(args: &[ColumnarValue]) -> Result { - handle::( - args, - |s| string_to_timestamp_nanos_shim(s).map(|n| n / 1_000), - "to_timestamp_micros", - ) + to_timestamp_impl::(args, "to_timestamp_micros") } /// to_timestamp_nanos SQL function pub fn to_timestamp_nanos(args: &[ColumnarValue]) -> Result { - handle::( - args, - string_to_timestamp_nanos_shim, - "to_timestamp_nanos", - ) + to_timestamp_impl::(args, "to_timestamp_nanos") } /// to_timestamp_seconds SQL function pub fn to_timestamp_seconds(args: &[ColumnarValue]) -> Result { - handle::( - args, - |s| string_to_timestamp_nanos_shim(s).map(|n| n / 1_000_000_000), - "to_timestamp_seconds", - ) + to_timestamp_impl::(args, "to_timestamp_seconds") } /// Create an implementation of `now()` that always returns the @@ -915,22 +1188,51 @@ where Ok(b) } -/// to_timestammp() SQL function implementation +fn validate_to_timestamp_data_types( + args: &[ColumnarValue], + name: &str, +) -> Option> { + for (idx, a) in args.iter().skip(1).enumerate() { + match a.data_type() { + DataType::Utf8 | DataType::LargeUtf8 => { + // all good + } + _ => { + return Some(internal_err!( + "{name} function unsupported data type at index {}: {}", + idx + 1, + a.data_type() + )); + } + } + } + + None +} + +/// to_timestamp() SQL function implementation pub fn to_timestamp_invoke(args: &[ColumnarValue]) -> Result { - if args.len() != 1 { + if args.is_empty() { return internal_err!( - "to_timestamp function requires 1 arguments, got {}", + "to_timestamp function requires 1 or more arguments, got {}", args.len() ); } + // validate that any args after the first one are Utf8 + if args.len() > 1 { + if let Some(value) = validate_to_timestamp_data_types(args, "to_timestamp") { + return value; + } + } + match args[0].data_type() { - DataType::Int64 => cast_column( + DataType::Int32 | DataType::Int64 => cast_column( &cast_column(&args[0], &DataType::Timestamp(TimeUnit::Second, None), None)?, &DataType::Timestamp(TimeUnit::Nanosecond, None), None, ), - DataType::Float64 => cast_column( + DataType::Null | DataType::Float64 => cast_column( &args[0], &DataType::Timestamp(TimeUnit::Nanosecond, None), None, @@ -940,7 +1242,7 @@ pub fn to_timestamp_invoke(args: &[ColumnarValue]) -> Result { &DataType::Timestamp(TimeUnit::Nanosecond, None), None, ), - DataType::Utf8 => datetime_expressions::to_timestamp(args), + DataType::Utf8 => to_timestamp(args), other => { internal_err!( "Unsupported data type {:?} for function to_timestamp", @@ -952,20 +1254,31 @@ pub fn to_timestamp_invoke(args: &[ColumnarValue]) -> Result { /// to_timestamp_millis() SQL function implementation pub fn to_timestamp_millis_invoke(args: &[ColumnarValue]) -> Result { - if args.len() != 1 { + if args.is_empty() { return internal_err!( - "to_timestamp_millis function requires 1 argument, got {}", + "to_timestamp_millis function requires 1 or more arguments, got {}", args.len() ); } + // validate that any args after the first one are Utf8 + if args.len() > 1 { + if let Some(value) = validate_to_timestamp_data_types(args, "to_timestamp_millis") + { + return value; + } + } + match args[0].data_type() { - DataType::Int64 | DataType::Timestamp(_, None) => cast_column( + DataType::Null + | DataType::Int32 + | DataType::Int64 + | DataType::Timestamp(_, None) => cast_column( &args[0], &DataType::Timestamp(TimeUnit::Millisecond, None), None, ), - DataType::Utf8 => datetime_expressions::to_timestamp_millis(args), + DataType::Utf8 => to_timestamp_millis(args), other => { internal_err!( "Unsupported data type {:?} for function to_timestamp_millis", @@ -977,20 +1290,31 @@ pub fn to_timestamp_millis_invoke(args: &[ColumnarValue]) -> Result Result { - if args.len() != 1 { + if args.is_empty() { return internal_err!( - "to_timestamp_micros function requires 1 argument, got {}", + "to_timestamp_micros function requires 1 or more arguments, got {}", args.len() ); } + // validate that any args after the first one are Utf8 + if args.len() > 1 { + if let Some(value) = validate_to_timestamp_data_types(args, "to_timestamp_micros") + { + return value; + } + } + match args[0].data_type() { - DataType::Int64 | DataType::Timestamp(_, None) => cast_column( + DataType::Null + | DataType::Int32 + | DataType::Int64 + | DataType::Timestamp(_, None) => cast_column( &args[0], &DataType::Timestamp(TimeUnit::Microsecond, None), None, ), - DataType::Utf8 => datetime_expressions::to_timestamp_micros(args), + DataType::Utf8 => to_timestamp_micros(args), other => { internal_err!( "Unsupported data type {:?} for function to_timestamp_micros", @@ -1002,20 +1326,31 @@ pub fn to_timestamp_micros_invoke(args: &[ColumnarValue]) -> Result Result { - if args.len() != 1 { + if args.is_empty() { return internal_err!( - "to_timestamp_nanos function requires 1 argument, got {}", + "to_timestamp_nanos function requires 1 or more arguments, got {}", args.len() ); } + // validate that any args after the first one are Utf8 + if args.len() > 1 { + if let Some(value) = validate_to_timestamp_data_types(args, "to_timestamp_nanos") + { + return value; + } + } + match args[0].data_type() { - DataType::Int64 | DataType::Timestamp(_, None) => cast_column( + DataType::Null + | DataType::Int32 + | DataType::Int64 + | DataType::Timestamp(_, None) => cast_column( &args[0], &DataType::Timestamp(TimeUnit::Nanosecond, None), None, ), - DataType::Utf8 => datetime_expressions::to_timestamp_nanos(args), + DataType::Utf8 => to_timestamp_nanos(args), other => { internal_err!( "Unsupported data type {:?} for function to_timestamp_nanos", @@ -1027,18 +1362,30 @@ pub fn to_timestamp_nanos_invoke(args: &[ColumnarValue]) -> Result Result { - if args.len() != 1 { + if args.is_empty() { return internal_err!( - "to_timestamp_seconds function requires 1 argument, got {}", + "to_timestamp_seconds function requires 1 or more arguments, got {}", args.len() ); } + // validate that any args after the first one are Utf8 + if args.len() > 1 { + if let Some(value) = + validate_to_timestamp_data_types(args, "to_timestamp_seconds") + { + return value; + } + } + match args[0].data_type() { - DataType::Int64 | DataType::Timestamp(_, None) => { + DataType::Null + | DataType::Int32 + | DataType::Int64 + | DataType::Timestamp(_, None) => { cast_column(&args[0], &DataType::Timestamp(TimeUnit::Second, None), None) } - DataType::Utf8 => datetime_expressions::to_timestamp_seconds(args), + DataType::Utf8 => to_timestamp_seconds(args), other => { internal_err!( "Unsupported data type {:?} for function to_timestamp_seconds", @@ -1077,7 +1424,13 @@ mod tests { use arrow::array::{ as_primitive_array, ArrayRef, Int64Array, IntervalDayTimeArray, StringBuilder, }; - use arrow_array::TimestampNanosecondArray; + use arrow_array::types::Int64Type; + use arrow_array::{ + TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, + TimestampSecondArray, + }; + use datafusion_common::assert_contains; + use datafusion_expr::ScalarFunctionImplementation; use super::*; @@ -1108,6 +1461,47 @@ mod tests { Ok(()) } + #[test] + fn to_timestamp_with_formats_arrays_and_nulls() -> Result<()> { + // ensure that arrow array implementation is wired up and handles nulls correctly + + let mut date_string_builder = StringBuilder::with_capacity(2, 1024); + let mut format1_builder = StringBuilder::with_capacity(2, 1024); + let mut format2_builder = StringBuilder::with_capacity(2, 1024); + let mut format3_builder = StringBuilder::with_capacity(2, 1024); + let mut ts_builder = TimestampNanosecondArray::builder(2); + + date_string_builder.append_null(); + format1_builder.append_null(); + format2_builder.append_null(); + format3_builder.append_null(); + ts_builder.append_null(); + + date_string_builder.append_value("2020-09-08T13:42:29.19085Z"); + format1_builder.append_value("%s"); + format2_builder.append_value("%c"); + format3_builder.append_value("%+"); + ts_builder.append_value(1599572549190850000); + + let expected_timestamps = &ts_builder.finish() as &dyn Array; + + let string_array = [ + ColumnarValue::Array(Arc::new(date_string_builder.finish()) as ArrayRef), + ColumnarValue::Array(Arc::new(format1_builder.finish()) as ArrayRef), + ColumnarValue::Array(Arc::new(format2_builder.finish()) as ArrayRef), + ColumnarValue::Array(Arc::new(format3_builder.finish()) as ArrayRef), + ]; + let parsed_timestamps = to_timestamp(&string_array) + .expect("that to_timestamp with format args parsed values without error"); + if let ColumnarValue::Array(parsed_array) = parsed_timestamps { + assert_eq!(parsed_array.len(), 2); + assert_eq!(expected_timestamps, parsed_array.as_ref()); + } else { + panic!("Expected a columnar array") + } + Ok(()) + } + #[test] fn date_trunc_test() { let cases = vec![ @@ -1663,7 +2057,7 @@ mod tests { let int64array = ColumnarValue::Array(Arc::new(builder.finish())); let expected_err = - "Internal error: Unsupported data type Int64 for function to_timestamp"; + "Execution error: Unsupported data type Int64 for function to_timestamp"; match to_timestamp(&[int64array]) { Ok(_) => panic!("Expected error but got success"), Err(e) => { @@ -1675,4 +2069,303 @@ mod tests { } Ok(()) } + + #[test] + fn to_timestamp_with_formats_invalid_input_type() -> Result<()> { + // pass the wrong type of input array to to_timestamp and test + // that we get an error. + + let mut builder = Int64Array::builder(1); + builder.append_value(1); + let int64array = [ + ColumnarValue::Array(Arc::new(builder.finish())), + ColumnarValue::Array(Arc::new(builder.finish())), + ]; + + let expected_err = + "Execution error: Unsupported data type Int64 for function to_timestamp"; + match to_timestamp(&int64array) { + Ok(_) => panic!("Expected error but got success"), + Err(e) => { + assert!( + e.to_string().contains(expected_err), + "Can not find expected error '{expected_err}'. Actual error '{e}'" + ); + } + } + Ok(()) + } + + #[test] + fn to_timestamp_with_unparseable_data() -> Result<()> { + let mut date_string_builder = StringBuilder::with_capacity(2, 1024); + + date_string_builder.append_null(); + + date_string_builder.append_value("2020-09-08 - 13:42:29.19085Z"); + + let string_array = + ColumnarValue::Array(Arc::new(date_string_builder.finish()) as ArrayRef); + + let expected_err = + "Arrow error: Parser error: Error parsing timestamp from '2020-09-08 - 13:42:29.19085Z': error parsing time"; + match to_timestamp(&[string_array]) { + Ok(_) => panic!("Expected error but got success"), + Err(e) => { + assert!( + e.to_string().contains(expected_err), + "Can not find expected error '{expected_err}'. Actual error '{e}'" + ); + } + } + Ok(()) + } + + #[test] + fn to_timestamp_with_no_matching_formats() -> Result<()> { + let mut date_string_builder = StringBuilder::with_capacity(2, 1024); + let mut format1_builder = StringBuilder::with_capacity(2, 1024); + let mut format2_builder = StringBuilder::with_capacity(2, 1024); + let mut format3_builder = StringBuilder::with_capacity(2, 1024); + + date_string_builder.append_null(); + format1_builder.append_null(); + format2_builder.append_null(); + format3_builder.append_null(); + + date_string_builder.append_value("2020-09-08T13:42:29.19085Z"); + format1_builder.append_value("%s"); + format2_builder.append_value("%c"); + format3_builder.append_value("%H:%M:%S"); + + let string_array = [ + ColumnarValue::Array(Arc::new(date_string_builder.finish()) as ArrayRef), + ColumnarValue::Array(Arc::new(format1_builder.finish()) as ArrayRef), + ColumnarValue::Array(Arc::new(format2_builder.finish()) as ArrayRef), + ColumnarValue::Array(Arc::new(format3_builder.finish()) as ArrayRef), + ]; + + let expected_err = + "Execution error: Error parsing timestamp from '2020-09-08T13:42:29.19085Z' using format '%H:%M:%S': input contains invalid characters"; + match to_timestamp(&string_array) { + Ok(_) => panic!("Expected error but got success"), + Err(e) => { + assert!( + e.to_string().contains(expected_err), + "Can not find expected error '{expected_err}'. Actual error '{e}'" + ); + } + } + Ok(()) + } + + #[test] + fn string_to_timestamp_formatted() { + // Explicit timezone + assert_eq!( + 1599572549190855000, + parse_timestamp_formatted("2020-09-08T13:42:29.190855+00:00", "%+").unwrap() + ); + assert_eq!( + 1599572549190855000, + parse_timestamp_formatted("2020-09-08T13:42:29.190855Z", "%+").unwrap() + ); + assert_eq!( + 1599572549000000000, + parse_timestamp_formatted("2020-09-08T13:42:29Z", "%+").unwrap() + ); // no fractional part + assert_eq!( + 1599590549190855000, + parse_timestamp_formatted("2020-09-08T13:42:29.190855-05:00", "%+").unwrap() + ); + assert_eq!( + 1599590549000000000, + parse_timestamp_formatted("1599590549", "%s").unwrap() + ); + assert_eq!( + 1599572549000000000, + parse_timestamp_formatted("09-08-2020 13/42/29", "%m-%d-%Y %H/%M/%S") + .unwrap() + ); + } + + fn parse_timestamp_formatted(s: &str, format: &str) -> Result { + let result = string_to_timestamp_nanos_formatted(s, format); + if let Err(e) = &result { + eprintln!("Error parsing timestamp '{s}' using format '{format}': {e:?}"); + } + result + } + + #[test] + fn string_to_timestamp_formatted_invalid() { + // Test parsing invalid formats + let cases = [ + ("", "%Y%m%d %H%M%S", "premature end of input"), + ("SS", "%c", "premature end of input"), + ("Wed, 18 Feb 2015 23:16:09 GMT", "", "trailing input"), + ( + "Wed, 18 Feb 2015 23:16:09 GMT", + "%XX", + "input contains invalid characters", + ), + ( + "Wed, 18 Feb 2015 23:16:09 GMT", + "%Y%m%d %H%M%S", + "input contains invalid characters", + ), + ]; + + for (s, f, ctx) in cases { + let expected = format!("Execution error: Error parsing timestamp from '{s}' using format '{f}': {ctx}"); + let actual = string_to_datetime_formatted(&Utc, s, f) + .unwrap_err() + .to_string(); + assert_eq!(actual, expected) + } + } + + #[test] + fn string_to_timestamp_invalid_arguments() { + // Test parsing invalid formats + let cases = [ + ("", "%Y%m%d %H%M%S", "premature end of input"), + ("SS", "%c", "premature end of input"), + ("Wed, 18 Feb 2015 23:16:09 GMT", "", "trailing input"), + ( + "Wed, 18 Feb 2015 23:16:09 GMT", + "%XX", + "input contains invalid characters", + ), + ( + "Wed, 18 Feb 2015 23:16:09 GMT", + "%Y%m%d %H%M%S", + "input contains invalid characters", + ), + ]; + + for (s, f, ctx) in cases { + let expected = format!("Execution error: Error parsing timestamp from '{s}' using format '{f}': {ctx}"); + let actual = string_to_datetime_formatted(&Utc, s, f) + .unwrap_err() + .to_string(); + assert_eq!(actual, expected) + } + } + + #[test] + fn test_to_timestamp_arg_validation() { + let mut date_string_builder = StringBuilder::with_capacity(2, 1024); + date_string_builder.append_value("2020-09-08T13:42:29.19085Z"); + + let data = date_string_builder.finish(); + + let funcs: Vec<(ScalarFunctionImplementation, TimeUnit)> = vec![ + (Arc::new(to_timestamp), TimeUnit::Nanosecond), + (Arc::new(to_timestamp_micros), TimeUnit::Microsecond), + (Arc::new(to_timestamp_millis), TimeUnit::Millisecond), + (Arc::new(to_timestamp_nanos), TimeUnit::Nanosecond), + (Arc::new(to_timestamp_seconds), TimeUnit::Second), + ]; + + let mut nanos_builder = TimestampNanosecondArray::builder(2); + let mut millis_builder = TimestampMillisecondArray::builder(2); + let mut micros_builder = TimestampMicrosecondArray::builder(2); + let mut sec_builder = TimestampSecondArray::builder(2); + + nanos_builder.append_value(1599572549190850000); + millis_builder.append_value(1599572549190); + micros_builder.append_value(1599572549190850); + sec_builder.append_value(1599572549); + + let nanos_expected_timestamps = &nanos_builder.finish() as &dyn Array; + let millis_expected_timestamps = &millis_builder.finish() as &dyn Array; + let micros_expected_timestamps = µs_builder.finish() as &dyn Array; + let sec_expected_timestamps = &sec_builder.finish() as &dyn Array; + + for (func, time_unit) in funcs { + // test UTF8 + let string_array = [ + ColumnarValue::Array(Arc::new(data.clone()) as ArrayRef), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("%s".to_string()))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("%c".to_string()))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("%+".to_string()))), + ]; + let parsed_timestamps = func(&string_array) + .expect("that to_timestamp with format args parsed values without error"); + if let ColumnarValue::Array(parsed_array) = parsed_timestamps { + assert_eq!(parsed_array.len(), 1); + match time_unit { + TimeUnit::Nanosecond => { + assert_eq!(nanos_expected_timestamps, parsed_array.as_ref()) + } + TimeUnit::Millisecond => { + assert_eq!(millis_expected_timestamps, parsed_array.as_ref()) + } + TimeUnit::Microsecond => { + assert_eq!(micros_expected_timestamps, parsed_array.as_ref()) + } + TimeUnit::Second => { + assert_eq!(sec_expected_timestamps, parsed_array.as_ref()) + } + }; + } else { + panic!("Expected a columnar array") + } + + // test LargeUTF8 + let string_array = [ + ColumnarValue::Array(Arc::new(data.clone()) as ArrayRef), + ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some("%s".to_string()))), + ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some("%c".to_string()))), + ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some("%+".to_string()))), + ]; + let parsed_timestamps = func(&string_array) + .expect("that to_timestamp with format args parsed values without error"); + if let ColumnarValue::Array(parsed_array) = parsed_timestamps { + assert_eq!(parsed_array.len(), 1); + match time_unit { + TimeUnit::Nanosecond => { + assert_eq!(nanos_expected_timestamps, parsed_array.as_ref()) + } + TimeUnit::Millisecond => { + assert_eq!(millis_expected_timestamps, parsed_array.as_ref()) + } + TimeUnit::Microsecond => { + assert_eq!(micros_expected_timestamps, parsed_array.as_ref()) + } + TimeUnit::Second => { + assert_eq!(sec_expected_timestamps, parsed_array.as_ref()) + } + }; + } else { + panic!("Expected a columnar array") + } + + // test other types + let string_array = [ + ColumnarValue::Array(Arc::new(data.clone()) as ArrayRef), + ColumnarValue::Scalar(ScalarValue::Int32(Some(1))), + ColumnarValue::Scalar(ScalarValue::Int32(Some(2))), + ColumnarValue::Scalar(ScalarValue::Int32(Some(3))), + ]; + + let expected = "Unsupported data type Int32 for function".to_string(); + let actual = func(&string_array).unwrap_err().to_string(); + assert_contains!(actual, expected); + + // test other types + let string_array = [ + ColumnarValue::Array(Arc::new(data.clone()) as ArrayRef), + ColumnarValue::Array(Arc::new(PrimitiveArray::::new( + vec![1i64].into(), + None, + )) as ArrayRef), + ]; + + let expected = "Unsupported data type".to_string(); + let actual = func(&string_array).unwrap_err().to_string(); + assert_contains!(actual, expected); + } + } } diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 973e366d0bbd..aae19a15b89a 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -65,10 +65,9 @@ use datafusion_expr::{ radians, random, regexp_match, regexp_replace, repeat, replace, reverse, right, round, rpad, rtrim, sha224, sha256, sha384, sha512, signum, sin, sinh, split_part, sqrt, starts_with, string_to_array, strpos, struct_fun, substr, substr_index, - substring, tan, tanh, to_hex, to_timestamp_micros, to_timestamp_millis, - to_timestamp_nanos, to_timestamp_seconds, translate, trim, trunc, upper, uuid, - AggregateFunction, Between, BinaryExpr, BuiltInWindowFunction, BuiltinScalarFunction, - Case, Cast, Expr, GetFieldAccess, GetIndexedField, GroupingSet, + substring, tan, tanh, to_hex, translate, trim, trunc, upper, uuid, AggregateFunction, + Between, BinaryExpr, BuiltInWindowFunction, BuiltinScalarFunction, Case, Cast, Expr, + GetFieldAccess, GetIndexedField, GroupingSet, GroupingSet::GroupingSets, JoinConstraint, JoinType, Like, Operator, TryCast, WindowFrame, WindowFrameBound, WindowFrameUnits, @@ -476,7 +475,6 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::Trim => Self::Trim, ScalarFunction::Ltrim => Self::Ltrim, ScalarFunction::Rtrim => Self::Rtrim, - ScalarFunction::ToTimestamp => Self::ToTimestamp, ScalarFunction::ArrayAppend => Self::ArrayAppend, ScalarFunction::ArraySort => Self::ArraySort, ScalarFunction::ArrayConcat => Self::ArrayConcat, @@ -523,7 +521,6 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::Digest => Self::Digest, ScalarFunction::Encode => Self::Encode, ScalarFunction::Decode => Self::Decode, - ScalarFunction::ToTimestampMillis => Self::ToTimestampMillis, ScalarFunction::Log2 => Self::Log2, ScalarFunction::Signum => Self::Signum, ScalarFunction::Ascii => Self::Ascii, @@ -548,6 +545,8 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::Strpos => Self::Strpos, ScalarFunction::Substr => Self::Substr, ScalarFunction::ToHex => Self::ToHex, + ScalarFunction::ToTimestamp => Self::ToTimestamp, + ScalarFunction::ToTimestampMillis => Self::ToTimestampMillis, ScalarFunction::ToTimestampMicros => Self::ToTimestampMicros, ScalarFunction::ToTimestampNanos => Self::ToTimestampNanos, ScalarFunction::ToTimestampSeconds => Self::ToTimestampSeconds, @@ -1689,17 +1688,55 @@ pub fn parse_expr( parse_expr(&args[1], registry)?, )), ScalarFunction::ToHex => Ok(to_hex(parse_expr(&args[0], registry)?)), + ScalarFunction::ToTimestamp => { + let args: Vec<_> = args + .iter() + .map(|expr| parse_expr(expr, registry)) + .collect::>()?; + Ok(Expr::ScalarFunction(expr::ScalarFunction::new( + BuiltinScalarFunction::ToTimestamp, + args, + ))) + } ScalarFunction::ToTimestampMillis => { - Ok(to_timestamp_millis(parse_expr(&args[0], registry)?)) + let args: Vec<_> = args + .iter() + .map(|expr| parse_expr(expr, registry)) + .collect::>()?; + Ok(Expr::ScalarFunction(expr::ScalarFunction::new( + BuiltinScalarFunction::ToTimestampMillis, + args, + ))) } ScalarFunction::ToTimestampMicros => { - Ok(to_timestamp_micros(parse_expr(&args[0], registry)?)) + let args: Vec<_> = args + .iter() + .map(|expr| parse_expr(expr, registry)) + .collect::>()?; + Ok(Expr::ScalarFunction(expr::ScalarFunction::new( + BuiltinScalarFunction::ToTimestampMicros, + args, + ))) } ScalarFunction::ToTimestampNanos => { - Ok(to_timestamp_nanos(parse_expr(&args[0], registry)?)) + let args: Vec<_> = args + .iter() + .map(|expr| parse_expr(expr, registry)) + .collect::>()?; + Ok(Expr::ScalarFunction(expr::ScalarFunction::new( + BuiltinScalarFunction::ToTimestampNanos, + args, + ))) } ScalarFunction::ToTimestampSeconds => { - Ok(to_timestamp_seconds(parse_expr(&args[0], registry)?)) + let args: Vec<_> = args + .iter() + .map(|expr| parse_expr(expr, registry)) + .collect::>()?; + Ok(Expr::ScalarFunction(expr::ScalarFunction::new( + BuiltinScalarFunction::ToTimestampSeconds, + args, + ))) } ScalarFunction::Now => Ok(now()), ScalarFunction::Translate => Ok(translate( @@ -1741,18 +1778,6 @@ pub fn parse_expr( ScalarFunction::ArrowTypeof => { Ok(arrow_typeof(parse_expr(&args[0], registry)?)) } - ScalarFunction::ToTimestamp => { - let args: Vec<_> = args - .iter() - .map(|expr| parse_expr(expr, registry)) - .collect::>()?; - Ok(Expr::ScalarFunction( - datafusion_expr::expr::ScalarFunction::new( - BuiltinScalarFunction::ToTimestamp, - args, - ), - )) - } ScalarFunction::Flatten => Ok(flatten(parse_expr(&args[0], registry)?)), ScalarFunction::StringToArray => Ok(string_to_array( parse_expr(&args[0], registry)?, diff --git a/datafusion/sqllogictest/test_files/timestamps.slt b/datafusion/sqllogictest/test_files/timestamps.slt index 2ab3dbdac61b..5c7687aa27b2 100644 --- a/datafusion/sqllogictest/test_files/timestamps.slt +++ b/datafusion/sqllogictest/test_files/timestamps.slt @@ -18,7 +18,7 @@ ########## ## Common timestamp data # -# ts_data: Int64 nanosecods +# ts_data: Int64 nanoseconds # ts_data_nanos: Timestamp(Nanosecond, None) # ts_data_micros: Timestamp(Microsecond, None) # ts_data_millis: Timestamp(Millisecond, None) @@ -331,6 +331,35 @@ SELECT COUNT(*) FROM ts_data_secs where ts > to_timestamp_seconds('2020-09-08T12 ---- 2 +# to_timestamp with formatting +query I +SELECT COUNT(*) FROM ts_data_nanos where ts > to_timestamp('2020-09-08T12:00:00+00:00', '2020-09-08 12/00/00+00:00', '%c', '%+', '%Y-%m-%d %H/%M/%s%#z') +---- +2 + +# to_timestamp_nanos with formatting +query I +SELECT COUNT(*) FROM ts_data_nanos where ts > to_timestamp_nanos('2020-09-08 12/00/00+00:00', '%c', '%+', '%Y-%m-%d %H/%M/%S%#z') +---- +2 + +# to_timestamp_millis with formatting +query I +SELECT COUNT(*) FROM ts_data_millis where ts > to_timestamp_millis('2020-09-08 12/00/00+00:00', '%c', '%+', '%Y-%m-%d %H/%M/%S%#z') +---- +2 + +# to_timestamp_micros with formatting +query I +SELECT COUNT(*) FROM ts_data_micros where ts > to_timestamp_micros('2020-09-08 12/00/00+00:00', '%c', '%+', '%Y-%m-%d %H/%M/%S%#z') +---- +2 + +# to_timestamp_seconds with formatting +query I +SELECT COUNT(*) FROM ts_data_secs where ts > to_timestamp_seconds('2020-09-08 12/00/00+00:00', '%c', '%+', '%Y-%m-%d %H/%M/%S%#z') +---- +2 # to_timestamp float inputs @@ -1880,7 +1909,7 @@ SELECT to_timestamp(null), to_timestamp(0), to_timestamp(1926632005), to_timesta ---- NULL 1970-01-01T00:00:00 2031-01-19T23:33:25 1970-01-01T00:00:01 1969-12-31T23:59:59 1969-12-31T23:59:59 -# verify timestamp syntax stlyes are consistent +# verify timestamp syntax styles are consistent query BBBBBBBBBBBBB SELECT to_timestamp(null) is null as c1, null::timestamp is null as c2, @@ -1922,6 +1951,116 @@ true true true true true true #---- #0001-04-25T00:00:00 +63022-07-16T12:59:37 0001-04-25T00:00:00 +63022-07-16T12:59:37 0001-04-25T00:00:00 +63022-07-16T12:59:37 +# verify timestamp data with formatting options +query PPPPPP +SELECT to_timestamp(null, '%+'), to_timestamp(0, '%s'), to_timestamp(1926632005, '%s'), to_timestamp(1, '%+', '%s'), to_timestamp(-1, '%c', '%+', '%s'), to_timestamp(0-1, '%c', '%+', '%s') +---- +NULL 1970-01-01T00:00:00 2031-01-19T23:33:25 1970-01-01T00:00:01 1969-12-31T23:59:59 1969-12-31T23:59:59 + +# verify timestamp data with formatting options +query PPPPPP +SELECT to_timestamp(null, '%+'), to_timestamp(0, '%s'), to_timestamp(1926632005, '%s'), to_timestamp(1, '%+', '%s'), to_timestamp(-1, '%c', '%+', '%s'), to_timestamp(0-1, '%c', '%+', '%s') +---- +NULL 1970-01-01T00:00:00 2031-01-19T23:33:25 1970-01-01T00:00:01 1969-12-31T23:59:59 1969-12-31T23:59:59 + +# verify timestamp output types with formatting options +query TTT +SELECT arrow_typeof(to_timestamp(1, '%c', '%s')), arrow_typeof(to_timestamp(null, '%+', '%s')), arrow_typeof(to_timestamp('2023-01-10 12:34:56.000', '%Y-%m-%d %H:%M:%S%.f')) +---- +Timestamp(Nanosecond, None) Timestamp(Nanosecond, None) Timestamp(Nanosecond, None) + +# to_timestamp with invalid formatting +query error input contains invalid characters +SELECT to_timestamp('2020-09-08 12/00/00+00:00', '%c', '%+') + +# to_timestamp_nanos with invalid formatting +query error input contains invalid characters +SELECT to_timestamp_nanos('2020-09-08 12/00/00+00:00', '%c', '%+') + +# to_timestamp_millis with invalid formatting +query error input contains invalid characters +SELECT to_timestamp_millis('2020-09-08 12/00/00+00:00', '%c', '%+') + +# to_timestamp_micros with invalid formatting +query error input contains invalid characters +SELECT to_timestamp_micros('2020-09-08 12/00/00+00:00', '%c', '%+') + +# to_timestamp_seconds with invalid formatting +query error input contains invalid characters +SELECT to_timestamp_seconds('2020-09-08 12/00/00+00:00', '%c', '%+') + +# to_timestamp with broken formatting +query error bad or unsupported format string +SELECT to_timestamp('2020-09-08 12/00/00+00:00', '%q') + +# to_timestamp_nanos with broken formatting +query error bad or unsupported format string +SELECT to_timestamp_nanos('2020-09-08 12/00/00+00:00', '%q') + +# to_timestamp_millis with broken formatting +query error bad or unsupported format string +SELECT to_timestamp_millis('2020-09-08 12/00/00+00:00', '%q') + +# to_timestamp_micros with broken formatting +query error bad or unsupported format string +SELECT to_timestamp_micros('2020-09-08 12/00/00+00:00', '%q') + +# to_timestamp_seconds with broken formatting +query error bad or unsupported format string +SELECT to_timestamp_seconds('2020-09-08 12/00/00+00:00', '%q') + +# Create string timestamp table with different formats +# including a few very non-standard formats + +statement ok +create table ts_utf8_data(ts varchar(100), format varchar(100)) as values + ('2020-09-08 12/00/00+00:00', '%Y-%m-%d %H/%M/%S%#z'), + ('2031-01-19T23:33:25+05:00', '%+'), + ('08-09-2020 12:00:00+00:00', '%d-%m-%Y %H:%M:%S%#z'), + ('1926632005', '%s'), + ('2000-01-01T01:01:01+07:00', '%+'); + +# verify timestamp data using tables with formatting options +query P +SELECT to_timestamp(t.ts, t.format) from ts_utf8_data as t +---- +2020-09-08T12:00:00 +2031-01-19T18:33:25 +2020-09-08T12:00:00 +2031-01-19T23:33:25 +1999-12-31T18:01:01 + +# verify timestamp data using tables with formatting options +query P +SELECT to_timestamp(t.ts, '%Y-%m-%d %H/%M/%S%#z', '%+', '%s', '%d-%m-%Y %H:%M:%S%#z') from ts_utf8_data as t +---- +2020-09-08T12:00:00 +2031-01-19T18:33:25 +2020-09-08T12:00:00 +2031-01-19T23:33:25 +1999-12-31T18:01:01 + +# verify timestamp data using tables with formatting options where at least one column cannot be parsed +query error Error parsing timestamp from '1926632005' using format '%d-%m-%Y %H:%M:%S%#z': input contains invalid characters +SELECT to_timestamp(t.ts, '%Y-%m-%d %H/%M/%S%#z', '%+', '%d-%m-%Y %H:%M:%S%#z') from ts_utf8_data as t + +# verify timestamp data using tables with formatting options where one of the formats is invalid +query P +SELECT to_timestamp(t.ts, '%Y-%m-%d %H/%M/%S%#z', '%s', '%q', '%d-%m-%Y %H:%M:%S%#z', '%+') from ts_utf8_data as t +---- +2020-09-08T12:00:00 +2031-01-19T18:33:25 +2020-09-08T12:00:00 +2031-01-19T23:33:25 +1999-12-31T18:01:01 + +# timestamp data using tables with formatting options in an array is not supported at this time +query error function unsupported data type at index 1: +SELECT to_timestamp(t.ts, make_array('%Y-%m-%d %H/%M/%S%#z', '%s', '%q', '%d-%m-%Y %H:%M:%S%#z', '%+')) from ts_utf8_data as t + +statement ok +drop table ts_utf8_data + ########## ## Test binary temporal coercion for Date and Timestamp ########## diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index 9dd008f8fc44..c72ef94f42ea 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -1471,84 +1471,107 @@ extract(field FROM source) Converts a value to a timestamp (`YYYY-MM-DDT00:00:00Z`). Supports strings, integer, unsigned integer, and double types as input. -Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') -Integers, unsigned integers, and doubles are interpreted as seconds since the unix epoch (`1970-01-01T00:00:00Z`) -return the corresponding timestamp. +Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono formats] are provided. +Integers, unsigned integers, and doubles are interpreted as seconds since the unix epoch (`1970-01-01T00:00:00Z`). +Returns the corresponding timestamp. Note: `to_timestamp` returns `Timestamp(Nanosecond)`. The supported range for integer input is between `-9223372037` and `9223372036`. -Supported range for string input is between `1677-09-21T00:12:44.0` and `2262-04-11T23:47:16.0`. Please use `to_timestamp_seconds` for the input outside of supported bounds. +Supported range for string input is between `1677-09-21T00:12:44.0` and `2262-04-11T23:47:16.0`. Please use `to_timestamp_seconds` +for the input outside of supported bounds. ``` -to_timestamp(expression) +to_timestamp(expression[, ..., format_n]) ``` #### Arguments - **expression**: Expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators. +- **format_n**: Optional [Chrono format] strings to use to parse the expression. Formats will be tried in the order + they appear with the first successful one being returned. If none of the formats successfully parse the expression + an error will be returned. + +[chrono format]: https://docs.rs/chrono/latest/chrono/format/strftime/index.html ### `to_timestamp_millis` Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000Z`). Supports strings, integer, and unsigned integer types as input. -Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') -Integers and unsigned integers are interpreted as milliseconds since the unix epoch (`1970-01-01T00:00:00Z`) -return the corresponding timestamp. +Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono format]s are provided. +Integers and unsigned integers are interpreted as milliseconds since the unix epoch (`1970-01-01T00:00:00Z`). +Returns the corresponding timestamp. ``` -to_timestamp_millis(expression) +to_timestamp_millis(expression[, ..., format_n]) ``` #### Arguments - **expression**: Expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators. +- **format_n**: Optional [Chrono format] strings to use to parse the expression. Formats will be tried in the order + they appear with the first successful one being returned. If none of the formats successfully parse the expression + an error will be returned. ### `to_timestamp_micros` Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000000Z`). Supports strings, integer, and unsigned integer types as input. -Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') +Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono format]s are provided. Integers and unsigned integers are interpreted as microseconds since the unix epoch (`1970-01-01T00:00:00Z`) -return the corresponding timestamp. +Returns the corresponding timestamp. ``` -to_timestamp_nanos(expression) +to_timestamp_micros(expression[, ..., format_n]) ``` +#### Arguments + +- **expression**: Expression to operate on. + Can be a constant, column, or function, and any combination of arithmetic operators. +- **format_n**: Optional [Chrono format] strings to use to parse the expression. Formats will be tried in the order + they appear with the first successful one being returned. If none of the formats successfully parse the expression + an error will be returned. + ### `to_timestamp_nanos` Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000000000Z`). Supports strings, integer, and unsigned integer types as input. -Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') -Integers and unsigned integers are interpreted as nanoseconds since the unix epoch (`1970-01-01T00:00:00Z`) -return the corresponding timestamp. +Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono format]s are provided. +Integers and unsigned integers are interpreted as nanoseconds since the unix epoch (`1970-01-01T00:00:00Z`). +Returns the corresponding timestamp. ``` -to_timestamp_nanos(expression) +to_timestamp_nanos(expression[, ..., format_n]) ``` #### Arguments - **expression**: Expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators. +- **format_n**: Optional [Chrono format] strings to use to parse the expression. Formats will be tried in the order + they appear with the first successful one being returned. If none of the formats successfully parse the expression + an error will be returned. ### `to_timestamp_seconds` Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000Z`). Supports strings, integer, and unsigned integer types as input. -Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') -Integers and unsigned integers are interpreted as seconds since the unix epoch (`1970-01-01T00:00:00Z`) -return the corresponding timestamp. +Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono format]s are provided. +Integers and unsigned integers are interpreted as seconds since the unix epoch (`1970-01-01T00:00:00Z`). +Returns the corresponding timestamp. ``` -to_timestamp_seconds(expression) +to_timestamp_seconds(expression[, ..., format_n]) ``` #### Arguments - **expression**: Expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators. +- **format_n**: Optional [Chrono format] strings to use to parse the expression. Formats will be tried in the order + they appear with the first successful one being returned. If none of the formats successfully parse the expression + an error will be returned. ### `from_unixtime` From a4a94291268f5d9cab094618bb1dffa6bba9b290 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sat, 20 Jan 2024 05:34:52 -0500 Subject: [PATCH 28/39] Minor: Document third argument of `date_bin` as optional and default value (#8912) * Minor: Document third argument of date_bin as optional # Rationale @mhilton noticed this as part of an internal discussion where we were confused about the output values of `date_bin` for timestamps with timezones (as the default argument is for the unix EPOCH in UTC) # Changes: Document the third parameter to `date_bin` better * prettier --- docs/source/user-guide/sql/scalar_functions.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index c72ef94f42ea..50e1cbc3d622 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -1347,7 +1347,8 @@ date_bin(interval, expression, origin-timestamp) - **interval**: Bin interval. - **expression**: Time expression to operate on. Can be a constant, column, or function. -- **timestamp**: Starting point used to determine bin boundaries. +- **origin-timestamp**: Optional. Starting point used to determine bin boundaries. If not specified + defaults `1970-01-01T00:00:00Z` (the UNIX epoch in UTC). The following intervals are supported: From 95e739cb605307d3337c54ef3f0ab8c72cca5717 Mon Sep 17 00:00:00 2001 From: Yang Jiang Date: Sat, 20 Jan 2024 18:52:53 +0800 Subject: [PATCH 29/39] Minor: distinguish parquet row group pruning test type (#8921) --- .../core/tests/parquet/row_group_pruning.rs | 122 ++++++++++++------ 1 file changed, 83 insertions(+), 39 deletions(-) diff --git a/datafusion/core/tests/parquet/row_group_pruning.rs b/datafusion/core/tests/parquet/row_group_pruning.rs index 2bc5bd3f1ca7..fc1b66efed87 100644 --- a/datafusion/core/tests/parquet/row_group_pruning.rs +++ b/datafusion/core/tests/parquet/row_group_pruning.rs @@ -26,11 +26,12 @@ use crate::parquet::Unit::RowGroup; use crate::parquet::{ContextWithParquet, Scenario}; use datafusion_expr::{col, lit}; -async fn test_prune( +async fn test_row_group_prune( case_data_type: Scenario, sql: &str, expected_errors: Option, - expected_row_group_pruned: Option, + expected_row_group_pruned_by_statistics: Option, + expected_row_group_pruned_by_bloom_filter: Option, expected_results: usize, ) { let output = ContextWithParquet::new(case_data_type, RowGroup) @@ -40,7 +41,14 @@ async fn test_prune( println!("{}", output.description()); assert_eq!(output.predicate_evaluation_errors(), expected_errors); - assert_eq!(output.row_groups_pruned(), expected_row_group_pruned); + assert_eq!( + output.row_groups_pruned_statistics(), + expected_row_group_pruned_by_statistics + ); + assert_eq!( + output.row_groups_pruned_bloom_filter(), + expected_row_group_pruned_by_bloom_filter + ); assert_eq!( output.result_rows, expected_results, @@ -83,11 +91,12 @@ async fn test_prune_verbose( #[tokio::test] async fn prune_timestamps_nanos() { - test_prune( + test_row_group_prune( Scenario::Timestamps, "SELECT * FROM t where nanos < to_timestamp('2020-01-02 01:01:11Z')", Some(0), Some(1), + Some(0), 10, ) .await; @@ -95,11 +104,12 @@ async fn prune_timestamps_nanos() { #[tokio::test] async fn prune_timestamps_micros() { - test_prune( + test_row_group_prune( Scenario::Timestamps, "SELECT * FROM t where micros < to_timestamp_micros('2020-01-02 01:01:11Z')", Some(0), Some(1), + Some(0), 10, ) .await; @@ -107,11 +117,12 @@ async fn prune_timestamps_micros() { #[tokio::test] async fn prune_timestamps_millis() { - test_prune( + test_row_group_prune( Scenario::Timestamps, "SELECT * FROM t where millis < to_timestamp_millis('2020-01-02 01:01:11Z')", Some(0), Some(1), + Some(0), 10, ) .await; @@ -119,11 +130,12 @@ async fn prune_timestamps_millis() { #[tokio::test] async fn prune_timestamps_seconds() { - test_prune( + test_row_group_prune( Scenario::Timestamps, "SELECT * FROM t where seconds < to_timestamp_seconds('2020-01-02 01:01:11Z')", Some(0), Some(1), + Some(0), 10, ) .await; @@ -131,11 +143,12 @@ async fn prune_timestamps_seconds() { #[tokio::test] async fn prune_date32() { - test_prune( + test_row_group_prune( Scenario::Dates, "SELECT * FROM t where date32 < cast('2020-01-02' as date)", Some(0), Some(3), + Some(0), 1, ) .await; @@ -168,11 +181,12 @@ async fn prune_date64() { #[tokio::test] async fn prune_disabled() { - test_prune( + test_row_group_prune( Scenario::Timestamps, "SELECT * FROM t where nanos < to_timestamp('2020-01-02 01:01:11Z')", Some(0), Some(1), + Some(0), 10, ) .await; @@ -201,21 +215,23 @@ async fn prune_disabled() { #[tokio::test] async fn prune_int32_lt() { - test_prune( + test_row_group_prune( Scenario::Int32, "SELECT * FROM t where i < 1", Some(0), Some(1), + Some(0), 11, ) .await; // result of sql "SELECT * FROM t where i < 1" is same as // "SELECT * FROM t where -i > -1" - test_prune( + test_row_group_prune( Scenario::Int32, "SELECT * FROM t where -i > -1", Some(0), Some(1), + Some(0), 11, ) .await; @@ -223,22 +239,24 @@ async fn prune_int32_lt() { #[tokio::test] async fn prune_int32_eq() { - test_prune( + test_row_group_prune( Scenario::Int32, "SELECT * FROM t where i = 1", Some(0), Some(3), + Some(0), 1, ) .await; } #[tokio::test] async fn prune_int32_scalar_fun_and_eq() { - test_prune( + test_row_group_prune( Scenario::Int32, "SELECT * FROM t where abs(i) = 1 and i = 1", Some(0), Some(3), + Some(0), 1, ) .await; @@ -246,11 +264,12 @@ async fn prune_int32_scalar_fun_and_eq() { #[tokio::test] async fn prune_int32_scalar_fun() { - test_prune( + test_row_group_prune( Scenario::Int32, "SELECT * FROM t where abs(i) = 1", Some(0), Some(0), + Some(0), 3, ) .await; @@ -258,11 +277,12 @@ async fn prune_int32_scalar_fun() { #[tokio::test] async fn prune_int32_complex_expr() { - test_prune( + test_row_group_prune( Scenario::Int32, "SELECT * FROM t where i+1 = 1", Some(0), Some(0), + Some(0), 2, ) .await; @@ -270,11 +290,12 @@ async fn prune_int32_complex_expr() { #[tokio::test] async fn prune_int32_complex_expr_subtract() { - test_prune( + test_row_group_prune( Scenario::Int32, "SELECT * FROM t where 1-i > 1", Some(0), Some(0), + Some(0), 9, ) .await; @@ -282,19 +303,21 @@ async fn prune_int32_complex_expr_subtract() { #[tokio::test] async fn prune_f64_lt() { - test_prune( + test_row_group_prune( Scenario::Float64, "SELECT * FROM t where f < 1", Some(0), Some(1), + Some(0), 11, ) .await; - test_prune( + test_row_group_prune( Scenario::Float64, "SELECT * FROM t where -f > -1", Some(0), Some(1), + Some(0), 11, ) .await; @@ -304,11 +327,12 @@ async fn prune_f64_lt() { async fn prune_f64_scalar_fun_and_gt() { // result of sql "SELECT * FROM t where abs(f - 1) <= 0.000001 and f >= 0.1" // only use "f >= 0" to prune - test_prune( + test_row_group_prune( Scenario::Float64, "SELECT * FROM t where abs(f - 1) <= 0.000001 and f >= 0.1", Some(0), Some(2), + Some(0), 1, ) .await; @@ -317,11 +341,12 @@ async fn prune_f64_scalar_fun_and_gt() { #[tokio::test] async fn prune_f64_scalar_fun() { // result of sql "SELECT * FROM t where abs(f-1) <= 0.000001" is not supported - test_prune( + test_row_group_prune( Scenario::Float64, "SELECT * FROM t where abs(f-1) <= 0.000001", Some(0), Some(0), + Some(0), 1, ) .await; @@ -330,11 +355,12 @@ async fn prune_f64_scalar_fun() { #[tokio::test] async fn prune_f64_complex_expr() { // result of sql "SELECT * FROM t where f+1 > 1.1"" is not supported - test_prune( + test_row_group_prune( Scenario::Float64, "SELECT * FROM t where f+1 > 1.1", Some(0), Some(0), + Some(0), 9, ) .await; @@ -343,11 +369,12 @@ async fn prune_f64_complex_expr() { #[tokio::test] async fn prune_f64_complex_expr_subtract() { // result of sql "SELECT * FROM t where 1-f > 1" is not supported - test_prune( + test_row_group_prune( Scenario::Float64, "SELECT * FROM t where 1-f > 1", Some(0), Some(0), + Some(0), 9, ) .await; @@ -356,11 +383,12 @@ async fn prune_f64_complex_expr_subtract() { #[tokio::test] async fn prune_int32_eq_in_list() { // result of sql "SELECT * FROM t where in (1)" - test_prune( + test_row_group_prune( Scenario::Int32, "SELECT * FROM t where i in (1)", Some(0), Some(3), + Some(0), 1, ) .await; @@ -404,11 +432,12 @@ async fn prune_int32_eq_large_in_list() { #[tokio::test] async fn prune_int32_eq_in_list_negated() { // result of sql "SELECT * FROM t where not in (1)" prune nothing - test_prune( + test_row_group_prune( Scenario::Int32, "SELECT * FROM t where i not in (1)", Some(0), Some(0), + Some(0), 19, ) .await; @@ -419,39 +448,43 @@ async fn prune_decimal_lt() { // The data type of decimal_col is decimal(9,2) // There are three row groups: // [1.00, 6.00], [-5.00,6.00], [20.00,60.00] - test_prune( + test_row_group_prune( Scenario::Decimal, "SELECT * FROM t where decimal_col < 4", Some(0), Some(1), + Some(0), 6, ) .await; // compare with the casted decimal value - test_prune( + test_row_group_prune( Scenario::Decimal, "SELECT * FROM t where decimal_col < cast(4.55 as decimal(20,2))", Some(0), Some(1), + Some(0), 8, ) .await; // The data type of decimal_col is decimal(38,2) - test_prune( + test_row_group_prune( Scenario::DecimalLargePrecision, "SELECT * FROM t where decimal_col < 4", Some(0), Some(1), + Some(0), 6, ) .await; // compare with the casted decimal value - test_prune( + test_row_group_prune( Scenario::DecimalLargePrecision, "SELECT * FROM t where decimal_col < cast(4.55 as decimal(20,2))", Some(0), Some(1), + Some(0), 8, ) .await; @@ -462,37 +495,41 @@ async fn prune_decimal_eq() { // The data type of decimal_col is decimal(9,2) // There are three row groups: // [1.00, 6.00], [-5.00,6.00], [20.00,60.00] - test_prune( + test_row_group_prune( Scenario::Decimal, "SELECT * FROM t where decimal_col = 4", Some(0), Some(1), + Some(0), 2, ) .await; - test_prune( + test_row_group_prune( Scenario::Decimal, "SELECT * FROM t where decimal_col = 4.00", Some(0), Some(1), + Some(0), 2, ) .await; // The data type of decimal_col is decimal(38,2) - test_prune( + test_row_group_prune( Scenario::DecimalLargePrecision, "SELECT * FROM t where decimal_col = 4", Some(0), Some(1), + Some(0), 2, ) .await; - test_prune( + test_row_group_prune( Scenario::DecimalLargePrecision, "SELECT * FROM t where decimal_col = 4.00", Some(0), Some(1), + Some(0), 2, ) .await; @@ -503,37 +540,41 @@ async fn prune_decimal_in_list() { // The data type of decimal_col is decimal(9,2) // There are three row groups: // [1.00, 6.00], [-5.00,6.00], [20.00,60.00] - test_prune( + test_row_group_prune( Scenario::Decimal, "SELECT * FROM t where decimal_col in (4,3,2,123456789123)", Some(0), Some(1), + Some(0), 5, ) .await; - test_prune( + test_row_group_prune( Scenario::Decimal, "SELECT * FROM t where decimal_col in (4.00,3.00,11.2345,1)", Some(0), Some(1), + Some(0), 6, ) .await; // The data type of decimal_col is decimal(38,2) - test_prune( + test_row_group_prune( Scenario::DecimalLargePrecision, "SELECT * FROM t where decimal_col in (4,3,2,123456789123)", Some(0), Some(1), + Some(0), 5, ) .await; - test_prune( + test_row_group_prune( Scenario::DecimalLargePrecision, "SELECT * FROM t where decimal_col in (4.00,3.00,11.2345,1)", Some(0), Some(1), + Some(0), 6, ) .await; @@ -545,28 +586,31 @@ async fn prune_periods_in_column_names() { // name = "HTTP GET / DISPATCH", service.name = ['frontend', 'frontend'], // name = "HTTP PUT / DISPATCH", service.name = ['backend', 'frontend'], // name = "HTTP GET / DISPATCH", service.name = ['backend', 'backend' ], - test_prune( + test_row_group_prune( Scenario::PeriodsInColumnNames, // use double quotes to use column named "service.name" "SELECT \"name\", \"service.name\" FROM t WHERE \"service.name\" = 'frontend'", Some(0), Some(1), // prune out last row group + Some(0), 7, ) .await; - test_prune( + test_row_group_prune( Scenario::PeriodsInColumnNames, "SELECT \"name\", \"service.name\" FROM t WHERE \"name\" != 'HTTP GET / DISPATCH'", Some(0), Some(2), // prune out first and last row group + Some(0), 5, ) .await; - test_prune( + test_row_group_prune( Scenario::PeriodsInColumnNames, "SELECT \"name\", \"service.name\" FROM t WHERE \"service.name\" = 'frontend' AND \"name\" != 'HTTP GET / DISPATCH'", Some(0), Some(2), // prune out middle and last row group + Some(0), 2, ) .await; From f5a97d58d484e93c2f79c5b624b10bd9e75c45f0 Mon Sep 17 00:00:00 2001 From: Eugene Marushchenko Date: Sun, 21 Jan 2024 02:35:00 +1000 Subject: [PATCH 30/39] Add hash_join_single_partition_threshold_rows config (#8720) Co-authored-by: Andrew Lamb --- datafusion/common/src/config.rs | 4 + .../src/physical_optimizer/join_selection.rs | 284 ++++++++++-------- .../test_files/information_schema.slt | 2 + docs/source/user-guide/configs.md | 1 + 4 files changed, 165 insertions(+), 126 deletions(-) diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index e00c17930850..eb516f97a48f 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -561,6 +561,10 @@ config_namespace! { /// will be collected into a single partition pub hash_join_single_partition_threshold: usize, default = 1024 * 1024 + /// The maximum estimated size in rows for one input side of a HashJoin + /// will be collected into a single partition + pub hash_join_single_partition_threshold_rows: usize, default = 1024 * 128 + /// The default filter selectivity used by Filter Statistics /// when an exact selectivity cannot be determined. Valid values are /// between 0 (no selectivity) and 100 (all rows are selected). diff --git a/datafusion/core/src/physical_optimizer/join_selection.rs b/datafusion/core/src/physical_optimizer/join_selection.rs index ba66dca55b35..f9b9fdf85cfa 100644 --- a/datafusion/core/src/physical_optimizer/join_selection.rs +++ b/datafusion/core/src/physical_optimizer/join_selection.rs @@ -87,9 +87,10 @@ fn should_swap_join_order( } } -fn supports_collect_by_size( +fn supports_collect_by_thresholds( plan: &dyn ExecutionPlan, - collection_size_threshold: usize, + threshold_byte_size: usize, + threshold_num_rows: usize, ) -> bool { // Currently we do not trust the 0 value from stats, due to stats collection might have bug // TODO check the logic in datasource::get_statistics_with_limit() @@ -97,10 +98,10 @@ fn supports_collect_by_size( return false; }; - if let Some(size) = stats.total_byte_size.get_value() { - *size != 0 && *size < collection_size_threshold - } else if let Some(row_count) = stats.num_rows.get_value() { - *row_count != 0 && *row_count < collection_size_threshold + if let Some(byte_size) = stats.total_byte_size.get_value() { + *byte_size != 0 && *byte_size < threshold_byte_size + } else if let Some(num_rows) = stats.num_rows.get_value() { + *num_rows != 0 && *num_rows < threshold_num_rows } else { false } @@ -251,9 +252,14 @@ impl PhysicalOptimizerRule for JoinSelection { // - We will also swap left and right sides for cross joins so that the left // side is the small side. let config = &config.optimizer; - let collect_left_threshold = config.hash_join_single_partition_threshold; + let collect_threshold_byte_size = config.hash_join_single_partition_threshold; + let collect_threshold_num_rows = config.hash_join_single_partition_threshold_rows; state.plan.transform_up(&|plan| { - statistical_join_selection_subrule(plan, collect_left_threshold) + statistical_join_selection_subrule( + plan, + collect_threshold_byte_size, + collect_threshold_num_rows, + ) }) } @@ -270,8 +276,8 @@ impl PhysicalOptimizerRule for JoinSelection { /// /// This function will first consider the given join type and check whether the /// `CollectLeft` mode is applicable. Otherwise, it will try to swap the join sides. -/// When the `collect_threshold` is provided, this function will also check left -/// and right sizes. +/// When the `ignore_threshold` is false, this function will also check left +/// and right sizes in bytes or rows. /// /// For [`JoinType::Full`], it can not use `CollectLeft` mode and will return `None`. /// For [`JoinType::Left`] and [`JoinType::LeftAnti`], it can not run `CollectLeft` @@ -279,7 +285,9 @@ impl PhysicalOptimizerRule for JoinSelection { /// and [`JoinType::RightAnti`], respectively. fn try_collect_left( hash_join: &HashJoinExec, - collect_threshold: Option, + ignore_threshold: bool, + threshold_byte_size: usize, + threshold_num_rows: usize, ) -> Result>> { let left = hash_join.left(); let right = hash_join.right(); @@ -291,9 +299,14 @@ fn try_collect_left( | JoinType::LeftSemi | JoinType::Right | JoinType::RightSemi - | JoinType::RightAnti => collect_threshold.map_or(true, |threshold| { - supports_collect_by_size(&**left, threshold) - }), + | JoinType::RightAnti => { + ignore_threshold + || supports_collect_by_thresholds( + &**left, + threshold_byte_size, + threshold_num_rows, + ) + } }; let right_can_collect = match join_type { JoinType::Right | JoinType::Full | JoinType::RightAnti => false, @@ -301,9 +314,14 @@ fn try_collect_left( | JoinType::RightSemi | JoinType::Left | JoinType::LeftSemi - | JoinType::LeftAnti => collect_threshold.map_or(true, |threshold| { - supports_collect_by_size(&**right, threshold) - }), + | JoinType::LeftAnti => { + ignore_threshold + || supports_collect_by_thresholds( + &**right, + threshold_byte_size, + threshold_num_rows, + ) + } }; match (left_can_collect, right_can_collect) { (true, true) => { @@ -366,52 +384,56 @@ fn partitioned_hash_join(hash_join: &HashJoinExec) -> Result, - collect_left_threshold: usize, + collect_threshold_byte_size: usize, + collect_threshold_num_rows: usize, ) -> Result>> { - let transformed = if let Some(hash_join) = - plan.as_any().downcast_ref::() - { - match hash_join.partition_mode() { - PartitionMode::Auto => { - try_collect_left(hash_join, Some(collect_left_threshold))?.map_or_else( - || partitioned_hash_join(hash_join).map(Some), - |v| Ok(Some(v)), + let transformed = + if let Some(hash_join) = plan.as_any().downcast_ref::() { + match hash_join.partition_mode() { + PartitionMode::Auto => try_collect_left( + hash_join, + false, + collect_threshold_byte_size, + collect_threshold_num_rows, )? - } - PartitionMode::CollectLeft => try_collect_left(hash_join, None)? .map_or_else( || partitioned_hash_join(hash_join).map(Some), |v| Ok(Some(v)), )?, - PartitionMode::Partitioned => { - let left = hash_join.left(); - let right = hash_join.right(); - if should_swap_join_order(&**left, &**right)? - && supports_swap(*hash_join.join_type()) - { - swap_hash_join(hash_join, PartitionMode::Partitioned).map(Some)? - } else { - None + PartitionMode::CollectLeft => try_collect_left(hash_join, true, 0, 0)? + .map_or_else( + || partitioned_hash_join(hash_join).map(Some), + |v| Ok(Some(v)), + )?, + PartitionMode::Partitioned => { + let left = hash_join.left(); + let right = hash_join.right(); + if should_swap_join_order(&**left, &**right)? + && supports_swap(*hash_join.join_type()) + { + swap_hash_join(hash_join, PartitionMode::Partitioned).map(Some)? + } else { + None + } } } - } - } else if let Some(cross_join) = plan.as_any().downcast_ref::() { - let left = cross_join.left(); - let right = cross_join.right(); - if should_swap_join_order(&**left, &**right)? { - let new_join = CrossJoinExec::new(Arc::clone(right), Arc::clone(left)); - // TODO avoid adding ProjectionExec again and again, only adding Final Projection - let proj: Arc = Arc::new(ProjectionExec::try_new( - swap_reverting_projection(&left.schema(), &right.schema()), - Arc::new(new_join), - )?); - Some(proj) + } else if let Some(cross_join) = plan.as_any().downcast_ref::() { + let left = cross_join.left(); + let right = cross_join.right(); + if should_swap_join_order(&**left, &**right)? { + let new_join = CrossJoinExec::new(Arc::clone(right), Arc::clone(left)); + // TODO avoid adding ProjectionExec again and again, only adding Final Projection + let proj: Arc = Arc::new(ProjectionExec::try_new( + swap_reverting_projection(&left.schema(), &right.schema()), + Arc::new(new_join), + )?); + Some(proj) + } else { + None + } } else { None - } - } else { - None - }; + }; Ok(if let Some(transformed) = transformed { Transformed::Yes(transformed) @@ -682,22 +704,62 @@ mod tests_statistical { use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::PhysicalExpr; + /// Return statistcs for empty table + fn empty_statistics() -> Statistics { + Statistics { + num_rows: Precision::Absent, + total_byte_size: Precision::Absent, + column_statistics: vec![ColumnStatistics::new_unknown()], + } + } + + /// Get table thresholds: (num_rows, byte_size) + fn get_thresholds() -> (usize, usize) { + let optimizer_options = ConfigOptions::new().optimizer; + ( + optimizer_options.hash_join_single_partition_threshold_rows, + optimizer_options.hash_join_single_partition_threshold, + ) + } + + /// Return statistcs for small table + fn small_statistics() -> Statistics { + let (threshold_num_rows, threshold_byte_size) = get_thresholds(); + Statistics { + num_rows: Precision::Inexact(threshold_num_rows / 128), + total_byte_size: Precision::Inexact(threshold_byte_size / 128), + column_statistics: vec![ColumnStatistics::new_unknown()], + } + } + + /// Return statistcs for big table + fn big_statistics() -> Statistics { + let (threshold_num_rows, threshold_byte_size) = get_thresholds(); + Statistics { + num_rows: Precision::Inexact(threshold_num_rows * 2), + total_byte_size: Precision::Inexact(threshold_byte_size * 2), + column_statistics: vec![ColumnStatistics::new_unknown()], + } + } + + /// Return statistcs for big table + fn bigger_statistics() -> Statistics { + let (threshold_num_rows, threshold_byte_size) = get_thresholds(); + Statistics { + num_rows: Precision::Inexact(threshold_num_rows * 4), + total_byte_size: Precision::Inexact(threshold_byte_size * 4), + column_statistics: vec![ColumnStatistics::new_unknown()], + } + } + fn create_big_and_small() -> (Arc, Arc) { let big = Arc::new(StatisticsExec::new( - Statistics { - num_rows: Precision::Inexact(10), - total_byte_size: Precision::Inexact(100000), - column_statistics: vec![ColumnStatistics::new_unknown()], - }, + big_statistics(), Schema::new(vec![Field::new("big_col", DataType::Int32, false)]), )); let small = Arc::new(StatisticsExec::new( - Statistics { - num_rows: Precision::Inexact(100000), - total_byte_size: Precision::Inexact(10), - column_statistics: vec![ColumnStatistics::new_unknown()], - }, + small_statistics(), Schema::new(vec![Field::new("small_col", DataType::Int32, false)]), )); (big, small) @@ -821,11 +883,11 @@ mod tests_statistical { assert_eq!( swapped_join.left().statistics().unwrap().total_byte_size, - Precision::Inexact(10) + Precision::Inexact(8192) ); assert_eq!( swapped_join.right().statistics().unwrap().total_byte_size, - Precision::Inexact(100000) + Precision::Inexact(2097152) ); } @@ -872,11 +934,11 @@ mod tests_statistical { assert_eq!( swapped_join.left().statistics().unwrap().total_byte_size, - Precision::Inexact(100000) + Precision::Inexact(2097152) ); assert_eq!( swapped_join.right().statistics().unwrap().total_byte_size, - Precision::Inexact(10) + Precision::Inexact(8192) ); } @@ -917,11 +979,11 @@ mod tests_statistical { assert_eq!( swapped_join.left().statistics().unwrap().total_byte_size, - Precision::Inexact(10) + Precision::Inexact(8192) ); assert_eq!( swapped_join.right().statistics().unwrap().total_byte_size, - Precision::Inexact(100000) + Precision::Inexact(2097152) ); assert_eq!(original_schema, swapped_join.schema()); @@ -1032,11 +1094,11 @@ mod tests_statistical { assert_eq!( swapped_join.left().statistics().unwrap().total_byte_size, - Precision::Inexact(10) + Precision::Inexact(8192) ); assert_eq!( swapped_join.right().statistics().unwrap().total_byte_size, - Precision::Inexact(100000) + Precision::Inexact(2097152) ); } @@ -1078,29 +1140,17 @@ mod tests_statistical { #[tokio::test] async fn test_join_selection_collect_left() { let big = Arc::new(StatisticsExec::new( - Statistics { - num_rows: Precision::Inexact(10000000), - total_byte_size: Precision::Inexact(10000000), - column_statistics: vec![ColumnStatistics::new_unknown()], - }, + big_statistics(), Schema::new(vec![Field::new("big_col", DataType::Int32, false)]), )); let small = Arc::new(StatisticsExec::new( - Statistics { - num_rows: Precision::Inexact(10), - total_byte_size: Precision::Inexact(10), - column_statistics: vec![ColumnStatistics::new_unknown()], - }, + small_statistics(), Schema::new(vec![Field::new("small_col", DataType::Int32, false)]), )); let empty = Arc::new(StatisticsExec::new( - Statistics { - num_rows: Precision::Absent, - total_byte_size: Precision::Absent, - column_statistics: vec![ColumnStatistics::new_unknown()], - }, + empty_statistics(), Schema::new(vec![Field::new("empty_col", DataType::Int32, false)]), )); @@ -1121,7 +1171,7 @@ mod tests_statistical { Column::new_with_schema("small_col", &small.schema()).unwrap(), )]; check_join_partition_mode( - big, + big.clone(), small.clone(), join_on, true, @@ -1145,8 +1195,8 @@ mod tests_statistical { Column::new_with_schema("small_col", &small.schema()).unwrap(), )]; check_join_partition_mode( - empty, - small, + empty.clone(), + small.clone(), join_on, true, PartitionMode::CollectLeft, @@ -1155,52 +1205,40 @@ mod tests_statistical { #[tokio::test] async fn test_join_selection_partitioned() { - let big1 = Arc::new(StatisticsExec::new( - Statistics { - num_rows: Precision::Inexact(10000000), - total_byte_size: Precision::Inexact(10000000), - column_statistics: vec![ColumnStatistics::new_unknown()], - }, - Schema::new(vec![Field::new("big_col1", DataType::Int32, false)]), + let bigger = Arc::new(StatisticsExec::new( + bigger_statistics(), + Schema::new(vec![Field::new("bigger_col", DataType::Int32, false)]), )); - let big2 = Arc::new(StatisticsExec::new( - Statistics { - num_rows: Precision::Inexact(20000000), - total_byte_size: Precision::Inexact(20000000), - column_statistics: vec![ColumnStatistics::new_unknown()], - }, - Schema::new(vec![Field::new("big_col2", DataType::Int32, false)]), + let big = Arc::new(StatisticsExec::new( + big_statistics(), + Schema::new(vec![Field::new("big_col", DataType::Int32, false)]), )); let empty = Arc::new(StatisticsExec::new( - Statistics { - num_rows: Precision::Absent, - total_byte_size: Precision::Absent, - column_statistics: vec![ColumnStatistics::new_unknown()], - }, + empty_statistics(), Schema::new(vec![Field::new("empty_col", DataType::Int32, false)]), )); let join_on = vec![( - Column::new_with_schema("big_col1", &big1.schema()).unwrap(), - Column::new_with_schema("big_col2", &big2.schema()).unwrap(), + Column::new_with_schema("big_col", &big.schema()).unwrap(), + Column::new_with_schema("bigger_col", &bigger.schema()).unwrap(), )]; check_join_partition_mode( - big1.clone(), - big2.clone(), + big.clone(), + bigger.clone(), join_on, false, PartitionMode::Partitioned, ); let join_on = vec![( - Column::new_with_schema("big_col2", &big2.schema()).unwrap(), - Column::new_with_schema("big_col1", &big1.schema()).unwrap(), + Column::new_with_schema("bigger_col", &bigger.schema()).unwrap(), + Column::new_with_schema("big_col", &big.schema()).unwrap(), )]; check_join_partition_mode( - big2, - big1.clone(), + bigger.clone(), + big.clone(), join_on, true, PartitionMode::Partitioned, @@ -1208,27 +1246,21 @@ mod tests_statistical { let join_on = vec![( Column::new_with_schema("empty_col", &empty.schema()).unwrap(), - Column::new_with_schema("big_col1", &big1.schema()).unwrap(), + Column::new_with_schema("big_col", &big.schema()).unwrap(), )]; check_join_partition_mode( empty.clone(), - big1.clone(), + big.clone(), join_on, false, PartitionMode::Partitioned, ); let join_on = vec![( - Column::new_with_schema("big_col1", &big1.schema()).unwrap(), + Column::new_with_schema("big_col", &big.schema()).unwrap(), Column::new_with_schema("empty_col", &empty.schema()).unwrap(), )]; - check_join_partition_mode( - big1, - empty, - join_on, - false, - PartitionMode::Partitioned, - ); + check_join_partition_mode(big, empty, join_on, false, PartitionMode::Partitioned); } fn check_join_partition_mode( diff --git a/datafusion/sqllogictest/test_files/information_schema.slt b/datafusion/sqllogictest/test_files/information_schema.slt index b37b78ab6d79..768292d3d4b4 100644 --- a/datafusion/sqllogictest/test_files/information_schema.slt +++ b/datafusion/sqllogictest/test_files/information_schema.slt @@ -196,6 +196,7 @@ datafusion.optimizer.enable_round_robin_repartition true datafusion.optimizer.enable_topk_aggregation true datafusion.optimizer.filter_null_join_keys false datafusion.optimizer.hash_join_single_partition_threshold 1048576 +datafusion.optimizer.hash_join_single_partition_threshold_rows 131072 datafusion.optimizer.max_passes 3 datafusion.optimizer.prefer_existing_sort false datafusion.optimizer.prefer_hash_join true @@ -272,6 +273,7 @@ datafusion.optimizer.enable_round_robin_repartition true When set to true, the p datafusion.optimizer.enable_topk_aggregation true When set to true, the optimizer will attempt to perform limit operations during aggregations, if possible datafusion.optimizer.filter_null_join_keys false When set to true, the optimizer will insert filters before a join between a nullable and non-nullable column to filter out nulls on the nullable side. This filter can add additional overhead when the file format does not fully support predicate push down. datafusion.optimizer.hash_join_single_partition_threshold 1048576 The maximum estimated size in bytes for one input side of a HashJoin will be collected into a single partition +datafusion.optimizer.hash_join_single_partition_threshold_rows 131072 The maximum estimated size in rows for one input side of a HashJoin will be collected into a single partition datafusion.optimizer.max_passes 3 Number of times that the optimizer will attempt to optimize the plan datafusion.optimizer.prefer_existing_sort false When true, DataFusion will opportunistically remove sorts when the data is already sorted, (i.e. setting `preserve_order` to true on `RepartitionExec` and using `SortPreservingMergeExec`) When false, DataFusion will maximize plan parallelism using `RepartitionExec` even if this requires subsequently resorting data using a `SortExec`. datafusion.optimizer.prefer_hash_join true When set to true, the physical plan optimizer will prefer HashJoin over SortMergeJoin. HashJoin can work more efficiently than SortMergeJoin but consumes more memory diff --git a/docs/source/user-guide/configs.md b/docs/source/user-guide/configs.md index a812b74284cf..7a7460799b1a 100644 --- a/docs/source/user-guide/configs.md +++ b/docs/source/user-guide/configs.md @@ -101,6 +101,7 @@ Environment variables are read during `SessionConfig` initialisation so they mus | datafusion.optimizer.top_down_join_key_reordering | true | When set to true, the physical plan optimizer will run a top down process to reorder the join keys | | datafusion.optimizer.prefer_hash_join | true | When set to true, the physical plan optimizer will prefer HashJoin over SortMergeJoin. HashJoin can work more efficiently than SortMergeJoin but consumes more memory | | datafusion.optimizer.hash_join_single_partition_threshold | 1048576 | The maximum estimated size in bytes for one input side of a HashJoin will be collected into a single partition | +| datafusion.optimizer.hash_join_single_partition_threshold_rows | 131072 | The maximum estimated size in rows for one input side of a HashJoin will be collected into a single partition | | datafusion.optimizer.default_filter_selectivity | 20 | The default filter selectivity used by Filter Statistics when an exact selectivity cannot be determined. Valid values are between 0 (no selectivity) and 100 (all rows are selected). | | datafusion.explain.logical_plan_only | false | When set to true, the explain statement will only print logical plans | | datafusion.explain.physical_plan_only | false | When set to true, the explain statement will only print physical plans | From b7e13a0af711477ad41450566c14430089edd3f2 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sat, 20 Jan 2024 15:00:05 -0700 Subject: [PATCH 31/39] Prepare 35.0.0-rc1 (#8924) --- Cargo.toml | 24 +-- benchmarks/Cargo.toml | 8 +- datafusion-cli/Cargo.lock | 26 +-- datafusion-cli/Cargo.toml | 4 +- datafusion/CHANGELOG.md | 1 + datafusion/core/Cargo.toml | 6 +- datafusion/optimizer/Cargo.toml | 4 +- datafusion/proto/Cargo.toml | 2 +- datafusion/sqllogictest/Cargo.toml | 2 +- dev/changelog/35.0.0.md | 295 +++++++++++++++++++++++++++++ dev/release/README.md | 2 +- docs/Cargo.toml | 2 +- docs/source/user-guide/configs.md | 2 +- 13 files changed, 337 insertions(+), 41 deletions(-) create mode 100644 dev/changelog/35.0.0.md diff --git a/Cargo.toml b/Cargo.toml index cc1861677476..cd88e18fe17c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,7 +29,7 @@ license = "Apache-2.0" readme = "README.md" repository = "https://github.com/apache/arrow-datafusion" rust-version = "1.70" -version = "34.0.0" +version = "35.0.0" [workspace.dependencies] arrow = { version = "50.0.0", features = ["prettyprint"] } @@ -45,17 +45,17 @@ bytes = "1.4" chrono = { version = "0.4.31", default-features = false } ctor = "0.2.0" dashmap = "5.4.0" -datafusion = { path = "datafusion/core", version = "34.0.0" } -datafusion-common = { path = "datafusion/common", version = "34.0.0" } -datafusion-execution = { path = "datafusion/execution", version = "34.0.0" } -datafusion-expr = { path = "datafusion/expr", version = "34.0.0" } -datafusion-optimizer = { path = "datafusion/optimizer", version = "34.0.0" } -datafusion-physical-expr = { path = "datafusion/physical-expr", version = "34.0.0" } -datafusion-physical-plan = { path = "datafusion/physical-plan", version = "34.0.0" } -datafusion-proto = { path = "datafusion/proto", version = "34.0.0" } -datafusion-sql = { path = "datafusion/sql", version = "34.0.0" } -datafusion-sqllogictest = { path = "datafusion/sqllogictest", version = "34.0.0" } -datafusion-substrait = { path = "datafusion/substrait", version = "34.0.0" } +datafusion = { path = "datafusion/core", version = "35.0.0" } +datafusion-common = { path = "datafusion/common", version = "35.0.0" } +datafusion-execution = { path = "datafusion/execution", version = "35.0.0" } +datafusion-expr = { path = "datafusion/expr", version = "35.0.0" } +datafusion-optimizer = { path = "datafusion/optimizer", version = "35.0.0" } +datafusion-physical-expr = { path = "datafusion/physical-expr", version = "35.0.0" } +datafusion-physical-plan = { path = "datafusion/physical-plan", version = "35.0.0" } +datafusion-proto = { path = "datafusion/proto", version = "35.0.0" } +datafusion-sql = { path = "datafusion/sql", version = "35.0.0" } +datafusion-sqllogictest = { path = "datafusion/sqllogictest", version = "35.0.0" } +datafusion-substrait = { path = "datafusion/substrait", version = "35.0.0" } doc-comment = "0.3" env_logger = "0.10" futures = "0.3" diff --git a/benchmarks/Cargo.toml b/benchmarks/Cargo.toml index 94c1ebe7ee47..50b79b4b0661 100644 --- a/benchmarks/Cargo.toml +++ b/benchmarks/Cargo.toml @@ -18,7 +18,7 @@ [package] name = "datafusion-benchmarks" description = "DataFusion Benchmarks" -version = "34.0.0" +version = "35.0.0" edition = { workspace = true } authors = ["Apache Arrow "] homepage = "https://github.com/apache/arrow-datafusion" @@ -33,8 +33,8 @@ snmalloc = ["snmalloc-rs"] [dependencies] arrow = { workspace = true } -datafusion = { path = "../datafusion/core", version = "34.0.0" } -datafusion-common = { path = "../datafusion/common", version = "34.0.0" } +datafusion = { path = "../datafusion/core", version = "35.0.0" } +datafusion-common = { path = "../datafusion/common", version = "35.0.0" } env_logger = { workspace = true } futures = { workspace = true } log = { workspace = true } @@ -49,4 +49,4 @@ test-utils = { path = "../test-utils/", version = "0.1.0" } tokio = { version = "^1.0", features = ["macros", "rt", "rt-multi-thread", "parking_lot"] } [dev-dependencies] -datafusion-proto = { path = "../datafusion/proto", version = "34.0.0" } +datafusion-proto = { path = "../datafusion/proto", version = "35.0.0" } diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index db5913503aaa..c90b59b924f6 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1098,7 +1098,7 @@ dependencies = [ [[package]] name = "datafusion" -version = "34.0.0" +version = "35.0.0" dependencies = [ "ahash", "apache-avro", @@ -1146,7 +1146,7 @@ dependencies = [ [[package]] name = "datafusion-cli" -version = "34.0.0" +version = "35.0.0" dependencies = [ "arrow", "assert_cmd", @@ -1174,7 +1174,7 @@ dependencies = [ [[package]] name = "datafusion-common" -version = "34.0.0" +version = "35.0.0" dependencies = [ "ahash", "apache-avro", @@ -1193,7 +1193,7 @@ dependencies = [ [[package]] name = "datafusion-execution" -version = "34.0.0" +version = "35.0.0" dependencies = [ "arrow", "chrono", @@ -1212,7 +1212,7 @@ dependencies = [ [[package]] name = "datafusion-expr" -version = "34.0.0" +version = "35.0.0" dependencies = [ "ahash", "arrow", @@ -1226,7 +1226,7 @@ dependencies = [ [[package]] name = "datafusion-optimizer" -version = "34.0.0" +version = "35.0.0" dependencies = [ "arrow", "async-trait", @@ -1242,7 +1242,7 @@ dependencies = [ [[package]] name = "datafusion-physical-expr" -version = "34.0.0" +version = "35.0.0" dependencies = [ "ahash", "arrow", @@ -1274,7 +1274,7 @@ dependencies = [ [[package]] name = "datafusion-physical-plan" -version = "34.0.0" +version = "35.0.0" dependencies = [ "ahash", "arrow", @@ -1303,7 +1303,7 @@ dependencies = [ [[package]] name = "datafusion-sql" -version = "34.0.0" +version = "35.0.0" dependencies = [ "arrow", "arrow-schema", @@ -3092,9 +3092,9 @@ dependencies = [ [[package]] name = "smallvec" -version = "1.12.0" +version = "1.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2593d31f82ead8df961d8bd23a64c2ccf2eb5dd34b0a34bfb4dd54011c72009e" +checksum = "e6ecd384b10a64542d77071bd64bd7b231f4ed5940fba55e98c3de13824cf3d7" [[package]] name = "snafu" @@ -3621,9 +3621,9 @@ checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" [[package]] name = "uuid" -version = "1.6.1" +version = "1.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e395fcf16a7a3d8127ec99782007af141946b4795001f876d54fb0d55978560" +checksum = "f00cc9702ca12d3c81455259621e676d0f7251cec66a21e98fe2e9a37db93b2a" dependencies = [ "getrandom", "serde", diff --git a/datafusion-cli/Cargo.toml b/datafusion-cli/Cargo.toml index d084938030b1..07ee65e3f6cd 100644 --- a/datafusion-cli/Cargo.toml +++ b/datafusion-cli/Cargo.toml @@ -18,7 +18,7 @@ [package] name = "datafusion-cli" description = "Command Line Client for DataFusion query engine." -version = "34.0.0" +version = "35.0.0" authors = ["Apache Arrow "] edition = "2021" keywords = ["arrow", "datafusion", "query", "sql"] @@ -34,7 +34,7 @@ async-trait = "0.1.41" aws-config = "0.55" aws-credential-types = "0.55" clap = { version = "3", features = ["derive", "cargo"] } -datafusion = { path = "../datafusion/core", version = "34.0.0", features = ["avro", "crypto_expressions", "encoding_expressions", "parquet", "regex_expressions", "unicode_expressions", "compression"] } +datafusion = { path = "../datafusion/core", version = "35.0.0", features = ["avro", "crypto_expressions", "encoding_expressions", "parquet", "regex_expressions", "unicode_expressions", "compression"] } datafusion-common = { path = "../datafusion/common" } dirs = "4.0.0" env_logger = "0.9" diff --git a/datafusion/CHANGELOG.md b/datafusion/CHANGELOG.md index d64bbeda877d..ae9da0e865e9 100644 --- a/datafusion/CHANGELOG.md +++ b/datafusion/CHANGELOG.md @@ -19,6 +19,7 @@ # Changelog +- [35.0.0](../dev/changelog/35.0.0.md) - [34.0.0](../dev/changelog/34.0.0.md) - [33.0.0](../dev/changelog/33.0.0.md) - [32.0.0](../dev/changelog/32.0.0.md) diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index f5496d4c4700..69b18a326951 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -62,11 +62,11 @@ bytes = { workspace = true } bzip2 = { version = "0.4.3", optional = true } chrono = { workspace = true } dashmap = { workspace = true } -datafusion-common = { path = "../common", version = "34.0.0", features = ["object_store"], default-features = false } +datafusion-common = { path = "../common", version = "35.0.0", features = ["object_store"], default-features = false } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } -datafusion-optimizer = { path = "../optimizer", version = "34.0.0", default-features = false } -datafusion-physical-expr = { path = "../physical-expr", version = "34.0.0", default-features = false } +datafusion-optimizer = { path = "../optimizer", version = "35.0.0", default-features = false } +datafusion-physical-expr = { path = "../physical-expr", version = "35.0.0", default-features = false } datafusion-physical-plan = { workspace = true } datafusion-sql = { workspace = true } flate2 = { version = "1.0.24", optional = true } diff --git a/datafusion/optimizer/Cargo.toml b/datafusion/optimizer/Cargo.toml index b350d41d3fe3..6aec52ad70d1 100644 --- a/datafusion/optimizer/Cargo.toml +++ b/datafusion/optimizer/Cargo.toml @@ -44,7 +44,7 @@ async-trait = { workspace = true } chrono = { workspace = true } datafusion-common = { workspace = true } datafusion-expr = { workspace = true } -datafusion-physical-expr = { path = "../physical-expr", version = "34.0.0", default-features = false } +datafusion-physical-expr = { path = "../physical-expr", version = "35.0.0", default-features = false } hashbrown = { version = "0.14", features = ["raw"] } itertools = { workspace = true } log = { workspace = true } @@ -52,5 +52,5 @@ regex-syntax = "0.8.0" [dev-dependencies] ctor = { workspace = true } -datafusion-sql = { path = "../sql", version = "34.0.0" } +datafusion-sql = { path = "../sql", version = "35.0.0" } env_logger = "0.10.0" diff --git a/datafusion/proto/Cargo.toml b/datafusion/proto/Cargo.toml index e42322021630..f9d54dba5756 100644 --- a/datafusion/proto/Cargo.toml +++ b/datafusion/proto/Cargo.toml @@ -43,7 +43,7 @@ parquet = ["datafusion/parquet", "datafusion-common/parquet"] [dependencies] arrow = { workspace = true } chrono = { workspace = true } -datafusion = { path = "../core", version = "34.0.0" } +datafusion = { path = "../core", version = "35.0.0" } datafusion-common = { workspace = true } datafusion-expr = { workspace = true } object_store = { workspace = true } diff --git a/datafusion/sqllogictest/Cargo.toml b/datafusion/sqllogictest/Cargo.toml index 7085e1ada09a..911b46c0bcf4 100644 --- a/datafusion/sqllogictest/Cargo.toml +++ b/datafusion/sqllogictest/Cargo.toml @@ -37,7 +37,7 @@ bigdecimal = { workspace = true } bytes = { version = "1.4.0", optional = true } chrono = { workspace = true, optional = true } clap = { version = "4.4.8", features = ["derive", "env"] } -datafusion = { path = "../core", version = "34.0.0" } +datafusion = { path = "../core", version = "35.0.0" } datafusion-common = { workspace = true } futures = { version = "0.3.28" } half = { workspace = true } diff --git a/dev/changelog/35.0.0.md b/dev/changelog/35.0.0.md new file mode 100644 index 000000000000..b48b2b7aaa12 --- /dev/null +++ b/dev/changelog/35.0.0.md @@ -0,0 +1,295 @@ + + +## [35.0.0](https://github.com/apache/arrow-datafusion/tree/35.0.0) (2024-01-20) + +[Full Changelog](https://github.com/apache/arrow-datafusion/compare/34.0.0...35.0.0) + +**Breaking changes:** + +- Minor: make SubqueryAlias::try_new take Arc [#8542](https://github.com/apache/arrow-datafusion/pull/8542) (sadboy) +- Remove ListingTable and FileScanConfig Unbounded (#8540) [#8573](https://github.com/apache/arrow-datafusion/pull/8573) (tustvold) +- Rename `ParamValues::{LIST -> List,MAP -> Map}` [#8611](https://github.com/apache/arrow-datafusion/pull/8611) (kawadakk) +- Rename `expr::window_function::WindowFunction` to `WindowFunctionDefinition`, make structure consistent with ScalarFunction [#8382](https://github.com/apache/arrow-datafusion/pull/8382) (edmondop) +- Implement `ScalarUDF` in terms of `ScalarUDFImpl` trait [#8713](https://github.com/apache/arrow-datafusion/pull/8713) (alamb) +- Change `ScalarValue::{List, LargeList, FixedSizedList}` to take specific types rather than `ArrayRef` [#8562](https://github.com/apache/arrow-datafusion/pull/8562) (rspears74) +- Remove unused array_expression.rs and `SUPPORTED_ARRAY_TYPES` [#8807](https://github.com/apache/arrow-datafusion/pull/8807) (alamb) +- Simplify physical expression creation API (not require schema) [#8823](https://github.com/apache/arrow-datafusion/pull/8823) (comphead) +- Determine causal window frames to produce early results. [#8842](https://github.com/apache/arrow-datafusion/pull/8842) (mustafasrepo) + +**Implemented enhancements:** + +- feat: implement Unary Expr in substrait [#8534](https://github.com/apache/arrow-datafusion/pull/8534) (waynexia) +- feat: implement Repartition plan in substrait [#8526](https://github.com/apache/arrow-datafusion/pull/8526) (waynexia) +- feat: support largelist in array_slice [#8561](https://github.com/apache/arrow-datafusion/pull/8561) (Weijun-H) +- feat: support `LargeList` in `array_positions` [#8571](https://github.com/apache/arrow-datafusion/pull/8571) (Weijun-H) +- feat: support `LargeList` in `array_element` [#8570](https://github.com/apache/arrow-datafusion/pull/8570) (Weijun-H) +- feat: support `LargeList` in `array_dims` [#8592](https://github.com/apache/arrow-datafusion/pull/8592) (Weijun-H) +- feat: support `LargeList` in `array_remove` [#8595](https://github.com/apache/arrow-datafusion/pull/8595) (Weijun-H) +- feat: support inlist in LiteralGurantee for pruning [#8654](https://github.com/apache/arrow-datafusion/pull/8654) (my-vegetable-has-exploded) +- feat: support 'LargeList' in `array_pop_front` and `array_pop_back` [#8569](https://github.com/apache/arrow-datafusion/pull/8569) (Weijun-H) +- feat: support `LargeList` in `array_position` [#8714](https://github.com/apache/arrow-datafusion/pull/8714) (Weijun-H) +- feat: support `LargeList` in `array_ndims` [#8716](https://github.com/apache/arrow-datafusion/pull/8716) (Weijun-H) +- feat: remove filters with null constants [#8700](https://github.com/apache/arrow-datafusion/pull/8700) (asimsedhain) +- feat: support LargeList in array_repeat [#8725](https://github.com/apache/arrow-datafusion/pull/8725) (Weijun-H) +- feat: native types in `DistinctCountAccumulator` for primitive types [#8721](https://github.com/apache/arrow-datafusion/pull/8721) (korowa) +- feat: support `LargeList` in `cardinality` [#8726](https://github.com/apache/arrow-datafusion/pull/8726) (Weijun-H) +- feat: support `largelist` in `array_to_string` [#8729](https://github.com/apache/arrow-datafusion/pull/8729) (Weijun-H) +- feat: Add bloom filter metric to ParquetExec [#8772](https://github.com/apache/arrow-datafusion/pull/8772) (my-vegetable-has-exploded) +- feat: support `array_resize` [#8744](https://github.com/apache/arrow-datafusion/pull/8744) (Weijun-H) +- feat: add more components to the wasm-pack compatible list [#8843](https://github.com/apache/arrow-datafusion/pull/8843) (waynexia) + +**Fixed bugs:** + +- fix: make sure CASE WHEN pick first true branch when WHEN clause is true [#8477](https://github.com/apache/arrow-datafusion/pull/8477) (haohuaijin) +- fix: `Antarctica/Vostok` tz offset changed in chrono-tz 0.8.5 [#8677](https://github.com/apache/arrow-datafusion/pull/8677) (korowa) +- fix: struct field don't push down to TableScan [#8774](https://github.com/apache/arrow-datafusion/pull/8774) (haohuaijin) +- fix: failed to create ValuesExec with non-nullable schema [#8776](https://github.com/apache/arrow-datafusion/pull/8776) (jonahgao) +- fix: fix markdown table in docs [#8812](https://github.com/apache/arrow-datafusion/pull/8812) (tshauck) +- fix: don't extract common sub expr in `CASE WHEN` clause [#8833](https://github.com/apache/arrow-datafusion/pull/8833) (haohuaijin) + +**Documentation updates:** + +- docs: update udf docs for udtf [#8546](https://github.com/apache/arrow-datafusion/pull/8546) (tshauck) +- Doc: Clarify When Limit is Pushed Down to TableProvider::Scan [#8686](https://github.com/apache/arrow-datafusion/pull/8686) (devinjdangelo) +- Minor: Improve `PruningPredicate` docstrings [#8748](https://github.com/apache/arrow-datafusion/pull/8748) (alamb) +- Minor: Add documentation about stream cancellation [#8747](https://github.com/apache/arrow-datafusion/pull/8747) (alamb) +- docs: add sudo for install commands [#8804](https://github.com/apache/arrow-datafusion/pull/8804) (caicancai) +- docs: document SessionConfig [#8771](https://github.com/apache/arrow-datafusion/pull/8771) (wjones127) +- Upgrade to object_store `0.9.0` and arrow `50.0.0` [#8758](https://github.com/apache/arrow-datafusion/pull/8758) (tustvold) +- docs: fix wrong pushdown name & a typo [#8875](https://github.com/apache/arrow-datafusion/pull/8875) (SteveLauC) +- docs: Update contributor guide with installation instructions [#8876](https://github.com/apache/arrow-datafusion/pull/8876) (caicancai) +- docs: fix wrong name in sub-crates' README [#8889](https://github.com/apache/arrow-datafusion/pull/8889) (SteveLauC) +- docs: add an example for RecordBatchReceiverStreamBuilder [#8888](https://github.com/apache/arrow-datafusion/pull/8888) (SteveLauC) + +**Merged pull requests:** + +- Remove order_bys from AggregateExec state [#8537](https://github.com/apache/arrow-datafusion/pull/8537) (mustafasrepo) +- Fix count(null) and count(distinct null) [#8511](https://github.com/apache/arrow-datafusion/pull/8511) (joroKr21) +- Minor: reduce code duplication in `date_bin_impl` [#8528](https://github.com/apache/arrow-datafusion/pull/8528) (Weijun-H) +- Add metrics for UnnestExec [#8482](https://github.com/apache/arrow-datafusion/pull/8482) (simonvandel) +- Prepare 34.0.0-rc3 [#8549](https://github.com/apache/arrow-datafusion/pull/8549) (andygrove) +- fix: make sure CASE WHEN pick first true branch when WHEN clause is true [#8477](https://github.com/apache/arrow-datafusion/pull/8477) (haohuaijin) +- Minor: make SubqueryAlias::try_new take Arc [#8542](https://github.com/apache/arrow-datafusion/pull/8542) (sadboy) +- Fallback on null empty value in ExprBoundaries::try_from_column [#8501](https://github.com/apache/arrow-datafusion/pull/8501) (razeghi71) +- Add test for DataFrame::write_table [#8531](https://github.com/apache/arrow-datafusion/pull/8531) (devinjdangelo) +- [MINOR]: Generate empty column at placeholder exec [#8553](https://github.com/apache/arrow-datafusion/pull/8553) (mustafasrepo) +- Minor: Remove now dead `SUPPORTED_STRUCT_TYPES` [#8480](https://github.com/apache/arrow-datafusion/pull/8480) (alamb) +- [MINOR]: Add getter methods to first and last value [#8555](https://github.com/apache/arrow-datafusion/pull/8555) (mustafasrepo) +- [MINOR]: Some code changes and a new empty batch guard for SHJ [#8557](https://github.com/apache/arrow-datafusion/pull/8557) (metesynnada) +- docs: update udf docs for udtf [#8546](https://github.com/apache/arrow-datafusion/pull/8546) (tshauck) +- feat: implement Unary Expr in substrait [#8534](https://github.com/apache/arrow-datafusion/pull/8534) (waynexia) +- Fix `compute_record_batch_statistics` wrong with `projection` [#8489](https://github.com/apache/arrow-datafusion/pull/8489) (Asura7969) +- Minor: Cleanup warning in scalar.rs test [#8563](https://github.com/apache/arrow-datafusion/pull/8563) (jayzhan211) +- Minor: move some invariants out of the loop [#8564](https://github.com/apache/arrow-datafusion/pull/8564) (haohuaijin) +- feat: implement Repartition plan in substrait [#8526](https://github.com/apache/arrow-datafusion/pull/8526) (waynexia) +- Fix sort order aware file group parallelization [#8517](https://github.com/apache/arrow-datafusion/pull/8517) (alamb) +- feat: support largelist in array_slice [#8561](https://github.com/apache/arrow-datafusion/pull/8561) (Weijun-H) +- minor: fix to support scalars [#8559](https://github.com/apache/arrow-datafusion/pull/8559) (comphead) +- refactor: `HashJoinStream` state machine [#8538](https://github.com/apache/arrow-datafusion/pull/8538) (korowa) +- Remove ListingTable and FileScanConfig Unbounded (#8540) [#8573](https://github.com/apache/arrow-datafusion/pull/8573) (tustvold) +- Update substrait requirement from 0.20.0 to 0.21.0 [#8574](https://github.com/apache/arrow-datafusion/pull/8574) (dependabot[bot]) +- [minor]: Fix rank calculation bug when empty order by is seen [#8567](https://github.com/apache/arrow-datafusion/pull/8567) (mustafasrepo) +- Add `LiteralGuarantee` on columns to extract conditions required for `PhysicalExpr` expressions to evaluate to true [#8437](https://github.com/apache/arrow-datafusion/pull/8437) (alamb) +- [MINOR]: Parametrize sort-preservation tests to exercise all situations (unbounded/bounded sources and flag behavior) [#8575](https://github.com/apache/arrow-datafusion/pull/8575) (mustafasrepo) +- Minor: Add some comments to scalar_udf example [#8576](https://github.com/apache/arrow-datafusion/pull/8576) (alamb) +- Move Coercion for MakeArray to `coerce_arguments_for_signature` and introduce another one for ArrayAppend [#8317](https://github.com/apache/arrow-datafusion/pull/8317) (jayzhan211) +- feat: support `LargeList` in `array_positions` [#8571](https://github.com/apache/arrow-datafusion/pull/8571) (Weijun-H) +- feat: support `LargeList` in `array_element` [#8570](https://github.com/apache/arrow-datafusion/pull/8570) (Weijun-H) +- Increase test coverage for unbounded and bounded cases [#8581](https://github.com/apache/arrow-datafusion/pull/8581) (mustafasrepo) +- Port tests in `parquet.rs` to sqllogictest [#8560](https://github.com/apache/arrow-datafusion/pull/8560) (hiltontj) +- Minor: avoid a copy in Expr::unalias [#8588](https://github.com/apache/arrow-datafusion/pull/8588) (alamb) +- Minor: support complex expr as the arg in the ApproxPercentileCont function [#8580](https://github.com/apache/arrow-datafusion/pull/8580) (liukun4515) +- Bugfix: Add functional dependency check and aggregate try_new schema [#8584](https://github.com/apache/arrow-datafusion/pull/8584) (mustafasrepo) +- Remove GroupByOrderMode [#8593](https://github.com/apache/arrow-datafusion/pull/8593) (ozankabak) +- Minor: replace` not-impl-err` in `array_expression` [#8589](https://github.com/apache/arrow-datafusion/pull/8589) (Weijun-H) +- Substrait insubquery [#8363](https://github.com/apache/arrow-datafusion/pull/8363) (tgujar) +- Minor: port last test from parquet.rs [#8587](https://github.com/apache/arrow-datafusion/pull/8587) (alamb) +- Minor: consolidate map sqllogictest tests [#8550](https://github.com/apache/arrow-datafusion/pull/8550) (alamb) +- feat: support `LargeList` in `array_dims` [#8592](https://github.com/apache/arrow-datafusion/pull/8592) (Weijun-H) +- Fix regression in regenerating protobuf source [#8603](https://github.com/apache/arrow-datafusion/pull/8603) (andygrove) +- Remove unbounded_input from FileSinkOptions [#8605](https://github.com/apache/arrow-datafusion/pull/8605) (devinjdangelo) +- Add `arrow_err!` macros, optional backtrace to ArrowError [#8586](https://github.com/apache/arrow-datafusion/pull/8586) (comphead) +- Add examples of DataFrame::write\* methods without S3 dependency [#8606](https://github.com/apache/arrow-datafusion/pull/8606) (devinjdangelo) +- Implement logical plan serde for CopyTo [#8618](https://github.com/apache/arrow-datafusion/pull/8618) (andygrove) +- Fix InListExpr to return the correct number of rows [#8601](https://github.com/apache/arrow-datafusion/pull/8601) (alamb) +- Remove ListingTable single_file option [#8604](https://github.com/apache/arrow-datafusion/pull/8604) (devinjdangelo) +- feat: support `LargeList` in `array_remove` [#8595](https://github.com/apache/arrow-datafusion/pull/8595) (Weijun-H) +- Rename `ParamValues::{LIST -> List,MAP -> Map}` [#8611](https://github.com/apache/arrow-datafusion/pull/8611) (kawadakk) +- Support binary temporal coercion for Date64 and Timestamp types [#8616](https://github.com/apache/arrow-datafusion/pull/8616) (Asura7969) +- Add new configuration item `listing_table_ignore_subdirectory` [#8565](https://github.com/apache/arrow-datafusion/pull/8565) (Asura7969) +- Optimize the parameter types of `ParamValues`'s methods [#8613](https://github.com/apache/arrow-datafusion/pull/8613) (kawadakk) +- Do not panic on zero placeholders in `ParamValues::get_placeholders_with_values` [#8615](https://github.com/apache/arrow-datafusion/pull/8615) (kawadakk) +- Fix #8507: Non-null sub-field on nullable struct-field has wrong nullity [#8623](https://github.com/apache/arrow-datafusion/pull/8623) (marvinlanhenke) +- Implement `contained` API in PruningPredicate [#8440](https://github.com/apache/arrow-datafusion/pull/8440) (alamb) +- Add partial serde support for ParquetWriterOptions [#8627](https://github.com/apache/arrow-datafusion/pull/8627) (andygrove) +- Minor: add arguments length check in `array_expressions` [#8622](https://github.com/apache/arrow-datafusion/pull/8622) (Weijun-H) +- Minor: improve dataframe functional dependency tests [#8630](https://github.com/apache/arrow-datafusion/pull/8630) (alamb) +- Improve regexp_match performance by avoiding cloning Regex [#8631](https://github.com/apache/arrow-datafusion/pull/8631) (viirya) +- Minor: improve `listing_table_ignore_subdirectory` config documentation [#8634](https://github.com/apache/arrow-datafusion/pull/8634) (alamb) +- Support Writing Arrow files [#8608](https://github.com/apache/arrow-datafusion/pull/8608) (devinjdangelo) +- Filter pushdown into cross join [#8626](https://github.com/apache/arrow-datafusion/pull/8626) (mustafasrepo) +- [MINOR] Remove duplicate test utility and move one utility function for better organization [#8652](https://github.com/apache/arrow-datafusion/pull/8652) (metesynnada) +- [MINOR]: Add new test for filter pushdown into cross join [#8648](https://github.com/apache/arrow-datafusion/pull/8648) (mustafasrepo) +- Rewrite bloom filters to use contains API [#8442](https://github.com/apache/arrow-datafusion/pull/8442) (alamb) +- Split equivalence code into smaller modules. [#8649](https://github.com/apache/arrow-datafusion/pull/8649) (tushushu) +- Move parquet_schema.rs from sql to parquet tests [#8644](https://github.com/apache/arrow-datafusion/pull/8644) (alamb) +- Fix group by aliased expression in LogicalPLanBuilder::aggregate [#8629](https://github.com/apache/arrow-datafusion/pull/8629) (alamb) +- Refactor `array_union` and `array_intersect` functions to one general function [#8516](https://github.com/apache/arrow-datafusion/pull/8516) (Weijun-H) +- Minor: avoid extra clone in datafusion-proto::physical_plan [#8650](https://github.com/apache/arrow-datafusion/pull/8650) (ongchi) +- Minor: name some constant values in arrow writer, parquet writer [#8642](https://github.com/apache/arrow-datafusion/pull/8642) (alamb) +- TreeNode Refactor Part 2 [#8653](https://github.com/apache/arrow-datafusion/pull/8653) (berkaysynnada) +- feat: support inlist in LiteralGurantee for pruning [#8654](https://github.com/apache/arrow-datafusion/pull/8654) (my-vegetable-has-exploded) +- Streaming CLI support [#8651](https://github.com/apache/arrow-datafusion/pull/8651) (berkaysynnada) +- Add serde support for CSV FileTypeWriterOptions [#8641](https://github.com/apache/arrow-datafusion/pull/8641) (andygrove) +- Add trait based ScalarUDF API [#8578](https://github.com/apache/arrow-datafusion/pull/8578) (alamb) +- Handle ordering of first last aggregation inside aggregator [#8662](https://github.com/apache/arrow-datafusion/pull/8662) (mustafasrepo) +- feat: support 'LargeList' in `array_pop_front` and `array_pop_back` [#8569](https://github.com/apache/arrow-datafusion/pull/8569) (Weijun-H) +- chore: rename ceresdb to apache horaedb [#8674](https://github.com/apache/arrow-datafusion/pull/8674) (tanruixiang) +- Minor: clean up code [#8671](https://github.com/apache/arrow-datafusion/pull/8671) (Weijun-H) +- fix: `Antarctica/Vostok` tz offset changed in chrono-tz 0.8.5 [#8677](https://github.com/apache/arrow-datafusion/pull/8677) (korowa) +- Make the BatchSerializer behind Arc to avoid unnecessary struct creation [#8666](https://github.com/apache/arrow-datafusion/pull/8666) (metesynnada) +- Implement serde for CSV and Parquet FileSinkExec [#8646](https://github.com/apache/arrow-datafusion/pull/8646) (andygrove) +- [pruning] Add shortcut when all units have been pruned [#8675](https://github.com/apache/arrow-datafusion/pull/8675) (Ted-Jiang) +- Change first/last implementation to prevent redundant comparisons when data is already sorted [#8678](https://github.com/apache/arrow-datafusion/pull/8678) (mustafasrepo) +- minor: remove useless conversion [#8684](https://github.com/apache/arrow-datafusion/pull/8684) (comphead) +- refactor: modified `JoinHashMap` build order for `HashJoinStream` [#8658](https://github.com/apache/arrow-datafusion/pull/8658) (korowa) +- Start setting up tpch planning benchmarks [#8665](https://github.com/apache/arrow-datafusion/pull/8665) (matthewmturner) +- Doc: Clarify When Limit is Pushed Down to TableProvider::Scan [#8686](https://github.com/apache/arrow-datafusion/pull/8686) (devinjdangelo) +- Closes #8502: Parallel NDJSON file reading [#8659](https://github.com/apache/arrow-datafusion/pull/8659) (marvinlanhenke) +- Improve `array_prepend` signature for null and empty array [#8625](https://github.com/apache/arrow-datafusion/pull/8625) (jayzhan211) +- Cleanup TreeNode implementations [#8672](https://github.com/apache/arrow-datafusion/pull/8672) (viirya) +- Update sqlparser requirement from 0.40.0 to 0.41.0 [#8647](https://github.com/apache/arrow-datafusion/pull/8647) (dependabot[bot]) +- Update scalar functions doc for extract/datepart [#8682](https://github.com/apache/arrow-datafusion/pull/8682) (Jefffrey) +- Remove DescribeTableStmt in parser in favour of existing functionality from sqlparser-rs [#8703](https://github.com/apache/arrow-datafusion/pull/8703) (Jefffrey) +- Simplify `NULL [NOT] IN (..)` expressions [#8691](https://github.com/apache/arrow-datafusion/pull/8691) (asimsedhain) +- Rename `expr::window_function::WindowFunction` to `WindowFunctionDefinition`, make structure consistent with ScalarFunction [#8382](https://github.com/apache/arrow-datafusion/pull/8382) (edmondop) +- Deprecate duplicate function `LogicalPlan::with_new_inputs` [#8707](https://github.com/apache/arrow-datafusion/pull/8707) (viirya) +- Minor: refactor bloom filter tests to reduce duplication [#8435](https://github.com/apache/arrow-datafusion/pull/8435) (alamb) +- Minor: clean up code based on `Clippy` [#8715](https://github.com/apache/arrow-datafusion/pull/8715) (Weijun-H) +- Minor: Unbounded Output of AnalyzeExec [#8717](https://github.com/apache/arrow-datafusion/pull/8717) (berkaysynnada) +- feat: support `LargeList` in `array_position` [#8714](https://github.com/apache/arrow-datafusion/pull/8714) (Weijun-H) +- feat: support `LargeList` in `array_ndims` [#8716](https://github.com/apache/arrow-datafusion/pull/8716) (Weijun-H) +- feat: remove filters with null constants [#8700](https://github.com/apache/arrow-datafusion/pull/8700) (asimsedhain) +- support `LargeList` in `array_prepend` and `array_append` [#8679](https://github.com/apache/arrow-datafusion/pull/8679) (Weijun-H) +- Support for `extract(epoch from date)` for Date32 and Date64 [#8695](https://github.com/apache/arrow-datafusion/pull/8695) (Jefffrey) +- Implement trait based API for defining WindowUDF [#8719](https://github.com/apache/arrow-datafusion/pull/8719) (guojidan) +- Minor: Introduce utils::hash for StructArray [#8552](https://github.com/apache/arrow-datafusion/pull/8552) (jayzhan211) +- [CI] Improve windows machine CI test time [#8730](https://github.com/apache/arrow-datafusion/pull/8730) (comphead) +- fix guarantees in allways_true of PruningPredicate [#8732](https://github.com/apache/arrow-datafusion/pull/8732) (my-vegetable-has-exploded) +- Minor: Avoid memory copy in construct window exprs [#8718](https://github.com/apache/arrow-datafusion/pull/8718) (Ted-Jiang) +- feat: support LargeList in array_repeat [#8725](https://github.com/apache/arrow-datafusion/pull/8725) (Weijun-H) +- Minor: Ctrl+C Termination in CLI [#8739](https://github.com/apache/arrow-datafusion/pull/8739) (berkaysynnada) +- Add support for functional dependency for ROW_NUMBER window function. [#8737](https://github.com/apache/arrow-datafusion/pull/8737) (mustafasrepo) +- Minor: reduce code duplication in PruningPredicate test [#8441](https://github.com/apache/arrow-datafusion/pull/8441) (alamb) +- feat: native types in `DistinctCountAccumulator` for primitive types [#8721](https://github.com/apache/arrow-datafusion/pull/8721) (korowa) +- [MINOR]: Add a test case for when target partition is 1, no hash repartition is added to the plan. [#8757](https://github.com/apache/arrow-datafusion/pull/8757) (mustafasrepo) +- Minor: Improve `PruningPredicate` docstrings [#8748](https://github.com/apache/arrow-datafusion/pull/8748) (alamb) +- feat: support `LargeList` in `cardinality` [#8726](https://github.com/apache/arrow-datafusion/pull/8726) (Weijun-H) +- Add reproducer for #8738 [#8750](https://github.com/apache/arrow-datafusion/pull/8750) (alamb) +- Minor: Use faster check for column name in schema merge [#8765](https://github.com/apache/arrow-datafusion/pull/8765) (matthewmturner) +- Minor: Add documentation about stream cancellation [#8747](https://github.com/apache/arrow-datafusion/pull/8747) (alamb) +- Move `repartition_file_scans` out of `enable_round_robin` check in `EnforceDistribution` rule [#8731](https://github.com/apache/arrow-datafusion/pull/8731) (viirya) +- Clean internal implementation of WindowUDF [#8746](https://github.com/apache/arrow-datafusion/pull/8746) (guojidan) +- feat: support `largelist` in `array_to_string` [#8729](https://github.com/apache/arrow-datafusion/pull/8729) (Weijun-H) +- [MINOR] CLI error handling on streaming use cases [#8761](https://github.com/apache/arrow-datafusion/pull/8761) (metesynnada) +- Convert Binary Operator `StringConcat` to Function for `array_concat`, `array_append` and `array_prepend` [#8636](https://github.com/apache/arrow-datafusion/pull/8636) (jayzhan211) +- Minor: Fix incorrect indices for hashing struct [#8775](https://github.com/apache/arrow-datafusion/pull/8775) (jayzhan211) +- Minor: Improve library docs to mention TreeNode, ExprSimplifier, PruningPredicate and cp_solver [#8749](https://github.com/apache/arrow-datafusion/pull/8749) (alamb) +- [MINOR] Add logo source files [#8762](https://github.com/apache/arrow-datafusion/pull/8762) (andygrove) +- Add Apache attribution to site footer [#8760](https://github.com/apache/arrow-datafusion/pull/8760) (alamb) +- ci: speed up win64 test [#8728](https://github.com/apache/arrow-datafusion/pull/8728) (Jefffrey) +- Add `schema_err!` error macros with optional backtrace [#8620](https://github.com/apache/arrow-datafusion/pull/8620) (comphead) +- Fix regression by reverting Materialize dictionaries in group keys [#8740](https://github.com/apache/arrow-datafusion/pull/8740) (alamb) +- fix: struct field don't push down to TableScan [#8774](https://github.com/apache/arrow-datafusion/pull/8774) (haohuaijin) +- Implement `ScalarUDF` in terms of `ScalarUDFImpl` trait [#8713](https://github.com/apache/arrow-datafusion/pull/8713) (alamb) +- Minor: Fix error messages in array expressions [#8781](https://github.com/apache/arrow-datafusion/pull/8781) (Weijun-H) +- Move tests from `expr.rs` to sqllogictests. Part1 [#8773](https://github.com/apache/arrow-datafusion/pull/8773) (comphead) +- Permit running `sqllogictest` as a rust test in IDEs (+ use clap for sqllogicttest parsing, accept (and ignore) rust test harness arguments) [#8288](https://github.com/apache/arrow-datafusion/pull/8288) (alamb) +- Minor: Use standard tree walk in Projection Pushdown [#8787](https://github.com/apache/arrow-datafusion/pull/8787) (alamb) +- Implement trait based API for define AggregateUDF [#8733](https://github.com/apache/arrow-datafusion/pull/8733) (guojidan) +- Minor: Improve `DataFusionError` documentation [#8792](https://github.com/apache/arrow-datafusion/pull/8792) (alamb) +- fix: failed to create ValuesExec with non-nullable schema [#8776](https://github.com/apache/arrow-datafusion/pull/8776) (jonahgao) +- Update substrait requirement from 0.21.0 to 0.22.1 [#8796](https://github.com/apache/arrow-datafusion/pull/8796) (dependabot[bot]) +- Bump follow-redirects from 1.15.3 to 1.15.4 in /datafusion/wasmtest/datafusion-wasm-app [#8798](https://github.com/apache/arrow-datafusion/pull/8798) (dependabot[bot]) +- Minor: array_pop_first should be array_pop_front in documentation [#8797](https://github.com/apache/arrow-datafusion/pull/8797) (ongchi) +- feat: Add bloom filter metric to ParquetExec [#8772](https://github.com/apache/arrow-datafusion/pull/8772) (my-vegetable-has-exploded) +- Add note on using larger row group size [#8745](https://github.com/apache/arrow-datafusion/pull/8745) (twitu) +- Change `ScalarValue::{List, LargeList, FixedSizedList}` to take specific types rather than `ArrayRef` [#8562](https://github.com/apache/arrow-datafusion/pull/8562) (rspears74) +- fix: fix markdown table in docs [#8812](https://github.com/apache/arrow-datafusion/pull/8812) (tshauck) +- docs: add sudo for install commands [#8804](https://github.com/apache/arrow-datafusion/pull/8804) (caicancai) +- Standardize `CompressionTypeVariant` encoding in protobuf [#8785](https://github.com/apache/arrow-datafusion/pull/8785) (tushushu) +- Make benefits_from_input_partitioning Default in SHJ [#8801](https://github.com/apache/arrow-datafusion/pull/8801) (metesynnada) +- refactor: standardize exec_from funcs arg order [#8809](https://github.com/apache/arrow-datafusion/pull/8809) (tshauck) +- [Minor] extract const and add doc and more tests for in_list pruning [#8815](https://github.com/apache/arrow-datafusion/pull/8815) (Ted-Jiang) +- [MINOR]: Add size check for aggregate [#8813](https://github.com/apache/arrow-datafusion/pull/8813) (mustafasrepo) +- Minor: chores: Update clippy in pre-commit.sh [#8810](https://github.com/apache/arrow-datafusion/pull/8810) (my-vegetable-has-exploded) +- Cleanup the usage of round-robin repartitioning [#8794](https://github.com/apache/arrow-datafusion/pull/8794) (viirya) +- Implement monotonicity for ScalarUDF [#8799](https://github.com/apache/arrow-datafusion/pull/8799) (guojidan) +- Remove unused array_expression.rs and `SUPPORTED_ARRAY_TYPES` [#8807](https://github.com/apache/arrow-datafusion/pull/8807) (alamb) +- feat: support `array_resize` [#8744](https://github.com/apache/arrow-datafusion/pull/8744) (Weijun-H) +- Minor: typo in `arrays.slt` [#8831](https://github.com/apache/arrow-datafusion/pull/8831) (Weijun-H) +- docs: document SessionConfig [#8771](https://github.com/apache/arrow-datafusion/pull/8771) (wjones127) +- Minor: Improve `datafusion-proto` documentation [#8822](https://github.com/apache/arrow-datafusion/pull/8822) (alamb) +- [CI] Refactor CI builders [#8826](https://github.com/apache/arrow-datafusion/pull/8826) (comphead) +- Serialize function signature simplifications [#8802](https://github.com/apache/arrow-datafusion/pull/8802) (metesynnada) +- Port tests in `group_by.rs` to sqllogictest [#8834](https://github.com/apache/arrow-datafusion/pull/8834) (hiltontj) +- Simplify physical expression creation API (not require schema) [#8823](https://github.com/apache/arrow-datafusion/pull/8823) (comphead) +- feat: add more components to the wasm-pack compatible list [#8843](https://github.com/apache/arrow-datafusion/pull/8843) (waynexia) +- Port tests in timestamp.rs to sqllogictest. Part 1 [#8818](https://github.com/apache/arrow-datafusion/pull/8818) (caicancai) +- Upgrade to object_store `0.9.0` and arrow `50.0.0` [#8758](https://github.com/apache/arrow-datafusion/pull/8758) (tustvold) +- Fix ApproxPercentileCont signature [#8825](https://github.com/apache/arrow-datafusion/pull/8825) (joroKr21) +- Minor: Update `with_column_rename` method doc [#8858](https://github.com/apache/arrow-datafusion/pull/8858) (comphead) +- Minor: Document `parquet_metadata` function [#8852](https://github.com/apache/arrow-datafusion/pull/8852) (alamb) +- Speedup new_with_metadata by removing sort [#8855](https://github.com/apache/arrow-datafusion/pull/8855) (simonvandel) +- Minor: fix wrong function call [#8847](https://github.com/apache/arrow-datafusion/pull/8847) (Weijun-H) +- Add options of parquet bloom filter and page index in Session config [#8869](https://github.com/apache/arrow-datafusion/pull/8869) (Ted-Jiang) +- Port tests in timestamp.rs to sqllogictest [#8859](https://github.com/apache/arrow-datafusion/pull/8859) (caicancai) +- test: Port `order.rs` tests to sqllogictest [#8857](https://github.com/apache/arrow-datafusion/pull/8857) (simicd) +- Determine causal window frames to produce early results. [#8842](https://github.com/apache/arrow-datafusion/pull/8842) (mustafasrepo) +- docs: fix wrong pushdown name & a typo [#8875](https://github.com/apache/arrow-datafusion/pull/8875) (SteveLauC) +- fix: don't extract common sub expr in `CASE WHEN` clause [#8833](https://github.com/apache/arrow-datafusion/pull/8833) (haohuaijin) +- Add "Extended" clickbench queries [#8861](https://github.com/apache/arrow-datafusion/pull/8861) (alamb) +- Change cli to propagate error to exit code [#8856](https://github.com/apache/arrow-datafusion/pull/8856) (tshauck) +- test: Port tests in `predicates.rs` to sqllogictest [#8879](https://github.com/apache/arrow-datafusion/pull/8879) (simicd) +- docs: Update contributor guide with installation instructions [#8876](https://github.com/apache/arrow-datafusion/pull/8876) (caicancai) +- Minor: add tests for casts between nested `List` and `LargeList` [#8882](https://github.com/apache/arrow-datafusion/pull/8882) (Weijun-H) +- Disable Parallel Parquet Writer by Default, Improve Writing Test Coverage [#8854](https://github.com/apache/arrow-datafusion/pull/8854) (devinjdangelo) +- Support for order sensitive `NTH_VALUE` aggregation, make reverse `ARRAY_AGG` more efficient [#8841](https://github.com/apache/arrow-datafusion/pull/8841) (mustafasrepo) +- test: Port tests in `csv_files.rs` to sqllogictest [#8885](https://github.com/apache/arrow-datafusion/pull/8885) (simicd) +- test: Port tests in `references.rs` to sqllogictest [#8877](https://github.com/apache/arrow-datafusion/pull/8877) (simicd) +- fix bug with `to_timestamp` and `InitCap` logical serialization, add roundtrip test between expression and proto, [#8868](https://github.com/apache/arrow-datafusion/pull/8868) (Weijun-H) +- Support `LargeListArray` scalar values and `align_array_dimensions` [#8881](https://github.com/apache/arrow-datafusion/pull/8881) (Weijun-H) +- refactor: rename FileStream.file_reader to file_opener & update doc [#8883](https://github.com/apache/arrow-datafusion/pull/8883) (SteveLauC) +- docs: fix wrong name in sub-crates' README [#8889](https://github.com/apache/arrow-datafusion/pull/8889) (SteveLauC) +- Recursive CTEs: Stage 1 - add config flag [#8828](https://github.com/apache/arrow-datafusion/pull/8828) (matthewgapp) +- Support array literal with scalar function [#8884](https://github.com/apache/arrow-datafusion/pull/8884) (jayzhan211) +- Bump actions/cache from 3 to 4 [#8903](https://github.com/apache/arrow-datafusion/pull/8903) (dependabot[bot]) +- Fix `datafusion-cli` print output [#8895](https://github.com/apache/arrow-datafusion/pull/8895) (alamb) +- docs: add an example for RecordBatchReceiverStreamBuilder [#8888](https://github.com/apache/arrow-datafusion/pull/8888) (SteveLauC) +- Fix "Projection references non-aggregate values" by updating `rebase_expr` to use `transform_down` [#8890](https://github.com/apache/arrow-datafusion/pull/8890) (wizardxz) +- Add serde support for Arrow FileTypeWriterOptions [#8850](https://github.com/apache/arrow-datafusion/pull/8850) (tushushu) +- Improve `datafusion-cli` print format tests [#8896](https://github.com/apache/arrow-datafusion/pull/8896) (alamb) +- Recursive CTEs: Stage 2 - add support for sql -> logical plan generation [#8839](https://github.com/apache/arrow-datafusion/pull/8839) (matthewgapp) +- Minor: remove null in `array-append` and `array-prepend` [#8901](https://github.com/apache/arrow-datafusion/pull/8901) (Weijun-H) +- Add support for FixedSizeList type in `arrow_cast`, hashing [#8344](https://github.com/apache/arrow-datafusion/pull/8344) (Weijun-H) +- aggregate_statistics should only optimize MIN/MAX when relation is not empty [#8914](https://github.com/apache/arrow-datafusion/pull/8914) (viirya) +- support to_timestamp with optional chrono formats [#8886](https://github.com/apache/arrow-datafusion/pull/8886) (Omega359) +- Minor: Document third argument of `date_bin` as optional and default value [#8912](https://github.com/apache/arrow-datafusion/pull/8912) (alamb) +- Minor: distinguish parquet row group pruning type in unit test [#8921](https://github.com/apache/arrow-datafusion/pull/8921) (Ted-Jiang) diff --git a/dev/release/README.md b/dev/release/README.md index 53487678aa69..9cf241355e4d 100644 --- a/dev/release/README.md +++ b/dev/release/README.md @@ -19,7 +19,7 @@ # Release Process -DataFusion typically has major releases every two weeks, including breaking API changes. +DataFusion typically has major releases around once per month, including breaking API changes. Patch releases are made on an adhoc basis, but we try and avoid them given the frequent major releases. diff --git a/docs/Cargo.toml b/docs/Cargo.toml index 813335e30f77..3a8c90cae085 100644 --- a/docs/Cargo.toml +++ b/docs/Cargo.toml @@ -29,4 +29,4 @@ authors = { workspace = true } rust-version = "1.70" [dependencies] -datafusion = { path = "../datafusion/core", version = "34.0.0", default-features = false } +datafusion = { path = "../datafusion/core", version = "35.0.0", default-features = false } diff --git a/docs/source/user-guide/configs.md b/docs/source/user-guide/configs.md index 7a7460799b1a..9d914aaaf15f 100644 --- a/docs/source/user-guide/configs.md +++ b/docs/source/user-guide/configs.md @@ -64,7 +64,7 @@ Environment variables are read during `SessionConfig` initialisation so they mus | datafusion.execution.parquet.statistics_enabled | NULL | Sets if statistics are enabled for any column Valid values are: "none", "chunk", and "page" These values are not case sensitive. If NULL, uses default parquet writer setting | | datafusion.execution.parquet.max_statistics_size | NULL | Sets max statistics size for any column. If NULL, uses default parquet writer setting | | datafusion.execution.parquet.max_row_group_size | 1048576 | Target maximum number of rows in each row group (defaults to 1M rows). Writing larger row groups requires more memory to write, but can get better compression and be faster to read. | -| datafusion.execution.parquet.created_by | datafusion version 34.0.0 | Sets "created by" property | +| datafusion.execution.parquet.created_by | datafusion version 35.0.0 | Sets "created by" property | | datafusion.execution.parquet.column_index_truncate_length | NULL | Sets column index truncate length | | datafusion.execution.parquet.data_page_row_count_limit | 18446744073709551615 | Sets best effort maximum number of rows in data page | | datafusion.execution.parquet.encoding | NULL | Sets default encoding for any column Valid values are: plain, plain_dictionary, rle, bit_packed, delta_binary_packed, delta_length_byte_array, delta_byte_array, rle_dictionary, and byte_stream_split. These values are not case sensitive. If NULL, uses default parquet writer setting | From 0116e2a9b4a3ed4491802e19195769b96b7a971a Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Sun, 21 Jan 2024 20:30:30 +0800 Subject: [PATCH 32/39] feat: support `stride` in `array_slice`, change indexes to be`1` based (#8829) * support array slice * fix argument * fix typo * support from and to is negative * fix conflict * modify user doc * refactor code * fix clippy * add 1-index test --------- Co-authored-by: Andrew Lamb --- datafusion/expr/src/built_in_function.rs | 5 +- datafusion/expr/src/expr_fn.rs | 2 +- .../physical-expr/src/array_expressions.rs | 93 +++++++++++++++---- .../proto/src/logical_plan/from_proto.rs | 1 + datafusion/sqllogictest/test_files/array.slt | 55 +++++++---- docs/source/user-guide/expressions.md | 71 +++++++------- 6 files changed, 155 insertions(+), 72 deletions(-) diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index b54cd68164c1..81c8f67cc67b 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -977,7 +977,10 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ArrayReplaceAll => { Signature::any(3, self.volatility()) } - BuiltinScalarFunction::ArraySlice => Signature::any(3, self.volatility()), + BuiltinScalarFunction::ArraySlice => { + Signature::variadic_any(self.volatility()) + } + BuiltinScalarFunction::ArrayToString => { Signature::variadic_any(self.volatility()) } diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index ae534f4bb44b..1d45fa4facd0 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -731,7 +731,7 @@ scalar_expr!( scalar_expr!( ArraySlice, array_slice, - array offset length, + array begin end stride, "returns a slice of the array." ); scalar_expr!( diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index b9cdcff20659..a3dec2762c10 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -24,7 +24,7 @@ use std::sync::Arc; use arrow::array::*; use arrow::buffer::OffsetBuffer; -use arrow::compute; +use arrow::compute::{self}; use arrow::datatypes::{DataType, Field, UInt64Type}; use arrow::row::{RowConverter, SortField}; use arrow_buffer::{ArrowNativeType, NullBuffer}; @@ -575,23 +575,31 @@ pub fn array_except(args: &[ArrayRef]) -> Result { /// /// See test cases in `array.slt` for more details. pub fn array_slice(args: &[ArrayRef]) -> Result { - if args.len() != 3 { - return exec_err!("array_slice needs three arguments"); + let args_len = args.len(); + if args_len != 3 && args_len != 4 { + return exec_err!("array_slice needs three or four arguments"); } + let stride = if args_len == 4 { + Some(as_int64_array(&args[3])?) + } else { + None + }; + + let from_array = as_int64_array(&args[1])?; + let to_array = as_int64_array(&args[2])?; + let array_data_type = args[0].data_type(); match array_data_type { DataType::List(_) => { let array = as_list_array(&args[0])?; - let from_array = as_int64_array(&args[1])?; - let to_array = as_int64_array(&args[2])?; - general_array_slice::(array, from_array, to_array) + general_array_slice::(array, from_array, to_array, stride) } DataType::LargeList(_) => { let array = as_large_list_array(&args[0])?; let from_array = as_int64_array(&args[1])?; let to_array = as_int64_array(&args[2])?; - general_array_slice::(array, from_array, to_array) + general_array_slice::(array, from_array, to_array, stride) } _ => exec_err!("array_slice does not support type: {:?}", array_data_type), } @@ -601,6 +609,7 @@ fn general_array_slice( array: &GenericListArray, from_array: &Int64Array, to_array: &Int64Array, + stride: Option<&Int64Array>, ) -> Result where i64: TryInto, @@ -652,7 +661,7 @@ where let adjusted_zero_index = if index < 0 { // array_slice in duckdb with negative to_index is python-like, so index itself is exclusive if let Ok(index) = index.try_into() { - index + len - O::usize_as(1) + index + len } else { return exec_err!("array_slice got invalid index: {}", index); } @@ -700,17 +709,67 @@ where }; if let (Some(from), Some(to)) = (from_index, to_index) { + let stride = stride.map(|s| s.value(row_index)); + // array_slice with stride in duckdb, return empty array if stride is not supported and from > to. + if stride.is_none() && from > to { + // return empty array + offsets.push(offsets[row_index]); + continue; + } + let stride = stride.unwrap_or(1); + if stride.is_zero() { + return exec_err!( + "array_slice got invalid stride: {:?}, it cannot be 0", + stride + ); + } else if from <= to && stride.is_negative() { + // return empty array + offsets.push(offsets[row_index]); + continue; + } + + let stride: O = stride.try_into().map_err(|_| { + internal_datafusion_err!("array_slice got invalid stride: {}", stride) + })?; + if from <= to { assert!(start + to <= end); - mutable.extend( - 0, - (start + from).to_usize().unwrap(), - (start + to + O::usize_as(1)).to_usize().unwrap(), - ); - offsets.push(offsets[row_index] + (to - from + O::usize_as(1))); + if stride.eq(&O::one()) { + // stride is default to 1 + mutable.extend( + 0, + (start + from).to_usize().unwrap(), + (start + to + O::usize_as(1)).to_usize().unwrap(), + ); + offsets.push(offsets[row_index] + (to - from + O::usize_as(1))); + continue; + } + let mut index = start + from; + let mut cnt = 0; + while index <= start + to { + mutable.extend( + 0, + index.to_usize().unwrap(), + index.to_usize().unwrap() + 1, + ); + index += stride; + cnt += 1; + } + offsets.push(offsets[row_index] + O::usize_as(cnt)); } else { + let mut index = start + from; + let mut cnt = 0; + while index >= start + to { + mutable.extend( + 0, + index.to_usize().unwrap(), + index.to_usize().unwrap() + 1, + ); + index += stride; + cnt += 1; + } // invalid range, return empty array - offsets.push(offsets[row_index]); + offsets.push(offsets[row_index] + O::usize_as(cnt)); } } else { // invalid range, return empty array @@ -741,7 +800,7 @@ where .map(|arr| arr.map_or(0, |arr| arr.len() as i64)) .collect::>(), ); - general_array_slice::(array, &from_array, &to_array) + general_array_slice::(array, &from_array, &to_array, None) } fn general_pop_back_list( @@ -757,7 +816,7 @@ where .map(|arr| arr.map_or(0, |arr| arr.len() as i64 - 1)) .collect::>(), ); - general_array_slice::(array, &from_array, &to_array) + general_array_slice::(array, &from_array, &to_array, None) } /// array_pop_front SQL function diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index aae19a15b89a..8db5ccdfd604 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -1456,6 +1456,7 @@ pub fn parse_expr( parse_expr(&args[0], registry)?, parse_expr(&args[1], registry)?, parse_expr(&args[2], registry)?, + parse_expr(&args[3], registry)?, )), ScalarFunction::ArrayToString => Ok(array_to_string( parse_expr(&args[0], registry)?, diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index ee9168de6482..b7d92aec88e6 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -1234,6 +1234,25 @@ select array_slice(make_array(1, 2, 3, 4, 5), 2, 4), array_slice(make_array('h', ---- [2, 3, 4] [h, e] +query ???? +select array_slice(make_array(1, 2, 3, 4, 5), 1, 5, 2), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 1, 5, 2), + array_slice(make_array(1, 2, 3, 4, 5), 0, 5, 2), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 0, 5, 2); +---- +[1, 3, 5] [h, l, o] [1, 3, 5] [h, l, o] + +query ?? +select array_slice(make_array(1, 2, 3, 4, 5), 1, 5, -1), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 1, 5, -1); +---- +[] [] + +query error Execution error: array_slice got invalid stride: 0, it cannot be 0 +select array_slice(make_array(1, 2, 3, 4, 5), 1, 5, 0), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 1, 5, 0); + +query ?? +select array_slice(make_array(1, 2, 3, 4, 5), 5, 1, -2), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 5, 1, -2); +---- +[5, 3, 1] [o, l, h] + query ?? select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2, 4), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 1, 2); ---- @@ -1342,12 +1361,12 @@ select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), NU query ?? select array_slice(make_array(1, 2, 3, 4, 5), 0, -4), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 0, -3); ---- -[1] [h, e] +[1, 2] [h, e, l] query ?? select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 0, -4), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 0, -3); ---- -[1] [h, e] +[1, 2] [h, e, l] # array_slice scalar function #13 (with negative number and NULL) query error @@ -1367,34 +1386,34 @@ select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), NU query ?? select array_slice(make_array(1, 2, 3, 4, 5), -4, -1), array_slice(make_array('h', 'e', 'l', 'l', 'o'), -3, -1); ---- -[2, 3, 4] [l, l] +[2, 3, 4, 5] [l, l, o] query ?? select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), -4, -1), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), -3, -1); ---- -[2, 3, 4] [l, l] +[2, 3, 4, 5] [l, l, o] # array_slice scalar function #16 (with negative indexes; almost full array (only with negative indices cannot return full array)) query ?? select array_slice(make_array(1, 2, 3, 4, 5), -5, -1), array_slice(make_array('h', 'e', 'l', 'l', 'o'), -5, -1); ---- -[1, 2, 3, 4] [h, e, l, l] +[1, 2, 3, 4, 5] [h, e, l, l, o] query ?? select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), -5, -1), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), -5, -1); ---- -[1, 2, 3, 4] [h, e, l, l] +[1, 2, 3, 4, 5] [h, e, l, l, o] # array_slice scalar function #17 (with negative indexes; first index = second index) query ?? select array_slice(make_array(1, 2, 3, 4, 5), -4, -4), array_slice(make_array('h', 'e', 'l', 'l', 'o'), -3, -3); ---- -[] [] +[2] [l] query ?? select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), -4, -4), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), -3, -3); ---- -[] [] +[2] [l] # array_slice scalar function #18 (with negative indexes; first index > second_index) query ?? @@ -1422,24 +1441,24 @@ select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), -7 query ?? select array_slice(make_array(make_array(1, 2, 3, 4, 5), make_array(6, 7, 8, 9, 10)), -2, -1), array_slice(make_array(make_array(1, 2, 3), make_array(6, 7, 8)), -1, -1); ---- -[[1, 2, 3, 4, 5]] [] +[[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]] [[6, 7, 8]] query ?? select array_slice(arrow_cast(make_array(make_array(1, 2, 3, 4, 5), make_array(6, 7, 8, 9, 10)), 'LargeList(List(Int64))'), -2, -1), array_slice(arrow_cast(make_array(make_array(1, 2, 3), make_array(6, 7, 8)), 'LargeList(List(Int64))'), -1, -1); ---- -[[1, 2, 3, 4, 5]] [] +[[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]] [[6, 7, 8]] # array_slice scalar function #21 (with first positive index and last negative index) query ?? select array_slice(make_array(1, 2, 3, 4, 5), 2, -3), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 2, -2); ---- -[2] [e, l] +[2, 3] [e, l, l] query ?? select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2, -3), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 2, -2); ---- -[2] [e, l] +[2, 3] [e, l, l] # array_slice scalar function #22 (with first negative index and last positive index) query ?? @@ -1468,7 +1487,7 @@ query ? select array_slice(column1, column2, column3) from slices; ---- [] -[12, 13, 14, 15, 16] +[12, 13, 14, 15, 16, 17] [] [] [] @@ -1479,7 +1498,7 @@ query ? select array_slice(arrow_cast(column1, 'LargeList(Int64)'), column2, column3) from slices; ---- [] -[12, 13, 14, 15, 16] +[12, 13, 14, 15, 16, 17] [] [] [] @@ -1492,9 +1511,9 @@ query ??? select array_slice(make_array(1, 2, 3, 4, 5), column2, column3), array_slice(column1, 3, column3), array_slice(column1, column2, 5) from slices; ---- [1] [] [, 2, 3, 4, 5] -[] [13, 14, 15, 16] [12, 13, 14, 15] +[2] [13, 14, 15, 16, 17] [12, 13, 14, 15] [] [] [21, 22, 23, , 25] -[] [33] [] +[] [33, 34] [] [4, 5] [] [] [1, 2, 3, 4, 5] [43, 44, 45, 46] [41, 42, 43, 44, 45] [5] [, 54, 55, 56, 57, 58, 59, 60] [55] @@ -1503,9 +1522,9 @@ query ??? select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), column2, column3), array_slice(arrow_cast(column1, 'LargeList(Int64)'), 3, column3), array_slice(arrow_cast(column1, 'LargeList(Int64)'), column2, 5) from slices; ---- [1] [] [, 2, 3, 4, 5] -[] [13, 14, 15, 16] [12, 13, 14, 15] +[2] [13, 14, 15, 16, 17] [12, 13, 14, 15] [] [] [21, 22, 23, , 25] -[] [33] [] +[] [33, 34] [] [4, 5] [] [] [1, 2, 3, 4, 5] [43, 44, 45, 46] [41, 42, 43, 44, 45] [5] [, 54, 55, 56, 57, 58, 59, 60] [55] diff --git a/docs/source/user-guide/expressions.md b/docs/source/user-guide/expressions.md index 85322d9fa766..f01750e56ae0 100644 --- a/docs/source/user-guide/expressions.md +++ b/docs/source/user-guide/expressions.md @@ -207,41 +207,42 @@ Unlike to some databases the math functions in Datafusion works the same way as ## Array Expressions -| Syntax | Description | -| ------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | -| array_append(array, element) | Appends an element to the end of an array. `array_append([1, 2, 3], 4) -> [1, 2, 3, 4]` | -| array_concat(array[, ..., array_n]) | Concatenates arrays. `array_concat([1, 2, 3], [4, 5, 6]) -> [1, 2, 3, 4, 5, 6]` | -| array_has(array, element) | Returns true if the array contains the element `array_has([1,2,3], 1) -> true` | -| array_has_all(array, sub-array) | Returns true if all elements of sub-array exist in array `array_has_all([1,2,3], [1,3]) -> true` | -| array_has_any(array, sub-array) | Returns true if any elements exist in both arrays `array_has_any([1,2,3], [1,4]) -> true` | -| array_dims(array) | Returns an array of the array's dimensions. `array_dims([[1, 2, 3], [4, 5, 6]]) -> [2, 3]` | -| array_distinct(array) | Returns distinct values from the array after removing duplicates. `array_distinct([1, 3, 2, 3, 1, 2, 4]) -> [1, 2, 3, 4]` | -| array_element(array, index) | Extracts the element with the index n from the array `array_element([1, 2, 3, 4], 3) -> 3` | -| flatten(array) | Converts an array of arrays to a flat array `flatten([[1], [2, 3], [4, 5, 6]]) -> [1, 2, 3, 4, 5, 6]` | -| array_length(array, dimension) | Returns the length of the array dimension. `array_length([1, 2, 3, 4, 5]) -> 5` | -| array_ndims(array) | Returns the number of dimensions of the array. `array_ndims([[1, 2, 3], [4, 5, 6]]) -> 2` | -| array_pop_front(array) | Returns the array without the first element. `array_pop_front([1, 2, 3]) -> [2, 3]` | -| array_pop_back(array) | Returns the array without the last element. `array_pop_back([1, 2, 3]) -> [1, 2]` | -| array_position(array, element) | Searches for an element in the array, returns first occurrence. `array_position([1, 2, 2, 3, 4], 2) -> 2` | -| array_positions(array, element) | Searches for an element in the array, returns all occurrences. `array_positions([1, 2, 2, 3, 4], 2) -> [2, 3]` | -| array_prepend(array, element) | Prepends an element to the beginning of an array. `array_prepend(1, [2, 3, 4]) -> [1, 2, 3, 4]` | -| array_repeat(element, count) | Returns an array containing element `count` times. `array_repeat(1, 3) -> [1, 1, 1]` | -| array_remove(array, element) | Removes the first element from the array equal to the given value. `array_remove([1, 2, 2, 3, 2, 1, 4], 2) -> [1, 2, 3, 2, 1, 4]` | -| array_remove_n(array, element, max) | Removes the first `max` elements from the array equal to the given value. `array_remove_n([1, 2, 2, 3, 2, 1, 4], 2, 2) -> [1, 3, 2, 1, 4]` | -| array_remove_all(array, element) | Removes all elements from the array equal to the given value. `array_remove_all([1, 2, 2, 3, 2, 1, 4], 2) -> [1, 3, 1, 4]` | -| array_replace(array, from, to) | Replaces the first occurrence of the specified element with another specified element. `array_replace([1, 2, 2, 3, 2, 1, 4], 2, 5) -> [1, 5, 2, 3, 2, 1, 4]` | -| array_replace_n(array, from, to, max) | Replaces the first `max` occurrences of the specified element with another specified element. `array_replace_n([1, 2, 2, 3, 2, 1, 4], 2, 5, 2) -> [1, 5, 5, 3, 2, 1, 4]` | -| array_replace_all(array, from, to) | Replaces all occurrences of the specified element with another specified element. `array_replace_all([1, 2, 2, 3, 2, 1, 4], 2, 5) -> [1, 5, 5, 3, 5, 1, 4]` | -| array_slice(array, index) | Returns a slice of the array. `array_slice([1, 2, 3, 4, 5, 6, 7, 8], 3, 6) -> [3, 4, 5, 6]` | -| array_to_string(array, delimiter) | Converts each element to its text representation. `array_to_string([1, 2, 3, 4], ',') -> 1,2,3,4` | -| array_intersect(array1, array2) | Returns an array of the elements in the intersection of array1 and array2. `array_intersect([1, 2, 3, 4], [5, 6, 3, 4]) -> [3, 4]` | -| array_union(array1, array2) | Returns an array of the elements in the union of array1 and array2 without duplicates. `array_union([1, 2, 3, 4], [5, 6, 3, 4]) -> [1, 2, 3, 4, 5, 6]` | -| array_except(array1, array2) | Returns an array of the elements that appear in the first array but not in the second. `array_except([1, 2, 3, 4], [5, 6, 3, 4]) -> [3, 4]` | -| array_resize(array, size, value) | Resizes the list to contain size elements. Initializes new elements with value or empty if value is not set. `array_resize([1, 2, 3], 5, 0) -> [1, 2, 3, 4, 5, 6]` | -| cardinality(array) | Returns the total number of elements in the array. `cardinality([[1, 2, 3], [4, 5, 6]]) -> 6` | -| make_array(value1, [value2 [, ...]]) | Returns an Arrow array using the specified input expressions. `make_array(1, 2, 3) -> [1, 2, 3]` | -| range(start [, stop, step]) | Returns an Arrow array between start and stop with step. `SELECT range(2, 10, 3) -> [2, 5, 8]` | -| trim_array(array, n) | Deprecated | +| Syntax | Description | +| -------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| array_append(array, element) | Appends an element to the end of an array. `array_append([1, 2, 3], 4) -> [1, 2, 3, 4]` | +| array_concat(array[, ..., array_n]) | Concatenates arrays. `array_concat([1, 2, 3], [4, 5, 6]) -> [1, 2, 3, 4, 5, 6]` | +| array_has(array, element) | Returns true if the array contains the element `array_has([1,2,3], 1) -> true` | +| array_has_all(array, sub-array) | Returns true if all elements of sub-array exist in array `array_has_all([1,2,3], [1,3]) -> true` | +| array_has_any(array, sub-array) | Returns true if any elements exist in both arrays `array_has_any([1,2,3], [1,4]) -> true` | +| array_dims(array) | Returns an array of the array's dimensions. `array_dims([[1, 2, 3], [4, 5, 6]]) -> [2, 3]` | +| array_distinct(array) | Returns distinct values from the array after removing duplicates. `array_distinct([1, 3, 2, 3, 1, 2, 4]) -> [1, 2, 3, 4]` | +| array_element(array, index) | Extracts the element with the index n from the array `array_element([1, 2, 3, 4], 3) -> 3` | +| flatten(array) | Converts an array of arrays to a flat array `flatten([[1], [2, 3], [4, 5, 6]]) -> [1, 2, 3, 4, 5, 6]` | +| array_length(array, dimension) | Returns the length of the array dimension. `array_length([1, 2, 3, 4, 5]) -> 5` | +| array_ndims(array) | Returns the number of dimensions of the array. `array_ndims([[1, 2, 3], [4, 5, 6]]) -> 2` | +| array_pop_front(array) | Returns the array without the first element. `array_pop_front([1, 2, 3]) -> [2, 3]` | +| array_pop_back(array) | Returns the array without the last element. `array_pop_back([1, 2, 3]) -> [1, 2]` | +| array_position(array, element) | Searches for an element in the array, returns first occurrence. `array_position([1, 2, 2, 3, 4], 2) -> 2` | +| array_positions(array, element) | Searches for an element in the array, returns all occurrences. `array_positions([1, 2, 2, 3, 4], 2) -> [2, 3]` | +| array_prepend(array, element) | Prepends an element to the beginning of an array. `array_prepend(1, [2, 3, 4]) -> [1, 2, 3, 4]` | +| array_repeat(element, count) | Returns an array containing element `count` times. `array_repeat(1, 3) -> [1, 1, 1]` | +| array_remove(array, element) | Removes the first element from the array equal to the given value. `array_remove([1, 2, 2, 3, 2, 1, 4], 2) -> [1, 2, 3, 2, 1, 4]` | +| array_remove_n(array, element, max) | Removes the first `max` elements from the array equal to the given value. `array_remove_n([1, 2, 2, 3, 2, 1, 4], 2, 2) -> [1, 3, 2, 1, 4]` | +| array_remove_all(array, element) | Removes all elements from the array equal to the given value. `array_remove_all([1, 2, 2, 3, 2, 1, 4], 2) -> [1, 3, 1, 4]` | +| array_replace(array, from, to) | Replaces the first occurrence of the specified element with another specified element. `array_replace([1, 2, 2, 3, 2, 1, 4], 2, 5) -> [1, 5, 2, 3, 2, 1, 4]` | +| array_replace_n(array, from, to, max) | Replaces the first `max` occurrences of the specified element with another specified element. `array_replace_n([1, 2, 2, 3, 2, 1, 4], 2, 5, 2) -> [1, 5, 5, 3, 2, 1, 4]` | +| array_replace_all(array, from, to) | Replaces all occurrences of the specified element with another specified element. `array_replace_all([1, 2, 2, 3, 2, 1, 4], 2, 5) -> [1, 5, 5, 3, 5, 1, 4]` | +| array_slice(array, begin,end) | Returns a slice of the array. `array_slice([1, 2, 3, 4, 5, 6, 7, 8], 3, 6) -> [3, 4, 5, 6]` | +| array_slice(array, begin, end, stride) | Returns a slice of the array with added stride feature. `array_slice([1, 2, 3, 4, 5, 6, 7, 8], 3, 6, 2) -> [3, 5, 6]` | +| array_to_string(array, delimiter) | Converts each element to its text representation. `array_to_string([1, 2, 3, 4], ',') -> 1,2,3,4` | +| array_intersect(array1, array2) | Returns an array of the elements in the intersection of array1 and array2. `array_intersect([1, 2, 3, 4], [5, 6, 3, 4]) -> [3, 4]` | +| array_union(array1, array2) | Returns an array of the elements in the union of array1 and array2 without duplicates. `array_union([1, 2, 3, 4], [5, 6, 3, 4]) -> [1, 2, 3, 4, 5, 6]` | +| array_except(array1, array2) | Returns an array of the elements that appear in the first array but not in the second. `array_except([1, 2, 3, 4], [5, 6, 3, 4]) -> [3, 4]` | +| array_resize(array, size, value) | Resizes the list to contain size elements. Initializes new elements with value or empty if value is not set. `array_resize([1, 2, 3], 5, 0) -> [1, 2, 3, 4, 5, 6]` | +| cardinality(array) | Returns the total number of elements in the array. `cardinality([[1, 2, 3], [4, 5, 6]]) -> 6` | +| make_array(value1, [value2 [, ...]]) | Returns an Arrow array using the specified input expressions. `make_array(1, 2, 3) -> [1, 2, 3]` | +| range(start [, stop, step]) | Returns an Arrow array between start and stop with step. `SELECT range(2, 10, 3) -> [2, 5, 8]` | +| trim_array(array, n) | Deprecated | ## Regular Expressions From 97441cca553304d9a2d939ee744425de2efb54fc Mon Sep 17 00:00:00 2001 From: Ruihang Xia Date: Mon, 22 Jan 2024 10:46:25 +0800 Subject: [PATCH 33/39] fix: recursive initialize method (#8937) Signed-off-by: Ruihang Xia --- datafusion/execution/src/object_store.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/datafusion/execution/src/object_store.rs b/datafusion/execution/src/object_store.rs index 7626f8bef162..c0c58a87dcc6 100644 --- a/datafusion/execution/src/object_store.rs +++ b/datafusion/execution/src/object_store.rs @@ -180,7 +180,8 @@ impl DefaultObjectStoreRegistry { /// Default without any backend registered. #[cfg(target_arch = "wasm32")] pub fn new() -> Self { - Self::default() + let object_stores: DashMap> = DashMap::new(); + Self { object_stores } } } From c0a69a7124b27a74bc26f5b9afb8c6ac8980b240 Mon Sep 17 00:00:00 2001 From: Mustafa Akur <106137913+mustafasrepo@users.noreply.github.com> Date: Mon, 22 Jan 2024 10:20:05 +0300 Subject: [PATCH 34/39] Fix expr partial ord test (#8908) * Fix expr partial ord test * Add comment * Resolve linter error --- datafusion/expr/src/expr.rs | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 40d40692e593..9aeebb190e81 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -1869,10 +1869,14 @@ mod test { let exp2 = col("a") + lit(2); let exp3 = !(col("a") + lit(2)); - assert!(exp1 < exp2); - assert!(exp2 > exp1); - assert!(exp2 > exp3); - assert!(exp3 < exp2); + // Since comparisons are done using hash value of the expression + // expr < expr2 may return false, or true. There is no guaranteed result. + // The only guarantee is "<" operator should have the opposite result of ">=" operator + let greater_or_equal = exp1 >= exp2; + assert_eq!(exp1 < exp2, !greater_or_equal); + + let greater_or_equal = exp3 >= exp2; + assert_eq!(exp3 < exp2, !greater_or_equal); } #[test] From 2b218be67a6c412629530b812836a6cec76efc32 Mon Sep 17 00:00:00 2001 From: comphead Date: Mon, 22 Jan 2024 00:34:43 -0800 Subject: [PATCH 35/39] Simplify windows builtin functions return type (#8920) * Simplify windows builtin functions * add field comments --- datafusion/core/src/physical_planner.rs | 25 ++++--- .../core/tests/fuzz_cases/window_fuzz.rs | 41 +++++++++++- .../expr/src/built_in_window_function.rs | 4 +- .../physical-expr/src/window/cume_dist.rs | 14 ++-- .../physical-expr/src/window/lead_lag.rs | 1 + .../physical-expr/src/window/nth_value.rs | 1 + datafusion/physical-expr/src/window/ntile.rs | 13 ++-- datafusion/physical-expr/src/window/rank.rs | 23 +++---- .../physical-expr/src/window/row_number.rs | 16 +++-- datafusion/physical-plan/src/windows/mod.rs | 42 +++++++----- .../tests/cases/roundtrip_physical_plan.rs | 3 +- datafusion/sqllogictest/test_files/window.slt | 66 +++++++++++++++++++ 12 files changed, 186 insertions(+), 63 deletions(-) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index bc448fe06fcf..ed92688559fb 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -86,6 +86,7 @@ use datafusion_expr::expr::{ }; use datafusion_expr::expr_rewriter::unnormalize_cols; use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary; +use datafusion_expr::utils::exprlist_to_fields; use datafusion_expr::{ DescribeTable, DmlStatement, RecursiveQuery, ScalarFunctionDefinition, StringifiedPlan, WindowFrame, WindowFrameBound, WriteOp, @@ -719,14 +720,16 @@ impl DefaultPhysicalPlanner { } let logical_input_schema = input.schema(); - let physical_input_schema = input_exec.schema(); + // Extend the schema to include window expression fields as builtin window functions derives its datatype from incoming schema + let mut window_fields = logical_input_schema.fields().clone(); + window_fields.extend_from_slice(&exprlist_to_fields(window_expr.iter(), input)?); + let extended_schema = &DFSchema::new_with_metadata(window_fields, HashMap::new())?; let window_expr = window_expr .iter() .map(|e| { create_window_expr( e, - logical_input_schema, - &physical_input_schema, + extended_schema, session_state.execution_props(), ) }) @@ -1529,7 +1532,7 @@ fn get_physical_expr_pair( /// queries like: /// OVER (ORDER BY a RANGES BETWEEN 3 PRECEDING AND 5 PRECEDING) /// OVER (ORDER BY a RANGES BETWEEN INTERVAL '3 DAY' PRECEDING AND '5 DAY' PRECEDING) are rejected -pub fn is_window_valid(window_frame: &WindowFrame) -> bool { +pub fn is_window_frame_bound_valid(window_frame: &WindowFrame) -> bool { match (&window_frame.start_bound, &window_frame.end_bound) { (WindowFrameBound::Following(_), WindowFrameBound::Preceding(_)) | (WindowFrameBound::Following(_), WindowFrameBound::CurrentRow) @@ -1549,10 +1552,10 @@ pub fn create_window_expr_with_name( e: &Expr, name: impl Into, logical_input_schema: &DFSchema, - physical_input_schema: &Schema, execution_props: &ExecutionProps, ) -> Result> { let name = name.into(); + let physical_input_schema: &Schema = &logical_input_schema.into(); match e { Expr::WindowFunction(WindowFunction { fun, @@ -1575,7 +1578,8 @@ pub fn create_window_expr_with_name( create_physical_sort_expr(e, logical_input_schema, execution_props) }) .collect::>>()?; - if !is_window_valid(window_frame) { + + if !is_window_frame_bound_valid(window_frame) { return plan_err!( "Invalid window frame: start bound ({}) cannot be larger than end bound ({})", window_frame.start_bound, window_frame.end_bound @@ -1601,7 +1605,6 @@ pub fn create_window_expr_with_name( pub fn create_window_expr( e: &Expr, logical_input_schema: &DFSchema, - physical_input_schema: &Schema, execution_props: &ExecutionProps, ) -> Result> { // unpack aliased logical expressions, e.g. "sum(col) over () as total" @@ -1609,13 +1612,7 @@ pub fn create_window_expr( Expr::Alias(Alias { expr, name, .. }) => (name.clone(), expr.as_ref()), _ => (e.display_name()?, e), }; - create_window_expr_with_name( - e, - name, - logical_input_schema, - physical_input_schema, - execution_props, - ) + create_window_expr_with_name(e, name, logical_input_schema, execution_props) } type AggregateExprWithOptionalArgs = ( diff --git a/datafusion/core/tests/fuzz_cases/window_fuzz.rs b/datafusion/core/tests/fuzz_cases/window_fuzz.rs index 6e5c5f8eb95e..4c440d6a5bfd 100644 --- a/datafusion/core/tests/fuzz_cases/window_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/window_fuzz.rs @@ -22,6 +22,7 @@ use arrow::compute::{concat_batches, SortOptions}; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use arrow::util::pretty::pretty_format_batches; +use arrow_schema::{Field, Schema}; use datafusion::physical_plan::memory::MemoryExec; use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::physical_plan::windows::{ @@ -37,6 +38,7 @@ use datafusion_expr::{ }; use datafusion_physical_expr::expressions::{cast, col, lit}; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; +use itertools::Itertools; use test_utils::add_empty_batches; use hashbrown::HashMap; @@ -482,7 +484,6 @@ async fn run_window_test( let session_config = SessionConfig::new().with_batch_size(50); let ctx = SessionContext::new_with_config(session_config); let (window_fn, args, fn_name) = get_random_function(&schema, &mut rng, is_linear); - let window_frame = get_random_window_frame(&mut rng, is_linear); let mut orderby_exprs = vec![]; for column in &orderby_columns { @@ -532,6 +533,40 @@ async fn run_window_test( if is_linear { exec1 = Arc::new(SortExec::new(sort_keys.clone(), exec1)) as _; } + + // The schema needs to be enriched before the `create_window_expr` + // The reason for this is window expressions datatypes are derived from the schema + // The datafusion code enriches the schema on physical planner and this test copies the same behavior manually + // Also bunch of functions dont require input arguments thus just send an empty vec for such functions + let data_types = if [ + "row_number", + "rank", + "dense_rank", + "percent_rank", + "ntile", + "cume_dist", + ] + .contains(&fn_name.as_str()) + { + vec![] + } else { + args.iter() + .map(|e| e.clone().as_ref().data_type(&schema)) + .collect::>>()? + }; + let window_expr_return_type = window_fn.return_type(&data_types)?; + let mut window_fields = schema + .fields() + .iter() + .map(|f| f.as_ref().clone()) + .collect_vec(); + window_fields.extend_from_slice(&[Field::new( + &fn_name, + window_expr_return_type, + true, + )]); + let extended_schema = Arc::new(Schema::new(window_fields)); + let usual_window_exec = Arc::new( WindowAggExec::try_new( vec![create_window_expr( @@ -541,7 +576,7 @@ async fn run_window_test( &partitionby_exprs, &orderby_exprs, Arc::new(window_frame.clone()), - schema.as_ref(), + &extended_schema, ) .unwrap()], exec1, @@ -563,7 +598,7 @@ async fn run_window_test( &partitionby_exprs, &orderby_exprs, Arc::new(window_frame.clone()), - schema.as_ref(), + extended_schema.as_ref(), ) .unwrap()], exec2, diff --git a/datafusion/expr/src/built_in_window_function.rs b/datafusion/expr/src/built_in_window_function.rs index a03e3d2d24a9..f4b1cd03db1f 100644 --- a/datafusion/expr/src/built_in_window_function.rs +++ b/datafusion/expr/src/built_in_window_function.rs @@ -133,11 +133,11 @@ impl BuiltInWindowFunction { match self { BuiltInWindowFunction::RowNumber | BuiltInWindowFunction::Rank - | BuiltInWindowFunction::DenseRank => Ok(DataType::UInt64), + | BuiltInWindowFunction::DenseRank + | BuiltInWindowFunction::Ntile => Ok(DataType::UInt64), BuiltInWindowFunction::PercentRank | BuiltInWindowFunction::CumeDist => { Ok(DataType::Float64) } - BuiltInWindowFunction::Ntile => Ok(DataType::UInt64), BuiltInWindowFunction::Lag | BuiltInWindowFunction::Lead | BuiltInWindowFunction::FirstValue diff --git a/datafusion/physical-expr/src/window/cume_dist.rs b/datafusion/physical-expr/src/window/cume_dist.rs index edef77c51c31..9720187ea83d 100644 --- a/datafusion/physical-expr/src/window/cume_dist.rs +++ b/datafusion/physical-expr/src/window/cume_dist.rs @@ -34,11 +34,16 @@ use std::sync::Arc; #[derive(Debug)] pub struct CumeDist { name: String, + /// Output data type + data_type: DataType, } /// Create a cume_dist window function -pub fn cume_dist(name: String) -> CumeDist { - CumeDist { name } +pub fn cume_dist(name: String, data_type: &DataType) -> CumeDist { + CumeDist { + name, + data_type: data_type.clone(), + } } impl BuiltInWindowFunctionExpr for CumeDist { @@ -49,8 +54,7 @@ impl BuiltInWindowFunctionExpr for CumeDist { fn field(&self) -> Result { let nullable = false; - let data_type = DataType::Float64; - Ok(Field::new(self.name(), data_type, nullable)) + Ok(Field::new(self.name(), self.data_type.clone(), nullable)) } fn expressions(&self) -> Vec> { @@ -119,7 +123,7 @@ mod tests { #[test] #[allow(clippy::single_range_in_vec_init)] fn test_cume_dist() -> Result<()> { - let r = cume_dist("arr".into()); + let r = cume_dist("arr".into(), &DataType::Float64); let expected = vec![0.0; 0]; test_i32_result(&r, 0, vec![], expected)?; diff --git a/datafusion/physical-expr/src/window/lead_lag.rs b/datafusion/physical-expr/src/window/lead_lag.rs index 7ee736ce9caa..054a4c13e6b6 100644 --- a/datafusion/physical-expr/src/window/lead_lag.rs +++ b/datafusion/physical-expr/src/window/lead_lag.rs @@ -35,6 +35,7 @@ use std::sync::Arc; #[derive(Debug)] pub struct WindowShift { name: String, + /// Output data type data_type: DataType, shift_offset: i64, expr: Arc, diff --git a/datafusion/physical-expr/src/window/nth_value.rs b/datafusion/physical-expr/src/window/nth_value.rs index b3c89122ebad..05909ab25a07 100644 --- a/datafusion/physical-expr/src/window/nth_value.rs +++ b/datafusion/physical-expr/src/window/nth_value.rs @@ -39,6 +39,7 @@ use datafusion_expr::PartitionEvaluator; pub struct NthValue { name: String, expr: Arc, + /// Output data type data_type: DataType, kind: NthValueKind, } diff --git a/datafusion/physical-expr/src/window/ntile.rs b/datafusion/physical-expr/src/window/ntile.rs index f5442e1b0fee..fb7a7ad84fb7 100644 --- a/datafusion/physical-expr/src/window/ntile.rs +++ b/datafusion/physical-expr/src/window/ntile.rs @@ -35,11 +35,17 @@ use std::sync::Arc; pub struct Ntile { name: String, n: u64, + /// Output data type + data_type: DataType, } impl Ntile { - pub fn new(name: String, n: u64) -> Self { - Self { name, n } + pub fn new(name: String, n: u64, data_type: &DataType) -> Self { + Self { + name, + n, + data_type: data_type.clone(), + } } pub fn get_n(&self) -> u64 { @@ -54,8 +60,7 @@ impl BuiltInWindowFunctionExpr for Ntile { fn field(&self) -> Result { let nullable = false; - let data_type = DataType::UInt64; - Ok(Field::new(self.name(), data_type, nullable)) + Ok(Field::new(self.name(), self.data_type.clone(), nullable)) } fn expressions(&self) -> Vec> { diff --git a/datafusion/physical-expr/src/window/rank.rs b/datafusion/physical-expr/src/window/rank.rs index 86af5b322133..1f643f0280dc 100644 --- a/datafusion/physical-expr/src/window/rank.rs +++ b/datafusion/physical-expr/src/window/rank.rs @@ -41,6 +41,8 @@ use std::sync::Arc; pub struct Rank { name: String, rank_type: RankType, + /// Output data type + data_type: DataType, } impl Rank { @@ -58,26 +60,29 @@ pub enum RankType { } /// Create a rank window function -pub fn rank(name: String) -> Rank { +pub fn rank(name: String, data_type: &DataType) -> Rank { Rank { name, rank_type: RankType::Basic, + data_type: data_type.clone(), } } /// Create a dense rank window function -pub fn dense_rank(name: String) -> Rank { +pub fn dense_rank(name: String, data_type: &DataType) -> Rank { Rank { name, rank_type: RankType::Dense, + data_type: data_type.clone(), } } /// Create a percent rank window function -pub fn percent_rank(name: String) -> Rank { +pub fn percent_rank(name: String, data_type: &DataType) -> Rank { Rank { name, rank_type: RankType::Percent, + data_type: data_type.clone(), } } @@ -89,11 +94,7 @@ impl BuiltInWindowFunctionExpr for Rank { fn field(&self) -> Result { let nullable = false; - let data_type = match self.rank_type { - RankType::Basic | RankType::Dense => DataType::UInt64, - RankType::Percent => DataType::Float64, - }; - Ok(Field::new(self.name(), data_type, nullable)) + Ok(Field::new(self.name(), self.data_type.clone(), nullable)) } fn expressions(&self) -> Vec> { @@ -268,7 +269,7 @@ mod tests { #[test] fn test_dense_rank() -> Result<()> { - let r = dense_rank("arr".into()); + let r = dense_rank("arr".into(), &DataType::UInt64); test_without_rank(&r, vec![1; 8])?; test_with_rank(&r, vec![1, 1, 2, 3, 3, 3, 4, 5])?; Ok(()) @@ -276,7 +277,7 @@ mod tests { #[test] fn test_rank() -> Result<()> { - let r = rank("arr".into()); + let r = rank("arr".into(), &DataType::UInt64); test_without_rank(&r, vec![1; 8])?; test_with_rank(&r, vec![1, 1, 3, 4, 4, 4, 7, 8])?; Ok(()) @@ -285,7 +286,7 @@ mod tests { #[test] #[allow(clippy::single_range_in_vec_init)] fn test_percent_rank() -> Result<()> { - let r = percent_rank("arr".into()); + let r = percent_rank("arr".into(), &DataType::Float64); // empty case let expected = vec![0.0; 0]; diff --git a/datafusion/physical-expr/src/window/row_number.rs b/datafusion/physical-expr/src/window/row_number.rs index f5e2f65a656e..759f447ab0f8 100644 --- a/datafusion/physical-expr/src/window/row_number.rs +++ b/datafusion/physical-expr/src/window/row_number.rs @@ -36,12 +36,17 @@ use std::sync::Arc; #[derive(Debug)] pub struct RowNumber { name: String, + /// Output data type + data_type: DataType, } impl RowNumber { /// Create a new ROW_NUMBER function - pub fn new(name: impl Into) -> Self { - Self { name: name.into() } + pub fn new(name: impl Into, data_type: &DataType) -> Self { + Self { + name: name.into(), + data_type: data_type.clone(), + } } } @@ -53,8 +58,7 @@ impl BuiltInWindowFunctionExpr for RowNumber { fn field(&self) -> Result { let nullable = false; - let data_type = DataType::UInt64; - Ok(Field::new(self.name(), data_type, nullable)) + Ok(Field::new(self.name(), self.data_type.clone(), nullable)) } fn expressions(&self) -> Vec> { @@ -127,7 +131,7 @@ mod tests { ])); let schema = Schema::new(vec![Field::new("arr", DataType::Boolean, true)]); let batch = RecordBatch::try_new(Arc::new(schema), vec![arr])?; - let row_number = RowNumber::new("row_number".to_owned()); + let row_number = RowNumber::new("row_number".to_owned(), &DataType::UInt64); let values = row_number.evaluate_args(&batch)?; let result = row_number .create_evaluator()? @@ -145,7 +149,7 @@ mod tests { ])); let schema = Schema::new(vec![Field::new("arr", DataType::Boolean, false)]); let batch = RecordBatch::try_new(Arc::new(schema), vec![arr])?; - let row_number = RowNumber::new("row_number".to_owned()); + let row_number = RowNumber::new("row_number".to_owned(), &DataType::UInt64); let values = row_number.evaluate_args(&batch)?; let result = row_number .create_evaluator()? diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index a85e5cc31c58..e55cc7fca7a6 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -160,12 +160,13 @@ fn create_built_in_window_expr( input_schema: &Schema, name: String, ) -> Result> { + let data_type = input_schema.field_with_name(&name)?.data_type(); Ok(match fun { - BuiltInWindowFunction::RowNumber => Arc::new(RowNumber::new(name)), - BuiltInWindowFunction::Rank => Arc::new(rank(name)), - BuiltInWindowFunction::DenseRank => Arc::new(dense_rank(name)), - BuiltInWindowFunction::PercentRank => Arc::new(percent_rank(name)), - BuiltInWindowFunction::CumeDist => Arc::new(cume_dist(name)), + BuiltInWindowFunction::RowNumber => Arc::new(RowNumber::new(name, data_type)), + BuiltInWindowFunction::Rank => Arc::new(rank(name, data_type)), + BuiltInWindowFunction::DenseRank => Arc::new(dense_rank(name, data_type)), + BuiltInWindowFunction::PercentRank => Arc::new(percent_rank(name, data_type)), + BuiltInWindowFunction::CumeDist => Arc::new(cume_dist(name, data_type)), BuiltInWindowFunction::Ntile => { let n = get_scalar_value_from_args(args, 0)?.ok_or_else(|| { DataFusionError::Execution( @@ -179,32 +180,42 @@ fn create_built_in_window_expr( if n.is_unsigned() { let n: u64 = n.try_into()?; - Arc::new(Ntile::new(name, n)) + Arc::new(Ntile::new(name, n, data_type)) } else { let n: i64 = n.try_into()?; if n <= 0 { return exec_err!("NTILE requires a positive integer"); } - Arc::new(Ntile::new(name, n as u64)) + Arc::new(Ntile::new(name, n as u64, data_type)) } } BuiltInWindowFunction::Lag => { let arg = args[0].clone(); - let data_type = args[0].data_type(input_schema)?; let shift_offset = get_scalar_value_from_args(args, 1)? .map(|v| v.try_into()) .and_then(|v| v.ok()); let default_value = get_scalar_value_from_args(args, 2)?; - Arc::new(lag(name, data_type, arg, shift_offset, default_value)) + Arc::new(lag( + name, + data_type.clone(), + arg, + shift_offset, + default_value, + )) } BuiltInWindowFunction::Lead => { let arg = args[0].clone(); - let data_type = args[0].data_type(input_schema)?; let shift_offset = get_scalar_value_from_args(args, 1)? .map(|v| v.try_into()) .and_then(|v| v.ok()); let default_value = get_scalar_value_from_args(args, 2)?; - Arc::new(lead(name, data_type, arg, shift_offset, default_value)) + Arc::new(lead( + name, + data_type.clone(), + arg, + shift_offset, + default_value, + )) } BuiltInWindowFunction::NthValue => { let arg = args[0].clone(); @@ -214,18 +225,15 @@ fn create_built_in_window_expr( .try_into() .map_err(|e| DataFusionError::Execution(format!("{e:?}")))?; let n: u32 = n as u32; - let data_type = args[0].data_type(input_schema)?; - Arc::new(NthValue::nth(name, arg, data_type, n)?) + Arc::new(NthValue::nth(name, arg, data_type.clone(), n)?) } BuiltInWindowFunction::FirstValue => { let arg = args[0].clone(); - let data_type = args[0].data_type(input_schema)?; - Arc::new(NthValue::first(name, arg, data_type)) + Arc::new(NthValue::first(name, arg, data_type.clone())) } BuiltInWindowFunction::LastValue => { let arg = args[0].clone(); - let data_type = args[0].data_type(input_schema)?; - Arc::new(NthValue::last(name, arg, data_type)) + Arc::new(NthValue::last(name, arg, data_type.clone())) } }) } diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 3a13dc887f0c..8e0f75ce7d11 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -253,7 +253,8 @@ fn roundtrip_nested_loop_join() -> Result<()> { fn roundtrip_window() -> Result<()> { let field_a = Field::new("a", DataType::Int64, false); let field_b = Field::new("b", DataType::Int64, false); - let schema = Arc::new(Schema::new(vec![field_a, field_b])); + let field_c = Field::new("FIRST_VALUE(a) PARTITION BY [b] ORDER BY [a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", DataType::Int64, false); + let schema = Arc::new(Schema::new(vec![field_a, field_b, field_c])); let window_frame = WindowFrame::new_bounds( datafusion_expr::WindowFrameUnits::Range, diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index f8337e21d703..f6d8a1ce8fff 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -3906,3 +3906,69 @@ ProjectionExec: expr=[sn@0 as sn, ts@1 as ts, currency@2 as currency, amount@3 a --BoundedWindowAggExec: wdw=[SUM(table_with_pk.amount) ORDER BY [table_with_pk.sn ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(table_with_pk.amount) ORDER BY [table_with_pk.sn ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow }], mode=[Sorted] ----SortExec: expr=[sn@0 ASC NULLS LAST] ------MemoryExec: partitions=1, partition_sizes=[1] + +# test ROW_NUMBER window function returns correct data_type +query T +select arrow_typeof(row_number() over ()) from (select 1 a) +---- +UInt64 + +# test RANK window function returns correct data_type +query T +select arrow_typeof(rank() over ()) from (select 1 a) +---- +UInt64 + +# test DENSE_RANK window function returns correct data_type +query T +select arrow_typeof(dense_rank() over ()) from (select 1 a) +---- +UInt64 + +# test PERCENT_RANK window function returns correct data_type +query T +select arrow_typeof(percent_rank() over ()) from (select 1 a) +---- +Float64 + +# test CUME_DIST window function returns correct data_type +query T +select arrow_typeof(cume_dist() over ()) from (select 1 a) +---- +Float64 + +# test NTILE window function returns correct data_type +query T +select arrow_typeof(ntile(1) over ()) from (select 1 a) +---- +UInt64 + +# test LAG window function returns correct data_type +query T +select arrow_typeof(lag(a) over ()) from (select 1 a) +---- +Int64 + +# test LEAD window function returns correct data_type +query T +select arrow_typeof(lead(a) over ()) from (select 1 a) +---- +Int64 + +# test FIRST_VALUE window function returns correct data_type +query T +select arrow_typeof(first_value(a) over ()) from (select 1 a) +---- +Int64 + +# test LAST_VALUE window function returns correct data_type +query T +select arrow_typeof(last_value(a) over ()) from (select 1 a) +---- +Int64 + +# test NTH_VALUE window function returns correct data_type +query T +select arrow_typeof(nth_value(a, 1) over ()) from (select 1 a) +---- +Int64 \ No newline at end of file From 38d5f75de45ae3a7e1602456da4f86e127ed319f Mon Sep 17 00:00:00 2001 From: Devin D'Angelo Date: Mon, 22 Jan 2024 13:47:08 -0500 Subject: [PATCH 36/39] Fix handling of nested leaf columns in parallel parquet writer (#8923) * fix handling of nested columns * lint * add suggested tests --- datafusion/common/src/config.rs | 2 +- .../src/datasource/file_format/parquet.rs | 21 ++++---- datafusion/sqllogictest/test_files/copy.slt | 54 ++++++++++++++++++- .../test_files/information_schema.slt | 4 +- .../test_files/repartition_scan.slt | 8 +-- docs/source/user-guide/configs.md | 2 +- 6 files changed, 71 insertions(+), 20 deletions(-) diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index eb516f97a48f..0d773ddb2b4c 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -408,7 +408,7 @@ config_namespace! { /// parquet files by serializing them in parallel. Each column /// in each row group in each output file are serialized in parallel /// leveraging a maximum possible core count of n_files*n_row_groups*n_columns. - pub allow_single_file_parallelism: bool, default = false + pub allow_single_file_parallelism: bool, default = true /// By default parallel parquet writer is tuned for minimum /// memory usage in a streaming execution plan. You may see diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index 9729bfa163af..fdf6277a5ed2 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -885,16 +885,17 @@ async fn send_arrays_to_col_writers( rb: &RecordBatch, schema: Arc, ) -> Result<()> { - for (tx, array, field) in col_array_channels - .iter() - .zip(rb.columns()) - .zip(schema.fields()) - .map(|((a, b), c)| (a, b, c)) - { + // Each leaf column has its own channel, increment next_channel for each leaf column sent. + let mut next_channel = 0; + for (array, field) in rb.columns().iter().zip(schema.fields()) { for c in compute_leaves(field, array)? { - tx.send(c).await.map_err(|_| { - DataFusionError::Internal("Unable to send array to writer!".into()) - })?; + col_array_channels[next_channel] + .send(c) + .await + .map_err(|_| { + DataFusionError::Internal("Unable to send array to writer!".into()) + })?; + next_channel += 1; } } @@ -902,7 +903,7 @@ async fn send_arrays_to_col_writers( } /// Spawns a tokio task which joins the parallel column writer tasks, -/// and finalizes the row group. +/// and finalizes the row group fn spawn_rg_join_and_finalize_task( column_writer_handles: Vec>>, rg_rows: usize, diff --git a/datafusion/sqllogictest/test_files/copy.slt b/datafusion/sqllogictest/test_files/copy.slt index 9f5b7af41577..c9b3bdfa338b 100644 --- a/datafusion/sqllogictest/test_files/copy.slt +++ b/datafusion/sqllogictest/test_files/copy.slt @@ -64,6 +64,24 @@ select * from validate_parquet; 1 Foo 2 Bar +query ? +copy (values (struct(timestamp '2021-01-01 01:00:01', 1)), (struct(timestamp '2022-01-01 01:00:01', 2)), +(struct(timestamp '2023-01-03 01:00:01', 3)), (struct(timestamp '2024-01-01 01:00:01', 4))) +to 'test_files/scratch/copy/table_nested2' (format parquet, single_file_output false); +---- +4 + +statement ok +CREATE EXTERNAL TABLE validate_parquet_nested2 STORED AS PARQUET LOCATION 'test_files/scratch/copy/table_nested2/'; + +query ? +select * from validate_parquet_nested2; +---- +{c0: 2021-01-01T01:00:01, c1: 1} +{c0: 2022-01-01T01:00:01, c1: 2} +{c0: 2023-01-03T01:00:01, c1: 3} +{c0: 2024-01-01T01:00:01, c1: 4} + query ?? COPY (values (struct ('foo', (struct ('foo', make_array(struct('a',1), struct('b',2))))), make_array(timestamp '2023-01-01 01:00:01',timestamp '2023-01-01 01:00:01')), @@ -72,9 +90,9 @@ to 'test_files/scratch/copy/table_nested' (format parquet, single_file_output fa ---- 2 -# validate multiple parquet file output statement ok -CREATE EXTERNAL TABLE validate_parquet_nested STORED AS PARQUET LOCATION 'test_files/scratch/copy/table_nested/'; +CREATE EXTERNAL TABLE validate_parquet_nested STORED AS PARQUET +LOCATION 'test_files/scratch/copy/table_nested/'; query ?? select * from validate_parquet_nested; @@ -82,6 +100,38 @@ select * from validate_parquet_nested; {c0: foo, c1: {c0: foo, c1: [{c0: a, c1: 1}, {c0: b, c1: 2}]}} [2023-01-01T01:00:01, 2023-01-01T01:00:01] {c0: bar, c1: {c0: foo, c1: [{c0: aa, c1: 10}, {c0: bb, c1: 20}]}} [2024-01-01T01:00:01, 2024-01-01T01:00:01] +query ? +copy (values ([struct('foo', 1), struct('bar', 2)])) +to 'test_files/scratch/copy/array_of_struct/' +(format parquet, single_file_output false); +---- +1 + +statement ok +CREATE EXTERNAL TABLE validate_array_of_struct +STORED AS PARQUET LOCATION 'test_files/scratch/copy/array_of_struct/'; + +query ? +select * from validate_array_of_struct; +---- +[{c0: foo, c1: 1}, {c0: bar, c1: 2}] + +query ? +copy (values (struct('foo', [1,2,3], struct('bar', [2,3,4])))) +to 'test_files/scratch/copy/struct_with_array/' +(format parquet, single_file_output false); +---- +1 + +statement ok +CREATE EXTERNAL TABLE validate_struct_with_array +STORED AS PARQUET LOCATION 'test_files/scratch/copy/struct_with_array/'; + +query ? +select * from validate_struct_with_array; +---- +{c0: foo, c1: [1, 2, 3], c2: {c0: bar, c1: [2, 3, 4]}} + # Copy parquet with all supported statment overrides query IT diff --git a/datafusion/sqllogictest/test_files/information_schema.slt b/datafusion/sqllogictest/test_files/information_schema.slt index 768292d3d4b4..43899f756735 100644 --- a/datafusion/sqllogictest/test_files/information_schema.slt +++ b/datafusion/sqllogictest/test_files/information_schema.slt @@ -155,7 +155,7 @@ datafusion.execution.listing_table_ignore_subdirectory true datafusion.execution.max_buffered_batches_per_output_file 2 datafusion.execution.meta_fetch_concurrency 32 datafusion.execution.minimum_parallel_output_files 4 -datafusion.execution.parquet.allow_single_file_parallelism false +datafusion.execution.parquet.allow_single_file_parallelism true datafusion.execution.parquet.bloom_filter_enabled false datafusion.execution.parquet.bloom_filter_fpp NULL datafusion.execution.parquet.bloom_filter_ndv NULL @@ -232,7 +232,7 @@ datafusion.execution.listing_table_ignore_subdirectory true Should sub directori datafusion.execution.max_buffered_batches_per_output_file 2 This is the maximum number of RecordBatches buffered for each output file being worked. Higher values can potentially give faster write performance at the cost of higher peak memory consumption datafusion.execution.meta_fetch_concurrency 32 Number of files to read in parallel when inferring schema and statistics datafusion.execution.minimum_parallel_output_files 4 Guarantees a minimum level of output files running in parallel. RecordBatches will be distributed in round robin fashion to each parallel writer. Each writer is closed and a new file opened once soft_max_rows_per_output_file is reached. -datafusion.execution.parquet.allow_single_file_parallelism false Controls whether DataFusion will attempt to speed up writing parquet files by serializing them in parallel. Each column in each row group in each output file are serialized in parallel leveraging a maximum possible core count of n_files*n_row_groups*n_columns. +datafusion.execution.parquet.allow_single_file_parallelism true Controls whether DataFusion will attempt to speed up writing parquet files by serializing them in parallel. Each column in each row group in each output file are serialized in parallel leveraging a maximum possible core count of n_files*n_row_groups*n_columns. datafusion.execution.parquet.bloom_filter_enabled false Sets if bloom filter is enabled for any column datafusion.execution.parquet.bloom_filter_fpp NULL Sets bloom filter false positive probability. If NULL, uses default parquet writer setting datafusion.execution.parquet.bloom_filter_ndv NULL Sets bloom filter number of distinct values. If NULL, uses default parquet writer setting diff --git a/datafusion/sqllogictest/test_files/repartition_scan.slt b/datafusion/sqllogictest/test_files/repartition_scan.slt index 5ee0da2d33e8..4b8c8f2f084e 100644 --- a/datafusion/sqllogictest/test_files/repartition_scan.slt +++ b/datafusion/sqllogictest/test_files/repartition_scan.slt @@ -61,7 +61,7 @@ Filter: parquet_table.column1 != Int32(42) physical_plan CoalesceBatchesExec: target_batch_size=8192 --FilterExec: column1@0 != 42 -----ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..153], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:153..306], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:306..459], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:459..610]]}, projection=[column1], predicate=column1@0 != 42, pruning_predicate=column1_min@0 != 42 OR 42 != column1_max@1, required_guarantees=[column1 not in (42)] +----ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..104], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:104..208], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:208..312], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:312..414]]}, projection=[column1], predicate=column1@0 != 42, pruning_predicate=column1_min@0 != 42 OR 42 != column1_max@1, required_guarantees=[column1 not in (42)] # disable round robin repartitioning statement ok @@ -77,7 +77,7 @@ Filter: parquet_table.column1 != Int32(42) physical_plan CoalesceBatchesExec: target_batch_size=8192 --FilterExec: column1@0 != 42 -----ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..153], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:153..306], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:306..459], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:459..610]]}, projection=[column1], predicate=column1@0 != 42, pruning_predicate=column1_min@0 != 42 OR 42 != column1_max@1, required_guarantees=[column1 not in (42)] +----ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..104], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:104..208], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:208..312], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:312..414]]}, projection=[column1], predicate=column1@0 != 42, pruning_predicate=column1_min@0 != 42 OR 42 != column1_max@1, required_guarantees=[column1 not in (42)] # enable round robin repartitioning again statement ok @@ -102,7 +102,7 @@ SortPreservingMergeExec: [column1@0 ASC NULLS LAST] --SortExec: expr=[column1@0 ASC NULLS LAST] ----CoalesceBatchesExec: target_batch_size=8192 ------FilterExec: column1@0 != 42 ---------ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:0..303], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:303..601, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..5], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:5..308], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:308..610]]}, projection=[column1], predicate=column1@0 != 42, pruning_predicate=column1_min@0 != 42 OR 42 != column1_max@1, required_guarantees=[column1 not in (42)] +--------ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:0..205], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:205..405, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..5], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:5..210], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:210..414]]}, projection=[column1], predicate=column1@0 != 42, pruning_predicate=column1_min@0 != 42 OR 42 != column1_max@1, required_guarantees=[column1 not in (42)] ## Read the files as though they are ordered @@ -138,7 +138,7 @@ physical_plan SortPreservingMergeExec: [column1@0 ASC NULLS LAST] --CoalesceBatchesExec: target_batch_size=8192 ----FilterExec: column1@0 != 42 -------ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:0..300], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..305], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:305..610], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:300..601]]}, projection=[column1], output_ordering=[column1@0 ASC NULLS LAST], predicate=column1@0 != 42, pruning_predicate=column1_min@0 != 42 OR 42 != column1_max@1, required_guarantees=[column1 not in (42)] +------ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:0..202], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..207], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:207..414], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:202..405]]}, projection=[column1], output_ordering=[column1@0 ASC NULLS LAST], predicate=column1@0 != 42, pruning_predicate=column1_min@0 != 42 OR 42 != column1_max@1, required_guarantees=[column1 not in (42)] # Cleanup statement ok diff --git a/docs/source/user-guide/configs.md b/docs/source/user-guide/configs.md index 9d914aaaf15f..8b039102d4d7 100644 --- a/docs/source/user-guide/configs.md +++ b/docs/source/user-guide/configs.md @@ -71,7 +71,7 @@ Environment variables are read during `SessionConfig` initialisation so they mus | datafusion.execution.parquet.bloom_filter_enabled | false | Sets if bloom filter is enabled for any column | | datafusion.execution.parquet.bloom_filter_fpp | NULL | Sets bloom filter false positive probability. If NULL, uses default parquet writer setting | | datafusion.execution.parquet.bloom_filter_ndv | NULL | Sets bloom filter number of distinct values. If NULL, uses default parquet writer setting | -| datafusion.execution.parquet.allow_single_file_parallelism | false | Controls whether DataFusion will attempt to speed up writing parquet files by serializing them in parallel. Each column in each row group in each output file are serialized in parallel leveraging a maximum possible core count of n_files*n_row_groups*n_columns. | +| datafusion.execution.parquet.allow_single_file_parallelism | true | Controls whether DataFusion will attempt to speed up writing parquet files by serializing them in parallel. Each column in each row group in each output file are serialized in parallel leveraging a maximum possible core count of n_files*n_row_groups*n_columns. | | datafusion.execution.parquet.maximum_parallel_row_group_writers | 1 | By default parallel parquet writer is tuned for minimum memory usage in a streaming execution plan. You may see a performance benefit when writing large parquet files by increasing maximum_parallel_row_group_writers and maximum_buffered_record_batches_per_stream if your system has idle cores and can tolerate additional memory usage. Boosting these values is likely worthwhile when writing out already in-memory data, such as from a cached data frame. | | datafusion.execution.parquet.maximum_buffered_record_batches_per_stream | 2 | By default parallel parquet writer is tuned for minimum memory usage in a streaming execution plan. You may see a performance benefit when writing large parquet files by increasing maximum_parallel_row_group_writers and maximum_buffered_record_batches_per_stream if your system has idle cores and can tolerate additional memory usage. Boosting these values is likely worthwhile when writing out already in-memory data, such as from a cached data frame. | | datafusion.execution.aggregate.scalar_update_factor | 10 | Specifies the threshold for using `ScalarValue`s to update accumulators during high-cardinality aggregations for each input batch. The aggregation is considered high-cardinality if the number of affected groups is greater than or equal to `batch_size / scalar_update_factor`. In such cases, `ScalarValue`s are utilized for updating accumulators, rather than the default batch-slice approach. This can lead to performance improvements. By adjusting the `scalar_update_factor`, you can balance the trade-off between more efficient accumulator updates and the number of groups affected. | From f2e67019c3655d9dffb4769a19ab32c97a70bdce Mon Sep 17 00:00:00 2001 From: Eduard Karacharov <13005055+korowa@users.noreply.github.com> Date: Mon, 22 Jan 2024 20:53:46 +0200 Subject: [PATCH 37/39] feat: emitting partial join results in `HashJoinStream` (#8020) * HashJoin partial batch emitting * batch splitting tests * stream state & extended tests * fmt & clippy warns fixed * Apply suggestions from code review Co-authored-by: Andrew Lamb * review comments * ported join limited output * comments & formatting * Reuse hashes buffer * Apply suggestions from code review Co-authored-by: Andrew Lamb * fixed metrics and updated comment * precalculate hashes & remove iterators * draft: preparing for review --------- Co-authored-by: Andrew Lamb --- datafusion/physical-plan/Cargo.toml | 1 + .../physical-plan/src/joins/hash_join.rs | 623 ++++++++++++------ .../src/joins/nested_loop_join.rs | 18 +- .../src/joins/symmetric_hash_join.rs | 131 +++- datafusion/physical-plan/src/joins/utils.rs | 238 +++++-- datafusion/physical-plan/src/lib.rs | 3 + datafusion/physical-plan/src/test/exec.rs | 6 +- 7 files changed, 736 insertions(+), 284 deletions(-) diff --git a/datafusion/physical-plan/Cargo.toml b/datafusion/physical-plan/Cargo.toml index 357e036b6f39..1c638d9c184e 100644 --- a/datafusion/physical-plan/Cargo.toml +++ b/datafusion/physical-plan/Cargo.toml @@ -61,6 +61,7 @@ uuid = { version = "^1.2", features = ["v4"] } [dev-dependencies] rstest = { workspace = true } +rstest_reuse = "0.6.0" termtree = "0.4.1" tokio = { version = "1.28", features = [ "macros", diff --git a/datafusion/physical-plan/src/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join.rs index 374a0ad50700..0c213f425785 100644 --- a/datafusion/physical-plan/src/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join.rs @@ -26,7 +26,7 @@ use std::{any::Any, usize, vec}; use crate::joins::utils::{ adjust_indices_by_join_type, apply_join_filter_to_indices, build_batch_from_indices, calculate_join_output_ordering, get_final_indices_from_bit_map, - need_produce_result_in_final, JoinHashMap, JoinHashMapType, + need_produce_result_in_final, JoinHashMap, JoinHashMapOffset, JoinHashMapType, }; use crate::{ coalesce_partitions::CoalescePartitionsExec, @@ -61,7 +61,8 @@ use arrow::util::bit_util; use arrow_array::cast::downcast_array; use arrow_schema::ArrowError; use datafusion_common::{ - exec_err, internal_err, plan_err, DataFusionError, JoinSide, JoinType, Result, + internal_datafusion_err, internal_err, plan_err, DataFusionError, JoinSide, JoinType, + Result, }; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; use datafusion_execution::TaskContext; @@ -644,6 +645,8 @@ impl ExecutionPlan for HashJoinExec { } }; + let batch_size = context.session_config().batch_size(); + let reservation = MemoryConsumer::new(format!("HashJoinStream[{partition}]")) .register(context.memory_pool()); @@ -665,6 +668,8 @@ impl ExecutionPlan for HashJoinExec { reservation, state: HashJoinStreamState::WaitBuildSide, build_side: BuildSide::Initial(BuildSideInitialState { left_fut }), + batch_size, + hashes_buffer: vec![], })) } @@ -908,16 +913,10 @@ enum HashJoinStreamState { Completed, } -/// Container for HashJoinStreamState::ProcessProbeBatch related data -struct ProcessProbeBatchState { - /// Current probe-side batch - batch: RecordBatch, -} - impl HashJoinStreamState { /// Tries to extract ProcessProbeBatchState from HashJoinStreamState enum. /// Returns an error if state is not ProcessProbeBatchState. - fn try_as_process_probe_batch(&self) -> Result<&ProcessProbeBatchState> { + fn try_as_process_probe_batch_mut(&mut self) -> Result<&mut ProcessProbeBatchState> { match self { HashJoinStreamState::ProcessProbeBatch(state) => Ok(state), _ => internal_err!("Expected hash join stream in ProcessProbeBatch state"), @@ -925,6 +924,25 @@ impl HashJoinStreamState { } } +/// Container for HashJoinStreamState::ProcessProbeBatch related data +struct ProcessProbeBatchState { + /// Current probe-side batch + batch: RecordBatch, + /// Starting offset for JoinHashMap lookups + offset: JoinHashMapOffset, + /// Max joined probe-side index from current batch + joined_probe_idx: Option, +} + +impl ProcessProbeBatchState { + fn advance(&mut self, offset: JoinHashMapOffset, joined_probe_idx: Option) { + self.offset = offset; + if joined_probe_idx.is_some() { + self.joined_probe_idx = joined_probe_idx; + } + } +} + /// [`Stream`] for [`HashJoinExec`] that does the actual join. /// /// This stream: @@ -960,6 +978,10 @@ struct HashJoinStream { state: HashJoinStreamState, /// Build side build_side: BuildSide, + /// Maximum output batch size + batch_size: usize, + /// Scratch space for computing hashes + hashes_buffer: Vec, } impl RecordBatchStream for HashJoinStream { @@ -968,7 +990,10 @@ impl RecordBatchStream for HashJoinStream { } } -/// Returns build/probe indices satisfying the equality condition. +/// Executes lookups by hash against JoinHashMap and resolves potential +/// hash collisions. +/// Returns build/probe indices satisfying the equality condition, along with +/// (optional) starting point for next iteration. /// /// # Example /// @@ -1014,20 +1039,17 @@ impl RecordBatchStream for HashJoinStream { /// Probe indices: 3, 3, 4, 5 /// ``` #[allow(clippy::too_many_arguments)] -pub fn build_equal_condition_join_indices( - build_hashmap: &T, +fn lookup_join_hashmap( + build_hashmap: &JoinHashMap, build_input_buffer: &RecordBatch, probe_batch: &RecordBatch, build_on: &[Column], probe_on: &[Column], - random_state: &RandomState, null_equals_null: bool, - hashes_buffer: &mut Vec, - filter: Option<&JoinFilter>, - build_side: JoinSide, - deleted_offset: Option, - fifo_hashmap: bool, -) -> Result<(UInt64Array, UInt32Array)> { + hashes_buffer: &[u64], + limit: usize, + offset: JoinHashMapOffset, +) -> Result<(UInt64Array, UInt32Array, Option)> { let keys_values = probe_on .iter() .map(|c| c.evaluate(probe_batch)?.into_array(probe_batch.num_rows())) @@ -1039,76 +1061,24 @@ pub fn build_equal_condition_join_indices( .into_array(build_input_buffer.num_rows()) }) .collect::>>()?; - hashes_buffer.clear(); - hashes_buffer.resize(probe_batch.num_rows(), 0); - let hash_values = create_hashes(&keys_values, random_state, hashes_buffer)?; - // In case build-side input has not been inverted while JoinHashMap creation, the chained list algorithm - // will return build indices for each probe row in a reverse order as such: - // Build Indices: [5, 4, 3] - // Probe Indices: [1, 1, 1] - // - // This affects the output sequence. Hypothetically, it's possible to preserve the lexicographic order on the build side. - // Let's consider probe rows [0,1] as an example: - // - // When the probe iteration sequence is reversed, the following pairings can be derived: - // - // For probe row 1: - // (5, 1) - // (4, 1) - // (3, 1) - // - // For probe row 0: - // (5, 0) - // (4, 0) - // (3, 0) - // - // After reversing both sets of indices, we obtain reversed indices: - // - // (3,0) - // (4,0) - // (5,0) - // (3,1) - // (4,1) - // (5,1) - // - // With this approach, the lexicographic order on both the probe side and the build side is preserved. - let (mut probe_indices, mut build_indices) = if fifo_hashmap { - build_hashmap.get_matched_indices(hash_values.iter().enumerate(), deleted_offset) - } else { - let (mut matched_probe, mut matched_build) = build_hashmap - .get_matched_indices(hash_values.iter().enumerate().rev(), deleted_offset); - - matched_probe.as_slice_mut().reverse(); - matched_build.as_slice_mut().reverse(); - - (matched_probe, matched_build) - }; + let (mut probe_builder, mut build_builder, next_offset) = build_hashmap + .get_matched_indices_with_limit_offset(hashes_buffer, None, limit, offset); - let left: UInt64Array = PrimitiveArray::new(build_indices.finish().into(), None); - let right: UInt32Array = PrimitiveArray::new(probe_indices.finish().into(), None); + let build_indices: UInt64Array = + PrimitiveArray::new(build_builder.finish().into(), None); + let probe_indices: UInt32Array = + PrimitiveArray::new(probe_builder.finish().into(), None); - let (left, right) = if let Some(filter) = filter { - // Filter the indices which satisfy the non-equal join condition, like `left.b1 = 10` - apply_join_filter_to_indices( - build_input_buffer, - probe_batch, - left, - right, - filter, - build_side, - )? - } else { - (left, right) - }; - - equal_rows_arr( - &left, - &right, + let (build_indices, probe_indices) = equal_rows_arr( + &build_indices, + &probe_indices, &build_join_values, &keys_values, null_equals_null, - ) + )?; + + Ok((build_indices, probe_indices, next_offset)) } // version of eq_dyn supporting equality on null arrays @@ -1253,9 +1223,25 @@ impl HashJoinStream { self.state = HashJoinStreamState::ExhaustedProbeSide; } Some(Ok(batch)) => { + // Precalculate hash values for fetched batch + let keys_values = self + .on_right + .iter() + .map(|c| c.evaluate(&batch)?.into_array(batch.num_rows())) + .collect::>>()?; + + self.hashes_buffer.clear(); + self.hashes_buffer.resize(batch.num_rows(), 0); + create_hashes(&keys_values, &self.random_state, &mut self.hashes_buffer)?; + + self.join_metrics.input_batches.add(1); + self.join_metrics.input_rows.add(batch.num_rows()); + self.state = HashJoinStreamState::ProcessProbeBatch(ProcessProbeBatchState { batch, + offset: (0, None), + joined_probe_idx: None, }); } Some(Err(err)) => return Poll::Ready(Err(err)), @@ -1270,70 +1256,108 @@ impl HashJoinStream { fn process_probe_batch( &mut self, ) -> Result>> { - let state = self.state.try_as_process_probe_batch()?; + let state = self.state.try_as_process_probe_batch_mut()?; let build_side = self.build_side.try_as_ready_mut()?; - self.join_metrics.input_batches.add(1); - self.join_metrics.input_rows.add(state.batch.num_rows()); let timer = self.join_metrics.join_time.timer(); - let mut hashes_buffer = vec![]; - // get the matched two indices for the on condition - let left_right_indices = build_equal_condition_join_indices( + // get the matched by join keys indices + let (left_indices, right_indices, next_offset) = lookup_join_hashmap( build_side.left_data.hash_map(), build_side.left_data.batch(), &state.batch, &self.on_left, &self.on_right, - &self.random_state, self.null_equals_null, - &mut hashes_buffer, - self.filter.as_ref(), - JoinSide::Left, - None, - true, - ); + &self.hashes_buffer, + self.batch_size, + state.offset, + )?; - let result = match left_right_indices { - Ok((left_side, right_side)) => { - // set the left bitmap - // and only left, full, left semi, left anti need the left bitmap - if need_produce_result_in_final(self.join_type) { - left_side.iter().flatten().for_each(|x| { - build_side.visited_left_side.set_bit(x as usize, true); - }); - } + // apply join filter if exists + let (left_indices, right_indices) = if let Some(filter) = &self.filter { + apply_join_filter_to_indices( + build_side.left_data.batch(), + &state.batch, + left_indices, + right_indices, + filter, + JoinSide::Left, + )? + } else { + (left_indices, right_indices) + }; - // adjust the two side indices base on the join type - let (left_side, right_side) = adjust_indices_by_join_type( - left_side, - right_side, - state.batch.num_rows(), - self.join_type, - ); + // mark joined left-side indices as visited, if required by join type + if need_produce_result_in_final(self.join_type) { + left_indices.iter().flatten().for_each(|x| { + build_side.visited_left_side.set_bit(x as usize, true); + }); + } - let result = build_batch_from_indices( - &self.schema, - build_side.left_data.batch(), - &state.batch, - &left_side, - &right_side, - &self.column_indices, - JoinSide::Left, - ); - self.join_metrics.output_batches.add(1); - self.join_metrics.output_rows.add(state.batch.num_rows()); - result - } - Err(err) => { - exec_err!("Fail to build join indices in HashJoinExec, error:{err}") - } + // The goals of index alignment for different join types are: + // + // 1) Right & FullJoin -- to append all missing probe-side indices between + // previous (excluding) and current joined indices. + // 2) SemiJoin -- deduplicate probe indices in range between previous + // (excluding) and current joined indices. + // 3) AntiJoin -- return only missing indices in range between + // previous and current joined indices. + // Inclusion/exclusion of the indices themselves don't matter + // + // As a summary -- alignment range can be produced based only on + // joined (matched with filters applied) probe side indices, excluding starting one + // (left from previous iteration). + + // if any rows have been joined -- get last joined probe-side (right) row + // it's important that index counts as "joined" after hash collisions checks + // and join filters applied. + let last_joined_right_idx = match right_indices.len() { + 0 => None, + n => Some(right_indices.value(n - 1) as usize), }; + + // Calculate range and perform alignment. + // In case probe batch has been processed -- align all remaining rows. + let index_alignment_range_start = state.joined_probe_idx.map_or(0, |v| v + 1); + let index_alignment_range_end = if next_offset.is_none() { + state.batch.num_rows() + } else { + last_joined_right_idx.map_or(0, |v| v + 1) + }; + + let (left_indices, right_indices) = adjust_indices_by_join_type( + left_indices, + right_indices, + index_alignment_range_start..index_alignment_range_end, + self.join_type, + ); + + let result = build_batch_from_indices( + &self.schema, + build_side.left_data.batch(), + &state.batch, + &left_indices, + &right_indices, + &self.column_indices, + JoinSide::Left, + )?; + + self.join_metrics.output_batches.add(1); + self.join_metrics.output_rows.add(result.num_rows()); timer.done(); - self.state = HashJoinStreamState::FetchProbeBatch; + if next_offset.is_none() { + self.state = HashJoinStreamState::FetchProbeBatch; + } else { + state.advance( + next_offset + .ok_or_else(|| internal_datafusion_err!("unexpected None offset"))?, + last_joined_right_idx, + ) + }; - Ok(StatefulStreamResult::Ready(Some(result?))) + Ok(StatefulStreamResult::Ready(Some(result))) } /// Processes unmatched build-side rows for certain join types and produces output batch @@ -1399,15 +1423,15 @@ mod tests { use super::*; use crate::{ - common, expressions::Column, hash_utils::create_hashes, - joins::hash_join::build_equal_condition_join_indices, memory::MemoryExec, + common, expressions::Column, hash_utils::create_hashes, memory::MemoryExec, repartition::RepartitionExec, test::build_table_i32, test::exec::MockExec, }; use arrow::array::{ArrayRef, Date32Array, Int32Array, UInt32Builder, UInt64Builder}; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::{ - assert_batches_eq, assert_batches_sorted_eq, assert_contains, ScalarValue, + assert_batches_eq, assert_batches_sorted_eq, assert_contains, exec_err, + ScalarValue, }; use datafusion_execution::config::SessionConfig; use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; @@ -1415,6 +1439,21 @@ mod tests { use datafusion_physical_expr::expressions::{BinaryExpr, Literal}; use hashbrown::raw::RawTable; + use rstest::*; + use rstest_reuse::{self, *}; + + fn div_ceil(a: usize, b: usize) -> usize { + (a + b - 1) / b + } + + #[template] + #[rstest] + fn batch_sizes(#[values(8192, 10, 5, 2, 1)] batch_size: usize) {} + + fn prepare_task_ctx(batch_size: usize) -> Arc { + let session_config = SessionConfig::default().with_batch_size(batch_size); + Arc::new(TaskContext::default().with_session_config(session_config)) + } fn build_table( a: (&str, &Vec), @@ -1533,9 +1572,10 @@ mod tests { Ok((columns, batches)) } + #[apply(batch_sizes)] #[tokio::test] - async fn join_inner_one() -> Result<()> { - let task_ctx = Arc::new(TaskContext::default()); + async fn join_inner_one(batch_size: usize) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 5]), // this has a repetition @@ -1580,9 +1620,10 @@ mod tests { Ok(()) } + #[apply(batch_sizes)] #[tokio::test] - async fn partitioned_join_inner_one() -> Result<()> { - let task_ctx = Arc::new(TaskContext::default()); + async fn partitioned_join_inner_one(batch_size: usize) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 5]), // this has a repetition @@ -1703,9 +1744,10 @@ mod tests { Ok(()) } + #[apply(batch_sizes)] #[tokio::test] - async fn join_inner_two() -> Result<()> { - let task_ctx = Arc::new(TaskContext::default()); + async fn join_inner_two(batch_size: usize) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size); let left = build_table( ("a1", &vec![1, 2, 2]), ("b2", &vec![1, 2, 2]), @@ -1732,7 +1774,13 @@ mod tests { assert_eq!(columns, vec!["a1", "b2", "c1", "a1", "b2", "c2"]); - assert_eq!(batches.len(), 1); + // expected joined records = 3 + // in case batch_size is 1 - additional empty batch for remaining 3-2 row + let mut expected_batch_count = div_ceil(3, batch_size); + if batch_size == 1 { + expected_batch_count += 1; + } + assert_eq!(batches.len(), expected_batch_count); let expected = [ "+----+----+----+----+----+----+", @@ -1751,9 +1799,10 @@ mod tests { } /// Test where the left has 2 parts, the right with 1 part => 1 part + #[apply(batch_sizes)] #[tokio::test] - async fn join_inner_one_two_parts_left() -> Result<()> { - let task_ctx = Arc::new(TaskContext::default()); + async fn join_inner_one_two_parts_left(batch_size: usize) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size); let batch1 = build_table_i32( ("a1", &vec![1, 2]), ("b2", &vec![1, 2]), @@ -1787,7 +1836,13 @@ mod tests { assert_eq!(columns, vec!["a1", "b2", "c1", "a1", "b2", "c2"]); - assert_eq!(batches.len(), 1); + // expected joined records = 3 + // in case batch_size is 1 - additional empty batch for remaining 3-2 row + let mut expected_batch_count = div_ceil(3, batch_size); + if batch_size == 1 { + expected_batch_count += 1; + } + assert_eq!(batches.len(), expected_batch_count); let expected = [ "+----+----+----+----+----+----+", @@ -1856,9 +1911,10 @@ mod tests { } /// Test where the left has 1 part, the right has 2 parts => 2 parts + #[apply(batch_sizes)] #[tokio::test] - async fn join_inner_one_two_parts_right() -> Result<()> { - let task_ctx = Arc::new(TaskContext::default()); + async fn join_inner_one_two_parts_right(batch_size: usize) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 5]), // this has a repetition @@ -1890,7 +1946,14 @@ mod tests { // first part let stream = join.execute(0, task_ctx.clone())?; let batches = common::collect(stream).await?; - assert_eq!(batches.len(), 1); + + // expected joined records = 1 (first right batch) + // and additional empty batch for non-joined 20-6-80 + let mut expected_batch_count = div_ceil(1, batch_size); + if batch_size == 1 { + expected_batch_count += 1; + } + assert_eq!(batches.len(), expected_batch_count); let expected = [ "+----+----+----+----+----+----+", @@ -1906,7 +1969,11 @@ mod tests { // second part let stream = join.execute(1, task_ctx.clone())?; let batches = common::collect(stream).await?; - assert_eq!(batches.len(), 1); + + // expected joined records = 2 (second right batch) + let expected_batch_count = div_ceil(2, batch_size); + assert_eq!(batches.len(), expected_batch_count); + let expected = [ "+----+----+----+----+----+----+", "| a1 | b1 | c1 | a2 | b1 | c2 |", @@ -1934,9 +2001,10 @@ mod tests { ) } + #[apply(batch_sizes)] #[tokio::test] - async fn join_left_multi_batch() { - let task_ctx = Arc::new(TaskContext::default()); + async fn join_left_multi_batch(batch_size: usize) { + let task_ctx = prepare_task_ctx(batch_size); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), // 7 does not exist on the right @@ -1975,9 +2043,10 @@ mod tests { assert_batches_sorted_eq!(expected, &batches); } + #[apply(batch_sizes)] #[tokio::test] - async fn join_full_multi_batch() { - let task_ctx = Arc::new(TaskContext::default()); + async fn join_full_multi_batch(batch_size: usize) { + let task_ctx = prepare_task_ctx(batch_size); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), // 7 does not exist on the right @@ -2019,9 +2088,10 @@ mod tests { assert_batches_sorted_eq!(expected, &batches); } + #[apply(batch_sizes)] #[tokio::test] - async fn join_left_empty_right() { - let task_ctx = Arc::new(TaskContext::default()); + async fn join_left_empty_right(batch_size: usize) { + let task_ctx = prepare_task_ctx(batch_size); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), @@ -2055,9 +2125,10 @@ mod tests { assert_batches_sorted_eq!(expected, &batches); } + #[apply(batch_sizes)] #[tokio::test] - async fn join_full_empty_right() { - let task_ctx = Arc::new(TaskContext::default()); + async fn join_full_empty_right(batch_size: usize) { + let task_ctx = prepare_task_ctx(batch_size); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), @@ -2091,9 +2162,10 @@ mod tests { assert_batches_sorted_eq!(expected, &batches); } + #[apply(batch_sizes)] #[tokio::test] - async fn join_left_one() -> Result<()> { - let task_ctx = Arc::new(TaskContext::default()); + async fn join_left_one(batch_size: usize) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), // 7 does not exist on the right @@ -2134,9 +2206,10 @@ mod tests { Ok(()) } + #[apply(batch_sizes)] #[tokio::test] - async fn partitioned_join_left_one() -> Result<()> { - let task_ctx = Arc::new(TaskContext::default()); + async fn partitioned_join_left_one(batch_size: usize) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), // 7 does not exist on the right @@ -2197,9 +2270,10 @@ mod tests { ) } + #[apply(batch_sizes)] #[tokio::test] - async fn join_left_semi() -> Result<()> { - let task_ctx = Arc::new(TaskContext::default()); + async fn join_left_semi(batch_size: usize) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size); let left = build_semi_anti_left_table(); let right = build_semi_anti_right_table(); // left_table left semi join right_table on left_table.b1 = right_table.b2 @@ -2231,9 +2305,10 @@ mod tests { Ok(()) } + #[apply(batch_sizes)] #[tokio::test] - async fn join_left_semi_with_filter() -> Result<()> { - let task_ctx = Arc::new(TaskContext::default()); + async fn join_left_semi_with_filter(batch_size: usize) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size); let left = build_semi_anti_left_table(); let right = build_semi_anti_right_table(); @@ -2317,9 +2392,10 @@ mod tests { Ok(()) } + #[apply(batch_sizes)] #[tokio::test] - async fn join_right_semi() -> Result<()> { - let task_ctx = Arc::new(TaskContext::default()); + async fn join_right_semi(batch_size: usize) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size); let left = build_semi_anti_left_table(); let right = build_semi_anti_right_table(); @@ -2353,9 +2429,10 @@ mod tests { Ok(()) } + #[apply(batch_sizes)] #[tokio::test] - async fn join_right_semi_with_filter() -> Result<()> { - let task_ctx = Arc::new(TaskContext::default()); + async fn join_right_semi_with_filter(batch_size: usize) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size); let left = build_semi_anti_left_table(); let right = build_semi_anti_right_table(); @@ -2442,9 +2519,10 @@ mod tests { Ok(()) } + #[apply(batch_sizes)] #[tokio::test] - async fn join_left_anti() -> Result<()> { - let task_ctx = Arc::new(TaskContext::default()); + async fn join_left_anti(batch_size: usize) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size); let left = build_semi_anti_left_table(); let right = build_semi_anti_right_table(); // left_table left anti join right_table on left_table.b1 = right_table.b2 @@ -2475,9 +2553,10 @@ mod tests { Ok(()) } + #[apply(batch_sizes)] #[tokio::test] - async fn join_left_anti_with_filter() -> Result<()> { - let task_ctx = Arc::new(TaskContext::default()); + async fn join_left_anti_with_filter(batch_size: usize) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size); let left = build_semi_anti_left_table(); let right = build_semi_anti_right_table(); // left_table left anti join right_table on left_table.b1 = right_table.b2 and right_table.a2!=8 @@ -2568,9 +2647,10 @@ mod tests { Ok(()) } + #[apply(batch_sizes)] #[tokio::test] - async fn join_right_anti() -> Result<()> { - let task_ctx = Arc::new(TaskContext::default()); + async fn join_right_anti(batch_size: usize) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size); let left = build_semi_anti_left_table(); let right = build_semi_anti_right_table(); let on = vec![( @@ -2601,9 +2681,10 @@ mod tests { Ok(()) } + #[apply(batch_sizes)] #[tokio::test] - async fn join_right_anti_with_filter() -> Result<()> { - let task_ctx = Arc::new(TaskContext::default()); + async fn join_right_anti_with_filter(batch_size: usize) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size); let left = build_semi_anti_left_table(); let right = build_semi_anti_right_table(); // left_table right anti join right_table on left_table.b1 = right_table.b2 and left_table.a1!=13 @@ -2701,9 +2782,10 @@ mod tests { Ok(()) } + #[apply(batch_sizes)] #[tokio::test] - async fn join_right_one() -> Result<()> { - let task_ctx = Arc::new(TaskContext::default()); + async fn join_right_one(batch_size: usize) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), @@ -2739,9 +2821,10 @@ mod tests { Ok(()) } + #[apply(batch_sizes)] #[tokio::test] - async fn partitioned_join_right_one() -> Result<()> { - let task_ctx = Arc::new(TaskContext::default()); + async fn partitioned_join_right_one(batch_size: usize) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), @@ -2778,9 +2861,10 @@ mod tests { Ok(()) } + #[apply(batch_sizes)] #[tokio::test] - async fn join_full_one() -> Result<()> { - let task_ctx = Arc::new(TaskContext::default()); + async fn join_full_one(batch_size: usize) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), // 7 does not exist on the right @@ -2845,21 +2929,26 @@ mod tests { ("c", &vec![30, 40]), ); + // Join key column for both join sides + let key_column = Column::new("a", 0); + let join_hash_map = JoinHashMap::new(hashmap_left, next); - let (l, r) = build_equal_condition_join_indices( + let right_keys_values = + key_column.evaluate(&right)?.into_array(right.num_rows())?; + let mut hashes_buffer = vec![0; right.num_rows()]; + create_hashes(&[right_keys_values], &random_state, &mut hashes_buffer)?; + + let (l, r, _) = lookup_join_hashmap( &join_hash_map, &left, &right, - &[Column::new("a", 0)], - &[Column::new("a", 0)], - &random_state, - false, - &mut vec![0; right.num_rows()], - None, - JoinSide::Left, - None, + &[key_column.clone()], + &[key_column], false, + &hashes_buffer, + 8192, + (0, None), )?; let mut left_ids = UInt64Builder::with_capacity(0); @@ -2941,9 +3030,10 @@ mod tests { JoinFilter::new(filter_expression, column_indices, intermediate_schema) } + #[apply(batch_sizes)] #[tokio::test] - async fn join_inner_with_filter() -> Result<()> { - let task_ctx = Arc::new(TaskContext::default()); + async fn join_inner_with_filter(batch_size: usize) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size); let left = build_table( ("a", &vec![0, 1, 2, 2]), ("b", &vec![4, 5, 7, 8]), @@ -2981,9 +3071,10 @@ mod tests { Ok(()) } + #[apply(batch_sizes)] #[tokio::test] - async fn join_left_with_filter() -> Result<()> { - let task_ctx = Arc::new(TaskContext::default()); + async fn join_left_with_filter(batch_size: usize) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size); let left = build_table( ("a", &vec![0, 1, 2, 2]), ("b", &vec![4, 5, 7, 8]), @@ -3024,9 +3115,10 @@ mod tests { Ok(()) } + #[apply(batch_sizes)] #[tokio::test] - async fn join_right_with_filter() -> Result<()> { - let task_ctx = Arc::new(TaskContext::default()); + async fn join_right_with_filter(batch_size: usize) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size); let left = build_table( ("a", &vec![0, 1, 2, 2]), ("b", &vec![4, 5, 7, 8]), @@ -3066,9 +3158,10 @@ mod tests { Ok(()) } + #[apply(batch_sizes)] #[tokio::test] - async fn join_full_with_filter() -> Result<()> { - let task_ctx = Arc::new(TaskContext::default()); + async fn join_full_with_filter(batch_size: usize) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size); let left = build_table( ("a", &vec![0, 1, 2, 2]), ("b", &vec![4, 5, 7, 8]), @@ -3211,6 +3304,140 @@ mod tests { } } + #[tokio::test] + async fn join_splitted_batch() { + let left = build_table( + ("a1", &vec![1, 2, 3, 4]), + ("b1", &vec![1, 1, 1, 1]), + ("c1", &vec![0, 0, 0, 0]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30, 40, 50]), + ("b2", &vec![1, 1, 1, 1, 1]), + ("c2", &vec![0, 0, 0, 0, 0]), + ); + let on = vec![( + Column::new_with_schema("b1", &left.schema()).unwrap(), + Column::new_with_schema("b2", &right.schema()).unwrap(), + )]; + + let join_types = vec![ + JoinType::Inner, + JoinType::Left, + JoinType::Right, + JoinType::Full, + JoinType::RightSemi, + JoinType::RightAnti, + JoinType::LeftSemi, + JoinType::LeftAnti, + ]; + let expected_resultset_records = 20; + let common_result = [ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 1 | 0 | 10 | 1 | 0 |", + "| 2 | 1 | 0 | 10 | 1 | 0 |", + "| 3 | 1 | 0 | 10 | 1 | 0 |", + "| 4 | 1 | 0 | 10 | 1 | 0 |", + "| 1 | 1 | 0 | 20 | 1 | 0 |", + "| 2 | 1 | 0 | 20 | 1 | 0 |", + "| 3 | 1 | 0 | 20 | 1 | 0 |", + "| 4 | 1 | 0 | 20 | 1 | 0 |", + "| 1 | 1 | 0 | 30 | 1 | 0 |", + "| 2 | 1 | 0 | 30 | 1 | 0 |", + "| 3 | 1 | 0 | 30 | 1 | 0 |", + "| 4 | 1 | 0 | 30 | 1 | 0 |", + "| 1 | 1 | 0 | 40 | 1 | 0 |", + "| 2 | 1 | 0 | 40 | 1 | 0 |", + "| 3 | 1 | 0 | 40 | 1 | 0 |", + "| 4 | 1 | 0 | 40 | 1 | 0 |", + "| 1 | 1 | 0 | 50 | 1 | 0 |", + "| 2 | 1 | 0 | 50 | 1 | 0 |", + "| 3 | 1 | 0 | 50 | 1 | 0 |", + "| 4 | 1 | 0 | 50 | 1 | 0 |", + "+----+----+----+----+----+----+", + ]; + let left_batch = [ + "+----+----+----+", + "| a1 | b1 | c1 |", + "+----+----+----+", + "| 1 | 1 | 0 |", + "| 2 | 1 | 0 |", + "| 3 | 1 | 0 |", + "| 4 | 1 | 0 |", + "+----+----+----+", + ]; + let right_batch = [ + "+----+----+----+", + "| a2 | b2 | c2 |", + "+----+----+----+", + "| 10 | 1 | 0 |", + "| 20 | 1 | 0 |", + "| 30 | 1 | 0 |", + "| 40 | 1 | 0 |", + "| 50 | 1 | 0 |", + "+----+----+----+", + ]; + let right_empty = [ + "+----+----+----+", + "| a2 | b2 | c2 |", + "+----+----+----+", + "+----+----+----+", + ]; + let left_empty = [ + "+----+----+----+", + "| a1 | b1 | c1 |", + "+----+----+----+", + "+----+----+----+", + ]; + + // validation of partial join results output for different batch_size setting + for join_type in join_types { + for batch_size in (1..21).rev() { + let task_ctx = prepare_task_ctx(batch_size); + + let join = + join(left.clone(), right.clone(), on.clone(), &join_type, false) + .unwrap(); + + let stream = join.execute(0, task_ctx).unwrap(); + let batches = common::collect(stream).await.unwrap(); + + // For inner/right join expected batch count equals dev_ceil result, + // as there is no need to append non-joined build side data. + // For other join types it'll be div_ceil + 1 -- for additional batch + // containing not visited build side rows (empty in this test case). + let expected_batch_count = match join_type { + JoinType::Inner + | JoinType::Right + | JoinType::RightSemi + | JoinType::RightAnti => { + (expected_resultset_records + batch_size - 1) / batch_size + } + _ => (expected_resultset_records + batch_size - 1) / batch_size + 1, + }; + assert_eq!( + batches.len(), + expected_batch_count, + "expected {} output batches for {} join with batch_size = {}", + expected_batch_count, + join_type, + batch_size + ); + + let expected = match join_type { + JoinType::RightSemi => right_batch.to_vec(), + JoinType::RightAnti => right_empty.to_vec(), + JoinType::LeftSemi => left_batch.to_vec(), + JoinType::LeftAnti => left_empty.to_vec(), + _ => common_result.to_vec(), + }; + assert_batches_eq!(expected, &batches); + } + } + } + #[tokio::test] async fn single_partition_join_overallocation() -> Result<()> { let left = build_table( diff --git a/datafusion/physical-plan/src/joins/nested_loop_join.rs b/datafusion/physical-plan/src/joins/nested_loop_join.rs index 6951642ff801..f89a2445fd07 100644 --- a/datafusion/physical-plan/src/joins/nested_loop_join.rs +++ b/datafusion/physical-plan/src/joins/nested_loop_join.rs @@ -28,9 +28,9 @@ use crate::coalesce_batches::concat_batches; use crate::joins::utils::{ append_right_indices, apply_join_filter_to_indices, build_batch_from_indices, build_join_schema, check_join_is_valid, estimate_join_statistics, get_anti_indices, - get_anti_u64_indices, get_final_indices_from_bit_map, get_semi_indices, - get_semi_u64_indices, partitioned_join_output_partitioning, BuildProbeJoinMetrics, - ColumnIndex, JoinFilter, OnceAsync, OnceFut, + get_final_indices_from_bit_map, get_semi_indices, + partitioned_join_output_partitioning, BuildProbeJoinMetrics, ColumnIndex, JoinFilter, + OnceAsync, OnceFut, }; use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet}; use crate::{ @@ -649,20 +649,20 @@ fn adjust_indices_by_join_type( // matched // unmatched left row will be produced in this batch let left_unmatched_indices = - get_anti_u64_indices(count_left_batch, &left_indices); + get_anti_indices(0..count_left_batch, &left_indices); // combine the matched and unmatched left result together append_left_indices(left_indices, right_indices, left_unmatched_indices) } JoinType::LeftSemi => { // need to remove the duplicated record in the left side - let left_indices = get_semi_u64_indices(count_left_batch, &left_indices); + let left_indices = get_semi_indices(0..count_left_batch, &left_indices); // the right_indices will not be used later for the `left semi` join (left_indices, right_indices) } JoinType::LeftAnti => { // need to remove the duplicated record in the left side // get the anti index for the left side - let left_indices = get_anti_u64_indices(count_left_batch, &left_indices); + let left_indices = get_anti_indices(0..count_left_batch, &left_indices); // the right_indices will not be used later for the `left anti` join (left_indices, right_indices) } @@ -671,20 +671,20 @@ fn adjust_indices_by_join_type( // matched // unmatched right row will be produced in this batch let right_unmatched_indices = - get_anti_indices(count_right_batch, &right_indices); + get_anti_indices(0..count_right_batch, &right_indices); // combine the matched and unmatched right result together append_right_indices(left_indices, right_indices, right_unmatched_indices) } JoinType::RightSemi => { // need to remove the duplicated record in the right side - let right_indices = get_semi_indices(count_right_batch, &right_indices); + let right_indices = get_semi_indices(0..count_right_batch, &right_indices); // the left_indices will not be used later for the `right semi` join (left_indices, right_indices) } JoinType::RightAnti => { // need to remove the duplicated record in the right side // get the anti index for the right side - let right_indices = get_anti_indices(count_right_batch, &right_indices); + let right_indices = get_anti_indices(0..count_right_batch, &right_indices); // the left_indices will not be used later for the `right anti` join (left_indices, right_indices) } diff --git a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs index 7719c72774d6..00950f082582 100644 --- a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs +++ b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs @@ -32,7 +32,7 @@ use std::task::Poll; use std::{usize, vec}; use crate::common::SharedMemoryReservation; -use crate::joins::hash_join::{build_equal_condition_join_indices, update_hash}; +use crate::joins::hash_join::{equal_rows_arr, update_hash}; use crate::joins::stream_join_utils::{ calculate_filter_expr_intervals, combine_two_batches, convert_sort_expr_with_filter_schema, get_pruning_anti_indices, @@ -41,22 +41,26 @@ use crate::joins::stream_join_utils::{ StreamJoinMetrics, }; use crate::joins::utils::{ - build_batch_from_indices, build_join_schema, check_join_is_valid, - partitioned_join_output_partitioning, ColumnIndex, JoinFilter, JoinOn, - StatefulStreamResult, + apply_join_filter_to_indices, build_batch_from_indices, build_join_schema, + check_join_is_valid, partitioned_join_output_partitioning, ColumnIndex, JoinFilter, + JoinHashMapType, JoinOn, StatefulStreamResult, }; use crate::{ expressions::{Column, PhysicalSortExpr}, joins::StreamJoinPartitionMode, metrics::{ExecutionPlanMetricsSet, MetricsSet}, DisplayAs, DisplayFormatType, Distribution, EquivalenceProperties, ExecutionPlan, - Partitioning, RecordBatchStream, SendableRecordBatchStream, Statistics, + Partitioning, PhysicalExpr, RecordBatchStream, SendableRecordBatchStream, Statistics, }; -use arrow::array::{ArrowPrimitiveType, NativeAdapter, PrimitiveArray, PrimitiveBuilder}; +use arrow::array::{ + ArrowPrimitiveType, NativeAdapter, PrimitiveArray, PrimitiveBuilder, UInt32Array, + UInt64Array, +}; use arrow::compute::concat_batches; use arrow::datatypes::{Schema, SchemaRef}; use arrow::record_batch::RecordBatch; +use datafusion_common::hash_utils::create_hashes; use datafusion_common::utils::bisect; use datafusion_common::{ internal_err, plan_err, DataFusionError, JoinSide, JoinType, Result, @@ -785,7 +789,7 @@ pub(crate) fn join_with_probe_batch( if build_hash_joiner.input_buffer.num_rows() == 0 || probe_batch.num_rows() == 0 { return Ok(None); } - let (build_indices, probe_indices) = build_equal_condition_join_indices( + let (build_indices, probe_indices) = lookup_join_hashmap( &build_hash_joiner.hashmap, &build_hash_joiner.input_buffer, probe_batch, @@ -794,11 +798,22 @@ pub(crate) fn join_with_probe_batch( random_state, null_equals_null, &mut build_hash_joiner.hashes_buffer, - filter, - build_hash_joiner.build_side, Some(build_hash_joiner.deleted_offset), - false, )?; + + let (build_indices, probe_indices) = if let Some(filter) = filter { + apply_join_filter_to_indices( + &build_hash_joiner.input_buffer, + probe_batch, + build_indices, + probe_indices, + filter, + build_hash_joiner.build_side, + )? + } else { + (build_indices, probe_indices) + }; + if need_to_produce_result_in_final(build_hash_joiner.build_side, join_type) { record_visited_indices( &mut build_hash_joiner.visited_rows, @@ -835,6 +850,102 @@ pub(crate) fn join_with_probe_batch( } } +/// This method performs lookups against JoinHashMap by hash values of join-key columns, and handles potential +/// hash collisions. +/// +/// # Arguments +/// +/// * `build_hashmap` - hashmap collected from build side data. +/// * `build_batch` - Build side record batch. +/// * `probe_batch` - Probe side record batch. +/// * `build_on` - An array of columns on which the join will be performed. The columns are from the build side of the join. +/// * `probe_on` - An array of columns on which the join will be performed. The columns are from the probe side of the join. +/// * `random_state` - The random state for the join. +/// * `null_equals_null` - A boolean indicating whether NULL values should be treated as equal when joining. +/// * `hashes_buffer` - Buffer used for probe side keys hash calculation. +/// * `deleted_offset` - deleted offset for build side data. +/// +/// # Returns +/// +/// A [Result] containing a tuple with two equal length arrays, representing indices of rows from build and probe side, +/// matched by join key columns. +#[allow(clippy::too_many_arguments)] +fn lookup_join_hashmap( + build_hashmap: &PruningJoinHashMap, + build_batch: &RecordBatch, + probe_batch: &RecordBatch, + build_on: &[Column], + probe_on: &[Column], + random_state: &RandomState, + null_equals_null: bool, + hashes_buffer: &mut Vec, + deleted_offset: Option, +) -> Result<(UInt64Array, UInt32Array)> { + let keys_values = probe_on + .iter() + .map(|c| c.evaluate(probe_batch)?.into_array(probe_batch.num_rows())) + .collect::>>()?; + let build_join_values = build_on + .iter() + .map(|c| c.evaluate(build_batch)?.into_array(build_batch.num_rows())) + .collect::>>()?; + + hashes_buffer.clear(); + hashes_buffer.resize(probe_batch.num_rows(), 0); + let hash_values = create_hashes(&keys_values, random_state, hashes_buffer)?; + + // As SymmetricHashJoin uses LIFO JoinHashMap, the chained list algorithm + // will return build indices for each probe row in a reverse order as such: + // Build Indices: [5, 4, 3] + // Probe Indices: [1, 1, 1] + // + // This affects the output sequence. Hypothetically, it's possible to preserve the lexicographic order on the build side. + // Let's consider probe rows [0,1] as an example: + // + // When the probe iteration sequence is reversed, the following pairings can be derived: + // + // For probe row 1: + // (5, 1) + // (4, 1) + // (3, 1) + // + // For probe row 0: + // (5, 0) + // (4, 0) + // (3, 0) + // + // After reversing both sets of indices, we obtain reversed indices: + // + // (3,0) + // (4,0) + // (5,0) + // (3,1) + // (4,1) + // (5,1) + // + // With this approach, the lexicographic order on both the probe side and the build side is preserved. + let (mut matched_probe, mut matched_build) = build_hashmap + .get_matched_indices(hash_values.iter().enumerate().rev(), deleted_offset); + + matched_probe.as_slice_mut().reverse(); + matched_build.as_slice_mut().reverse(); + + let build_indices: UInt64Array = + PrimitiveArray::new(matched_build.finish().into(), None); + let probe_indices: UInt32Array = + PrimitiveArray::new(matched_probe.finish().into(), None); + + let (build_indices, probe_indices) = equal_rows_arr( + &build_indices, + &probe_indices, + &build_join_values, + &keys_values, + null_equals_null, + )?; + + Ok((build_indices, probe_indices)) +} + pub struct OneSideHashJoiner { /// Build side build_side: JoinSide, diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index 1e3cf5abb477..6ab08d3db022 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -20,7 +20,7 @@ use std::collections::HashSet; use std::fmt::{self, Debug}; use std::future::Future; -use std::ops::IndexMut; +use std::ops::{IndexMut, Range}; use std::sync::Arc; use std::task::{Context, Poll}; use std::usize; @@ -35,6 +35,8 @@ use arrow::array::{ use arrow::compute; use arrow::datatypes::{Field, Schema, SchemaBuilder}; use arrow::record_batch::{RecordBatch, RecordBatchOptions}; +use arrow_array::{ArrowPrimitiveType, NativeAdapter, PrimitiveArray}; +use arrow_buffer::ArrowNativeType; use datafusion_common::cast::as_boolean_array; use datafusion_common::stats::Precision; use datafusion_common::{ @@ -136,6 +138,53 @@ impl JoinHashMap { } } +// Type of offsets for obtaining indices from JoinHashMap. +pub(crate) type JoinHashMapOffset = (usize, Option); + +// Macro for traversing chained values with limit. +// Early returns in case of reacing output tuples limit. +macro_rules! chain_traverse { + ( + $input_indices:ident, $match_indices:ident, $hash_values:ident, $next_chain:ident, + $input_idx:ident, $chain_idx:ident, $deleted_offset:ident, $remaining_output:ident + ) => { + let mut i = $chain_idx - 1; + loop { + let match_row_idx = if let Some(offset) = $deleted_offset { + // This arguments means that we prune the next index way before here. + if i < offset as u64 { + // End of the list due to pruning + break; + } + i - offset as u64 + } else { + i + }; + $match_indices.append(match_row_idx); + $input_indices.append($input_idx as u32); + $remaining_output -= 1; + // Follow the chain to get the next index value + let next = $next_chain[match_row_idx as usize]; + + if $remaining_output == 0 { + // In case current input index is the last, and no more chain values left + // returning None as whole input has been scanned + let next_offset = if $input_idx == $hash_values.len() - 1 && next == 0 { + None + } else { + Some(($input_idx, Some(next))) + }; + return ($input_indices, $match_indices, next_offset); + } + if next == 0 { + // end of list + break; + } + i = next - 1; + } + }; +} + // Trait defining methods that must be implemented by a hash map type to be used for joins. pub trait JoinHashMapType { /// The type of list used to store the next list @@ -224,6 +273,78 @@ pub trait JoinHashMapType { (input_indices, match_indices) } + + /// Matches hashes with taking limit and offset into account. + /// Returns pairs of matched indices along with the starting point for next + /// matching iteration (`None` if limit has not been reached). + /// + /// This method only compares hashes, so additional further check for actual values + /// equality may be required. + fn get_matched_indices_with_limit_offset( + &self, + hash_values: &[u64], + deleted_offset: Option, + limit: usize, + offset: JoinHashMapOffset, + ) -> ( + UInt32BufferBuilder, + UInt64BufferBuilder, + Option, + ) { + let mut input_indices = UInt32BufferBuilder::new(0); + let mut match_indices = UInt64BufferBuilder::new(0); + + let mut remaining_output = limit; + + let hash_map: &RawTable<(u64, u64)> = self.get_map(); + let next_chain = self.get_list(); + + // Calculate initial `hash_values` index before iterating + let to_skip = match offset { + // None `initial_next_idx` indicates that `initial_idx` processing has'n been started + (initial_idx, None) => initial_idx, + // Zero `initial_next_idx` indicates that `initial_idx` has been processed during + // previous iteration, and it should be skipped + (initial_idx, Some(0)) => initial_idx + 1, + // Otherwise, process remaining `initial_idx` matches by traversing `next_chain`, + // to start with the next index + (initial_idx, Some(initial_next_idx)) => { + chain_traverse!( + input_indices, + match_indices, + hash_values, + next_chain, + initial_idx, + initial_next_idx, + deleted_offset, + remaining_output + ); + + initial_idx + 1 + } + }; + + let mut row_idx = to_skip; + for hash_value in &hash_values[to_skip..] { + if let Some((_, index)) = + hash_map.get(*hash_value, |(hash, _)| *hash_value == *hash) + { + chain_traverse!( + input_indices, + match_indices, + hash_values, + next_chain, + row_idx, + index, + deleted_offset, + remaining_output + ); + } + row_idx += 1; + } + + (input_indices, match_indices, None) + } } /// Implementation of `JoinHashMapType` for `JoinHashMap`. @@ -1079,7 +1200,7 @@ pub(crate) fn build_batch_from_indices( pub(crate) fn adjust_indices_by_join_type( left_indices: UInt64Array, right_indices: UInt32Array, - count_right_batch: usize, + adjust_range: Range, join_type: JoinType, ) -> (UInt64Array, UInt32Array) { match join_type { @@ -1095,21 +1216,20 @@ pub(crate) fn adjust_indices_by_join_type( JoinType::Right | JoinType::Full => { // matched // unmatched right row will be produced in this batch - let right_unmatched_indices = - get_anti_indices(count_right_batch, &right_indices); + let right_unmatched_indices = get_anti_indices(adjust_range, &right_indices); // combine the matched and unmatched right result together append_right_indices(left_indices, right_indices, right_unmatched_indices) } JoinType::RightSemi => { // need to remove the duplicated record in the right side - let right_indices = get_semi_indices(count_right_batch, &right_indices); + let right_indices = get_semi_indices(adjust_range, &right_indices); // the left_indices will not be used later for the `right semi` join (left_indices, right_indices) } JoinType::RightAnti => { // need to remove the duplicated record in the right side // get the anti index for the right side - let right_indices = get_anti_indices(count_right_batch, &right_indices); + let right_indices = get_anti_indices(adjust_range, &right_indices); // the left_indices will not be used later for the `right anti` join (left_indices, right_indices) } @@ -1151,72 +1271,62 @@ pub(crate) fn append_right_indices( } } -/// Get unmatched and deduplicated indices -pub(crate) fn get_anti_indices( - row_count: usize, - input_indices: &UInt32Array, -) -> UInt32Array { - let mut bitmap = BooleanBufferBuilder::new(row_count); - bitmap.append_n(row_count, false); - input_indices.iter().flatten().for_each(|v| { - bitmap.set_bit(v as usize, true); - }); - - // get the anti index - (0..row_count) - .filter_map(|idx| (!bitmap.get_bit(idx)).then_some(idx as u32)) - .collect::() -} +/// Returns `range` indices which are not present in `input_indices` +pub(crate) fn get_anti_indices( + range: Range, + input_indices: &PrimitiveArray, +) -> PrimitiveArray +where + NativeAdapter: From<::Native>, +{ + let mut bitmap = BooleanBufferBuilder::new(range.len()); + bitmap.append_n(range.len(), false); + input_indices + .iter() + .flatten() + .map(|v| v.as_usize()) + .filter(|v| range.contains(v)) + .for_each(|v| { + bitmap.set_bit(v - range.start, true); + }); -/// Get unmatched and deduplicated indices -pub(crate) fn get_anti_u64_indices( - row_count: usize, - input_indices: &UInt64Array, -) -> UInt64Array { - let mut bitmap = BooleanBufferBuilder::new(row_count); - bitmap.append_n(row_count, false); - input_indices.iter().flatten().for_each(|v| { - bitmap.set_bit(v as usize, true); - }); + let offset = range.start; // get the anti index - (0..row_count) - .filter_map(|idx| (!bitmap.get_bit(idx)).then_some(idx as u64)) - .collect::() + (range) + .filter_map(|idx| { + (!bitmap.get_bit(idx - offset)).then_some(T::Native::from_usize(idx)) + }) + .collect::>() } -/// Get matched and deduplicated indices -pub(crate) fn get_semi_indices( - row_count: usize, - input_indices: &UInt32Array, -) -> UInt32Array { - let mut bitmap = BooleanBufferBuilder::new(row_count); - bitmap.append_n(row_count, false); - input_indices.iter().flatten().for_each(|v| { - bitmap.set_bit(v as usize, true); - }); - - // get the semi index - (0..row_count) - .filter_map(|idx| (bitmap.get_bit(idx)).then_some(idx as u32)) - .collect::() -} +/// Returns intersection of `range` and `input_indices` omitting duplicates +pub(crate) fn get_semi_indices( + range: Range, + input_indices: &PrimitiveArray, +) -> PrimitiveArray +where + NativeAdapter: From<::Native>, +{ + let mut bitmap = BooleanBufferBuilder::new(range.len()); + bitmap.append_n(range.len(), false); + input_indices + .iter() + .flatten() + .map(|v| v.as_usize()) + .filter(|v| range.contains(v)) + .for_each(|v| { + bitmap.set_bit(v - range.start, true); + }); -/// Get matched and deduplicated indices -pub(crate) fn get_semi_u64_indices( - row_count: usize, - input_indices: &UInt64Array, -) -> UInt64Array { - let mut bitmap = BooleanBufferBuilder::new(row_count); - bitmap.append_n(row_count, false); - input_indices.iter().flatten().for_each(|v| { - bitmap.set_bit(v as usize, true); - }); + let offset = range.start; // get the semi index - (0..row_count) - .filter_map(|idx| (bitmap.get_bit(idx)).then_some(idx as u64)) - .collect::() + (range) + .filter_map(|idx| { + (bitmap.get_bit(idx - offset)).then_some(T::Native::from_usize(idx)) + }) + .collect::>() } /// Metrics for build & probe joins diff --git a/datafusion/physical-plan/src/lib.rs b/datafusion/physical-plan/src/lib.rs index 1dd1392b9d86..01d4f8941802 100644 --- a/datafusion/physical-plan/src/lib.rs +++ b/datafusion/physical-plan/src/lib.rs @@ -610,4 +610,7 @@ pub fn get_plan_string(plan: &Arc) -> Vec { } #[cfg(test)] +#[allow(clippy::single_component_path_imports)] +use rstest_reuse; + pub mod test; diff --git a/datafusion/physical-plan/src/test/exec.rs b/datafusion/physical-plan/src/test/exec.rs index 1f6ee1f117aa..5a8ef2db77c2 100644 --- a/datafusion/physical-plan/src/test/exec.rs +++ b/datafusion/physical-plan/src/test/exec.rs @@ -61,7 +61,7 @@ impl BatchIndex { /// Iterator over batches #[derive(Debug, Default)] -pub(crate) struct TestStream { +pub struct TestStream { /// Vector of record batches data: Vec, /// Index into the data that has been returned so far @@ -684,7 +684,7 @@ pub struct PanicExec { } impl PanicExec { - /// Create new [`PanickingExec`] with a give schema and number of + /// Create new [`PanicExec`] with a give schema and number of /// partitions, which will each panic immediately. pub fn new(schema: SchemaRef, n_partitions: usize) -> Self { Self { @@ -708,7 +708,7 @@ impl DisplayAs for PanicExec { ) -> std::fmt::Result { match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { - write!(f, "PanickingExec",) + write!(f, "PanicExec",) } } } From c9935ae52ebdc54a3578d789e2c1c4cd29ba54bd Mon Sep 17 00:00:00 2001 From: Huaijin Date: Tue, 23 Jan 2024 03:16:39 +0800 Subject: [PATCH 38/39] fix: common_subexpr_eliminate rule should not apply to short-circuit expression (#8928) * fix: common_subexpr_eliminate rule should not apply to short-circuit expression * add more tests * format * minor * apply reviews * add some commont * fmt --- datafusion/expr/src/expr.rs | 48 +++++++++++++++++++ .../optimizer/src/common_subexpr_eliminate.rs | 17 ++++--- datafusion/sqllogictest/test_files/select.slt | 44 +++++++++++++++++ 3 files changed, 103 insertions(+), 6 deletions(-) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 9aeebb190e81..c5d158d87638 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -1266,6 +1266,54 @@ impl Expr { Ok(Transformed::Yes(expr)) }) } + + /// Returns true if some of this `exprs` subexpressions may not be evaluated + /// and thus any side effects (like divide by zero) may not be encountered + pub fn short_circuits(&self) -> bool { + match self { + Expr::ScalarFunction(ScalarFunction { func_def, .. }) => { + matches!(func_def, ScalarFunctionDefinition::BuiltIn(fun) if *fun == BuiltinScalarFunction::Coalesce) + } + Expr::BinaryExpr(BinaryExpr { op, .. }) => { + matches!(op, Operator::And | Operator::Or) + } + Expr::Case { .. } => true, + // Use explicit pattern match instead of a default + // implementation, so that in the future if someone adds + // new Expr types, they will check here as well + Expr::AggregateFunction(..) + | Expr::Alias(..) + | Expr::Between(..) + | Expr::Cast(..) + | Expr::Column(..) + | Expr::Exists(..) + | Expr::GetIndexedField(..) + | Expr::GroupingSet(..) + | Expr::InList(..) + | Expr::InSubquery(..) + | Expr::IsFalse(..) + | Expr::IsNotFalse(..) + | Expr::IsNotNull(..) + | Expr::IsNotTrue(..) + | Expr::IsNotUnknown(..) + | Expr::IsNull(..) + | Expr::IsTrue(..) + | Expr::IsUnknown(..) + | Expr::Like(..) + | Expr::ScalarSubquery(..) + | Expr::ScalarVariable(_, _) + | Expr::SimilarTo(..) + | Expr::Not(..) + | Expr::Negative(..) + | Expr::OuterReferenceColumn(_, _) + | Expr::TryCast(..) + | Expr::Wildcard { .. } + | Expr::WindowFunction(..) + | Expr::Literal(..) + | Expr::Sort(..) + | Expr::Placeholder(..) => false, + } + } } // modifies expr if it is a placeholder with datatype of right diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index f29c7406acc9..fe71171ce545 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -616,8 +616,8 @@ impl TreeNodeVisitor for ExprIdentifierVisitor<'_> { fn pre_visit(&mut self, expr: &Expr) -> Result { // related to https://github.com/apache/arrow-datafusion/issues/8814 - // If the expr contain volatile expression or is a case expression, skip it. - if matches!(expr, Expr::Case(..)) || is_volatile_expression(expr)? { + // If the expr contain volatile expression or is a short-circuit expression, skip it. + if expr.short_circuits() || is_volatile_expression(expr)? { return Ok(VisitRecursion::Skip); } self.visit_stack @@ -696,7 +696,13 @@ struct CommonSubexprRewriter<'a> { impl TreeNodeRewriter for CommonSubexprRewriter<'_> { type N = Expr; - fn pre_visit(&mut self, _: &Expr) -> Result { + fn pre_visit(&mut self, expr: &Expr) -> Result { + // The `CommonSubexprRewriter` relies on `ExprIdentifierVisitor` to generate + // the `id_array`, which records the expr's identifier used to rewrite expr. So if we + // skip an expr in `ExprIdentifierVisitor`, we should skip it here, too. + if expr.short_circuits() || is_volatile_expression(expr)? { + return Ok(RewriteRecursion::Stop); + } if self.curr_index >= self.id_array.len() || self.max_series_number > self.id_array[self.curr_index].0 { @@ -1249,12 +1255,11 @@ mod test { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(table_scan) - .filter(lit(1).gt(col("a")).and(lit(1).gt(col("a"))))? + .filter((lit(1) + col("a") - lit(10)).gt(lit(1) + col("a")))? .build()?; let expected = "Projection: test.a, test.b, test.c\ - \n Filter: Int32(1) > test.atest.aInt32(1) AS Int32(1) > test.a AND Int32(1) > test.atest.aInt32(1) AS Int32(1) > test.a\ - \n Projection: Int32(1) > test.a AS Int32(1) > test.atest.aInt32(1), test.a, test.b, test.c\ + \n Filter: Int32(1) + test.atest.aInt32(1) AS Int32(1) + test.a - Int32(10) > Int32(1) + test.atest.aInt32(1) AS Int32(1) + test.a\n Projection: Int32(1) + test.a AS Int32(1) + test.atest.aInt32(1), test.a, test.b, test.c\ \n TableScan: test"; assert_optimized_plan_eq(expected, &plan); diff --git a/datafusion/sqllogictest/test_files/select.slt b/datafusion/sqllogictest/test_files/select.slt index ca48c07b0914..9ffddc6e2d46 100644 --- a/datafusion/sqllogictest/test_files/select.slt +++ b/datafusion/sqllogictest/test_files/select.slt @@ -1129,5 +1129,49 @@ FROM t AS A, (SELECT * FROM t WHERE x = 0) AS B; 0 0 0 0 +# Expressions that short circuit should not be refactored out as that may cause side effects (divide by zero) +# at plan time that would not actually happen during execution, so the follow three query should not be extract +# the common sub-expression +query TT +explain select coalesce(1, y/x), coalesce(2, y/x) from t; +---- +logical_plan +Projection: coalesce(Int64(1), CAST(t.y / t.x AS Int64)), coalesce(Int64(2), CAST(t.y / t.x AS Int64)) +--TableScan: t projection=[x, y] +physical_plan +ProjectionExec: expr=[coalesce(1, CAST(y@1 / x@0 AS Int64)) as coalesce(Int64(1),t.y / t.x), coalesce(2, CAST(y@1 / x@0 AS Int64)) as coalesce(Int64(2),t.y / t.x)] +--MemoryExec: partitions=1, partition_sizes=[1] + +query TT +EXPLAIN SELECT y > 0 and 1 / y < 1, x > 0 and y > 0 and 1 / y < 1 / x from t; +---- +logical_plan +Projection: t.y > Int32(0) AND Int64(1) / CAST(t.y AS Int64) < Int64(1) AS t.y > Int64(0) AND Int64(1) / t.y < Int64(1), t.x > Int32(0) AND t.y > Int32(0) AND Int64(1) / CAST(t.y AS Int64) < Int64(1) / CAST(t.x AS Int64) AS t.x > Int64(0) AND t.y > Int64(0) AND Int64(1) / t.y < Int64(1) / t.x +--TableScan: t projection=[x, y] +physical_plan +ProjectionExec: expr=[y@1 > 0 AND 1 / CAST(y@1 AS Int64) < 1 as t.y > Int64(0) AND Int64(1) / t.y < Int64(1), x@0 > 0 AND y@1 > 0 AND 1 / CAST(y@1 AS Int64) < 1 / CAST(x@0 AS Int64) as t.x > Int64(0) AND t.y > Int64(0) AND Int64(1) / t.y < Int64(1) / t.x] +--MemoryExec: partitions=1, partition_sizes=[1] + +query TT +EXPLAIN SELECT y = 0 or 1 / y < 1, x = 0 or y = 0 or 1 / y < 1 / x from t; +---- +logical_plan +Projection: t.y = Int32(0) OR Int64(1) / CAST(t.y AS Int64) < Int64(1) AS t.y = Int64(0) OR Int64(1) / t.y < Int64(1), t.x = Int32(0) OR t.y = Int32(0) OR Int64(1) / CAST(t.y AS Int64) < Int64(1) / CAST(t.x AS Int64) AS t.x = Int64(0) OR t.y = Int64(0) OR Int64(1) / t.y < Int64(1) / t.x +--TableScan: t projection=[x, y] +physical_plan +ProjectionExec: expr=[y@1 = 0 OR 1 / CAST(y@1 AS Int64) < 1 as t.y = Int64(0) OR Int64(1) / t.y < Int64(1), x@0 = 0 OR y@1 = 0 OR 1 / CAST(y@1 AS Int64) < 1 / CAST(x@0 AS Int64) as t.x = Int64(0) OR t.y = Int64(0) OR Int64(1) / t.y < Int64(1) / t.x] +--MemoryExec: partitions=1, partition_sizes=[1] + +# due to the reason describe in https://github.com/apache/arrow-datafusion/issues/8927, +# the following queries will fail +query error +select coalesce(1, y/x), coalesce(2, y/x) from t; + +query error +SELECT y > 0 and 1 / y < 1, x > 0 and y > 0 and 1 / y < 1 / x from t; + +query error +SELECT y = 0 or 1 / y < 1, x = 0 or y = 0 or 1 / y < 1 / x from t; + statement ok DROP TABLE t; From edec4189242ab07ac65967490537d77e776aad5c Mon Sep 17 00:00:00 2001 From: junxiangMu <63799833+guojidan@users.noreply.github.com> Date: Tue, 23 Jan 2024 03:39:45 +0800 Subject: [PATCH 39/39] Support GroupsAccumulator accumulator for udaf (#8892) * Support GroupsAccumulator accumulator for udaf * del some outdated content * add example in advanced_udaf.rs * fix doc * fix logic err && add test case --- datafusion-examples/Cargo.toml | 1 + datafusion-examples/examples/advanced_udaf.rs | 220 +++++++++++++++++- .../user_defined/user_defined_aggregates.rs | 128 +++++++++- datafusion/expr/src/groups_accumulator.rs | 153 ++++++++++++ datafusion/expr/src/lib.rs | 2 + datafusion/expr/src/udaf.rs | 29 ++- .../physical-expr/src/aggregate/average.rs | 5 +- .../src/aggregate/bit_and_or_xor.rs | 4 +- .../src/aggregate/bool_and_or.rs | 4 +- .../physical-expr/src/aggregate/count.rs | 5 +- .../groups_accumulator/accumulate.rs | 13 +- .../aggregate/groups_accumulator/adapter.rs | 3 +- .../aggregate/groups_accumulator/bool_op.rs | 5 +- .../src/aggregate/groups_accumulator/mod.rs | 138 +---------- .../aggregate/groups_accumulator/prim_op.rs | 5 +- .../physical-expr/src/aggregate/min_max.rs | 4 +- datafusion/physical-expr/src/aggregate/mod.rs | 3 +- datafusion/physical-expr/src/aggregate/sum.rs | 4 +- .../physical-expr/src/expressions/mod.rs | 5 +- datafusion/physical-expr/src/lib.rs | 4 +- .../src/aggregates/group_values/mod.rs | 2 +- .../src/aggregates/group_values/primitive.rs | 2 +- .../src/aggregates/group_values/row.rs | 2 +- .../src/aggregates/order/full.rs | 2 +- .../physical-plan/src/aggregates/order/mod.rs | 3 +- .../src/aggregates/order/partial.rs | 2 +- .../physical-plan/src/aggregates/row_hash.rs | 3 +- datafusion/physical-plan/src/udaf.rs | 9 + 28 files changed, 568 insertions(+), 192 deletions(-) create mode 100644 datafusion/expr/src/groups_accumulator.rs diff --git a/datafusion-examples/Cargo.toml b/datafusion-examples/Cargo.toml index 59580bcb6a05..45c9709a342e 100644 --- a/datafusion-examples/Cargo.toml +++ b/datafusion-examples/Cargo.toml @@ -40,6 +40,7 @@ datafusion = { path = "../datafusion/core", features = ["avro"] } datafusion-common = { path = "../datafusion/common" } datafusion-expr = { path = "../datafusion/expr" } datafusion-optimizer = { path = "../datafusion/optimizer" } +datafusion-physical-expr = { workspace = true } datafusion-sql = { path = "../datafusion/sql" } env_logger = { workspace = true } futures = { workspace = true } diff --git a/datafusion-examples/examples/advanced_udaf.rs b/datafusion-examples/examples/advanced_udaf.rs index 8d5314bfbea5..e5433013d9a7 100644 --- a/datafusion-examples/examples/advanced_udaf.rs +++ b/datafusion-examples/examples/advanced_udaf.rs @@ -16,16 +16,22 @@ // under the License. use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility}; +use datafusion_physical_expr::NullState; use std::{any::Any, sync::Arc}; use arrow::{ - array::{ArrayRef, Float32Array}, + array::{ + ArrayRef, AsArray, Float32Array, PrimitiveArray, PrimitiveBuilder, UInt32Array, + }, + datatypes::{ArrowNativeTypeOp, ArrowPrimitiveType, Float64Type, UInt32Type}, record_batch::RecordBatch, }; use datafusion::error::Result; use datafusion::prelude::*; use datafusion_common::{cast::as_float64_array, ScalarValue}; -use datafusion_expr::{Accumulator, AggregateUDF, AggregateUDFImpl, Signature}; +use datafusion_expr::{ + Accumulator, AggregateUDF, AggregateUDFImpl, GroupsAccumulator, Signature, +}; /// This example shows how to use the full AggregateUDFImpl API to implement a user /// defined aggregate function. As in the `simple_udaf.rs` example, this struct implements @@ -33,12 +39,12 @@ use datafusion_expr::{Accumulator, AggregateUDF, AggregateUDFImpl, Signature}; /// /// To do so, we must implement the `AggregateUDFImpl` trait. #[derive(Debug, Clone)] -struct GeoMeanUdf { +struct GeoMeanUdaf { signature: Signature, } -impl GeoMeanUdf { - /// Create a new instance of the GeoMeanUdf struct +impl GeoMeanUdaf { + /// Create a new instance of the GeoMeanUdaf struct fn new() -> Self { Self { signature: Signature::exact( @@ -52,7 +58,7 @@ impl GeoMeanUdf { } } -impl AggregateUDFImpl for GeoMeanUdf { +impl AggregateUDFImpl for GeoMeanUdaf { /// We implement as_any so that we can downcast the AggregateUDFImpl trait object fn as_any(&self) -> &dyn Any { self @@ -74,6 +80,11 @@ impl AggregateUDFImpl for GeoMeanUdf { } /// This is the accumulator factory; DataFusion uses it to create new accumulators. + /// + /// This is the accumulator factory for row wise accumulation; Even when `GroupsAccumulator` + /// is supported, DataFusion will use this row oriented + /// accumulator when the aggregate function is used as a window function + /// or when there are only aggregates (no GROUP BY columns) in the plan. fn accumulator(&self, _arg: &DataType) -> Result> { Ok(Box::new(GeometricMean::new())) } @@ -82,6 +93,16 @@ impl AggregateUDFImpl for GeoMeanUdf { fn state_type(&self, _return_type: &DataType) -> Result> { Ok(vec![DataType::Float64, DataType::UInt32]) } + + /// Tell DataFusion that this aggregate supports the more performant `GroupsAccumulator` + /// which is used for cases when there are grouping columns in the query + fn groups_accumulator_supported(&self) -> bool { + true + } + + fn create_groups_accumulator(&self) -> Result> { + Ok(Box::new(GeometricMeanGroupsAccumulator::new())) + } } /// A UDAF has state across multiple rows, and thus we require a `struct` with that state. @@ -173,16 +194,25 @@ fn create_context() -> Result { use datafusion::arrow::datatypes::{Field, Schema}; use datafusion::datasource::MemTable; // define a schema. - let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, false)])); + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Float32, false), + Field::new("b", DataType::Float32, false), + ])); // define data in two partitions let batch1 = RecordBatch::try_new( schema.clone(), - vec![Arc::new(Float32Array::from(vec![2.0, 4.0, 8.0]))], + vec![ + Arc::new(Float32Array::from(vec![2.0, 4.0, 8.0])), + Arc::new(Float32Array::from(vec![2.0, 2.0, 2.0])), + ], )?; let batch2 = RecordBatch::try_new( schema.clone(), - vec![Arc::new(Float32Array::from(vec![64.0]))], + vec![ + Arc::new(Float32Array::from(vec![64.0])), + Arc::new(Float32Array::from(vec![2.0])), + ], )?; // declare a new context. In spark API, this corresponds to a new spark SQLsession @@ -194,15 +224,183 @@ fn create_context() -> Result { Ok(ctx) } +// Define a `GroupsAccumulator` for GeometricMean +/// which handles accumulator state for multiple groups at once. +/// This API is significantly more complicated than `Accumulator`, which manages +/// the state for a single group, but for queries with a large number of groups +/// can be significantly faster. See the `GroupsAccumulator` documentation for +/// more information. +struct GeometricMeanGroupsAccumulator { + /// The type of the internal sum + prod_data_type: DataType, + + /// The type of the returned sum + return_data_type: DataType, + + /// Count per group (use u32 to make UInt32Array) + counts: Vec, + + /// product per group, stored as the native type (not `ScalarValue`) + prods: Vec, + + /// Track nulls in the input / filters + null_state: NullState, +} + +impl GeometricMeanGroupsAccumulator { + fn new() -> Self { + Self { + prod_data_type: DataType::Float64, + return_data_type: DataType::Float64, + counts: vec![], + prods: vec![], + null_state: NullState::new(), + } + } +} + +impl GroupsAccumulator for GeometricMeanGroupsAccumulator { + /// Updates the accumulator state given input. DataFusion provides `group_indices`, + /// the groups that each row in `values` belongs to as well as an optional filter of which rows passed. + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&arrow::array::BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + assert_eq!(values.len(), 1, "single argument to update_batch"); + let values = values[0].as_primitive::(); + + // increment counts, update sums + self.counts.resize(total_num_groups, 0); + self.prods.resize(total_num_groups, 1.0); + // Use the `NullState` structure to generate specialized code for null / non null input elements + self.null_state.accumulate( + group_indices, + values, + opt_filter, + total_num_groups, + |group_index, new_value| { + let prod = &mut self.prods[group_index]; + *prod = prod.mul_wrapping(new_value); + + self.counts[group_index] += 1; + }, + ); + + Ok(()) + } + + /// Merge the results from previous invocations of `evaluate` into this accumulator's state + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&arrow::array::BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + assert_eq!(values.len(), 2, "two arguments to merge_batch"); + // first batch is counts, second is partial sums + let partial_prods = values[0].as_primitive::(); + let partial_counts = values[1].as_primitive::(); + // update counts with partial counts + self.counts.resize(total_num_groups, 0); + self.null_state.accumulate( + group_indices, + partial_counts, + opt_filter, + total_num_groups, + |group_index, partial_count| { + self.counts[group_index] += partial_count; + }, + ); + + // update prods + self.prods.resize(total_num_groups, 1.0); + self.null_state.accumulate( + group_indices, + partial_prods, + opt_filter, + total_num_groups, + |group_index, new_value: ::Native| { + let prod = &mut self.prods[group_index]; + *prod = prod.mul_wrapping(new_value); + }, + ); + + Ok(()) + } + + /// Generate output, as specififed by `emit_to` and update the intermediate state + fn evaluate(&mut self, emit_to: datafusion_expr::EmitTo) -> Result { + let counts = emit_to.take_needed(&mut self.counts); + let prods = emit_to.take_needed(&mut self.prods); + let nulls = self.null_state.build(emit_to); + + assert_eq!(nulls.len(), prods.len()); + assert_eq!(counts.len(), prods.len()); + + // don't evaluate geometric mean with null inputs to avoid errors on null values + + let array: PrimitiveArray = if nulls.null_count() > 0 { + let mut builder = PrimitiveBuilder::::with_capacity(nulls.len()); + let iter = prods.into_iter().zip(counts).zip(nulls.iter()); + + for ((prod, count), is_valid) in iter { + if is_valid { + builder.append_value(prod.powf(1.0 / count as f64)) + } else { + builder.append_null(); + } + } + builder.finish() + } else { + let geo_mean: Vec<::Native> = prods + .into_iter() + .zip(counts) + .map(|(prod, count)| prod.powf(1.0 / count as f64)) + .collect::>(); + PrimitiveArray::new(geo_mean.into(), Some(nulls)) // no copy + .with_data_type(self.return_data_type.clone()) + }; + + Ok(Arc::new(array)) + } + + // return arrays for counts and prods + fn state(&mut self, emit_to: datafusion_expr::EmitTo) -> Result> { + let nulls = self.null_state.build(emit_to); + let nulls = Some(nulls); + + let counts = emit_to.take_needed(&mut self.counts); + let counts = UInt32Array::new(counts.into(), nulls.clone()); // zero copy + + let prods = emit_to.take_needed(&mut self.prods); + let prods = PrimitiveArray::::new(prods.into(), nulls) // zero copy + .with_data_type(self.prod_data_type.clone()); + + Ok(vec![ + Arc::new(prods) as ArrayRef, + Arc::new(counts) as ArrayRef, + ]) + } + + fn size(&self) -> usize { + self.counts.capacity() * std::mem::size_of::() + + self.prods.capacity() * std::mem::size_of::() + } +} + #[tokio::main] async fn main() -> Result<()> { let ctx = create_context()?; // create the AggregateUDF - let geometric_mean = AggregateUDF::from(GeoMeanUdf::new()); + let geometric_mean = AggregateUDF::from(GeoMeanUdaf::new()); ctx.register_udaf(geometric_mean.clone()); - let sql_df = ctx.sql("SELECT geo_mean(a) FROM t").await?; + let sql_df = ctx.sql("SELECT geo_mean(a) FROM t group by b").await?; sql_df.show().await?; // get a DataFrame from the context diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index 5882718acefd..5b578daa7e34 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -19,7 +19,7 @@ //! user defined aggregate functions use arrow::{array::AsArray, datatypes::Fields}; -use arrow_array::Int32Array; +use arrow_array::{types::UInt64Type, Int32Array, PrimitiveArray}; use arrow_schema::Schema; use std::sync::{ atomic::{AtomicBool, Ordering}, @@ -45,7 +45,9 @@ use datafusion::{ use datafusion_common::{ assert_contains, cast::as_primitive_array, exec_err, DataFusionError, }; -use datafusion_expr::{create_udaf, SimpleAggregateUDF}; +use datafusion_expr::{ + create_udaf, AggregateUDFImpl, GroupsAccumulator, SimpleAggregateUDF, +}; use datafusion_physical_expr::expressions::AvgAccumulator; /// Test to show the contents of the setup @@ -297,6 +299,25 @@ async fn case_sensitive_identifiers_user_defined_aggregates() -> Result<()> { Ok(()) } +#[tokio::test] +async fn test_groups_accumulator() -> Result<()> { + let ctx = SessionContext::new(); + let arr = Int32Array::from(vec![1]); + let batch = RecordBatch::try_from_iter(vec![("a", Arc::new(arr) as _)])?; + ctx.register_batch("t", batch).unwrap(); + + let udaf = AggregateUDF::from(TestGroupsAccumulator { + signature: Signature::exact(vec![DataType::Float64], Volatility::Immutable), + result: 1, + }); + ctx.register_udaf(udaf.clone()); + + let sql_df = ctx.sql("SELECT geo_mean(a) FROM t group by a").await?; + sql_df.show().await?; + + Ok(()) +} + /// Returns an context with a table "t" and the "first" and "time_sum" /// aggregate functions registered. /// @@ -621,3 +642,106 @@ impl Accumulator for FirstSelector { std::mem::size_of_val(self) } } + +#[derive(Debug, Clone)] +struct TestGroupsAccumulator { + signature: Signature, + result: u64, +} + +impl AggregateUDFImpl for TestGroupsAccumulator { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "geo_mean" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::UInt64) + } + + fn accumulator(&self, _arg: &DataType) -> Result> { + // should use groups accumulator + panic!("accumulator shouldn't invoke"); + } + + fn state_type(&self, _return_type: &DataType) -> Result> { + Ok(vec![DataType::UInt64]) + } + + fn groups_accumulator_supported(&self) -> bool { + true + } + + fn create_groups_accumulator(&self) -> Result> { + Ok(Box::new(self.clone())) + } +} + +impl Accumulator for TestGroupsAccumulator { + fn update_batch(&mut self, _values: &[ArrayRef]) -> Result<()> { + Ok(()) + } + + fn evaluate(&self) -> Result { + Ok(ScalarValue::from(self.result)) + } + + fn size(&self) -> usize { + std::mem::size_of::() + } + + fn state(&self) -> Result> { + Ok(vec![ScalarValue::from(self.result)]) + } + + fn merge_batch(&mut self, _states: &[ArrayRef]) -> Result<()> { + Ok(()) + } +} + +impl GroupsAccumulator for TestGroupsAccumulator { + fn update_batch( + &mut self, + _values: &[ArrayRef], + _group_indices: &[usize], + _opt_filter: Option<&arrow_array::BooleanArray>, + _total_num_groups: usize, + ) -> Result<()> { + Ok(()) + } + + fn evaluate(&mut self, _emit_to: datafusion_expr::EmitTo) -> Result { + Ok(Arc::new(PrimitiveArray::::new( + vec![self.result].into(), + None, + )) as ArrayRef) + } + + fn state(&mut self, _emit_to: datafusion_expr::EmitTo) -> Result> { + Ok(vec![Arc::new(PrimitiveArray::::new( + vec![self.result].into(), + None, + )) as ArrayRef]) + } + + fn merge_batch( + &mut self, + _values: &[ArrayRef], + _group_indices: &[usize], + _opt_filter: Option<&arrow_array::BooleanArray>, + _total_num_groups: usize, + ) -> Result<()> { + Ok(()) + } + + fn size(&self) -> usize { + std::mem::size_of::() + } +} diff --git a/datafusion/expr/src/groups_accumulator.rs b/datafusion/expr/src/groups_accumulator.rs new file mode 100644 index 000000000000..6580de19bc68 --- /dev/null +++ b/datafusion/expr/src/groups_accumulator.rs @@ -0,0 +1,153 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Vectorized [`GroupsAccumulator`] + +use arrow_array::{ArrayRef, BooleanArray}; +use datafusion_common::Result; + +/// Describes how many rows should be emitted during grouping. +#[derive(Debug, Clone, Copy)] +pub enum EmitTo { + /// Emit all groups + All, + /// Emit only the first `n` groups and shift all existing group + /// indexes down by `n`. + /// + /// For example, if `n=10`, group_index `0, 1, ... 9` are emitted + /// and group indexes '`10, 11, 12, ...` become `0, 1, 2, ...`. + First(usize), +} + +impl EmitTo { + /// Removes the number of rows from `v` required to emit the right + /// number of rows, returning a `Vec` with elements taken, and the + /// remaining values in `v`. + /// + /// This avoids copying if Self::All + pub fn take_needed(&self, v: &mut Vec) -> Vec { + match self { + Self::All => { + // Take the entire vector, leave new (empty) vector + std::mem::take(v) + } + Self::First(n) => { + // get end n+1,.. values into t + let mut t = v.split_off(*n); + // leave n+1,.. in v + std::mem::swap(v, &mut t); + t + } + } + } +} + +/// `GroupAccumulator` implements a single aggregate (e.g. AVG) and +/// stores the state for *all* groups internally. +/// +/// Each group is assigned a `group_index` by the hash table and each +/// accumulator manages the specific state, one per group_index. +/// +/// group_indexes are contiguous (there aren't gaps), and thus it is +/// expected that each GroupAccumulator will use something like `Vec<..>` +/// to store the group states. +pub trait GroupsAccumulator: Send { + /// Updates the accumulator's state from its arguments, encoded as + /// a vector of [`ArrayRef`]s. + /// + /// * `values`: the input arguments to the accumulator + /// + /// * `group_indices`: To which groups do the rows in `values` + /// belong, group id) + /// + /// * `opt_filter`: if present, only update aggregate state using + /// `values[i]` if `opt_filter[i]` is true + /// + /// * `total_num_groups`: the number of groups (the largest + /// group_index is thus `total_num_groups - 1`). + /// + /// Note that subsequent calls to update_batch may have larger + /// total_num_groups as new groups are seen. + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()>; + + /// Returns the final aggregate value for each group as a single + /// `RecordBatch`, resetting the internal state. + /// + /// The rows returned *must* be in group_index order: The value + /// for group_index 0, followed by 1, etc. Any group_index that + /// did not have values, should be null. + /// + /// For example, a `SUM` accumulator maintains a running sum for + /// each group, and `evaluate` will produce that running sum as + /// its output for all groups, in group_index order + /// + /// If `emit_to`` is [`EmitTo::All`], the accumulator should + /// return all groups and release / reset its internal state + /// equivalent to when it was first created. + /// + /// If `emit_to` is [`EmitTo::First`], only the first `n` groups + /// should be emitted and the state for those first groups + /// removed. State for the remaining groups must be retained for + /// future use. The group_indices on subsequent calls to + /// `update_batch` or `merge_batch` will be shifted down by + /// `n`. See [`EmitTo::First`] for more details. + fn evaluate(&mut self, emit_to: EmitTo) -> Result; + + /// Returns the intermediate aggregate state for this accumulator, + /// used for multi-phase grouping, resetting its internal state. + /// + /// For example, `AVG` might return two arrays: `SUM` and `COUNT` + /// but the `MIN` aggregate would just return a single array. + /// + /// Note more sophisticated internal state can be passed as + /// single `StructArray` rather than multiple arrays. + /// + /// See [`Self::evaluate`] for details on the required output + /// order and `emit_to`. + fn state(&mut self, emit_to: EmitTo) -> Result>; + + /// Merges intermediate state (the output from [`Self::state`]) + /// into this accumulator's values. + /// + /// For some aggregates (such as `SUM`), `merge_batch` is the same + /// as `update_batch`, but for some aggregates (such as `COUNT`, + /// where the partial counts must be summed) the operations + /// differ. See [`Self::state`] for more details on how state is + /// used and merged. + /// + /// * `values`: arrays produced from calling `state` previously to the accumulator + /// + /// Other arguments are the same as for [`Self::update_batch`]; + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()>; + + /// Amount of memory used to store the state of this accumulator, + /// in bytes. This function is called once per batch, so it should + /// be `O(n)` to compute, not `O(num_groups)` + fn size(&self) -> usize; +} diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 21647f384159..c29535456327 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -47,6 +47,7 @@ pub mod expr_rewriter; pub mod expr_schema; pub mod field_util; pub mod function; +pub mod groups_accumulator; pub mod interval_arithmetic; pub mod logical_plan; pub mod tree_node; @@ -70,6 +71,7 @@ pub use function::{ AccumulatorFactoryFunction, PartitionEvaluatorFactory, ReturnTypeFunction, ScalarFunctionImplementation, StateTypeFunction, }; +pub use groups_accumulator::{EmitTo, GroupsAccumulator}; pub use literal::{lit, lit_timestamp_nano, Literal, TimestampLiteral}; pub use logical_plan::*; pub use nullif::SUPPORTED_NULLIF_TYPES; diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 4983f6247d24..444a4f1e8099 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -17,12 +17,13 @@ //! [`AggregateUDF`]: User Defined Aggregate Functions +use crate::groups_accumulator::GroupsAccumulator; use crate::{Accumulator, Expr}; use crate::{ AccumulatorFactoryFunction, ReturnTypeFunction, Signature, StateTypeFunction, }; use arrow::datatypes::DataType; -use datafusion_common::Result; +use datafusion_common::{not_impl_err, DataFusionError, Result}; use std::any::Any; use std::fmt::{self, Debug, Formatter}; use std::sync::Arc; @@ -163,6 +164,16 @@ impl AggregateUDF { pub fn state_type(&self, return_type: &DataType) -> Result> { self.inner.state_type(return_type) } + + /// See [`AggregateUDFImpl::groups_accumulator_supported`] for more details. + pub fn groups_accumulator_supported(&self) -> bool { + self.inner.groups_accumulator_supported() + } + + /// See [`AggregateUDFImpl::create_groups_accumulator`] for more details. + pub fn create_groups_accumulator(&self) -> Result> { + self.inner.create_groups_accumulator() + } } impl From for AggregateUDF @@ -250,6 +261,22 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { /// Return the type used to serialize the [`Accumulator`]'s intermediate state. /// See [`Accumulator::state()`] for more details fn state_type(&self, return_type: &DataType) -> Result>; + + /// If the aggregate expression has a specialized + /// [`GroupsAccumulator`] implementation. If this returns true, + /// `[Self::create_groups_accumulator`] will be called. + fn groups_accumulator_supported(&self) -> bool { + false + } + + /// Return a specialized [`GroupsAccumulator`] that manages state + /// for all groups. + /// + /// For maximum performance, a [`GroupsAccumulator`] should be + /// implemented in addition to [`Accumulator`]. + fn create_groups_accumulator(&self) -> Result> { + not_impl_err!("GroupsAccumulator hasn't been implemented for {self:?} yet") + } } /// Implementation of [`AggregateUDFImpl`] that wraps the function style pointers diff --git a/datafusion/physical-expr/src/aggregate/average.rs b/datafusion/physical-expr/src/aggregate/average.rs index 91f2fb952dce..187373e14f99 100644 --- a/datafusion/physical-expr/src/aggregate/average.rs +++ b/datafusion/physical-expr/src/aggregate/average.rs @@ -27,7 +27,7 @@ use std::sync::Arc; use crate::aggregate::groups_accumulator::accumulate::NullState; use crate::aggregate::utils::down_cast_any_ref; use crate::expressions::format_state_name; -use crate::{AggregateExpr, GroupsAccumulator, PhysicalExpr}; +use crate::{AggregateExpr, PhysicalExpr}; use arrow::compute::sum; use arrow::datatypes::{DataType, Decimal128Type, Float64Type, UInt64Type}; use arrow::{ @@ -41,9 +41,8 @@ use arrow_array::{ use arrow_buffer::{i256, ArrowNativeType}; use datafusion_common::{not_impl_err, DataFusionError, Result, ScalarValue}; use datafusion_expr::type_coercion::aggregates::avg_return_type; -use datafusion_expr::Accumulator; +use datafusion_expr::{Accumulator, EmitTo, GroupsAccumulator}; -use super::groups_accumulator::EmitTo; use super::utils::DecimalAverager; /// AVG aggregate expression diff --git a/datafusion/physical-expr/src/aggregate/bit_and_or_xor.rs b/datafusion/physical-expr/src/aggregate/bit_and_or_xor.rs index 6c97d620616a..92883d8049d2 100644 --- a/datafusion/physical-expr/src/aggregate/bit_and_or_xor.rs +++ b/datafusion/physical-expr/src/aggregate/bit_and_or_xor.rs @@ -22,11 +22,11 @@ use datafusion_common::cast::as_list_array; use std::any::Any; use std::sync::Arc; -use crate::{AggregateExpr, GroupsAccumulator, PhysicalExpr}; +use crate::{AggregateExpr, PhysicalExpr}; use arrow::datatypes::DataType; use arrow::{array::ArrayRef, datatypes::Field}; use datafusion_common::{not_impl_err, DataFusionError, Result, ScalarValue}; -use datafusion_expr::Accumulator; +use datafusion_expr::{Accumulator, GroupsAccumulator}; use std::collections::HashSet; use crate::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator; diff --git a/datafusion/physical-expr/src/aggregate/bool_and_or.rs b/datafusion/physical-expr/src/aggregate/bool_and_or.rs index 9757d314b6aa..ae205141b4b9 100644 --- a/datafusion/physical-expr/src/aggregate/bool_and_or.rs +++ b/datafusion/physical-expr/src/aggregate/bool_and_or.rs @@ -17,7 +17,7 @@ //! Defines physical expressions that can evaluated at runtime during query execution -use crate::{AggregateExpr, GroupsAccumulator, PhysicalExpr}; +use crate::{AggregateExpr, PhysicalExpr}; use arrow::datatypes::DataType; use arrow::{ array::{ArrayRef, BooleanArray}, @@ -26,7 +26,7 @@ use arrow::{ use datafusion_common::{ downcast_value, internal_err, not_impl_err, DataFusionError, Result, ScalarValue, }; -use datafusion_expr::Accumulator; +use datafusion_expr::{Accumulator, GroupsAccumulator}; use std::any::Any; use std::sync::Arc; diff --git a/datafusion/physical-expr/src/aggregate/count.rs b/datafusion/physical-expr/src/aggregate/count.rs index 8e9ae5cea36b..b6d4b7300434 100644 --- a/datafusion/physical-expr/src/aggregate/count.rs +++ b/datafusion/physical-expr/src/aggregate/count.rs @@ -23,7 +23,7 @@ use std::ops::BitAnd; use std::sync::Arc; use crate::aggregate::utils::down_cast_any_ref; -use crate::{AggregateExpr, GroupsAccumulator, PhysicalExpr}; +use crate::{AggregateExpr, PhysicalExpr}; use arrow::array::{Array, Int64Array}; use arrow::compute; use arrow::datatypes::DataType; @@ -34,12 +34,11 @@ use arrow_array::PrimitiveArray; use arrow_buffer::BooleanBuffer; use datafusion_common::{downcast_value, ScalarValue}; use datafusion_common::{DataFusionError, Result}; -use datafusion_expr::Accumulator; +use datafusion_expr::{Accumulator, EmitTo, GroupsAccumulator}; use crate::expressions::format_state_name; use super::groups_accumulator::accumulate::accumulate_indices; -use super::groups_accumulator::EmitTo; /// COUNT aggregate expression /// Returns the amount of non-null values of the given expression. diff --git a/datafusion/physical-expr/src/aggregate/groups_accumulator/accumulate.rs b/datafusion/physical-expr/src/aggregate/groups_accumulator/accumulate.rs index 596265a737da..7080ea40039d 100644 --- a/datafusion/physical-expr/src/aggregate/groups_accumulator/accumulate.rs +++ b/datafusion/physical-expr/src/aggregate/groups_accumulator/accumulate.rs @@ -17,14 +17,13 @@ //! [`GroupsAccumulator`] helpers: [`NullState`] and [`accumulate_indices`] //! -//! [`GroupsAccumulator`]: crate::GroupsAccumulator +//! [`GroupsAccumulator`]: datafusion_expr::GroupsAccumulator use arrow::datatypes::ArrowPrimitiveType; use arrow_array::{Array, BooleanArray, PrimitiveArray}; use arrow_buffer::{BooleanBuffer, BooleanBufferBuilder, NullBuffer}; -use crate::EmitTo; - +use datafusion_expr::EmitTo; /// Track the accumulator null state per row: if any values for that /// group were null and if any values have been seen at all for that group. /// @@ -49,7 +48,7 @@ use crate::EmitTo; /// had at least one value to accumulate so they do not need to track /// if they have seen values for a particular group. /// -/// [`GroupsAccumulator`]: crate::GroupsAccumulator +/// [`GroupsAccumulator`]: datafusion_expr::GroupsAccumulator #[derive(Debug)] pub struct NullState { /// Have we seen any non-filtered input values for `group_index`? @@ -62,6 +61,12 @@ pub struct NullState { seen_values: BooleanBufferBuilder, } +impl Default for NullState { + fn default() -> Self { + Self::new() + } +} + impl NullState { pub fn new() -> Self { Self { diff --git a/datafusion/physical-expr/src/aggregate/groups_accumulator/adapter.rs b/datafusion/physical-expr/src/aggregate/groups_accumulator/adapter.rs index c6fd17a69b39..b4e6d2ebc5fc 100644 --- a/datafusion/physical-expr/src/aggregate/groups_accumulator/adapter.rs +++ b/datafusion/physical-expr/src/aggregate/groups_accumulator/adapter.rs @@ -17,7 +17,6 @@ //! Adapter that makes [`GroupsAccumulator`] out of [`Accumulator`] -use super::{EmitTo, GroupsAccumulator}; use arrow::{ array::{AsArray, UInt32Builder}, compute, @@ -28,7 +27,7 @@ use datafusion_common::{ arrow_datafusion_err, utils::get_arrayref_at_indices, DataFusionError, Result, ScalarValue, }; -use datafusion_expr::Accumulator; +use datafusion_expr::{Accumulator, EmitTo, GroupsAccumulator}; /// An adapter that implements [`GroupsAccumulator`] for any [`Accumulator`] /// diff --git a/datafusion/physical-expr/src/aggregate/groups_accumulator/bool_op.rs b/datafusion/physical-expr/src/aggregate/groups_accumulator/bool_op.rs index 21b6cc29e83d..f40c661a7a2f 100644 --- a/datafusion/physical-expr/src/aggregate/groups_accumulator/bool_op.rs +++ b/datafusion/physical-expr/src/aggregate/groups_accumulator/bool_op.rs @@ -21,10 +21,9 @@ use arrow::array::AsArray; use arrow_array::{ArrayRef, BooleanArray}; use arrow_buffer::{BooleanBuffer, BooleanBufferBuilder}; use datafusion_common::Result; +use datafusion_expr::{EmitTo, GroupsAccumulator}; -use crate::GroupsAccumulator; - -use super::{accumulate::NullState, EmitTo}; +use super::accumulate::NullState; /// An accumulator that implements a single operation over a /// [`BooleanArray`] where the accumulated state is also boolean (such diff --git a/datafusion/physical-expr/src/aggregate/groups_accumulator/mod.rs b/datafusion/physical-expr/src/aggregate/groups_accumulator/mod.rs index d2e64d373be2..de090badd349 100644 --- a/datafusion/physical-expr/src/aggregate/groups_accumulator/mod.rs +++ b/datafusion/physical-expr/src/aggregate/groups_accumulator/mod.rs @@ -15,146 +15,10 @@ // specific language governing permissions and limitations // under the License. -//! Vectorized [`GroupsAccumulator`] - pub(crate) mod accumulate; mod adapter; +pub use accumulate::NullState; pub use adapter::GroupsAccumulatorAdapter; pub(crate) mod bool_op; pub(crate) mod prim_op; - -use arrow_array::{ArrayRef, BooleanArray}; -use datafusion_common::Result; - -/// Describes how many rows should be emitted during grouping. -#[derive(Debug, Clone, Copy)] -pub enum EmitTo { - /// Emit all groups - All, - /// Emit only the first `n` groups and shift all existing group - /// indexes down by `n`. - /// - /// For example, if `n=10`, group_index `0, 1, ... 9` are emitted - /// and group indexes '`10, 11, 12, ...` become `0, 1, 2, ...`. - First(usize), -} - -impl EmitTo { - /// Removes the number of rows from `v` required to emit the right - /// number of rows, returning a `Vec` with elements taken, and the - /// remaining values in `v`. - /// - /// This avoids copying if Self::All - pub fn take_needed(&self, v: &mut Vec) -> Vec { - match self { - Self::All => { - // Take the entire vector, leave new (empty) vector - std::mem::take(v) - } - Self::First(n) => { - // get end n+1,.. values into t - let mut t = v.split_off(*n); - // leave n+1,.. in v - std::mem::swap(v, &mut t); - t - } - } - } -} - -/// `GroupAccumulator` implements a single aggregate (e.g. AVG) and -/// stores the state for *all* groups internally. -/// -/// Each group is assigned a `group_index` by the hash table and each -/// accumulator manages the specific state, one per group_index. -/// -/// group_indexes are contiguous (there aren't gaps), and thus it is -/// expected that each GroupAccumulator will use something like `Vec<..>` -/// to store the group states. -pub trait GroupsAccumulator: Send { - /// Updates the accumulator's state from its arguments, encoded as - /// a vector of [`ArrayRef`]s. - /// - /// * `values`: the input arguments to the accumulator - /// - /// * `group_indices`: To which groups do the rows in `values` - /// belong, group id) - /// - /// * `opt_filter`: if present, only update aggregate state using - /// `values[i]` if `opt_filter[i]` is true - /// - /// * `total_num_groups`: the number of groups (the largest - /// group_index is thus `total_num_groups - 1`). - /// - /// Note that subsequent calls to update_batch may have larger - /// total_num_groups as new groups are seen. - fn update_batch( - &mut self, - values: &[ArrayRef], - group_indices: &[usize], - opt_filter: Option<&BooleanArray>, - total_num_groups: usize, - ) -> Result<()>; - - /// Returns the final aggregate value for each group as a single - /// `RecordBatch`, resetting the internal state. - /// - /// The rows returned *must* be in group_index order: The value - /// for group_index 0, followed by 1, etc. Any group_index that - /// did not have values, should be null. - /// - /// For example, a `SUM` accumulator maintains a running sum for - /// each group, and `evaluate` will produce that running sum as - /// its output for all groups, in group_index order - /// - /// If `emit_to`` is [`EmitTo::All`], the accumulator should - /// return all groups and release / reset its internal state - /// equivalent to when it was first created. - /// - /// If `emit_to` is [`EmitTo::First`], only the first `n` groups - /// should be emitted and the state for those first groups - /// removed. State for the remaining groups must be retained for - /// future use. The group_indices on subsequent calls to - /// `update_batch` or `merge_batch` will be shifted down by - /// `n`. See [`EmitTo::First`] for more details. - fn evaluate(&mut self, emit_to: EmitTo) -> Result; - - /// Returns the intermediate aggregate state for this accumulator, - /// used for multi-phase grouping, resetting its internal state. - /// - /// For example, `AVG` might return two arrays: `SUM` and `COUNT` - /// but the `MIN` aggregate would just return a single array. - /// - /// Note more sophisticated internal state can be passed as - /// single `StructArray` rather than multiple arrays. - /// - /// See [`Self::evaluate`] for details on the required output - /// order and `emit_to`. - fn state(&mut self, emit_to: EmitTo) -> Result>; - - /// Merges intermediate state (the output from [`Self::state`]) - /// into this accumulator's values. - /// - /// For some aggregates (such as `SUM`), `merge_batch` is the same - /// as `update_batch`, but for some aggregates (such as `COUNT`, - /// where the partial counts must be summed) the operations - /// differ. See [`Self::state`] for more details on how state is - /// used and merged. - /// - /// * `values`: arrays produced from calling `state` previously to the accumulator - /// - /// Other arguments are the same as for [`Self::update_batch`]; - fn merge_batch( - &mut self, - values: &[ArrayRef], - group_indices: &[usize], - opt_filter: Option<&BooleanArray>, - total_num_groups: usize, - ) -> Result<()>; - - /// Amount of memory used to store the state of this accumulator, - /// in bytes. This function is called once per batch, so it should - /// be `O(n)` to compute, not `O(num_groups)` - fn size(&self) -> usize; -} diff --git a/datafusion/physical-expr/src/aggregate/groups_accumulator/prim_op.rs b/datafusion/physical-expr/src/aggregate/groups_accumulator/prim_op.rs index 130d56271280..994f5447d7c0 100644 --- a/datafusion/physical-expr/src/aggregate/groups_accumulator/prim_op.rs +++ b/datafusion/physical-expr/src/aggregate/groups_accumulator/prim_op.rs @@ -21,10 +21,9 @@ use arrow::{array::AsArray, datatypes::ArrowPrimitiveType}; use arrow_array::{ArrayRef, BooleanArray, PrimitiveArray}; use arrow_schema::DataType; use datafusion_common::Result; +use datafusion_expr::{EmitTo, GroupsAccumulator}; -use crate::GroupsAccumulator; - -use super::{accumulate::NullState, EmitTo}; +use super::accumulate::NullState; /// An accumulator that implements a single operation over /// [`ArrowPrimitiveType`] where the accumulated state is the same as diff --git a/datafusion/physical-expr/src/aggregate/min_max.rs b/datafusion/physical-expr/src/aggregate/min_max.rs index 7e3ef2a2abab..ba3e70855355 100644 --- a/datafusion/physical-expr/src/aggregate/min_max.rs +++ b/datafusion/physical-expr/src/aggregate/min_max.rs @@ -22,7 +22,7 @@ use std::convert::TryFrom; use std::sync::Arc; use crate::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator; -use crate::{AggregateExpr, GroupsAccumulator, PhysicalExpr}; +use crate::{AggregateExpr, PhysicalExpr}; use arrow::compute; use arrow::datatypes::{ DataType, Date32Type, Date64Type, Time32MillisecondType, Time32SecondType, @@ -47,7 +47,7 @@ use arrow_array::types::{ use datafusion_common::internal_err; use datafusion_common::ScalarValue; use datafusion_common::{downcast_value, DataFusionError, Result}; -use datafusion_expr::Accumulator; +use datafusion_expr::{Accumulator, GroupsAccumulator}; use crate::aggregate::utils::down_cast_any_ref; use crate::expressions::format_state_name; diff --git a/datafusion/physical-expr/src/aggregate/mod.rs b/datafusion/physical-expr/src/aggregate/mod.rs index 270a8e6f7705..2bb205ce90dc 100644 --- a/datafusion/physical-expr/src/aggregate/mod.rs +++ b/datafusion/physical-expr/src/aggregate/mod.rs @@ -19,13 +19,12 @@ use std::any::Any; use std::fmt::Debug; use std::sync::Arc; -use self::groups_accumulator::GroupsAccumulator; use crate::expressions::{NthValueAgg, OrderSensitiveArrayAgg}; use crate::{PhysicalExpr, PhysicalSortExpr}; use arrow::datatypes::Field; use datafusion_common::{not_impl_err, DataFusionError, Result}; -use datafusion_expr::Accumulator; +use datafusion_expr::{Accumulator, GroupsAccumulator}; mod hyperloglog; mod tdigest; diff --git a/datafusion/physical-expr/src/aggregate/sum.rs b/datafusion/physical-expr/src/aggregate/sum.rs index 03f666cc4e5d..a770b3874ce0 100644 --- a/datafusion/physical-expr/src/aggregate/sum.rs +++ b/datafusion/physical-expr/src/aggregate/sum.rs @@ -23,7 +23,7 @@ use std::sync::Arc; use super::groups_accumulator::prim_op::PrimitiveGroupsAccumulator; use crate::aggregate::utils::down_cast_any_ref; use crate::expressions::format_state_name; -use crate::{AggregateExpr, GroupsAccumulator, PhysicalExpr}; +use crate::{AggregateExpr, PhysicalExpr}; use arrow::compute::sum; use arrow::datatypes::DataType; use arrow::{array::ArrayRef, datatypes::Field}; @@ -35,7 +35,7 @@ use arrow_array::{Array, ArrowNativeTypeOp, ArrowNumericType}; use arrow_buffer::ArrowNativeType; use datafusion_common::{not_impl_err, DataFusionError, Result, ScalarValue}; use datafusion_expr::type_coercion::aggregates::sum_return_type; -use datafusion_expr::Accumulator; +use datafusion_expr::{Accumulator, GroupsAccumulator}; /// SUM aggregate expression #[derive(Debug, Clone)] diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index bbfba4ad8310..007a03985f45 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -105,14 +105,13 @@ pub(crate) mod tests { use std::sync::Arc; use crate::expressions::{col, create_aggregate_expr, try_cast}; - use crate::{AggregateExpr, EmitTo}; - + use crate::AggregateExpr; use arrow::record_batch::RecordBatch; use arrow_array::ArrayRef; use arrow_schema::{Field, Schema}; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::type_coercion::aggregates::coerce_types; - use datafusion_expr::AggregateFunction; + use datafusion_expr::{AggregateFunction, EmitTo}; /// macro to perform an aggregation using [`datafusion_expr::Accumulator`] and verify the /// result. diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index fffa8f602d87..6f55f56916e7 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -48,9 +48,7 @@ pub mod utils; pub mod var_provider; pub mod window; -pub use aggregate::groups_accumulator::{ - EmitTo, GroupsAccumulator, GroupsAccumulatorAdapter, -}; +pub use aggregate::groups_accumulator::{GroupsAccumulatorAdapter, NullState}; pub use aggregate::AggregateExpr; pub use analysis::{analyze, AnalysisContext, ExprBoundaries}; pub use equivalence::EquivalenceProperties; diff --git a/datafusion/physical-plan/src/aggregates/group_values/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/mod.rs index cafa385eac39..ef9aac3d3ef0 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/mod.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/mod.rs @@ -19,9 +19,9 @@ use arrow::record_batch::RecordBatch; use arrow_array::{downcast_primitive, ArrayRef}; use arrow_schema::SchemaRef; use datafusion_common::Result; -use datafusion_physical_expr::EmitTo; pub(crate) mod primitive; +use datafusion_expr::EmitTo; use primitive::GroupValuesPrimitive; mod row; diff --git a/datafusion/physical-plan/src/aggregates/group_values/primitive.rs b/datafusion/physical-plan/src/aggregates/group_values/primitive.rs index e3ba284797d1..18d20f3c47e6 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/primitive.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/primitive.rs @@ -26,7 +26,7 @@ use arrow_array::{ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, PrimitiveArra use arrow_schema::DataType; use datafusion_common::Result; use datafusion_execution::memory_pool::proxy::VecAllocExt; -use datafusion_physical_expr::EmitTo; +use datafusion_expr::EmitTo; use half::f16; use hashbrown::raw::RawTable; use std::sync::Arc; diff --git a/datafusion/physical-plan/src/aggregates/group_values/row.rs b/datafusion/physical-plan/src/aggregates/group_values/row.rs index 10ff9edb8912..3b7480cd292a 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/row.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/row.rs @@ -25,7 +25,7 @@ use arrow_schema::{DataType, SchemaRef}; use datafusion_common::hash_utils::create_hashes; use datafusion_common::{DataFusionError, Result}; use datafusion_execution::memory_pool::proxy::{RawTableAllocExt, VecAllocExt}; -use datafusion_physical_expr::EmitTo; +use datafusion_expr::EmitTo; use hashbrown::raw::RawTable; /// A [`GroupValues`] making use of [`Rows`] diff --git a/datafusion/physical-plan/src/aggregates/order/full.rs b/datafusion/physical-plan/src/aggregates/order/full.rs index f46ee687faf1..c15538e8ab8e 100644 --- a/datafusion/physical-plan/src/aggregates/order/full.rs +++ b/datafusion/physical-plan/src/aggregates/order/full.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use datafusion_physical_expr::EmitTo; +use datafusion_expr::EmitTo; /// Tracks grouping state when the data is ordered entirely by its /// group keys diff --git a/datafusion/physical-plan/src/aggregates/order/mod.rs b/datafusion/physical-plan/src/aggregates/order/mod.rs index b258b97a9e84..4f1914b12c96 100644 --- a/datafusion/physical-plan/src/aggregates/order/mod.rs +++ b/datafusion/physical-plan/src/aggregates/order/mod.rs @@ -18,7 +18,8 @@ use arrow_array::ArrayRef; use arrow_schema::Schema; use datafusion_common::Result; -use datafusion_physical_expr::{EmitTo, PhysicalSortExpr}; +use datafusion_expr::EmitTo; +use datafusion_physical_expr::PhysicalSortExpr; mod full; mod partial; diff --git a/datafusion/physical-plan/src/aggregates/order/partial.rs b/datafusion/physical-plan/src/aggregates/order/partial.rs index ff8a75b9b28b..ecd37c913e98 100644 --- a/datafusion/physical-plan/src/aggregates/order/partial.rs +++ b/datafusion/physical-plan/src/aggregates/order/partial.rs @@ -20,7 +20,7 @@ use arrow_array::ArrayRef; use arrow_schema::Schema; use datafusion_common::Result; use datafusion_execution::memory_pool::proxy::VecAllocExt; -use datafusion_physical_expr::EmitTo; +use datafusion_expr::EmitTo; use datafusion_physical_expr::PhysicalSortExpr; /// Tracks grouping state when the data is ordered by some subset of diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index 6a0c02f5caf3..f9db0a050cfc 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -44,9 +44,10 @@ use datafusion_execution::memory_pool::proxy::VecAllocExt; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_execution::TaskContext; +use datafusion_expr::{EmitTo, GroupsAccumulator}; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::{ - AggregateExpr, EmitTo, GroupsAccumulator, GroupsAccumulatorAdapter, PhysicalSortExpr, + AggregateExpr, GroupsAccumulatorAdapter, PhysicalSortExpr, }; use futures::ready; diff --git a/datafusion/physical-plan/src/udaf.rs b/datafusion/physical-plan/src/udaf.rs index 94017efe97aa..a82bbe1d0705 100644 --- a/datafusion/physical-plan/src/udaf.rs +++ b/datafusion/physical-plan/src/udaf.rs @@ -17,6 +17,7 @@ //! This module contains functions and structs supporting user-defined aggregate functions. +use datafusion_expr::GroupsAccumulator; use fmt::Debug; use std::any::Any; use std::fmt; @@ -166,6 +167,14 @@ impl AggregateExpr for AggregateFunctionExpr { fn name(&self) -> &str { &self.name } + + fn groups_accumulator_supported(&self) -> bool { + self.fun.groups_accumulator_supported() + } + + fn create_groups_accumulator(&self) -> Result> { + self.fun.create_groups_accumulator() + } } impl PartialEq for AggregateFunctionExpr {