Skip to content

Commit

Permalink
Refactor SortedStream trait, output block along with order column.
Browse files Browse the repository at this point in the history
  • Loading branch information
RinChanNOWWW committed Dec 15, 2023
1 parent e2649c9 commit 524d89b
Show file tree
Hide file tree
Showing 11 changed files with 134 additions and 103 deletions.
7 changes: 7 additions & 0 deletions src/query/expression/src/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,13 @@ impl DataBlock {
}
DataBlock::new_with_meta(columns, self.num_rows, self.meta)
}

#[inline]
pub fn get_last_column(&self) -> &Column {
debug_assert!(!self.columns.is_empty());
debug_assert!(self.columns.last().unwrap().value.as_column().is_some());
self.columns.last().unwrap().value.as_column().unwrap()
}
}

impl TryFrom<DataBlock> for ArrowChunk<ArrayRef> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,27 +18,27 @@ use std::collections::VecDeque;
use std::sync::Arc;

use common_exception::Result;
use common_expression::Column;
use common_expression::DataBlock;
use common_expression::DataSchemaRef;
use common_expression::SortColumnDescription;

use super::utils::find_bigger_child_of_root;
use super::Cursor;
use super::Rows;
use crate::processors::sort::utils::get_ordered_rows;

#[async_trait::async_trait]
pub trait SortedStream {
/// Returns the next block and if it is pending.
/// Returns the next block with the order column and if it is pending.
///
/// If the block is [None] and it's not pending, it means the stream is finished.
/// If the block is [None] but it's pending, it means the stream is not finished yet.
fn next(&mut self) -> Result<(Option<DataBlock>, bool)> {
fn next(&mut self) -> Result<(Option<(DataBlock, Column)>, bool)> {
Ok((None, false))
}

/// The async version of `next`.
async fn async_next(&mut self) -> Result<(Option<DataBlock>, bool)> {
async fn async_next(&mut self) -> Result<(Option<(DataBlock, Column)>, bool)> {
self.next()
}
}
Expand All @@ -56,7 +56,6 @@ where
buffer: Vec<DataBlock>,
pending_stream: VecDeque<usize>,
batch_size: usize,
output_order_col: bool,

temp_sorted_num_rows: usize,
temp_output_indices: Vec<(usize, usize, usize)>,
Expand All @@ -73,11 +72,14 @@ where
streams: Vec<S>,
sort_desc: Arc<Vec<SortColumnDescription>>,
batch_size: usize,
output_order_col: bool,
) -> Self {
// We only create a merger when there are at least two streams.
debug_assert!(streams.len() > 1);
debug_assert!(schema.num_fields() > 0);
debug_assert!(streams.len() > 1, "streams.len() = {}", streams.len());
debug_assert!(
schema.num_fields() > 0,
"schema.num_fields = {}",
schema.num_fields()
);
debug_assert_eq!(schema.fields.last().unwrap().name(), "_order_col");
let heap = BinaryHeap::with_capacity(streams.len());
let buffer = vec![DataBlock::empty_with_schema(schema.clone()); streams.len()];
Expand All @@ -89,7 +91,6 @@ where
heap,
buffer,
batch_size,
output_order_col,
sort_desc,
pending_stream,
temp_sorted_num_rows: 0,
Expand All @@ -103,16 +104,13 @@ where
let mut continue_pendings = Vec::new();
while let Some(i) = self.pending_stream.pop_front() {
debug_assert!(self.buffer[i].is_empty());
let (block, pending) = self.unsorted_streams[i].async_next().await?;
let (input, pending) = self.unsorted_streams[i].async_next().await?;
if pending {
continue_pendings.push(i);
continue;
}
if let Some(mut block) = block {
let rows = get_ordered_rows(&block, &self.sort_desc)?;
if !self.output_order_col {
block.pop_columns(1);
}
if let Some((block, col)) = input {
let rows = R::from_column(&col, &self.sort_desc)?;
let cursor = Cursor::new(i, rows);
self.heap.push(Reverse(cursor));
self.buffer[i] = block;
Expand All @@ -127,16 +125,13 @@ where
let mut continue_pendings = Vec::new();
while let Some(i) = self.pending_stream.pop_front() {
debug_assert!(self.buffer[i].is_empty());
let (block, pending) = self.unsorted_streams[i].next()?;
let (input, pending) = self.unsorted_streams[i].next()?;
if pending {
continue_pendings.push(i);
continue;
}
if let Some(mut block) = block {
let rows = get_ordered_rows(&block, &self.sort_desc)?;
if !self.output_order_col {
block.pop_columns(1);
}
if let Some((block, col)) = input {
let rows = R::from_column(&col, &self.sort_desc)?;
let cursor = Cursor::new(i, rows);
self.heap.push(Reverse(cursor));
self.buffer[i] = block;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ impl Rows for StringColumn {
Column::String(self.clone())
}

fn from_column(col: Column, _: &[SortColumnDescription]) -> Option<Self> {
fn try_from_column(col: &Column, _: &[SortColumnDescription]) -> Option<Self> {
col.as_string().cloned()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ mod common;
mod simple;

pub use common::*;
use common_exception::ErrorCode;
use common_exception::Result;
use common_expression::types::DataType;
use common_expression::BlockEntry;
Expand Down Expand Up @@ -45,7 +46,17 @@ where Self: Sized + Clone
fn len(&self) -> usize;
fn row(&self, index: usize) -> Self::Item<'_>;
fn to_column(&self) -> Column;
fn from_column(col: Column, desc: &[SortColumnDescription]) -> Option<Self>;

fn from_column(col: &Column, desc: &[SortColumnDescription]) -> Result<Self> {
Self::try_from_column(col, desc).ok_or_else(|| {
ErrorCode::BadDataValueType(format!(
"Order column type mismatched. Expecetd {} but got {}",
Self::data_type(),
col.data_type()
))
})
}
fn try_from_column(col: &Column, desc: &[SortColumnDescription]) -> Option<Self>;

fn data_type() -> DataType;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,8 @@ where
T::upcast_column(self.inner.clone())
}

fn from_column(col: Column, desc: &[SortColumnDescription]) -> Option<Self> {
let inner = T::try_downcast_column(&col)?;
fn try_from_column(col: &Column, desc: &[SortColumnDescription]) -> Option<Self> {
let inner = T::try_downcast_column(col)?;
Some(Self {
inner,
desc: !desc[0].asc,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,6 @@

use std::collections::BinaryHeap;

use common_exception::ErrorCode;
use common_exception::Result;
use common_expression::DataBlock;
use common_expression::SortColumnDescription;

use super::Rows;

/// Find the bigger child of the root of the heap.
#[inline(always)]
pub fn find_bigger_child_of_root<T: Ord>(heap: &BinaryHeap<T>) -> &T {
Expand All @@ -32,15 +25,3 @@ pub fn find_bigger_child_of_root<T: Ord>(heap: &BinaryHeap<T>) -> &T {
(&slice[1]).max(&slice[2])
}
}

#[inline(always)]
pub fn get_ordered_rows<R: Rows>(block: &DataBlock, desc: &[SortColumnDescription]) -> Result<R> {
let order_col = block.columns().last().unwrap().value.as_column().unwrap();
R::from_column(order_col.clone(), desc).ok_or_else(|| {
let expected_ty = R::data_type();
let ty = order_col.data_type();
ErrorCode::BadDataValueType(format!(
"Order column type mismatched. Expecetd {expected_ty} but got {ty}"
))
})
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ use common_pipeline_core::Pipeline;
use common_profile::SharedProcessorProfiles;

use super::sort::utils::find_bigger_child_of_root;
use super::sort::utils::get_ordered_rows;
use super::sort::Cursor;
use super::sort::Rows;
use super::sort::SimpleRows;
Expand Down Expand Up @@ -513,7 +512,7 @@ where R: Rows + Send + 'static
continue;
}
let mut block = block.convert_to_full();
let rows = get_ordered_rows(&block, &self.sort_desc)?;
let rows = R::from_column(block.get_last_column(), &self.sort_desc)?;
// Remove the order column
if self.remove_order_col {
block.pop_columns(1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,11 @@ use common_base::runtime::GLOBAL_MEM_STAT;
use common_exception::ErrorCode;
use common_exception::Result;
use common_expression::row::RowConverter as CommonConverter;
use common_expression::BlockEntry;
use common_expression::BlockMetaInfo;
use common_expression::Column;
use common_expression::DataBlock;
use common_expression::DataSchemaRef;
use common_expression::SortColumnDescription;
use common_expression::Value;

use super::sort::CommonRows;
use super::sort::Cursor;
Expand Down Expand Up @@ -58,7 +57,7 @@ pub struct TransformSortMerge<R: Rows> {
sort_desc: Arc<Vec<SortColumnDescription>>,

block_size: usize,
buffer: Vec<Option<DataBlock>>,
buffer: Vec<Option<(DataBlock, Column)>>,

aborting: Arc<AtomicBool>,
// The following fields are used for spilling.
Expand All @@ -76,8 +75,6 @@ pub struct TransformSortMerge<R: Rows> {
/// The number of spilled blocks in each merge of the spill processor.
spill_num_merge: usize,

output_order_col: bool,

_r: PhantomData<R>,
}

Expand All @@ -88,7 +85,6 @@ impl<R: Rows> TransformSortMerge<R> {
block_size: usize,
max_memory_usage: usize,
spilling_bytes_threshold: usize,
output_order_col: bool,
) -> Self {
let may_spill = max_memory_usage != 0 && spilling_bytes_threshold != 0;
TransformSortMerge {
Expand All @@ -104,7 +100,6 @@ impl<R: Rows> TransformSortMerge<R> {
num_rows: 0,
spill_batch_size: 0,
spill_num_merge: 0,
output_order_col,
_r: PhantomData,
}
}
Expand All @@ -113,7 +108,7 @@ impl<R: Rows> TransformSortMerge<R> {
impl<R: Rows> MergeSort<R> for TransformSortMerge<R> {
const NAME: &'static str = "TransformSortMerge";

fn add_block(&mut self, mut block: DataBlock, init_cursor: Cursor<R>) -> Result<Status> {
fn add_block(&mut self, block: DataBlock, init_cursor: Cursor<R>) -> Result<Status> {
if unlikely(self.aborting.load(Ordering::Relaxed)) {
return Err(ErrorCode::AbortedQuery(
"Aborted query, because the server is shutting down or the query was killed.",
Expand All @@ -124,19 +119,9 @@ impl<R: Rows> MergeSort<R> for TransformSortMerge<R> {
return Ok(Status::Continue);
}

// If `self.output_order_col` is true, the order column will be removed outside the processor.
// In order to reuse codes, we add the order column back to the block.
// TODO: find a more elegant way to do this.
if !self.output_order_col {
let order_col = init_cursor.to_column();
block.add_column(BlockEntry {
data_type: order_col.data_type(),
value: Value::Column(order_col),
});
}
self.num_bytes += block.memory_size();
self.num_rows += block.num_rows();
self.buffer.push(Some(block));
self.buffer.push(Some((block, init_cursor.to_column())));

if self.may_spill
&& (self.num_bytes >= self.spilling_bytes_threshold
Expand Down Expand Up @@ -198,11 +183,15 @@ impl<R: Rows> TransformSortMerge<R> {
}

fn merge_sort(&mut self, batch_size: usize) -> Result<Vec<DataBlock>> {
if self.buffer.is_empty() {
return Ok(vec![]);
}

let size_hint = self.num_rows.div_ceil(batch_size);

if self.buffer.len() == 1 {
// If there is only one block, we don't need to merge.
let block = self.buffer.pop().unwrap().unwrap();
let (block, _) = self.buffer.pop().unwrap().unwrap();
let num_rows = block.num_rows();
if size_hint == 1 {
return Ok(vec![block]);
Expand All @@ -211,7 +200,7 @@ impl<R: Rows> TransformSortMerge<R> {
for i in 0..size_hint {
let start = i * batch_size;
let end = ((i + 1) * batch_size).min(num_rows);
let mut block = block.slice(start..end);
let block = block.slice(start..end);
result.push(block);
}
return Ok(result);
Expand All @@ -224,7 +213,6 @@ impl<R: Rows> TransformSortMerge<R> {
streams,
self.sort_desc.clone(),
batch_size,
self.output_order_col,
);

while let (Some(block), _) = merger.next_block()? {
Expand All @@ -240,10 +228,10 @@ impl<R: Rows> TransformSortMerge<R> {
}
}

type BlockStream = Option<DataBlock>;
type BlockStream = Option<(DataBlock, Column)>;

impl SortedStream for BlockStream {
fn next(&mut self) -> Result<(Option<DataBlock>, bool)> {
fn next(&mut self) -> Result<(Option<(DataBlock, Column)>, bool)> {
Ok((self.take(), false))
}
}
Expand Down Expand Up @@ -275,7 +263,7 @@ pub fn sort_merge(
sort_desc.clone(),
false,
false,
MergeSortCommonImpl::create(schema, sort_desc, block_size, 0, 0, false),
MergeSortCommonImpl::create(schema, sort_desc, block_size, 0, 0),
)?;
for block in data_blocks {
processor.transform(block)?;
Expand Down
Loading

0 comments on commit 524d89b

Please sign in to comment.