diff --git a/e2e_test/udf/bug_fixes/17560_udaf_as_win_func.slt b/e2e_test/udf/bug_fixes/17560_udaf_as_win_func.slt index 3d8e3279a8b42..0818e8f1e67b1 100644 --- a/e2e_test/udf/bug_fixes/17560_udaf_as_win_func.slt +++ b/e2e_test/udf/bug_fixes/17560_udaf_as_win_func.slt @@ -17,3 +17,69 @@ select t.value, sum00(weight) OVER (PARTITION BY value) from (values (1, 1), (nu ---- 1 1 3 3 + +statement ok +drop aggregate sum00; + +# https://github.com/risingwavelabs/risingwave/issues/18436 + +statement ok +CREATE TABLE exam_scores ( + score_id int, + exam_id int, + student_id int, + score real, + exam_date timestamp +); + +statement ok +INSERT INTO exam_scores (score_id, exam_id, student_id, score, exam_date) +VALUES + (1, 101, 1001, 85.5, '2022-01-10'), + (2, 101, 1002, 92.0, '2022-01-10'), + (3, 101, 1003, 78.5, '2022-01-10'), + (4, 102, 1001, 91.2, '2022-02-15'), + (5, 102, 1003, 88.9, '2022-02-15'); + +statement ok +create aggregate weighted_avg(value float, weight float) returns float language python as $$ +def create_state(): + return (0, 0) +def accumulate(state, value, weight): + if value is None or weight is None: + return state + (s, w) = state + s += value * weight + w += weight + return (s, w) +def retract(state, value, weight): + if value is None or weight is None: + return state + (s, w) = state + s -= value * weight + w -= weight + return (s, w) +def finish(state): + (sum, weight) = state + if weight == 0: + return None + else: + return sum / weight +$$; + +query +SELECT + *, + weighted_avg(score, 1) OVER ( + PARTITION BY "student_id" + ORDER BY "exam_date" + ROWS 2 PRECEDING + ) AS "weighted_avg" +FROM exam_scores +ORDER BY "student_id", "exam_date"; +---- +1 101 1001 85.5 2022-01-10 00:00:00 85.5 +4 102 1001 91.2 2022-02-15 00:00:00 88.3499984741211 +2 101 1002 92 2022-01-10 00:00:00 92 +3 101 1003 78.5 2022-01-10 00:00:00 78.5 +5 102 1003 88.9 2022-02-15 00:00:00 83.70000076293945 diff --git a/proto/expr.proto b/proto/expr.proto index c5c66733d0d92..808f402a77aa8 100644 --- a/proto/expr.proto +++ b/proto/expr.proto @@ -471,6 +471,20 @@ message AggCall { ExprNode scalar = 9; } +// The aggregation type. +// +// Ideally this should be used to encode the Rust `AggCall::agg_type` field, but historically we +// flattened it into multiple fields in proto `AggCall` - `kind` + `udf` + `scalar`. So this +// `AggType` proto type is only used by `WindowFunction` currently. +message AggType { + AggCall.Kind kind = 1; + + // UDF metadata. Only present when the kind is `USER_DEFINED`. + optional UserDefinedFunctionMetadata udf_meta = 8; + // Wrapped scalar expression. Only present when the kind is `WRAP_SCALAR`. + optional ExprNode scalar_expr = 9; +} + message WindowFrame { enum Type { TYPE_UNSPECIFIED = 0; @@ -562,7 +576,8 @@ message WindowFunction { oneof type { GeneralType general = 1; - AggCall.Kind aggregate = 2; + AggCall.Kind aggregate = 2 [deprecated = true]; // Deprecated since we have a new `aggregate2` variant. + AggType aggregate2 = 103; } repeated InputRef args = 3; data.DataType return_type = 4; diff --git a/src/expr/core/src/aggregate/def.rs b/src/expr/core/src/aggregate/def.rs index 3abe80dcd4d31..8ee002c7e92d4 100644 --- a/src/expr/core/src/aggregate/def.rs +++ b/src/expr/core/src/aggregate/def.rs @@ -27,7 +27,9 @@ use risingwave_common::types::{DataType, Datum}; use risingwave_common::util::sort_util::{ColumnOrder, OrderType}; use risingwave_common::util::value_encoding::DatumFromProtoExt; pub use risingwave_pb::expr::agg_call::PbKind as PbAggKind; -use risingwave_pb::expr::{PbAggCall, PbExprNode, PbInputRef, PbUserDefinedFunctionMetadata}; +use risingwave_pb::expr::{ + PbAggCall, PbAggType, PbExprNode, PbInputRef, PbUserDefinedFunctionMetadata, +}; use crate::expr::{ build_from_prost, BoxedExpression, ExpectExt, Expression, LiteralExpression, Token, @@ -65,7 +67,7 @@ pub struct AggCall { impl AggCall { pub fn from_protobuf(agg_call: &PbAggCall) -> Result { - let agg_type = AggType::from_protobuf( + let agg_type = AggType::from_protobuf_flatten( agg_call.get_kind()?, agg_call.udf.as_ref(), agg_call.scalar.as_ref(), @@ -160,7 +162,7 @@ impl> Parser { self.tokens.next(); // Consume the RParen AggCall { - agg_type: AggType::from_protobuf(func, None, None).unwrap(), + agg_type: AggType::from_protobuf_flatten(func, None, None).unwrap(), args: AggArgs { data_types: children.iter().map(|(_, ty)| ty.clone()).collect(), val_indices: children.iter().map(|(idx, _)| *idx).collect(), @@ -260,7 +262,7 @@ impl From for AggType { } impl AggType { - pub fn from_protobuf( + pub fn from_protobuf_flatten( pb_kind: PbAggKind, user_defined: Option<&PbUserDefinedFunctionMetadata>, scalar: Option<&PbExprNode>, @@ -286,6 +288,35 @@ impl AggType { Self::WrapScalar(_) => PbAggKind::WrapScalar, } } + + pub fn from_protobuf(pb_type: &PbAggType) -> Result { + match PbAggKind::try_from(pb_type.kind).context("no such aggregate function type")? { + PbAggKind::Unspecified => bail!("Unrecognized agg."), + PbAggKind::UserDefined => Ok(AggType::UserDefined(pb_type.get_udf_meta()?.clone())), + PbAggKind::WrapScalar => Ok(AggType::WrapScalar(pb_type.get_scalar_expr()?.clone())), + kind => Ok(AggType::Builtin(kind)), + } + } + + pub fn to_protobuf(&self) -> PbAggType { + match self { + Self::Builtin(kind) => PbAggType { + kind: *kind as _, + udf_meta: None, + scalar_expr: None, + }, + Self::UserDefined(udf_meta) => PbAggType { + kind: PbAggKind::UserDefined as _, + udf_meta: Some(udf_meta.clone()), + scalar_expr: None, + }, + Self::WrapScalar(scalar_expr) => PbAggType { + kind: PbAggKind::WrapScalar as _, + udf_meta: None, + scalar_expr: Some(scalar_expr.clone()), + }, + } + } } /// Macros to generate match arms for `AggType`. diff --git a/src/expr/core/src/window_function/kind.rs b/src/expr/core/src/window_function/kind.rs index 32c5f746020d0..9abea9b8655a4 100644 --- a/src/expr/core/src/window_function/kind.rs +++ b/src/expr/core/src/window_function/kind.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use anyhow::Context; use parse_display::{Display, FromStr}; use risingwave_common::bail; @@ -51,11 +52,12 @@ impl WindowFuncKind { Ok(PbGeneralType::Lead) => Self::Lead, Err(_) => bail!("no such window function type"), }, - PbType::Aggregate(agg_type) => match PbAggKind::try_from(*agg_type) { - // TODO(runji): support UDAF and wrapped scalar functions - Ok(agg_type) => Self::Aggregate(AggType::from_protobuf(agg_type, None, None)?), - Err(_) => bail!("no such aggregate function type"), - }, + PbType::Aggregate(kind) => Self::Aggregate(AggType::from_protobuf_flatten( + PbAggKind::try_from(*kind).context("no such aggregate function type")?, + None, + None, + )?), + PbType::Aggregate2(agg_type) => Self::Aggregate(AggType::from_protobuf(agg_type)?), }; Ok(kind) } diff --git a/src/expr/impl/src/window_function/aggregate.rs b/src/expr/impl/src/window_function/aggregate.rs index 9a30d103ade1a..c173ae8c82995 100644 --- a/src/expr/impl/src/window_function/aggregate.rs +++ b/src/expr/impl/src/window_function/aggregate.rs @@ -21,7 +21,8 @@ use risingwave_common::util::iter_util::ZipEqFast; use risingwave_common::{bail, must_match}; use risingwave_common_estimate_size::{EstimateSize, KvSize}; use risingwave_expr::aggregate::{ - AggCall, AggType, AggregateFunction, AggregateState as AggImplState, BoxedAggregateFunction, + build_append_only, AggCall, AggType, AggregateFunction, AggregateState as AggImplState, + BoxedAggregateFunction, }; use risingwave_expr::sig::FUNCTION_REGISTRY; use risingwave_expr::window_function::{ @@ -63,19 +64,34 @@ pub(super) fn new(call: &WindowFuncCall) -> Result { distinct: false, direct_args: vec![], }; - // TODO(runji): support UDAF and wrapped scalar function - let agg_kind = must_match!(agg_type, AggType::Builtin(agg_kind) => agg_kind); - let agg_func_sig = FUNCTION_REGISTRY - .get(*agg_kind, &arg_data_types, &call.return_type) - .expect("the agg func must exist"); - let agg_func = agg_func_sig.build_aggregate(&agg_call)?; - let (agg_impl, enable_delta) = - if agg_func_sig.is_retractable() && call.frame.exclusion.is_no_others() { - let init_state = agg_func.create_state()?; - (AggImpl::Incremental(init_state), true) - } else { - (AggImpl::Full, false) - }; + + let (agg_func, agg_impl, enable_delta) = match agg_type { + AggType::Builtin(kind) => { + let agg_func_sig = FUNCTION_REGISTRY + .get(*kind, &arg_data_types, &call.return_type) + .expect("the agg func must exist"); + let agg_func = agg_func_sig.build_aggregate(&agg_call)?; + let (agg_impl, enable_delta) = + if agg_func_sig.is_retractable() && call.frame.exclusion.is_no_others() { + let init_state = agg_func.create_state()?; + (AggImpl::Incremental(init_state), true) + } else { + (AggImpl::Full, false) + }; + (agg_func, agg_impl, enable_delta) + } + AggType::UserDefined(_) => { + // TODO(rc): utilize `retract` method of embedded UDAF to do incremental aggregation + let agg_func = build_append_only(&agg_call)?; + (agg_func, AggImpl::Full, false) + } + AggType::WrapScalar(_) => { + // we have to feed the wrapped scalar function with all the rows in the window, + // instead of doing incremental aggregation + let agg_func = build_append_only(&agg_call)?; + (agg_func, AggImpl::Full, false) + } + }; let this = match &call.frame.bounds { FrameBounds::Rows(frame_bounds) => Box::new(AggregateState { diff --git a/src/frontend/src/binder/expr/function/mod.rs b/src/frontend/src/binder/expr/function/mod.rs index 5d3dfb79300d2..00f2438cb35af 100644 --- a/src/frontend/src/binder/expr/function/mod.rs +++ b/src/frontend/src/binder/expr/function/mod.rs @@ -227,8 +227,8 @@ impl Binder { None }; - let agg_type = if let Some(wrapped_agg_type) = wrapped_agg_type { - Some(wrapped_agg_type) + let agg_type = if wrapped_agg_type.is_some() { + wrapped_agg_type } else if let Some(ref udf) = udf && udf.kind.is_aggregate() { diff --git a/src/frontend/src/optimizer/plan_node/generic/over_window.rs b/src/frontend/src/optimizer/plan_node/generic/over_window.rs index 5622d1e8952cf..fc10df60421bb 100644 --- a/src/frontend/src/optimizer/plan_node/generic/over_window.rs +++ b/src/frontend/src/optimizer/plan_node/generic/over_window.rs @@ -121,7 +121,7 @@ impl PlanWindowFunction { DenseRank => PbType::General(PbGeneralType::DenseRank as _), Lag => PbType::General(PbGeneralType::Lag as _), Lead => PbType::General(PbGeneralType::Lead as _), - Aggregate(agg_type) => PbType::Aggregate(agg_type.to_protobuf_simple() as _), + Aggregate(agg_type) => PbType::Aggregate2(agg_type.to_protobuf()), }; PbWindowFunction {