diff --git a/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/CometDiskBlockWriter.java b/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/CometDiskBlockWriter.java index d1593f725b..ad819ccb34 100644 --- a/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/CometDiskBlockWriter.java +++ b/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/CometDiskBlockWriter.java @@ -36,7 +36,6 @@ import org.apache.spark.SparkConf; import org.apache.spark.TaskContext; import org.apache.spark.executor.ShuffleWriteMetrics; -import org.apache.spark.internal.config.package$; import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.serializer.SerializationStream; import org.apache.spark.serializer.SerializerInstance; @@ -102,7 +101,7 @@ public final class CometDiskBlockWriter { private final File file; private long totalWritten = 0L; private boolean initialized = false; - private final int initialBufferSize; + private final int columnarBatchSize; private final boolean isAsync; private final int asyncThreadNum; private final ExecutorService threadPool; @@ -152,8 +151,7 @@ public final class CometDiskBlockWriter { this.asyncThreadNum = asyncThreadNum; this.threadPool = threadPool; - this.initialBufferSize = - (int) (long) conf.get(package$.MODULE$.SHUFFLE_SORT_INIT_BUFFER_SIZE()); + this.columnarBatchSize = (int) CometConf$.MODULE$.COMET_COLUMNAR_SHUFFLE_BATCH_SIZE().get(); this.numElementsForSpillThreshold = (int) CometConf$.MODULE$.COMET_EXEC_SHUFFLE_SPILL_THRESHOLD().get(); @@ -264,10 +262,11 @@ public void insertRow(UnsafeRow row, int partitionId) throws IOException { // While proceeding with possible spilling and inserting the record, we need to synchronize // it, because other threads may be spilling this writer at the same time. synchronized (CometDiskBlockWriter.this) { - if (activeWriter.numRecords() >= numElementsForSpillThreshold) { + if (activeWriter.numRecords() >= numElementsForSpillThreshold + || activeWriter.numRecords() >= columnarBatchSize) { + int threshold = Math.min(numElementsForSpillThreshold, columnarBatchSize); logger.info( - "Spilling data because number of spilledRecords crossed the threshold " - + numElementsForSpillThreshold); + "Spilling data because number of spilledRecords crossed the threshold " + threshold); // Spill the current writer doSpill(false); if (activeWriter.numRecords() != 0) { diff --git a/spark/src/main/scala/org/apache/spark/shuffle/sort/RowPartition.scala b/spark/src/main/scala/org/apache/spark/shuffle/sort/RowPartition.scala index 873e422fb9..d650a3f1d9 100644 --- a/spark/src/main/scala/org/apache/spark/shuffle/sort/RowPartition.scala +++ b/spark/src/main/scala/org/apache/spark/shuffle/sort/RowPartition.scala @@ -22,8 +22,8 @@ package org.apache.spark.shuffle.sort import scala.collection.mutable.ArrayBuffer class RowPartition(initialSize: Int) { - private val rowAddresses: ArrayBuffer[Long] = new ArrayBuffer[Long](initialSize) - private val rowSizes: ArrayBuffer[Int] = new ArrayBuffer[Int](initialSize) + private var rowAddresses: ArrayBuffer[Long] = new ArrayBuffer[Long](initialSize) + private var rowSizes: ArrayBuffer[Int] = new ArrayBuffer[Int](initialSize) def addRow(addr: Long, size: Int): Unit = { rowAddresses += addr @@ -36,7 +36,10 @@ class RowPartition(initialSize: Int) { def getRowSizes: Array[Int] = rowSizes.toArray def reset(): Unit = { - rowAddresses.clear() - rowSizes.clear() + rowAddresses = null + rowSizes = null + + rowAddresses = new ArrayBuffer[Long](initialSize) + rowSizes = new ArrayBuffer[Int](initialSize) } }