From 0fc2a5d2f4de307e212aa25f94db78207d3db18c Mon Sep 17 00:00:00 2001 From: hezheyu Date: Fri, 15 Dec 2023 10:22:49 +0800 Subject: [PATCH] Refactor `SortedStream` trait, output block along with order column. --- src/query/expression/src/block.rs | 7 ++ src/query/pipeline/transforms/src/lib.rs | 1 + .../src/processors/transforms/sort/merger.rs | 34 +++----- .../processors/transforms/sort/rows/common.rs | 2 +- .../processors/transforms/sort/rows/mod.rs | 13 ++- .../processors/transforms/sort/rows/simple.rs | 4 +- .../src/processors/transforms/sort/utils.rs | 45 ++++++++--- .../transforms/transform_multi_sort_merge.rs | 3 +- .../transforms/transform_sort_merge.rs | 54 +++++++------ .../transforms/transform_sort_merge_base.rs | 9 +-- .../pipeline/transforms/tests/it/merger.rs | 81 ++++++++++++++++--- .../src/pipelines/builders/builder_sort.rs | 41 +++------- .../transforms/range_join/ie_join_state.rs | 2 +- .../transforms/transform_sort_spill.rs | 20 ++--- .../executor/physical_plans/physical_sort.rs | 5 +- 15 files changed, 191 insertions(+), 130 deletions(-) diff --git a/src/query/expression/src/block.rs b/src/query/expression/src/block.rs index 383e344d8c227..1412397b915c3 100644 --- a/src/query/expression/src/block.rs +++ b/src/query/expression/src/block.rs @@ -553,6 +553,13 @@ impl DataBlock { } DataBlock::new_with_meta(columns, self.num_rows, self.meta) } + + #[inline] + pub fn get_last_column(&self) -> &Column { + debug_assert!(!self.columns.is_empty()); + debug_assert!(self.columns.last().unwrap().value.as_column().is_some()); + self.columns.last().unwrap().value.as_column().unwrap() + } } impl TryFrom for ArrowChunk { diff --git a/src/query/pipeline/transforms/src/lib.rs b/src/query/pipeline/transforms/src/lib.rs index addeeb815e238..339e62d63f846 100644 --- a/src/query/pipeline/transforms/src/lib.rs +++ b/src/query/pipeline/transforms/src/lib.rs @@ -15,5 +15,6 @@ #![feature(core_intrinsics)] #![feature(int_roundings)] #![feature(binary_heap_as_slice)] +#![feature(let_chains)] pub mod processors; diff --git a/src/query/pipeline/transforms/src/processors/transforms/sort/merger.rs b/src/query/pipeline/transforms/src/processors/transforms/sort/merger.rs index 0b1bc8d7c0845..9ddab650c5bac 100644 --- a/src/query/pipeline/transforms/src/processors/transforms/sort/merger.rs +++ b/src/query/pipeline/transforms/src/processors/transforms/sort/merger.rs @@ -18,6 +18,7 @@ use std::collections::VecDeque; use std::sync::Arc; use common_exception::Result; +use common_expression::Column; use common_expression::DataBlock; use common_expression::DataSchemaRef; use common_expression::SortColumnDescription; @@ -25,20 +26,19 @@ use common_expression::SortColumnDescription; use super::utils::find_bigger_child_of_root; use super::Cursor; use super::Rows; -use crate::processors::sort::utils::get_ordered_rows; #[async_trait::async_trait] pub trait SortedStream { - /// Returns the next block and if it is pending. + /// Returns the next block with the order column and if it is pending. /// /// If the block is [None] and it's not pending, it means the stream is finished. /// If the block is [None] but it's pending, it means the stream is not finished yet. - fn next(&mut self) -> Result<(Option, bool)> { + fn next(&mut self) -> Result<(Option<(DataBlock, Column)>, bool)> { Ok((None, false)) } /// The async version of `next`. - async fn async_next(&mut self) -> Result<(Option, bool)> { + async fn async_next(&mut self) -> Result<(Option<(DataBlock, Column)>, bool)> { self.next() } } @@ -56,7 +56,6 @@ where buffer: Vec, pending_stream: VecDeque, batch_size: usize, - output_order_col: bool, temp_sorted_num_rows: usize, temp_output_indices: Vec<(usize, usize, usize)>, @@ -73,12 +72,10 @@ where streams: Vec, sort_desc: Arc>, batch_size: usize, - output_order_col: bool, ) -> Self { // We only create a merger when there are at least two streams. - debug_assert!(streams.len() > 1); - debug_assert!(schema.num_fields() > 0); - debug_assert_eq!(schema.fields.last().unwrap().name(), "_order_col"); + debug_assert!(streams.len() > 1, "streams.len() = {}", streams.len()); + let heap = BinaryHeap::with_capacity(streams.len()); let buffer = vec![DataBlock::empty_with_schema(schema.clone()); streams.len()]; let pending_stream = (0..streams.len()).collect(); @@ -89,7 +86,6 @@ where heap, buffer, batch_size, - output_order_col, sort_desc, pending_stream, temp_sorted_num_rows: 0, @@ -103,16 +99,13 @@ where let mut continue_pendings = Vec::new(); while let Some(i) = self.pending_stream.pop_front() { debug_assert!(self.buffer[i].is_empty()); - let (block, pending) = self.unsorted_streams[i].async_next().await?; + let (input, pending) = self.unsorted_streams[i].async_next().await?; if pending { continue_pendings.push(i); continue; } - if let Some(mut block) = block { - let rows = get_ordered_rows(&block, &self.sort_desc)?; - if !self.output_order_col { - block.pop_columns(1); - } + if let Some((block, col)) = input { + let rows = R::from_column(&col, &self.sort_desc)?; let cursor = Cursor::new(i, rows); self.heap.push(Reverse(cursor)); self.buffer[i] = block; @@ -127,16 +120,13 @@ where let mut continue_pendings = Vec::new(); while let Some(i) = self.pending_stream.pop_front() { debug_assert!(self.buffer[i].is_empty()); - let (block, pending) = self.unsorted_streams[i].next()?; + let (input, pending) = self.unsorted_streams[i].next()?; if pending { continue_pendings.push(i); continue; } - if let Some(mut block) = block { - let rows = get_ordered_rows(&block, &self.sort_desc)?; - if !self.output_order_col { - block.pop_columns(1); - } + if let Some((block, col)) = input { + let rows = R::from_column(&col, &self.sort_desc)?; let cursor = Cursor::new(i, rows); self.heap.push(Reverse(cursor)); self.buffer[i] = block; diff --git a/src/query/pipeline/transforms/src/processors/transforms/sort/rows/common.rs b/src/query/pipeline/transforms/src/processors/transforms/sort/rows/common.rs index 3ed8489c05850..6d28815177aa4 100644 --- a/src/query/pipeline/transforms/src/processors/transforms/sort/rows/common.rs +++ b/src/query/pipeline/transforms/src/processors/transforms/sort/rows/common.rs @@ -48,7 +48,7 @@ impl Rows for StringColumn { Column::String(self.clone()) } - fn from_column(col: Column, _: &[SortColumnDescription]) -> Option { + fn try_from_column(col: &Column, _: &[SortColumnDescription]) -> Option { col.as_string().cloned() } diff --git a/src/query/pipeline/transforms/src/processors/transforms/sort/rows/mod.rs b/src/query/pipeline/transforms/src/processors/transforms/sort/rows/mod.rs index dd14b385e09ac..9cd70599fb001 100644 --- a/src/query/pipeline/transforms/src/processors/transforms/sort/rows/mod.rs +++ b/src/query/pipeline/transforms/src/processors/transforms/sort/rows/mod.rs @@ -16,6 +16,7 @@ mod common; mod simple; pub use common::*; +use common_exception::ErrorCode; use common_exception::Result; use common_expression::types::DataType; use common_expression::BlockEntry; @@ -45,7 +46,17 @@ where Self: Sized + Clone fn len(&self) -> usize; fn row(&self, index: usize) -> Self::Item<'_>; fn to_column(&self) -> Column; - fn from_column(col: Column, desc: &[SortColumnDescription]) -> Option; + + fn from_column(col: &Column, desc: &[SortColumnDescription]) -> Result { + Self::try_from_column(col, desc).ok_or_else(|| { + ErrorCode::BadDataValueType(format!( + "Order column type mismatched. Expecetd {} but got {}", + Self::data_type(), + col.data_type() + )) + }) + } + fn try_from_column(col: &Column, desc: &[SortColumnDescription]) -> Option; fn data_type() -> DataType; diff --git a/src/query/pipeline/transforms/src/processors/transforms/sort/rows/simple.rs b/src/query/pipeline/transforms/src/processors/transforms/sort/rows/simple.rs index f4045df5e44c3..0909fd258b9d0 100644 --- a/src/query/pipeline/transforms/src/processors/transforms/sort/rows/simple.rs +++ b/src/query/pipeline/transforms/src/processors/transforms/sort/rows/simple.rs @@ -115,8 +115,8 @@ where T::upcast_column(self.inner.clone()) } - fn from_column(col: Column, desc: &[SortColumnDescription]) -> Option { - let inner = T::try_downcast_column(&col)?; + fn try_from_column(col: &Column, desc: &[SortColumnDescription]) -> Option { + let inner = T::try_downcast_column(col)?; Some(Self { inner, desc: !desc[0].asc, diff --git a/src/query/pipeline/transforms/src/processors/transforms/sort/utils.rs b/src/query/pipeline/transforms/src/processors/transforms/sort/utils.rs index d020885ac0b81..c56e291aa12e1 100644 --- a/src/query/pipeline/transforms/src/processors/transforms/sort/utils.rs +++ b/src/query/pipeline/transforms/src/processors/transforms/sort/utils.rs @@ -14,12 +14,14 @@ use std::collections::BinaryHeap; -use common_exception::ErrorCode; -use common_exception::Result; -use common_expression::DataBlock; +use common_expression::types::DataType; +use common_expression::DataField; +use common_expression::DataSchema; +use common_expression::DataSchemaRef; +use common_expression::DataSchemaRefExt; use common_expression::SortColumnDescription; -use super::Rows; +pub const ORDER_COL_NAME: &'static str = "_order_col"; /// Find the bigger child of the root of the heap. #[inline(always)] @@ -34,13 +36,30 @@ pub fn find_bigger_child_of_root(heap: &BinaryHeap) -> &T { } #[inline(always)] -pub fn get_ordered_rows(block: &DataBlock, desc: &[SortColumnDescription]) -> Result { - let order_col = block.columns().last().unwrap().value.as_column().unwrap(); - R::from_column(order_col.clone(), desc).ok_or_else(|| { - let expected_ty = R::data_type(); - let ty = order_col.data_type(); - ErrorCode::BadDataValueType(format!( - "Order column type mismatched. Expecetd {expected_ty} but got {ty}" - )) - }) +fn order_field_type(schema: &DataSchema, desc: &[SortColumnDescription]) -> DataType { + debug_assert!(!desc.is_empty()); + if desc.len() == 1 { + let order_by_field = schema.field(desc[0].offset); + if matches!( + order_by_field.data_type(), + DataType::Number(_) | DataType::Date | DataType::Timestamp | DataType::String + ) { + return order_by_field.data_type().clone(); + } + } + DataType::String +} + +#[inline(always)] +pub fn add_order_field(schema: DataSchemaRef, desc: &[SortColumnDescription]) -> DataSchemaRef { + if let Some(f) = schema.fields.last() && f.name() == ORDER_COL_NAME { + schema + } else { + let mut fields = schema.fields().clone(); + fields.push(DataField::new( + ORDER_COL_NAME, + order_field_type(&schema, desc), + )); + DataSchemaRefExt::create(fields) + } } diff --git a/src/query/pipeline/transforms/src/processors/transforms/transform_multi_sort_merge.rs b/src/query/pipeline/transforms/src/processors/transforms/transform_multi_sort_merge.rs index f850822c9b525..3684993ba0f50 100644 --- a/src/query/pipeline/transforms/src/processors/transforms/transform_multi_sort_merge.rs +++ b/src/query/pipeline/transforms/src/processors/transforms/transform_multi_sort_merge.rs @@ -42,7 +42,6 @@ use common_pipeline_core::Pipeline; use common_profile::SharedProcessorProfiles; use super::sort::utils::find_bigger_child_of_root; -use super::sort::utils::get_ordered_rows; use super::sort::Cursor; use super::sort::Rows; use super::sort::SimpleRows; @@ -513,7 +512,7 @@ where R: Rows + Send + 'static continue; } let mut block = block.convert_to_full(); - let rows = get_ordered_rows(&block, &self.sort_desc)?; + let rows = R::from_column(block.get_last_column(), &self.sort_desc)?; // Remove the order column if self.remove_order_col { block.pop_columns(1); diff --git a/src/query/pipeline/transforms/src/processors/transforms/transform_sort_merge.rs b/src/query/pipeline/transforms/src/processors/transforms/transform_sort_merge.rs index b14029c061fdf..5e3cb84c069b8 100644 --- a/src/query/pipeline/transforms/src/processors/transforms/transform_sort_merge.rs +++ b/src/query/pipeline/transforms/src/processors/transforms/transform_sort_merge.rs @@ -22,12 +22,13 @@ use common_base::runtime::GLOBAL_MEM_STAT; use common_exception::ErrorCode; use common_exception::Result; use common_expression::row::RowConverter as CommonConverter; -use common_expression::BlockEntry; +use common_expression::types::DataType; use common_expression::BlockMetaInfo; +use common_expression::Column; use common_expression::DataBlock; +use common_expression::DataSchema; use common_expression::DataSchemaRef; use common_expression::SortColumnDescription; -use common_expression::Value; use super::sort::CommonRows; use super::sort::Cursor; @@ -58,7 +59,7 @@ pub struct TransformSortMerge { sort_desc: Arc>, block_size: usize, - buffer: Vec>, + buffer: Vec>, aborting: Arc, // The following fields are used for spilling. @@ -76,8 +77,6 @@ pub struct TransformSortMerge { /// The number of spilled blocks in each merge of the spill processor. spill_num_merge: usize, - output_order_col: bool, - _r: PhantomData, } @@ -88,7 +87,6 @@ impl TransformSortMerge { block_size: usize, max_memory_usage: usize, spilling_bytes_threshold: usize, - output_order_col: bool, ) -> Self { let may_spill = max_memory_usage != 0 && spilling_bytes_threshold != 0; TransformSortMerge { @@ -104,7 +102,6 @@ impl TransformSortMerge { num_rows: 0, spill_batch_size: 0, spill_num_merge: 0, - output_order_col, _r: PhantomData, } } @@ -113,7 +110,7 @@ impl TransformSortMerge { impl MergeSort for TransformSortMerge { const NAME: &'static str = "TransformSortMerge"; - fn add_block(&mut self, mut block: DataBlock, init_cursor: Cursor) -> Result { + fn add_block(&mut self, block: DataBlock, init_cursor: Cursor) -> Result { if unlikely(self.aborting.load(Ordering::Relaxed)) { return Err(ErrorCode::AbortedQuery( "Aborted query, because the server is shutting down or the query was killed.", @@ -124,19 +121,9 @@ impl MergeSort for TransformSortMerge { return Ok(Status::Continue); } - // If `self.output_order_col` is true, the order column will be removed outside the processor. - // In order to reuse codes, we add the order column back to the block. - // TODO: find a more elegant way to do this. - if !self.output_order_col { - let order_col = init_cursor.to_column(); - block.add_column(BlockEntry { - data_type: order_col.data_type(), - value: Value::Column(order_col), - }); - } self.num_bytes += block.memory_size(); self.num_rows += block.num_rows(); - self.buffer.push(Some(block)); + self.buffer.push(Some((block, init_cursor.to_column()))); if self.may_spill && (self.num_bytes >= self.spilling_bytes_threshold @@ -198,11 +185,15 @@ impl TransformSortMerge { } fn merge_sort(&mut self, batch_size: usize) -> Result> { + if self.buffer.is_empty() { + return Ok(vec![]); + } + let size_hint = self.num_rows.div_ceil(batch_size); if self.buffer.len() == 1 { // If there is only one block, we don't need to merge. - let block = self.buffer.pop().unwrap().unwrap(); + let (block, _) = self.buffer.pop().unwrap().unwrap(); let num_rows = block.num_rows(); if size_hint == 1 { return Ok(vec![block]); @@ -211,7 +202,7 @@ impl TransformSortMerge { for i in 0..size_hint { let start = i * batch_size; let end = ((i + 1) * batch_size).min(num_rows); - let mut block = block.slice(start..end); + let block = block.slice(start..end); result.push(block); } return Ok(result); @@ -224,7 +215,6 @@ impl TransformSortMerge { streams, self.sort_desc.clone(), batch_size, - self.output_order_col, ); while let (Some(block), _) = merger.next_block()? { @@ -240,10 +230,10 @@ impl TransformSortMerge { } } -type BlockStream = Option; +type BlockStream = Option<(DataBlock, Column)>; impl SortedStream for BlockStream { - fn next(&mut self) -> Result<(Option, bool)> { + fn next(&mut self) -> Result<(Option<(DataBlock, Column)>, bool)> { Ok((self.take(), false)) } } @@ -263,6 +253,20 @@ pub(super) type MergeSortCommonImpl = TransformSortMerge; pub(super) type MergeSortCommon = TransformSortMergeBase; +pub fn order_column_type(desc: &[SortColumnDescription], schema: &DataSchema) -> DataType { + debug_assert!(!desc.is_empty()); + if desc.len() == 1 { + let order_by_field = schema.field(desc[0].offset); + if matches!( + order_by_field.data_type(), + DataType::Number(_) | DataType::Date | DataType::Timestamp | DataType::String + ) { + return order_by_field.data_type().clone(); + } + } + DataType::String +} + pub fn sort_merge( schema: DataSchemaRef, block_size: usize, @@ -275,7 +279,7 @@ pub fn sort_merge( sort_desc.clone(), false, false, - MergeSortCommonImpl::create(schema, sort_desc, block_size, 0, 0, false), + MergeSortCommonImpl::create(schema, sort_desc, block_size, 0, 0), )?; for block in data_blocks { processor.transform(block)?; diff --git a/src/query/pipeline/transforms/src/processors/transforms/transform_sort_merge_base.rs b/src/query/pipeline/transforms/src/processors/transforms/transform_sort_merge_base.rs index 360ef0138bba8..926d35ac1304b 100644 --- a/src/query/pipeline/transforms/src/processors/transforms/transform_sort_merge_base.rs +++ b/src/query/pipeline/transforms/src/processors/transforms/transform_sort_merge_base.rs @@ -29,7 +29,6 @@ use common_pipeline_core::processors::InputPort; use common_pipeline_core::processors::OutputPort; use common_pipeline_core::processors::Processor; -use super::sort::utils::get_ordered_rows; use super::sort::Cursor; use super::sort::RowConverter; use super::sort::Rows; @@ -134,7 +133,7 @@ where fn transform(&mut self, mut block: DataBlock) -> Result> { let rows = if self.order_col_generated { - let rows = get_ordered_rows(&block, &self.sort_desc)?; + let rows = R::from_column(block.get_last_column(), &self.sort_desc)?; if !self.output_order_col { // The next processor could be a sort spill processor which need order column. // And the order column will be removed in that processor. @@ -284,7 +283,6 @@ impl TransformSortMergeBuilder { block_size, max_memory_usage, spilling_bytes_threshold_per_core, - output_order_col ), )?, ), @@ -303,7 +301,6 @@ impl TransformSortMergeBuilder { block_size, max_memory_usage, spilling_bytes_threshold_per_core, - output_order_col, ), )?, ), @@ -321,7 +318,6 @@ impl TransformSortMergeBuilder { block_size, max_memory_usage, spilling_bytes_threshold_per_core, - output_order_col, ), )?, ), @@ -339,7 +335,6 @@ impl TransformSortMergeBuilder { block_size, max_memory_usage, spilling_bytes_threshold_per_core, - output_order_col, ), )?, ), @@ -357,7 +352,6 @@ impl TransformSortMergeBuilder { block_size, max_memory_usage, spilling_bytes_threshold_per_core, - output_order_col, ), )?, ), @@ -377,7 +371,6 @@ impl TransformSortMergeBuilder { block_size, max_memory_usage, spilling_bytes_threshold_per_core, - output_order_col, ), )?, ) diff --git a/src/query/pipeline/transforms/tests/it/merger.rs b/src/query/pipeline/transforms/tests/it/merger.rs index 097b14d2641dd..68ffdbe242a53 100644 --- a/src/query/pipeline/transforms/tests/it/merger.rs +++ b/src/query/pipeline/transforms/tests/it/merger.rs @@ -15,11 +15,13 @@ use std::collections::VecDeque; use std::sync::Arc; +use common_base::base::tokio; use common_exception::Result; use common_expression::block_debug::pretty_format_blocks; use common_expression::types::DataType; use common_expression::types::Int32Type; use common_expression::types::NumberDataType; +use common_expression::Column; use common_expression::DataBlock; use common_expression::DataField; use common_expression::DataSchemaRefExt; @@ -48,13 +50,19 @@ impl TestStream { } impl SortedStream for TestStream { - fn next(&mut self) -> Result<(Option, bool)> { + fn next(&mut self) -> Result<(Option<(DataBlock, Column)>, bool)> { // To simulate the real scenario, we randomly decide whether the stream is pending or not. let pending = self.rng.gen_bool(0.5); if pending { Ok((None, true)) } else { - Ok((self.data.pop_front(), false)) + Ok(( + self.data.pop_front().map(|b| { + let col = b.get_last_column().clone(); + (b, col) + }), + false, + )) } } } @@ -126,19 +134,10 @@ fn create_test_merger(input: Vec>) -> TestMerger { .map(|v| TestStream::new(v.into_iter().collect::>())) .collect::>(); - TestMerger::create(schema, streams, sort_desc, 4, true) + TestMerger::create(schema, streams, sort_desc, 4) } -fn test(mut merger: TestMerger, expected: DataBlock) -> Result<()> { - let mut result = Vec::new(); - - while let (Some(block), pending) = merger.next_block()? { - if pending { - continue; - } - result.push(block); - } - +fn check_result(result: Vec, expected: DataBlock) { let result_rows = result.iter().map(|v| v.num_rows()).sum::(); let result = pretty_format_blocks(&result).unwrap(); let expected_rows = expected.num_rows(); @@ -148,6 +147,42 @@ fn test(mut merger: TestMerger, expected: DataBlock) -> Result<()> { "expected (num_rows = {}):\n{}\nactual (num_rows = {}):\n{}", expected_rows, expected, result_rows, result ); +} + +fn test(mut merger: TestMerger, expected: DataBlock) -> Result<()> { + let mut result = Vec::new(); + + loop { + let (block, pending) = merger.next_block()?; + if pending { + continue; + } + if block.is_none() { + break; + } + result.push(block.unwrap()); + } + + check_result(result, expected); + + Ok(()) +} + +async fn test_async(mut merger: TestMerger, expected: DataBlock) -> Result<()> { + let mut result = Vec::new(); + + loop { + let (block, pending) = merger.async_next_block().await?; + if pending { + continue; + } + if block.is_none() { + break; + } + result.push(block.unwrap()); + } + + check_result(result, expected); Ok(()) } @@ -171,3 +206,23 @@ fn test_fuzz() -> Result<()> { Ok(()) } + +#[tokio::test(flavor = "multi_thread")] +async fn test_basic_async() -> Result<()> { + let (input, expected) = basic_test_data(); + let merger = create_test_merger(input); + test_async(merger, expected).await +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_fuzz_async() -> Result<()> { + let mut rng = rand::thread_rng(); + + for _ in 0..10 { + let (input, expected) = random_test_data(&mut rng); + let merger = create_test_merger(input); + test_async(merger, expected).await?; + } + + Ok(()) +} diff --git a/src/query/service/src/pipelines/builders/builder_sort.rs b/src/query/service/src/pipelines/builders/builder_sort.rs index 92ead930a85b2..060167ca7b09a 100644 --- a/src/query/service/src/pipelines/builders/builder_sort.rs +++ b/src/query/service/src/pipelines/builders/builder_sort.rs @@ -15,15 +15,12 @@ use std::sync::Arc; use common_exception::Result; -use common_expression::types::DataType; -use common_expression::DataField; -use common_expression::DataSchema; use common_expression::DataSchemaRef; -use common_expression::DataSchemaRefExt; use common_expression::SortColumnDescription; use common_pipeline_core::processors::ProcessorPtr; use common_pipeline_core::query_spill_prefix; use common_pipeline_core::Pipeline; +use common_pipeline_transforms::processors::sort::utils::add_order_field; use common_pipeline_transforms::processors::try_add_multi_sort_merge; use common_pipeline_transforms::processors::ProcessorProfileWrapper; use common_pipeline_transforms::processors::TransformSortMergeBuilder; @@ -293,11 +290,17 @@ impl SortPipelineBuilder { let may_spill = max_memory_usage != 0 && bytes_limit_per_proc != 0; + let sort_merge_output_schema = if output_order_col || may_spill { + add_order_field(self.schema.clone(), &self.sort_desc) + } else { + self.schema.clone() + }; + pipeline.add_transform(|input, output| { let builder = TransformSortMergeBuilder::create( input, output, - self.schema.clone(), + sort_merge_output_schema.clone(), self.sort_desc.clone(), self.partial_block_size, ) @@ -320,18 +323,8 @@ impl SortPipelineBuilder { })?; if may_spill { + let schema = add_order_field(sort_merge_output_schema.clone(), &self.sort_desc); let config = SpillerConfig::create(query_spill_prefix(&self.ctx.get_tenant())); - // The input of the processor must contain an order column. - let schema = if let Some(f) = self.schema.fields.last() && f.name() == "_order_col" { - self.schema.clone() - } else { - let mut fields = self.schema.fields().clone(); - fields.push(DataField::new( - "_order_col", - order_column_type(&self.sort_desc, &self.schema), - )); - DataSchemaRefExt::create(fields) - }; pipeline.add_transform(|input, output| { let op = DataOperator::instance().operator(); let spiller = @@ -360,7 +353,7 @@ impl SortPipelineBuilder { // Multi-pipelines merge sort try_add_multi_sort_merge( pipeline, - self.schema, + sort_merge_output_schema, self.final_block_size, self.limit, self.sort_desc, @@ -372,17 +365,3 @@ impl SortPipelineBuilder { Ok(()) } } - -fn order_column_type(desc: &[SortColumnDescription], schema: &DataSchema) -> DataType { - debug_assert!(!desc.is_empty()); - if desc.len() == 1 { - let order_by_field = schema.field(desc[0].offset); - if matches!( - order_by_field.data_type(), - DataType::Number(_) | DataType::Date | DataType::Timestamp | DataType::String - ) { - return order_by_field.data_type().clone(); - } - } - DataType::String -} diff --git a/src/query/service/src/pipelines/processors/transforms/range_join/ie_join_state.rs b/src/query/service/src/pipelines/processors/transforms/range_join/ie_join_state.rs index cb0004be27e05..1ff87720b95ae 100644 --- a/src/query/service/src/pipelines/processors/transforms/range_join/ie_join_state.rs +++ b/src/query/service/src/pipelines/processors/transforms/range_join/ie_join_state.rs @@ -200,7 +200,7 @@ impl RangeJoinState { ); left_sorted_blocks = sort_merge( - data_schema, + data_schema.clone(), block_size, ie_join_state.l1_sort_descriptions.clone(), left_sorted_blocks, diff --git a/src/query/service/src/pipelines/processors/transforms/transform_sort_spill.rs b/src/query/service/src/pipelines/processors/transforms/transform_sort_spill.rs index 411a00fbd5c50..67257a4c42191 100644 --- a/src/query/service/src/pipelines/processors/transforms/transform_sort_spill.rs +++ b/src/query/service/src/pipelines/processors/transforms/transform_sort_spill.rs @@ -24,6 +24,7 @@ use common_expression::types::NumberDataType; use common_expression::types::NumberType; use common_expression::with_number_mapped_type; use common_expression::BlockMetaInfoDowncast; +use common_expression::Column; use common_expression::DataBlock; use common_expression::DataSchemaRef; use common_expression::SortColumnDescription; @@ -328,7 +329,6 @@ where R: Rows + Sync + Send + 'static streams, self.sort_desc.clone(), self.batch_size, - true, ); let mut spilled = VecDeque::new(); @@ -358,7 +358,7 @@ enum BlockStream { #[async_trait::async_trait] impl SortedStream for BlockStream { - async fn async_next(&mut self) -> Result<(Option, bool)> { + async fn async_next(&mut self) -> Result<(Option<(DataBlock, Column)>, bool)> { let block = match self { BlockStream::Block(block) => block.take(), BlockStream::Spilled((files, spiller)) => { @@ -379,7 +379,13 @@ impl SortedStream for BlockStream { } } }; - Ok((block, false)) + Ok(( + block.map(|b| { + let col = b.get_last_column().clone(); + (b, col) + }), + false, + )) } } @@ -596,12 +602,8 @@ mod tests { let mut result = Vec::new(); - loop { - let (block, _) = block_stream.async_next().await?; - if block.is_none() { - break; - } - result.push(block.unwrap()); + while let (Some((block, _)), _) = block_stream.async_next().await? { + result.push(block); } let result = pretty_format_blocks(&result).unwrap(); diff --git a/src/query/sql/src/executor/physical_plans/physical_sort.rs b/src/query/sql/src/executor/physical_plans/physical_sort.rs index acb15d7b57a43..51c074b5e6535 100644 --- a/src/query/sql/src/executor/physical_plans/physical_sort.rs +++ b/src/query/sql/src/executor/physical_plans/physical_sort.rs @@ -18,6 +18,7 @@ use common_expression::DataField; use common_expression::DataSchema; use common_expression::DataSchemaRef; use common_expression::DataSchemaRefExt; +use common_pipeline_transforms::processors::sort::utils::ORDER_COL_NAME; use itertools::Itertools; use crate::executor::explain::PlanStatsInfo; @@ -65,7 +66,7 @@ impl Sort { if matches!(self.after_exchange, Some(true)) { // If the plan is after exchange plan in cluster mode, // the order column is at the last of the input schema. - debug_assert_eq!(fields.last().unwrap().name(), "_order_col"); + debug_assert_eq!(fields.last().unwrap().name(), ORDER_COL_NAME); debug_assert_eq!( fields.last().unwrap().data_type(), &self.order_col_type(&input_schema)? @@ -88,7 +89,7 @@ impl Sort { // If the plan is before exchange plan in cluster mode, // the order column should be added to the output schema. fields.push(DataField::new( - "_order_col", + ORDER_COL_NAME, self.order_col_type(&input_schema)?, )); }