diff --git a/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/SpillWriter.java b/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/SpillWriter.java index b9f4dab39..cc8c04fdd 100644 --- a/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/SpillWriter.java +++ b/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/SpillWriter.java @@ -175,10 +175,6 @@ protected long doSpilling( long[] addresses = rowPartition.getRowAddresses(); int[] sizes = rowPartition.getRowSizes(); - // We exported the addresses and sizes, reset the row partition - // to release the memory as soon as possible. - rowPartition.reset(); - boolean checksumEnabled = checksum != -1; long currentChecksum = checksumEnabled ? checksum : 0L; @@ -199,6 +195,8 @@ protected long doSpilling( long written = results[0]; checksum = results[1]; + rowPartition.reset(); + // Update metrics // Other threads may be updating the metrics at the same time, so we need to synchronize it. synchronized (writeMetricsToUse) { diff --git a/spark/src/main/scala/org/apache/spark/shuffle/sort/RowPartition.scala b/spark/src/main/scala/org/apache/spark/shuffle/sort/RowPartition.scala index 4c75f118a..32d64fad4 100644 --- a/spark/src/main/scala/org/apache/spark/shuffle/sort/RowPartition.scala +++ b/spark/src/main/scala/org/apache/spark/shuffle/sort/RowPartition.scala @@ -32,9 +32,17 @@ class RowPartition(initialSize: Int) { def getNumRows: Int = rowAddresses.size - def getRowAddresses: Array[Long] = rowAddresses.toArray + def getRowAddresses: Array[Long] = { + val array = rowAddresses.toArray + rowAddresses = null + array + } - def getRowSizes: Array[Int] = rowSizes.toArray + def getRowSizes: Array[Int] = { + val array = rowSizes.toArray + rowSizes = null + array + } def reset(): Unit = { rowAddresses = new ArrayBuffer[Long](initialSize)