Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Input batch to ShuffleRepartitioner.insert_batch should not be larger than configured batch size #523

Merged
merged 3 commits into from
Jun 6, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 61 additions & 3 deletions core/src/execution/datafusion/shuffle_writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,8 @@ struct ShuffleRepartitioner {
hashes_buf: Vec<u32>,
/// Partition ids for each row in the current batch
partition_ids: Vec<u64>,
/// The configured batch size
batch_size: usize,
}

struct ShuffleRepartitionerMetrics {
Expand Down Expand Up @@ -642,17 +644,41 @@ impl ShuffleRepartitioner {
reservation,
hashes_buf,
partition_ids,
batch_size,
}
}

/// Shuffles rows in input batch into corresponding partition buffer.
/// This function will slice input batch according to configured batch size and then
/// shuffle rows into corresponding partition buffer.
async fn insert_batch(&mut self, batch: RecordBatch) -> Result<()> {
let mut start = 0;
while start < batch.num_rows() {
let end = (start + self.batch_size).min(batch.num_rows());
let batch = batch.slice(start, end - start);
self.partitioning_batch(batch).await?;
start = end;
}
Ok(())
}

/// Shuffles rows in input batch into corresponding partition buffer.
/// This function first calculates hashes for rows and then takes rows in same
/// partition as a record batch which is appended into partition buffer.
async fn insert_batch(&mut self, input: RecordBatch) -> Result<()> {
/// This should not be called directly. Use `insert_batch` instead.
async fn partitioning_batch(&mut self, input: RecordBatch) -> Result<()> {
if input.num_rows() == 0 {
// skip empty batch
return Ok(());
}

if input.num_rows() > self.batch_size {
return Err(DataFusionError::Internal(
"Input batch size exceeds configured batch size. Call `insert_batch` instead."
.to_string(),
));
}

let _timer = self.metrics.baseline.elapsed_compute().timer();

// NOTE: in shuffle writer exec, the output_rows metrics represents the
Expand Down Expand Up @@ -951,8 +977,7 @@ async fn external_shuffle(
);

while let Some(batch) = input.next().await {
let batch = batch?;
repartitioner.insert_batch(batch).await?;
repartitioner.insert_batch(batch?).await?;
}
repartitioner.shuffle_write().await
}
Expand Down Expand Up @@ -1387,6 +1412,11 @@ impl RecordBatchStream for EmptyStream {
#[cfg(test)]
mod test {
use super::*;
use datafusion::physical_plan::common::collect;
use datafusion::physical_plan::memory::MemoryExec;
use datafusion::prelude::SessionContext;
use datafusion_physical_expr::expressions::Column;
use tokio::runtime::Runtime;

#[test]
fn test_slot_size() {
Expand Down Expand Up @@ -1415,4 +1445,32 @@ mod test {
assert_eq!(slot_size, *expected);
})
}

#[test]
fn test_insert_larger_batch() {
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Utf8, true)]));
let mut b = StringBuilder::new();
for i in 0..10000 {
b.append_value(format!("{i}"));
}
let array = b.finish();
let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(array)]).unwrap();

let mut batches = Vec::new();
batches.push(batch.clone());

let partitions = &[batches];
let exec = ShuffleWriterExec::try_new(
Arc::new(MemoryExec::try_new(partitions, batch.schema(), None).unwrap()),
Partitioning::Hash(vec![Arc::new(Column::new("a", 0))], 16),
"/tmp/data.out".to_string(),
"/tmp/index.out".to_string(),
)
.unwrap();
let ctx = SessionContext::new();
let task_ctx = ctx.task_ctx();
let stream = exec.execute(0, task_ctx).unwrap();
let rt = Runtime::new().unwrap();
rt.block_on(collect(stream)).unwrap();
}
}
Loading