From 7640e4c11383c0e50fe6eef8d60386ec277968c0 Mon Sep 17 00:00:00 2001 From: hezheyu Date: Tue, 10 Oct 2023 13:59:46 +0800 Subject: [PATCH] fix: fix the invalid memory write bug for aggregating index. --- .../tests/it/aggregating_index/index_scan.rs | 31 +++++++++-- .../aggregator/transform_aggregate_partial.rs | 54 +++++++++++++++---- 2 files changed, 71 insertions(+), 14 deletions(-) diff --git a/src/query/ee/tests/it/aggregating_index/index_scan.rs b/src/query/ee/tests/it/aggregating_index/index_scan.rs index 9316863ac46e..063ac7483682 100644 --- a/src/query/ee/tests/it/aggregating_index/index_scan.rs +++ b/src/query/ee/tests/it/aggregating_index/index_scan.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::collections::HashMap; use std::fmt::Display; use std::sync::Arc; @@ -25,6 +26,7 @@ use common_sql::optimizer::SExpr; use common_sql::planner::plans::Plan; use common_sql::plans::RelOperator; use common_sql::Planner; +use common_storages_fuse::TableContext; use databend_query::interpreters::InterpreterFactory; use databend_query::sessions::QueryContext; use databend_query::test_kits::table_test_fixture::expects_ok; @@ -64,8 +66,14 @@ async fn test_index_scan_agg_args_are_expression() -> Result<()> { #[tokio::test(flavor = "multi_thread")] async fn test_fuzz() -> Result<()> { - test_fuzz_impl("parquet").await?; - test_fuzz_impl("native").await + test_fuzz_impl("parquet", false).await?; + test_fuzz_impl("native", false).await +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_fuzz_with_spill() -> Result<()> { + test_fuzz_impl("parquet", true).await?; + test_fuzz_impl("native", true).await } async fn plan_sql(ctx: Arc, sql: &str) -> Result { @@ -1038,12 +1046,29 @@ fn get_test_suites() -> Vec { ] } -async fn test_fuzz_impl(format: &str) -> Result<()> { +async fn test_fuzz_impl(format: &str, spill: bool) -> Result<()> { let test_suites = get_test_suites(); + let spill_settings = if spill { + Some(HashMap::from([ + ("spilling_memory_ratio".to_string(), "100".to_string()), + ( + "spilling_bytes_threshold_per_proc".to_string(), + "1".to_string(), + ), + ])) + } else { + None + }; for num_blocks in [1, 10] { for num_rows_per_block in [1, 50] { let (_guard, ctx, _) = create_ee_query_context(None).await.unwrap(); + if let Some(s) = spill_settings.as_ref() { + let settings = ctx.get_settings(); + // Make sure the operator will spill the aggregation. + settings.set_batch_settings(s)?; + } + let fixture = TestFixture::new_with_ctx(_guard, ctx).await; // Prepare table and data // Create random engine table to generate random data. diff --git a/src/query/service/src/pipelines/processors/transforms/aggregator/transform_aggregate_partial.rs b/src/query/service/src/pipelines/processors/transforms/aggregator/transform_aggregate_partial.rs index 4cf479fb5ace..479a3eecee63 100644 --- a/src/query/service/src/pipelines/processors/transforms/aggregator/transform_aggregate_partial.rs +++ b/src/query/service/src/pipelines/processors/transforms/aggregator/transform_aggregate_partial.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::intrinsics::unlikely; use std::sync::Arc; use std::vec; @@ -39,6 +40,7 @@ use log::info; use crate::pipelines::processors::transforms::aggregator::aggregate_cell::AggregateHashTableDropper; use crate::pipelines::processors::transforms::aggregator::aggregate_meta::AggregateMeta; +use crate::pipelines::processors::transforms::group_by::Area; use crate::pipelines::processors::transforms::group_by::HashMethodBounds; use crate::pipelines::processors::transforms::group_by::PartitionedHashMethod; use crate::pipelines::processors::transforms::group_by::PolymorphicKeysHelper; @@ -103,6 +105,35 @@ impl TryFrom> for AggregateSettings { } } +/// A owned temporary memory. +struct TempMemory { + place: StateAddr, + arena: Area, +} + +impl TempMemory { + /// Create a lazy memory wh ich will not be allocated until the first time it is used. + fn create_lazy() -> Self { + let arena = Area::create(); + Self { + place: StateAddr::new(0), + arena, + } + } + + #[inline(always)] + fn alloc_layout(&mut self, params: &AggregatorParams) { + if unlikely(self.place.addr() == 0) { + self.place = params.alloc_layout(&mut self.arena); + } + } + + #[inline(always)] + fn place(&self) -> &StateAddr { + &self.place + } +} + // SELECT column_name, agg(xxx) FROM table_name GROUP BY column_name pub struct TransformPartialAggregate { method: Method, @@ -111,8 +142,13 @@ pub struct TransformPartialAggregate { params: Arc, - /// A temporary place to hold aggregating state from index data. - temp_place: StateAddr, + /// A temporary memory to transform aggregating state from index data. + /// + /// **NOTES**: we should create a new [`Area`] to transform the aggregating index data. + /// We cannot use the [`Area`] in `hash_table` to hold the temporary memory, + /// because the [`Area`] may be moved out when spilling happens. + /// And this [`TransformPartialAggregate`] will lose the control of the memory. + temp_memory: TempMemory, } impl TransformPartialAggregate { @@ -143,7 +179,7 @@ impl TransformPartialAggregate { params, hash_table, settings: AggregateSettings::try_from(ctx)?, - temp_place: StateAddr::new(0), + temp_memory: TempMemory::create_lazy(), }, )) } @@ -204,9 +240,11 @@ impl TransformPartialAggregate { #[inline(always)] #[allow(clippy::ptr_arg)] // &[StateAddr] slower than &StateAddrs ~20% - fn execute_agg_index_block(&self, block: &DataBlock, places: &StateAddrs) -> Result<()> { + fn execute_agg_index_block(&mut self, block: &DataBlock, places: &StateAddrs) -> Result<()> { + self.temp_memory.alloc_layout(&self.params); let aggregate_functions = &self.params.aggregate_functions; let offsets_aggregate_states = &self.params.offsets_aggregate_states; + let temp_place = self.temp_memory.place(); for index in 0..aggregate_functions.len() { // Aggregation states are in the back of the block. @@ -220,7 +258,7 @@ impl TransformPartialAggregate { .unwrap() .as_string() .unwrap(); - let state_place = self.temp_place.next(offset); + let state_place = temp_place.next(offset); for (row, mut raw_state) in agg_state.iter().enumerate() { let place = &places[row]; function.deserialize(state_place, &mut raw_state)?; @@ -277,9 +315,6 @@ impl TransformPartialAggregate { } if is_agg_index_block { - if self.temp_place.addr() == 0 { - self.temp_place = self.params.alloc_layout(&mut hashtable.arena); - } self.execute_agg_index_block(&block, &places) } else { Self::execute(&self.params, &block, &places) @@ -300,9 +335,6 @@ impl TransformPartialAggregate { } if is_agg_index_block { - if self.temp_place.addr() == 0 { - self.temp_place = self.params.alloc_layout(&mut hashtable.arena); - } self.execute_agg_index_block(&block, &places) } else { Self::execute(&self.params, &block, &places)