From 31016e9e3fd5d27d5003588920c3746ab3d46ac9 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 26 Mar 2024 15:11:46 -0700 Subject: [PATCH] fix: Remove redundant data copy in columnar shuffle --- .../shuffle/CometShuffleExchangeExec.scala | 200 +++++++++++++++++- 1 file changed, 189 insertions(+), 11 deletions(-) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala index 47e6dc70c1..8d2c846c82 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala @@ -21,17 +21,23 @@ package org.apache.spark.sql.comet.execution.shuffle import java.nio.{ByteBuffer, ByteOrder} import java.nio.file.{Files, Paths} +import java.util.function.Supplier import scala.collection.JavaConverters.asJavaIterableConverter +import scala.collection.mutable import scala.concurrent.Future import org.apache.spark._ +import org.apache.spark.internal.config import org.apache.spark.rdd.{MapPartitionsRDD, RDD} import org.apache.spark.scheduler.MapStatus import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleWriteMetricsReporter, ShuffleWriteProcessor} +import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences +import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.comet.{CometExec, CometMetricNode, CometPlan} @@ -39,8 +45,12 @@ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, ShuffleExchangeLike, ShuffleOrigin} import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics, SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.util.MutablePair +import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, RecordComparator} +import org.apache.spark.util.random.XORShiftRandom import org.apache.comet.serde.{OperatorOuterClass, PartitioningOuterClass, QueryPlanSerde} import org.apache.comet.serde.OperatorOuterClass.Operator @@ -208,6 +218,49 @@ object CometShuffleExchangeExec extends ShimCometShuffleExchangeExec { dependency } + /** + * This is copied from Spark `ShuffleExchangeExec.needToCopyObjectsBeforeShuffle`. The only + * difference is that we use `BosonShuffleManager` instead of `SortShuffleManager`. + */ + private def needToCopyObjectsBeforeShuffle(partitioner: Partitioner): Boolean = { + // Note: even though we only use the partitioner's `numPartitions` field, we require it to be + // passed instead of directly passing the number of partitions in order to guard against + // corner-cases where a partitioner constructed with `numPartitions` partitions may output + // fewer partitions (like RangePartitioner, for example). + val conf = SparkEnv.get.conf + val shuffleManager = SparkEnv.get.shuffleManager + val sortBasedShuffleOn = shuffleManager.isInstanceOf[CometShuffleManager] + val bypassMergeThreshold = conf.get(config.SHUFFLE_SORT_BYPASS_MERGE_THRESHOLD) + val numParts = partitioner.numPartitions + if (sortBasedShuffleOn) { + if (numParts <= bypassMergeThreshold) { + // If we're using the original SortShuffleManager and the number of output partitions is + // sufficiently small, then Spark will fall back to the hash-based shuffle write path, which + // doesn't buffer deserialized records. + // Note that we'll have to remove this case if we fix SPARK-6026 and remove this bypass. + false + } else if (numParts <= SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE) { + // SPARK-4550 and SPARK-7081 extended sort-based shuffle to serialize individual records + // prior to sorting them. This optimization is only applied in cases where shuffle + // dependency does not specify an aggregator or ordering and the record serializer has + // certain properties and the number of partitions doesn't exceed the limitation. If this + // optimization is enabled, we can safely avoid the copy. + // + // Exchange never configures its ShuffledRDDs with aggregators or key orderings, and the + // serializer in Spark SQL always satisfy the properties, so we only need to check whether + // the number of partitions exceeds the limitation. + false + } else { + // Spark's SortShuffleManager uses `ExternalSorter` to buffer records in memory, so we must + // copy. + true + } + } else { + // Catch-all case to safely handle any future ShuffleManager implementations. + true + } + } + /** * Returns a [[ShuffleDependency]] that will partition rows of its child based on the * partitioning scheme defined in `newPartitioning`. Those partitions of the returned @@ -219,21 +272,146 @@ object CometShuffleExchangeExec extends ShimCometShuffleExchangeExec { newPartitioning: Partitioning, serializer: Serializer, writeMetrics: Map[String, SQLMetric]): ShuffleDependency[Int, InternalRow, InternalRow] = { - val sparkShuffleDep = ShuffleExchangeExec.prepareShuffleDependency( - rdd, - outputAttributes, - newPartitioning, - serializer, - writeMetrics) + val part: Partitioner = newPartitioning match { + case RoundRobinPartitioning(numPartitions) => new HashPartitioner(numPartitions) + case HashPartitioning(_, n) => + // For HashPartitioning, the partitioning key is already a valid partition ID, as we use + // `HashPartitioning.partitionIdExpression` to produce partitioning key. + new PartitionIdPassthrough(n) + case RangePartitioning(sortingExpressions, numPartitions) => + // Extract only fields used for sorting to avoid collecting large fields that does not + // affect sorting result when deciding partition bounds in RangePartitioner + val rddForSampling = rdd.mapPartitionsInternal { iter => + val projection = + UnsafeProjection.create(sortingExpressions.map(_.child), outputAttributes) + val mutablePair = new MutablePair[InternalRow, Null]() + // Internally, RangePartitioner runs a job on the RDD that samples keys to compute + // partition bounds. To get accurate samples, we need to copy the mutable keys. + iter.map(row => mutablePair.update(projection(row).copy(), null)) + } + // Construct ordering on extracted sort key. + val orderingAttributes = sortingExpressions.zipWithIndex.map { case (ord, i) => + ord.copy(child = BoundReference(i, ord.dataType, ord.nullable)) + } + implicit val ordering = new LazilyGeneratedOrdering(orderingAttributes) + new RangePartitioner( + numPartitions, + rddForSampling, + ascending = true, + samplePointsPerPartitionHint = SQLConf.get.rangeExchangeSampleSizePerPartition) + case SinglePartition => new ConstantPartitioner + case _ => throw new IllegalStateException(s"Exchange not implemented for $newPartitioning") + // TODO: Handle BroadcastPartitioning. + } + def getPartitionKeyExtractor(): InternalRow => Any = newPartitioning match { + case RoundRobinPartitioning(numPartitions) => + // Distributes elements evenly across output partitions, starting from a random partition. + // nextInt(numPartitions) implementation has a special case when bound is a power of 2, + // which is basically taking several highest bits from the initial seed, with only a + // minimal scrambling. Due to deterministic seed, using the generator only once, + // and lack of scrambling, the position values for power-of-two numPartitions always + // end up being almost the same regardless of the index. substantially scrambling the + // seed by hashing will help. Refer to SPARK-21782 for more details. + val partitionId = TaskContext.get().partitionId() + var position = new XORShiftRandom(partitionId).nextInt(numPartitions) + (row: InternalRow) => { + // The HashPartitioner will handle the `mod` by the number of partitions + position += 1 + position + } + case h: HashPartitioning => + val projection = UnsafeProjection.create(h.partitionIdExpression :: Nil, outputAttributes) + row => projection(row).getInt(0) + case RangePartitioning(sortingExpressions, _) => + val projection = + UnsafeProjection.create(sortingExpressions.map(_.child), outputAttributes) + row => projection(row) + case SinglePartition => identity + case _ => throw new IllegalStateException(s"Exchange not implemented for $newPartitioning") + } + val isRoundRobin = newPartitioning.isInstanceOf[RoundRobinPartitioning] && + newPartitioning.numPartitions > 1 + + val rddWithPartitionIds: RDD[Product2[Int, InternalRow]] = { + // [SPARK-23207] Have to make sure the generated RoundRobinPartitioning is deterministic, + // otherwise a retry task may output different rows and thus lead to data loss. + // + // Currently we following the most straight-forward way that perform a local sort before + // partitioning. + // + // Note that we don't perform local sort if the new partitioning has only 1 partition, under + // that case all output rows go to the same partition. + val newRdd = if (isRoundRobin && SQLConf.get.sortBeforeRepartition) { + rdd.mapPartitionsInternal { iter => + val recordComparatorSupplier = new Supplier[RecordComparator] { + override def get: RecordComparator = new RecordBinaryComparator() + } + // The comparator for comparing row hashcode, which should always be Integer. + val prefixComparator = PrefixComparators.LONG + + // The prefix computer generates row hashcode as the prefix, so we may decrease the + // probability that the prefixes are equal when input rows choose column values from a + // limited range. + val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer { + private val result = new UnsafeExternalRowSorter.PrefixComputer.Prefix + override def computePrefix( + row: InternalRow): UnsafeExternalRowSorter.PrefixComputer.Prefix = { + // The hashcode generated from the binary form of a [[UnsafeRow]] should not be null. + result.isNull = false + result.value = row.hashCode() + result + } + } + val pageSize = SparkEnv.get.memoryManager.pageSizeBytes + + val sorter = UnsafeExternalRowSorter.createWithRecordComparator( + StructType.fromAttributes(outputAttributes), + recordComparatorSupplier, + prefixComparator, + prefixComputer, + pageSize, + // We are comparing binary here, which does not support radix sort. + // See more details in SPARK-28699. + false) + sorter.sort(iter.asInstanceOf[Iterator[UnsafeRow]]) + } + } else { + rdd + } + + // round-robin function is order sensitive if we don't sort the input. + val isOrderSensitive = isRoundRobin && !SQLConf.get.sortBeforeRepartition + if (CometShuffleExchangeExec.needToCopyObjectsBeforeShuffle(part)) { + newRdd.mapPartitionsWithIndexInternal( + (_, iter) => { + val getPartitionKey = getPartitionKeyExtractor() + iter.map { row => (part.getPartition(getPartitionKey(row)), row.copy()) } + }, + isOrderSensitive = isOrderSensitive) + } else { + newRdd.mapPartitionsWithIndexInternal( + (_, iter) => { + val getPartitionKey = getPartitionKeyExtractor() + val mutablePair = new MutablePair[Int, InternalRow]() + iter.map { row => mutablePair.update(part.getPartition(getPartitionKey(row)), row) } + }, + isOrderSensitive = isOrderSensitive) + } + } + + // Now, we manually create a ShuffleDependency. Because pairs in rddWithPartitionIds + // are in the form of (partitionId, row) and every partitionId is in the expected range + // [0, part.numPartitions - 1]. The partitioner of this is a PartitionIdPassthrough. val dependency = new CometShuffleDependency[Int, InternalRow, InternalRow]( - sparkShuffleDep.rdd, - sparkShuffleDep.partitioner, - sparkShuffleDep.serializer, - shuffleWriterProcessor = sparkShuffleDep.shuffleWriterProcessor, + rddWithPartitionIds, + new PartitionIdPassthrough(part.numPartitions), + serializer, + shuffleWriterProcessor = ShuffleExchangeExec.createShuffleWriteProcessor(writeMetrics), shuffleType = CometColumnarShuffle, schema = Some(StructType.fromAttributes(outputAttributes))) + dependency } }