diff --git a/Cargo.toml b/Cargo.toml index ff231178a2b3..48e555bd5527 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,7 @@ exclude = ["datafusion-cli"] members = [ "datafusion/common", + "datafusion/common_runtime", "datafusion/core", "datafusion/expr", "datafusion/execution", @@ -65,13 +66,14 @@ arrow-ord = { version = "50.0.0", default-features = false } arrow-schema = { version = "50.0.0", default-features = false } arrow-string = { version = "50.0.0", default-features = false } async-trait = "0.1.73" -bigdecimal = "0.4.1" +bigdecimal = "=0.4.1" bytes = "1.4" chrono = { version = "0.4.34", default-features = false } ctor = "0.2.0" dashmap = "5.4.0" datafusion = { path = "datafusion/core", version = "36.0.0", default-features = false } datafusion-common = { path = "datafusion/common", version = "36.0.0", default-features = false } +datafusion-common-runtime = { path = "datafusion/common_runtime", version = "36.0.0" } datafusion-execution = { path = "datafusion/execution", version = "36.0.0" } datafusion-expr = { path = "datafusion/expr", version = "36.0.0" } datafusion-functions = { path = "datafusion/functions", version = "36.0.0" } @@ -97,7 +99,7 @@ parquet = { version = "50.0.0", default-features = false, features = ["arrow", " rand = "0.8" rstest = "0.18.0" serde_json = "1" -sqlparser = { version = "0.43.0", features = ["visitor"] } +sqlparser = { version = "0.44.0", features = ["visitor"] } tempfile = "3" thiserror = "1.0.44" tokio = { version = "1.36", features = ["macros", "rt", "sync"] } diff --git a/README.md b/README.md index 634aa426bdff..e5ac9503be44 100644 --- a/README.md +++ b/README.md @@ -38,7 +38,7 @@ [API Docs](https://docs.rs/datafusion/latest/datafusion/) | [Chat](https://discord.com/channels/885562378132000778/885562378132000781) -logo +logo DataFusion is a very fast, extensible query engine for building high-quality data-centric systems in [Rust](http://rustlang.org), using the [Apache Arrow](https://arrow.apache.org) diff --git a/clippy.toml b/clippy.toml index 6eb9906c89cf..62d8263085df 100644 --- a/clippy.toml +++ b/clippy.toml @@ -1,6 +1,6 @@ disallowed-methods = [ { path = "tokio::task::spawn", reason = "To provide cancel-safety, use `SpawnedTask::spawn` instead (https://github.com/apache/arrow-datafusion/issues/6513)" }, - { path = "tokio::task::spawn_blocking", reason = "To provide cancel-safety, use `SpawnedTask::spawn` instead (https://github.com/apache/arrow-datafusion/issues/6513)" }, + { path = "tokio::task::spawn_blocking", reason = "To provide cancel-safety, use `SpawnedTask::spawn_blocking` instead (https://github.com/apache/arrow-datafusion/issues/6513)" }, ] disallowed-types = [ diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 2379a30ce10f..46484be0e195 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -25,9 +25,9 @@ checksum = "aae1277d39aeec15cb388266ecc24b11c80469deae6067e17a1a7aa9e5c1f234" [[package]] name = "ahash" -version = "0.8.10" +version = "0.8.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b79b82693f705137f8fb9b37871d99e4f9a7df12b917eed79c3d3954830a60b" +checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" dependencies = [ "cfg-if", "const-random", @@ -270,7 +270,7 @@ dependencies = [ "arrow-schema", "chrono", "half", - "indexmap 2.2.3", + "indexmap 2.2.5", "lexical-core", "num", "serde", @@ -384,7 +384,7 @@ checksum = "c980ee35e870bd1a4d2c8294d4c04d0499e67bca1e4b5cefcc693c2fa00caea9" dependencies = [ "proc-macro2", "quote", - "syn 2.0.51", + "syn 2.0.52", ] [[package]] @@ -874,7 +874,7 @@ dependencies = [ "iana-time-zone", "num-traits", "serde", - "windows-targets 0.52.3", + "windows-targets 0.52.4", ] [[package]] @@ -962,9 +962,9 @@ dependencies = [ [[package]] name = "const-random" -version = "0.1.17" +version = "0.1.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5aaf16c9c2c612020bcfd042e170f6e32de9b9d75adb5277cdbbd2e2c8c8299a" +checksum = "87e00182fe74b066627d63b85fd550ac2998d4b0bd86bfed477a0ae4c7c71359" dependencies = [ "const-random-macro", ] @@ -1073,7 +1073,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ad291aa74992b9b7a7e88c38acbbf6ad7e107f1d90ee8775b7bc1fc3394f485c" dependencies = [ "quote", - "syn 2.0.51", + "syn 2.0.52", ] [[package]] @@ -1112,6 +1112,7 @@ dependencies = [ "chrono", "dashmap", "datafusion-common", + "datafusion-common-runtime", "datafusion-execution", "datafusion-expr", "datafusion-functions", @@ -1125,7 +1126,7 @@ dependencies = [ "glob", "half", "hashbrown 0.14.3", - "indexmap 2.2.3", + "indexmap 2.2.5", "itertools", "log", "num-traits", @@ -1193,6 +1194,13 @@ dependencies = [ "sqlparser", ] +[[package]] +name = "datafusion-common-runtime" +version = "36.0.0" +dependencies = [ + "tokio", +] + [[package]] name = "datafusion-execution" version = "36.0.0" @@ -1291,7 +1299,7 @@ dependencies = [ "half", "hashbrown 0.14.3", "hex", - "indexmap 2.2.3", + "indexmap 2.2.5", "itertools", "log", "md-5", @@ -1316,13 +1324,14 @@ dependencies = [ "async-trait", "chrono", "datafusion-common", + "datafusion-common-runtime", "datafusion-execution", "datafusion-expr", "datafusion-physical-expr", "futures", "half", "hashbrown 0.14.3", - "indexmap 2.2.3", + "indexmap 2.2.5", "itertools", "log", "once_cell", @@ -1610,7 +1619,7 @@ checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" dependencies = [ "proc-macro2", "quote", - "syn 2.0.51", + "syn 2.0.52", ] [[package]] @@ -1694,7 +1703,7 @@ dependencies = [ "futures-sink", "futures-util", "http", - "indexmap 2.2.3", + "indexmap 2.2.5", "slab", "tokio", "tokio-util", @@ -1754,9 +1763,9 @@ dependencies = [ [[package]] name = "hermit-abi" -version = "0.3.8" +version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "379dada1584ad501b383485dd706b8afb7a70fcbc7f4da7d780638a5a6124a60" +checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" [[package]] name = "hex" @@ -1911,9 +1920,9 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.2.3" +version = "2.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "233cf39063f058ea2caae4091bf4a3ef70a653afbc026f5c4a4135d114e3c177" +checksum = "7b0b929d511467233429c45a44ac1dcaa21ba0f5ba11e4879e6ed28ddb4f9df4" dependencies = [ "equivalent", "hashbrown 0.14.3", @@ -2112,9 +2121,9 @@ dependencies = [ [[package]] name = "log" -version = "0.4.20" +version = "0.4.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f" +checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c" [[package]] name = "lz4_flex" @@ -2178,9 +2187,9 @@ dependencies = [ [[package]] name = "mio" -version = "0.8.10" +version = "0.8.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f3d0b296e374a4e6f3c7b0a1f5a51d748a0d34c85e7dc48fc3fa9a87657fe09" +checksum = "a4a650543ca06a924e8b371db273b2756685faae30f8487da1b56505a8f78b0c" dependencies = [ "libc", "wasi", @@ -2301,7 +2310,7 @@ version = "1.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" dependencies = [ - "hermit-abi 0.3.8", + "hermit-abi 0.3.9", "libc", ] @@ -2334,7 +2343,7 @@ dependencies = [ "rand", "reqwest", "ring 0.17.8", - "rustls-pemfile 2.1.0", + "rustls-pemfile 2.1.1", "serde", "serde_json", "snafu", @@ -2463,7 +2472,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e1d3afd2628e69da2be385eb6f2fd57c8ac7977ceeff6dc166ff1657b0e386a9" dependencies = [ "fixedbitset", - "indexmap 2.2.3", + "indexmap 2.2.5", ] [[package]] @@ -2521,7 +2530,7 @@ checksum = "266c042b60c9c76b8d53061e52b2e0d1116abc57cefc8c5cd671619a56ac3690" dependencies = [ "proc-macro2", "quote", - "syn 2.0.51", + "syn 2.0.52", ] [[package]] @@ -2917,9 +2926,9 @@ dependencies = [ [[package]] name = "rustls-pemfile" -version = "2.1.0" +version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3c333bb734fcdedcea57de1602543590f545f127dc8b533324318fd492c5c70b" +checksum = "f48172685e6ff52a556baa527774f61fcaa884f59daf3375c62a3f1cd2549dab" dependencies = [ "base64", "rustls-pki-types", @@ -3062,7 +3071,7 @@ checksum = "7eb0b34b42edc17f6b7cac84a52a1c5f0e1bb2227e997ca9011ea3dd34e8610b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.51", + "syn 2.0.52", ] [[package]] @@ -3181,9 +3190,9 @@ checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" [[package]] name = "sqlparser" -version = "0.43.1" +version = "0.44.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f95c4bae5aba7cd30bd506f7140026ade63cff5afd778af8854026f9606bf5d4" +checksum = "aaf9c7ff146298ffda83a200f8d5084f08dcee1edfc135fcc1d646a45d50ffd6" dependencies = [ "log", "sqlparser_derive", @@ -3197,7 +3206,7 @@ checksum = "01b2e185515564f15375f593fb966b5718bc624ba77fe49fa4616ad619690554" dependencies = [ "proc-macro2", "quote", - "syn 2.0.51", + "syn 2.0.52", ] [[package]] @@ -3243,7 +3252,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.51", + "syn 2.0.52", ] [[package]] @@ -3256,7 +3265,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.51", + "syn 2.0.52", ] [[package]] @@ -3278,9 +3287,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.51" +version = "2.0.52" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ab617d94515e94ae53b8406c628598680aa0c9587474ecbe58188f7b345d66c" +checksum = "b699d15b36d1f02c3e7c69f8ffef53de37aefae075d8488d4ba1a7788d574a07" dependencies = [ "proc-macro2", "quote", @@ -3364,7 +3373,7 @@ checksum = "a953cb265bef375dae3de6663da4d3804eee9682ea80d8e2542529b73c531c81" dependencies = [ "proc-macro2", "quote", - "syn 2.0.51", + "syn 2.0.52", ] [[package]] @@ -3459,7 +3468,7 @@ checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.51", + "syn 2.0.52", ] [[package]] @@ -3556,7 +3565,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.51", + "syn 2.0.52", ] [[package]] @@ -3601,7 +3610,7 @@ checksum = "f03ca4cb38206e2bef0700092660bb74d696f808514dae47fa1467cbfe26e96e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.51", + "syn 2.0.52", ] [[package]] @@ -3711,9 +3720,9 @@ dependencies = [ [[package]] name = "walkdir" -version = "2.4.0" +version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d71d857dc86794ca4c280d616f7da00d2dbfd8cd788846559a6813e6aa4b54ee" +checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b" dependencies = [ "same-file", "winapi-util", @@ -3755,7 +3764,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.51", + "syn 2.0.52", "wasm-bindgen-shared", ] @@ -3789,7 +3798,7 @@ checksum = "642f325be6301eb8107a83d12a8ac6c1e1c54345a7ef1a9261962dfefda09e66" dependencies = [ "proc-macro2", "quote", - "syn 2.0.51", + "syn 2.0.52", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -3870,7 +3879,7 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" dependencies = [ - "windows-targets 0.52.3", + "windows-targets 0.52.4", ] [[package]] @@ -3888,7 +3897,7 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" dependencies = [ - "windows-targets 0.52.3", + "windows-targets 0.52.4", ] [[package]] @@ -3908,17 +3917,17 @@ dependencies = [ [[package]] name = "windows-targets" -version = "0.52.3" +version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d380ba1dc7187569a8a9e91ed34b8ccfc33123bbacb8c0aed2d1ad7f3ef2dc5f" +checksum = "7dd37b7e5ab9018759f893a1952c9420d060016fc19a472b4bb20d1bdd694d1b" dependencies = [ - "windows_aarch64_gnullvm 0.52.3", - "windows_aarch64_msvc 0.52.3", - "windows_i686_gnu 0.52.3", - "windows_i686_msvc 0.52.3", - "windows_x86_64_gnu 0.52.3", - "windows_x86_64_gnullvm 0.52.3", - "windows_x86_64_msvc 0.52.3", + "windows_aarch64_gnullvm 0.52.4", + "windows_aarch64_msvc 0.52.4", + "windows_i686_gnu 0.52.4", + "windows_i686_msvc 0.52.4", + "windows_x86_64_gnu 0.52.4", + "windows_x86_64_gnullvm 0.52.4", + "windows_x86_64_msvc 0.52.4", ] [[package]] @@ -3929,9 +3938,9 @@ checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" [[package]] name = "windows_aarch64_gnullvm" -version = "0.52.3" +version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "68e5dcfb9413f53afd9c8f86e56a7b4d86d9a2fa26090ea2dc9e40fba56c6ec6" +checksum = "bcf46cf4c365c6f2d1cc93ce535f2c8b244591df96ceee75d8e83deb70a9cac9" [[package]] name = "windows_aarch64_msvc" @@ -3941,9 +3950,9 @@ checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" [[package]] name = "windows_aarch64_msvc" -version = "0.52.3" +version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8dab469ebbc45798319e69eebf92308e541ce46760b49b18c6b3fe5e8965b30f" +checksum = "da9f259dd3bcf6990b55bffd094c4f7235817ba4ceebde8e6d11cd0c5633b675" [[package]] name = "windows_i686_gnu" @@ -3953,9 +3962,9 @@ checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" [[package]] name = "windows_i686_gnu" -version = "0.52.3" +version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2a4e9b6a7cac734a8b4138a4e1044eac3404d8326b6c0f939276560687a033fb" +checksum = "b474d8268f99e0995f25b9f095bc7434632601028cf86590aea5c8a5cb7801d3" [[package]] name = "windows_i686_msvc" @@ -3965,9 +3974,9 @@ checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" [[package]] name = "windows_i686_msvc" -version = "0.52.3" +version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "28b0ec9c422ca95ff34a78755cfa6ad4a51371da2a5ace67500cf7ca5f232c58" +checksum = "1515e9a29e5bed743cb4415a9ecf5dfca648ce85ee42e15873c3cd8610ff8e02" [[package]] name = "windows_x86_64_gnu" @@ -3977,9 +3986,9 @@ checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" [[package]] name = "windows_x86_64_gnu" -version = "0.52.3" +version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "704131571ba93e89d7cd43482277d6632589b18ecf4468f591fbae0a8b101614" +checksum = "5eee091590e89cc02ad514ffe3ead9eb6b660aedca2183455434b93546371a03" [[package]] name = "windows_x86_64_gnullvm" @@ -3989,9 +3998,9 @@ checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" [[package]] name = "windows_x86_64_gnullvm" -version = "0.52.3" +version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42079295511643151e98d61c38c0acc444e52dd42ab456f7ccfd5152e8ecf21c" +checksum = "77ca79f2451b49fa9e2af39f0747fe999fcda4f5e241b2898624dca97a1f2177" [[package]] name = "windows_x86_64_msvc" @@ -4001,9 +4010,9 @@ checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" [[package]] name = "windows_x86_64_msvc" -version = "0.52.3" +version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0770833d60a970638e989b3fa9fd2bb1aaadcf88963d1659fd7d9990196ed2d6" +checksum = "32b752e52a2da0ddfbdbcc6fceadfeede4c939ed16d13e648833a61dfb611ed8" [[package]] name = "winreg" @@ -4047,7 +4056,7 @@ checksum = "9ce1b18ccd8e73a9321186f97e46f9f04b778851177567b1975109d26a08d2a6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.51", + "syn 2.0.52", ] [[package]] diff --git a/datafusion-examples/examples/advanced_udwf.rs b/datafusion-examples/examples/advanced_udwf.rs index 826abc28e174..41c6381df5d4 100644 --- a/datafusion-examples/examples/advanced_udwf.rs +++ b/datafusion-examples/examples/advanced_udwf.rs @@ -169,7 +169,7 @@ async fn main() -> Result<()> { // creating a new `PartitionEvaluator`) // // `ORDER BY time`: within each partition ('green' or 'red') the - // rows will be be ordered by the value in the `time` column + // rows will be ordered by the value in the `time` column // // `evaluate_inside_range` is invoked with a window defined by the // SQL. In this case: diff --git a/datafusion-examples/examples/rewrite_expr.rs b/datafusion-examples/examples/rewrite_expr.rs index 8d13d1201881..cc1396f770e4 100644 --- a/datafusion-examples/examples/rewrite_expr.rs +++ b/datafusion-examples/examples/rewrite_expr.rs @@ -17,7 +17,7 @@ use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{plan_err, Result, ScalarValue}; use datafusion_expr::{ AggregateUDF, Between, Expr, Filter, LogicalPlan, ScalarUDF, TableSource, WindowUDF, @@ -95,14 +95,15 @@ impl MyAnalyzerRule { Ok(match plan { LogicalPlan::Filter(filter) => { let predicate = Self::analyze_expr(filter.predicate.clone())?; - Transformed::Yes(LogicalPlan::Filter(Filter::try_new( + Transformed::yes(LogicalPlan::Filter(Filter::try_new( predicate, filter.input, )?)) } - _ => Transformed::No(plan), + _ => Transformed::no(plan), }) }) + .data() } fn analyze_expr(expr: Expr) -> Result { @@ -111,13 +112,14 @@ impl MyAnalyzerRule { Ok(match expr { Expr::Literal(ScalarValue::Int64(i)) => { // transform to UInt64 - Transformed::Yes(Expr::Literal(ScalarValue::UInt64( + Transformed::yes(Expr::Literal(ScalarValue::UInt64( i.map(|i| i as u64), ))) } - _ => Transformed::No(expr), + _ => Transformed::no(expr), }) }) + .data() } } @@ -175,14 +177,15 @@ fn my_rewrite(expr: Expr) -> Result { let low: Expr = *low; let high: Expr = *high; if negated { - Transformed::Yes(expr.clone().lt(low).or(expr.gt(high))) + Transformed::yes(expr.clone().lt(low).or(expr.gt(high))) } else { - Transformed::Yes(expr.clone().gt_eq(low).and(expr.lt_eq(high))) + Transformed::yes(expr.clone().gt_eq(low).and(expr.lt_eq(high))) } } - _ => Transformed::No(expr), + _ => Transformed::no(expr), }) }) + .data() } #[derive(Default)] diff --git a/datafusion-examples/examples/simple_udwf.rs b/datafusion-examples/examples/simple_udwf.rs index a6149d661e75..5555e873aeb7 100644 --- a/datafusion-examples/examples/simple_udwf.rs +++ b/datafusion-examples/examples/simple_udwf.rs @@ -72,7 +72,7 @@ async fn main() -> Result<()> { // creating a new `PartitionEvaluator`) // // `ORDER BY time`: within each partition ('green' or 'red') the - // rows will be be ordered by the value in the `time` column + // rows will be ordered by the value in the `time` column // // `evaluate_inside_range` is invoked with a window defined by the // SQL. In this case: diff --git a/datafusion/common/src/parsers.rs b/datafusion/common/src/parsers.rs index ea2508f8c455..9583ecbdb733 100644 --- a/datafusion/common/src/parsers.rs +++ b/datafusion/common/src/parsers.rs @@ -17,6 +17,7 @@ //! Interval parsing logic use sqlparser::parser::ParserError; +use std::fmt::Display; use std::result; use std::str::FromStr; @@ -54,16 +55,16 @@ impl FromStr for CompressionTypeVariant { } } -impl ToString for CompressionTypeVariant { - fn to_string(&self) -> String { - match self { +impl Display for CompressionTypeVariant { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let str = match self { Self::GZIP => "GZIP", Self::BZIP2 => "BZIP2", Self::XZ => "XZ", Self::ZSTD => "ZSTD", Self::UNCOMPRESSED => "", - } - .to_string() + }; + write!(f, "{}", str) } } diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index 6ab4507f949c..f431e6264367 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -4451,7 +4451,7 @@ mod tests { // per distinct value. // // The alignment requirements differ across architectures and - // thus the size of the enum appears to as as well + // thus the size of the enum appears to as well assert_eq!(std::mem::size_of::(), 48); } diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index c5c4ee824d61..2d653a27c47b 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -22,29 +22,74 @@ use std::sync::Arc; use crate::Result; -/// If the function returns [`VisitRecursion::Continue`], the normal execution of the -/// function continues. If it returns [`VisitRecursion::Skip`], the function returns -/// with [`VisitRecursion::Continue`] to jump next recursion step, bypassing further -/// exploration of the current step. In case of [`VisitRecursion::Stop`], the function -/// return with [`VisitRecursion::Stop`] and recursion halts. +/// This macro is used to control continuation behaviors during tree traversals +/// based on the specified direction. Depending on `$DIRECTION` and the value of +/// the given expression (`$EXPR`), which should be a variant of [`TreeNodeRecursion`], +/// the macro results in the following behavior: +/// +/// - If the expression returns [`TreeNodeRecursion::Continue`], normal execution +/// continues. +/// - If it returns [`TreeNodeRecursion::Stop`], recursion halts and propagates +/// [`TreeNodeRecursion::Stop`]. +/// - If it returns [`TreeNodeRecursion::Jump`], the continuation behavior depends +/// on the traversal direction: +/// - For `UP` direction, the function returns with [`TreeNodeRecursion::Jump`], +/// bypassing further bottom-up closures until the next top-down closure. +/// - For `DOWN` direction, the function returns with [`TreeNodeRecursion::Continue`], +/// skipping further exploration. +/// - If no direction is specified, `Jump` is treated like `Continue`. #[macro_export] -macro_rules! handle_tree_recursion { - ($EXPR:expr) => { +macro_rules! handle_visit_recursion { + // Internal helper macro for handling the `Jump` case based on the direction: + (@handle_jump UP) => { + return Ok(TreeNodeRecursion::Jump) + }; + (@handle_jump DOWN) => { + return Ok(TreeNodeRecursion::Continue) + }; + (@handle_jump) => { + {} // Treat `Jump` like `Continue`, do nothing and continue execution. + }; + + // Main macro logic with variables to handle directionality. + ($EXPR:expr $(, $DIRECTION:ident)?) => { match $EXPR { - VisitRecursion::Continue => {} - // If the recursion should skip, do not apply to its children, let - // the recursion continue: - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - // If the recursion should stop, do not apply to its children: - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), + TreeNodeRecursion::Continue => {} + TreeNodeRecursion::Jump => handle_visit_recursion!(@handle_jump $($DIRECTION)?), + TreeNodeRecursion::Stop => return Ok(TreeNodeRecursion::Stop), } }; } -/// Defines a visitable and rewriteable a tree node. This trait is -/// implemented for plans ([`ExecutionPlan`] and [`LogicalPlan`]) as -/// well as expression trees ([`PhysicalExpr`], [`Expr`]) in -/// DataFusion +/// This macro is used to determine continuation during combined transforming +/// traversals. +/// +/// Depending on the [`TreeNodeRecursion`] the bottom-up closure returns, +/// [`Transformed::try_transform_node_with()`] decides recursion continuation +/// and if state propagation is necessary. Then, the same procedure recursively +/// applies to the children of the node in question. +macro_rules! handle_transform_recursion { + ($F_DOWN:expr, $F_SELF:expr, $F_UP:expr) => {{ + let pre_visited = $F_DOWN?; + match pre_visited.tnr { + TreeNodeRecursion::Continue => pre_visited + .data + .map_children($F_SELF)? + .try_transform_node_with($F_UP, TreeNodeRecursion::Jump), + #[allow(clippy::redundant_closure_call)] + TreeNodeRecursion::Jump => $F_UP(pre_visited.data), + TreeNodeRecursion::Stop => return Ok(pre_visited), + } + .map(|mut post_visited| { + post_visited.transformed |= pre_visited.transformed; + post_visited + }) + }}; +} + +/// Defines a visitable and rewriteable tree node. This trait is implemented +/// for plans ([`ExecutionPlan`] and [`LogicalPlan`]) as well as expression +/// trees ([`PhysicalExpr`], [`Expr`]) in DataFusion. /// /// /// [`ExecutionPlan`]: https://docs.rs/datafusion/latest/datafusion/physical_plan/trait.ExecutionPlan.html @@ -52,283 +97,507 @@ macro_rules! handle_tree_recursion { /// [`LogicalPlan`]: https://docs.rs/datafusion-expr/latest/datafusion_expr/logical_plan/enum.LogicalPlan.html /// [`Expr`]: https://docs.rs/datafusion-expr/latest/datafusion_expr/expr/enum.Expr.html pub trait TreeNode: Sized { - /// Applies `op` to the node and its children. `op` is applied in a preoder way, - /// and it is controlled by [`VisitRecursion`], which means result of the `op` - /// on the self node can cause an early return. + /// Visit the tree node using the given [`TreeNodeVisitor`], performing a + /// depth-first walk of the node and its children. + /// + /// Consider the following tree structure: + /// ```text + /// ParentNode + /// left: ChildNode1 + /// right: ChildNode2 + /// ``` /// - /// The `op` closure can be used to collect some info from the - /// tree node or do some checking for the tree node. - fn apply Result>( + /// Here, the nodes would be visited using the following order: + /// ```text + /// TreeNodeVisitor::f_down(ParentNode) + /// TreeNodeVisitor::f_down(ChildNode1) + /// TreeNodeVisitor::f_up(ChildNode1) + /// TreeNodeVisitor::f_down(ChildNode2) + /// TreeNodeVisitor::f_up(ChildNode2) + /// TreeNodeVisitor::f_up(ParentNode) + /// ``` + /// + /// See [`TreeNodeRecursion`] for more details on controlling the traversal. + /// + /// If [`TreeNodeVisitor::f_down()`] or [`TreeNodeVisitor::f_up()`] returns [`Err`], + /// the recursion stops immediately. + /// + /// If using the default [`TreeNodeVisitor::f_up`] that does nothing, consider using + /// [`Self::apply`]. + fn visit>( &self, - op: &mut F, - ) -> Result { - handle_tree_recursion!(op(self)?); - self.apply_children(&mut |node| node.apply(op)) + visitor: &mut V, + ) -> Result { + match visitor.f_down(self)? { + TreeNodeRecursion::Continue => { + handle_visit_recursion!( + self.apply_children(&mut |n| n.visit(visitor))?, + UP + ); + visitor.f_up(self) + } + TreeNodeRecursion::Jump => visitor.f_up(self), + TreeNodeRecursion::Stop => Ok(TreeNodeRecursion::Stop), + } } - /// Visit the tree node using the given [TreeNodeVisitor] - /// It performs a depth first walk of an node and its children. + /// Implements the [visitor pattern](https://en.wikipedia.org/wiki/Visitor_pattern) for + /// recursively transforming [`TreeNode`]s. /// - /// For an node tree such as + /// Consider the following tree structure: /// ```text /// ParentNode /// left: ChildNode1 /// right: ChildNode2 /// ``` /// - /// The nodes are visited using the following order + /// Here, the nodes would be visited using the following order: /// ```text - /// pre_visit(ParentNode) - /// pre_visit(ChildNode1) - /// post_visit(ChildNode1) - /// pre_visit(ChildNode2) - /// post_visit(ChildNode2) - /// post_visit(ParentNode) + /// TreeNodeRewriter::f_down(ParentNode) + /// TreeNodeRewriter::f_down(ChildNode1) + /// TreeNodeRewriter::f_up(ChildNode1) + /// TreeNodeRewriter::f_down(ChildNode2) + /// TreeNodeRewriter::f_up(ChildNode2) + /// TreeNodeRewriter::f_up(ParentNode) /// ``` /// - /// If an Err result is returned, recursion is stopped immediately + /// See [`TreeNodeRecursion`] for more details on controlling the traversal. /// - /// If [`VisitRecursion::Stop`] is returned on a call to pre_visit, no - /// children of that node will be visited, nor is post_visit - /// called on that node. Details see [`TreeNodeVisitor`] + /// If [`TreeNodeVisitor::f_down()`] or [`TreeNodeVisitor::f_up()`] returns [`Err`], + /// the recursion stops immediately. + fn rewrite>( + self, + rewriter: &mut R, + ) -> Result> { + handle_transform_recursion!(rewriter.f_down(self), |c| c.rewrite(rewriter), |n| { + rewriter.f_up(n) + }) + } + + /// Applies `f` to the node and its children. `f` is applied in a pre-order + /// way, and it is controlled by [`TreeNodeRecursion`], which means result + /// of the `f` on a node can cause an early return. /// - /// If using the default [`TreeNodeVisitor::post_visit`] that does - /// nothing, [`Self::apply`] should be preferred. - fn visit>( + /// The `f` closure can be used to collect some information from tree nodes + /// or run a check on the tree. + fn apply Result>( &self, - visitor: &mut V, - ) -> Result { - handle_tree_recursion!(visitor.pre_visit(self)?); - handle_tree_recursion!(self.apply_children(&mut |node| node.visit(visitor))?); - visitor.post_visit(self) - } - - /// Convenience utils for writing optimizers rule: recursively apply the given `op` to the node tree. - /// When `op` does not apply to a given node, it is left unchanged. - /// The default tree traversal direction is transform_up(Postorder Traversal). - fn transform(self, op: &F) -> Result - where - F: Fn(Self) -> Result>, - { - self.transform_up(op) - } - - /// Convenience utils for writing optimizers rule: recursively apply the given 'op' to the node and all of its - /// children(Preorder Traversal). - /// When the `op` does not apply to a given node, it is left unchanged. - fn transform_down(self, op: &F) -> Result - where - F: Fn(Self) -> Result>, - { - let after_op = op(self)?.into(); - after_op.map_children(|node| node.transform_down(op)) - } - - /// Convenience utils for writing optimizers rule: recursively apply the given 'op' to the node and all of its - /// children(Preorder Traversal) using a mutable function, `F`. - /// When the `op` does not apply to a given node, it is left unchanged. - fn transform_down_mut(self, op: &mut F) -> Result - where - F: FnMut(Self) -> Result>, - { - let after_op = op(self)?.into(); - after_op.map_children(|node| node.transform_down_mut(op)) - } - - /// Convenience utils for writing optimizers rule: recursively apply the given 'op' first to all of its - /// children and then itself(Postorder Traversal). - /// When the `op` does not apply to a given node, it is left unchanged. - fn transform_up(self, op: &F) -> Result - where - F: Fn(Self) -> Result>, - { - let after_op_children = self.map_children(|node| node.transform_up(op))?; - let new_node = op(after_op_children)?.into(); - Ok(new_node) - } - - /// Convenience utils for writing optimizers rule: recursively apply the given 'op' first to all of its - /// children and then itself(Postorder Traversal) using a mutable function, `F`. - /// When the `op` does not apply to a given node, it is left unchanged. - fn transform_up_mut(self, op: &mut F) -> Result - where - F: FnMut(Self) -> Result>, - { - let after_op_children = self.map_children(|node| node.transform_up_mut(op))?; - let new_node = op(after_op_children)?.into(); - Ok(new_node) - } - - /// Transform the tree node using the given [TreeNodeRewriter] - /// It performs a depth first walk of an node and its children. + f: &mut F, + ) -> Result { + handle_visit_recursion!(f(self)?, DOWN); + self.apply_children(&mut |n| n.apply(f)) + } + + /// Convenience utility for writing optimizer rules: Recursively apply the + /// given function `f` to the tree in a bottom-up (post-order) fashion. When + /// `f` does not apply to a given node, it is left unchanged. + fn transform Result>>( + self, + f: &F, + ) -> Result> { + self.transform_up(f) + } + + /// Convenience utility for writing optimizer rules: Recursively apply the + /// given function `f` to a node and then to its children (pre-order traversal). + /// When `f` does not apply to a given node, it is left unchanged. + fn transform_down Result>>( + self, + f: &F, + ) -> Result> { + f(self)?.try_transform_node_with( + |n| n.map_children(|c| c.transform_down(f)), + TreeNodeRecursion::Continue, + ) + } + + /// Convenience utility for writing optimizer rules: Recursively apply the + /// given mutable function `f` to a node and then to its children (pre-order + /// traversal). When `f` does not apply to a given node, it is left unchanged. + fn transform_down_mut Result>>( + self, + f: &mut F, + ) -> Result> { + f(self)?.try_transform_node_with( + |n| n.map_children(|c| c.transform_down_mut(f)), + TreeNodeRecursion::Continue, + ) + } + + /// Convenience utility for writing optimizer rules: Recursively apply the + /// given function `f` to all children of a node, and then to the node itself + /// (post-order traversal). When `f` does not apply to a given node, it is + /// left unchanged. + fn transform_up Result>>( + self, + f: &F, + ) -> Result> { + self.map_children(|c| c.transform_up(f))? + .try_transform_node_with(f, TreeNodeRecursion::Jump) + } + + /// Convenience utility for writing optimizer rules: Recursively apply the + /// given mutable function `f` to all children of a node, and then to the + /// node itself (post-order traversal). When `f` does not apply to a given + /// node, it is left unchanged. + fn transform_up_mut Result>>( + self, + f: &mut F, + ) -> Result> { + self.map_children(|c| c.transform_up_mut(f))? + .try_transform_node_with(f, TreeNodeRecursion::Jump) + } + + /// Transforms the tree using `f_down` while traversing the tree top-down + /// (pre-order), and using `f_up` while traversing the tree bottom-up + /// (post-order). + /// + /// Use this method if you want to start the `f_up` process right where `f_down` jumps. + /// This can make the whole process faster by reducing the number of `f_up` steps. + /// If you don't need this, it's just like using `transform_down_mut` followed by + /// `transform_up_mut` on the same tree. /// - /// For an node tree such as + /// Consider the following tree structure: /// ```text /// ParentNode /// left: ChildNode1 /// right: ChildNode2 /// ``` /// - /// The nodes are visited using the following order + /// The nodes are visited using the following order: /// ```text - /// pre_visit(ParentNode) - /// pre_visit(ChildNode1) - /// mutate(ChildNode1) - /// pre_visit(ChildNode2) - /// mutate(ChildNode2) - /// mutate(ParentNode) + /// f_down(ParentNode) + /// f_down(ChildNode1) + /// f_up(ChildNode1) + /// f_down(ChildNode2) + /// f_up(ChildNode2) + /// f_up(ParentNode) /// ``` /// - /// If an Err result is returned, recursion is stopped immediately + /// See [`TreeNodeRecursion`] for more details on controlling the traversal. /// - /// If [`false`] is returned on a call to pre_visit, no - /// children of that node will be visited, nor is mutate - /// called on that node + /// If `f_down` or `f_up` returns [`Err`], the recursion stops immediately. /// - /// If using the default [`TreeNodeRewriter::pre_visit`] which - /// returns `true`, [`Self::transform`] should be preferred. - fn rewrite>(self, rewriter: &mut R) -> Result { - let need_mutate = match rewriter.pre_visit(&self)? { - RewriteRecursion::Mutate => return rewriter.mutate(self), - RewriteRecursion::Stop => return Ok(self), - RewriteRecursion::Continue => true, - RewriteRecursion::Skip => false, - }; - - let after_op_children = self.map_children(|node| node.rewrite(rewriter))?; - - // now rewrite this node itself - if need_mutate { - rewriter.mutate(after_op_children) - } else { - Ok(after_op_children) - } + /// Example: + /// ```text + /// | +---+ + /// | | J | + /// | +---+ + /// | | + /// | +---+ + /// TreeNodeRecursion::Continue | | I | + /// | +---+ + /// | | + /// | +---+ + /// \|/ | F | + /// ' +---+ + /// / \ ___________________ + /// When `f_down` is +---+ \ ---+ + /// applied on node "E", | E | | G | + /// it returns with "Jump". +---+ +---+ + /// | | + /// +---+ +---+ + /// | C | | H | + /// +---+ +---+ + /// / \ + /// +---+ +---+ + /// | B | | D | + /// +---+ +---+ + /// | + /// +---+ + /// | A | + /// +---+ + /// + /// Instead of starting from leaf nodes, `f_up` starts from the node "E". + /// +---+ + /// | | J | + /// | +---+ + /// | | + /// | +---+ + /// | | I | + /// | +---+ + /// | | + /// / +---+ + /// / | F | + /// / +---+ + /// / / \ ______________________ + /// | +---+ . \ ---+ + /// | | E | /|\ After `f_down` jumps | G | + /// | +---+ | on node E, `f_up` +---+ + /// \------| ---/ if applied on node E. | + /// +---+ +---+ + /// | C | | H | + /// +---+ +---+ + /// / \ + /// +---+ +---+ + /// | B | | D | + /// +---+ +---+ + /// | + /// +---+ + /// | A | + /// +---+ + /// ``` + fn transform_down_up< + FD: FnMut(Self) -> Result>, + FU: FnMut(Self) -> Result>, + >( + self, + f_down: &mut FD, + f_up: &mut FU, + ) -> Result> { + handle_transform_recursion!( + f_down(self), + |c| c.transform_down_up(f_down, f_up), + f_up + ) } - /// Apply the closure `F` to the node's children - fn apply_children(&self, op: &mut F) -> Result - where - F: FnMut(&Self) -> Result; + /// Apply the closure `F` to the node's children. + fn apply_children Result>( + &self, + f: &mut F, + ) -> Result; - /// Apply transform `F` to the node's children, the transform `F` might have a direction(Preorder or Postorder) - fn map_children(self, transform: F) -> Result - where - F: FnMut(Self) -> Result; + /// Apply transform `F` to the node's children. Note that the transform `F` + /// might have a direction (pre-order or post-order). + fn map_children Result>>( + self, + f: F, + ) -> Result>; } -/// Implements the [visitor -/// pattern](https://en.wikipedia.org/wiki/Visitor_pattern) for recursively walking [`TreeNode`]s. -/// -/// [`TreeNodeVisitor`] allows keeping the algorithms -/// separate from the code to traverse the structure of the `TreeNode` -/// tree and makes it easier to add new types of tree node and -/// algorithms. -/// -/// When passed to[`TreeNode::visit`], [`TreeNodeVisitor::pre_visit`] -/// and [`TreeNodeVisitor::post_visit`] are invoked recursively -/// on an node tree. +/// Implements the [visitor pattern](https://en.wikipedia.org/wiki/Visitor_pattern) +/// for recursively walking [`TreeNode`]s. /// -/// If an [`Err`] result is returned, recursion is stopped -/// immediately. +/// A [`TreeNodeVisitor`] allows one to express algorithms separately from the +/// code traversing the structure of the `TreeNode` tree, making it easier to +/// add new types of tree nodes and algorithms. /// -/// If [`VisitRecursion::Stop`] is returned on a call to pre_visit, no -/// children of that tree node are visited, nor is post_visit -/// called on that tree node -/// -/// If [`VisitRecursion::Stop`] is returned on a call to post_visit, no -/// siblings of that tree node are visited, nor is post_visit -/// called on its parent tree node -/// -/// If [`VisitRecursion::Skip`] is returned on a call to pre_visit, no -/// children of that tree node are visited. +/// When passed to [`TreeNode::visit`], [`TreeNodeVisitor::f_down`] and +/// [`TreeNodeVisitor::f_up`] are invoked recursively on the tree. +/// See [`TreeNodeRecursion`] for more details on controlling the traversal. pub trait TreeNodeVisitor: Sized { /// The node type which is visitable. - type N: TreeNode; + type Node: TreeNode; /// Invoked before any children of `node` are visited. - fn pre_visit(&mut self, node: &Self::N) -> Result; + /// Default implementation simply continues the recursion. + fn f_down(&mut self, _node: &Self::Node) -> Result { + Ok(TreeNodeRecursion::Continue) + } - /// Invoked after all children of `node` are visited. Default - /// implementation does nothing. - fn post_visit(&mut self, _node: &Self::N) -> Result { - Ok(VisitRecursion::Continue) + /// Invoked after all children of `node` are visited. + /// Default implementation simply continues the recursion. + fn f_up(&mut self, _node: &Self::Node) -> Result { + Ok(TreeNodeRecursion::Continue) } } -/// Trait for potentially recursively transform an [`TreeNode`] node -/// tree. When passed to `TreeNode::rewrite`, `TreeNodeRewriter::mutate` is -/// invoked recursively on all nodes of a tree. +/// Trait for potentially recursively transforming a tree of [`TreeNode`]s. pub trait TreeNodeRewriter: Sized { /// The node type which is rewritable. - type N: TreeNode; + type Node: TreeNode; - /// Invoked before (Preorder) any children of `node` are rewritten / - /// visited. Default implementation returns `Ok(Recursion::Continue)` - fn pre_visit(&mut self, _node: &Self::N) -> Result { - Ok(RewriteRecursion::Continue) + /// Invoked while traversing down the tree before any children are rewritten. + /// Default implementation returns the node as is and continues recursion. + fn f_down(&mut self, node: Self::Node) -> Result> { + Ok(Transformed::no(node)) } - /// Invoked after (Postorder) all children of `node` have been mutated and - /// returns a potentially modified node. - fn mutate(&mut self, node: Self::N) -> Result; -} - -/// Controls how the [`TreeNode`] recursion should proceed for [`TreeNode::rewrite`]. -#[derive(Debug)] -pub enum RewriteRecursion { - /// Continue rewrite this node tree. - Continue, - /// Call 'op' immediately and return. - Mutate, - /// Do not rewrite the children of this node. - Stop, - /// Keep recursive but skip apply op on this node - Skip, + /// Invoked while traversing up the tree after all children have been rewritten. + /// Default implementation returns the node as is and continues recursion. + fn f_up(&mut self, node: Self::Node) -> Result> { + Ok(Transformed::no(node)) + } } -/// Controls how the [`TreeNode`] recursion should proceed for [`TreeNode::visit`]. -#[derive(Debug)] -pub enum VisitRecursion { - /// Continue the visit to this node tree. +/// Controls how [`TreeNode`] recursions should proceed. +#[derive(Debug, PartialEq, Clone, Copy)] +pub enum TreeNodeRecursion { + /// Continue recursion with the next node. Continue, - /// Keep recursive but skip applying op on the children - Skip, - /// Stop the visit to this node tree. + /// In top-down traversals, skip recursing into children but continue with + /// the next node, which actually means pruning of the subtree. + /// + /// In bottom-up traversals, bypass calling bottom-up closures till the next + /// leaf node. + /// + /// In combined traversals, if it is the `f_down` (pre-order) phase, execution + /// "jumps" to the next `f_up` (post-order) phase by shortcutting its children. + /// If it is the `f_up` (post-order) phase, execution "jumps" to the next `f_down` + /// (pre-order) phase by shortcutting its parent nodes until the first parent node + /// having unvisited children path. + Jump, + /// Stop recursion. Stop, } -pub enum Transformed { - /// The item was transformed / rewritten somehow - Yes(T), - /// The item was not transformed - No(T), +/// This struct is used by tree transformation APIs such as +/// - [`TreeNode::rewrite`], +/// - [`TreeNode::transform_down`], +/// - [`TreeNode::transform_down_mut`], +/// - [`TreeNode::transform_up`], +/// - [`TreeNode::transform_up_mut`], +/// - [`TreeNode::transform_down_up`] +/// +/// to control the transformation and return the transformed result. +/// +/// Specifically, API users can provide transformation closures or [`TreeNodeRewriter`] +/// implementations to control the transformation by returning: +/// - The resulting (possibly transformed) node, +/// - A flag indicating whether any change was made to the node, and +/// - A flag specifying how to proceed with the recursion. +/// +/// At the end of the transformation, the return value will contain: +/// - The final (possibly transformed) tree, +/// - A flag indicating whether any change was made to the tree, and +/// - A flag specifying how the recursion ended. +#[derive(PartialEq, Debug)] +pub struct Transformed { + pub data: T, + pub transformed: bool, + pub tnr: TreeNodeRecursion, } impl Transformed { - pub fn into(self) -> T { - match self { - Transformed::Yes(t) => t, - Transformed::No(t) => t, + /// Create a new `Transformed` object with the given information. + pub fn new(data: T, transformed: bool, tnr: TreeNodeRecursion) -> Self { + Self { + data, + transformed, + tnr, } } - pub fn into_pair(self) -> (T, bool) { - match self { - Transformed::Yes(t) => (t, true), - Transformed::No(t) => (t, false), + /// Wrapper for transformed data with [`TreeNodeRecursion::Continue`] statement. + pub fn yes(data: T) -> Self { + Self::new(data, true, TreeNodeRecursion::Continue) + } + + /// Wrapper for unchanged data with [`TreeNodeRecursion::Continue`] statement. + pub fn no(data: T) -> Self { + Self::new(data, false, TreeNodeRecursion::Continue) + } + + /// Applies the given `f` to the data of this [`Transformed`] object. + pub fn update_data U>(self, f: F) -> Transformed { + Transformed::new(f(self.data), self.transformed, self.tnr) + } + + /// Maps the data of [`Transformed`] object to the result of the given `f`. + pub fn map_data Result>(self, f: F) -> Result> { + f(self.data).map(|data| Transformed::new(data, self.transformed, self.tnr)) + } + + /// Handling [`TreeNodeRecursion::Continue`] and [`TreeNodeRecursion::Stop`] + /// is straightforward, but [`TreeNodeRecursion::Jump`] can behave differently + /// when we are traversing down or up on a tree. If [`TreeNodeRecursion`] of + /// the node is [`TreeNodeRecursion::Jump`], recursion stops with the given + /// `return_if_jump` value. + fn try_transform_node_with Result>>( + mut self, + f: F, + return_if_jump: TreeNodeRecursion, + ) -> Result> { + match self.tnr { + TreeNodeRecursion::Continue => { + return f(self.data).map(|mut t| { + t.transformed |= self.transformed; + t + }); + } + TreeNodeRecursion::Jump => { + self.tnr = return_if_jump; + } + TreeNodeRecursion::Stop => {} + } + Ok(self) + } + + /// If [`TreeNodeRecursion`] of the node is [`TreeNodeRecursion::Continue`] or + /// [`TreeNodeRecursion::Jump`], transformation is applied to the node. + /// Otherwise, it remains as it is. + pub fn try_transform_node Result>>( + self, + f: F, + ) -> Result> { + if self.tnr == TreeNodeRecursion::Stop { + Ok(self) + } else { + f(self.data).map(|mut t| { + t.transformed |= self.transformed; + t + }) } } } -/// Helper trait for implementing [`TreeNode`] that have children stored as Arc's -/// -/// If some trait object, such as `dyn T`, implements this trait, -/// its related `Arc` will automatically implement [`TreeNode`] +/// Transformation helper to process tree nodes that are siblings. +pub trait TransformedIterator: Iterator { + fn map_until_stop_and_collect< + F: FnMut(Self::Item) -> Result>, + >( + self, + f: F, + ) -> Result>>; +} + +impl TransformedIterator for I { + fn map_until_stop_and_collect< + F: FnMut(Self::Item) -> Result>, + >( + self, + mut f: F, + ) -> Result>> { + let mut tnr = TreeNodeRecursion::Continue; + let mut transformed = false; + let data = self + .map(|item| match tnr { + TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => { + f(item).map(|result| { + tnr = result.tnr; + transformed |= result.transformed; + result.data + }) + } + TreeNodeRecursion::Stop => Ok(item), + }) + .collect::>>()?; + Ok(Transformed::new(data, transformed, tnr)) + } +} + +/// Transformation helper to access [`Transformed`] fields in a [`Result`] easily. +pub trait TransformedResult { + fn data(self) -> Result; + + fn transformed(self) -> Result; + + fn tnr(self) -> Result; +} + +impl TransformedResult for Result> { + fn data(self) -> Result { + self.map(|t| t.data) + } + + fn transformed(self) -> Result { + self.map(|t| t.transformed) + } + + fn tnr(self) -> Result { + self.map(|t| t.tnr) + } +} + +/// Helper trait for implementing [`TreeNode`] that have children stored as +/// `Arc`s. If some trait object, such as `dyn T`, implements this trait, +/// its related `Arc` will automatically implement [`TreeNode`]. pub trait DynTreeNode { - /// Returns all children of the specified TreeNode + /// Returns all children of the specified `TreeNode`. fn arc_children(&self) -> Vec>; - /// construct a new self with the specified children + /// Constructs a new node with the specified children. fn with_new_arc_children( &self, arc_self: Arc, @@ -336,32 +605,40 @@ pub trait DynTreeNode { ) -> Result>; } -/// Blanket implementation for Arc for any tye that implements -/// [`DynTreeNode`] (such as [`Arc`]) +/// Blanket implementation for any `Arc` where `T` implements [`DynTreeNode`] +/// (such as [`Arc`]). impl TreeNode for Arc { - /// Apply the closure `F` to the node's children - fn apply_children(&self, op: &mut F) -> Result - where - F: FnMut(&Self) -> Result, - { + fn apply_children Result>( + &self, + f: &mut F, + ) -> Result { + let mut tnr = TreeNodeRecursion::Continue; for child in self.arc_children() { - handle_tree_recursion!(op(&child)?) + tnr = f(&child)?; + handle_visit_recursion!(tnr) } - Ok(VisitRecursion::Continue) + Ok(tnr) } - fn map_children(self, transform: F) -> Result - where - F: FnMut(Self) -> Result, - { + fn map_children Result>>( + self, + f: F, + ) -> Result> { let children = self.arc_children(); if !children.is_empty() { - let new_children = - children.into_iter().map(transform).collect::>()?; - let arc_self = Arc::clone(&self); - self.with_new_arc_children(arc_self, new_children) + let new_children = children.into_iter().map_until_stop_and_collect(f)?; + // Propagate up `new_children.transformed` and `new_children.tnr` + // along with the node containing transformed children. + if new_children.transformed { + let arc_self = Arc::clone(&self); + new_children.map_data(|new_children| { + self.with_new_arc_children(arc_self, new_children) + }) + } else { + Ok(Transformed::new(self, false, new_children.tnr)) + } } else { - Ok(self) + Ok(Transformed::no(self)) } } } @@ -381,28 +658,1016 @@ pub trait ConcreteTreeNode: Sized { } impl TreeNode for T { - /// Apply the closure `F` to the node's children - fn apply_children(&self, op: &mut F) -> Result - where - F: FnMut(&Self) -> Result, - { + fn apply_children Result>( + &self, + f: &mut F, + ) -> Result { + let mut tnr = TreeNodeRecursion::Continue; for child in self.children() { - handle_tree_recursion!(op(child)?) + tnr = f(child)?; + handle_visit_recursion!(tnr) } - Ok(VisitRecursion::Continue) + Ok(tnr) } - fn map_children(self, transform: F) -> Result - where - F: FnMut(Self) -> Result, - { + fn map_children Result>>( + self, + f: F, + ) -> Result> { let (new_self, children) = self.take_children(); if !children.is_empty() { - let new_children = - children.into_iter().map(transform).collect::>()?; - new_self.with_new_children(new_children) + let new_children = children.into_iter().map_until_stop_and_collect(f)?; + // Propagate up `new_children.transformed` and `new_children.tnr` along with + // the node containing transformed children. + new_children.map_data(|new_children| new_self.with_new_children(new_children)) } else { - Ok(new_self) + Ok(Transformed::no(new_self)) } } } + +#[cfg(test)] +mod tests { + use std::fmt::Display; + + use crate::tree_node::{ + Transformed, TransformedIterator, TreeNode, TreeNodeRecursion, TreeNodeRewriter, + TreeNodeVisitor, + }; + use crate::Result; + + #[derive(PartialEq, Debug)] + struct TestTreeNode { + children: Vec>, + data: T, + } + + impl TestTreeNode { + fn new(children: Vec>, data: T) -> Self { + Self { children, data } + } + } + + impl TreeNode for TestTreeNode { + fn apply_children(&self, f: &mut F) -> Result + where + F: FnMut(&Self) -> Result, + { + let mut tnr = TreeNodeRecursion::Continue; + for child in &self.children { + tnr = f(child)?; + handle_visit_recursion!(tnr); + } + Ok(tnr) + } + + fn map_children(self, f: F) -> Result> + where + F: FnMut(Self) -> Result>, + { + Ok(self + .children + .into_iter() + .map_until_stop_and_collect(f)? + .update_data(|new_children| Self { + children: new_children, + ..self + })) + } + } + + // J + // | + // I + // | + // F + // / \ + // E G + // | | + // C H + // / \ + // B D + // | + // A + fn test_tree() -> TestTreeNode { + let node_a = TestTreeNode::new(vec![], "a".to_string()); + let node_b = TestTreeNode::new(vec![], "b".to_string()); + let node_d = TestTreeNode::new(vec![node_a], "d".to_string()); + let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string()); + let node_e = TestTreeNode::new(vec![node_c], "e".to_string()); + let node_h = TestTreeNode::new(vec![], "h".to_string()); + let node_g = TestTreeNode::new(vec![node_h], "g".to_string()); + let node_f = TestTreeNode::new(vec![node_e, node_g], "f".to_string()); + let node_i = TestTreeNode::new(vec![node_f], "i".to_string()); + TestTreeNode::new(vec![node_i], "j".to_string()) + } + + // Continue on all nodes + // Expected visits in a combined traversal + fn all_visits() -> Vec { + vec![ + "f_down(j)", + "f_down(i)", + "f_down(f)", + "f_down(e)", + "f_down(c)", + "f_down(b)", + "f_up(b)", + "f_down(d)", + "f_down(a)", + "f_up(a)", + "f_up(d)", + "f_up(c)", + "f_up(e)", + "f_down(g)", + "f_down(h)", + "f_up(h)", + "f_up(g)", + "f_up(f)", + "f_up(i)", + "f_up(j)", + ] + .into_iter() + .map(|s| s.to_string()) + .collect() + } + + // Expected transformed tree after a combined traversal + fn transformed_tree() -> TestTreeNode { + let node_a = TestTreeNode::new(vec![], "f_up(f_down(a))".to_string()); + let node_b = TestTreeNode::new(vec![], "f_up(f_down(b))".to_string()); + let node_d = TestTreeNode::new(vec![node_a], "f_up(f_down(d))".to_string()); + let node_c = + TestTreeNode::new(vec![node_b, node_d], "f_up(f_down(c))".to_string()); + let node_e = TestTreeNode::new(vec![node_c], "f_up(f_down(e))".to_string()); + let node_h = TestTreeNode::new(vec![], "f_up(f_down(h))".to_string()); + let node_g = TestTreeNode::new(vec![node_h], "f_up(f_down(g))".to_string()); + let node_f = + TestTreeNode::new(vec![node_e, node_g], "f_up(f_down(f))".to_string()); + let node_i = TestTreeNode::new(vec![node_f], "f_up(f_down(i))".to_string()); + TestTreeNode::new(vec![node_i], "f_up(f_down(j))".to_string()) + } + + // Expected transformed tree after a top-down traversal + fn transformed_down_tree() -> TestTreeNode { + let node_a = TestTreeNode::new(vec![], "f_down(a)".to_string()); + let node_b = TestTreeNode::new(vec![], "f_down(b)".to_string()); + let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string()); + let node_c = TestTreeNode::new(vec![node_b, node_d], "f_down(c)".to_string()); + let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); + let node_h = TestTreeNode::new(vec![], "f_down(h)".to_string()); + let node_g = TestTreeNode::new(vec![node_h], "f_down(g)".to_string()); + let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string()); + let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string()); + TestTreeNode::new(vec![node_i], "f_down(j)".to_string()) + } + + // Expected transformed tree after a bottom-up traversal + fn transformed_up_tree() -> TestTreeNode { + let node_a = TestTreeNode::new(vec![], "f_up(a)".to_string()); + let node_b = TestTreeNode::new(vec![], "f_up(b)".to_string()); + let node_d = TestTreeNode::new(vec![node_a], "f_up(d)".to_string()); + let node_c = TestTreeNode::new(vec![node_b, node_d], "f_up(c)".to_string()); + let node_e = TestTreeNode::new(vec![node_c], "f_up(e)".to_string()); + let node_h = TestTreeNode::new(vec![], "f_up(h)".to_string()); + let node_g = TestTreeNode::new(vec![node_h], "f_up(g)".to_string()); + let node_f = TestTreeNode::new(vec![node_e, node_g], "f_up(f)".to_string()); + let node_i = TestTreeNode::new(vec![node_f], "f_up(i)".to_string()); + TestTreeNode::new(vec![node_i], "f_up(j)".to_string()) + } + + // f_down Jump on A node + fn f_down_jump_on_a_visits() -> Vec { + vec![ + "f_down(j)", + "f_down(i)", + "f_down(f)", + "f_down(e)", + "f_down(c)", + "f_down(b)", + "f_up(b)", + "f_down(d)", + "f_down(a)", + "f_up(a)", + "f_up(d)", + "f_up(c)", + "f_up(e)", + "f_down(g)", + "f_down(h)", + "f_up(h)", + "f_up(g)", + "f_up(f)", + "f_up(i)", + "f_up(j)", + ] + .into_iter() + .map(|s| s.to_string()) + .collect() + } + + fn f_down_jump_on_a_transformed_down_tree() -> TestTreeNode { + let node_a = TestTreeNode::new(vec![], "f_down(a)".to_string()); + let node_b = TestTreeNode::new(vec![], "f_down(b)".to_string()); + let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string()); + let node_c = TestTreeNode::new(vec![node_b, node_d], "f_down(c)".to_string()); + let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); + let node_h = TestTreeNode::new(vec![], "f_down(h)".to_string()); + let node_g = TestTreeNode::new(vec![node_h], "f_down(g)".to_string()); + let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string()); + let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string()); + TestTreeNode::new(vec![node_i], "f_down(j)".to_string()) + } + + // f_down Jump on E node + fn f_down_jump_on_e_visits() -> Vec { + vec![ + "f_down(j)", + "f_down(i)", + "f_down(f)", + "f_down(e)", + "f_up(e)", + "f_down(g)", + "f_down(h)", + "f_up(h)", + "f_up(g)", + "f_up(f)", + "f_up(i)", + "f_up(j)", + ] + .into_iter() + .map(|s| s.to_string()) + .collect() + } + + fn f_down_jump_on_e_transformed_tree() -> TestTreeNode { + let node_a = TestTreeNode::new(vec![], "a".to_string()); + let node_b = TestTreeNode::new(vec![], "b".to_string()); + let node_d = TestTreeNode::new(vec![node_a], "d".to_string()); + let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string()); + let node_e = TestTreeNode::new(vec![node_c], "f_up(f_down(e))".to_string()); + let node_h = TestTreeNode::new(vec![], "f_up(f_down(h))".to_string()); + let node_g = TestTreeNode::new(vec![node_h], "f_up(f_down(g))".to_string()); + let node_f = + TestTreeNode::new(vec![node_e, node_g], "f_up(f_down(f))".to_string()); + let node_i = TestTreeNode::new(vec![node_f], "f_up(f_down(i))".to_string()); + TestTreeNode::new(vec![node_i], "f_up(f_down(j))".to_string()) + } + + fn f_down_jump_on_e_transformed_down_tree() -> TestTreeNode { + let node_a = TestTreeNode::new(vec![], "a".to_string()); + let node_b = TestTreeNode::new(vec![], "b".to_string()); + let node_d = TestTreeNode::new(vec![node_a], "d".to_string()); + let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string()); + let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); + let node_h = TestTreeNode::new(vec![], "f_down(h)".to_string()); + let node_g = TestTreeNode::new(vec![node_h], "f_down(g)".to_string()); + let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string()); + let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string()); + TestTreeNode::new(vec![node_i], "f_down(j)".to_string()) + } + + // f_up Jump on A node + fn f_up_jump_on_a_visits() -> Vec { + vec![ + "f_down(j)", + "f_down(i)", + "f_down(f)", + "f_down(e)", + "f_down(c)", + "f_down(b)", + "f_up(b)", + "f_down(d)", + "f_down(a)", + "f_up(a)", + "f_down(g)", + "f_down(h)", + "f_up(h)", + "f_up(g)", + "f_up(f)", + "f_up(i)", + "f_up(j)", + ] + .into_iter() + .map(|s| s.to_string()) + .collect() + } + + fn f_up_jump_on_a_transformed_tree() -> TestTreeNode { + let node_a = TestTreeNode::new(vec![], "f_up(f_down(a))".to_string()); + let node_b = TestTreeNode::new(vec![], "f_up(f_down(b))".to_string()); + let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string()); + let node_c = TestTreeNode::new(vec![node_b, node_d], "f_down(c)".to_string()); + let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); + let node_h = TestTreeNode::new(vec![], "f_up(f_down(h))".to_string()); + let node_g = TestTreeNode::new(vec![node_h], "f_up(f_down(g))".to_string()); + let node_f = + TestTreeNode::new(vec![node_e, node_g], "f_up(f_down(f))".to_string()); + let node_i = TestTreeNode::new(vec![node_f], "f_up(f_down(i))".to_string()); + TestTreeNode::new(vec![node_i], "f_up(f_down(j))".to_string()) + } + + fn f_up_jump_on_a_transformed_up_tree() -> TestTreeNode { + let node_a = TestTreeNode::new(vec![], "f_up(a)".to_string()); + let node_b = TestTreeNode::new(vec![], "f_up(b)".to_string()); + let node_d = TestTreeNode::new(vec![node_a], "d".to_string()); + let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string()); + let node_e = TestTreeNode::new(vec![node_c], "e".to_string()); + let node_h = TestTreeNode::new(vec![], "f_up(h)".to_string()); + let node_g = TestTreeNode::new(vec![node_h], "f_up(g)".to_string()); + let node_f = TestTreeNode::new(vec![node_e, node_g], "f_up(f)".to_string()); + let node_i = TestTreeNode::new(vec![node_f], "f_up(i)".to_string()); + TestTreeNode::new(vec![node_i], "f_up(j)".to_string()) + } + + // f_up Jump on E node + fn f_up_jump_on_e_visits() -> Vec { + vec![ + "f_down(j)", + "f_down(i)", + "f_down(f)", + "f_down(e)", + "f_down(c)", + "f_down(b)", + "f_up(b)", + "f_down(d)", + "f_down(a)", + "f_up(a)", + "f_up(d)", + "f_up(c)", + "f_up(e)", + "f_down(g)", + "f_down(h)", + "f_up(h)", + "f_up(g)", + "f_up(f)", + "f_up(i)", + "f_up(j)", + ] + .into_iter() + .map(|s| s.to_string()) + .collect() + } + + fn f_up_jump_on_e_transformed_tree() -> TestTreeNode { + transformed_tree() + } + + fn f_up_jump_on_e_transformed_up_tree() -> TestTreeNode { + transformed_up_tree() + } + + // f_down Stop on A node + + fn f_down_stop_on_a_visits() -> Vec { + vec![ + "f_down(j)", + "f_down(i)", + "f_down(f)", + "f_down(e)", + "f_down(c)", + "f_down(b)", + "f_up(b)", + "f_down(d)", + "f_down(a)", + ] + .into_iter() + .map(|s| s.to_string()) + .collect() + } + + fn f_down_stop_on_a_transformed_tree() -> TestTreeNode { + let node_a = TestTreeNode::new(vec![], "f_down(a)".to_string()); + let node_b = TestTreeNode::new(vec![], "f_up(f_down(b))".to_string()); + let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string()); + let node_c = TestTreeNode::new(vec![node_b, node_d], "f_down(c)".to_string()); + let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); + let node_h = TestTreeNode::new(vec![], "h".to_string()); + let node_g = TestTreeNode::new(vec![node_h], "g".to_string()); + let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string()); + let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string()); + TestTreeNode::new(vec![node_i], "f_down(j)".to_string()) + } + + fn f_down_stop_on_a_transformed_down_tree() -> TestTreeNode { + let node_a = TestTreeNode::new(vec![], "f_down(a)".to_string()); + let node_b = TestTreeNode::new(vec![], "f_down(b)".to_string()); + let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string()); + let node_c = TestTreeNode::new(vec![node_b, node_d], "f_down(c)".to_string()); + let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); + let node_h = TestTreeNode::new(vec![], "h".to_string()); + let node_g = TestTreeNode::new(vec![node_h], "g".to_string()); + let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string()); + let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string()); + TestTreeNode::new(vec![node_i], "f_down(j)".to_string()) + } + + // f_down Stop on E node + fn f_down_stop_on_e_visits() -> Vec { + vec!["f_down(j)", "f_down(i)", "f_down(f)", "f_down(e)"] + .into_iter() + .map(|s| s.to_string()) + .collect() + } + + fn f_down_stop_on_e_transformed_tree() -> TestTreeNode { + let node_a = TestTreeNode::new(vec![], "a".to_string()); + let node_b = TestTreeNode::new(vec![], "b".to_string()); + let node_d = TestTreeNode::new(vec![node_a], "d".to_string()); + let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string()); + let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); + let node_h = TestTreeNode::new(vec![], "h".to_string()); + let node_g = TestTreeNode::new(vec![node_h], "g".to_string()); + let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string()); + let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string()); + TestTreeNode::new(vec![node_i], "f_down(j)".to_string()) + } + + fn f_down_stop_on_e_transformed_down_tree() -> TestTreeNode { + let node_a = TestTreeNode::new(vec![], "a".to_string()); + let node_b = TestTreeNode::new(vec![], "b".to_string()); + let node_d = TestTreeNode::new(vec![node_a], "d".to_string()); + let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string()); + let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); + let node_h = TestTreeNode::new(vec![], "h".to_string()); + let node_g = TestTreeNode::new(vec![node_h], "g".to_string()); + let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string()); + let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string()); + TestTreeNode::new(vec![node_i], "f_down(j)".to_string()) + } + + // f_up Stop on A node + fn f_up_stop_on_a_visits() -> Vec { + vec![ + "f_down(j)", + "f_down(i)", + "f_down(f)", + "f_down(e)", + "f_down(c)", + "f_down(b)", + "f_up(b)", + "f_down(d)", + "f_down(a)", + "f_up(a)", + ] + .into_iter() + .map(|s| s.to_string()) + .collect() + } + + fn f_up_stop_on_a_transformed_tree() -> TestTreeNode { + let node_a = TestTreeNode::new(vec![], "f_up(f_down(a))".to_string()); + let node_b = TestTreeNode::new(vec![], "f_up(f_down(b))".to_string()); + let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string()); + let node_c = TestTreeNode::new(vec![node_b, node_d], "f_down(c)".to_string()); + let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); + let node_h = TestTreeNode::new(vec![], "h".to_string()); + let node_g = TestTreeNode::new(vec![node_h], "g".to_string()); + let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string()); + let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string()); + TestTreeNode::new(vec![node_i], "f_down(j)".to_string()) + } + + fn f_up_stop_on_a_transformed_up_tree() -> TestTreeNode { + let node_a = TestTreeNode::new(vec![], "f_up(a)".to_string()); + let node_b = TestTreeNode::new(vec![], "f_up(b)".to_string()); + let node_d = TestTreeNode::new(vec![node_a], "d".to_string()); + let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string()); + let node_e = TestTreeNode::new(vec![node_c], "e".to_string()); + let node_h = TestTreeNode::new(vec![], "h".to_string()); + let node_g = TestTreeNode::new(vec![node_h], "g".to_string()); + let node_f = TestTreeNode::new(vec![node_e, node_g], "f".to_string()); + let node_i = TestTreeNode::new(vec![node_f], "i".to_string()); + TestTreeNode::new(vec![node_i], "j".to_string()) + } + + // f_up Stop on E node + fn f_up_stop_on_e_visits() -> Vec { + vec![ + "f_down(j)", + "f_down(i)", + "f_down(f)", + "f_down(e)", + "f_down(c)", + "f_down(b)", + "f_up(b)", + "f_down(d)", + "f_down(a)", + "f_up(a)", + "f_up(d)", + "f_up(c)", + "f_up(e)", + ] + .into_iter() + .map(|s| s.to_string()) + .collect() + } + + fn f_up_stop_on_e_transformed_tree() -> TestTreeNode { + let node_a = TestTreeNode::new(vec![], "f_up(f_down(a))".to_string()); + let node_b = TestTreeNode::new(vec![], "f_up(f_down(b))".to_string()); + let node_d = TestTreeNode::new(vec![node_a], "f_up(f_down(d))".to_string()); + let node_c = + TestTreeNode::new(vec![node_b, node_d], "f_up(f_down(c))".to_string()); + let node_e = TestTreeNode::new(vec![node_c], "f_up(f_down(e))".to_string()); + let node_h = TestTreeNode::new(vec![], "h".to_string()); + let node_g = TestTreeNode::new(vec![node_h], "g".to_string()); + let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string()); + let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string()); + TestTreeNode::new(vec![node_i], "f_down(j)".to_string()) + } + + fn f_up_stop_on_e_transformed_up_tree() -> TestTreeNode { + let node_a = TestTreeNode::new(vec![], "f_up(a)".to_string()); + let node_b = TestTreeNode::new(vec![], "f_up(b)".to_string()); + let node_d = TestTreeNode::new(vec![node_a], "f_up(d)".to_string()); + let node_c = TestTreeNode::new(vec![node_b, node_d], "f_up(c)".to_string()); + let node_e = TestTreeNode::new(vec![node_c], "f_up(e)".to_string()); + let node_h = TestTreeNode::new(vec![], "h".to_string()); + let node_g = TestTreeNode::new(vec![node_h], "g".to_string()); + let node_f = TestTreeNode::new(vec![node_e, node_g], "f".to_string()); + let node_i = TestTreeNode::new(vec![node_f], "i".to_string()); + TestTreeNode::new(vec![node_i], "j".to_string()) + } + + fn down_visits(visits: Vec) -> Vec { + visits + .into_iter() + .filter(|v| v.starts_with("f_down")) + .collect() + } + + type TestVisitorF = Box) -> Result>; + + struct TestVisitor { + visits: Vec, + f_down: TestVisitorF, + f_up: TestVisitorF, + } + + impl TestVisitor { + fn new(f_down: TestVisitorF, f_up: TestVisitorF) -> Self { + Self { + visits: vec![], + f_down, + f_up, + } + } + } + + impl TreeNodeVisitor for TestVisitor { + type Node = TestTreeNode; + + fn f_down(&mut self, node: &Self::Node) -> Result { + self.visits.push(format!("f_down({})", node.data)); + (*self.f_down)(node) + } + + fn f_up(&mut self, node: &Self::Node) -> Result { + self.visits.push(format!("f_up({})", node.data)); + (*self.f_up)(node) + } + } + + fn visit_continue(_: &TestTreeNode) -> Result { + Ok(TreeNodeRecursion::Continue) + } + + fn visit_event_on>( + data: D, + event: TreeNodeRecursion, + ) -> impl FnMut(&TestTreeNode) -> Result { + let d = data.into(); + move |node| { + Ok(if node.data == d { + event + } else { + TreeNodeRecursion::Continue + }) + } + } + + macro_rules! visit_test { + ($NAME:ident, $F_DOWN:expr, $F_UP:expr, $EXPECTED_VISITS:expr) => { + #[test] + fn $NAME() -> Result<()> { + let tree = test_tree(); + let mut visitor = TestVisitor::new(Box::new($F_DOWN), Box::new($F_UP)); + tree.visit(&mut visitor)?; + assert_eq!(visitor.visits, $EXPECTED_VISITS); + + Ok(()) + } + }; + } + + macro_rules! test_apply { + ($NAME:ident, $F:expr, $EXPECTED_VISITS:expr) => { + #[test] + fn $NAME() -> Result<()> { + let tree = test_tree(); + let mut visits = vec![]; + tree.apply(&mut |node| { + visits.push(format!("f_down({})", node.data)); + $F(node) + })?; + assert_eq!(visits, $EXPECTED_VISITS); + + Ok(()) + } + }; + } + + type TestRewriterF = + Box) -> Result>>>; + + struct TestRewriter { + f_down: TestRewriterF, + f_up: TestRewriterF, + } + + impl TestRewriter { + fn new(f_down: TestRewriterF, f_up: TestRewriterF) -> Self { + Self { f_down, f_up } + } + } + + impl TreeNodeRewriter for TestRewriter { + type Node = TestTreeNode; + + fn f_down(&mut self, node: Self::Node) -> Result> { + (*self.f_down)(node) + } + + fn f_up(&mut self, node: Self::Node) -> Result> { + (*self.f_up)(node) + } + } + + fn transform_yes>( + transformation_name: N, + ) -> impl FnMut(TestTreeNode) -> Result>> { + move |node| { + Ok(Transformed::yes(TestTreeNode::new( + node.children, + format!("{}({})", transformation_name, node.data).into(), + ))) + } + } + + fn transform_and_event_on< + N: Display, + T: PartialEq + Display + From, + D: Into, + >( + transformation_name: N, + data: D, + event: TreeNodeRecursion, + ) -> impl FnMut(TestTreeNode) -> Result>> { + let d = data.into(); + move |node| { + let new_node = TestTreeNode::new( + node.children, + format!("{}({})", transformation_name, node.data).into(), + ); + Ok(if node.data == d { + Transformed::new(new_node, true, event) + } else { + Transformed::yes(new_node) + }) + } + } + + macro_rules! rewrite_test { + ($NAME:ident, $F_DOWN:expr, $F_UP:expr, $EXPECTED_TREE:expr) => { + #[test] + fn $NAME() -> Result<()> { + let tree = test_tree(); + let mut rewriter = TestRewriter::new(Box::new($F_DOWN), Box::new($F_UP)); + assert_eq!(tree.rewrite(&mut rewriter)?, $EXPECTED_TREE); + + Ok(()) + } + }; + } + + macro_rules! transform_test { + ($NAME:ident, $F_DOWN:expr, $F_UP:expr, $EXPECTED_TREE:expr) => { + #[test] + fn $NAME() -> Result<()> { + let tree = test_tree(); + assert_eq!( + tree.transform_down_up(&mut $F_DOWN, &mut $F_UP,)?, + $EXPECTED_TREE + ); + + Ok(()) + } + }; + } + + macro_rules! transform_down_test { + ($NAME:ident, $F:expr, $EXPECTED_TREE:expr) => { + #[test] + fn $NAME() -> Result<()> { + let tree = test_tree(); + assert_eq!(tree.transform_down_mut(&mut $F)?, $EXPECTED_TREE); + + Ok(()) + } + }; + } + + macro_rules! transform_up_test { + ($NAME:ident, $F:expr, $EXPECTED_TREE:expr) => { + #[test] + fn $NAME() -> Result<()> { + let tree = test_tree(); + assert_eq!(tree.transform_up_mut(&mut $F)?, $EXPECTED_TREE); + + Ok(()) + } + }; + } + + visit_test!(test_visit, visit_continue, visit_continue, all_visits()); + visit_test!( + test_visit_f_down_jump_on_a, + visit_event_on("a", TreeNodeRecursion::Jump), + visit_continue, + f_down_jump_on_a_visits() + ); + visit_test!( + test_visit_f_down_jump_on_e, + visit_event_on("e", TreeNodeRecursion::Jump), + visit_continue, + f_down_jump_on_e_visits() + ); + visit_test!( + test_visit_f_up_jump_on_a, + visit_continue, + visit_event_on("a", TreeNodeRecursion::Jump), + f_up_jump_on_a_visits() + ); + visit_test!( + test_visit_f_up_jump_on_e, + visit_continue, + visit_event_on("e", TreeNodeRecursion::Jump), + f_up_jump_on_e_visits() + ); + visit_test!( + test_visit_f_down_stop_on_a, + visit_event_on("a", TreeNodeRecursion::Stop), + visit_continue, + f_down_stop_on_a_visits() + ); + visit_test!( + test_visit_f_down_stop_on_e, + visit_event_on("e", TreeNodeRecursion::Stop), + visit_continue, + f_down_stop_on_e_visits() + ); + visit_test!( + test_visit_f_up_stop_on_a, + visit_continue, + visit_event_on("a", TreeNodeRecursion::Stop), + f_up_stop_on_a_visits() + ); + visit_test!( + test_visit_f_up_stop_on_e, + visit_continue, + visit_event_on("e", TreeNodeRecursion::Stop), + f_up_stop_on_e_visits() + ); + + test_apply!(test_apply, visit_continue, down_visits(all_visits())); + test_apply!( + test_apply_f_down_jump_on_a, + visit_event_on("a", TreeNodeRecursion::Jump), + down_visits(f_down_jump_on_a_visits()) + ); + test_apply!( + test_apply_f_down_jump_on_e, + visit_event_on("e", TreeNodeRecursion::Jump), + down_visits(f_down_jump_on_e_visits()) + ); + test_apply!( + test_apply_f_down_stop_on_a, + visit_event_on("a", TreeNodeRecursion::Stop), + down_visits(f_down_stop_on_a_visits()) + ); + test_apply!( + test_apply_f_down_stop_on_e, + visit_event_on("e", TreeNodeRecursion::Stop), + down_visits(f_down_stop_on_e_visits()) + ); + + rewrite_test!( + test_rewrite, + transform_yes("f_down"), + transform_yes("f_up"), + Transformed::yes(transformed_tree()) + ); + rewrite_test!( + test_rewrite_f_down_jump_on_a, + transform_and_event_on("f_down", "a", TreeNodeRecursion::Jump), + transform_yes("f_up"), + Transformed::yes(transformed_tree()) + ); + rewrite_test!( + test_rewrite_f_down_jump_on_e, + transform_and_event_on("f_down", "e", TreeNodeRecursion::Jump), + transform_yes("f_up"), + Transformed::yes(f_down_jump_on_e_transformed_tree()) + ); + rewrite_test!( + test_rewrite_f_up_jump_on_a, + transform_yes("f_down"), + transform_and_event_on("f_up", "f_down(a)", TreeNodeRecursion::Jump), + Transformed::yes(f_up_jump_on_a_transformed_tree()) + ); + rewrite_test!( + test_rewrite_f_up_jump_on_e, + transform_yes("f_down"), + transform_and_event_on("f_up", "f_down(e)", TreeNodeRecursion::Jump), + Transformed::yes(f_up_jump_on_e_transformed_tree()) + ); + rewrite_test!( + test_rewrite_f_down_stop_on_a, + transform_and_event_on("f_down", "a", TreeNodeRecursion::Stop), + transform_yes("f_up"), + Transformed::new( + f_down_stop_on_a_transformed_tree(), + true, + TreeNodeRecursion::Stop + ) + ); + rewrite_test!( + test_rewrite_f_down_stop_on_e, + transform_and_event_on("f_down", "e", TreeNodeRecursion::Stop), + transform_yes("f_up"), + Transformed::new( + f_down_stop_on_e_transformed_tree(), + true, + TreeNodeRecursion::Stop + ) + ); + rewrite_test!( + test_rewrite_f_up_stop_on_a, + transform_yes("f_down"), + transform_and_event_on("f_up", "f_down(a)", TreeNodeRecursion::Stop), + Transformed::new( + f_up_stop_on_a_transformed_tree(), + true, + TreeNodeRecursion::Stop + ) + ); + rewrite_test!( + test_rewrite_f_up_stop_on_e, + transform_yes("f_down"), + transform_and_event_on("f_up", "f_down(e)", TreeNodeRecursion::Stop), + Transformed::new( + f_up_stop_on_e_transformed_tree(), + true, + TreeNodeRecursion::Stop + ) + ); + + transform_test!( + test_transform, + transform_yes("f_down"), + transform_yes("f_up"), + Transformed::yes(transformed_tree()) + ); + transform_test!( + test_transform_f_down_jump_on_a, + transform_and_event_on("f_down", "a", TreeNodeRecursion::Jump), + transform_yes("f_up"), + Transformed::yes(transformed_tree()) + ); + transform_test!( + test_transform_f_down_jump_on_e, + transform_and_event_on("f_down", "e", TreeNodeRecursion::Jump), + transform_yes("f_up"), + Transformed::yes(f_down_jump_on_e_transformed_tree()) + ); + transform_test!( + test_transform_f_up_jump_on_a, + transform_yes("f_down"), + transform_and_event_on("f_up", "f_down(a)", TreeNodeRecursion::Jump), + Transformed::yes(f_up_jump_on_a_transformed_tree()) + ); + transform_test!( + test_transform_f_up_jump_on_e, + transform_yes("f_down"), + transform_and_event_on("f_up", "f_down(e)", TreeNodeRecursion::Jump), + Transformed::yes(f_up_jump_on_e_transformed_tree()) + ); + transform_test!( + test_transform_f_down_stop_on_a, + transform_and_event_on("f_down", "a", TreeNodeRecursion::Stop), + transform_yes("f_up"), + Transformed::new( + f_down_stop_on_a_transformed_tree(), + true, + TreeNodeRecursion::Stop + ) + ); + transform_test!( + test_transform_f_down_stop_on_e, + transform_and_event_on("f_down", "e", TreeNodeRecursion::Stop), + transform_yes("f_up"), + Transformed::new( + f_down_stop_on_e_transformed_tree(), + true, + TreeNodeRecursion::Stop + ) + ); + transform_test!( + test_transform_f_up_stop_on_a, + transform_yes("f_down"), + transform_and_event_on("f_up", "f_down(a)", TreeNodeRecursion::Stop), + Transformed::new( + f_up_stop_on_a_transformed_tree(), + true, + TreeNodeRecursion::Stop + ) + ); + transform_test!( + test_transform_f_up_stop_on_e, + transform_yes("f_down"), + transform_and_event_on("f_up", "f_down(e)", TreeNodeRecursion::Stop), + Transformed::new( + f_up_stop_on_e_transformed_tree(), + true, + TreeNodeRecursion::Stop + ) + ); + + transform_down_test!( + test_transform_down, + transform_yes("f_down"), + Transformed::yes(transformed_down_tree()) + ); + transform_down_test!( + test_transform_down_f_down_jump_on_a, + transform_and_event_on("f_down", "a", TreeNodeRecursion::Jump), + Transformed::yes(f_down_jump_on_a_transformed_down_tree()) + ); + transform_down_test!( + test_transform_down_f_down_jump_on_e, + transform_and_event_on("f_down", "e", TreeNodeRecursion::Jump), + Transformed::yes(f_down_jump_on_e_transformed_down_tree()) + ); + transform_down_test!( + test_transform_down_f_down_stop_on_a, + transform_and_event_on("f_down", "a", TreeNodeRecursion::Stop), + Transformed::new( + f_down_stop_on_a_transformed_down_tree(), + true, + TreeNodeRecursion::Stop + ) + ); + transform_down_test!( + test_transform_down_f_down_stop_on_e, + transform_and_event_on("f_down", "e", TreeNodeRecursion::Stop), + Transformed::new( + f_down_stop_on_e_transformed_down_tree(), + true, + TreeNodeRecursion::Stop + ) + ); + + transform_up_test!( + test_transform_up, + transform_yes("f_up"), + Transformed::yes(transformed_up_tree()) + ); + transform_up_test!( + test_transform_up_f_up_jump_on_a, + transform_and_event_on("f_up", "a", TreeNodeRecursion::Jump), + Transformed::yes(f_up_jump_on_a_transformed_up_tree()) + ); + transform_up_test!( + test_transform_up_f_up_jump_on_e, + transform_and_event_on("f_up", "e", TreeNodeRecursion::Jump), + Transformed::yes(f_up_jump_on_e_transformed_up_tree()) + ); + transform_up_test!( + test_transform_up_f_up_stop_on_a, + transform_and_event_on("f_up", "a", TreeNodeRecursion::Stop), + Transformed::new( + f_up_stop_on_a_transformed_up_tree(), + true, + TreeNodeRecursion::Stop + ) + ); + transform_up_test!( + test_transform_up_f_up_stop_on_e, + transform_and_event_on("f_up", "e", TreeNodeRecursion::Stop), + Transformed::new( + f_up_stop_on_e_transformed_up_tree(), + true, + TreeNodeRecursion::Stop + ) + ); +} diff --git a/datafusion/common_runtime/Cargo.toml b/datafusion/common_runtime/Cargo.toml new file mode 100644 index 000000000000..7ed8b2cf2975 --- /dev/null +++ b/datafusion/common_runtime/Cargo.toml @@ -0,0 +1,36 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "datafusion-common-runtime" +description = "Common Runtime functionality for DataFusion query engine" +keywords = ["arrow", "query", "sql"] +readme = "README.md" +version = { workspace = true } +edition = { workspace = true } +homepage = { workspace = true } +repository = { workspace = true } +license = { workspace = true } +authors = { workspace = true } +rust-version = { workspace = true } + +[lib] +name = "datafusion_common_runtime" +path = "src/lib.rs" + +[dependencies] +tokio = { workspace = true } diff --git a/datafusion/common_runtime/README.md b/datafusion/common_runtime/README.md new file mode 100644 index 000000000000..77100e52603c --- /dev/null +++ b/datafusion/common_runtime/README.md @@ -0,0 +1,26 @@ + + +# DataFusion Common Runtime + +[DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. + +This crate is a submodule of DataFusion that provides common utilities. + +[df]: https://crates.io/crates/datafusion diff --git a/datafusion/common_runtime/src/common.rs b/datafusion/common_runtime/src/common.rs new file mode 100644 index 000000000000..2f7ddb972f42 --- /dev/null +++ b/datafusion/common_runtime/src/common.rs @@ -0,0 +1,77 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::future::Future; + +use tokio::task::{JoinError, JoinSet}; + +/// Helper that provides a simple API to spawn a single task and join it. +/// Provides guarantees of aborting on `Drop` to keep it cancel-safe. +/// +/// Technically, it's just a wrapper of `JoinSet` (with size=1). +#[derive(Debug)] +pub struct SpawnedTask { + inner: JoinSet, +} + +impl SpawnedTask { + pub fn spawn(task: T) -> Self + where + T: Future, + T: Send + 'static, + R: Send, + { + let mut inner = JoinSet::new(); + inner.spawn(task); + Self { inner } + } + + pub fn spawn_blocking(task: T) -> Self + where + T: FnOnce() -> R, + T: Send + 'static, + R: Send, + { + let mut inner = JoinSet::new(); + inner.spawn_blocking(task); + Self { inner } + } + + /// Joins the task, returning the result of join (`Result`). + pub async fn join(mut self) -> Result { + self.inner + .join_next() + .await + .expect("`SpawnedTask` instance always contains exactly 1 task") + } + + /// Joins the task and unwinds the panic if it happens. + pub async fn join_unwind(self) -> R { + self.join().await.unwrap_or_else(|e| { + // `JoinError` can be caused either by panic or cancellation. We have to handle panics: + if e.is_panic() { + std::panic::resume_unwind(e.into_panic()); + } else { + // Cancellation may be caused by two reasons: + // 1. Abort is called, but since we consumed `self`, it's not our case (`JoinHandle` not accessible outside). + // 2. The runtime is shutting down. + // So we consider this branch as unreachable. + unreachable!("SpawnedTask was cancelled unexpectedly"); + } + }) + } +} diff --git a/datafusion/common_runtime/src/lib.rs b/datafusion/common_runtime/src/lib.rs new file mode 100644 index 000000000000..e8624163f224 --- /dev/null +++ b/datafusion/common_runtime/src/lib.rs @@ -0,0 +1,20 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +pub mod common; + +pub use common::SpawnedTask; diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index 662d95a9323c..0c378d9d83f5 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -89,6 +89,7 @@ bzip2 = { version = "0.4.3", optional = true } chrono = { workspace = true } dashmap = { workspace = true } datafusion-common = { workspace = true, features = ["object_store"] } +datafusion-common-runtime = { workspace = true } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } datafusion-functions = { workspace = true } diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index d7c31b9bd6b3..3bdf2af4552d 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -1510,6 +1510,7 @@ mod tests { use arrow::array::{self, Int32Array}; use arrow::datatypes::DataType; use datafusion_common::{Constraint, Constraints}; + use datafusion_common_runtime::SpawnedTask; use datafusion_expr::{ avg, cast, count, count_distinct, create_udf, expr, lit, max, min, sum, BuiltInWindowFunction, ScalarFunctionImplementation, Volatility, WindowFrame, @@ -2169,15 +2170,14 @@ mod tests { } #[tokio::test] - #[allow(clippy::disallowed_methods)] async fn sendable() { let df = test_table().await.unwrap(); // dataframes should be sendable between threads/tasks - let task = tokio::task::spawn(async move { + let task = SpawnedTask::spawn(async move { df.select_columns(&["c1"]) .expect("should be usable in a task") }); - task.await.expect("task completed successfully"); + task.join().await.expect("task completed successfully"); } #[tokio::test] diff --git a/datafusion/core/src/datasource/file_format/arrow.rs b/datafusion/core/src/datasource/file_format/arrow.rs index d5f07d11bee9..90417a978137 100644 --- a/datafusion/core/src/datasource/file_format/arrow.rs +++ b/datafusion/core/src/datasource/file_format/arrow.rs @@ -295,16 +295,7 @@ impl DataSink for ArrowFileSink { } } - match demux_task.join().await { - Ok(r) => r?, - Err(e) => { - if e.is_panic() { - std::panic::resume_unwind(e.into_panic()); - } else { - unreachable!(); - } - } - } + demux_task.join_unwind().await?; Ok(row_count as u64) } } diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index 739850115370..3824177cb363 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -40,9 +40,9 @@ use arrow::datatypes::SchemaRef; use arrow::datatypes::{Fields, Schema}; use bytes::{BufMut, BytesMut}; use datafusion_common::{exec_err, not_impl_err, DataFusionError, FileType}; +use datafusion_common_runtime::SpawnedTask; use datafusion_execution::TaskContext; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortRequirement}; -use datafusion_physical_plan::common::SpawnedTask; use futures::{StreamExt, TryStreamExt}; use hashbrown::HashMap; use object_store::path::Path; @@ -729,16 +729,7 @@ impl DataSink for ParquetSink { } } - match demux_task.join().await { - Ok(r) => r?, - Err(e) => { - if e.is_panic() { - std::panic::resume_unwind(e.into_panic()); - } else { - unreachable!(); - } - } - } + demux_task.join_unwind().await?; Ok(row_count as u64) } @@ -831,19 +822,8 @@ fn spawn_rg_join_and_finalize_task( let num_cols = column_writer_tasks.len(); let mut finalized_rg = Vec::with_capacity(num_cols); for task in column_writer_tasks.into_iter() { - match task.join().await { - Ok(r) => { - let w = r?; - finalized_rg.push(w.close()?); - } - Err(e) => { - if e.is_panic() { - std::panic::resume_unwind(e.into_panic()) - } else { - unreachable!() - } - } - } + let writer = task.join_unwind().await?; + finalized_rg.push(writer.close()?); } Ok((finalized_rg, rg_rows)) @@ -952,31 +932,21 @@ async fn concatenate_parallel_row_groups( let mut row_count = 0; while let Some(task) = serialize_rx.recv().await { - match task.join().await { - Ok(result) => { - let mut rg_out = parquet_writer.next_row_group()?; - let (serialized_columns, cnt) = result?; - row_count += cnt; - for chunk in serialized_columns { - chunk.append_to_row_group(&mut rg_out)?; - let mut buff_to_flush = merged_buff.buffer.try_lock().unwrap(); - if buff_to_flush.len() > BUFFER_FLUSH_BYTES { - object_store_writer - .write_all(buff_to_flush.as_slice()) - .await?; - buff_to_flush.clear(); - } - } - rg_out.close()?; - } - Err(e) => { - if e.is_panic() { - std::panic::resume_unwind(e.into_panic()); - } else { - unreachable!(); - } + let result = task.join_unwind().await; + let mut rg_out = parquet_writer.next_row_group()?; + let (serialized_columns, cnt) = result?; + row_count += cnt; + for chunk in serialized_columns { + chunk.append_to_row_group(&mut rg_out)?; + let mut buff_to_flush = merged_buff.buffer.try_lock().unwrap(); + if buff_to_flush.len() > BUFFER_FLUSH_BYTES { + object_store_writer + .write_all(buff_to_flush.as_slice()) + .await?; + buff_to_flush.clear(); } } + rg_out.close()?; } let inner_writer = parquet_writer.into_inner()?; @@ -1020,18 +990,7 @@ async fn output_single_parquet_file_parallelized( ) .await?; - match launch_serialization_task.join().await { - Ok(Ok(_)) => (), - Ok(Err(e)) => return Err(e), - Err(e) => { - if e.is_panic() { - std::panic::resume_unwind(e.into_panic()) - } else { - unreachable!() - } - } - } - + launch_serialization_task.join_unwind().await?; Ok(row_count) } diff --git a/datafusion/core/src/datasource/file_format/write/demux.rs b/datafusion/core/src/datasource/file_format/write/demux.rs index d70b4811da5b..396da96332f6 100644 --- a/datafusion/core/src/datasource/file_format/write/demux.rs +++ b/datafusion/core/src/datasource/file_format/write/demux.rs @@ -33,7 +33,7 @@ use arrow_array::{downcast_dictionary_array, RecordBatch, StringArray, StructArr use arrow_schema::{DataType, Schema}; use datafusion_common::cast::as_string_array; use datafusion_common::{exec_datafusion_err, DataFusionError}; - +use datafusion_common_runtime::SpawnedTask; use datafusion_execution::TaskContext; use futures::StreamExt; @@ -41,7 +41,6 @@ use object_store::path::Path; use rand::distributions::DistString; -use datafusion_physical_plan::common::SpawnedTask; use tokio::sync::mpsc::{self, Receiver, Sender, UnboundedReceiver, UnboundedSender}; type RecordBatchReceiver = Receiver; diff --git a/datafusion/core/src/datasource/file_format/write/orchestration.rs b/datafusion/core/src/datasource/file_format/write/orchestration.rs index 05406d3751c9..b7f268959311 100644 --- a/datafusion/core/src/datasource/file_format/write/orchestration.rs +++ b/datafusion/core/src/datasource/file_format/write/orchestration.rs @@ -30,11 +30,11 @@ use crate::physical_plan::SendableRecordBatchStream; use arrow_array::RecordBatch; use datafusion_common::{internal_datafusion_err, internal_err, DataFusionError}; +use datafusion_common_runtime::SpawnedTask; use datafusion_execution::TaskContext; use bytes::Bytes; -use datafusion_physical_plan::common::SpawnedTask; -use futures::try_join; +use futures::join; use tokio::io::{AsyncWrite, AsyncWriteExt}; use tokio::sync::mpsc::{self, Receiver}; use tokio::task::JoinSet; @@ -264,19 +264,12 @@ pub(crate) async fn stateless_multipart_put( // Signal to the write coordinator that no more files are coming drop(tx_file_bundle); - match try_join!(write_coordinator_task.join(), demux_task.join()) { - Ok((r1, r2)) => { - r1?; - r2?; - } - Err(e) => { - if e.is_panic() { - std::panic::resume_unwind(e.into_panic()); - } else { - unreachable!(); - } - } - } + let (r1, r2) = join!( + write_coordinator_task.join_unwind(), + demux_task.join_unwind() + ); + r1?; + r2?; let total_count = rx_row_cnt.await.map_err(|_| { internal_datafusion_err!("Did not receieve row count from write coordinater") diff --git a/datafusion/core/src/datasource/listing/helpers.rs b/datafusion/core/src/datasource/listing/helpers.rs index 077356b716b0..eef25792d00a 100644 --- a/datafusion/core/src/datasource/listing/helpers.rs +++ b/datafusion/core/src/datasource/listing/helpers.rs @@ -19,29 +19,27 @@ use std::sync::Arc; -use arrow::compute::{and, cast, prep_null_mask_filter}; +use super::PartitionedFile; +use crate::datasource::listing::ListingTableUrl; +use crate::execution::context::SessionState; +use crate::{error::Result, scalar::ScalarValue}; + use arrow::{ - array::{ArrayRef, StringBuilder}, + array::{Array, ArrayRef, AsArray, StringBuilder}, + compute::{and, cast, prep_null_mask_filter}, datatypes::{DataType, Field, Schema}, record_batch::RecordBatch, }; -use arrow_array::cast::AsArray; -use arrow_array::Array; use arrow_schema::Fields; -use futures::stream::FuturesUnordered; -use futures::{stream::BoxStream, StreamExt, TryStreamExt}; -use log::{debug, trace}; - -use crate::{error::Result, scalar::ScalarValue}; - -use super::PartitionedFile; -use crate::datasource::listing::ListingTableUrl; -use crate::execution::context::SessionState; -use datafusion_common::tree_node::{TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion_common::{internal_err, Column, DFField, DFSchema, DataFusionError}; use datafusion_expr::{Expr, ScalarFunctionDefinition, Volatility}; use datafusion_physical_expr::create_physical_expr; use datafusion_physical_expr::execution_props::ExecutionProps; + +use futures::stream::{BoxStream, FuturesUnordered}; +use futures::{StreamExt, TryStreamExt}; +use log::{debug, trace}; use object_store::path::Path; use object_store::{ObjectMeta, ObjectStore}; @@ -57,9 +55,9 @@ pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool { Expr::Column(Column { ref name, .. }) => { is_applicable &= col_names.contains(name); if is_applicable { - Ok(VisitRecursion::Skip) + Ok(TreeNodeRecursion::Jump) } else { - Ok(VisitRecursion::Stop) + Ok(TreeNodeRecursion::Stop) } } Expr::Literal(_) @@ -88,27 +86,27 @@ pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool { | Expr::ScalarSubquery(_) | Expr::GetIndexedField { .. } | Expr::GroupingSet(_) - | Expr::Case { .. } => Ok(VisitRecursion::Continue), + | Expr::Case { .. } => Ok(TreeNodeRecursion::Continue), Expr::ScalarFunction(scalar_function) => { match &scalar_function.func_def { ScalarFunctionDefinition::BuiltIn(fun) => { match fun.volatility() { - Volatility::Immutable => Ok(VisitRecursion::Continue), + Volatility::Immutable => Ok(TreeNodeRecursion::Continue), // TODO: Stable functions could be `applicable`, but that would require access to the context Volatility::Stable | Volatility::Volatile => { is_applicable = false; - Ok(VisitRecursion::Stop) + Ok(TreeNodeRecursion::Stop) } } } ScalarFunctionDefinition::UDF(fun) => { match fun.signature().volatility { - Volatility::Immutable => Ok(VisitRecursion::Continue), + Volatility::Immutable => Ok(TreeNodeRecursion::Continue), // TODO: Stable functions could be `applicable`, but that would require access to the context Volatility::Stable | Volatility::Volatile => { is_applicable = false; - Ok(VisitRecursion::Stop) + Ok(TreeNodeRecursion::Stop) } } } @@ -129,7 +127,7 @@ pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool { | Expr::Unnest { .. } | Expr::Placeholder(_) => { is_applicable = false; - Ok(VisitRecursion::Stop) + Ok(TreeNodeRecursion::Stop) } } }) diff --git a/datafusion/core/src/datasource/physical_plan/arrow_file.rs b/datafusion/core/src/datasource/physical_plan/arrow_file.rs index 82774a6e831c..96b3adf968b8 100644 --- a/datafusion/core/src/datasource/physical_plan/arrow_file.rs +++ b/datafusion/core/src/datasource/physical_plan/arrow_file.rs @@ -177,7 +177,7 @@ impl ExecutionPlan for ArrowExec { let opener = ArrowOpener { object_store, - projection: self.base_config.projection.clone(), + projection: self.base_config.file_column_projection_indices(), }; let stream = FileStream::new(&self.base_config, partition, opener, &self.metrics)?; diff --git a/datafusion/core/src/datasource/physical_plan/file_scan_config.rs b/datafusion/core/src/datasource/physical_plan/file_scan_config.rs index 4a814c5b9b2c..370ca91a0b0e 100644 --- a/datafusion/core/src/datasource/physical_plan/file_scan_config.rs +++ b/datafusion/core/src/datasource/physical_plan/file_scan_config.rs @@ -157,6 +157,22 @@ impl FileScanConfig { }) } + /// Projects only file schema, ignoring partition columns + pub(crate) fn projected_file_schema(&self) -> SchemaRef { + let fields = self.file_column_projection_indices().map(|indices| { + indices + .iter() + .map(|col_idx| self.file_schema.field(*col_idx)) + .cloned() + .collect::>() + }); + + fields.map_or_else( + || Arc::clone(&self.file_schema), + |f| Arc::new(Schema::new(f).with_metadata(self.file_schema.metadata.clone())), + ) + } + pub(crate) fn file_column_projection_indices(&self) -> Option> { self.projection.as_ref().map(|p| { p.iter() @@ -686,6 +702,66 @@ mod tests { crate::assert_batches_eq!(expected, &[projected_batch]); } + #[test] + fn test_projected_file_schema_with_partition_col() { + let schema = aggr_test_schema(); + let partition_cols = vec![ + ( + "part1".to_owned(), + wrap_partition_type_in_dict(DataType::Utf8), + ), + ( + "part2".to_owned(), + wrap_partition_type_in_dict(DataType::Utf8), + ), + ]; + + // Projected file schema for config with projection including partition column + let projection = config_for_projection( + schema.clone(), + Some(vec![0, 3, 5, schema.fields().len()]), + Statistics::new_unknown(&schema), + to_partition_cols(partition_cols.clone()), + ) + .projected_file_schema(); + + // Assert partition column filtered out in projected file schema + let expected_columns = vec!["c1", "c4", "c6"]; + let actual_columns = projection + .fields() + .iter() + .map(|f| f.name().clone()) + .collect::>(); + assert_eq!(expected_columns, actual_columns); + } + + #[test] + fn test_projected_file_schema_without_projection() { + let schema = aggr_test_schema(); + let partition_cols = vec![ + ( + "part1".to_owned(), + wrap_partition_type_in_dict(DataType::Utf8), + ), + ( + "part2".to_owned(), + wrap_partition_type_in_dict(DataType::Utf8), + ), + ]; + + // Projected file schema for config without projection + let projection = config_for_projection( + schema.clone(), + None, + Statistics::new_unknown(&schema), + to_partition_cols(partition_cols.clone()), + ) + .projected_file_schema(); + + // Assert projected file schema is equal to file schema + assert_eq!(projection.fields(), schema.fields()); + } + // sets default for configs that play no role in projections fn config_for_projection( file_schema: SchemaRef, diff --git a/datafusion/core/src/datasource/physical_plan/json.rs b/datafusion/core/src/datasource/physical_plan/json.rs index 62b96ea3aefb..ca466b5c6a92 100644 --- a/datafusion/core/src/datasource/physical_plan/json.rs +++ b/datafusion/core/src/datasource/physical_plan/json.rs @@ -174,14 +174,13 @@ impl ExecutionPlan for NdJsonExec { context: Arc, ) -> Result { let batch_size = context.session_config().batch_size(); - let (projected_schema, ..) = self.base_config.project(); let object_store = context .runtime_env() .object_store(&self.base_config.object_store_url)?; let opener = JsonOpener { batch_size, - projected_schema, + projected_schema: self.base_config.projected_file_schema(), file_compression_type: self.file_compression_type.to_owned(), object_store, }; diff --git a/datafusion/core/src/datasource/physical_plan/parquet/page_filter.rs b/datafusion/core/src/datasource/physical_plan/parquet/page_filter.rs index 556ae35d48fb..064a8e1fff33 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/page_filter.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/page_filter.rs @@ -323,7 +323,7 @@ fn prune_pages_in_one_row_group( assert_eq!(row_vec.len(), values.len()); let mut sum_row = *row_vec.first().unwrap(); let mut selected = *values.first().unwrap(); - trace!("Pruned to to {:?} using {:?}", values, pruning_stats); + trace!("Pruned to {:?} using {:?}", values, pruning_stats); for (i, &f) in values.iter().enumerate().skip(1) { if f == selected { sum_row += *row_vec.get(i).unwrap(); diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs index 3c40509a86d2..c0e37a7150d9 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs @@ -15,26 +15,28 @@ // specific language governing permissions and limitations // under the License. +use std::collections::BTreeSet; +use std::sync::Arc; + +use super::ParquetFileMetrics; +use crate::physical_plan::metrics; + use arrow::array::BooleanArray; use arrow::datatypes::{DataType, Schema}; use arrow::error::{ArrowError, Result as ArrowResult}; use arrow::record_batch::RecordBatch; use datafusion_common::cast::as_boolean_array; -use datafusion_common::tree_node::{RewriteRecursion, TreeNode, TreeNodeRewriter}; +use datafusion_common::tree_node::{ + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter, +}; use datafusion_common::{arrow_err, DataFusionError, Result, ScalarValue}; use datafusion_physical_expr::expressions::{Column, Literal}; use datafusion_physical_expr::utils::reassign_predicate_columns; -use std::collections::BTreeSet; - use datafusion_physical_expr::{split_conjunction, PhysicalExpr}; + use parquet::arrow::arrow_reader::{ArrowPredicate, RowFilter}; use parquet::arrow::ProjectionMask; use parquet::file::metadata::ParquetMetaData; -use std::sync::Arc; - -use crate::physical_plan::metrics; - -use super::ParquetFileMetrics; /// This module contains utilities for enabling the pushdown of DataFusion filter predicates (which /// can be any DataFusion `Expr` that evaluates to a `BooleanArray`) to the parquet decoder level in `arrow-rs`. @@ -188,8 +190,7 @@ impl<'a> FilterCandidateBuilder<'a> { mut self, metadata: &ParquetMetaData, ) -> Result> { - let expr = self.expr.clone(); - let expr = expr.rewrite(&mut self)?; + let expr = self.expr.clone().rewrite(&mut self).data()?; if self.non_primitive_columns || self.projected_columns { Ok(None) @@ -209,29 +210,35 @@ impl<'a> FilterCandidateBuilder<'a> { } impl<'a> TreeNodeRewriter for FilterCandidateBuilder<'a> { - type N = Arc; + type Node = Arc; - fn pre_visit(&mut self, node: &Arc) -> Result { + fn f_down( + &mut self, + node: Arc, + ) -> Result>> { if let Some(column) = node.as_any().downcast_ref::() { if let Ok(idx) = self.file_schema.index_of(column.name()) { self.required_column_indices.insert(idx); if DataType::is_nested(self.file_schema.field(idx).data_type()) { self.non_primitive_columns = true; - return Ok(RewriteRecursion::Stop); + return Ok(Transformed::new(node, false, TreeNodeRecursion::Jump)); } } else if self.table_schema.index_of(column.name()).is_err() { // If the column does not exist in the (un-projected) table schema then // it must be a projected column. self.projected_columns = true; - return Ok(RewriteRecursion::Stop); + return Ok(Transformed::new(node, false, TreeNodeRecursion::Jump)); } } - Ok(RewriteRecursion::Continue) + Ok(Transformed::no(node)) } - fn mutate(&mut self, expr: Arc) -> Result> { + fn f_up( + &mut self, + expr: Arc, + ) -> Result>> { if let Some(column) = expr.as_any().downcast_ref::() { if self.file_schema.field_with_name(column.name()).is_err() { // the column expr must be in the table schema @@ -239,7 +246,7 @@ impl<'a> TreeNodeRewriter for FilterCandidateBuilder<'a> { Ok(field) => { // return the null value corresponding to the data type let null_value = ScalarValue::try_from(field.data_type())?; - Ok(Arc::new(Literal::new(null_value))) + Ok(Transformed::yes(Arc::new(Literal::new(null_value)))) } Err(e) => { // If the column is not in the table schema, should throw the error @@ -249,7 +256,7 @@ impl<'a> TreeNodeRewriter for FilterCandidateBuilder<'a> { } } - Ok(expr) + Ok(Transformed::no(expr)) } } diff --git a/datafusion/core/src/datasource/stream.rs b/datafusion/core/src/datasource/stream.rs index 6dc59e4a5c65..079c1a891d14 100644 --- a/datafusion/core/src/datasource/stream.rs +++ b/datafusion/core/src/datasource/stream.rs @@ -31,9 +31,9 @@ use async_trait::async_trait; use futures::StreamExt; use datafusion_common::{plan_err, Constraints, DataFusionError, Result}; +use datafusion_common_runtime::SpawnedTask; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_expr::{CreateExternalTable, Expr, TableType}; -use datafusion_physical_plan::common::SpawnedTask; use datafusion_physical_plan::insert::{DataSink, FileSinkExec}; use datafusion_physical_plan::metrics::MetricsSet; use datafusion_physical_plan::stream::RecordBatchReceiverStreamBuilder; @@ -359,6 +359,6 @@ impl DataSink for StreamWrite { } } drop(sender); - write_task.join().await.unwrap() + write_task.join_unwind().await } } diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 3aa4edfe3adc..2144cd3c7736 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -40,7 +40,7 @@ use arrow_schema::Schema; use datafusion_common::{ alias::AliasGenerator, exec_err, not_impl_err, plan_datafusion_err, plan_err, - tree_node::{TreeNode, TreeNodeVisitor, VisitRecursion}, + tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor}, }; use datafusion_execution::registry::SerializerRegistry; use datafusion_expr::{ @@ -2189,9 +2189,9 @@ impl<'a> BadPlanVisitor<'a> { } impl<'a> TreeNodeVisitor for BadPlanVisitor<'a> { - type N = LogicalPlan; + type Node = LogicalPlan; - fn pre_visit(&mut self, node: &Self::N) -> Result { + fn f_down(&mut self, node: &Self::Node) -> Result { match node { LogicalPlan::Ddl(ddl) if !self.options.allow_ddl => { plan_err!("DDL not supported: {}", ddl.name()) @@ -2205,7 +2205,7 @@ impl<'a> TreeNodeVisitor for BadPlanVisitor<'a> { LogicalPlan::Statement(stmt) if !self.options.allow_statements => { plan_err!("Statement not supported: {}", stmt.name()) } - _ => Ok(VisitRecursion::Continue), + _ => Ok(TreeNodeRecursion::Continue), } } } @@ -2222,6 +2222,7 @@ mod tests { use crate::test_util::{plan_and_collect, populate_csv_partitions}; use crate::variable::VarType; use async_trait::async_trait; + use datafusion_common_runtime::SpawnedTask; use datafusion_expr::Expr; use std::env; use std::path::PathBuf; @@ -2321,7 +2322,6 @@ mod tests { } #[tokio::test] - #[allow(clippy::disallowed_methods)] async fn send_context_to_threads() -> Result<()> { // ensure SessionContexts can be used in a multi-threaded // environment. Usecase is for concurrent planing. @@ -2332,7 +2332,7 @@ mod tests { let threads: Vec<_> = (0..2) .map(|_| ctx.clone()) .map(|ctx| { - tokio::spawn(async move { + SpawnedTask::spawn(async move { // Ensure we can create logical plan code on a separate thread. ctx.sql("SELECT c1, c2 FROM test WHERE c1 > 0 AND c1 < 3") .await @@ -2341,7 +2341,7 @@ mod tests { .collect(); for handle in threads { - handle.await.unwrap().unwrap(); + handle.join().await.unwrap().unwrap(); } Ok(()) } diff --git a/datafusion/core/src/lib.rs b/datafusion/core/src/lib.rs index d78d7a38a1c3..2b565ece7568 100644 --- a/datafusion/core/src/lib.rs +++ b/datafusion/core/src/lib.rs @@ -480,6 +480,11 @@ pub use parquet; /// re-export of [`datafusion_common`] crate pub mod common { pub use datafusion_common::*; + + /// re-export of [`datafusion_common_runtime`] crate + pub mod runtime { + pub use datafusion_common_runtime::*; + } } // Backwards compatibility @@ -524,7 +529,7 @@ pub mod functions { /// re-export of [`datafusion_functions_array`] crate, if "array_expressions" feature is enabled pub mod functions_array { #[cfg(feature = "array_expressions")] - pub use datafusion_functions::*; + pub use datafusion_functions_array::*; } #[cfg(test)] diff --git a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs index 4fe11c14a758..df54222270ce 100644 --- a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs +++ b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs @@ -27,7 +27,7 @@ use crate::physical_plan::{expressions, AggregateExpr, ExecutionPlan, Statistics use crate::scalar::ScalarValue; use datafusion_common::stats::Precision; -use datafusion_common::tree_node::TreeNode; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_expr::utils::COUNT_STAR_EXPANSION; use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; @@ -85,10 +85,14 @@ impl PhysicalOptimizerRule for AggregateStatistics { Arc::new(PlaceholderRowExec::new(plan.schema())), )?)) } else { - plan.map_children(|child| self.optimize(child, _config)) + plan.map_children(|child| { + self.optimize(child, _config).map(Transformed::yes) + }) + .data() } } else { - plan.map_children(|child| self.optimize(child, _config)) + plan.map_children(|child| self.optimize(child, _config).map(Transformed::yes)) + .data() } } diff --git a/datafusion/core/src/physical_optimizer/coalesce_batches.rs b/datafusion/core/src/physical_optimizer/coalesce_batches.rs index 7b66ca529094..7c0082037da0 100644 --- a/datafusion/core/src/physical_optimizer/coalesce_batches.rs +++ b/datafusion/core/src/physical_optimizer/coalesce_batches.rs @@ -18,8 +18,10 @@ //! CoalesceBatches optimizer that groups batches together rows //! in bigger batches to avoid overhead with small batches -use crate::config::ConfigOptions; +use std::sync::Arc; + use crate::{ + config::ConfigOptions, error::Result, physical_optimizer::PhysicalOptimizerRule, physical_plan::{ @@ -27,8 +29,8 @@ use crate::{ repartition::RepartitionExec, Partitioning, }, }; -use datafusion_common::tree_node::{Transformed, TreeNode}; -use std::sync::Arc; + +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; /// Optimizer rule that introduces CoalesceBatchesExec to avoid overhead with small batches that /// are produced by highly selective filters @@ -71,14 +73,15 @@ impl PhysicalOptimizerRule for CoalesceBatches { }) .unwrap_or(false); if wrap_in_coalesce { - Ok(Transformed::Yes(Arc::new(CoalesceBatchesExec::new( + Ok(Transformed::yes(Arc::new(CoalesceBatchesExec::new( plan, target_batch_size, )))) } else { - Ok(Transformed::No(plan)) + Ok(Transformed::no(plan)) } }) + .data() } fn name(&self) -> &str { diff --git a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs index 61eb2381c63b..c45e14100e82 100644 --- a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs +++ b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs @@ -26,7 +26,7 @@ use crate::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGro use crate::physical_plan::ExecutionPlan; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::{AggregateExpr, PhysicalExpr}; @@ -109,11 +109,12 @@ impl PhysicalOptimizerRule for CombinePartialFinalAggregate { }); Ok(if let Some(transformed) = transformed { - Transformed::Yes(transformed) + Transformed::yes(transformed) } else { - Transformed::No(plan) + Transformed::no(plan) }) }) + .data() } fn name(&self) -> &str { @@ -185,11 +186,12 @@ fn discard_column_index(group_expr: Arc) -> Arc None, }; Ok(if let Some(normalized_form) = normalized_form { - Transformed::Yes(normalized_form) + Transformed::yes(normalized_form) } else { - Transformed::No(expr) + Transformed::no(expr) }) }) + .data() .unwrap_or(group_expr) } diff --git a/datafusion/core/src/physical_optimizer/enforce_distribution.rs b/datafusion/core/src/physical_optimizer/enforce_distribution.rs index eb221a28e2cf..822cd0541ae2 100644 --- a/datafusion/core/src/physical_optimizer/enforce_distribution.rs +++ b/datafusion/core/src/physical_optimizer/enforce_distribution.rs @@ -45,7 +45,7 @@ use crate::physical_plan::windows::WindowAggExec; use crate::physical_plan::{Distribution, ExecutionPlan, Partitioning}; use arrow::compute::SortOptions; -use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_expr::logical_plan::JoinType; use datafusion_physical_expr::expressions::{Column, NoOp}; use datafusion_physical_expr::utils::map_columns_before_projection; @@ -197,22 +197,25 @@ impl PhysicalOptimizerRule for EnforceDistribution { let adjusted = if top_down_join_key_reordering { // Run a top-down process to adjust input key ordering recursively let plan_requirements = PlanWithKeyRequirements::new_default(plan); - let adjusted = - plan_requirements.transform_down(&adjust_input_keys_ordering)?; + let adjusted = plan_requirements + .transform_down(&adjust_input_keys_ordering) + .data()?; adjusted.plan } else { // Run a bottom-up process plan.transform_up(&|plan| { - Ok(Transformed::Yes(reorder_join_keys_to_inputs(plan)?)) - })? + Ok(Transformed::yes(reorder_join_keys_to_inputs(plan)?)) + }) + .data()? }; let distribution_context = DistributionContext::new_default(adjusted); // Distribution enforcement needs to be applied bottom-up. - let distribution_context = - distribution_context.transform_up(&|distribution_context| { + let distribution_context = distribution_context + .transform_up(&|distribution_context| { ensure_distribution(distribution_context, config) - })?; + }) + .data()?; Ok(distribution_context.plan) } @@ -306,7 +309,7 @@ fn adjust_input_keys_ordering( vec![], &join_constructor, ) - .map(Transformed::Yes); + .map(Transformed::yes); } PartitionMode::CollectLeft => { // Push down requirements to the right side @@ -370,18 +373,18 @@ fn adjust_input_keys_ordering( sort_options.clone(), &join_constructor, ) - .map(Transformed::Yes); + .map(Transformed::yes); } else if let Some(aggregate_exec) = plan.as_any().downcast_ref::() { if !requirements.data.is_empty() { if aggregate_exec.mode() == &AggregateMode::FinalPartitioned { return reorder_aggregate_keys(requirements, aggregate_exec) - .map(Transformed::Yes); + .map(Transformed::yes); } else { requirements.data.clear(); } } else { // Keep everything unchanged - return Ok(Transformed::No(requirements)); + return Ok(Transformed::no(requirements)); } } else if let Some(proj) = plan.as_any().downcast_ref::() { let expr = proj.expr(); @@ -409,7 +412,7 @@ fn adjust_input_keys_ordering( child.data = requirements.data.clone(); } } - Ok(Transformed::Yes(requirements)) + Ok(Transformed::yes(requirements)) } fn reorder_partitioned_join_keys( @@ -1057,7 +1060,7 @@ fn ensure_distribution( let dist_context = update_children(dist_context)?; if dist_context.plan.children().is_empty() { - return Ok(Transformed::No(dist_context)); + return Ok(Transformed::no(dist_context)); } let target_partitions = config.execution.target_partitions; @@ -1237,7 +1240,7 @@ fn ensure_distribution( plan.with_new_children(children_plans)? }; - Ok(Transformed::Yes(DistributionContext::new( + Ok(Transformed::yes(DistributionContext::new( plan, data, children, ))) } @@ -1323,6 +1326,7 @@ pub(crate) mod tests { use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; + use datafusion_common::tree_node::TransformedResult; use datafusion_common::ScalarValue; use datafusion_expr::logical_plan::JoinType; use datafusion_expr::Operator; @@ -1716,7 +1720,7 @@ pub(crate) mod tests { config.optimizer.repartition_file_scans = false; config.optimizer.repartition_file_min_size = 1024; config.optimizer.prefer_existing_sort = prefer_existing_sort; - ensure_distribution(distribution_context, &config).map(|item| item.into().plan) + ensure_distribution(distribution_context, &config).map(|item| item.data.plan) } /// Test whether plan matches with expected plan @@ -1785,14 +1789,16 @@ pub(crate) mod tests { PlanWithKeyRequirements::new_default($PLAN.clone()); let adjusted = plan_requirements .transform_down(&adjust_input_keys_ordering) + .data() .and_then(check_integrity)?; // TODO: End state payloads will be checked here. adjusted.plan } else { // Run reorder_join_keys_to_inputs rule $PLAN.clone().transform_up(&|plan| { - Ok(Transformed::Yes(reorder_join_keys_to_inputs(plan)?)) - })? + Ok(Transformed::yes(reorder_join_keys_to_inputs(plan)?)) + }) + .data()? }; // Then run ensure_distribution rule @@ -1800,6 +1806,7 @@ pub(crate) mod tests { .transform_up(&|distribution_context| { ensure_distribution(distribution_context, &config) }) + .data() .and_then(check_integrity)?; // TODO: End state payloads will be checked here. } diff --git a/datafusion/core/src/physical_optimizer/enforce_sorting.rs b/datafusion/core/src/physical_optimizer/enforce_sorting.rs index 5fac1397e023..79dd5758cc2f 100644 --- a/datafusion/core/src/physical_optimizer/enforce_sorting.rs +++ b/datafusion/core/src/physical_optimizer/enforce_sorting.rs @@ -60,7 +60,7 @@ use crate::physical_plan::windows::{ use crate::physical_plan::{Distribution, ExecutionPlan, InputOrderMode}; use datafusion_common::plan_err; -use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_physical_expr::{PhysicalSortExpr, PhysicalSortRequirement}; use datafusion_physical_plan::repartition::RepartitionExec; use datafusion_physical_plan::sorts::partial_sort::PartialSortExec; @@ -160,37 +160,40 @@ impl PhysicalOptimizerRule for EnforceSorting { let plan_requirements = PlanWithCorrespondingSort::new_default(plan); // Execute a bottom-up traversal to enforce sorting requirements, // remove unnecessary sorts, and optimize sort-sensitive operators: - let adjusted = plan_requirements.transform_up(&ensure_sorting)?; + let adjusted = plan_requirements.transform_up(&ensure_sorting)?.data; let new_plan = if config.optimizer.repartition_sorts { let plan_with_coalesce_partitions = PlanWithCorrespondingCoalescePartitions::new_default(adjusted.plan); - let parallel = - plan_with_coalesce_partitions.transform_up(¶llelize_sorts)?; + let parallel = plan_with_coalesce_partitions + .transform_up(¶llelize_sorts) + .data()?; parallel.plan } else { adjusted.plan }; let plan_with_pipeline_fixer = OrderPreservationContext::new_default(new_plan); - let updated_plan = - plan_with_pipeline_fixer.transform_up(&|plan_with_pipeline_fixer| { + let updated_plan = plan_with_pipeline_fixer + .transform_up(&|plan_with_pipeline_fixer| { replace_with_order_preserving_variants( plan_with_pipeline_fixer, false, true, config, ) - })?; + }) + .data()?; // Execute a top-down traversal to exploit sort push-down opportunities // missed by the bottom-up traversal: let mut sort_pushdown = SortPushDown::new_default(updated_plan.plan); assign_initial_requirements(&mut sort_pushdown); - let adjusted = sort_pushdown.transform_down(&pushdown_sorts)?; + let adjusted = sort_pushdown.transform_down(&pushdown_sorts)?.data; adjusted .plan - .transform_up(&|plan| Ok(Transformed::Yes(replace_with_partial_sort(plan)?))) + .transform_up(&|plan| Ok(Transformed::yes(replace_with_partial_sort(plan)?))) + .data() } fn name(&self) -> &str { @@ -262,7 +265,7 @@ fn parallelize_sorts( // `SortPreservingMergeExec` or a `CoalescePartitionsExec`, and they // all have a single child. Therefore, if the first child has no // connection, we can return immediately. - Ok(Transformed::No(requirements)) + Ok(Transformed::no(requirements)) } else if (is_sort(&requirements.plan) || is_sort_preserving_merge(&requirements.plan)) && requirements.plan.output_partitioning().partition_count() <= 1 @@ -291,7 +294,7 @@ fn parallelize_sorts( } let spm = SortPreservingMergeExec::new(sort_exprs, requirements.plan.clone()); - Ok(Transformed::Yes( + Ok(Transformed::yes( PlanWithCorrespondingCoalescePartitions::new( Arc::new(spm.with_fetch(fetch)), false, @@ -305,7 +308,7 @@ fn parallelize_sorts( // For the removal of self node which is also a `CoalescePartitionsExec`. requirements = requirements.children.swap_remove(0); - Ok(Transformed::Yes( + Ok(Transformed::yes( PlanWithCorrespondingCoalescePartitions::new( Arc::new(CoalescePartitionsExec::new(requirements.plan.clone())), false, @@ -313,7 +316,7 @@ fn parallelize_sorts( ), )) } else { - Ok(Transformed::Yes(requirements)) + Ok(Transformed::yes(requirements)) } } @@ -326,10 +329,12 @@ fn ensure_sorting( // Perform naive analysis at the beginning -- remove already-satisfied sorts: if requirements.children.is_empty() { - return Ok(Transformed::No(requirements)); + return Ok(Transformed::no(requirements)); } let maybe_requirements = analyze_immediate_sort_removal(requirements); - let Transformed::No(mut requirements) = maybe_requirements else { + requirements = if !maybe_requirements.transformed { + maybe_requirements.data + } else { return Ok(maybe_requirements); }; @@ -368,17 +373,17 @@ fn ensure_sorting( // calculate the result in reverse: let child_node = &requirements.children[0]; if is_window(plan) && child_node.data { - return adjust_window_sort_removal(requirements).map(Transformed::Yes); + return adjust_window_sort_removal(requirements).map(Transformed::yes); } else if is_sort_preserving_merge(plan) && child_node.plan.output_partitioning().partition_count() <= 1 { // This `SortPreservingMergeExec` is unnecessary, input already has a // single partition. let child_node = requirements.children.swap_remove(0); - return Ok(Transformed::Yes(child_node)); + return Ok(Transformed::yes(child_node)); } - update_sort_ctx_children(requirements, false).map(Transformed::Yes) + update_sort_ctx_children(requirements, false).map(Transformed::yes) } /// Analyzes a given [`SortExec`] (`plan`) to determine whether its input @@ -408,10 +413,10 @@ fn analyze_immediate_sort_removal( child.data = false; } node.data = false; - return Transformed::Yes(node); + return Transformed::yes(node); } } - Transformed::No(node) + Transformed::no(node) } /// Adjusts a [`WindowAggExec`] or a [`BoundedWindowAggExec`] to determine @@ -683,6 +688,7 @@ mod tests { let plan_requirements = PlanWithCorrespondingSort::new_default($PLAN.clone()); let adjusted = plan_requirements .transform_up(&ensure_sorting) + .data() .and_then(check_integrity)?; // TODO: End state payloads will be checked here. @@ -691,6 +697,7 @@ mod tests { PlanWithCorrespondingCoalescePartitions::new_default(adjusted.plan); let parallel = plan_with_coalesce_partitions .transform_up(¶llelize_sorts) + .data() .and_then(check_integrity)?; // TODO: End state payloads will be checked here. parallel.plan @@ -708,6 +715,7 @@ mod tests { state.config_options(), ) }) + .data() .and_then(check_integrity)?; // TODO: End state payloads will be checked here. @@ -715,6 +723,7 @@ mod tests { assign_initial_requirements(&mut sort_pushdown); sort_pushdown .transform_down(&pushdown_sorts) + .data() .and_then(check_integrity)?; // TODO: End state payloads will be checked here. } diff --git a/datafusion/core/src/physical_optimizer/join_selection.rs b/datafusion/core/src/physical_optimizer/join_selection.rs index ee60c65ead0b..20104285e44a 100644 --- a/datafusion/core/src/physical_optimizer/join_selection.rs +++ b/datafusion/core/src/physical_optimizer/join_selection.rs @@ -37,7 +37,7 @@ use crate::physical_plan::projection::ProjectionExec; use crate::physical_plan::{ExecutionPlan, ExecutionPlanProperties}; use arrow_schema::Schema; -use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{internal_err, JoinSide, JoinType}; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::sort_properties::SortProperties; @@ -57,7 +57,7 @@ impl JoinSelection { } // TODO: We need some performance test for Right Semi/Right Join swap to Left Semi/Left Join in case that the right side is smaller but not much smaller. -// TODO: In PrestoSQL, the optimizer flips join sides only if one side is much smaller than the other by more than SIZE_DIFFERENCE_THRESHOLD times, by default is is 8 times. +// TODO: In PrestoSQL, the optimizer flips join sides only if one side is much smaller than the other by more than SIZE_DIFFERENCE_THRESHOLD times, by default is 8 times. /// Checks statistics for join swap. fn should_swap_join_order( left: &dyn ExecutionPlan, @@ -236,7 +236,9 @@ impl PhysicalOptimizerRule for JoinSelection { Box::new(hash_join_convert_symmetric_subrule), Box::new(hash_join_swap_subrule), ]; - let new_plan = plan.transform_up(&|p| apply_subrules(p, &subrules, config))?; + let new_plan = plan + .transform_up(&|p| apply_subrules(p, &subrules, config)) + .data()?; // Next, we apply another subrule that tries to optimize joins using any // statistics their inputs might have. // - For a hash join with partition mode [`PartitionMode::Auto`], we will @@ -251,13 +253,15 @@ impl PhysicalOptimizerRule for JoinSelection { let config = &config.optimizer; let collect_threshold_byte_size = config.hash_join_single_partition_threshold; let collect_threshold_num_rows = config.hash_join_single_partition_threshold_rows; - new_plan.transform_up(&|plan| { - statistical_join_selection_subrule( - plan, - collect_threshold_byte_size, - collect_threshold_num_rows, - ) - }) + new_plan + .transform_up(&|plan| { + statistical_join_selection_subrule( + plan, + collect_threshold_byte_size, + collect_threshold_num_rows, + ) + }) + .data() } fn name(&self) -> &str { @@ -433,9 +437,9 @@ fn statistical_join_selection_subrule( }; Ok(if let Some(transformed) = transformed { - Transformed::Yes(transformed) + Transformed::yes(transformed) } else { - Transformed::No(plan) + Transformed::no(plan) }) } @@ -647,7 +651,7 @@ fn apply_subrules( for subrule in subrules { input = subrule(input, config_options)?; } - Ok(Transformed::Yes(input)) + Ok(Transformed::yes(input)) } #[cfg(test)] @@ -808,8 +812,9 @@ mod tests_statistical { Box::new(hash_join_convert_symmetric_subrule), Box::new(hash_join_swap_subrule), ]; - let new_plan = - plan.transform_up(&|p| apply_subrules(p, &subrules, &ConfigOptions::new()))?; + let new_plan = plan + .transform_up(&|p| apply_subrules(p, &subrules, &ConfigOptions::new())) + .data()?; // TODO: End state payloads will be checked here. let config = ConfigOptions::new().optimizer; let collect_left_threshold = config.hash_join_single_partition_threshold; diff --git a/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs b/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs index 7be9acec5092..9509d4e4c828 100644 --- a/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs +++ b/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs @@ -26,7 +26,7 @@ use crate::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use crate::physical_plan::{ExecutionPlan, ExecutionPlanProperties}; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::Result; use itertools::Itertools; @@ -109,7 +109,7 @@ impl LimitedDistinctAggregation { let mut rewrite_applicable = true; let mut closure = |plan: Arc| { if !rewrite_applicable { - return Ok(Transformed::No(plan)); + return Ok(Transformed::no(plan)); } if let Some(aggr) = plan.as_any().downcast_ref::() { if found_match_aggr { @@ -120,7 +120,7 @@ impl LimitedDistinctAggregation { // a partial and final aggregation with different groupings disqualifies // rewriting the child aggregation rewrite_applicable = false; - return Ok(Transformed::No(plan)); + return Ok(Transformed::no(plan)); } } } @@ -131,14 +131,14 @@ impl LimitedDistinctAggregation { Some(new_aggr) => { match_aggr = plan; found_match_aggr = true; - return Ok(Transformed::Yes(new_aggr)); + return Ok(Transformed::yes(new_aggr)); } } } rewrite_applicable = false; - Ok(Transformed::No(plan)) + Ok(Transformed::no(plan)) }; - let child = child.clone().transform_down_mut(&mut closure).ok()?; + let child = child.clone().transform_down_mut(&mut closure).data().ok()?; if is_global_limit { return Some(Arc::new(GlobalLimitExec::new( child, @@ -162,22 +162,22 @@ impl PhysicalOptimizerRule for LimitedDistinctAggregation { plan: Arc, config: &ConfigOptions, ) -> Result> { - let plan = if config.optimizer.enable_distinct_aggregation_soft_limit { + if config.optimizer.enable_distinct_aggregation_soft_limit { plan.transform_down(&|plan| { Ok( if let Some(plan) = LimitedDistinctAggregation::transform_limit(plan.clone()) { - Transformed::Yes(plan) + Transformed::yes(plan) } else { - Transformed::No(plan) + Transformed::no(plan) }, ) - })? + }) + .data() } else { - plan - }; - Ok(plan) + Ok(plan) + } } fn name(&self) -> &str { diff --git a/datafusion/core/src/physical_optimizer/output_requirements.rs b/datafusion/core/src/physical_optimizer/output_requirements.rs index 7fea375725a5..bd71b3e8ed80 100644 --- a/datafusion/core/src/physical_optimizer/output_requirements.rs +++ b/datafusion/core/src/physical_optimizer/output_requirements.rs @@ -29,7 +29,7 @@ use crate::physical_plan::sorts::sort::SortExec; use crate::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan}; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{Result, Statistics}; use datafusion_physical_expr::{Distribution, LexRequirement, PhysicalSortRequirement}; use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; @@ -193,15 +193,17 @@ impl PhysicalOptimizerRule for OutputRequirements { ) -> Result> { match self.mode { RuleMode::Add => require_top_ordering(plan), - RuleMode::Remove => plan.transform_up(&|plan| { - if let Some(sort_req) = - plan.as_any().downcast_ref::() - { - Ok(Transformed::Yes(sort_req.input())) - } else { - Ok(Transformed::No(plan)) - } - }), + RuleMode::Remove => plan + .transform_up(&|plan| { + if let Some(sort_req) = + plan.as_any().downcast_ref::() + { + Ok(Transformed::yes(sort_req.input())) + } else { + Ok(Transformed::no(plan)) + } + }) + .data(), } } diff --git a/datafusion/core/src/physical_optimizer/pipeline_checker.rs b/datafusion/core/src/physical_optimizer/pipeline_checker.rs index e783f75378b1..1dc8bc5042bf 100644 --- a/datafusion/core/src/physical_optimizer/pipeline_checker.rs +++ b/datafusion/core/src/physical_optimizer/pipeline_checker.rs @@ -28,7 +28,7 @@ use crate::physical_plan::{ExecutionPlan, ExecutionPlanProperties}; use datafusion_common::config::OptimizerOptions; use datafusion_common::plan_err; -use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_physical_expr::intervals::utils::{check_support, is_datatype_supported}; use datafusion_physical_plan::joins::SymmetricHashJoinExec; @@ -51,6 +51,7 @@ impl PhysicalOptimizerRule for PipelineChecker { config: &ConfigOptions, ) -> Result> { plan.transform_up(&|p| check_finiteness_requirements(p, &config.optimizer)) + .data() } fn name(&self) -> &str { @@ -82,7 +83,7 @@ pub fn check_finiteness_requirements( input ) } else { - Ok(Transformed::No(input)) + Ok(Transformed::no(input)) } } diff --git a/datafusion/core/src/physical_optimizer/projection_pushdown.rs b/datafusion/core/src/physical_optimizer/projection_pushdown.rs index 4ed265d59526..17d30a2b4ec1 100644 --- a/datafusion/core/src/physical_optimizer/projection_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/projection_pushdown.rs @@ -43,7 +43,9 @@ use crate::physical_plan::{Distribution, ExecutionPlan, ExecutionPlanProperties} use arrow_schema::SchemaRef; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{ + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, +}; use datafusion_common::{DataFusionError, JoinSide}; use datafusion_physical_expr::expressions::{Column, Literal}; use datafusion_physical_expr::{ @@ -73,7 +75,7 @@ impl PhysicalOptimizerRule for ProjectionPushdown { plan: Arc, _config: &ConfigOptions, ) -> Result> { - plan.transform_down(&remove_unnecessary_projections) + plan.transform_down(&remove_unnecessary_projections).data() } fn name(&self) -> &str { @@ -98,7 +100,7 @@ pub fn remove_unnecessary_projections( // If the projection does not cause any change on the input, we can // safely remove it: if is_projection_removable(projection) { - return Ok(Transformed::Yes(projection.input().clone())); + return Ok(Transformed::yes(projection.input().clone())); } // If it does, check if we can push it under its child(ren): let input = projection.input().as_any(); @@ -111,8 +113,10 @@ pub fn remove_unnecessary_projections( return if let Some(new_plan) = maybe_unified { // To unify 3 or more sequential projections: remove_unnecessary_projections(new_plan) + .data() + .map(Transformed::yes) } else { - Ok(Transformed::No(plan)) + Ok(Transformed::no(plan)) }; } else if let Some(output_req) = input.downcast_ref::() { try_swapping_with_output_req(projection, output_req)? @@ -148,10 +152,10 @@ pub fn remove_unnecessary_projections( None } } else { - return Ok(Transformed::No(plan)); + return Ok(Transformed::no(plan)); }; - Ok(maybe_modified.map_or(Transformed::No(plan), Transformed::Yes)) + Ok(maybe_modified.map_or(Transformed::no(plan), Transformed::yes)) } /// Tries to embed `projection` to its input (`csv`). If possible, returns @@ -271,7 +275,7 @@ fn try_unifying_projections( if let Some(column) = expr.as_any().downcast_ref::() { *column_ref_map.entry(column.clone()).or_default() += 1; } - VisitRecursion::Continue + TreeNodeRecursion::Continue }) }) .unwrap(); @@ -893,16 +897,16 @@ fn update_expr( .clone() .transform_up_mut(&mut |expr: Arc| { if state == RewriteState::RewrittenInvalid { - return Ok(Transformed::No(expr)); + return Ok(Transformed::no(expr)); } let Some(column) = expr.as_any().downcast_ref::() else { - return Ok(Transformed::No(expr)); + return Ok(Transformed::no(expr)); }; if sync_with_child { state = RewriteState::RewrittenValid; // Update the index of `column`: - Ok(Transformed::Yes(projected_exprs[column.index()].0.clone())) + Ok(Transformed::yes(projected_exprs[column.index()].0.clone())) } else { // default to invalid, in case we can't find the relevant column state = RewriteState::RewrittenInvalid; @@ -923,11 +927,12 @@ fn update_expr( ) }) .map_or_else( - || Ok(Transformed::No(expr)), - |c| Ok(Transformed::Yes(c)), + || Ok(Transformed::no(expr)), + |c| Ok(Transformed::yes(c)), ) } - }); + }) + .data(); new_expr.map(|e| (state == RewriteState::RewrittenValid).then_some(e)) } @@ -1044,7 +1049,7 @@ fn new_columns_for_join_on( }) .map(|(index, (_, alias))| Column::new(alias, index)); if let Some(new_column) = new_column { - Ok(Transformed::Yes(Arc::new(new_column))) + Ok(Transformed::yes(Arc::new(new_column))) } else { // If the column is not found in the projection expressions, // it means that the column is not projected. In this case, @@ -1055,9 +1060,10 @@ fn new_columns_for_join_on( ))) } } else { - Ok(Transformed::No(expr)) + Ok(Transformed::no(expr)) } }) + .data() .ok() }) .collect::>(); diff --git a/datafusion/core/src/physical_optimizer/pruning.rs b/datafusion/core/src/physical_optimizer/pruning.rs index 37f705d8a82f..05d2d852e057 100644 --- a/datafusion/core/src/physical_optimizer/pruning.rs +++ b/datafusion/core/src/physical_optimizer/pruning.rs @@ -29,20 +29,22 @@ use crate::{ logical_expr::Operator, physical_plan::{ColumnarValue, PhysicalExpr}, }; -use arrow::record_batch::RecordBatchOptions; + use arrow::{ array::{new_null_array, ArrayRef, BooleanArray}, datatypes::{DataType, Field, Schema, SchemaRef}, - record_batch::RecordBatch, + record_batch::{RecordBatch, RecordBatchOptions}, }; use arrow_array::cast::AsArray; +use datafusion_common::tree_node::TransformedResult; use datafusion_common::{ - internal_err, plan_err, + internal_err, plan_datafusion_err, plan_err, tree_node::{Transformed, TreeNode}, + ScalarValue, }; -use datafusion_common::{plan_datafusion_err, ScalarValue}; use datafusion_physical_expr::utils::{collect_columns, Guarantee, LiteralGuarantee}; use datafusion_physical_expr::{expressions as phys_expr, PhysicalExprRef}; + use log::trace; /// A source of runtime statistical information to [`PruningPredicate`]s. @@ -1034,12 +1036,13 @@ fn rewrite_column_expr( e.transform(&|expr| { if let Some(column) = expr.as_any().downcast_ref::() { if column == column_old { - return Ok(Transformed::Yes(Arc::new(column_new.clone()))); + return Ok(Transformed::yes(Arc::new(column_new.clone()))); } } - Ok(Transformed::No(expr)) + Ok(Transformed::no(expr)) }) + .data() } fn reverse_operator(op: Operator) -> Result { diff --git a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs index c0abde26c300..e8b6a78b929e 100644 --- a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs +++ b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs @@ -236,7 +236,7 @@ pub(crate) fn replace_with_order_preserving_variants( ) -> Result> { update_children(&mut requirements); if !(is_sort(&requirements.plan) && requirements.children[0].data) { - return Ok(Transformed::No(requirements)); + return Ok(Transformed::no(requirements)); } // For unbounded cases, we replace with the order-preserving variant in any @@ -260,13 +260,13 @@ pub(crate) fn replace_with_order_preserving_variants( for child in alternate_plan.children.iter_mut() { child.data = false; } - Ok(Transformed::Yes(alternate_plan)) + Ok(Transformed::yes(alternate_plan)) } else { // The alternate plan does not help, use faster order-breaking variants: alternate_plan = plan_with_order_breaking_variants(alternate_plan)?; alternate_plan.data = false; requirements.children = vec![alternate_plan]; - Ok(Transformed::Yes(requirements)) + Ok(Transformed::yes(requirements)) } } @@ -293,7 +293,7 @@ mod tests { use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; - use datafusion_common::tree_node::TreeNode; + use datafusion_common::tree_node::{TransformedResult, TreeNode}; use datafusion_common::{Result, Statistics}; use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_expr::{JoinType, Operator}; @@ -395,7 +395,7 @@ mod tests { // Run the rule top-down let config = SessionConfig::new().with_prefer_existing_sort($PREFER_EXISTING_SORT); let plan_with_pipeline_fixer = OrderPreservationContext::new_default(physical_plan); - let parallel = plan_with_pipeline_fixer.transform_up(&|plan_with_pipeline_fixer| replace_with_order_preserving_variants(plan_with_pipeline_fixer, false, false, config.options())).and_then(check_integrity)?; + let parallel = plan_with_pipeline_fixer.transform_up(&|plan_with_pipeline_fixer| replace_with_order_preserving_variants(plan_with_pipeline_fixer, false, false, config.options())).data().and_then(check_integrity)?; let optimized_physical_plan = parallel.plan; // Get string representation of the plan diff --git a/datafusion/core/src/physical_optimizer/sort_pushdown.rs b/datafusion/core/src/physical_optimizer/sort_pushdown.rs index ff82319fba19..c527819e7746 100644 --- a/datafusion/core/src/physical_optimizer/sort_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/sort_pushdown.rs @@ -87,7 +87,7 @@ pub(crate) fn pushdown_sorts( } // Can push down requirements child.data = None; - return Ok(Transformed::Yes(child)); + return Ok(Transformed::yes(child)); } else { // Can not push down requirements requirements.children = vec![child]; @@ -112,7 +112,7 @@ pub(crate) fn pushdown_sorts( requirements = add_sort_above(requirements, sort_reqs, None); assign_initial_requirements(&mut requirements); } - Ok(Transformed::Yes(requirements)) + Ok(Transformed::yes(requirements)) } fn pushdown_requirement_to_children( diff --git a/datafusion/core/src/physical_optimizer/test_utils.rs b/datafusion/core/src/physical_optimizer/test_utils.rs index 3898fb6345f0..d944cedb0f96 100644 --- a/datafusion/core/src/physical_optimizer/test_utils.rs +++ b/datafusion/core/src/physical_optimizer/test_utils.rs @@ -40,7 +40,7 @@ use crate::physical_plan::{ExecutionPlan, InputOrderMode, Partitioning}; use crate::prelude::{CsvReadOptions, SessionContext}; use arrow_schema::{Schema, SchemaRef, SortOptions}; -use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{JoinType, Statistics}; use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_expr::{AggregateFunction, WindowFrame, WindowFunctionDefinition}; @@ -376,15 +376,19 @@ pub fn sort_exec( /// TODO: Once [`ExecutionPlan`] implements [`PartialEq`], string comparisons should be /// replaced with direct plan equality checks. pub fn check_integrity(context: PlanContext) -> Result> { - context.transform_up(&|node| { - let children_plans = node.plan.children(); - assert_eq!(node.children.len(), children_plans.len()); - for (child_plan, child_node) in children_plans.iter().zip(node.children.iter()) { - assert_eq!( - displayable(child_plan.as_ref()).one_line().to_string(), - displayable(child_node.plan.as_ref()).one_line().to_string() - ); - } - Ok(Transformed::No(node)) - }) + context + .transform_up(&|node| { + let children_plans = node.plan.children(); + assert_eq!(node.children.len(), children_plans.len()); + for (child_plan, child_node) in + children_plans.iter().zip(node.children.iter()) + { + assert_eq!( + displayable(child_plan.as_ref()).one_line().to_string(), + displayable(child_node.plan.as_ref()).one_line().to_string() + ); + } + Ok(Transformed::no(node)) + }) + .data() } diff --git a/datafusion/core/src/physical_optimizer/topk_aggregation.rs b/datafusion/core/src/physical_optimizer/topk_aggregation.rs index 0ca709e56bcb..c47e5e25d143 100644 --- a/datafusion/core/src/physical_optimizer/topk_aggregation.rs +++ b/datafusion/core/src/physical_optimizer/topk_aggregation.rs @@ -29,7 +29,7 @@ use crate::physical_plan::ExecutionPlan; use arrow_schema::DataType; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::Result; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::PhysicalSortExpr; @@ -104,13 +104,13 @@ impl TopKAggregation { let mut cardinality_preserved = true; let mut closure = |plan: Arc| { if !cardinality_preserved { - return Ok(Transformed::No(plan)); + return Ok(Transformed::no(plan)); } if let Some(aggr) = plan.as_any().downcast_ref::() { // either we run into an Aggregate and transform it match Self::transform_agg(aggr, order, limit) { None => cardinality_preserved = false, - Some(plan) => return Ok(Transformed::Yes(plan)), + Some(plan) => return Ok(Transformed::yes(plan)), } } else { // or we continue down whitelisted nodes of other types @@ -118,9 +118,9 @@ impl TopKAggregation { cardinality_preserved = false; } } - Ok(Transformed::No(plan)) + Ok(Transformed::no(plan)) }; - let child = child.clone().transform_down_mut(&mut closure).ok()?; + let child = child.clone().transform_down_mut(&mut closure).data().ok()?; let sort = SortExec::new(sort.expr().to_vec(), child) .with_fetch(sort.fetch()) .with_preserve_partitioning(sort.preserve_partitioning()); @@ -140,20 +140,20 @@ impl PhysicalOptimizerRule for TopKAggregation { plan: Arc, config: &ConfigOptions, ) -> Result> { - let plan = if config.optimizer.enable_topk_aggregation { + if config.optimizer.enable_topk_aggregation { plan.transform_down(&|plan| { Ok( if let Some(plan) = TopKAggregation::transform_sort(plan.clone()) { - Transformed::Yes(plan) + Transformed::yes(plan) } else { - Transformed::No(plan) + Transformed::no(plan) }, ) - })? + }) + .data() } else { - plan - }; - Ok(plan) + Ok(plan) + } } fn name(&self) -> &str { diff --git a/datafusion/core/tests/data/partitioned_table_arrow/part=123/data.arrow b/datafusion/core/tests/data/partitioned_table_arrow/part=123/data.arrow new file mode 100644 index 000000000000..48151a2ed240 Binary files /dev/null and b/datafusion/core/tests/data/partitioned_table_arrow/part=123/data.arrow differ diff --git a/datafusion/core/tests/data/partitioned_table_arrow/part=456/data.arrow b/datafusion/core/tests/data/partitioned_table_arrow/part=456/data.arrow new file mode 100644 index 000000000000..be932c7f656a Binary files /dev/null and b/datafusion/core/tests/data/partitioned_table_arrow/part=456/data.arrow differ diff --git a/datafusion/core/tests/data/partitioned_table_json/part=1/data.json b/datafusion/core/tests/data/partitioned_table_json/part=1/data.json new file mode 100644 index 000000000000..466c5b3dc4ab --- /dev/null +++ b/datafusion/core/tests/data/partitioned_table_json/part=1/data.json @@ -0,0 +1,2 @@ +{"id": 1, "value": "foo"} +{"id": 2, "value": "bar"} diff --git a/datafusion/core/tests/data/partitioned_table_json/part=2/data.json b/datafusion/core/tests/data/partitioned_table_json/part=2/data.json new file mode 100644 index 000000000000..857d70e1f397 --- /dev/null +++ b/datafusion/core/tests/data/partitioned_table_json/part=2/data.json @@ -0,0 +1,2 @@ +{"id": 3, "value": "baz"} +{"id": 4, "value": "qux"} diff --git a/datafusion/core/tests/dataframe/dataframe_functions.rs b/datafusion/core/tests/dataframe/dataframe_functions.rs index ff553a48888b..c857202c237e 100644 --- a/datafusion/core/tests/dataframe/dataframe_functions.rs +++ b/datafusion/core/tests/dataframe/dataframe_functions.rs @@ -282,26 +282,6 @@ async fn test_fn_initcap() -> Result<()> { Ok(()) } -#[tokio::test] -async fn test_fn_instr() -> Result<()> { - let expr = instr(col("a"), lit("b")); - - let expected = [ - "+-------------------------+", - "| instr(test.a,Utf8(\"b\")) |", - "+-------------------------+", - "| 2 |", - "| 2 |", - "| 0 |", - "| 5 |", - "+-------------------------+", - ]; - - assert_fn_batches!(expr, expected); - - Ok(()) -} - #[tokio::test] #[cfg(feature = "unicode_expressions")] async fn test_fn_left() -> Result<()> { @@ -446,7 +426,7 @@ async fn test_fn_md5() -> Result<()> { #[tokio::test] #[cfg(feature = "unicode_expressions")] async fn test_fn_regexp_like() -> Result<()> { - let expr = regexp_like(vec![col("a"), lit("[a-z]")]); + let expr = regexp_like(col("a"), lit("[a-z]")); let expected = [ "+-----------------------------------+", diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs index 6b371b782cb5..59905d859dc8 100644 --- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs @@ -17,19 +17,12 @@ use std::sync::Arc; -use arrow::array::{ArrayRef, Int64Array}; +use arrow::array::{Array, ArrayRef, AsArray, Int64Array}; use arrow::compute::{concat_batches, SortOptions}; use arrow::datatypes::DataType; use arrow::record_batch::RecordBatch; use arrow::util::pretty::pretty_format_batches; -use arrow_array::cast::AsArray; use arrow_array::types::Int64Type; -use arrow_array::Array; -use hashbrown::HashMap; -use rand::rngs::StdRng; -use rand::{Rng, SeedableRng}; -use tokio::task::JoinSet; - use datafusion::common::Result; use datafusion::datasource::MemTable; use datafusion::physical_plan::aggregates::{ @@ -38,12 +31,17 @@ use datafusion::physical_plan::aggregates::{ use datafusion::physical_plan::memory::MemoryExec; use datafusion::physical_plan::{collect, displayable, ExecutionPlan}; use datafusion::prelude::{DataFrame, SessionConfig, SessionContext}; -use datafusion_common::tree_node::{TreeNode, TreeNodeVisitor, VisitRecursion}; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor}; use datafusion_physical_expr::expressions::{col, Sum}; use datafusion_physical_expr::{AggregateExpr, PhysicalSortExpr}; use datafusion_physical_plan::InputOrderMode; use test_utils::{add_empty_batches, StringBatchGenerator}; +use hashbrown::HashMap; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; +use tokio::task::JoinSet; + /// Tests that streaming aggregate and batch (non streaming) aggregate produce /// same results #[tokio::test(flavor = "multi_thread")] @@ -315,8 +313,9 @@ async fn verify_ordered_aggregate(frame: &DataFrame, expected_sort: bool) { let mut visitor = Visitor { expected_sort }; impl TreeNodeVisitor for Visitor { - type N = Arc; - fn pre_visit(&mut self, node: &Self::N) -> Result { + type Node = Arc; + + fn f_down(&mut self, node: &Self::Node) -> Result { if let Some(exec) = node.as_any().downcast_ref::() { if self.expected_sort { assert!(matches!( @@ -327,7 +326,7 @@ async fn verify_ordered_aggregate(frame: &DataFrame, expected_sort: bool) { assert!(matches!(exec.input_order_mode(), InputOrderMode::Linear)); } } - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) } } diff --git a/datafusion/core/tests/fuzz_cases/merge_fuzz.rs b/datafusion/core/tests/fuzz_cases/merge_fuzz.rs index c38ff41f5783..95cd75f50a00 100644 --- a/datafusion/core/tests/fuzz_cases/merge_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/merge_fuzz.rs @@ -89,7 +89,7 @@ async fn test_merge_3_gaps() { /// Merge a set of input streams using SortPreservingMergeExec and /// `Vec::sort` and ensure the results are the same. /// -/// For each case, the `input` streams are turned into a set of of +/// For each case, the `input` streams are turned into a set of /// streams which are then merged together by [SortPreservingMerge] /// /// Each `Vec` in `input` must be sorted and have a diff --git a/datafusion/core/tests/fuzz_cases/window_fuzz.rs b/datafusion/core/tests/fuzz_cases/window_fuzz.rs index 1cab4d5c2f98..ee5e34bd703f 100644 --- a/datafusion/core/tests/fuzz_cases/window_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/window_fuzz.rs @@ -30,6 +30,7 @@ use datafusion::physical_plan::windows::{ use datafusion::physical_plan::{collect, ExecutionPlan, InputOrderMode}; use datafusion::prelude::{SessionConfig, SessionContext}; use datafusion_common::{Result, ScalarValue}; +use datafusion_common_runtime::SpawnedTask; use datafusion_expr::type_coercion::aggregates::coerce_types; use datafusion_expr::{ AggregateFunction, BuiltInWindowFunction, WindowFrame, WindowFrameBound, @@ -123,8 +124,7 @@ async fn window_bounded_window_random_comparison() -> Result<()> { for i in 0..n { let idx = i % test_cases.len(); let (pb_cols, ob_cols, search_mode) = test_cases[idx].clone(); - #[allow(clippy::disallowed_methods)] // spawn allowed only in tests - let job = tokio::spawn(run_window_test( + let job = SpawnedTask::spawn(run_window_test( make_staggered_batches::(1000, n_distinct, i as u64), i as u64, pb_cols, @@ -134,7 +134,7 @@ async fn window_bounded_window_random_comparison() -> Result<()> { handles.push(job); } for job in handles { - job.await.unwrap()?; + job.join().await.unwrap()?; } } Ok(()) diff --git a/datafusion/core/tests/memory_limit.rs b/datafusion/core/tests/memory_limit.rs index a98d097856fb..4735a97fee0c 100644 --- a/datafusion/core/tests/memory_limit.rs +++ b/datafusion/core/tests/memory_limit.rs @@ -473,7 +473,7 @@ enum Scenario { /// [`StreamingTable`] AccessLogStreaming, - /// N partitions of of sorted, dictionary encoded strings. + /// N partitions of sorted, dictionary encoded strings. DictionaryStrings { partitions: usize, /// If true, splits all input batches into 1 row each diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index e13ab0c86329..2d70f6051e15 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -128,8 +128,6 @@ pub enum BuiltinScalarFunction { ArrayPopFront, /// array_pop_back ArrayPopBack, - /// array_dims - ArrayDims, /// array_distinct ArrayDistinct, /// array_element @@ -138,8 +136,6 @@ pub enum BuiltinScalarFunction { ArrayEmpty, /// array_length ArrayLength, - /// array_ndims - ArrayNdims, /// array_position ArrayPosition, /// array_positions @@ -170,8 +166,6 @@ pub enum BuiltinScalarFunction { ArrayUnion, /// array_except ArrayExcept, - /// cardinality - Cardinality, /// array_resize ArrayResize, /// construct an array from columns @@ -208,8 +202,6 @@ pub enum BuiltinScalarFunction { EndsWith, /// initcap InitCap, - /// InStr - InStr, /// left Left, /// lpad @@ -224,8 +216,6 @@ pub enum BuiltinScalarFunction { OctetLength, /// random Random, - /// regexp_like - RegexpLike, /// regexp_match /// regexp_replace RegexpReplace, @@ -383,12 +373,10 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ArrayHasAll => Volatility::Immutable, BuiltinScalarFunction::ArrayHasAny => Volatility::Immutable, BuiltinScalarFunction::ArrayHas => Volatility::Immutable, - BuiltinScalarFunction::ArrayDims => Volatility::Immutable, BuiltinScalarFunction::ArrayDistinct => Volatility::Immutable, BuiltinScalarFunction::ArrayElement => Volatility::Immutable, BuiltinScalarFunction::ArrayExcept => Volatility::Immutable, BuiltinScalarFunction::ArrayLength => Volatility::Immutable, - BuiltinScalarFunction::ArrayNdims => Volatility::Immutable, BuiltinScalarFunction::ArrayPopFront => Volatility::Immutable, BuiltinScalarFunction::ArrayPopBack => Volatility::Immutable, BuiltinScalarFunction::ArrayPosition => Volatility::Immutable, @@ -407,7 +395,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ArrayIntersect => Volatility::Immutable, BuiltinScalarFunction::ArrayUnion => Volatility::Immutable, BuiltinScalarFunction::ArrayResize => Volatility::Immutable, - BuiltinScalarFunction::Cardinality => Volatility::Immutable, BuiltinScalarFunction::MakeArray => Volatility::Immutable, BuiltinScalarFunction::Ascii => Volatility::Immutable, BuiltinScalarFunction::BitLength => Volatility::Immutable, @@ -421,7 +408,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::DateBin => Volatility::Immutable, BuiltinScalarFunction::EndsWith => Volatility::Immutable, BuiltinScalarFunction::InitCap => Volatility::Immutable, - BuiltinScalarFunction::InStr => Volatility::Immutable, BuiltinScalarFunction::Left => Volatility::Immutable, BuiltinScalarFunction::Lpad => Volatility::Immutable, BuiltinScalarFunction::Lower => Volatility::Immutable, @@ -429,7 +415,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::MD5 => Volatility::Immutable, BuiltinScalarFunction::OctetLength => Volatility::Immutable, BuiltinScalarFunction::Radians => Volatility::Immutable, - BuiltinScalarFunction::RegexpLike => Volatility::Immutable, BuiltinScalarFunction::RegexpReplace => Volatility::Immutable, BuiltinScalarFunction::Repeat => Volatility::Immutable, BuiltinScalarFunction::Replace => Volatility::Immutable, @@ -559,9 +544,6 @@ impl BuiltinScalarFunction { | BuiltinScalarFunction::ArrayHasAny | BuiltinScalarFunction::ArrayHas | BuiltinScalarFunction::ArrayEmpty => Ok(Boolean), - BuiltinScalarFunction::ArrayDims => { - Ok(List(Arc::new(Field::new("item", UInt64, true)))) - } BuiltinScalarFunction::ArrayDistinct => Ok(input_expr_types[0].clone()), BuiltinScalarFunction::ArrayElement => match &input_expr_types[0] { List(field) @@ -572,7 +554,6 @@ impl BuiltinScalarFunction { ), }, BuiltinScalarFunction::ArrayLength => Ok(UInt64), - BuiltinScalarFunction::ArrayNdims => Ok(UInt64), BuiltinScalarFunction::ArrayPopFront => Ok(input_expr_types[0].clone()), BuiltinScalarFunction::ArrayPopBack => Ok(input_expr_types[0].clone()), BuiltinScalarFunction::ArrayPosition => Ok(UInt64), @@ -620,7 +601,6 @@ impl BuiltinScalarFunction { (dt, _) => Ok(dt), } } - BuiltinScalarFunction::Cardinality => Ok(UInt64), BuiltinScalarFunction::MakeArray => match input_expr_types.len() { 0 => Ok(List(Arc::new(Field::new("item", Null, true)))), _ => { @@ -677,9 +657,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::InitCap => { utf8_to_str_type(&input_expr_types[0], "initcap") } - BuiltinScalarFunction::InStr => { - utf8_to_int_type(&input_expr_types[0], "instr/position") - } BuiltinScalarFunction::Left => utf8_to_str_type(&input_expr_types[0], "left"), BuiltinScalarFunction::Lower => { utf8_to_str_type(&input_expr_types[0], "lower") @@ -740,7 +717,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::StartsWith => Ok(Boolean), BuiltinScalarFunction::EndsWith => Ok(Boolean), BuiltinScalarFunction::Strpos => { - utf8_to_int_type(&input_expr_types[0], "strpos") + utf8_to_int_type(&input_expr_types[0], "strpos/instr/position") } BuiltinScalarFunction::Substr => { utf8_to_str_type(&input_expr_types[0], "substr") @@ -772,15 +749,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Upper => { utf8_to_str_type(&input_expr_types[0], "upper") } - BuiltinScalarFunction::RegexpLike => Ok(match &input_expr_types[0] { - LargeUtf8 | Utf8 => Boolean, - Null => Null, - other => { - return plan_err!( - "The regexp_like function can only accept strings. Got {other}" - ); - } - }), BuiltinScalarFunction::Factorial | BuiltinScalarFunction::Gcd @@ -882,7 +850,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ArrayConcat => { Signature::variadic_any(self.volatility()) } - BuiltinScalarFunction::ArrayDims => Signature::array(self.volatility()), BuiltinScalarFunction::ArrayEmpty => Signature::array(self.volatility()), BuiltinScalarFunction::ArrayElement => { Signature::array_and_index(self.volatility()) @@ -898,7 +865,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ArrayLength => { Signature::variadic_any(self.volatility()) } - BuiltinScalarFunction::ArrayNdims => Signature::array(self.volatility()), BuiltinScalarFunction::ArrayDistinct => Signature::array(self.volatility()), BuiltinScalarFunction::ArrayPosition => { Signature::array_and_element_and_optional_index(self.volatility()) @@ -929,7 +895,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ArrayIntersect => Signature::any(2, self.volatility()), BuiltinScalarFunction::ArrayUnion => Signature::any(2, self.volatility()), - BuiltinScalarFunction::Cardinality => Signature::array(self.volatility()), BuiltinScalarFunction::ArrayResize => { Signature::variadic_any(self.volatility()) } @@ -1157,7 +1122,6 @@ impl BuiltinScalarFunction { ), BuiltinScalarFunction::EndsWith - | BuiltinScalarFunction::InStr | BuiltinScalarFunction::Strpos | BuiltinScalarFunction::StartsWith => Signature::one_of( vec![ @@ -1194,15 +1158,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Replace | BuiltinScalarFunction::Translate => { Signature::one_of(vec![Exact(vec![Utf8, Utf8, Utf8])], self.volatility()) } - BuiltinScalarFunction::RegexpLike => Signature::one_of( - vec![ - Exact(vec![Utf8, Utf8]), - Exact(vec![LargeUtf8, Utf8]), - Exact(vec![Utf8, Utf8, Utf8]), - Exact(vec![LargeUtf8, Utf8, Utf8]), - ], - self.volatility(), - ), BuiltinScalarFunction::RegexpReplace => Signature::one_of( vec![ Exact(vec![Utf8, Utf8, Utf8]), @@ -1413,7 +1368,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Chr => &["chr"], BuiltinScalarFunction::EndsWith => &["ends_with"], BuiltinScalarFunction::InitCap => &["initcap"], - BuiltinScalarFunction::InStr => &["instr", "position"], BuiltinScalarFunction::Left => &["left"], BuiltinScalarFunction::Lower => &["lower"], BuiltinScalarFunction::Lpad => &["lpad"], @@ -1430,7 +1384,7 @@ impl BuiltinScalarFunction { &["string_to_array", "string_to_list"] } BuiltinScalarFunction::StartsWith => &["starts_with"], - BuiltinScalarFunction::Strpos => &["strpos"], + BuiltinScalarFunction::Strpos => &["strpos", "instr", "position"], BuiltinScalarFunction::Substr => &["substr"], BuiltinScalarFunction::ToHex => &["to_hex"], BuiltinScalarFunction::Translate => &["translate"], @@ -1442,7 +1396,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::FindInSet => &["find_in_set"], // regex functions - BuiltinScalarFunction::RegexpLike => &["regexp_like"], BuiltinScalarFunction::RegexpReplace => &["regexp_replace"], // time/date functions @@ -1478,7 +1431,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ArrayConcat => { &["array_concat", "array_cat", "list_concat", "list_cat"] } - BuiltinScalarFunction::ArrayDims => &["array_dims", "list_dims"], BuiltinScalarFunction::ArrayDistinct => &["array_distinct", "list_distinct"], BuiltinScalarFunction::ArrayEmpty => &["empty"], BuiltinScalarFunction::ArrayElement => &[ @@ -1495,7 +1447,6 @@ impl BuiltinScalarFunction { &["array_has", "list_has", "array_contains", "list_contains"] } BuiltinScalarFunction::ArrayLength => &["array_length", "list_length"], - BuiltinScalarFunction::ArrayNdims => &["array_ndims", "list_ndims"], BuiltinScalarFunction::ArrayPopFront => { &["array_pop_front", "list_pop_front"] } @@ -1531,7 +1482,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ArrayReverse => &["array_reverse", "list_reverse"], BuiltinScalarFunction::ArraySlice => &["array_slice", "list_slice"], BuiltinScalarFunction::ArrayUnion => &["array_union", "list_union"], - BuiltinScalarFunction::Cardinality => &["cardinality"], BuiltinScalarFunction::ArrayResize => &["array_resize", "list_resize"], BuiltinScalarFunction::MakeArray => &["make_array", "make_list"], BuiltinScalarFunction::ArrayIntersect => { diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index d1687cbd6f29..68b123ab1f28 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -17,28 +17,27 @@ //! Logical Expressions: [`Expr`] +use std::collections::HashSet; +use std::fmt::{self, Display, Formatter, Write}; +use std::hash::Hash; +use std::str::FromStr; +use std::sync::Arc; + use crate::expr_fn::binary_expr; use crate::logical_plan::Subquery; use crate::utils::{expr_to_columns, find_out_reference_exprs}; use crate::window_frame; +use crate::{ + aggregate_function, built_in_function, built_in_window_function, udaf, + BuiltinScalarFunction, ExprSchemable, Operator, Signature, +}; -use crate::Operator; -use crate::{aggregate_function, ExprSchemable}; -use crate::{built_in_function, BuiltinScalarFunction}; -use crate::{built_in_window_function, udaf}; use arrow::datatypes::DataType; -use datafusion_common::tree_node::{Transformed, TreeNode}; -use datafusion_common::{internal_err, DFSchema, OwnedTableReference}; -use datafusion_common::{plan_err, Column, Result, ScalarValue}; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_common::{ + internal_err, plan_err, Column, DFSchema, OwnedTableReference, Result, ScalarValue, +}; use sqlparser::ast::NullTreatment; -use std::collections::HashSet; -use std::fmt; -use std::fmt::{Display, Formatter, Write}; -use std::hash::Hash; -use std::str::FromStr; -use std::sync::Arc; - -use crate::Signature; /// `Expr` is a central struct of DataFusion's query API, and /// represent logical expressions such as `A + 1`, or `CAST(c1 AS @@ -1275,8 +1274,9 @@ impl Expr { rewrite_placeholder(low.as_mut(), expr.as_ref(), schema)?; rewrite_placeholder(high.as_mut(), expr.as_ref(), schema)?; } - Ok(Transformed::Yes(expr)) + Ok(Transformed::yes(expr)) }) + .data() } /// Returns true if some of this `exprs` subexpressions may not be evaluated diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 157b8b0989df..ec53fd4ef1de 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -628,12 +628,6 @@ scalar_expr!( array, "flattens an array of arrays into a single array." ); -scalar_expr!( - ArrayDims, - array_dims, - array, - "returns an array of the array's dimensions." -); scalar_expr!( ArrayElement, array_element, @@ -652,12 +646,6 @@ scalar_expr!( array dimension, "returns the length of the array dimension." ); -scalar_expr!( - ArrayNdims, - array_ndims, - array, - "returns the number of dimensions of the array." -); scalar_expr!( ArrayDistinct, array_distinct, @@ -738,13 +726,6 @@ scalar_expr!( ); scalar_expr!(ArrayUnion, array_union, array1 array2, "returns an array of the elements in the union of array1 and array2 without duplicates."); -scalar_expr!( - Cardinality, - cardinality, - array, - "returns the total number of elements in the array." -); - scalar_expr!( ArrayResize, array_resize, @@ -786,7 +767,6 @@ scalar_expr!( ); scalar_expr!(Digest, digest, input algorithm, "compute the binary hash of `input`, using the `algorithm`"); scalar_expr!(InitCap, initcap, string, "converts the first letter of each word in `string` in uppercase and the remaining characters in lowercase"); -scalar_expr!(InStr, instr, string substring, "returns the position of the first occurrence of `substring` in `string`"); scalar_expr!(Left, left, string n, "returns the first `n` characters in the `string`"); scalar_expr!(Lower, lower, string, "convert the string to lower case"); scalar_expr!( @@ -842,11 +822,6 @@ nary_scalar_expr!( rpad, "fill up a string to the length by appending the characters" ); -nary_scalar_expr!( - RegexpLike, - regexp_like, - "matches a regular expression against a string and returns true or false if there was at least one match or not" -); nary_scalar_expr!( RegexpReplace, regexp_replace, @@ -1332,7 +1307,6 @@ mod test { test_scalar_expr!(Gcd, gcd, arg_1, arg_2); test_scalar_expr!(Lcm, lcm, arg_1, arg_2); test_scalar_expr!(InitCap, initcap, string); - test_scalar_expr!(InStr, instr, string, substring); test_scalar_expr!(Left, left, string, count); test_scalar_expr!(Lower, lower, string); test_nary_scalar_expr!(Lpad, lpad, string, count); @@ -1340,8 +1314,6 @@ mod test { test_scalar_expr!(Ltrim, ltrim, string); test_scalar_expr!(MD5, md5, string); test_scalar_expr!(OctetLength, octet_length, string); - test_nary_scalar_expr!(RegexpLike, regexp_like, string, pattern); - test_nary_scalar_expr!(RegexpLike, regexp_like, string, pattern, flags); test_nary_scalar_expr!( RegexpReplace, regexp_replace, @@ -1389,9 +1361,7 @@ mod test { test_scalar_expr!(ArraySort, array_sort, array, desc, null_first); test_scalar_expr!(ArrayPopFront, array_pop_front, array); test_scalar_expr!(ArrayPopBack, array_pop_back, array); - test_unary_scalar_expr!(ArrayDims, array_dims); test_scalar_expr!(ArrayLength, array_length, array, dimension); - test_unary_scalar_expr!(ArrayNdims, array_ndims); test_scalar_expr!(ArrayPosition, array_position, array, element, index); test_scalar_expr!(ArrayPositions, array_positions, array, element); test_scalar_expr!(ArrayPrepend, array_prepend, array, element); @@ -1402,7 +1372,6 @@ mod test { test_scalar_expr!(ArrayReplace, array_replace, array, from, to); test_scalar_expr!(ArrayReplaceN, array_replace_n, array, from, to, max); test_scalar_expr!(ArrayReplaceAll, array_replace_all, array, from, to); - test_unary_scalar_expr!(Cardinality, cardinality); test_nary_scalar_expr!(MakeArray, array, input); test_unary_scalar_expr!(ArrowTypeof, arrow_typeof); diff --git a/datafusion/expr/src/expr_rewriter/mod.rs b/datafusion/expr/src/expr_rewriter/mod.rs index 3f7388c3c3d5..cd9a8344dec4 100644 --- a/datafusion/expr/src/expr_rewriter/mod.rs +++ b/datafusion/expr/src/expr_rewriter/mod.rs @@ -17,16 +17,19 @@ //! Expression rewriter -use crate::expr::{Alias, Unnest}; -use crate::logical_plan::Projection; -use crate::{Expr, ExprSchemable, LogicalPlan, LogicalPlanBuilder}; -use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter}; -use datafusion_common::Result; -use datafusion_common::{Column, DFSchema}; use std::collections::HashMap; use std::collections::HashSet; use std::sync::Arc; +use crate::expr::{Alias, Unnest}; +use crate::logical_plan::Projection; +use crate::{Expr, ExprSchemable, LogicalPlan, LogicalPlanBuilder}; + +use datafusion_common::tree_node::{ + Transformed, TransformedResult, TreeNode, TreeNodeRewriter, +}; +use datafusion_common::{Column, DFSchema, Result}; + mod order_by; pub use order_by::rewrite_sort_cols_by_aggs; @@ -37,12 +40,13 @@ pub fn normalize_col(expr: Expr, plan: &LogicalPlan) -> Result { Ok({ if let Expr::Column(c) = expr { let col = LogicalPlanBuilder::normalize(plan, c)?; - Transformed::Yes(Expr::Column(col)) + Transformed::yes(Expr::Column(col)) } else { - Transformed::No(expr) + Transformed::no(expr) } }) }) + .data() } /// Recursively call [`Column::normalize_with_schemas`] on all [`Column`] expressions @@ -61,12 +65,13 @@ pub fn normalize_col_with_schemas( Ok({ if let Expr::Column(c) = expr { let col = c.normalize_with_schemas(schemas, using_columns)?; - Transformed::Yes(Expr::Column(col)) + Transformed::yes(Expr::Column(col)) } else { - Transformed::No(expr) + Transformed::no(expr) } }) }) + .data() } /// See [`Column::normalize_with_schemas_and_ambiguity_check`] for usage @@ -90,12 +95,13 @@ pub fn normalize_col_with_schemas_and_ambiguity_check( if let Expr::Column(c) = expr { let col = c.normalize_with_schemas_and_ambiguity_check(schemas, using_columns)?; - Transformed::Yes(Expr::Column(col)) + Transformed::yes(Expr::Column(col)) } else { - Transformed::No(expr) + Transformed::no(expr) } }) }) + .data() } /// Recursively normalize all [`Column`] expressions in a list of expression trees @@ -116,14 +122,15 @@ pub fn replace_col(expr: Expr, replace_map: &HashMap<&Column, &Column>) -> Resul Ok({ if let Expr::Column(c) = &expr { match replace_map.get(c) { - Some(new_c) => Transformed::Yes(Expr::Column((*new_c).to_owned())), - None => Transformed::No(expr), + Some(new_c) => Transformed::yes(Expr::Column((*new_c).to_owned())), + None => Transformed::no(expr), } } else { - Transformed::No(expr) + Transformed::no(expr) } }) }) + .data() } /// Recursively 'unnormalize' (remove all qualifiers) from an @@ -139,12 +146,13 @@ pub fn unnormalize_col(expr: Expr) -> Expr { relation: None, name: c.name, }; - Transformed::Yes(Expr::Column(col)) + Transformed::yes(Expr::Column(col)) } else { - Transformed::No(expr) + Transformed::no(expr) } }) }) + .data() .expect("Unnormalize is infallable") } @@ -177,12 +185,13 @@ pub fn strip_outer_reference(expr: Expr) -> Expr { expr.transform(&|expr| { Ok({ if let Expr::OuterReferenceColumn(_, col) = expr { - Transformed::Yes(Expr::Column(col)) + Transformed::yes(Expr::Column(col)) } else { - Transformed::No(expr) + Transformed::no(expr) } }) }) + .data() .expect("strip_outer_reference is infallable") } @@ -260,22 +269,24 @@ pub fn unalias(expr: Expr) -> Expr { /// schema of plan nodes don't change after optimization pub fn rewrite_preserving_name(expr: Expr, rewriter: &mut R) -> Result where - R: TreeNodeRewriter, + R: TreeNodeRewriter, { let original_name = expr.name_for_alias()?; - let expr = expr.rewrite(rewriter)?; + let expr = expr.rewrite(rewriter)?.data; expr.alias_if_changed(original_name) } #[cfg(test)] mod test { + use std::ops::Add; + use super::*; use crate::expr::Sort; use crate::{col, lit, Cast}; + use arrow::datatypes::DataType; - use datafusion_common::tree_node::{RewriteRecursion, TreeNode, TreeNodeRewriter}; + use datafusion_common::tree_node::{TreeNode, TreeNodeRewriter}; use datafusion_common::{DFField, DFSchema, ScalarValue}; - use std::ops::Add; #[derive(Default)] struct RecordingRewriter { @@ -283,16 +294,16 @@ mod test { } impl TreeNodeRewriter for RecordingRewriter { - type N = Expr; + type Node = Expr; - fn pre_visit(&mut self, expr: &Expr) -> Result { + fn f_down(&mut self, expr: Expr) -> Result> { self.v.push(format!("Previsited {expr}")); - Ok(RewriteRecursion::Continue) + Ok(Transformed::no(expr)) } - fn mutate(&mut self, expr: Expr) -> Result { + fn f_up(&mut self, expr: Expr) -> Result> { self.v.push(format!("Mutated {expr}")); - Ok(expr) + Ok(Transformed::no(expr)) } } @@ -307,19 +318,27 @@ mod test { } else { utf8_val }; - Ok(Transformed::Yes(lit(utf8_val))) + Ok(Transformed::yes(lit(utf8_val))) } // otherwise, return None - _ => Ok(Transformed::No(expr)), + _ => Ok(Transformed::no(expr)), } }; // rewrites "foo" --> "bar" - let rewritten = col("state").eq(lit("foo")).transform(&transformer).unwrap(); + let rewritten = col("state") + .eq(lit("foo")) + .transform(&transformer) + .data() + .unwrap(); assert_eq!(rewritten, col("state").eq(lit("bar"))); // doesn't rewrite - let rewritten = col("state").eq(lit("baz")).transform(&transformer).unwrap(); + let rewritten = col("state") + .eq(lit("baz")) + .transform(&transformer) + .data() + .unwrap(); assert_eq!(rewritten, col("state").eq(lit("baz"))); } @@ -454,10 +473,10 @@ mod test { } impl TreeNodeRewriter for TestRewriter { - type N = Expr; + type Node = Expr; - fn mutate(&mut self, _: Expr) -> Result { - Ok(self.rewrite_to.clone()) + fn f_up(&mut self, _: Expr) -> Result> { + Ok(Transformed::yes(self.rewrite_to.clone())) } } diff --git a/datafusion/expr/src/expr_rewriter/order_by.rs b/datafusion/expr/src/expr_rewriter/order_by.rs index c87a724d5646..b1bc11a83f90 100644 --- a/datafusion/expr/src/expr_rewriter/order_by.rs +++ b/datafusion/expr/src/expr_rewriter/order_by.rs @@ -20,7 +20,8 @@ use crate::expr::{Alias, Sort}; use crate::expr_rewriter::normalize_col; use crate::{Cast, Expr, ExprSchemable, LogicalPlan, TryCast}; -use datafusion_common::tree_node::{Transformed, TreeNode}; + +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{Column, Result}; /// Rewrite sort on aggregate expressions to sort on the column of aggregate output @@ -91,7 +92,7 @@ fn rewrite_in_terms_of_projection( .to_field(input.schema()) .map(|f| f.qualified_column())?, ); - return Ok(Transformed::Yes(col)); + return Ok(Transformed::yes(col)); } // if that doesn't work, try to match the expression as an @@ -103,7 +104,7 @@ fn rewrite_in_terms_of_projection( e } else { // The expr is not based on Aggregate plan output. Skip it. - return Ok(Transformed::No(expr)); + return Ok(Transformed::no(expr)); }; // expr is an actual expr like min(t.c2), but we are looking @@ -118,7 +119,7 @@ fn rewrite_in_terms_of_projection( // look for the column named the same as this expr if let Some(found) = proj_exprs.iter().find(|a| expr_match(&search_col, a)) { let found = found.clone(); - return Ok(Transformed::Yes(match normalized_expr { + return Ok(Transformed::yes(match normalized_expr { Expr::Cast(Cast { expr: _, data_type }) => Expr::Cast(Cast { expr: Box::new(found), data_type, @@ -131,8 +132,9 @@ fn rewrite_in_terms_of_projection( })); } - Ok(Transformed::No(expr)) + Ok(Transformed::no(expr)) }) + .data() } /// Does the underlying expr match e? diff --git a/datafusion/expr/src/field_util.rs b/datafusion/expr/src/field_util.rs index 8039a211c9e4..f0ce61ee9bbb 100644 --- a/datafusion/expr/src/field_util.rs +++ b/datafusion/expr/src/field_util.rs @@ -78,10 +78,11 @@ impl GetFieldAccessSchema { Self::ListIndex{ key_dt } => { match (data_type, key_dt) { (DataType::List(lt), DataType::Int64) => Ok(Field::new("list", lt.data_type().clone(), true)), - (DataType::List(_), _) => plan_err!( - "Only ints are valid as an indexed field in a list" + (DataType::LargeList(lt), DataType::Int64) => Ok(Field::new("large_list", lt.data_type().clone(), true)), + (DataType::List(_), _) | (DataType::LargeList(_), _) => plan_err!( + "Only ints are valid as an indexed field in a List/LargeList" ), - (other, _) => plan_err!("The expression to get an indexed field is only valid for `List` or `Struct` types, got {other}"), + (other, _) => plan_err!("The expression to get an indexed field is only valid for `List`, `LargeList` or `Struct` types, got {other}"), } } Self::ListRange { start_dt, stop_dt, stride_dt } => { @@ -89,7 +90,7 @@ impl GetFieldAccessSchema { (DataType::List(_), DataType::Int64, DataType::Int64, DataType::Int64) => Ok(Field::new("list", data_type.clone(), true)), (DataType::LargeList(_), DataType::Int64, DataType::Int64, DataType::Int64) => Ok(Field::new("large_list", data_type.clone(), true)), (DataType::List(_), _, _, _) | (DataType::LargeList(_), _, _, _)=> plan_err!( - "Only ints are valid as an indexed field in a list" + "Only ints are valid as an indexed field in a List/LargeList" ), (other, _, _, _) => plan_err!("The expression to get an indexed field is only valid for `List`, `LargeList` or `Struct` types, got {other}"), } diff --git a/datafusion/expr/src/logical_plan/display.rs b/datafusion/expr/src/logical_plan/display.rs index 112dbf74dba1..e0cb44626e24 100644 --- a/datafusion/expr/src/logical_plan/display.rs +++ b/datafusion/expr/src/logical_plan/display.rs @@ -16,12 +16,14 @@ // under the License. //! This module provides logic for displaying LogicalPlans in various styles +use std::fmt; + use crate::LogicalPlan; + use arrow::datatypes::Schema; use datafusion_common::display::GraphvizBuilder; -use datafusion_common::tree_node::{TreeNodeVisitor, VisitRecursion}; +use datafusion_common::tree_node::{TreeNodeRecursion, TreeNodeVisitor}; use datafusion_common::DataFusionError; -use std::fmt; /// Formats plans with a single line per node. For example: /// @@ -49,12 +51,12 @@ impl<'a, 'b> IndentVisitor<'a, 'b> { } impl<'a, 'b> TreeNodeVisitor for IndentVisitor<'a, 'b> { - type N = LogicalPlan; + type Node = LogicalPlan; - fn pre_visit( + fn f_down( &mut self, plan: &LogicalPlan, - ) -> datafusion_common::Result { + ) -> datafusion_common::Result { if self.indent > 0 { writeln!(self.f)?; } @@ -69,15 +71,15 @@ impl<'a, 'b> TreeNodeVisitor for IndentVisitor<'a, 'b> { } self.indent += 1; - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) } - fn post_visit( + fn f_up( &mut self, _plan: &LogicalPlan, - ) -> datafusion_common::Result { + ) -> datafusion_common::Result { self.indent -= 1; - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) } } @@ -171,12 +173,12 @@ impl<'a, 'b> GraphvizVisitor<'a, 'b> { } impl<'a, 'b> TreeNodeVisitor for GraphvizVisitor<'a, 'b> { - type N = LogicalPlan; + type Node = LogicalPlan; - fn pre_visit( + fn f_down( &mut self, plan: &LogicalPlan, - ) -> datafusion_common::Result { + ) -> datafusion_common::Result { let id = self.graphviz_builder.next_id(); // Create a new graph node for `plan` such as @@ -204,18 +206,18 @@ impl<'a, 'b> TreeNodeVisitor for GraphvizVisitor<'a, 'b> { } self.parent_ids.push(id); - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) } - fn post_visit( + fn f_up( &mut self, _plan: &LogicalPlan, - ) -> datafusion_common::Result { + ) -> datafusion_common::Result { // always be non-empty as pre_visit always pushes // So it should always be Ok(true) let res = self.parent_ids.pop(); res.ok_or(DataFusionError::Internal("Fail to format".to_string())) - .map(|_| VisitRecursion::Continue) + .map(|_| TreeNodeRecursion::Continue) } } diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 5cce8f9cd45c..ca021c4bfc28 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -46,8 +46,7 @@ use crate::{ use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::tree_node::{ - RewriteRecursion, Transformed, TreeNode, TreeNodeRewriter, TreeNodeVisitor, - VisitRecursion, + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeVisitor, }; use datafusion_common::{ aggregate_functional_dependencies, internal_err, plan_err, Column, Constraints, @@ -476,7 +475,7 @@ impl LogicalPlan { })?; using_columns.push(columns); } - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) })?; Ok(using_columns) @@ -649,31 +648,24 @@ impl LogicalPlan { // Decimal128(Some(69999999999999),30,15) // AND lineitem.l_quantity < Decimal128(Some(2400),15,2) - struct RemoveAliases {} - - impl TreeNodeRewriter for RemoveAliases { - type N = Expr; - - fn pre_visit(&mut self, expr: &Expr) -> Result { + let predicate = predicate + .transform_down(&|expr| { match expr { Expr::Exists { .. } | Expr::ScalarSubquery(_) | Expr::InSubquery(_) => { // subqueries could contain aliases so we don't recurse into those - Ok(RewriteRecursion::Stop) + Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump)) } - Expr::Alias(_) => Ok(RewriteRecursion::Mutate), - _ => Ok(RewriteRecursion::Continue), + Expr::Alias(_) => Ok(Transformed::new( + expr.unalias(), + true, + TreeNodeRecursion::Jump, + )), + _ => Ok(Transformed::no(expr)), } - } - - fn mutate(&mut self, expr: Expr) -> Result { - Ok(expr.unalias()) - } - } - - let mut remove_aliases = RemoveAliases {}; - let predicate = predicate.rewrite(&mut remove_aliases)?; + }) + .data()?; Filter::try_new(predicate, Arc::new(inputs.swap_remove(0))) .map(LogicalPlan::Filter) @@ -1125,9 +1117,9 @@ impl LogicalPlan { impl LogicalPlan { /// applies `op` to any subqueries in the plan - pub(crate) fn apply_subqueries(&self, op: &mut F) -> datafusion_common::Result<()> + pub(crate) fn apply_subqueries(&self, op: &mut F) -> Result<()> where - F: FnMut(&Self) -> datafusion_common::Result, + F: FnMut(&Self) -> Result, { self.inspect_expressions(|expr| { // recursively look for subqueries @@ -1151,9 +1143,9 @@ impl LogicalPlan { } /// applies visitor to any subqueries in the plan - pub(crate) fn visit_subqueries(&self, v: &mut V) -> datafusion_common::Result<()> + pub(crate) fn visit_subqueries(&self, v: &mut V) -> Result<()> where - V: TreeNodeVisitor, + V: TreeNodeVisitor, { self.inspect_expressions(|expr| { // recursively look for subqueries @@ -1226,11 +1218,11 @@ impl LogicalPlan { _ => {} } } - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) })?; Ok::<(), DataFusionError>(()) })?; - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) })?; Ok(param_types) @@ -1247,19 +1239,20 @@ impl LogicalPlan { Expr::Placeholder(Placeholder { id, .. }) => { let value = param_values.get_placeholders_with_values(id)?; // Replace the placeholder with the value - Ok(Transformed::Yes(Expr::Literal(value))) + Ok(Transformed::yes(Expr::Literal(value))) } Expr::ScalarSubquery(qry) => { let subquery = Arc::new(qry.subquery.replace_params_with_values(param_values)?); - Ok(Transformed::Yes(Expr::ScalarSubquery(Subquery { + Ok(Transformed::yes(Expr::ScalarSubquery(Subquery { subquery, outer_ref_columns: qry.outer_ref_columns.clone(), }))) } - _ => Ok(Transformed::No(expr)), + _ => Ok(Transformed::no(expr)), } }) + .data() } } @@ -2842,9 +2835,9 @@ digraph { } impl TreeNodeVisitor for OkVisitor { - type N = LogicalPlan; + type Node = LogicalPlan; - fn pre_visit(&mut self, plan: &LogicalPlan) -> Result { + fn f_down(&mut self, plan: &LogicalPlan) -> Result { let s = match plan { LogicalPlan::Projection { .. } => "pre_visit Projection", LogicalPlan::Filter { .. } => "pre_visit Filter", @@ -2855,10 +2848,10 @@ digraph { }; self.strings.push(s.into()); - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) } - fn post_visit(&mut self, plan: &LogicalPlan) -> Result { + fn f_up(&mut self, plan: &LogicalPlan) -> Result { let s = match plan { LogicalPlan::Projection { .. } => "post_visit Projection", LogicalPlan::Filter { .. } => "post_visit Filter", @@ -2869,7 +2862,7 @@ digraph { }; self.strings.push(s.into()); - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) } } @@ -2925,23 +2918,23 @@ digraph { } impl TreeNodeVisitor for StoppingVisitor { - type N = LogicalPlan; + type Node = LogicalPlan; - fn pre_visit(&mut self, plan: &LogicalPlan) -> Result { + fn f_down(&mut self, plan: &LogicalPlan) -> Result { if self.return_false_from_pre_in.dec() { - return Ok(VisitRecursion::Stop); + return Ok(TreeNodeRecursion::Stop); } - self.inner.pre_visit(plan)?; + self.inner.f_down(plan)?; - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) } - fn post_visit(&mut self, plan: &LogicalPlan) -> Result { + fn f_up(&mut self, plan: &LogicalPlan) -> Result { if self.return_false_from_post_in.dec() { - return Ok(VisitRecursion::Stop); + return Ok(TreeNodeRecursion::Stop); } - self.inner.post_visit(plan) + self.inner.f_up(plan) } } @@ -2994,22 +2987,22 @@ digraph { } impl TreeNodeVisitor for ErrorVisitor { - type N = LogicalPlan; + type Node = LogicalPlan; - fn pre_visit(&mut self, plan: &LogicalPlan) -> Result { + fn f_down(&mut self, plan: &LogicalPlan) -> Result { if self.return_error_from_pre_in.dec() { return not_impl_err!("Error in pre_visit"); } - self.inner.pre_visit(plan) + self.inner.f_down(plan) } - fn post_visit(&mut self, plan: &LogicalPlan) -> Result { + fn f_up(&mut self, plan: &LogicalPlan) -> Result { if self.return_error_from_post_in.dec() { return not_impl_err!("Error in post_visit"); } - self.inner.post_visit(plan) + self.inner.f_up(plan) } } @@ -3317,10 +3310,11 @@ digraph { Arc::new(LogicalPlan::TableScan(table)), ) .unwrap(); - Ok(Transformed::Yes(LogicalPlan::Filter(filter))) + Ok(Transformed::yes(LogicalPlan::Filter(filter))) } - x => Ok(Transformed::No(x)), + x => Ok(Transformed::no(x)), }) + .data() .unwrap(); let expected = "Explain\ diff --git a/datafusion/expr/src/tree_node/expr.rs b/datafusion/expr/src/tree_node/expr.rs index 81949f2178f6..67d48f986f13 100644 --- a/datafusion/expr/src/tree_node/expr.rs +++ b/datafusion/expr/src/tree_node/expr.rs @@ -24,14 +24,16 @@ use crate::expr::{ }; use crate::{Expr, GetFieldAccess}; -use datafusion_common::tree_node::{TreeNode, VisitRecursion}; -use datafusion_common::{internal_err, Result}; +use datafusion_common::tree_node::{ + Transformed, TransformedIterator, TreeNode, TreeNodeRecursion, +}; +use datafusion_common::{handle_visit_recursion, internal_err, Result}; impl TreeNode for Expr { - fn apply_children Result>( + fn apply_children Result>( &self, - op: &mut F, - ) -> Result { + f: &mut F, + ) -> Result { let children = match self { Expr::Alias(Alias{expr,..}) | Expr::Not(expr) @@ -129,21 +131,19 @@ impl TreeNode for Expr { } }; + let mut tnr = TreeNodeRecursion::Continue; for child in children { - match op(child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } + tnr = f(child)?; + handle_visit_recursion!(tnr, DOWN); } - Ok(VisitRecursion::Continue) + Ok(tnr) } - fn map_children Result>( - self, - mut transform: F, - ) -> Result { + fn map_children(self, mut f: F) -> Result> + where + F: FnMut(Self) -> Result>, + { Ok(match self { Expr::Column(_) | Expr::Wildcard { .. } @@ -153,27 +153,29 @@ impl TreeNode for Expr { | Expr::ScalarSubquery(_) | Expr::ScalarVariable(_, _) | Expr::Unnest(_) - | Expr::Literal(_) => self, + | Expr::Literal(_) => Transformed::no(self), Expr::Alias(Alias { expr, relation, name, - }) => Expr::Alias(Alias::new(transform(*expr)?, relation, name)), + }) => f(*expr)?.update_data(|e| Expr::Alias(Alias::new(e, relation, name))), Expr::InSubquery(InSubquery { expr, subquery, negated, - }) => Expr::InSubquery(InSubquery::new( - transform_boxed(expr, &mut transform)?, - subquery, - negated, - )), + }) => transform_box(expr, &mut f)?.update_data(|be| { + Expr::InSubquery(InSubquery::new(be, subquery, negated)) + }), Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - Expr::BinaryExpr(BinaryExpr::new( - transform_boxed(left, &mut transform)?, - op, - transform_boxed(right, &mut transform)?, - )) + transform_box(left, &mut f)? + .update_data(|new_left| (new_left, right)) + .try_transform_node(|(new_left, right)| { + Ok(transform_box(right, &mut f)? + .update_data(|new_right| (new_left, new_right))) + })? + .update_data(|(new_left, new_right)| { + Expr::BinaryExpr(BinaryExpr::new(new_left, op, new_right)) + }) } Expr::Like(Like { negated, @@ -181,102 +183,136 @@ impl TreeNode for Expr { pattern, escape_char, case_insensitive, - }) => Expr::Like(Like::new( - negated, - transform_boxed(expr, &mut transform)?, - transform_boxed(pattern, &mut transform)?, - escape_char, - case_insensitive, - )), + }) => transform_box(expr, &mut f)? + .update_data(|new_expr| (new_expr, pattern)) + .try_transform_node(|(new_expr, pattern)| { + Ok(transform_box(pattern, &mut f)? + .update_data(|new_pattern| (new_expr, new_pattern))) + })? + .update_data(|(new_expr, new_pattern)| { + Expr::Like(Like::new( + negated, + new_expr, + new_pattern, + escape_char, + case_insensitive, + )) + }), Expr::SimilarTo(Like { negated, expr, pattern, escape_char, case_insensitive, - }) => Expr::SimilarTo(Like::new( - negated, - transform_boxed(expr, &mut transform)?, - transform_boxed(pattern, &mut transform)?, - escape_char, - case_insensitive, - )), - Expr::Not(expr) => Expr::Not(transform_boxed(expr, &mut transform)?), + }) => transform_box(expr, &mut f)? + .update_data(|new_expr| (new_expr, pattern)) + .try_transform_node(|(new_expr, pattern)| { + Ok(transform_box(pattern, &mut f)? + .update_data(|new_pattern| (new_expr, new_pattern))) + })? + .update_data(|(new_expr, new_pattern)| { + Expr::SimilarTo(Like::new( + negated, + new_expr, + new_pattern, + escape_char, + case_insensitive, + )) + }), + Expr::Not(expr) => transform_box(expr, &mut f)?.update_data(Expr::Not), Expr::IsNotNull(expr) => { - Expr::IsNotNull(transform_boxed(expr, &mut transform)?) + transform_box(expr, &mut f)?.update_data(Expr::IsNotNull) + } + Expr::IsNull(expr) => transform_box(expr, &mut f)?.update_data(Expr::IsNull), + Expr::IsTrue(expr) => transform_box(expr, &mut f)?.update_data(Expr::IsTrue), + Expr::IsFalse(expr) => { + transform_box(expr, &mut f)?.update_data(Expr::IsFalse) } - Expr::IsNull(expr) => Expr::IsNull(transform_boxed(expr, &mut transform)?), - Expr::IsTrue(expr) => Expr::IsTrue(transform_boxed(expr, &mut transform)?), - Expr::IsFalse(expr) => Expr::IsFalse(transform_boxed(expr, &mut transform)?), Expr::IsUnknown(expr) => { - Expr::IsUnknown(transform_boxed(expr, &mut transform)?) + transform_box(expr, &mut f)?.update_data(Expr::IsUnknown) } Expr::IsNotTrue(expr) => { - Expr::IsNotTrue(transform_boxed(expr, &mut transform)?) + transform_box(expr, &mut f)?.update_data(Expr::IsNotTrue) } Expr::IsNotFalse(expr) => { - Expr::IsNotFalse(transform_boxed(expr, &mut transform)?) + transform_box(expr, &mut f)?.update_data(Expr::IsNotFalse) } Expr::IsNotUnknown(expr) => { - Expr::IsNotUnknown(transform_boxed(expr, &mut transform)?) + transform_box(expr, &mut f)?.update_data(Expr::IsNotUnknown) } Expr::Negative(expr) => { - Expr::Negative(transform_boxed(expr, &mut transform)?) + transform_box(expr, &mut f)?.update_data(Expr::Negative) } Expr::Between(Between { expr, negated, low, high, - }) => Expr::Between(Between::new( - transform_boxed(expr, &mut transform)?, - negated, - transform_boxed(low, &mut transform)?, - transform_boxed(high, &mut transform)?, - )), - Expr::Case(case) => { - let expr = transform_option_box(case.expr, &mut transform)?; - let when_then_expr = case - .when_then_expr - .into_iter() - .map(|(when, then)| { - Ok(( - transform_boxed(when, &mut transform)?, - transform_boxed(then, &mut transform)?, - )) - }) - .collect::>>()?; - let else_expr = transform_option_box(case.else_expr, &mut transform)?; - - Expr::Case(Case::new(expr, when_then_expr, else_expr)) - } - Expr::Cast(Cast { expr, data_type }) => { - Expr::Cast(Cast::new(transform_boxed(expr, &mut transform)?, data_type)) - } - Expr::TryCast(TryCast { expr, data_type }) => Expr::TryCast(TryCast::new( - transform_boxed(expr, &mut transform)?, - data_type, - )), + }) => transform_box(expr, &mut f)? + .update_data(|new_expr| (new_expr, low, high)) + .try_transform_node(|(new_expr, low, high)| { + Ok(transform_box(low, &mut f)? + .update_data(|new_low| (new_expr, new_low, high))) + })? + .try_transform_node(|(new_expr, new_low, high)| { + Ok(transform_box(high, &mut f)? + .update_data(|new_high| (new_expr, new_low, new_high))) + })? + .update_data(|(new_expr, new_low, new_high)| { + Expr::Between(Between::new(new_expr, negated, new_low, new_high)) + }), + Expr::Case(Case { + expr, + when_then_expr, + else_expr, + }) => transform_option_box(expr, &mut f)? + .update_data(|new_expr| (new_expr, when_then_expr, else_expr)) + .try_transform_node(|(new_expr, when_then_expr, else_expr)| { + Ok(when_then_expr + .into_iter() + .map_until_stop_and_collect(|(when, then)| { + transform_box(when, &mut f)? + .update_data(|new_when| (new_when, then)) + .try_transform_node(|(new_when, then)| { + Ok(transform_box(then, &mut f)? + .update_data(|new_then| (new_when, new_then))) + }) + })? + .update_data(|new_when_then_expr| { + (new_expr, new_when_then_expr, else_expr) + })) + })? + .try_transform_node(|(new_expr, new_when_then_expr, else_expr)| { + Ok(transform_option_box(else_expr, &mut f)?.update_data( + |new_else_expr| (new_expr, new_when_then_expr, new_else_expr), + )) + })? + .update_data(|(new_expr, new_when_then_expr, new_else_expr)| { + Expr::Case(Case::new(new_expr, new_when_then_expr, new_else_expr)) + }), + Expr::Cast(Cast { expr, data_type }) => transform_box(expr, &mut f)? + .update_data(|be| Expr::Cast(Cast::new(be, data_type))), + Expr::TryCast(TryCast { expr, data_type }) => transform_box(expr, &mut f)? + .update_data(|be| Expr::TryCast(TryCast::new(be, data_type))), Expr::Sort(Sort { expr, asc, nulls_first, - }) => Expr::Sort(Sort::new( - transform_boxed(expr, &mut transform)?, - asc, - nulls_first, - )), - Expr::ScalarFunction(ScalarFunction { func_def, args }) => match func_def { - ScalarFunctionDefinition::BuiltIn(fun) => Expr::ScalarFunction( - ScalarFunction::new(fun, transform_vec(args, &mut transform)?), - ), - ScalarFunctionDefinition::UDF(fun) => Expr::ScalarFunction( - ScalarFunction::new_udf(fun, transform_vec(args, &mut transform)?), - ), - ScalarFunctionDefinition::Name(_) => { - return internal_err!("Function `Expr` with name should be resolved.") - } - }, + }) => transform_box(expr, &mut f)? + .update_data(|be| Expr::Sort(Sort::new(be, asc, nulls_first))), + Expr::ScalarFunction(ScalarFunction { func_def, args }) => { + transform_vec(args, &mut f)?.map_data(|new_args| match func_def { + ScalarFunctionDefinition::BuiltIn(fun) => { + Ok(Expr::ScalarFunction(ScalarFunction::new(fun, new_args))) + } + ScalarFunctionDefinition::UDF(fun) => { + Ok(Expr::ScalarFunction(ScalarFunction::new_udf(fun, new_args))) + } + ScalarFunctionDefinition::Name(_) => { + internal_err!("Function `Expr` with name should be resolved.") + } + })? + } Expr::WindowFunction(WindowFunction { args, fun, @@ -284,112 +320,139 @@ impl TreeNode for Expr { order_by, window_frame, null_treatment, - }) => Expr::WindowFunction(WindowFunction::new( - fun, - transform_vec(args, &mut transform)?, - transform_vec(partition_by, &mut transform)?, - transform_vec(order_by, &mut transform)?, - window_frame, - null_treatment, - )), + }) => transform_vec(args, &mut f)? + .update_data(|new_args| (new_args, partition_by, order_by)) + .try_transform_node(|(new_args, partition_by, order_by)| { + Ok(transform_vec(partition_by, &mut f)?.update_data( + |new_partition_by| (new_args, new_partition_by, order_by), + )) + })? + .try_transform_node(|(new_args, new_partition_by, order_by)| { + Ok( + transform_vec(order_by, &mut f)?.update_data(|new_order_by| { + (new_args, new_partition_by, new_order_by) + }), + ) + })? + .update_data(|(new_args, new_partition_by, new_order_by)| { + Expr::WindowFunction(WindowFunction::new( + fun, + new_args, + new_partition_by, + new_order_by, + window_frame, + null_treatment, + )) + }), Expr::AggregateFunction(AggregateFunction { args, func_def, distinct, filter, order_by, - }) => match func_def { - AggregateFunctionDefinition::BuiltIn(fun) => { - Expr::AggregateFunction(AggregateFunction::new( - fun, - transform_vec(args, &mut transform)?, - distinct, - transform_option_box(filter, &mut transform)?, - transform_option_vec(order_by, &mut transform)?, - )) - } - AggregateFunctionDefinition::UDF(fun) => { - let order_by = order_by - .map(|order_by| transform_vec(order_by, &mut transform)) - .transpose()?; - Expr::AggregateFunction(AggregateFunction::new_udf( - fun, - transform_vec(args, &mut transform)?, - false, - transform_option_box(filter, &mut transform)?, - transform_option_vec(order_by, &mut transform)?, - )) - } - AggregateFunctionDefinition::Name(_) => { - return internal_err!("Function `Expr` with name should be resolved.") - } - }, + }) => transform_vec(args, &mut f)? + .update_data(|new_args| (new_args, filter, order_by)) + .try_transform_node(|(new_args, filter, order_by)| { + Ok(transform_option_box(filter, &mut f)? + .update_data(|new_filter| (new_args, new_filter, order_by))) + })? + .try_transform_node(|(new_args, new_filter, order_by)| { + Ok(transform_option_vec(order_by, &mut f)? + .update_data(|new_order_by| (new_args, new_filter, new_order_by))) + })? + .map_data(|(new_args, new_filter, new_order_by)| match func_def { + AggregateFunctionDefinition::BuiltIn(fun) => { + Ok(Expr::AggregateFunction(AggregateFunction::new( + fun, + new_args, + distinct, + new_filter, + new_order_by, + ))) + } + AggregateFunctionDefinition::UDF(fun) => { + Ok(Expr::AggregateFunction(AggregateFunction::new_udf( + fun, + new_args, + false, + new_filter, + new_order_by, + ))) + } + AggregateFunctionDefinition::Name(_) => { + internal_err!("Function `Expr` with name should be resolved.") + } + })?, Expr::GroupingSet(grouping_set) => match grouping_set { - GroupingSet::Rollup(exprs) => Expr::GroupingSet(GroupingSet::Rollup( - transform_vec(exprs, &mut transform)?, - )), - GroupingSet::Cube(exprs) => Expr::GroupingSet(GroupingSet::Cube( - transform_vec(exprs, &mut transform)?, - )), - GroupingSet::GroupingSets(lists_of_exprs) => { - Expr::GroupingSet(GroupingSet::GroupingSets( - lists_of_exprs - .into_iter() - .map(|exprs| transform_vec(exprs, &mut transform)) - .collect::>>()?, - )) - } + GroupingSet::Rollup(exprs) => transform_vec(exprs, &mut f)? + .update_data(|ve| Expr::GroupingSet(GroupingSet::Rollup(ve))), + GroupingSet::Cube(exprs) => transform_vec(exprs, &mut f)? + .update_data(|ve| Expr::GroupingSet(GroupingSet::Cube(ve))), + GroupingSet::GroupingSets(lists_of_exprs) => lists_of_exprs + .into_iter() + .map_until_stop_and_collect(|exprs| transform_vec(exprs, &mut f))? + .update_data(|new_lists_of_exprs| { + Expr::GroupingSet(GroupingSet::GroupingSets(new_lists_of_exprs)) + }), }, Expr::InList(InList { expr, list, negated, - }) => Expr::InList(InList::new( - transform_boxed(expr, &mut transform)?, - transform_vec(list, &mut transform)?, - negated, - )), + }) => transform_box(expr, &mut f)? + .update_data(|new_expr| (new_expr, list)) + .try_transform_node(|(new_expr, list)| { + Ok(transform_vec(list, &mut f)? + .update_data(|new_list| (new_expr, new_list))) + })? + .update_data(|(new_expr, new_list)| { + Expr::InList(InList::new(new_expr, new_list, negated)) + }), Expr::GetIndexedField(GetIndexedField { expr, field }) => { - Expr::GetIndexedField(GetIndexedField::new( - transform_boxed(expr, &mut transform)?, - field, - )) + transform_box(expr, &mut f)?.update_data(|be| { + Expr::GetIndexedField(GetIndexedField::new(be, field)) + }) } }) } } -fn transform_boxed Result>( - boxed_expr: Box, - transform: &mut F, -) -> Result> { - // TODO: It might be possible to avoid an allocation (the Box::new) below by reusing the box. - transform(*boxed_expr).map(Box::new) +fn transform_box(be: Box, f: &mut F) -> Result>> +where + F: FnMut(Expr) -> Result>, +{ + Ok(f(*be)?.update_data(Box::new)) } -fn transform_option_box Result>( - option_box: Option>, - transform: &mut F, -) -> Result>> { - option_box - .map(|expr| transform_boxed(expr, transform)) - .transpose() +fn transform_option_box( + obe: Option>, + f: &mut F, +) -> Result>>> +where + F: FnMut(Expr) -> Result>, +{ + obe.map_or(Ok(Transformed::no(None)), |be| { + Ok(transform_box(be, f)?.update_data(Some)) + }) } /// &mut transform a Option<`Vec` of `Expr`s> -fn transform_option_vec Result>( - option_box: Option>, - transform: &mut F, -) -> Result>> { - option_box - .map(|exprs| transform_vec(exprs, transform)) - .transpose() +fn transform_option_vec( + ove: Option>, + f: &mut F, +) -> Result>>> +where + F: FnMut(Expr) -> Result>, +{ + ove.map_or(Ok(Transformed::no(None)), |ve| { + Ok(transform_vec(ve, f)?.update_data(Some)) + }) } /// &mut transform a `Vec` of `Expr`s -fn transform_vec Result>( - v: Vec, - transform: &mut F, -) -> Result> { - v.into_iter().map(transform).collect() +fn transform_vec(ve: Vec, f: &mut F) -> Result>> +where + F: FnMut(Expr) -> Result>, +{ + ve.into_iter().map_until_stop_and_collect(f) } diff --git a/datafusion/expr/src/tree_node/plan.rs b/datafusion/expr/src/tree_node/plan.rs index c35a09874a62..02d5d1851289 100644 --- a/datafusion/expr/src/tree_node/plan.rs +++ b/datafusion/expr/src/tree_node/plan.rs @@ -19,19 +19,21 @@ use crate::LogicalPlan; -use datafusion_common::tree_node::{TreeNode, TreeNodeVisitor, VisitRecursion}; -use datafusion_common::{handle_tree_recursion, Result}; +use datafusion_common::tree_node::{ + Transformed, TransformedIterator, TreeNode, TreeNodeRecursion, TreeNodeVisitor, +}; +use datafusion_common::{handle_visit_recursion, Result}; impl TreeNode for LogicalPlan { - fn apply Result>( + fn apply Result>( &self, - op: &mut F, - ) -> Result { + f: &mut F, + ) -> Result { // Compared to the default implementation, we need to invoke // [`Self::apply_subqueries`] before visiting its children - handle_tree_recursion!(op(self)?); - self.apply_subqueries(op)?; - self.apply_children(&mut |node| node.apply(op)) + handle_visit_recursion!(f(self)?, DOWN); + self.apply_subqueries(f)?; + self.apply_children(&mut |n| n.apply(f)) } /// To use, define a struct that implements the trait [`TreeNodeVisitor`] and then invoke @@ -54,48 +56,58 @@ impl TreeNode for LogicalPlan { /// visitor.post_visit(Filter) /// visitor.post_visit(Projection) /// ``` - fn visit>( + fn visit>( &self, visitor: &mut V, - ) -> Result { + ) -> Result { // Compared to the default implementation, we need to invoke // [`Self::visit_subqueries`] before visiting its children - handle_tree_recursion!(visitor.pre_visit(self)?); - self.visit_subqueries(visitor)?; - handle_tree_recursion!(self.apply_children(&mut |node| node.visit(visitor))?); - visitor.post_visit(self) + match visitor.f_down(self)? { + TreeNodeRecursion::Continue => { + self.visit_subqueries(visitor)?; + handle_visit_recursion!( + self.apply_children(&mut |n| n.visit(visitor))?, + UP + ); + visitor.f_up(self) + } + TreeNodeRecursion::Jump => { + self.visit_subqueries(visitor)?; + visitor.f_up(self) + } + TreeNodeRecursion::Stop => Ok(TreeNodeRecursion::Stop), + } } - fn apply_children Result>( + fn apply_children Result>( &self, - op: &mut F, - ) -> Result { + f: &mut F, + ) -> Result { + let mut tnr = TreeNodeRecursion::Continue; for child in self.inputs() { - handle_tree_recursion!(op(child)?) + tnr = f(child)?; + handle_visit_recursion!(tnr, DOWN) } - Ok(VisitRecursion::Continue) + Ok(tnr) } - fn map_children(self, transform: F) -> Result + fn map_children(self, f: F) -> Result> where - F: FnMut(Self) -> Result, + F: FnMut(Self) -> Result>, { - let old_children = self.inputs(); - let new_children = old_children + let new_children = self + .inputs() .iter() .map(|&c| c.clone()) - .map(transform) - .collect::>>()?; - - // if any changes made, make a new child - if old_children - .into_iter() - .zip(new_children.iter()) - .any(|(c1, c2)| c1 != c2) - { - self.with_new_exprs(self.expressions(), new_children) + .map_until_stop_and_collect(f)?; + // Propagate up `new_children.transformed` and `new_children.tnr` + // along with the node containing transformed children. + if new_children.transformed { + new_children.map_data(|new_children| { + self.with_new_exprs(self.expressions(), new_children) + }) } else { - Ok(self) + Ok(new_children.update_data(|_| self)) } } } diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index fe9297b32a8e..dfd90e470965 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -31,7 +31,7 @@ use crate::{ }; use arrow::datatypes::{DataType, TimeUnit}; -use datafusion_common::tree_node::{TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion_common::utils::get_at_indices; use datafusion_common::{ internal_err, plan_datafusion_err, plan_err, Column, DFField, DFSchema, DFSchemaRef, @@ -665,10 +665,10 @@ where exprs.push(expr.clone()) } // stop recursing down this expr once we find a match - return Ok(VisitRecursion::Skip); + return Ok(TreeNodeRecursion::Jump); } - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) }) // pre_visit always returns OK, so this will always too .expect("no way to return error during recursion"); @@ -685,10 +685,10 @@ where if let Err(e) = f(expr) { // save the error for later (it may not be a DataFusionError err = Err(e); - Ok(VisitRecursion::Stop) + Ok(TreeNodeRecursion::Stop) } else { // keep going - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) } }) // The closure always returns OK, so this will always too diff --git a/datafusion/functions-array/Cargo.toml b/datafusion/functions-array/Cargo.toml index 70d676c6d270..088babdf50e3 100644 --- a/datafusion/functions-array/Cargo.toml +++ b/datafusion/functions-array/Cargo.toml @@ -38,7 +38,7 @@ path = "src/lib.rs" [dependencies] arrow = { workspace = true } -datafusion-common = { workspace = true, default-features = true } +datafusion-common = { workspace = true } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } log = { workspace = true } diff --git a/datafusion/functions-array/src/kernels.rs b/datafusion/functions-array/src/kernels.rs index b9a68b466605..70c778f34082 100644 --- a/datafusion/functions-array/src/kernels.rs +++ b/datafusion/functions-array/src/kernels.rs @@ -17,16 +17,21 @@ //! implementation kernels for array functions +use arrow::array::ListArray; use arrow::array::{ - Array, ArrayRef, BooleanArray, Float32Array, Float64Array, GenericListArray, - Int16Array, Int32Array, Int64Array, Int8Array, LargeStringArray, OffsetSizeTrait, - StringArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array, + Array, ArrayRef, BooleanArray, Date32Array, Float32Array, Float64Array, + GenericListArray, Int16Array, Int32Array, Int64Array, Int8Array, LargeStringArray, + OffsetSizeTrait, StringArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array, }; -use arrow::datatypes::DataType; +use arrow::buffer::OffsetBuffer; +use arrow::datatypes::Field; +use arrow::datatypes::UInt64Type; +use arrow::datatypes::{DataType, Date32Type, IntervalMonthDayNanoType}; use datafusion_common::cast::{ - as_int64_array, as_large_list_array, as_list_array, as_string_array, + as_date32_array, as_int64_array, as_interval_mdn_array, as_large_list_array, + as_list_array, as_string_array, }; -use datafusion_common::{exec_err, DataFusionError}; +use datafusion_common::{exec_err, DataFusionError, Result}; use std::any::type_name; use std::sync::Arc; macro_rules! downcast_arg { @@ -102,7 +107,7 @@ macro_rules! call_array_function { } /// Array_to_string SQL function -pub(super) fn array_to_string(args: &[ArrayRef]) -> datafusion_common::Result { +pub(super) fn array_to_string(args: &[ArrayRef]) -> Result { if args.len() < 2 || args.len() > 3 { return exec_err!("array_to_string expects two or three arguments"); } @@ -254,9 +259,6 @@ pub(super) fn array_to_string(args: &[ArrayRef]) -> datafusion_common::Result [0, 1, 2] /// gen_range(1, 4) => [1, 2, 3] /// gen_range(1, 7, 2) => [1, 3, 5] -pub fn gen_range( - args: &[ArrayRef], - include_upper: i64, -) -> datafusion_common::Result { +pub fn gen_range(args: &[ArrayRef], include_upper: i64) -> Result { let (start_array, stop_array, step_array) = match args.len() { 1 => (None, as_int64_array(&args[0])?, None), 2 => ( @@ -319,3 +318,168 @@ pub fn gen_range( )?); Ok(arr) } + +/// Returns the length of each array dimension +fn compute_array_dims(arr: Option) -> Result>>> { + let mut value = match arr { + Some(arr) => arr, + None => return Ok(None), + }; + if value.is_empty() { + return Ok(None); + } + let mut res = vec![Some(value.len() as u64)]; + + loop { + match value.data_type() { + DataType::List(..) => { + value = downcast_arg!(value, ListArray).value(0); + res.push(Some(value.len() as u64)); + } + _ => return Ok(Some(res)), + } + } +} + +fn generic_list_cardinality( + array: &GenericListArray, +) -> Result { + let result = array + .iter() + .map(|arr| match compute_array_dims(arr)? { + Some(vector) => Ok(Some(vector.iter().map(|x| x.unwrap()).product::())), + None => Ok(None), + }) + .collect::>()?; + Ok(Arc::new(result) as ArrayRef) +} + +/// Cardinality SQL function +pub fn cardinality(args: &[ArrayRef]) -> Result { + if args.len() != 1 { + return exec_err!("cardinality expects one argument"); + } + + match &args[0].data_type() { + DataType::List(_) => { + let list_array = as_list_array(&args[0])?; + generic_list_cardinality::(list_array) + } + DataType::LargeList(_) => { + let list_array = as_large_list_array(&args[0])?; + generic_list_cardinality::(list_array) + } + other => { + exec_err!("cardinality does not support type '{:?}'", other) + } + } +} + +/// Array_dims SQL function +pub fn array_dims(args: &[ArrayRef]) -> Result { + if args.len() != 1 { + return exec_err!("array_dims needs one argument"); + } + + let data = match args[0].data_type() { + DataType::List(_) => { + let array = as_list_array(&args[0])?; + array + .iter() + .map(compute_array_dims) + .collect::>>()? + } + DataType::LargeList(_) => { + let array = as_large_list_array(&args[0])?; + array + .iter() + .map(compute_array_dims) + .collect::>>()? + } + array_type => { + return exec_err!("array_dims does not support type '{array_type:?}'"); + } + }; + + let result = ListArray::from_iter_primitive::(data); + + Ok(Arc::new(result) as ArrayRef) +} + +/// Array_ndims SQL function +pub fn array_ndims(args: &[ArrayRef]) -> Result { + if args.len() != 1 { + return exec_err!("array_ndims needs one argument"); + } + + fn general_list_ndims( + array: &GenericListArray, + ) -> Result { + let mut data = Vec::new(); + let ndims = datafusion_common::utils::list_ndims(array.data_type()); + + for arr in array.iter() { + if arr.is_some() { + data.push(Some(ndims)) + } else { + data.push(None) + } + } + + Ok(Arc::new(UInt64Array::from(data)) as ArrayRef) + } + match args[0].data_type() { + DataType::List(_) => { + let array = as_list_array(&args[0])?; + general_list_ndims::(array) + } + DataType::LargeList(_) => { + let array = as_large_list_array(&args[0])?; + general_list_ndims::(array) + } + array_type => exec_err!("array_ndims does not support type {array_type:?}"), + } +} +pub fn gen_range_date( + args: &[ArrayRef], + include_upper: i32, +) -> datafusion_common::Result { + if args.len() != 3 { + return exec_err!("arguments length does not match"); + } + let (start_array, stop_array, step_array) = ( + Some(as_date32_array(&args[0])?), + as_date32_array(&args[1])?, + Some(as_interval_mdn_array(&args[2])?), + ); + + let mut values = vec![]; + let mut offsets = vec![0]; + for (idx, stop) in stop_array.iter().enumerate() { + let mut stop = stop.unwrap_or(0); + let start = start_array.as_ref().map(|x| x.value(idx)).unwrap_or(0); + let step = step_array.as_ref().map(|arr| arr.value(idx)).unwrap_or(1); + let (months, days, _) = IntervalMonthDayNanoType::to_parts(step); + let neg = months < 0 || days < 0; + if include_upper == 0 { + stop = Date32Type::subtract_month_day_nano(stop, step); + } + let mut new_date = start; + loop { + if neg && new_date < stop || !neg && new_date > stop { + break; + } + values.push(new_date); + new_date = Date32Type::add_month_day_nano(new_date, step); + } + offsets.push(values.len() as i32); + } + + let arr = Arc::new(ListArray::try_new( + Arc::new(Field::new("item", DataType::Date32, true)), + OffsetBuffer::new(offsets.into()), + Arc::new(Date32Array::from(values)), + None, + )?); + Ok(arr) +} diff --git a/datafusion/functions-array/src/lib.rs b/datafusion/functions-array/src/lib.rs index e3515ccf9f72..e4cdf69aa93a 100644 --- a/datafusion/functions-array/src/lib.rs +++ b/datafusion/functions-array/src/lib.rs @@ -39,7 +39,10 @@ use std::sync::Arc; /// Fluent-style API for creating `Expr`s pub mod expr_fn { + pub use super::udf::array_dims; + pub use super::udf::array_ndims; pub use super::udf::array_to_string; + pub use super::udf::cardinality; pub use super::udf::gen_series; pub use super::udf::range; } @@ -50,6 +53,9 @@ pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { udf::array_to_string_udf(), udf::range_udf(), udf::gen_series_udf(), + udf::array_dims_udf(), + udf::cardinality_udf(), + udf::array_ndims_udf(), ]; functions.into_iter().try_for_each(|udf| { let existing_udf = registry.register_udf(udf)?; diff --git a/datafusion/functions-array/src/udf.rs b/datafusion/functions-array/src/udf.rs index 17769419c0b2..709a33cc4506 100644 --- a/datafusion/functions-array/src/udf.rs +++ b/datafusion/functions-array/src/udf.rs @@ -19,6 +19,8 @@ use arrow::datatypes::DataType; use arrow::datatypes::Field; +use arrow::datatypes::IntervalUnit::MonthDayNano; +use datafusion_common::exec_err; use datafusion_common::plan_err; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::Expr; @@ -26,6 +28,7 @@ use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; use std::any::Any; use std::sync::Arc; + // Create static instances of ScalarUDFs for each function make_udf_function!(ArrayToString, array_to_string, @@ -106,6 +109,7 @@ impl Range { Exact(vec![Int64]), Exact(vec![Int64, Int64]), Exact(vec![Int64, Int64, Int64]), + Exact(vec![Date32, Date32, Interval(MonthDayNano)]), ], Volatility::Immutable, ), @@ -136,7 +140,17 @@ impl ScalarUDFImpl for Range { fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result { let args = ColumnarValue::values_to_arrays(args)?; - crate::kernels::gen_range(&args, 0).map(ColumnarValue::Array) + match args[0].data_type() { + arrow::datatypes::DataType::Int64 => { + crate::kernels::gen_range(&args, 0).map(ColumnarValue::Array) + } + arrow::datatypes::DataType::Date32 => { + crate::kernels::gen_range_date(&args, 0).map(ColumnarValue::Array) + } + _ => { + exec_err!("unsupported type for range") + } + } } fn aliases(&self) -> &[String] { @@ -165,6 +179,7 @@ impl GenSeries { Exact(vec![Int64]), Exact(vec![Int64, Int64]), Exact(vec![Int64, Int64, Int64]), + Exact(vec![Date32, Date32, Interval(MonthDayNano)]), ], Volatility::Immutable, ), @@ -195,7 +210,182 @@ impl ScalarUDFImpl for GenSeries { fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result { let args = ColumnarValue::values_to_arrays(args)?; - crate::kernels::gen_range(&args, 1).map(ColumnarValue::Array) + match args[0].data_type() { + arrow::datatypes::DataType::Int64 => { + crate::kernels::gen_range(&args, 1).map(ColumnarValue::Array) + } + arrow::datatypes::DataType::Date32 => { + crate::kernels::gen_range_date(&args, 1).map(ColumnarValue::Array) + } + _ => { + exec_err!("unsupported type for range") + } + } + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +make_udf_function!( + ArrayDims, + array_dims, + array, + "returns an array of the array's dimensions.", + array_dims_udf +); + +#[derive(Debug)] +pub(super) struct ArrayDims { + signature: Signature, + aliases: Vec, +} + +impl ArrayDims { + pub fn new() -> Self { + Self { + signature: Signature::array(Volatility::Immutable), + aliases: vec!["array_dims".to_string(), "list_dims".to_string()], + } + } +} + +impl ScalarUDFImpl for ArrayDims { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "array_dims" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result { + use DataType::*; + Ok(match arg_types[0] { + List(_) | LargeList(_) | FixedSizeList(_, _) => { + List(Arc::new(Field::new("item", UInt64, true))) + } + _ => { + return plan_err!("The array_dims function can only accept List/LargeList/FixedSizeList."); + } + }) + } + + fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result { + let args = ColumnarValue::values_to_arrays(args)?; + crate::kernels::array_dims(&args).map(ColumnarValue::Array) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +make_udf_function!( + Cardinality, + cardinality, + array, + "returns the total number of elements in the array.", + cardinality_udf +); + +impl Cardinality { + pub fn new() -> Self { + Self { + signature: Signature::array(Volatility::Immutable), + aliases: vec![String::from("cardinality")], + } + } +} + +#[derive(Debug)] +pub(super) struct Cardinality { + signature: Signature, + aliases: Vec, +} +impl ScalarUDFImpl for Cardinality { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "cardinality" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result { + use DataType::*; + Ok(match arg_types[0] { + List(_) | LargeList(_) | FixedSizeList(_, _) => UInt64, + _ => { + return plan_err!("The cardinality function can only accept List/LargeList/FixedSizeList."); + } + }) + } + + fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result { + let args = ColumnarValue::values_to_arrays(args)?; + crate::kernels::cardinality(&args).map(ColumnarValue::Array) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +make_udf_function!( + ArrayNdims, + array_ndims, + array, + "returns the number of dimensions of the array.", + array_ndims_udf +); + +#[derive(Debug)] +pub(super) struct ArrayNdims { + signature: Signature, + aliases: Vec, +} +impl ArrayNdims { + pub fn new() -> Self { + Self { + signature: Signature::array(Volatility::Immutable), + aliases: vec![String::from("array_ndims"), String::from("list_ndims")], + } + } +} + +impl ScalarUDFImpl for ArrayNdims { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "array_ndims" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result { + use DataType::*; + Ok(match arg_types[0] { + List(_) | LargeList(_) | FixedSizeList(_, _) => UInt64, + _ => { + return plan_err!("The array_ndims function can only accept List/LargeList/FixedSizeList."); + } + }) + } + + fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result { + let args = ColumnarValue::values_to_arrays(args)?; + crate::kernels::array_ndims(&args).map(ColumnarValue::Array) } fn aliases(&self) -> &[String] { diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index e890c9623ca3..502c6923019b 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -68,3 +68,7 @@ tokio = { workspace = true, features = ["macros", "rt", "sync"] } [[bench]] harness = false name = "to_timestamp" + +[[bench]] +harness = false +name = "regx" diff --git a/datafusion/functions/benches/regx.rs b/datafusion/functions/benches/regx.rs new file mode 100644 index 000000000000..390676f8f249 --- /dev/null +++ b/datafusion/functions/benches/regx.rs @@ -0,0 +1,107 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern crate criterion; + +use std::sync::Arc; + +use arrow_array::builder::StringBuilder; +use arrow_array::{ArrayRef, StringArray}; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_functions::regex::regexplike::regexp_like; +use datafusion_functions::regex::regexpmatch::regexp_match; +use rand::distributions::Alphanumeric; +use rand::rngs::ThreadRng; +use rand::seq::SliceRandom; +use rand::Rng; +fn data(rng: &mut ThreadRng) -> StringArray { + let mut data: Vec = vec![]; + for _ in 0..1000 { + data.push( + rng.sample_iter(&Alphanumeric) + .take(7) + .map(char::from) + .collect(), + ); + } + + StringArray::from(data) +} + +fn regex(rng: &mut ThreadRng) -> StringArray { + let samples = vec![ + ".*([A-Z]{1}).*".to_string(), + "^(A).*".to_string(), + r#"[\p{Letter}-]+"#.to_string(), + r#"[\p{L}-]+"#.to_string(), + "[a-zA-Z]_[a-zA-Z]{2}".to_string(), + ]; + let mut data: Vec = vec![]; + for _ in 0..1000 { + data.push(samples.choose(rng).unwrap().to_string()); + } + + StringArray::from(data) +} + +fn flags(rng: &mut ThreadRng) -> StringArray { + let samples = vec![Some("i".to_string()), Some("im".to_string()), None]; + let mut sb = StringBuilder::new(); + for _ in 0..1000 { + let sample = samples.choose(rng).unwrap(); + if sample.is_some() { + sb.append_value(sample.clone().unwrap()); + } else { + sb.append_null(); + } + } + + sb.finish() +} + +fn criterion_benchmark(c: &mut Criterion) { + c.bench_function("regexp_like_1000", |b| { + let mut rng = rand::thread_rng(); + let data = Arc::new(data(&mut rng)) as ArrayRef; + let regex = Arc::new(regex(&mut rng)) as ArrayRef; + let flags = Arc::new(flags(&mut rng)) as ArrayRef; + + b.iter(|| { + black_box( + regexp_like::(&[data.clone(), regex.clone(), flags.clone()]) + .expect("regexp_like should work on valid values"), + ) + }) + }); + + c.bench_function("regexp_match_1000", |b| { + let mut rng = rand::thread_rng(); + let data = Arc::new(data(&mut rng)) as ArrayRef; + let regex = Arc::new(regex(&mut rng)) as ArrayRef; + let flags = Arc::new(flags(&mut rng)) as ArrayRef; + + b.iter(|| { + black_box( + regexp_match::(&[data.clone(), regex.clone(), flags.clone()]) + .expect("regexp_match should work on valid values"), + ) + }) + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/src/datetime/to_timestamp.rs b/datafusion/functions/src/datetime/to_timestamp.rs index adba84af72ae..cd8593337c7a 100644 --- a/datafusion/functions/src/datetime/to_timestamp.rs +++ b/datafusion/functions/src/datetime/to_timestamp.rs @@ -561,6 +561,31 @@ mod tests { Ok(()) } + #[test] + fn to_timestamp_with_invalid_tz() -> Result<()> { + let mut date_string_builder = StringBuilder::with_capacity(2, 1024); + + date_string_builder.append_null(); + + date_string_builder.append_value("2020-09-08T13:42:29ZZ"); + + let string_array = + ColumnarValue::Array(Arc::new(date_string_builder.finish()) as ArrayRef); + + let expected_err = + "Arrow error: Parser error: Invalid timezone \"ZZ\": 'ZZ' is not a valid timezone"; + match to_timestamp(&[string_array]) { + Ok(_) => panic!("Expected error but got success"), + Err(e) => { + assert!( + e.to_string().contains(expected_err), + "Can not find expected error '{expected_err}'. Actual error '{e}'" + ); + } + } + Ok(()) + } + #[test] fn to_timestamp_with_no_matching_formats() -> Result<()> { let mut date_string_builder = StringBuilder::with_capacity(2, 1024); diff --git a/datafusion/functions/src/regex/mod.rs b/datafusion/functions/src/regex/mod.rs index 862e8b77a2d6..1e0c7799c6a5 100644 --- a/datafusion/functions/src/regex/mod.rs +++ b/datafusion/functions/src/regex/mod.rs @@ -17,13 +17,18 @@ //! "regx" DataFusion functions -mod regexpmatch; +pub mod regexplike; +pub mod regexpmatch; + // create UDFs make_udf_function!(regexpmatch::RegexpMatchFunc, REGEXP_MATCH, regexp_match); - +make_udf_function!(regexplike::RegexpLikeFunc, REGEXP_LIKE, regexp_like); export_functions!(( regexp_match, - input_arg1 - input_arg2, + input_arg1 input_arg2, "returns a list of regular expression matches in a string. " +),( + regexp_like, + input_arg1 input_arg2, + "Returns true if a has at least one match in a string,false otherwise." )); diff --git a/datafusion/functions/src/regex/regexplike.rs b/datafusion/functions/src/regex/regexplike.rs new file mode 100644 index 000000000000..b0abad318058 --- /dev/null +++ b/datafusion/functions/src/regex/regexplike.rs @@ -0,0 +1,252 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Regx expressions +use arrow::array::{Array, ArrayRef, OffsetSizeTrait}; +use arrow::compute::kernels::regexp; +use arrow::datatypes::DataType; +use datafusion_common::exec_err; +use datafusion_common::ScalarValue; +use datafusion_common::{arrow_datafusion_err, plan_err}; +use datafusion_common::{ + cast::as_generic_string_array, internal_err, DataFusionError, Result, +}; +use datafusion_expr::ColumnarValue; +use datafusion_expr::TypeSignature::*; +use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use std::any::Any; +use std::sync::Arc; + +#[derive(Debug)] +pub(super) struct RegexpLikeFunc { + signature: Signature, +} +impl RegexpLikeFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![ + Exact(vec![Utf8, Utf8]), + Exact(vec![LargeUtf8, Utf8]), + Exact(vec![Utf8, Utf8, Utf8]), + Exact(vec![LargeUtf8, Utf8, Utf8]), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for RegexpLikeFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "regexp_like" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + use DataType::*; + + Ok(match &arg_types[0] { + LargeUtf8 | Utf8 => Boolean, + Null => Null, + other => { + return plan_err!( + "The regexp_like function can only accept strings. Got {other}" + ); + } + }) + } + fn invoke(&self, args: &[ColumnarValue]) -> Result { + let len = args + .iter() + .fold(Option::::None, |acc, arg| match arg { + ColumnarValue::Scalar(_) => acc, + ColumnarValue::Array(a) => Some(a.len()), + }); + + let is_scalar = len.is_none(); + let inferred_length = len.unwrap_or(1); + let args = args + .iter() + .map(|arg| arg.clone().into_array(inferred_length)) + .collect::>>()?; + + let result = regexp_like_func(&args); + if is_scalar { + // If all inputs are scalar, keeps output as scalar + let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0)); + result.map(ColumnarValue::Scalar) + } else { + result.map(ColumnarValue::Array) + } + } +} +fn regexp_like_func(args: &[ArrayRef]) -> Result { + match args[0].data_type() { + DataType::Utf8 => regexp_like::(args), + DataType::LargeUtf8 => regexp_like::(args), + other => { + internal_err!("Unsupported data type {other:?} for function regexp_like") + } + } +} +/// Tests a string using a regular expression returning true if at +/// least one match, false otherwise. +/// +/// The full list of supported features and syntax can be found at +/// +/// +/// Supported flags can be found at +/// +/// +/// # Examples +/// +/// ```ignore +/// # use datafusion::prelude::*; +/// # use datafusion::error::Result; +/// # #[tokio::main] +/// # async fn main() -> Result<()> { +/// let ctx = SessionContext::new(); +/// let df = ctx.read_csv("tests/data/regex.csv", CsvReadOptions::new()).await?; +/// +/// // use the regexp_like function to test col 'values', +/// // against patterns in col 'patterns' without flags +/// let df = df.with_column( +/// "a", +/// regexp_like(vec![col("values"), col("patterns")]) +/// )?; +/// // use the regexp_like function to test col 'values', +/// // against patterns in col 'patterns' with flags +/// let df = df.with_column( +/// "b", +/// regexp_like(vec![col("values"), col("patterns"), col("flags")]) +/// )?; +/// // literals can be used as well with dataframe calls +/// let df = df.with_column( +/// "c", +/// regexp_like(vec![lit("foobarbequebaz"), lit("(bar)(beque)")]) +/// )?; +/// +/// df.show().await?; +/// +/// # Ok(()) +/// # } +/// ``` +pub fn regexp_like(args: &[ArrayRef]) -> Result { + match args.len() { + 2 => { + let values = as_generic_string_array::(&args[0])?; + let regex = as_generic_string_array::(&args[1])?; + let array = regexp::regexp_is_match_utf8(values, regex, None) + .map_err(|e| arrow_datafusion_err!(e))?; + + Ok(Arc::new(array) as ArrayRef) + } + 3 => { + let values = as_generic_string_array::(&args[0])?; + let regex = as_generic_string_array::(&args[1])?; + let flags = as_generic_string_array::(&args[2])?; + + if flags.iter().any(|s| s == Some("g")) { + return plan_err!("regexp_like() does not support the \"global\" option"); + } + + let array = regexp::regexp_is_match_utf8(values, regex, Some(flags)) + .map_err(|e| arrow_datafusion_err!(e))?; + + Ok(Arc::new(array) as ArrayRef) + } + other => exec_err!( + "regexp_like was called with {other} arguments. It requires at least 2 and at most 3." + ), + } +} +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use arrow::array::BooleanBuilder; + use arrow_array::StringArray; + + use crate::regex::regexplike::regexp_like; + + #[test] + fn test_case_sensitive_regexp_like() { + let values = StringArray::from(vec!["abc"; 5]); + + let patterns = + StringArray::from(vec!["^(a)", "^(A)", "(b|d)", "(B|D)", "^(b|c)"]); + + let mut expected_builder: BooleanBuilder = BooleanBuilder::new(); + expected_builder.append_value(true); + expected_builder.append_value(false); + expected_builder.append_value(true); + expected_builder.append_value(false); + expected_builder.append_value(false); + let expected = expected_builder.finish(); + + let re = regexp_like::(&[Arc::new(values), Arc::new(patterns)]).unwrap(); + + assert_eq!(re.as_ref(), &expected); + } + + #[test] + fn test_case_insensitive_regexp_like() { + let values = StringArray::from(vec!["abc"; 5]); + let patterns = + StringArray::from(vec!["^(a)", "^(A)", "(b|d)", "(B|D)", "^(b|c)"]); + let flags = StringArray::from(vec!["i"; 5]); + + let mut expected_builder: BooleanBuilder = BooleanBuilder::new(); + expected_builder.append_value(true); + expected_builder.append_value(true); + expected_builder.append_value(true); + expected_builder.append_value(true); + expected_builder.append_value(false); + let expected = expected_builder.finish(); + + let re = + regexp_like::(&[Arc::new(values), Arc::new(patterns), Arc::new(flags)]) + .unwrap(); + + assert_eq!(re.as_ref(), &expected); + } + + #[test] + fn test_unsupported_global_flag_regexp_like() { + let values = StringArray::from(vec!["abc"]); + let patterns = StringArray::from(vec!["^(a)"]); + let flags = StringArray::from(vec!["g"]); + + let re_err = + regexp_like::(&[Arc::new(values), Arc::new(patterns), Arc::new(flags)]) + .expect_err("unsupported flag should have failed"); + + assert_eq!( + re_err.strip_backtrace(), + "Error during planning: regexp_like() does not support the \"global\" option" + ); + } +} diff --git a/datafusion/functions/src/regex/regexpmatch.rs b/datafusion/functions/src/regex/regexpmatch.rs index 8a2180f00be7..f34502af35b7 100644 --- a/datafusion/functions/src/regex/regexpmatch.rs +++ b/datafusion/functions/src/regex/regexpmatch.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! Encoding expressions +//! Regx expressions use arrow::array::{Array, ArrayRef, OffsetSizeTrait}; use arrow::compute::kernels::regexp; use arrow::datatypes::DataType; @@ -139,3 +139,72 @@ pub fn regexp_match(args: &[ArrayRef]) -> Result { ), } } +#[cfg(test)] +mod tests { + use crate::regex::regexpmatch::regexp_match; + use arrow::array::{GenericStringBuilder, ListBuilder}; + use arrow_array::StringArray; + use std::sync::Arc; + + #[test] + fn test_case_sensitive_regexp_match() { + let values = StringArray::from(vec!["abc"; 5]); + let patterns = + StringArray::from(vec!["^(a)", "^(A)", "(b|d)", "(B|D)", "^(b|c)"]); + + let elem_builder: GenericStringBuilder = GenericStringBuilder::new(); + let mut expected_builder = ListBuilder::new(elem_builder); + expected_builder.values().append_value("a"); + expected_builder.append(true); + expected_builder.append(false); + expected_builder.values().append_value("b"); + expected_builder.append(true); + expected_builder.append(false); + expected_builder.append(false); + let expected = expected_builder.finish(); + + let re = regexp_match::(&[Arc::new(values), Arc::new(patterns)]).unwrap(); + + assert_eq!(re.as_ref(), &expected); + } + + #[test] + fn test_case_insensitive_regexp_match() { + let values = StringArray::from(vec!["abc"; 5]); + let patterns = + StringArray::from(vec!["^(a)", "^(A)", "(b|d)", "(B|D)", "^(b|c)"]); + let flags = StringArray::from(vec!["i"; 5]); + + let elem_builder: GenericStringBuilder = GenericStringBuilder::new(); + let mut expected_builder = ListBuilder::new(elem_builder); + expected_builder.values().append_value("a"); + expected_builder.append(true); + expected_builder.values().append_value("a"); + expected_builder.append(true); + expected_builder.values().append_value("b"); + expected_builder.append(true); + expected_builder.values().append_value("b"); + expected_builder.append(true); + expected_builder.append(false); + let expected = expected_builder.finish(); + + let re = + regexp_match::(&[Arc::new(values), Arc::new(patterns), Arc::new(flags)]) + .unwrap(); + + assert_eq!(re.as_ref(), &expected); + } + + #[test] + fn test_unsupported_global_flag_regexp_match() { + let values = StringArray::from(vec!["abc"]); + let patterns = StringArray::from(vec!["^(a)"]); + let flags = StringArray::from(vec!["g"]); + + let re_err = + regexp_match::(&[Arc::new(values), Arc::new(patterns), Arc::new(flags)]) + .expect_err("unsupported flag should have failed"); + + assert_eq!(re_err.strip_backtrace(), "Error during planning: regexp_match() does not support the \"global\" option"); + } +} diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index 9242e68562c6..93b24d71c496 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -15,9 +15,14 @@ // specific language governing permissions and limitations // under the License. +use std::sync::Arc; + use crate::analyzer::AnalyzerRule; + use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter}; +use datafusion_common::tree_node::{ + Transformed, TransformedResult, TreeNode, TreeNodeRewriter, +}; use datafusion_common::Result; use datafusion_expr::expr::{AggregateFunction, AggregateFunctionDefinition, InSubquery}; use datafusion_expr::expr_rewriter::rewrite_preserving_name; @@ -27,7 +32,6 @@ use datafusion_expr::{ aggregate_function, expr, lit, Aggregate, Expr, Filter, LogicalPlan, LogicalPlanBuilder, Projection, Sort, Subquery, }; -use std::sync::Arc; /// Rewrite `Count(Expr:Wildcard)` to `Count(Expr:Literal)`. /// @@ -43,7 +47,7 @@ impl CountWildcardRule { impl AnalyzerRule for CountWildcardRule { fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result { - plan.transform_down(&analyze_internal) + plan.transform_down(&analyze_internal).data() } fn name(&self) -> &str { @@ -61,7 +65,7 @@ fn analyze_internal(plan: LogicalPlan) -> Result> { .map(|expr| rewrite_preserving_name(expr.clone(), &mut rewriter)) .collect::>>()?; - Ok(Transformed::Yes( + Ok(Transformed::yes( LogicalPlanBuilder::from((*window.input).clone()) .window(window_expr)? .build()?, @@ -74,7 +78,7 @@ fn analyze_internal(plan: LogicalPlan) -> Result> { .map(|expr| rewrite_preserving_name(expr.clone(), &mut rewriter)) .collect::>>()?; - Ok(Transformed::Yes(LogicalPlan::Aggregate( + Ok(Transformed::yes(LogicalPlan::Aggregate( Aggregate::try_new(agg.input.clone(), agg.group_expr, aggr_expr)?, ))) } @@ -83,7 +87,7 @@ fn analyze_internal(plan: LogicalPlan) -> Result> { .iter() .map(|expr| rewrite_preserving_name(expr.clone(), &mut rewriter)) .collect::>>()?; - Ok(Transformed::Yes(LogicalPlan::Sort(Sort { + Ok(Transformed::yes(LogicalPlan::Sort(Sort { expr: sort_expr, input, fetch, @@ -95,7 +99,7 @@ fn analyze_internal(plan: LogicalPlan) -> Result> { .iter() .map(|expr| rewrite_preserving_name(expr.clone(), &mut rewriter)) .collect::>>()?; - Ok(Transformed::Yes(LogicalPlan::Projection( + Ok(Transformed::yes(LogicalPlan::Projection( Projection::try_new(projection_expr, projection.input)?, ))) } @@ -103,22 +107,22 @@ fn analyze_internal(plan: LogicalPlan) -> Result> { predicate, input, .. }) => { let predicate = rewrite_preserving_name(predicate, &mut rewriter)?; - Ok(Transformed::Yes(LogicalPlan::Filter(Filter::try_new( + Ok(Transformed::yes(LogicalPlan::Filter(Filter::try_new( predicate, input, )?))) } - _ => Ok(Transformed::No(plan)), + _ => Ok(Transformed::no(plan)), } } struct CountWildcardRewriter {} impl TreeNodeRewriter for CountWildcardRewriter { - type N = Expr; + type Node = Expr; - fn mutate(&mut self, old_expr: Expr) -> Result { - let new_expr = match old_expr.clone() { + fn f_up(&mut self, old_expr: Expr) -> Result> { + Ok(match old_expr.clone() { Expr::WindowFunction(expr::WindowFunction { fun: expr::WindowFunctionDefinition::AggregateFunction( @@ -131,7 +135,7 @@ impl TreeNodeRewriter for CountWildcardRewriter { null_treatment, }) if args.len() == 1 => match args[0] { Expr::Wildcard { qualifier: None } => { - Expr::WindowFunction(expr::WindowFunction { + Transformed::yes(Expr::WindowFunction(expr::WindowFunction { fun: expr::WindowFunctionDefinition::AggregateFunction( aggregate_function::AggregateFunction::Count, ), @@ -140,10 +144,10 @@ impl TreeNodeRewriter for CountWildcardRewriter { order_by, window_frame, null_treatment, - }) + })) } - _ => old_expr, + _ => Transformed::no(old_expr), }, Expr::AggregateFunction(AggregateFunction { func_def: @@ -156,68 +160,65 @@ impl TreeNodeRewriter for CountWildcardRewriter { order_by, }) if args.len() == 1 => match args[0] { Expr::Wildcard { qualifier: None } => { - Expr::AggregateFunction(AggregateFunction::new( + Transformed::yes(Expr::AggregateFunction(AggregateFunction::new( aggregate_function::AggregateFunction::Count, vec![lit(COUNT_STAR_EXPANSION)], distinct, filter, order_by, - )) + ))) } - _ => old_expr, + _ => Transformed::no(old_expr), }, ScalarSubquery(Subquery { subquery, outer_ref_columns, - }) => { - let new_plan = subquery - .as_ref() - .clone() - .transform_down(&analyze_internal)?; - ScalarSubquery(Subquery { - subquery: Arc::new(new_plan), - outer_ref_columns, - }) - } + }) => subquery + .as_ref() + .clone() + .transform_down(&analyze_internal)? + .update_data(|new_plan| { + ScalarSubquery(Subquery { + subquery: Arc::new(new_plan), + outer_ref_columns, + }) + }), Expr::InSubquery(InSubquery { expr, subquery, negated, - }) => { - let new_plan = subquery - .subquery - .as_ref() - .clone() - .transform_down(&analyze_internal)?; - - Expr::InSubquery(InSubquery::new( - expr, - Subquery { - subquery: Arc::new(new_plan), - outer_ref_columns: subquery.outer_ref_columns, - }, - negated, - )) - } - Expr::Exists(expr::Exists { subquery, negated }) => { - let new_plan = subquery - .subquery - .as_ref() - .clone() - .transform_down(&analyze_internal)?; - - Expr::Exists(expr::Exists { - subquery: Subquery { - subquery: Arc::new(new_plan), - outer_ref_columns: subquery.outer_ref_columns, - }, - negated, - }) - } - _ => old_expr, - }; - Ok(new_expr) + }) => subquery + .subquery + .as_ref() + .clone() + .transform_down(&analyze_internal)? + .update_data(|new_plan| { + Expr::InSubquery(InSubquery::new( + expr, + Subquery { + subquery: Arc::new(new_plan), + outer_ref_columns: subquery.outer_ref_columns, + }, + negated, + )) + }), + Expr::Exists(expr::Exists { subquery, negated }) => subquery + .subquery + .as_ref() + .clone() + .transform_down(&analyze_internal)? + .update_data(|new_plan| { + Expr::Exists(expr::Exists { + subquery: Subquery { + subquery: Arc::new(new_plan), + outer_ref_columns: subquery.outer_ref_columns, + }, + negated, + }) + }), + _ => Transformed::no(old_expr), + }) } } diff --git a/datafusion/optimizer/src/analyzer/inline_table_scan.rs b/datafusion/optimizer/src/analyzer/inline_table_scan.rs index 90af7aec8293..b21ec851dfcd 100644 --- a/datafusion/optimizer/src/analyzer/inline_table_scan.rs +++ b/datafusion/optimizer/src/analyzer/inline_table_scan.rs @@ -20,11 +20,11 @@ use std::sync::Arc; use crate::analyzer::AnalyzerRule; + use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::Result; -use datafusion_expr::expr::Exists; -use datafusion_expr::expr::InSubquery; +use datafusion_expr::expr::{Exists, InSubquery}; use datafusion_expr::{ logical_plan::LogicalPlan, Expr, Filter, LogicalPlanBuilder, TableScan, }; @@ -42,7 +42,7 @@ impl InlineTableScan { impl AnalyzerRule for InlineTableScan { fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result { - plan.transform_up(&analyze_internal) + plan.transform_up(&analyze_internal).data() } fn name(&self) -> &str { @@ -51,7 +51,7 @@ impl AnalyzerRule for InlineTableScan { } fn analyze_internal(plan: LogicalPlan) -> Result> { - Ok(match plan { + match plan { // Match only on scans without filter / projection / fetch // Views and DataFrames won't have those added // during the early stage of planning @@ -64,33 +64,31 @@ fn analyze_internal(plan: LogicalPlan) -> Result> { }) if filters.is_empty() && source.get_logical_plan().is_some() => { let sub_plan = source.get_logical_plan().unwrap(); let projection_exprs = generate_projection_expr(&projection, sub_plan)?; - let plan = LogicalPlanBuilder::from(sub_plan.clone()) + LogicalPlanBuilder::from(sub_plan.clone()) .project(projection_exprs)? // Ensures that the reference to the inlined table remains the // same, meaning we don't have to change any of the parent nodes // that reference this table. .alias(table_name)? - .build()?; - Transformed::Yes(plan) + .build() + .map(Transformed::yes) } LogicalPlan::Filter(filter) => { - let new_expr = filter.predicate.transform(&rewrite_subquery)?; - Transformed::Yes(LogicalPlan::Filter(Filter::try_new( - new_expr, - filter.input, - )?)) + let new_expr = filter.predicate.transform(&rewrite_subquery).data()?; + Filter::try_new(new_expr, filter.input) + .map(|e| Transformed::yes(LogicalPlan::Filter(e))) } - _ => Transformed::No(plan), - }) + _ => Ok(Transformed::no(plan)), + } } fn rewrite_subquery(expr: Expr) -> Result> { match expr { Expr::Exists(Exists { subquery, negated }) => { let plan = subquery.subquery.as_ref().clone(); - let new_plan = plan.transform_up(&analyze_internal)?; + let new_plan = plan.transform_up(&analyze_internal).data()?; let subquery = subquery.with_plan(Arc::new(new_plan)); - Ok(Transformed::Yes(Expr::Exists(Exists { subquery, negated }))) + Ok(Transformed::yes(Expr::Exists(Exists { subquery, negated }))) } Expr::InSubquery(InSubquery { expr, @@ -98,19 +96,19 @@ fn rewrite_subquery(expr: Expr) -> Result> { negated, }) => { let plan = subquery.subquery.as_ref().clone(); - let new_plan = plan.transform_up(&analyze_internal)?; + let new_plan = plan.transform_up(&analyze_internal).data()?; let subquery = subquery.with_plan(Arc::new(new_plan)); - Ok(Transformed::Yes(Expr::InSubquery(InSubquery::new( + Ok(Transformed::yes(Expr::InSubquery(InSubquery::new( expr, subquery, negated, )))) } Expr::ScalarSubquery(subquery) => { let plan = subquery.subquery.as_ref().clone(); - let new_plan = plan.transform_up(&analyze_internal)?; + let new_plan = plan.transform_up(&analyze_internal).data()?; let subquery = subquery.with_plan(Arc::new(new_plan)); - Ok(Transformed::Yes(Expr::ScalarSubquery(subquery))) + Ok(Transformed::yes(Expr::ScalarSubquery(subquery))) } - _ => Ok(Transformed::No(expr)), + _ => Ok(Transformed::no(expr)), } } @@ -135,13 +133,12 @@ fn generate_projection_expr( mod tests { use std::{sync::Arc, vec}; - use arrow::datatypes::{DataType, Field, Schema}; - - use datafusion_expr::{col, lit, LogicalPlan, LogicalPlanBuilder, TableSource}; - use crate::analyzer::inline_table_scan::InlineTableScan; use crate::test::assert_analyzed_plan_eq; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_expr::{col, lit, LogicalPlan, LogicalPlanBuilder, TableSource}; + pub struct RawTableSource {} impl TableSource for RawTableSource { diff --git a/datafusion/optimizer/src/analyzer/mod.rs b/datafusion/optimizer/src/analyzer/mod.rs index 4c480017fc3a..08caa4be60a9 100644 --- a/datafusion/optimizer/src/analyzer/mod.rs +++ b/datafusion/optimizer/src/analyzer/mod.rs @@ -29,7 +29,7 @@ use crate::analyzer::type_coercion::TypeCoercion; use crate::utils::log_plan; use datafusion_common::config::ConfigOptions; use datafusion_common::instant::Instant; -use datafusion_common::tree_node::{TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::expr::Exists; use datafusion_expr::expr::InSubquery; @@ -136,7 +136,7 @@ fn check_plan(plan: &LogicalPlan) -> Result<()> { })?; } - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) })?; Ok(()) diff --git a/datafusion/optimizer/src/analyzer/rewrite_expr.rs b/datafusion/optimizer/src/analyzer/rewrite_expr.rs index eedfc40a7f80..41ebcd8e501a 100644 --- a/datafusion/optimizer/src/analyzer/rewrite_expr.rs +++ b/datafusion/optimizer/src/analyzer/rewrite_expr.rs @@ -19,21 +19,19 @@ use std::sync::Arc; +use super::AnalyzerRule; + use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::TreeNodeRewriter; +use datafusion_common::tree_node::{Transformed, TreeNodeRewriter}; use datafusion_common::utils::list_ndims; -use datafusion_common::DFSchema; -use datafusion_common::DFSchemaRef; -use datafusion_common::Result; +use datafusion_common::{DFSchema, DFSchemaRef, Result}; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::expr_rewriter::rewrite_preserving_name; use datafusion_expr::utils::merge_schema; -use datafusion_expr::BuiltinScalarFunction; -use datafusion_expr::Operator; -use datafusion_expr::ScalarFunctionDefinition; -use datafusion_expr::{BinaryExpr, Expr, LogicalPlan}; - -use super::AnalyzerRule; +use datafusion_expr::{ + BinaryExpr, BuiltinScalarFunction, Expr, LogicalPlan, Operator, + ScalarFunctionDefinition, +}; #[derive(Default)] pub struct OperatorToFunction {} @@ -94,41 +92,34 @@ pub(crate) struct OperatorToFunctionRewriter { } impl TreeNodeRewriter for OperatorToFunctionRewriter { - type N = Expr; + type Node = Expr; - fn mutate(&mut self, expr: Expr) -> Result { - match expr { - Expr::BinaryExpr(BinaryExpr { - ref left, + fn f_up(&mut self, expr: Expr) -> Result> { + if let Expr::BinaryExpr(BinaryExpr { + ref left, + op, + ref right, + }) = expr + { + if let Some(fun) = rewrite_array_concat_operator_to_func_for_column( + left.as_ref(), op, - ref right, - }) => { - if let Some(fun) = rewrite_array_concat_operator_to_func_for_column( - left.as_ref(), - op, - right.as_ref(), - self.schema.as_ref(), - )? - .or_else(|| { - rewrite_array_concat_operator_to_func( - left.as_ref(), - op, - right.as_ref(), - ) - }) { - // Convert &Box -> Expr - let left = (**left).clone(); - let right = (**right).clone(); - return Ok(Expr::ScalarFunction(ScalarFunction { - func_def: ScalarFunctionDefinition::BuiltIn(fun), - args: vec![left, right], - })); - } - - Ok(expr) + right.as_ref(), + self.schema.as_ref(), + )? + .or_else(|| { + rewrite_array_concat_operator_to_func(left.as_ref(), op, right.as_ref()) + }) { + // Convert &Box -> Expr + let left = (**left).clone(); + let right = (**right).clone(); + return Ok(Transformed::yes(Expr::ScalarFunction(ScalarFunction { + func_def: ScalarFunctionDefinition::BuiltIn(fun), + args: vec![left, right], + }))); } - _ => Ok(expr), } + Ok(Transformed::no(expr)) } } diff --git a/datafusion/optimizer/src/analyzer/subquery.rs b/datafusion/optimizer/src/analyzer/subquery.rs index a0e972fc703c..b7f513727d39 100644 --- a/datafusion/optimizer/src/analyzer/subquery.rs +++ b/datafusion/optimizer/src/analyzer/subquery.rs @@ -15,9 +15,12 @@ // specific language governing permissions and limitations // under the License. +use std::ops::Deref; + use crate::analyzer::check_plan; use crate::utils::collect_subquery_cols; -use datafusion_common::tree_node::{TreeNode, VisitRecursion}; + +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion_common::{plan_err, Result}; use datafusion_expr::expr_rewriter::strip_outer_reference; use datafusion_expr::utils::split_conjunction; @@ -25,7 +28,6 @@ use datafusion_expr::{ Aggregate, BinaryExpr, Cast, Expr, Filter, Join, JoinType, LogicalPlan, Operator, Window, }; -use std::ops::Deref; /// Do necessary check on subquery expressions and fail the invalid plan /// 1) Check whether the outer plan is in the allowed outer plans list to use subquery expressions, @@ -146,7 +148,7 @@ fn check_inner_plan( LogicalPlan::Aggregate(_) => { inner_plan.apply_children(&mut |plan| { check_inner_plan(plan, is_scalar, true, can_contain_outer_ref)?; - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) })?; Ok(()) } @@ -171,7 +173,7 @@ fn check_inner_plan( check_mixed_out_refer_in_window(window)?; inner_plan.apply_children(&mut |plan| { check_inner_plan(plan, is_scalar, is_aggregate, can_contain_outer_ref)?; - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) })?; Ok(()) } @@ -188,7 +190,7 @@ fn check_inner_plan( | LogicalPlan::SubqueryAlias(_) => { inner_plan.apply_children(&mut |plan| { check_inner_plan(plan, is_scalar, is_aggregate, can_contain_outer_ref)?; - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) })?; Ok(()) } @@ -206,7 +208,7 @@ fn check_inner_plan( is_aggregate, can_contain_outer_ref, )?; - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) })?; Ok(()) } @@ -221,7 +223,7 @@ fn check_inner_plan( JoinType::Full => { inner_plan.apply_children(&mut |plan| { check_inner_plan(plan, is_scalar, is_aggregate, false)?; - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) })?; Ok(()) } @@ -287,12 +289,11 @@ fn get_correlated_expressions(inner_plan: &LogicalPlan) -> Result> { .into_iter() .partition(|e| e.contains_outer()); - correlated - .into_iter() - .for_each(|expr| exprs.push(strip_outer_reference(expr.clone()))); - return Ok(VisitRecursion::Continue); + for expr in correlated { + exprs.push(strip_outer_reference(expr.clone())); + } } - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) })?; Ok(exprs) } diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index d469e0f8ce0d..08f49ed15b09 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -19,10 +19,11 @@ use std::sync::Arc; -use arrow::datatypes::{DataType, IntervalUnit}; +use crate::analyzer::AnalyzerRule; +use arrow::datatypes::{DataType, IntervalUnit}; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{RewriteRecursion, TreeNodeRewriter}; +use datafusion_common::tree_node::{Transformed, TreeNodeRewriter}; use datafusion_common::{ exec_err, internal_err, plan_datafusion_err, plan_err, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, @@ -50,8 +51,6 @@ use datafusion_expr::{ WindowFrameBound, WindowFrameUnits, }; -use crate::analyzer::AnalyzerRule; - #[derive(Default)] pub struct TypeCoercion {} @@ -126,13 +125,9 @@ pub(crate) struct TypeCoercionRewriter { } impl TreeNodeRewriter for TypeCoercionRewriter { - type N = Expr; - - fn pre_visit(&mut self, _expr: &Expr) -> Result { - Ok(RewriteRecursion::Continue) - } + type Node = Expr; - fn mutate(&mut self, expr: Expr) -> Result { + fn f_up(&mut self, expr: Expr) -> Result> { match expr { Expr::Unnest(_) => internal_err!( "Unnest should be rewritten to LogicalPlan::Unnest before type coercion" @@ -142,20 +137,20 @@ impl TreeNodeRewriter for TypeCoercionRewriter { outer_ref_columns, }) => { let new_plan = analyze_internal(&self.schema, &subquery)?; - Ok(Expr::ScalarSubquery(Subquery { + Ok(Transformed::yes(Expr::ScalarSubquery(Subquery { subquery: Arc::new(new_plan), outer_ref_columns, - })) + }))) } Expr::Exists(Exists { subquery, negated }) => { let new_plan = analyze_internal(&self.schema, &subquery.subquery)?; - Ok(Expr::Exists(Exists { + Ok(Transformed::yes(Expr::Exists(Exists { subquery: Subquery { subquery: Arc::new(new_plan), outer_ref_columns: subquery.outer_ref_columns, }, negated, - })) + }))) } Expr::InSubquery(InSubquery { expr, @@ -173,42 +168,34 @@ impl TreeNodeRewriter for TypeCoercionRewriter { subquery: Arc::new(new_plan), outer_ref_columns: subquery.outer_ref_columns, }; - Ok(Expr::InSubquery(InSubquery::new( + Ok(Transformed::yes(Expr::InSubquery(InSubquery::new( Box::new(expr.cast_to(&common_type, &self.schema)?), cast_subquery(new_subquery, &common_type)?, negated, - ))) - } - Expr::Not(expr) => { - let expr = not(get_casted_expr_for_bool_op(&expr, &self.schema)?); - Ok(expr) - } - Expr::IsTrue(expr) => { - let expr = is_true(get_casted_expr_for_bool_op(&expr, &self.schema)?); - Ok(expr) - } - Expr::IsNotTrue(expr) => { - let expr = is_not_true(get_casted_expr_for_bool_op(&expr, &self.schema)?); - Ok(expr) - } - Expr::IsFalse(expr) => { - let expr = is_false(get_casted_expr_for_bool_op(&expr, &self.schema)?); - Ok(expr) - } - Expr::IsNotFalse(expr) => { - let expr = - is_not_false(get_casted_expr_for_bool_op(&expr, &self.schema)?); - Ok(expr) - } - Expr::IsUnknown(expr) => { - let expr = is_unknown(get_casted_expr_for_bool_op(&expr, &self.schema)?); - Ok(expr) - } - Expr::IsNotUnknown(expr) => { - let expr = - is_not_unknown(get_casted_expr_for_bool_op(&expr, &self.schema)?); - Ok(expr) + )))) } + Expr::Not(expr) => Ok(Transformed::yes(not(get_casted_expr_for_bool_op( + &expr, + &self.schema, + )?))), + Expr::IsTrue(expr) => Ok(Transformed::yes(is_true( + get_casted_expr_for_bool_op(&expr, &self.schema)?, + ))), + Expr::IsNotTrue(expr) => Ok(Transformed::yes(is_not_true( + get_casted_expr_for_bool_op(&expr, &self.schema)?, + ))), + Expr::IsFalse(expr) => Ok(Transformed::yes(is_false( + get_casted_expr_for_bool_op(&expr, &self.schema)?, + ))), + Expr::IsNotFalse(expr) => Ok(Transformed::yes(is_not_false( + get_casted_expr_for_bool_op(&expr, &self.schema)?, + ))), + Expr::IsUnknown(expr) => Ok(Transformed::yes(is_unknown( + get_casted_expr_for_bool_op(&expr, &self.schema)?, + ))), + Expr::IsNotUnknown(expr) => Ok(Transformed::yes(is_not_unknown( + get_casted_expr_for_bool_op(&expr, &self.schema)?, + ))), Expr::Like(Like { negated, expr, @@ -230,14 +217,13 @@ impl TreeNodeRewriter for TypeCoercionRewriter { })?; let expr = Box::new(expr.cast_to(&coerced_type, &self.schema)?); let pattern = Box::new(pattern.cast_to(&coerced_type, &self.schema)?); - let expr = Expr::Like(Like::new( + Ok(Transformed::yes(Expr::Like(Like::new( negated, expr, pattern, escape_char, case_insensitive, - )); - Ok(expr) + )))) } Expr::BinaryExpr(BinaryExpr { left, op, right }) => { let (left_type, right_type) = get_input_types( @@ -245,12 +231,11 @@ impl TreeNodeRewriter for TypeCoercionRewriter { &op, &right.get_type(&self.schema)?, )?; - - Ok(Expr::BinaryExpr(BinaryExpr::new( + Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr::new( Box::new(left.cast_to(&left_type, &self.schema)?), op, Box::new(right.cast_to(&right_type, &self.schema)?), - ))) + )))) } Expr::Between(Between { expr, @@ -280,13 +265,12 @@ impl TreeNodeRewriter for TypeCoercionRewriter { "Failed to coerce types {expr_type} and {high_type} in BETWEEN expression" )) })?; - let expr = Expr::Between(Between::new( + Ok(Transformed::yes(Expr::Between(Between::new( Box::new(expr.cast_to(&coercion_type, &self.schema)?), negated, Box::new(low.cast_to(&coercion_type, &self.schema)?), Box::new(high.cast_to(&coercion_type, &self.schema)?), - )); - Ok(expr) + )))) } Expr::InList(InList { expr, @@ -313,18 +297,17 @@ impl TreeNodeRewriter for TypeCoercionRewriter { list_expr.cast_to(&coerced_type, &self.schema) }) .collect::>>()?; - let expr = Expr::InList(InList ::new( + Ok(Transformed::yes(Expr::InList(InList ::new( Box::new(cast_expr), cast_list_expr, negated, - )); - Ok(expr) + )))) } } } Expr::Case(case) => { let case = coerce_case_expression(case, &self.schema)?; - Ok(Expr::Case(case)) + Ok(Transformed::yes(Expr::Case(case))) } Expr::ScalarFunction(ScalarFunction { func_def, args }) => match func_def { ScalarFunctionDefinition::BuiltIn(fun) => { @@ -338,7 +321,9 @@ impl TreeNodeRewriter for TypeCoercionRewriter { &self.schema, &fun, )?; - Ok(Expr::ScalarFunction(ScalarFunction::new(fun, new_args))) + Ok(Transformed::yes(Expr::ScalarFunction(ScalarFunction::new( + fun, new_args, + )))) } ScalarFunctionDefinition::UDF(fun) => { let new_expr = coerce_arguments_for_signature( @@ -346,7 +331,9 @@ impl TreeNodeRewriter for TypeCoercionRewriter { &self.schema, fun.signature(), )?; - Ok(Expr::ScalarFunction(ScalarFunction::new_udf(fun, new_expr))) + Ok(Transformed::yes(Expr::ScalarFunction( + ScalarFunction::new_udf(fun, new_expr), + ))) } ScalarFunctionDefinition::Name(_) => { internal_err!("Function `Expr` with name should be resolved.") @@ -366,10 +353,11 @@ impl TreeNodeRewriter for TypeCoercionRewriter { &self.schema, &fun.signature(), )?; - let expr = Expr::AggregateFunction(expr::AggregateFunction::new( - fun, new_expr, distinct, filter, order_by, - )); - Ok(expr) + Ok(Transformed::yes(Expr::AggregateFunction( + expr::AggregateFunction::new( + fun, new_expr, distinct, filter, order_by, + ), + ))) } AggregateFunctionDefinition::UDF(fun) => { let new_expr = coerce_arguments_for_signature( @@ -377,10 +365,11 @@ impl TreeNodeRewriter for TypeCoercionRewriter { &self.schema, fun.signature(), )?; - let expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf( - fun, new_expr, false, filter, order_by, - )); - Ok(expr) + Ok(Transformed::yes(Expr::AggregateFunction( + expr::AggregateFunction::new_udf( + fun, new_expr, false, filter, order_by, + ), + ))) } AggregateFunctionDefinition::Name(_) => { internal_err!("Function `Expr` with name should be resolved.") @@ -409,15 +398,14 @@ impl TreeNodeRewriter for TypeCoercionRewriter { _ => args, }; - let expr = Expr::WindowFunction(WindowFunction::new( + Ok(Transformed::yes(Expr::WindowFunction(WindowFunction::new( fun, args, partition_by, order_by, window_frame, null_treatment, - )); - Ok(expr) + )))) } Expr::Alias(_) | Expr::Column(_) @@ -434,7 +422,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter { | Expr::Wildcard { .. } | Expr::GroupingSet(_) | Expr::Placeholder(_) - | Expr::OuterReferenceColumn(_, _) => Ok(expr), + | Expr::OuterReferenceColumn(_, _) => Ok(Transformed::no(expr)), } } } @@ -764,31 +752,26 @@ mod test { use std::any::Any; use std::sync::{Arc, OnceLock}; - use arrow::array::{FixedSizeListArray, Int32Array}; - use arrow::datatypes::{DataType, TimeUnit}; + use crate::analyzer::type_coercion::{ + cast_expr, coerce_case_expression, TypeCoercion, TypeCoercionRewriter, + }; + use crate::test::assert_analyzed_plan_eq; - use arrow::datatypes::Field; - use datafusion_common::tree_node::TreeNode; + use arrow::array::{FixedSizeListArray, Int32Array}; + use arrow::datatypes::{DataType, Field, TimeUnit}; + use datafusion_common::tree_node::{TransformedResult, TreeNode}; use datafusion_common::{DFField, DFSchema, DFSchemaRef, Result, ScalarValue}; use datafusion_expr::expr::{self, InSubquery, Like, ScalarFunction}; + use datafusion_expr::logical_plan::{EmptyRelation, Projection}; use datafusion_expr::{ - cast, col, concat, concat_ws, create_udaf, is_true, AccumulatorFactoryFunction, - AggregateFunction, AggregateUDF, BinaryExpr, BuiltinScalarFunction, Case, - ColumnarValue, ExprSchemable, Filter, Operator, ScalarUDFImpl, - SimpleAggregateUDF, Subquery, - }; - use datafusion_expr::{ - lit, - logical_plan::{EmptyRelation, Projection}, - Expr, LogicalPlan, ScalarUDF, Signature, Volatility, + cast, col, concat, concat_ws, create_udaf, is_true, lit, + AccumulatorFactoryFunction, AggregateFunction, AggregateUDF, BinaryExpr, + BuiltinScalarFunction, Case, ColumnarValue, Expr, ExprSchemable, Filter, + LogicalPlan, Operator, ScalarUDF, ScalarUDFImpl, Signature, SimpleAggregateUDF, + Subquery, Volatility, }; use datafusion_physical_expr::expressions::AvgAccumulator; - use crate::analyzer::type_coercion::{ - cast_expr, coerce_case_expression, TypeCoercion, TypeCoercionRewriter, - }; - use crate::test::assert_analyzed_plan_eq; - fn empty() -> Arc { Arc::new(LogicalPlan::EmptyRelation(EmptyRelation { produce_one_row: false, @@ -1289,7 +1272,7 @@ mod test { std::collections::HashMap::new(), )?); let mut rewriter = TypeCoercionRewriter { schema }; - let result = expr.rewrite(&mut rewriter)?; + let result = expr.rewrite(&mut rewriter).data()?; let schema = Arc::new(DFSchema::new_with_metadata( vec![DFField::new_unqualified( @@ -1324,7 +1307,7 @@ mod test { let mut rewriter = TypeCoercionRewriter { schema }; let expr = is_true(lit(12i32).gt(lit(13i64))); let expected = is_true(cast(lit(12i32), DataType::Int64).gt(lit(13i64))); - let result = expr.rewrite(&mut rewriter)?; + let result = expr.rewrite(&mut rewriter).data()?; assert_eq!(expected, result); // eq @@ -1335,7 +1318,7 @@ mod test { let mut rewriter = TypeCoercionRewriter { schema }; let expr = is_true(lit(12i32).eq(lit(13i64))); let expected = is_true(cast(lit(12i32), DataType::Int64).eq(lit(13i64))); - let result = expr.rewrite(&mut rewriter)?; + let result = expr.rewrite(&mut rewriter).data()?; assert_eq!(expected, result); // lt @@ -1346,7 +1329,7 @@ mod test { let mut rewriter = TypeCoercionRewriter { schema }; let expr = is_true(lit(12i32).lt(lit(13i64))); let expected = is_true(cast(lit(12i32), DataType::Int64).lt(lit(13i64))); - let result = expr.rewrite(&mut rewriter)?; + let result = expr.rewrite(&mut rewriter).data()?; assert_eq!(expected, result); Ok(()) diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index ae720bc68998..30c184a28e33 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -25,10 +25,12 @@ use crate::{utils, OptimizerConfig, OptimizerRule}; use arrow::datatypes::DataType; use datafusion_common::tree_node::{ - RewriteRecursion, TreeNode, TreeNodeRewriter, TreeNodeVisitor, VisitRecursion, + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter, + TreeNodeVisitor, }; use datafusion_common::{ - internal_err, Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, + internal_datafusion_err, internal_err, Column, DFField, DFSchema, DFSchemaRef, + DataFusionError, Result, }; use datafusion_expr::expr::Alias; use datafusion_expr::logical_plan::{ @@ -642,51 +644,52 @@ impl ExprIdentifierVisitor<'_> { /// Find the first `EnterMark` in the stack, and accumulates every `ExprItem` /// before it. - fn pop_enter_mark(&mut self) -> (usize, Identifier) { + fn pop_enter_mark(&mut self) -> Option<(usize, Identifier)> { let mut desc = String::new(); while let Some(item) = self.visit_stack.pop() { match item { VisitRecord::EnterMark(idx) => { - return (idx, desc); + return Some((idx, desc)); } VisitRecord::ExprItem(s) => { desc.push_str(&s); } } } - - unreachable!("Enter mark should paired with node number"); + None } } impl TreeNodeVisitor for ExprIdentifierVisitor<'_> { - type N = Expr; + type Node = Expr; - fn pre_visit(&mut self, expr: &Expr) -> Result { + fn f_down(&mut self, expr: &Expr) -> Result { // related to https://github.com/apache/arrow-datafusion/issues/8814 // If the expr contain volatile expression or is a short-circuit expression, skip it. if expr.short_circuits() || is_volatile_expression(expr)? { - return Ok(VisitRecursion::Skip); + return Ok(TreeNodeRecursion::Jump); } self.visit_stack .push(VisitRecord::EnterMark(self.node_count)); self.node_count += 1; // put placeholder self.id_array.push((0, "".to_string())); - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) } - fn post_visit(&mut self, expr: &Expr) -> Result { + fn f_up(&mut self, expr: &Expr) -> Result { self.series_number += 1; - let (idx, sub_expr_desc) = self.pop_enter_mark(); + let Some((idx, sub_expr_desc)) = self.pop_enter_mark() else { + return Ok(TreeNodeRecursion::Continue); + }; // skip exprs should not be recognize. if self.expr_mask.ignores(expr) { self.id_array[idx].0 = self.series_number; let desc = Self::desc_expr(expr); self.visit_stack.push(VisitRecord::ExprItem(desc)); - return Ok(VisitRecursion::Continue); + return Ok(TreeNodeRecursion::Continue); } let mut desc = Self::desc_expr(expr); desc.push_str(&sub_expr_desc); @@ -700,7 +703,7 @@ impl TreeNodeVisitor for ExprIdentifierVisitor<'_> { .entry(desc) .or_insert_with(|| (expr.clone(), 0, data_type)) .1 += 1; - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) } } @@ -743,74 +746,83 @@ struct CommonSubexprRewriter<'a> { } impl TreeNodeRewriter for CommonSubexprRewriter<'_> { - type N = Expr; + type Node = Expr; - fn pre_visit(&mut self, expr: &Expr) -> Result { + fn f_down(&mut self, expr: Expr) -> Result> { // The `CommonSubexprRewriter` relies on `ExprIdentifierVisitor` to generate // the `id_array`, which records the expr's identifier used to rewrite expr. So if we // skip an expr in `ExprIdentifierVisitor`, we should skip it here, too. - if expr.short_circuits() || is_volatile_expression(expr)? { - return Ok(RewriteRecursion::Stop); + if expr.short_circuits() || is_volatile_expression(&expr)? { + return Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump)); } if self.curr_index >= self.id_array.len() || self.max_series_number > self.id_array[self.curr_index].0 { - return Ok(RewriteRecursion::Stop); + return Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump)); } let curr_id = &self.id_array[self.curr_index].1; // skip `Expr`s without identifier (empty identifier). if curr_id.is_empty() { self.curr_index += 1; - return Ok(RewriteRecursion::Skip); + return Ok(Transformed::no(expr)); } match self.expr_set.get(curr_id) { Some((_, counter, _)) => { if *counter > 1 { self.affected_id.insert(curr_id.clone()); - Ok(RewriteRecursion::Mutate) + + // This expr tree is finished. + if self.curr_index >= self.id_array.len() { + return Ok(Transformed::new( + expr, + false, + TreeNodeRecursion::Jump, + )); + } + + let (series_number, id) = &self.id_array[self.curr_index]; + self.curr_index += 1; + // Skip sub-node of a replaced tree, or without identifier, or is not repeated expr. + let expr_set_item = self.expr_set.get(id).ok_or_else(|| { + internal_datafusion_err!("expr_set invalid state") + })?; + if *series_number < self.max_series_number + || id.is_empty() + || expr_set_item.1 <= 1 + { + return Ok(Transformed::new( + expr, + false, + TreeNodeRecursion::Jump, + )); + } + + self.max_series_number = *series_number; + // step index to skip all sub-node (which has smaller series number). + while self.curr_index < self.id_array.len() + && *series_number > self.id_array[self.curr_index].0 + { + self.curr_index += 1; + } + + let expr_name = expr.display_name()?; + // Alias this `Column` expr to it original "expr name", + // `projection_push_down` optimizer use "expr name" to eliminate useless + // projections. + Ok(Transformed::new( + col(id).alias(expr_name), + true, + TreeNodeRecursion::Jump, + )) } else { self.curr_index += 1; - Ok(RewriteRecursion::Skip) + Ok(Transformed::no(expr)) } } _ => internal_err!("expr_set invalid state"), } } - - fn mutate(&mut self, expr: Expr) -> Result { - // This expr tree is finished. - if self.curr_index >= self.id_array.len() { - return Ok(expr); - } - - let (series_number, id) = &self.id_array[self.curr_index]; - self.curr_index += 1; - // Skip sub-node of a replaced tree, or without identifier, or is not repeated expr. - let expr_set_item = self.expr_set.get(id).ok_or_else(|| { - DataFusionError::Internal("expr_set invalid state".to_string()) - })?; - if *series_number < self.max_series_number - || id.is_empty() - || expr_set_item.1 <= 1 - { - return Ok(expr); - } - - self.max_series_number = *series_number; - // step index to skip all sub-node (which has smaller series number). - while self.curr_index < self.id_array.len() - && *series_number > self.id_array[self.curr_index].0 - { - self.curr_index += 1; - } - - let expr_name = expr.display_name()?; - // Alias this `Column` expr to it original "expr name", - // `projection_push_down` optimizer use "expr name" to eliminate useless - // projections. - Ok(col(id).alias(expr_name)) - } } fn replace_common_expr( @@ -826,6 +838,7 @@ fn replace_common_expr( max_series_number: 0, curr_index: 0, }) + .data() } #[cfg(test)] diff --git a/datafusion/optimizer/src/decorrelate.rs b/datafusion/optimizer/src/decorrelate.rs index 0f4b39d9eee3..fd548ba4948e 100644 --- a/datafusion/optimizer/src/decorrelate.rs +++ b/datafusion/optimizer/src/decorrelate.rs @@ -15,19 +15,20 @@ // specific language governing permissions and limitations // under the License. +use std::collections::{BTreeSet, HashMap}; +use std::ops::Deref; + use crate::simplify_expressions::{ExprSimplifier, SimplifyContext}; use crate::utils::collect_subquery_cols; + use datafusion_common::tree_node::{ - RewriteRecursion, Transformed, TreeNode, TreeNodeRewriter, + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter, }; -use datafusion_common::{plan_err, Result}; -use datafusion_common::{Column, DFSchemaRef, ScalarValue}; +use datafusion_common::{plan_err, Column, DFSchemaRef, Result, ScalarValue}; use datafusion_expr::expr::{AggregateFunctionDefinition, Alias}; use datafusion_expr::utils::{conjunction, find_join_exprs, split_conjunction}; use datafusion_expr::{expr, EmptyRelation, Expr, LogicalPlan, LogicalPlanBuilder}; use datafusion_physical_expr::execution_props::ExecutionProps; -use std::collections::{BTreeSet, HashMap}; -use std::ops::Deref; /// This struct rewrite the sub query plan by pull up the correlated expressions(contains outer reference columns) from the inner subquery's 'Filter'. /// It adds the inner reference columns to the 'Projection' or 'Aggregate' of the subquery if they are missing, so that they can be evaluated by the parent operator as the join condition. @@ -56,19 +57,19 @@ pub const UN_MATCHED_ROW_INDICATOR: &str = "__always_true"; pub type ExprResultMap = HashMap; impl TreeNodeRewriter for PullUpCorrelatedExpr { - type N = LogicalPlan; + type Node = LogicalPlan; - fn pre_visit(&mut self, plan: &LogicalPlan) -> Result { + fn f_down(&mut self, plan: LogicalPlan) -> Result> { match plan { - LogicalPlan::Filter(_) => Ok(RewriteRecursion::Continue), + LogicalPlan::Filter(_) => Ok(Transformed::no(plan)), LogicalPlan::Union(_) | LogicalPlan::Sort(_) | LogicalPlan::Extension(_) => { let plan_hold_outer = !plan.all_out_ref_exprs().is_empty(); if plan_hold_outer { // the unsupported case self.can_pull_up = false; - Ok(RewriteRecursion::Stop) + Ok(Transformed::new(plan, false, TreeNodeRecursion::Jump)) } else { - Ok(RewriteRecursion::Continue) + Ok(Transformed::no(plan)) } } LogicalPlan::Limit(_) => { @@ -77,21 +78,21 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { (false, true) => { // the unsupported case self.can_pull_up = false; - Ok(RewriteRecursion::Stop) + Ok(Transformed::new(plan, false, TreeNodeRecursion::Jump)) } - _ => Ok(RewriteRecursion::Continue), + _ => Ok(Transformed::no(plan)), } } _ if plan.expressions().iter().any(|expr| expr.contains_outer()) => { // the unsupported cases, the plan expressions contain out reference columns(like window expressions) self.can_pull_up = false; - Ok(RewriteRecursion::Stop) + Ok(Transformed::new(plan, false, TreeNodeRecursion::Jump)) } - _ => Ok(RewriteRecursion::Continue), + _ => Ok(Transformed::no(plan)), } } - fn mutate(&mut self, plan: LogicalPlan) -> Result { + fn f_up(&mut self, plan: LogicalPlan) -> Result> { let subquery_schema = plan.schema().clone(); match &plan { LogicalPlan::Filter(plan_filter) => { @@ -140,7 +141,7 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { .build()?; self.correlated_subquery_cols_map .insert(new_plan.clone(), correlated_subquery_cols); - Ok(new_plan) + Ok(Transformed::yes(new_plan)) } (None, _) => { // if the subquery still has filter expressions, restore them. @@ -152,7 +153,7 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { let new_plan = plan.build()?; self.correlated_subquery_cols_map .insert(new_plan.clone(), correlated_subquery_cols); - Ok(new_plan) + Ok(Transformed::yes(new_plan)) } } } @@ -196,7 +197,7 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { self.collected_count_expr_map .insert(new_plan.clone(), expr_result_map_for_count_bug); } - Ok(new_plan) + Ok(Transformed::yes(new_plan)) } LogicalPlan::Aggregate(aggregate) if self.in_predicate_opt.is_some() || !self.join_filters.is_empty() => @@ -240,7 +241,7 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { self.collected_count_expr_map .insert(new_plan.clone(), expr_result_map_for_count_bug); } - Ok(new_plan) + Ok(Transformed::yes(new_plan)) } LogicalPlan::SubqueryAlias(alias) => { let mut local_correlated_cols = BTreeSet::new(); @@ -262,7 +263,7 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { self.collected_count_expr_map .insert(plan.clone(), input_map.clone()); } - Ok(plan) + Ok(Transformed::no(plan)) } LogicalPlan::Limit(limit) => { let input_expr_map = self @@ -273,7 +274,7 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { let new_plan = match (self.exists_sub_query, self.join_filters.is_empty()) { // Correlated exist subquery, remove the limit(so that correlated expressions can pull up) - (true, false) => { + (true, false) => Transformed::yes( if limit.fetch.filter(|limit_row| *limit_row == 0).is_some() { LogicalPlan::EmptyRelation(EmptyRelation { produce_one_row: false, @@ -281,17 +282,17 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { }) } else { LogicalPlanBuilder::from((*limit.input).clone()).build()? - } - } - _ => plan, + }, + ), + _ => Transformed::no(plan), }; if let Some(input_map) = input_expr_map { self.collected_count_expr_map - .insert(new_plan.clone(), input_map); + .insert(new_plan.data.clone(), input_map); } Ok(new_plan) } - _ => Ok(plan), + _ => Ok(Transformed::no(plan)), } } } @@ -370,31 +371,34 @@ fn agg_exprs_evaluation_result_on_empty_batch( expr_result_map_for_count_bug: &mut ExprResultMap, ) -> Result<()> { for e in agg_expr.iter() { - let result_expr = e.clone().transform_up(&|expr| { - let new_expr = match expr { - Expr::AggregateFunction(expr::AggregateFunction { func_def, .. }) => { - match func_def { + let result_expr = e + .clone() + .transform_up(&|expr| { + let new_expr = match expr { + Expr::AggregateFunction(expr::AggregateFunction { + func_def, .. + }) => match func_def { AggregateFunctionDefinition::BuiltIn(fun) => { if matches!(fun, datafusion_expr::AggregateFunction::Count) { - Transformed::Yes(Expr::Literal(ScalarValue::Int64(Some( + Transformed::yes(Expr::Literal(ScalarValue::Int64(Some( 0, )))) } else { - Transformed::Yes(Expr::Literal(ScalarValue::Null)) + Transformed::yes(Expr::Literal(ScalarValue::Null)) } } AggregateFunctionDefinition::UDF { .. } => { - Transformed::Yes(Expr::Literal(ScalarValue::Null)) + Transformed::yes(Expr::Literal(ScalarValue::Null)) } AggregateFunctionDefinition::Name(_) => { - Transformed::Yes(Expr::Literal(ScalarValue::Null)) + Transformed::yes(Expr::Literal(ScalarValue::Null)) } - } - } - _ => Transformed::No(expr), - }; - Ok(new_expr) - })?; + }, + _ => Transformed::no(expr), + }; + Ok(new_expr) + }) + .data()?; let result_expr = result_expr.unalias(); let props = ExecutionProps::new(); @@ -415,17 +419,23 @@ fn proj_exprs_evaluation_result_on_empty_batch( expr_result_map_for_count_bug: &mut ExprResultMap, ) -> Result<()> { for expr in proj_expr.iter() { - let result_expr = expr.clone().transform_up(&|expr| { - if let Expr::Column(Column { name, .. }) = &expr { - if let Some(result_expr) = input_expr_result_map_for_count_bug.get(name) { - Ok(Transformed::Yes(result_expr.clone())) + let result_expr = expr + .clone() + .transform_up(&|expr| { + if let Expr::Column(Column { name, .. }) = &expr { + if let Some(result_expr) = + input_expr_result_map_for_count_bug.get(name) + { + Ok(Transformed::yes(result_expr.clone())) + } else { + Ok(Transformed::no(expr)) + } } else { - Ok(Transformed::No(expr)) + Ok(Transformed::no(expr)) } - } else { - Ok(Transformed::No(expr)) - } - })?; + }) + .data()?; + if result_expr.ne(expr) { let props = ExecutionProps::new(); let info = SimplifyContext::new(&props).with_schema(schema.clone()); @@ -448,17 +458,21 @@ fn filter_exprs_evaluation_result_on_empty_batch( input_expr_result_map_for_count_bug: &ExprResultMap, expr_result_map_for_count_bug: &mut ExprResultMap, ) -> Result> { - let result_expr = filter_expr.clone().transform_up(&|expr| { - if let Expr::Column(Column { name, .. }) = &expr { - if let Some(result_expr) = input_expr_result_map_for_count_bug.get(name) { - Ok(Transformed::Yes(result_expr.clone())) + let result_expr = filter_expr + .clone() + .transform_up(&|expr| { + if let Expr::Column(Column { name, .. }) = &expr { + if let Some(result_expr) = input_expr_result_map_for_count_bug.get(name) { + Ok(Transformed::yes(result_expr.clone())) + } else { + Ok(Transformed::no(expr)) + } } else { - Ok(Transformed::No(expr)) + Ok(Transformed::no(expr)) } - } else { - Ok(Transformed::No(expr)) - } - })?; + }) + .data()?; + let pull_up_expr = if result_expr.ne(filter_expr) { let props = ExecutionProps::new(); let info = SimplifyContext::new(&props).with_schema(schema); diff --git a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs index a9e1f1228e5e..b94cf37c5c12 100644 --- a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs +++ b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs @@ -15,12 +15,17 @@ // specific language governing permissions and limitations // under the License. +use std::collections::BTreeSet; +use std::ops::Deref; +use std::sync::Arc; + use crate::decorrelate::PullUpCorrelatedExpr; use crate::optimizer::ApplyOrder; use crate::utils::replace_qualified_name; use crate::{OptimizerConfig, OptimizerRule}; + use datafusion_common::alias::AliasGenerator; -use datafusion_common::tree_node::TreeNode; +use datafusion_common::tree_node::{TransformedResult, TreeNode}; use datafusion_common::{plan_err, Result}; use datafusion_expr::expr::{Exists, InSubquery}; use datafusion_expr::expr_rewriter::create_col_from_scalar_expr; @@ -30,10 +35,8 @@ use datafusion_expr::{ exists, in_subquery, not_exists, not_in_subquery, BinaryExpr, Expr, Filter, LogicalPlan, LogicalPlanBuilder, Operator, }; + use log::debug; -use std::collections::BTreeSet; -use std::ops::Deref; -use std::sync::Arc; /// Optimizer rule for rewriting predicate(IN/EXISTS) subquery to left semi/anti joins #[derive(Default)] @@ -228,7 +231,7 @@ fn build_join( collected_count_expr_map: Default::default(), pull_up_having_expr: None, }; - let new_plan = subquery.clone().rewrite(&mut pull_up)?; + let new_plan = subquery.clone().rewrite(&mut pull_up).data()?; if !pull_up.can_pull_up { return Ok(None); } @@ -321,8 +324,11 @@ impl SubqueryInfo { #[cfg(test)] mod tests { + use std::ops::Add; + use super::*; use crate::test::*; + use arrow::datatypes::DataType; use datafusion_common::Result; use datafusion_expr::{ @@ -330,7 +336,6 @@ mod tests { logical_plan::LogicalPlanBuilder, not_exists, not_in_subquery, or, out_ref_col, Operator, }; - use std::ops::Add; fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq_display_indent( diff --git a/datafusion/optimizer/src/plan_signature.rs b/datafusion/optimizer/src/plan_signature.rs index 07f495a7262d..4143d52a053e 100644 --- a/datafusion/optimizer/src/plan_signature.rs +++ b/datafusion/optimizer/src/plan_signature.rs @@ -15,13 +15,13 @@ // specific language governing permissions and limitations // under the License. -use datafusion_common::tree_node::{TreeNode, VisitRecursion}; use std::{ collections::hash_map::DefaultHasher, hash::{Hash, Hasher}, num::NonZeroUsize, }; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion_expr::LogicalPlan; /// Non-unique identifier of a [`LogicalPlan`]. @@ -75,7 +75,7 @@ fn get_node_number(plan: &LogicalPlan) -> NonZeroUsize { let mut node_number = 0; plan.apply(&mut |_plan| { node_number += 1; - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) }) // Closure always return Ok .unwrap(); diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 40156d43c572..a63133c5166f 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -22,7 +22,9 @@ use crate::optimizer::ApplyOrder; use crate::utils::is_volatile_expression; use crate::{OptimizerConfig, OptimizerRule}; -use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{ + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, +}; use datafusion_common::{ internal_err, plan_datafusion_err, Column, DFSchema, DFSchemaRef, JoinConstraint, Result, @@ -222,7 +224,7 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result { Expr::Column(_) | Expr::Literal(_) | Expr::Placeholder(_) - | Expr::ScalarVariable(_, _) => Ok(VisitRecursion::Skip), + | Expr::ScalarVariable(_, _) => Ok(TreeNodeRecursion::Jump), Expr::Exists { .. } | Expr::InSubquery(_) | Expr::ScalarSubquery(_) @@ -233,7 +235,7 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result { .. }) => { is_evaluate = false; - Ok(VisitRecursion::Stop) + Ok(TreeNodeRecursion::Stop) } Expr::Alias(_) | Expr::BinaryExpr(_) @@ -255,7 +257,7 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result { | Expr::Cast(_) | Expr::TryCast(_) | Expr::ScalarFunction(..) - | Expr::InList { .. } => Ok(VisitRecursion::Continue), + | Expr::InList { .. } => Ok(TreeNodeRecursion::Continue), Expr::Sort(_) | Expr::AggregateFunction(_) | Expr::WindowFunction(_) @@ -992,13 +994,14 @@ pub fn replace_cols_by_name( e.transform_up(&|expr| { Ok(if let Expr::Column(c) = &expr { match replace_map.get(&c.flat_name()) { - Some(new_c) => Transformed::Yes(new_c.clone()), - None => Transformed::No(expr), + Some(new_c) => Transformed::yes(new_c.clone()), + None => Transformed::no(expr), } } else { - Transformed::No(expr) + Transformed::no(expr) }) }) + .data() } /// check whether the expression uses the columns in `check_map`. @@ -1009,12 +1012,12 @@ fn contain(e: &Expr, check_map: &HashMap) -> bool { match check_map.get(&c.flat_name()) { Some(_) => { is_contain = true; - VisitRecursion::Stop + TreeNodeRecursion::Stop } - None => VisitRecursion::Continue, + None => TreeNodeRecursion::Continue, } } else { - VisitRecursion::Continue + TreeNodeRecursion::Continue }) }) .unwrap(); diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs b/datafusion/optimizer/src/scalar_subquery_to_join.rs index 9aa08c37fa35..8acc36e479ca 100644 --- a/datafusion/optimizer/src/scalar_subquery_to_join.rs +++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs @@ -15,21 +15,23 @@ // specific language governing permissions and limitations // under the License. +use std::collections::{BTreeSet, HashMap}; +use std::sync::Arc; + use crate::decorrelate::{PullUpCorrelatedExpr, UN_MATCHED_ROW_INDICATOR}; use crate::optimizer::ApplyOrder; use crate::utils::replace_qualified_name; use crate::{OptimizerConfig, OptimizerRule}; + use datafusion_common::alias::AliasGenerator; use datafusion_common::tree_node::{ - RewriteRecursion, Transformed, TreeNode, TreeNodeRewriter, + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter, }; use datafusion_common::{plan_err, Column, Result, ScalarValue}; use datafusion_expr::expr_rewriter::create_col_from_scalar_expr; use datafusion_expr::logical_plan::{JoinType, Subquery}; use datafusion_expr::utils::conjunction; use datafusion_expr::{expr, EmptyRelation, Expr, LogicalPlan, LogicalPlanBuilder}; -use std::collections::{BTreeSet, HashMap}; -use std::sync::Arc; /// Optimizer rule for rewriting subquery filters to joins #[derive(Default)] @@ -56,8 +58,11 @@ impl ScalarSubqueryToJoin { sub_query_info: vec![], alias_gen, }; - let new_expr = predicate.clone().rewrite(&mut extract)?; - Ok((extract.sub_query_info, new_expr)) + predicate + .clone() + .rewrite(&mut extract) + .data() + .map(|new_expr| (extract.sub_query_info, new_expr)) } } @@ -86,20 +91,22 @@ impl OptimizerRule for ScalarSubqueryToJoin { build_join(&subquery, &cur_input, &alias)? { if !expr_check_map.is_empty() { - rewrite_expr = - rewrite_expr.clone().transform_up(&|expr| { + rewrite_expr = rewrite_expr + .clone() + .transform_up(&|expr| { if let Expr::Column(col) = &expr { if let Some(map_expr) = expr_check_map.get(&col.name) { - Ok(Transformed::Yes(map_expr.clone())) + Ok(Transformed::yes(map_expr.clone())) } else { - Ok(Transformed::No(expr)) + Ok(Transformed::no(expr)) } } else { - Ok(Transformed::No(expr)) + Ok(Transformed::no(expr)) } - })?; + }) + .data()?; } cur_input = optimized_subquery; } else { @@ -141,20 +148,22 @@ impl OptimizerRule for ScalarSubqueryToJoin { if let Some(rewrite_expr) = expr_to_rewrite_expr_map.get(expr) { - let new_expr = - rewrite_expr.clone().transform_up(&|expr| { + let new_expr = rewrite_expr + .clone() + .transform_up(&|expr| { if let Expr::Column(col) = &expr { if let Some(map_expr) = expr_check_map.get(&col.name) { - Ok(Transformed::Yes(map_expr.clone())) + Ok(Transformed::yes(map_expr.clone())) } else { - Ok(Transformed::No(expr)) + Ok(Transformed::no(expr)) } } else { - Ok(Transformed::No(expr)) + Ok(Transformed::no(expr)) } - })?; + }) + .data()?; expr_to_rewrite_expr_map.insert(expr, new_expr); } } @@ -201,16 +210,9 @@ struct ExtractScalarSubQuery { } impl TreeNodeRewriter for ExtractScalarSubQuery { - type N = Expr; - - fn pre_visit(&mut self, expr: &Expr) -> Result { - match expr { - Expr::ScalarSubquery(_) => Ok(RewriteRecursion::Mutate), - _ => Ok(RewriteRecursion::Continue), - } - } + type Node = Expr; - fn mutate(&mut self, expr: Expr) -> Result { + fn f_down(&mut self, expr: Expr) -> Result> { match expr { Expr::ScalarSubquery(subquery) => { let subqry_alias = self.alias_gen.next("__scalar_sq"); @@ -220,12 +222,16 @@ impl TreeNodeRewriter for ExtractScalarSubQuery { .subquery .head_output_expr()? .map_or(plan_err!("single expression required."), Ok)?; - Ok(Expr::Column(create_col_from_scalar_expr( - &scalar_expr, - subqry_alias, - )?)) + Ok(Transformed::new( + Expr::Column(create_col_from_scalar_expr( + &scalar_expr, + subqry_alias, + )?), + true, + TreeNodeRecursion::Jump, + )) } - _ => Ok(expr), + _ => Ok(Transformed::no(expr)), } } } @@ -282,7 +288,7 @@ fn build_join( collected_count_expr_map: Default::default(), pull_up_having_expr: None, }; - let new_plan = subquery_plan.clone().rewrite(&mut pull_up)?; + let new_plan = subquery_plan.clone().rewrite(&mut pull_up).data()?; if !pull_up.can_pull_up { return Ok(None); } @@ -371,15 +377,17 @@ fn build_join( #[cfg(test)] mod tests { + use std::ops::Add; + use super::*; use crate::test::*; + use arrow::datatypes::DataType; use datafusion_common::Result; + use datafusion_expr::logical_plan::LogicalPlanBuilder; use datafusion_expr::{ - col, lit, logical_plan::LogicalPlanBuilder, max, min, out_ref_col, - scalar_subquery, sum, Between, + col, lit, max, min, out_ref_col, scalar_subquery, sum, Between, }; - use std::ops::Add; /// Test multiple correlated subqueries #[test] diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 3f65c68bc45b..6b5dd1b4681e 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -19,15 +19,21 @@ use std::ops::Not; +use super::inlist_simplifier::{InListSimplifier, ShortenInListSimplifier}; +use super::utils::*; +use crate::analyzer::type_coercion::TypeCoercionRewriter; +use crate::simplify_expressions::guarantees::GuaranteeRewriter; +use crate::simplify_expressions::regex::simplify_regex_expr; +use crate::simplify_expressions::SimplifyInfo; + use arrow::{ array::{new_null_array, AsArray}, datatypes::{DataType, Field, Schema}, record_batch::RecordBatch, }; - use datafusion_common::{ cast::{as_large_list_array, as_list_array}, - tree_node::{RewriteRecursion, TreeNode, TreeNodeRewriter}, + tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter}, }; use datafusion_common::{ internal_err, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, @@ -39,14 +45,6 @@ use datafusion_expr::{ use datafusion_expr::{expr::ScalarFunction, interval_arithmetic::NullableInterval}; use datafusion_physical_expr::{create_physical_expr, execution_props::ExecutionProps}; -use crate::analyzer::type_coercion::TypeCoercionRewriter; -use crate::simplify_expressions::guarantees::GuaranteeRewriter; -use crate::simplify_expressions::regex::simplify_regex_expr; -use crate::simplify_expressions::SimplifyInfo; - -use super::inlist_simplifier::{InListSimplifier, ShortenInListSimplifier}; -use super::utils::*; - /// This structure handles API for expression simplification pub struct ExprSimplifier { info: S, @@ -131,31 +129,36 @@ impl ExprSimplifier { /// let expr = simplifier.simplify(expr).unwrap(); /// assert_eq!(expr, b_lt_2); /// ``` - pub fn simplify(&self, expr: Expr) -> Result { + pub fn simplify(&self, mut expr: Expr) -> Result { let mut simplifier = Simplifier::new(&self.info); let mut const_evaluator = ConstEvaluator::try_new(self.info.execution_props())?; let mut shorten_in_list_simplifier = ShortenInListSimplifier::new(); let mut inlist_simplifier = InListSimplifier::new(); let mut guarantee_rewriter = GuaranteeRewriter::new(&self.guarantees); - let expr = if self.canonicalize { - expr.rewrite(&mut Canonicalizer::new())? - } else { - expr - }; + if self.canonicalize { + expr = expr.rewrite(&mut Canonicalizer::new()).data()? + } // TODO iterate until no changes are made during rewrite // (evaluating constants can enable new simplifications and // simplifications can enable new constant evaluation) // https://github.com/apache/arrow-datafusion/issues/1160 - expr.rewrite(&mut const_evaluator)? - .rewrite(&mut simplifier)? - .rewrite(&mut inlist_simplifier)? - .rewrite(&mut shorten_in_list_simplifier)? - .rewrite(&mut guarantee_rewriter)? + expr.rewrite(&mut const_evaluator) + .data()? + .rewrite(&mut simplifier) + .data()? + .rewrite(&mut inlist_simplifier) + .data()? + .rewrite(&mut shorten_in_list_simplifier) + .data()? + .rewrite(&mut guarantee_rewriter) + .data()? // run both passes twice to try an minimize simplifications that we missed - .rewrite(&mut const_evaluator)? + .rewrite(&mut const_evaluator) + .data()? .rewrite(&mut simplifier) + .data() } /// Apply type coercion to an [`Expr`] so that it can be @@ -171,7 +174,7 @@ impl ExprSimplifier { pub fn coerce(&self, expr: Expr, schema: DFSchemaRef) -> Result { let mut expr_rewrite = TypeCoercionRewriter { schema }; - expr.rewrite(&mut expr_rewrite) + expr.rewrite(&mut expr_rewrite).data() } /// Input guarantees about the values of columns. @@ -303,32 +306,36 @@ impl Canonicalizer { } impl TreeNodeRewriter for Canonicalizer { - type N = Expr; + type Node = Expr; - fn mutate(&mut self, expr: Expr) -> Result { + fn f_up(&mut self, expr: Expr) -> Result> { let Expr::BinaryExpr(BinaryExpr { left, op, right }) = expr else { - return Ok(expr); + return Ok(Transformed::no(expr)); }; match (left.as_ref(), right.as_ref(), op.swap()) { // (Expr::Column(left_col), Expr::Column(right_col), Some(swapped_op)) if right_col > left_col => { - Ok(Expr::BinaryExpr(BinaryExpr { + Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr { left: right, op: swapped_op, right: left, - })) + }))) } // (Expr::Literal(_a), Expr::Column(_b), Some(swapped_op)) => { - Ok(Expr::BinaryExpr(BinaryExpr { + Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr { left: right, op: swapped_op, right: left, - })) + }))) } - _ => Ok(Expr::BinaryExpr(BinaryExpr { left, op, right })), + _ => Ok(Transformed::no(Expr::BinaryExpr(BinaryExpr { + left, + op, + right, + }))), } } } @@ -367,9 +374,9 @@ enum ConstSimplifyResult { } impl<'a> TreeNodeRewriter for ConstEvaluator<'a> { - type N = Expr; + type Node = Expr; - fn pre_visit(&mut self, expr: &Expr) -> Result { + fn f_down(&mut self, expr: Expr) -> Result> { // Default to being able to evaluate this node self.can_evaluate.push(true); @@ -377,7 +384,7 @@ impl<'a> TreeNodeRewriter for ConstEvaluator<'a> { // stack as not ok (as all parents have at least one child or // descendant that can not be evaluated - if !Self::can_evaluate(expr) { + if !Self::can_evaluate(&expr) { // walk back up stack, marking first parent that is not mutable let parent_iter = self.can_evaluate.iter_mut().rev(); for p in parent_iter { @@ -393,10 +400,10 @@ impl<'a> TreeNodeRewriter for ConstEvaluator<'a> { // NB: do not short circuit recursion even if we find a non // evaluatable node (so we can fold other children, args to // functions, etc) - Ok(RewriteRecursion::Continue) + Ok(Transformed::no(expr)) } - fn mutate(&mut self, expr: Expr) -> Result { + fn f_up(&mut self, expr: Expr) -> Result> { match self.can_evaluate.pop() { // Certain expressions such as `CASE` and `COALESCE` are short circuiting // and may not evalute all their sub expressions. Thus if @@ -405,11 +412,15 @@ impl<'a> TreeNodeRewriter for ConstEvaluator<'a> { Some(true) => { let result = self.evaluate_to_scalar(expr); match result { - ConstSimplifyResult::Simplified(s) => Ok(Expr::Literal(s)), - ConstSimplifyResult::SimplifyRuntimeError(_, expr) => Ok(expr), + ConstSimplifyResult::Simplified(s) => { + Ok(Transformed::yes(Expr::Literal(s))) + } + ConstSimplifyResult::SimplifyRuntimeError(_, expr) => { + Ok(Transformed::yes(expr)) + } } } - Some(false) => Ok(expr), + Some(false) => Ok(Transformed::no(expr)), _ => internal_err!("Failed to pop can_evaluate"), } } @@ -566,10 +577,10 @@ impl<'a, S> Simplifier<'a, S> { } impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { - type N = Expr; + type Node = Expr; /// rewrite the expression simplifying any constant expressions - fn mutate(&mut self, expr: Expr) -> Result { + fn f_up(&mut self, expr: Expr) -> Result> { use datafusion_expr::Operator::{ And, BitwiseAnd, BitwiseOr, BitwiseShiftLeft, BitwiseShiftRight, BitwiseXor, Divide, Eq, Modulo, Multiply, NotEq, Or, RegexIMatch, RegexMatch, @@ -577,7 +588,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { }; let info = self.info; - let new_expr = match expr { + Ok(match expr { // // Rules for Eq // @@ -590,11 +601,11 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: Eq, right, }) if is_bool_lit(&left) && info.is_boolean_type(&right)? => { - match as_bool_lit(*left)? { + Transformed::yes(match as_bool_lit(*left)? { Some(true) => *right, Some(false) => Expr::Not(right), None => lit_bool_null(), - } + }) } // A = true --> A // A = false --> !A @@ -604,11 +615,11 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: Eq, right, }) if is_bool_lit(&right) && info.is_boolean_type(&left)? => { - match as_bool_lit(*right)? { + Transformed::yes(match as_bool_lit(*right)? { Some(true) => *left, Some(false) => Expr::Not(left), None => lit_bool_null(), - } + }) } // Rules for NotEq // @@ -621,11 +632,11 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: NotEq, right, }) if is_bool_lit(&left) && info.is_boolean_type(&right)? => { - match as_bool_lit(*left)? { + Transformed::yes(match as_bool_lit(*left)? { Some(true) => Expr::Not(right), Some(false) => *right, None => lit_bool_null(), - } + }) } // A != true --> !A // A != false --> A @@ -635,11 +646,11 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: NotEq, right, }) if is_bool_lit(&right) && info.is_boolean_type(&left)? => { - match as_bool_lit(*right)? { + Transformed::yes(match as_bool_lit(*right)? { Some(true) => Expr::Not(left), Some(false) => *left, None => lit_bool_null(), - } + }) } // @@ -651,32 +662,32 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { left, op: Or, right: _, - }) if is_true(&left) => *left, + }) if is_true(&left) => Transformed::yes(*left), // false OR A --> A Expr::BinaryExpr(BinaryExpr { left, op: Or, right, - }) if is_false(&left) => *right, + }) if is_false(&left) => Transformed::yes(*right), // A OR true --> true (even if A is null) Expr::BinaryExpr(BinaryExpr { left: _, op: Or, right, - }) if is_true(&right) => *right, + }) if is_true(&right) => Transformed::yes(*right), // A OR false --> A Expr::BinaryExpr(BinaryExpr { left, op: Or, right, - }) if is_false(&right) => *left, + }) if is_false(&right) => Transformed::yes(*left), // A OR !A ---> true (if A not nullable) Expr::BinaryExpr(BinaryExpr { left, op: Or, right, }) if is_not_of(&right, &left) && !info.nullable(&left)? => { - Expr::Literal(ScalarValue::Boolean(Some(true))) + Transformed::yes(Expr::Literal(ScalarValue::Boolean(Some(true)))) } // !A OR A ---> true (if A not nullable) Expr::BinaryExpr(BinaryExpr { @@ -684,32 +695,36 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: Or, right, }) if is_not_of(&left, &right) && !info.nullable(&right)? => { - Expr::Literal(ScalarValue::Boolean(Some(true))) + Transformed::yes(Expr::Literal(ScalarValue::Boolean(Some(true)))) } // (..A..) OR A --> (..A..) Expr::BinaryExpr(BinaryExpr { left, op: Or, right, - }) if expr_contains(&left, &right, Or) => *left, + }) if expr_contains(&left, &right, Or) => Transformed::yes(*left), // A OR (..A..) --> (..A..) Expr::BinaryExpr(BinaryExpr { left, op: Or, right, - }) if expr_contains(&right, &left, Or) => *right, + }) if expr_contains(&right, &left, Or) => Transformed::yes(*right), // A OR (A AND B) --> A (if B not null) Expr::BinaryExpr(BinaryExpr { left, op: Or, right, - }) if !info.nullable(&right)? && is_op_with(And, &right, &left) => *left, + }) if !info.nullable(&right)? && is_op_with(And, &right, &left) => { + Transformed::yes(*left) + } // (A AND B) OR A --> A (if B not null) Expr::BinaryExpr(BinaryExpr { left, op: Or, right, - }) if !info.nullable(&left)? && is_op_with(And, &left, &right) => *right, + }) if !info.nullable(&left)? && is_op_with(And, &left, &right) => { + Transformed::yes(*right) + } // // Rules for AND @@ -720,32 +735,32 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { left, op: And, right, - }) if is_true(&left) => *right, + }) if is_true(&left) => Transformed::yes(*right), // false AND A --> false (even if A is null) Expr::BinaryExpr(BinaryExpr { left, op: And, right: _, - }) if is_false(&left) => *left, + }) if is_false(&left) => Transformed::yes(*left), // A AND true --> A Expr::BinaryExpr(BinaryExpr { left, op: And, right, - }) if is_true(&right) => *left, + }) if is_true(&right) => Transformed::yes(*left), // A AND false --> false (even if A is null) Expr::BinaryExpr(BinaryExpr { left: _, op: And, right, - }) if is_false(&right) => *right, + }) if is_false(&right) => Transformed::yes(*right), // A AND !A ---> false (if A not nullable) Expr::BinaryExpr(BinaryExpr { left, op: And, right, }) if is_not_of(&right, &left) && !info.nullable(&left)? => { - Expr::Literal(ScalarValue::Boolean(Some(false))) + Transformed::yes(Expr::Literal(ScalarValue::Boolean(Some(false)))) } // !A AND A ---> false (if A not nullable) Expr::BinaryExpr(BinaryExpr { @@ -753,32 +768,36 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: And, right, }) if is_not_of(&left, &right) && !info.nullable(&right)? => { - Expr::Literal(ScalarValue::Boolean(Some(false))) + Transformed::yes(Expr::Literal(ScalarValue::Boolean(Some(false)))) } // (..A..) AND A --> (..A..) Expr::BinaryExpr(BinaryExpr { left, op: And, right, - }) if expr_contains(&left, &right, And) => *left, + }) if expr_contains(&left, &right, And) => Transformed::yes(*left), // A AND (..A..) --> (..A..) Expr::BinaryExpr(BinaryExpr { left, op: And, right, - }) if expr_contains(&right, &left, And) => *right, + }) if expr_contains(&right, &left, And) => Transformed::yes(*right), // A AND (A OR B) --> A (if B not null) Expr::BinaryExpr(BinaryExpr { left, op: And, right, - }) if !info.nullable(&right)? && is_op_with(Or, &right, &left) => *left, + }) if !info.nullable(&right)? && is_op_with(Or, &right, &left) => { + Transformed::yes(*left) + } // (A OR B) AND A --> A (if B not null) Expr::BinaryExpr(BinaryExpr { left, op: And, right, - }) if !info.nullable(&left)? && is_op_with(Or, &left, &right) => *right, + }) if !info.nullable(&left)? && is_op_with(Or, &left, &right) => { + Transformed::yes(*right) + } // // Rules for Multiply @@ -789,25 +808,25 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { left, op: Multiply, right, - }) if is_one(&right) => *left, + }) if is_one(&right) => Transformed::yes(*left), // 1 * A --> A Expr::BinaryExpr(BinaryExpr { left, op: Multiply, right, - }) if is_one(&left) => *right, + }) if is_one(&left) => Transformed::yes(*right), // A * null --> null Expr::BinaryExpr(BinaryExpr { left: _, op: Multiply, right, - }) if is_null(&right) => *right, + }) if is_null(&right) => Transformed::yes(*right), // null * A --> null Expr::BinaryExpr(BinaryExpr { left, op: Multiply, right: _, - }) if is_null(&left) => *left, + }) if is_null(&left) => Transformed::yes(*left), // A * 0 --> 0 (if A is not null and not floating, since NAN * 0 -> NAN) Expr::BinaryExpr(BinaryExpr { @@ -818,7 +837,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { && !info.get_data_type(&left)?.is_floating() && is_zero(&right) => { - *right + Transformed::yes(*right) } // 0 * A --> 0 (if A is not null and not floating, since 0 * NAN -> NAN) Expr::BinaryExpr(BinaryExpr { @@ -829,7 +848,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { && !info.get_data_type(&right)?.is_floating() && is_zero(&left) => { - *left + Transformed::yes(*left) } // @@ -841,19 +860,19 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { left, op: Divide, right, - }) if is_one(&right) => *left, + }) if is_one(&right) => Transformed::yes(*left), // null / A --> null Expr::BinaryExpr(BinaryExpr { left, op: Divide, right: _, - }) if is_null(&left) => *left, + }) if is_null(&left) => Transformed::yes(*left), // A / null --> null Expr::BinaryExpr(BinaryExpr { left: _, op: Divide, right, - }) if is_null(&right) => *right, + }) if is_null(&right) => Transformed::yes(*right), // // Rules for Modulo @@ -864,13 +883,13 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { left: _, op: Modulo, right, - }) if is_null(&right) => *right, + }) if is_null(&right) => Transformed::yes(*right), // null % A --> null Expr::BinaryExpr(BinaryExpr { left, op: Modulo, right: _, - }) if is_null(&left) => *left, + }) if is_null(&left) => Transformed::yes(*left), // A % 1 --> 0 (if A is not nullable and not floating, since NAN % 1 --> NAN) Expr::BinaryExpr(BinaryExpr { left, @@ -880,7 +899,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { && !info.get_data_type(&left)?.is_floating() && is_one(&right) => { - lit(0) + Transformed::yes(lit(0)) } // @@ -892,28 +911,28 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { left: _, op: BitwiseAnd, right, - }) if is_null(&right) => *right, + }) if is_null(&right) => Transformed::yes(*right), // null & A -> null Expr::BinaryExpr(BinaryExpr { left, op: BitwiseAnd, right: _, - }) if is_null(&left) => *left, + }) if is_null(&left) => Transformed::yes(*left), // A & 0 -> 0 (if A not nullable) Expr::BinaryExpr(BinaryExpr { left, op: BitwiseAnd, right, - }) if !info.nullable(&left)? && is_zero(&right) => *right, + }) if !info.nullable(&left)? && is_zero(&right) => Transformed::yes(*right), // 0 & A -> 0 (if A not nullable) Expr::BinaryExpr(BinaryExpr { left, op: BitwiseAnd, right, - }) if !info.nullable(&right)? && is_zero(&left) => *left, + }) if !info.nullable(&right)? && is_zero(&left) => Transformed::yes(*left), // !A & A -> 0 (if A not nullable) Expr::BinaryExpr(BinaryExpr { @@ -921,7 +940,9 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: BitwiseAnd, right, }) if is_negative_of(&left, &right) && !info.nullable(&right)? => { - Expr::Literal(ScalarValue::new_zero(&info.get_data_type(&left)?)?) + Transformed::yes(Expr::Literal(ScalarValue::new_zero( + &info.get_data_type(&left)?, + )?)) } // A & !A -> 0 (if A not nullable) @@ -930,7 +951,9 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: BitwiseAnd, right, }) if is_negative_of(&right, &left) && !info.nullable(&left)? => { - Expr::Literal(ScalarValue::new_zero(&info.get_data_type(&left)?)?) + Transformed::yes(Expr::Literal(ScalarValue::new_zero( + &info.get_data_type(&left)?, + )?)) } // (..A..) & A --> (..A..) @@ -938,14 +961,14 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { left, op: BitwiseAnd, right, - }) if expr_contains(&left, &right, BitwiseAnd) => *left, + }) if expr_contains(&left, &right, BitwiseAnd) => Transformed::yes(*left), // A & (..A..) --> (..A..) Expr::BinaryExpr(BinaryExpr { left, op: BitwiseAnd, right, - }) if expr_contains(&right, &left, BitwiseAnd) => *right, + }) if expr_contains(&right, &left, BitwiseAnd) => Transformed::yes(*right), // A & (A | B) --> A (if B not null) Expr::BinaryExpr(BinaryExpr { @@ -953,7 +976,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: BitwiseAnd, right, }) if !info.nullable(&right)? && is_op_with(BitwiseOr, &right, &left) => { - *left + Transformed::yes(*left) } // (A | B) & A --> A (if B not null) @@ -962,7 +985,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: BitwiseAnd, right, }) if !info.nullable(&left)? && is_op_with(BitwiseOr, &left, &right) => { - *right + Transformed::yes(*right) } // @@ -974,28 +997,28 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { left: _, op: BitwiseOr, right, - }) if is_null(&right) => *right, + }) if is_null(&right) => Transformed::yes(*right), // null | A -> null Expr::BinaryExpr(BinaryExpr { left, op: BitwiseOr, right: _, - }) if is_null(&left) => *left, + }) if is_null(&left) => Transformed::yes(*left), // A | 0 -> A (even if A is null) Expr::BinaryExpr(BinaryExpr { left, op: BitwiseOr, right, - }) if is_zero(&right) => *left, + }) if is_zero(&right) => Transformed::yes(*left), // 0 | A -> A (even if A is null) Expr::BinaryExpr(BinaryExpr { left, op: BitwiseOr, right, - }) if is_zero(&left) => *right, + }) if is_zero(&left) => Transformed::yes(*right), // !A | A -> -1 (if A not nullable) Expr::BinaryExpr(BinaryExpr { @@ -1003,7 +1026,9 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: BitwiseOr, right, }) if is_negative_of(&left, &right) && !info.nullable(&right)? => { - Expr::Literal(ScalarValue::new_negative_one(&info.get_data_type(&left)?)?) + Transformed::yes(Expr::Literal(ScalarValue::new_negative_one( + &info.get_data_type(&left)?, + )?)) } // A | !A -> -1 (if A not nullable) @@ -1012,7 +1037,9 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: BitwiseOr, right, }) if is_negative_of(&right, &left) && !info.nullable(&left)? => { - Expr::Literal(ScalarValue::new_negative_one(&info.get_data_type(&left)?)?) + Transformed::yes(Expr::Literal(ScalarValue::new_negative_one( + &info.get_data_type(&left)?, + )?)) } // (..A..) | A --> (..A..) @@ -1020,14 +1047,14 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { left, op: BitwiseOr, right, - }) if expr_contains(&left, &right, BitwiseOr) => *left, + }) if expr_contains(&left, &right, BitwiseOr) => Transformed::yes(*left), // A | (..A..) --> (..A..) Expr::BinaryExpr(BinaryExpr { left, op: BitwiseOr, right, - }) if expr_contains(&right, &left, BitwiseOr) => *right, + }) if expr_contains(&right, &left, BitwiseOr) => Transformed::yes(*right), // A | (A & B) --> A (if B not null) Expr::BinaryExpr(BinaryExpr { @@ -1035,7 +1062,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: BitwiseOr, right, }) if !info.nullable(&right)? && is_op_with(BitwiseAnd, &right, &left) => { - *left + Transformed::yes(*left) } // (A & B) | A --> A (if B not null) @@ -1044,7 +1071,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: BitwiseOr, right, }) if !info.nullable(&left)? && is_op_with(BitwiseAnd, &left, &right) => { - *right + Transformed::yes(*right) } // @@ -1056,28 +1083,28 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { left: _, op: BitwiseXor, right, - }) if is_null(&right) => *right, + }) if is_null(&right) => Transformed::yes(*right), // null ^ A -> null Expr::BinaryExpr(BinaryExpr { left, op: BitwiseXor, right: _, - }) if is_null(&left) => *left, + }) if is_null(&left) => Transformed::yes(*left), // A ^ 0 -> A (if A not nullable) Expr::BinaryExpr(BinaryExpr { left, op: BitwiseXor, right, - }) if !info.nullable(&left)? && is_zero(&right) => *left, + }) if !info.nullable(&left)? && is_zero(&right) => Transformed::yes(*left), // 0 ^ A -> A (if A not nullable) Expr::BinaryExpr(BinaryExpr { left, op: BitwiseXor, right, - }) if !info.nullable(&right)? && is_zero(&left) => *right, + }) if !info.nullable(&right)? && is_zero(&left) => Transformed::yes(*right), // !A ^ A -> -1 (if A not nullable) Expr::BinaryExpr(BinaryExpr { @@ -1085,7 +1112,9 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: BitwiseXor, right, }) if is_negative_of(&left, &right) && !info.nullable(&right)? => { - Expr::Literal(ScalarValue::new_negative_one(&info.get_data_type(&left)?)?) + Transformed::yes(Expr::Literal(ScalarValue::new_negative_one( + &info.get_data_type(&left)?, + )?)) } // A ^ !A -> -1 (if A not nullable) @@ -1094,7 +1123,9 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: BitwiseXor, right, }) if is_negative_of(&right, &left) && !info.nullable(&left)? => { - Expr::Literal(ScalarValue::new_negative_one(&info.get_data_type(&left)?)?) + Transformed::yes(Expr::Literal(ScalarValue::new_negative_one( + &info.get_data_type(&left)?, + )?)) } // (..A..) ^ A --> (the expression without A, if number of A is odd, otherwise one A) @@ -1104,11 +1135,11 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { right, }) if expr_contains(&left, &right, BitwiseXor) => { let expr = delete_xor_in_complex_expr(&left, &right, false); - if expr == *right { + Transformed::yes(if expr == *right { Expr::Literal(ScalarValue::new_zero(&info.get_data_type(&right)?)?) } else { expr - } + }) } // A ^ (..A..) --> (the expression without A, if number of A is odd, otherwise one A) @@ -1118,11 +1149,11 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { right, }) if expr_contains(&right, &left, BitwiseXor) => { let expr = delete_xor_in_complex_expr(&right, &left, true); - if expr == *left { + Transformed::yes(if expr == *left { Expr::Literal(ScalarValue::new_zero(&info.get_data_type(&left)?)?) } else { expr - } + }) } // @@ -1134,21 +1165,21 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { left: _, op: BitwiseShiftRight, right, - }) if is_null(&right) => *right, + }) if is_null(&right) => Transformed::yes(*right), // null >> A -> null Expr::BinaryExpr(BinaryExpr { left, op: BitwiseShiftRight, right: _, - }) if is_null(&left) => *left, + }) if is_null(&left) => Transformed::yes(*left), // A >> 0 -> A (even if A is null) Expr::BinaryExpr(BinaryExpr { left, op: BitwiseShiftRight, right, - }) if is_zero(&right) => *left, + }) if is_zero(&right) => Transformed::yes(*left), // // Rules for BitwiseShiftRight @@ -1159,31 +1190,31 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { left: _, op: BitwiseShiftLeft, right, - }) if is_null(&right) => *right, + }) if is_null(&right) => Transformed::yes(*right), // null << A -> null Expr::BinaryExpr(BinaryExpr { left, op: BitwiseShiftLeft, right: _, - }) if is_null(&left) => *left, + }) if is_null(&left) => Transformed::yes(*left), // A << 0 -> A (even if A is null) Expr::BinaryExpr(BinaryExpr { left, op: BitwiseShiftLeft, right, - }) if is_zero(&right) => *left, + }) if is_zero(&right) => Transformed::yes(*left), // // Rules for Not // - Expr::Not(inner) => negate_clause(*inner), + Expr::Not(inner) => Transformed::yes(negate_clause(*inner)), // // Rules for Negative // - Expr::Negative(inner) => distribute_negation(*inner), + Expr::Negative(inner) => Transformed::yes(distribute_negation(*inner)), // // Rules for Case @@ -1237,19 +1268,19 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { Expr::ScalarFunction(ScalarFunction { func_def: ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::Log), args, - }) => simpl_log(args, <&S>::clone(&info))?, + }) => Transformed::yes(simpl_log(args, info)?), // power Expr::ScalarFunction(ScalarFunction { func_def: ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::Power), args, - }) => simpl_power(args, <&S>::clone(&info))?, + }) => Transformed::yes(simpl_power(args, info)?), // concat Expr::ScalarFunction(ScalarFunction { func_def: ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::Concat), args, - }) => simpl_concat(args)?, + }) => Transformed::yes(simpl_concat(args)?), // concat_ws Expr::ScalarFunction(ScalarFunction { @@ -1259,11 +1290,13 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { ), args, }) => match &args[..] { - [delimiter, vals @ ..] => simpl_concat_ws(delimiter, vals)?, - _ => Expr::ScalarFunction(ScalarFunction::new( + [delimiter, vals @ ..] => { + Transformed::yes(simpl_concat_ws(delimiter, vals)?) + } + _ => Transformed::yes(Expr::ScalarFunction(ScalarFunction::new( BuiltinScalarFunction::ConcatWithSeparator, args, - )), + ))), }, // @@ -1272,18 +1305,16 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { // a between 3 and 5 --> a >= 3 AND a <=5 // a not between 3 and 5 --> a < 3 OR a > 5 - Expr::Between(between) => { - if between.negated { - let l = *between.expr.clone(); - let r = *between.expr; - or(l.lt(*between.low), r.gt(*between.high)) - } else { - and( - between.expr.clone().gt_eq(*between.low), - between.expr.lt_eq(*between.high), - ) - } - } + Expr::Between(between) => Transformed::yes(if between.negated { + let l = *between.expr.clone(); + let r = *between.expr; + or(l.lt(*between.low), r.gt(*between.high)) + } else { + and( + between.expr.clone().gt_eq(*between.low), + between.expr.lt_eq(*between.high), + ) + }), // // Rules for regexes @@ -1292,7 +1323,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { left, op: op @ (RegexMatch | RegexNotMatch | RegexIMatch | RegexNotIMatch), right, - }) => simplify_regex_expr(left, op, right)?, + }) => Transformed::yes(simplify_regex_expr(left, op, right)?), // Rules for Like Expr::Like(Like { @@ -1307,25 +1338,24 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { Expr::Literal(ScalarValue::Utf8(Some(pattern_str))) if pattern_str == "%" ) => { - lit(!negated) + Transformed::yes(lit(!negated)) } // a is not null/unknown --> true (if a is not nullable) Expr::IsNotNull(expr) | Expr::IsNotUnknown(expr) if !info.nullable(&expr)? => { - lit(true) + Transformed::yes(lit(true)) } // a is null/unknown --> false (if a is not nullable) Expr::IsNull(expr) | Expr::IsUnknown(expr) if !info.nullable(&expr)? => { - lit(false) + Transformed::yes(lit(false)) } // no additional rewrites possible - expr => expr, - }; - Ok(new_expr) + expr => Transformed::no(expr), + }) } } @@ -1337,16 +1367,15 @@ mod tests { sync::Arc, }; + use super::*; + use crate::simplify_expressions::SimplifyContext; + use crate::test::test_table_scan_with_name; + use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::{assert_contains, DFField, ToDFSchema}; use datafusion_expr::{interval_arithmetic::Interval, *}; use datafusion_physical_expr::execution_props::ExecutionProps; - use crate::simplify_expressions::SimplifyContext; - use crate::test::test_table_scan_with_name; - - use super::*; - // ------------------------------ // --- ExprSimplifier tests ----- // ------------------------------ diff --git a/datafusion/optimizer/src/simplify_expressions/guarantees.rs b/datafusion/optimizer/src/simplify_expressions/guarantees.rs index aa7bb4f78a93..6eb583257dcb 100644 --- a/datafusion/optimizer/src/simplify_expressions/guarantees.rs +++ b/datafusion/optimizer/src/simplify_expressions/guarantees.rs @@ -21,7 +21,8 @@ use std::{borrow::Cow, collections::HashMap}; -use datafusion_common::{tree_node::TreeNodeRewriter, DataFusionError, Result}; +use datafusion_common::tree_node::{Transformed, TreeNodeRewriter}; +use datafusion_common::{DataFusionError, Result}; use datafusion_expr::interval_arithmetic::{Interval, NullableInterval}; use datafusion_expr::{expr::InList, lit, Between, BinaryExpr, Expr}; @@ -57,23 +58,25 @@ impl<'a> GuaranteeRewriter<'a> { } impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> { - type N = Expr; + type Node = Expr; - fn mutate(&mut self, expr: Expr) -> Result { + fn f_up(&mut self, expr: Expr) -> Result> { if self.guarantees.is_empty() { - return Ok(expr); + return Ok(Transformed::no(expr)); } match &expr { Expr::IsNull(inner) => match self.guarantees.get(inner.as_ref()) { - Some(NullableInterval::Null { .. }) => Ok(lit(true)), - Some(NullableInterval::NotNull { .. }) => Ok(lit(false)), - _ => Ok(expr), + Some(NullableInterval::Null { .. }) => Ok(Transformed::yes(lit(true))), + Some(NullableInterval::NotNull { .. }) => { + Ok(Transformed::yes(lit(false))) + } + _ => Ok(Transformed::no(expr)), }, Expr::IsNotNull(inner) => match self.guarantees.get(inner.as_ref()) { - Some(NullableInterval::Null { .. }) => Ok(lit(false)), - Some(NullableInterval::NotNull { .. }) => Ok(lit(true)), - _ => Ok(expr), + Some(NullableInterval::Null { .. }) => Ok(Transformed::yes(lit(false))), + Some(NullableInterval::NotNull { .. }) => Ok(Transformed::yes(lit(true))), + _ => Ok(Transformed::no(expr)), }, Expr::Between(Between { expr: inner, @@ -93,14 +96,14 @@ impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> { let contains = expr_interval.contains(*interval)?; if contains.is_certainly_true() { - Ok(lit(!negated)) + Ok(Transformed::yes(lit(!negated))) } else if contains.is_certainly_false() { - Ok(lit(*negated)) + Ok(Transformed::yes(lit(*negated))) } else { - Ok(expr) + Ok(Transformed::no(expr)) } } else { - Ok(expr) + Ok(Transformed::no(expr)) } } @@ -135,23 +138,23 @@ impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> { let result = left_interval.apply_operator(op, right_interval.as_ref())?; if result.is_certainly_true() { - Ok(lit(true)) + Ok(Transformed::yes(lit(true))) } else if result.is_certainly_false() { - Ok(lit(false)) + Ok(Transformed::yes(lit(false))) } else { - Ok(expr) + Ok(Transformed::no(expr)) } } - _ => Ok(expr), + _ => Ok(Transformed::no(expr)), } } // Columns (if interval is collapsed to a single value) Expr::Column(_) => { if let Some(interval) = self.guarantees.get(&expr) { - Ok(interval.single_value().map_or(expr, lit)) + Ok(Transformed::yes(interval.single_value().map_or(expr, lit))) } else { - Ok(expr) + Ok(Transformed::no(expr)) } } @@ -181,17 +184,17 @@ impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> { }) .collect::>()?; - Ok(Expr::InList(InList { + Ok(Transformed::yes(Expr::InList(InList { expr: inner.clone(), list: new_list, negated: *negated, - })) + }))) } else { - Ok(expr) + Ok(Transformed::no(expr)) } } - _ => Ok(expr), + _ => Ok(Transformed::no(expr)), } } } @@ -201,7 +204,8 @@ mod tests { use super::*; use arrow::datatypes::DataType; - use datafusion_common::{tree_node::TreeNode, ScalarValue}; + use datafusion_common::tree_node::{TransformedResult, TreeNode}; + use datafusion_common::ScalarValue; use datafusion_expr::{col, lit, Operator}; #[test] @@ -221,12 +225,12 @@ mod tests { // x IS NULL => guaranteed false let expr = col("x").is_null(); - let output = expr.clone().rewrite(&mut rewriter).unwrap(); + let output = expr.clone().rewrite(&mut rewriter).data().unwrap(); assert_eq!(output, lit(false)); // x IS NOT NULL => guaranteed true let expr = col("x").is_not_null(); - let output = expr.clone().rewrite(&mut rewriter).unwrap(); + let output = expr.clone().rewrite(&mut rewriter).data().unwrap(); assert_eq!(output, lit(true)); } @@ -236,7 +240,7 @@ mod tests { T: Clone, { for (expr, expected_value) in cases { - let output = expr.clone().rewrite(rewriter).unwrap(); + let output = expr.clone().rewrite(rewriter).data().unwrap(); let expected = lit(ScalarValue::from(expected_value.clone())); assert_eq!( output, expected, @@ -248,7 +252,7 @@ mod tests { fn validate_unchanged_cases(rewriter: &mut GuaranteeRewriter, cases: &[Expr]) { for expr in cases { - let output = expr.clone().rewrite(rewriter).unwrap(); + let output = expr.clone().rewrite(rewriter).data().unwrap(); assert_eq!( &output, expr, "{} was simplified to {}, but expected it to be unchanged", @@ -478,7 +482,7 @@ mod tests { let guarantees = vec![(col("x"), NullableInterval::from(scalar.clone()))]; let mut rewriter = GuaranteeRewriter::new(guarantees.iter()); - let output = col("x").rewrite(&mut rewriter).unwrap(); + let output = col("x").rewrite(&mut rewriter).data().unwrap(); assert_eq!(output, Expr::Literal(scalar.clone())); } } @@ -522,7 +526,7 @@ mod tests { .collect(), *negated, ); - let output = expr.clone().rewrite(&mut rewriter).unwrap(); + let output = expr.clone().rewrite(&mut rewriter).data().unwrap(); let expected_list = expected_list .iter() .map(|v| lit(ScalarValue::Int32(Some(*v)))) diff --git a/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs index 710c24f66e33..fa1d7cfc1239 100644 --- a/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs @@ -17,17 +17,17 @@ //! This module implements a rule that simplifies the values for `InList`s +use super::utils::{is_null, lit_bool_null}; +use super::THRESHOLD_INLINE_INLIST; + use std::borrow::Cow; use std::collections::HashSet; -use datafusion_common::tree_node::TreeNodeRewriter; +use datafusion_common::tree_node::{Transformed, TreeNodeRewriter}; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::expr::{InList, InSubquery}; use datafusion_expr::{lit, BinaryExpr, Expr, Operator}; -use super::utils::{is_null, lit_bool_null}; -use super::THRESHOLD_INLINE_INLIST; - pub(super) struct ShortenInListSimplifier {} impl ShortenInListSimplifier { @@ -37,9 +37,9 @@ impl ShortenInListSimplifier { } impl TreeNodeRewriter for ShortenInListSimplifier { - type N = Expr; + type Node = Expr; - fn mutate(&mut self, expr: Expr) -> Result { + fn f_up(&mut self, expr: Expr) -> Result> { // if expr is a single column reference: // expr IN (A, B, ...) --> (expr = A) OR (expr = B) OR (expr = C) if let Expr::InList(InList { @@ -61,7 +61,7 @@ impl TreeNodeRewriter for ShortenInListSimplifier { { let first_val = list[0].clone(); if negated { - return Ok(list.into_iter().skip(1).fold( + return Ok(Transformed::yes(list.into_iter().skip(1).fold( (*expr.clone()).not_eq(first_val), |acc, y| { // Note that `A and B and C and D` is a left-deep tree structure @@ -83,20 +83,20 @@ impl TreeNodeRewriter for ShortenInListSimplifier { // The code below maintain the left-deep tree structure. acc.and((*expr.clone()).not_eq(y)) }, - )); + ))); } else { - return Ok(list.into_iter().skip(1).fold( + return Ok(Transformed::yes(list.into_iter().skip(1).fold( (*expr.clone()).eq(first_val), |acc, y| { // Same reasoning as above acc.or((*expr.clone()).eq(y)) }, - )); + ))); } } } - Ok(expr) + Ok(Transformed::no(expr)) } } @@ -109,9 +109,9 @@ impl InListSimplifier { } impl TreeNodeRewriter for InListSimplifier { - type N = Expr; + type Node = Expr; - fn mutate(&mut self, expr: Expr) -> Result { + fn f_up(&mut self, expr: Expr) -> Result> { if let Expr::InList(InList { expr, mut list, @@ -121,11 +121,11 @@ impl TreeNodeRewriter for InListSimplifier { // expr IN () --> false // expr NOT IN () --> true if list.is_empty() && *expr != Expr::Literal(ScalarValue::Null) { - return Ok(lit(negated)); + return Ok(Transformed::yes(lit(negated))); // null in (x, y, z) --> null // null not in (x, y, z) --> null } else if is_null(&expr) { - return Ok(lit_bool_null()); + return Ok(Transformed::yes(lit_bool_null())); // expr IN ((subquery)) -> expr IN (subquery), see ##5529 } else if list.len() == 1 && matches!(list.first(), Some(Expr::ScalarSubquery { .. })) @@ -133,7 +133,9 @@ impl TreeNodeRewriter for InListSimplifier { let Expr::ScalarSubquery(subquery) = list.remove(0) else { unreachable!() }; - return Ok(Expr::InSubquery(InSubquery::new(expr, subquery, negated))); + return Ok(Transformed::yes(Expr::InSubquery(InSubquery::new( + expr, subquery, negated, + )))); } } // Combine multiple OR expressions into a single IN list expression if possible @@ -165,7 +167,7 @@ impl TreeNodeRewriter for InListSimplifier { list, negated: false, }; - return Ok(Expr::InList(merged_inlist)); + return Ok(Transformed::yes(Expr::InList(merged_inlist))); } } } @@ -191,40 +193,40 @@ impl TreeNodeRewriter for InListSimplifier { (Expr::InList(l1), Operator::And, Expr::InList(l2)) if l1.expr == l2.expr && !l1.negated && !l2.negated => { - return inlist_intersection(l1, l2, false); + return inlist_intersection(l1, l2, false).map(Transformed::yes); } (Expr::InList(l1), Operator::And, Expr::InList(l2)) if l1.expr == l2.expr && l1.negated && l2.negated => { - return inlist_union(l1, l2, true); + return inlist_union(l1, l2, true).map(Transformed::yes); } (Expr::InList(l1), Operator::And, Expr::InList(l2)) if l1.expr == l2.expr && !l1.negated && l2.negated => { - return inlist_except(l1, l2); + return inlist_except(l1, l2).map(Transformed::yes); } (Expr::InList(l1), Operator::And, Expr::InList(l2)) if l1.expr == l2.expr && l1.negated && !l2.negated => { - return inlist_except(l2, l1); + return inlist_except(l2, l1).map(Transformed::yes); } (Expr::InList(l1), Operator::Or, Expr::InList(l2)) if l1.expr == l2.expr && l1.negated && l2.negated => { - return inlist_intersection(l1, l2, true); + return inlist_intersection(l1, l2, true).map(Transformed::yes); } (left, op, right) => { // put the expression back together - return Ok(Expr::BinaryExpr(BinaryExpr { + return Ok(Transformed::no(Expr::BinaryExpr(BinaryExpr { left: Box::new(left), op, right: Box::new(right), - })); + }))); } } } - Ok(expr) + Ok(Transformed::no(expr)) } } diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs index 4c22742c8635..196a35ee9ae8 100644 --- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs +++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs @@ -18,13 +18,18 @@ //! Unwrap-cast binary comparison rule can be used to the binary/inlist comparison expr now, and other type //! of expr can be added if needed. //! This rule can reduce adding the `Expr::Cast` the expr instead of adding the `Expr::Cast` to literal expr. + +use std::cmp::Ordering; +use std::sync::Arc; + use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; + use arrow::datatypes::{ DataType, TimeUnit, MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION, }; use arrow::temporal_conversions::{MICROSECONDS, MILLISECONDS, NANOSECONDS}; -use datafusion_common::tree_node::{RewriteRecursion, TreeNodeRewriter}; +use datafusion_common::tree_node::{Transformed, TreeNodeRewriter}; use datafusion_common::{internal_err, DFSchema, DFSchemaRef, Result, ScalarValue}; use datafusion_expr::expr::{BinaryExpr, Cast, InList, TryCast}; use datafusion_expr::expr_rewriter::rewrite_preserving_name; @@ -32,8 +37,6 @@ use datafusion_expr::utils::merge_schema; use datafusion_expr::{ binary_expr, in_list, lit, Expr, ExprSchemable, LogicalPlan, Operator, }; -use std::cmp::Ordering; -use std::sync::Arc; /// [`UnwrapCastInComparison`] attempts to remove casts from /// comparisons to literals ([`ScalarValue`]s) by applying the casts @@ -125,13 +128,9 @@ struct UnwrapCastExprRewriter { } impl TreeNodeRewriter for UnwrapCastExprRewriter { - type N = Expr; - - fn pre_visit(&mut self, _expr: &Expr) -> Result { - Ok(RewriteRecursion::Continue) - } + type Node = Expr; - fn mutate(&mut self, expr: Expr) -> Result { + fn f_up(&mut self, expr: Expr) -> Result> { match &expr { // For case: // try_cast/cast(expr as data_type) op literal @@ -159,11 +158,11 @@ impl TreeNodeRewriter for UnwrapCastExprRewriter { try_cast_literal_to_type(left_lit_value, &expr_type)?; if let Some(value) = casted_scalar_value { // unwrap the cast/try_cast for the right expr - return Ok(binary_expr( + return Ok(Transformed::yes(binary_expr( lit(value), *op, expr.as_ref().clone(), - )); + ))); } } ( @@ -178,11 +177,11 @@ impl TreeNodeRewriter for UnwrapCastExprRewriter { try_cast_literal_to_type(right_lit_value, &expr_type)?; if let Some(value) = casted_scalar_value { // unwrap the cast/try_cast for the left expr - return Ok(binary_expr( + return Ok(Transformed::yes(binary_expr( expr.as_ref().clone(), *op, lit(value), - )); + ))); } } (_, _) => { @@ -191,7 +190,7 @@ impl TreeNodeRewriter for UnwrapCastExprRewriter { }; } // return the new binary op - Ok(binary_expr(left, *op, right)) + Ok(Transformed::yes(binary_expr(left, *op, right))) } // For case: // try_cast/cast(expr as left_type) in (expr1,expr2,expr3) @@ -215,12 +214,12 @@ impl TreeNodeRewriter for UnwrapCastExprRewriter { let internal_left_type = internal_left.get_type(&self.schema); if internal_left_type.is_err() { // error data type - return Ok(expr); + return Ok(Transformed::no(expr)); } let internal_left_type = internal_left_type?; if !is_support_data_type(&internal_left_type) { // not supported data type - return Ok(expr); + return Ok(Transformed::no(expr)); } let right_exprs = list .iter() @@ -255,17 +254,19 @@ impl TreeNodeRewriter for UnwrapCastExprRewriter { }) .collect::>>(); match right_exprs { - Ok(right_exprs) => { - Ok(in_list(internal_left, right_exprs, *negated)) - } - Err(_) => Ok(expr), + Ok(right_exprs) => Ok(Transformed::yes(in_list( + internal_left, + right_exprs, + *negated, + ))), + Err(_) => Ok(Transformed::no(expr)), } } else { - Ok(expr) + Ok(Transformed::no(expr)) } } // TODO: handle other expr type and dfs visit them - _ => Ok(expr), + _ => Ok(Transformed::no(expr)), } } } @@ -474,15 +475,17 @@ fn cast_between_timestamp(from: DataType, to: DataType, value: i128) -> Option DFSchemaRef { diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 6189f9a57942..0df79550f143 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -17,16 +17,18 @@ //! Collection of utility functions that are leveraged by the query optimizer rules +use std::collections::{BTreeSet, HashMap}; + use crate::{OptimizerConfig, OptimizerRule}; -use datafusion_common::tree_node::{TreeNode, VisitRecursion}; -use datafusion_common::{Column, DFSchemaRef}; -use datafusion_common::{DFSchema, Result}; + +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; +use datafusion_common::{Column, DFSchema, DFSchemaRef, Result}; use datafusion_expr::expr::is_volatile; use datafusion_expr::expr_rewriter::replace_col; use datafusion_expr::utils as expr_utils; use datafusion_expr::{logical_plan::LogicalPlan, Expr, Operator}; + use log::{debug, trace}; -use std::collections::{BTreeSet, HashMap}; /// Convenience rule for writing optimizers: recursively invoke /// optimize on plan's children and then return a node of the same @@ -101,9 +103,9 @@ pub(crate) fn is_volatile_expression(e: &Expr) -> Result { e.apply(&mut |expr| { Ok(if is_volatile(expr)? { is_volatile_expr = true; - VisitRecursion::Stop + TreeNodeRecursion::Stop } else { - VisitRecursion::Continue + TreeNodeRecursion::Continue }) })?; Ok(is_volatile_expr) diff --git a/datafusion/physical-expr/benches/regexp.rs b/datafusion/physical-expr/benches/regexp.rs index 0371b6bf28a9..32acd6ca8f28 100644 --- a/datafusion/physical-expr/benches/regexp.rs +++ b/datafusion/physical-expr/benches/regexp.rs @@ -23,15 +23,11 @@ use std::sync::Arc; use arrow_array::builder::StringBuilder; use arrow_array::{ArrayRef, StringArray}; use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_physical_expr::regex_expressions::{regexp_match, regexp_replace}; use rand::distributions::Alphanumeric; use rand::rngs::ThreadRng; use rand::seq::SliceRandom; use rand::Rng; - -use datafusion_physical_expr::regex_expressions::{ - regexp_like, regexp_match, regexp_replace, -}; - fn data(rng: &mut ThreadRng) -> StringArray { let mut data: Vec = vec![]; for _ in 0..1000 { @@ -78,20 +74,6 @@ fn flags(rng: &mut ThreadRng) -> StringArray { } fn criterion_benchmark(c: &mut Criterion) { - c.bench_function("regexp_like_1000", |b| { - let mut rng = rand::thread_rng(); - let data = Arc::new(data(&mut rng)) as ArrayRef; - let regex = Arc::new(regex(&mut rng)) as ArrayRef; - let flags = Arc::new(flags(&mut rng)) as ArrayRef; - - b.iter(|| { - black_box( - regexp_like::(&[data.clone(), regex.clone(), flags.clone()]) - .expect("regexp_like should work on valid values"), - ) - }) - }); - c.bench_function("regexp_match_1000", |b| { let mut rng = rand::thread_rng(); let data = Arc::new(data(&mut rng)) as ArrayRef; diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 01b2ae13c8d4..c10f5df54027 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -193,28 +193,6 @@ fn compute_array_length( } } -/// Returns the length of each array dimension -fn compute_array_dims(arr: Option) -> Result>>> { - let mut value = match arr { - Some(arr) => arr, - None => return Ok(None), - }; - if value.is_empty() { - return Ok(None); - } - let mut res = vec![Some(value.len() as u64)]; - - loop { - match value.data_type() { - DataType::List(..) => { - value = downcast_arg!(value, ListArray).value(0); - res.push(Some(value.len() as u64)); - } - _ => return Ok(Some(res)), - } - } -} - fn check_datatypes(name: &str, args: &[&ArrayRef]) -> Result<()> { let data_type = args[0].data_type(); if !args.iter().all(|arg| { @@ -1938,40 +1916,6 @@ pub fn array_intersect(args: &[ArrayRef]) -> Result { general_set_op(array1, array2, SetOp::Intersect) } -/// Cardinality SQL function -pub fn cardinality(args: &[ArrayRef]) -> Result { - if args.len() != 1 { - return exec_err!("cardinality expects one argument"); - } - - match &args[0].data_type() { - DataType::List(_) => { - let list_array = as_list_array(&args[0])?; - generic_list_cardinality::(list_array) - } - DataType::LargeList(_) => { - let list_array = as_large_list_array(&args[0])?; - generic_list_cardinality::(list_array) - } - other => { - exec_err!("cardinality does not support type '{:?}'", other) - } - } -} - -fn generic_list_cardinality( - array: &GenericListArray, -) -> Result { - let result = array - .iter() - .map(|arr| match compute_array_dims(arr)? { - Some(vector) => Ok(Some(vector.iter().map(|x| x.unwrap()).product::())), - None => Ok(None), - }) - .collect::>()?; - Ok(Arc::new(result) as ArrayRef) -} - // Create new offsets that are euqiavlent to `flatten` the array. fn get_offsets_for_flatten( offsets: OffsetBuffer, @@ -2074,72 +2018,6 @@ pub fn array_length(args: &[ArrayRef]) -> Result { } } -/// Array_dims SQL function -pub fn array_dims(args: &[ArrayRef]) -> Result { - if args.len() != 1 { - return exec_err!("array_dims needs one argument"); - } - - let data = match args[0].data_type() { - DataType::List(_) => { - let array = as_list_array(&args[0])?; - array - .iter() - .map(compute_array_dims) - .collect::>>()? - } - DataType::LargeList(_) => { - let array = as_large_list_array(&args[0])?; - array - .iter() - .map(compute_array_dims) - .collect::>>()? - } - array_type => { - return exec_err!("array_dims does not support type '{array_type:?}'"); - } - }; - - let result = ListArray::from_iter_primitive::(data); - - Ok(Arc::new(result) as ArrayRef) -} - -/// Array_ndims SQL function -pub fn array_ndims(args: &[ArrayRef]) -> Result { - if args.len() != 1 { - return exec_err!("array_ndims needs one argument"); - } - - fn general_list_ndims( - array: &GenericListArray, - ) -> Result { - let mut data = Vec::new(); - let ndims = datafusion_common::utils::list_ndims(array.data_type()); - - for arr in array.iter() { - if arr.is_some() { - data.push(Some(ndims)) - } else { - data.push(None) - } - } - - Ok(Arc::new(UInt64Array::from(data)) as ArrayRef) - } - match args[0].data_type() { - DataType::List(_) => { - let array = as_list_array(&args[0])?; - general_list_ndims::(array) - } - DataType::LargeList(_) => { - let array = as_large_list_array(&args[0])?; - general_list_ndims::(array) - } - array_type => exec_err!("array_ndims does not support type {array_type:?}"), - } -} - /// Represents the type of comparison for array_has. #[derive(Debug, PartialEq)] enum ComparisonType { diff --git a/datafusion/physical-expr/src/equivalence/class.rs b/datafusion/physical-expr/src/equivalence/class.rs index 1f797018719b..280535f5e6be 100644 --- a/datafusion/physical-expr/src/equivalence/class.rs +++ b/datafusion/physical-expr/src/equivalence/class.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +use std::sync::Arc; + use super::{add_offset_to_expr, collapse_lex_req, ProjectionMapping}; use crate::{ expressions::Column, physical_expr::deduplicate_physical_exprs, @@ -22,9 +24,9 @@ use crate::{ LexRequirement, LexRequirementRef, PhysicalExpr, PhysicalExprRef, PhysicalSortExpr, PhysicalSortRequirement, }; -use datafusion_common::tree_node::TreeNode; -use datafusion_common::{tree_node::Transformed, JoinType}; -use std::sync::Arc; + +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_common::JoinType; /// An `EquivalenceClass` is a set of [`Arc`]s that are known /// to have the same value for all tuples in a relation. These are generated by @@ -263,11 +265,12 @@ impl EquivalenceGroup { .transform(&|expr| { for cls in self.iter() { if cls.contains(&expr) { - return Ok(Transformed::Yes(cls.canonical_expr().unwrap())); + return Ok(Transformed::yes(cls.canonical_expr().unwrap())); } } - Ok(Transformed::No(expr)) + Ok(Transformed::no(expr)) }) + .data() .unwrap_or(expr) } @@ -458,11 +461,12 @@ impl EquivalenceGroup { column.index() + left_size, )) as _; - return Ok(Transformed::Yes(new_column)); + return Ok(Transformed::yes(new_column)); } - Ok(Transformed::No(expr)) + Ok(Transformed::no(expr)) }) + .data() .unwrap(); result.add_equal_conditions(&new_lhs, &new_rhs); } @@ -477,15 +481,14 @@ impl EquivalenceGroup { #[cfg(test)] mod tests { + use std::sync::Arc; + use super::*; use crate::equivalence::tests::create_test_params; use crate::equivalence::{EquivalenceClass, EquivalenceGroup}; - use crate::expressions::lit; - use crate::expressions::Column; - use crate::expressions::Literal; - use datafusion_common::Result; - use datafusion_common::ScalarValue; - use std::sync::Arc; + use crate::expressions::{lit, Column, Literal}; + + use datafusion_common::{Result, ScalarValue}; #[test] fn test_bridge_groups() -> Result<()> { diff --git a/datafusion/physical-expr/src/equivalence/mod.rs b/datafusion/physical-expr/src/equivalence/mod.rs index a31be06ecf0b..46909f23616f 100644 --- a/datafusion/physical-expr/src/equivalence/mod.rs +++ b/datafusion/physical-expr/src/equivalence/mod.rs @@ -15,18 +15,22 @@ // specific language governing permissions and limitations // under the License. +use std::sync::Arc; + +use crate::expressions::Column; +use crate::{LexRequirement, PhysicalExpr, PhysicalSortRequirement}; + +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; + mod class; mod ordering; mod projection; mod properties; -use crate::expressions::Column; -use crate::{LexRequirement, PhysicalExpr, PhysicalSortRequirement}; + pub use class::{EquivalenceClass, EquivalenceGroup}; -use datafusion_common::tree_node::{Transformed, TreeNode}; pub use ordering::OrderingEquivalenceClass; pub use projection::ProjectionMapping; pub use properties::{join_equivalence_properties, EquivalenceProperties}; -use std::sync::Arc; /// This function constructs a duplicate-free `LexOrderingReq` by filtering out /// duplicate entries that have same physical expression inside. For example, @@ -48,12 +52,13 @@ pub fn add_offset_to_expr( offset: usize, ) -> Arc { expr.transform_down(&|e| match e.as_any().downcast_ref::() { - Some(col) => Ok(Transformed::Yes(Arc::new(Column::new( + Some(col) => Ok(Transformed::yes(Arc::new(Column::new( col.name(), offset + col.index(), )))), - None => Ok(Transformed::No(e)), + None => Ok(Transformed::no(e)), }) + .data() .unwrap() // Note that we can safely unwrap here since our transform always returns // an `Ok` value. @@ -61,19 +66,22 @@ pub fn add_offset_to_expr( #[cfg(test)] mod tests { + use std::sync::Arc; + use super::*; use crate::expressions::{col, Column}; use crate::PhysicalSortExpr; + use arrow::compute::{lexsort_to_indices, SortColumn}; use arrow::datatypes::{DataType, Field, Schema}; use arrow_array::{ArrayRef, Float64Array, RecordBatch, UInt32Array}; use arrow_schema::{SchemaRef, SortOptions}; use datafusion_common::{plan_datafusion_err, Result}; + use itertools::izip; use rand::rngs::StdRng; use rand::seq::SliceRandom; use rand::{Rng, SeedableRng}; - use std::sync::Arc; pub fn output_schema( mapping: &ProjectionMapping, diff --git a/datafusion/physical-expr/src/equivalence/ordering.rs b/datafusion/physical-expr/src/equivalence/ordering.rs index 1a414592ce4c..64000937448e 100644 --- a/datafusion/physical-expr/src/equivalence/ordering.rs +++ b/datafusion/physical-expr/src/equivalence/ordering.rs @@ -479,7 +479,7 @@ mod tests { vec![ (col_a, options), (col_c, options), - (&floor_a, options), + (floor_a, options), (&a_plus_b, options), ], // expected: requirement is not satisfied. @@ -505,8 +505,8 @@ mod tests { vec![ (col_a, options), (col_b, options), - (&col_c, options), - (&floor_a, options), + (col_c, options), + (floor_a, options), ], // expected: requirement is satisfied. true, diff --git a/datafusion/physical-expr/src/equivalence/projection.rs b/datafusion/physical-expr/src/equivalence/projection.rs index 0f92b2c2f431..96c919667d84 100644 --- a/datafusion/physical-expr/src/equivalence/projection.rs +++ b/datafusion/physical-expr/src/equivalence/projection.rs @@ -21,7 +21,7 @@ use crate::expressions::Column; use crate::PhysicalExpr; use arrow::datatypes::SchemaRef; -use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::Result; /// Stores the mapping between source expressions and target expressions for a @@ -68,10 +68,11 @@ impl ProjectionMapping { let matching_input_field = input_schema.field(idx); let matching_input_column = Column::new(matching_input_field.name(), idx); - Ok(Transformed::Yes(Arc::new(matching_input_column))) + Ok(Transformed::yes(Arc::new(matching_input_column))) } - None => Ok(Transformed::No(e)), + None => Ok(Transformed::no(e)), }) + .data() .map(|source_expr| (source_expr, target_expr)) }) .collect::>>() @@ -108,6 +109,8 @@ impl ProjectionMapping { #[cfg(test)] mod tests { + use std::sync::Arc; + use super::*; use crate::equivalence::tests::{ apply_projection, convert_to_orderings, convert_to_orderings_owned, @@ -119,12 +122,13 @@ mod tests { use crate::expressions::{col, BinaryExpr, Literal}; use crate::functions::create_physical_expr; use crate::PhysicalSortExpr; + use arrow::datatypes::{DataType, Field, Schema}; use arrow_schema::{SortOptions, TimeUnit}; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{BuiltinScalarFunction, Operator}; + use itertools::Itertools; - use std::sync::Arc; #[test] fn project_orderings() -> Result<()> { @@ -283,7 +287,7 @@ mod tests { // orderings vec![ // [a + b ASC, c ASC] - vec![(&a_plus_b, option_asc), (&col_c, option_asc)], + vec![(&a_plus_b, option_asc), (col_c, option_asc)], ], // projection exprs vec![ @@ -546,7 +550,7 @@ mod tests { vec![ (col_a, option_asc), (col_c, option_asc), - (&col_b, option_asc), + (col_b, option_asc), ], ], // proj exprs @@ -805,7 +809,7 @@ mod tests { // [a+b ASC, round(c) ASC, c_new ASC] vec![ (&a_new_plus_b_new, option_asc), - (&col_round_c_res, option_asc), + (col_round_c_res, option_asc), ], // [a+b ASC, c_new ASC] vec![(&a_new_plus_b_new, option_asc), (col_c_new, option_asc)], diff --git a/datafusion/physical-expr/src/equivalence/properties.rs b/datafusion/physical-expr/src/equivalence/properties.rs index 5a9a4f64876d..f234a1fa08cd 100644 --- a/datafusion/physical-expr/src/equivalence/properties.rs +++ b/datafusion/physical-expr/src/equivalence/properties.rs @@ -15,11 +15,6 @@ // specific language governing permissions and limitations // under the License. -use crate::expressions::CastExpr; -use arrow_schema::SchemaRef; -use datafusion_common::{JoinSide, JoinType, Result}; -use indexmap::{IndexMap, IndexSet}; -use itertools::Itertools; use std::hash::{Hash, Hasher}; use std::sync::Arc; @@ -27,7 +22,7 @@ use super::ordering::collapse_lex_ordering; use crate::equivalence::{ collapse_lex_req, EquivalenceGroup, OrderingEquivalenceClass, ProjectionMapping, }; -use crate::expressions::Literal; +use crate::expressions::{CastExpr, Literal}; use crate::sort_properties::{ExprOrdering, SortProperties}; use crate::{ physical_exprs_contains, LexOrdering, LexOrderingRef, LexRequirement, @@ -35,8 +30,12 @@ use crate::{ PhysicalSortRequirement, }; -use arrow_schema::SortOptions; -use datafusion_common::tree_node::{Transformed, TreeNode}; +use arrow_schema::{SchemaRef, SortOptions}; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_common::{JoinSide, JoinType, Result}; + +use indexmap::{IndexMap, IndexSet}; +use itertools::Itertools; /// A `EquivalenceProperties` object stores useful information related to a schema. /// Currently, it keeps track of: @@ -848,6 +847,7 @@ impl EquivalenceProperties { pub fn get_expr_ordering(&self, expr: Arc) -> ExprOrdering { ExprOrdering::new_default(expr.clone()) .transform_up(&|expr| Ok(update_ordering(expr, self))) + .data() // Guaranteed to always return `Ok`. .unwrap() } @@ -886,9 +886,9 @@ fn update_ordering( // We have a Literal, which is the other possible leaf node type: node.data = node.expr.get_ordering(&[]); } else { - return Transformed::No(node); + return Transformed::no(node); } - Transformed::Yes(node) + Transformed::yes(node) } /// This function determines whether the provided expression is constant @@ -1297,10 +1297,12 @@ mod tests { use crate::expressions::{col, BinaryExpr, Column}; use crate::functions::create_physical_expr; use crate::PhysicalSortExpr; + use arrow::datatypes::{DataType, Field, Schema}; use arrow_schema::{Fields, SortOptions, TimeUnit}; use datafusion_common::Result; use datafusion_expr::{BuiltinScalarFunction, Operator}; + use itertools::Itertools; #[test] diff --git a/datafusion/physical-expr/src/execution_props.rs b/datafusion/physical-expr/src/execution_props.rs index 8fdbbb7c5452..20999ab8d3db 100644 --- a/datafusion/physical-expr/src/execution_props.rs +++ b/datafusion/physical-expr/src/execution_props.rs @@ -99,7 +99,7 @@ impl ExecutionProps { ) -> Option> { self.var_providers .as_ref() - .and_then(|var_providers| var_providers.get(&var_type).map(Arc::clone)) + .and_then(|var_providers| var_providers.get(&var_type).cloned()) } } diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 6a168e2f1e5f..609349509b86 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -19,18 +19,18 @@ use std::borrow::Cow; use std::hash::{Hash, Hasher}; use std::{any::Any, sync::Arc}; -use crate::expressions::try_cast; -use crate::expressions::NoOp; +use crate::expressions::{try_cast, NoOp}; use crate::physical_expr::down_cast_any_ref; use crate::PhysicalExpr; + use arrow::array::*; use arrow::compute::kernels::cmp::eq; use arrow::compute::kernels::zip::zip; use arrow::compute::{and, is_null, not, nullif, or, prep_null_mask_filter}; use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; -use datafusion_common::{cast::as_boolean_array, internal_err, DataFusionError, Result}; -use datafusion_common::{exec_err, ScalarValue}; +use datafusion_common::cast::as_boolean_array; +use datafusion_common::{exec_err, internal_err, DataFusionError, Result, ScalarValue}; use datafusion_expr::ColumnarValue; use itertools::Itertools; @@ -414,17 +414,15 @@ pub fn case( #[cfg(test)] mod tests { use super::*; - use crate::expressions::col; - use crate::expressions::lit; - use crate::expressions::{binary, cast}; + use crate::expressions::{binary, cast, col, lit}; + use arrow::array::StringArray; use arrow::buffer::Buffer; use arrow::datatypes::DataType::Float64; use arrow::datatypes::*; use datafusion_common::cast::{as_float64_array, as_int32_array}; - use datafusion_common::plan_err; - use datafusion_common::tree_node::{Transformed, TreeNode}; - use datafusion_common::ScalarValue; + use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; + use datafusion_common::{plan_err, ScalarValue}; use datafusion_expr::type_coercion::binary::comparison_coercion; use datafusion_expr::Operator; @@ -972,11 +970,12 @@ mod tests { _ => None, }; Ok(if let Some(transformed) = transformed { - Transformed::Yes(transformed) + Transformed::yes(transformed) } else { - Transformed::No(e) + Transformed::no(e) }) }) + .data() .unwrap(); let expr3 = expr @@ -993,11 +992,12 @@ mod tests { _ => None, }; Ok(if let Some(transformed) = transformed { - Transformed::Yes(transformed) + Transformed::yes(transformed) } else { - Transformed::No(e) + Transformed::no(e) }) }) + .data() .unwrap(); assert!(expr.ne(&expr2)); diff --git a/datafusion/physical-expr/src/expressions/get_indexed_field.rs b/datafusion/physical-expr/src/expressions/get_indexed_field.rs index 773387bf7421..c93090c4946f 100644 --- a/datafusion/physical-expr/src/expressions/get_indexed_field.rs +++ b/datafusion/physical-expr/src/expressions/get_indexed_field.rs @@ -252,14 +252,14 @@ impl PhysicalExpr for GetIndexedFieldExpr { GetFieldAccessExpr::ListIndex{key} => { let key = key.evaluate(batch)?.into_array(batch.num_rows())?; match (array.data_type(), key.data_type()) { - (DataType::List(_), DataType::Int64) => Ok(ColumnarValue::Array(array_element(&[ + (DataType::List(_), DataType::Int64) | (DataType::LargeList(_), DataType::Int64) => Ok(ColumnarValue::Array(array_element(&[ array, key ])?)), - (DataType::List(_), key) => exec_err!( - "get indexed field is only possible on lists with int64 indexes. \ + (DataType::List(_), key) | (DataType::LargeList(_), key) => exec_err!( + "get indexed field is only possible on List/LargeList with int64 indexes. \ Tried with {key:?} index"), (dt, key) => exec_err!( - "get indexed field is only possible on lists with int64 indexes or struct \ + "get indexed field is only possible on List/LargeList with int64 indexes or struct \ with utf8 indexes. Tried {dt:?} with {key:?} index"), } }, diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 56ad92082d9f..81013882ad89 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -105,14 +105,6 @@ macro_rules! invoke_if_crypto_expressions_feature_flag { }; } -#[cfg(feature = "regex_expressions")] -macro_rules! invoke_on_array_if_regex_expressions_feature_flag { - ($FUNC:ident, $T:tt, $NAME:expr) => {{ - use crate::regex_expressions; - regex_expressions::$FUNC::<$T> - }}; -} - #[cfg(not(feature = "regex_expressions"))] macro_rules! invoke_on_array_if_regex_expressions_feature_flag { ($FUNC:ident, $T:tt, $NAME:expr) => { @@ -339,9 +331,6 @@ pub fn create_physical_fun( BuiltinScalarFunction::ArrayHas => Arc::new(|args| { make_scalar_function_inner(array_expressions::array_has)(args) }), - BuiltinScalarFunction::ArrayDims => Arc::new(|args| { - make_scalar_function_inner(array_expressions::array_dims)(args) - }), BuiltinScalarFunction::ArrayDistinct => Arc::new(|args| { make_scalar_function_inner(array_expressions::array_distinct)(args) }), @@ -357,9 +346,6 @@ pub fn create_physical_fun( BuiltinScalarFunction::Flatten => { Arc::new(|args| make_scalar_function_inner(array_expressions::flatten)(args)) } - BuiltinScalarFunction::ArrayNdims => Arc::new(|args| { - make_scalar_function_inner(array_expressions::array_ndims)(args) - }), BuiltinScalarFunction::ArrayPopFront => Arc::new(|args| { make_scalar_function_inner(array_expressions::array_pop_front)(args) }), @@ -405,9 +391,6 @@ pub fn create_physical_fun( BuiltinScalarFunction::ArrayIntersect => Arc::new(|args| { make_scalar_function_inner(array_expressions::array_intersect)(args) }), - BuiltinScalarFunction::Cardinality => Arc::new(|args| { - make_scalar_function_inner(array_expressions::cardinality)(args) - }), BuiltinScalarFunction::ArrayResize => Arc::new(|args| { make_scalar_function_inner(array_expressions::array_resize)(args) }), @@ -519,15 +502,6 @@ pub fn create_physical_fun( exec_err!("Unsupported data type {other:?} for function initcap") } }), - BuiltinScalarFunction::InStr => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - make_scalar_function_inner(string_expressions::instr::)(args) - } - DataType::LargeUtf8 => { - make_scalar_function_inner(string_expressions::instr::)(args) - } - other => exec_err!("Unsupported data type {other:?} for function instr"), - }), BuiltinScalarFunction::Left => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { let func = invoke_if_unicode_expressions_feature_flag!(left, i32, "left"); @@ -578,27 +552,6 @@ pub fn create_physical_fun( _ => unreachable!(), }, }), - BuiltinScalarFunction::RegexpLike => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - let func = invoke_on_array_if_regex_expressions_feature_flag!( - regexp_like, - i32, - "regexp_like" - ); - make_scalar_function_inner(func)(args) - } - DataType::LargeUtf8 => { - let func = invoke_on_array_if_regex_expressions_feature_flag!( - regexp_like, - i64, - "regexp_like" - ); - make_scalar_function_inner(func)(args) - } - other => { - exec_err!("Unsupported data type {other:?} for function regexp_like") - } - }), BuiltinScalarFunction::RegexpReplace => { Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { @@ -967,8 +920,8 @@ fn func_order_in_one_dimension( #[cfg(test)] mod tests { use super::*; + use crate::expressions::lit; use crate::expressions::try_cast; - use crate::expressions::{col, lit}; use arrow::{ array::{ Array, ArrayRef, BinaryArray, BooleanArray, Float32Array, Float64Array, @@ -977,7 +930,7 @@ mod tests { datatypes::Field, record_batch::RecordBatch, }; - use datafusion_common::cast::{as_boolean_array, as_uint64_array}; + use datafusion_common::cast::as_uint64_array; use datafusion_common::{exec_err, internal_err, plan_err}; use datafusion_common::{DataFusionError, Result, ScalarValue}; use datafusion_expr::type_coercion::functions::data_types; @@ -1364,95 +1317,6 @@ mod tests { Utf8, StringArray ); - test_function!( - InStr, - &[lit("abc"), lit("b")], - Ok(Some(2)), - i32, - Int32, - Int32Array - ); - test_function!( - InStr, - &[lit("abc"), lit("c")], - Ok(Some(3)), - i32, - Int32, - Int32Array - ); - test_function!( - InStr, - &[lit("abc"), lit("d")], - Ok(Some(0)), - i32, - Int32, - Int32Array - ); - test_function!( - InStr, - &[lit("abc"), lit("")], - Ok(Some(1)), - i32, - Int32, - Int32Array - ); - test_function!( - InStr, - &[lit("Helloworld"), lit("world")], - Ok(Some(6)), - i32, - Int32, - Int32Array - ); - test_function!( - InStr, - &[lit("Helloworld"), lit(ScalarValue::Utf8(None))], - Ok(None), - i32, - Int32, - Int32Array - ); - test_function!( - InStr, - &[lit(ScalarValue::Utf8(None)), lit("Hello")], - Ok(None), - i32, - Int32, - Int32Array - ); - test_function!( - InStr, - &[ - lit(ScalarValue::LargeUtf8(Some("Helloworld".to_string()))), - lit(ScalarValue::LargeUtf8(Some("world".to_string()))) - ], - Ok(Some(6)), - i64, - Int64, - Int64Array - ); - test_function!( - InStr, - &[ - lit(ScalarValue::LargeUtf8(None)), - lit(ScalarValue::LargeUtf8(Some("world".to_string()))) - ], - Ok(None), - i64, - Int64, - Int64Array - ); - test_function!( - InStr, - &[ - lit(ScalarValue::LargeUtf8(Some("Helloworld".to_string()))), - lit(ScalarValue::LargeUtf8(None)) - ], - Ok(None), - i64, - Int64, - Int64Array - ); #[cfg(feature = "unicode_expressions")] test_function!( Left, @@ -2613,6 +2477,87 @@ mod tests { Int32Array ); #[cfg(feature = "unicode_expressions")] + test_function!( + Strpos, + &[lit("abc"), lit("d")], + Ok(Some(0)), + i32, + Int32, + Int32Array + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + Strpos, + &[lit("abc"), lit("")], + Ok(Some(1)), + i32, + Int32, + Int32Array + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + Strpos, + &[lit("Helloworld"), lit("world")], + Ok(Some(6)), + i32, + Int32, + Int32Array + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + Strpos, + &[lit("Helloworld"), lit(ScalarValue::Utf8(None))], + Ok(None), + i32, + Int32, + Int32Array + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + Strpos, + &[lit(ScalarValue::Utf8(None)), lit("Hello")], + Ok(None), + i32, + Int32, + Int32Array + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + Strpos, + &[ + lit(ScalarValue::LargeUtf8(Some("Helloworld".to_string()))), + lit(ScalarValue::LargeUtf8(Some("world".to_string()))) + ], + Ok(Some(6)), + i64, + Int64, + Int64Array + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + Strpos, + &[ + lit(ScalarValue::LargeUtf8(None)), + lit(ScalarValue::LargeUtf8(Some("world".to_string()))) + ], + Ok(None), + i64, + Int64, + Int64Array + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + Strpos, + &[ + lit(ScalarValue::LargeUtf8(Some("Helloworld".to_string()))), + lit(ScalarValue::LargeUtf8(None)) + ], + Ok(None), + i64, + Int64, + Int64Array + ); + #[cfg(feature = "unicode_expressions")] test_function!( Strpos, &[lit("josé"), lit("é"),], @@ -3070,74 +3015,6 @@ mod tests { Ok(()) } - #[test] - #[cfg(feature = "regex_expressions")] - fn test_regexp_like() -> Result<()> { - let schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]); - let execution_props = ExecutionProps::new(); - - let col_value: ArrayRef = Arc::new(StringArray::from(vec!["aaa-555"])); - let pattern = lit(r".*-(\d*)"); - let columns: Vec = vec![col_value]; - let expr = create_physical_expr_with_type_coercion( - &BuiltinScalarFunction::RegexpLike, - &[col("a", &schema)?, pattern], - &schema, - &execution_props, - )?; - - // type is correct - assert_eq!(expr.data_type(&schema)?, DataType::Boolean); - - // evaluate works - let batch = RecordBatch::try_new(Arc::new(schema.clone()), columns)?; - let result = expr - .evaluate(&batch)? - .into_array(batch.num_rows()) - .expect("Failed to convert to array"); - - let result = as_boolean_array(&result)?; - - // value is correct - assert!(result.value(0)); - - Ok(()) - } - - #[test] - #[cfg(feature = "regex_expressions")] - fn test_regexp_like_all_literals() -> Result<()> { - let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); - let execution_props = ExecutionProps::new(); - - let col_value = lit("aaa-555"); - let pattern = lit(r".*-(\d*)"); - let columns: Vec = vec![Arc::new(Int32Array::from(vec![1]))]; - let expr = create_physical_expr_with_type_coercion( - &BuiltinScalarFunction::RegexpLike, - &[col_value, pattern], - &schema, - &execution_props, - )?; - - // type is correct - assert_eq!(expr.data_type(&schema)?, DataType::Boolean); - - // evaluate works - let batch = RecordBatch::try_new(Arc::new(schema.clone()), columns)?; - let result = expr - .evaluate(&batch)? - .into_array(batch.num_rows()) - .expect("Failed to convert to array"); - - let result = as_boolean_array(&result)?; - - // value is correct - assert!(result.value(0)); - - Ok(()) - } - // Helper function just for testing. // Returns `expressions` coerced to types compatible with // `signature`, if possible. diff --git a/datafusion/physical-expr/src/physical_expr.rs b/datafusion/physical-expr/src/physical_expr.rs index 567054e2b59e..39b8de81af56 100644 --- a/datafusion/physical-expr/src/physical_expr.rs +++ b/datafusion/physical-expr/src/physical_expr.rs @@ -252,7 +252,7 @@ pub fn with_new_children_if_necessary( .zip(old_children.iter()) .any(|(c1, c2)| !Arc::data_ptr_eq(c1, c2)) { - expr.with_new_children(children) + Ok(expr.with_new_children(children)?) } else { Ok(expr) } diff --git a/datafusion/physical-expr/src/regex_expressions.rs b/datafusion/physical-expr/src/regex_expressions.rs index 846e5801af1c..99e6597dad82 100644 --- a/datafusion/physical-expr/src/regex_expressions.rs +++ b/datafusion/physical-expr/src/regex_expressions.rs @@ -53,78 +53,6 @@ macro_rules! fetch_string_arg { }}; } -/// Tests a string using a regular expression returning true if at -/// least one match, false otherwise. -/// -/// The full list of supported features and syntax can be found at -/// -/// -/// Supported flags can be found at -/// -/// -/// # Examples -/// -/// ```ignore -/// # use datafusion::prelude::*; -/// # use datafusion::error::Result; -/// # #[tokio::main] -/// # async fn main() -> Result<()> { -/// let ctx = SessionContext::new(); -/// let df = ctx.read_csv("tests/data/regex.csv", CsvReadOptions::new()).await?; -/// -/// // use the regexp_like function to test col 'values', -/// // against patterns in col 'patterns' without flags -/// let df = df.with_column( -/// "a", -/// regexp_like(vec![col("values"), col("patterns")]) -/// )?; -/// // use the regexp_like function to test col 'values', -/// // against patterns in col 'patterns' with flags -/// let df = df.with_column( -/// "b", -/// regexp_like(vec![col("values"), col("patterns"), col("flags")]) -/// )?; -/// // literals can be used as well with dataframe calls -/// let df = df.with_column( -/// "c", -/// regexp_like(vec![lit("foobarbequebaz"), lit("(bar)(beque)")]) -/// )?; -/// -/// df.show().await?; -/// -/// # Ok(()) -/// # } -/// ``` -pub fn regexp_like(args: &[ArrayRef]) -> Result { - match args.len() { - 2 => { - let values = as_generic_string_array::(&args[0])?; - let regex = as_generic_string_array::(&args[1])?; - let array = arrow_string::regexp::regexp_is_match_utf8(values, regex, None) - .map_err(|e| arrow_datafusion_err!(e))?; - - Ok(Arc::new(array) as ArrayRef) - } - 3 => { - let values = as_generic_string_array::(&args[0])?; - let regex = as_generic_string_array::(&args[1])?; - let flags = as_generic_string_array::(&args[2])?; - - if flags.iter().any(|s| s == Some("g")) { - return plan_err!("regexp_like() does not support the \"global\" option"); - } - - let array = arrow_string::regexp::regexp_is_match_utf8(values, regex, Some(flags)) - .map_err(|e| arrow_datafusion_err!(e))?; - - Ok(Arc::new(array) as ArrayRef) - } - other => exec_err!( - "regexp_like was called with {other} arguments. It requires at least 2 and at most 3." - ), - } -} - /// Extract a specific group from a string column, using a regular expression. /// /// The full list of supported features and syntax can be found at @@ -487,64 +415,6 @@ mod tests { use super::*; - #[test] - fn test_case_sensitive_regexp_like() { - let values = StringArray::from(vec!["abc"; 5]); - - let patterns = - StringArray::from(vec!["^(a)", "^(A)", "(b|d)", "(B|D)", "^(b|c)"]); - - let mut expected_builder: BooleanBuilder = BooleanBuilder::new(); - expected_builder.append_value(true); - expected_builder.append_value(false); - expected_builder.append_value(true); - expected_builder.append_value(false); - expected_builder.append_value(false); - let expected = expected_builder.finish(); - - let re = regexp_like::(&[Arc::new(values), Arc::new(patterns)]).unwrap(); - - assert_eq!(re.as_ref(), &expected); - } - - #[test] - fn test_case_insensitive_regexp_like() { - let values = StringArray::from(vec!["abc"; 5]); - let patterns = - StringArray::from(vec!["^(a)", "^(A)", "(b|d)", "(B|D)", "^(b|c)"]); - let flags = StringArray::from(vec!["i"; 5]); - - let mut expected_builder: BooleanBuilder = BooleanBuilder::new(); - expected_builder.append_value(true); - expected_builder.append_value(true); - expected_builder.append_value(true); - expected_builder.append_value(true); - expected_builder.append_value(false); - let expected = expected_builder.finish(); - - let re = - regexp_like::(&[Arc::new(values), Arc::new(patterns), Arc::new(flags)]) - .unwrap(); - - assert_eq!(re.as_ref(), &expected); - } - - #[test] - fn test_unsupported_global_flag_regexp_like() { - let values = StringArray::from(vec!["abc"]); - let patterns = StringArray::from(vec!["^(a)"]); - let flags = StringArray::from(vec!["g"]); - - let re_err = - regexp_like::(&[Arc::new(values), Arc::new(patterns), Arc::new(flags)]) - .expect_err("unsupported flag should have failed"); - - assert_eq!( - re_err.strip_backtrace(), - "Error during planning: regexp_like() does not support the \"global\" option" - ); - } - #[test] fn test_case_sensitive_regexp_match() { let values = StringArray::from(vec!["abc"; 5]); diff --git a/datafusion/physical-expr/src/utils/guarantee.rs b/datafusion/physical-expr/src/utils/guarantee.rs index 26ee95f4793c..c249af232bf5 100644 --- a/datafusion/physical-expr/src/utils/guarantee.rs +++ b/datafusion/physical-expr/src/utils/guarantee.rs @@ -470,7 +470,7 @@ mod test { test_analyze( col("b").not_eq(lit(1)).and(col("b").eq(lit(2))), vec![ - // can only be true of b is not 1 and b is is 2 (even though it is redundant) + // can only be true of b is not 1 and b is 2 (even though it is redundant) not_in_guarantee("b", [1]), in_guarantee("b", [2]), ], diff --git a/datafusion/physical-expr/src/utils/mod.rs b/datafusion/physical-expr/src/utils/mod.rs index e14ff2692146..b8e99403d695 100644 --- a/datafusion/physical-expr/src/utils/mod.rs +++ b/datafusion/physical-expr/src/utils/mod.rs @@ -30,7 +30,7 @@ use arrow::array::{make_array, Array, ArrayRef, BooleanArray, MutableArrayData}; use arrow::compute::{and_kleene, is_not_null, SlicesIterator}; use arrow::datatypes::SchemaRef; use datafusion_common::tree_node::{ - Transformed, TreeNode, TreeNodeRewriter, VisitRecursion, + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, }; use datafusion_common::Result; use datafusion_expr::Operator; @@ -130,11 +130,10 @@ pub fn get_indices_of_exprs_strict>>( pub type ExprTreeNode = ExprContext>; -/// This struct facilitates the [TreeNodeRewriter] mechanism to convert a -/// [PhysicalExpr] tree into a DAEG (i.e. an expression DAG) by collecting -/// identical expressions in one node. Caller specifies the node type in the -/// DAEG via the `constructor` argument, which constructs nodes in the DAEG -/// from the [ExprTreeNode] ancillary object. +/// This struct is used to convert a [`PhysicalExpr`] tree into a DAEG (i.e. an expression +/// DAG) by collecting identical expressions in one node. Caller specifies the node type +/// in the DAEG via the `constructor` argument, which constructs nodes in the DAEG from +/// the [`ExprTreeNode`] ancillary object. struct PhysicalExprDAEGBuilder<'a, T, F: Fn(&ExprTreeNode) -> Result> { // The resulting DAEG (expression DAG). graph: StableGraph, @@ -144,16 +143,15 @@ struct PhysicalExprDAEGBuilder<'a, T, F: Fn(&ExprTreeNode) -> Result< constructor: &'a F, } -impl<'a, T, F: Fn(&ExprTreeNode) -> Result> TreeNodeRewriter - for PhysicalExprDAEGBuilder<'a, T, F> +impl<'a, T, F: Fn(&ExprTreeNode) -> Result> + PhysicalExprDAEGBuilder<'a, T, F> { - type N = ExprTreeNode; // This method mutates an expression node by transforming it to a physical expression // and adding it to the graph. The method returns the mutated expression node. fn mutate( &mut self, mut node: ExprTreeNode, - ) -> Result> { + ) -> Result>> { // Get the expression associated with the input expression node. let expr = &node.expr; @@ -176,7 +174,7 @@ impl<'a, T, F: Fn(&ExprTreeNode) -> Result> TreeNodeRewriter // Set the data field of the input expression node to the corresponding node index. node.data = Some(node_idx); // Return the mutated expression node. - Ok(node) + Ok(Transformed::yes(node)) } } @@ -197,7 +195,9 @@ where constructor, }; // Use the builder to transform the expression tree node into a DAG. - let root = init.rewrite(&mut builder)?; + let root = init + .transform_up_mut(&mut |node| builder.mutate(node)) + .data()?; // Return a tuple containing the root node index and the DAG. Ok((root.data.unwrap(), builder.graph)) } @@ -211,7 +211,7 @@ pub fn collect_columns(expr: &Arc) -> HashSet { columns.insert(column.clone()); } } - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) }) // pre_visit always returns OK, so this will always too .expect("no way to return error during recursion"); @@ -234,13 +234,14 @@ pub fn reassign_predicate_columns( Err(_) if ignore_not_found => usize::MAX, Err(e) => return Err(e.into()), }; - return Ok(Transformed::Yes(Arc::new(Column::new( + return Ok(Transformed::yes(Arc::new(Column::new( column.name(), index, )))); } - Ok(Transformed::No(expr)) + Ok(Transformed::no(expr)) }) + .data() } /// Reverses the ORDER BY expression, which is useful during equivalent window diff --git a/datafusion/physical-expr/src/window/lead_lag.rs b/datafusion/physical-expr/src/window/lead_lag.rs index 6e1aad575f6a..1d6cfc6b0418 100644 --- a/datafusion/physical-expr/src/window/lead_lag.rs +++ b/datafusion/physical-expr/src/window/lead_lag.rs @@ -206,6 +206,7 @@ impl PartitionEvaluator for WindowShiftEvaluator { fn get_range(&self, idx: usize, n_rows: usize) -> Result> { if self.is_lag() { let start = if self.non_null_offsets.len() == self.shift_offset as usize { + // How many rows needed previous than the current row to get necessary lag result let offset: usize = self.non_null_offsets.iter().sum(); idx.saturating_sub(offset + 1) } else { @@ -214,8 +215,13 @@ impl PartitionEvaluator for WindowShiftEvaluator { let end = idx + 1; Ok(Range { start, end }) } else { - let offset = (-self.shift_offset) as usize; - let end = min(idx + offset, n_rows); + let end = if self.non_null_offsets.len() == (-self.shift_offset) as usize { + // How many rows needed further than the current row to get necessary lead result + let offset: usize = self.non_null_offsets.iter().sum(); + min(idx + offset + 1, n_rows) + } else { + n_rows + }; Ok(Range { start: idx, end }) } } @@ -244,10 +250,10 @@ impl PartitionEvaluator for WindowShiftEvaluator { range.start as i64 - self.shift_offset }; - // Support LAG only for now, as LEAD requires some brainstorm first // LAG with IGNORE NULLS calculated as the current row index - offset, but only for non-NULL rows // If current row index points to NULL value the row is NOT counted if self.ignore_nulls && self.is_lag() { + // LAG when NULLS are ignored. // Find the nonNULL row index that shifted by offset comparing to current row index idx = if self.non_null_offsets.len() == self.shift_offset as usize { let total_offset: usize = self.non_null_offsets.iter().sum(); @@ -270,10 +276,55 @@ impl PartitionEvaluator for WindowShiftEvaluator { self.non_null_offsets[end_idx] += 1; } } else if self.ignore_nulls && !self.is_lag() { - // IGNORE NULLS and LEAD mode. - return Err(exec_datafusion_err!( - "IGNORE NULLS mode for LEAD is not supported for BoundedWindowAggExec" - )); + // LEAD when NULLS are ignored. + // Stores the necessary non-null entry number further than the current row. + let non_null_row_count = (-self.shift_offset) as usize; + + if self.non_null_offsets.is_empty() { + // When empty, fill non_null offsets with the data further than the current row. + let mut offset_val = 1; + for idx in range.start + 1..range.end { + if array.is_valid(idx) { + self.non_null_offsets.push_back(offset_val); + offset_val = 1; + } else { + offset_val += 1; + } + // It is enough to keep track of `non_null_row_count + 1` non-null offset. + // further data is unnecessary for the result. + if self.non_null_offsets.len() == non_null_row_count + 1 { + break; + } + } + } else if range.end < len as usize && array.is_valid(range.end) { + // Update `non_null_offsets` with the new end data. + if array.is_valid(range.end) { + // When non-null, append a new offset. + self.non_null_offsets.push_back(1); + } else { + // When null, increment offset count of the last entry + let last_idx = self.non_null_offsets.len() - 1; + self.non_null_offsets[last_idx] += 1; + } + } + + // Find the nonNULL row index that shifted by offset comparing to current row index + idx = if self.non_null_offsets.len() >= non_null_row_count { + let total_offset: usize = + self.non_null_offsets.iter().take(non_null_row_count).sum(); + (range.start + total_offset) as i64 + } else { + -1 + }; + // Prune `self.non_null_offsets` from the start. so that at next iteration + // start of the `self.non_null_offsets` matches with current row. + if !self.non_null_offsets.is_empty() { + self.non_null_offsets[0] -= 1; + if self.non_null_offsets[0] == 0 { + // When offset is 0. Remove it. + self.non_null_offsets.pop_front(); + } + } } // Set the default value if diff --git a/datafusion/physical-plan/Cargo.toml b/datafusion/physical-plan/Cargo.toml index b4621109d2b1..72ee4fb3ef7e 100644 --- a/datafusion/physical-plan/Cargo.toml +++ b/datafusion/physical-plan/Cargo.toml @@ -43,6 +43,7 @@ arrow-schema = { workspace = true } async-trait = { workspace = true } chrono = { workspace = true } datafusion-common = { workspace = true, default-features = true } +datafusion-common-runtime = { workspace = true, default-features = true } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } datafusion-physical-expr = { workspace = true, default-features = true } diff --git a/datafusion/physical-plan/src/common.rs b/datafusion/physical-plan/src/common.rs index 47cdf3e400e3..656bffd4a799 100644 --- a/datafusion/physical-plan/src/common.rs +++ b/datafusion/physical-plan/src/common.rs @@ -36,9 +36,8 @@ use datafusion_execution::memory_pool::MemoryReservation; use datafusion_physical_expr::expressions::{BinaryExpr, Column}; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; -use futures::{Future, StreamExt, TryStreamExt}; +use futures::{StreamExt, TryStreamExt}; use parking_lot::Mutex; -use tokio::task::{JoinError, JoinSet}; /// [`MemoryReservation`] used across query execution streams pub(crate) type SharedMemoryReservation = Arc>; @@ -172,46 +171,6 @@ pub fn compute_record_batch_statistics( } } -/// Helper that provides a simple API to spawn a single task and join it. -/// Provides guarantees of aborting on `Drop` to keep it cancel-safe. -/// -/// Technically, it's just a wrapper of `JoinSet` (with size=1). -#[derive(Debug)] -pub struct SpawnedTask { - inner: JoinSet, -} - -impl SpawnedTask { - pub fn spawn(task: T) -> Self - where - T: Future, - T: Send + 'static, - R: Send, - { - let mut inner = JoinSet::new(); - inner.spawn(task); - Self { inner } - } - - pub fn spawn_blocking(task: T) -> Self - where - T: FnOnce() -> R, - T: Send + 'static, - R: Send, - { - let mut inner = JoinSet::new(); - inner.spawn_blocking(task); - Self { inner } - } - - pub async fn join(mut self) -> Result { - self.inner - .join_next() - .await - .expect("`SpawnedTask` instance always contains exactly 1 task") - } -} - /// Transposes the given vector of vectors. pub fn transpose(original: Vec>) -> Vec> { match original.as_slice() { diff --git a/datafusion/physical-plan/src/empty.rs b/datafusion/physical-plan/src/empty.rs index ebd92efb4cd2..4ff79cdaae70 100644 --- a/datafusion/physical-plan/src/empty.rs +++ b/datafusion/physical-plan/src/empty.rs @@ -182,7 +182,7 @@ mod tests { let schema = test::aggr_test_schema(); let empty = Arc::new(EmptyExec::new(schema.clone())); - let empty2 = with_new_children_if_necessary(empty.clone(), vec![])?.into(); + let empty2 = with_new_children_if_necessary(empty.clone(), vec![])?; assert_eq!(empty.schema(), empty2.schema()); let too_many_kids = vec![empty2]; diff --git a/datafusion/physical-plan/src/joins/stream_join_utils.rs b/datafusion/physical-plan/src/joins/stream_join_utils.rs index 9a4c98927683..9824c723d9d1 100644 --- a/datafusion/physical-plan/src/joins/stream_join_utils.rs +++ b/datafusion/physical-plan/src/joins/stream_join_utils.rs @@ -31,7 +31,7 @@ use arrow::compute::concat_batches; use arrow_array::{ArrowPrimitiveType, NativeAdapter, PrimitiveArray, RecordBatch}; use arrow_buffer::{ArrowNativeType, BooleanBufferBuilder}; use arrow_schema::{Schema, SchemaRef}; -use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{ arrow_datafusion_err, plan_datafusion_err, DataFusionError, JoinSide, Result, ScalarValue, @@ -284,14 +284,16 @@ pub fn convert_sort_expr_with_filter_schema( if all_columns_are_included { // Since we are sure that one to one column mapping includes all columns, we convert // the sort expression into a filter expression. - let converted_filter_expr = expr.transform_up(&|p| { - convert_filter_columns(p.as_ref(), &column_map).map(|transformed| { - match transformed { - Some(transformed) => Transformed::Yes(transformed), - None => Transformed::No(p), - } + let converted_filter_expr = expr + .transform_up(&|p| { + convert_filter_columns(p.as_ref(), &column_map).map(|transformed| { + match transformed { + Some(transformed) => Transformed::yes(transformed), + None => Transformed::no(p), + } + }) }) - })?; + .data()?; // Search the converted `PhysicalExpr` in filter expression; if an exact // match is found, use this sorted expression in graph traversals. if check_filter_expr_contains_sort_information( diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index 3dac0107d3ef..1cb2b100e2d6 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -39,6 +39,7 @@ use arrow_array::{ArrowPrimitiveType, NativeAdapter, PrimitiveArray}; use arrow_buffer::ArrowNativeType; use datafusion_common::cast::as_boolean_array; use datafusion_common::stats::Precision; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{ plan_err, DataFusionError, JoinSide, JoinType, Result, SharedResult, }; @@ -50,7 +51,6 @@ use datafusion_physical_expr::{ LexOrdering, LexOrderingRef, PhysicalExpr, PhysicalExprRef, PhysicalSortExpr, }; -use datafusion_common::tree_node::{Transformed, TreeNode}; use futures::future::{BoxFuture, Shared}; use futures::{ready, FutureExt}; use hashbrown::raw::RawTable; @@ -475,13 +475,17 @@ fn replace_on_columns_of_right_ordering( ) -> Result<()> { for (left_col, right_col) in on_columns { for item in right_ordering.iter_mut() { - let new_expr = item.expr.clone().transform(&|e| { - if e.eq(right_col) { - Ok(Transformed::Yes(left_col.clone())) - } else { - Ok(Transformed::No(e)) - } - })?; + let new_expr = item + .expr + .clone() + .transform(&|e| { + if e.eq(right_col) { + Ok(Transformed::yes(left_col.clone())) + } else { + Ok(Transformed::no(e)) + } + }) + .data()?; item.expr = new_expr; } } diff --git a/datafusion/physical-plan/src/lib.rs b/datafusion/physical-plan/src/lib.rs index 006cd646b0ca..6334a4a211d4 100644 --- a/datafusion/physical-plan/src/lib.rs +++ b/datafusion/physical-plan/src/lib.rs @@ -30,7 +30,6 @@ use crate::sorts::sort_preserving_merge::SortPreservingMergeExec; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::Transformed; use datafusion_common::utils::DataPtr; use datafusion_common::Result; use datafusion_execution::TaskContext; @@ -264,7 +263,7 @@ pub trait ExecutionPlan: Debug + DisplayAs + Send + Sync { /// /// [`spawn`]: tokio::task::spawn /// [`JoinSet`]: tokio::task::JoinSet - /// [`SpawnedTask`]: crate::common::SpawnedTask + /// [`SpawnedTask`]: datafusion_common_runtime::SpawnedTask /// [`RecordBatchReceiverStreamBuilder`]: crate::stream::RecordBatchReceiverStreamBuilder /// /// # Implementation Examples @@ -341,7 +340,7 @@ pub trait ExecutionPlan: Debug + DisplayAs + Send + Sync { /// /// ## Lazily (async) create a Stream /// - /// If you need to to create the return `Stream` using an `async` function, + /// If you need to create the return `Stream` using an `async` function, /// you can do so by flattening the result: /// /// ``` @@ -652,7 +651,7 @@ pub fn need_data_exchange(plan: Arc) -> bool { pub fn with_new_children_if_necessary( plan: Arc, children: Vec>, -) -> Result>> { +) -> Result> { let old_children = plan.children(); if children.len() != old_children.len() { internal_err!("Wrong number of children") @@ -662,9 +661,9 @@ pub fn with_new_children_if_necessary( .zip(old_children.iter()) .any(|(c1, c2)| !Arc::data_ptr_eq(c1, c2)) { - Ok(Transformed::Yes(plan.with_new_children(children)?)) + plan.with_new_children(children) } else { - Ok(Transformed::No(plan)) + Ok(plan) } } diff --git a/datafusion/physical-plan/src/metrics/builder.rs b/datafusion/physical-plan/src/metrics/builder.rs index beecc13e0029..5e8ff72df35c 100644 --- a/datafusion/physical-plan/src/metrics/builder.rs +++ b/datafusion/physical-plan/src/metrics/builder.rs @@ -183,7 +183,7 @@ impl<'a> MetricBuilder<'a> { } /// Consumes self and creates a new Timer for recording some - /// subset of of an operators execution time. + /// subset of an operators execution time. pub fn subset_time( self, subset_name: impl Into>, diff --git a/datafusion/physical-plan/src/placeholder_row.rs b/datafusion/physical-plan/src/placeholder_row.rs index 37d209a3b473..3880cf3d77af 100644 --- a/datafusion/physical-plan/src/placeholder_row.rs +++ b/datafusion/physical-plan/src/placeholder_row.rs @@ -184,8 +184,7 @@ mod tests { let placeholder = Arc::new(PlaceholderRowExec::new(schema)); - let placeholder_2 = - with_new_children_if_necessary(placeholder.clone(), vec![])?.into(); + let placeholder_2 = with_new_children_if_necessary(placeholder.clone(), vec![])?; assert_eq!(placeholder.schema(), placeholder_2.schema()); let too_many_kids = vec![placeholder_2]; diff --git a/datafusion/physical-plan/src/recursive_query.rs b/datafusion/physical-plan/src/recursive_query.rs index 9786b1cbf6fd..2e4b97bc224b 100644 --- a/datafusion/physical-plan/src/recursive_query.rs +++ b/datafusion/physical-plan/src/recursive_query.rs @@ -30,7 +30,7 @@ use crate::{DisplayAs, DisplayFormatType, ExecutionMode, ExecutionPlan}; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; -use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{not_impl_err, DataFusionError, Result}; use datafusion_execution::TaskContext; use datafusion_physical_expr::{EquivalenceProperties, Partitioning}; @@ -317,16 +317,17 @@ fn assign_work_table( ) } else { work_table_refs += 1; - Ok(Transformed::Yes(Arc::new( + Ok(Transformed::yes(Arc::new( exec.with_work_table(work_table.clone()), ))) } } else if plan.as_any().is::() { not_impl_err!("Recursive queries cannot be nested") } else { - Ok(Transformed::No(plan)) + Ok(Transformed::no(plan)) } }) + .data() } impl Stream for RecursiveQueryStream { diff --git a/datafusion/physical-plan/src/repartition/mod.rs b/datafusion/physical-plan/src/repartition/mod.rs index fe93ea131506..7ac70949f893 100644 --- a/datafusion/physical-plan/src/repartition/mod.rs +++ b/datafusion/physical-plan/src/repartition/mod.rs @@ -29,7 +29,7 @@ use super::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}; use super::{ DisplayAs, ExecutionPlanProperties, RecordBatchStream, SendableRecordBatchStream, }; -use crate::common::{transpose, SpawnedTask}; +use crate::common::transpose; use crate::hash_utils::create_hashes; use crate::metrics::BaselineMetrics; use crate::repartition::distributor_channels::{ @@ -42,6 +42,7 @@ use arrow::array::{ArrayRef, UInt64Builder}; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use datafusion_common::{arrow_datafusion_err, not_impl_err, DataFusionError, Result}; +use datafusion_common_runtime::SpawnedTask; use datafusion_execution::memory_pool::MemoryConsumer; use datafusion_execution::TaskContext; use datafusion_physical_expr::{EquivalenceProperties, PhysicalExpr, PhysicalSortExpr}; @@ -946,7 +947,6 @@ mod tests { use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use futures::FutureExt; - use tokio::task::JoinHandle; #[tokio::test] async fn one_to_many_round_robin() -> Result<()> { @@ -1060,10 +1060,9 @@ mod tests { } #[tokio::test] - #[allow(clippy::disallowed_methods)] async fn many_to_many_round_robin_within_tokio_task() -> Result<()> { - let join_handle: JoinHandle>>> = - tokio::spawn(async move { + let handle: SpawnedTask>>> = + SpawnedTask::spawn(async move { // define input partitions let schema = test_schema(); let partition = create_vec_batches(50); @@ -1074,7 +1073,7 @@ mod tests { repartition(&schema, partitions, Partitioning::RoundRobinBatch(5)).await }); - let output_partitions = join_handle.await.unwrap().unwrap(); + let output_partitions = handle.join().await.unwrap().unwrap(); assert_eq!(5, output_partitions.len()); assert_eq!(30, output_partitions[0].len()); diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs index f46958663252..db352bb2c86f 100644 --- a/datafusion/physical-plan/src/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -27,7 +27,7 @@ use std::io::BufReader; use std::path::{Path, PathBuf}; use std::sync::Arc; -use crate::common::{spawn_buffered, IPCWriter, SpawnedTask}; +use crate::common::{spawn_buffered, IPCWriter}; use crate::expressions::PhysicalSortExpr; use crate::metrics::{ BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet, @@ -41,11 +41,15 @@ use crate::{ SendableRecordBatchStream, Statistics, }; -use arrow::compute::{concat_batches, lexsort_to_indices, take}; +use arrow::compute::{concat_batches, lexsort_to_indices, take, SortColumn}; use arrow::datatypes::SchemaRef; use arrow::ipc::reader::FileReader; use arrow::record_batch::RecordBatch; +use arrow::row::{RowConverter, SortField}; +use arrow_array::{Array, UInt32Array}; +use arrow_schema::DataType; use datafusion_common::{exec_err, DataFusionError, Result}; +use datafusion_common_runtime::SpawnedTask; use datafusion_execution::disk_manager::RefCountedTempFile; use datafusion_execution::memory_pool::{ human_readable_size, MemoryConsumer, MemoryReservation, @@ -587,7 +591,13 @@ pub(crate) fn sort_batch( .map(|expr| expr.evaluate_to_sort_column(batch)) .collect::>>()?; - let indices = lexsort_to_indices(&sort_columns, fetch)?; + let indices = if is_multi_column_with_lists(&sort_columns) { + // lex_sort_to_indices doesn't support List with more than one colum + // https://github.com/apache/arrow-rs/issues/5454 + lexsort_to_indices_multi_columns(sort_columns, fetch)? + } else { + lexsort_to_indices(&sort_columns, fetch)? + }; let columns = batch .columns() @@ -598,6 +608,48 @@ pub(crate) fn sort_batch( Ok(RecordBatch::try_new(batch.schema(), columns)?) } +#[inline] +fn is_multi_column_with_lists(sort_columns: &[SortColumn]) -> bool { + sort_columns.iter().any(|c| { + matches!( + c.values.data_type(), + DataType::List(_) | DataType::LargeList(_) | DataType::FixedSizeList(_, _) + ) + }) +} + +pub(crate) fn lexsort_to_indices_multi_columns( + sort_columns: Vec, + limit: Option, +) -> Result { + let (fields, columns) = sort_columns.into_iter().fold( + (vec![], vec![]), + |(mut fields, mut columns), sort_column| { + fields.push(SortField::new_with_options( + sort_column.values.data_type().clone(), + sort_column.options.unwrap_or_default(), + )); + columns.push(sort_column.values); + (fields, columns) + }, + ); + + // TODO reuse converter and rows, refer to TopK. + let converter = RowConverter::new(fields)?; + let rows = converter.convert_columns(&columns)?; + let mut sort: Vec<_> = rows.iter().enumerate().collect(); + sort.sort_unstable_by(|(_, a), (_, b)| a.cmp(b)); + + let mut len = rows.num_rows(); + if let Some(limit) = limit { + len = limit.min(len); + } + let indices = + UInt32Array::from_iter_values(sort.iter().take(len).map(|(i, _)| *i as u32)); + + Ok(indices) +} + async fn spill_sorted_batches( batches: Vec, path: &Path, @@ -1158,6 +1210,82 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_lex_sort_by_mixed_types() -> Result<()> { + let task_ctx = Arc::new(TaskContext::default()); + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new( + "b", + DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + true, + ), + ])); + + // define data. + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![Some(2), None, Some(1), Some(2)])), + Arc::new(ListArray::from_iter_primitive::(vec![ + Some(vec![Some(3)]), + Some(vec![Some(1)]), + Some(vec![Some(6), None]), + Some(vec![Some(5)]), + ])), + ], + )?; + + let sort_exec = Arc::new(SortExec::new( + vec![ + PhysicalSortExpr { + expr: col("a", &schema)?, + options: SortOptions { + descending: false, + nulls_first: true, + }, + }, + PhysicalSortExpr { + expr: col("b", &schema)?, + options: SortOptions { + descending: true, + nulls_first: false, + }, + }, + ], + Arc::new(MemoryExec::try_new(&[vec![batch]], schema.clone(), None)?), + )); + + assert_eq!(DataType::Int32, *sort_exec.schema().field(0).data_type()); + assert_eq!( + DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + *sort_exec.schema().field(1).data_type() + ); + + let result: Vec = collect(sort_exec.clone(), task_ctx).await?; + let metrics = sort_exec.metrics().unwrap(); + assert!(metrics.elapsed_compute().unwrap() > 0); + assert_eq!(metrics.output_rows().unwrap(), 4); + assert_eq!(result.len(), 1); + + let expected = RecordBatch::try_new( + schema, + vec![ + Arc::new(Int32Array::from(vec![None, Some(1), Some(2), Some(2)])), + Arc::new(ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1)]), + Some(vec![Some(6), None]), + Some(vec![Some(5)]), + Some(vec![Some(3)]), + ])), + ], + )?; + + assert_eq!(expected, result[0]); + + Ok(()) + } + #[tokio::test] async fn test_lex_sort_by_float() -> Result<()> { let task_ctx = Arc::new(TaskContext::default()); diff --git a/datafusion/physical-plan/src/test/exec.rs b/datafusion/physical-plan/src/test/exec.rs index 23df3753e817..b4f1eac0a655 100644 --- a/datafusion/physical-plan/src/test/exec.rs +++ b/datafusion/physical-plan/src/test/exec.rs @@ -118,7 +118,7 @@ pub struct MockExec { /// the results to send back data: Vec>, schema: SchemaRef, - /// if true (the default), sends data using a separate task to to ensure the + /// if true (the default), sends data using a separate task to ensure the /// batches are not available without this stream yielding first use_task: bool, cache: PlanProperties, diff --git a/datafusion/physical-plan/src/tree_node.rs b/datafusion/physical-plan/src/tree_node.rs index b8a5f95c5325..52a52f81bdaf 100644 --- a/datafusion/physical-plan/src/tree_node.rs +++ b/datafusion/physical-plan/src/tree_node.rs @@ -22,7 +22,7 @@ use std::sync::Arc; use crate::{displayable, with_new_children_if_necessary, ExecutionPlan}; -use datafusion_common::tree_node::{ConcreteTreeNode, DynTreeNode, Transformed}; +use datafusion_common::tree_node::{ConcreteTreeNode, DynTreeNode}; use datafusion_common::Result; impl DynTreeNode for dyn ExecutionPlan { @@ -35,7 +35,7 @@ impl DynTreeNode for dyn ExecutionPlan { arc_self: Arc, new_children: Vec>, ) -> Result> { - with_new_children_if_necessary(arc_self, new_children).map(Transformed::into) + with_new_children_if_necessary(arc_self, new_children) } } @@ -63,7 +63,7 @@ impl PlanContext { pub fn update_plan_from_children(mut self) -> Result { let children_plans = self.children.iter().map(|c| c.plan.clone()).collect(); - self.plan = with_new_children_if_necessary(self.plan, children_plans)?.into(); + self.plan = with_new_children_if_necessary(self.plan, children_plans)?; Ok(self) } } diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 1ad7a2c3afaf..c47b9abadb0e 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -636,17 +636,17 @@ enum ScalarFunction { Gcd = 85; ArrayAppend = 86; ArrayConcat = 87; - ArrayDims = 88; + // 88 was ArrayDims ArrayRepeat = 89; ArrayLength = 90; - ArrayNdims = 91; + // 91 was ArrayNdims ArrayPosition = 92; ArrayPositions = 93; ArrayPrepend = 94; ArrayRemove = 95; ArrayReplace = 96; // 97 was ArrayToString - Cardinality = 98; + // 98 was Cardinality ArrayElement = 99; ArraySlice = 100; Cot = 103; @@ -678,10 +678,10 @@ enum ScalarFunction { ArrayDistinct = 129; ArrayResize = 130; EndsWith = 131; - InStr = 132; + /// 132 was InStr MakeDate = 133; ArrayReverse = 134; - RegexpLike = 135; + /// 135 is RegexpLike ToChar = 136; /// 137 was ToDate } @@ -751,6 +751,7 @@ message AggregateUDFExprNode { message ScalarUDFExprNode { string fun_name = 1; repeated LogicalExprNode args = 2; + optional bytes fun_definition = 3; } enum BuiltInWindowFunction { diff --git a/datafusion/proto/src/bytes/mod.rs b/datafusion/proto/src/bytes/mod.rs index d4abb9ed9c6f..610c533d574c 100644 --- a/datafusion/proto/src/bytes/mod.rs +++ b/datafusion/proto/src/bytes/mod.rs @@ -16,6 +16,7 @@ // under the License. //! Serialization / Deserialization to Bytes +use crate::logical_plan::to_proto::serialize_expr; use crate::logical_plan::{ self, AsLogicalPlan, DefaultLogicalExtensionCodec, LogicalExtensionCodec, }; @@ -87,8 +88,8 @@ pub trait Serializeable: Sized { impl Serializeable for Expr { fn to_bytes(&self) -> Result { let mut buffer = BytesMut::new(); - let protobuf: protobuf::LogicalExprNode = self - .try_into() + let extension_codec = DefaultLogicalExtensionCodec {}; + let protobuf: protobuf::LogicalExprNode = serialize_expr(self, &extension_codec) .map_err(|e| plan_datafusion_err!("Error encoding expr as protobuf: {e}"))?; protobuf @@ -177,7 +178,8 @@ impl Serializeable for Expr { let protobuf = protobuf::LogicalExprNode::decode(bytes) .map_err(|e| plan_datafusion_err!("Error decoding expr as protobuf: {e}"))?; - logical_plan::from_proto::parse_expr(&protobuf, registry) + let extension_codec = DefaultLogicalExtensionCodec {}; + logical_plan::from_proto::parse_expr(&protobuf, registry, &extension_codec) .map_err(|e| plan_datafusion_err!("Error parsing protobuf into Expr: {e}")) } } diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 33ebdf310ae0..c9be1bb7f3e5 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -22401,16 +22401,13 @@ impl serde::Serialize for ScalarFunction { Self::Gcd => "Gcd", Self::ArrayAppend => "ArrayAppend", Self::ArrayConcat => "ArrayConcat", - Self::ArrayDims => "ArrayDims", Self::ArrayRepeat => "ArrayRepeat", Self::ArrayLength => "ArrayLength", - Self::ArrayNdims => "ArrayNdims", Self::ArrayPosition => "ArrayPosition", Self::ArrayPositions => "ArrayPositions", Self::ArrayPrepend => "ArrayPrepend", Self::ArrayRemove => "ArrayRemove", Self::ArrayReplace => "ArrayReplace", - Self::Cardinality => "Cardinality", Self::ArrayElement => "ArrayElement", Self::ArraySlice => "ArraySlice", Self::Cot => "Cot", @@ -22439,10 +22436,8 @@ impl serde::Serialize for ScalarFunction { Self::ArrayDistinct => "ArrayDistinct", Self::ArrayResize => "ArrayResize", Self::EndsWith => "EndsWith", - Self::InStr => "InStr", Self::MakeDate => "MakeDate", Self::ArrayReverse => "ArrayReverse", - Self::RegexpLike => "RegexpLike", Self::ToChar => "ToChar", }; serializer.serialize_str(variant) @@ -22535,16 +22530,13 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Gcd", "ArrayAppend", "ArrayConcat", - "ArrayDims", "ArrayRepeat", "ArrayLength", - "ArrayNdims", "ArrayPosition", "ArrayPositions", "ArrayPrepend", "ArrayRemove", "ArrayReplace", - "Cardinality", "ArrayElement", "ArraySlice", "Cot", @@ -22573,10 +22565,8 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ArrayDistinct", "ArrayResize", "EndsWith", - "InStr", "MakeDate", "ArrayReverse", - "RegexpLike", "ToChar", ]; @@ -22698,16 +22688,13 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Gcd" => Ok(ScalarFunction::Gcd), "ArrayAppend" => Ok(ScalarFunction::ArrayAppend), "ArrayConcat" => Ok(ScalarFunction::ArrayConcat), - "ArrayDims" => Ok(ScalarFunction::ArrayDims), "ArrayRepeat" => Ok(ScalarFunction::ArrayRepeat), "ArrayLength" => Ok(ScalarFunction::ArrayLength), - "ArrayNdims" => Ok(ScalarFunction::ArrayNdims), "ArrayPosition" => Ok(ScalarFunction::ArrayPosition), "ArrayPositions" => Ok(ScalarFunction::ArrayPositions), "ArrayPrepend" => Ok(ScalarFunction::ArrayPrepend), "ArrayRemove" => Ok(ScalarFunction::ArrayRemove), "ArrayReplace" => Ok(ScalarFunction::ArrayReplace), - "Cardinality" => Ok(ScalarFunction::Cardinality), "ArrayElement" => Ok(ScalarFunction::ArrayElement), "ArraySlice" => Ok(ScalarFunction::ArraySlice), "Cot" => Ok(ScalarFunction::Cot), @@ -22736,10 +22723,8 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ArrayDistinct" => Ok(ScalarFunction::ArrayDistinct), "ArrayResize" => Ok(ScalarFunction::ArrayResize), "EndsWith" => Ok(ScalarFunction::EndsWith), - "InStr" => Ok(ScalarFunction::InStr), "MakeDate" => Ok(ScalarFunction::MakeDate), "ArrayReverse" => Ok(ScalarFunction::ArrayReverse), - "RegexpLike" => Ok(ScalarFunction::RegexpLike), "ToChar" => Ok(ScalarFunction::ToChar), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } @@ -23381,6 +23366,9 @@ impl serde::Serialize for ScalarUdfExprNode { if !self.args.is_empty() { len += 1; } + if self.fun_definition.is_some() { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.ScalarUDFExprNode", len)?; if !self.fun_name.is_empty() { struct_ser.serialize_field("funName", &self.fun_name)?; @@ -23388,6 +23376,10 @@ impl serde::Serialize for ScalarUdfExprNode { if !self.args.is_empty() { struct_ser.serialize_field("args", &self.args)?; } + if let Some(v) = self.fun_definition.as_ref() { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("funDefinition", pbjson::private::base64::encode(&v).as_str())?; + } struct_ser.end() } } @@ -23401,12 +23393,15 @@ impl<'de> serde::Deserialize<'de> for ScalarUdfExprNode { "fun_name", "funName", "args", + "fun_definition", + "funDefinition", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { FunName, Args, + FunDefinition, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -23430,6 +23425,7 @@ impl<'de> serde::Deserialize<'de> for ScalarUdfExprNode { match value { "funName" | "fun_name" => Ok(GeneratedField::FunName), "args" => Ok(GeneratedField::Args), + "funDefinition" | "fun_definition" => Ok(GeneratedField::FunDefinition), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -23451,6 +23447,7 @@ impl<'de> serde::Deserialize<'de> for ScalarUdfExprNode { { let mut fun_name__ = None; let mut args__ = None; + let mut fun_definition__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::FunName => { @@ -23465,11 +23462,20 @@ impl<'de> serde::Deserialize<'de> for ScalarUdfExprNode { } args__ = Some(map_.next_value()?); } + GeneratedField::FunDefinition => { + if fun_definition__.is_some() { + return Err(serde::de::Error::duplicate_field("funDefinition")); + } + fun_definition__ = + map_.next_value::<::std::option::Option<::pbjson::private::BytesDeserialize<_>>>()?.map(|x| x.0) + ; + } } } Ok(ScalarUdfExprNode { fun_name: fun_name__.unwrap_or_default(), args: args__.unwrap_or_default(), + fun_definition: fun_definition__, }) } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 2d21f15570dd..4d19b79a3b2c 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -895,6 +895,8 @@ pub struct ScalarUdfExprNode { pub fun_name: ::prost::alloc::string::String, #[prost(message, repeated, tag = "2")] pub args: ::prost::alloc::vec::Vec, + #[prost(bytes = "vec", optional, tag = "3")] + pub fun_definition: ::core::option::Option<::prost::alloc::vec::Vec>, } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] @@ -2723,17 +2725,17 @@ pub enum ScalarFunction { Gcd = 85, ArrayAppend = 86, ArrayConcat = 87, - ArrayDims = 88, + /// 88 was ArrayDims ArrayRepeat = 89, ArrayLength = 90, - ArrayNdims = 91, + /// 91 was ArrayNdims ArrayPosition = 92, ArrayPositions = 93, ArrayPrepend = 94, ArrayRemove = 95, ArrayReplace = 96, /// 97 was ArrayToString - Cardinality = 98, + /// 98 was Cardinality ArrayElement = 99, ArraySlice = 100, Cot = 103, @@ -2765,10 +2767,11 @@ pub enum ScalarFunction { ArrayDistinct = 129, ArrayResize = 130, EndsWith = 131, - InStr = 132, + /// / 132 was InStr MakeDate = 133, ArrayReverse = 134, - RegexpLike = 135, + /// / 135 is RegexpLike + /// /// / 137 was ToDate ToChar = 136, } @@ -2859,16 +2862,13 @@ impl ScalarFunction { ScalarFunction::Gcd => "Gcd", ScalarFunction::ArrayAppend => "ArrayAppend", ScalarFunction::ArrayConcat => "ArrayConcat", - ScalarFunction::ArrayDims => "ArrayDims", ScalarFunction::ArrayRepeat => "ArrayRepeat", ScalarFunction::ArrayLength => "ArrayLength", - ScalarFunction::ArrayNdims => "ArrayNdims", ScalarFunction::ArrayPosition => "ArrayPosition", ScalarFunction::ArrayPositions => "ArrayPositions", ScalarFunction::ArrayPrepend => "ArrayPrepend", ScalarFunction::ArrayRemove => "ArrayRemove", ScalarFunction::ArrayReplace => "ArrayReplace", - ScalarFunction::Cardinality => "Cardinality", ScalarFunction::ArrayElement => "ArrayElement", ScalarFunction::ArraySlice => "ArraySlice", ScalarFunction::Cot => "Cot", @@ -2897,10 +2897,8 @@ impl ScalarFunction { ScalarFunction::ArrayDistinct => "ArrayDistinct", ScalarFunction::ArrayResize => "ArrayResize", ScalarFunction::EndsWith => "EndsWith", - ScalarFunction::InStr => "InStr", ScalarFunction::MakeDate => "MakeDate", ScalarFunction::ArrayReverse => "ArrayReverse", - ScalarFunction::RegexpLike => "RegexpLike", ScalarFunction::ToChar => "ToChar", } } @@ -2987,16 +2985,13 @@ impl ScalarFunction { "Gcd" => Some(Self::Gcd), "ArrayAppend" => Some(Self::ArrayAppend), "ArrayConcat" => Some(Self::ArrayConcat), - "ArrayDims" => Some(Self::ArrayDims), "ArrayRepeat" => Some(Self::ArrayRepeat), "ArrayLength" => Some(Self::ArrayLength), - "ArrayNdims" => Some(Self::ArrayNdims), "ArrayPosition" => Some(Self::ArrayPosition), "ArrayPositions" => Some(Self::ArrayPositions), "ArrayPrepend" => Some(Self::ArrayPrepend), "ArrayRemove" => Some(Self::ArrayRemove), "ArrayReplace" => Some(Self::ArrayReplace), - "Cardinality" => Some(Self::Cardinality), "ArrayElement" => Some(Self::ArrayElement), "ArraySlice" => Some(Self::ArraySlice), "Cot" => Some(Self::Cot), @@ -3025,10 +3020,8 @@ impl ScalarFunction { "ArrayDistinct" => Some(Self::ArrayDistinct), "ArrayResize" => Some(Self::ArrayResize), "EndsWith" => Some(Self::EndsWith), - "InStr" => Some(Self::InStr), "MakeDate" => Some(Self::MakeDate), "ArrayReverse" => Some(Self::ArrayReverse), - "RegexpLike" => Some(Self::RegexpLike), "ToChar" => Some(Self::ToChar), _ => None, } diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index ab7065cfbd85..aee53849c806 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -47,26 +47,26 @@ use datafusion_common::{ use datafusion_expr::expr::Unnest; use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_by}; use datafusion_expr::{ - acosh, array, array_append, array_concat, array_dims, array_distinct, array_element, - array_empty, array_except, array_has, array_has_all, array_has_any, array_intersect, - array_length, array_ndims, array_pop_back, array_pop_front, array_position, - array_positions, array_prepend, array_remove, array_remove_all, array_remove_n, - array_repeat, array_replace, array_replace_all, array_replace_n, array_resize, - array_slice, array_sort, array_union, arrow_typeof, ascii, asinh, atan, atan2, atanh, - bit_length, btrim, cardinality, cbrt, ceil, character_length, chr, coalesce, - concat_expr, concat_ws_expr, cos, cosh, cot, current_date, current_time, date_bin, - date_part, date_trunc, degrees, digest, ends_with, exp, + acosh, array, array_append, array_concat, array_distinct, array_element, array_empty, + array_except, array_has, array_has_all, array_has_any, array_intersect, array_length, + array_pop_back, array_pop_front, array_position, array_positions, array_prepend, + array_remove, array_remove_all, array_remove_n, array_repeat, array_replace, + array_replace_all, array_replace_n, array_resize, array_slice, array_sort, + array_union, arrow_typeof, ascii, asinh, atan, atan2, atanh, bit_length, btrim, cbrt, + ceil, character_length, chr, coalesce, concat_expr, concat_ws_expr, cos, cosh, cot, + current_date, current_time, date_bin, date_part, date_trunc, degrees, digest, + ends_with, exp, expr::{self, InList, Sort, WindowFunction}, - factorial, find_in_set, flatten, floor, from_unixtime, gcd, initcap, instr, iszero, - lcm, left, levenshtein, ln, log, log10, log2, + factorial, find_in_set, flatten, floor, from_unixtime, gcd, initcap, iszero, lcm, + left, levenshtein, ln, log, log10, log2, logical_plan::{PlanType, StringifiedPlan}, lower, lpad, ltrim, md5, nanvl, now, octet_length, overlay, pi, power, radians, - random, regexp_like, regexp_replace, repeat, replace, reverse, right, round, rpad, - rtrim, sha224, sha256, sha384, sha512, signum, sin, sinh, split_part, sqrt, - starts_with, string_to_array, strpos, struct_fun, substr, substr_index, substring, - tan, tanh, to_hex, translate, trim, trunc, upper, uuid, AggregateFunction, Between, - BinaryExpr, BuiltInWindowFunction, BuiltinScalarFunction, Case, Cast, Expr, - GetFieldAccess, GetIndexedField, GroupingSet, + random, regexp_replace, repeat, replace, reverse, right, round, rpad, rtrim, sha224, + sha256, sha384, sha512, signum, sin, sinh, split_part, sqrt, starts_with, + string_to_array, strpos, struct_fun, substr, substr_index, substring, tan, tanh, + to_hex, translate, trim, trunc, upper, uuid, AggregateFunction, Between, BinaryExpr, + BuiltInWindowFunction, BuiltinScalarFunction, Case, Cast, Expr, GetFieldAccess, + GetIndexedField, GroupingSet, GroupingSet::GroupingSets, JoinConstraint, JoinType, Like, Operator, TryCast, WindowFrame, WindowFrameBound, WindowFrameUnits, @@ -76,6 +76,8 @@ use datafusion_expr::{ expr::{Alias, Placeholder}, }; +use super::LogicalExtensionCodec; + #[derive(Debug)] pub enum Error { General(String), @@ -484,12 +486,10 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::ArrayHasAll => Self::ArrayHasAll, ScalarFunction::ArrayHasAny => Self::ArrayHasAny, ScalarFunction::ArrayHas => Self::ArrayHas, - ScalarFunction::ArrayDims => Self::ArrayDims, ScalarFunction::ArrayDistinct => Self::ArrayDistinct, ScalarFunction::ArrayElement => Self::ArrayElement, ScalarFunction::Flatten => Self::Flatten, ScalarFunction::ArrayLength => Self::ArrayLength, - ScalarFunction::ArrayNdims => Self::ArrayNdims, ScalarFunction::ArrayPopFront => Self::ArrayPopFront, ScalarFunction::ArrayPopBack => Self::ArrayPopBack, ScalarFunction::ArrayPosition => Self::ArrayPosition, @@ -507,7 +507,6 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::ArrayIntersect => Self::ArrayIntersect, ScalarFunction::ArrayUnion => Self::ArrayUnion, ScalarFunction::ArrayResize => Self::ArrayResize, - ScalarFunction::Cardinality => Self::Cardinality, ScalarFunction::Array => Self::MakeArray, ScalarFunction::DatePart => Self::DatePart, ScalarFunction::DateTrunc => Self::DateTrunc, @@ -528,11 +527,9 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::ConcatWithSeparator => Self::ConcatWithSeparator, ScalarFunction::EndsWith => Self::EndsWith, ScalarFunction::InitCap => Self::InitCap, - ScalarFunction::InStr => Self::InStr, ScalarFunction::Left => Self::Left, ScalarFunction::Lpad => Self::Lpad, ScalarFunction::Random => Self::Random, - ScalarFunction::RegexpLike => Self::RegexpLike, ScalarFunction::RegexpReplace => Self::RegexpReplace, ScalarFunction::Repeat => Self::Repeat, ScalarFunction::Replace => Self::Replace, @@ -976,6 +973,7 @@ pub fn parse_i32_to_aggregate_function(value: &i32) -> Result Result { use protobuf::{logical_expr_node::ExprType, window_expr_node, ScalarFunction}; @@ -990,7 +988,7 @@ pub fn parse_expr( let operands = binary_expr .operands .iter() - .map(|expr| parse_expr(expr, registry)) + .map(|expr| parse_expr(expr, registry, codec)) .collect::, _>>()?; if operands.len() < 2 { @@ -1009,8 +1007,12 @@ pub fn parse_expr( .expect("Binary expression could not be reduced to a single expression.")) } ExprType::GetIndexedField(get_indexed_field) => { - let expr = - parse_required_expr(get_indexed_field.expr.as_deref(), registry, "expr")?; + let expr = parse_required_expr( + get_indexed_field.expr.as_deref(), + registry, + "expr", + codec, + )?; let field = match &get_indexed_field.field { Some(protobuf::get_indexed_field::Field::NamedStructField( named_struct_field, @@ -1027,6 +1029,7 @@ pub fn parse_expr( list_index.key.as_deref(), registry, "key", + codec, )?), } } @@ -1036,16 +1039,19 @@ pub fn parse_expr( list_range.start.as_deref(), registry, "start", + codec, )?), stop: Box::new(parse_required_expr( list_range.stop.as_deref(), registry, "stop", + codec, )?), stride: Box::new(parse_required_expr( list_range.stride.as_deref(), registry, "stride", + codec, )?), } } @@ -1070,12 +1076,12 @@ pub fn parse_expr( let partition_by = expr .partition_by .iter() - .map(|e| parse_expr(e, registry)) + .map(|e| parse_expr(e, registry, codec)) .collect::, _>>()?; let mut order_by = expr .order_by .iter() - .map(|e| parse_expr(e, registry)) + .map(|e| parse_expr(e, registry, codec)) .collect::, _>>()?; let window_frame = expr .window_frame @@ -1103,7 +1109,7 @@ pub fn parse_expr( datafusion_expr::expr::WindowFunctionDefinition::AggregateFunction( aggr_function, ), - vec![parse_required_expr(expr.expr.as_deref(), registry, "expr")?], + vec![parse_required_expr(expr.expr.as_deref(), registry, "expr", codec)?], partition_by, order_by, window_frame, @@ -1115,9 +1121,10 @@ pub fn parse_expr( .map_err(|_| Error::unknown("BuiltInWindowFunction", *i))? .into(); - let args = parse_optional_expr(expr.expr.as_deref(), registry)? - .map(|e| vec![e]) - .unwrap_or_else(Vec::new); + let args = + parse_optional_expr(expr.expr.as_deref(), registry, codec)? + .map(|e| vec![e]) + .unwrap_or_else(Vec::new); Ok(Expr::WindowFunction(WindowFunction::new( datafusion_expr::expr::WindowFunctionDefinition::BuiltInWindowFunction( @@ -1132,9 +1139,10 @@ pub fn parse_expr( } window_expr_node::WindowFunction::Udaf(udaf_name) => { let udaf_function = registry.udaf(udaf_name)?; - let args = parse_optional_expr(expr.expr.as_deref(), registry)? - .map(|e| vec![e]) - .unwrap_or_else(Vec::new); + let args = + parse_optional_expr(expr.expr.as_deref(), registry, codec)? + .map(|e| vec![e]) + .unwrap_or_else(Vec::new); Ok(Expr::WindowFunction(WindowFunction::new( datafusion_expr::expr::WindowFunctionDefinition::AggregateUDF( udaf_function, @@ -1148,9 +1156,10 @@ pub fn parse_expr( } window_expr_node::WindowFunction::Udwf(udwf_name) => { let udwf_function = registry.udwf(udwf_name)?; - let args = parse_optional_expr(expr.expr.as_deref(), registry)? - .map(|e| vec![e]) - .unwrap_or_else(Vec::new); + let args = + parse_optional_expr(expr.expr.as_deref(), registry, codec)? + .map(|e| vec![e]) + .unwrap_or_else(Vec::new); Ok(Expr::WindowFunction(WindowFunction::new( datafusion_expr::expr::WindowFunctionDefinition::WindowUDF( udwf_function, @@ -1171,15 +1180,16 @@ pub fn parse_expr( fun, expr.expr .iter() - .map(|e| parse_expr(e, registry)) + .map(|e| parse_expr(e, registry, codec)) .collect::, _>>()?, expr.distinct, - parse_optional_expr(expr.filter.as_deref(), registry)?.map(Box::new), - parse_vec_expr(&expr.order_by, registry)?, + parse_optional_expr(expr.filter.as_deref(), registry, codec)? + .map(Box::new), + parse_vec_expr(&expr.order_by, registry, codec)?, ))) } ExprType::Alias(alias) => Ok(Expr::Alias(Alias::new( - parse_required_expr(alias.expr.as_deref(), registry, "expr")?, + parse_required_expr(alias.expr.as_deref(), registry, "expr", codec)?, alias .relation .first() @@ -1191,90 +1201,118 @@ pub fn parse_expr( is_null.expr.as_deref(), registry, "expr", + codec, )?))), ExprType::IsNotNullExpr(is_not_null) => Ok(Expr::IsNotNull(Box::new( - parse_required_expr(is_not_null.expr.as_deref(), registry, "expr")?, + parse_required_expr(is_not_null.expr.as_deref(), registry, "expr", codec)?, ))), ExprType::NotExpr(not) => Ok(Expr::Not(Box::new(parse_required_expr( not.expr.as_deref(), registry, "expr", + codec, )?))), ExprType::IsTrue(msg) => Ok(Expr::IsTrue(Box::new(parse_required_expr( msg.expr.as_deref(), registry, "expr", + codec, )?))), ExprType::IsFalse(msg) => Ok(Expr::IsFalse(Box::new(parse_required_expr( msg.expr.as_deref(), registry, "expr", + codec, )?))), ExprType::IsUnknown(msg) => Ok(Expr::IsUnknown(Box::new(parse_required_expr( msg.expr.as_deref(), registry, "expr", + codec, )?))), ExprType::IsNotTrue(msg) => Ok(Expr::IsNotTrue(Box::new(parse_required_expr( msg.expr.as_deref(), registry, "expr", + codec, )?))), ExprType::IsNotFalse(msg) => Ok(Expr::IsNotFalse(Box::new(parse_required_expr( msg.expr.as_deref(), registry, "expr", + codec, )?))), ExprType::IsNotUnknown(msg) => Ok(Expr::IsNotUnknown(Box::new( - parse_required_expr(msg.expr.as_deref(), registry, "expr")?, + parse_required_expr(msg.expr.as_deref(), registry, "expr", codec)?, ))), ExprType::Between(between) => Ok(Expr::Between(Between::new( Box::new(parse_required_expr( between.expr.as_deref(), registry, "expr", + codec, )?), between.negated, Box::new(parse_required_expr( between.low.as_deref(), registry, "expr", + codec, )?), Box::new(parse_required_expr( between.high.as_deref(), registry, "expr", + codec, )?), ))), ExprType::Like(like) => Ok(Expr::Like(Like::new( like.negated, - Box::new(parse_required_expr(like.expr.as_deref(), registry, "expr")?), + Box::new(parse_required_expr( + like.expr.as_deref(), + registry, + "expr", + codec, + )?), Box::new(parse_required_expr( like.pattern.as_deref(), registry, "pattern", + codec, )?), parse_escape_char(&like.escape_char)?, false, ))), ExprType::Ilike(like) => Ok(Expr::Like(Like::new( like.negated, - Box::new(parse_required_expr(like.expr.as_deref(), registry, "expr")?), + Box::new(parse_required_expr( + like.expr.as_deref(), + registry, + "expr", + codec, + )?), Box::new(parse_required_expr( like.pattern.as_deref(), registry, "pattern", + codec, )?), parse_escape_char(&like.escape_char)?, true, ))), ExprType::SimilarTo(like) => Ok(Expr::SimilarTo(Like::new( like.negated, - Box::new(parse_required_expr(like.expr.as_deref(), registry, "expr")?), + Box::new(parse_required_expr( + like.expr.as_deref(), + registry, + "expr", + codec, + )?), Box::new(parse_required_expr( like.pattern.as_deref(), registry, "pattern", + codec, )?), parse_escape_char(&like.escape_char)?, false, @@ -1284,44 +1322,66 @@ pub fn parse_expr( .when_then_expr .iter() .map(|e| { - let when_expr = - parse_required_expr(e.when_expr.as_ref(), registry, "when_expr")?; - let then_expr = - parse_required_expr(e.then_expr.as_ref(), registry, "then_expr")?; + let when_expr = parse_required_expr( + e.when_expr.as_ref(), + registry, + "when_expr", + codec, + )?; + let then_expr = parse_required_expr( + e.then_expr.as_ref(), + registry, + "then_expr", + codec, + )?; Ok((Box::new(when_expr), Box::new(then_expr))) }) .collect::, Box)>, Error>>()?; Ok(Expr::Case(Case::new( - parse_optional_expr(case.expr.as_deref(), registry)?.map(Box::new), + parse_optional_expr(case.expr.as_deref(), registry, codec)?.map(Box::new), when_then_expr, - parse_optional_expr(case.else_expr.as_deref(), registry)?.map(Box::new), + parse_optional_expr(case.else_expr.as_deref(), registry, codec)? + .map(Box::new), ))) } ExprType::Cast(cast) => { - let expr = - Box::new(parse_required_expr(cast.expr.as_deref(), registry, "expr")?); + let expr = Box::new(parse_required_expr( + cast.expr.as_deref(), + registry, + "expr", + codec, + )?); let data_type = cast.arrow_type.as_ref().required("arrow_type")?; Ok(Expr::Cast(Cast::new(expr, data_type))) } ExprType::TryCast(cast) => { - let expr = - Box::new(parse_required_expr(cast.expr.as_deref(), registry, "expr")?); + let expr = Box::new(parse_required_expr( + cast.expr.as_deref(), + registry, + "expr", + codec, + )?); let data_type = cast.arrow_type.as_ref().required("arrow_type")?; Ok(Expr::TryCast(TryCast::new(expr, data_type))) } ExprType::Sort(sort) => Ok(Expr::Sort(Sort::new( - Box::new(parse_required_expr(sort.expr.as_deref(), registry, "expr")?), + Box::new(parse_required_expr( + sort.expr.as_deref(), + registry, + "expr", + codec, + )?), sort.asc, sort.nulls_first, ))), ExprType::Negative(negative) => Ok(Expr::Negative(Box::new( - parse_required_expr(negative.expr.as_deref(), registry, "expr")?, + parse_required_expr(negative.expr.as_deref(), registry, "expr", codec)?, ))), ExprType::Unnest(unnest) => { let exprs = unnest .exprs .iter() - .map(|e| parse_expr(e, registry)) + .map(|e| parse_expr(e, registry, codec)) .collect::, _>>()?; Ok(Expr::Unnest(Unnest { exprs })) } @@ -1330,11 +1390,12 @@ pub fn parse_expr( in_list.expr.as_deref(), registry, "expr", + codec, )?), in_list .list .iter() - .map(|expr| parse_expr(expr, registry)) + .map(|expr| parse_expr(expr, registry, codec)) .collect::, _>>()?, in_list.negated, ))), @@ -1352,330 +1413,351 @@ pub fn parse_expr( match scalar_function { ScalarFunction::Unknown => Err(proto_error("Unknown scalar function")), - ScalarFunction::Asinh => Ok(asinh(parse_expr(&args[0], registry)?)), - ScalarFunction::Acosh => Ok(acosh(parse_expr(&args[0], registry)?)), + ScalarFunction::Asinh => { + Ok(asinh(parse_expr(&args[0], registry, codec)?)) + } + ScalarFunction::Acosh => { + Ok(acosh(parse_expr(&args[0], registry, codec)?)) + } ScalarFunction::Array => Ok(array( args.to_owned() .iter() - .map(|expr| parse_expr(expr, registry)) + .map(|expr| parse_expr(expr, registry, codec)) .collect::, _>>()?, )), ScalarFunction::ArrayAppend => Ok(array_append( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::ArraySort => Ok(array_sort( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, - parse_expr(&args[2], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, + parse_expr(&args[2], registry, codec)?, )), ScalarFunction::ArrayPopFront => { - Ok(array_pop_front(parse_expr(&args[0], registry)?)) + Ok(array_pop_front(parse_expr(&args[0], registry, codec)?)) } ScalarFunction::ArrayPopBack => { - Ok(array_pop_back(parse_expr(&args[0], registry)?)) + Ok(array_pop_back(parse_expr(&args[0], registry, codec)?)) } ScalarFunction::ArrayPrepend => Ok(array_prepend( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::ArrayConcat => Ok(array_concat( args.to_owned() .iter() - .map(|expr| parse_expr(expr, registry)) + .map(|expr| parse_expr(expr, registry, codec)) .collect::, _>>()?, )), ScalarFunction::ArrayExcept => Ok(array_except( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::ArrayHasAll => Ok(array_has_all( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::ArrayHasAny => Ok(array_has_any( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::ArrayHas => Ok(array_has( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::ArrayIntersect => Ok(array_intersect( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::ArrayPosition => Ok(array_position( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, - parse_expr(&args[2], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, + parse_expr(&args[2], registry, codec)?, )), ScalarFunction::ArrayPositions => Ok(array_positions( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::ArrayRepeat => Ok(array_repeat( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::ArrayRemove => Ok(array_remove( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::ArrayRemoveN => Ok(array_remove_n( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, - parse_expr(&args[2], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, + parse_expr(&args[2], registry, codec)?, )), ScalarFunction::ArrayRemoveAll => Ok(array_remove_all( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::ArrayReplace => Ok(array_replace( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, - parse_expr(&args[2], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, + parse_expr(&args[2], registry, codec)?, )), ScalarFunction::ArrayReplaceN => Ok(array_replace_n( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, - parse_expr(&args[2], registry)?, - parse_expr(&args[3], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, + parse_expr(&args[2], registry, codec)?, + parse_expr(&args[3], registry, codec)?, )), ScalarFunction::ArrayReplaceAll => Ok(array_replace_all( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, - parse_expr(&args[2], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, + parse_expr(&args[2], registry, codec)?, )), ScalarFunction::ArrayReverse => { - Ok(array_reverse(parse_expr(&args[0], registry)?)) + Ok(array_reverse(parse_expr(&args[0], registry, codec)?)) } ScalarFunction::ArraySlice => Ok(array_slice( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, - parse_expr(&args[2], registry)?, - parse_expr(&args[3], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, + parse_expr(&args[2], registry, codec)?, + parse_expr(&args[3], registry, codec)?, )), - ScalarFunction::Cardinality => { - Ok(cardinality(parse_expr(&args[0], registry)?)) - } ScalarFunction::ArrayLength => Ok(array_length( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), - ScalarFunction::ArrayDims => { - Ok(array_dims(parse_expr(&args[0], registry)?)) - } ScalarFunction::ArrayDistinct => { - Ok(array_distinct(parse_expr(&args[0], registry)?)) + Ok(array_distinct(parse_expr(&args[0], registry, codec)?)) } ScalarFunction::ArrayElement => Ok(array_element( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::ArrayEmpty => { - Ok(array_empty(parse_expr(&args[0], registry)?)) - } - ScalarFunction::ArrayNdims => { - Ok(array_ndims(parse_expr(&args[0], registry)?)) + Ok(array_empty(parse_expr(&args[0], registry, codec)?)) } ScalarFunction::ArrayUnion => Ok(array_union( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::ArrayResize => Ok(array_resize( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, - parse_expr(&args[2], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, + parse_expr(&args[2], registry, codec)?, )), - ScalarFunction::Sqrt => Ok(sqrt(parse_expr(&args[0], registry)?)), - ScalarFunction::Cbrt => Ok(cbrt(parse_expr(&args[0], registry)?)), - ScalarFunction::Sin => Ok(sin(parse_expr(&args[0], registry)?)), - ScalarFunction::Cos => Ok(cos(parse_expr(&args[0], registry)?)), - ScalarFunction::Tan => Ok(tan(parse_expr(&args[0], registry)?)), - ScalarFunction::Atan => Ok(atan(parse_expr(&args[0], registry)?)), - ScalarFunction::Sinh => Ok(sinh(parse_expr(&args[0], registry)?)), - ScalarFunction::Cosh => Ok(cosh(parse_expr(&args[0], registry)?)), - ScalarFunction::Tanh => Ok(tanh(parse_expr(&args[0], registry)?)), - ScalarFunction::Atanh => Ok(atanh(parse_expr(&args[0], registry)?)), - ScalarFunction::Exp => Ok(exp(parse_expr(&args[0], registry)?)), - ScalarFunction::Degrees => Ok(degrees(parse_expr(&args[0], registry)?)), - ScalarFunction::Radians => Ok(radians(parse_expr(&args[0], registry)?)), - ScalarFunction::Log2 => Ok(log2(parse_expr(&args[0], registry)?)), - ScalarFunction::Ln => Ok(ln(parse_expr(&args[0], registry)?)), - ScalarFunction::Log10 => Ok(log10(parse_expr(&args[0], registry)?)), - ScalarFunction::Floor => Ok(floor(parse_expr(&args[0], registry)?)), + ScalarFunction::Sqrt => Ok(sqrt(parse_expr(&args[0], registry, codec)?)), + ScalarFunction::Cbrt => Ok(cbrt(parse_expr(&args[0], registry, codec)?)), + ScalarFunction::Sin => Ok(sin(parse_expr(&args[0], registry, codec)?)), + ScalarFunction::Cos => Ok(cos(parse_expr(&args[0], registry, codec)?)), + ScalarFunction::Tan => Ok(tan(parse_expr(&args[0], registry, codec)?)), + ScalarFunction::Atan => Ok(atan(parse_expr(&args[0], registry, codec)?)), + ScalarFunction::Sinh => Ok(sinh(parse_expr(&args[0], registry, codec)?)), + ScalarFunction::Cosh => Ok(cosh(parse_expr(&args[0], registry, codec)?)), + ScalarFunction::Tanh => Ok(tanh(parse_expr(&args[0], registry, codec)?)), + ScalarFunction::Atanh => { + Ok(atanh(parse_expr(&args[0], registry, codec)?)) + } + ScalarFunction::Exp => Ok(exp(parse_expr(&args[0], registry, codec)?)), + ScalarFunction::Degrees => { + Ok(degrees(parse_expr(&args[0], registry, codec)?)) + } + ScalarFunction::Radians => { + Ok(radians(parse_expr(&args[0], registry, codec)?)) + } + ScalarFunction::Log2 => Ok(log2(parse_expr(&args[0], registry, codec)?)), + ScalarFunction::Ln => Ok(ln(parse_expr(&args[0], registry, codec)?)), + ScalarFunction::Log10 => { + Ok(log10(parse_expr(&args[0], registry, codec)?)) + } + ScalarFunction::Floor => { + Ok(floor(parse_expr(&args[0], registry, codec)?)) + } ScalarFunction::Factorial => { - Ok(factorial(parse_expr(&args[0], registry)?)) + Ok(factorial(parse_expr(&args[0], registry, codec)?)) } - ScalarFunction::Ceil => Ok(ceil(parse_expr(&args[0], registry)?)), + ScalarFunction::Ceil => Ok(ceil(parse_expr(&args[0], registry, codec)?)), ScalarFunction::Round => Ok(round( args.to_owned() .iter() - .map(|expr| parse_expr(expr, registry)) + .map(|expr| parse_expr(expr, registry, codec)) .collect::, _>>()?, )), ScalarFunction::Trunc => Ok(trunc( args.to_owned() .iter() - .map(|expr| parse_expr(expr, registry)) + .map(|expr| parse_expr(expr, registry, codec)) .collect::, _>>()?, )), - ScalarFunction::Signum => Ok(signum(parse_expr(&args[0], registry)?)), + ScalarFunction::Signum => { + Ok(signum(parse_expr(&args[0], registry, codec)?)) + } ScalarFunction::OctetLength => { - Ok(octet_length(parse_expr(&args[0], registry)?)) + Ok(octet_length(parse_expr(&args[0], registry, codec)?)) + } + ScalarFunction::Lower => { + Ok(lower(parse_expr(&args[0], registry, codec)?)) + } + ScalarFunction::Upper => { + Ok(upper(parse_expr(&args[0], registry, codec)?)) + } + ScalarFunction::Trim => Ok(trim(parse_expr(&args[0], registry, codec)?)), + ScalarFunction::Ltrim => { + Ok(ltrim(parse_expr(&args[0], registry, codec)?)) + } + ScalarFunction::Rtrim => { + Ok(rtrim(parse_expr(&args[0], registry, codec)?)) } - ScalarFunction::Lower => Ok(lower(parse_expr(&args[0], registry)?)), - ScalarFunction::Upper => Ok(upper(parse_expr(&args[0], registry)?)), - ScalarFunction::Trim => Ok(trim(parse_expr(&args[0], registry)?)), - ScalarFunction::Ltrim => Ok(ltrim(parse_expr(&args[0], registry)?)), - ScalarFunction::Rtrim => Ok(rtrim(parse_expr(&args[0], registry)?)), ScalarFunction::DatePart => Ok(date_part( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::DateTrunc => Ok(date_trunc( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::DateBin => Ok(date_bin( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, - parse_expr(&args[2], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, + parse_expr(&args[2], registry, codec)?, )), - ScalarFunction::Sha224 => Ok(sha224(parse_expr(&args[0], registry)?)), - ScalarFunction::Sha256 => Ok(sha256(parse_expr(&args[0], registry)?)), - ScalarFunction::Sha384 => Ok(sha384(parse_expr(&args[0], registry)?)), - ScalarFunction::Sha512 => Ok(sha512(parse_expr(&args[0], registry)?)), - ScalarFunction::Md5 => Ok(md5(parse_expr(&args[0], registry)?)), + ScalarFunction::Sha224 => { + Ok(sha224(parse_expr(&args[0], registry, codec)?)) + } + ScalarFunction::Sha256 => { + Ok(sha256(parse_expr(&args[0], registry, codec)?)) + } + ScalarFunction::Sha384 => { + Ok(sha384(parse_expr(&args[0], registry, codec)?)) + } + ScalarFunction::Sha512 => { + Ok(sha512(parse_expr(&args[0], registry, codec)?)) + } + ScalarFunction::Md5 => Ok(md5(parse_expr(&args[0], registry, codec)?)), ScalarFunction::Digest => Ok(digest( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), - ScalarFunction::Ascii => Ok(ascii(parse_expr(&args[0], registry)?)), + ScalarFunction::Ascii => { + Ok(ascii(parse_expr(&args[0], registry, codec)?)) + } ScalarFunction::BitLength => { - Ok(bit_length(parse_expr(&args[0], registry)?)) + Ok(bit_length(parse_expr(&args[0], registry, codec)?)) } ScalarFunction::CharacterLength => { - Ok(character_length(parse_expr(&args[0], registry)?)) + Ok(character_length(parse_expr(&args[0], registry, codec)?)) + } + ScalarFunction::Chr => Ok(chr(parse_expr(&args[0], registry, codec)?)), + ScalarFunction::InitCap => { + Ok(initcap(parse_expr(&args[0], registry, codec)?)) } - ScalarFunction::Chr => Ok(chr(parse_expr(&args[0], registry)?)), - ScalarFunction::InitCap => Ok(initcap(parse_expr(&args[0], registry)?)), - ScalarFunction::InStr => Ok(instr( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, - )), ScalarFunction::Gcd => Ok(gcd( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::Lcm => Ok(lcm( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::Left => Ok(left( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::Random => Ok(random()), ScalarFunction::Uuid => Ok(uuid()), ScalarFunction::Repeat => Ok(repeat( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::Replace => Ok(replace( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, - parse_expr(&args[2], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, + parse_expr(&args[2], registry, codec)?, )), - ScalarFunction::Reverse => Ok(reverse(parse_expr(&args[0], registry)?)), + ScalarFunction::Reverse => { + Ok(reverse(parse_expr(&args[0], registry, codec)?)) + } ScalarFunction::Right => Ok(right( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::Concat => Ok(concat_expr( args.to_owned() .iter() - .map(|expr| parse_expr(expr, registry)) + .map(|expr| parse_expr(expr, registry, codec)) .collect::, _>>()?, )), ScalarFunction::ConcatWithSeparator => Ok(concat_ws_expr( args.to_owned() .iter() - .map(|expr| parse_expr(expr, registry)) + .map(|expr| parse_expr(expr, registry, codec)) .collect::, _>>()?, )), ScalarFunction::Lpad => Ok(lpad( args.to_owned() .iter() - .map(|expr| parse_expr(expr, registry)) + .map(|expr| parse_expr(expr, registry, codec)) .collect::, _>>()?, )), ScalarFunction::Rpad => Ok(rpad( args.to_owned() .iter() - .map(|expr| parse_expr(expr, registry)) - .collect::, _>>()?, - )), - ScalarFunction::RegexpLike => Ok(regexp_like( - args.to_owned() - .iter() - .map(|expr| parse_expr(expr, registry)) + .map(|expr| parse_expr(expr, registry, codec)) .collect::, _>>()?, )), ScalarFunction::RegexpReplace => Ok(regexp_replace( args.to_owned() .iter() - .map(|expr| parse_expr(expr, registry)) + .map(|expr| parse_expr(expr, registry, codec)) .collect::, _>>()?, )), ScalarFunction::Btrim => Ok(btrim( args.to_owned() .iter() - .map(|expr| parse_expr(expr, registry)) + .map(|expr| parse_expr(expr, registry, codec)) .collect::, _>>()?, )), ScalarFunction::SplitPart => Ok(split_part( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, - parse_expr(&args[2], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, + parse_expr(&args[2], registry, codec)?, )), ScalarFunction::StartsWith => Ok(starts_with( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::EndsWith => Ok(ends_with( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::Strpos => Ok(strpos( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::Substr => { if args.len() > 2 { assert_eq!(args.len(), 3); Ok(substring( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, - parse_expr(&args[2], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, + parse_expr(&args[2], registry, codec)?, )) } else { Ok(substr( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )) } } ScalarFunction::Levenshtein => Ok(levenshtein( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), - ScalarFunction::ToHex => Ok(to_hex(parse_expr(&args[0], registry)?)), + ScalarFunction::ToHex => { + Ok(to_hex(parse_expr(&args[0], registry, codec)?)) + } ScalarFunction::MakeDate => { let args: Vec<_> = args .iter() - .map(|expr| parse_expr(expr, registry)) + .map(|expr| parse_expr(expr, registry, codec)) .collect::>()?; Ok(Expr::ScalarFunction(expr::ScalarFunction::new( BuiltinScalarFunction::MakeDate, @@ -1685,7 +1767,7 @@ pub fn parse_expr( ScalarFunction::ToChar => { let args: Vec<_> = args .iter() - .map(|expr| parse_expr(expr, registry)) + .map(|expr| parse_expr(expr, registry, codec)) .collect::>()?; Ok(Expr::ScalarFunction(expr::ScalarFunction::new( BuiltinScalarFunction::ToChar, @@ -1694,75 +1776,86 @@ pub fn parse_expr( } ScalarFunction::Now => Ok(now()), ScalarFunction::Translate => Ok(translate( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, - parse_expr(&args[2], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, + parse_expr(&args[2], registry, codec)?, )), ScalarFunction::Coalesce => Ok(coalesce( args.to_owned() .iter() - .map(|expr| parse_expr(expr, registry)) + .map(|expr| parse_expr(expr, registry, codec)) .collect::, _>>()?, )), ScalarFunction::Pi => Ok(pi()), ScalarFunction::Power => Ok(power( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::Log => Ok(log( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::FromUnixtime => { - Ok(from_unixtime(parse_expr(&args[0], registry)?)) + Ok(from_unixtime(parse_expr(&args[0], registry, codec)?)) } ScalarFunction::Atan2 => Ok(atan2( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::CurrentDate => Ok(current_date()), ScalarFunction::CurrentTime => Ok(current_time()), - ScalarFunction::Cot => Ok(cot(parse_expr(&args[0], registry)?)), + ScalarFunction::Cot => Ok(cot(parse_expr(&args[0], registry, codec)?)), ScalarFunction::Nanvl => Ok(nanvl( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), - ScalarFunction::Iszero => Ok(iszero(parse_expr(&args[0], registry)?)), + ScalarFunction::Iszero => { + Ok(iszero(parse_expr(&args[0], registry, codec)?)) + } ScalarFunction::ArrowTypeof => { - Ok(arrow_typeof(parse_expr(&args[0], registry)?)) + Ok(arrow_typeof(parse_expr(&args[0], registry, codec)?)) + } + ScalarFunction::Flatten => { + Ok(flatten(parse_expr(&args[0], registry, codec)?)) } - ScalarFunction::Flatten => Ok(flatten(parse_expr(&args[0], registry)?)), ScalarFunction::StringToArray => Ok(string_to_array( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, - parse_expr(&args[2], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, + parse_expr(&args[2], registry, codec)?, )), ScalarFunction::OverLay => Ok(overlay( args.to_owned() .iter() - .map(|expr| parse_expr(expr, registry)) + .map(|expr| parse_expr(expr, registry, codec)) .collect::, _>>()?, )), ScalarFunction::SubstrIndex => Ok(substr_index( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, - parse_expr(&args[2], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, + parse_expr(&args[2], registry, codec)?, )), ScalarFunction::FindInSet => Ok(find_in_set( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::StructFun => { - Ok(struct_fun(parse_expr(&args[0], registry)?)) + Ok(struct_fun(parse_expr(&args[0], registry, codec)?)) } } } - ExprType::ScalarUdfExpr(protobuf::ScalarUdfExprNode { fun_name, args }) => { - let scalar_fn = registry.udf(fun_name.as_str())?; + ExprType::ScalarUdfExpr(protobuf::ScalarUdfExprNode { + fun_name, + args, + fun_definition, + }) => { + let scalar_fn = match fun_definition { + Some(buf) => codec.try_decode_udf(fun_name, buf)?, + None => registry.udf(fun_name.as_str())?, + }; Ok(Expr::ScalarFunction(expr::ScalarFunction::new_udf( scalar_fn, args.iter() - .map(|expr| parse_expr(expr, registry)) + .map(|expr| parse_expr(expr, registry, codec)) .collect::, Error>>()?, ))) } @@ -1773,11 +1866,11 @@ pub fn parse_expr( agg_fn, pb.args .iter() - .map(|expr| parse_expr(expr, registry)) + .map(|expr| parse_expr(expr, registry, codec)) .collect::, Error>>()?, false, - parse_optional_expr(pb.filter.as_deref(), registry)?.map(Box::new), - parse_vec_expr(&pb.order_by, registry)?, + parse_optional_expr(pb.filter.as_deref(), registry, codec)?.map(Box::new), + parse_vec_expr(&pb.order_by, registry, codec)?, ))) } @@ -1788,7 +1881,7 @@ pub fn parse_expr( expr_list .expr .iter() - .map(|expr| parse_expr(expr, registry)) + .map(|expr| parse_expr(expr, registry, codec)) .collect::, Error>>() }) .collect::, Error>>()?, @@ -1796,13 +1889,13 @@ pub fn parse_expr( } ExprType::Cube(CubeNode { expr }) => Ok(Expr::GroupingSet(GroupingSet::Cube( expr.iter() - .map(|expr| parse_expr(expr, registry)) + .map(|expr| parse_expr(expr, registry, codec)) .collect::, Error>>()?, ))), ExprType::Rollup(RollupNode { expr }) => { Ok(Expr::GroupingSet(GroupingSet::Rollup( expr.iter() - .map(|expr| parse_expr(expr, registry)) + .map(|expr| parse_expr(expr, registry, codec)) .collect::, Error>>()?, ))) } @@ -1870,10 +1963,13 @@ pub fn from_proto_binary_op(op: &str) -> Result { fn parse_vec_expr( p: &[protobuf::LogicalExprNode], registry: &dyn FunctionRegistry, + codec: &dyn LogicalExtensionCodec, ) -> Result>, Error> { let res = p .iter() - .map(|elem| parse_expr(elem, registry).map_err(|e| plan_datafusion_err!("{}", e))) + .map(|elem| { + parse_expr(elem, registry, codec).map_err(|e| plan_datafusion_err!("{}", e)) + }) .collect::>>()?; // Convert empty vector to None. Ok((!res.is_empty()).then_some(res)) @@ -1882,9 +1978,10 @@ fn parse_vec_expr( fn parse_optional_expr( p: Option<&protobuf::LogicalExprNode>, registry: &dyn FunctionRegistry, + codec: &dyn LogicalExtensionCodec, ) -> Result, Error> { match p { - Some(expr) => parse_expr(expr, registry).map(Some), + Some(expr) => parse_expr(expr, registry, codec).map(Some), None => Ok(None), } } @@ -1893,9 +1990,10 @@ fn parse_required_expr( p: Option<&protobuf::LogicalExprNode>, registry: &dyn FunctionRegistry, field: impl Into, + codec: &dyn LogicalExtensionCodec, ) -> Result { match p { - Some(expr) => parse_expr(expr, registry), + Some(expr) => parse_expr(expr, registry, codec), None => Err(Error::required(field)), } } diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index f107af757a71..7c9ead27e3b5 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -17,6 +17,7 @@ use arrow::csv::WriterBuilder; use datafusion_common::file_options::arrow_writer::ArrowWriterOptions; +use datafusion_expr::ScalarUDF; use std::collections::HashMap; use std::fmt::Debug; use std::str::FromStr; @@ -72,6 +73,8 @@ use datafusion_expr::dml::CopyOptions; use prost::bytes::BufMut; use prost::Message; +use self::to_proto::serialize_expr; + pub mod from_proto; pub mod to_proto; @@ -133,6 +136,14 @@ pub trait LogicalExtensionCodec: Debug + Send + Sync { node: Arc, buf: &mut Vec, ) -> Result<()>; + + fn try_decode_udf(&self, name: &str, _buf: &[u8]) -> Result> { + not_impl_err!("LogicalExtensionCodec is not provided for scalar function {name}") + } + + fn try_encode_udf(&self, _node: &ScalarUDF, _buf: &mut Vec) -> Result<()> { + Ok(()) + } } #[derive(Debug, Clone)] @@ -241,7 +252,9 @@ impl AsLogicalPlan for LogicalPlanNode { .chunks_exact(n_cols) .map(|r| { r.iter() - .map(|expr| from_proto::parse_expr(expr, ctx)) + .map(|expr| { + from_proto::parse_expr(expr, ctx, extension_codec) + }) .collect::, from_proto::Error>>() }) .collect::, _>>() @@ -255,7 +268,7 @@ impl AsLogicalPlan for LogicalPlanNode { let expr: Vec = projection .expr .iter() - .map(|expr| from_proto::parse_expr(expr, ctx)) + .map(|expr| from_proto::parse_expr(expr, ctx, extension_codec)) .collect::, _>>()?; let new_proj = project(input, expr)?; @@ -277,7 +290,7 @@ impl AsLogicalPlan for LogicalPlanNode { let expr: Expr = selection .expr .as_ref() - .map(|expr| from_proto::parse_expr(expr, ctx)) + .map(|expr| from_proto::parse_expr(expr, ctx, extension_codec)) .transpose()? .ok_or_else(|| { DataFusionError::Internal("expression required".to_string()) @@ -291,7 +304,7 @@ impl AsLogicalPlan for LogicalPlanNode { let window_expr = window .window_expr .iter() - .map(|expr| from_proto::parse_expr(expr, ctx)) + .map(|expr| from_proto::parse_expr(expr, ctx, extension_codec)) .collect::, _>>()?; LogicalPlanBuilder::from(input).window(window_expr)?.build() } @@ -301,12 +314,12 @@ impl AsLogicalPlan for LogicalPlanNode { let group_expr = aggregate .group_expr .iter() - .map(|expr| from_proto::parse_expr(expr, ctx)) + .map(|expr| from_proto::parse_expr(expr, ctx, extension_codec)) .collect::, _>>()?; let aggr_expr = aggregate .aggr_expr .iter() - .map(|expr| from_proto::parse_expr(expr, ctx)) + .map(|expr| from_proto::parse_expr(expr, ctx, extension_codec)) .collect::, _>>()?; LogicalPlanBuilder::from(input) .aggregate(group_expr, aggr_expr)? @@ -328,7 +341,7 @@ impl AsLogicalPlan for LogicalPlanNode { let filters = scan .filters .iter() - .map(|expr| from_proto::parse_expr(expr, ctx)) + .map(|expr| from_proto::parse_expr(expr, ctx, extension_codec)) .collect::, _>>()?; let mut all_sort_orders = vec![]; @@ -336,7 +349,7 @@ impl AsLogicalPlan for LogicalPlanNode { let file_sort_order = order .logical_expr_nodes .iter() - .map(|expr| from_proto::parse_expr(expr, ctx)) + .map(|expr| from_proto::parse_expr(expr, ctx, extension_codec)) .collect::, _>>()?; all_sort_orders.push(file_sort_order) } @@ -436,7 +449,7 @@ impl AsLogicalPlan for LogicalPlanNode { let filters = scan .filters .iter() - .map(|expr| from_proto::parse_expr(expr, ctx)) + .map(|expr| from_proto::parse_expr(expr, ctx, extension_codec)) .collect::, _>>()?; let provider = extension_codec.try_decode_table_provider( &scan.custom_table_data, @@ -461,7 +474,7 @@ impl AsLogicalPlan for LogicalPlanNode { let sort_expr: Vec = sort .expr .iter() - .map(|expr| from_proto::parse_expr(expr, ctx)) + .map(|expr| from_proto::parse_expr(expr, ctx, extension_codec)) .collect::, _>>()?; LogicalPlanBuilder::from(input).sort(sort_expr)?.build() } @@ -483,7 +496,9 @@ impl AsLogicalPlan for LogicalPlanNode { }) => Partitioning::Hash( pb_hash_expr .iter() - .map(|expr| from_proto::parse_expr(expr, ctx)) + .map(|expr| { + from_proto::parse_expr(expr, ctx, extension_codec) + }) .collect::, _>>()?, *partition_count as usize, ), @@ -527,7 +542,7 @@ impl AsLogicalPlan for LogicalPlanNode { let order_expr = expr .logical_expr_nodes .iter() - .map(|expr| from_proto::parse_expr(expr, ctx)) + .map(|expr| from_proto::parse_expr(expr, ctx, extension_codec)) .collect::, _>>()?; order_exprs.push(order_expr) } @@ -535,7 +550,7 @@ impl AsLogicalPlan for LogicalPlanNode { let mut column_defaults = HashMap::with_capacity(create_extern_table.column_defaults.len()); for (col_name, expr) in &create_extern_table.column_defaults { - let expr = from_proto::parse_expr(expr, ctx)?; + let expr = from_proto::parse_expr(expr, ctx, extension_codec)?; column_defaults.insert(col_name.clone(), expr); } @@ -663,12 +678,12 @@ impl AsLogicalPlan for LogicalPlanNode { let left_keys: Vec = join .left_join_key .iter() - .map(|expr| from_proto::parse_expr(expr, ctx)) + .map(|expr| from_proto::parse_expr(expr, ctx, extension_codec)) .collect::, _>>()?; let right_keys: Vec = join .right_join_key .iter() - .map(|expr| from_proto::parse_expr(expr, ctx)) + .map(|expr| from_proto::parse_expr(expr, ctx, extension_codec)) .collect::, _>>()?; let join_type = protobuf::JoinType::try_from(join.join_type).map_err(|_| { @@ -689,7 +704,7 @@ impl AsLogicalPlan for LogicalPlanNode { let filter: Option = join .filter .as_ref() - .map(|expr| from_proto::parse_expr(expr, ctx)) + .map(|expr| from_proto::parse_expr(expr, ctx, extension_codec)) .map_or(Ok(None), |v| v.map(Some))?; let builder = LogicalPlanBuilder::from(into_logical_plan!( @@ -769,12 +784,12 @@ impl AsLogicalPlan for LogicalPlanNode { let on_expr = distinct_on .on_expr .iter() - .map(|expr| from_proto::parse_expr(expr, ctx)) + .map(|expr| from_proto::parse_expr(expr, ctx, extension_codec)) .collect::, _>>()?; let select_expr = distinct_on .select_expr .iter() - .map(|expr| from_proto::parse_expr(expr, ctx)) + .map(|expr| from_proto::parse_expr(expr, ctx, extension_codec)) .collect::, _>>()?; let sort_expr = match distinct_on.sort_expr.len() { 0 => None, @@ -782,7 +797,9 @@ impl AsLogicalPlan for LogicalPlanNode { distinct_on .sort_expr .iter() - .map(|expr| from_proto::parse_expr(expr, ctx)) + .map(|expr| { + from_proto::parse_expr(expr, ctx, extension_codec) + }) .collect::, _>>()?, ), }; @@ -944,7 +961,7 @@ impl AsLogicalPlan for LogicalPlanNode { let values_list = values .iter() .flatten() - .map(|v| v.try_into()) + .map(|v| serialize_expr(v, extension_codec)) .collect::, _>>()?; Ok(protobuf::LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Values( @@ -982,7 +999,7 @@ impl AsLogicalPlan for LogicalPlanNode { let filters: Vec = filters .iter() - .map(|filter| filter.try_into()) + .map(|filter| serialize_expr(filter, extension_codec)) .collect::, _>>()?; if let Some(listing_table) = source.downcast_ref::() { @@ -1039,7 +1056,7 @@ impl AsLogicalPlan for LogicalPlanNode { let expr_vec = LogicalExprNodeCollection { logical_expr_nodes: order .iter() - .map(|expr| expr.try_into()) + .map(|expr| serialize_expr(expr, extension_codec)) .collect::, to_proto::Error>>()?, }; exprs_vec.push(expr_vec); @@ -1120,7 +1137,7 @@ impl AsLogicalPlan for LogicalPlanNode { )), expr: expr .iter() - .map(|expr| expr.try_into()) + .map(|expr| serialize_expr(expr, extension_codec)) .collect::, to_proto::Error>>()?, optional_alias: None, }, @@ -1137,7 +1154,10 @@ impl AsLogicalPlan for LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Selection(Box::new( protobuf::SelectionNode { input: Some(Box::new(input)), - expr: Some((&filter.predicate).try_into()?), + expr: Some(serialize_expr( + &filter.predicate, + extension_codec, + )?), }, ))), }) @@ -1172,7 +1192,7 @@ impl AsLogicalPlan for LogicalPlanNode { None => vec![], Some(sort_expr) => sort_expr .iter() - .map(|expr| expr.try_into()) + .map(|expr| serialize_expr(expr, extension_codec)) .collect::, _>>()?, }; Ok(protobuf::LogicalPlanNode { @@ -1180,11 +1200,11 @@ impl AsLogicalPlan for LogicalPlanNode { protobuf::DistinctOnNode { on_expr: on_expr .iter() - .map(|expr| expr.try_into()) + .map(|expr| serialize_expr(expr, extension_codec)) .collect::, _>>()?, select_expr: select_expr .iter() - .map(|expr| expr.try_into()) + .map(|expr| serialize_expr(expr, extension_codec)) .collect::, _>>()?, sort_expr, input: Some(Box::new(input)), @@ -1206,7 +1226,7 @@ impl AsLogicalPlan for LogicalPlanNode { input: Some(Box::new(input)), window_expr: window_expr .iter() - .map(|expr| expr.try_into()) + .map(|expr| serialize_expr(expr, extension_codec)) .collect::, _>>()?, }, ))), @@ -1229,11 +1249,11 @@ impl AsLogicalPlan for LogicalPlanNode { input: Some(Box::new(input)), group_expr: group_expr .iter() - .map(|expr| expr.try_into()) + .map(|expr| serialize_expr(expr, extension_codec)) .collect::, _>>()?, aggr_expr: aggr_expr .iter() - .map(|expr| expr.try_into()) + .map(|expr| serialize_expr(expr, extension_codec)) .collect::, _>>()?, }, ))), @@ -1261,7 +1281,12 @@ impl AsLogicalPlan for LogicalPlanNode { )?; let (left_join_key, right_join_key) = on .iter() - .map(|(l, r)| Ok((l.try_into()?, r.try_into()?))) + .map(|(l, r)| { + Ok(( + serialize_expr(l, extension_codec)?, + serialize_expr(r, extension_codec)?, + )) + }) .collect::, to_proto::Error>>()? .into_iter() .unzip(); @@ -1270,7 +1295,7 @@ impl AsLogicalPlan for LogicalPlanNode { join_constraint.to_owned().into(); let filter = filter .as_ref() - .map(|e| e.try_into()) + .map(|e| serialize_expr(e, extension_codec)) .map_or(Ok(None), |v| v.map(Some))?; Ok(protobuf::LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Join(Box::new( @@ -1329,7 +1354,7 @@ impl AsLogicalPlan for LogicalPlanNode { )?; let selection_expr: Vec = expr .iter() - .map(|expr| expr.try_into()) + .map(|expr| serialize_expr(expr, extension_codec)) .collect::, to_proto::Error>>()?; Ok(protobuf::LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Sort(Box::new( @@ -1361,7 +1386,7 @@ impl AsLogicalPlan for LogicalPlanNode { PartitionMethod::Hash(protobuf::HashRepartition { hash_expr: exprs .iter() - .map(|expr| expr.try_into()) + .map(|expr| serialize_expr(expr, extension_codec)) .collect::, to_proto::Error>>()?, partition_count: *partition_count as u64, }) @@ -1416,9 +1441,8 @@ impl AsLogicalPlan for LogicalPlanNode { let temp = LogicalExprNodeCollection { logical_expr_nodes: order .iter() - .map(|expr| expr.try_into()) - .collect::, to_proto::Error>>( - )?, + .map(|expr| serialize_expr(expr, extension_codec)) + .collect::, to_proto::Error>>()?, }; converted_order_exprs.push(temp); } @@ -1426,7 +1450,8 @@ impl AsLogicalPlan for LogicalPlanNode { let mut converted_column_defaults = HashMap::with_capacity(column_defaults.len()); for (col_name, expr) in column_defaults { - converted_column_defaults.insert(col_name.clone(), expr.try_into()?); + converted_column_defaults + .insert(col_name.clone(), serialize_expr(expr, extension_codec)?); } let file_compression_type = diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index c913119ff9ed..a4e9fd423bbf 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -56,6 +56,8 @@ use datafusion_expr::{ TryCast, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; +use super::LogicalExtensionCodec; + #[derive(Debug)] pub enum Error { General(String), @@ -480,615 +482,612 @@ impl TryFrom<&WindowFrame> for protobuf::WindowFrame { } } -impl TryFrom<&Expr> for protobuf::LogicalExprNode { - type Error = Error; +pub fn serialize_expr( + expr: &Expr, + codec: &dyn LogicalExtensionCodec, +) -> Result { + use protobuf::logical_expr_node::ExprType; + + let expr_node = match expr { + Expr::Column(c) => protobuf::LogicalExprNode { + expr_type: Some(ExprType::Column(c.into())), + }, + Expr::Alias(Alias { + expr, + relation, + name, + }) => { + let alias = Box::new(protobuf::AliasNode { + expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), + relation: relation + .to_owned() + .map(|r| vec![r.into()]) + .unwrap_or(vec![]), + alias: name.to_owned(), + }); + protobuf::LogicalExprNode { + expr_type: Some(ExprType::Alias(alias)), + } + } + Expr::Literal(value) => { + let pb_value: protobuf::ScalarValue = value.try_into()?; + protobuf::LogicalExprNode { + expr_type: Some(ExprType::Literal(pb_value)), + } + } + Expr::BinaryExpr(BinaryExpr { left, op, right }) => { + // Try to linerize a nested binary expression tree of the same operator + // into a flat vector of expressions. + let mut exprs = vec![right.as_ref()]; + let mut current_expr = left.as_ref(); + while let Expr::BinaryExpr(BinaryExpr { + left, + op: current_op, + right, + }) = current_expr + { + if current_op == op { + exprs.push(right.as_ref()); + current_expr = left.as_ref(); + } else { + break; + } + } + exprs.push(current_expr); - fn try_from(expr: &Expr) -> Result { - use protobuf::logical_expr_node::ExprType; + let binary_expr = protobuf::BinaryExprNode { + // We need to reverse exprs since operands are expected to be + // linearized from left innermost to right outermost (but while + // traversing the chain we do the exact opposite). + operands: exprs + .into_iter() + .rev() + .map(|expr| serialize_expr(expr, codec)) + .collect::, Error>>()?, + op: format!("{op:?}"), + }; + protobuf::LogicalExprNode { + expr_type: Some(ExprType::BinaryExpr(binary_expr)), + } + } + Expr::Like(Like { + negated, + expr, + pattern, + escape_char, + case_insensitive, + }) => { + if *case_insensitive { + let pb = Box::new(protobuf::ILikeNode { + negated: *negated, + expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), + pattern: Some(Box::new(serialize_expr(pattern.as_ref(), codec)?)), + escape_char: escape_char.map(|ch| ch.to_string()).unwrap_or_default(), + }); - let expr_node = match expr { - Expr::Column(c) => Self { - expr_type: Some(ExprType::Column(c.into())), - }, - Expr::Alias(Alias { - expr, - relation, - name, - }) => { - let alias = Box::new(protobuf::AliasNode { - expr: Some(Box::new(expr.as_ref().try_into()?)), - relation: relation - .to_owned() - .map(|r| vec![r.into()]) - .unwrap_or(vec![]), - alias: name.to_owned(), + protobuf::LogicalExprNode { + expr_type: Some(ExprType::Ilike(pb)), + } + } else { + let pb = Box::new(protobuf::LikeNode { + negated: *negated, + expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), + pattern: Some(Box::new(serialize_expr(pattern.as_ref(), codec)?)), + escape_char: escape_char.map(|ch| ch.to_string()).unwrap_or_default(), }); - Self { - expr_type: Some(ExprType::Alias(alias)), + + protobuf::LogicalExprNode { + expr_type: Some(ExprType::Like(pb)), } } - Expr::Literal(value) => { - let pb_value: protobuf::ScalarValue = value.try_into()?; - Self { - expr_type: Some(ExprType::Literal(pb_value)), - } + } + Expr::SimilarTo(Like { + negated, + expr, + pattern, + escape_char, + case_insensitive: _, + }) => { + let pb = Box::new(protobuf::SimilarToNode { + negated: *negated, + expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), + pattern: Some(Box::new(serialize_expr(pattern.as_ref(), codec)?)), + escape_char: escape_char.map(|ch| ch.to_string()).unwrap_or_default(), + }); + protobuf::LogicalExprNode { + expr_type: Some(ExprType::SimilarTo(pb)), } - Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - // Try to linerize a nested binary expression tree of the same operator - // into a flat vector of expressions. - let mut exprs = vec![right.as_ref()]; - let mut current_expr = left.as_ref(); - while let Expr::BinaryExpr(BinaryExpr { - left, - op: current_op, - right, - }) = current_expr - { - if current_op == op { - exprs.push(right.as_ref()); - current_expr = left.as_ref(); - } else { - break; - } + } + Expr::WindowFunction(expr::WindowFunction { + ref fun, + ref args, + ref partition_by, + ref order_by, + ref window_frame, + // TODO: support null treatment in proto + null_treatment: _, + }) => { + let window_function = match fun { + WindowFunctionDefinition::AggregateFunction(fun) => { + protobuf::window_expr_node::WindowFunction::AggrFunction( + protobuf::AggregateFunction::from(fun).into(), + ) } - exprs.push(current_expr); - - let binary_expr = protobuf::BinaryExprNode { - // We need to reverse exprs since operands are expected to be - // linearized from left innermost to right outermost (but while - // traversing the chain we do the exact opposite). - operands: exprs - .into_iter() - .rev() - .map(|expr| expr.try_into()) - .collect::, Error>>()?, - op: format!("{op:?}"), - }; - Self { - expr_type: Some(ExprType::BinaryExpr(binary_expr)), + WindowFunctionDefinition::BuiltInWindowFunction(fun) => { + protobuf::window_expr_node::WindowFunction::BuiltInFunction( + protobuf::BuiltInWindowFunction::from(fun).into(), + ) } - } - Expr::Like(Like { - negated, - expr, - pattern, - escape_char, - case_insensitive, - }) => { - if *case_insensitive { - let pb = Box::new(protobuf::ILikeNode { - negated: *negated, - expr: Some(Box::new(expr.as_ref().try_into()?)), - pattern: Some(Box::new(pattern.as_ref().try_into()?)), - escape_char: escape_char - .map(|ch| ch.to_string()) - .unwrap_or_default(), - }); - - Self { - expr_type: Some(ExprType::Ilike(pb)), - } - } else { - let pb = Box::new(protobuf::LikeNode { - negated: *negated, - expr: Some(Box::new(expr.as_ref().try_into()?)), - pattern: Some(Box::new(pattern.as_ref().try_into()?)), - escape_char: escape_char - .map(|ch| ch.to_string()) - .unwrap_or_default(), - }); - - Self { - expr_type: Some(ExprType::Like(pb)), - } + WindowFunctionDefinition::AggregateUDF(aggr_udf) => { + protobuf::window_expr_node::WindowFunction::Udaf( + aggr_udf.name().to_string(), + ) } - } - Expr::SimilarTo(Like { - negated, - expr, - pattern, - escape_char, - case_insensitive: _, - }) => { - let pb = Box::new(protobuf::SimilarToNode { - negated: *negated, - expr: Some(Box::new(expr.as_ref().try_into()?)), - pattern: Some(Box::new(pattern.as_ref().try_into()?)), - escape_char: escape_char.map(|ch| ch.to_string()).unwrap_or_default(), - }); - Self { - expr_type: Some(ExprType::SimilarTo(pb)), + WindowFunctionDefinition::WindowUDF(window_udf) => { + protobuf::window_expr_node::WindowFunction::Udwf( + window_udf.name().to_string(), + ) } + }; + let arg_expr: Option> = if !args.is_empty() { + let arg = &args[0]; + Some(Box::new(serialize_expr(arg, codec)?)) + } else { + None + }; + let partition_by = partition_by + .iter() + .map(|e| serialize_expr(e, codec)) + .collect::, _>>()?; + let order_by = order_by + .iter() + .map(|e| serialize_expr(e, codec)) + .collect::, _>>()?; + + let window_frame: Option = + Some(window_frame.try_into()?); + let window_expr = Box::new(protobuf::WindowExprNode { + expr: arg_expr, + window_function: Some(window_function), + partition_by, + order_by, + window_frame, + }); + protobuf::LogicalExprNode { + expr_type: Some(ExprType::WindowExpr(window_expr)), } - Expr::WindowFunction(expr::WindowFunction { - ref fun, - ref args, - ref partition_by, - ref order_by, - ref window_frame, - // TODO: support null treatment in proto - null_treatment: _, - }) => { - let window_function = match fun { - WindowFunctionDefinition::AggregateFunction(fun) => { - protobuf::window_expr_node::WindowFunction::AggrFunction( - protobuf::AggregateFunction::from(fun).into(), - ) + } + Expr::AggregateFunction(expr::AggregateFunction { + ref func_def, + ref args, + ref distinct, + ref filter, + ref order_by, + }) => match func_def { + AggregateFunctionDefinition::BuiltIn(fun) => { + let aggr_function = match fun { + AggregateFunction::ApproxDistinct => { + protobuf::AggregateFunction::ApproxDistinct + } + AggregateFunction::ApproxPercentileCont => { + protobuf::AggregateFunction::ApproxPercentileCont } - WindowFunctionDefinition::BuiltInWindowFunction(fun) => { - protobuf::window_expr_node::WindowFunction::BuiltInFunction( - protobuf::BuiltInWindowFunction::from(fun).into(), - ) + AggregateFunction::ApproxPercentileContWithWeight => { + protobuf::AggregateFunction::ApproxPercentileContWithWeight } - WindowFunctionDefinition::AggregateUDF(aggr_udf) => { - protobuf::window_expr_node::WindowFunction::Udaf( - aggr_udf.name().to_string(), - ) + AggregateFunction::ArrayAgg => protobuf::AggregateFunction::ArrayAgg, + AggregateFunction::Min => protobuf::AggregateFunction::Min, + AggregateFunction::Max => protobuf::AggregateFunction::Max, + AggregateFunction::Sum => protobuf::AggregateFunction::Sum, + AggregateFunction::BitAnd => protobuf::AggregateFunction::BitAnd, + AggregateFunction::BitOr => protobuf::AggregateFunction::BitOr, + AggregateFunction::BitXor => protobuf::AggregateFunction::BitXor, + AggregateFunction::BoolAnd => protobuf::AggregateFunction::BoolAnd, + AggregateFunction::BoolOr => protobuf::AggregateFunction::BoolOr, + AggregateFunction::Avg => protobuf::AggregateFunction::Avg, + AggregateFunction::Count => protobuf::AggregateFunction::Count, + AggregateFunction::Variance => protobuf::AggregateFunction::Variance, + AggregateFunction::VariancePop => { + protobuf::AggregateFunction::VariancePop } - WindowFunctionDefinition::WindowUDF(window_udf) => { - protobuf::window_expr_node::WindowFunction::Udwf( - window_udf.name().to_string(), - ) + AggregateFunction::Covariance => { + protobuf::AggregateFunction::Covariance + } + AggregateFunction::CovariancePop => { + protobuf::AggregateFunction::CovariancePop + } + AggregateFunction::Stddev => protobuf::AggregateFunction::Stddev, + AggregateFunction::StddevPop => { + protobuf::AggregateFunction::StddevPop + } + AggregateFunction::Correlation => { + protobuf::AggregateFunction::Correlation + } + AggregateFunction::RegrSlope => { + protobuf::AggregateFunction::RegrSlope + } + AggregateFunction::RegrIntercept => { + protobuf::AggregateFunction::RegrIntercept + } + AggregateFunction::RegrR2 => protobuf::AggregateFunction::RegrR2, + AggregateFunction::RegrAvgx => protobuf::AggregateFunction::RegrAvgx, + AggregateFunction::RegrAvgy => protobuf::AggregateFunction::RegrAvgy, + AggregateFunction::RegrCount => { + protobuf::AggregateFunction::RegrCount + } + AggregateFunction::RegrSXX => protobuf::AggregateFunction::RegrSxx, + AggregateFunction::RegrSYY => protobuf::AggregateFunction::RegrSyy, + AggregateFunction::RegrSXY => protobuf::AggregateFunction::RegrSxy, + AggregateFunction::ApproxMedian => { + protobuf::AggregateFunction::ApproxMedian + } + AggregateFunction::Grouping => protobuf::AggregateFunction::Grouping, + AggregateFunction::Median => protobuf::AggregateFunction::Median, + AggregateFunction::FirstValue => { + protobuf::AggregateFunction::FirstValueAgg + } + AggregateFunction::LastValue => { + protobuf::AggregateFunction::LastValueAgg + } + AggregateFunction::NthValue => { + protobuf::AggregateFunction::NthValueAgg + } + AggregateFunction::StringAgg => { + protobuf::AggregateFunction::StringAgg } }; - let arg_expr: Option> = if !args.is_empty() { - let arg = &args[0]; - Some(Box::new(arg.try_into()?)) - } else { - None + + let aggregate_expr = protobuf::AggregateExprNode { + aggr_function: aggr_function.into(), + expr: args + .iter() + .map(|v| serialize_expr(v, codec)) + .collect::, _>>()?, + distinct: *distinct, + filter: match filter { + Some(e) => Some(Box::new(serialize_expr(e, codec)?)), + None => None, + }, + order_by: match order_by { + Some(e) => e + .iter() + .map(|expr| serialize_expr(expr, codec)) + .collect::, _>>()?, + None => vec![], + }, }; - let partition_by = partition_by - .iter() - .map(|e| e.try_into()) - .collect::, _>>()?; - let order_by = order_by - .iter() - .map(|e| e.try_into()) - .collect::, _>>()?; - - let window_frame: Option = - Some(window_frame.try_into()?); - let window_expr = Box::new(protobuf::WindowExprNode { - expr: arg_expr, - window_function: Some(window_function), - partition_by, - order_by, - window_frame, - }); - Self { - expr_type: Some(ExprType::WindowExpr(window_expr)), + protobuf::LogicalExprNode { + expr_type: Some(ExprType::AggregateExpr(Box::new(aggregate_expr))), } } - Expr::AggregateFunction(expr::AggregateFunction { - ref func_def, - ref args, - ref distinct, - ref filter, - ref order_by, - }) => { - match func_def { - AggregateFunctionDefinition::BuiltIn(fun) => { - let aggr_function = match fun { - AggregateFunction::ApproxDistinct => { - protobuf::AggregateFunction::ApproxDistinct - } - AggregateFunction::ApproxPercentileCont => { - protobuf::AggregateFunction::ApproxPercentileCont - } - AggregateFunction::ApproxPercentileContWithWeight => { - protobuf::AggregateFunction::ApproxPercentileContWithWeight - } - AggregateFunction::ArrayAgg => protobuf::AggregateFunction::ArrayAgg, - AggregateFunction::Min => protobuf::AggregateFunction::Min, - AggregateFunction::Max => protobuf::AggregateFunction::Max, - AggregateFunction::Sum => protobuf::AggregateFunction::Sum, - AggregateFunction::BitAnd => protobuf::AggregateFunction::BitAnd, - AggregateFunction::BitOr => protobuf::AggregateFunction::BitOr, - AggregateFunction::BitXor => protobuf::AggregateFunction::BitXor, - AggregateFunction::BoolAnd => protobuf::AggregateFunction::BoolAnd, - AggregateFunction::BoolOr => protobuf::AggregateFunction::BoolOr, - AggregateFunction::Avg => protobuf::AggregateFunction::Avg, - AggregateFunction::Count => protobuf::AggregateFunction::Count, - AggregateFunction::Variance => protobuf::AggregateFunction::Variance, - AggregateFunction::VariancePop => { - protobuf::AggregateFunction::VariancePop - } - AggregateFunction::Covariance => { - protobuf::AggregateFunction::Covariance - } - AggregateFunction::CovariancePop => { - protobuf::AggregateFunction::CovariancePop - } - AggregateFunction::Stddev => protobuf::AggregateFunction::Stddev, - AggregateFunction::StddevPop => { - protobuf::AggregateFunction::StddevPop - } - AggregateFunction::Correlation => { - protobuf::AggregateFunction::Correlation - } - AggregateFunction::RegrSlope => { - protobuf::AggregateFunction::RegrSlope - } - AggregateFunction::RegrIntercept => { - protobuf::AggregateFunction::RegrIntercept - } - AggregateFunction::RegrR2 => protobuf::AggregateFunction::RegrR2, - AggregateFunction::RegrAvgx => protobuf::AggregateFunction::RegrAvgx, - AggregateFunction::RegrAvgy => protobuf::AggregateFunction::RegrAvgy, - AggregateFunction::RegrCount => { - protobuf::AggregateFunction::RegrCount - } - AggregateFunction::RegrSXX => protobuf::AggregateFunction::RegrSxx, - AggregateFunction::RegrSYY => protobuf::AggregateFunction::RegrSyy, - AggregateFunction::RegrSXY => protobuf::AggregateFunction::RegrSxy, - AggregateFunction::ApproxMedian => { - protobuf::AggregateFunction::ApproxMedian - } - AggregateFunction::Grouping => protobuf::AggregateFunction::Grouping, - AggregateFunction::Median => protobuf::AggregateFunction::Median, - AggregateFunction::FirstValue => { - protobuf::AggregateFunction::FirstValueAgg - } - AggregateFunction::LastValue => { - protobuf::AggregateFunction::LastValueAgg - } - AggregateFunction::NthValue => { - protobuf::AggregateFunction::NthValueAgg - } - AggregateFunction::StringAgg => { - protobuf::AggregateFunction::StringAgg - } - }; - - let aggregate_expr = protobuf::AggregateExprNode { - aggr_function: aggr_function.into(), - expr: args + AggregateFunctionDefinition::UDF(fun) => protobuf::LogicalExprNode { + expr_type: Some(ExprType::AggregateUdfExpr(Box::new( + protobuf::AggregateUdfExprNode { + fun_name: fun.name().to_string(), + args: args + .iter() + .map(|expr| serialize_expr(expr, codec)) + .collect::, Error>>()?, + filter: match filter { + Some(e) => Some(Box::new(serialize_expr(e.as_ref(), codec)?)), + None => None, + }, + order_by: match order_by { + Some(e) => e .iter() - .map(|v| v.try_into()) + .map(|expr| serialize_expr(expr, codec)) .collect::, _>>()?, - distinct: *distinct, - filter: match filter { - Some(e) => Some(Box::new(e.as_ref().try_into()?)), - None => None, - }, - order_by: match order_by { - Some(e) => e - .iter() - .map(|expr| expr.try_into()) - .collect::, _>>()?, - None => vec![], - }, - }; - Self { - expr_type: Some(ExprType::AggregateExpr(Box::new( - aggregate_expr, - ))), - } - } - AggregateFunctionDefinition::UDF(fun) => Self { - expr_type: Some(ExprType::AggregateUdfExpr(Box::new( - protobuf::AggregateUdfExprNode { - fun_name: fun.name().to_string(), - args: args - .iter() - .map(|expr| expr.try_into()) - .collect::, Error>>()?, - filter: match filter { - Some(e) => Some(Box::new(e.as_ref().try_into()?)), - None => None, - }, - order_by: match order_by { - Some(e) => e - .iter() - .map(|expr| expr.try_into()) - .collect::, _>>()?, - None => vec![], - }, - }, - ))), + None => vec![], + }, }, - AggregateFunctionDefinition::Name(_) => { - return Err(Error::NotImplemented( + ))), + }, + AggregateFunctionDefinition::Name(_) => { + return Err(Error::NotImplemented( "Proto serialization error: Trying to serialize a unresolved function" .to_string(), )); - } - } } + }, - Expr::ScalarVariable(_, _) => { - return Err(Error::General( - "Proto serialization error: Scalar Variable not supported" - .to_string(), - )) - } - Expr::ScalarFunction(ScalarFunction { func_def, args }) => { - let args = args - .iter() - .map(|expr| expr.try_into()) - .collect::, Error>>()?; - match func_def { - ScalarFunctionDefinition::BuiltIn(fun) => { - let fun: protobuf::ScalarFunction = fun.try_into()?; - Self { - expr_type: Some(ExprType::ScalarFunction( - protobuf::ScalarFunctionNode { - fun: fun.into(), - args, - }, - )), - } + Expr::ScalarVariable(_, _) => { + return Err(Error::General( + "Proto serialization error: Scalar Variable not supported".to_string(), + )) + } + Expr::ScalarFunction(ScalarFunction { func_def, args }) => { + let args = args + .iter() + .map(|expr| serialize_expr(expr, codec)) + .collect::, Error>>()?; + match func_def { + ScalarFunctionDefinition::BuiltIn(fun) => { + let fun: protobuf::ScalarFunction = fun.try_into()?; + protobuf::LogicalExprNode { + expr_type: Some(ExprType::ScalarFunction( + protobuf::ScalarFunctionNode { + fun: fun.into(), + args, + }, + )), } - ScalarFunctionDefinition::UDF(fun) => Self { + } + ScalarFunctionDefinition::UDF(fun) => { + let mut buf = Vec::new(); + let _ = codec.try_encode_udf(fun.as_ref(), &mut buf); + + let fun_definition = if buf.is_empty() { None } else { Some(buf) }; + + protobuf::LogicalExprNode { expr_type: Some(ExprType::ScalarUdfExpr( protobuf::ScalarUdfExprNode { fun_name: fun.name().to_string(), + fun_definition, args, }, )), - }, - ScalarFunctionDefinition::Name(_) => { - return Err(Error::NotImplemented( + } + } + ScalarFunctionDefinition::Name(_) => { + return Err(Error::NotImplemented( "Proto serialization error: Trying to serialize a unresolved function" .to_string(), )); - } } } - Expr::Not(expr) => { - let expr = Box::new(protobuf::Not { - expr: Some(Box::new(expr.as_ref().try_into()?)), - }); - Self { - expr_type: Some(ExprType::NotExpr(expr)), - } + } + Expr::Not(expr) => { + let expr = Box::new(protobuf::Not { + expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), + }); + protobuf::LogicalExprNode { + expr_type: Some(ExprType::NotExpr(expr)), } - Expr::IsNull(expr) => { - let expr = Box::new(protobuf::IsNull { - expr: Some(Box::new(expr.as_ref().try_into()?)), - }); - Self { - expr_type: Some(ExprType::IsNullExpr(expr)), - } + } + Expr::IsNull(expr) => { + let expr = Box::new(protobuf::IsNull { + expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), + }); + protobuf::LogicalExprNode { + expr_type: Some(ExprType::IsNullExpr(expr)), } - Expr::IsNotNull(expr) => { - let expr = Box::new(protobuf::IsNotNull { - expr: Some(Box::new(expr.as_ref().try_into()?)), - }); - Self { - expr_type: Some(ExprType::IsNotNullExpr(expr)), - } + } + Expr::IsNotNull(expr) => { + let expr = Box::new(protobuf::IsNotNull { + expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), + }); + protobuf::LogicalExprNode { + expr_type: Some(ExprType::IsNotNullExpr(expr)), } - Expr::IsTrue(expr) => { - let expr = Box::new(protobuf::IsTrue { - expr: Some(Box::new(expr.as_ref().try_into()?)), - }); - Self { - expr_type: Some(ExprType::IsTrue(expr)), - } + } + Expr::IsTrue(expr) => { + let expr = Box::new(protobuf::IsTrue { + expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), + }); + protobuf::LogicalExprNode { + expr_type: Some(ExprType::IsTrue(expr)), } - Expr::IsFalse(expr) => { - let expr = Box::new(protobuf::IsFalse { - expr: Some(Box::new(expr.as_ref().try_into()?)), - }); - Self { - expr_type: Some(ExprType::IsFalse(expr)), - } + } + Expr::IsFalse(expr) => { + let expr = Box::new(protobuf::IsFalse { + expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), + }); + protobuf::LogicalExprNode { + expr_type: Some(ExprType::IsFalse(expr)), } - Expr::IsUnknown(expr) => { - let expr = Box::new(protobuf::IsUnknown { - expr: Some(Box::new(expr.as_ref().try_into()?)), - }); - Self { - expr_type: Some(ExprType::IsUnknown(expr)), - } + } + Expr::IsUnknown(expr) => { + let expr = Box::new(protobuf::IsUnknown { + expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), + }); + protobuf::LogicalExprNode { + expr_type: Some(ExprType::IsUnknown(expr)), } - Expr::IsNotTrue(expr) => { - let expr = Box::new(protobuf::IsNotTrue { - expr: Some(Box::new(expr.as_ref().try_into()?)), - }); - Self { - expr_type: Some(ExprType::IsNotTrue(expr)), - } + } + Expr::IsNotTrue(expr) => { + let expr = Box::new(protobuf::IsNotTrue { + expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), + }); + protobuf::LogicalExprNode { + expr_type: Some(ExprType::IsNotTrue(expr)), } - Expr::IsNotFalse(expr) => { - let expr = Box::new(protobuf::IsNotFalse { - expr: Some(Box::new(expr.as_ref().try_into()?)), - }); - Self { - expr_type: Some(ExprType::IsNotFalse(expr)), - } + } + Expr::IsNotFalse(expr) => { + let expr = Box::new(protobuf::IsNotFalse { + expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), + }); + protobuf::LogicalExprNode { + expr_type: Some(ExprType::IsNotFalse(expr)), } - Expr::IsNotUnknown(expr) => { - let expr = Box::new(protobuf::IsNotUnknown { - expr: Some(Box::new(expr.as_ref().try_into()?)), - }); - Self { - expr_type: Some(ExprType::IsNotUnknown(expr)), - } + } + Expr::IsNotUnknown(expr) => { + let expr = Box::new(protobuf::IsNotUnknown { + expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), + }); + protobuf::LogicalExprNode { + expr_type: Some(ExprType::IsNotUnknown(expr)), } - Expr::Between(Between { - expr, - negated, - low, - high, - }) => { - let expr = Box::new(protobuf::BetweenNode { - expr: Some(Box::new(expr.as_ref().try_into()?)), - negated: *negated, - low: Some(Box::new(low.as_ref().try_into()?)), - high: Some(Box::new(high.as_ref().try_into()?)), - }); - Self { - expr_type: Some(ExprType::Between(expr)), - } + } + Expr::Between(Between { + expr, + negated, + low, + high, + }) => { + let expr = Box::new(protobuf::BetweenNode { + expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), + negated: *negated, + low: Some(Box::new(serialize_expr(low.as_ref(), codec)?)), + high: Some(Box::new(serialize_expr(high.as_ref(), codec)?)), + }); + protobuf::LogicalExprNode { + expr_type: Some(ExprType::Between(expr)), } - Expr::Case(case) => { - let when_then_expr = case - .when_then_expr - .iter() - .map(|(w, t)| { - Ok(protobuf::WhenThen { - when_expr: Some(w.as_ref().try_into()?), - then_expr: Some(t.as_ref().try_into()?), - }) + } + Expr::Case(case) => { + let when_then_expr = case + .when_then_expr + .iter() + .map(|(w, t)| { + Ok(protobuf::WhenThen { + when_expr: Some(serialize_expr(w.as_ref(), codec)?), + then_expr: Some(serialize_expr(t.as_ref(), codec)?), }) - .collect::, Error>>()?; - let expr = Box::new(protobuf::CaseNode { - expr: match &case.expr { - Some(e) => Some(Box::new(e.as_ref().try_into()?)), - None => None, - }, - when_then_expr, - else_expr: match &case.else_expr { - Some(e) => Some(Box::new(e.as_ref().try_into()?)), - None => None, - }, - }); - Self { - expr_type: Some(ExprType::Case(expr)), - } + }) + .collect::, Error>>()?; + let expr = Box::new(protobuf::CaseNode { + expr: match &case.expr { + Some(e) => Some(Box::new(serialize_expr(e.as_ref(), codec)?)), + None => None, + }, + when_then_expr, + else_expr: match &case.else_expr { + Some(e) => Some(Box::new(serialize_expr(e.as_ref(), codec)?)), + None => None, + }, + }); + protobuf::LogicalExprNode { + expr_type: Some(ExprType::Case(expr)), } - Expr::Cast(Cast { expr, data_type }) => { - let expr = Box::new(protobuf::CastNode { - expr: Some(Box::new(expr.as_ref().try_into()?)), - arrow_type: Some(data_type.try_into()?), - }); - Self { - expr_type: Some(ExprType::Cast(expr)), - } + } + Expr::Cast(Cast { expr, data_type }) => { + let expr = Box::new(protobuf::CastNode { + expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), + arrow_type: Some(data_type.try_into()?), + }); + protobuf::LogicalExprNode { + expr_type: Some(ExprType::Cast(expr)), } - Expr::TryCast(TryCast { expr, data_type }) => { - let expr = Box::new(protobuf::TryCastNode { - expr: Some(Box::new(expr.as_ref().try_into()?)), - arrow_type: Some(data_type.try_into()?), - }); - Self { - expr_type: Some(ExprType::TryCast(expr)), - } + } + Expr::TryCast(TryCast { expr, data_type }) => { + let expr = Box::new(protobuf::TryCastNode { + expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), + arrow_type: Some(data_type.try_into()?), + }); + protobuf::LogicalExprNode { + expr_type: Some(ExprType::TryCast(expr)), } - Expr::Sort(Sort { - expr, - asc, - nulls_first, - }) => { - let expr = Box::new(protobuf::SortExprNode { - expr: Some(Box::new(expr.as_ref().try_into()?)), - asc: *asc, - nulls_first: *nulls_first, - }); - Self { - expr_type: Some(ExprType::Sort(expr)), - } + } + Expr::Sort(Sort { + expr, + asc, + nulls_first, + }) => { + let expr = Box::new(protobuf::SortExprNode { + expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), + asc: *asc, + nulls_first: *nulls_first, + }); + protobuf::LogicalExprNode { + expr_type: Some(ExprType::Sort(expr)), } - Expr::Negative(expr) => { - let expr = Box::new(protobuf::NegativeNode { - expr: Some(Box::new(expr.as_ref().try_into()?)), - }); - Self { - expr_type: Some(ExprType::Negative(expr)), - } + } + Expr::Negative(expr) => { + let expr = Box::new(protobuf::NegativeNode { + expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), + }); + protobuf::LogicalExprNode { + expr_type: Some(ExprType::Negative(expr)), } - Expr::Unnest(Unnest { exprs }) => { - let expr = protobuf::Unnest { - exprs: exprs.iter().map(|expr| expr.try_into()).collect::, - Error, - >>( - )?, - }; - Self { - expr_type: Some(ExprType::Unnest(expr)), - } + } + Expr::Unnest(Unnest { exprs }) => { + let expr = protobuf::Unnest { + exprs: exprs + .iter() + .map(|expr| serialize_expr(expr, codec)) + .collect::, Error>>()?, + }; + protobuf::LogicalExprNode { + expr_type: Some(ExprType::Unnest(expr)), } - Expr::InList(InList { - expr, - list, - negated, - }) => { - let expr = Box::new(protobuf::InListNode { - expr: Some(Box::new(expr.as_ref().try_into()?)), - list: list - .iter() - .map(|expr| expr.try_into()) - .collect::, Error>>()?, - negated: *negated, - }); - Self { - expr_type: Some(ExprType::InList(expr)), - } + } + Expr::InList(InList { + expr, + list, + negated, + }) => { + let expr = Box::new(protobuf::InListNode { + expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), + list: list + .iter() + .map(|expr| serialize_expr(expr, codec)) + .collect::, Error>>()?, + negated: *negated, + }); + protobuf::LogicalExprNode { + expr_type: Some(ExprType::InList(expr)), } - Expr::Wildcard { qualifier } => Self { - expr_type: Some(ExprType::Wildcard(protobuf::Wildcard { - qualifier: qualifier.clone().unwrap_or("".to_string()), - })), - }, - Expr::ScalarSubquery(_) - | Expr::InSubquery(_) - | Expr::Exists { .. } - | Expr::OuterReferenceColumn { .. } => { - // we would need to add logical plan operators to datafusion.proto to support this - // see discussion in https://github.com/apache/arrow-datafusion/issues/2565 - return Err(Error::General("Proto serialization error: Expr::ScalarSubquery(_) | Expr::InSubquery(_) | Expr::Exists { .. } | Exp:OuterReferenceColumn not supported".to_string())); - } - Expr::GetIndexedField(GetIndexedField { expr, field }) => { - let field = match field { - GetFieldAccess::NamedStructField { name } => { - protobuf::get_indexed_field::Field::NamedStructField( - protobuf::NamedStructField { - name: Some(name.try_into()?), - }, - ) - } - GetFieldAccess::ListIndex { key } => { - protobuf::get_indexed_field::Field::ListIndex(Box::new( - protobuf::ListIndex { - key: Some(Box::new(key.as_ref().try_into()?)), - }, - )) - } - GetFieldAccess::ListRange { - start, - stop, - stride, - } => protobuf::get_indexed_field::Field::ListRange(Box::new( - protobuf::ListRange { - start: Some(Box::new(start.as_ref().try_into()?)), - stop: Some(Box::new(stop.as_ref().try_into()?)), - stride: Some(Box::new(stride.as_ref().try_into()?)), + } + Expr::Wildcard { qualifier } => protobuf::LogicalExprNode { + expr_type: Some(ExprType::Wildcard(protobuf::Wildcard { + qualifier: qualifier.clone().unwrap_or("".to_string()), + })), + }, + Expr::ScalarSubquery(_) + | Expr::InSubquery(_) + | Expr::Exists { .. } + | Expr::OuterReferenceColumn { .. } => { + // we would need to add logical plan operators to datafusion.proto to support this + // see discussion in https://github.com/apache/arrow-datafusion/issues/2565 + return Err(Error::General("Proto serialization error: Expr::ScalarSubquery(_) | Expr::InSubquery(_) | Expr::Exists { .. } | Exp:OuterReferenceColumn not supported".to_string())); + } + Expr::GetIndexedField(GetIndexedField { expr, field }) => { + let field = match field { + GetFieldAccess::NamedStructField { name } => { + protobuf::get_indexed_field::Field::NamedStructField( + protobuf::NamedStructField { + name: Some(name.try_into()?), }, - )), - }; - - Self { - expr_type: Some(ExprType::GetIndexedField(Box::new( - protobuf::GetIndexedField { - expr: Some(Box::new(expr.as_ref().try_into()?)), - field: Some(field), + ) + } + GetFieldAccess::ListIndex { key } => { + protobuf::get_indexed_field::Field::ListIndex(Box::new( + protobuf::ListIndex { + key: Some(Box::new(serialize_expr(key.as_ref(), codec)?)), }, - ))), + )) } + GetFieldAccess::ListRange { + start, + stop, + stride, + } => protobuf::get_indexed_field::Field::ListRange(Box::new( + protobuf::ListRange { + start: Some(Box::new(serialize_expr(start.as_ref(), codec)?)), + stop: Some(Box::new(serialize_expr(stop.as_ref(), codec)?)), + stride: Some(Box::new(serialize_expr(stride.as_ref(), codec)?)), + }, + )), + }; + + protobuf::LogicalExprNode { + expr_type: Some(ExprType::GetIndexedField(Box::new( + protobuf::GetIndexedField { + expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), + field: Some(field), + }, + ))), } + } - Expr::GroupingSet(GroupingSet::Cube(exprs)) => Self { - expr_type: Some(ExprType::Cube(CubeNode { - expr: exprs.iter().map(|expr| expr.try_into()).collect::, - Self::Error, - >>( - )?, - })), - }, - Expr::GroupingSet(GroupingSet::Rollup(exprs)) => Self { - expr_type: Some(ExprType::Rollup(RollupNode { - expr: exprs.iter().map(|expr| expr.try_into()).collect::, - Self::Error, - >>( - )?, - })), - }, - Expr::GroupingSet(GroupingSet::GroupingSets(exprs)) => Self { + Expr::GroupingSet(GroupingSet::Cube(exprs)) => protobuf::LogicalExprNode { + expr_type: Some(ExprType::Cube(CubeNode { + expr: exprs + .iter() + .map(|expr| serialize_expr(expr, codec)) + .collect::, Error>>()?, + })), + }, + Expr::GroupingSet(GroupingSet::Rollup(exprs)) => protobuf::LogicalExprNode { + expr_type: Some(ExprType::Rollup(RollupNode { + expr: exprs + .iter() + .map(|expr| serialize_expr(expr, codec)) + .collect::, Error>>()?, + })), + }, + Expr::GroupingSet(GroupingSet::GroupingSets(exprs)) => { + protobuf::LogicalExprNode { expr_type: Some(ExprType::GroupingSet(GroupingSetNode { expr: exprs .iter() @@ -1096,29 +1095,29 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { Ok(LogicalExprList { expr: expr_list .iter() - .map(|expr| expr.try_into()) - .collect::, Self::Error>>()?, + .map(|expr| serialize_expr(expr, codec)) + .collect::, Error>>()?, }) }) - .collect::, Self::Error>>()?, + .collect::, Error>>()?, })), - }, - Expr::Placeholder(Placeholder { id, data_type }) => { - let data_type = match data_type { - Some(data_type) => Some(data_type.try_into()?), - None => None, - }; - Self { - expr_type: Some(ExprType::Placeholder(PlaceholderNode { - id: id.clone(), - data_type, - })), - } } - }; + } + Expr::Placeholder(Placeholder { id, data_type }) => { + let data_type = match data_type { + Some(data_type) => Some(data_type.try_into()?), + None => None, + }; + protobuf::LogicalExprNode { + expr_type: Some(ExprType::Placeholder(PlaceholderNode { + id: id.clone(), + data_type, + })), + } + } + }; - Ok(expr_node) - } + Ok(expr_node) } impl TryFrom<&ScalarValue> for protobuf::ScalarValue { @@ -1464,12 +1463,10 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::ArrayHasAll => Self::ArrayHasAll, BuiltinScalarFunction::ArrayHasAny => Self::ArrayHasAny, BuiltinScalarFunction::ArrayHas => Self::ArrayHas, - BuiltinScalarFunction::ArrayDims => Self::ArrayDims, BuiltinScalarFunction::ArrayDistinct => Self::ArrayDistinct, BuiltinScalarFunction::ArrayElement => Self::ArrayElement, BuiltinScalarFunction::Flatten => Self::Flatten, BuiltinScalarFunction::ArrayLength => Self::ArrayLength, - BuiltinScalarFunction::ArrayNdims => Self::ArrayNdims, BuiltinScalarFunction::ArrayPopFront => Self::ArrayPopFront, BuiltinScalarFunction::ArrayPopBack => Self::ArrayPopBack, BuiltinScalarFunction::ArrayPosition => Self::ArrayPosition, @@ -1487,7 +1484,6 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::ArraySlice => Self::ArraySlice, BuiltinScalarFunction::ArrayIntersect => Self::ArrayIntersect, BuiltinScalarFunction::ArrayUnion => Self::ArrayUnion, - BuiltinScalarFunction::Cardinality => Self::Cardinality, BuiltinScalarFunction::MakeArray => Self::Array, BuiltinScalarFunction::DatePart => Self::DatePart, BuiltinScalarFunction::DateTrunc => Self::DateTrunc, @@ -1508,12 +1504,10 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::ConcatWithSeparator => Self::ConcatWithSeparator, BuiltinScalarFunction::EndsWith => Self::EndsWith, BuiltinScalarFunction::InitCap => Self::InitCap, - BuiltinScalarFunction::InStr => Self::InStr, BuiltinScalarFunction::Left => Self::Left, BuiltinScalarFunction::Lpad => Self::Lpad, BuiltinScalarFunction::Random => Self::Random, BuiltinScalarFunction::Uuid => Self::Uuid, - BuiltinScalarFunction::RegexpLike => Self::RegexpLike, BuiltinScalarFunction::RegexpReplace => Self::RegexpReplace, BuiltinScalarFunction::Repeat => Self::Repeat, BuiltinScalarFunction::Replace => Self::Replace, diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index d2961875d89a..a20baeb4e941 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -60,6 +60,7 @@ use datafusion::physical_plan::{ WindowExpr, }; use datafusion_common::{internal_err, not_impl_err, DataFusionError, Result}; +use datafusion_expr::ScalarUDF; use prost::bytes::BufMut; use prost::Message; @@ -1911,6 +1912,14 @@ pub trait PhysicalExtensionCodec: Debug + Send + Sync { ) -> Result>; fn try_encode(&self, node: Arc, buf: &mut Vec) -> Result<()>; + + fn try_decode_udf(&self, name: &str, _buf: &[u8]) -> Result> { + not_impl_err!("PhysicalExtensionCodec is not provided for scalar function {name}") + } + + fn try_encode_udf(&self, _node: &ScalarUDF, _buf: &mut Vec) -> Result<()> { + Ok(()) + } } #[derive(Debug)] diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index e3bd2cb1dc47..702ae99babd8 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -28,6 +28,8 @@ use arrow::datatypes::{ }; use datafusion_common::file_options::arrow_writer::ArrowWriterOptions; +use datafusion_expr::{ScalarUDF, ScalarUDFImpl}; +use datafusion_proto::logical_plan::to_proto::serialize_expr; use prost::Message; use datafusion::datasource::provider::TableProviderFactory; @@ -62,8 +64,8 @@ use datafusion_proto::bytes::{ logical_plan_from_bytes, logical_plan_from_bytes_with_extension_codec, logical_plan_to_bytes, logical_plan_to_bytes_with_extension_codec, }; -use datafusion_proto::logical_plan::from_proto; use datafusion_proto::logical_plan::LogicalExtensionCodec; +use datafusion_proto::logical_plan::{from_proto, DefaultLogicalExtensionCodec}; use datafusion_proto::protobuf; #[cfg(feature = "json")] @@ -78,13 +80,15 @@ fn roundtrip_json_test(_proto: &protobuf::LogicalExprNode) {} // Given a DataFusion logical Expr, convert it to protobuf and back, using debug formatting to test // equality. -fn roundtrip_expr_test(initial_struct: T, ctx: SessionContext) -where - for<'a> &'a T: TryInto + Debug, - E: Debug, -{ - let proto: protobuf::LogicalExprNode = (&initial_struct).try_into().unwrap(); - let round_trip: Expr = from_proto::parse_expr(&proto, &ctx).unwrap(); +fn roundtrip_expr_test(initial_struct: Expr, ctx: SessionContext) { + let extension_codec = DefaultLogicalExtensionCodec {}; + let proto: protobuf::LogicalExprNode = + match serialize_expr(&initial_struct, &extension_codec) { + Ok(p) => p, + Err(e) => panic!("Error serializing expression: {:?}", e), + }; + let round_trip: Expr = + from_proto::parse_expr(&proto, &ctx, &extension_codec).unwrap(); assert_eq!(format!("{:?}", &initial_struct), format!("{round_trip:?}")); @@ -579,6 +583,11 @@ async fn roundtrip_expr_api() -> Result<()> { encode(col("a").cast_to(&DataType::Utf8, &schema)?, lit("hex")), decode(lit("1234"), lit("hex")), array_to_string(array(vec![lit(1), lit(2), lit(3)]), lit(",")), + array_dims(array(vec![lit(1), lit(2), lit(3)])), + array_ndims(array(vec![lit(1), lit(2), lit(3)])), + cardinality(array(vec![lit(1), lit(2), lit(3)])), + range(lit(1), lit(10), lit(2)), + gen_series(lit(1), lit(10), lit(2)), ]; // ensure expressions created with the expr api can be round tripped @@ -631,6 +640,12 @@ pub mod proto { #[prost(uint64, tag = "1")] pub k: u64, } + + #[derive(Clone, PartialEq, ::prost::Message)] + pub struct MyRegexUdfNode { + #[prost(string, tag = "1")] + pub pattern: String, + } } #[derive(PartialEq, Eq, Hash)] @@ -707,7 +722,7 @@ impl LogicalExtensionCodec for TopKExtensionCodec { let node = TopKPlanNode::new( proto.k as usize, input.clone(), - from_proto::parse_expr(expr, ctx)?, + from_proto::parse_expr(expr, ctx, self)?, ); Ok(Extension { @@ -725,7 +740,7 @@ impl LogicalExtensionCodec for TopKExtensionCodec { if let Some(exec) = node.node.as_any().downcast_ref::() { let proto = proto::TopKPlanProto { k: exec.k as u64, - expr: Some((&exec.expr).try_into()?), + expr: Some(serialize_expr(&exec.expr, self)?), }; proto.encode(buf).map_err(|e| { @@ -756,6 +771,109 @@ impl LogicalExtensionCodec for TopKExtensionCodec { } } +#[derive(Debug)] +struct MyRegexUdf { + signature: Signature, + // regex as original string + pattern: String, +} + +impl MyRegexUdf { + fn new(pattern: String) -> Self { + Self { + signature: Signature::uniform( + 1, + vec![DataType::Int32], + Volatility::Immutable, + ), + pattern, + } + } +} + +/// Implement the ScalarUDFImpl trait for MyRegexUdf +impl ScalarUDFImpl for MyRegexUdf { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "regex_udf" + } + fn signature(&self) -> &Signature { + &self.signature + } + fn return_type(&self, args: &[DataType]) -> Result { + if !matches!(args.first(), Some(&DataType::Utf8)) { + return plan_err!("regex_udf only accepts Utf8 arguments"); + } + Ok(DataType::Int32) + } + fn invoke(&self, _args: &[ColumnarValue]) -> Result { + unimplemented!() + } +} + +#[derive(Debug)] +pub struct ScalarUDFExtensionCodec {} + +impl LogicalExtensionCodec for ScalarUDFExtensionCodec { + fn try_decode( + &self, + _buf: &[u8], + _inputs: &[LogicalPlan], + _ctx: &SessionContext, + ) -> Result { + not_impl_err!("No extension codec provided") + } + + fn try_encode(&self, _node: &Extension, _buf: &mut Vec) -> Result<()> { + not_impl_err!("No extension codec provided") + } + + fn try_decode_table_provider( + &self, + _buf: &[u8], + _schema: SchemaRef, + _ctx: &SessionContext, + ) -> Result> { + internal_err!("unsupported plan type") + } + + fn try_encode_table_provider( + &self, + _node: Arc, + _buf: &mut Vec, + ) -> Result<()> { + internal_err!("unsupported plan type") + } + + fn try_decode_udf(&self, name: &str, buf: &[u8]) -> Result> { + if name == "regex_udf" { + let proto = proto::MyRegexUdfNode::decode(buf).map_err(|err| { + DataFusionError::Internal(format!("failed to decode regex_udf: {}", err)) + })?; + + Ok(Arc::new(ScalarUDF::new_from_impl(MyRegexUdf::new( + proto.pattern, + )))) + } else { + not_impl_err!("unrecognized scalar UDF implementation, cannot decode") + } + } + + fn try_encode_udf(&self, node: &ScalarUDF, buf: &mut Vec) -> Result<()> { + let binding = node.inner(); + let udf = binding.as_any().downcast_ref::().unwrap(); + let proto = proto::MyRegexUdfNode { + pattern: udf.pattern.clone(), + }; + proto.encode(buf).map_err(|e| { + DataFusionError::Internal(format!("failed to encode udf: {e:?}")) + })?; + Ok(()) + } +} + #[test] fn round_trip_scalar_values() { let should_pass: Vec = vec![ @@ -1664,6 +1782,30 @@ fn roundtrip_scalar_udf() { roundtrip_expr_test(test_expr, ctx); } +#[test] +fn roundtrip_scalar_udf_extension_codec() { + let pattern = ".*"; + let udf = ScalarUDF::from(MyRegexUdf::new(pattern.to_string())); + let test_expr = + Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(udf.clone()), vec![])); + + let ctx = SessionContext::new(); + ctx.register_udf(udf); + + let extension_codec = ScalarUDFExtensionCodec {}; + let proto: protobuf::LogicalExprNode = + match serialize_expr(&test_expr, &extension_codec) { + Ok(p) => p, + Err(e) => panic!("Error serializing expression: {:?}", e), + }; + let round_trip: Expr = + from_proto::parse_expr(&proto, &ctx, &extension_codec).unwrap(); + + assert_eq!(format!("{:?}", &test_expr), format!("{round_trip:?}")); + + roundtrip_json_test(&proto); +} + #[test] fn roundtrip_grouping_sets() { let test_expr = Expr::GroupingSet(GroupingSet::GroupingSets(vec![ diff --git a/datafusion/proto/tests/cases/serialize.rs b/datafusion/proto/tests/cases/serialize.rs index 7dd0333909ee..d4a1ab44a6ea 100644 --- a/datafusion/proto/tests/cases/serialize.rs +++ b/datafusion/proto/tests/cases/serialize.rs @@ -25,6 +25,8 @@ use datafusion::prelude::SessionContext; use datafusion_expr::{col, create_udf, lit, ColumnarValue}; use datafusion_expr::{Expr, Volatility}; use datafusion_proto::bytes::Serializeable; +use datafusion_proto::logical_plan::to_proto::serialize_expr; +use datafusion_proto::logical_plan::DefaultLogicalExtensionCodec; #[test] #[should_panic( @@ -252,7 +254,6 @@ fn test_expression_serialization_roundtrip() { use datafusion_expr::expr::ScalarFunction; use datafusion_expr::BuiltinScalarFunction; use datafusion_proto::logical_plan::from_proto::parse_expr; - use datafusion_proto::protobuf::LogicalExprNode; use strum::IntoEnumIterator; let ctx = SessionContext::new(); @@ -266,8 +267,9 @@ fn test_expression_serialization_roundtrip() { let args: Vec<_> = std::iter::repeat(&lit).take(num_args).cloned().collect(); let expr = Expr::ScalarFunction(ScalarFunction::new(builtin_fun, args)); - let proto = LogicalExprNode::try_from(&expr).unwrap(); - let deserialize = parse_expr(&proto, &ctx).unwrap(); + let extension_codec = DefaultLogicalExtensionCodec {}; + let proto = serialize_expr(&expr, &extension_codec).unwrap(); + let deserialize = parse_expr(&proto, &ctx, &extension_codec).unwrap(); let serialize_name = extract_function_name(&expr); let deserialize_name = extract_function_name(&deserialize); diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index bb73e69ba9f4..ad1d2db70cf4 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -258,10 +258,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { FunctionArg::Named { name: _, arg: FunctionArgExpr::Expr(arg), + operator: _, } => self.sql_expr_to_logical_expr(arg, schema, planner_context), FunctionArg::Named { name: _, arg: FunctionArgExpr::Wildcard, + operator: _, } => Ok(Expr::Wildcard { qualifier: None }), FunctionArg::Unnamed(FunctionArgExpr::Expr(arg)) => { self.sql_expr_to_logical_expr(arg, schema, planner_context) diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index b058fb79b4a1..d36d973cbee6 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -750,7 +750,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { schema: &DFSchema, planner_context: &mut PlannerContext, ) -> Result { - let fun = BuiltinScalarFunction::InStr; + let fun = BuiltinScalarFunction::Strpos; let substr = self.sql_expr_to_logical_expr(substr_expr, schema, planner_context)?; let fullstr = self.sql_expr_to_logical_expr(str_expr, schema, planner_context)?; diff --git a/datafusion/sql/src/expr/value.rs b/datafusion/sql/src/expr/value.rs index c0870cc54106..15524b9ffab1 100644 --- a/datafusion/sql/src/expr/value.rs +++ b/datafusion/sql/src/expr/value.rs @@ -52,6 +52,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { plan_err!("Invalid HexStringLiteral '{s}'") } } + Value::EscapedStringLiteral(s) => Ok(lit(s)), _ => plan_err!("Unsupported Value '{value:?}'"), } } diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 1f21299d8559..2db2c01c5ee1 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -359,11 +359,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { SQLDataType::Array(ArrayElemTypeDef::AngleBracket(inner_sql_type)) | SQLDataType::Array(ArrayElemTypeDef::SquareBracket(inner_sql_type)) => { // Arrays may be multi-dimensional. - let data_type = self.convert_data_type(inner_sql_type)?; - - Ok(DataType::List(Arc::new(Field::new( - "field", data_type, true, - )))) + let inner_data_type = self.convert_data_type(inner_sql_type)?; + Ok(DataType::new_list(inner_data_type, true)) } SQLDataType::Array(ArrayElemTypeDef::None) => { not_impl_err!("Arrays with unspecified type is not supported") diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index bf15146a92f7..dfac8367e912 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -52,10 +52,10 @@ use datafusion_expr::{ }; use sqlparser::ast; use sqlparser::ast::{ - Assignment, ColumnDef, CreateTableOptions, Expr as SQLExpr, Expr, Ident, ObjectName, - ObjectType, Query, SchemaName, SetExpr, ShowCreateObject, ShowStatementFilter, - Statement, TableConstraint, TableFactor, TableWithJoins, TransactionMode, - UnaryOperator, Value, + Assignment, ColumnDef, CreateTableOptions, DescribeAlias, Expr as SQLExpr, Expr, + FromTable, Ident, ObjectName, ObjectType, Query, SchemaName, SetExpr, + ShowCreateObject, ShowStatementFilter, Statement, TableConstraint, TableFactor, + TableWithJoins, TransactionMode, UnaryOperator, Value, }; use sqlparser::parser::ParserError::ParserError; @@ -177,7 +177,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let sql = Some(statement.to_string()); match statement { Statement::ExplainTable { - describe_alias: true, // only parse 'DESCRIBE table_name' and not 'EXPLAIN table_name' + describe_alias: DescribeAlias::Describe, // only parse 'DESCRIBE table_name' and not 'EXPLAIN table_name' + hive_format: _, table_name, } => self.describe_table_to_plan(table_name), Statement::Explain { @@ -630,7 +631,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } } - fn get_delete_target(&self, mut from: Vec) -> Result { + fn get_delete_target(&self, from: FromTable) -> Result { + let mut from = match from { + FromTable::WithFromKeyword(v) => v, + FromTable::WithoutKeyword(v) => v, + }; + if from.len() != 1 { return not_impl_err!( "DELETE FROM only supports single table, got {}: {from:?}", diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs index 0dc1258ebabe..abb896ab113e 100644 --- a/datafusion/sql/src/utils.rs +++ b/datafusion/sql/src/utils.rs @@ -17,34 +17,36 @@ //! SQL Utility Functions +use std::collections::HashMap; + use arrow_schema::{ DataType, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, DECIMAL_DEFAULT_SCALE, }; -use datafusion_common::tree_node::{Transformed, TreeNode}; -use sqlparser::ast::Ident; - -use datafusion_common::{exec_err, internal_err, plan_err}; -use datafusion_common::{DataFusionError, Result, ScalarValue}; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_common::{ + exec_err, internal_err, plan_err, DataFusionError, Result, ScalarValue, +}; use datafusion_expr::expr::{Alias, GroupingSet, WindowFunction}; -use datafusion_expr::expr_vec_fmt; use datafusion_expr::utils::{expr_as_column_expr, find_column_exprs}; -use datafusion_expr::{Expr, LogicalPlan}; -use std::collections::HashMap; +use datafusion_expr::{expr_vec_fmt, Expr, LogicalPlan}; +use sqlparser::ast::Ident; /// Make a best-effort attempt at resolving all columns in the expression tree pub(crate) fn resolve_columns(expr: &Expr, plan: &LogicalPlan) -> Result { - expr.clone().transform_up(&|nested_expr| { - match nested_expr { - Expr::Column(col) => { - let field = plan.schema().field_from_column(&col)?; - Ok(Transformed::Yes(Expr::Column(field.qualified_column()))) - } - _ => { - // keep recursing - Ok(Transformed::No(nested_expr)) + expr.clone() + .transform_up(&|nested_expr| { + match nested_expr { + Expr::Column(col) => { + let field = plan.schema().field_from_column(&col)?; + Ok(Transformed::yes(Expr::Column(field.qualified_column()))) + } + _ => { + // keep recursing + Ok(Transformed::no(nested_expr)) + } } - } - }) + }) + .data() } /// Rebuilds an `Expr` as a projection on top of a collection of `Expr`'s. @@ -66,13 +68,15 @@ pub(crate) fn rebase_expr( base_exprs: &[Expr], plan: &LogicalPlan, ) -> Result { - expr.clone().transform_down(&|nested_expr| { - if base_exprs.contains(&nested_expr) { - Ok(Transformed::Yes(expr_as_column_expr(&nested_expr, plan)?)) - } else { - Ok(Transformed::No(nested_expr)) - } - }) + expr.clone() + .transform_down(&|nested_expr| { + if base_exprs.contains(&nested_expr) { + Ok(Transformed::yes(expr_as_column_expr(&nested_expr, plan)?)) + } else { + Ok(Transformed::no(nested_expr)) + } + }) + .data() } /// Determines if the set of `Expr`'s are a valid projection on the input @@ -170,16 +174,18 @@ pub(crate) fn resolve_aliases_to_exprs( expr: &Expr, aliases: &HashMap, ) -> Result { - expr.clone().transform_up(&|nested_expr| match nested_expr { - Expr::Column(c) if c.relation.is_none() => { - if let Some(aliased_expr) = aliases.get(&c.name) { - Ok(Transformed::Yes(aliased_expr.clone())) - } else { - Ok(Transformed::No(Expr::Column(c))) + expr.clone() + .transform_up(&|nested_expr| match nested_expr { + Expr::Column(c) if c.relation.is_none() => { + if let Some(aliased_expr) = aliases.get(&c.name) { + Ok(Transformed::yes(aliased_expr.clone())) + } else { + Ok(Transformed::no(Expr::Column(c))) + } } - } - _ => Ok(Transformed::No(nested_expr)), - }) + _ => Ok(Transformed::no(nested_expr)), + }) + .data() } /// given a slice of window expressions sharing the same sort key, find their common partition diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 55551d1d25a3..db1beb94446b 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -2881,11 +2881,11 @@ impl ContextProvider for MockContextProvider { } fn get_function_meta(&self, name: &str) -> Option> { - self.udfs.get(name).map(Arc::clone) + self.udfs.get(name).cloned() } fn get_aggregate_meta(&self, name: &str) -> Option> { - self.udafs.get(name).map(Arc::clone) + self.udafs.get(name).cloned() } fn get_variable_type(&self, _: &[String]) -> Option { @@ -4423,6 +4423,31 @@ fn test_field_not_found_window_function() { quick_test(qualified_sql, expected); } +#[test] +fn test_parse_escaped_string_literal_value() { + let sql = r"SELECT length('\r\n') AS len"; + let expected = "Projection: character_length(Utf8(\"\\r\\n\")) AS len\ + \n EmptyRelation"; + quick_test(sql, expected); + + let sql = r"SELECT length(E'\r\n') AS len"; + let expected = "Projection: character_length(Utf8(\"\r\n\")) AS len\ + \n EmptyRelation"; + quick_test(sql, expected); + + let sql = r"SELECT length(E'\445') AS len, E'\x4B' AS hex, E'\u0001' AS unicode"; + let expected = + "Projection: character_length(Utf8(\"%\")) AS len, Utf8(\"\u{004b}\") AS hex, Utf8(\"\u{0001}\") AS unicode\ + \n EmptyRelation"; + quick_test(sql, expected); + + let sql = r"SELECT length(E'\000') AS len"; + assert_eq!( + logical_plan(sql).unwrap_err().strip_backtrace(), + "SQL error: TokenizerError(\"Unterminated encoded string literal at Line: 1, Column 15\")" + ) +} + fn assert_field_not_found(err: DataFusionError, name: &str) { match err { DataFusionError::SchemaError { .. } => { diff --git a/datafusion/sqllogictest/Cargo.toml b/datafusion/sqllogictest/Cargo.toml index 12c4c96d5236..c348f2cddc93 100644 --- a/datafusion/sqllogictest/Cargo.toml +++ b/datafusion/sqllogictest/Cargo.toml @@ -39,6 +39,7 @@ chrono = { workspace = true, optional = true } clap = { version = "4.4.8", features = ["derive", "env"] } datafusion = { workspace = true, default-features = true } datafusion-common = { workspace = true, default-features = true } +datafusion-common-runtime = { workspace = true, default-features = true } futures = { workspace = true } half = { workspace = true, default-features = true } itertools = { workspace = true } diff --git a/datafusion/sqllogictest/bin/sqllogictests.rs b/datafusion/sqllogictest/bin/sqllogictests.rs index 41c33deec643..268d09681c72 100644 --- a/datafusion/sqllogictest/bin/sqllogictests.rs +++ b/datafusion/sqllogictest/bin/sqllogictests.rs @@ -28,6 +28,7 @@ use log::info; use sqllogictest::strict_column_validator; use datafusion_common::{exec_datafusion_err, exec_err, DataFusionError, Result}; +use datafusion_common_runtime::SpawnedTask; const TEST_DIRECTORY: &str = "test_files/"; const PG_COMPAT_FILE_PREFIX: &str = "pg_compat_"; @@ -88,8 +89,7 @@ async fn run_tests() -> Result<()> { // modifying shared state like `/tmp/`) let errors: Vec<_> = futures::stream::iter(read_test_files(&options)?) .map(|test_file| { - #[allow(clippy::disallowed_methods)] // spawn allowed only in tests - tokio::task::spawn(async move { + SpawnedTask::spawn(async move { println!("Running {:?}", test_file.relative_path); if options.complete { run_complete_file(test_file).await?; @@ -100,6 +100,7 @@ async fn run_tests() -> Result<()> { } Ok(()) as Result<()> }) + .join() }) // run up to num_cpus streams in parallel .buffer_unordered(num_cpus::get()) diff --git a/datafusion/sqllogictest/src/engines/postgres_engine/types.rs b/datafusion/sqllogictest/src/engines/postgres_engine/types.rs index 0c66150d1bb4..510462befb08 100644 --- a/datafusion/sqllogictest/src/engines/postgres_engine/types.rs +++ b/datafusion/sqllogictest/src/engines/postgres_engine/types.rs @@ -16,6 +16,7 @@ // under the License. use postgres_types::Type; +use std::fmt::Display; use tokio_postgres::types::FromSql; pub struct PgRegtype { @@ -37,8 +38,8 @@ impl<'a> FromSql<'a> for PgRegtype { } } -impl ToString for PgRegtype { - fn to_string(&self) -> String { - self.value.clone() +impl Display for PgRegtype { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.value) } } diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 4e6cb4d59d14..68a7a3474680 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -883,6 +883,11 @@ select arrow_cast([1, 2, 3], 'LargeList(Int64)')[0:0], ---- [] [1, 2] [h, e, l, l, o] +query I +select arrow_cast([1, 2, 3], 'LargeList(Int64)')[1]; +---- +1 + # TODO: support multiple negative index # multiple index with columns #3 (negative index) # query II @@ -5549,26 +5554,74 @@ from arrays_range; [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] [3, 4, 5, 6, 7, 8, 9] [3, 5, 7, 9] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] [4, 5, 6, 7, 8, 9, 10, 11, 12] [4, 7, 10] -query ?????? +query ?????????? select range(5), range(2, 5), range(2, 10, 3), range(1, 5, -1), range(1, -5, 1), - range(1, -5, -1) + range(1, -5, -1), + range(DATE '1992-09-01', DATE '1993-03-01', INTERVAL '1' MONTH), + range(DATE '1993-02-01', DATE '1993-01-01', INTERVAL '-1' DAY), + range(DATE '1989-04-01', DATE '1993-03-01', INTERVAL '1' YEAR), + range(DATE '1993-03-01', DATE '1989-04-01', INTERVAL '1' YEAR) ; ---- -[0, 1, 2, 3, 4] [2, 3, 4] [2, 5, 8] [] [] [1, 0, -1, -2, -3, -4] +[0, 1, 2, 3, 4] [2, 3, 4] [2, 5, 8] [] [] [1, 0, -1, -2, -3, -4] [1992-09-01, 1992-10-01, 1992-11-01, 1992-12-01, 1993-01-01, 1993-02-01] [1993-02-01, 1993-01-31, 1993-01-30, 1993-01-29, 1993-01-28, 1993-01-27, 1993-01-26, 1993-01-25, 1993-01-24, 1993-01-23, 1993-01-22, 1993-01-21, 1993-01-20, 1993-01-19, 1993-01-18, 1993-01-17, 1993-01-16, 1993-01-15, 1993-01-14, 1993-01-13, 1993-01-12, 1993-01-11, 1993-01-10, 1993-01-09, 1993-01-08, 1993-01-07, 1993-01-06, 1993-01-05, 1993-01-04, 1993-01-03, 1993-01-02] [1989-04-01, 1990-04-01, 1991-04-01] [] + +## should throw error +query error +select range(DATE '1992-09-01', NULL, INTERVAL '1' YEAR); + +query error +select range(DATE '1992-09-01', DATE '1993-03-01', NULL); + +query error +select range(NULL, DATE '1993-03-01', INTERVAL '1' YEAR); + +query ? +select range(DATE '1989-04-01', DATE '1993-03-01', INTERVAL '-1' YEAR) +---- +[] + +query ? +select range(DATE '1993-03-01', DATE '1989-04-01', INTERVAL '1' YEAR) +---- +[] -query ????? +query ???????? select generate_series(5), generate_series(2, 5), generate_series(2, 10, 3), generate_series(1, 5, 1), - generate_series(5, 1, -1) + generate_series(5, 1, -1), + generate_series(DATE '1992-09-01', DATE '1993-03-01', INTERVAL '1' MONTH), + generate_series(DATE '1993-02-01', DATE '1993-01-01', INTERVAL '-1' DAY), + generate_series(DATE '1989-04-01', DATE '1993-03-01', INTERVAL '1' YEAR) ; ---- -[0, 1, 2, 3, 4, 5] [2, 3, 4, 5] [2, 5, 8] [1, 2, 3, 4, 5] [5, 4, 3, 2, 1] +[0, 1, 2, 3, 4, 5] [2, 3, 4, 5] [2, 5, 8] [1, 2, 3, 4, 5] [5, 4, 3, 2, 1] [1992-09-01, 1992-10-01, 1992-11-01, 1992-12-01, 1993-01-01, 1993-02-01, 1993-03-01] [1993-02-01, 1993-01-31, 1993-01-30, 1993-01-29, 1993-01-28, 1993-01-27, 1993-01-26, 1993-01-25, 1993-01-24, 1993-01-23, 1993-01-22, 1993-01-21, 1993-01-20, 1993-01-19, 1993-01-18, 1993-01-17, 1993-01-16, 1993-01-15, 1993-01-14, 1993-01-13, 1993-01-12, 1993-01-11, 1993-01-10, 1993-01-09, 1993-01-08, 1993-01-07, 1993-01-06, 1993-01-05, 1993-01-04, 1993-01-03, 1993-01-02, 1993-01-01] [1989-04-01, 1990-04-01, 1991-04-01, 1992-04-01] + +## should throw error +query error +select generate_series(DATE '1992-09-01', NULL, INTERVAL '1' YEAR); + +query error +select generate_series(DATE '1992-09-01', DATE '1993-03-01', NULL); + +query error +select generate_series(NULL, DATE '1993-03-01', INTERVAL '1' YEAR); + + +query ? +select generate_series(DATE '1989-04-01', DATE '1993-03-01', INTERVAL '-1' YEAR) +---- +[] + +query ? +select generate_series(DATE '1993-03-01', DATE '1989-04-01', INTERVAL '1' YEAR) +---- +[] ## array_except @@ -6208,12 +6261,19 @@ select * from test_create_array_table; query T select arrow_typeof(a) from test_create_array_table; ---- -List(Field { name: "field", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) +List(Field { name: "item", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) query T select arrow_typeof(c) from test_create_array_table; ---- -List(Field { name: "field", data_type: List(Field { name: "field", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) +List(Field { name: "item", data_type: List(Field { name: "item", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) + +# Test casting to array types +# issue: https://github.com/apache/arrow-datafusion/issues/9440 +query ??T +select [1,2,3]::int[], [['1']]::int[][], arrow_typeof([]::text[]); +---- +[1, 2, 3] [[1]] List(Field { name: "item", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) ### Delete tables diff --git a/datafusion/sqllogictest/test_files/arrow_files.slt b/datafusion/sqllogictest/test_files/arrow_files.slt index 5c1b6fb726ed..8cf3550fdb25 100644 --- a/datafusion/sqllogictest/test_files/arrow_files.slt +++ b/datafusion/sqllogictest/test_files/arrow_files.slt @@ -42,3 +42,74 @@ SELECT * FROM arrow_simple 2 bar NULL 3 baz false 4 NULL true + +# ARROW partitioned table +statement ok +CREATE EXTERNAL TABLE arrow_partitioned ( + part Int, + f0 Bigint, + f1 String, + f2 Boolean +) +STORED AS ARROW +LOCATION '../core/tests/data/partitioned_table_arrow/' +PARTITIONED BY (part); + +# select wildcard +query ITBI +SELECT * FROM arrow_partitioned ORDER BY f0; +---- +1 foo true 123 +2 bar false 123 +3 baz true 456 +4 NULL NULL 456 + +# select all fields +query IITB +SELECT part, f0, f1, f2 FROM arrow_partitioned ORDER BY f0; +---- +123 1 foo true +123 2 bar false +456 3 baz true +456 4 NULL NULL + +# select without partition column +query IB +SELECT f0, f2 FROM arrow_partitioned ORDER BY f0 +---- +1 true +2 false +3 true +4 NULL + +# select only partition column +query I +SELECT part FROM arrow_partitioned ORDER BY part +---- +123 +123 +456 +456 + +# select without any table-related columns in projection +query I +SELECT 1 FROM arrow_partitioned +---- +1 +1 +1 +1 + +# select with partition filter +query I +SELECT f0 FROM arrow_partitioned WHERE part = 123 ORDER BY f0 +---- +1 +2 + +# select with partition filter should scan only one directory +query TT +EXPLAIN SELECT f0 FROM arrow_partitioned WHERE part = 456 +---- +logical_plan TableScan: arrow_partitioned projection=[f0], full_filters=[arrow_partitioned.part = Int32(456)] +physical_plan ArrowExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/partitioned_table_arrow/part=456/data.arrow]]}, projection=[f0] diff --git a/datafusion/sqllogictest/test_files/group_by.slt b/datafusion/sqllogictest/test_files/group_by.slt index 213a3f6b52ec..906926a5a9ab 100644 --- a/datafusion/sqllogictest/test_files/group_by.slt +++ b/datafusion/sqllogictest/test_files/group_by.slt @@ -3207,20 +3207,20 @@ ORDER BY first_c 9 15 -query ITIPTR +query ITIPTR rowsort SELECT s.zip_code, s.country, s.sn, s.ts, s.currency, LAST_VALUE(e.amount ORDER BY e.sn) AS last_rate FROM sales_global AS s JOIN sales_global AS e ON s.currency = e.currency AND s.ts >= e.ts GROUP BY s.sn, s.zip_code, s.country, s.ts, s.currency -ORDER BY s.sn +ORDER BY s.sn, s.zip_code ---- 0 GRC 0 2022-01-01T06:00:00 EUR 30 +0 GRC 4 2022-01-03T10:00:00 EUR 80 1 FRA 1 2022-01-01T08:00:00 EUR 50 -1 TUR 2 2022-01-01T11:30:00 TRY 75 1 FRA 3 2022-01-02T12:00:00 EUR 200 -0 GRC 4 2022-01-03T10:00:00 EUR 80 +1 TUR 2 2022-01-01T11:30:00 TRY 75 1 TUR 4 2022-01-03T10:00:00 TRY 100 # create a table for testing diff --git a/datafusion/sqllogictest/test_files/insert_to_external.slt b/datafusion/sqllogictest/test_files/insert_to_external.slt index c0f86ac76320..0082f2ecefb9 100644 --- a/datafusion/sqllogictest/test_files/insert_to_external.slt +++ b/datafusion/sqllogictest/test_files/insert_to_external.slt @@ -99,9 +99,11 @@ select * from dictionary_encoded_arrow_test_readback; ---- b -# https://github.com/apache/arrow-datafusion/issues/7816 -query error DataFusion error: Arrow error: Schema error: project index 1 out of bounds, max field 1 +query TT select * from dictionary_encoded_arrow_partitioned order by (a); +---- +a foo +b bar # test_insert_into @@ -195,9 +197,15 @@ INSERT INTO partitioned_insert_test_json values (1, 2), (3, 4), (5, 6), (1, 2), ---- 6 -# Issue open for this error: https://github.com/apache/arrow-datafusion/issues/7816 -query error DataFusion error: Arrow error: Json error: Encountered unmasked nulls in non\-nullable StructArray child: Field \{ name: "a", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: \{\} \} +query TT select * from partitioned_insert_test_json order by a,b +---- +1 2 +1 2 +3 4 +3 4 +5 6 +5 6 statement ok CREATE EXTERNAL TABLE diff --git a/datafusion/sqllogictest/test_files/json.slt b/datafusion/sqllogictest/test_files/json.slt index c0d5e895f0f2..24c97816fe7f 100644 --- a/datafusion/sqllogictest/test_files/json.slt +++ b/datafusion/sqllogictest/test_files/json.slt @@ -68,3 +68,74 @@ DROP TABLE json_test statement ok DROP TABLE single_nan + +# JSON partitioned table +statement ok +CREATE EXTERNAL TABLE json_partitioned_test ( + part Int, + id Int, + value String, +) +STORED AS JSON +LOCATION '../core/tests/data/partitioned_table_json' +PARTITIONED BY (part); + +# select wildcard always returns partition columns as the last ones +query ITI +SELECT * FROM json_partitioned_test ORDER BY id +---- +1 foo 1 +2 bar 1 +3 baz 2 +4 qux 2 + + +# select all fields +query IIT +SELECT part, id, value FROM json_partitioned_test ORDER BY id +---- +1 1 foo +1 2 bar +2 3 baz +2 4 qux + +# select without partition column +query I +SELECT id FROM json_partitioned_test ORDER BY id +---- +1 +2 +3 +4 + +# select only partition column +query I +SELECT part FROM json_partitioned_test ORDER BY part +---- +1 +1 +2 +2 + +# select without any table-related columns in projection +query T +SELECT 'x' FROM json_partitioned_test +---- +x +x +x +x + +# select with partition filter +query I +SELECT id FROM json_partitioned_test WHERE part = 1 ORDER BY id +---- +1 +2 + +# select with partition filter should scan only one directory +query TT +EXPLAIN SELECT id FROM json_partitioned_test WHERE part = 2 +---- +logical_plan TableScan: json_partitioned_test projection=[id], full_filters=[json_partitioned_test.part = Int32(2)] +physical_plan JsonExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/partitioned_table_json/part=2/data.json]]}, projection=[id] diff --git a/datafusion/sqllogictest/test_files/limit.slt b/datafusion/sqllogictest/test_files/limit.slt index e063d6e8960a..37a0360f8c26 100644 --- a/datafusion/sqllogictest/test_files/limit.slt +++ b/datafusion/sqllogictest/test_files/limit.slt @@ -400,7 +400,7 @@ AggregateExec: mode=FinalPartitioned, gby=[i@0 as i], aggr=[] --CoalesceBatchesExec: target_batch_size=8192 ----RepartitionExec: partitioning=Hash([i@0], 4), input_partitions=4 ------AggregateExec: mode=Partial, gby=[i@0 as i], aggr=[] ---------MemoryExec: partitions=4, partition_sizes=[1, 1, 2, 1] +--------MemoryExec: partitions=4, partition_sizes=[1, 2, 1, 1] query I SELECT i FROM t1000 ORDER BY i DESC LIMIT 3; diff --git a/datafusion/sqllogictest/test_files/order.slt b/datafusion/sqllogictest/test_files/order.slt index 2ea78448b940..f63179a369c5 100644 --- a/datafusion/sqllogictest/test_files/order.slt +++ b/datafusion/sqllogictest/test_files/order.slt @@ -784,3 +784,110 @@ SortPreservingMergeExec: [m@0 ASC NULLS LAST,t@1 ASC NULLS LAST] ----------------AggregateExec: mode=Partial, gby=[1 as Int64(1), t@0 as t], aggr=[], ordering_mode=PartiallySorted([0]) ------------------ProjectionExec: expr=[column1@0 as t] --------------------ValuesExec + +##### +# Multi column sorting with lists +##### + +statement ok +create table foo as values (2, [0]), (4, [1]), (2, [6]), (1, [2, 5]), (0, [3]), (3, [4]), (2, [2, 5]), (2, [7]); + +query I? +select column1, column2 from foo ORDER BY column1, column2; +---- +0 [3] +1 [2, 5] +2 [0] +2 [2, 5] +2 [6] +2 [7] +3 [4] +4 [1] + +query I? +select column1, column2 from foo ORDER BY column1 desc, column2; +---- +4 [1] +3 [4] +2 [0] +2 [2, 5] +2 [6] +2 [7] +1 [2, 5] +0 [3] + +query I? +select column1, column2 from foo ORDER BY column1, column2 desc; +---- +0 [3] +1 [2, 5] +2 [7] +2 [6] +2 [2, 5] +2 [0] +3 [4] +4 [1] + +query I? +select column1, column2 from foo ORDER BY column1 desc, column2 desc; +---- +4 [1] +3 [4] +2 [7] +2 [6] +2 [2, 5] +2 [0] +1 [2, 5] +0 [3] + +query ?I +select column2, column1 from foo ORDER BY column2, column1; +---- +[0] 2 +[1] 4 +[2, 5] 1 +[2, 5] 2 +[3] 0 +[4] 3 +[6] 2 +[7] 2 + +query ?I +select column2, column1 from foo ORDER BY column2 desc, column1; +---- +[7] 2 +[6] 2 +[4] 3 +[3] 0 +[2, 5] 1 +[2, 5] 2 +[1] 4 +[0] 2 + +query ?I +select column2, column1 from foo ORDER BY column2, column1 desc; +---- +[0] 2 +[1] 4 +[2, 5] 2 +[2, 5] 1 +[3] 0 +[4] 3 +[6] 2 +[7] 2 + +query ?I +select column2, column1 from foo ORDER BY column2 desc, column1 desc; +---- +[7] 2 +[6] 2 +[4] 3 +[3] 0 +[2, 5] 2 +[2, 5] 1 +[1] 4 +[0] 2 + +# Cleanup +statement ok +drop table foo; diff --git a/datafusion/sqllogictest/test_files/regexp.slt b/datafusion/sqllogictest/test_files/regexp.slt index a80b08c41ee3..19966be2095b 100644 --- a/datafusion/sqllogictest/test_files/regexp.slt +++ b/datafusion/sqllogictest/test_files/regexp.slt @@ -124,6 +124,10 @@ SELECT regexp_like('(?<=[A-Z]\w )Smith', 'John Smith', 'i'); ---- false +query B +select regexp_like('aaa-555', '.*-(\d*)'); +---- +true # # regexp_match tests diff --git a/datafusion/sqllogictest/test_files/scalar.slt b/datafusion/sqllogictest/test_files/scalar.slt index a3f773e607c7..a64fcbbdbca2 100644 --- a/datafusion/sqllogictest/test_files/scalar.slt +++ b/datafusion/sqllogictest/test_files/scalar.slt @@ -941,6 +941,25 @@ select round(sqrt(a), 5), round(sqrt(b), 5), round(sqrt(c), 5) from signed_integ NaN 10 NaN NaN 100 NaN +# sqrt scalar fraction +query RR rowsort +select sqrt(1.4), sqrt(2.0/3); +---- +1.18321595662 0.816496580928 + +# sqrt scalar cast +query R rowsort +select sqrt(cast(10e8 as double)); +---- +31622.776601683792 + + +# sqrt scalar negative +query R rowsort +select sqrt(-1); +---- +NaN + ## tan # tan scalar function @@ -2068,5 +2087,5 @@ select position('' in '') 1 -query error DataFusion error: Error during planning: The INSTR/POSITION function can only accept strings, but got Int64. +query error DataFusion error: Error during planning: The STRPOS/INSTR/POSITION function can only accept strings, but got Int64. select position(1 in 1) diff --git a/datafusion/sqllogictest/test_files/timestamps.slt b/datafusion/sqllogictest/test_files/timestamps.slt index d7085631777c..f0483aec8946 100644 --- a/datafusion/sqllogictest/test_files/timestamps.slt +++ b/datafusion/sqllogictest/test_files/timestamps.slt @@ -465,6 +465,10 @@ SELECT TIME 'not a time' as time; query error Cannot cast string '24:01:02' to value of Time64\(Nanosecond\) type SELECT TIME '24:01:02' as time; +# invalid timezone +query error Arrow error: Parser error: Invalid timezone "ZZ": 'ZZ' is not a valid timezone +SELECT TIMESTAMP '2023-12-05T21:58:10.45ZZ'; + statement ok set datafusion.optimizer.skip_failed_rules = true diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index a541d0370184..92d2208029b1 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -4169,43 +4169,103 @@ select lag(a, 2, null) ignore nulls over (order by id desc) as x1, sum(id) over (order by id desc ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) as sum_id from (select 2 id, 'b' a union all select 1 id, null a union all select 3 id, null union all select 4 id, 'x') +# LEAD window function IGNORE/RESPECT NULLS support with ascending order and default offset 1 +query TTTTTT +select lead(a) ignore nulls over (order by id) as x, + lead(a, 1, null) ignore nulls over (order by id) as x1, + lead(a, 1, 'def') ignore nulls over (order by id) as x2, + lead(a) respect nulls over (order by id) as x3, + lead(a, 1, null) respect nulls over (order by id) as x4, + lead(a, 1, 'def') respect nulls over (order by id) as x5 +from (select 2 id, 'b' a union all select 1 id, null a union all select 3 id, null union all select 4 id, 'x') +---- +b b b b b b +x x x NULL NULL NULL +x x x x x x +NULL NULL def NULL NULL def + +# LEAD window function IGNORE/RESPECT NULLS support with descending order and default offset 1 +query TTTTTT +select lead(a) ignore nulls over (order by id desc) as x, + lead(a, 1, null) ignore nulls over (order by id desc) as x1, + lead(a, 1, 'def') ignore nulls over (order by id desc) as x2, + lead(a) respect nulls over (order by id desc) as x3, + lead(a, 1, null) respect nulls over (order by id desc) as x4, + lead(a, 1, 'def') respect nulls over (order by id desc) as x5 +from (select 2 id, 'b' a union all select 1 id, null a union all select 3 id, null union all select 4 id, 'x') +---- +b b b NULL NULL NULL +b b b b b b +NULL NULL def NULL NULL NULL +NULL NULL def NULL NULL def + +# LEAD window function IGNORE/RESPECT NULLS support with ascending order and nondefault offset +query TTTT +select lead(a, 2, null) ignore nulls over (order by id) as x1, + lead(a, 2, 'def') ignore nulls over (order by id) as x2, + lead(a, 2, null) respect nulls over (order by id) as x4, + lead(a, 2, 'def') respect nulls over (order by id) as x5 +from (select 2 id, 'b' a union all select 1 id, null a union all select 3 id, null union all select 4 id, 'x') +---- +x x NULL NULL +NULL def x x +NULL def NULL def +NULL def NULL def + # LEAD window function IGNORE/RESPECT NULLS support with descending order and nondefault offset -statement error Execution error: IGNORE NULLS mode for LEAD is not supported for BoundedWindowAggExec +query TTTT select lead(a, 2, null) ignore nulls over (order by id desc) as x1, lead(a, 2, 'def') ignore nulls over (order by id desc) as x2, lead(a, 2, null) respect nulls over (order by id desc) as x4, lead(a, 2, 'def') respect nulls over (order by id desc) as x5 from (select 2 id, 'b' a union all select 1 id, null a union all select 3 id, null union all select 4 id, 'x') +---- +NULL def b b +NULL def NULL NULL +NULL def NULL def +NULL def NULL def + +# LEAD window function IGNORE/RESPECT NULLS support with descending order and nondefault offset. +# To trigger WindowAggExec, we added a sum window function with all of the ranges. +statement error Execution error: IGNORE NULLS mode for LAG and LEAD is not supported for WindowAggExec +select lead(a, 2, null) ignore nulls over (order by id desc) as x1, + lead(a, 2, 'def') ignore nulls over (order by id desc) as x2, + lead(a, 2, null) respect nulls over (order by id desc) as x4, + lead(a, 2, 'def') respect nulls over (order by id desc) as x5, + sum(id) over (order by id desc ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) as sum_id +from (select 2 id, 'b' a union all select 1 id, null a union all select 3 id, null union all select 4 id, 'x') statement ok set datafusion.execution.batch_size = 1000; -query I -SELECT LAG(c1, 2) IGNORE NULLS OVER() +query II +SELECT LAG(c1, 2) IGNORE NULLS OVER(), + LEAD(c1, 2) IGNORE NULLS OVER() FROM null_cases ORDER BY c2 LIMIT 5; ---- -78 -63 -3 -24 -14 +78 50 +63 38 +3 53 +24 31 +14 94 -# result should be same with above, when lag algorithm work with pruned data. +# result should be same with above, when LAG/LEAD algorithm work with pruned data. # decreasing batch size, causes data to be produced in smaller chunks at the source. # Hence sliding window algorithm is used during calculations. statement ok set datafusion.execution.batch_size = 1; -query I -SELECT LAG(c1, 2) IGNORE NULLS OVER() +query II +SELECT LAG(c1, 2) IGNORE NULLS OVER(), + LEAD(c1, 2) IGNORE NULLS OVER() FROM null_cases ORDER BY c2 LIMIT 5; ---- -78 -63 -3 -24 -14 +78 50 +63 38 +3 53 +24 31 +14 94 diff --git a/docs/logos/Datafusion_Branding_Guideline.pdf b/docs/logos/Datafusion_Branding_Guideline.pdf new file mode 100644 index 000000000000..dcf0a09dba9f Binary files /dev/null and b/docs/logos/Datafusion_Branding_Guideline.pdf differ diff --git a/docs/logos/DataFUSION-Logo-Dark.svg b/docs/logos/old_logo/DataFUSION-Logo-Dark.svg similarity index 100% rename from docs/logos/DataFUSION-Logo-Dark.svg rename to docs/logos/old_logo/DataFUSION-Logo-Dark.svg diff --git a/docs/logos/DataFUSION-Logo-Dark@2x.png b/docs/logos/old_logo/DataFUSION-Logo-Dark@2x.png similarity index 100% rename from docs/logos/DataFUSION-Logo-Dark@2x.png rename to docs/logos/old_logo/DataFUSION-Logo-Dark@2x.png diff --git a/docs/logos/DataFUSION-Logo-Dark@4x.png b/docs/logos/old_logo/DataFUSION-Logo-Dark@4x.png similarity index 100% rename from docs/logos/DataFUSION-Logo-Dark@4x.png rename to docs/logos/old_logo/DataFUSION-Logo-Dark@4x.png diff --git a/docs/logos/DataFUSION-Logo-Light.svg b/docs/logos/old_logo/DataFUSION-Logo-Light.svg similarity index 100% rename from docs/logos/DataFUSION-Logo-Light.svg rename to docs/logos/old_logo/DataFUSION-Logo-Light.svg diff --git a/docs/logos/DataFUSION-Logo-Light@2x.png b/docs/logos/old_logo/DataFUSION-Logo-Light@2x.png similarity index 100% rename from docs/logos/DataFUSION-Logo-Light@2x.png rename to docs/logos/old_logo/DataFUSION-Logo-Light@2x.png diff --git a/docs/logos/DataFUSION-Logo-Light@4x.png b/docs/logos/old_logo/DataFUSION-Logo-Light@4x.png similarity index 100% rename from docs/logos/DataFUSION-Logo-Light@4x.png rename to docs/logos/old_logo/DataFUSION-Logo-Light@4x.png diff --git a/docs/logos/DataFusion-LogoAndColorPaletteExploration_v01.pdf b/docs/logos/old_logo/DataFusion-LogoAndColorPaletteExploration_v01.pdf similarity index 100% rename from docs/logos/DataFusion-LogoAndColorPaletteExploration_v01.pdf rename to docs/logos/old_logo/DataFusion-LogoAndColorPaletteExploration_v01.pdf diff --git a/docs/logos/primary_mark/black.png b/docs/logos/primary_mark/black.png new file mode 100644 index 000000000000..053a798720d8 Binary files /dev/null and b/docs/logos/primary_mark/black.png differ diff --git a/docs/logos/primary_mark/black.svg b/docs/logos/primary_mark/black.svg new file mode 100644 index 000000000000..0b0a890e1eec --- /dev/null +++ b/docs/logos/primary_mark/black.svg @@ -0,0 +1,27 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/logos/primary_mark/black2x.png b/docs/logos/primary_mark/black2x.png new file mode 100644 index 000000000000..18ce390da8b2 Binary files /dev/null and b/docs/logos/primary_mark/black2x.png differ diff --git a/docs/logos/primary_mark/black4x.png b/docs/logos/primary_mark/black4x.png new file mode 100644 index 000000000000..cfcbd9c8ed59 Binary files /dev/null and b/docs/logos/primary_mark/black4x.png differ diff --git a/docs/logos/primary_mark/mixed.png b/docs/logos/primary_mark/mixed.png new file mode 100644 index 000000000000..4a24495f879a Binary files /dev/null and b/docs/logos/primary_mark/mixed.png differ diff --git a/docs/logos/primary_mark/mixed.svg b/docs/logos/primary_mark/mixed.svg new file mode 100644 index 000000000000..306450bbbf58 --- /dev/null +++ b/docs/logos/primary_mark/mixed.svg @@ -0,0 +1,31 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/logos/primary_mark/mixed2x.png b/docs/logos/primary_mark/mixed2x.png new file mode 100644 index 000000000000..16e1f5687127 Binary files /dev/null and b/docs/logos/primary_mark/mixed2x.png differ diff --git a/docs/logos/primary_mark/mixed4x.png b/docs/logos/primary_mark/mixed4x.png new file mode 100644 index 000000000000..ada80821508f Binary files /dev/null and b/docs/logos/primary_mark/mixed4x.png differ diff --git a/docs/logos/primary_mark/original.png b/docs/logos/primary_mark/original.png new file mode 100644 index 000000000000..687f946760b0 Binary files /dev/null and b/docs/logos/primary_mark/original.png differ diff --git a/docs/logos/primary_mark/original.svg b/docs/logos/primary_mark/original.svg new file mode 100644 index 000000000000..6ba0ece995a3 --- /dev/null +++ b/docs/logos/primary_mark/original.svg @@ -0,0 +1,31 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/logos/primary_mark/original2x.png b/docs/logos/primary_mark/original2x.png new file mode 100644 index 000000000000..a7402109b211 Binary files /dev/null and b/docs/logos/primary_mark/original2x.png differ diff --git a/docs/logos/primary_mark/original4x.png b/docs/logos/primary_mark/original4x.png new file mode 100644 index 000000000000..ae1000635cc6 Binary files /dev/null and b/docs/logos/primary_mark/original4x.png differ diff --git a/docs/logos/primary_mark/white.png b/docs/logos/primary_mark/white.png new file mode 100644 index 000000000000..cdb66f1f7c10 Binary files /dev/null and b/docs/logos/primary_mark/white.png differ diff --git a/docs/logos/primary_mark/white.svg b/docs/logos/primary_mark/white.svg new file mode 100644 index 000000000000..6f900590ce39 --- /dev/null +++ b/docs/logos/primary_mark/white.svg @@ -0,0 +1,27 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/logos/primary_mark/white2x.png b/docs/logos/primary_mark/white2x.png new file mode 100644 index 000000000000..d54606e667e4 Binary files /dev/null and b/docs/logos/primary_mark/white2x.png differ diff --git a/docs/logos/primary_mark/white4x.png b/docs/logos/primary_mark/white4x.png new file mode 100644 index 000000000000..bc867fb1b92b Binary files /dev/null and b/docs/logos/primary_mark/white4x.png differ diff --git a/docs/logos/standalone_logo/logo_black.png b/docs/logos/standalone_logo/logo_black.png new file mode 100644 index 000000000000..46cfd58e0d61 Binary files /dev/null and b/docs/logos/standalone_logo/logo_black.png differ diff --git a/docs/logos/standalone_logo/logo_black.svg b/docs/logos/standalone_logo/logo_black.svg new file mode 100644 index 000000000000..f82a47e1cf6d --- /dev/null +++ b/docs/logos/standalone_logo/logo_black.svg @@ -0,0 +1,4 @@ + + + + diff --git a/docs/logos/standalone_logo/logo_black2x.png b/docs/logos/standalone_logo/logo_black2x.png new file mode 100644 index 000000000000..34731a637736 Binary files /dev/null and b/docs/logos/standalone_logo/logo_black2x.png differ diff --git a/docs/logos/standalone_logo/logo_black4x.png b/docs/logos/standalone_logo/logo_black4x.png new file mode 100644 index 000000000000..6a6ee3c06fad Binary files /dev/null and b/docs/logos/standalone_logo/logo_black4x.png differ diff --git a/docs/logos/standalone_logo/logo_original.png b/docs/logos/standalone_logo/logo_original.png new file mode 100644 index 000000000000..381265e62d7b Binary files /dev/null and b/docs/logos/standalone_logo/logo_original.png differ diff --git a/docs/logos/standalone_logo/logo_original.svg b/docs/logos/standalone_logo/logo_original.svg new file mode 100644 index 000000000000..bf174719bcf2 --- /dev/null +++ b/docs/logos/standalone_logo/logo_original.svg @@ -0,0 +1,10 @@ + + + + + + + + + + diff --git a/docs/logos/standalone_logo/logo_original2x.png b/docs/logos/standalone_logo/logo_original2x.png new file mode 100644 index 000000000000..7d5b25bd2e8b Binary files /dev/null and b/docs/logos/standalone_logo/logo_original2x.png differ diff --git a/docs/logos/standalone_logo/logo_original4x.png b/docs/logos/standalone_logo/logo_original4x.png new file mode 100644 index 000000000000..10dd50978e37 Binary files /dev/null and b/docs/logos/standalone_logo/logo_original4x.png differ diff --git a/docs/logos/standalone_logo/logo_white.png b/docs/logos/standalone_logo/logo_white.png new file mode 100644 index 000000000000..a48ef890d6f4 Binary files /dev/null and b/docs/logos/standalone_logo/logo_white.png differ diff --git a/docs/logos/standalone_logo/logo_white.svg b/docs/logos/standalone_logo/logo_white.svg new file mode 100644 index 000000000000..9f1954ed82e5 --- /dev/null +++ b/docs/logos/standalone_logo/logo_white.svg @@ -0,0 +1,4 @@ + + + + diff --git a/docs/logos/standalone_logo/logo_white2x.png b/docs/logos/standalone_logo/logo_white2x.png new file mode 100644 index 000000000000..c26de0fe5a5c Binary files /dev/null and b/docs/logos/standalone_logo/logo_white2x.png differ diff --git a/docs/logos/standalone_logo/logo_white4x.png b/docs/logos/standalone_logo/logo_white4x.png new file mode 100644 index 000000000000..22bbb4892ed7 Binary files /dev/null and b/docs/logos/standalone_logo/logo_white4x.png differ diff --git a/docs/source/_static/images/2x_bgwhite_original.png b/docs/source/_static/images/2x_bgwhite_original.png new file mode 100644 index 000000000000..abb5fca6e461 Binary files /dev/null and b/docs/source/_static/images/2x_bgwhite_original.png differ diff --git a/docs/source/_static/images/DataFusion-Logo-Background-White.png b/docs/source/_static/images/old_logo/DataFusion-Logo-Background-White.png similarity index 100% rename from docs/source/_static/images/DataFusion-Logo-Background-White.png rename to docs/source/_static/images/old_logo/DataFusion-Logo-Background-White.png diff --git a/docs/source/_static/images/DataFusion-Logo-Background-White.svg b/docs/source/_static/images/old_logo/DataFusion-Logo-Background-White.svg similarity index 100% rename from docs/source/_static/images/DataFusion-Logo-Background-White.svg rename to docs/source/_static/images/old_logo/DataFusion-Logo-Background-White.svg diff --git a/docs/source/_static/images/DataFusion-Logo-Dark.png b/docs/source/_static/images/old_logo/DataFusion-Logo-Dark.png similarity index 100% rename from docs/source/_static/images/DataFusion-Logo-Dark.png rename to docs/source/_static/images/old_logo/DataFusion-Logo-Dark.png diff --git a/docs/source/_static/images/DataFusion-Logo-Dark.svg b/docs/source/_static/images/old_logo/DataFusion-Logo-Dark.svg similarity index 100% rename from docs/source/_static/images/DataFusion-Logo-Dark.svg rename to docs/source/_static/images/old_logo/DataFusion-Logo-Dark.svg diff --git a/docs/source/_static/images/DataFusion-Logo-Light.png b/docs/source/_static/images/old_logo/DataFusion-Logo-Light.png similarity index 100% rename from docs/source/_static/images/DataFusion-Logo-Light.png rename to docs/source/_static/images/old_logo/DataFusion-Logo-Light.png diff --git a/docs/source/_static/images/DataFusion-Logo-Light.svg b/docs/source/_static/images/old_logo/DataFusion-Logo-Light.svg similarity index 100% rename from docs/source/_static/images/DataFusion-Logo-Light.svg rename to docs/source/_static/images/old_logo/DataFusion-Logo-Light.svg diff --git a/docs/source/_static/images/original.png b/docs/source/_static/images/original.png new file mode 100644 index 000000000000..687f946760b0 Binary files /dev/null and b/docs/source/_static/images/original.png differ diff --git a/docs/source/_static/images/original.svg b/docs/source/_static/images/original.svg new file mode 100644 index 000000000000..6ba0ece995a3 --- /dev/null +++ b/docs/source/_static/images/original.svg @@ -0,0 +1,31 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/source/_static/images/original2x.png b/docs/source/_static/images/original2x.png new file mode 100644 index 000000000000..a7402109b211 Binary files /dev/null and b/docs/source/_static/images/original2x.png differ diff --git a/docs/source/_templates/docs-sidebar.html b/docs/source/_templates/docs-sidebar.html index 2b400b4dcade..7c3ecc3d802e 100644 --- a/docs/source/_templates/docs-sidebar.html +++ b/docs/source/_templates/docs-sidebar.html @@ -15,7 +15,7 @@ - + diff --git a/docs/source/conf.py b/docs/source/conf.py index becece330d1a..a203bfbb10d5 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -100,7 +100,7 @@ # so a file named "default.css" will overwrite the builtin "default.css". html_static_path = ['_static'] -html_logo = "_static/images/DataFusion-Logo-Background-White.png" +html_logo = "_static/images/2x_bgwhite_original.png" html_css_files = [ "theme_overrides.css" diff --git a/docs/source/index.rst b/docs/source/index.rst index 385371661716..f7c0873f3a5f 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -15,7 +15,7 @@ .. specific language governing permissions and limitations .. under the License. -.. image:: _static/images/DataFusion-Logo-Background-White.png +.. image:: _static/images/2x_bgwhite_original.png :alt: DataFusion Logo ======================= diff --git a/docs/source/library-user-guide/working-with-exprs.md b/docs/source/library-user-guide/working-with-exprs.md index b7e9248a7c1f..a839420aa5b2 100644 --- a/docs/source/library-user-guide/working-with-exprs.md +++ b/docs/source/library-user-guide/working-with-exprs.md @@ -92,7 +92,7 @@ In our example, we'll use rewriting to update our `add_one` UDF, to be rewritten ### Rewriting with `transform` -To implement the inlining, we'll need to write a function that takes an `Expr` and returns a `Result`. If the expression is _not_ to be rewritten `Transformed::No` is used to wrap the original `Expr`. If the expression _is_ to be rewritten, `Transformed::Yes` is used to wrap the new `Expr`. +To implement the inlining, we'll need to write a function that takes an `Expr` and returns a `Result`. If the expression is _not_ to be rewritten `Transformed::no` is used to wrap the original `Expr`. If the expression _is_ to be rewritten, `Transformed::yes` is used to wrap the new `Expr`. ```rust fn rewrite_add_one(expr: Expr) -> Result { @@ -102,9 +102,9 @@ fn rewrite_add_one(expr: Expr) -> Result { let input_arg = scalar_fun.args[0].clone(); let new_expression = input_arg + lit(1i64); - Transformed::Yes(new_expression) + Transformed::yes(new_expression) } - _ => Transformed::No(expr), + _ => Transformed::no(expr), }) }) } diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index cd1fbdabea1c..b0385b492365 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -832,12 +832,7 @@ initcap(str) ### `instr` -Returns the location where substr first appeared in str (counting from 1). -If substr does not appear in str, return 0. - -``` -instr(str, substr) -``` +_Alias of [strpos](#strpos)._ #### Arguments @@ -1108,6 +1103,10 @@ strpos(str, substr) - **substr**: Substring expression to search for. Can be a constant, column, or function, and any combination of string operators. +#### Aliases + +- instr + ### `substr` Extracts a substring of a specified number of characters from a specific @@ -1440,7 +1439,8 @@ Additional examples can be found [here](https://github.com/apache/arrow-datafusi ### `position` -Returns the position of substr in orig_str +Returns the position of `substr` in `origstr` (counting from 1). If `substr` does +not appear in `origstr`, return 0. ``` position(substr in origstr) @@ -1448,7 +1448,7 @@ position(substr in origstr) #### Arguments -- **substr**: he pattern string. +- **substr**: The pattern string. - **origstr**: The model string. ## Time and Date Functions @@ -1949,8 +1949,13 @@ from_unixtime(expression) - [array_concat](#array_concat) - [array_contains](#array_contains) - [array_dims](#array_dims) +- [array_has](#array_has) +- [array_has_all](#array_has_all) +- [array_has_any](#array_has_any)] - [array_element](#array_element) +- [array_except](#array_except) - [array_extract](#array_extract) +- [array_fill](#array_fill) - [array_indexof](#array_indexof) - [array_join](#array_join) - [array_length](#array_length) @@ -1972,6 +1977,7 @@ from_unixtime(expression) - [array_reverse](#array_reverse) - [array_slice](#array_slice) - [array_to_string](#array_to_string) +- [array_union](#array_union) - [cardinality](#cardinality) - [empty](#empty) - [flatten](#flatten) @@ -3144,12 +3150,39 @@ trim_array(array, n) ### `range` -Returns an Arrow array between start and stop with step. `SELECT range(2, 10, 3) -> [2, 5, 8]` +Returns an Arrow array between start and stop with step. `SELECT range(2, 10, 3) -> [2, 5, 8]` or `SELECT range(DATE '1992-09-01', DATE '1993-03-01', INTERVAL '1' MONTH);` The range start..end contains all values with start <= x < end. It is empty if start >= end. Step can not be 0 (then the range will be nonsense.). +Note that when the required range is a number, it accepts (stop), (start, stop), and (start, stop, step) as parameters, but when the required range is a date, it must be 3 non-NULL parameters. +For example, + +``` +SELECT range(3); +SELECT range(1,5); +SELECT range(1,5,1); +``` + +are allowed in number ranges + +but in date ranges, only + +``` +SELECT range(DATE '1992-09-01', DATE '1993-03-01', INTERVAL '1' MONTH); +``` + +is allowed, and + +``` +SELECT range(DATE '1992-09-01', DATE '1993-03-01', NULL); +SELECT range(NULL, DATE '1993-03-01', INTERVAL '1' MONTH); +SELECT range(DATE '1992-09-01', NULL, INTERVAL '1' MONTH); +``` + +are not allowed + #### Arguments - **start**: start of the range