diff --git a/core/src/execution/datafusion/shuffle_writer.rs b/core/src/execution/datafusion/shuffle_writer.rs index 9673409795..99ac885b51 100644 --- a/core/src/execution/datafusion/shuffle_writer.rs +++ b/core/src/execution/datafusion/shuffle_writer.rs @@ -575,6 +575,8 @@ struct ShuffleRepartitioner { hashes_buf: Vec, /// Partition ids for each row in the current batch partition_ids: Vec, + /// The configured batch size + batch_size: usize, } struct ShuffleRepartitionerMetrics { @@ -642,17 +644,41 @@ impl ShuffleRepartitioner { reservation, hashes_buf, partition_ids, + batch_size, } } + /// Shuffles rows in input batch into corresponding partition buffer. + /// This function will slice input batch according to configured batch size and then + /// shuffle rows into corresponding partition buffer. + async fn insert_batch(&mut self, batch: RecordBatch) -> Result<()> { + let mut start = 0; + while start < batch.num_rows() { + let end = (start + self.batch_size).min(batch.num_rows()); + let batch = batch.slice(start, end - start); + self.partitioning_batch(batch).await?; + start = end; + } + Ok(()) + } + /// Shuffles rows in input batch into corresponding partition buffer. /// This function first calculates hashes for rows and then takes rows in same /// partition as a record batch which is appended into partition buffer. - async fn insert_batch(&mut self, input: RecordBatch) -> Result<()> { + /// This should not be called directly. Use `insert_batch` instead. + async fn partitioning_batch(&mut self, input: RecordBatch) -> Result<()> { if input.num_rows() == 0 { // skip empty batch return Ok(()); } + + if input.num_rows() > self.batch_size { + return Err(DataFusionError::Internal( + "Input batch size exceeds configured batch size. Call `insert_batch` instead." + .to_string(), + )); + } + let _timer = self.metrics.baseline.elapsed_compute().timer(); // NOTE: in shuffle writer exec, the output_rows metrics represents the @@ -951,8 +977,7 @@ async fn external_shuffle( ); while let Some(batch) = input.next().await { - let batch = batch?; - repartitioner.insert_batch(batch).await?; + repartitioner.insert_batch(batch?).await?; } repartitioner.shuffle_write().await } @@ -1387,6 +1412,11 @@ impl RecordBatchStream for EmptyStream { #[cfg(test)] mod test { use super::*; + use datafusion::physical_plan::common::collect; + use datafusion::physical_plan::memory::MemoryExec; + use datafusion::prelude::SessionContext; + use datafusion_physical_expr::expressions::Column; + use tokio::runtime::Runtime; #[test] fn test_slot_size() { @@ -1415,4 +1445,32 @@ mod test { assert_eq!(slot_size, *expected); }) } + + #[test] + fn test_insert_larger_batch() { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Utf8, true)])); + let mut b = StringBuilder::new(); + for i in 0..10000 { + b.append_value(format!("{i}")); + } + let array = b.finish(); + let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(array)]).unwrap(); + + let mut batches = Vec::new(); + batches.push(batch.clone()); + + let partitions = &[batches]; + let exec = ShuffleWriterExec::try_new( + Arc::new(MemoryExec::try_new(partitions, batch.schema(), None).unwrap()), + Partitioning::Hash(vec![Arc::new(Column::new("a", 0))], 16), + "/tmp/data.out".to_string(), + "/tmp/index.out".to_string(), + ) + .unwrap(); + let ctx = SessionContext::new(); + let task_ctx = ctx.task_ctx(); + let stream = exec.execute(0, task_ctx).unwrap(); + let rt = Runtime::new().unwrap(); + rt.block_on(collect(stream)).unwrap(); + } }