From 6fa6f0abb865f11a90b5c2f28d0ea8653ec964de Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Mon, 19 Feb 2024 13:40:21 +0800 Subject: [PATCH] fix(udf): fix "udf returned no data" error (#15076) Signed-off-by: Runji Wang --- src/expr/core/src/expr/expr_udf.rs | 61 +++++++++++------------------- src/expr/udf/src/external.rs | 40 +++++++++++--------- 2 files changed, 46 insertions(+), 55 deletions(-) diff --git a/src/expr/core/src/expr/expr_udf.rs b/src/expr/core/src/expr/expr_udf.rs index b162691896a4..246944ae5d9d 100644 --- a/src/expr/core/src/expr/expr_udf.rs +++ b/src/expr/core/src/expr/expr_udf.rs @@ -44,6 +44,7 @@ pub struct UserDefinedFunction { children: Vec, arg_types: Vec, return_type: DataType, + #[expect(dead_code)] arg_schema: Arc, imp: UdfImpl, identifier: String, @@ -72,13 +73,19 @@ impl Expression for UserDefinedFunction { } async fn eval(&self, input: &DataChunk) -> Result { - let vis = input.visibility(); + if input.cardinality() == 0 { + // early return for empty input + let mut builder = self.return_type.create_array_builder(input.capacity()); + builder.append_n_null(input.capacity()); + return Ok(builder.finish().into_ref()); + } let mut columns = Vec::with_capacity(self.children.len()); for child in &self.children { let array = child.eval(input).await?; columns.push(array); } - self.eval_inner(columns, vis).await + let chunk = DataChunk::new(columns, input.visibility().clone()); + self.eval_inner(&chunk).await } async fn eval_row(&self, input: &OwnedRow) -> Result { @@ -89,51 +96,29 @@ impl Expression for UserDefinedFunction { } let arg_row = OwnedRow::new(columns); let chunk = DataChunk::from_rows(std::slice::from_ref(&arg_row), &self.arg_types); - let arg_columns = chunk.columns().to_vec(); - let output_array = self.eval_inner(arg_columns, chunk.visibility()).await?; + let output_array = self.eval_inner(&chunk).await?; Ok(output_array.to_datum()) } } impl UserDefinedFunction { - async fn eval_inner( - &self, - columns: Vec, - vis: &risingwave_common::buffer::Bitmap, - ) -> Result { - let chunk = DataChunk::new(columns, vis.clone()); - let compacted_chunk = chunk.compact_cow(); - let compacted_columns: Vec = compacted_chunk - .columns() - .iter() - .map(|c| { - c.as_ref() - .try_into() - .expect("failed covert ArrayRef to arrow_array::ArrayRef") - }) - .collect(); - let opts = arrow_array::RecordBatchOptions::default() - .with_row_count(Some(compacted_chunk.capacity())); - let input = arrow_array::RecordBatch::try_new_with_options( - self.arg_schema.clone(), - compacted_columns, - &opts, - ) - .expect("failed to build record batch"); + async fn eval_inner(&self, input: &DataChunk) -> Result { + // this will drop invisible rows + let arrow_input = arrow_array::RecordBatch::try_from(input)?; - let output: arrow_array::RecordBatch = match &self.imp { - UdfImpl::Wasm(runtime) => runtime.call(&self.identifier, &input)?, - UdfImpl::JavaScript(runtime) => runtime.call(&self.identifier, &input)?, + let arrow_output: arrow_array::RecordBatch = match &self.imp { + UdfImpl::Wasm(runtime) => runtime.call(&self.identifier, &arrow_input)?, + UdfImpl::JavaScript(runtime) => runtime.call(&self.identifier, &arrow_input)?, UdfImpl::External(client) => { let disable_retry_count = self.disable_retry_count.load(Ordering::Relaxed); let result = if disable_retry_count != 0 { client - .call(&self.identifier, input) + .call(&self.identifier, arrow_input) .instrument_await(self.span.clone()) .await } else { client - .call_with_retry(&self.identifier, input) + .call_with_retry(&self.identifier, arrow_input) .instrument_await(self.span.clone()) .await }; @@ -155,16 +140,16 @@ impl UserDefinedFunction { result? } }; - if output.num_rows() != vis.count_ones() { + if arrow_output.num_rows() != input.cardinality() { bail!( "UDF returned {} rows, but expected {}", - output.num_rows(), - vis.len(), + arrow_output.num_rows(), + input.cardinality(), ); } - let data_chunk = DataChunk::try_from(&output)?; - let output = data_chunk.uncompact(vis.clone()); + let output = DataChunk::try_from(&arrow_output)?; + let output = output.uncompact(input.visibility().clone()); let Some(array) = output.columns().first() else { bail!("UDF returned no columns"); diff --git a/src/expr/udf/src/external.rs b/src/expr/udf/src/external.rs index f8d4cf6cc379..046d681485c3 100644 --- a/src/expr/udf/src/external.rs +++ b/src/expr/udf/src/external.rs @@ -139,21 +139,17 @@ impl ArrowFlightUdfClient { } async fn call_internal(&self, id: &str, input: RecordBatch) -> Result { - let mut output_stream = self.call_stream(id, stream::once(async { input })).await?; - // TODO: support no output - let head = output_stream - .next() - .await - .ok_or_else(Error::no_returned)??; - let remaining = output_stream.try_collect::>().await?; - if remaining.is_empty() { - Ok(head) - } else { - Ok(arrow_select::concat::concat_batches( - &head.schema(), - std::iter::once(&head).chain(remaining.iter()), - )?) + let mut output_stream = self + .call_stream_internal(id, stream::once(async { input })) + .await?; + let mut batches = vec![]; + while let Some(batch) = output_stream.next().await { + batches.push(batch?); } + Ok(arrow_select::concat::concat_batches( + output_stream.schema().ok_or_else(Error::no_returned)?, + batches.iter(), + )?) } /// Call a function, retry up to 5 times / 3s if connection is broken. @@ -179,6 +175,17 @@ impl ArrowFlightUdfClient { id: &str, inputs: impl Stream + Send + 'static, ) -> Result> + Send + 'static> { + Ok(self + .call_stream_internal(id, inputs) + .await? + .map_err(|e| e.into())) + } + + async fn call_stream_internal( + &self, + id: &str, + inputs: impl Stream + Send + 'static, + ) -> Result { let descriptor = FlightDescriptor::new_path(vec![id.into()]); let flight_data_stream = FlightDataEncoderBuilder::new() @@ -194,11 +201,10 @@ impl ArrowFlightUdfClient { // decode response let stream = response.into_inner(); - let record_batch_stream = FlightRecordBatchStream::new_from_flight_data( + Ok(FlightRecordBatchStream::new_from_flight_data( // convert tonic::Status to FlightError stream.map_err(|e| e.into()), - ); - Ok(record_batch_stream.map_err(|e| e.into())) + )) } }