Skip to content

Commit

Permalink
refactor: refactor join spilling settings (#14781)
Browse files Browse the repository at this point in the history
* refactor: refactor join spilling settings

* fix lint
  • Loading branch information
xudong963 authored Feb 28, 2024
1 parent f9e0e1c commit 4def8bb
Show file tree
Hide file tree
Showing 12 changed files with 191 additions and 98 deletions.
5 changes: 1 addition & 4 deletions src/query/service/src/pipelines/builders/builder_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,16 +145,13 @@ impl PipelineBuilder {
assert!(build_res.main_pipeline.is_pulling_pipeline()?);
let output_len = build_res.main_pipeline.output_len();
let spill_coordinator = BuildSpillCoordinator::create(output_len);
let barrier = Barrier::new(output_len);
let restore_barrier = Barrier::new(output_len);
let build_state = HashJoinBuildState::try_create(
self.ctx.clone(),
self.func_ctx.clone(),
&hash_join_plan.build_keys,
&hash_join_plan.build_projections,
join_state.clone(),
barrier,
restore_barrier,
output_len,
)?;

let create_sink_processor = |input| {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ pub struct BuildSpillCoordinator {
/// If there is the last active processor, send true to watcher channel
pub(crate) ready_spill_watcher: Sender<bool>,
pub(crate) dummy_ready_spill_receiver: Receiver<bool>,
pub(crate) mutex: Mutex<()>,
}

impl BuildSpillCoordinator {
Expand All @@ -59,6 +60,7 @@ impl BuildSpillCoordinator {
non_spill_processors: Default::default(),
ready_spill_watcher,
dummy_ready_spill_receiver,
mutex: Default::default(),
})
}

Expand All @@ -70,6 +72,7 @@ impl BuildSpillCoordinator {

// If current waiting spilling builder is the last one, then spill all builders.
pub(crate) fn wait_spill(&self) -> Result<bool> {
let _lock = self.mutex.lock();
if *self.dummy_ready_spill_receiver.borrow() {
self.ready_spill_watcher
.send(false)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@ use crate::spillers::Spiller;
use crate::spillers::SpillerConfig;
use crate::spillers::SpillerType;

/// Define some states for hash join build spilling
/// Each processor owns its `BuildSpillState`
// Define some states for hash join build spilling
// Each processor owns its `BuildSpillState`
pub struct BuildSpillState {
/// Hash join build state
// Hash join build state
pub build_state: Arc<HashJoinBuildState>,
/// Hash join build spilling coordinator
// Hash join build spilling coordinator
pub spill_coordinator: Arc<BuildSpillCoordinator>,
/// Spiller, responsible for specific spill work
// Spiller, responsible for specific spill work
pub spiller: Spiller,
}

Expand Down Expand Up @@ -133,11 +133,24 @@ impl BuildSpillState {

// Check if need to spill.
// Notes: even if the method returns false, but there exists one processor need to spill, then it needs to wait spill.
pub(crate) fn check_need_spill(&self) -> Result<bool> {
pub(crate) fn check_need_spill(&self, input: &Option<DataBlock>) -> Result<bool> {
if self.spiller.is_all_spilled() {
return Ok(false);
}

// Check if input data size is bigger than `spilling_threshold_per_proc`
if let Some(input_data) = input {
let input_data_bytes = input_data.memory_size();
let spill_threshold_per_proc = self.build_state.spilling_threshold_per_proc;
if input_data_bytes > spill_threshold_per_proc {
info!(
"input data: {:?} bytes, spilling threshold per processor: {:?} bytes",
input_data_bytes, spill_threshold_per_proc
);
return Ok(true);
}
}

// Check if there are rows in `RowSpace`'s buffer and `Chunks`.
// If not, directly return false, no need to spill.
let buffer = self.build_state.hash_join_state.row_space.buffer.read();
Expand All @@ -155,17 +168,17 @@ impl BuildSpillState {
if global_used < 0 {
global_used = 0;
}
let max_memory_usage = self.build_state.max_memory_usage;
let byte = Byte::from_unit(global_used as f64, ByteUnit::B).unwrap();
let total_gb = byte.get_appropriate_unit(false).format(3);
let spill_threshold = self
.build_state
.ctx
.get_settings()
.get_join_spilling_threshold()?;
if global_used as usize > spill_threshold {
if global_used as usize > max_memory_usage {
let spill_threshold_gb = Byte::from_unit(max_memory_usage as f64, ByteUnit::B)
.unwrap()
.get_appropriate_unit(false)
.format(3);
info!(
"need to spill due to global memory usage {:?} is greater than spill threshold",
total_gb
"need to spill due to global memory usage {:?} is greater than spill threshold {:?}",
total_gb, spill_threshold_gb
);
return Ok(true);
}
Expand All @@ -179,7 +192,7 @@ impl BuildSpillState {
total_bytes += block.memory_size();
}

if total_bytes * 3 > spill_threshold {
if total_bytes * 3 > max_memory_usage {
return Ok(true);
}
Ok(false)
Expand All @@ -188,13 +201,9 @@ impl BuildSpillState {
// Pick partitions which need to spill
#[allow(unused)]
fn pick_partitions(&self, partition_blocks: &mut HashMap<u8, Vec<DataBlock>>) -> Result<()> {
let mut memory_limit = self
.build_state
.ctx
.get_settings()
.get_join_spilling_threshold()?;
let mut max_memory_usage = self.build_state.max_memory_usage;
let global_used = GLOBAL_MEM_STAT.get_memory_usage();
if global_used as usize > memory_limit {
if global_used as usize > max_memory_usage {
return Ok(());
}
// Compute each partition's data size
Expand All @@ -210,7 +219,7 @@ impl BuildSpillState {
partition_sizes.sort_by_key(|&(_id, size)| size);

for (id, size) in partition_sizes.into_iter() {
if size as f64 <= memory_limit as f64 / 3.0 {
if size as f64 <= max_memory_usage as f64 / 3.0 {
// Put the partition's data to chunks
let chunks =
&mut unsafe { &mut *self.build_state.hash_join_state.build_state.get() }
Expand All @@ -223,7 +232,7 @@ impl BuildSpillState {
unsafe { &mut *self.build_state.hash_join_state.build_state.get() };
build_state.generation_state.build_num_rows += rows_num;
partition_blocks.remove(&id);
memory_limit -= size;
max_memory_usage -= size;
} else {
break;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ use databend_common_exception::Result;
use databend_common_expression::DataBlock;
use log::info;

use crate::pipelines::processors::transforms::hash_join::transform_hash_join_build::HashJoinBuildStep;
use crate::pipelines::processors::transforms::BuildSpillState;

pub struct BuildSpillHandler {
Expand Down Expand Up @@ -68,6 +67,21 @@ impl BuildSpillHandler {
self.spill_data = Some(spill_data);
}

pub(crate) fn spill_data(&mut self) -> &mut Option<DataBlock> {
&mut self.spill_data
}

pub(crate) fn need_to_wait_probe(&self) -> bool {
if !self.enabled_spill() {
return false;
}
// Spilling actually didn't happen.
if self.spill_state().spiller.spilled_partition_set.is_empty() {
return false;
}
true
}

// Get `spilled_partition_set` from spiller and set `sent_partition_set` to true
pub(crate) fn spilled_partition_set(&mut self) -> Option<HashSet<u8>> {
if !self.enabled_spill() {
Expand All @@ -83,11 +97,11 @@ impl BuildSpillHandler {
// Request `spill_coordinator` to spill, it will return two possible steps:
// 1. WaitSpill
// 2. Start the first spilling if the processor is the last processor which waits for spilling
pub(crate) fn request_spill(&mut self) -> Result<HashJoinBuildStep> {
pub(crate) fn request_spill(&mut self) -> Result<()> {
let spill_state = self.spill_state_mut();
let wait = spill_state.spill_coordinator.wait_spill()?;
if wait {
return Ok(HashJoinBuildStep::WaitSpill);
return Ok(());
}
// Before notify all processors to spill, we need to collect all buffered data in `RowSpace` and `Chunks`
// Partition all rows and stat how many partitions and rows in each partition.
Expand All @@ -103,28 +117,35 @@ impl BuildSpillHandler {
.ready_spill_watcher
.send(true)
.map_err(|_| ErrorCode::TokioError("ready_spill_watcher channel is closed"))?;
Ok(HashJoinBuildStep::FirstSpill)
Ok(())
}

// Check if fit into memory and return next step
// If step doesn't be changed, return None.
pub(crate) fn check_memory_and_next_step(&mut self) -> Result<Option<HashJoinBuildStep>> {
// Check if need to wait spill
pub(crate) fn check_need_spill(&mut self, input: &mut Option<DataBlock>) -> Result<bool> {
if !self.enabled_spill() || self.after_spill() {
return Ok(None);
return Ok(false);
}
let spill_state = self.spill_state();
if spill_state.spiller.is_all_spilled() {
return Ok(None);
return Ok(false);
}
if spill_state.check_need_spill()? {
if spill_state.check_need_spill(input)? {
spill_state.spill_coordinator.need_spill()?;
return Ok(Some(self.request_spill()?));
if let Some(input) = input.take() {
spill_state.build_state.build(input)?;
}
self.request_spill()?;
return Ok(true);
} else if spill_state.spill_coordinator.get_need_spill() {
if let Some(input) = input.take() {
spill_state.build_state.build(input)?;
}
// even if input can fit into memory, but there exists one processor need to spill,
// then it needs to wait spill.
return Ok(Some(self.request_spill()?));
self.request_spill()?;
return Ok(true);
}
Ok(None)
Ok(false)
}

// Check if current processor is the last processor that is responsible for notifying spilling
Expand Down Expand Up @@ -180,18 +201,19 @@ impl BuildSpillHandler {
}

// Spill data block and return data that wasn't spilled
pub(crate) async fn spill(&mut self, processor_id: usize) -> Result<DataBlock> {
let mut unspilled_data = DataBlock::empty();
if let Some(block) = self.spill_data.take() {
let mut hashes = Vec::with_capacity(block.num_rows());
let spill_state = self.spill_state_mut();
spill_state.get_hashes(&block, &mut hashes)?;
let spilled_partition_set = spill_state.spiller.spilled_partition_set.clone();
unspilled_data = spill_state
.spiller
.spill_input(block, &hashes, &spilled_partition_set, processor_id)
.await?;
}
pub(crate) async fn spill(
&mut self,
block: DataBlock,
processor_id: usize,
) -> Result<DataBlock> {
let mut hashes = Vec::with_capacity(block.num_rows());
let spill_state = self.spill_state_mut();
spill_state.get_hashes(&block, &mut hashes)?;
let spilled_partition_set = spill_state.spiller.spilled_partition_set.clone();
let unspilled_data = spill_state
.spiller
.spill_input(block, &hashes, &spilled_partition_set, processor_id)
.await?;
Ok(unspilled_data)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,12 @@ pub struct HashJoinBuildState {
pub(crate) send_val: AtomicU8,
/// Wait all processors finish read spilled data, then go to new round build
pub(crate) restore_barrier: Barrier,
// Max memory usage for join
pub max_memory_usage: usize,
// Spilling threshold for each processor
pub spilling_threshold_per_proc: usize,

/// Runtime filter related states
pub(crate) enable_inlist_runtime_filter: bool,
pub(crate) enable_min_max_runtime_filter: bool,
/// Need to open runtime filter setting.
Expand All @@ -127,8 +133,7 @@ impl HashJoinBuildState {
build_keys: &[RemoteExpr],
build_projections: &ColumnSet,
hash_join_state: Arc<HashJoinState>,
barrier: Barrier,
restore_barrier: Barrier,
num_threads: usize,
) -> Result<Arc<HashJoinBuildState>> {
let hash_key_types = build_keys
.iter()
Expand All @@ -144,7 +149,7 @@ impl HashJoinBuildState {
let mut enable_inlist_runtime_filter = false;
let mut enable_min_max_runtime_filter = false;
if supported_join_type_for_runtime_filter(&hash_join_state.hash_join_desc.join_type)
&& ctx.get_settings().get_join_spilling_threshold()? == 0
&& ctx.get_settings().get_join_spilling_memory_ratio()? == 0
{
let is_cluster = !ctx.get_cluster().is_empty();
// For cluster, only support runtime filter for broadcast join.
Expand All @@ -157,14 +162,16 @@ impl HashJoinBuildState {
}
}
let chunk_size_limit = ctx.get_settings().get_max_block_size()? as usize * 16;

let (max_memory_usage, spilling_threshold_per_proc) =
Self::max_memory_usage(ctx.clone(), num_threads)?;
Ok(Arc::new(Self {
ctx: ctx.clone(),
func_ctx,
hash_join_state,
chunk_size_limit,
barrier,
restore_barrier,
barrier: Barrier::new(num_threads),
restore_barrier: Barrier::new(num_threads),
max_memory_usage,
row_space_builders: Default::default(),
method,
entry_size: Default::default(),
Expand All @@ -177,9 +184,35 @@ impl HashJoinBuildState {
enable_bloom_runtime_filter,
enable_inlist_runtime_filter,
enable_min_max_runtime_filter,
spilling_threshold_per_proc,
}))
}

// Get max memory usage for settings
fn max_memory_usage(ctx: Arc<QueryContext>, num_threads: usize) -> Result<(usize, usize)> {
debug_assert!(num_threads != 0);
let settings = ctx.get_settings();
let spilling_threshold_per_proc = settings.get_join_spilling_bytes_threshold_per_proc()?;
let mut memory_ratio = settings.get_join_spilling_memory_ratio()? as f64 / 100_f64;
if memory_ratio > 1_f64 {
memory_ratio = 1_f64;
}
let max_memory_usage = match settings.get_max_memory_usage()? {
0 => usize::MAX,
max_memory_usage => match memory_ratio {
mr if mr == 0_f64 => usize::MAX,
mr => (max_memory_usage as f64 * mr) as usize,
},
};

let spilling_threshold_per_proc = match spilling_threshold_per_proc {
0 => max_memory_usage / num_threads,
bytes => bytes,
};

Ok((max_memory_usage, spilling_threshold_per_proc))
}

/// Add input `DataBlock` to `hash_join_state.row_space`.
pub fn build(&self, input: DataBlock) -> Result<()> {
let mut buffer = self.hash_join_state.row_space.buffer.write();
Expand Down Expand Up @@ -304,7 +337,7 @@ impl HashJoinBuildState {
JoinType::LeftMark | JoinType::RightMark
)
&& self.ctx.get_cluster().is_empty()
&& self.ctx.get_settings().get_join_spilling_threshold()? == 0
&& self.ctx.get_settings().get_join_spilling_memory_ratio()? == 0
{
self.hash_join_state
.fast_return
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ impl HashJoinState {
let (continue_build_watcher, _continue_build_dummy_receiver) = watch::channel(false);
let mut enable_spill = false;
if hash_join_desc.join_type == JoinType::Inner
&& ctx.get_settings().get_join_spilling_threshold()? != 0
&& ctx.get_settings().get_join_spilling_memory_ratio()? != 0
{
enable_spill = true;
}
Expand Down
Loading

0 comments on commit 4def8bb

Please sign in to comment.