From 2da06d52c8af8d805b10220c294f730be2f0f625 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 28 Dec 2023 15:35:36 -0800 Subject: [PATCH] Refactor TreeNode and cleanup some implementations --- datafusion/common/src/tree_node.rs | 26 ++++-- .../enforce_distribution.rs | 31 ++----- .../src/physical_optimizer/enforce_sorting.rs | 21 ++--- .../physical_optimizer/pipeline_checker.rs | 17 +--- .../replace_with_order_preserving_variants.rs | 16 +--- .../src/physical_optimizer/sort_pushdown.rs | 18 +--- datafusion/expr/src/tree_node/expr.rs | 89 +++++++++++-------- datafusion/expr/src/tree_node/plan.rs | 19 +--- .../physical-expr/src/sort_properties.rs | 18 +--- datafusion/physical-expr/src/utils/mod.rs | 17 +--- 10 files changed, 104 insertions(+), 168 deletions(-) diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index 5da9636ffe18..bc4dc03dabc5 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -33,6 +33,9 @@ use crate::Result; /// [`LogicalPlan`]: https://docs.rs/datafusion-expr/latest/datafusion_expr/logical_plan/enum.LogicalPlan.html /// [`Expr`]: https://docs.rs/datafusion-expr/latest/datafusion_expr/expr/enum.Expr.html pub trait TreeNode: Sized { + /// Returns all children of the TreeNode + fn children_nodes(&self) -> Vec<&Self>; + /// Use preorder to iterate the node on the tree so that we can /// stop fast for some cases. /// @@ -211,7 +214,17 @@ pub trait TreeNode: Sized { /// Apply the closure `F` to the node's children fn apply_children(&self, op: &mut F) -> Result where - F: FnMut(&Self) -> Result; + F: FnMut(&Self) -> Result, + { + for child in self.children_nodes() { + match op(child)? { + VisitRecursion::Continue => {} + VisitRecursion::Skip => return Ok(VisitRecursion::Continue), + VisitRecursion::Stop => return Ok(VisitRecursion::Stop), + } + } + Ok(VisitRecursion::Continue) + } /// Apply transform `F` to the node's children, the transform `F` might have a direction(Preorder or Postorder) fn map_children(self, transform: F) -> Result @@ -342,18 +355,21 @@ pub trait DynTreeNode { /// Blanket implementation for Arc for any tye that implements /// [`DynTreeNode`] (such as [`Arc`]) impl TreeNode for Arc { + fn children_nodes(&self) -> Vec<&Arc> { + unimplemented!("Call arc_children instead") + } + fn apply_children(&self, op: &mut F) -> Result where F: FnMut(&Self) -> Result, { - for child in self.arc_children() { - match op(&child)? { + for child in &self.arc_children() { + match op(child)? { VisitRecursion::Continue => {} VisitRecursion::Skip => return Ok(VisitRecursion::Continue), VisitRecursion::Stop => return Ok(VisitRecursion::Stop), } } - Ok(VisitRecursion::Continue) } @@ -368,7 +384,7 @@ impl TreeNode for Arc { let arc_self = Arc::clone(&self); self.with_new_arc_children(arc_self, new_children?) } else { - Ok(self) + Ok(self.clone()) } } } diff --git a/datafusion/core/src/physical_optimizer/enforce_distribution.rs b/datafusion/core/src/physical_optimizer/enforce_distribution.rs index d5a086227323..f3af6d2c0d34 100644 --- a/datafusion/core/src/physical_optimizer/enforce_distribution.rs +++ b/datafusion/core/src/physical_optimizer/enforce_distribution.rs @@ -47,7 +47,7 @@ use crate::physical_plan::{ }; use arrow::compute::SortOptions; -use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_expr::logical_plan::JoinType; use datafusion_physical_expr::expressions::{Column, NoOp}; use datafusion_physical_expr::utils::map_columns_before_projection; @@ -1409,18 +1409,8 @@ impl DistributionContext { } impl TreeNode for DistributionContext { - fn apply_children(&self, op: &mut F) -> Result - where - F: FnMut(&Self) -> Result, - { - for child in &self.children_nodes { - match op(child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - Ok(VisitRecursion::Continue) + fn children_nodes(&self) -> Vec<&Self> { + self.children_nodes.iter().collect() } fn map_children(mut self, transform: F) -> Result @@ -1483,19 +1473,8 @@ impl PlanWithKeyRequirements { } impl TreeNode for PlanWithKeyRequirements { - fn apply_children(&self, op: &mut F) -> Result - where - F: FnMut(&Self) -> Result, - { - for child in &self.children { - match op(child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - - Ok(VisitRecursion::Continue) + fn children_nodes(&self) -> Vec<&Self> { + self.children.iter().collect() } fn map_children(mut self, transform: F) -> Result diff --git a/datafusion/core/src/physical_optimizer/enforce_sorting.rs b/datafusion/core/src/physical_optimizer/enforce_sorting.rs index 77d04a61c59e..27bd71d41393 100644 --- a/datafusion/core/src/physical_optimizer/enforce_sorting.rs +++ b/datafusion/core/src/physical_optimizer/enforce_sorting.rs @@ -145,11 +145,15 @@ impl PlanWithCorrespondingSort { } impl TreeNode for PlanWithCorrespondingSort { + fn children_nodes(&self) -> Vec<&Self> { + self.children_nodes.iter().collect() + } + fn apply_children(&self, op: &mut F) -> Result where F: FnMut(&Self) -> Result, { - for child in &self.children_nodes { + for child in self.children_nodes() { match op(child)? { VisitRecursion::Continue => {} VisitRecursion::Skip => return Ok(VisitRecursion::Continue), @@ -237,19 +241,8 @@ impl PlanWithCorrespondingCoalescePartitions { } impl TreeNode for PlanWithCorrespondingCoalescePartitions { - fn apply_children(&self, op: &mut F) -> Result - where - F: FnMut(&Self) -> Result, - { - for child in &self.children_nodes { - match op(child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - - Ok(VisitRecursion::Continue) + fn children_nodes(&self) -> Vec<&Self> { + self.children_nodes.iter().collect() } fn map_children(mut self, transform: F) -> Result diff --git a/datafusion/core/src/physical_optimizer/pipeline_checker.rs b/datafusion/core/src/physical_optimizer/pipeline_checker.rs index 9e9f647d073f..cd2f7c716a2e 100644 --- a/datafusion/core/src/physical_optimizer/pipeline_checker.rs +++ b/datafusion/core/src/physical_optimizer/pipeline_checker.rs @@ -27,7 +27,7 @@ use crate::physical_optimizer::PhysicalOptimizerRule; use crate::physical_plan::{with_new_children_if_necessary, ExecutionPlan}; use datafusion_common::config::OptimizerOptions; -use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{plan_err, DataFusionError}; use datafusion_physical_expr::intervals::utils::{check_support, is_datatype_supported}; use datafusion_physical_plan::joins::SymmetricHashJoinExec; @@ -91,19 +91,8 @@ impl PipelineStatePropagator { } impl TreeNode for PipelineStatePropagator { - fn apply_children(&self, op: &mut F) -> Result - where - F: FnMut(&Self) -> Result, - { - for child in &self.children { - match op(child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - - Ok(VisitRecursion::Continue) + fn children_nodes(&self) -> Vec<&Self> { + self.children.iter().collect() } fn map_children(mut self, transform: F) -> Result diff --git a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs index 91f3d2abc6ff..f1fb01fc0aee 100644 --- a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs +++ b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs @@ -29,7 +29,7 @@ use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use crate::physical_plan::{with_new_children_if_necessary, ExecutionPlan}; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_physical_plan::unbounded_output; /// For a given `plan`, this object carries the information one needs from its @@ -104,18 +104,8 @@ impl OrderPreservationContext { } impl TreeNode for OrderPreservationContext { - fn apply_children(&self, op: &mut F) -> Result - where - F: FnMut(&Self) -> Result, - { - for child in &self.children_nodes { - match op(child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - Ok(VisitRecursion::Continue) + fn children_nodes(&self) -> Vec<&Self> { + self.children_nodes.iter().collect() } fn map_children(mut self, transform: F) -> Result diff --git a/datafusion/core/src/physical_optimizer/sort_pushdown.rs b/datafusion/core/src/physical_optimizer/sort_pushdown.rs index b0013863010a..7b33a526d364 100644 --- a/datafusion/core/src/physical_optimizer/sort_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/sort_pushdown.rs @@ -28,7 +28,7 @@ use crate::physical_plan::repartition::RepartitionExec; use crate::physical_plan::sorts::sort::SortExec; use crate::physical_plan::{with_new_children_if_necessary, ExecutionPlan}; -use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{plan_err, DataFusionError, JoinSide, Result}; use datafusion_expr::JoinType; use datafusion_physical_expr::expressions::Column; @@ -71,20 +71,10 @@ impl SortPushDown { } impl TreeNode for SortPushDown { - fn apply_children(&self, op: &mut F) -> Result - where - F: FnMut(&Self) -> Result, - { - for child in &self.children_nodes { - match op(child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - - Ok(VisitRecursion::Continue) + fn children_nodes(&self) -> Vec<&Self> { + self.children_nodes.iter().collect() } + fn map_children(mut self, transform: F) -> Result where F: FnMut(Self) -> Result, diff --git a/datafusion/expr/src/tree_node/expr.rs b/datafusion/expr/src/tree_node/expr.rs index 1098842716b9..939e7b120b27 100644 --- a/datafusion/expr/src/tree_node/expr.rs +++ b/datafusion/expr/src/tree_node/expr.rs @@ -28,11 +28,8 @@ use datafusion_common::tree_node::{TreeNode, VisitRecursion}; use datafusion_common::{internal_err, DataFusionError, Result}; impl TreeNode for Expr { - fn apply_children(&self, op: &mut F) -> Result - where - F: FnMut(&Self) -> Result, - { - let children = match self { + fn children_nodes(&self) -> Vec<&Self> { + match self { Expr::Alias(Alias{expr,..}) | Expr::Not(expr) | Expr::IsNotNull(expr) @@ -47,15 +44,15 @@ impl TreeNode for Expr { | Expr::Cast(Cast { expr, .. }) | Expr::TryCast(TryCast { expr, .. }) | Expr::Sort(Sort { expr, .. }) - | Expr::InSubquery(InSubquery{ expr, .. }) => vec![expr.as_ref().clone()], + | Expr::InSubquery(InSubquery{ expr, .. }) => vec![expr.as_ref()], Expr::GetIndexedField(GetIndexedField { expr, field }) => { - let expr = expr.as_ref().clone(); + let expr = expr.as_ref(); match field { GetFieldAccess::ListIndex {key} => { - vec![key.as_ref().clone(), expr] + vec![key.as_ref(), expr] }, GetFieldAccess::ListRange {start, stop} => { - vec![start.as_ref().clone(), stop.as_ref().clone(), expr] + vec![start.as_ref(), stop.as_ref(), expr] } GetFieldAccess::NamedStructField {name: _name} => { vec![expr] @@ -63,12 +60,12 @@ impl TreeNode for Expr { } } Expr::GroupingSet(GroupingSet::Rollup(exprs)) - | Expr::GroupingSet(GroupingSet::Cube(exprs)) => exprs.clone(), + | Expr::GroupingSet(GroupingSet::Cube(exprs)) => exprs.iter().collect(), Expr::ScalarFunction (ScalarFunction{ args, .. } ) => { - args.clone() + args.iter().collect() } Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs)) => { - lists_of_exprs.clone().into_iter().flatten().collect() + lists_of_exprs.iter().map(|a| a.iter().collect::>()).flatten().collect() } Expr::Column(_) // Treat OuterReferenceColumn as a leaf expression @@ -80,66 +77,80 @@ impl TreeNode for Expr { | Expr::Wildcard {..} | Expr::Placeholder (_) => vec![], Expr::BinaryExpr(BinaryExpr { left, right, .. }) => { - vec![left.as_ref().clone(), right.as_ref().clone()] + vec![left.as_ref(), right.as_ref()] } Expr::Like(Like { expr, pattern, .. }) | Expr::SimilarTo(Like { expr, pattern, .. }) => { - vec![expr.as_ref().clone(), pattern.as_ref().clone()] + vec![expr.as_ref(), pattern.as_ref()] } Expr::Between(Between { - expr, low, high, .. - }) => vec![ - expr.as_ref().clone(), - low.as_ref().clone(), - high.as_ref().clone(), + expr, low, high, .. + }) => vec![ + expr.as_ref(), + low.as_ref(), + high.as_ref(), ], Expr::Case(case) => { let mut expr_vec = vec![]; if let Some(expr) = case.expr.as_ref() { - expr_vec.push(expr.as_ref().clone()); + expr_vec.push(expr.as_ref()); }; for (when, then) in case.when_then_expr.iter() { - expr_vec.push(when.as_ref().clone()); - expr_vec.push(then.as_ref().clone()); + expr_vec.push(when.as_ref()); + expr_vec.push(then.as_ref()); } if let Some(else_expr) = case.else_expr.as_ref() { - expr_vec.push(else_expr.as_ref().clone()); + expr_vec.push(else_expr.as_ref()); } expr_vec } Expr::AggregateFunction(AggregateFunction { args, filter, order_by, .. }) - => { - let mut expr_vec = args.clone(); + => { + let mut expr_vec: Vec<&Expr> = args.iter().collect(); if let Some(f) = filter { - expr_vec.push(f.as_ref().clone()); + expr_vec.push(f.as_ref()); } if let Some(o) = order_by { - expr_vec.extend(o.clone()); + for x in o { + expr_vec.push(x); + } } expr_vec } Expr::WindowFunction(WindowFunction { - args, - partition_by, - order_by, - .. - }) => { - let mut expr_vec = args.clone(); - expr_vec.extend(partition_by.clone()); - expr_vec.extend(order_by.clone()); + args, + partition_by, + order_by, + .. + }) => { + let mut expr_vec: Vec<&Expr> = args.iter().collect(); + for x in partition_by { + expr_vec.push(x); + } + for x in order_by { + expr_vec.push(x); + } expr_vec } Expr::InList(InList { expr, list, .. }) => { let mut expr_vec = vec![]; - expr_vec.push(expr.as_ref().clone()); - expr_vec.extend(list.clone()); + expr_vec.push(expr.as_ref()); + for x in list { + expr_vec.push(x); + } expr_vec } - }; + } + } - for child in children.iter() { + fn apply_children(&self, op: &mut F) -> Result + where + F: FnMut(&Self) -> Result, + { + let children = self.children_nodes(); + for child in children.into_iter() { match op(child)? { VisitRecursion::Continue => {} VisitRecursion::Skip => return Ok(VisitRecursion::Continue), diff --git a/datafusion/expr/src/tree_node/plan.rs b/datafusion/expr/src/tree_node/plan.rs index c7621bc17833..8cd2ac39b252 100644 --- a/datafusion/expr/src/tree_node/plan.rs +++ b/datafusion/expr/src/tree_node/plan.rs @@ -22,6 +22,10 @@ use datafusion_common::tree_node::{TreeNodeVisitor, VisitRecursion}; use datafusion_common::{tree_node::TreeNode, Result}; impl TreeNode for LogicalPlan { + fn children_nodes(&self) -> Vec<&Self> { + self.inputs() + } + fn apply(&self, op: &mut F) -> Result where F: FnMut(&Self) -> Result, @@ -91,21 +95,6 @@ impl TreeNode for LogicalPlan { visitor.post_visit(self) } - fn apply_children(&self, op: &mut F) -> Result - where - F: FnMut(&Self) -> Result, - { - for child in self.inputs() { - match op(child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - - Ok(VisitRecursion::Continue) - } - fn map_children(self, transform: F) -> Result where F: FnMut(Self) -> Result, diff --git a/datafusion/physical-expr/src/sort_properties.rs b/datafusion/physical-expr/src/sort_properties.rs index 91238e5b04b4..259b4bf89a46 100644 --- a/datafusion/physical-expr/src/sort_properties.rs +++ b/datafusion/physical-expr/src/sort_properties.rs @@ -20,7 +20,7 @@ use std::{ops::Neg, sync::Arc}; use arrow_schema::SortOptions; use crate::PhysicalExpr; -use datafusion_common::tree_node::{TreeNode, VisitRecursion}; +use datafusion_common::tree_node::TreeNode; use datafusion_common::Result; /// To propagate [`SortOptions`] across the [`PhysicalExpr`], it is insufficient @@ -158,7 +158,7 @@ impl ExprOrdering { /// Creates a new [`ExprOrdering`] with [`SortProperties::Unordered`] states /// for `expr` and its children. pub fn new(expr: Arc) -> Self { - let children = expr.children(); + let children = PhysicalExpr::children(expr.as_ref()); Self { expr, state: Default::default(), @@ -173,18 +173,8 @@ impl ExprOrdering { } impl TreeNode for ExprOrdering { - fn apply_children(&self, op: &mut F) -> Result - where - F: FnMut(&Self) -> Result, - { - for child in &self.children { - match op(child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - Ok(VisitRecursion::Continue) + fn children_nodes(&self) -> Vec<&Self> { + self.children.iter().collect() } fn map_children(mut self, transform: F) -> Result diff --git a/datafusion/physical-expr/src/utils/mod.rs b/datafusion/physical-expr/src/utils/mod.rs index 87ef36558b96..c2335ab979da 100644 --- a/datafusion/physical-expr/src/utils/mod.rs +++ b/datafusion/physical-expr/src/utils/mod.rs @@ -136,7 +136,7 @@ pub struct ExprTreeNode { impl ExprTreeNode { pub fn new(expr: Arc) -> Self { - let children = expr.children(); + let children = PhysicalExpr::children(expr.as_ref()); ExprTreeNode { expr, data: None, @@ -154,19 +154,8 @@ impl ExprTreeNode { } impl TreeNode for ExprTreeNode { - fn apply_children(&self, op: &mut F) -> Result - where - F: FnMut(&Self) -> Result, - { - for child in self.children() { - match op(child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - - Ok(VisitRecursion::Continue) + fn children_nodes(&self) -> Vec<&Self> { + self.children().iter().collect() } fn map_children(mut self, transform: F) -> Result