Skip to content

Commit

Permalink
chore(query): refactor agg state merge (databendlabs#13153)
Browse files Browse the repository at this point in the history
* chore(query): refactor agg state merge

* chore(query): refactor agg state merge

* chore(query): refactor agg state merge

* chore(query): refactor agg state merge

* chore(query): refactor agg state merge

* chore(query): refactor agg state merge

* chore(query): fix render bug

* feat(query): add test test_box_render_block

* feat(query): remove useless holder
  • Loading branch information
sundy-li authored and andylokandy committed Nov 27, 2023
1 parent 89af54d commit 0d880f8
Show file tree
Hide file tree
Showing 37 changed files with 350 additions and 422 deletions.
45 changes: 8 additions & 37 deletions src/query/expression/src/utils/block_debug.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,46 +177,17 @@ fn create_box_table(
};

let mut res_vec: Vec<Vec<String>> = vec![];
let top_collection = results.first().unwrap();
let top_rows = top_collection.num_rows().min(top_rows);
if bottom_rows == 0 {
for block in results {
for row in 0..block.num_rows() {
let mut v = vec![];
for block_entry in block.columns() {
let value = block_entry.value.index(row).unwrap().to_string();
if replace_newline {
v.push(value.to_string().replace('\n', "\\n"));
} else {
v.push(value.to_string());
}
}
res_vec.push(v);
}
}
} else {
let bottom_collection = results.last().unwrap();
for row in 0..top_rows {
let mut v = vec![];
for block_entry in top_collection.columns() {
let value = block_entry.value.index(row).unwrap().to_string();
if replace_newline {
v.push(value.to_string().replace('\n', "\\n"));
} else {
v.push(value.to_string());
}

let mut rows = 0;
for block in results {
for row in 0..block.num_rows() {
rows += 1;
if rows > top_rows && rows <= row_count - bottom_rows {
continue;
}
res_vec.push(v);
}
let take_num = if bottom_collection.num_rows() > bottom_rows {
bottom_collection.num_rows() - bottom_rows
} else {
0
};

for row in take_num..bottom_collection.num_rows() {
let mut v = vec![];
for block_entry in top_collection.columns() {
for block_entry in block.columns() {
let value = block_entry.value.index(row).unwrap().to_string();
if replace_newline {
v.push(value.to_string().replace('\n', "\\n"));
Expand Down
39 changes: 39 additions & 0 deletions src/query/expression/tests/it/block.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
use common_expression::block_debug::box_render;
use common_expression::types::string::StringColumnBuilder;
use common_expression::types::DataType;
use common_expression::types::Int32Type;
use common_expression::types::NumberDataType;
use common_expression::Column;
use common_expression::DataField;
use common_expression::DataSchemaRefExt;
use common_expression::FromData;

use crate::common::new_block;

Expand All @@ -17,3 +24,35 @@ fn test_split_block() {
.collect::<Vec<_>>();
assert_eq!(sizes, vec![3, 3, 4]);
}

#[test]
fn test_box_render_block() {
let value = b"abc";
let n = 10;
let block = new_block(&[
Int32Type::from_data(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]),
Column::String(StringColumnBuilder::repeat(&value[..], n).build()),
]);

let schema = DataSchemaRefExt::create(vec![
DataField::new("a", DataType::Number(NumberDataType::Int32)),
DataField::new("e", DataType::String),
]);
let d = box_render(&schema, &[block], 5, 1000, 1000, true).unwrap();
let expected = r#"┌────────────────────┐
│ a │ e │
│ Int32 │ String │
├───────────┼────────┤
│ 1 │ 'abc' │
│ 2 │ 'abc' │
│ 3 │ 'abc' │
│ · │ · │
│ · │ · │
│ · │ · │
│ 9 │ 'abc' │
│ 10 │ 'abc' │
│ 10 rows │ │
│ (5 shown) │ │
└────────────────────┘"#;
assert_eq!(d, expected);
}
Original file line number Diff line number Diff line change
Expand Up @@ -202,28 +202,34 @@ impl<const NULLABLE_RESULT: bool> AggregateFunction for AggregateNullUnaryAdapto
Ok(())
}

fn deserialize(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> {
fn merge(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> {
if self.get_flag(place) == 0 {
// initial the state to remove the dirty stats
self.init_state(place);
}

if NULLABLE_RESULT {
let flag = reader[reader.len() - 1];
self.nested
.deserialize(place, &mut &reader[..reader.len() - 1])?;
self.set_flag(place, flag);
if flag == 1 {
self.set_flag(place, 1);
self.nested.merge(place, &mut &reader[..reader.len() - 1])?;
}
} else {
self.nested.deserialize(place, reader)?;
self.nested.merge(place, reader)?;
}

Ok(())
}

fn merge(&self, place: StateAddr, rhs: StateAddr) -> Result<()> {
fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> {
if self.get_flag(place) == 0 {
// initial the state to remove the dirty stats
self.init_state(place);
}

if self.get_flag(rhs) == 1 {
self.set_flag(place, 1);
self.nested.merge(place, rhs)?;
self.nested.merge_states(place, rhs)?;
}

Ok(())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,27 +207,33 @@ impl<const NULLABLE_RESULT: bool> AggregateFunction
Ok(())
}

fn deserialize(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> {
fn merge(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> {
if self.get_flag(place) == 0 {
// initial the state to remove the dirty stats
self.init_state(place);
}

if NULLABLE_RESULT {
let flag = reader[reader.len() - 1];
self.nested
.deserialize(place, &mut &reader[..reader.len() - 1])?;
self.set_flag(place, flag);
if flag == 1 {
self.set_flag(place, flag);
self.nested.merge(place, &mut &reader[..reader.len() - 1])?;
}
} else {
self.nested.deserialize(place, reader)?;
self.nested.merge(place, reader)?;
}
Ok(())
}

fn merge(&self, place: StateAddr, rhs: StateAddr) -> Result<()> {
fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> {
if self.get_flag(place) == 0 {
// initial the state to remove the dirty stats
self.init_state(place);
}

if self.get_flag(rhs) == 1 {
self.set_flag(place, 1);
self.nested.merge(place, rhs)?;
self.nested.merge_states(place, rhs)?;
}
Ok(())
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,16 +173,16 @@ impl AggregateFunction for AggregateFunctionOrNullAdaptor {
}

#[inline]
fn deserialize(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> {
let flag = reader[reader.len() - 1];
self.inner
.deserialize(place, &mut &reader[..reader.len() - 1])?;
self.set_flag(place, flag);
fn merge(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> {
let flag = self.get_flag(place) > 0 || reader[reader.len() - 1] > 0;

self.inner.merge(place, &mut &reader[..reader.len() - 1])?;
self.set_flag(place, flag as u8);
Ok(())
}

fn merge(&self, place: StateAddr, rhs: StateAddr) -> Result<()> {
self.inner.merge(place, rhs)?;
fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> {
self.inner.merge_states(place, rhs)?;
let flag = self.get_flag(place) > 0 || self.get_flag(rhs) > 0;
self.set_flag(place, u8::from(flag));
Ok(())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,17 +126,18 @@ where for<'a> T::ScalarRef<'a>: Hash
serialize_into_buf(writer, &state.hll)
}

fn deserialize(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> {
fn merge(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> {
let state = place.get::<AggregateApproxCountDistinctState<T::ScalarRef<'_>>>();
state.hll = deserialize_from_slice(reader)?;
let hll: HyperLogLog<T::ScalarRef<'_>> = deserialize_from_slice(reader)?;
state.hll.union(&hll);

Ok(())
}

fn merge(&self, place: StateAddr, rhs: StateAddr) -> Result<()> {
fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> {
let state = place.get::<AggregateApproxCountDistinctState<T::ScalarRef<'_>>>();
let rhs = rhs.get::<AggregateApproxCountDistinctState<T::ScalarRef<'_>>>();
state.hll.union(&rhs.hll);

let other = rhs.get::<AggregateApproxCountDistinctState<T::ScalarRef<'_>>>();
state.hll.union(&other.hll);
Ok(())
}

Expand Down
29 changes: 11 additions & 18 deletions src/query/functions/src/aggregates/aggregate_arg_min_max.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ use crate::with_simple_no_number_mapped_type;
// State for arg_min(arg, val) and arg_max(arg, val)
// A: ValueType for arg.
// V: ValueType for val.
pub trait AggregateArgMinMaxState<A: ValueType, V: ValueType>: Send + Sync + 'static {
pub trait AggregateArgMinMaxState<A: ValueType, V: ValueType>:
Serialize + DeserializeOwned + Send + Sync + 'static
{
fn new() -> Self;
fn add(&mut self, value: V::ScalarRef<'_>, data: Scalar);
fn add_batch(
Expand All @@ -60,8 +62,6 @@ pub trait AggregateArgMinMaxState<A: ValueType, V: ValueType>: Send + Sync + 'st
) -> Result<()>;

fn merge(&mut self, rhs: &Self) -> Result<()>;
fn serialize(&self, writer: &mut Vec<u8>) -> Result<()>;
fn deserialize(&mut self, reader: &mut &[u8]) -> Result<()>;
fn merge_result(&mut self, column: &mut ColumnBuilder) -> Result<()>;
}

Expand Down Expand Up @@ -180,15 +180,6 @@ where
Ok(())
}

fn serialize(&self, writer: &mut Vec<u8>) -> Result<()> {
serialize_into_buf(writer, self)
}

fn deserialize(&mut self, reader: &mut &[u8]) -> Result<()> {
*self = deserialize_from_slice(reader)?;
Ok(())
}

fn merge_result(&mut self, builder: &mut ColumnBuilder) -> Result<()> {
if self.value.is_some() {
if let Some(inner) = A::try_downcast_builder(builder) {
Expand Down Expand Up @@ -291,18 +282,20 @@ where

fn serialize(&self, place: StateAddr, writer: &mut Vec<u8>) -> Result<()> {
let state = place.get::<State>();
state.serialize(writer)
serialize_into_buf(writer, state)
}

fn deserialize(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> {
fn merge(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> {
let state = place.get::<State>();
state.deserialize(reader)
let rhs: State = deserialize_from_slice(reader)?;

state.merge(&rhs)
}

fn merge(&self, place: StateAddr, rhs: StateAddr) -> Result<()> {
let rhs = rhs.get::<State>();
fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> {
let state = place.get::<State>();
state.merge(rhs)
let other = rhs.get::<State>();
state.merge(other)
}

fn merge_result(&self, place: StateAddr, builder: &mut ColumnBuilder) -> Result<()> {
Expand Down
32 changes: 8 additions & 24 deletions src/query/functions/src/aggregates/aggregate_array_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,15 +124,6 @@ where
builder.push(array_value);
Ok(())
}

fn serialize(&self, writer: &mut Vec<u8>) -> Result<()> {
serialize_into_buf(writer, self)
}

fn deserialize(&mut self, reader: &mut &[u8]) -> Result<()> {
self.values = deserialize_from_slice(reader)?;
Ok(())
}
}

#[derive(Serialize, Deserialize, Debug)]
Expand Down Expand Up @@ -234,15 +225,6 @@ where
builder.push(array_value);
Ok(())
}

fn serialize(&self, writer: &mut Vec<u8>) -> Result<()> {
serialize_into_buf(writer, self)
}

fn deserialize(&mut self, reader: &mut &[u8]) -> Result<()> {
self.values = deserialize_from_slice(reader)?;
Ok(())
}
}

#[derive(Clone)]
Expand Down Expand Up @@ -356,18 +338,20 @@ where

fn serialize(&self, place: StateAddr, writer: &mut Vec<u8>) -> Result<()> {
let state = place.get::<State>();
state.serialize(writer)
serialize_into_buf(writer, state)
}

fn deserialize(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> {
fn merge(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> {
let state = place.get::<State>();
state.deserialize(reader)
let rhs: State = deserialize_from_slice(reader)?;

state.merge(&rhs)
}

fn merge(&self, place: StateAddr, rhs: StateAddr) -> Result<()> {
let rhs = rhs.get::<State>();
fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> {
let state = place.get::<State>();
state.merge(rhs)
let other = rhs.get::<State>();
state.merge(other)
}

fn merge_result(&self, place: StateAddr, builder: &mut ColumnBuilder) -> Result<()> {
Expand Down
Loading

0 comments on commit 0d880f8

Please sign in to comment.