Skip to content

Commit

Permalink
chore: reduce duplicate column bindings for aggregate functions.
Browse files Browse the repository at this point in the history
  • Loading branch information
RinChanNOWWW committed Oct 8, 2023
1 parent c9977a7 commit 89ee55a
Show file tree
Hide file tree
Showing 4 changed files with 322 additions and 150 deletions.
30 changes: 30 additions & 0 deletions src/query/sql/src/planner/binder/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,13 @@ impl<'a> AggregateRewriter<'a> {
/// add the replaced aggregate function and the arguments into `AggregateInfo`.
fn replace_aggregate_function(&mut self, aggregate: &AggregateFunction) -> Result<ScalarExpr> {
let agg_info = &mut self.bind_context.aggregate_info;

if let Some(column) =
find_replaced_aggregate_function(&agg_info, aggregate, &aggregate.display_name)
{
return Ok(BoundColumnRef { span: None, column }.into());
}

let mut replaced_args: Vec<ScalarExpr> = Vec::with_capacity(aggregate.args.len());

for (i, arg) in aggregate.args.iter().enumerate() {
Expand Down Expand Up @@ -922,3 +929,26 @@ impl Binder {
}
}
}

/// Replace [`AggregateFunction`] with a [`ColumnBinding`] if the function is already replaced.
pub fn find_replaced_aggregate_function(
agg_info: &AggregateInfo,
agg: &AggregateFunction,
new_name: &str,
) -> Option<ColumnBinding> {
agg_info
.aggregate_functions_map
.get(&agg.display_name)
.map(|i| {
// This expression is already replaced.
let scalar_item = &agg_info.aggregate_functions[*i];
debug_assert_eq!(scalar_item.scalar.data_type().unwrap(), *agg.return_type);
ColumnBindingBuilder::new(
new_name.to_string(),
scalar_item.index,
agg.return_type.clone(),
Visibility::Visible,
)
.build()
})
}
26 changes: 18 additions & 8 deletions src/query/sql/src/planner/binder/project.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ use common_exception::Result;
use common_exception::Span;

use super::AggregateInfo;
use crate::binder::aggregate::find_replaced_aggregate_function;
use crate::binder::select::SelectItem;
use crate::binder::select::SelectList;
use crate::binder::ExprContext;
Expand Down Expand Up @@ -61,15 +62,24 @@ impl Binder {
// This item is a grouping sets item, its data type should be nullable.
let is_grouping_sets_item = agg_info.grouping_sets.is_some()
&& agg_info.group_items_map.contains_key(&item.scalar);
let mut column_binding = if let ScalarExpr::BoundColumnRef(ref column_ref) = item.scalar
{
let mut column_binding = column_ref.column.clone();
// We should apply alias for the ColumnBinding, since it comes from table
column_binding.column_name = item.alias.clone();
column_binding
} else {
self.create_derived_column_binding(item.alias.clone(), item.scalar.data_type()?)

let mut column_binding = match &item.scalar {
ScalarExpr::BoundColumnRef(column_ref) => {
let mut column_binding = column_ref.column.clone();
// We should apply alias for the ColumnBinding, since it comes from table
column_binding.column_name = item.alias.clone();
column_binding
}
ScalarExpr::AggregateFunction(agg) => {
// Replace to bound column to reduce duplicate derived column bindings.
debug_assert!(!is_grouping_sets_item);
find_replaced_aggregate_function(agg_info, agg, &item.alias).unwrap()
}
_ => {
self.create_derived_column_binding(item.alias.clone(), item.scalar.data_type()?)
}
};

if is_grouping_sets_item {
column_binding.data_type = Box::new(column_binding.data_type.wrap_nullable());
}
Expand Down
260 changes: 202 additions & 58 deletions tests/sqllogictests/suites/mode/standalone/explain/aggregate.test
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ query T
explain select count(3), type, name, trim(name) as a from system.columns group by name, type, a, concat(name, trim(name)), concat(type, name), length(name);
----
EvalScalar
├── output columns: [columns.name (#0), columns.type (#3), count(3) (#14), a (#15)]
├── expressions: [count(3) (#13), trim_both(columns.name (#0), ' ')]
├── output columns: [count(3) (#13), columns.name (#0), columns.type (#3), a (#14)]
├── expressions: [trim_both(columns.name (#0), ' ')]
├── estimated rows: 0.00
└── AggregateFinal
├── output columns: [count(3) (#13), columns.name (#0), columns.type (#3)]
Expand Down Expand Up @@ -181,19 +181,46 @@ AggregateFinal
query T
explain select a, max(b) from explain_agg_t1 group by a having a > 1;
----
EvalScalar
├── output columns: [explain_agg_t1.a (#0), max(b) (#3)]
├── expressions: [max(b) (#2)]
AggregateFinal
├── output columns: [max(b) (#2), explain_agg_t1.a (#0)]
├── group by: [a]
├── aggregate functions: [max(b)]
├── estimated rows: 0.00
└── AggregateFinal
├── output columns: [max(b) (#2), explain_agg_t1.a (#0)]
└── AggregatePartial
├── output columns: [max(b) (#2), #_group_by_key]
├── group by: [a]
├── aggregate functions: [max(b)]
├── estimated rows: 0.00
└── Filter
├── output columns: [explain_agg_t1.a (#0), explain_agg_t1.b (#1)]
├── filters: [explain_agg_t1.a (#0) > 1]
├── estimated rows: 0.00
└── TableScan
├── table: default.default.explain_agg_t1
├── output columns: [a (#0), b (#1)]
├── read rows: 0
├── read bytes: 0
├── partitions total: 0
├── partitions scanned: 0
├── push downs: [filters: [explain_agg_t1.a (#0) > 1], limit: NONE]
└── estimated rows: 0.00

query T
explain select a, avg(b) from explain_agg_t1 group by a having a > 1 and max(b) > 10;
----
Filter
├── output columns: [avg(b) (#2), explain_agg_t1.a (#0)]
├── filters: [is_true(max(b) (#3) > 10)]
├── estimated rows: 0.00
└── AggregateFinal
├── output columns: [avg(b) (#2), max(b) (#3), explain_agg_t1.a (#0)]
├── group by: [a]
├── aggregate functions: [avg(b), max(b)]
├── estimated rows: 0.00
└── AggregatePartial
├── output columns: [max(b) (#2), #_group_by_key]
├── output columns: [avg(b) (#2), max(b) (#3), #_group_by_key]
├── group by: [a]
├── aggregate functions: [max(b)]
├── aggregate functions: [avg(b), max(b)]
├── estimated rows: 0.00
└── Filter
├── output columns: [explain_agg_t1.a (#0), explain_agg_t1.b (#1)]
Expand All @@ -210,29 +237,150 @@ EvalScalar
└── estimated rows: 0.00

query T
explain select a, avg(b) from explain_agg_t1 group by a having a > 1 and max(b) > 10;
explain select avg(b) from explain_agg_t1 group by a order by avg(b);
----
EvalScalar
├── output columns: [explain_agg_t1.a (#0), avg(b) (#3)]
├── expressions: [avg(b) (#2)]
Sort
├── output columns: [avg(b) (#2)]
├── sort keys: [avg(b) ASC NULLS LAST]
├── estimated rows: 0.00
└── Filter
└── AggregateFinal
├── output columns: [avg(b) (#2), explain_agg_t1.a (#0)]
├── filters: [is_true(max(b) (#4) > 10)]
├── group by: [a]
├── aggregate functions: [avg(b)]
├── estimated rows: 0.00
└── AggregatePartial
├── output columns: [avg(b) (#2), #_group_by_key]
├── group by: [a]
├── aggregate functions: [avg(b)]
├── estimated rows: 0.00
└── TableScan
├── table: default.default.explain_agg_t1
├── output columns: [a (#0), b (#1)]
├── read rows: 0
├── read bytes: 0
├── partitions total: 0
├── partitions scanned: 0
├── push downs: [filters: [], limit: NONE]
└── estimated rows: 0.00


query T
explain select avg(b) + 1 from explain_agg_t1 group by a order by avg(b);
----
EvalScalar
├── output columns: [(avg(b) + 1) (#3)]
├── expressions: [avg(b) (#2) + 1]
├── estimated rows: 0.00
└── Sort
├── output columns: [avg(b) (#2)]
├── sort keys: [avg(b) ASC NULLS LAST]
├── estimated rows: 0.00
└── AggregateFinal
├── output columns: [avg(b) (#2), max(b) (#4), explain_agg_t1.a (#0)]
├── output columns: [avg(b) (#2), explain_agg_t1.a (#0)]
├── group by: [a]
├── aggregate functions: [avg(b), max(b)]
├── aggregate functions: [avg(b)]
├── estimated rows: 0.00
└── AggregatePartial
├── output columns: [avg(b) (#2), #_group_by_key]
├── group by: [a]
├── aggregate functions: [avg(b)]
├── estimated rows: 0.00
└── TableScan
├── table: default.default.explain_agg_t1
├── output columns: [a (#0), b (#1)]
├── read rows: 0
├── read bytes: 0
├── partitions total: 0
├── partitions scanned: 0
├── push downs: [filters: [], limit: NONE]
└── estimated rows: 0.00

query T
explain select avg(b), avg(b) + 1 from explain_agg_t1 group by a order by avg(b);
----
EvalScalar
├── output columns: [avg(b) (#2), (avg(b) + 1) (#3)]
├── expressions: [avg(b) (#2) + 1]
├── estimated rows: 0.00
└── Sort
├── output columns: [avg(b) (#2)]
├── sort keys: [avg(b) ASC NULLS LAST]
├── estimated rows: 0.00
└── AggregateFinal
├── output columns: [avg(b) (#2), explain_agg_t1.a (#0)]
├── group by: [a]
├── aggregate functions: [avg(b)]
├── estimated rows: 0.00
└── AggregatePartial
├── output columns: [avg(b) (#2), #_group_by_key]
├── group by: [a]
├── aggregate functions: [avg(b)]
├── estimated rows: 0.00
└── TableScan
├── table: default.default.explain_agg_t1
├── output columns: [a (#0), b (#1)]
├── read rows: 0
├── read bytes: 0
├── partitions total: 0
├── partitions scanned: 0
├── push downs: [filters: [], limit: NONE]
└── estimated rows: 0.00

query T
explain select avg(b) + 1, avg(b) from explain_agg_t1 group by a order by avg(b);
----
EvalScalar
├── output columns: [avg(b) (#2), (avg(b) + 1) (#3)]
├── expressions: [avg(b) (#2) + 1]
├── estimated rows: 0.00
└── Sort
├── output columns: [avg(b) (#2)]
├── sort keys: [avg(b) ASC NULLS LAST]
├── estimated rows: 0.00
└── AggregateFinal
├── output columns: [avg(b) (#2), explain_agg_t1.a (#0)]
├── group by: [a]
├── aggregate functions: [avg(b)]
├── estimated rows: 0.00
└── AggregatePartial
├── output columns: [avg(b) (#2), max(b) (#4), #_group_by_key]
├── output columns: [avg(b) (#2), #_group_by_key]
├── group by: [a]
├── aggregate functions: [avg(b), max(b)]
├── aggregate functions: [avg(b)]
├── estimated rows: 0.00
└── Filter
├── output columns: [explain_agg_t1.a (#0), explain_agg_t1.b (#1)]
├── filters: [explain_agg_t1.a (#0) > 1]
└── TableScan
├── table: default.default.explain_agg_t1
├── output columns: [a (#0), b (#1)]
├── read rows: 0
├── read bytes: 0
├── partitions total: 0
├── partitions scanned: 0
├── push downs: [filters: [], limit: NONE]
└── estimated rows: 0.00

query T
explain select avg(b), avg(b) + 1 from explain_agg_t1 group by a order by avg(b) + 1;
----
EvalScalar
├── output columns: [avg(b) (#2), (avg(b) + 1) (#3)]
├── expressions: [avg(b) (#2) + 1]
├── estimated rows: 0.00
└── Sort
├── output columns: [avg(b) (#2), (avg(b) + 1) (#4)]
├── sort keys: [(avg(b) + 1) ASC NULLS LAST]
├── estimated rows: 0.00
└── EvalScalar
├── output columns: [avg(b) (#2), (avg(b) + 1) (#4)]
├── expressions: [avg(b) (#2) + 1]
├── estimated rows: 0.00
└── AggregateFinal
├── output columns: [avg(b) (#2), explain_agg_t1.a (#0)]
├── group by: [a]
├── aggregate functions: [avg(b)]
├── estimated rows: 0.00
└── AggregatePartial
├── output columns: [avg(b) (#2), #_group_by_key]
├── group by: [a]
├── aggregate functions: [avg(b)]
├── estimated rows: 0.00
└── TableScan
├── table: default.default.explain_agg_t1
Expand All @@ -241,7 +389,7 @@ EvalScalar
├── read bytes: 0
├── partitions total: 0
├── partitions scanned: 0
├── push downs: [filters: [explain_agg_t1.a (#0) > 1], limit: NONE]
├── push downs: [filters: [], limit: NONE]
└── estimated rows: 0.00

statement ok
Expand All @@ -262,46 +410,42 @@ create table t2 as select number as a from numbers(100)
query T
explain select count() from t1, t2 where t1.a > t2.a;
----
EvalScalar
├── output columns: [count() (#3)]
├── expressions: [count() (#2)]
AggregateFinal
├── output columns: [count() (#2)]
├── group by: []
├── aggregate functions: [count()]
├── estimated rows: 1.00
└── AggregateFinal
└── AggregatePartial
├── output columns: [count() (#2)]
├── group by: []
├── aggregate functions: [count()]
├── estimated rows: 1.00
└── AggregatePartial
├── output columns: [count() (#2)]
├── group by: []
├── aggregate functions: [count()]
├── estimated rows: 1.00
└── MergeJoin
├── output columns: [t1.a (#0), t2.a (#1)]
├── join type: INNER
├── range join conditions: [t1.a (#0) "gt" t2.a (#1)]
├── other conditions: []
├── estimated rows: 1000.00
├── TableScan(Left)
│ ├── table: default.default.t1
│ ├── output columns: [a (#0)]
│ ├── read rows: 10
│ ├── read bytes: 65
│ ├── partitions total: 1
│ ├── partitions scanned: 1
│ ├── pruning stats: [segments: <range pruning: 1 to 1>, blocks: <range pruning: 1 to 1, bloom pruning: 0 to 0>]
│ ├── push downs: [filters: [], limit: NONE]
│ └── estimated rows: 10.00
└── TableScan(Right)
├── table: default.default.t2
├── output columns: [a (#1)]
├── read rows: 100
├── read bytes: 172
├── partitions total: 1
├── partitions scanned: 1
├── pruning stats: [segments: <range pruning: 1 to 1>, blocks: <range pruning: 1 to 1, bloom pruning: 0 to 0>]
├── push downs: [filters: [], limit: NONE]
└── estimated rows: 100.00
└── MergeJoin
├── output columns: [t1.a (#0), t2.a (#1)]
├── join type: INNER
├── range join conditions: [t1.a (#0) "gt" t2.a (#1)]
├── other conditions: []
├── estimated rows: 1000.00
├── TableScan(Left)
│ ├── table: default.default.t1
│ ├── output columns: [a (#0)]
│ ├── read rows: 10
│ ├── read bytes: 65
│ ├── partitions total: 1
│ ├── partitions scanned: 1
│ ├── pruning stats: [segments: <range pruning: 1 to 1>, blocks: <range pruning: 1 to 1, bloom pruning: 0 to 0>]
│ ├── push downs: [filters: [], limit: NONE]
│ └── estimated rows: 10.00
└── TableScan(Right)
├── table: default.default.t2
├── output columns: [a (#1)]
├── read rows: 100
├── read bytes: 172
├── partitions total: 1
├── partitions scanned: 1
├── pruning stats: [segments: <range pruning: 1 to 1>, blocks: <range pruning: 1 to 1, bloom pruning: 0 to 0>]
├── push downs: [filters: [], limit: NONE]
└── estimated rows: 100.00

statement ok
drop table t1;
Expand Down
Loading

0 comments on commit 89ee55a

Please sign in to comment.