From b66510639a809d26dc8aab3f0ffd1b0ecbc64871 Mon Sep 17 00:00:00 2001 From: Richard Chien Date: Mon, 23 Sep 2024 16:09:57 +0800 Subject: [PATCH 1/6] rename `AggCall.Type` to `AggCall.Kind` Signed-off-by: Richard Chien --- proto/expr.proto | 6 +++--- src/batch/src/executor/hash_agg.rs | 10 +++++----- src/expr/core/src/aggregate/def.rs | 2 +- src/expr/core/src/sig/mod.rs | 10 +++++----- src/expr/core/src/window_function/kind.rs | 4 ++-- src/expr/macro/src/gen.rs | 4 ++-- src/frontend/src/expr/type_inference/func.rs | 2 +- src/meta/src/stream/test_fragmenter.rs | 6 +++--- src/prost/src/lib.rs | 2 +- 9 files changed, 23 insertions(+), 23 deletions(-) diff --git a/proto/expr.proto b/proto/expr.proto index 53bba96cc587b..d4449648e6002 100644 --- a/proto/expr.proto +++ b/proto/expr.proto @@ -422,7 +422,7 @@ message FunctionCall { // Aggregate Function Calls for Aggregation message AggCall { - enum Type { + enum Kind { UNSPECIFIED = 0; SUM = 1; MIN = 2; @@ -458,7 +458,7 @@ message AggCall { // wraps a scalar function that takes a list as input as an aggregate function. WRAP_SCALAR = 101; } - Type type = 1; + Kind type = 1; repeated InputRef args = 2; data.DataType return_type = 3; bool distinct = 4; @@ -562,7 +562,7 @@ message WindowFunction { oneof type { GeneralType general = 1; - AggCall.Type aggregate = 2; + AggCall.Kind aggregate = 2; } repeated InputRef args = 3; data.DataType return_type = 4; diff --git a/src/batch/src/executor/hash_agg.rs b/src/batch/src/executor/hash_agg.rs index 00073217f7ead..2f54ca2b2876a 100644 --- a/src/batch/src/executor/hash_agg.rs +++ b/src/batch/src/executor/hash_agg.rs @@ -755,7 +755,7 @@ mod tests { use risingwave_common::util::sort_util::{ColumnOrder, OrderType}; use risingwave_pb::data::data_type::TypeName; use risingwave_pb::data::PbDataType; - use risingwave_pb::expr::agg_call::Type; + use risingwave_pb::expr::agg_call::PbKind as PbAggKind; use risingwave_pb::expr::{AggCall, InputRef}; use super::*; @@ -788,7 +788,7 @@ mod tests { )); let agg_call = AggCall { - r#type: Type::Sum as i32, + r#type: PbAggKind::Sum as i32, args: vec![InputRef { index: 2, r#type: Some(PbDataType { @@ -873,7 +873,7 @@ mod tests { ); let agg_call = AggCall { - r#type: Type::Count as i32, + r#type: PbAggKind::Count as i32, args: vec![], return_type: Some(PbDataType { type_name: TypeName::Int64 as i32, @@ -985,7 +985,7 @@ mod tests { ); let agg_call = AggCall { - r#type: Type::Sum as i32, + r#type: PbAggKind::Sum as i32, args: vec![InputRef { index: 2, r#type: Some(PbDataType { @@ -1078,7 +1078,7 @@ mod tests { )); let agg_call = AggCall { - r#type: Type::Sum as i32, + r#type: PbAggKind::Sum as i32, args: vec![InputRef { index: 2, r#type: Some(PbDataType { diff --git a/src/expr/core/src/aggregate/def.rs b/src/expr/core/src/aggregate/def.rs index 1cabd3a2f5ae5..e952ac53bdfc0 100644 --- a/src/expr/core/src/aggregate/def.rs +++ b/src/expr/core/src/aggregate/def.rs @@ -26,7 +26,7 @@ use risingwave_common::bail; use risingwave_common::types::{DataType, Datum}; use risingwave_common::util::sort_util::{ColumnOrder, OrderType}; use risingwave_common::util::value_encoding::DatumFromProtoExt; -pub use risingwave_pb::expr::agg_call::PbType as PbAggKind; +pub use risingwave_pb::expr::agg_call::PbKind as PbAggKind; use risingwave_pb::expr::{PbAggCall, PbExprNode, PbInputRef, PbUserDefinedFunctionMetadata}; use crate::expr::{ diff --git a/src/expr/core/src/sig/mod.rs b/src/expr/core/src/sig/mod.rs index aae3802489d09..48280c124c156 100644 --- a/src/expr/core/src/sig/mod.rs +++ b/src/expr/core/src/sig/mod.rs @@ -21,7 +21,7 @@ use std::sync::LazyLock; use itertools::Itertools; use risingwave_common::types::DataType; -use risingwave_pb::expr::agg_call::PbType as AggregateFunctionType; +use risingwave_pb::expr::agg_call::PbKind as PbAggKind; use risingwave_pb::expr::expr_node::PbType as ScalarFunctionType; use risingwave_pb::expr::table_function::PbType as TableFunctionType; @@ -354,7 +354,7 @@ impl FuncSign { pub enum FuncName { Scalar(ScalarFunctionType), Table(TableFunctionType), - Aggregate(AggregateFunctionType), + Aggregate(PbAggKind), Udf(String), } @@ -370,8 +370,8 @@ impl From for FuncName { } } -impl From for FuncName { - fn from(ty: AggregateFunctionType) -> Self { +impl From for FuncName { + fn from(ty: PbAggKind) -> Self { Self::Aggregate(ty) } } @@ -405,7 +405,7 @@ impl FuncName { } } - pub fn as_aggregate(&self) -> AggregateFunctionType { + pub fn as_aggregate(&self) -> PbAggKind { match self { Self::Aggregate(ty) => *ty, _ => panic!("Expected an aggregate function"), diff --git a/src/expr/core/src/window_function/kind.rs b/src/expr/core/src/window_function/kind.rs index 3042facb5cffc..9cd2c86416ede 100644 --- a/src/expr/core/src/window_function/kind.rs +++ b/src/expr/core/src/window_function/kind.rs @@ -38,7 +38,7 @@ impl WindowFuncKind { pub fn from_protobuf( window_function_type: &risingwave_pb::expr::window_function::PbType, ) -> Result { - use risingwave_pb::expr::agg_call::PbType as PbAggType; + use risingwave_pb::expr::agg_call::PbKind as PbAggKind; use risingwave_pb::expr::window_function::{PbGeneralType, PbType}; let kind = match window_function_type { @@ -51,7 +51,7 @@ impl WindowFuncKind { Ok(PbGeneralType::Lead) => Self::Lead, Err(_) => bail!("no such window function type"), }, - PbType::Aggregate(agg_type) => match PbAggType::try_from(*agg_type) { + PbType::Aggregate(agg_type) => match PbAggKind::try_from(*agg_type) { // TODO(runji): support UDAF and wrapped scalar functions Ok(agg_type) => Self::Aggregate(AggKind::from_protobuf(agg_type, None, None)?), Err(_) => bail!("no such aggregate function type"), diff --git a/src/expr/macro/src/gen.rs b/src/expr/macro/src/gen.rs index ce5c8a884abdf..8057f653c8d82 100644 --- a/src/expr/macro/src/gen.rs +++ b/src/expr/macro/src/gen.rs @@ -668,7 +668,7 @@ impl FunctionAttr { true => self.append_only, }; - let pb_type = format_ident!("{}", utils::to_camel_case(&name)); + let pb_kind = format_ident!("{}", utils::to_camel_case(&name)); let ctor_name = match append_only { false => format_ident!("{}", self.ident_name()), true => format_ident!("{}_append_only", self.ident_name()), @@ -707,7 +707,7 @@ impl FunctionAttr { use risingwave_expr::sig::{FuncSign, SigDataType, FuncBuilder}; FuncSign { - name: risingwave_pb::expr::agg_call::Type::#pb_type.into(), + name: risingwave_pb::expr::agg_call::PbKind::#pb_kind.into(), inputs_type: vec![#(#args),*], variadic: false, ret_type: #ret, diff --git a/src/frontend/src/expr/type_inference/func.rs b/src/frontend/src/expr/type_inference/func.rs index 9ed7530499921..08825f6748bb1 100644 --- a/src/frontend/src/expr/type_inference/func.rs +++ b/src/frontend/src/expr/type_inference/func.rs @@ -19,7 +19,7 @@ use risingwave_common::hash::VirtualNode; use risingwave_common::types::{DataType, StructType}; use risingwave_common::util::iter_util::ZipEqFast; pub use risingwave_expr::sig::*; -use risingwave_pb::expr::agg_call::PbType as PbAggKind; +use risingwave_pb::expr::agg_call::PbKind as PbAggKind; use risingwave_pb::expr::table_function::PbType as PbTableFuncType; use super::{align_types, cast_ok_base, CastContext}; diff --git a/src/meta/src/stream/test_fragmenter.rs b/src/meta/src/stream/test_fragmenter.rs index 601c708efef34..7005105ce74ed 100644 --- a/src/meta/src/stream/test_fragmenter.rs +++ b/src/meta/src/stream/test_fragmenter.rs @@ -23,7 +23,7 @@ use risingwave_pb::common::{PbColumnOrder, PbDirection, PbNullsAre, PbOrderType, use risingwave_pb::data::data_type::TypeName; use risingwave_pb::data::DataType; use risingwave_pb::ddl_service::TableJobType; -use risingwave_pb::expr::agg_call::Type; +use risingwave_pb::expr::agg_call::PbKind as PbAggKind; use risingwave_pb::expr::expr_node::RexNode; use risingwave_pb::expr::expr_node::Type::{Add, GreaterThan}; use risingwave_pb::expr::{AggCall, ExprNode, FunctionCall, PbInputRef}; @@ -45,7 +45,7 @@ use crate::MetaResult; fn make_inputref(idx: u32) -> ExprNode { ExprNode { - function_type: Type::Unspecified as i32, + function_type: PbAggKind::Unspecified as i32, return_type: Some(DataType { type_name: TypeName::Int32 as i32, ..Default::default() @@ -56,7 +56,7 @@ fn make_inputref(idx: u32) -> ExprNode { fn make_sum_aggcall(idx: u32) -> AggCall { AggCall { - r#type: Type::Sum as i32, + r#type: PbAggKind::Sum as i32, args: vec![PbInputRef { index: idx, r#type: Some(DataType { diff --git a/src/prost/src/lib.rs b/src/prost/src/lib.rs index e965f76282da4..5974a05664721 100644 --- a/src/prost/src/lib.rs +++ b/src/prost/src/lib.rs @@ -189,7 +189,7 @@ impl FromStr for crate::expr::table_function::PbType { } } -impl FromStr for crate::expr::agg_call::PbType { +impl FromStr for crate::expr::agg_call::PbKind { type Err = (); fn from_str(s: &str) -> Result { From 3c453f94db7b0f40db99a109cd370a2747109e74 Mon Sep 17 00:00:00 2001 From: Richard Chien Date: Mon, 23 Sep 2024 16:37:12 +0800 Subject: [PATCH 2/6] rename proto field to `kind` Signed-off-by: Richard Chien --- proto/expr.proto | 6 +++--- src/batch/benches/hash_agg.rs | 2 +- src/batch/src/executor/hash_agg.rs | 8 ++++---- src/expr/core/src/aggregate/def.rs | 2 +- src/frontend/src/optimizer/plan_node/generic/agg.rs | 2 +- src/meta/src/stream/test_fragmenter.rs | 2 +- 6 files changed, 11 insertions(+), 11 deletions(-) diff --git a/proto/expr.proto b/proto/expr.proto index d4449648e6002..c5c66733d0d92 100644 --- a/proto/expr.proto +++ b/proto/expr.proto @@ -458,16 +458,16 @@ message AggCall { // wraps a scalar function that takes a list as input as an aggregate function. WRAP_SCALAR = 101; } - Kind type = 1; + Kind kind = 1; repeated InputRef args = 2; data.DataType return_type = 3; bool distinct = 4; repeated common.ColumnOrder order_by = 5; ExprNode filter = 6; repeated Constant direct_args = 7; - // optional. only used when the type is USER_DEFINED. + // optional. only used when the kind is USER_DEFINED. UserDefinedFunctionMetadata udf = 8; - // optional. only used when the type is WRAP_SCALAR. + // optional. only used when the kind is WRAP_SCALAR. ExprNode scalar = 9; } diff --git a/src/batch/benches/hash_agg.rs b/src/batch/benches/hash_agg.rs index e393f1ff1bf8c..b29873a106673 100644 --- a/src/batch/benches/hash_agg.rs +++ b/src/batch/benches/hash_agg.rs @@ -39,7 +39,7 @@ fn create_agg_call( return_type: DataType, ) -> PbAggCall { PbAggCall { - r#type: agg_kind.to_protobuf() as i32, + kind: agg_kind.to_protobuf() as i32, args: args .into_iter() .map(|col_idx| PbInputRef { diff --git a/src/batch/src/executor/hash_agg.rs b/src/batch/src/executor/hash_agg.rs index 2f54ca2b2876a..bde5d36bd8c64 100644 --- a/src/batch/src/executor/hash_agg.rs +++ b/src/batch/src/executor/hash_agg.rs @@ -788,7 +788,7 @@ mod tests { )); let agg_call = AggCall { - r#type: PbAggKind::Sum as i32, + kind: PbAggKind::Sum as i32, args: vec![InputRef { index: 2, r#type: Some(PbDataType { @@ -873,7 +873,7 @@ mod tests { ); let agg_call = AggCall { - r#type: PbAggKind::Count as i32, + kind: PbAggKind::Count as i32, args: vec![], return_type: Some(PbDataType { type_name: TypeName::Int64 as i32, @@ -985,7 +985,7 @@ mod tests { ); let agg_call = AggCall { - r#type: PbAggKind::Sum as i32, + kind: PbAggKind::Sum as i32, args: vec![InputRef { index: 2, r#type: Some(PbDataType { @@ -1078,7 +1078,7 @@ mod tests { )); let agg_call = AggCall { - r#type: PbAggKind::Sum as i32, + kind: PbAggKind::Sum as i32, args: vec![InputRef { index: 2, r#type: Some(PbDataType { diff --git a/src/expr/core/src/aggregate/def.rs b/src/expr/core/src/aggregate/def.rs index e952ac53bdfc0..12a01beecbb25 100644 --- a/src/expr/core/src/aggregate/def.rs +++ b/src/expr/core/src/aggregate/def.rs @@ -66,7 +66,7 @@ pub struct AggCall { impl AggCall { pub fn from_protobuf(agg_call: &PbAggCall) -> Result { let agg_kind = AggKind::from_protobuf( - agg_call.get_type()?, + agg_call.get_kind()?, agg_call.udf.as_ref(), agg_call.scalar.as_ref(), )?; diff --git a/src/frontend/src/optimizer/plan_node/generic/agg.rs b/src/frontend/src/optimizer/plan_node/generic/agg.rs index bf5e4700a78c0..dc10ffd8b0526 100644 --- a/src/frontend/src/optimizer/plan_node/generic/agg.rs +++ b/src/frontend/src/optimizer/plan_node/generic/agg.rs @@ -780,7 +780,7 @@ impl PlanAggCall { pub fn to_protobuf(&self) -> PbAggCall { PbAggCall { - r#type: match &self.agg_kind { + kind: match &self.agg_kind { AggKind::Builtin(kind) => *kind, AggKind::UserDefined(_) => PbAggKind::UserDefined, AggKind::WrapScalar(_) => PbAggKind::WrapScalar, diff --git a/src/meta/src/stream/test_fragmenter.rs b/src/meta/src/stream/test_fragmenter.rs index 7005105ce74ed..c313f8b4ade38 100644 --- a/src/meta/src/stream/test_fragmenter.rs +++ b/src/meta/src/stream/test_fragmenter.rs @@ -56,7 +56,7 @@ fn make_inputref(idx: u32) -> ExprNode { fn make_sum_aggcall(idx: u32) -> AggCall { AggCall { - r#type: PbAggKind::Sum as i32, + kind: PbAggKind::Sum as i32, args: vec![PbInputRef { index: idx, r#type: Some(DataType { From 4f6a907f91769e0f5a026fd235171385bee51faa Mon Sep 17 00:00:00 2001 From: Richard Chien Date: Mon, 23 Sep 2024 16:40:31 +0800 Subject: [PATCH 3/6] rename Rust `AggKind` to `Agg Signed-off-by: Richard Chien --- src/batch/benches/hash_agg.rs | 6 +- src/expr/core/src/aggregate/def.rs | 88 +++++++++---------- src/expr/core/src/aggregate/mod.rs | 6 +- src/expr/core/src/window_function/kind.rs | 6 +- .../impl/src/window_function/aggregate.rs | 4 +- .../src/binder/expr/function/aggregate.rs | 22 ++--- src/frontend/src/binder/expr/function/mod.rs | 8 +- src/frontend/src/expr/agg_call.rs | 16 ++-- src/frontend/src/expr/window_function.rs | 8 +- .../optimizer/plan_node/batch_simple_agg.rs | 4 +- .../src/optimizer/plan_node/generic/agg.rs | 40 ++++----- .../src/optimizer/plan_node/logical_agg.rs | 12 +-- .../plan_node/logical_over_window.rs | 8 +- .../rule/apply_agg_transpose_rule.rs | 14 +-- .../src/optimizer/rule/distinct_agg_rule.rs | 4 +- .../rule/grouping_sets_to_expand_rule.rs | 4 +- .../optimizer/rule/min_max_on_index_rule.rs | 6 +- .../pull_up_correlated_predicate_agg_rule.rs | 6 +- src/stream/src/executor/aggregation/minput.rs | 16 ++-- src/stream/src/executor/aggregation/mod.rs | 4 +- src/stream/src/executor/test_utils.rs | 8 +- 21 files changed, 145 insertions(+), 145 deletions(-) diff --git a/src/batch/benches/hash_agg.rs b/src/batch/benches/hash_agg.rs index b29873a106673..6d3b24febcc25 100644 --- a/src/batch/benches/hash_agg.rs +++ b/src/batch/benches/hash_agg.rs @@ -25,7 +25,7 @@ use risingwave_common::catalog::{Field, Schema}; use risingwave_common::memory::MemoryContext; use risingwave_common::types::DataType; use risingwave_common::{enable_jemalloc, hash}; -use risingwave_expr::aggregate::{AggCall, AggKind, PbAggKind}; +use risingwave_expr::aggregate::{AggCall, AggType, PbAggKind}; use risingwave_pb::expr::{PbAggCall, PbInputRef}; use tokio::runtime::Runtime; use utils::{create_input, execute_executor}; @@ -34,7 +34,7 @@ enable_jemalloc!(); fn create_agg_call( input_schema: &Schema, - agg_kind: AggKind, + agg_kind: AggType, args: Vec, return_type: DataType, ) -> PbAggCall { @@ -59,7 +59,7 @@ fn create_agg_call( fn create_hash_agg_executor( group_key_columns: Vec, - agg_kind: AggKind, + agg_kind: AggType, arg_columns: Vec, return_type: DataType, chunk_size: usize, diff --git a/src/expr/core/src/aggregate/def.rs b/src/expr/core/src/aggregate/def.rs index 12a01beecbb25..3e1850641d14b 100644 --- a/src/expr/core/src/aggregate/def.rs +++ b/src/expr/core/src/aggregate/def.rs @@ -42,7 +42,7 @@ use crate::Result; #[derive(Debug, Clone)] pub struct AggCall { /// Aggregation kind for constructing agg state. - pub kind: AggKind, + pub kind: AggType, /// Arguments of aggregation function input. pub args: AggArgs, @@ -65,7 +65,7 @@ pub struct AggCall { impl AggCall { pub fn from_protobuf(agg_call: &PbAggCall) -> Result { - let agg_kind = AggKind::from_protobuf( + let agg_kind = AggType::from_protobuf( agg_call.get_kind()?, agg_call.udf.as_ref(), agg_call.scalar.as_ref(), @@ -160,7 +160,7 @@ impl> Parser { self.tokens.next(); // Consume the RParen AggCall { - kind: AggKind::from_protobuf(func, None, None).unwrap(), + kind: AggType::from_protobuf(func, None, None).unwrap(), args: AggArgs { data_types: children.iter().map(|(_, ty)| ty.clone()).collect(), val_indices: children.iter().map(|(idx, _)| *idx).collect(), @@ -216,7 +216,7 @@ impl> Parser { /// Aggregate function kind. #[derive(Debug, Clone, PartialEq, Eq, Hash, EnumAsInner)] -pub enum AggKind { +pub enum AggType { /// Built-in aggregate function. /// /// The associated value should not be `UserDefined` or `WrapScalar`. @@ -229,7 +229,7 @@ pub enum AggKind { WrapScalar(PbExprNode), } -impl Display for AggKind { +impl Display for AggType { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Self::Builtin(kind) => write!(f, "{}", kind.as_str_name().to_lowercase()), @@ -240,42 +240,42 @@ impl Display for AggKind { } /// `FromStr` for builtin aggregate functions. -impl FromStr for AggKind { +impl FromStr for AggType { type Err = (); fn from_str(s: &str) -> Result { let kind = PbAggKind::from_str(s)?; - Ok(AggKind::Builtin(kind)) + Ok(AggType::Builtin(kind)) } } -impl From for AggKind { +impl From for AggType { fn from(pb: PbAggKind) -> Self { assert!(!matches!( pb, PbAggKind::Unspecified | PbAggKind::UserDefined | PbAggKind::WrapScalar )); - AggKind::Builtin(pb) + AggType::Builtin(pb) } } -impl AggKind { +impl AggType { pub fn from_protobuf( - pb_type: PbAggKind, + pb_kind: PbAggKind, user_defined: Option<&PbUserDefinedFunctionMetadata>, scalar: Option<&PbExprNode>, ) -> Result { - match pb_type { + match pb_kind { PbAggKind::UserDefined => { let user_defined = user_defined.context("expect user defined")?; - Ok(AggKind::UserDefined(user_defined.clone())) + Ok(AggType::UserDefined(user_defined.clone())) } PbAggKind::WrapScalar => { let scalar = scalar.context("expect scalar")?; - Ok(AggKind::WrapScalar(scalar.clone())) + Ok(AggType::WrapScalar(scalar.clone())) } PbAggKind::Unspecified => bail!("Unrecognized agg."), - _ => Ok(AggKind::Builtin(pb_type)), + _ => Ok(AggType::Builtin(pb_kind)), } } @@ -288,27 +288,27 @@ impl AggKind { } } -/// Macros to generate match arms for [`AggKind`](AggKind). +/// Macros to generate match arms for `AggType`. /// IMPORTANT: These macros must be carefully maintained especially when adding new -/// [`AggKind`](AggKind) variants. +/// `AggType`/`PbAggKind` variants. pub mod agg_kinds { - /// [`AggKind`](super::AggKind)s that are currently not supported in streaming mode. + /// [`AggType`](super::AggType)s that are currently not supported in streaming mode. #[macro_export] macro_rules! unimplemented_in_stream { () => { - AggKind::Builtin( + AggType::Builtin( PbAggKind::PercentileCont | PbAggKind::PercentileDisc | PbAggKind::Mode, ) }; } pub use unimplemented_in_stream; - /// [`AggKind`](super::AggKind)s that should've been rewritten to other kinds. These kinds + /// [`AggType`](super::AggType)s that should've been rewritten to other kinds. These kinds /// should not appear when generating physical plan nodes. #[macro_export] macro_rules! rewritten { () => { - AggKind::Builtin( + AggType::Builtin( PbAggKind::Avg | PbAggKind::StddevPop | PbAggKind::StddevSamp @@ -323,12 +323,12 @@ pub mod agg_kinds { } pub use rewritten; - /// [`AggKind`](super::AggKind)s of which the aggregate results are not affected by the + /// [`AggType`](super::AggType)s of which the aggregate results are not affected by the /// user given ORDER BY clause. #[macro_export] macro_rules! result_unaffected_by_order_by { () => { - AggKind::Builtin(PbAggKind::BitAnd + AggType::Builtin(PbAggKind::BitAnd | PbAggKind::BitOr | PbAggKind::BitXor // XOR is commutative and associative | PbAggKind::BoolAnd @@ -348,13 +348,13 @@ pub mod agg_kinds { } pub use result_unaffected_by_order_by; - /// [`AggKind`](super::AggKind)s that must be called with ORDER BY clause. These are + /// [`AggType`](super::AggType)s that must be called with ORDER BY clause. These are /// slightly different from variants not in [`result_unaffected_by_order_by`], in that /// variants returned by this macro should be banned while the others should just be warned. #[macro_export] macro_rules! must_have_order_by { () => { - AggKind::Builtin( + AggType::Builtin( PbAggKind::FirstValue | PbAggKind::LastValue | PbAggKind::PercentileCont @@ -365,12 +365,12 @@ pub mod agg_kinds { } pub use must_have_order_by; - /// [`AggKind`](super::AggKind)s of which the aggregate results are not affected by the + /// [`AggType`](super::AggType)s of which the aggregate results are not affected by the /// user given DISTINCT keyword. #[macro_export] macro_rules! result_unaffected_by_distinct { () => { - AggKind::Builtin( + AggType::Builtin( PbAggKind::BitAnd | PbAggKind::BitOr | PbAggKind::BoolAnd @@ -383,11 +383,11 @@ pub mod agg_kinds { } pub use result_unaffected_by_distinct; - /// [`AggKind`](crate::aggregate::AggKind)s that are simply cannot 2-phased. + /// [`AggType`](crate::aggregate::AggType)s that are simply cannot 2-phased. #[macro_export] macro_rules! simply_cannot_two_phase { () => { - AggKind::Builtin( + AggType::Builtin( PbAggKind::StringAgg | PbAggKind::ApproxCountDistinct | PbAggKind::ArrayAgg @@ -405,18 +405,18 @@ pub mod agg_kinds { | PbAggKind::BitAnd | PbAggKind::BitOr ) - | AggKind::UserDefined(_) - | AggKind::WrapScalar(_) + | AggType::UserDefined(_) + | AggType::WrapScalar(_) }; } pub use simply_cannot_two_phase; - /// [`AggKind`](super::AggKind)s that are implemented with a single value state (so-called + /// [`AggType`](super::AggType)s that are implemented with a single value state (so-called /// stateless). #[macro_export] macro_rules! single_value_state { () => { - AggKind::Builtin( + AggType::Builtin( PbAggKind::Sum | PbAggKind::Sum0 | PbAggKind::Count @@ -428,26 +428,26 @@ pub mod agg_kinds { | PbAggKind::ApproxCountDistinct | PbAggKind::InternalLastSeenValue | PbAggKind::ApproxPercentile, - ) | AggKind::UserDefined(_) + ) | AggType::UserDefined(_) }; } pub use single_value_state; - /// [`AggKind`](super::AggKind)s that are implemented with a single value state (so-called + /// [`AggType`](super::AggType)s that are implemented with a single value state (so-called /// stateless) iff the input is append-only. #[macro_export] macro_rules! single_value_state_iff_in_append_only { () => { - AggKind::Builtin(PbAggKind::Max | PbAggKind::Min) + AggType::Builtin(PbAggKind::Max | PbAggKind::Min) }; } pub use single_value_state_iff_in_append_only; - /// [`AggKind`](super::AggKind)s that are implemented with a materialized input state. + /// [`AggType`](super::AggType)s that are implemented with a materialized input state. #[macro_export] macro_rules! materialized_input_state { () => { - AggKind::Builtin( + AggType::Builtin( PbAggKind::Min | PbAggKind::Max | PbAggKind::FirstValue @@ -456,7 +456,7 @@ pub mod agg_kinds { | PbAggKind::ArrayAgg | PbAggKind::JsonbAgg | PbAggKind::JsonbObjectAgg, - ) | AggKind::WrapScalar(_) + ) | AggType::WrapScalar(_) }; } pub use materialized_input_state; @@ -465,7 +465,7 @@ pub mod agg_kinds { #[macro_export] macro_rules! ordered_set { () => { - AggKind::Builtin( + AggType::Builtin( PbAggKind::PercentileCont | PbAggKind::PercentileDisc | PbAggKind::Mode @@ -476,24 +476,24 @@ pub mod agg_kinds { pub use ordered_set; } -impl AggKind { +impl AggType { /// Get the total phase agg kind from the partial phase agg kind. pub fn partial_to_total(&self) -> Option { match self { - AggKind::Builtin( + AggType::Builtin( PbAggKind::BitXor | PbAggKind::Min | PbAggKind::Max | PbAggKind::Sum | PbAggKind::InternalLastSeenValue, ) => Some(self.clone()), - AggKind::Builtin(PbAggKind::Sum0 | PbAggKind::Count) => { + AggType::Builtin(PbAggKind::Sum0 | PbAggKind::Count) => { Some(Self::Builtin(PbAggKind::Sum0)) } agg_kinds::simply_cannot_two_phase!() => None, agg_kinds::rewritten!() => None, // invalid variants - AggKind::Builtin( + AggType::Builtin( PbAggKind::Unspecified | PbAggKind::UserDefined | PbAggKind::WrapScalar, ) => None, } diff --git a/src/expr/core/src/aggregate/mod.rs b/src/expr/core/src/aggregate/mod.rs index e77c549a99db0..eb443de42978f 100644 --- a/src/expr/core/src/aggregate/mod.rs +++ b/src/expr/core/src/aggregate/mod.rs @@ -147,16 +147,16 @@ pub fn build_retractable(agg: &AggCall) -> Result { pub fn build(agg: &AggCall, prefer_append_only: bool) -> Result { // handle special kinds let kind = match &agg.kind { - AggKind::UserDefined(udf) => { + AggType::UserDefined(udf) => { return user_defined::new_user_defined(&agg.return_type, udf); } - AggKind::WrapScalar(scalar) => { + AggType::WrapScalar(scalar) => { return Ok(Box::new(scalar_wrapper::ScalarWrapper::new( agg.args.arg_types()[0].clone(), build_from_prost(scalar)?, ))); } - AggKind::Builtin(kind) => kind, + AggType::Builtin(kind) => kind, }; // find the signature for builtin aggregation diff --git a/src/expr/core/src/window_function/kind.rs b/src/expr/core/src/window_function/kind.rs index 9cd2c86416ede..32c5f746020d0 100644 --- a/src/expr/core/src/window_function/kind.rs +++ b/src/expr/core/src/window_function/kind.rs @@ -15,7 +15,7 @@ use parse_display::{Display, FromStr}; use risingwave_common::bail; -use crate::aggregate::AggKind; +use crate::aggregate::AggType; use crate::Result; /// Kind of window functions. @@ -31,7 +31,7 @@ pub enum WindowFuncKind { // Aggregate functions that are used with `OVER`. #[display("{0}")] - Aggregate(AggKind), + Aggregate(AggType), } impl WindowFuncKind { @@ -53,7 +53,7 @@ impl WindowFuncKind { }, PbType::Aggregate(agg_type) => match PbAggKind::try_from(*agg_type) { // TODO(runji): support UDAF and wrapped scalar functions - Ok(agg_type) => Self::Aggregate(AggKind::from_protobuf(agg_type, None, None)?), + Ok(agg_type) => Self::Aggregate(AggType::from_protobuf(agg_type, None, None)?), Err(_) => bail!("no such aggregate function type"), }, }; diff --git a/src/expr/impl/src/window_function/aggregate.rs b/src/expr/impl/src/window_function/aggregate.rs index da30d35a4e3e8..dbc09df0b3c69 100644 --- a/src/expr/impl/src/window_function/aggregate.rs +++ b/src/expr/impl/src/window_function/aggregate.rs @@ -21,7 +21,7 @@ use risingwave_common::util::iter_util::ZipEqFast; use risingwave_common::{bail, must_match}; use risingwave_common_estimate_size::{EstimateSize, KvSize}; use risingwave_expr::aggregate::{ - AggCall, AggKind, AggregateFunction, AggregateState as AggImplState, BoxedAggregateFunction, + AggCall, AggType, AggregateFunction, AggregateState as AggImplState, BoxedAggregateFunction, }; use risingwave_expr::sig::FUNCTION_REGISTRY; use risingwave_expr::window_function::{ @@ -64,7 +64,7 @@ pub(super) fn new(call: &WindowFuncCall) -> Result { direct_args: vec![], }; // TODO(runji): support UDAF and wrapped scalar function - let agg_kind = must_match!(agg_kind, AggKind::Builtin(agg_kind) => agg_kind); + let agg_kind = must_match!(agg_kind, AggType::Builtin(agg_kind) => agg_kind); let agg_func_sig = FUNCTION_REGISTRY .get(*agg_kind, &arg_data_types, &call.return_type) .expect("the agg func must exist"); diff --git a/src/frontend/src/binder/expr/function/aggregate.rs b/src/frontend/src/binder/expr/function/aggregate.rs index 1899846f7a450..586317af2de44 100644 --- a/src/frontend/src/binder/expr/function/aggregate.rs +++ b/src/frontend/src/binder/expr/function/aggregate.rs @@ -15,7 +15,7 @@ use itertools::Itertools; use risingwave_common::types::{DataType, ScalarImpl}; use risingwave_common::{bail, bail_not_implemented}; -use risingwave_expr::aggregate::{agg_kinds, AggKind, PbAggKind}; +use risingwave_expr::aggregate::{agg_kinds, AggType, PbAggKind}; use risingwave_sqlparser::ast::{self, FunctionArgExpr}; use crate::binder::Clause; @@ -48,7 +48,7 @@ impl Binder { pub(super) fn bind_aggregate_function( &mut self, - kind: AggKind, + kind: AggType, distinct: bool, args: Vec, order_by: Vec, @@ -97,7 +97,7 @@ impl Binder { fn bind_ordered_set_agg( &mut self, - kind: &AggKind, + kind: &AggType, distinct: bool, args: Vec, order_by: Vec, @@ -138,10 +138,10 @@ impl Binder { // check signature and do implicit cast match (kind, direct_args.len(), args.as_mut_slice()) { - (AggKind::Builtin(PbAggKind::PercentileCont | PbAggKind::PercentileDisc), 1, [arg]) => { + (AggType::Builtin(PbAggKind::PercentileCont | PbAggKind::PercentileDisc), 1, [arg]) => { let fraction = &mut direct_args[0]; decimal_to_float64(fraction, kind)?; - if matches!(&kind, AggKind::Builtin(PbAggKind::PercentileCont)) { + if matches!(&kind, AggType::Builtin(PbAggKind::PercentileCont)) { arg.cast_implicit_mut(DataType::Float64).map_err(|_| { ErrorCode::InvalidInputSyntax(format!( "arg in `{}` must be castable to float64", @@ -150,8 +150,8 @@ impl Binder { })?; } } - (AggKind::Builtin(PbAggKind::Mode), 0, [_arg]) => {} - (AggKind::Builtin(PbAggKind::ApproxPercentile), 1..=2, [_percentile_col]) => { + (AggType::Builtin(PbAggKind::Mode), 0, [_arg]) => {} + (AggType::Builtin(PbAggKind::ApproxPercentile), 1..=2, [_percentile_col]) => { let percentile = &mut direct_args[0]; decimal_to_float64(percentile, kind)?; match direct_args.len() { @@ -207,7 +207,7 @@ impl Binder { fn bind_normal_agg( &mut self, - kind: &AggKind, + kind: &AggType, distinct: bool, args: Vec, order_by: Vec, @@ -242,8 +242,8 @@ impl Binder { if distinct { if matches!( kind, - AggKind::Builtin(PbAggKind::ApproxCountDistinct) - | AggKind::Builtin(PbAggKind::ApproxPercentile) + AggType::Builtin(PbAggKind::ApproxCountDistinct) + | AggType::Builtin(PbAggKind::ApproxPercentile) ) { return Err(ErrorCode::InvalidInputSyntax(format!( "DISTINCT is not allowed for approximate aggregation `{}`", @@ -283,7 +283,7 @@ impl Binder { } } -fn decimal_to_float64(decimal_expr: &mut ExprImpl, kind: &AggKind) -> Result<()> { +fn decimal_to_float64(decimal_expr: &mut ExprImpl, kind: &AggType) -> Result<()> { if decimal_expr.cast_implicit_mut(DataType::Float64).is_err() { return Err(ErrorCode::InvalidInputSyntax(format!( "direct arg in `{}` must be castable to float64", diff --git a/src/frontend/src/binder/expr/function/mod.rs b/src/frontend/src/binder/expr/function/mod.rs index 3505f14936c7d..fd54e65863a81 100644 --- a/src/frontend/src/binder/expr/function/mod.rs +++ b/src/frontend/src/binder/expr/function/mod.rs @@ -20,7 +20,7 @@ use itertools::Itertools; use risingwave_common::bail_not_implemented; use risingwave_common::catalog::{INFORMATION_SCHEMA_SCHEMA_NAME, PG_CATALOG_SCHEMA_NAME}; use risingwave_common::types::DataType; -use risingwave_expr::aggregate::AggKind; +use risingwave_expr::aggregate::AggType; use risingwave_expr::window_function::WindowFuncKind; use risingwave_sqlparser::ast::{self, Function, FunctionArg, FunctionArgExpr, Ident}; use risingwave_sqlparser::parser::ParserError; @@ -182,7 +182,7 @@ impl Binder { }; // now this is either an aggregate/window function call - Some(AggKind::WrapScalar(scalar_func_expr.to_expr_proto())) + Some(AggType::WrapScalar(scalar_func_expr.to_expr_proto())) } else { None }; @@ -233,8 +233,8 @@ impl Binder { && udf.kind.is_aggregate() { assert_ne!(udf.language, "sql", "SQL UDAF is not supported yet"); - Some(AggKind::UserDefined(udf.as_ref().into())) - } else if let Ok(kind) = AggKind::from_str(&func_name) { + Some(AggType::UserDefined(udf.as_ref().into())) + } else if let Ok(kind) = AggType::from_str(&func_name) { Some(kind) } else { None diff --git a/src/frontend/src/expr/agg_call.rs b/src/frontend/src/expr/agg_call.rs index 452d37652d341..4996d4bab18b5 100644 --- a/src/frontend/src/expr/agg_call.rs +++ b/src/frontend/src/expr/agg_call.rs @@ -13,7 +13,7 @@ // limitations under the License. use risingwave_common::types::DataType; -use risingwave_expr::aggregate::AggKind; +use risingwave_expr::aggregate::AggType; use super::{infer_type, Expr, ExprImpl, Literal, OrderBy}; use crate::error::Result; @@ -21,7 +21,7 @@ use crate::utils::Condition; #[derive(Clone, Eq, PartialEq, Hash)] pub struct AggCall { - pub agg_kind: AggKind, + pub agg_kind: AggType, pub return_type: DataType, pub args: Vec, pub distinct: bool, @@ -56,7 +56,7 @@ impl AggCall { /// Returns error if the function name matches with an existing function /// but with illegal arguments. pub fn new( - agg_kind: AggKind, + agg_kind: AggType, mut args: Vec, distinct: bool, order_by: OrderBy, @@ -64,9 +64,9 @@ impl AggCall { direct_args: Vec, ) -> Result { let return_type = match &agg_kind { - AggKind::Builtin(kind) => infer_type((*kind).into(), &mut args)?, - AggKind::UserDefined(udf) => udf.return_type.as_ref().unwrap().into(), - AggKind::WrapScalar(expr) => expr.return_type.as_ref().unwrap().into(), + AggType::Builtin(kind) => infer_type((*kind).into(), &mut args)?, + AggType::UserDefined(udf) => udf.return_type.as_ref().unwrap().into(), + AggType::WrapScalar(expr) => expr.return_type.as_ref().unwrap().into(), }; Ok(AggCall { agg_kind, @@ -81,7 +81,7 @@ impl AggCall { /// Constructs an `AggCall` without type inference. pub fn new_unchecked( - agg_kind: AggKind, + agg_kind: AggType, args: Vec, return_type: DataType, ) -> Result { @@ -96,7 +96,7 @@ impl AggCall { }) } - pub fn agg_kind(&self) -> AggKind { + pub fn agg_kind(&self) -> AggType { self.agg_kind.clone() } diff --git a/src/frontend/src/expr/window_function.rs b/src/frontend/src/expr/window_function.rs index 8f2e6c66728dd..a4fd2b7b92a4a 100644 --- a/src/frontend/src/expr/window_function.rs +++ b/src/frontend/src/expr/window_function.rs @@ -15,7 +15,7 @@ use itertools::Itertools; use risingwave_common::bail_not_implemented; use risingwave_common::types::DataType; -use risingwave_expr::aggregate::AggKind; +use risingwave_expr::aggregate::AggType; use risingwave_expr::window_function::{Frame, WindowFuncKind}; use super::{Expr, ExprImpl, OrderBy, RwResult}; @@ -88,9 +88,9 @@ impl WindowFunction { } (Aggregate(agg_kind), args) => Ok(match agg_kind { - AggKind::Builtin(kind) => infer_type((*kind).into(), args)?, - AggKind::UserDefined(udf) => udf.return_type.as_ref().unwrap().into(), - AggKind::WrapScalar(expr) => expr.return_type.as_ref().unwrap().into(), + AggType::Builtin(kind) => infer_type((*kind).into(), args)?, + AggType::UserDefined(udf) => udf.return_type.as_ref().unwrap().into(), + AggType::WrapScalar(expr) => expr.return_type.as_ref().unwrap().into(), }), (_, args) => { diff --git a/src/frontend/src/optimizer/plan_node/batch_simple_agg.rs b/src/frontend/src/optimizer/plan_node/batch_simple_agg.rs index 894ad92011008..e01762fc56373 100644 --- a/src/frontend/src/optimizer/plan_node/batch_simple_agg.rs +++ b/src/frontend/src/optimizer/plan_node/batch_simple_agg.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use risingwave_expr::aggregate::{AggKind, PbAggKind}; +use risingwave_expr::aggregate::{AggType, PbAggKind}; use risingwave_pb::batch_plan::plan_node::NodeBody; use risingwave_pb::batch_plan::SortAggNode; @@ -59,7 +59,7 @@ impl BatchSimpleAgg { .agg_calls .iter() .map(|agg_call| &agg_call.agg_kind) - .all(|agg_kind| !matches!(agg_kind, AggKind::Builtin(PbAggKind::ApproxPercentile))) + .all(|agg_kind| !matches!(agg_kind, AggType::Builtin(PbAggKind::ApproxPercentile))) && self.two_phase_agg_enabled() } } diff --git a/src/frontend/src/optimizer/plan_node/generic/agg.rs b/src/frontend/src/optimizer/plan_node/generic/agg.rs index dc10ffd8b0526..71c1567a7eafa 100644 --- a/src/frontend/src/optimizer/plan_node/generic/agg.rs +++ b/src/frontend/src/optimizer/plan_node/generic/agg.rs @@ -23,7 +23,7 @@ use risingwave_common::types::DataType; use risingwave_common::util::iter_util::ZipEqFast; use risingwave_common::util::sort_util::{ColumnOrder, ColumnOrderDisplay, OrderType}; use risingwave_common::util::value_encoding::DatumToProtoExt; -use risingwave_expr::aggregate::{agg_kinds, AggKind, PbAggKind}; +use risingwave_expr::aggregate::{agg_kinds, AggType, PbAggKind}; use risingwave_expr::sig::{FuncBuilder, FUNCTION_REGISTRY}; use risingwave_pb::expr::{PbAggCall, PbConstant}; use risingwave_pb::stream_plan::{agg_call_state, AggCallState as PbAggCallState}; @@ -108,7 +108,7 @@ impl Agg { let order_ok = matches!( call.agg_kind, agg_kinds::result_unaffected_by_order_by!() - | AggKind::Builtin(PbAggKind::ApproxPercentile) + | AggType::Builtin(PbAggKind::ApproxPercentile) ) || call.order_by.is_empty(); let distinct_ok = matches!(call.agg_kind, agg_kinds::result_unaffected_by_distinct!()) @@ -422,20 +422,20 @@ impl Agg { // columns with order requirement in state table let sort_keys = { match agg_call.agg_kind { - AggKind::Builtin(PbAggKind::Min) => { + AggType::Builtin(PbAggKind::Min) => { vec![(OrderType::ascending(), agg_call.inputs[0].index)] } - AggKind::Builtin(PbAggKind::Max) => { + AggType::Builtin(PbAggKind::Max) => { vec![(OrderType::descending(), agg_call.inputs[0].index)] } - AggKind::Builtin( + AggType::Builtin( PbAggKind::FirstValue | PbAggKind::LastValue | PbAggKind::StringAgg | PbAggKind::ArrayAgg | PbAggKind::JsonbAgg, ) - | AggKind::WrapScalar(_) => { + | AggType::WrapScalar(_) => { if agg_call.order_by.is_empty() { me.ctx().warn_to_user(format!( "{} without ORDER BY may produce non-deterministic result", @@ -449,7 +449,7 @@ impl Agg { ( if matches!( agg_call.agg_kind, - AggKind::Builtin(PbAggKind::LastValue) + AggType::Builtin(PbAggKind::LastValue) ) { o.order_type.reverse() } else { @@ -460,7 +460,7 @@ impl Agg { }) .collect() } - AggKind::Builtin(PbAggKind::JsonbObjectAgg) => agg_call + AggType::Builtin(PbAggKind::JsonbObjectAgg) => agg_call .order_by .iter() .map(|o| (o.order_type, o.column_index)) @@ -482,7 +482,7 @@ impl Agg { // other columns that should be contained in state table let include_keys = match agg_call.agg_kind { // `agg_kinds::materialized_input_state` except for `min`/`max` - AggKind::Builtin( + AggType::Builtin( PbAggKind::FirstValue | PbAggKind::LastValue | PbAggKind::StringAgg @@ -490,7 +490,7 @@ impl Agg { | PbAggKind::JsonbAgg | PbAggKind::JsonbObjectAgg, ) - | AggKind::WrapScalar(_) => { + | AggType::WrapScalar(_) => { agg_call.inputs.iter().map(|i| i.index).collect() } _ => vec![], @@ -505,7 +505,7 @@ impl Agg { agg_kinds::unimplemented_in_stream!() => { unreachable!("should have been banned") } - AggKind::Builtin( + AggType::Builtin( PbAggKind::Unspecified | PbAggKind::UserDefined | PbAggKind::WrapScalar, ) => { unreachable!("invalid agg kind") @@ -532,16 +532,16 @@ impl Agg { .zip_eq_fast(&mut out_fields[self.group_key.len()..]) { let agg_kind = match agg_call.agg_kind { - AggKind::UserDefined(_) => { + AggType::UserDefined(_) => { // for user defined aggregate, the state type is always BYTEA field.data_type = DataType::Bytea; continue; } - AggKind::WrapScalar(_) => { + AggType::WrapScalar(_) => { // for wrapped scalar function, the state is always NULL continue; } - AggKind::Builtin(kind) => kind, + AggType::Builtin(kind) => kind, }; let sig = FUNCTION_REGISTRY .get( @@ -705,7 +705,7 @@ impl_distill_unit_from_fields!(Agg, stream::StreamPlanRef); #[derive(Clone, PartialEq, Eq, Hash)] pub struct PlanAggCall { /// Kind of aggregation function - pub agg_kind: AggKind, + pub agg_kind: AggType, /// Data type of the returned column pub return_type: DataType, @@ -781,9 +781,9 @@ impl PlanAggCall { pub fn to_protobuf(&self) -> PbAggCall { PbAggCall { kind: match &self.agg_kind { - AggKind::Builtin(kind) => *kind, - AggKind::UserDefined(_) => PbAggKind::UserDefined, - AggKind::WrapScalar(_) => PbAggKind::WrapScalar, + AggType::Builtin(kind) => *kind, + AggType::UserDefined(_) => PbAggKind::UserDefined, + AggType::WrapScalar(_) => PbAggKind::WrapScalar, } .into(), return_type: Some(self.return_type.to_protobuf()), @@ -800,11 +800,11 @@ impl PlanAggCall { }) .collect(), udf: match &self.agg_kind { - AggKind::UserDefined(udf) => Some(udf.clone()), + AggType::UserDefined(udf) => Some(udf.clone()), _ => None, }, scalar: match &self.agg_kind { - AggKind::WrapScalar(expr) => Some(expr.clone()), + AggType::WrapScalar(expr) => Some(expr.clone()), _ => None, }, } diff --git a/src/frontend/src/optimizer/plan_node/logical_agg.rs b/src/frontend/src/optimizer/plan_node/logical_agg.rs index b0ad102ee693c..76af6595673b8 100644 --- a/src/frontend/src/optimizer/plan_node/logical_agg.rs +++ b/src/frontend/src/optimizer/plan_node/logical_agg.rs @@ -17,7 +17,7 @@ use itertools::Itertools; use risingwave_common::types::{DataType, Datum, ScalarImpl}; use risingwave_common::util::sort_util::{ColumnOrder, OrderType}; use risingwave_common::{bail, bail_not_implemented, not_implemented}; -use risingwave_expr::aggregate::{agg_kinds, AggKind, PbAggKind}; +use risingwave_expr::aggregate::{agg_kinds, AggType, PbAggKind}; use super::generic::{self, Agg, GenericPlanRef, PlanAggCall, ProjectBuilder}; use super::utils::impl_distill_by_unit; @@ -389,7 +389,7 @@ impl LogicalAgg { let mut approx_percentile_col_mapping = Vec::with_capacity(estimated_len); let mut non_approx_percentile_col_mapping = Vec::with_capacity(estimated_len); for (output_idx, agg_call) in self.agg_calls().iter().enumerate() { - if agg_call.agg_kind == AggKind::Builtin(PbAggKind::ApproxPercentile) { + if agg_call.agg_kind == AggType::Builtin(PbAggKind::ApproxPercentile) { approx_percentile_agg_calls.push(agg_call.clone()); approx_percentile_col_mapping.push(Some(output_idx)); } else { @@ -612,7 +612,7 @@ impl LogicalAggBuilder { ) -> Result { match agg_call.agg_kind { // Rewrite avg to cast(sum as avg_return_type) / count. - AggKind::Builtin(PbAggKind::Avg) => { + AggType::Builtin(PbAggKind::Avg) => { assert_eq!(agg_call.args.len(), 1); let sum = ExprImpl::from(push_agg_call(AggCall::new( @@ -644,7 +644,7 @@ impl LogicalAggBuilder { // which is in a sense more general than the pow function, especially when calculating // covariances in the future. Also we don't have the sqrt function for rooting, so we // use pow(x, 0.5) to simulate - AggKind::Builtin( + AggType::Builtin( kind @ (PbAggKind::StddevPop | PbAggKind::StddevSamp | PbAggKind::VarPop @@ -740,7 +740,7 @@ impl LogicalAggBuilder { _ => unreachable!(), } } - AggKind::Builtin(PbAggKind::ApproxPercentile) => { + AggType::Builtin(PbAggKind::ApproxPercentile) => { if agg_call.order_by.sort_exprs[0].order_type == OrderType::descending() { // Rewrite DESC into 1.0-percentile for approx_percentile. let prev_percentile = agg_call.direct_args[0].clone(); @@ -871,7 +871,7 @@ impl LogicalAggBuilder { agg_call.distinct = false; } - if matches!(agg_call.agg_kind, AggKind::Builtin(PbAggKind::Grouping)) { + if matches!(agg_call.agg_kind, AggType::Builtin(PbAggKind::Grouping)) { if self.grouping_sets.is_empty() { return Err(ErrorCode::NotSupported( "GROUPING must be used in a query with grouping sets".into(), diff --git a/src/frontend/src/optimizer/plan_node/logical_over_window.rs b/src/frontend/src/optimizer/plan_node/logical_over_window.rs index bb78380482752..f196aeca5f4a2 100644 --- a/src/frontend/src/optimizer/plan_node/logical_over_window.rs +++ b/src/frontend/src/optimizer/plan_node/logical_over_window.rs @@ -17,7 +17,7 @@ use itertools::Itertools; use risingwave_common::types::DataType; use risingwave_common::util::sort_util::{ColumnOrder, OrderType}; use risingwave_common::{bail_not_implemented, not_implemented}; -use risingwave_expr::aggregate::{AggKind, PbAggKind}; +use risingwave_expr::aggregate::{AggType, PbAggKind}; use risingwave_expr::window_function::{Frame, FrameBound, WindowFuncKind}; use super::generic::{GenericPlanRef, OverWindow, PlanWindowFunction, ProjectBuilder}; @@ -108,7 +108,7 @@ impl<'a> LogicalOverWindowBuilder<'a> { let new_expr = if let WindowFuncKind::Aggregate(agg_kind) = &kind && matches!( agg_kind, - AggKind::Builtin( + AggType::Builtin( PbAggKind::Avg | PbAggKind::StddevPop | PbAggKind::StddevSamp @@ -191,7 +191,7 @@ impl<'a> OverWindowProjectBuilder<'a> { if let WindowFuncKind::Aggregate(agg_kind) = &window_function.kind && matches!( agg_kind, - AggKind::Builtin( + AggType::Builtin( PbAggKind::StddevPop | PbAggKind::StddevSamp | PbAggKind::VarPop @@ -379,7 +379,7 @@ impl LogicalOverWindow { }; ( - WindowFuncKind::Aggregate(AggKind::Builtin(PbAggKind::FirstValue)), + WindowFuncKind::Aggregate(AggType::Builtin(PbAggKind::FirstValue)), frame, ) } diff --git a/src/frontend/src/optimizer/rule/apply_agg_transpose_rule.rs b/src/frontend/src/optimizer/rule/apply_agg_transpose_rule.rs index 33bb59e59bf13..0b70c57036dc9 100644 --- a/src/frontend/src/optimizer/rule/apply_agg_transpose_rule.rs +++ b/src/frontend/src/optimizer/rule/apply_agg_transpose_rule.rs @@ -13,7 +13,7 @@ // limitations under the License. use risingwave_common::types::DataType; -use risingwave_expr::aggregate::{AggKind, PbAggKind}; +use risingwave_expr::aggregate::{AggType, PbAggKind}; use risingwave_pb::plan_common::JoinType; use super::{ApplyOffsetRewriter, BoxedRule, Rule}; @@ -141,20 +141,20 @@ impl Rule for ApplyAggTransposeRule { let pos_of_constant_column = node.schema().len() - 1; agg_calls.iter_mut().for_each(|agg_call| { match agg_call.agg_kind { - AggKind::Builtin(PbAggKind::Count) if agg_call.inputs.is_empty() => { + AggType::Builtin(PbAggKind::Count) if agg_call.inputs.is_empty() => { let input_ref = InputRef::new(pos_of_constant_column, DataType::Int32); agg_call.inputs.push(input_ref); } - AggKind::Builtin(PbAggKind::ArrayAgg + AggType::Builtin(PbAggKind::ArrayAgg | PbAggKind::JsonbAgg | PbAggKind::JsonbObjectAgg) - | AggKind::UserDefined(_) - | AggKind::WrapScalar(_) => { + | AggType::UserDefined(_) + | AggType::WrapScalar(_) => { let input_ref = InputRef::new(pos_of_constant_column, DataType::Int32); let cond = FunctionCall::new(ExprType::IsNotNull, vec![input_ref.into()]).unwrap(); agg_call.filter.conjunctions.push(cond.into()); } - AggKind::Builtin(PbAggKind::Count + AggType::Builtin(PbAggKind::Count | PbAggKind::Sum | PbAggKind::Sum0 | PbAggKind::Avg @@ -186,7 +186,7 @@ impl Rule for ApplyAggTransposeRule { => { // no-op when `agg(0 rows) == agg(1 row of nulls)` } - AggKind::Builtin(PbAggKind::Unspecified | PbAggKind::UserDefined | PbAggKind::WrapScalar) => { + AggType::Builtin(PbAggKind::Unspecified | PbAggKind::UserDefined | PbAggKind::WrapScalar) => { panic!("Unexpected aggregate function: {:?}", agg_call.agg_kind) } } diff --git a/src/frontend/src/optimizer/rule/distinct_agg_rule.rs b/src/frontend/src/optimizer/rule/distinct_agg_rule.rs index 52aad1336a7bb..fe12de26473d2 100644 --- a/src/frontend/src/optimizer/rule/distinct_agg_rule.rs +++ b/src/frontend/src/optimizer/rule/distinct_agg_rule.rs @@ -18,7 +18,7 @@ use std::mem; use fixedbitset::FixedBitSet; use itertools::Itertools; use risingwave_common::types::DataType; -use risingwave_expr::aggregate::{agg_kinds, AggKind, PbAggKind}; +use risingwave_expr::aggregate::{agg_kinds, AggType, PbAggKind}; use super::{BoxedRule, Rule}; use crate::expr::{CollectInputRef, ExprType, FunctionCall, InputRef, Literal}; @@ -60,7 +60,7 @@ impl Rule for DistinctAggRule { let order_ok = matches!( c.agg_kind, agg_kinds::result_unaffected_by_order_by!() - | AggKind::Builtin(PbAggKind::ApproxPercentile) + | AggType::Builtin(PbAggKind::ApproxPercentile) ) || c.order_by.is_empty(); agg_kind_ok && order_ok }) { diff --git a/src/frontend/src/optimizer/rule/grouping_sets_to_expand_rule.rs b/src/frontend/src/optimizer/rule/grouping_sets_to_expand_rule.rs index c24ab4e806c32..7856442ad01f8 100644 --- a/src/frontend/src/optimizer/rule/grouping_sets_to_expand_rule.rs +++ b/src/frontend/src/optimizer/rule/grouping_sets_to_expand_rule.rs @@ -16,7 +16,7 @@ use fixedbitset::FixedBitSet; use itertools::Itertools; use risingwave_common::types::DataType; use risingwave_common::util::column_index_mapping::ColIndexMapping; -use risingwave_expr::aggregate::{AggKind, PbAggKind}; +use risingwave_expr::aggregate::{AggType, PbAggKind}; use super::super::plan_node::*; use super::{BoxedRule, Rule}; @@ -105,7 +105,7 @@ impl Rule for GroupingSetsToExpandRule { let mut new_agg_calls = vec![]; for agg_call in old_agg_calls { // Deal with grouping agg call for grouping sets. - if matches!(agg_call.agg_kind, AggKind::Builtin(PbAggKind::Grouping)) { + if matches!(agg_call.agg_kind, AggType::Builtin(PbAggKind::Grouping)) { let mut grouping_values = vec![]; let args = agg_call .inputs diff --git a/src/frontend/src/optimizer/rule/min_max_on_index_rule.rs b/src/frontend/src/optimizer/rule/min_max_on_index_rule.rs index 8782ca8f481ff..440eb253a9bce 100644 --- a/src/frontend/src/optimizer/rule/min_max_on_index_rule.rs +++ b/src/frontend/src/optimizer/rule/min_max_on_index_rule.rs @@ -23,7 +23,7 @@ use std::vec; use itertools::Itertools; use risingwave_common::types::DataType; use risingwave_common::util::sort_util::{ColumnOrder, OrderType}; -use risingwave_expr::aggregate::{AggKind, PbAggKind}; +use risingwave_expr::aggregate::{AggType, PbAggKind}; use super::{BoxedRule, Rule}; use crate::expr::{ExprImpl, ExprType, FunctionCall, InputRef}; @@ -51,7 +51,7 @@ impl Rule for MinMaxOnIndexRule { if matches!( first_call.agg_kind, - AggKind::Builtin(PbAggKind::Min | PbAggKind::Max) + AggType::Builtin(PbAggKind::Min | PbAggKind::Max) ) && !first_call.distinct && first_call.filter.always_true() && first_call.order_by.is_empty() @@ -64,7 +64,7 @@ impl Rule for MinMaxOnIndexRule { let order = Order { column_orders: vec![ColumnOrder::new( calls.first()?.inputs.first()?.index(), - if matches!(kind, AggKind::Builtin(PbAggKind::Min)) { + if matches!(kind, AggType::Builtin(PbAggKind::Min)) { OrderType::ascending() } else { OrderType::descending() diff --git a/src/frontend/src/optimizer/rule/pull_up_correlated_predicate_agg_rule.rs b/src/frontend/src/optimizer/rule/pull_up_correlated_predicate_agg_rule.rs index e136c96c6e086..41aa080cd5aa2 100644 --- a/src/frontend/src/optimizer/rule/pull_up_correlated_predicate_agg_rule.rs +++ b/src/frontend/src/optimizer/rule/pull_up_correlated_predicate_agg_rule.rs @@ -16,7 +16,7 @@ use fixedbitset::FixedBitSet; use itertools::{Either, Itertools}; use risingwave_common::types::DataType; use risingwave_common::util::column_index_mapping::ColIndexMapping; -use risingwave_expr::aggregate::{AggKind, PbAggKind}; +use risingwave_expr::aggregate::{AggType, PbAggKind}; use super::super::plan_node::*; use super::{BoxedRule, Rule}; @@ -162,14 +162,14 @@ impl Rule for PullUpCorrelatedPredicateAggRule { // sum is null, so avg is null. And null-rejected expression will be false, so we can still apply this rule and we don't need to generate a 0 value for count. let count_exists = agg_calls .iter() - .any(|agg_call| matches!(agg_call.agg_kind, AggKind::Builtin(PbAggKind::Count))); + .any(|agg_call| matches!(agg_call.agg_kind, AggType::Builtin(PbAggKind::Count))); if count_exists { // When group input is empty, not count agg would return null. let null_agg_pos = agg_calls .iter() .positions(|agg_call| { - !matches!(agg_call.agg_kind, AggKind::Builtin(PbAggKind::Count)) + !matches!(agg_call.agg_kind, AggType::Builtin(PbAggKind::Count)) }) .collect_vec(); diff --git a/src/stream/src/executor/aggregation/minput.rs b/src/stream/src/executor/aggregation/minput.rs index 393be1878412b..89fd8881a691e 100644 --- a/src/stream/src/executor/aggregation/minput.rs +++ b/src/stream/src/executor/aggregation/minput.rs @@ -24,7 +24,7 @@ use risingwave_common::types::Datum; use risingwave_common::util::row_serde::OrderedRowSerde; use risingwave_common::util::sort_util::{ColumnOrder, OrderType}; use risingwave_common_estimate_size::EstimateSize; -use risingwave_expr::aggregate::{AggCall, AggKind, BoxedAggregateFunction, PbAggKind}; +use risingwave_expr::aggregate::{AggCall, AggType, BoxedAggregateFunction, PbAggKind}; use risingwave_pb::stream_plan::PbAggNodeVersion; use risingwave_storage::store::PrefetchOptions; use risingwave_storage::StateStore; @@ -125,19 +125,19 @@ impl MaterializedInputState { let cache_key_serializer = OrderedRowSerde::new(cache_key_data_types, order_types); let cache: Box = match agg_call.kind { - AggKind::Builtin( + AggType::Builtin( PbAggKind::Min | PbAggKind::Max | PbAggKind::FirstValue | PbAggKind::LastValue, ) => Box::new(GenericAggStateCache::new( TopNStateCache::new(extreme_cache_size), agg_call.args.arg_types(), )), - AggKind::Builtin( + AggType::Builtin( PbAggKind::StringAgg | PbAggKind::ArrayAgg | PbAggKind::JsonbAgg | PbAggKind::JsonbObjectAgg, ) - | AggKind::WrapScalar(_) => Box::new(GenericAggStateCache::new( + | AggType::WrapScalar(_) => Box::new(GenericAggStateCache::new( OrderedStateCache::new(), agg_call.args.arg_types(), )), @@ -148,7 +148,7 @@ impl MaterializedInputState { }; let output_first_value = matches!( agg_call.kind, - AggKind::Builtin( + AggType::Builtin( PbAggKind::Min | PbAggKind::Max | PbAggKind::FirstValue | PbAggKind::LastValue ) ); @@ -246,11 +246,11 @@ fn generate_order_columns_before_version_issue_13465( ) -> (Vec, Vec) { let (mut order_col_indices, mut order_types) = if matches!( agg_call.kind, - AggKind::Builtin(PbAggKind::Min | PbAggKind::Max) + AggType::Builtin(PbAggKind::Min | PbAggKind::Max) ) { // `min`/`max` need not to order by any other columns, but have to // order by the agg value implicitly. - let order_type = if matches!(agg_call.kind, AggKind::Builtin(PbAggKind::Min)) { + let order_type = if matches!(agg_call.kind, AggType::Builtin(PbAggKind::Min)) { OrderType::ascending() } else { OrderType::descending() @@ -263,7 +263,7 @@ fn generate_order_columns_before_version_issue_13465( .map(|p| { ( p.column_index, - if matches!(agg_call.kind, AggKind::Builtin(PbAggKind::LastValue)) { + if matches!(agg_call.kind, AggType::Builtin(PbAggKind::LastValue)) { p.order_type.reverse() } else { p.order_type diff --git a/src/stream/src/executor/aggregation/mod.rs b/src/stream/src/executor/aggregation/mod.rs index a5bd631e0caec..1695a0deab086 100644 --- a/src/stream/src/executor/aggregation/mod.rs +++ b/src/stream/src/executor/aggregation/mod.rs @@ -19,7 +19,7 @@ use risingwave_common::array::ArrayImpl::Bool; use risingwave_common::array::DataChunk; use risingwave_common::bail; use risingwave_common::bitmap::Bitmap; -use risingwave_expr::aggregate::{AggCall, AggKind, PbAggKind}; +use risingwave_expr::aggregate::{AggCall, AggType, PbAggKind}; use risingwave_expr::expr::{LogReport, NonStrictExpression}; use risingwave_storage::StateStore; @@ -39,7 +39,7 @@ pub async fn agg_call_filter_res( let mut vis = chunk.visibility().clone(); if matches!( agg_call.kind, - AggKind::Builtin(PbAggKind::Min | PbAggKind::Max | PbAggKind::StringAgg) + AggType::Builtin(PbAggKind::Min | PbAggKind::Max | PbAggKind::StringAgg) ) { // should skip NULL value for these kinds of agg function let agg_col_idx = agg_call.args.val_indices()[0]; // the first arg is the agg column for all these kinds diff --git a/src/stream/src/executor/test_utils.rs b/src/stream/src/executor/test_utils.rs index f4e0f40761aab..793dfb0e7266f 100644 --- a/src/stream/src/executor/test_utils.rs +++ b/src/stream/src/executor/test_utils.rs @@ -277,7 +277,7 @@ pub mod agg_executor { use risingwave_common::hash::SerializedKey; use risingwave_common::types::DataType; use risingwave_common::util::sort_util::{ColumnOrder, OrderType}; - use risingwave_expr::aggregate::{AggCall, AggKind, PbAggKind}; + use risingwave_expr::aggregate::{AggCall, AggType, PbAggKind}; use risingwave_pb::stream_plan::PbAggNodeVersion; use risingwave_storage::StateStore; @@ -329,7 +329,7 @@ pub mod agg_executor { is_append_only: bool, ) -> AggStateStorage { match agg_call.kind { - AggKind::Builtin(PbAggKind::Min | PbAggKind::Max) if !is_append_only => { + AggType::Builtin(PbAggKind::Min | PbAggKind::Max) if !is_append_only => { let mut column_descs = Vec::new(); let mut order_types = Vec::new(); let mut upstream_columns = Vec::new(); @@ -353,7 +353,7 @@ pub mod agg_executor { add_column(*idx, input_fields[*idx].data_type(), None); } - add_column(agg_call.args.val_indices()[0], agg_call.args.arg_types()[0].clone(), if matches!(agg_call.kind, AggKind::Builtin(PbAggKind::Max)) { + add_column(agg_call.args.val_indices()[0], agg_call.args.arg_types()[0].clone(), if matches!(agg_call.kind, AggType::Builtin(PbAggKind::Max)) { Some(OrderType::descending()) } else { Some(OrderType::ascending()) @@ -377,7 +377,7 @@ pub mod agg_executor { AggStateStorage::MaterializedInput { table: state_table, mapping: StateTableColumnMapping::new(upstream_columns, None), order_columns } } - AggKind::Builtin( + AggType::Builtin( PbAggKind::Min /* append only */ | PbAggKind::Max /* append only */ | PbAggKind::Sum From f9485f5ea1f7ea694bb61b3126088f80c8aca409 Mon Sep 17 00:00:00 2001 From: Richard Chien Date: Mon, 23 Sep 2024 17:05:48 +0800 Subject: [PATCH 4/6] rename `agg_kind` to `agg_type` Signed-off-by: Richard Chien --- src/batch/benches/hash_agg.rs | 12 ++-- src/expr/core/src/aggregate/def.rs | 10 +-- .../impl/src/window_function/aggregate.rs | 6 +- .../src/binder/expr/function/aggregate.rs | 16 ++--- src/frontend/src/binder/expr/function/mod.rs | 20 +++--- src/frontend/src/binder/mod.rs | 2 +- src/frontend/src/expr/agg_call.rs | 20 +++--- src/frontend/src/expr/expr_rewriter.rs | 4 +- src/frontend/src/expr/window_function.rs | 2 +- src/frontend/src/optimizer/mod.rs | 2 +- .../optimizer/plan_node/batch_simple_agg.rs | 4 +- .../src/optimizer/plan_node/generic/agg.rs | 62 +++++++++---------- .../plan_node/generic/over_window.rs | 2 +- .../src/optimizer/plan_node/logical_agg.rs | 52 ++++++++-------- .../plan_node/logical_over_window.rs | 12 ++-- src/frontend/src/optimizer/plan_node/utils.rs | 2 +- .../rule/agg_group_by_simplify_rule.rs | 2 +- .../rule/apply_agg_transpose_rule.rs | 4 +- .../src/optimizer/rule/distinct_agg_rule.rs | 18 +++--- .../rule/grouping_sets_to_expand_rule.rs | 2 +- .../optimizer/rule/min_max_on_index_rule.rs | 8 +-- .../pull_up_correlated_predicate_agg_rule.rs | 4 +- 22 files changed, 133 insertions(+), 133 deletions(-) diff --git a/src/batch/benches/hash_agg.rs b/src/batch/benches/hash_agg.rs index 6d3b24febcc25..1d77a3430c2a6 100644 --- a/src/batch/benches/hash_agg.rs +++ b/src/batch/benches/hash_agg.rs @@ -34,12 +34,12 @@ enable_jemalloc!(); fn create_agg_call( input_schema: &Schema, - agg_kind: AggType, + agg_type: AggType, args: Vec, return_type: DataType, ) -> PbAggCall { PbAggCall { - kind: agg_kind.to_protobuf() as i32, + kind: agg_type.to_protobuf() as i32, args: args .into_iter() .map(|col_idx| PbInputRef { @@ -59,7 +59,7 @@ fn create_agg_call( fn create_hash_agg_executor( group_key_columns: Vec, - agg_kind: AggType, + agg_type: AggType, arg_columns: Vec, return_type: DataType, chunk_size: usize, @@ -75,7 +75,7 @@ fn create_hash_agg_executor( let agg_calls = vec![create_agg_call( input_schema, - agg_kind, + agg_type, arg_columns, return_type, )]; @@ -131,7 +131,7 @@ fn bench_hash_agg(c: &mut Criterion) { (vec![0, 2], PbAggKind::Min, vec![1], DataType::Int64), ]; - for (group_key_columns, agg_kind, arg_columns, return_type) in bench_variants { + for (group_key_columns, agg_type, arg_columns, return_type) in bench_variants { for chunk_size in &[32, 128, 512, 1024, 2048, 4096] { c.bench_with_input( BenchmarkId::new("HashAggExecutor", chunk_size), @@ -142,7 +142,7 @@ fn bench_hash_agg(c: &mut Criterion) { || { create_hash_agg_executor( group_key_columns.clone(), - agg_kind.into(), + agg_type.into(), arg_columns.clone(), return_type.clone(), chunk_size, diff --git a/src/expr/core/src/aggregate/def.rs b/src/expr/core/src/aggregate/def.rs index 3e1850641d14b..9127170bed12a 100644 --- a/src/expr/core/src/aggregate/def.rs +++ b/src/expr/core/src/aggregate/def.rs @@ -65,7 +65,7 @@ pub struct AggCall { impl AggCall { pub fn from_protobuf(agg_call: &PbAggCall) -> Result { - let agg_kind = AggType::from_protobuf( + let agg_type = AggType::from_protobuf( agg_call.get_kind()?, agg_call.udf.as_ref(), agg_call.scalar.as_ref(), @@ -96,7 +96,7 @@ impl AggCall { }) .collect_vec(); Ok(AggCall { - kind: agg_kind, + kind: agg_type, args, return_type: DataType::from(agg_call.get_return_type()?), column_orders, @@ -291,7 +291,7 @@ impl AggType { /// Macros to generate match arms for `AggType`. /// IMPORTANT: These macros must be carefully maintained especially when adding new /// `AggType`/`PbAggKind` variants. -pub mod agg_kinds { +pub mod agg_types { /// [`AggType`](super::AggType)s that are currently not supported in streaming mode. #[macro_export] macro_rules! unimplemented_in_stream { @@ -490,8 +490,8 @@ impl AggType { AggType::Builtin(PbAggKind::Sum0 | PbAggKind::Count) => { Some(Self::Builtin(PbAggKind::Sum0)) } - agg_kinds::simply_cannot_two_phase!() => None, - agg_kinds::rewritten!() => None, + agg_types::simply_cannot_two_phase!() => None, + agg_types::rewritten!() => None, // invalid variants AggType::Builtin( PbAggKind::Unspecified | PbAggKind::UserDefined | PbAggKind::WrapScalar, diff --git a/src/expr/impl/src/window_function/aggregate.rs b/src/expr/impl/src/window_function/aggregate.rs index dbc09df0b3c69..b922a51103992 100644 --- a/src/expr/impl/src/window_function/aggregate.rs +++ b/src/expr/impl/src/window_function/aggregate.rs @@ -50,10 +50,10 @@ pub(super) fn new(call: &WindowFuncCall) -> Result { if call.frame.bounds.validate().is_err() { bail!("the window frame must be valid"); } - let agg_kind = must_match!(&call.kind, WindowFuncKind::Aggregate(agg_kind) => agg_kind); + let agg_type = must_match!(&call.kind, WindowFuncKind::Aggregate(agg_type) => agg_type); let arg_data_types = call.args.arg_types().to_vec(); let agg_call = AggCall { - kind: agg_kind.clone(), + kind: agg_type.clone(), args: call.args.clone(), return_type: call.return_type.clone(), column_orders: Vec::new(), // the input is already sorted @@ -64,7 +64,7 @@ pub(super) fn new(call: &WindowFuncCall) -> Result { direct_args: vec![], }; // TODO(runji): support UDAF and wrapped scalar function - let agg_kind = must_match!(agg_kind, AggType::Builtin(agg_kind) => agg_kind); + let agg_kind = must_match!(agg_type, AggType::Builtin(agg_kind) => agg_kind); let agg_func_sig = FUNCTION_REGISTRY .get(*agg_kind, &arg_data_types, &call.return_type) .expect("the agg func must exist"); diff --git a/src/frontend/src/binder/expr/function/aggregate.rs b/src/frontend/src/binder/expr/function/aggregate.rs index 586317af2de44..3dec3add1b8a9 100644 --- a/src/frontend/src/binder/expr/function/aggregate.rs +++ b/src/frontend/src/binder/expr/function/aggregate.rs @@ -15,7 +15,7 @@ use itertools::Itertools; use risingwave_common::types::{DataType, ScalarImpl}; use risingwave_common::{bail, bail_not_implemented}; -use risingwave_expr::aggregate::{agg_kinds, AggType, PbAggKind}; +use risingwave_expr::aggregate::{agg_types, AggType, PbAggKind}; use risingwave_sqlparser::ast::{self, FunctionArgExpr}; use crate::binder::Clause; @@ -48,7 +48,7 @@ impl Binder { pub(super) fn bind_aggregate_function( &mut self, - kind: AggType, + agg_type: AggType, distinct: bool, args: Vec, order_by: Vec, @@ -57,10 +57,10 @@ impl Binder { ) -> Result { self.ensure_aggregate_allowed()?; - let (direct_args, args, order_by) = if matches!(kind, agg_kinds::ordered_set!()) { - self.bind_ordered_set_agg(&kind, distinct, args, order_by, within_group)? + let (direct_args, args, order_by) = if matches!(agg_type, agg_types::ordered_set!()) { + self.bind_ordered_set_agg(&agg_type, distinct, args, order_by, within_group)? } else { - self.bind_normal_agg(&kind, distinct, args, order_by, within_group)? + self.bind_normal_agg(&agg_type, distinct, args, order_by, within_group)? }; let filter = match filter { @@ -86,7 +86,7 @@ impl Binder { }; Ok(ExprImpl::AggCall(Box::new(AggCall::new( - kind, + agg_type, args, distinct, order_by, @@ -107,7 +107,7 @@ impl Binder { // aggregate_name ( [ expression [ , ... ] ] ) WITHIN GROUP ( order_by_clause ) [ FILTER // ( WHERE filter_clause ) ] - assert!(matches!(kind, agg_kinds::ordered_set!())); + assert!(matches!(kind, agg_types::ordered_set!())); if !order_by.is_empty() { return Err(ErrorCode::InvalidInputSyntax(format!( @@ -222,7 +222,7 @@ impl Binder { // filter_clause ) ] // aggregate_name ( * ) [ FILTER ( WHERE filter_clause ) ] - assert!(!matches!(kind, agg_kinds::ordered_set!())); + assert!(!matches!(kind, agg_types::ordered_set!())); if within_group.is_some() { return Err(ErrorCode::InvalidInputSyntax(format!( diff --git a/src/frontend/src/binder/expr/function/mod.rs b/src/frontend/src/binder/expr/function/mod.rs index fd54e65863a81..5d3dfb79300d2 100644 --- a/src/frontend/src/binder/expr/function/mod.rs +++ b/src/frontend/src/binder/expr/function/mod.rs @@ -153,7 +153,7 @@ impl Binder { .flatten_ok() .try_collect()?; - let wrapped_agg_kind = if scalar_as_agg { + let wrapped_agg_type = if scalar_as_agg { // Let's firstly try to apply the `AGGREGATE:` prefix. // We will reject functions that are not able to be wrapped as aggregate function. let mut array_args = args @@ -187,7 +187,7 @@ impl Binder { None }; - let udf = if wrapped_agg_kind.is_none() + let udf = if wrapped_agg_type.is_none() && let Ok(schema) = self.first_valid_schema() && let Some(func) = schema .get_function_by_name_inputs(&func_name, &mut args) @@ -227,15 +227,15 @@ impl Binder { None }; - let agg_kind = if let Some(wrapped_agg_kind) = wrapped_agg_kind { - Some(wrapped_agg_kind) + let agg_type = if let Some(wrapped_agg_type) = wrapped_agg_type { + Some(wrapped_agg_type) } else if let Some(ref udf) = udf && udf.kind.is_aggregate() { assert_ne!(udf.language, "sql", "SQL UDAF is not supported yet"); Some(AggType::UserDefined(udf.as_ref().into())) - } else if let Ok(kind) = AggType::from_str(&func_name) { - Some(kind) + } else if let Ok(agg_type) = AggType::from_str(&func_name) { + Some(agg_type) } else { None }; @@ -259,9 +259,9 @@ impl Binder { "`WITHIN GROUP` is not allowed in window function call" ); - let kind = if let Some(agg_kind) = agg_kind { + let kind = if let Some(agg_type) = agg_type { // aggregate as window function - WindowFuncKind::Aggregate(agg_kind) + WindowFuncKind::Aggregate(agg_type) } else if let Ok(kind) = WindowFuncKind::from_str(&func_name) { kind } else { @@ -277,13 +277,13 @@ impl Binder { ); // try to bind it as an aggregate function call - if let Some(agg_kind) = agg_kind { + if let Some(agg_type) = agg_type { reject_syntax!( arg_list.variadic, "`VARIADIC` is not allowed in aggregate function call" ); return self.bind_aggregate_function( - agg_kind, + agg_type, arg_list.distinct, args, arg_list.order_by, diff --git a/src/frontend/src/binder/mod.rs b/src/frontend/src/binder/mod.rs index 5a4502a18ec05..adb7a1b9d0f2f 100644 --- a/src/frontend/src/binder/mod.rs +++ b/src/frontend/src/binder/mod.rs @@ -863,7 +863,7 @@ mod tests { select_items: [ AggCall( AggCall { - agg_kind: Builtin( + agg_type: Builtin( ApproxPercentile, ), return_type: Float64, diff --git a/src/frontend/src/expr/agg_call.rs b/src/frontend/src/expr/agg_call.rs index 4996d4bab18b5..3d298c90652a0 100644 --- a/src/frontend/src/expr/agg_call.rs +++ b/src/frontend/src/expr/agg_call.rs @@ -21,7 +21,7 @@ use crate::utils::Condition; #[derive(Clone, Eq, PartialEq, Hash)] pub struct AggCall { - pub agg_kind: AggType, + pub agg_type: AggType, pub return_type: DataType, pub args: Vec, pub distinct: bool, @@ -34,7 +34,7 @@ impl std::fmt::Debug for AggCall { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { if f.alternate() { f.debug_struct("AggCall") - .field("agg_kind", &self.agg_kind) + .field("agg_type", &self.agg_type) .field("return_type", &self.return_type) .field("args", &self.args) .field("filter", &self.filter) @@ -43,7 +43,7 @@ impl std::fmt::Debug for AggCall { .field("direct_args", &self.direct_args) .finish() } else { - let mut builder = f.debug_tuple(&format!("{}", self.agg_kind)); + let mut builder = f.debug_tuple(&format!("{}", self.agg_type)); self.args.iter().for_each(|child| { builder.field(child); }); @@ -56,20 +56,20 @@ impl AggCall { /// Returns error if the function name matches with an existing function /// but with illegal arguments. pub fn new( - agg_kind: AggType, + agg_type: AggType, mut args: Vec, distinct: bool, order_by: OrderBy, filter: Condition, direct_args: Vec, ) -> Result { - let return_type = match &agg_kind { + let return_type = match &agg_type { AggType::Builtin(kind) => infer_type((*kind).into(), &mut args)?, AggType::UserDefined(udf) => udf.return_type.as_ref().unwrap().into(), AggType::WrapScalar(expr) => expr.return_type.as_ref().unwrap().into(), }; Ok(AggCall { - agg_kind, + agg_type, return_type, args, distinct, @@ -81,12 +81,12 @@ impl AggCall { /// Constructs an `AggCall` without type inference. pub fn new_unchecked( - agg_kind: AggType, + agg_type: AggType, args: Vec, return_type: DataType, ) -> Result { Ok(AggCall { - agg_kind, + agg_type, return_type, args, distinct: false, @@ -96,8 +96,8 @@ impl AggCall { }) } - pub fn agg_kind(&self) -> AggType { - self.agg_kind.clone() + pub fn agg_type(&self) -> AggType { + self.agg_type.clone() } /// Get a reference to the agg call's arguments. diff --git a/src/frontend/src/expr/expr_rewriter.rs b/src/frontend/src/expr/expr_rewriter.rs index ccbf1b329bba5..15737f6528098 100644 --- a/src/frontend/src/expr/expr_rewriter.rs +++ b/src/frontend/src/expr/expr_rewriter.rs @@ -87,7 +87,7 @@ pub trait ExprRewriter { fn rewrite_agg_call(&mut self, agg_call: AggCall) -> ExprImpl { let AggCall { - agg_kind, + agg_type, return_type, args, distinct, @@ -102,7 +102,7 @@ pub trait ExprRewriter { let order_by = order_by.rewrite_expr(self); let filter = filter.rewrite_expr(self); AggCall { - agg_kind, + agg_type, return_type, args, distinct, diff --git a/src/frontend/src/expr/window_function.rs b/src/frontend/src/expr/window_function.rs index a4fd2b7b92a4a..e8e03e31606ce 100644 --- a/src/frontend/src/expr/window_function.rs +++ b/src/frontend/src/expr/window_function.rs @@ -87,7 +87,7 @@ impl WindowFunction { ); } - (Aggregate(agg_kind), args) => Ok(match agg_kind { + (Aggregate(agg_type), args) => Ok(match agg_type { AggType::Builtin(kind) => infer_type((*kind).into(), args)?, AggType::UserDefined(udf) => udf.return_type.as_ref().unwrap().into(), AggType::WrapScalar(expr) => expr.return_type.as_ref().unwrap().into(), diff --git a/src/frontend/src/optimizer/mod.rs b/src/frontend/src/optimizer/mod.rs index de5c3deaf0d6b..af9d589fe3df6 100644 --- a/src/frontend/src/optimizer/mod.rs +++ b/src/frontend/src/optimizer/mod.rs @@ -243,7 +243,7 @@ impl PlanRoot { let return_type = DataType::List(input_column_type.clone().into()); let agg = Agg::new( vec![PlanAggCall { - agg_kind: PbAggKind::ArrayAgg.into(), + agg_type: PbAggKind::ArrayAgg.into(), return_type: return_type.clone(), inputs: vec![InputRef::new(select_idx, input_column_type.clone())], distinct: false, diff --git a/src/frontend/src/optimizer/plan_node/batch_simple_agg.rs b/src/frontend/src/optimizer/plan_node/batch_simple_agg.rs index e01762fc56373..e8e8a222688dd 100644 --- a/src/frontend/src/optimizer/plan_node/batch_simple_agg.rs +++ b/src/frontend/src/optimizer/plan_node/batch_simple_agg.rs @@ -58,8 +58,8 @@ impl BatchSimpleAgg { // Ban two phase approx percentile. .agg_calls .iter() - .map(|agg_call| &agg_call.agg_kind) - .all(|agg_kind| !matches!(agg_kind, AggType::Builtin(PbAggKind::ApproxPercentile))) + .map(|agg_call| &agg_call.agg_type) + .all(|agg_type| !matches!(agg_type, AggType::Builtin(PbAggKind::ApproxPercentile))) && self.two_phase_agg_enabled() } } diff --git a/src/frontend/src/optimizer/plan_node/generic/agg.rs b/src/frontend/src/optimizer/plan_node/generic/agg.rs index 71c1567a7eafa..eb499279fe478 100644 --- a/src/frontend/src/optimizer/plan_node/generic/agg.rs +++ b/src/frontend/src/optimizer/plan_node/generic/agg.rs @@ -23,7 +23,7 @@ use risingwave_common::types::DataType; use risingwave_common::util::iter_util::ZipEqFast; use risingwave_common::util::sort_util::{ColumnOrder, ColumnOrderDisplay, OrderType}; use risingwave_common::util::value_encoding::DatumToProtoExt; -use risingwave_expr::aggregate::{agg_kinds, AggType, PbAggKind}; +use risingwave_expr::aggregate::{agg_types, AggType, PbAggKind}; use risingwave_expr::sig::{FuncBuilder, FUNCTION_REGISTRY}; use risingwave_pb::expr::{PbAggCall, PbConstant}; use risingwave_pb::stream_plan::{agg_call_state, AggCallState as PbAggCallState}; @@ -104,16 +104,16 @@ impl Agg { self.two_phase_agg_enabled() && !self.agg_calls.is_empty() && self.agg_calls.iter().all(|call| { - let agg_kind_ok = !matches!(call.agg_kind, agg_kinds::simply_cannot_two_phase!()); + let agg_type_ok = !matches!(call.agg_type, agg_types::simply_cannot_two_phase!()); let order_ok = matches!( - call.agg_kind, - agg_kinds::result_unaffected_by_order_by!() + call.agg_type, + agg_types::result_unaffected_by_order_by!() | AggType::Builtin(PbAggKind::ApproxPercentile) ) || call.order_by.is_empty(); let distinct_ok = - matches!(call.agg_kind, agg_kinds::result_unaffected_by_distinct!()) + matches!(call.agg_type, agg_types::result_unaffected_by_distinct!()) || !call.distinct; - agg_kind_ok && order_ok && distinct_ok + agg_type_ok && order_ok && distinct_ok }) } @@ -134,8 +134,8 @@ impl Agg { /// See if all stream aggregation calls have a stateless local agg counterpart. pub(crate) fn all_local_aggs_are_stateless(&self, stream_input_append_only: bool) -> bool { self.agg_calls.iter().all(|c| { - matches!(c.agg_kind, agg_kinds::single_value_state!()) - || (matches!(c.agg_kind, agg_kinds::single_value_state_iff_in_append_only!() if stream_input_append_only)) + matches!(c.agg_type, agg_types::single_value_state!()) + || (matches!(c.agg_type, agg_types::single_value_state_iff_in_append_only!() if stream_input_append_only)) }) } @@ -413,15 +413,15 @@ impl Agg { self.agg_calls .iter() - .map(|agg_call| match agg_call.agg_kind { - agg_kinds::single_value_state_iff_in_append_only!() if in_append_only => { + .map(|agg_call| match agg_call.agg_type { + agg_types::single_value_state_iff_in_append_only!() if in_append_only => { AggCallState::Value } - agg_kinds::single_value_state!() => AggCallState::Value, - agg_kinds::materialized_input_state!() => { + agg_types::single_value_state!() => AggCallState::Value, + agg_types::materialized_input_state!() => { // columns with order requirement in state table let sort_keys = { - match agg_call.agg_kind { + match agg_call.agg_type { AggType::Builtin(PbAggKind::Min) => { vec![(OrderType::ascending(), agg_call.inputs[0].index)] } @@ -439,7 +439,7 @@ impl Agg { if agg_call.order_by.is_empty() { me.ctx().warn_to_user(format!( "{} without ORDER BY may produce non-deterministic result", - agg_call.agg_kind, + agg_call.agg_type, )); } agg_call @@ -448,7 +448,7 @@ impl Agg { .map(|o| { ( if matches!( - agg_call.agg_kind, + agg_call.agg_type, AggType::Builtin(PbAggKind::LastValue) ) { o.order_type.reverse() @@ -480,8 +480,8 @@ impl Agg { }; // other columns that should be contained in state table - let include_keys = match agg_call.agg_kind { - // `agg_kinds::materialized_input_state` except for `min`/`max` + let include_keys = match agg_call.agg_type { + // `agg_types::materialized_input_state` except for `min`/`max` AggType::Builtin( PbAggKind::FirstValue | PbAggKind::LastValue @@ -499,10 +499,10 @@ impl Agg { let state = gen_materialized_input_state(sort_keys, extra_keys, include_keys); AggCallState::MaterializedInput(Box::new(state)) } - agg_kinds::rewritten!() => { + agg_types::rewritten!() => { unreachable!("should have been rewritten") } - agg_kinds::unimplemented_in_stream!() => { + agg_types::unimplemented_in_stream!() => { unreachable!("should have been banned") } AggType::Builtin( @@ -531,7 +531,7 @@ impl Agg { .iter() .zip_eq_fast(&mut out_fields[self.group_key.len()..]) { - let agg_kind = match agg_call.agg_kind { + let agg_kind = match agg_call.agg_type { AggType::UserDefined(_) => { // for user defined aggregate, the state type is always BYTEA field.data_type = DataType::Bytea; @@ -704,8 +704,8 @@ impl_distill_unit_from_fields!(Agg, stream::StreamPlanRef); /// for more details. #[derive(Clone, PartialEq, Eq, Hash)] pub struct PlanAggCall { - /// Kind of aggregation function - pub agg_kind: AggType, + /// Type of aggregation function + pub agg_type: AggType, /// Data type of the returned column pub return_type: DataType, @@ -730,7 +730,7 @@ pub struct PlanAggCall { impl fmt::Debug for PlanAggCall { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.agg_kind)?; + write!(f, "{}", self.agg_type)?; if !self.inputs.is_empty() { write!(f, "(")?; for (idx, input) in self.inputs.iter().enumerate() { @@ -780,7 +780,7 @@ impl PlanAggCall { pub fn to_protobuf(&self) -> PbAggCall { PbAggCall { - kind: match &self.agg_kind { + kind: match &self.agg_type { AggType::Builtin(kind) => *kind, AggType::UserDefined(_) => PbAggKind::UserDefined, AggType::WrapScalar(_) => PbAggKind::WrapScalar, @@ -799,11 +799,11 @@ impl PlanAggCall { r#type: Some(x.return_type().to_protobuf()), }) .collect(), - udf: match &self.agg_kind { + udf: match &self.agg_type { AggType::UserDefined(udf) => Some(udf.clone()), _ => None, }, - scalar: match &self.agg_kind { + scalar: match &self.agg_type { AggType::WrapScalar(expr) => Some(expr.clone()), _ => None, }, @@ -811,12 +811,12 @@ impl PlanAggCall { } pub fn partial_to_total_agg_call(&self, partial_output_idx: usize) -> PlanAggCall { - let total_agg_kind = self - .agg_kind + let total_agg_type = self + .agg_type .partial_to_total() .expect("unsupported kinds shouldn't get here"); PlanAggCall { - agg_kind: total_agg_kind, + agg_type: total_agg_type, inputs: vec![InputRef::new(partial_output_idx, self.return_type.clone())], order_by: vec![], // order must make no difference when we use 2-phase agg filter: Condition::true_cond(), @@ -826,7 +826,7 @@ impl PlanAggCall { pub fn count_star() -> Self { PlanAggCall { - agg_kind: PbAggKind::Count.into(), + agg_type: PbAggKind::Count.into(), return_type: DataType::Int64, inputs: vec![], distinct: false, @@ -854,7 +854,7 @@ pub struct PlanAggCallDisplay<'a> { impl fmt::Debug for PlanAggCallDisplay<'_> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let that = self.plan_agg_call; - write!(f, "{}", that.agg_kind)?; + write!(f, "{}", that.agg_type)?; if !that.inputs.is_empty() { write!(f, "(")?; for (idx, input) in that.inputs.iter().enumerate() { diff --git a/src/frontend/src/optimizer/plan_node/generic/over_window.rs b/src/frontend/src/optimizer/plan_node/generic/over_window.rs index c39fd99be9895..c8d2f64ba6141 100644 --- a/src/frontend/src/optimizer/plan_node/generic/over_window.rs +++ b/src/frontend/src/optimizer/plan_node/generic/over_window.rs @@ -121,7 +121,7 @@ impl PlanWindowFunction { DenseRank => PbType::General(PbGeneralType::DenseRank as _), Lag => PbType::General(PbGeneralType::Lag as _), Lead => PbType::General(PbGeneralType::Lead as _), - Aggregate(agg_kind) => PbType::Aggregate(agg_kind.to_protobuf() as _), + Aggregate(agg_type) => PbType::Aggregate(agg_type.to_protobuf() as _), }; PbWindowFunction { diff --git a/src/frontend/src/optimizer/plan_node/logical_agg.rs b/src/frontend/src/optimizer/plan_node/logical_agg.rs index 76af6595673b8..4e2474287c969 100644 --- a/src/frontend/src/optimizer/plan_node/logical_agg.rs +++ b/src/frontend/src/optimizer/plan_node/logical_agg.rs @@ -17,7 +17,7 @@ use itertools::Itertools; use risingwave_common::types::{DataType, Datum, ScalarImpl}; use risingwave_common::util::sort_util::{ColumnOrder, OrderType}; use risingwave_common::{bail, bail_not_implemented, not_implemented}; -use risingwave_expr::aggregate::{agg_kinds, AggType, PbAggKind}; +use risingwave_expr::aggregate::{agg_types, AggType, PbAggKind}; use super::generic::{self, Agg, GenericPlanRef, PlanAggCall, ProjectBuilder}; use super::utils::impl_distill_by_unit; @@ -389,7 +389,7 @@ impl LogicalAgg { let mut approx_percentile_col_mapping = Vec::with_capacity(estimated_len); let mut non_approx_percentile_col_mapping = Vec::with_capacity(estimated_len); for (output_idx, agg_call) in self.agg_calls().iter().enumerate() { - if agg_call.agg_kind == AggType::Builtin(PbAggKind::ApproxPercentile) { + if agg_call.agg_type == AggType::Builtin(PbAggKind::ApproxPercentile) { approx_percentile_agg_calls.push(agg_call.clone()); approx_percentile_col_mapping.push(Some(output_idx)); } else { @@ -610,7 +610,7 @@ impl LogicalAggBuilder { agg_call: AggCall, mut push_agg_call: impl FnMut(AggCall) -> Result, ) -> Result { - match agg_call.agg_kind { + match agg_call.agg_type { // Rewrite avg to cast(sum as avg_return_type) / count. AggType::Builtin(PbAggKind::Avg) => { assert_eq!(agg_call.args.len(), 1); @@ -778,7 +778,7 @@ impl LogicalAggBuilder { /// For existing agg calls, return an `InputRef` to the existing one. fn push_agg_call(&mut self, agg_call: AggCall) -> Result { let AggCall { - agg_kind, + agg_type, return_type, args, distinct, @@ -815,7 +815,7 @@ impl LogicalAggBuilder { })?; let plan_agg_call = PlanAggCall { - agg_kind, + agg_type, return_type: return_type.clone(), inputs: args, distinct, @@ -846,32 +846,32 @@ impl LogicalAggBuilder { /// /// Note that the rewriter does not traverse into inputs of agg calls. fn try_rewrite_agg_call(&mut self, mut agg_call: AggCall) -> Result { - if matches!(agg_call.agg_kind, agg_kinds::must_have_order_by!()) + if matches!(agg_call.agg_type, agg_types::must_have_order_by!()) && agg_call.order_by.sort_exprs.is_empty() { return Err(ErrorCode::InvalidInputSyntax(format!( "Aggregation function {} requires ORDER BY clause", - agg_call.agg_kind + agg_call.agg_type )) .into()); } // try ignore ORDER BY if it doesn't affect the result if matches!( - agg_call.agg_kind, - agg_kinds::result_unaffected_by_order_by!() + agg_call.agg_type, + agg_types::result_unaffected_by_order_by!() ) { agg_call.order_by = OrderBy::any(); } // try ignore DISTINCT if it doesn't affect the result if matches!( - agg_call.agg_kind, - agg_kinds::result_unaffected_by_distinct!() + agg_call.agg_type, + agg_types::result_unaffected_by_distinct!() ) { agg_call.distinct = false; } - if matches!(agg_call.agg_kind, AggType::Builtin(PbAggKind::Grouping)) { + if matches!(agg_call.agg_type, AggType::Builtin(PbAggKind::Grouping)) { if self.grouping_sets.is_empty() { return Err(ErrorCode::NotSupported( "GROUPING must be used in a query with grouping sets".into(), @@ -1343,8 +1343,8 @@ impl ToStream for LogicalAgg { use super::stream::prelude::*; for agg_call in self.agg_calls() { - if matches!(agg_call.agg_kind, agg_kinds::unimplemented_in_stream!()) { - bail_not_implemented!("{} aggregation in materialized view", agg_call.agg_kind); + if matches!(agg_call.agg_type, agg_types::unimplemented_in_stream!()) { + bail_not_implemented!("{} aggregation in materialized view", agg_call.agg_type); } } let eowc = ctx.emit_on_window_close(); @@ -1517,7 +1517,7 @@ mod tests { assert_eq_input_ref!(&exprs[1], 1); assert_eq!(agg_calls.len(), 1); - assert_eq!(agg_calls[0].agg_kind, PbAggKind::Min.into()); + assert_eq!(agg_calls[0].agg_type, PbAggKind::Min.into()); assert_eq!(input_ref_to_column_indices(&agg_calls[0].inputs), vec![1]); assert_eq!(group_key, vec![0].into()); } @@ -1560,9 +1560,9 @@ mod tests { } assert_eq!(agg_calls.len(), 2); - assert_eq!(agg_calls[0].agg_kind, PbAggKind::Min.into()); + assert_eq!(agg_calls[0].agg_type, PbAggKind::Min.into()); assert_eq!(input_ref_to_column_indices(&agg_calls[0].inputs), vec![1]); - assert_eq!(agg_calls[1].agg_kind, PbAggKind::Max.into()); + assert_eq!(agg_calls[1].agg_type, PbAggKind::Max.into()); assert_eq!(input_ref_to_column_indices(&agg_calls[1].inputs), vec![2]); assert_eq!(group_key, vec![0].into()); } @@ -1592,7 +1592,7 @@ mod tests { assert_eq_input_ref!(&exprs[1], 1); assert_eq!(agg_calls.len(), 1); - assert_eq!(agg_calls[0].agg_kind, PbAggKind::Min.into()); + assert_eq!(agg_calls[0].agg_type, PbAggKind::Min.into()); assert_eq!(input_ref_to_column_indices(&agg_calls[0].inputs), vec![1]); assert_eq!(group_key, vec![0].into()); } @@ -1609,7 +1609,7 @@ mod tests { let values = LogicalValues::new(vec![], Schema { fields }, ctx); let agg_call = PlanAggCall { - agg_kind: PbAggKind::Min.into(), + agg_type: PbAggKind::Min.into(), return_type: ty.clone(), inputs: vec![InputRef::new(2, ty.clone())], distinct: false, @@ -1649,7 +1649,7 @@ mod tests { assert_eq!(agg_new.agg_calls().len(), 1); let agg_call_new = agg_new.agg_calls()[0].clone(); - assert_eq!(agg_call_new.agg_kind, PbAggKind::Min.into()); + assert_eq!(agg_call_new.agg_type, PbAggKind::Min.into()); assert_eq!(input_ref_to_column_indices(&agg_call_new.inputs), vec![1]); assert_eq!(agg_call_new.return_type, ty); @@ -1692,7 +1692,7 @@ mod tests { assert_eq!(agg_new.agg_calls().len(), 1); let agg_call_new = agg_new.agg_calls()[0].clone(); - assert_eq!(agg_call_new.agg_kind, PbAggKind::Min.into()); + assert_eq!(agg_call_new.agg_type, PbAggKind::Min.into()); assert_eq!(input_ref_to_column_indices(&agg_call_new.inputs), vec![1]); assert_eq!(agg_call_new.return_type, ty); @@ -1729,7 +1729,7 @@ mod tests { ctx, ); let agg_call = PlanAggCall { - agg_kind: PbAggKind::Min.into(), + agg_type: PbAggKind::Min.into(), return_type: ty.clone(), inputs: vec![InputRef::new(2, ty.clone())], distinct: false, @@ -1754,7 +1754,7 @@ mod tests { assert_eq!(agg_new.agg_calls().len(), 1); let agg_call_new = agg_new.agg_calls()[0].clone(); - assert_eq!(agg_call_new.agg_kind, PbAggKind::Min.into()); + assert_eq!(agg_call_new.agg_type, PbAggKind::Min.into()); assert_eq!(input_ref_to_column_indices(&agg_call_new.inputs), vec![1]); assert_eq!(agg_call_new.return_type, ty); @@ -1793,7 +1793,7 @@ mod tests { let agg_calls = vec![ PlanAggCall { - agg_kind: PbAggKind::Min.into(), + agg_type: PbAggKind::Min.into(), return_type: ty.clone(), inputs: vec![InputRef::new(2, ty.clone())], distinct: false, @@ -1802,7 +1802,7 @@ mod tests { direct_args: vec![], }, PlanAggCall { - agg_kind: PbAggKind::Max.into(), + agg_type: PbAggKind::Max.into(), return_type: ty.clone(), inputs: vec![InputRef::new(1, ty.clone())], distinct: false, @@ -1828,7 +1828,7 @@ mod tests { assert_eq!(agg_new.agg_calls().len(), 1); let agg_call_new = agg_new.agg_calls()[0].clone(); - assert_eq!(agg_call_new.agg_kind, PbAggKind::Max.into()); + assert_eq!(agg_call_new.agg_type, PbAggKind::Max.into()); assert_eq!(input_ref_to_column_indices(&agg_call_new.inputs), vec![0]); assert_eq!(agg_call_new.return_type, ty); diff --git a/src/frontend/src/optimizer/plan_node/logical_over_window.rs b/src/frontend/src/optimizer/plan_node/logical_over_window.rs index f196aeca5f4a2..a790e006ddd97 100644 --- a/src/frontend/src/optimizer/plan_node/logical_over_window.rs +++ b/src/frontend/src/optimizer/plan_node/logical_over_window.rs @@ -105,9 +105,9 @@ impl<'a> LogicalOverWindowBuilder<'a> { window_func.frame, ); - let new_expr = if let WindowFuncKind::Aggregate(agg_kind) = &kind + let new_expr = if let WindowFuncKind::Aggregate(agg_type) = &kind && matches!( - agg_kind, + agg_type, AggType::Builtin( PbAggKind::Avg | PbAggKind::StddevPop @@ -117,7 +117,7 @@ impl<'a> LogicalOverWindowBuilder<'a> { ) ) { let agg_call = AggCall::new( - agg_kind.clone(), + agg_type.clone(), args, false, order_by, @@ -128,7 +128,7 @@ impl<'a> LogicalOverWindowBuilder<'a> { Ok(self.push_window_func( // AggCall -> WindowFunction WindowFunction::new( - WindowFuncKind::Aggregate(agg_call.agg_kind), + WindowFuncKind::Aggregate(agg_call.agg_type), partition_by.clone(), agg_call.order_by.clone(), agg_call.args.clone(), @@ -188,9 +188,9 @@ impl<'a> OverWindowProjectBuilder<'a> { &mut self, window_function: &WindowFunction, ) -> std::result::Result<(), ErrorCode> { - if let WindowFuncKind::Aggregate(agg_kind) = &window_function.kind + if let WindowFuncKind::Aggregate(agg_type) = &window_function.kind && matches!( - agg_kind, + agg_type, AggType::Builtin( PbAggKind::StddevPop | PbAggKind::StddevSamp diff --git a/src/frontend/src/optimizer/plan_node/utils.rs b/src/frontend/src/optimizer/plan_node/utils.rs index bc3c223c615e6..6a17d2507e558 100644 --- a/src/frontend/src/optimizer/plan_node/utils.rs +++ b/src/frontend/src/optimizer/plan_node/utils.rs @@ -288,7 +288,7 @@ pub(crate) fn sum_affected_row(dml: PlanRef) -> Result { let dml = RequiredDist::single().enforce_if_not_satisfies(dml, &Order::any())?; // Accumulate the affected rows. let sum_agg = PlanAggCall { - agg_kind: PbAggKind::Sum.into(), + agg_type: PbAggKind::Sum.into(), return_type: DataType::Int64, inputs: vec![InputRef::new(0, DataType::Int64)], distinct: false, diff --git a/src/frontend/src/optimizer/rule/agg_group_by_simplify_rule.rs b/src/frontend/src/optimizer/rule/agg_group_by_simplify_rule.rs index f7bc29611b618..af41bce3d3216 100644 --- a/src/frontend/src/optimizer/rule/agg_group_by_simplify_rule.rs +++ b/src/frontend/src/optimizer/rule/agg_group_by_simplify_rule.rs @@ -47,7 +47,7 @@ impl Rule for AggGroupBySimplifyRule { if !new_group_key.contains(i) { let data_type = agg_input.schema().fields[i].data_type(); new_agg_calls.push(PlanAggCall { - agg_kind: PbAggKind::InternalLastSeenValue.into(), + agg_type: PbAggKind::InternalLastSeenValue.into(), return_type: data_type.clone(), inputs: vec![InputRef::new(i, data_type)], distinct: false, diff --git a/src/frontend/src/optimizer/rule/apply_agg_transpose_rule.rs b/src/frontend/src/optimizer/rule/apply_agg_transpose_rule.rs index 0b70c57036dc9..a17dd46031b3c 100644 --- a/src/frontend/src/optimizer/rule/apply_agg_transpose_rule.rs +++ b/src/frontend/src/optimizer/rule/apply_agg_transpose_rule.rs @@ -140,7 +140,7 @@ impl Rule for ApplyAggTransposeRule { // convert count(*) to count(1). let pos_of_constant_column = node.schema().len() - 1; agg_calls.iter_mut().for_each(|agg_call| { - match agg_call.agg_kind { + match agg_call.agg_type { AggType::Builtin(PbAggKind::Count) if agg_call.inputs.is_empty() => { let input_ref = InputRef::new(pos_of_constant_column, DataType::Int32); agg_call.inputs.push(input_ref); @@ -187,7 +187,7 @@ impl Rule for ApplyAggTransposeRule { // no-op when `agg(0 rows) == agg(1 row of nulls)` } AggType::Builtin(PbAggKind::Unspecified | PbAggKind::UserDefined | PbAggKind::WrapScalar) => { - panic!("Unexpected aggregate function: {:?}", agg_call.agg_kind) + panic!("Unexpected aggregate function: {:?}", agg_call.agg_type) } } }); diff --git a/src/frontend/src/optimizer/rule/distinct_agg_rule.rs b/src/frontend/src/optimizer/rule/distinct_agg_rule.rs index fe12de26473d2..4bc4f51ecf223 100644 --- a/src/frontend/src/optimizer/rule/distinct_agg_rule.rs +++ b/src/frontend/src/optimizer/rule/distinct_agg_rule.rs @@ -18,7 +18,7 @@ use std::mem; use fixedbitset::FixedBitSet; use itertools::Itertools; use risingwave_common::types::DataType; -use risingwave_expr::aggregate::{agg_kinds, AggType, PbAggKind}; +use risingwave_expr::aggregate::{agg_types, AggType, PbAggKind}; use super::{BoxedRule, Rule}; use crate::expr::{CollectInputRef, ExprType, FunctionCall, InputRef, Literal}; @@ -52,17 +52,17 @@ impl Rule for DistinctAggRule { if !agg_calls.iter().all(|c| { assert!( - !matches!(c.agg_kind, agg_kinds::rewritten!()), + !matches!(c.agg_type, agg_types::rewritten!()), "We shouldn't see agg kind {} here", - c.agg_kind + c.agg_type ); - let agg_kind_ok = !matches!(c.agg_kind, agg_kinds::simply_cannot_two_phase!()); + let agg_type_ok = !matches!(c.agg_type, agg_types::simply_cannot_two_phase!()); let order_ok = matches!( - c.agg_kind, - agg_kinds::result_unaffected_by_order_by!() + c.agg_type, + agg_types::result_unaffected_by_order_by!() | AggType::Builtin(PbAggKind::ApproxPercentile) ) || c.order_by.is_empty(); - agg_kind_ok && order_ok + agg_type_ok && order_ok }) { tracing::warn!("DistinctAggRule: unsupported agg kind, fallback to backend impl"); return None; @@ -305,8 +305,8 @@ impl DistinctAggRule { // the filter of non-distinct agg has been calculated in middle agg. agg_call.filter = Condition::true_cond(); - // change final agg's agg_kind just like two-phase agg. - agg_call.agg_kind = agg_call.agg_kind.partial_to_total().expect( + // change final agg's agg_type just like two-phase agg. + agg_call.agg_type = agg_call.agg_type.partial_to_total().expect( "we should get a valid total phase agg kind here since unsupported cases have been filtered out" ); diff --git a/src/frontend/src/optimizer/rule/grouping_sets_to_expand_rule.rs b/src/frontend/src/optimizer/rule/grouping_sets_to_expand_rule.rs index 7856442ad01f8..75159971bb225 100644 --- a/src/frontend/src/optimizer/rule/grouping_sets_to_expand_rule.rs +++ b/src/frontend/src/optimizer/rule/grouping_sets_to_expand_rule.rs @@ -105,7 +105,7 @@ impl Rule for GroupingSetsToExpandRule { let mut new_agg_calls = vec![]; for agg_call in old_agg_calls { // Deal with grouping agg call for grouping sets. - if matches!(agg_call.agg_kind, AggType::Builtin(PbAggKind::Grouping)) { + if matches!(agg_call.agg_type, AggType::Builtin(PbAggKind::Grouping)) { let mut grouping_values = vec![]; let args = agg_call .inputs diff --git a/src/frontend/src/optimizer/rule/min_max_on_index_rule.rs b/src/frontend/src/optimizer/rule/min_max_on_index_rule.rs index 440eb253a9bce..4809c3c82a058 100644 --- a/src/frontend/src/optimizer/rule/min_max_on_index_rule.rs +++ b/src/frontend/src/optimizer/rule/min_max_on_index_rule.rs @@ -50,14 +50,14 @@ impl Rule for MinMaxOnIndexRule { let first_call = calls.iter().exactly_one().ok()?; if matches!( - first_call.agg_kind, + first_call.agg_type, AggType::Builtin(PbAggKind::Min | PbAggKind::Max) ) && !first_call.distinct && first_call.filter.always_true() && first_call.order_by.is_empty() { let logical_scan: LogicalScan = logical_agg.input().as_logical_scan()?.to_owned(); - let kind = &calls.first()?.agg_kind; + let kind = &calls.first()?.agg_type; if !logical_scan.predicate().always_true() { return None; } @@ -114,7 +114,7 @@ impl MinMaxOnIndexRule { let formatting_agg = Agg::new( vec![PlanAggCall { - agg_kind: logical_agg.agg_calls().first()?.agg_kind.clone(), + agg_type: logical_agg.agg_calls().first()?.agg_type.clone(), return_type: logical_agg.schema().fields[0].data_type.clone(), inputs: vec![InputRef::new( 0, @@ -184,7 +184,7 @@ impl MinMaxOnIndexRule { let formatting_agg = Agg::new( vec![PlanAggCall { - agg_kind: logical_agg.agg_calls().first()?.agg_kind.clone(), + agg_type: logical_agg.agg_calls().first()?.agg_type.clone(), return_type: logical_agg.schema().fields[0].data_type.clone(), inputs: vec![InputRef::new( 0, diff --git a/src/frontend/src/optimizer/rule/pull_up_correlated_predicate_agg_rule.rs b/src/frontend/src/optimizer/rule/pull_up_correlated_predicate_agg_rule.rs index 41aa080cd5aa2..4a59dcda785b8 100644 --- a/src/frontend/src/optimizer/rule/pull_up_correlated_predicate_agg_rule.rs +++ b/src/frontend/src/optimizer/rule/pull_up_correlated_predicate_agg_rule.rs @@ -162,14 +162,14 @@ impl Rule for PullUpCorrelatedPredicateAggRule { // sum is null, so avg is null. And null-rejected expression will be false, so we can still apply this rule and we don't need to generate a 0 value for count. let count_exists = agg_calls .iter() - .any(|agg_call| matches!(agg_call.agg_kind, AggType::Builtin(PbAggKind::Count))); + .any(|agg_call| matches!(agg_call.agg_type, AggType::Builtin(PbAggKind::Count))); if count_exists { // When group input is empty, not count agg would return null. let null_agg_pos = agg_calls .iter() .positions(|agg_call| { - !matches!(agg_call.agg_kind, AggType::Builtin(PbAggKind::Count)) + !matches!(agg_call.agg_type, AggType::Builtin(PbAggKind::Count)) }) .collect_vec(); From 9988c6089063286bd30ff37b8ad4f28bf8553400 Mon Sep 17 00:00:00 2001 From: Richard Chien Date: Mon, 23 Sep 2024 17:09:48 +0800 Subject: [PATCH 5/6] rename other agg `kind` to `agg_type` Signed-off-by: Richard Chien --- src/expr/core/src/aggregate/def.rs | 8 ++++---- src/expr/core/src/aggregate/mod.rs | 2 +- src/expr/core/src/sig/mod.rs | 2 +- src/expr/impl/benches/expr.rs | 2 +- src/expr/impl/src/window_function/aggregate.rs | 2 +- src/stream/src/executor/aggregation/minput.rs | 14 +++++++------- src/stream/src/executor/aggregation/mod.rs | 2 +- src/stream/src/executor/test_utils.rs | 4 ++-- 8 files changed, 18 insertions(+), 18 deletions(-) diff --git a/src/expr/core/src/aggregate/def.rs b/src/expr/core/src/aggregate/def.rs index 9127170bed12a..8a6ba3ead6fb9 100644 --- a/src/expr/core/src/aggregate/def.rs +++ b/src/expr/core/src/aggregate/def.rs @@ -41,8 +41,8 @@ use crate::Result; // advanced features like order by, filter, distinct, etc. should be handled by the upper layer. #[derive(Debug, Clone)] pub struct AggCall { - /// Aggregation kind for constructing agg state. - pub kind: AggType, + /// Aggregation type for constructing agg state. + pub agg_type: AggType, /// Arguments of aggregation function input. pub args: AggArgs, @@ -96,7 +96,7 @@ impl AggCall { }) .collect_vec(); Ok(AggCall { - kind: agg_type, + agg_type, args, return_type: DataType::from(agg_call.get_return_type()?), column_orders, @@ -160,7 +160,7 @@ impl> Parser { self.tokens.next(); // Consume the RParen AggCall { - kind: AggType::from_protobuf(func, None, None).unwrap(), + agg_type: AggType::from_protobuf(func, None, None).unwrap(), args: AggArgs { data_types: children.iter().map(|(_, ty)| ty.clone()).collect(), val_indices: children.iter().map(|(idx, _)| *idx).collect(), diff --git a/src/expr/core/src/aggregate/mod.rs b/src/expr/core/src/aggregate/mod.rs index eb443de42978f..436b96b418ce4 100644 --- a/src/expr/core/src/aggregate/mod.rs +++ b/src/expr/core/src/aggregate/mod.rs @@ -146,7 +146,7 @@ pub fn build_retractable(agg: &AggCall) -> Result { /// `AggCall`. Such operations should be done in batch or streaming executors. pub fn build(agg: &AggCall, prefer_append_only: bool) -> Result { // handle special kinds - let kind = match &agg.kind { + let kind = match &agg.agg_type { AggType::UserDefined(udf) => { return user_defined::new_user_defined(&agg.return_type, udf); } diff --git a/src/expr/core/src/sig/mod.rs b/src/expr/core/src/sig/mod.rs index 48280c124c156..fc15cfce36098 100644 --- a/src/expr/core/src/sig/mod.rs +++ b/src/expr/core/src/sig/mod.rs @@ -407,7 +407,7 @@ impl FuncName { pub fn as_aggregate(&self) -> PbAggKind { match self { - Self::Aggregate(ty) => *ty, + Self::Aggregate(kind) => *kind, _ => panic!("Expected an aggregate function"), } } diff --git a/src/expr/impl/benches/expr.rs b/src/expr/impl/benches/expr.rs index 06cf18e4fbd0c..a8895435a6a80 100644 --- a/src/expr/impl/benches/expr.rs +++ b/src/expr/impl/benches/expr.rs @@ -395,7 +395,7 @@ fn bench_expr(c: &mut Criterion) { continue; } let agg = match build_append_only(&AggCall { - kind: sig.name.as_aggregate().into(), + agg_type: sig.name.as_aggregate().into(), args: sig .inputs_type .iter() diff --git a/src/expr/impl/src/window_function/aggregate.rs b/src/expr/impl/src/window_function/aggregate.rs index b922a51103992..9a30d103ade1a 100644 --- a/src/expr/impl/src/window_function/aggregate.rs +++ b/src/expr/impl/src/window_function/aggregate.rs @@ -53,7 +53,7 @@ pub(super) fn new(call: &WindowFuncCall) -> Result { let agg_type = must_match!(&call.kind, WindowFuncKind::Aggregate(agg_type) => agg_type); let arg_data_types = call.args.arg_types().to_vec(); let agg_call = AggCall { - kind: agg_type.clone(), + agg_type: agg_type.clone(), args: call.args.clone(), return_type: call.return_type.clone(), column_orders: Vec::new(), // the input is already sorted diff --git a/src/stream/src/executor/aggregation/minput.rs b/src/stream/src/executor/aggregation/minput.rs index 89fd8881a691e..8e140c9bebae6 100644 --- a/src/stream/src/executor/aggregation/minput.rs +++ b/src/stream/src/executor/aggregation/minput.rs @@ -124,7 +124,7 @@ impl MaterializedInputState { .collect_vec(); let cache_key_serializer = OrderedRowSerde::new(cache_key_data_types, order_types); - let cache: Box = match agg_call.kind { + let cache: Box = match agg_call.agg_type { AggType::Builtin( PbAggKind::Min | PbAggKind::Max | PbAggKind::FirstValue | PbAggKind::LastValue, ) => Box::new(GenericAggStateCache::new( @@ -142,12 +142,12 @@ impl MaterializedInputState { agg_call.args.arg_types(), )), _ => panic!( - "Agg kind `{}` is not expected to have materialized input state", - agg_call.kind + "Agg type `{}` is not expected to have materialized input state", + agg_call.agg_type ), }; let output_first_value = matches!( - agg_call.kind, + agg_call.agg_type, AggType::Builtin( PbAggKind::Min | PbAggKind::Max | PbAggKind::FirstValue | PbAggKind::LastValue ) @@ -245,12 +245,12 @@ fn generate_order_columns_before_version_issue_13465( arg_col_indices: &[usize], ) -> (Vec, Vec) { let (mut order_col_indices, mut order_types) = if matches!( - agg_call.kind, + agg_call.agg_type, AggType::Builtin(PbAggKind::Min | PbAggKind::Max) ) { // `min`/`max` need not to order by any other columns, but have to // order by the agg value implicitly. - let order_type = if matches!(agg_call.kind, AggType::Builtin(PbAggKind::Min)) { + let order_type = if matches!(agg_call.agg_type, AggType::Builtin(PbAggKind::Min)) { OrderType::ascending() } else { OrderType::descending() @@ -263,7 +263,7 @@ fn generate_order_columns_before_version_issue_13465( .map(|p| { ( p.column_index, - if matches!(agg_call.kind, AggType::Builtin(PbAggKind::LastValue)) { + if matches!(agg_call.agg_type, AggType::Builtin(PbAggKind::LastValue)) { p.order_type.reverse() } else { p.order_type diff --git a/src/stream/src/executor/aggregation/mod.rs b/src/stream/src/executor/aggregation/mod.rs index 1695a0deab086..676ecc3d7c4da 100644 --- a/src/stream/src/executor/aggregation/mod.rs +++ b/src/stream/src/executor/aggregation/mod.rs @@ -38,7 +38,7 @@ pub async fn agg_call_filter_res( ) -> StreamExecutorResult { let mut vis = chunk.visibility().clone(); if matches!( - agg_call.kind, + agg_call.agg_type, AggType::Builtin(PbAggKind::Min | PbAggKind::Max | PbAggKind::StringAgg) ) { // should skip NULL value for these kinds of agg function diff --git a/src/stream/src/executor/test_utils.rs b/src/stream/src/executor/test_utils.rs index 793dfb0e7266f..76301a2de5099 100644 --- a/src/stream/src/executor/test_utils.rs +++ b/src/stream/src/executor/test_utils.rs @@ -328,7 +328,7 @@ pub mod agg_executor { input_fields: Vec, is_append_only: bool, ) -> AggStateStorage { - match agg_call.kind { + match agg_call.agg_type { AggType::Builtin(PbAggKind::Min | PbAggKind::Max) if !is_append_only => { let mut column_descs = Vec::new(); let mut order_types = Vec::new(); @@ -353,7 +353,7 @@ pub mod agg_executor { add_column(*idx, input_fields[*idx].data_type(), None); } - add_column(agg_call.args.val_indices()[0], agg_call.args.arg_types()[0].clone(), if matches!(agg_call.kind, AggType::Builtin(PbAggKind::Max)) { + add_column(agg_call.args.val_indices()[0], agg_call.args.arg_types()[0].clone(), if matches!(agg_call.agg_type, AggType::Builtin(PbAggKind::Max)) { Some(OrderType::descending()) } else { Some(OrderType::ascending()) From fbf4f06911a22757d5b39d8a823fa442ee9ea89c Mon Sep 17 00:00:00 2001 From: Richard Chien Date: Mon, 23 Sep 2024 17:10:02 +0800 Subject: [PATCH 6/6] rename `AggType::to_protobuf` to `AggType::to_protobuf_simple` Signed-off-by: Richard Chien --- src/batch/benches/hash_agg.rs | 2 +- src/expr/core/src/aggregate/def.rs | 2 +- src/frontend/src/optimizer/plan_node/generic/over_window.rs | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/batch/benches/hash_agg.rs b/src/batch/benches/hash_agg.rs index 1d77a3430c2a6..b4d773ae425f2 100644 --- a/src/batch/benches/hash_agg.rs +++ b/src/batch/benches/hash_agg.rs @@ -39,7 +39,7 @@ fn create_agg_call( return_type: DataType, ) -> PbAggCall { PbAggCall { - kind: agg_type.to_protobuf() as i32, + kind: agg_type.to_protobuf_simple() as i32, args: args .into_iter() .map(|col_idx| PbInputRef { diff --git a/src/expr/core/src/aggregate/def.rs b/src/expr/core/src/aggregate/def.rs index 8a6ba3ead6fb9..3abe80dcd4d31 100644 --- a/src/expr/core/src/aggregate/def.rs +++ b/src/expr/core/src/aggregate/def.rs @@ -279,7 +279,7 @@ impl AggType { } } - pub fn to_protobuf(&self) -> PbAggKind { + pub fn to_protobuf_simple(&self) -> PbAggKind { match self { Self::Builtin(pb) => *pb, Self::UserDefined(_) => PbAggKind::UserDefined, diff --git a/src/frontend/src/optimizer/plan_node/generic/over_window.rs b/src/frontend/src/optimizer/plan_node/generic/over_window.rs index c8d2f64ba6141..5622d1e8952cf 100644 --- a/src/frontend/src/optimizer/plan_node/generic/over_window.rs +++ b/src/frontend/src/optimizer/plan_node/generic/over_window.rs @@ -121,7 +121,7 @@ impl PlanWindowFunction { DenseRank => PbType::General(PbGeneralType::DenseRank as _), Lag => PbType::General(PbGeneralType::Lag as _), Lead => PbType::General(PbGeneralType::Lead as _), - Aggregate(agg_type) => PbType::Aggregate(agg_type.to_protobuf() as _), + Aggregate(agg_type) => PbType::Aggregate(agg_type.to_protobuf_simple() as _), }; PbWindowFunction {