From 57b3be4297a47aa45094c16e37ddf0141d723bf0 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 15 Apr 2024 17:29:45 -0700 Subject: [PATCH] fix: Specify row count in sort_batch for empty batch --- datafusion/physical-plan/src/sorts/sort.rs | 27 ++++++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs index a6f47d3d2fc9..e74fcc40aa13 100644 --- a/datafusion/physical-plan/src/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -46,7 +46,7 @@ use arrow::datatypes::SchemaRef; use arrow::ipc::reader::FileReader; use arrow::record_batch::RecordBatch; use arrow::row::{RowConverter, SortField}; -use arrow_array::{Array, UInt32Array}; +use arrow_array::{Array, RecordBatchOptions, UInt32Array}; use arrow_schema::DataType; use datafusion_common::{exec_err, DataFusionError, Result}; use datafusion_common_runtime::SpawnedTask; @@ -605,7 +605,12 @@ pub(crate) fn sort_batch( .map(|c| take(c.as_ref(), &indices, None)) .collect::>()?; - Ok(RecordBatch::try_new(batch.schema(), columns)?) + let options = RecordBatchOptions::new().with_row_count(Some(indices.len())); + Ok(RecordBatch::try_new_with_options( + batch.schema(), + columns, + &options, + )?) } #[inline] @@ -993,6 +998,8 @@ mod tests { use datafusion_execution::config::SessionConfig; use datafusion_execution::runtime_env::RuntimeConfig; + use datafusion_common::ScalarValue; + use datafusion_physical_expr::expressions::Literal; use futures::FutureExt; #[tokio::test] @@ -1399,4 +1406,20 @@ mod tests { Ok(()) } + + #[test] + fn test_empty_sort_batch() { + let schema = Arc::new(Schema::empty()); + let options = RecordBatchOptions::new().with_row_count(Some(1)); + let batch = + RecordBatch::try_new_with_options(schema.clone(), vec![], &options).unwrap(); + + let expressions = vec![PhysicalSortExpr { + expr: Arc::new(Literal::new(ScalarValue::Int64(Some(1)))), + options: SortOptions::default(), + }]; + + let result = sort_batch(&batch, &expressions, None).unwrap(); + assert_eq!(result.num_rows(), 1); + } }