From 43177964a713eb117475b77f7ba83b6b7c8149a8 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 6 Apr 2024 23:29:42 -0700 Subject: [PATCH 1/3] fix: Deallocate row addresses and size arrays after exporting --- .../apache/spark/shuffle/sort/RowPartition.scala | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/spark/src/main/scala/org/apache/spark/shuffle/sort/RowPartition.scala b/spark/src/main/scala/org/apache/spark/shuffle/sort/RowPartition.scala index bce24be1f..32d64fad4 100644 --- a/spark/src/main/scala/org/apache/spark/shuffle/sort/RowPartition.scala +++ b/spark/src/main/scala/org/apache/spark/shuffle/sort/RowPartition.scala @@ -32,8 +32,17 @@ class RowPartition(initialSize: Int) { def getNumRows: Int = rowAddresses.size - def getRowAddresses: Array[Long] = rowAddresses.toArray - def getRowSizes: Array[Int] = rowSizes.toArray + def getRowAddresses: Array[Long] = { + val array = rowAddresses.toArray + rowAddresses = null + array + } + + def getRowSizes: Array[Int] = { + val array = rowSizes.toArray + rowSizes = null + array + } def reset(): Unit = { rowAddresses = new ArrayBuffer[Long](initialSize) From 4cf7a3d0f9abb2ca78dba705ac309095d8749d36 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 7 Apr 2024 10:58:48 -0700 Subject: [PATCH 2/3] More --- .../sql/comet/execution/shuffle/SpillWriter.java | 6 ++++-- .../org/apache/spark/shuffle/sort/RowPartition.scala | 12 ++---------- 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/SpillWriter.java b/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/SpillWriter.java index cc8c04fdd..b9f4dab39 100644 --- a/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/SpillWriter.java +++ b/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/SpillWriter.java @@ -175,6 +175,10 @@ protected long doSpilling( long[] addresses = rowPartition.getRowAddresses(); int[] sizes = rowPartition.getRowSizes(); + // We exported the addresses and sizes, reset the row partition + // to release the memory as soon as possible. + rowPartition.reset(); + boolean checksumEnabled = checksum != -1; long currentChecksum = checksumEnabled ? checksum : 0L; @@ -195,8 +199,6 @@ protected long doSpilling( long written = results[0]; checksum = results[1]; - rowPartition.reset(); - // Update metrics // Other threads may be updating the metrics at the same time, so we need to synchronize it. synchronized (writeMetricsToUse) { diff --git a/spark/src/main/scala/org/apache/spark/shuffle/sort/RowPartition.scala b/spark/src/main/scala/org/apache/spark/shuffle/sort/RowPartition.scala index 32d64fad4..4c75f118a 100644 --- a/spark/src/main/scala/org/apache/spark/shuffle/sort/RowPartition.scala +++ b/spark/src/main/scala/org/apache/spark/shuffle/sort/RowPartition.scala @@ -32,17 +32,9 @@ class RowPartition(initialSize: Int) { def getNumRows: Int = rowAddresses.size - def getRowAddresses: Array[Long] = { - val array = rowAddresses.toArray - rowAddresses = null - array - } + def getRowAddresses: Array[Long] = rowAddresses.toArray - def getRowSizes: Array[Int] = { - val array = rowSizes.toArray - rowSizes = null - array - } + def getRowSizes: Array[Int] = rowSizes.toArray def reset(): Unit = { rowAddresses = new ArrayBuffer[Long](initialSize) From fa5d0b79cd2f5dc7f711707838589565e261945e Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 7 Apr 2024 11:00:23 -0700 Subject: [PATCH 3/3] Revert "More" This reverts commit 4cf7a3d0f9abb2ca78dba705ac309095d8749d36. --- .../sql/comet/execution/shuffle/SpillWriter.java | 6 ++---- .../org/apache/spark/shuffle/sort/RowPartition.scala | 12 ++++++++++-- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/SpillWriter.java b/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/SpillWriter.java index b9f4dab39..cc8c04fdd 100644 --- a/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/SpillWriter.java +++ b/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/SpillWriter.java @@ -175,10 +175,6 @@ protected long doSpilling( long[] addresses = rowPartition.getRowAddresses(); int[] sizes = rowPartition.getRowSizes(); - // We exported the addresses and sizes, reset the row partition - // to release the memory as soon as possible. - rowPartition.reset(); - boolean checksumEnabled = checksum != -1; long currentChecksum = checksumEnabled ? checksum : 0L; @@ -199,6 +195,8 @@ protected long doSpilling( long written = results[0]; checksum = results[1]; + rowPartition.reset(); + // Update metrics // Other threads may be updating the metrics at the same time, so we need to synchronize it. synchronized (writeMetricsToUse) { diff --git a/spark/src/main/scala/org/apache/spark/shuffle/sort/RowPartition.scala b/spark/src/main/scala/org/apache/spark/shuffle/sort/RowPartition.scala index 4c75f118a..32d64fad4 100644 --- a/spark/src/main/scala/org/apache/spark/shuffle/sort/RowPartition.scala +++ b/spark/src/main/scala/org/apache/spark/shuffle/sort/RowPartition.scala @@ -32,9 +32,17 @@ class RowPartition(initialSize: Int) { def getNumRows: Int = rowAddresses.size - def getRowAddresses: Array[Long] = rowAddresses.toArray + def getRowAddresses: Array[Long] = { + val array = rowAddresses.toArray + rowAddresses = null + array + } - def getRowSizes: Array[Int] = rowSizes.toArray + def getRowSizes: Array[Int] = { + val array = rowSizes.toArray + rowSizes = null + array + } def reset(): Unit = { rowAddresses = new ArrayBuffer[Long](initialSize)