Skip to content

Commit

Permalink
Add modeling support for subqueries
Browse files Browse the repository at this point in the history
  • Loading branch information
asenac committed Jan 2, 2024
1 parent 8797efd commit 2bfd990
Show file tree
Hide file tree
Showing 17 changed files with 550 additions and 32 deletions.
6 changes: 6 additions & 0 deletions src/query_graph/explain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ impl<'a> Explainer<'a> {
let mut explain = ExplainVisitor::new(self);
self.query_graph
.visit_subgraph(&mut explain, self.entry_point);
let subquery_roots = self.query_graph.subquery_roots();
for subquery_root in subquery_roots {
explain.result += "\n";
self.query_graph.visit_subgraph(&mut explain, subquery_root);
}
explain.result
}
}
Expand Down Expand Up @@ -156,6 +161,7 @@ impl<'a> QueryGraphPrePostVisitor for ExplainVisitor<'a> {
.join(", "),
),
QueryNode::Union { .. } => format!("{}Union\n", prefix),
QueryNode::SubqueryRoot { .. } => format!("{}SubqueryRoot\n", prefix),
};
self.result += &node;

Expand Down
21 changes: 20 additions & 1 deletion src/query_graph/json.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
//! JSON serializer for generating visual representations of the plans.
use std::collections::VecDeque;

use crate::{
query_graph::{explain::explain_scalar_expr_vec, *},
scalar_expr::ScalarExpr,
Expand All @@ -11,6 +13,7 @@ pub struct JsonSerializer<'a> {
annotators: Vec<&'a dyn Fn(&QueryGraph, NodeId) -> Option<String>>,
included_nodes: HashSet<NodeId>,
graph: Graph,
queue: VecDeque<NodeId>,
}

impl<'a> JsonSerializer<'a> {
Expand All @@ -19,6 +22,7 @@ impl<'a> JsonSerializer<'a> {
annotators,
included_nodes: HashSet::new(),
graph: Graph::new(),
queue: VecDeque::new(),
}
}

Expand All @@ -28,7 +32,10 @@ impl<'a> JsonSerializer<'a> {

/// Ensure the given subgraph is included in the output graph.
pub fn add_subgraph(&mut self, query_graph: &QueryGraph, node_id: NodeId) {
query_graph.visit_subgraph(self, node_id);
self.queue.push_back(node_id);
while let Some(node_id) = self.queue.pop_front() {
query_graph.visit_subgraph(self, node_id);
}
}

pub fn add_node_replacement(
Expand Down Expand Up @@ -100,6 +107,7 @@ impl<'a> QueryGraphPrePostVisitor for JsonSerializer<'a> {
.join(", "),
),
QueryNode::Union { .. } => format!("{}Union", prefix),
QueryNode::SubqueryRoot { .. } => format!("{}SubqueryRoot", prefix),
};
let mut annotations = Vec::new();
for annotator in self.annotators.iter() {
Expand All @@ -121,6 +129,17 @@ impl<'a> QueryGraphPrePostVisitor for JsonSerializer<'a> {
label: format!("input {}", i),
});
}

// Link the current node with the subqueries it references
let subqueries = node.collect_subqueries();
for subquery_root in subqueries {
self.queue.push_back(subquery_root);
self.graph.edges.push(Edge {
from: node_id.to_string(),
to: subquery_root.to_string(),
label: format!("subquery({})", subquery_root),
});
}
return PreOrderVisitationResult::VisitInputs;
}

Expand Down
100 changes: 97 additions & 3 deletions src/query_graph/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use itertools::Itertools;

use crate::{
data_type::DataType,
scalar_expr::{AggregateExprRef, ScalarExprRef},
scalar_expr::{visitor::visit_expr_pre, AggregateExprRef, ScalarExprRef},
visitor_utils::PreOrderVisitationResult,
};
use std::{
cell::RefCell,
Expand Down Expand Up @@ -59,6 +62,10 @@ pub enum QueryNode {
Union {
inputs: Vec<NodeId>,
},
/// Subgraph root.
SubqueryRoot {
input: NodeId,
},
}

pub struct QueryGraph {
Expand All @@ -72,6 +79,8 @@ pub struct QueryGraph {
/// For each node, it contains a set with the nodes pointing to it through any of their
/// inputs.
parents: HashMap<NodeId, BTreeSet<NodeId>>,
/// Subqueries
subqueries: Vec<Rc<NodeId>>,
/// Keeps track of the number of node replacements the query graph has gone through.
pub gen_number: usize,
pub property_cache: RefCell<PropertyCache>,
Expand All @@ -85,6 +94,7 @@ impl QueryNode {
Self::TableScan { .. } => 0,
Self::Join { .. } => 2,
Self::Union { inputs } => inputs.len(),
Self::SubqueryRoot { .. } => 1,
}
}

Expand All @@ -95,7 +105,8 @@ impl QueryNode {
match self {
Self::Project { input, .. }
| Self::Filter { input, .. }
| Self::Aggregate { input, .. } => *input,
| Self::Aggregate { input, .. }
| Self::SubqueryRoot { input } => *input,
Self::TableScan { .. } => panic!(),
Self::Join { left, right, .. } => {
if input_idx == 0 {
Expand All @@ -116,7 +127,8 @@ impl QueryNode {
match self {
Self::Project { input, .. }
| Self::Filter { input, .. }
| Self::Aggregate { input, .. } => *input = node_id,
| Self::Aggregate { input, .. }
| Self::SubqueryRoot { input } => *input = node_id,
Self::TableScan { .. } => panic!(),
Self::Join { left, right, .. } => {
if input_idx == 0 {
Expand All @@ -128,6 +140,53 @@ impl QueryNode {
Self::Union { inputs } => inputs[input_idx] = node_id,
}
}

/// Visit the scalar expressions within the node.
pub fn visit_scalar_expr<F>(&self, visitor: &mut F)
where
F: FnMut(&ScalarExprRef),
{
match self {
QueryNode::Project { outputs: exprs, .. }
| QueryNode::Filter {
conditions: exprs, ..
}
| QueryNode::Join {
conditions: exprs, ..
} => {
for expr in exprs {
visitor(expr);
}
}
QueryNode::TableScan { .. }
| QueryNode::Aggregate { .. }
| QueryNode::Union { .. }
| QueryNode::SubqueryRoot { .. } => {}
}
}

/// Returns the subqueries contained in the node
pub fn collect_subqueries(&self) -> BTreeSet<NodeId> {
let mut subqueries = BTreeSet::new();
self.visit_scalar_expr(&mut |expr| {
visit_expr_pre(expr, &mut |curr_expr| {
use crate::scalar_expr::ScalarExpr;
match curr_expr.as_ref() {
ScalarExpr::Literal(_)
| ScalarExpr::InputRef { .. }
| ScalarExpr::BinaryOp { .. }
| ScalarExpr::NaryOp { .. } => {}
ScalarExpr::ScalarSubquery { subquery }
| ScalarExpr::ExistsSubquery { subquery }
| ScalarExpr::ScalarSubqueryCmp { subquery, .. } => {
subqueries.insert(**subquery);
}
}
PreOrderVisitationResult::VisitInputs
});
});
subqueries
}
}

impl QueryGraph {
Expand All @@ -138,6 +197,7 @@ impl QueryGraph {
next_node_id: 0,
gen_number: 0,
parents: HashMap::new(),
subqueries: Vec::new(),
property_cache: RefCell::new(PropertyCache::new()),
}
}
Expand Down Expand Up @@ -171,6 +231,19 @@ impl QueryGraph {
node_id
}

pub fn add_subquery(&mut self, input: NodeId) -> Rc<NodeId> {
let root_id = Rc::new(self.add_node(QueryNode::SubqueryRoot { input }));
self.subqueries.push(root_id.clone());
return root_id;
}

pub fn subquery_roots(&self) -> Vec<NodeId> {
self.subqueries
.iter()
.map(|root_id| **root_id)
.collect_vec()
}

/// Finds whether there is an existing node exactly like the given one.
fn find_node(&self, node: &QueryNode) -> Option<NodeId> {
self.nodes.iter().find_map(|(node_id, existing_node)| {
Expand Down Expand Up @@ -229,6 +302,7 @@ impl QueryGraph {
}

self.remove_detached_nodes(node_id);
self.garbage_collect_subqueries();
self.gen_number += 1;
}

Expand All @@ -251,6 +325,25 @@ impl QueryGraph {
.filter(|(x, _)| visited_nodes.contains(x))
.collect();
}

// Removes subquery plans that are no longer referenced by any subquery
// expression.
pub fn garbage_collect_subqueries(&mut self) {
let mut detached_roots = HashSet::new();
self.subqueries.retain(|subquery_root_id| {
if std::rc::Rc::<NodeId>::strong_count(&subquery_root_id) > 1 {
true
} else {
// The Root node is only expected to be referenced by subquery
// expressions
detached_roots.insert(**subquery_root_id);
false
}
});
for detached_root in detached_roots {
self.remove_detached_nodes(detached_root);
}
}
}

/// Useful node construction methods.
Expand Down Expand Up @@ -307,6 +400,7 @@ impl Clone for QueryGraph {
next_node_id: self.next_node_id,
gen_number: self.gen_number,
parents: self.parents.clone(),
subqueries: self.subqueries.clone(),
// Cached metadata is not cloned
property_cache: RefCell::new(PropertyCache::new()),
}
Expand Down
5 changes: 4 additions & 1 deletion src/query_graph/optimizer/rules/equality_propagation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ impl SingleReplacementRule for EqualityPropagationRule {

let new_left_predicates = if from_right_to_left_allowed {
Self::propagate_predicates(
&query_graph,
&right_predicates,
&right_to_left,
&left_predicates,
Expand All @@ -72,6 +73,7 @@ impl SingleReplacementRule for EqualityPropagationRule {
};
let new_right_predicates = if from_left_to_right_allowed {
Self::propagate_predicates(
&query_graph,
&left_predicates,
&left_to_right,
&right_predicates,
Expand Down Expand Up @@ -165,6 +167,7 @@ impl EqualityPropagationRule {
/// the rewritten predicate only references columns from the other side using `validate_input_ref`
/// and that the resulting predicate is not already known.
fn propagate_predicates<F>(
query_graph: &QueryGraph,
predicates: &Vec<ScalarExprRef>,
translation_map: &HashMap<ScalarExprRef, ScalarExprRef>,
other_side_predicates: &Vec<ScalarExprRef>,
Expand All @@ -188,7 +191,7 @@ impl EqualityPropagationRule {
)
.unwrap();
let rewritten_predicate =
reduce_expr_recursively(&rewritten_predicate, cross_product_row_type);
reduce_expr_recursively(&rewritten_predicate, query_graph, cross_product_row_type);

if !other_side_predicates.contains(&rewritten_predicate)
&& collect_input_dependencies(&rewritten_predicate)
Expand Down
6 changes: 3 additions & 3 deletions src/query_graph/optimizer/rules/expression_reduction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ impl SingleReplacementRule for ExpressionReductionRule {
*input,
outputs
.iter()
.map(|e| reduce_expr_recursively(e, &row_type))
.map(|e| reduce_expr_recursively(e, &query_graph, &row_type))
.collect_vec(),
)
}
Expand All @@ -34,7 +34,7 @@ impl SingleReplacementRule for ExpressionReductionRule {
*input,
conditions
.iter()
.map(|e| reduce_expr_recursively(e, &row_type))
.map(|e| reduce_expr_recursively(e, &query_graph, &row_type))
.collect_vec(),
)
}
Expand All @@ -55,7 +55,7 @@ impl SingleReplacementRule for ExpressionReductionRule {
join_type: join_type.clone(),
conditions: conditions
.iter()
.map(|e| reduce_expr_recursively(e, &row_type))
.map(|e| reduce_expr_recursively(e, &query_graph, &row_type))
.collect_vec(),
left: *left,
right: *right,
Expand Down
7 changes: 5 additions & 2 deletions src/query_graph/optimizer/rules/outer_to_inner_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ fn do_all_parents_reject_null_from_non_preserving(
let input_row_type = row_type(query_graph, *input);
// 4.) and 5.)
if any_condition_rejects_nulls(
query_graph,
&rewrite_map,
&input_row_type,
conditions,
Expand Down Expand Up @@ -147,6 +148,7 @@ fn do_all_parents_reject_null_from_non_preserving(
cross_product_row_type(query_graph, node_id).unwrap();
// 4.) and 5.)
if any_condition_rejects_nulls(
query_graph,
&rewrite_map,
&input_row_type,
conditions,
Expand Down Expand Up @@ -232,7 +234,7 @@ fn build_rewrite_map(
.filter_map(|(i, e)| {
if let Some(e) = e {
// Reduce the expression containing nulls instead of input refs
let reduced_expr = reduce_expr_recursively(&e, &non_prev_row_type);
let reduced_expr = reduce_expr_recursively(&e, &query_graph, &non_prev_row_type);
// If the expression can be reduced to a literal, we can add it to
// the replacement map.
if reduced_expr.is_literal() {
Expand Down Expand Up @@ -260,6 +262,7 @@ fn replace_with_nulls(expr: &ScalarExprRef, row_type: &[DataType]) -> ScalarExpr
/// Check whether any of the given conditions evaluates to null or false after
/// applying the replacements in the given map.
fn any_condition_rejects_nulls(
query_graph: &QueryGraph,
rewrite_map: &HashMap<usize, ScalarExprRef>,
row_type: &[DataType],
conditions: &Vec<ScalarExprRef>,
Expand All @@ -276,7 +279,7 @@ fn any_condition_rejects_nulls(
condition,
);
// 5.)
let reduced_expr = reduce_expr_recursively(&rewritten_expr, row_type);
let reduced_expr = reduce_expr_recursively(&rewritten_expr, query_graph, row_type);
match reduced_expr.as_ref() {
ScalarExpr::Literal(Literal {
value: Value::Bool(false),
Expand Down
4 changes: 3 additions & 1 deletion src/query_graph/properties/input_dependencies.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ pub fn input_dependencies(query_graph: &QueryGraph, node_id: NodeId) -> HashSet<
dependencies.extend(aggregate.operands.iter());
}
}
QueryNode::Union { .. } => dependencies.extend(0..num_columns(query_graph, node_id)),
QueryNode::Union { .. } | QueryNode::SubqueryRoot { .. } => {
dependencies.extend(0..num_columns(query_graph, node_id))
}
}
dependencies
}
3 changes: 3 additions & 0 deletions src/query_graph/properties/keys.rs
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,9 @@ impl Keys {
}
}
}
QueryNode::SubqueryRoot { input } => {
keys.extend(self.keys_unchecked(query_graph, *input).iter().cloned());
}
};
// Normalize the keys, remove constants
// TODO(asenac) consider removing the non-normalized version
Expand Down
4 changes: 3 additions & 1 deletion src/query_graph/properties/num_columns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ impl NumColumns {
fn compute_num_columns_for_node(&self, query_graph: &QueryGraph, node_id: NodeId) -> usize {
match query_graph.node(node_id) {
QueryNode::Project { outputs, .. } => outputs.len(),
QueryNode::Filter { input, .. } => self.num_columns_unchecked(query_graph, *input),
QueryNode::Filter { input, .. } | QueryNode::SubqueryRoot { input } => {
self.num_columns_unchecked(query_graph, *input)
}
QueryNode::TableScan { row_type, .. } => row_type.len(),
QueryNode::Join {
join_type,
Expand Down
Loading

0 comments on commit 2bfd990

Please sign in to comment.