Skip to content

Commit

Permalink
fix: Reduce RowPartition memory allocation
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Apr 6, 2024
1 parent d76c113 commit 63f5554
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
import org.apache.spark.SparkConf;
import org.apache.spark.TaskContext;
import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.internal.config.package$;
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.serializer.SerializationStream;
import org.apache.spark.serializer.SerializerInstance;
Expand Down Expand Up @@ -102,7 +101,7 @@ public final class CometDiskBlockWriter {
private final File file;
private long totalWritten = 0L;
private boolean initialized = false;
private final int initialBufferSize;
private final int columnarBatchSize;
private final boolean isAsync;
private final int asyncThreadNum;
private final ExecutorService threadPool;
Expand Down Expand Up @@ -152,8 +151,7 @@ public final class CometDiskBlockWriter {
this.asyncThreadNum = asyncThreadNum;
this.threadPool = threadPool;

this.initialBufferSize =
(int) (long) conf.get(package$.MODULE$.SHUFFLE_SORT_INIT_BUFFER_SIZE());
this.columnarBatchSize = (int) CometConf$.MODULE$.COMET_COLUMNAR_SHUFFLE_BATCH_SIZE().get();

this.numElementsForSpillThreshold =
(int) CometConf$.MODULE$.COMET_EXEC_SHUFFLE_SPILL_THRESHOLD().get();
Expand Down Expand Up @@ -264,10 +262,11 @@ public void insertRow(UnsafeRow row, int partitionId) throws IOException {
// While proceeding with possible spilling and inserting the record, we need to synchronize
// it, because other threads may be spilling this writer at the same time.
synchronized (CometDiskBlockWriter.this) {
if (activeWriter.numRecords() >= numElementsForSpillThreshold) {
if (activeWriter.numRecords() >= numElementsForSpillThreshold
|| activeWriter.numRecords() >= columnarBatchSize) {
int threshold = Math.min(numElementsForSpillThreshold, columnarBatchSize);
logger.info(
"Spilling data because number of spilledRecords crossed the threshold "
+ numElementsForSpillThreshold);
"Spilling data because number of spilledRecords crossed the threshold " + threshold);
// Spill the current writer
doSpill(false);
if (activeWriter.numRecords() != 0) {
Expand Down Expand Up @@ -347,7 +346,7 @@ class ArrowIPCWriter extends SpillWriter {
private final RowPartition rowPartition;

ArrowIPCWriter() {
rowPartition = new RowPartition(initialBufferSize);
rowPartition = new RowPartition(columnarBatchSize);

this.allocatedPages = new LinkedList<>();
this.allocator = CometDiskBlockWriter.this.allocator;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ package org.apache.spark.shuffle.sort
import scala.collection.mutable.ArrayBuffer

class RowPartition(initialSize: Int) {
private val rowAddresses: ArrayBuffer[Long] = new ArrayBuffer[Long](initialSize)
private val rowSizes: ArrayBuffer[Int] = new ArrayBuffer[Int](initialSize)
private var rowAddresses: ArrayBuffer[Long] = new ArrayBuffer[Long](initialSize)
private var rowSizes: ArrayBuffer[Int] = new ArrayBuffer[Int](initialSize)

def addRow(addr: Long, size: Int): Unit = {
rowAddresses += addr
Expand All @@ -36,7 +36,7 @@ class RowPartition(initialSize: Int) {
def getRowSizes: Array[Int] = rowSizes.toArray

def reset(): Unit = {
rowAddresses.clear()
rowSizes.clear()
rowAddresses = new ArrayBuffer[Long](initialSize)
rowSizes = new ArrayBuffer[Int](initialSize)
}
}

0 comments on commit 63f5554

Please sign in to comment.