From 4cf7a3d0f9abb2ca78dba705ac309095d8749d36 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 7 Apr 2024 10:58:48 -0700 Subject: [PATCH] More --- .../sql/comet/execution/shuffle/SpillWriter.java | 6 ++++-- .../org/apache/spark/shuffle/sort/RowPartition.scala | 12 ++---------- 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/SpillWriter.java b/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/SpillWriter.java index cc8c04fdd..b9f4dab39 100644 --- a/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/SpillWriter.java +++ b/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/SpillWriter.java @@ -175,6 +175,10 @@ protected long doSpilling( long[] addresses = rowPartition.getRowAddresses(); int[] sizes = rowPartition.getRowSizes(); + // We exported the addresses and sizes, reset the row partition + // to release the memory as soon as possible. + rowPartition.reset(); + boolean checksumEnabled = checksum != -1; long currentChecksum = checksumEnabled ? checksum : 0L; @@ -195,8 +199,6 @@ protected long doSpilling( long written = results[0]; checksum = results[1]; - rowPartition.reset(); - // Update metrics // Other threads may be updating the metrics at the same time, so we need to synchronize it. synchronized (writeMetricsToUse) { diff --git a/spark/src/main/scala/org/apache/spark/shuffle/sort/RowPartition.scala b/spark/src/main/scala/org/apache/spark/shuffle/sort/RowPartition.scala index 32d64fad4..4c75f118a 100644 --- a/spark/src/main/scala/org/apache/spark/shuffle/sort/RowPartition.scala +++ b/spark/src/main/scala/org/apache/spark/shuffle/sort/RowPartition.scala @@ -32,17 +32,9 @@ class RowPartition(initialSize: Int) { def getNumRows: Int = rowAddresses.size - def getRowAddresses: Array[Long] = { - val array = rowAddresses.toArray - rowAddresses = null - array - } + def getRowAddresses: Array[Long] = rowAddresses.toArray - def getRowSizes: Array[Int] = { - val array = rowSizes.toArray - rowSizes = null - array - } + def getRowSizes: Array[Int] = rowSizes.toArray def reset(): Unit = { rowAddresses = new ArrayBuffer[Long](initialSize)