Skip to content

Commit

Permalink
Merge branch 'mergeinto-runtime-filter' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
JackTan25 authored Jan 3, 2024
2 parents 8ce9622 + a226452 commit ff555dd
Show file tree
Hide file tree
Showing 23 changed files with 871 additions and 333 deletions.
79 changes: 58 additions & 21 deletions src/query/service/src/interpreters/interpreter_merge_into.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,14 @@ use databend_common_exception::Result;
use databend_common_expression::types::UInt32Type;
use databend_common_expression::ConstantFolder;
use databend_common_expression::DataBlock;
use databend_common_expression::DataField;
use databend_common_expression::DataSchema;
use databend_common_expression::DataSchemaRef;
use databend_common_expression::FieldIndex;
use databend_common_expression::FromData;
use databend_common_expression::RemoteExpr;
use databend_common_expression::SendableDataBlockStream;
use databend_common_expression::ROW_ID_COL_NAME;
use databend_common_expression::ROW_NUMBER_COL_NAME;
use databend_common_functions::BUILTIN_FUNCTIONS;
use databend_common_meta_app::schema::TableInfo;
Expand Down Expand Up @@ -162,6 +164,7 @@ impl MergeIntoInterpreter {
field_index_map,
merge_type,
distributed,
change_join_order,
..
} = &self.plan;

Expand All @@ -181,20 +184,18 @@ impl MergeIntoInterpreter {
let table_name = table_name.clone();
let input = input.clone();

let input = if let RelOperator::Exchange(_) = input.plan() {
Box::new(input.child(0)?.clone())
// we need to extract join plan, but we need to give this exchange
// back at last.
let (input, extract_exchange) = if let RelOperator::Exchange(_) = input.plan() {
(Box::new(input.child(0)?.clone()), true)
} else {
input
(input, false)
};

let optimized_input =
Self::build_static_filter(&input, meta_data, self.ctx.clone(), check_table).await?;
let mut builder = PhysicalPlanBuilder::new(meta_data.clone(), self.ctx.clone(), false);

// build source for MergeInto
let join_input = builder
.build(&optimized_input, *columns_set.clone())
.await?;
let join_input = builder.build(&input, *columns_set.clone()).await?;


// find row_id column index
let join_output_schema = join_input.output_schema()?;
Expand Down Expand Up @@ -227,7 +228,7 @@ impl MergeIntoInterpreter {
}
}

if *distributed {
if *distributed && !*change_join_order {
row_number_idx = Some(join_output_schema.index_of(ROW_NUMBER_COL_NAME)?);
}

Expand All @@ -238,7 +239,7 @@ impl MergeIntoInterpreter {
));
}

if *distributed && row_number_idx.is_none() {
if *distributed && row_number_idx.is_none() && !*change_join_order {
return Err(ErrorCode::InvalidRowIdIndex(
"can't get internal row_number_idx when running merge into",
));
Expand All @@ -258,11 +259,28 @@ impl MergeIntoInterpreter {

// merge_into_source is used to recv join's datablocks and split them into macthed and not matched
// datablocks.
let merge_into_source = PhysicalPlan::MergeIntoSource(MergeIntoSource {
input: Box::new(join_input),
row_id_idx: row_id_idx as u32,
merge_type: merge_type.clone(),
});
let merge_into_source = if !*distributed && extract_exchange {
// if we doesn't support distributed merge into, we should give the exchange merge back.
let rollback_join_input = PhysicalPlan::Exchange(Exchange {
plan_id: 0,
input: Box::new(join_input),
kind: FragmentKind::Merge,
keys: vec![],
allow_adjust_parallelism: true,
ignore_exchange: false,
});
PhysicalPlan::MergeIntoSource(MergeIntoSource {
input: Box::new(rollback_join_input),
row_id_idx: row_id_idx as u32,
merge_type: merge_type.clone(),
})
} else {
PhysicalPlan::MergeIntoSource(MergeIntoSource {
input: Box::new(join_input),
row_id_idx: row_id_idx as u32,
merge_type: merge_type.clone(),
})
};

// transform unmatched for insert
// reference to func `build_eval_scalar`
Expand Down Expand Up @@ -399,6 +417,7 @@ impl MergeIntoInterpreter {
distributed: false,
output_schema: DataSchemaRef::default(),
merge_type: merge_type.clone(),
change_join_order: *change_join_order,
}))
} else {
let merge_append = PhysicalPlan::MergeInto(Box::new(MergeInto {
Expand All @@ -409,14 +428,30 @@ impl MergeIntoInterpreter {
matched,
field_index_of_input_schema,
row_id_idx,
segments,
segments: segments.clone(),
distributed: true,
output_schema: DataSchemaRef::new(DataSchema::new(vec![
join_output_schema.fields[row_number_idx.unwrap()].clone(),
])),
output_schema: match *change_join_order {
false => DataSchemaRef::new(DataSchema::new(vec![
join_output_schema.fields[row_number_idx.unwrap()].clone(),
])),
true => DataSchemaRef::new(DataSchema::new(vec![DataField::new(
ROW_ID_COL_NAME,
databend_common_expression::types::DataType::Number(
databend_common_expression::types::NumberDataType::UInt64,
),
)])),
},
merge_type: merge_type.clone(),
change_join_order: *change_join_order,
}));

// if change_join_order = true, it means the target is build side,
// in this way, we will do matched operation and not matched operation
// locally in every node, and the main node just receive rowids to apply.
let segments = if *change_join_order {
segments.clone()
} else {
vec![]
};
PhysicalPlan::MergeIntoAppendNotMatched(Box::new(MergeIntoAppendNotMatched {
input: Box::new(PhysicalPlan::Exchange(Exchange {
plan_id: 0,
Expand All @@ -431,6 +466,8 @@ impl MergeIntoInterpreter {
unmatched: unmatched.clone(),
input_schema: merge_into_source.output_schema()?,
merge_type: merge_type.clone(),
change_join_order: *change_join_order,
segments,
}))
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,15 @@ use crate::interpreters::interpreter_merge_into::MergeIntoInterpreter;
use crate::interpreters::InterpreterFactory;
use crate::sessions::QueryContext;

#[allow(dead_code)]
struct MergeStyleJoin<'a> {
source_conditions: &'a [ScalarExpr],
target_conditions: &'a [ScalarExpr],
source_sexpr: &'a SExpr,
target_sexpr: &'a SExpr,
build_conditions: &'a [ScalarExpr],
probe_conditions: &'a [ScalarExpr],
build_sexpr: &'a SExpr,
probe_sexpr: &'a SExpr,
}

#[allow(dead_code)]
impl MergeStyleJoin<'_> {
pub fn new(join: &SExpr) -> MergeStyleJoin {
let join_op = match join.plan() {
Expand All @@ -73,25 +75,27 @@ impl MergeStyleJoin<'_> {
join_op.join_type == JoinType::Right
|| join_op.join_type == JoinType::RightAnti
|| join_op.join_type == JoinType::Inner
|| join_op.join_type == JoinType::Left
|| join_op.join_type == JoinType::LeftAnti
);
let source_conditions = &join_op.right_conditions;
let target_conditions = &join_op.left_conditions;
let source_sexpr = join.child(1).unwrap();
let target_sexpr = join.child(0).unwrap();
let build_conditions = &join_op.right_conditions;
let probe_conditions = &join_op.left_conditions;
let build_sexpr = join.child(1).unwrap();
let probe_sexpr = join.child(0).unwrap();
MergeStyleJoin {
source_conditions,
target_conditions,
source_sexpr,
target_sexpr,
build_conditions,
probe_conditions,
build_sexpr,
probe_sexpr,
}
}

pub fn collect_column_map(&self) -> HashMap<String, ColumnBinding> {
let mut column_map = HashMap::new();
for (t, s) in self
.target_conditions
.probe_conditions
.iter()
.zip(self.source_conditions.iter())
.zip(self.build_conditions.iter())
{
if let (ScalarExpr::BoundColumnRef(t_col), ScalarExpr::BoundColumnRef(s_col)) = (t, s) {
column_map.insert(t_col.column.column_name.clone(), s_col.column.clone());
Expand All @@ -101,6 +105,7 @@ impl MergeStyleJoin<'_> {
}
}

#[allow(dead_code)]
impl MergeIntoInterpreter {
pub async fn build_static_filter(
join: &SExpr,
Expand All @@ -119,7 +124,7 @@ impl MergeIntoInterpreter {
// \
// SourcePlan
let m_join = MergeStyleJoin::new(join);
if m_join.source_conditions.is_empty() {
if m_join.build_conditions.is_empty() {
return Ok(Box::new(join.clone()));
}
let column_map = m_join.collect_column_map();
Expand Down Expand Up @@ -181,9 +186,9 @@ impl MergeIntoInterpreter {

// 2. build filter and push down to target side
ctx.set_status_info("building pushdown filters");
let mut filters = Vec::with_capacity(m_join.target_conditions.len());
let mut filters = Vec::with_capacity(m_join.probe_conditions.len());

for (i, target_side_expr) in m_join.target_conditions.iter().enumerate() {
for (i, target_side_expr) in m_join.probe_conditions.iter().enumerate() {
let mut filter_parts = vec![];
for block in blocks.iter() {
let block = block.convert_to_full();
Expand Down Expand Up @@ -225,11 +230,11 @@ impl MergeIntoInterpreter {
}
filters.extend(Self::combine_filter_parts(&filter_parts).into_iter());
}
let mut target_plan = m_join.target_sexpr.clone();
Self::push_down_filters(&mut target_plan, &filters)?;
let source_plan = m_join.source_sexpr;
let mut probe_plan = m_join.probe_sexpr.clone();
Self::push_down_filters(&mut probe_plan, &filters)?;
let build_plan = m_join.build_sexpr;
let new_sexpr =
join.replace_children(vec![Arc::new(target_plan), Arc::new(source_plan.clone())]);
join.replace_children(vec![Arc::new(probe_plan), Arc::new(build_plan.clone())]);

ctx.set_status_info("join expression replaced");
Ok(Box::new(new_sexpr))
Expand Down Expand Up @@ -381,9 +386,9 @@ impl MergeIntoInterpreter {
metadata: &MetadataRef,
group_expr: ScalarExpr,
) -> Result<Plan> {
let mut eval_scalar_items = Vec::with_capacity(m_join.source_conditions.len());
let mut min_max_binding = Vec::with_capacity(m_join.source_conditions.len() * 2);
let mut min_max_scalar_items = Vec::with_capacity(m_join.source_conditions.len() * 2);
let mut eval_scalar_items = Vec::with_capacity(m_join.build_conditions.len());
let mut min_max_binding = Vec::with_capacity(m_join.build_conditions.len() * 2);
let mut min_max_scalar_items = Vec::with_capacity(m_join.build_conditions.len() * 2);
let mut group_items = vec![];

let index = metadata
Expand All @@ -407,46 +412,46 @@ impl MergeIntoInterpreter {
scalar: evaled,
index,
});
for source_side_expr in m_join.source_conditions {
for build_side_expr in m_join.build_conditions {
// eval source side join expr
let index = metadata
.write()
.add_derived_column("".to_string(), source_side_expr.data_type()?);
.add_derived_column("".to_string(), build_side_expr.data_type()?);
let evaled = ScalarExpr::BoundColumnRef(BoundColumnRef {
span: None,
column: ColumnBindingBuilder::new(
"".to_string(),
index,
Box::new(source_side_expr.data_type()?),
Box::new(build_side_expr.data_type()?),
Visibility::Visible,
)
.build(),
});
eval_scalar_items.push(ScalarItem {
scalar: source_side_expr.clone(),
scalar: build_side_expr.clone(),
index,
});

// eval min/max of source side join expr
let min_display_name = format!("min({:?})", source_side_expr);
let max_display_name = format!("max({:?})", source_side_expr);
let min_display_name = format!("min({:?})", build_side_expr);
let max_display_name = format!("max({:?})", build_side_expr);
let min_index = metadata
.write()
.add_derived_column(min_display_name.clone(), source_side_expr.data_type()?);
.add_derived_column(min_display_name.clone(), build_side_expr.data_type()?);
let max_index = metadata
.write()
.add_derived_column(max_display_name.clone(), source_side_expr.data_type()?);
.add_derived_column(max_display_name.clone(), build_side_expr.data_type()?);
let min_binding = ColumnBindingBuilder::new(
min_display_name.clone(),
min_index,
Box::new(source_side_expr.data_type()?),
Box::new(build_side_expr.data_type()?),
Visibility::Visible,
)
.build();
let max_binding = ColumnBindingBuilder::new(
max_display_name.clone(),
max_index,
Box::new(source_side_expr.data_type()?),
Box::new(build_side_expr.data_type()?),
Visibility::Visible,
)
.build();
Expand All @@ -458,7 +463,7 @@ impl MergeIntoInterpreter {
distinct: false,
params: vec![],
args: vec![evaled.clone()],
return_type: Box::new(source_side_expr.data_type()?),
return_type: Box::new(build_side_expr.data_type()?),
display_name: min_display_name.clone(),
}),
index: min_index,
Expand All @@ -469,7 +474,7 @@ impl MergeIntoInterpreter {
distinct: false,
params: vec![],
args: vec![evaled],
return_type: Box::new(source_side_expr.data_type()?),
return_type: Box::new(build_side_expr.data_type()?),
display_name: max_display_name.clone(),
}),
index: max_index,
Expand All @@ -478,21 +483,26 @@ impl MergeIntoInterpreter {
min_max_scalar_items.push(max);
}

let eval_source_side_join_expr_op = EvalScalar {
let eval_build_side_join_expr_op = EvalScalar {
items: eval_scalar_items,
};
let source_plan = m_join.source_sexpr;
let eval_target_side_condition_sexpr = if let RelOperator::Exchange(_) = source_plan.plan()
{
let build_plan = m_join.build_sexpr;
let eval_probe_side_condition_sexpr = if let RelOperator::Exchange(_) = build_plan.plan() {
// there is another row_number operator here
SExpr::create_unary(
Arc::new(eval_source_side_join_expr_op.into()),
Arc::new(source_plan.child(0)?.child(0)?.clone()),
Arc::new(eval_build_side_join_expr_op.into()),
Arc::new(SExpr::create_unary(
// merge data here
Arc::new(RelOperator::Exchange(
databend_common_sql::plans::Exchange::Merge,
)),
Arc::new(build_plan.child(0)?.child(0)?.clone()),
)),
)
} else {
SExpr::create_unary(
Arc::new(eval_source_side_join_expr_op.into()),
Arc::new(source_plan.clone()),
Arc::new(eval_build_side_join_expr_op.into()),
Arc::new(build_plan.clone()),
)
};

Expand All @@ -509,7 +519,7 @@ impl MergeIntoInterpreter {
};
let agg_partial_sexpr = SExpr::create_unary(
Arc::new(agg_partial_op.into()),
Arc::new(eval_target_side_condition_sexpr),
Arc::new(eval_probe_side_condition_sexpr),
);
let agg_final_op = Aggregate {
mode: AggregateMode::Final,
Expand Down
1 change: 0 additions & 1 deletion src/query/service/src/interpreters/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ mod interpreter_index_refresh;
mod interpreter_insert;
mod interpreter_kill;
mod interpreter_merge_into;
mod interpreter_merge_into_static_filter;
mod interpreter_metrics;
mod interpreter_network_policies_show;
mod interpreter_network_policy_alter;
Expand Down
Loading

0 comments on commit ff555dd

Please sign in to comment.