Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(query): prune unused flatten result columns #13935

Merged
merged 15 commits into from
Dec 8, 2023
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 1 addition & 4 deletions src/query/functions/src/srfs/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use common_expression::FunctionRegistry;

mod array;
mod variant;

pub use variant::FlattenGenerator;
pub use variant::FlattenMode;
use common_expression::FunctionRegistry;

pub fn register(registry: &mut FunctionRegistry) {
array::register(registry);
Expand Down
341 changes: 252 additions & 89 deletions src/query/functions/src/srfs/variant.rs

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -233,11 +233,14 @@ impl BlockingTransform for TransformSRF {
}
}
} else {
let data_type = srf_expr.data_type();
let inner_tys = data_type.as_tuple().unwrap();
let inner_vals = vec![ScalarRef::Null; inner_tys.len()];
row_result = Value::Column(
ColumnBuilder::repeat(
&ScalarRef::Tuple(vec![ScalarRef::Null]),
&ScalarRef::Tuple(inner_vals),
self.num_rows[i],
srf_expr.data_type(),
data_type,
)
.build(),
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::sync::Arc;

use common_exception::Result;
use common_expression::ConstantFolder;
use common_expression::DataField;
Expand All @@ -25,6 +27,10 @@ use crate::executor::physical_plan::PhysicalPlan;
use crate::executor::physical_plan_builder::PhysicalPlanBuilder;
use crate::optimizer::ColumnSet;
use crate::optimizer::SExpr;
use crate::plans::FunctionCall;
use crate::plans::ProjectSet;
use crate::plans::RelOperator;
use crate::plans::ScalarExpr;
use crate::IndexType;
use crate::TypeCheck;

Expand Down Expand Up @@ -91,7 +97,17 @@ impl PhysicalPlanBuilder {
if used.is_empty() {
self.build(s_expr.child(0)?, required).await
} else {
let input = self.build(s_expr.child(0)?, required).await?;
let child = s_expr.child(0)?;
let input = if let RelOperator::ProjectSet(project_set) = child.plan() {
let new_project_set =
self.prune_flatten_columns(eval_scalar, project_set, &required);
let mut new_child = child.clone();
new_child.plan = Arc::new(new_project_set.into());
self.build(&new_child, required).await?
} else {
self.build(child, required).await?
};

let eval_scalar = crate::plans::EvalScalar { items: used };
self.create_eval_scalar(&eval_scalar, column_projections, input, stat_info)
}
Expand Down Expand Up @@ -149,4 +165,46 @@ impl PhysicalPlanBuilder {
stat_info: Some(stat_info),
}))
}

// The flatten function returns a tuple, which contains 6 columns.
// Only keep columns required by parent plan, other columns can be pruned
// to reduce the memory usage.
fn prune_flatten_columns(
&mut self,
eval_scalar: &crate::plans::EvalScalar,
project_set: &ProjectSet,
required: &ColumnSet,
) -> ProjectSet {
let mut project_set = project_set.clone();
for srf_item in &mut project_set.srfs {
if let ScalarExpr::FunctionCall(srf_func) = &srf_item.scalar {
if srf_func.func_name == "flatten" {
// Store the columns required by the parent plan in params.
let mut params = Vec::new();
for item in &eval_scalar.items {
if !required.contains(&item.index) {
continue;
}
if let ScalarExpr::FunctionCall(func) = &item.scalar {
if func.func_name == "get" && !func.arguments.is_empty() {
if let ScalarExpr::BoundColumnRef(column_ref) = &func.arguments[0] {
if column_ref.column.index == srf_item.index {
params.push(func.params[0]);
}
}
}
}
}

srf_item.scalar = ScalarExpr::FunctionCall(FunctionCall {
span: srf_func.span,
func_name: srf_func.func_name.clone(),
params,
arguments: srf_func.arguments.clone(),
});
}
}
}
project_set
}
}
34 changes: 17 additions & 17 deletions src/query/sql/src/planner/binder/project_set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ use common_ast::ast::Lambda;
use common_ast::ast::Literal;
use common_ast::ast::Window;
use common_ast::Visitor;
use common_exception::ErrorCode;
use common_exception::Result;
use common_exception::Span;
use common_expression::FunctionKind;
Expand Down Expand Up @@ -148,12 +147,6 @@ impl Binder {
let srf_expr = srf_scalar.as_expr()?;
let return_types = srf_expr.data_type().as_tuple().unwrap();

if return_types.len() > 1 {
return Err(ErrorCode::Unimplemented(
"set-returning functions with more than one return type are not supported yet",
));
}

// Add result column to metadata
let column_index = self
.metadata
Expand All @@ -173,20 +166,27 @@ impl Binder {
};
items.push(item);

// Flatten the tuple fields of the srfs to the top level columns
// TODO(andylokandy/leisky): support multiple return types
let flatten_result = ScalarExpr::FunctionCall(FunctionCall {
span: srf.span(),
func_name: "get".to_string(),
params: vec![1],
arguments: vec![ScalarExpr::BoundColumnRef(BoundColumnRef {
// If tuple has more than one field, return the tuple column,
// otherwise, extract the tuple field to top level column.
let result_column = if return_types.len() > 1 {
ScalarExpr::BoundColumnRef(BoundColumnRef {
span: srf.span(),
column,
})],
});
})
} else {
ScalarExpr::FunctionCall(FunctionCall {
span: srf.span(),
func_name: "get".to_string(),
params: vec![1],
arguments: vec![ScalarExpr::BoundColumnRef(BoundColumnRef {
span: srf.span(),
column,
})],
})
};

// Add the srf to bind context, so we can replace the srfs later.
bind_context.srfs.insert(srf.to_string(), flatten_result);
bind_context.srfs.insert(srf.to_string(), result_column);
}

let project_set = ProjectSet { srfs: items };
Expand Down
47 changes: 26 additions & 21 deletions src/query/sql/src/planner/binder/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ impl Binder {
plan.items.len()
)));
}
// Delete result tuple column
// Delete srf result tuple column, extract tuple inner columns instead
let _ = bind_context.columns.pop();
let scalar = &plan.items[0].scalar;

Expand Down Expand Up @@ -462,28 +462,34 @@ impl Binder {
.bind_project_set(&mut bind_context, &srfs, child)
.await?;

if let Some((_, flatten_scalar)) = bind_context.srfs.remove(&srf.to_string()) {
// Add result column to metadata
let data_type = flatten_scalar.data_type()?;
let index = self
.metadata
.write()
.add_derived_column(srf.to_string(), data_type.clone());
let column_binding = ColumnBindingBuilder::new(
srf.to_string(),
index,
Box::new(data_type),
Visibility::Visible,
)
.build();
bind_context.add_column_binding(column_binding);
if let Some((_, srf_result)) = bind_context.srfs.remove(&srf.to_string()) {
let column_binding =
if let ScalarExpr::BoundColumnRef(column_ref) = &srf_result {
column_ref.column.clone()
} else {
// Add result column to metadata
let data_type = srf_result.data_type()?;
let index = self
.metadata
.write()
.add_derived_column(srf.to_string(), data_type.clone());
ColumnBindingBuilder::new(
srf.to_string(),
index,
Box::new(data_type),
Visibility::Visible,
)
.build()
};

let eval_scalar = EvalScalar {
items: vec![ScalarItem {
scalar: flatten_scalar,
index,
scalar: srf_result,
index: column_binding.index,
}],
};
// Add srf result column
bind_context.add_column_binding(column_binding);

let flatten_expr =
SExpr::create_unary(Arc::new(eval_scalar.into()), Arc::new(srf_expr));
Expand All @@ -505,9 +511,8 @@ impl Binder {

return Ok((new_expr, bind_context));
} else {
return Err(
ErrorCode::Internal("srf flatten result is missing").set_span(*span)
);
return Err(ErrorCode::Internal("lateral join bind project_set failed")
.set_span(*span));
}
} else {
return Err(ErrorCode::InvalidArgument(format!(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,5 +84,53 @@ ProjectSet
├── push downs: [filters: [], limit: NONE]
└── estimated rows: 10.00

statement ok
drop table if exists t;

statement ok
create table t(a int, b variant);

query T
EXPLAIN SELECT t.a, f.seq, f.value FROM t, LATERAL FLATTEN(input => t.b) f
----
EvalScalar
├── output columns: [t.a (#0), seq (#3), value (#7)]
├── expressions: [get(1)(flatten (#2)), get(5)(flatten (#2))]
├── estimated rows: 0.00
└── ProjectSet
├── output columns: [t.a (#0), flatten (#2)]
├── estimated rows: 0.00
├── set returning functions: flatten(1, 5)(t.b (#1))
└── TableScan
├── table: default.project_set.t
├── output columns: [a (#0), b (#1)]
├── read rows: 0
├── read bytes: 0
├── partitions total: 0
├── partitions scanned: 0
├── push downs: [filters: [], limit: NONE]
└── estimated rows: 0.00

query T
EXPLAIN SELECT json_each(t.b), unnest(t.b) FROM t
----
EvalScalar
├── output columns: [json_each (#2), unnest(t.b) (#4)]
├── expressions: [get(1)(unnest (#3))]
├── estimated rows: 0.00
└── ProjectSet
├── output columns: [json_each (#2), unnest (#3)]
├── estimated rows: 0.00
├── set returning functions: json_each(t.b (#1)), unnest(t.b (#1))
└── TableScan
├── table: default.project_set.t
├── output columns: [b (#1)]
├── read rows: 0
├── read bytes: 0
├── partitions total: 0
├── partitions scanned: 0
├── push downs: [filters: [], limit: NONE]
└── estimated rows: 0.00

statement ok
drop database project_set
14 changes: 14 additions & 0 deletions tests/sqllogictests/suites/query/lateral.test
Original file line number Diff line number Diff line change
Expand Up @@ -156,5 +156,19 @@ GROUP BY p.id ORDER BY p.id
12712555 2
98127771 2

query IT
SELECT u.user_id, f.value from
user_activities u,
LATERAL unnest(u.activities) f
----
1 "reading"
1 "swimming"
1 "cycling"
2 "painting"
2 "running"
3 "cooking"
3 "climbing"
3 "writing"

statement ok
drop database test_lateral
Loading