Skip to content

Commit

Permalink
fix: fix the invalid memory write bug for aggregating index.
Browse files Browse the repository at this point in the history
  • Loading branch information
RinChanNOWWW committed Oct 10, 2023
1 parent 534002d commit 7640e4c
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 14 deletions.
31 changes: 28 additions & 3 deletions src/query/ee/tests/it/aggregating_index/index_scan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::collections::HashMap;
use std::fmt::Display;
use std::sync::Arc;

Expand All @@ -25,6 +26,7 @@ use common_sql::optimizer::SExpr;
use common_sql::planner::plans::Plan;
use common_sql::plans::RelOperator;
use common_sql::Planner;
use common_storages_fuse::TableContext;
use databend_query::interpreters::InterpreterFactory;
use databend_query::sessions::QueryContext;
use databend_query::test_kits::table_test_fixture::expects_ok;
Expand Down Expand Up @@ -64,8 +66,14 @@ async fn test_index_scan_agg_args_are_expression() -> Result<()> {

#[tokio::test(flavor = "multi_thread")]
async fn test_fuzz() -> Result<()> {
test_fuzz_impl("parquet").await?;
test_fuzz_impl("native").await
test_fuzz_impl("parquet", false).await?;
test_fuzz_impl("native", false).await
}

#[tokio::test(flavor = "multi_thread")]
async fn test_fuzz_with_spill() -> Result<()> {
test_fuzz_impl("parquet", true).await?;
test_fuzz_impl("native", true).await
}

async fn plan_sql(ctx: Arc<QueryContext>, sql: &str) -> Result<Plan> {
Expand Down Expand Up @@ -1038,12 +1046,29 @@ fn get_test_suites() -> Vec<TestSuite> {
]
}

async fn test_fuzz_impl(format: &str) -> Result<()> {
async fn test_fuzz_impl(format: &str, spill: bool) -> Result<()> {
let test_suites = get_test_suites();
let spill_settings = if spill {
Some(HashMap::from([
("spilling_memory_ratio".to_string(), "100".to_string()),
(
"spilling_bytes_threshold_per_proc".to_string(),
"1".to_string(),
),
]))
} else {
None
};

for num_blocks in [1, 10] {
for num_rows_per_block in [1, 50] {
let (_guard, ctx, _) = create_ee_query_context(None).await.unwrap();
if let Some(s) = spill_settings.as_ref() {
let settings = ctx.get_settings();
// Make sure the operator will spill the aggregation.
settings.set_batch_settings(s)?;
}

let fixture = TestFixture::new_with_ctx(_guard, ctx).await;
// Prepare table and data
// Create random engine table to generate random data.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::intrinsics::unlikely;
use std::sync::Arc;
use std::vec;

Expand Down Expand Up @@ -39,6 +40,7 @@ use log::info;

use crate::pipelines::processors::transforms::aggregator::aggregate_cell::AggregateHashTableDropper;
use crate::pipelines::processors::transforms::aggregator::aggregate_meta::AggregateMeta;
use crate::pipelines::processors::transforms::group_by::Area;
use crate::pipelines::processors::transforms::group_by::HashMethodBounds;
use crate::pipelines::processors::transforms::group_by::PartitionedHashMethod;
use crate::pipelines::processors::transforms::group_by::PolymorphicKeysHelper;
Expand Down Expand Up @@ -103,6 +105,35 @@ impl TryFrom<Arc<QueryContext>> for AggregateSettings {
}
}

/// A owned temporary memory.
struct TempMemory {
place: StateAddr,
arena: Area,
}

impl TempMemory {
/// Create a lazy memory wh ich will not be allocated until the first time it is used.
fn create_lazy() -> Self {
let arena = Area::create();
Self {
place: StateAddr::new(0),
arena,
}
}

#[inline(always)]
fn alloc_layout(&mut self, params: &AggregatorParams) {
if unlikely(self.place.addr() == 0) {
self.place = params.alloc_layout(&mut self.arena);
}
}

#[inline(always)]
fn place(&self) -> &StateAddr {
&self.place
}
}

// SELECT column_name, agg(xxx) FROM table_name GROUP BY column_name
pub struct TransformPartialAggregate<Method: HashMethodBounds> {
method: Method,
Expand All @@ -111,8 +142,13 @@ pub struct TransformPartialAggregate<Method: HashMethodBounds> {

params: Arc<AggregatorParams>,

/// A temporary place to hold aggregating state from index data.
temp_place: StateAddr,
/// A temporary memory to transform aggregating state from index data.
///
/// **NOTES**: we should create a new [`Area`] to transform the aggregating index data.
/// We cannot use the [`Area`] in `hash_table` to hold the temporary memory,
/// because the [`Area`] may be moved out when spilling happens.
/// And this [`TransformPartialAggregate`] will lose the control of the memory.
temp_memory: TempMemory,
}

impl<Method: HashMethodBounds> TransformPartialAggregate<Method> {
Expand Down Expand Up @@ -143,7 +179,7 @@ impl<Method: HashMethodBounds> TransformPartialAggregate<Method> {
params,
hash_table,
settings: AggregateSettings::try_from(ctx)?,
temp_place: StateAddr::new(0),
temp_memory: TempMemory::create_lazy(),
},
))
}
Expand Down Expand Up @@ -204,9 +240,11 @@ impl<Method: HashMethodBounds> TransformPartialAggregate<Method> {

#[inline(always)]
#[allow(clippy::ptr_arg)] // &[StateAddr] slower than &StateAddrs ~20%
fn execute_agg_index_block(&self, block: &DataBlock, places: &StateAddrs) -> Result<()> {
fn execute_agg_index_block(&mut self, block: &DataBlock, places: &StateAddrs) -> Result<()> {
self.temp_memory.alloc_layout(&self.params);
let aggregate_functions = &self.params.aggregate_functions;
let offsets_aggregate_states = &self.params.offsets_aggregate_states;
let temp_place = self.temp_memory.place();

for index in 0..aggregate_functions.len() {
// Aggregation states are in the back of the block.
Expand All @@ -220,7 +258,7 @@ impl<Method: HashMethodBounds> TransformPartialAggregate<Method> {
.unwrap()
.as_string()
.unwrap();
let state_place = self.temp_place.next(offset);
let state_place = temp_place.next(offset);
for (row, mut raw_state) in agg_state.iter().enumerate() {
let place = &places[row];
function.deserialize(state_place, &mut raw_state)?;
Expand Down Expand Up @@ -277,9 +315,6 @@ impl<Method: HashMethodBounds> TransformPartialAggregate<Method> {
}

if is_agg_index_block {
if self.temp_place.addr() == 0 {
self.temp_place = self.params.alloc_layout(&mut hashtable.arena);
}
self.execute_agg_index_block(&block, &places)
} else {
Self::execute(&self.params, &block, &places)
Expand All @@ -300,9 +335,6 @@ impl<Method: HashMethodBounds> TransformPartialAggregate<Method> {
}

if is_agg_index_block {
if self.temp_place.addr() == 0 {
self.temp_place = self.params.alloc_layout(&mut hashtable.arena);
}
self.execute_agg_index_block(&block, &places)
} else {
Self::execute(&self.params, &block, &places)
Expand Down

0 comments on commit 7640e4c

Please sign in to comment.