Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(query): refactor agg state merge #13153

Merged
merged 13 commits into from
Oct 11, 2023
45 changes: 8 additions & 37 deletions src/query/expression/src/utils/block_debug.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,46 +177,17 @@ fn create_box_table(
};

let mut res_vec: Vec<Vec<String>> = vec![];
let top_collection = results.first().unwrap();
let top_rows = top_collection.num_rows().min(top_rows);
if bottom_rows == 0 {
for block in results {
for row in 0..block.num_rows() {
let mut v = vec![];
for block_entry in block.columns() {
let value = block_entry.value.index(row).unwrap().to_string();
if replace_newline {
v.push(value.to_string().replace('\n', "\\n"));
} else {
v.push(value.to_string());
}
}
res_vec.push(v);
}
}
} else {
let bottom_collection = results.last().unwrap();
for row in 0..top_rows {
let mut v = vec![];
for block_entry in top_collection.columns() {
let value = block_entry.value.index(row).unwrap().to_string();
if replace_newline {
v.push(value.to_string().replace('\n', "\\n"));
} else {
v.push(value.to_string());
}

sundy-li marked this conversation as resolved.
Show resolved Hide resolved
let mut rows = 0;
for block in results {
for row in 0..block.num_rows() {
rows += 1;
if rows > top_rows && rows <= row_count - bottom_rows {
continue;
}
res_vec.push(v);
}
let take_num = if bottom_collection.num_rows() > bottom_rows {
bottom_collection.num_rows() - bottom_rows
} else {
0
};

for row in take_num..bottom_collection.num_rows() {
let mut v = vec![];
for block_entry in top_collection.columns() {
for block_entry in block.columns() {
let value = block_entry.value.index(row).unwrap().to_string();
if replace_newline {
v.push(value.to_string().replace('\n', "\\n"));
Expand Down
39 changes: 39 additions & 0 deletions src/query/expression/tests/it/block.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
use common_expression::block_debug::box_render;
use common_expression::types::string::StringColumnBuilder;
use common_expression::types::DataType;
use common_expression::types::Int32Type;
use common_expression::types::NumberDataType;
use common_expression::Column;
use common_expression::DataField;
use common_expression::DataSchemaRefExt;
use common_expression::FromData;

use crate::common::new_block;

Expand All @@ -17,3 +24,35 @@ fn test_split_block() {
.collect::<Vec<_>>();
assert_eq!(sizes, vec![3, 3, 4]);
}

#[test]
fn test_box_render_block() {
let value = b"abc";
let n = 10;
let block = new_block(&[
Int32Type::from_data(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]),
Column::String(StringColumnBuilder::repeat(&value[..], n).build()),
]);

let schema = DataSchemaRefExt::create(vec![
DataField::new("a", DataType::Number(NumberDataType::Int32)),
DataField::new("e", DataType::String),
]);
let d = box_render(&schema, &[block], 5, 1000, 1000, true).unwrap();
let expected = r#"┌────────────────────┐
│ a │ e │
│ Int32 │ String │
├───────────┼────────┤
│ 1 │ 'abc' │
│ 2 │ 'abc' │
│ 3 │ 'abc' │
│ · │ · │
│ · │ · │
│ · │ · │
│ 9 │ 'abc' │
│ 10 │ 'abc' │
│ 10 rows │ │
│ (5 shown) │ │
└────────────────────┘"#;
assert_eq!(d, expected);
}
Original file line number Diff line number Diff line change
Expand Up @@ -202,28 +202,34 @@ impl<const NULLABLE_RESULT: bool> AggregateFunction for AggregateNullUnaryAdapto
Ok(())
}

fn deserialize(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> {
fn merge(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> {
if self.get_flag(place) == 0 {
// initial the state to remove the dirty stats
self.init_state(place);
}

if NULLABLE_RESULT {
let flag = reader[reader.len() - 1];
self.nested
.deserialize(place, &mut &reader[..reader.len() - 1])?;
self.set_flag(place, flag);
if flag == 1 {
self.set_flag(place, 1);
self.nested.merge(place, &mut &reader[..reader.len() - 1])?;
}
} else {
self.nested.deserialize(place, reader)?;
self.nested.merge(place, reader)?;
}

Ok(())
}

fn merge(&self, place: StateAddr, rhs: StateAddr) -> Result<()> {
fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> {
if self.get_flag(place) == 0 {
// initial the state to remove the dirty stats
self.init_state(place);
}

if self.get_flag(rhs) == 1 {
self.set_flag(place, 1);
self.nested.merge(place, rhs)?;
self.nested.merge_states(place, rhs)?;
}

Ok(())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,27 +207,33 @@ impl<const NULLABLE_RESULT: bool> AggregateFunction
Ok(())
}

fn deserialize(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> {
fn merge(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> {
if self.get_flag(place) == 0 {
// initial the state to remove the dirty stats
self.init_state(place);
}

if NULLABLE_RESULT {
let flag = reader[reader.len() - 1];
self.nested
.deserialize(place, &mut &reader[..reader.len() - 1])?;
self.set_flag(place, flag);
if flag == 1 {
self.set_flag(place, flag);
self.nested.merge(place, &mut &reader[..reader.len() - 1])?;
}
} else {
self.nested.deserialize(place, reader)?;
self.nested.merge(place, reader)?;
}
Ok(())
}

fn merge(&self, place: StateAddr, rhs: StateAddr) -> Result<()> {
fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> {
if self.get_flag(place) == 0 {
// initial the state to remove the dirty stats
self.init_state(place);
}

if self.get_flag(rhs) == 1 {
self.set_flag(place, 1);
self.nested.merge(place, rhs)?;
self.nested.merge_states(place, rhs)?;
}
Ok(())
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,16 +173,16 @@ impl AggregateFunction for AggregateFunctionOrNullAdaptor {
}

#[inline]
fn deserialize(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> {
let flag = reader[reader.len() - 1];
sundy-li marked this conversation as resolved.
Show resolved Hide resolved
self.inner
.deserialize(place, &mut &reader[..reader.len() - 1])?;
self.set_flag(place, flag);
fn merge(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> {
let flag = self.get_flag(place) > 0 || reader[reader.len() - 1] > 0;

self.inner.merge(place, &mut &reader[..reader.len() - 1])?;
self.set_flag(place, flag as u8);
Ok(())
}

fn merge(&self, place: StateAddr, rhs: StateAddr) -> Result<()> {
self.inner.merge(place, rhs)?;
fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> {
self.inner.merge_states(place, rhs)?;
let flag = self.get_flag(place) > 0 || self.get_flag(rhs) > 0;
self.set_flag(place, u8::from(flag));
Ok(())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,17 +126,18 @@ where for<'a> T::ScalarRef<'a>: Hash
serialize_into_buf(writer, &state.hll)
}

fn deserialize(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> {
fn merge(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> {
let state = place.get::<AggregateApproxCountDistinctState<T::ScalarRef<'_>>>();
state.hll = deserialize_from_slice(reader)?;
let hll: HyperLogLog<T::ScalarRef<'_>> = deserialize_from_slice(reader)?;
state.hll.union(&hll);

Ok(())
}

fn merge(&self, place: StateAddr, rhs: StateAddr) -> Result<()> {
fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> {
let state = place.get::<AggregateApproxCountDistinctState<T::ScalarRef<'_>>>();
let rhs = rhs.get::<AggregateApproxCountDistinctState<T::ScalarRef<'_>>>();
state.hll.union(&rhs.hll);

let other = rhs.get::<AggregateApproxCountDistinctState<T::ScalarRef<'_>>>();
state.hll.union(&other.hll);
Ok(())
}

Expand Down
29 changes: 11 additions & 18 deletions src/query/functions/src/aggregates/aggregate_arg_min_max.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ use crate::with_simple_no_number_mapped_type;
// State for arg_min(arg, val) and arg_max(arg, val)
// A: ValueType for arg.
// V: ValueType for val.
pub trait AggregateArgMinMaxState<A: ValueType, V: ValueType>: Send + Sync + 'static {
pub trait AggregateArgMinMaxState<A: ValueType, V: ValueType>:
Serialize + DeserializeOwned + Send + Sync + 'static
{
fn new() -> Self;
fn add(&mut self, value: V::ScalarRef<'_>, data: Scalar);
fn add_batch(
Expand All @@ -60,8 +62,6 @@ pub trait AggregateArgMinMaxState<A: ValueType, V: ValueType>: Send + Sync + 'st
) -> Result<()>;

fn merge(&mut self, rhs: &Self) -> Result<()>;
fn serialize(&self, writer: &mut Vec<u8>) -> Result<()>;
fn deserialize(&mut self, reader: &mut &[u8]) -> Result<()>;
fn merge_result(&mut self, column: &mut ColumnBuilder) -> Result<()>;
}

Expand Down Expand Up @@ -180,15 +180,6 @@ where
Ok(())
}

fn serialize(&self, writer: &mut Vec<u8>) -> Result<()> {
serialize_into_buf(writer, self)
}

fn deserialize(&mut self, reader: &mut &[u8]) -> Result<()> {
*self = deserialize_from_slice(reader)?;
Ok(())
}

fn merge_result(&mut self, builder: &mut ColumnBuilder) -> Result<()> {
if self.value.is_some() {
if let Some(inner) = A::try_downcast_builder(builder) {
Expand Down Expand Up @@ -291,18 +282,20 @@ where

fn serialize(&self, place: StateAddr, writer: &mut Vec<u8>) -> Result<()> {
let state = place.get::<State>();
state.serialize(writer)
serialize_into_buf(writer, state)
}

fn deserialize(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> {
fn merge(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> {
let state = place.get::<State>();
state.deserialize(reader)
let rhs: State = deserialize_from_slice(reader)?;

state.merge(&rhs)
}

fn merge(&self, place: StateAddr, rhs: StateAddr) -> Result<()> {
let rhs = rhs.get::<State>();
fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> {
let state = place.get::<State>();
state.merge(rhs)
let other = rhs.get::<State>();
state.merge(other)
}

fn merge_result(&self, place: StateAddr, builder: &mut ColumnBuilder) -> Result<()> {
Expand Down
32 changes: 8 additions & 24 deletions src/query/functions/src/aggregates/aggregate_array_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,15 +124,6 @@ where
builder.push(array_value);
Ok(())
}

fn serialize(&self, writer: &mut Vec<u8>) -> Result<()> {
serialize_into_buf(writer, self)
}

fn deserialize(&mut self, reader: &mut &[u8]) -> Result<()> {
self.values = deserialize_from_slice(reader)?;
Ok(())
}
}

#[derive(Serialize, Deserialize, Debug)]
Expand Down Expand Up @@ -234,15 +225,6 @@ where
builder.push(array_value);
Ok(())
}

fn serialize(&self, writer: &mut Vec<u8>) -> Result<()> {
serialize_into_buf(writer, self)
}

fn deserialize(&mut self, reader: &mut &[u8]) -> Result<()> {
self.values = deserialize_from_slice(reader)?;
Ok(())
}
}

#[derive(Clone)]
Expand Down Expand Up @@ -356,18 +338,20 @@ where

fn serialize(&self, place: StateAddr, writer: &mut Vec<u8>) -> Result<()> {
let state = place.get::<State>();
state.serialize(writer)
serialize_into_buf(writer, state)
}

fn deserialize(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> {
fn merge(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> {
let state = place.get::<State>();
state.deserialize(reader)
let rhs: State = deserialize_from_slice(reader)?;

state.merge(&rhs)
}

fn merge(&self, place: StateAddr, rhs: StateAddr) -> Result<()> {
let rhs = rhs.get::<State>();
fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> {
let state = place.get::<State>();
state.merge(rhs)
let other = rhs.get::<State>();
state.merge(other)
}

fn merge_result(&self, place: StateAddr, builder: &mut ColumnBuilder) -> Result<()> {
Expand Down
Loading
Loading