Skip to content

Commit

Permalink
fix: Remove redundant data copy in columnar shuffle
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Mar 26, 2024
1 parent 0826772 commit 686e382
Showing 1 changed file with 203 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,26 +21,35 @@ package org.apache.spark.sql.comet.execution.shuffle

import java.nio.{ByteBuffer, ByteOrder}
import java.nio.file.{Files, Paths}
import java.util.function.Supplier

import scala.collection.JavaConverters.asJavaIterableConverter
import scala.collection.mutable
import scala.concurrent.Future

import org.apache.spark._
import org.apache.spark.internal.config
import org.apache.spark.rdd.{MapPartitionsRDD, RDD}
import org.apache.spark.scheduler.MapStatus
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleWriteMetricsReporter, ShuffleWriteProcessor}
import org.apache.spark.shuffle.sort.SortShuffleManager
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering
import org.apache.spark.sql.catalyst.plans.logical.Statistics
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.comet.{CometExec, CometMetricNode, CometPlan}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, ShuffleExchangeLike, ShuffleOrigin}
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics, SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.MutablePair
import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, RecordComparator}
import org.apache.spark.util.random.XORShiftRandom

import org.apache.comet.serde.{OperatorOuterClass, PartitioningOuterClass, QueryPlanSerde}
import org.apache.comet.serde.OperatorOuterClass.Operator
Expand Down Expand Up @@ -208,6 +217,49 @@ object CometShuffleExchangeExec extends ShimCometShuffleExchangeExec {
dependency
}

/**
* This is copied from Spark `ShuffleExchangeExec.needToCopyObjectsBeforeShuffle`. The only
* difference is that we use `BosonShuffleManager` instead of `SortShuffleManager`.
*/
private def needToCopyObjectsBeforeShuffle(partitioner: Partitioner): Boolean = {
// Note: even though we only use the partitioner's `numPartitions` field, we require it to be
// passed instead of directly passing the number of partitions in order to guard against
// corner-cases where a partitioner constructed with `numPartitions` partitions may output
// fewer partitions (like RangePartitioner, for example).
val conf = SparkEnv.get.conf
val shuffleManager = SparkEnv.get.shuffleManager
val sortBasedShuffleOn = shuffleManager.isInstanceOf[CometShuffleManager]
val bypassMergeThreshold = conf.get(config.SHUFFLE_SORT_BYPASS_MERGE_THRESHOLD)
val numParts = partitioner.numPartitions
if (sortBasedShuffleOn) {
if (numParts <= bypassMergeThreshold) {
// If we're using the original SortShuffleManager and the number of output partitions is
// sufficiently small, then Spark will fall back to the hash-based shuffle write path, which
// doesn't buffer deserialized records.
// Note that we'll have to remove this case if we fix SPARK-6026 and remove this bypass.
false
} else if (numParts <= SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE) {
// SPARK-4550 and SPARK-7081 extended sort-based shuffle to serialize individual records
// prior to sorting them. This optimization is only applied in cases where shuffle
// dependency does not specify an aggregator or ordering and the record serializer has
// certain properties and the number of partitions doesn't exceed the limitation. If this
// optimization is enabled, we can safely avoid the copy.
//
// Exchange never configures its ShuffledRDDs with aggregators or key orderings, and the
// serializer in Spark SQL always satisfy the properties, so we only need to check whether
// the number of partitions exceeds the limitation.
false
} else {
// Spark's SortShuffleManager uses `ExternalSorter` to buffer records in memory, so we must
// copy.
true
}
} else {
// Catch-all case to safely handle any future ShuffleManager implementations.
true
}
}

/**
* Returns a [[ShuffleDependency]] that will partition rows of its child based on the
* partitioning scheme defined in `newPartitioning`. Those partitions of the returned
Expand All @@ -219,21 +271,146 @@ object CometShuffleExchangeExec extends ShimCometShuffleExchangeExec {
newPartitioning: Partitioning,
serializer: Serializer,
writeMetrics: Map[String, SQLMetric]): ShuffleDependency[Int, InternalRow, InternalRow] = {
val sparkShuffleDep = ShuffleExchangeExec.prepareShuffleDependency(
rdd,
outputAttributes,
newPartitioning,
serializer,
writeMetrics)
val part: Partitioner = newPartitioning match {
case RoundRobinPartitioning(numPartitions) => new HashPartitioner(numPartitions)
case HashPartitioning(_, n) =>
// For HashPartitioning, the partitioning key is already a valid partition ID, as we use
// `HashPartitioning.partitionIdExpression` to produce partitioning key.
new PartitionIdPassthrough(n)
case RangePartitioning(sortingExpressions, numPartitions) =>
// Extract only fields used for sorting to avoid collecting large fields that does not
// affect sorting result when deciding partition bounds in RangePartitioner
val rddForSampling = rdd.mapPartitionsInternal { iter =>
val projection =
UnsafeProjection.create(sortingExpressions.map(_.child), outputAttributes)
val mutablePair = new MutablePair[InternalRow, Null]()
// Internally, RangePartitioner runs a job on the RDD that samples keys to compute
// partition bounds. To get accurate samples, we need to copy the mutable keys.
iter.map(row => mutablePair.update(projection(row).copy(), null))
}
// Construct ordering on extracted sort key.
val orderingAttributes = sortingExpressions.zipWithIndex.map { case (ord, i) =>
ord.copy(child = BoundReference(i, ord.dataType, ord.nullable))
}
implicit val ordering = new LazilyGeneratedOrdering(orderingAttributes)
new RangePartitioner(
numPartitions,
rddForSampling,
ascending = true,
samplePointsPerPartitionHint = SQLConf.get.rangeExchangeSampleSizePerPartition)
case SinglePartition => new ConstantPartitioner
case _ => throw new IllegalStateException(s"Exchange not implemented for $newPartitioning")
// TODO: Handle BroadcastPartitioning.
}
def getPartitionKeyExtractor(): InternalRow => Any = newPartitioning match {
case RoundRobinPartitioning(numPartitions) =>
// Distributes elements evenly across output partitions, starting from a random partition.
// nextInt(numPartitions) implementation has a special case when bound is a power of 2,
// which is basically taking several highest bits from the initial seed, with only a
// minimal scrambling. Due to deterministic seed, using the generator only once,
// and lack of scrambling, the position values for power-of-two numPartitions always
// end up being almost the same regardless of the index. substantially scrambling the
// seed by hashing will help. Refer to SPARK-21782 for more details.
val partitionId = TaskContext.get().partitionId()
var position = new XORShiftRandom(partitionId).nextInt(numPartitions)
(row: InternalRow) => {
// The HashPartitioner will handle the `mod` by the number of partitions
position += 1
position
}
case h: HashPartitioning =>
val projection = UnsafeProjection.create(h.partitionIdExpression :: Nil, outputAttributes)
row => projection(row).getInt(0)
case RangePartitioning(sortingExpressions, _) =>
val projection =
UnsafeProjection.create(sortingExpressions.map(_.child), outputAttributes)
row => projection(row)
case SinglePartition => identity
case _ => throw new IllegalStateException(s"Exchange not implemented for $newPartitioning")
}

val isRoundRobin = newPartitioning.isInstanceOf[RoundRobinPartitioning] &&
newPartitioning.numPartitions > 1

val rddWithPartitionIds: RDD[Product2[Int, InternalRow]] = {
// [SPARK-23207] Have to make sure the generated RoundRobinPartitioning is deterministic,
// otherwise a retry task may output different rows and thus lead to data loss.
//
// Currently we following the most straight-forward way that perform a local sort before
// partitioning.
//
// Note that we don't perform local sort if the new partitioning has only 1 partition, under
// that case all output rows go to the same partition.
val newRdd = if (isRoundRobin && SQLConf.get.sortBeforeRepartition) {
rdd.mapPartitionsInternal { iter =>
val recordComparatorSupplier = new Supplier[RecordComparator] {
override def get: RecordComparator = new RecordBinaryComparator()
}
// The comparator for comparing row hashcode, which should always be Integer.
val prefixComparator = PrefixComparators.LONG

// The prefix computer generates row hashcode as the prefix, so we may decrease the
// probability that the prefixes are equal when input rows choose column values from a
// limited range.
val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer {
private val result = new UnsafeExternalRowSorter.PrefixComputer.Prefix
override def computePrefix(
row: InternalRow): UnsafeExternalRowSorter.PrefixComputer.Prefix = {
// The hashcode generated from the binary form of a [[UnsafeRow]] should not be null.
result.isNull = false
result.value = row.hashCode()
result
}
}
val pageSize = SparkEnv.get.memoryManager.pageSizeBytes

val sorter = UnsafeExternalRowSorter.createWithRecordComparator(
StructType.fromAttributes(outputAttributes),
recordComparatorSupplier,
prefixComparator,
prefixComputer,
pageSize,
// We are comparing binary here, which does not support radix sort.
// See more details in SPARK-28699.
false)
sorter.sort(iter.asInstanceOf[Iterator[UnsafeRow]])
}
} else {
rdd
}

// round-robin function is order sensitive if we don't sort the input.
val isOrderSensitive = isRoundRobin && !SQLConf.get.sortBeforeRepartition
if (CometShuffleExchangeExec.needToCopyObjectsBeforeShuffle(part)) {
newRdd.mapPartitionsWithIndexInternal(
(_, iter) => {
val getPartitionKey = getPartitionKeyExtractor()
iter.map { row => (part.getPartition(getPartitionKey(row)), row.copy()) }
},
isOrderSensitive = isOrderSensitive)
} else {
newRdd.mapPartitionsWithIndexInternal(
(_, iter) => {
val getPartitionKey = getPartitionKeyExtractor()
val mutablePair = new MutablePair[Int, InternalRow]()
iter.map { row => mutablePair.update(part.getPartition(getPartitionKey(row)), row) }
},
isOrderSensitive = isOrderSensitive)
}
}

// Now, we manually create a ShuffleDependency. Because pairs in rddWithPartitionIds
// are in the form of (partitionId, row) and every partitionId is in the expected range
// [0, part.numPartitions - 1]. The partitioner of this is a PartitionIdPassthrough.
val dependency =
new CometShuffleDependency[Int, InternalRow, InternalRow](
sparkShuffleDep.rdd,
sparkShuffleDep.partitioner,
sparkShuffleDep.serializer,
shuffleWriterProcessor = sparkShuffleDep.shuffleWriterProcessor,
rddWithPartitionIds,
new PartitionIdPassthrough(part.numPartitions),
serializer,
shuffleWriterProcessor = ShuffleExchangeExec.createShuffleWriteProcessor(writeMetrics),
shuffleType = CometColumnarShuffle,
schema = Some(StructType.fromAttributes(outputAttributes)))

dependency
}
}
Expand Down Expand Up @@ -379,3 +556,18 @@ class CometShuffleWriteProcessor(
}
}
}

/**
* Copied from Spark `PartitionIdPassthrough` as it is private in Spark 3.2.
*/
private[spark] class PartitionIdPassthrough(override val numPartitions: Int) extends Partitioner {
override def getPartition(key: Any): Int = key.asInstanceOf[Int]
}

/**
* Copied from Spark `ConstantPartitioner` as it doesn't exist in Spark 3.2.
*/
private[spark] class ConstantPartitioner extends Partitioner {
override def numPartitions: Int = 1
override def getPartition(key: Any): Int = 0
}

0 comments on commit 686e382

Please sign in to comment.