From f98837e3a896d097c7624d6311a25cbe313e5e8a Mon Sep 17 00:00:00 2001 From: sundyli <543950155@qq.com> Date: Tue, 10 Oct 2023 21:09:57 -0700 Subject: [PATCH] chore(query): refactor agg state merge (#13153) * chore(query): refactor agg state merge * chore(query): refactor agg state merge * chore(query): refactor agg state merge * chore(query): refactor agg state merge * chore(query): refactor agg state merge * chore(query): refactor agg state merge * chore(query): fix render bug * feat(query): add test test_box_render_block * feat(query): remove useless holder --- src/query/expression/src/utils/block_debug.rs | 45 ++++--------------- src/query/expression/tests/it/block.rs | 39 ++++++++++++++++ .../adaptors/aggregate_null_unary_adaptor.rs | 20 ++++++--- .../aggregate_null_variadic_adaptor.rs | 20 ++++++--- .../adaptors/aggregate_ornull_adaptor.rs | 14 +++--- .../aggregate_approx_count_distinct.rs | 13 +++--- .../src/aggregates/aggregate_arg_min_max.rs | 29 +++++------- .../src/aggregates/aggregate_array_agg.rs | 32 ++++--------- .../src/aggregates/aggregate_array_moving.rs | 42 +++++++++-------- .../functions/src/aggregates/aggregate_avg.rs | 17 +++---- .../src/aggregates/aggregate_bitmap.rs | 31 +++++++------ .../aggregate_combinator_distinct.rs | 12 ++--- .../src/aggregates/aggregate_combinator_if.rs | 8 ++-- .../aggregates/aggregate_combinator_state.rs | 8 ++-- .../src/aggregates/aggregate_count.rs | 14 +++--- .../src/aggregates/aggregate_covariance.rs | 12 ++--- .../aggregates/aggregate_distinct_state.rs | 38 ++++++++-------- .../src/aggregates/aggregate_function.rs | 4 +- .../src/aggregates/aggregate_kurtosis.rs | 40 ++++++++--------- .../src/aggregates/aggregate_min_max_any.rs | 17 ++++--- .../src/aggregates/aggregate_null_result.rs | 4 +- .../src/aggregates/aggregate_quantile_cont.rs | 13 +++--- .../src/aggregates/aggregate_quantile_disc.rs | 29 ++++++------ .../aggregates/aggregate_quantile_tdigest.rs | 16 ++++--- .../src/aggregates/aggregate_retention.rs | 26 +++++------ .../src/aggregates/aggregate_scalar_state.rs | 17 ++----- .../src/aggregates/aggregate_skewness.rs | 40 ++++++++--------- .../src/aggregates/aggregate_stddev.rs | 12 ++--- .../src/aggregates/aggregate_string_agg.rs | 11 ++--- .../functions/src/aggregates/aggregate_sum.rs | 31 +++++++------ .../src/aggregates/aggregate_window_funnel.rs | 25 ++++------- .../transforms/aggregator/aggregate_cell.rs | 6 --- .../transforms/aggregator/aggregate_meta.rs | 4 -- .../aggregator/transform_aggregate_final.rs | 16 +------ .../aggregator/transform_aggregate_partial.rs | 22 +-------- .../aggregator/transform_single_key.rs | 43 ++++-------------- .../group_by/aggregator_polymorphic_keys.rs | 2 - 37 files changed, 350 insertions(+), 422 deletions(-) diff --git a/src/query/expression/src/utils/block_debug.rs b/src/query/expression/src/utils/block_debug.rs index 916f44f1ce5c..67bd614b99f2 100644 --- a/src/query/expression/src/utils/block_debug.rs +++ b/src/query/expression/src/utils/block_debug.rs @@ -177,46 +177,17 @@ fn create_box_table( }; let mut res_vec: Vec> = vec![]; - let top_collection = results.first().unwrap(); - let top_rows = top_collection.num_rows().min(top_rows); - if bottom_rows == 0 { - for block in results { - for row in 0..block.num_rows() { - let mut v = vec![]; - for block_entry in block.columns() { - let value = block_entry.value.index(row).unwrap().to_string(); - if replace_newline { - v.push(value.to_string().replace('\n', "\\n")); - } else { - v.push(value.to_string()); - } - } - res_vec.push(v); - } - } - } else { - let bottom_collection = results.last().unwrap(); - for row in 0..top_rows { - let mut v = vec![]; - for block_entry in top_collection.columns() { - let value = block_entry.value.index(row).unwrap().to_string(); - if replace_newline { - v.push(value.to_string().replace('\n', "\\n")); - } else { - v.push(value.to_string()); - } + + let mut rows = 0; + for block in results { + for row in 0..block.num_rows() { + rows += 1; + if rows > top_rows && rows <= row_count - bottom_rows { + continue; } - res_vec.push(v); - } - let take_num = if bottom_collection.num_rows() > bottom_rows { - bottom_collection.num_rows() - bottom_rows - } else { - 0 - }; - for row in take_num..bottom_collection.num_rows() { let mut v = vec![]; - for block_entry in top_collection.columns() { + for block_entry in block.columns() { let value = block_entry.value.index(row).unwrap().to_string(); if replace_newline { v.push(value.to_string().replace('\n', "\\n")); diff --git a/src/query/expression/tests/it/block.rs b/src/query/expression/tests/it/block.rs index 8264f595949d..397c769028bb 100644 --- a/src/query/expression/tests/it/block.rs +++ b/src/query/expression/tests/it/block.rs @@ -1,5 +1,12 @@ +use common_expression::block_debug::box_render; use common_expression::types::string::StringColumnBuilder; +use common_expression::types::DataType; +use common_expression::types::Int32Type; +use common_expression::types::NumberDataType; use common_expression::Column; +use common_expression::DataField; +use common_expression::DataSchemaRefExt; +use common_expression::FromData; use crate::common::new_block; @@ -17,3 +24,35 @@ fn test_split_block() { .collect::>(); assert_eq!(sizes, vec![3, 3, 4]); } + +#[test] +fn test_box_render_block() { + let value = b"abc"; + let n = 10; + let block = new_block(&[ + Int32Type::from_data(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]), + Column::String(StringColumnBuilder::repeat(&value[..], n).build()), + ]); + + let schema = DataSchemaRefExt::create(vec![ + DataField::new("a", DataType::Number(NumberDataType::Int32)), + DataField::new("e", DataType::String), + ]); + let d = box_render(&schema, &[block], 5, 1000, 1000, true).unwrap(); + let expected = r#"┌────────────────────┐ +│ a │ e │ +│ Int32 │ String │ +├───────────┼────────┤ +│ 1 │ 'abc' │ +│ 2 │ 'abc' │ +│ 3 │ 'abc' │ +│ · │ · │ +│ · │ · │ +│ · │ · │ +│ 9 │ 'abc' │ +│ 10 │ 'abc' │ +│ 10 rows │ │ +│ (5 shown) │ │ +└────────────────────┘"#; + assert_eq!(d, expected); +} diff --git a/src/query/functions/src/aggregates/adaptors/aggregate_null_unary_adaptor.rs b/src/query/functions/src/aggregates/adaptors/aggregate_null_unary_adaptor.rs index d349cb0dec77..0f5833b8e1d9 100644 --- a/src/query/functions/src/aggregates/adaptors/aggregate_null_unary_adaptor.rs +++ b/src/query/functions/src/aggregates/adaptors/aggregate_null_unary_adaptor.rs @@ -202,20 +202,26 @@ impl AggregateFunction for AggregateNullUnaryAdapto Ok(()) } - fn deserialize(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { + fn merge(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { + if self.get_flag(place) == 0 { + // initial the state to remove the dirty stats + self.init_state(place); + } + if NULLABLE_RESULT { let flag = reader[reader.len() - 1]; - self.nested - .deserialize(place, &mut &reader[..reader.len() - 1])?; - self.set_flag(place, flag); + if flag == 1 { + self.set_flag(place, 1); + self.nested.merge(place, &mut &reader[..reader.len() - 1])?; + } } else { - self.nested.deserialize(place, reader)?; + self.nested.merge(place, reader)?; } Ok(()) } - fn merge(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { + fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { if self.get_flag(place) == 0 { // initial the state to remove the dirty stats self.init_state(place); @@ -223,7 +229,7 @@ impl AggregateFunction for AggregateNullUnaryAdapto if self.get_flag(rhs) == 1 { self.set_flag(place, 1); - self.nested.merge(place, rhs)?; + self.nested.merge_states(place, rhs)?; } Ok(()) diff --git a/src/query/functions/src/aggregates/adaptors/aggregate_null_variadic_adaptor.rs b/src/query/functions/src/aggregates/adaptors/aggregate_null_variadic_adaptor.rs index e4e0317a7809..3c455fe7b1ec 100644 --- a/src/query/functions/src/aggregates/adaptors/aggregate_null_variadic_adaptor.rs +++ b/src/query/functions/src/aggregates/adaptors/aggregate_null_variadic_adaptor.rs @@ -207,19 +207,25 @@ impl AggregateFunction Ok(()) } - fn deserialize(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { + fn merge(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { + if self.get_flag(place) == 0 { + // initial the state to remove the dirty stats + self.init_state(place); + } + if NULLABLE_RESULT { let flag = reader[reader.len() - 1]; - self.nested - .deserialize(place, &mut &reader[..reader.len() - 1])?; - self.set_flag(place, flag); + if flag == 1 { + self.set_flag(place, flag); + self.nested.merge(place, &mut &reader[..reader.len() - 1])?; + } } else { - self.nested.deserialize(place, reader)?; + self.nested.merge(place, reader)?; } Ok(()) } - fn merge(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { + fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { if self.get_flag(place) == 0 { // initial the state to remove the dirty stats self.init_state(place); @@ -227,7 +233,7 @@ impl AggregateFunction if self.get_flag(rhs) == 1 { self.set_flag(place, 1); - self.nested.merge(place, rhs)?; + self.nested.merge_states(place, rhs)?; } Ok(()) } diff --git a/src/query/functions/src/aggregates/adaptors/aggregate_ornull_adaptor.rs b/src/query/functions/src/aggregates/adaptors/aggregate_ornull_adaptor.rs index 442fc85cc356..80d74c1a7ee1 100644 --- a/src/query/functions/src/aggregates/adaptors/aggregate_ornull_adaptor.rs +++ b/src/query/functions/src/aggregates/adaptors/aggregate_ornull_adaptor.rs @@ -173,16 +173,16 @@ impl AggregateFunction for AggregateFunctionOrNullAdaptor { } #[inline] - fn deserialize(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { - let flag = reader[reader.len() - 1]; - self.inner - .deserialize(place, &mut &reader[..reader.len() - 1])?; - self.set_flag(place, flag); + fn merge(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { + let flag = self.get_flag(place) > 0 || reader[reader.len() - 1] > 0; + + self.inner.merge(place, &mut &reader[..reader.len() - 1])?; + self.set_flag(place, flag as u8); Ok(()) } - fn merge(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { - self.inner.merge(place, rhs)?; + fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { + self.inner.merge_states(place, rhs)?; let flag = self.get_flag(place) > 0 || self.get_flag(rhs) > 0; self.set_flag(place, u8::from(flag)); Ok(()) diff --git a/src/query/functions/src/aggregates/aggregate_approx_count_distinct.rs b/src/query/functions/src/aggregates/aggregate_approx_count_distinct.rs index 94a53dd09c28..94bba8b7c9d6 100644 --- a/src/query/functions/src/aggregates/aggregate_approx_count_distinct.rs +++ b/src/query/functions/src/aggregates/aggregate_approx_count_distinct.rs @@ -126,17 +126,18 @@ where for<'a> T::ScalarRef<'a>: Hash serialize_into_buf(writer, &state.hll) } - fn deserialize(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { + fn merge(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { let state = place.get::>>(); - state.hll = deserialize_from_slice(reader)?; + let hll: HyperLogLog> = deserialize_from_slice(reader)?; + state.hll.union(&hll); + Ok(()) } - fn merge(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { + fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { let state = place.get::>>(); - let rhs = rhs.get::>>(); - state.hll.union(&rhs.hll); - + let other = rhs.get::>>(); + state.hll.union(&other.hll); Ok(()) } diff --git a/src/query/functions/src/aggregates/aggregate_arg_min_max.rs b/src/query/functions/src/aggregates/aggregate_arg_min_max.rs index 7c2909c6e8ad..bcb72619ce15 100644 --- a/src/query/functions/src/aggregates/aggregate_arg_min_max.rs +++ b/src/query/functions/src/aggregates/aggregate_arg_min_max.rs @@ -49,7 +49,9 @@ use crate::with_simple_no_number_mapped_type; // State for arg_min(arg, val) and arg_max(arg, val) // A: ValueType for arg. // V: ValueType for val. -pub trait AggregateArgMinMaxState: Send + Sync + 'static { +pub trait AggregateArgMinMaxState: + Serialize + DeserializeOwned + Send + Sync + 'static +{ fn new() -> Self; fn add(&mut self, value: V::ScalarRef<'_>, data: Scalar); fn add_batch( @@ -60,8 +62,6 @@ pub trait AggregateArgMinMaxState: Send + Sync + 'st ) -> Result<()>; fn merge(&mut self, rhs: &Self) -> Result<()>; - fn serialize(&self, writer: &mut Vec) -> Result<()>; - fn deserialize(&mut self, reader: &mut &[u8]) -> Result<()>; fn merge_result(&mut self, column: &mut ColumnBuilder) -> Result<()>; } @@ -180,15 +180,6 @@ where Ok(()) } - fn serialize(&self, writer: &mut Vec) -> Result<()> { - serialize_into_buf(writer, self) - } - - fn deserialize(&mut self, reader: &mut &[u8]) -> Result<()> { - *self = deserialize_from_slice(reader)?; - Ok(()) - } - fn merge_result(&mut self, builder: &mut ColumnBuilder) -> Result<()> { if self.value.is_some() { if let Some(inner) = A::try_downcast_builder(builder) { @@ -291,18 +282,20 @@ where fn serialize(&self, place: StateAddr, writer: &mut Vec) -> Result<()> { let state = place.get::(); - state.serialize(writer) + serialize_into_buf(writer, state) } - fn deserialize(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { + fn merge(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { let state = place.get::(); - state.deserialize(reader) + let rhs: State = deserialize_from_slice(reader)?; + + state.merge(&rhs) } - fn merge(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { - let rhs = rhs.get::(); + fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { let state = place.get::(); - state.merge(rhs) + let other = rhs.get::(); + state.merge(other) } fn merge_result(&self, place: StateAddr, builder: &mut ColumnBuilder) -> Result<()> { diff --git a/src/query/functions/src/aggregates/aggregate_array_agg.rs b/src/query/functions/src/aggregates/aggregate_array_agg.rs index 43930e3d2cf6..30c743e73d58 100644 --- a/src/query/functions/src/aggregates/aggregate_array_agg.rs +++ b/src/query/functions/src/aggregates/aggregate_array_agg.rs @@ -124,15 +124,6 @@ where builder.push(array_value); Ok(()) } - - fn serialize(&self, writer: &mut Vec) -> Result<()> { - serialize_into_buf(writer, self) - } - - fn deserialize(&mut self, reader: &mut &[u8]) -> Result<()> { - self.values = deserialize_from_slice(reader)?; - Ok(()) - } } #[derive(Serialize, Deserialize, Debug)] @@ -234,15 +225,6 @@ where builder.push(array_value); Ok(()) } - - fn serialize(&self, writer: &mut Vec) -> Result<()> { - serialize_into_buf(writer, self) - } - - fn deserialize(&mut self, reader: &mut &[u8]) -> Result<()> { - self.values = deserialize_from_slice(reader)?; - Ok(()) - } } #[derive(Clone)] @@ -356,18 +338,20 @@ where fn serialize(&self, place: StateAddr, writer: &mut Vec) -> Result<()> { let state = place.get::(); - state.serialize(writer) + serialize_into_buf(writer, state) } - fn deserialize(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { + fn merge(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { let state = place.get::(); - state.deserialize(reader) + let rhs: State = deserialize_from_slice(reader)?; + + state.merge(&rhs) } - fn merge(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { - let rhs = rhs.get::(); + fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { let state = place.get::(); - state.merge(rhs) + let other = rhs.get::(); + state.merge(other) } fn merge_result(&self, place: StateAddr, builder: &mut ColumnBuilder) -> Result<()> { diff --git a/src/query/functions/src/aggregates/aggregate_array_moving.rs b/src/query/functions/src/aggregates/aggregate_array_moving.rs index 3a58e9640939..5242afbaa4ca 100644 --- a/src/query/functions/src/aggregates/aggregate_array_moving.rs +++ b/src/query/functions/src/aggregates/aggregate_array_moving.rs @@ -45,6 +45,7 @@ use common_io::prelude::*; use ethnum::i256; use num_traits::AsPrimitive; use serde::de::DeserializeOwned; +use serde::Deserialize; use serde::Serialize; use super::aggregate_function::AggregateFunction; @@ -56,9 +57,10 @@ use crate::aggregates::assert_unary_arguments; use crate::aggregates::assert_variadic_params; use crate::BUILTIN_FUNCTIONS; -#[derive(Default, Debug)] -pub struct NumberArrayMovingSumState { +#[derive(Default, Debug, Deserialize, Serialize)] +pub struct NumberArrayMovingSumState { values: Vec, + #[serde(skip)] _t: PhantomData, } @@ -137,7 +139,7 @@ where } #[inline(always)] - fn merge(&mut self, other: &mut Self) -> Result<()> { + fn merge(&mut self, other: &Self) -> Result<()> { self.values.extend_from_slice(&other.values); Ok(()) } @@ -200,8 +202,8 @@ where } } -#[derive(Default)] -pub struct DecimalArrayMovingSumState { +#[derive(Default, Deserialize, Serialize)] +pub struct DecimalArrayMovingSumState { pub values: Vec, } @@ -313,7 +315,7 @@ where T: Decimal } #[inline(always)] - fn merge(&mut self, other: &mut Self) -> Result<()> { + fn merge(&mut self, other: &Self) -> Result<()> { self.values.extend_from_slice(&other.values); Ok(()) } @@ -454,18 +456,20 @@ where State: SumState fn serialize(&self, place: StateAddr, writer: &mut Vec) -> Result<()> { let state = place.get::(); - state.serialize(writer) + serialize_into_buf(writer, state) } - fn deserialize(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { + fn merge(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { let state = place.get::(); - state.deserialize(reader) + let rhs: State = deserialize_from_slice(reader)?; + + state.merge(&rhs) } - fn merge(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { - let rhs = rhs.get::(); + fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { let state = place.get::(); - state.merge(rhs) + let other = rhs.get::(); + state.merge(other) } fn merge_result(&self, place: StateAddr, builder: &mut ColumnBuilder) -> Result<()> { @@ -646,18 +650,20 @@ where State: SumState fn serialize(&self, place: StateAddr, writer: &mut Vec) -> Result<()> { let state = place.get::(); - state.serialize(writer) + serialize_into_buf(writer, state) } - fn deserialize(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { + fn merge(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { let state = place.get::(); - state.deserialize(reader) + let rhs: State = deserialize_from_slice(reader)?; + + state.merge(&rhs) } - fn merge(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { - let rhs = rhs.get::(); + fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { let state = place.get::(); - state.merge(rhs) + let other = rhs.get::(); + state.merge(other) } fn merge_result(&self, place: StateAddr, builder: &mut ColumnBuilder) -> Result<()> { diff --git a/src/query/functions/src/aggregates/aggregate_avg.rs b/src/query/functions/src/aggregates/aggregate_avg.rs index a4aa86c0fc0c..51656c0052b5 100644 --- a/src/query/functions/src/aggregates/aggregate_avg.rs +++ b/src/query/functions/src/aggregates/aggregate_avg.rs @@ -42,7 +42,7 @@ use crate::aggregates::AggregateFunction; use crate::aggregates::AggregateFunctionRef; #[derive(Serialize, Deserialize)] -struct AvgState { +struct AvgState { pub value: T, pub count: u64, } @@ -117,21 +117,22 @@ where T: SumState fn serialize(&self, place: StateAddr, writer: &mut Vec) -> Result<()> { let state = place.get::>(); - writer.write_scalar(&state.count)?; - state.value.serialize(writer) + serialize_into_buf(writer, state) } - fn deserialize(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { + fn merge(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { let state = place.get::>(); - state.count = reader.read_scalar()?; - state.value.deserialize(reader) + let rhs: AvgState = deserialize_from_slice(reader)?; + + state.count += rhs.count; + state.value.merge(&rhs.value) } - fn merge(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { + fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { let state = place.get::>(); let rhs = rhs.get::>(); state.count += rhs.count; - state.value.merge(&mut rhs.value) + state.value.merge(&rhs.value) } #[allow(unused_mut)] diff --git a/src/query/functions/src/aggregates/aggregate_bitmap.rs b/src/query/functions/src/aggregates/aggregate_bitmap.rs index dd0286d80bd3..efc289ee44d2 100644 --- a/src/query/functions/src/aggregates/aggregate_bitmap.rs +++ b/src/query/functions/src/aggregates/aggregate_bitmap.rs @@ -14,6 +14,7 @@ use std::alloc::Layout; use std::fmt; +use std::io::BufRead; use std::marker::PhantomData; use std::ops::BitAndAssign; use std::ops::BitOrAssign; @@ -296,22 +297,24 @@ where Ok(()) } - fn deserialize(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { + fn merge(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { let state = place.get::(); + let flag = reader[0]; - state.rb = if flag == 1 { - Some(RoaringTreemap::deserialize_from(&reader[1..])?) - } else { - None - }; + reader.consume(1); + if flag == 1 { + let rb = RoaringTreemap::deserialize_from(reader)?; + state.add::(rb); + } Ok(()) } - fn merge(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { + fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { let state = place.get::(); - let rhs = rhs.get::(); - if let Some(rb) = &rhs.rb { - state.add::(rb.clone()); + let other = rhs.get::(); + + if let Some(rb) = other.rb.take() { + state.add::(rb); } Ok(()) } @@ -482,12 +485,12 @@ where self.inner.serialize(place, writer) } - fn deserialize(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { - self.inner.deserialize(place, reader) + fn merge(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { + self.inner.merge(place, reader) } - fn merge(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { - self.inner.merge(place, rhs) + fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { + self.inner.merge_states(place, rhs) } fn merge_result(&self, place: StateAddr, builder: &mut ColumnBuilder) -> Result<()> { diff --git a/src/query/functions/src/aggregates/aggregate_combinator_distinct.rs b/src/query/functions/src/aggregates/aggregate_combinator_distinct.rs index a8cba0a414c7..1e392d00a737 100644 --- a/src/query/functions/src/aggregates/aggregate_combinator_distinct.rs +++ b/src/query/functions/src/aggregates/aggregate_combinator_distinct.rs @@ -96,15 +96,17 @@ where State: DistinctStateFunc state.serialize(writer) } - fn deserialize(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { + fn merge(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { let state = place.get::(); - state.deserialize(reader) + let rhs = State::deserialize(reader)?; + + state.merge(&rhs) } - fn merge(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { + fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { let state = place.get::(); - let rhs = rhs.get::(); - state.merge(rhs) + let other = rhs.get::(); + state.merge(other) } #[allow(unused_mut)] diff --git a/src/query/functions/src/aggregates/aggregate_combinator_if.rs b/src/query/functions/src/aggregates/aggregate_combinator_if.rs index f02f8fc5936d..fe82c14dad37 100644 --- a/src/query/functions/src/aggregates/aggregate_combinator_if.rs +++ b/src/query/functions/src/aggregates/aggregate_combinator_if.rs @@ -153,12 +153,12 @@ impl AggregateFunction for AggregateIfCombinator { self.nested.serialize(place, writer) } - fn deserialize(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { - self.nested.deserialize(place, reader) + fn merge(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { + self.nested.merge(place, reader) } - fn merge(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { - self.nested.merge(place, rhs) + fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { + self.nested.merge_states(place, rhs) } fn merge_result(&self, place: StateAddr, builder: &mut ColumnBuilder) -> Result<()> { diff --git a/src/query/functions/src/aggregates/aggregate_combinator_state.rs b/src/query/functions/src/aggregates/aggregate_combinator_state.rs index 368264e5cfb5..59ada5c9a016 100644 --- a/src/query/functions/src/aggregates/aggregate_combinator_state.rs +++ b/src/query/functions/src/aggregates/aggregate_combinator_state.rs @@ -111,12 +111,12 @@ impl AggregateFunction for AggregateStateCombinator { self.nested.serialize(place, writer) } - fn deserialize(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { - self.nested.deserialize(place, reader) + fn merge(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { + self.nested.merge(place, reader) } - fn merge(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { - self.nested.merge(place, rhs) + fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { + self.nested.merge_states(place, rhs) } fn merge_result(&self, place: StateAddr, builder: &mut ColumnBuilder) -> Result<()> { diff --git a/src/query/functions/src/aggregates/aggregate_count.rs b/src/query/functions/src/aggregates/aggregate_count.rs index bc15ad143cab..e4478bce05c0 100644 --- a/src/query/functions/src/aggregates/aggregate_count.rs +++ b/src/query/functions/src/aggregates/aggregate_count.rs @@ -32,7 +32,7 @@ use super::aggregate_function_factory::AggregateFunctionDescription; use super::StateAddr; use crate::aggregates::aggregator_common::assert_variadic_arguments; -pub struct AggregateCountState { +struct AggregateCountState { count: u64, } @@ -152,17 +152,17 @@ impl AggregateFunction for AggregateCountFunction { serialize_into_buf(writer, &state.count) } - fn deserialize(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { + fn merge(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { let state = place.get::(); - state.count = deserialize_from_slice(reader)?; + let other: u64 = deserialize_from_slice(reader)?; + state.count += other; Ok(()) } - fn merge(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { + fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { let state = place.get::(); - let rhs = rhs.get::(); - state.count += rhs.count; - + let other = rhs.get::(); + state.count += other.count; Ok(()) } diff --git a/src/query/functions/src/aggregates/aggregate_covariance.rs b/src/query/functions/src/aggregates/aggregate_covariance.rs index f28ae31a511f..4b13e100e04d 100644 --- a/src/query/functions/src/aggregates/aggregate_covariance.rs +++ b/src/query/functions/src/aggregates/aggregate_covariance.rs @@ -230,17 +230,17 @@ where serialize_into_buf(writer, state) } - fn deserialize(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { + fn merge(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { let state = place.get::(); - *state = deserialize_from_slice(reader)?; - + let rhs: AggregateCovarianceState = deserialize_from_slice(reader)?; + state.merge(&rhs); Ok(()) } - fn merge(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { + fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { let state = place.get::(); - let rhs = rhs.get::(); - state.merge(rhs); + let other = rhs.get::(); + state.merge(other); Ok(()) } diff --git a/src/query/functions/src/aggregates/aggregate_distinct_state.rs b/src/query/functions/src/aggregates/aggregate_distinct_state.rs index ce6e4253526e..f38c201951b8 100644 --- a/src/query/functions/src/aggregates/aggregate_distinct_state.rs +++ b/src/query/functions/src/aggregates/aggregate_distinct_state.rs @@ -15,6 +15,7 @@ use std::collections::hash_map::RandomState; use std::collections::HashSet; use std::hash::Hasher; +use std::io::BufRead; use std::marker::Send; use std::marker::Sync; use std::sync::Arc; @@ -44,10 +45,10 @@ use serde::Serialize; use siphasher::sip128::Hasher128; use siphasher::sip128::SipHasher24; -pub trait DistinctStateFunc: Send + Sync { +pub trait DistinctStateFunc: Sized + Send + Sync { fn new() -> Self; fn serialize(&self, writer: &mut Vec) -> Result<()>; - fn deserialize(&mut self, reader: &mut &[u8]) -> Result<()>; + fn deserialize(reader: &mut &[u8]) -> Result; fn is_empty(&self) -> bool; fn len(&self) -> usize; fn add(&mut self, columns: &[Column], row: usize) -> Result<()>; @@ -85,9 +86,9 @@ impl DistinctStateFunc for AggregateDistinctState { serialize_into_buf(writer, &self.set) } - fn deserialize(&mut self, reader: &mut &[u8]) -> Result<()> { - self.set = deserialize_from_slice(reader)?; - Ok(()) + fn deserialize(reader: &mut &[u8]) -> Result { + let set = deserialize_from_slice(reader)?; + Ok(Self { set }) } fn is_empty(&self) -> bool { @@ -167,15 +168,16 @@ impl DistinctStateFunc for AggregateDistinctStringState { Ok(()) } - fn deserialize(&mut self, reader: &mut &[u8]) -> Result<()> { + fn deserialize(reader: &mut &[u8]) -> Result { let size = reader.read_uvarint()?; - self.set = ShortStringHashSet::<[u8]>::with_capacity(size as usize, Arc::new(Bump::new())); + let mut set = + ShortStringHashSet::<[u8]>::with_capacity(size as usize, Arc::new(Bump::new())); for _ in 0..size { let s = reader.read_uvarint()? as usize; - let _ = self.set.set_insert(&reader[..s]); - *reader = &reader[s..]; + let _ = set.set_insert(&reader[..s]); + reader.consume(s); } - Ok(()) + Ok(Self { set }) } fn is_empty(&self) -> bool { @@ -252,14 +254,14 @@ where T: Number + Serialize + DeserializeOwned + HashtableKeyable Ok(()) } - fn deserialize(&mut self, reader: &mut &[u8]) -> Result<()> { + fn deserialize(reader: &mut &[u8]) -> Result { let size = reader.read_uvarint()?; - self.set = CommonHashSet::with_capacity(size as usize); + let mut set = CommonHashSet::with_capacity(size as usize); for _ in 0..size { let t: T = deserialize_from_slice(reader)?; - let _ = self.set.set_insert(t).is_ok(); + let _ = set.set_insert(t).is_ok(); } - Ok(()) + Ok(Self { set }) } fn is_empty(&self) -> bool { @@ -333,14 +335,14 @@ impl DistinctStateFunc for AggregateUniqStringState { Ok(()) } - fn deserialize(&mut self, reader: &mut &[u8]) -> Result<()> { + fn deserialize(reader: &mut &[u8]) -> Result { let size = reader.read_uvarint()?; - self.set = StackHashSet::with_capacity(size as usize); + let mut set = StackHashSet::with_capacity(size as usize); for _ in 0..size { let e = deserialize_from_slice(reader)?; - let _ = self.set.set_insert(e).is_ok(); + let _ = set.set_insert(e).is_ok(); } - Ok(()) + Ok(Self { set }) } fn is_empty(&self) -> bool { diff --git a/src/query/functions/src/aggregates/aggregate_function.rs b/src/query/functions/src/aggregates/aggregate_function.rs index 97a9de5d0bcc..5c8370552dd6 100644 --- a/src/query/functions/src/aggregates/aggregate_function.rs +++ b/src/query/functions/src/aggregates/aggregate_function.rs @@ -71,9 +71,9 @@ pub trait AggregateFunction: fmt::Display + Sync + Send { // serialize the state into binary array fn serialize(&self, _place: StateAddr, _writer: &mut Vec) -> Result<()>; - fn deserialize(&self, _place: StateAddr, _reader: &mut &[u8]) -> Result<()>; + fn merge(&self, _place: StateAddr, _reader: &mut &[u8]) -> Result<()>; - fn merge(&self, _place: StateAddr, _rhs: StateAddr) -> Result<()>; + fn merge_states(&self, _place: StateAddr, _rhs: StateAddr) -> Result<()>; fn batch_merge_result(&self, places: &[StateAddr], builder: &mut ColumnBuilder) -> Result<()> { for place in places { diff --git a/src/query/functions/src/aggregates/aggregate_kurtosis.rs b/src/query/functions/src/aggregates/aggregate_kurtosis.rs index 2c1b1337686c..6cdab1bfa5f0 100644 --- a/src/query/functions/src/aggregates/aggregate_kurtosis.rs +++ b/src/query/functions/src/aggregates/aggregate_kurtosis.rs @@ -1,17 +1,17 @@ -// Copyright 2021 Datafuse Labs -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - +// Copyright 2021 Datafuse Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + use std::alloc::Layout; use std::fmt::Display; use std::fmt::Formatter; @@ -198,17 +198,17 @@ where T: Number + AsPrimitive serialize_into_buf(writer, state) } - fn deserialize(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { + fn merge(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { let state = place.get::(); - *state = deserialize_from_slice(reader)?; - + let rhs: KurtosisState = deserialize_from_slice(reader)?; + state.merge(&rhs); Ok(()) } - fn merge(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { - let rhs = rhs.get::(); + fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { let state = place.get::(); - state.merge(rhs); + let other = rhs.get::(); + state.merge(other); Ok(()) } diff --git a/src/query/functions/src/aggregates/aggregate_min_max_any.rs b/src/query/functions/src/aggregates/aggregate_min_max_any.rs index 4092c0940ef1..0169329b4c15 100644 --- a/src/query/functions/src/aggregates/aggregate_min_max_any.rs +++ b/src/query/functions/src/aggregates/aggregate_min_max_any.rs @@ -27,6 +27,8 @@ use common_expression::with_number_mapped_type; use common_expression::Column; use common_expression::ColumnBuilder; use common_expression::Scalar; +use common_io::prelude::deserialize_from_slice; +use common_io::prelude::serialize_into_buf; use ethnum::i256; use super::aggregate_function_factory::AggregateFunctionDescription; @@ -117,18 +119,21 @@ where fn serialize(&self, place: StateAddr, writer: &mut Vec) -> Result<()> { let state = place.get::(); - state.serialize(writer) + + serialize_into_buf(writer, state) } - fn deserialize(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { + fn merge(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { let state = place.get::(); - state.deserialize(reader) + let rhs: State = deserialize_from_slice(reader)?; + + state.merge(&rhs) } - fn merge(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { - let rhs = rhs.get::(); + fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { let state = place.get::(); - state.merge(rhs) + let other = rhs.get::(); + state.merge(other) } fn merge_result(&self, place: StateAddr, builder: &mut ColumnBuilder) -> Result<()> { diff --git a/src/query/functions/src/aggregates/aggregate_null_result.rs b/src/query/functions/src/aggregates/aggregate_null_result.rs index afdd510c3c5b..01e4dcbf71cf 100644 --- a/src/query/functions/src/aggregates/aggregate_null_result.rs +++ b/src/query/functions/src/aggregates/aggregate_null_result.rs @@ -81,11 +81,11 @@ impl AggregateFunction for AggregateNullResultFunction { Ok(()) } - fn deserialize(&self, _place: StateAddr, _reader: &mut &[u8]) -> Result<()> { + fn merge(&self, _place: StateAddr, _reader: &mut &[u8]) -> Result<()> { Ok(()) } - fn merge(&self, _place: StateAddr, _rhs: StateAddr) -> Result<()> { + fn merge_states(&self, _place: StateAddr, _rhs: StateAddr) -> Result<()> { Ok(()) } diff --git a/src/query/functions/src/aggregates/aggregate_quantile_cont.rs b/src/query/functions/src/aggregates/aggregate_quantile_cont.rs index 1eabd6998db1..e307f63f9585 100644 --- a/src/query/functions/src/aggregates/aggregate_quantile_cont.rs +++ b/src/query/functions/src/aggregates/aggregate_quantile_cont.rs @@ -213,17 +213,16 @@ where T: Number + AsPrimitive serialize_into_buf(writer, state) } - fn deserialize(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { + fn merge(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { let state = place.get::(); - *state = deserialize_from_slice(reader)?; - - Ok(()) + let rhs: QuantileContState = deserialize_from_slice(reader)?; + state.merge(&rhs) } - fn merge(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { - let rhs = rhs.get::(); + fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { let state = place.get::(); - state.merge(rhs) + let other = rhs.get::(); + state.merge(other) } fn merge_result(&self, place: StateAddr, builder: &mut ColumnBuilder) -> Result<()> { diff --git a/src/query/functions/src/aggregates/aggregate_quantile_disc.rs b/src/query/functions/src/aggregates/aggregate_quantile_disc.rs index e14bb9203471..99b66305c76f 100644 --- a/src/query/functions/src/aggregates/aggregate_quantile_disc.rs +++ b/src/query/functions/src/aggregates/aggregate_quantile_disc.rs @@ -46,14 +46,14 @@ use crate::aggregates::StateAddr; use crate::with_simple_no_number_mapped_type; use crate::BUILTIN_FUNCTIONS; -pub trait QuantileStateFunc: Send + Sync + 'static { +pub trait QuantileStateFunc: + Serialize + DeserializeOwned + Send + Sync + 'static +{ fn new() -> Self; fn add(&mut self, other: T::ScalarRef<'_>); fn add_batch(&mut self, column: &T::Column, validity: Option<&Bitmap>) -> Result<()>; fn merge(&mut self, rhs: &Self) -> Result<()>; fn merge_result(&mut self, builder: &mut ColumnBuilder, levels: Vec) -> Result<()>; - fn serialize(&self, writer: &mut Vec) -> Result<()>; - fn deserialize(&mut self, reader: &mut &[u8]) -> Result<()>; } #[derive(Serialize, Deserialize)] struct QuantileState @@ -148,13 +148,6 @@ where } Ok(()) } - fn serialize(&self, writer: &mut Vec) -> Result<()> { - serialize_into_buf(writer, self) - } - fn deserialize(&mut self, reader: &mut &[u8]) -> Result<()> { - self.value = deserialize_from_slice(reader)?; - Ok(()) - } } #[derive(Clone)] pub struct AggregateQuantileDiscFunction { @@ -229,17 +222,21 @@ where } fn serialize(&self, place: StateAddr, writer: &mut Vec) -> Result<()> { let state = place.get::(); - state.serialize(writer) + serialize_into_buf(writer, state) } - fn deserialize(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { + + fn merge(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { let state = place.get::(); - state.deserialize(reader) + let rhs: State = deserialize_from_slice(reader)?; + state.merge(&rhs) } - fn merge(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { - let rhs = rhs.get::(); + + fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { let state = place.get::(); - state.merge(rhs) + let other = rhs.get::(); + state.merge(other) } + fn merge_result(&self, place: StateAddr, builder: &mut ColumnBuilder) -> Result<()> { let state = place.get::(); state.merge_result(builder, self.levels.clone()) diff --git a/src/query/functions/src/aggregates/aggregate_quantile_tdigest.rs b/src/query/functions/src/aggregates/aggregate_quantile_tdigest.rs index 2a9493e84a5e..767962ad5b97 100644 --- a/src/query/functions/src/aggregates/aggregate_quantile_tdigest.rs +++ b/src/query/functions/src/aggregates/aggregate_quantile_tdigest.rs @@ -359,17 +359,19 @@ where T: Number + AsPrimitive let state = place.get::(); serialize_into_buf(writer, state) } - fn deserialize(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { - let state = place.get::(); - *state = deserialize_from_slice(reader)?; - Ok(()) + fn merge(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { + let state = place.get::(); + let mut rhs: QuantileTDigestState = deserialize_from_slice(reader)?; + state.merge(&mut rhs) } - fn merge(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { - let rhs = rhs.get::(); + + fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { let state = place.get::(); - state.merge(rhs) + let other = rhs.get::(); + state.merge(other) } + fn merge_result(&self, place: StateAddr, builder: &mut ColumnBuilder) -> Result<()> { let state = place.get::(); state.merge_result(builder, self.levels.clone()) diff --git a/src/query/functions/src/aggregates/aggregate_retention.rs b/src/query/functions/src/aggregates/aggregate_retention.rs index 4ea9fb645a60..bb7a1ea55f58 100644 --- a/src/query/functions/src/aggregates/aggregate_retention.rs +++ b/src/query/functions/src/aggregates/aggregate_retention.rs @@ -27,6 +27,8 @@ use common_expression::Column; use common_expression::ColumnBuilder; use common_expression::Scalar; use common_io::prelude::*; +use serde::Deserialize; +use serde::Serialize; use super::aggregate_function::AggregateFunction; use super::aggregate_function::AggregateFunctionRef; @@ -34,6 +36,7 @@ use super::aggregate_function_factory::AggregateFunctionDescription; use super::StateAddr; use crate::aggregates::aggregator_common::assert_variadic_arguments; +#[derive(Serialize, Deserialize)] struct AggregateRetentionState { pub events: u32, } @@ -47,15 +50,6 @@ impl AggregateRetentionState { fn merge(&mut self, other: &Self) { self.events |= other.events; } - - fn serialize(&self, writer: &mut Vec) -> Result<()> { - serialize_into_buf(writer, &self.events) - } - - fn deserialize(&mut self, reader: &mut &[u8]) -> Result<()> { - self.events = deserialize_from_slice(reader)?; - Ok(()) - } } #[derive(Clone)] @@ -144,18 +138,20 @@ impl AggregateFunction for AggregateRetentionFunction { fn serialize(&self, place: StateAddr, writer: &mut Vec) -> Result<()> { let state = place.get::(); - state.serialize(writer) + serialize_into_buf(writer, state) } - fn deserialize(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { + fn merge(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { let state = place.get::(); - state.deserialize(reader) + let rhs: AggregateRetentionState = deserialize_from_slice(reader)?; + state.merge(&rhs); + Ok(()) } - fn merge(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { - let rhs = rhs.get::(); + fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { let state = place.get::(); - state.merge(rhs); + let other = rhs.get::(); + state.merge(other); Ok(()) } diff --git a/src/query/functions/src/aggregates/aggregate_scalar_state.rs b/src/query/functions/src/aggregates/aggregate_scalar_state.rs index 0ee4039da433..0f9f7c8b68d6 100644 --- a/src/query/functions/src/aggregates/aggregate_scalar_state.rs +++ b/src/query/functions/src/aggregates/aggregate_scalar_state.rs @@ -20,8 +20,6 @@ use common_exception::Result; use common_expression::types::DataType; use common_expression::types::ValueType; use common_expression::ColumnBuilder; -use common_io::prelude::deserialize_from_slice; -use common_io::prelude::serialize_into_buf; use serde::de::DeserializeOwned; use serde::Deserialize; use serde::Serialize; @@ -111,14 +109,14 @@ impl ChangeIf for CmpAny { } } -pub trait ScalarStateFunc: Send + Sync + 'static { +pub trait ScalarStateFunc: + Serialize + DeserializeOwned + Send + Sync + 'static +{ fn new() -> Self; fn add(&mut self, other: Option>); fn add_batch(&mut self, column: &T::Column, validity: Option<&Bitmap>) -> Result<()>; fn merge(&mut self, rhs: &Self) -> Result<()>; fn merge_result(&mut self, builder: &mut ColumnBuilder) -> Result<()>; - fn serialize(&self, writer: &mut Vec) -> Result<()>; - fn deserialize(&mut self, reader: &mut &[u8]) -> Result<()>; } #[derive(Serialize, Deserialize)] @@ -237,15 +235,6 @@ where } Ok(()) } - - fn serialize(&self, writer: &mut Vec) -> Result<()> { - serialize_into_buf(writer, self) - } - - fn deserialize(&mut self, reader: &mut &[u8]) -> Result<()> { - self.value = deserialize_from_slice(reader)?; - Ok(()) - } } pub fn need_manual_drop_state(data_type: &DataType) -> bool { diff --git a/src/query/functions/src/aggregates/aggregate_skewness.rs b/src/query/functions/src/aggregates/aggregate_skewness.rs index 0f3cb0969992..36984528804f 100644 --- a/src/query/functions/src/aggregates/aggregate_skewness.rs +++ b/src/query/functions/src/aggregates/aggregate_skewness.rs @@ -1,17 +1,17 @@ -// Copyright 2021 Datafuse Labs -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - +// Copyright 2021 Datafuse Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + use std::alloc::Layout; use std::fmt::Display; use std::fmt::Formatter; @@ -197,17 +197,17 @@ where T: Number + AsPrimitive serialize_into_buf(writer, state) } - fn deserialize(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { + fn merge(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { let state = place.get::(); - *state = deserialize_from_slice(reader)?; - + let rhs: SkewnessState = deserialize_from_slice(reader)?; + state.merge(&rhs); Ok(()) } - fn merge(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { - let rhs = rhs.get::(); + fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { let state = place.get::(); - state.merge(rhs); + let other = rhs.get::(); + state.merge(other); Ok(()) } diff --git a/src/query/functions/src/aggregates/aggregate_stddev.rs b/src/query/functions/src/aggregates/aggregate_stddev.rs index 9a3330be09c7..b34d57aec2be 100644 --- a/src/query/functions/src/aggregates/aggregate_stddev.rs +++ b/src/query/functions/src/aggregates/aggregate_stddev.rs @@ -173,17 +173,17 @@ where T: Number + AsPrimitive serialize_into_buf(writer, state) } - fn deserialize(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { + fn merge(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { let state = place.get::(); - *state = deserialize_from_slice(reader)?; - + let rhs: AggregateStddevState = deserialize_from_slice(reader)?; + state.merge(&rhs); Ok(()) } - fn merge(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { + fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { let state = place.get::(); - let rhs = rhs.get::(); - state.merge(rhs); + let other = rhs.get::(); + state.merge(other); Ok(()) } diff --git a/src/query/functions/src/aggregates/aggregate_string_agg.rs b/src/query/functions/src/aggregates/aggregate_string_agg.rs index 42c094b4f251..b8a06c519ca2 100644 --- a/src/query/functions/src/aggregates/aggregate_string_agg.rs +++ b/src/query/functions/src/aggregates/aggregate_string_agg.rs @@ -126,16 +126,17 @@ impl AggregateFunction for AggregateStringAggFunction { Ok(()) } - fn deserialize(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { + fn merge(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { let state = place.get::(); - state.values = deserialize_from_slice(reader)?; + let rhs: StringAggState = deserialize_from_slice(reader)?; + state.values.extend_from_slice(rhs.values.as_slice()); Ok(()) } - fn merge(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { - let rhs = rhs.get::(); + fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { let state = place.get::(); - state.values.extend_from_slice(rhs.values.as_slice()); + let other = rhs.get::(); + state.values.extend_from_slice(other.values.as_slice()); Ok(()) } diff --git a/src/query/functions/src/aggregates/aggregate_sum.rs b/src/query/functions/src/aggregates/aggregate_sum.rs index 9cc668b6672b..1bf34d6a4557 100644 --- a/src/query/functions/src/aggregates/aggregate_sum.rs +++ b/src/query/functions/src/aggregates/aggregate_sum.rs @@ -39,6 +39,7 @@ use common_io::prelude::*; use ethnum::i256; use num_traits::AsPrimitive; use serde::de::DeserializeOwned; +use serde::Deserialize; use serde::Serialize; use super::aggregate_function::AggregateFunction; @@ -47,8 +48,8 @@ use super::aggregate_function_factory::AggregateFunctionDescription; use super::StateAddr; use crate::aggregates::aggregator_common::assert_unary_arguments; -pub trait SumState: Send + Sync + Default + 'static { - fn merge(&mut self, other: &mut Self) -> Result<()>; +pub trait SumState: Serialize + DeserializeOwned + Send + Sync + Default + 'static { + fn merge(&mut self, other: &Self) -> Result<()>; fn serialize(&self, writer: &mut Vec) -> Result<()>; fn deserialize(&mut self, reader: &mut &[u8]) -> Result<()>; fn accumulate(&mut self, column: &Column, validity: Option<&Bitmap>) -> Result<()>; @@ -71,9 +72,10 @@ pub trait SumState: Send + Sync + Default + 'static { ) -> Result<()>; } -#[derive(Default)] -pub struct NumberSumState { +#[derive(Default, Deserialize, Serialize)] +pub struct NumberSumState { pub value: TSum, + #[serde(skip)] _t: PhantomData, } @@ -114,7 +116,7 @@ where } #[inline(always)] - fn merge(&mut self, other: &mut Self) -> Result<()> { + fn merge(&mut self, other: &Self) -> Result<()> { self.value += other.value; Ok(()) } @@ -144,8 +146,8 @@ where } } -#[derive(Default)] -pub struct DecimalSumState { +#[derive(Default, Deserialize, Serialize)] +pub struct DecimalSumState { pub value: T, } @@ -227,7 +229,7 @@ where T: Decimal } #[inline(always)] - fn merge(&mut self, other: &mut Self) -> Result<()> { + fn merge(&mut self, other: &Self) -> Result<()> { self.add(other.value) } @@ -324,18 +326,19 @@ where State: SumState fn serialize(&self, place: StateAddr, writer: &mut Vec) -> Result<()> { let state = place.get::(); - state.serialize(writer) + serialize_into_buf(writer, state) } - fn deserialize(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { + fn merge(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { let state = place.get::(); - state.deserialize(reader) + let rhs: State = deserialize_from_slice(reader)?; + state.merge(&rhs) } - fn merge(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { - let rhs = rhs.get::(); + fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { let state = place.get::(); - state.merge(rhs) + let other = rhs.get::(); + state.merge(other) } #[allow(unused_mut)] diff --git a/src/query/functions/src/aggregates/aggregate_window_funnel.rs b/src/query/functions/src/aggregates/aggregate_window_funnel.rs index b710255df07e..d709f0fdc664 100644 --- a/src/query/functions/src/aggregates/aggregate_window_funnel.rs +++ b/src/query/functions/src/aggregates/aggregate_window_funnel.rs @@ -147,16 +147,6 @@ where T: Ord self.events_list.sort_by(cmp); } } - - fn serialize(&self, writer: &mut Vec) -> Result<()> { - serialize_into_buf(writer, self) - } - - fn deserialize(&mut self, reader: &mut &[u8]) -> Result<()> { - *self = deserialize_from_slice(reader)?; - - Ok(()) - } } #[derive(Clone)] @@ -287,19 +277,20 @@ where fn serialize(&self, place: StateAddr, writer: &mut Vec) -> Result<()> { let state = place.get::>(); - AggregateWindowFunnelState::::serialize(state, writer) + serialize_into_buf(writer, state) } - fn deserialize(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { + fn merge(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { let state = place.get::>(); - state.deserialize(reader) + let mut rhs: AggregateWindowFunnelState = deserialize_from_slice(reader)?; + state.merge(&mut rhs); + Ok(()) } - fn merge(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { - let rhs = rhs.get::>(); + fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { let state = place.get::>(); - - state.merge(rhs); + let other = rhs.get::>(); + state.merge(other); Ok(()) } diff --git a/src/query/service/src/pipelines/processors/transforms/aggregator/aggregate_cell.rs b/src/query/service/src/pipelines/processors/transforms/aggregator/aggregate_cell.rs index 0245729e701c..a9b37089805f 100644 --- a/src/query/service/src/pipelines/processors/transforms/aggregator/aggregate_cell.rs +++ b/src/query/service/src/pipelines/processors/transforms/aggregator/aggregate_cell.rs @@ -32,7 +32,6 @@ pub struct HashTableCell { pub hashtable: T::HashTable, pub arena: Area, pub arena_holders: Vec, - pub temp_values: Vec< as HashtableLike>::Value>, pub _dropper: Option>>, } @@ -44,10 +43,6 @@ impl Drop for HashTableCell fn drop(&mut self) { if let Some(dropper) = self._dropper.take() { dropper.destroy(&mut self.hashtable); - - for value in &self.temp_values { - dropper.destroy_value(value) - } } } } @@ -60,7 +55,6 @@ impl HashTableCell { HashTableCell:: { hashtable: inner, arena_holders: vec![], - temp_values: vec![], _dropper: Some(_dropper), arena: Area::create(), } diff --git a/src/query/service/src/pipelines/processors/transforms/aggregator/aggregate_meta.rs b/src/query/service/src/pipelines/processors/transforms/aggregator/aggregate_meta.rs index 57f3e4309d0d..e391f730b202 100644 --- a/src/query/service/src/pipelines/processors/transforms/aggregator/aggregate_meta.rs +++ b/src/query/service/src/pipelines/processors/transforms/aggregator/aggregate_meta.rs @@ -21,7 +21,6 @@ use common_expression::BlockMetaInfoPtr; use common_expression::Column; use common_expression::DataBlock; -use crate::pipelines::processors::transforms::group_by::ArenaHolder; use crate::pipelines::processors::transforms::group_by::HashMethodBounds; use crate::pipelines::processors::transforms::group_by::PartitionedHashMethod; use crate::pipelines::processors::transforms::HashTableCell; @@ -29,7 +28,6 @@ use crate::pipelines::processors::transforms::HashTableCell; pub struct HashTablePayload { pub bucket: isize, pub cell: HashTableCell, - pub arena_holder: ArenaHolder, } pub struct SerializedPayload { @@ -66,7 +64,6 @@ impl AggregateMeta::HashTable(HashTablePayload { cell, bucket, - arena_holder: ArenaHolder::create(None), })) } @@ -83,7 +80,6 @@ impl AggregateMeta::Spilling(HashTablePayload { cell, bucket: 0, - arena_holder: ArenaHolder::create(None), })) } diff --git a/src/query/service/src/pipelines/processors/transforms/aggregator/transform_aggregate_final.rs b/src/query/service/src/pipelines/processors/transforms/aggregator/transform_aggregate_final.rs index d738ed0905e9..e252df4b4e7b 100644 --- a/src/query/service/src/pipelines/processors/transforms/aggregator/transform_aggregate_final.rs +++ b/src/query/service/src/pipelines/processors/transforms/aggregator/transform_aggregate_final.rs @@ -71,8 +71,6 @@ where Method: HashMethodBounds let hashtable = self.method.create_hash_table::(arena)?; let _dropper = AggregateHashTableDropper::create(self.params.clone()); let mut hash_cell = HashTableCell::::create(hashtable, _dropper); - let temp_place = self.params.alloc_layout(&mut hash_cell.arena); - hash_cell.temp_values.push(temp_place.addr()); for bucket_data in data { match bucket_data { @@ -155,20 +153,10 @@ where Method: HashMethodBounds for (idx, aggregate_function) in aggregate_functions.iter().enumerate() { let final_place = place.next(offsets_aggregate_states[idx]); - let state_place = temp_place.next(offsets_aggregate_states[idx]); let mut data = unsafe { states_binary_columns[idx].index_unchecked(*row) }; - aggregate_function.deserialize(state_place, &mut data)?; - aggregate_function.merge(final_place, state_place)?; - if aggregate_function.need_manual_drop_state() { - unsafe { - // State may allocate memory out of the arena, - // drop state to avoid memory leak. - aggregate_function.drop_state(state_place); - } - aggregate_function.init_state(state_place); - } + aggregate_function.merge(final_place, &mut data)?; } } } @@ -193,7 +181,7 @@ where Method: HashMethodBounds { let final_place = place.next(offsets_aggregate_states[idx]); let state_place = old_place.next(offsets_aggregate_states[idx]); - aggregate_function.merge(final_place, state_place)?; + aggregate_function.merge_states(final_place, state_place)?; } } }, diff --git a/src/query/service/src/pipelines/processors/transforms/aggregator/transform_aggregate_partial.rs b/src/query/service/src/pipelines/processors/transforms/aggregator/transform_aggregate_partial.rs index 4cf479fb5ace..1bca822ea3eb 100644 --- a/src/query/service/src/pipelines/processors/transforms/aggregator/transform_aggregate_partial.rs +++ b/src/query/service/src/pipelines/processors/transforms/aggregator/transform_aggregate_partial.rs @@ -110,9 +110,6 @@ pub struct TransformPartialAggregate { hash_table: HashTable, params: Arc, - - /// A temporary place to hold aggregating state from index data. - temp_place: StateAddr, } impl TransformPartialAggregate { @@ -143,7 +140,6 @@ impl TransformPartialAggregate { params, hash_table, settings: AggregateSettings::try_from(ctx)?, - temp_place: StateAddr::new(0), }, )) } @@ -220,19 +216,9 @@ impl TransformPartialAggregate { .unwrap() .as_string() .unwrap(); - let state_place = self.temp_place.next(offset); for (row, mut raw_state) in agg_state.iter().enumerate() { let place = &places[row]; - function.deserialize(state_place, &mut raw_state)?; - function.merge(place.next(offset), state_place)?; - if function.need_manual_drop_state() { - unsafe { - // State may allocate memory out of the arena, - // drop state to avoid memory leak. - function.drop_state(state_place); - } - function.init_state(state_place); - } + function.merge(place.next(offset), &mut raw_state)?; } } @@ -277,9 +263,6 @@ impl TransformPartialAggregate { } if is_agg_index_block { - if self.temp_place.addr() == 0 { - self.temp_place = self.params.alloc_layout(&mut hashtable.arena); - } self.execute_agg_index_block(&block, &places) } else { Self::execute(&self.params, &block, &places) @@ -300,9 +283,6 @@ impl TransformPartialAggregate { } if is_agg_index_block { - if self.temp_place.addr() == 0 { - self.temp_place = self.params.alloc_layout(&mut hashtable.arena); - } self.execute_agg_index_block(&block, &places) } else { Self::execute(&self.params, &block, &places) diff --git a/src/query/service/src/pipelines/processors/transforms/aggregator/transform_single_key.rs b/src/query/service/src/pipelines/processors/transforms/aggregator/transform_single_key.rs index 73499a4cdf82..c71d7f87551e 100644 --- a/src/query/service/src/pipelines/processors/transforms/aggregator/transform_single_key.rs +++ b/src/query/service/src/pipelines/processors/transforms/aggregator/transform_single_key.rs @@ -21,6 +21,7 @@ use bumpalo::Bump; use common_catalog::plan::AggIndexMeta; use common_exception::ErrorCode; use common_exception::Result; +use common_expression::types::string::StringColumn; use common_expression::types::DataType; use common_expression::BlockEntry; use common_expression::BlockMetaInfoDowncast; @@ -45,9 +46,6 @@ pub struct PartialSingleStateAggregator { places: Vec, arg_indices: Vec>, funcs: Vec, - - /// A temporary place to hold aggregating state from index data. - temp_places: Vec, } impl PartialSingleStateAggregator { @@ -66,7 +64,6 @@ impl PartialSingleStateAggregator { let place: StateAddr = arena.alloc_layout(layout).into(); let temp_place: StateAddr = arena.alloc_layout(layout).into(); let mut places = Vec::with_capacity(params.offsets_aggregate_states.len()); - let mut temp_places = Vec::with_capacity(params.offsets_aggregate_states.len()); for (idx, func) in params.aggregate_functions.iter().enumerate() { let arg_place = place.next(params.offsets_aggregate_states[idx]); @@ -75,7 +72,6 @@ impl PartialSingleStateAggregator { let state_place = temp_place.next(params.offsets_aggregate_states[idx]); func.init_state(state_place); - temp_places.push(state_place); } Ok(AccumulatingTransformer::create( @@ -86,7 +82,6 @@ impl PartialSingleStateAggregator { places, funcs: params.aggregate_functions.clone(), arg_indices: params.aggregate_functions_arguments.clone(), - temp_places, }, )) } @@ -127,18 +122,8 @@ impl AccumulatingTransform for PartialSingleStateAggregator { .unwrap() .as_string() .unwrap(); - let state_place = self.temp_places[idx]; for (_, mut raw_state) in agg_state.iter().enumerate() { - func.deserialize(state_place, &mut raw_state)?; - func.merge(place, state_place)?; - if func.need_manual_drop_state() { - unsafe { - // State may allocate memory out of the arena, - // drop state to avoid memory leak. - func.drop_state(state_place); - } - func.init_state(state_place); - } + func.merge(place, &mut raw_state)?; } } else { func.accumulate(place, &arg_columns, None, block.num_rows())?; @@ -183,7 +168,7 @@ impl AccumulatingTransform for PartialSingleStateAggregator { pub struct FinalSingleStateAggregator { arena: Bump, layout: Layout, - to_merge_places: Vec>, + to_merge_data: Vec>, funcs: Vec, offsets_aggregate_states: Vec, } @@ -208,7 +193,7 @@ impl FinalSingleStateAggregator { arena, layout, funcs: params.aggregate_functions.clone(), - to_merge_places: vec![vec![]; params.aggregate_functions.len()], + to_merge_data: vec![vec![]; params.aggregate_functions.len()], offsets_aggregate_states: params.offsets_aggregate_states.clone(), }, )) @@ -234,17 +219,14 @@ impl AccumulatingTransform for FinalSingleStateAggregator { fn transform(&mut self, block: DataBlock) -> Result> { if !block.is_empty() { let block = block.convert_to_full(); - let places = self.new_places(); - for (index, func) in self.funcs.iter().enumerate() { + for (index, _) in self.funcs.iter().enumerate() { let binary_array = block.get_by_offset(index).value.as_column().unwrap(); let binary_array = binary_array.as_string().ok_or_else(|| { ErrorCode::IllegalDataType("binary array should be string type") })?; - let mut data = unsafe { binary_array.index_unchecked(0) }; - func.deserialize(places[index], &mut data)?; - self.to_merge_places[index].push(places[index]); + self.to_merge_data[index].push(binary_array.clone()); } } @@ -268,8 +250,9 @@ impl AccumulatingTransform for FinalSingleStateAggregator { for (index, func) in self.funcs.iter().enumerate() { let main_place = main_places[index]; - for place in self.to_merge_places[index].iter() { - func.merge(main_place, *place)?; + for col in self.to_merge_data[index].iter() { + let mut data = unsafe { col.index_unchecked(0) }; + func.merge(main_place, &mut data)?; } let array = aggr_values[index].borrow_mut(); @@ -291,14 +274,6 @@ impl AccumulatingTransform for FinalSingleStateAggregator { generate_data_block = vec![DataBlock::new_from_columns(columns)]; } - for (places, func) in self.to_merge_places.iter().zip(self.funcs.iter()) { - if func.need_manual_drop_state() { - for place in places { - unsafe { func.drop_state(*place) } - } - } - } - Ok(generate_data_block) } } diff --git a/src/query/service/src/pipelines/processors/transforms/group_by/aggregator_polymorphic_keys.rs b/src/query/service/src/pipelines/processors/transforms/group_by/aggregator_polymorphic_keys.rs index 19f96812298c..67b26e9169d4 100644 --- a/src/query/service/src/pipelines/processors/transforms/group_by/aggregator_polymorphic_keys.rs +++ b/src/query/service/src/pipelines/processors/transforms/group_by/aggregator_polymorphic_keys.rs @@ -582,7 +582,6 @@ impl PartitionedHashMethod { let arena = std::mem::replace(&mut cell.arena, Area::create()); cell.arena_holders.push(ArenaHolder::create(Some(arena))); - let temp_values = cell.temp_values.to_vec(); let arena_holders = cell.arena_holders.to_vec(); let _old_dropper = cell._dropper.clone().unwrap(); @@ -594,7 +593,6 @@ impl PartitionedHashMethod { // create new HashTableCell before take_old_dropper - may double free memory let _old_dropper = cell._dropper.take(); let mut cell = HashTableCell::create(partitioned_hashtable, _new_dropper); - cell.temp_values = temp_values; cell.arena_holders = arena_holders; Ok(cell) }