Skip to content

Commit

Permalink
chore(query): refactor agg state merge
Browse files Browse the repository at this point in the history
  • Loading branch information
sundy-li committed Oct 9, 2023
1 parent 776db07 commit 8395f90
Show file tree
Hide file tree
Showing 32 changed files with 204 additions and 77 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,20 @@ impl<const NULLABLE_RESULT: bool> AggregateFunction for AggregateNullUnaryAdapto
Ok(())
}

fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> {
if self.get_flag(place) == 0 {
// initial the state to remove the dirty stats
self.init_state(place);
}

if self.get_flag(rhs) == 1 {
self.set_flag(place, 1);
self.nested.merge_states(place, rhs)?;
}

Ok(())
}

fn merge_result(&self, place: StateAddr, builder: &mut ColumnBuilder) -> Result<()> {
if NULLABLE_RESULT {
if self.get_flag(place) == 1 {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,19 @@ impl<const NULLABLE_RESULT: bool> AggregateFunction
Ok(())
}

fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> {
if self.get_flag(place) == 0 {
// initial the state to remove the dirty stats
self.init_state(place);
}

if self.get_flag(rhs) == 1 {
self.set_flag(place, 1);
self.nested.merge_states(place, rhs)?;
}
Ok(())
}

fn merge_result(&self, place: StateAddr, builder: &mut ColumnBuilder) -> Result<()> {
if NULLABLE_RESULT {
if self.get_flag(place) == 1 {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,13 @@ impl AggregateFunction for AggregateFunctionOrNullAdaptor {
Ok(())
}

fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> {
self.inner.merge_states(place, rhs)?;
let flag = self.get_flag(place) > 0 || self.get_flag(rhs) > 0;
self.set_flag(place, u8::from(flag));
Ok(())
}

fn merge_result(&self, place: StateAddr, builder: &mut ColumnBuilder) -> Result<()> {
match builder {
ColumnBuilder::Nullable(inner_mut) => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,13 @@ where for<'a> T::ScalarRef<'a>: Hash
Ok(())
}

fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> {
let state = place.get::<AggregateApproxCountDistinctState<T::ScalarRef<'_>>>();
let other = rhs.get::<AggregateApproxCountDistinctState<T::ScalarRef<'_>>>();
state.hll.union(&other.hll);
Ok(())
}

fn merge_result(&self, place: StateAddr, builder: &mut ColumnBuilder) -> Result<()> {
let state = place.get::<AggregateApproxCountDistinctState<T::ScalarRef<'_>>>();
let builder = NumberType::<u64>::try_downcast_builder(builder).unwrap();
Expand Down
6 changes: 6 additions & 0 deletions src/query/functions/src/aggregates/aggregate_arg_min_max.rs
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,12 @@ where
state.merge(&rhs)
}

fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> {
let state = place.get::<State>();
let other = rhs.get::<State>();
state.merge(other)
}

fn merge_result(&self, place: StateAddr, builder: &mut ColumnBuilder) -> Result<()> {
let state = place.get::<State>();
state.merge_result(builder)
Expand Down
6 changes: 6 additions & 0 deletions src/query/functions/src/aggregates/aggregate_array_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,12 @@ where
state.merge(&rhs)
}

fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> {
let state = place.get::<State>();
let other = rhs.get::<State>();
state.merge(other)
}

fn merge_result(&self, place: StateAddr, builder: &mut ColumnBuilder) -> Result<()> {
let state = place.get::<State>();
state.merge_result(builder)
Expand Down
12 changes: 12 additions & 0 deletions src/query/functions/src/aggregates/aggregate_array_moving.rs
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,12 @@ where State: SumState
state.merge(&rhs)
}

fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> {
let state = place.get::<State>();
let other = rhs.get::<State>();
state.merge(other)
}

fn merge_result(&self, place: StateAddr, builder: &mut ColumnBuilder) -> Result<()> {
let state = place.get::<State>();
state.merge_avg_result(builder, 0_u64, self.scale_add, &self.window_size)
Expand Down Expand Up @@ -654,6 +660,12 @@ where State: SumState
state.merge(&rhs)
}

fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> {
let state = place.get::<State>();
let other = rhs.get::<State>();
state.merge(other)
}

fn merge_result(&self, place: StateAddr, builder: &mut ColumnBuilder) -> Result<()> {
let state = place.get::<State>();
state.merge_result(builder, &self.window_size)
Expand Down
7 changes: 7 additions & 0 deletions src/query/functions/src/aggregates/aggregate_avg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,13 @@ where T: SumState
state.value.merge(&rhs.value)
}

fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> {
let state = place.get::<AvgState<T>>();
let rhs = rhs.get::<AvgState<T>>();
state.count += rhs.count;
state.value.merge(&rhs.value)
}

#[allow(unused_mut)]
fn merge_result(&self, place: StateAddr, builder: &mut ColumnBuilder) -> Result<()> {
let state = place.get::<AvgState<T>>();
Expand Down
14 changes: 14 additions & 0 deletions src/query/functions/src/aggregates/aggregate_bitmap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,16 @@ where
Ok(())
}

fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> {
let state = place.get::<BitmapAggState>();
let other = rhs.get::<BitmapAggState>();

if let Some(rb) = other.rb.take() {
state.add::<OP>(rb);
}
Ok(())
}

fn merge_result(&self, place: StateAddr, builder: &mut ColumnBuilder) -> Result<()> {
AGG::merge_result(place, builder)
}
Expand Down Expand Up @@ -479,6 +489,10 @@ where
self.inner.merge(place, reader)
}

fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> {
self.inner.merge_states(place, rhs)
}

fn merge_result(&self, place: StateAddr, builder: &mut ColumnBuilder) -> Result<()> {
self.inner.merge_result(place, builder)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,12 @@ where State: DistinctStateFunc
state.merge(&rhs)
}

fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> {
let state = place.get::<State>();
let other = rhs.get::<State>();
state.merge(other)
}

#[allow(unused_mut)]
fn merge_result(&self, place: StateAddr, builder: &mut ColumnBuilder) -> Result<()> {
let state = place.get::<State>();
Expand Down
4 changes: 4 additions & 0 deletions src/query/functions/src/aggregates/aggregate_combinator_if.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,10 @@ impl AggregateFunction for AggregateIfCombinator {
self.nested.merge(place, reader)
}

fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> {
self.nested.merge_states(place, rhs)
}

fn merge_result(&self, place: StateAddr, builder: &mut ColumnBuilder) -> Result<()> {
self.nested.merge_result(place, builder)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,10 @@ impl AggregateFunction for AggregateStateCombinator {
self.nested.merge(place, reader)
}

fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> {
self.nested.merge_states(place, rhs)
}

fn merge_result(&self, place: StateAddr, builder: &mut ColumnBuilder) -> Result<()> {
let str_builder = builder.as_string_mut().unwrap();
self.serialize(place, &mut str_builder.data)?;
Expand Down
7 changes: 7 additions & 0 deletions src/query/functions/src/aggregates/aggregate_count.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,13 @@ impl AggregateFunction for AggregateCountFunction {
Ok(())
}

fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> {
let state = place.get::<AggregateCountState>();
let other = rhs.get::<AggregateCountState>();
state.count += other.count;
Ok(())
}

fn batch_merge_result(&self, places: &[StateAddr], builder: &mut ColumnBuilder) -> Result<()> {
match builder {
ColumnBuilder::Number(NumberColumnBuilder::UInt64(builder)) => {
Expand Down
7 changes: 7 additions & 0 deletions src/query/functions/src/aggregates/aggregate_covariance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,13 @@ where
Ok(())
}

fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> {
let state = place.get::<AggregateCovarianceState>();
let other = rhs.get::<AggregateCovarianceState>();
state.merge(other);
Ok(())
}

#[allow(unused_mut)]
fn merge_result(&self, place: StateAddr, builder: &mut ColumnBuilder) -> Result<()> {
let state = place.get::<AggregateCovarianceState>();
Expand Down
2 changes: 2 additions & 0 deletions src/query/functions/src/aggregates/aggregate_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ pub trait AggregateFunction: fmt::Display + Sync + Send {

fn merge(&self, _place: StateAddr, _reader: &mut &[u8]) -> Result<()>;

fn merge_states(&self, _place: StateAddr, _rhs: StateAddr) -> Result<()>;

fn batch_merge_result(&self, places: &[StateAddr], builder: &mut ColumnBuilder) -> Result<()> {
for place in places {
self.merge_result(*place, builder)?;
Expand Down
7 changes: 7 additions & 0 deletions src/query/functions/src/aggregates/aggregate_kurtosis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,13 @@ where T: Number + AsPrimitive<f64>
Ok(())
}

fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> {
let state = place.get::<KurtosisState>();
let other = rhs.get::<KurtosisState>();
state.merge(other);
Ok(())
}

fn merge_result(&self, place: StateAddr, builder: &mut ColumnBuilder) -> Result<()> {
let state = place.get::<KurtosisState>();
state.merge_result(builder)
Expand Down
6 changes: 6 additions & 0 deletions src/query/functions/src/aggregates/aggregate_min_max_any.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,12 @@ where
state.merge(&rhs)
}

fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> {
let state = place.get::<State>();
let other = rhs.get::<State>();
state.merge(other)
}

fn merge_result(&self, place: StateAddr, builder: &mut ColumnBuilder) -> Result<()> {
let state = place.get::<State>();
state.merge_result(builder)
Expand Down
4 changes: 4 additions & 0 deletions src/query/functions/src/aggregates/aggregate_null_result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ impl AggregateFunction for AggregateNullResultFunction {
Ok(())
}

fn merge_states(&self, _place: StateAddr, _rhs: StateAddr) -> Result<()> {
Ok(())
}

fn merge_result(&self, _place: StateAddr, array: &mut ColumnBuilder) -> Result<()> {
AnyType::push_default(array);
Ok(())
Expand Down
6 changes: 6 additions & 0 deletions src/query/functions/src/aggregates/aggregate_quantile_cont.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,12 @@ where T: Number + AsPrimitive<f64>
state.merge(&rhs)
}

fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> {
let state = place.get::<QuantileContState>();
let other = rhs.get::<QuantileContState>();
state.merge(other)
}

fn merge_result(&self, place: StateAddr, builder: &mut ColumnBuilder) -> Result<()> {
let state = place.get::<QuantileContState>();
state.merge_result(builder, self.levels.clone())
Expand Down
6 changes: 6 additions & 0 deletions src/query/functions/src/aggregates/aggregate_quantile_disc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,12 @@ where
state.merge(&rhs)
}

fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> {
let state = place.get::<State>();
let other = rhs.get::<State>();
state.merge(other)
}

fn merge_result(&self, place: StateAddr, builder: &mut ColumnBuilder) -> Result<()> {
let state = place.get::<State>();
state.merge_result(builder, self.levels.clone())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,12 @@ where T: Number + AsPrimitive<f64>
state.merge(&mut rhs)
}

fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> {
let state = place.get::<QuantileTDigestState>();
let other = rhs.get::<QuantileTDigestState>();
state.merge(other)
}

fn merge_result(&self, place: StateAddr, builder: &mut ColumnBuilder) -> Result<()> {
let state = place.get::<QuantileTDigestState>();
state.merge_result(builder, self.levels.clone())
Expand Down
7 changes: 7 additions & 0 deletions src/query/functions/src/aggregates/aggregate_retention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,13 @@ impl AggregateFunction for AggregateRetentionFunction {
Ok(())
}

fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> {
let state = place.get::<AggregateRetentionState>();
let other = rhs.get::<AggregateRetentionState>();
state.merge(other);
Ok(())
}

#[allow(unused_mut)]
fn merge_result(&self, place: StateAddr, builder: &mut ColumnBuilder) -> Result<()> {
let state = place.get::<AggregateRetentionState>();
Expand Down
7 changes: 7 additions & 0 deletions src/query/functions/src/aggregates/aggregate_skewness.rs
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,13 @@ where T: Number + AsPrimitive<f64>
Ok(())
}

fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> {
let state = place.get::<SkewnessState>();
let other = rhs.get::<SkewnessState>();
state.merge(other);
Ok(())
}

fn merge_result(&self, place: StateAddr, builder: &mut ColumnBuilder) -> Result<()> {
let state = place.get::<SkewnessState>();
state.merge_result(builder)
Expand Down
7 changes: 7 additions & 0 deletions src/query/functions/src/aggregates/aggregate_stddev.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,13 @@ where T: Number + AsPrimitive<f64>
Ok(())
}

fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> {
let state = place.get::<AggregateStddevState>();
let other = rhs.get::<AggregateStddevState>();
state.merge(other);
Ok(())
}

#[allow(unused_mut)]
fn merge_result(&self, place: StateAddr, builder: &mut ColumnBuilder) -> Result<()> {
let state = place.get::<AggregateStddevState>();
Expand Down
7 changes: 7 additions & 0 deletions src/query/functions/src/aggregates/aggregate_string_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,13 @@ impl AggregateFunction for AggregateStringAggFunction {
Ok(())
}

fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> {
let state = place.get::<StringAggState>();
let other = rhs.get::<StringAggState>();
state.values.extend_from_slice(other.values.as_slice());
Ok(())
}

fn merge_result(&self, place: StateAddr, builder: &mut ColumnBuilder) -> Result<()> {
let state = place.get::<StringAggState>();
let builder = StringType::try_downcast_builder(builder).unwrap();
Expand Down
6 changes: 6 additions & 0 deletions src/query/functions/src/aggregates/aggregate_sum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,12 @@ where State: SumState
state.merge(&rhs)
}

fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> {
let state = place.get::<State>();
let other = rhs.get::<State>();
state.merge(other)
}

#[allow(unused_mut)]
fn merge_result(&self, place: StateAddr, builder: &mut ColumnBuilder) -> Result<()> {
let state = place.get::<State>();
Expand Down
7 changes: 7 additions & 0 deletions src/query/functions/src/aggregates/aggregate_window_funnel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,13 @@ where
Ok(())
}

fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> {
let state = place.get::<AggregateWindowFunnelState<T::Scalar>>();
let other = rhs.get::<AggregateWindowFunnelState<T::Scalar>>();
state.merge(other);
Ok(())
}

#[allow(unused_mut)]
fn merge_result(&self, place: StateAddr, builder: &mut ColumnBuilder) -> Result<()> {
let builder = UInt8Type::try_downcast_builder(builder).unwrap();
Expand Down
Loading

0 comments on commit 8395f90

Please sign in to comment.