Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Remove redundant data copy in columnar shuffle #233

Merged
merged 2 commits into from
Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,26 +21,34 @@ package org.apache.spark.sql.comet.execution.shuffle

import java.nio.{ByteBuffer, ByteOrder}
import java.nio.file.{Files, Paths}
import java.util.function.Supplier

import scala.collection.JavaConverters.asJavaIterableConverter
import scala.concurrent.Future

import org.apache.spark._
import org.apache.spark.internal.config
import org.apache.spark.rdd.{MapPartitionsRDD, RDD}
import org.apache.spark.scheduler.MapStatus
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleWriteMetricsReporter, ShuffleWriteProcessor}
import org.apache.spark.shuffle.sort.SortShuffleManager
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering
import org.apache.spark.sql.catalyst.plans.logical.Statistics
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.comet.{CometExec, CometMetricNode, CometPlan}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, ShuffleExchangeLike, ShuffleOrigin}
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics, SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.MutablePair
import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, RecordComparator}
import org.apache.spark.util.random.XORShiftRandom

import org.apache.comet.serde.{OperatorOuterClass, PartitioningOuterClass, QueryPlanSerde}
import org.apache.comet.serde.OperatorOuterClass.Operator
Expand Down Expand Up @@ -208,6 +216,50 @@ object CometShuffleExchangeExec extends ShimCometShuffleExchangeExec {
dependency
}

/**
* This is copied from Spark `ShuffleExchangeExec.needToCopyObjectsBeforeShuffle`. The only
* difference is that we use `BosonShuffleManager` instead of `SortShuffleManager`.
*/
private def needToCopyObjectsBeforeShuffle(partitioner: Partitioner): Boolean = {
Copy link
Member Author

@viirya viirya Mar 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Spark needToCopyObjectsBeforeShuffle will return true if it finds the shuffle manager is not Spark SortShuffleManager. So it incurs additional row copying for all cases which are redundant for Comet and increase memory usage (doubling input size).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could also improve the shuffle time on columnar shuffle as it removes expensive copying operation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So not a memory leak, just double the memory usage?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, I think so.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[a bit orthogonal] how did you find it ? by profiling or by just analyzing the shuffle size, what was the scale of tpcds (3TB) curious ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did memory profiling when running TPCDS queries locally. Because it is local benchmark, the scale is small. This issue is not only for larger scale so it can be profiled in small scale too.

// Note: even though we only use the partitioner's `numPartitions` field, we require it to be
// passed instead of directly passing the number of partitions in order to guard against
// corner-cases where a partitioner constructed with `numPartitions` partitions may output
// fewer partitions (like RangePartitioner, for example).
val conf = SparkEnv.get.conf
val shuffleManager = SparkEnv.get.shuffleManager
val sortBasedShuffleOn = shuffleManager.isInstanceOf[CometShuffleManager]
val bypassMergeThreshold = conf.get(config.SHUFFLE_SORT_BYPASS_MERGE_THRESHOLD)
val numParts = partitioner.numPartitions
if (sortBasedShuffleOn) {
if (numParts <= bypassMergeThreshold) {
// If we're using the original SortShuffleManager and the number of output partitions is
// sufficiently small, then Spark will fall back to the hash-based shuffle write path, which
// doesn't buffer deserialized records.
// Note that we'll have to remove this case if we fix SPARK-6026 and remove this bypass.
false
} else if (numParts <= SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE) {
// SPARK-4550 and SPARK-7081 extended sort-based shuffle to serialize individual records
// prior to sorting them. This optimization is only applied in cases where shuffle
// dependency does not specify an aggregator or ordering and the record serializer has
// certain properties and the number of partitions doesn't exceed the limitation. If this
// optimization is enabled, we can safely avoid the copy.
//
// Exchange never configures its ShuffledRDDs with aggregators or key orderings, and the
// serializer in Spark SQL always satisfy the properties, so we only need to check whether
// the number of partitions exceeds the limitation.
false
} else {
// This different to Spark `SortShuffleManager`.
// Comet doesn't use Spark `ExternalSorter` to buffer records in memory, so we don't need to
// copy.
false
Copy link
Member Author

@viirya viirya Mar 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Basically Comet shuffle doesn't run through the ways that require additional copying on row objects like Spark. So these branches are returning false at all. But I keep them to make it clear to explain the reason.

}
} else {
// Catch-all case to safely handle any future ShuffleManager implementations.
true
}
}

/**
* Returns a [[ShuffleDependency]] that will partition rows of its child based on the
* partitioning scheme defined in `newPartitioning`. Those partitions of the returned
Expand All @@ -219,21 +271,146 @@ object CometShuffleExchangeExec extends ShimCometShuffleExchangeExec {
newPartitioning: Partitioning,
serializer: Serializer,
writeMetrics: Map[String, SQLMetric]): ShuffleDependency[Int, InternalRow, InternalRow] = {
val sparkShuffleDep = ShuffleExchangeExec.prepareShuffleDependency(
rdd,
outputAttributes,
newPartitioning,
serializer,
writeMetrics)
val part: Partitioner = newPartitioning match {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a difference between this and calling ShuffleExchangeExec.prepareShuffleDependency directly?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We call ShuffleExchangeExec.prepareShuffleDependency before.

But now we need to overwrite needToCopyObjectsBeforeShuffle which is a private method in Spark. So we need to copy the code of ShuffleExchangeExec.prepareShuffleDependency here.

case RoundRobinPartitioning(numPartitions) => new HashPartitioner(numPartitions)
case HashPartitioning(_, n) =>
// For HashPartitioning, the partitioning key is already a valid partition ID, as we use
// `HashPartitioning.partitionIdExpression` to produce partitioning key.
new PartitionIdPassthrough(n)
case RangePartitioning(sortingExpressions, numPartitions) =>
// Extract only fields used for sorting to avoid collecting large fields that does not
// affect sorting result when deciding partition bounds in RangePartitioner
val rddForSampling = rdd.mapPartitionsInternal { iter =>
val projection =
UnsafeProjection.create(sortingExpressions.map(_.child), outputAttributes)
val mutablePair = new MutablePair[InternalRow, Null]()
// Internally, RangePartitioner runs a job on the RDD that samples keys to compute
// partition bounds. To get accurate samples, we need to copy the mutable keys.
iter.map(row => mutablePair.update(projection(row).copy(), null))
}
// Construct ordering on extracted sort key.
val orderingAttributes = sortingExpressions.zipWithIndex.map { case (ord, i) =>
ord.copy(child = BoundReference(i, ord.dataType, ord.nullable))
}
implicit val ordering = new LazilyGeneratedOrdering(orderingAttributes)
new RangePartitioner(
numPartitions,
rddForSampling,
ascending = true,
samplePointsPerPartitionHint = SQLConf.get.rangeExchangeSampleSizePerPartition)
case SinglePartition => new ConstantPartitioner
case _ => throw new IllegalStateException(s"Exchange not implemented for $newPartitioning")
// TODO: Handle BroadcastPartitioning.
}
def getPartitionKeyExtractor(): InternalRow => Any = newPartitioning match {
case RoundRobinPartitioning(numPartitions) =>
// Distributes elements evenly across output partitions, starting from a random partition.
// nextInt(numPartitions) implementation has a special case when bound is a power of 2,
// which is basically taking several highest bits from the initial seed, with only a
// minimal scrambling. Due to deterministic seed, using the generator only once,
// and lack of scrambling, the position values for power-of-two numPartitions always
// end up being almost the same regardless of the index. substantially scrambling the
// seed by hashing will help. Refer to SPARK-21782 for more details.
val partitionId = TaskContext.get().partitionId()
var position = new XORShiftRandom(partitionId).nextInt(numPartitions)
(row: InternalRow) => {
// The HashPartitioner will handle the `mod` by the number of partitions
position += 1
position
}
case h: HashPartitioning =>
val projection = UnsafeProjection.create(h.partitionIdExpression :: Nil, outputAttributes)
row => projection(row).getInt(0)
case RangePartitioning(sortingExpressions, _) =>
val projection =
UnsafeProjection.create(sortingExpressions.map(_.child), outputAttributes)
row => projection(row)
case SinglePartition => identity
case _ => throw new IllegalStateException(s"Exchange not implemented for $newPartitioning")
}

val isRoundRobin = newPartitioning.isInstanceOf[RoundRobinPartitioning] &&
newPartitioning.numPartitions > 1

val rddWithPartitionIds: RDD[Product2[Int, InternalRow]] = {
// [SPARK-23207] Have to make sure the generated RoundRobinPartitioning is deterministic,
// otherwise a retry task may output different rows and thus lead to data loss.
//
// Currently we following the most straight-forward way that perform a local sort before
// partitioning.
//
// Note that we don't perform local sort if the new partitioning has only 1 partition, under
// that case all output rows go to the same partition.
val newRdd = if (isRoundRobin && SQLConf.get.sortBeforeRepartition) {
rdd.mapPartitionsInternal { iter =>
val recordComparatorSupplier = new Supplier[RecordComparator] {
override def get: RecordComparator = new RecordBinaryComparator()
}
// The comparator for comparing row hashcode, which should always be Integer.
val prefixComparator = PrefixComparators.LONG

// The prefix computer generates row hashcode as the prefix, so we may decrease the
// probability that the prefixes are equal when input rows choose column values from a
// limited range.
val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer {
private val result = new UnsafeExternalRowSorter.PrefixComputer.Prefix
override def computePrefix(
row: InternalRow): UnsafeExternalRowSorter.PrefixComputer.Prefix = {
// The hashcode generated from the binary form of a [[UnsafeRow]] should not be null.
result.isNull = false
result.value = row.hashCode()
result
}
}
val pageSize = SparkEnv.get.memoryManager.pageSizeBytes

val sorter = UnsafeExternalRowSorter.createWithRecordComparator(
StructType.fromAttributes(outputAttributes),
recordComparatorSupplier,
prefixComparator,
prefixComputer,
pageSize,
// We are comparing binary here, which does not support radix sort.
// See more details in SPARK-28699.
false)
sorter.sort(iter.asInstanceOf[Iterator[UnsafeRow]])
}
} else {
rdd
}

// round-robin function is order sensitive if we don't sort the input.
val isOrderSensitive = isRoundRobin && !SQLConf.get.sortBeforeRepartition
if (CometShuffleExchangeExec.needToCopyObjectsBeforeShuffle(part)) {
newRdd.mapPartitionsWithIndexInternal(
(_, iter) => {
val getPartitionKey = getPartitionKeyExtractor()
iter.map { row => (part.getPartition(getPartitionKey(row)), row.copy()) }
},
isOrderSensitive = isOrderSensitive)
} else {
newRdd.mapPartitionsWithIndexInternal(
(_, iter) => {
val getPartitionKey = getPartitionKeyExtractor()
val mutablePair = new MutablePair[Int, InternalRow]()
iter.map { row => mutablePair.update(part.getPartition(getPartitionKey(row)), row) }
},
isOrderSensitive = isOrderSensitive)
}
}

// Now, we manually create a ShuffleDependency. Because pairs in rddWithPartitionIds
// are in the form of (partitionId, row) and every partitionId is in the expected range
// [0, part.numPartitions - 1]. The partitioner of this is a PartitionIdPassthrough.
val dependency =
new CometShuffleDependency[Int, InternalRow, InternalRow](
sparkShuffleDep.rdd,
sparkShuffleDep.partitioner,
sparkShuffleDep.serializer,
shuffleWriterProcessor = sparkShuffleDep.shuffleWriterProcessor,
rddWithPartitionIds,
new PartitionIdPassthrough(part.numPartitions),
serializer,
shuffleWriterProcessor = ShuffleExchangeExec.createShuffleWriteProcessor(writeMetrics),
shuffleType = CometColumnarShuffle,
schema = Some(StructType.fromAttributes(outputAttributes)))

dependency
}
}
Expand Down Expand Up @@ -379,3 +556,18 @@ class CometShuffleWriteProcessor(
}
}
}

/**
* Copied from Spark `PartitionIdPassthrough` as it is private in Spark 3.2.
*/
private[spark] class PartitionIdPassthrough(override val numPartitions: Int) extends Partitioner {
override def getPartition(key: Any): Int = key.asInstanceOf[Int]
}

/**
* Copied from Spark `ConstantPartitioner` as it doesn't exist in Spark 3.2.
*/
private[spark] class ConstantPartitioner extends Partitioner {
override def numPartitions: Int = 1
override def getPartition(key: Any): Int = 0
}
Original file line number Diff line number Diff line change
Expand Up @@ -950,6 +950,7 @@ class CometShuffleSuite extends CometColumnarShuffleSuite {
.filter($"a" > 4)
.repartition(10)
.sortWithinPartitions($"a")
.filter($"a" >= 10)
checkSparkAnswerAndOperator(shuffled, classOf[ShuffleQueryStageExec])
}
}
Expand Down
Loading